deluca/agents/_gpc.py
changed.
deluca/agents/__init__.py
changed.
Newly tracked file
deluca/agents/_bpc.py
created.
Other files ignored by Codecov
examples/agents/BPC test.ipynb
is new.
87 | 87 | # Model Parameters |
|
88 | 88 | # initial linear policy / perturbation contributions / bias |
|
89 | 89 | # TODO: need to address problem of LQR with jax.lax.scan |
|
90 | - | self.K = K if K is not None else LQR(self.A, self.B, Q, R).K |
|
90 | + | self.K = K if K is not None else LQR(self.A, self.B, Q, R).K |
|
91 | 91 | ||
92 | 92 | self.M = jnp.zeros((H, d_action, d_state)) |
|
93 | 93 |
156 | 156 | lr = self.lr_scale |
|
157 | 157 | lr *= (1/ (self.t+1)) if self.decay else 1 |
|
158 | 158 | self.M -= lr * delta_M |
|
159 | - | self.M -= lr * delta_M |
|
159 | + | self.bias -= lr * delta_bias |
|
160 | 160 | ||
161 | 161 | # update state |
|
162 | 162 | self.state = state |
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
12 | 12 | # See the License for the specific language governing permissions and |
|
13 | 13 | # limitations under the License. |
|
14 | + | from deluca.agents._bpc import BPC |
|
14 | 15 | from deluca.agents._gpc import GPC |
|
15 | 16 | from deluca.agents._hinf import Hinf |
|
16 | 17 | from deluca.agents._ilqr import ILQR |
21 | 22 | from deluca.agents._adaptive import Adaptive |
|
22 | 23 | from deluca.agents._deep import Deep |
|
23 | 24 | ||
24 | - | __all__ = ["LQR", "PID", "GPC", "ILQR", "Hinf", "Zero", "DRC", "Adaptive", "Deep"] |
|
25 | + | __all__ = ["LQR", "PID", "BPC", "GPC", "ILQR", "Hinf", "Zero", "DRC", "Adaptive", "Deep"] |
1 | + | # Copyright 2020 Google LLC |
|
2 | + | # |
|
3 | + | # Licensed under the Apache License, Version 2.0 (the "License"); |
|
4 | + | # you may not use this file except in compliance with the License. |
|
5 | + | # You may obtain a copy of the License at |
|
6 | + | # |
|
7 | + | # https://www.apache.org/licenses/LICENSE-2.0 |
|
8 | + | # |
|
9 | + | # Unless required by applicable law or agreed to in writing, software |
|
10 | + | # distributed under the License is distributed on an "AS IS" BASIS, |
|
11 | + | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
12 | + | # See the License for the specific language governing permissions and |
|
13 | + | # limitations under the License. |
|
14 | + | """deluca.agents._bpc""" |
|
15 | + | from numbers import Real |
|
16 | + | from typing import Callable |
|
17 | + | ||
18 | + | import jax |
|
19 | + | import jax.numpy as jnp |
|
20 | + | import numpy as np |
|
21 | + | import numpy.random as random |
|
22 | + | from jax import grad |
|
23 | + | from jax import jit |
|
24 | + | ||
25 | + | from deluca.agents._lqr import LQR |
|
26 | + | from deluca.agents.core import Agent |
|
27 | + | ||
28 | + | def generate_uniform(shape, norm=1.00): |
|
29 | + | v = random.normal(size=shape) |
|
30 | + | v = norm * v / np.linalg.norm(v) |
|
31 | + | v = np.array(v) |
|
32 | + | return v |
|
33 | + | ||
34 | + | class BPC(Agent): |
|
35 | + | def __init__( |
|
36 | + | self, |
|
37 | + | A: jnp.ndarray, |
|
38 | + | B: jnp.ndarray, |
|
39 | + | Q: jnp.ndarray = None, |
|
40 | + | R: jnp.ndarray = None, |
|
41 | + | K: jnp.ndarray = None, |
|
42 | + | start_time: int = 0, |
|
43 | + | H: int = 5, |
|
44 | + | lr_scale: Real = 0.005, |
|
45 | + | decay: bool = False, |
|
46 | + | delta: Real = 0.01 |
|
47 | + | ) -> None: |
|
48 | + | """ |
|
49 | + | Description: Initialize the dynamics of the model. |
|
50 | + | ||
51 | + | Args: |
|
52 | + | A (jnp.ndarray): system dynamics |
|
53 | + | B (jnp.ndarray): system dynamics |
|
54 | + | Q (jnp.ndarray): cost matrices (i.e. cost = x^TQx + u^TRu) |
|
55 | + | R (jnp.ndarray): cost matrices (i.e. cost = x^TQx + u^TRu) |
|
56 | + | K (jnp.ndarray): Starting policy (optional). Defaults to LQR gain. |
|
57 | + | start_time (int): |
|
58 | + | H (postive int): history of the controller |
|
59 | + | lr_scale (Real): |
|
60 | + | decay (boolean): |
|
61 | + | """ |
|
62 | + | ||
63 | + | self.d_state, self.d_action = B.shape # State & Action Dimensions |
|
64 | + | ||
65 | + | self.A, self.B = A, B # System Dynamics |
|
66 | + | ||
67 | + | self.t = 0 # Time Counter (for decaying learning rate) |
|
68 | + | ||
69 | + | self.H = H |
|
70 | + | ||
71 | + | self.lr_scale, self.decay = lr_scale, decay |
|
72 | + | ||
73 | + | self.delta = delta |
|
74 | + | ||
75 | + | # Model Parameters |
|
76 | + | # initial linear policy / perturbation contributions / bias |
|
77 | + | # TODO: need to address problem of LQR with jax.lax.scan |
|
78 | + | self.K = K if K is not None else LQR(self.A, self.B, Q, R).K |
|
79 | + | ||
80 | + | self.M = self.delta * generate_uniform((H, self.d_action, self.d_state)) |
|
81 | + | ||
82 | + | # Past H noises ordered increasing in time |
|
83 | + | self.noise_history = jnp.zeros((H, self.d_state, 1)) |
|
84 | + | ||
85 | + | # past state and past action |
|
86 | + | self.state, self.action = jnp.zeros((self.d_state, 1)), jnp.zeros((self.d_action, 1)) |
|
87 | + | ||
88 | + | self.eps = generate_uniform((H, H, self.d_action, self.d_state)) |
|
89 | + | self.eps_bias = generate_uniform((H, self.d_action, 1)) |
|
90 | + | ||
91 | + | def grad(M, noise_history, cost): |
|
92 | + | return cost * jnp.sum(self.eps, axis = 0) |
|
93 | + | ||
94 | + | self.grad = grad |
|
95 | + | ||
96 | + | def __call__(self, |
|
97 | + | state: jnp.ndarray, |
|
98 | + | cost: Real |
|
99 | + | ) -> jnp.ndarray: |
|
100 | + | """ |
|
101 | + | Description: Return the action based on current state and internal parameters. |
|
102 | + | ||
103 | + | Args: |
|
104 | + | state (jnp.ndarray): current state |
|
105 | + | ||
106 | + | Returns: |
|
107 | + | jnp.ndarray: action to take |
|
108 | + | """ |
|
109 | + | ||
110 | + | action = self.get_action(state) |
|
111 | + | self.update(state, action, cost) |
|
112 | + | return action |
|
113 | + | ||
114 | + | def update(self, |
|
115 | + | state: jnp.ndarray, |
|
116 | + | action:jnp.ndarray, |
|
117 | + | cost: Real |
|
118 | + | ) -> None: |
|
119 | + | """ |
|
120 | + | Description: update agent internal state. |
|
121 | + | ||
122 | + | Args: |
|
123 | + | state (jnp.ndarray): current state |
|
124 | + | action (jnp.ndarray): action taken |
|
125 | + | cost (Real): scalar cost received |
|
126 | + | ||
127 | + | Returns: |
|
128 | + | None |
|
129 | + | """ |
|
130 | + | noise = state - self.A @ self.state - self.B @ action |
|
131 | + | self.noise_history = jax.ops.index_update(self.noise_history, 0, noise) |
|
132 | + | self.noise_history = jnp.roll(self.noise_history, -1, axis=0) |
|
133 | + | ||
134 | + | lr = self.lr_scale |
|
135 | + | lr *= (1/ (self.t**(3/4)+1)) if self.decay else 1 |
|
136 | + | ||
137 | + | delta_M = self.grad(self.M, self.noise_history, cost) |
|
138 | + | self.M -= lr * delta_M |
|
139 | + | ||
140 | + | self.eps = jax.ops.index_update(self.eps, 0, \ |
|
141 | + | generate_uniform((self.H, self.d_action, self.d_state))) |
|
142 | + | self.eps = np.roll(self.eps, -1, axis = 0) |
|
143 | + | ||
144 | + | self.M += self.delta * self.eps[-1] |
|
145 | + | ||
146 | + | # update state |
|
147 | + | self.state = state |
|
148 | + | ||
149 | + | self.t += 1 |
|
150 | + | ||
151 | + | def get_action(self, state: jnp.ndarray) -> jnp.ndarray: |
|
152 | + | """ |
|
153 | + | Description: get action from state. |
|
154 | + | ||
155 | + | Args: |
|
156 | + | state (jnp.ndarray): |
|
157 | + | ||
158 | + | Returns: |
|
159 | + | jnp.ndarray |
|
160 | + | """ |
|
161 | + | return -self.K @ state + jnp.tensordot(self.M, self.noise_history, axes=([0, 2], [0, 1])) |