1
# Organization-wise, this file is actually "a part of" solvers.py.
2
# However, due to the fact that this is only compatible with nengo>=2.5.0
3
# it simplifies life to move this into its own file.
4

5 1
import numpy as np
6

7 1
from nengo import Ensemble
8 1
from nengo.builder import Builder
9 1
from nengo.builder.neurons import SimNeurons
10 1
from nengo.builder.signal import SignalDict
11 1
from nengo.config import SupportDefaultsMixin
12 1
from nengo.params import Default
13 1
from nengo.solvers import Solver, LstsqL2, SolverParam
14 1
from nengo.synapses import SynapseParam, Lowpass
15

16

17 1
class Temporal(Solver, SupportDefaultsMixin):
18
    """Solves for connection weights by accounting for the neural dynamics.
19

20
    This allows the optimization procedure to potentially harness any
21
    correlations in spike-timing between neurons, and/or the adaptative
22
    dynamics of more detailed neuron models, given the dynamics
23
    of the desired function with respect to the evaluation points.
24
    This works by explicitly simulating the neurons given the stimulus, and
25
    then learning to decode the desired function in the time-domain.
26

27
    To use this method, pass it to the ``solver`` parameter for a
28
    :class:`nengo.Connection`. The ``pre`` object on this connection should be
29
    a :class:`nengo.Ensemble` that uses some dynamic neuron model.
30

31
    Parameters
32
    ----------
33
    synapse : :class:`nengo.synapses.Synapse`, optional
34
        The :class:`nengo.synapses.Synapse` model used to filter the
35
        pre-synaptic activities of the neurons before being passed to the
36
        underlying solver. A value of ``None`` will bypass any filtering.
37
        Defaults to a :class:`nengo.Lowpass` filter with a time-constant of
38
        5 ms.
39
    solver : :class:`nengo.solvers.Solver`, optional
40
        The underlying :class:`nengo.solvers.Solver` used to solve the problem
41
        ``AD = Y``, where ``A`` are the (potentially filtered) neural
42
        activities (in response to the evaluation points, over time), ``D``
43
        are the Nengo decoders, and ``Y`` are the corresponding targets given
44
        by the ``function`` supplied to the connection.
45
        Defaults to :class:`nengo.solvers.LstsqL2`.
46

47
    See Also
48
    --------
49
    :class:`.RLS`
50
    :class:`nengo.Connection`
51
    :class:`nengo.solvers.Solver`
52
    :mod:`.synapses`
53

54
    Notes
55
    -----
56
    Requires ``nengo>=2.5.0``
57
    (specifically, `PR #1313 <https://github.com/nengo/nengo/pull/1313>`_).
58

59
    If the neuron model for the pre-synaptic population includes some
60
    internal state that varies over time (which it should, otherwise there is
61
    little point in using this solver), then the order of the given evaluation
62
    points will matter. You will likely want to supply them as an array, rather
63
    than as a distribution. Likewise, you may want to filter your desired
64
    output, and specify the function as an array on the connection (see example
65
    below).
66

67
    The effect of the solver's regularization has a very different
68
    interpretation in this context (due to the filtered spiking error having
69
    its own statistics), and so you may also wish to instantiate the solver
70
    yourself with some value other than the default regularization.
71

72
    Examples
73
    --------
74
    Below we use the temporal solver to learn a filtered communication-channel
75
    (the identity function) using 100 low-threshold spiking (LTS) Izhikevich
76
    neurons. The training and test data are sampled independently from the
77
    same band-limited white-noise process.
78

79
    >>> from nengolib import Temporal, Network
80
    >>> import nengo
81
    >>> neuron_type = nengo.Izhikevich(coupling=0.25)
82
    >>> tau = 0.005
83
    >>> process = nengo.processes.WhiteSignal(period=5, high=5, y0=0, rms=0.3)
84
    >>> eval_points = process.run_steps(5000)
85
    >>> with Network() as model:
86
    >>>     stim = nengo.Node(output=process)
87
    >>>     x = nengo.Ensemble(100, 1, neuron_type=neuron_type)
88
    >>>     out = nengo.Node(size_in=1)
89
    >>>     nengo.Connection(stim, x, synapse=None)
90
    >>>     nengo.Connection(x, out, synapse=None,
91
    >>>                      eval_points=eval_points,
92
    >>>                      function=nengo.Lowpass(tau).filt(eval_points),
93
    >>>                      solver=Temporal(synapse=tau))
94
    >>>     p_actual = nengo.Probe(out, synapse=tau)
95
    >>>     p_ideal = nengo.Probe(stim, synapse=tau)
96
    >>> with nengo.Simulator(model) as sim:
97
    >>>     sim.run(5)
98

99
    >>> import matplotlib.pyplot as plt
100
    >>> plt.plot(sim.trange(), sim.data[p_actual], label="Actual")
101
    >>> plt.plot(sim.trange(), sim.data[p_ideal], label="Ideal")
102
    >>> plt.xlabel("Time (s)")
103
    >>> plt.legend()
104
    >>> plt.show()
105
    """
