1
|
|
# Organization-wise, this file is actually "a part of" solvers.py.
|
2
|
|
# However, due to the fact that this is only compatible with nengo>=2.5.0
|
3
|
|
# it simplifies life to move this into its own file.
|
4
|
|
|
5
|
1
|
import numpy as np
|
6
|
|
|
7
|
1
|
from nengo import Ensemble
|
8
|
1
|
from nengo.builder import Builder
|
9
|
1
|
from nengo.builder.neurons import SimNeurons
|
10
|
1
|
from nengo.builder.signal import SignalDict
|
11
|
1
|
from nengo.config import SupportDefaultsMixin
|
12
|
1
|
from nengo.params import Default
|
13
|
1
|
from nengo.solvers import Solver, LstsqL2, SolverParam
|
14
|
1
|
from nengo.synapses import SynapseParam, Lowpass
|
15
|
|
|
16
|
|
|
17
|
1
|
class Temporal(Solver, SupportDefaultsMixin):
|
18
|
|
"""Solves for connection weights by accounting for the neural dynamics.
|
19
|
|
|
20
|
|
This allows the optimization procedure to potentially harness any
|
21
|
|
correlations in spike-timing between neurons, and/or the adaptative
|
22
|
|
dynamics of more detailed neuron models, given the dynamics
|
23
|
|
of the desired function with respect to the evaluation points.
|
24
|
|
This works by explicitly simulating the neurons given the stimulus, and
|
25
|
|
then learning to decode the desired function in the time-domain.
|
26
|
|
|
27
|
|
To use this method, pass it to the ``solver`` parameter for a
|
28
|
|
:class:`nengo.Connection`. The ``pre`` object on this connection should be
|
29
|
|
a :class:`nengo.Ensemble` that uses some dynamic neuron model.
|
30
|
|
|
31
|
|
Parameters
|
32
|
|
----------
|
33
|
|
synapse : :class:`nengo.synapses.Synapse`, optional
|
34
|
|
The :class:`nengo.synapses.Synapse` model used to filter the
|
35
|
|
pre-synaptic activities of the neurons before being passed to the
|
36
|
|
underlying solver. A value of ``None`` will bypass any filtering.
|
37
|
|
Defaults to a :class:`nengo.Lowpass` filter with a time-constant of
|
38
|
|
5 ms.
|
39
|
|
solver : :class:`nengo.solvers.Solver`, optional
|
40
|
|
The underlying :class:`nengo.solvers.Solver` used to solve the problem
|
41
|
|
``AD = Y``, where ``A`` are the (potentially filtered) neural
|
42
|
|
activities (in response to the evaluation points, over time), ``D``
|
43
|
|
are the Nengo decoders, and ``Y`` are the corresponding targets given
|
44
|
|
by the ``function`` supplied to the connection.
|
45
|
|
Defaults to :class:`nengo.solvers.LstsqL2`.
|
46
|
|
|
47
|
|
See Also
|
48
|
|
--------
|
49
|
|
:class:`.RLS`
|
50
|
|
:class:`nengo.Connection`
|
51
|
|
:class:`nengo.solvers.Solver`
|
52
|
|
:mod:`.synapses`
|
53
|
|
|
54
|
|
Notes
|
55
|
|
-----
|
56
|
|
Requires ``nengo>=2.5.0``
|
57
|
|
(specifically, `PR #1313 <https://github.com/nengo/nengo/pull/1313>`_).
|
58
|
|
|
59
|
|
If the neuron model for the pre-synaptic population includes some
|
60
|
|
internal state that varies over time (which it should, otherwise there is
|
61
|
|
little point in using this solver), then the order of the given evaluation
|
62
|
|
points will matter. You will likely want to supply them as an array, rather
|
63
|
|
than as a distribution. Likewise, you may want to filter your desired
|
64
|
|
output, and specify the function as an array on the connection (see example
|
65
|
|
below).
|
66
|
|
|
67
|
|
The effect of the solver's regularization has a very different
|
68
|
|
interpretation in this context (due to the filtered spiking error having
|
69
|
|
its own statistics), and so you may also wish to instantiate the solver
|
70
|
|
yourself with some value other than the default regularization.
|
71
|
|
|
72
|
|
Examples
|
73
|
|
--------
|
74
|
|
Below we use the temporal solver to learn a filtered communication-channel
|
75
|
|
(the identity function) using 100 low-threshold spiking (LTS) Izhikevich
|
76
|
|
neurons. The training and test data are sampled independently from the
|
77
|
|
same band-limited white-noise process.
|
78
|
|
|
79
|
|
>>> from nengolib import Temporal, Network
|
80
|
|
>>> import nengo
|
81
|
|
>>> neuron_type = nengo.Izhikevich(coupling=0.25)
|
82
|
|
>>> tau = 0.005
|
83
|
|
>>> process = nengo.processes.WhiteSignal(period=5, high=5, y0=0, rms=0.3)
|
84
|
|
>>> eval_points = process.run_steps(5000)
|
85
|
|
>>> with Network() as model:
|
86
|
|
>>> stim = nengo.Node(output=process)
|
87
|
|
>>> x = nengo.Ensemble(100, 1, neuron_type=neuron_type)
|
88
|
|
>>> out = nengo.Node(size_in=1)
|
89
|
|
>>> nengo.Connection(stim, x, synapse=None)
|
90
|
|
>>> nengo.Connection(x, out, synapse=None,
|
91
|
|
>>> eval_points=eval_points,
|
92
|
|
>>> function=nengo.Lowpass(tau).filt(eval_points),
|
93
|
|
>>> solver=Temporal(synapse=tau))
|
94
|
|
>>> p_actual = nengo.Probe(out, synapse=tau)
|
95
|
|
>>> p_ideal = nengo.Probe(stim, synapse=tau)
|
96
|
|
>>> with nengo.Simulator(model) as sim:
|
97
|
|
>>> sim.run(5)
|
98
|
|
|
99
|
|
>>> import matplotlib.pyplot as plt
|
100
|
|
>>> plt.plot(sim.trange(), sim.data[p_actual], label="Actual")
|
101
|
|
>>> plt.plot(sim.trange(), sim.data[p_ideal], label="Ideal")
|
102
|
|
>>> plt.xlabel("Time (s)")
|
103
|
|
>>> plt.legend()
|
104
|
|
>>> plt.show()
|
105
|
|
"""
|
106
|
|
|
107
|
1
|
synapse = SynapseParam('synapse', default=Lowpass(tau=0.005),
|
108
|
|
readonly=True)
|
109
|
1
|
solver = SolverParam('solver', default=LstsqL2(), readonly=True)
|
110
|
|
|
111
|
1
|
def __init__(self, synapse=Default, solver=Default):
|
112
|
|
# We can't use super here because we need the defaults mixin
|
113
|
|
# in order to determine self.solver.weights.
|
114
|
1
|
SupportDefaultsMixin.__init__(self)
|
115
|
1
|
self.synapse = synapse
|
116
|
1
|
self.solver = solver
|
117
|
1
|
Solver.__init__(self, weights=self.solver.weights)
|
118
|
|
|
119
|
1
|
def __call__(self, A, Y, __hack__=None, **kwargs):
|
120
|
1
|
assert __hack__ is None
|
121
|
|
# __hack__ is necessary prior to nengo PR #1359 (<2.6.1)
|
122
|
|
# and following nengo PR #1507 (>2.8.0)
|
123
|
|
|
124
|
|
# Note: mul_encoders is never called directly on self.
|
125
|
|
# It is invoked on the sub-solver through the following call.
|
126
|
1
|
return self.solver.__call__(A, Y, **kwargs)
|
127
|
|
|
128
|
|
|
129
|
1
|
@Builder.register(Temporal)
|
130
|
1
|
def build_temporal_solver(model, solver, conn, rng, transform=None):
|
131
|
|
# Unpack the relevant variables from the connection.
|
132
|
1
|
assert isinstance(conn.pre_obj, Ensemble)
|
133
|
1
|
ensemble = conn.pre_obj
|
134
|
1
|
neurons = ensemble.neurons
|
135
|
1
|
neuron_type = ensemble.neuron_type
|
136
|
|
|
137
|
|
# Find the operator that simulates the neurons.
|
138
|
|
# We do it this way (instead of using the step_math method)
|
139
|
|
# because we don't know the number of state parameters or their shapes.
|
140
|
1
|
ops = list(filter(
|
141
|
|
lambda op: (isinstance(op, SimNeurons) and
|
142
|
|
op.J is model.sig[neurons]['in']),
|
143
|
|
model.operators))
|
144
|
|
if not len(ops) == 1: # pragma: no cover
|
145
|
|
raise RuntimeError("Expected exactly one operator for simulating "
|
146
|
|
"neurons (%s), found: %s" % (neurons, ops))
|
147
|
1
|
op = ops[0]
|
148
|
|
|
149
|
|
# Create stepper for the neuron model.
|
150
|
1
|
signals = SignalDict()
|
151
|
1
|
op.init_signals(signals)
|
152
|
1
|
step_simneurons = op.make_step(signals, model.dt, rng)
|
153
|
|
|
154
|
|
# Create custom rates method that uses the built neurons.
|
155
|
1
|
def override_rates_method(x, gain, bias):
|
156
|
1
|
n_eval_points, n_neurons = x.shape
|
157
|
1
|
assert ensemble.n_neurons == n_neurons
|
158
|
|
|
159
|
1
|
a = np.empty((n_eval_points, n_neurons))
|
160
|
1
|
for i, x_t in enumerate(x):
|
161
|
1
|
signals[op.J][...] = neuron_type.current(x_t, gain, bias)
|
162
|
1
|
step_simneurons()
|
163
|
1
|
a[i, :] = signals[op.output]
|
164
|
|
|
165
|
1
|
if solver.synapse is None:
|
166
|
1
|
return a
|
167
|
1
|
return solver.synapse.filt(a, axis=0, y0=0, dt=model.dt)
|
168
|
|
|
169
|
|
# Hot-swap the rates method while calling the underlying solver.
|
170
|
|
# The solver will then call this temporarily created rates method
|
171
|
|
# to process each evaluation point.
|
172
|
1
|
save_rates_method = neuron_type.rates
|
173
|
1
|
neuron_type.rates = override_rates_method
|
174
|
1
|
try:
|
175
|
|
# Note: passing solver.solver doesn't actually cause solver.solver
|
176
|
|
# to be built. It will still use conn.solver. This is because
|
177
|
|
# the function decorated with @Builder.register(Solver) actually
|
178
|
|
# ignores the solver and considers only the conn. The only point of
|
179
|
|
# passing solver.solver here is to invoke its corresponding builder
|
180
|
|
# function in case something custom happens to be registered.
|
181
|
|
# Note: in nengo>2.8.0 the transform parameter is dropped
|
182
|
1
|
return model.build(solver.solver, conn, rng, transform)
|
183
|
|
|
184
|
|
finally:
|
185
|
1
|
neuron_type.rates = save_rates_method
|