MinRegret / deluca
Showing 3 of 4 files from the diff.

@@ -87,7 +87,7 @@
Loading
87 87
        # Model Parameters
88 88
        # initial linear policy / perturbation contributions / bias
89 89
        # TODO: need to address problem of LQR with jax.lax.scan
90 -
        self.K = K if K is not None else LQR(self.A, self.B, Q, R).K 
90 +
        self.K = K if K is not None else LQR(self.A, self.B, Q, R).K
91 91
92 92
        self.M = jnp.zeros((H, d_action, d_state))
93 93
@@ -156,7 +156,7 @@
Loading
156 156
        lr = self.lr_scale
157 157
        lr *= (1/ (self.t+1)) if self.decay else 1
158 158
        self.M -= lr * delta_M
159 -
        self.M -= lr * delta_M
159 +
        self.bias -= lr * delta_bias
160 160
161 161
        # update state
162 162
        self.state = state

@@ -11,6 +11,7 @@
Loading
11 11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 12
# See the License for the specific language governing permissions and
13 13
# limitations under the License.
14 +
from deluca.agents._bpc import BPC
14 15
from deluca.agents._gpc import GPC
15 16
from deluca.agents._hinf import Hinf
16 17
from deluca.agents._ilqr import ILQR
@@ -21,4 +22,4 @@
Loading
21 22
from deluca.agents._adaptive import Adaptive
22 23
from deluca.agents._deep import Deep
23 24
24 -
__all__ = ["LQR", "PID", "GPC", "ILQR", "Hinf", "Zero", "DRC", "Adaptive", "Deep"]
25 +
__all__ = ["LQR", "PID", "BPC", "GPC", "ILQR", "Hinf", "Zero", "DRC", "Adaptive", "Deep"]

@@ -0,0 +1,161 @@
Loading
1 +
# Copyright 2020 Google LLC
2 +
#
3 +
# Licensed under the Apache License, Version 2.0 (the "License");
4 +
# you may not use this file except in compliance with the License.
5 +
# You may obtain a copy of the License at
6 +
#
7 +
#     https://www.apache.org/licenses/LICENSE-2.0
8 +
#
9 +
# Unless required by applicable law or agreed to in writing, software
10 +
# distributed under the License is distributed on an "AS IS" BASIS,
11 +
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +
# See the License for the specific language governing permissions and
13 +
# limitations under the License.
14 +
"""deluca.agents._bpc"""
15 +
from numbers import Real
16 +
from typing import Callable
17 +
18 +
import jax
19 +
import jax.numpy as jnp
20 +
import numpy as np
21 +
import numpy.random as random
22 +
from jax import grad
23 +
from jax import jit
24 +
25 +
from deluca.agents._lqr import LQR
26 +
from deluca.agents.core import Agent
27 +
28 +
def generate_uniform(shape, norm=1.00):
29 +
            v = random.normal(size=shape)
30 +
            v = norm * v / np.linalg.norm(v)
31 +
            v = np.array(v)
32 +
            return v
33 +
34 +
class BPC(Agent):
35 +
    def __init__(
36 +
        self,
37 +
        A: jnp.ndarray,
38 +
        B: jnp.ndarray,
39 +
        Q: jnp.ndarray = None,
40 +
        R: jnp.ndarray = None,
41 +
        K: jnp.ndarray = None,
42 +
        start_time: int = 0,
43 +
        H: int = 5,
44 +
        lr_scale: Real = 0.005,
45 +
        decay: bool = False,
46 +
        delta: Real = 0.01
47 +
    ) -> None:
48 +
        """
49 +
        Description: Initialize the dynamics of the model.
50 +
51 +
        Args:
52 +
            A (jnp.ndarray): system dynamics
53 +
            B (jnp.ndarray): system dynamics
54 +
            Q (jnp.ndarray): cost matrices (i.e. cost = x^TQx + u^TRu)
55 +
            R (jnp.ndarray): cost matrices (i.e. cost = x^TQx + u^TRu)
56 +
            K (jnp.ndarray): Starting policy (optional). Defaults to LQR gain.
57 +
            start_time (int):
58 +
            H (postive int): history of the controller
59 +
            lr_scale (Real):
60 +
            decay (boolean):
61 +
        """
62 +
63 +
        self.d_state, self.d_action = B.shape  # State & Action Dimensions