106

107 1
    synapse = SynapseParam('synapse', default=Lowpass(tau=0.005),
108
                           readonly=True)
109 1
    solver = SolverParam('solver', default=LstsqL2(), readonly=True)
110

111 1
    def __init__(self, synapse=Default, solver=Default):
112
        # We can't use super here because we need the defaults mixin
113
        # in order to determine self.solver.weights.
114 1
        SupportDefaultsMixin.__init__(self)
115 1
        self.synapse = synapse
116 1
        self.solver = solver
117 1
        Solver.__init__(self, weights=self.solver.weights)
118

119 1
    def __call__(self, A, Y, __hack__=None, **kwargs):
120 1
        assert __hack__ is None
121
        # __hack__ is necessary prior to nengo PR #1359 (<2.6.1)
122
        # and following nengo PR #1507 (>2.8.0)
123

124
        # Note: mul_encoders is never called directly on self.
125
        # It is invoked on the sub-solver through the following call.
126 1
        return self.solver.__call__(A, Y, **kwargs)
127

128

129 1
@Builder.register(Temporal)
130 1
def build_temporal_solver(model, solver, conn, rng, transform=None):
131
    # Unpack the relevant variables from the connection.
132 1
    assert isinstance(conn.pre_obj, Ensemble)
133 1
    ensemble = conn.pre_obj
134 1
    neurons = ensemble.neurons
135 1
    neuron_type = ensemble.neuron_type
136

137
    # Find the operator that simulates the neurons.
138
    # We do it this way (instead of using the step_math method)
139
    # because we don't know the number of state parameters or their shapes.
140 1
    ops = list(filter(
141
        lambda op: (isinstance(op, SimNeurons) and
142
                    op.J is model.sig[neurons]['in']),
143
        model.operators))
144
    if not len(ops) == 1:  # pragma: no cover
145
        raise RuntimeError("Expected exactly one operator for simulating "
146
                           "neurons (%s), found: %s" % (neurons, ops))
147 1
    op = ops[0]
148

149
    # Create stepper for the neuron model.
150 1
    signals = SignalDict()
151 1
    op.init_signals(signals)
152 1
    step_simneurons = op.make_step(signals, model.dt, rng)
153

154
    # Create custom rates method that uses the built neurons.
155 1
    def override_rates_method(x, gain, bias):
156 1
        n_eval_points, n_neurons = x.shape
157 1
        assert ensemble.n_neurons == n_neurons
158

159 1
        a = np.empty((n_eval_points, n_neurons))
160 1
        for i, x_t in enumerate(x):
161 1
            signals[op.J][...] = neuron_type.current(x_t, gain, bias)
162 1
            step_simneurons()
163 1
            a[i, :] = signals[op.output]
164

165 1
        if solver.synapse is None:
166 1
            return a
167 1
        return solver.synapse.filt(a, axis=0, y0=0, dt=model.dt)
168

169
    # Hot-swap the rates method while calling the underlying solver.
170
    # The solver will then call this temporarily created rates method
171
    # to process each evaluation point.
172 1
    save_rates_method = neuron_type.rates
173 1
    neuron_type.rates = override_rates_method
174 1
    try:
175
        # Note: passing solver.solver doesn't actually cause solver.solver
176
        # to be built. It will still use conn.solver. This is because
177
        # the function decorated with @Builder.register(Solver) actually
178
        # ignores the solver and considers only the conn. The only point of
179
        # passing solver.solver here is to invoke its corresponding builder
180
        # function in case something custom happens to be registered.
181
        # Note: in nengo>2.8.0 the transform parameter is dropped
182 1
        return model.build(solver.solver, conn, rng, transform)
183

184
    finally:
185 1
        neuron_type.rates = save_rates_method

Read our documentation on viewing source code .

Loading