1


# Organizationwise, this file is actually "a part of" solvers.py.

2


# However, due to the fact that this is only compatible with nengo>=2.5.0

3


# it simplifies life to move this into its own file.

4



5

1

import numpy as np

6



7

1

from nengo import Ensemble

8

1

from nengo.builder import Builder

9

1

from nengo.builder.neurons import SimNeurons

10

1

from nengo.builder.signal import SignalDict

11

1

from nengo.config import SupportDefaultsMixin

12

1

from nengo.params import Default

13

1

from nengo.solvers import Solver, LstsqL2, SolverParam

14

1

from nengo.synapses import SynapseParam, Lowpass

15



16



17

1

class Temporal(Solver, SupportDefaultsMixin):

18


"""Solves for connection weights by accounting for the neural dynamics.

19



20


This allows the optimization procedure to potentially harness any

21


correlations in spiketiming between neurons, and/or the adaptative

22


dynamics of more detailed neuron models, given the dynamics

23


of the desired function with respect to the evaluation points.

24


This works by explicitly simulating the neurons given the stimulus, and

25


then learning to decode the desired function in the timedomain.

26



27


To use this method, pass it to the ``solver`` parameter for a

28


:class:`nengo.Connection`. The ``pre`` object on this connection should be

29


a :class:`nengo.Ensemble` that uses some dynamic neuron model.

30



31


Parameters

32




33


synapse : :class:`nengo.synapses.Synapse`, optional

34


The :class:`nengo.synapses.Synapse` model used to filter the

35


presynaptic activities of the neurons before being passed to the

36


underlying solver. A value of ``None`` will bypass any filtering.

37


Defaults to a :class:`nengo.Lowpass` filter with a timeconstant of

38


5 ms.

39


solver : :class:`nengo.solvers.Solver`, optional

40


The underlying :class:`nengo.solvers.Solver` used to solve the problem

41


``AD = Y``, where ``A`` are the (potentially filtered) neural

42


activities (in response to the evaluation points, over time), ``D``

43


are the Nengo decoders, and ``Y`` are the corresponding targets given

44


by the ``function`` supplied to the connection.

45


Defaults to :class:`nengo.solvers.LstsqL2`.

46



47


See Also

48




49


:class:`.RLS`

50


:class:`nengo.Connection`

51


:class:`nengo.solvers.Solver`

52


:mod:`.synapses`

53



54


Notes

55




56


Requires ``nengo>=2.5.0``

57


(specifically, `PR #1313 <https://github.com/nengo/nengo/pull/1313>`_).

58



59


If the neuron model for the presynaptic population includes some

60


internal state that varies over time (which it should, otherwise there is

61


little point in using this solver), then the order of the given evaluation

62


points will matter. You will likely want to supply them as an array, rather

63


than as a distribution. Likewise, you may want to filter your desired

64


output, and specify the function as an array on the connection (see example

65


below).

66



67


The effect of the solver's regularization has a very different

68


interpretation in this context (due to the filtered spiking error having

69


its own statistics), and so you may also wish to instantiate the solver

70


yourself with some value other than the default regularization.

71



72


Examples

73




74


Below we use the temporal solver to learn a filtered communicationchannel

75


(the identity function) using 100 lowthreshold spiking (LTS) Izhikevich

76


neurons. The training and test data are sampled independently from the

77


same bandlimited whitenoise process.

78



79


>>> from nengolib import Temporal, Network

80


>>> import nengo

81


>>> neuron_type = nengo.Izhikevich(coupling=0.25)

82


>>> tau = 0.005

83


>>> process = nengo.processes.WhiteSignal(period=5, high=5, y0=0, rms=0.3)

84


>>> eval_points = process.run_steps(5000)

85


>>> with Network() as model:

86


>>> stim = nengo.Node(output=process)

87


>>> x = nengo.Ensemble(100, 1, neuron_type=neuron_type)

88


>>> out = nengo.Node(size_in=1)

89


>>> nengo.Connection(stim, x, synapse=None)

90


>>> nengo.Connection(x, out, synapse=None,

91


>>> eval_points=eval_points,

92


>>> function=nengo.Lowpass(tau).filt(eval_points),

93


>>> solver=Temporal(synapse=tau))

94


>>> p_actual = nengo.Probe(out, synapse=tau)

95


>>> p_ideal = nengo.Probe(stim, synapse=tau)

96


>>> with nengo.Simulator(model) as sim:

97


>>> sim.run(5)

98



99


>>> import matplotlib.pyplot as plt

100


>>> plt.plot(sim.trange(), sim.data[p_actual], label="Actual")

101


>>> plt.plot(sim.trange(), sim.data[p_ideal], label="Ideal")

102


>>> plt.xlabel("Time (s)")

103


>>> plt.legend()

104


>>> plt.show()

105


"""

