1

1

import warnings

2



3

1

import numpy as np

4



5

1

from nengo.solvers import Solver, LstsqL2

6

1

from nengo.version import version_info

7



8


if version_info < (2, 5, 0): # pragma: no cover

9


Temporal = None # not supported (requires nengo PR #1313)

10


else:

11

1

from nengolib.temporal import Temporal

12



13

1

__all__ = ['Temporal']

14



15



16

1

class BiasedSolver(Solver):

17


"""Wraps a solver with a bias neuron, and extracts its weights.

18



19


This is setup correctly by nengolib.Connection; not to be used directly.

20


"""

21



22

1

def __init__(self, solver=LstsqL2()):

23

1

self.solver = solver

24

1

self.bias = None

25

1

try:

26


# parent class changed in Nengo 2.1.1

27


# need to do this because self.weights is readonly

28

1

super(BiasedSolver, self).__init__(weights=solver.weights)

29


except TypeError: # pragma: no cover

30


super(BiasedSolver, self).__init__()

31


self.weights = solver.weights

32



33

1

def __call__(self, A, Y, __hack__=None, **kwargs):

34

1

assert __hack__ is None

35


# __hack__ is necessary prior to nengo PR #1359 (<2.6.1)

36


# and following nengo PR #1507 (>2.8.0)

37



38

1

if self.bias is not None:

39


# this is okay if due to multiple builds of the same network (#99)

40

1

warnings.warn("%s called twice; ensure not being shared between "

41


"multiple connections" % type(self).__name__,

42


UserWarning)

43

1

scale = A.max() # to make regularization consistent

44

1

AB = np.empty((A.shape[0], A.shape[1] + 1))

45

1

AB[:, :1] = A

46

1

AB[:, 1] = scale

47

1

XB, solver_info = self.solver.__call__(AB, Y, **kwargs)

48

1

solver_info['bias'] = self.bias = XB[1, :] * scale

49

1

return XB[:1, :], solver_info

50



51

1

def bias_function(self, size):

52


"""Returns the function for the presynaptic bias node."""

53

1

return lambda _: np.zeros(size) if self.bias is None else self.bias
