Putting the Neural back into Networks

Part 3: I got 99 problems but a spike ain't one

13th January, 2020

Ocean scene

Ocean scene

Photo by @tgerz on Unsplash

In the last post, we saw that spike generation can cause major issues with gradient descent, by completely zeroing the gradients. We also learned how to get around this issue by faking the neuron output by making a spiking surrogate.

Part 2: More spikes, more problems

In this installment we’ll see how to build and train a feed-forward spiking neural network to solve a temporal signal transformation task.

Spiking Neural Networks in Rockpool

Rockpool is an open-source Python package for building and training spiking (and non-spiking) neural networks. It provides a way to build networks at a high level, providing a large number of layer types with differing neuron dynamics (including non-spiking neurons). It tries to hide most of this implementation detail, with a goal of becoming a modular high-level package for neuromorphic machine learning.

So let’s build and train a very simple network in Rockpool, to generate a complex output time series from a spiking input.

The network architecture will be as shown here: a number of spiking input channels will be provided via input weights \(W_i\) to a layer of \(N\) spiking neurons. The output of the spiking neurons will converge via output weights \(W_o\) to a single output channel.

We'll begin by defining the input and target signals required for our task. We need a randomly-chosen set of Poisson input spikes, which we can generate by thresholding uniform noise. We also need to generate a chirp signal (a sinusoid with increasing frequency over time), which we will use as our training target.

The frozen Poisson input spikes act as the input to the network (blue dots). The network should learn to transform this input into a chirp signal (orange).

Now that we’ve got our input and target signals, we need to construct a layer of spiking neurons. We can initialise the weight parameters randomly — we’ll train these anyway. We’ll make sensible choices for initial biases \(b\), and initial time constants \(\tau_{syn}\) and \(\tau_m\).

We can then use the .evolve() method of lyrIO to simulate the layer with the random spikes input, and look at the activity of a single neuron.

The neuron state (\(V_m\); blue), and output spikes (\(S_o\), orange) emitted by a single neuron

Training the network

We can now use Rockpool to perform gradient descent optimization of the layer, attempting to match the output of the layer with our target chirp signal.

Behind the scenes, Rockpool is using JAX [6] to simulate the neuron dynamics and compute a loss (i.e. a measure of the error) between the current output and the target signal. By default, Rockpool gives us a regularised mean-squared-error loss, and uses the ADAM stochastic gradient descent optimizer [7]. These are both configurable; you can learn how in the Rockpool documentation!

Let’s see how our network performs now!

Output of the trained network (blue), and the target signal (orange).

Selected neuron states \(V_{mem}\).

Selected neuron spiking activity.

Not too shabby! All this and more can be yours with Rockpool. 😉

Reconstructing a smooth function by summing deltas isn’t easy, but the gradient descent process has found a fairly good solution with only 50 spiking neurons. In fact this task is solvable down to only two neurons, with a reduction in solution quality of course.

Taking things further

In practice, there’s only so much you can do with a single feedforward layer. In DNN world, you build bigger networks by stacking layers on top of each other. That works with spiking neurons too, but you need to be very smart about propagating errors or setting output targets for each layer [1].

One alternative is to use highly recurrent architectures, which are common in spiking networks [e.g. 2]. Recurrent spiking networks extend the temporal dynamics of the neurons and synapses by adding recurrent dynamics. Essentially a single recurrent layer of neurons acts as an infinite stack of feed-forward layers, but with fading memory as you propagate up the stack. If you configure the recurrent networks well, this approach lets you build systems which can analyse long stretches of time series data.

Rockpool is being actively developed — it’s what we use in-house at SynSense to build sub-mW signal processing applications. The library also provides direct interfaces to spike-based neuromorphic computation hardware from SynSense.

We've used Rockpool to build demos for low-power audio and bio-signal processing, using spiking recurrent NNs. These temporal tasks are a good fit for spiking NNs, since they can take advantage of the direct representation of time in SNNs. When coupled with low-power asynchronous neuromorphic inference hardware, using SNNs confers a huge advantage in energy efficiency over GPUs and CPUs.

References

[1]: Neftçi et al. 2019 Surrogate Gradient Learning in Spiking Neural Networks. arXiv 1901.09948.
[2]: Maass, Natschläger & Markram 2002 Real-time computing without stable states: a new framework for neural computation based on perturbations Neural Computation 14 (11): 2531–60. doi:10.1162/089976602760407955.