106



107

1

synapse = SynapseParam('synapse', default=Lowpass(tau=0.005),

108


readonly=True)

109

1

solver = SolverParam('solver', default=LstsqL2(), readonly=True)

110



111

1

def __init__(self, synapse=Default, solver=Default):

112


# We can't use super here because we need the defaults mixin

113


# in order to determine self.solver.weights.

114

1

SupportDefaultsMixin.__init__(self)

115

1

self.synapse = synapse

116

1

self.solver = solver

117

1

Solver.__init__(self, weights=self.solver.weights)

118



119

1

def __call__(self, A, Y, __hack__=None, **kwargs):

120

1

assert __hack__ is None

121


# __hack__ is necessary prior to nengo PR #1359 (<2.6.1)

122


# and following nengo PR #1507 (>2.8.0)

123



124


# Note: mul_encoders is never called directly on self.

125


# It is invoked on the subsolver through the following call.

126

1

return self.solver.__call__(A, Y, **kwargs)

127



128



129

1

@Builder.register(Temporal)

130

1

def build_temporal_solver(model, solver, conn, rng, transform=None):

131


# Unpack the relevant variables from the connection.

132

1

assert isinstance(conn.pre_obj, Ensemble)

133

1

ensemble = conn.pre_obj

134

1

neurons = ensemble.neurons

135

1

neuron_type = ensemble.neuron_type

136



137


# Find the operator that simulates the neurons.

138


# We do it this way (instead of using the step_math method)

139


# because we don't know the number of state parameters or their shapes.

140

1

ops = list(filter(

141


lambda op: (isinstance(op, SimNeurons) and

142


op.J is model.sig[neurons]['in']),

143


model.operators))

144


if not len(ops) == 1: # pragma: no cover

145


raise RuntimeError("Expected exactly one operator for simulating "

146


"neurons (%s), found: %s" % (neurons, ops))

147

1

op = ops[0]

148



149


# Create stepper for the neuron model.

150

1

signals = SignalDict()

151

1

op.init_signals(signals)

152

1

step_simneurons = op.make_step(signals, model.dt, rng)

153



154


# Create custom rates method that uses the built neurons.

155

1

def override_rates_method(x, gain, bias):

156

1

n_eval_points, n_neurons = x.shape

157

1

assert ensemble.n_neurons == n_neurons

158



159

1

a = np.empty((n_eval_points, n_neurons))

160

1

for i, x_t in enumerate(x):

161

1

signals[op.J][...] = neuron_type.current(x_t, gain, bias)

162

1

step_simneurons()

163

1

a[i, :] = signals[op.output]

164



165

1

if solver.synapse is None:

166

1

return a

167

1

return solver.synapse.filt(a, axis=0, y0=0, dt=model.dt)

168



169


# Hotswap the rates method while calling the underlying solver.

170


# The solver will then call this temporarily created rates method

171


# to process each evaluation point.

172

1

save_rates_method = neuron_type.rates

173

1

neuron_type.rates = override_rates_method

174

1

try:

175


# Note: passing solver.solver doesn't actually cause solver.solver

176


# to be built. It will still use conn.solver. This is because

177


# the function decorated with @Builder.register(Solver) actually

178


# ignores the solver and considers only the conn. The only point of

179


# passing solver.solver here is to invoke its corresponding builder

180


# function in case something custom happens to be registered.

181


# Note: in nengo>2.8.0 the transform parameter is dropped

182

1

return model.build(solver.solver, conn, rng, transform)

183



184


finally:

185

1

neuron_type.rates = save_rates_method