64 +
65 +
        self.A, self.B = A, B  # System Dynamics
66 +
67 +
        self.t = 0  # Time Counter (for decaying learning rate)
68 +
69 +
        self.H = H
70 +
71 +
        self.lr_scale, self.decay = lr_scale, decay
72 +
73 +
        self.delta = delta
74 +
75 +
        # Model Parameters
76 +
        # initial linear policy / perturbation contributions / bias
77 +
        # TODO: need to address problem of LQR with jax.lax.scan
78 +
        self.K = K if K is not None else LQR(self.A, self.B, Q, R).K
79 +
80 +
        self.M = self.delta * generate_uniform((H, self.d_action, self.d_state))
81 +
82 +
        # Past H noises ordered increasing in time
83 +
        self.noise_history = jnp.zeros((H, self.d_state, 1))
84 +
85 +
        # past state and past action
86 +
        self.state, self.action = jnp.zeros((self.d_state, 1)), jnp.zeros((self.d_action, 1))
87 +
88 +
        self.eps = generate_uniform((H, H, self.d_action, self.d_state))
89 +
        self.eps_bias = generate_uniform((H, self.d_action, 1))
90 +
91 +
        def grad(M, noise_history, cost):
92 +
            return cost * jnp.sum(self.eps, axis = 0)
93 +
94 +
        self.grad = grad
95 +
96 +
    def __call__(self,
97 +
                state: jnp.ndarray,
98 +
                cost: Real
99 +
                ) -> jnp.ndarray:
100 +
        """
101 +
        Description: Return the action based on current state and internal parameters.
102 +
103 +
        Args:
104 +
            state (jnp.ndarray): current state
105 +
106 +
        Returns:
107 +
           jnp.ndarray: action to take
108 +
        """
109 +
110 +
        action = self.get_action(state)
111 +
        self.update(state, action, cost)
112 +
        return action
113 +
114 +
    def update(self,
115 +
            state: jnp.ndarray,
116 +
            action:jnp.ndarray,
117 +
            cost: Real
118 +
            ) -> None:
119 +
        """
120 +
        Description: update agent internal state.
121 +
122 +
        Args:
123 +
            state (jnp.ndarray): current state
124 +
            action (jnp.ndarray): action taken
125 +
            cost (Real): scalar cost received
126 +
127 +
        Returns:
128 +
            None
129 +
        """
130 +
        noise = state - self.A @ self.state - self.B @ action
131 +
        self.noise_history = jax.ops.index_update(self.noise_history, 0, noise)
132 +
        self.noise_history = jnp.roll(self.noise_history, -1, axis=0)
133 +
134 +
        lr = self.lr_scale
135 +
        lr *= (1/ (self.t**(3/4)+1)) if self.decay else 1
136 +
137 +
        delta_M = self.grad(self.M, self.noise_history, cost)
138 +
        self.M -= lr * delta_M
139 +
140 +
        self.eps = jax.ops.index_update(self.eps, 0, \
141 +
                        generate_uniform((self.H, self.d_action, self.d_state)))
142 +
        self.eps = np.roll(self.eps, -1, axis = 0)
143 +
144 +
        self.M += self.delta * self.eps[-1]
145 +
146 +
        # update state
147 +
        self.state = state
148 +
149 +
        self.t += 1
150 +
151 +
    def get_action(self, state: jnp.ndarray) -> jnp.ndarray:
152 +
        """
153 +
        Description: get action from state.
154 +
155 +
        Args:
156 +
            state (jnp.ndarray):
157 +
158 +
        Returns:
159 +
            jnp.ndarray
160 +
        """
161 +
        return -self.K @ state + jnp.tensordot(self.M, self.noise_history, axes=([0, 2], [0, 1]))
Files Coverage
deluca 34.90%
tests 100.00%
Project Totals (40 files) 38.51%
Sunburst
The inner-most circle is the entire project, moving away from the center are folders then, finally, a single file. The size and color of each slice is representing the number of statements and the coverage, respectively.
Icicle
The top section represents the entire project. Proceeding with folders and finally individual files. The size and color of each slice is representing the number of statements and the coverage, respectively.
Grid
Each block represents a single file in the project. The size and color of each block is represented by the number of statements and the coverage, respectively.
Loading