Putting the Neural back into Networks
Part 2: More spikes, more problems
12th January, 2020In the last post, we learned about how a spiking neuron differs from a common-or-garden Artificial Neuron. TL;DR: Spiking neurons understand time and have internal dynamics, Artificial Neurons don’t.
Spiking neural networks (SNNs) have an edge in energy efficiency, especially when using special-purpose hardware. But ANNs and DNNs have a huge potential benefit: since they are simple blobs of linear algebra, it’s easy to use optimization techniques such as gradient descent to find a good network. Gradient-based optimisation methods have been highly successful for tuning deep network weights to solve an enormous range of tasks.
With SNNs we are a little bit out in the cold, since SNNs contain several mechanisms that lead to undifferentiability — obviously making gradient-based approaches difficult to apply.
For example, shown above is the internal state (\(V_m\)) of a spiking neuron as it evolves over time, along with the spikes generated when \(V_m\) is above the threshold \(V_{th}\) of 0.
In contrast to a ReLU or \(tanh\) neuron, the output of each spiking neuron is highly discontinuous — essentially a sequence of delta functions.
The spike generation function as a function of \(V_m\) is a Heaviside function: \(S_o(t)=H(V_m)\). Unfortunately the derivative of the Heaviside function is zero everywhere, and is undefined at 0 (see figure below). When propagating the error backwards through a spiking neuron, the spike generation function essentially zeros the gradient. No gradient means no weight updates, leading to frustrated and unhappy ML engineers.
Recent work on SNNs proposes using surrogate gradients — essentially, simulating a nonlinear spiking neuron in a forward pass, then using a neuron with similar but differentiable dynamics in a backward pass when gradients are needed [e.g. 1, 2, 3, 4, 5].
For example, we could generate a surrogate spike signal as a simple function of the neuron state, as shown below, using \(S^*=min(0, x+1/2)\). This makes the surrogate output look like a ReLU neuron with a slope of 1, and the derivative calculation becomes trivial.
If we define a spiking neuron in this way, we can propagate an error backwards through the surrogate, instead of through the spike generation function. As long as we chose the surrogate well, such that approximates the behaviour of the spiking neuron to some extent, then reducing errors via the surrogate should lead to reducing errors in the spiking neuron. We therefore end up with an approximate gradient descent optimisation process.
With this approach we can compute error gradients for our weights. In fact, using the powerful automatic differentiation packages available (e.g. PyTorch and JAX), we can also compute error gradients for other neuron parameters such as time constants.
So how do we build a spiking network and apply a surrogate gradient descent approach to train it? In the next post, we’ll see how to use an open-source package “Rockpool” to build and train a spiking neural network.
Part 3: I got 99 problems but a spike ain't one
References
[1]: Lee et al. 2016 Training Deep Spiking Neural Networks Using Backpropagation, Front. Neurosci. DOI: 10.3389/fnins.2016.00508.
[2]: Zenke & Ganguli 2018 SuperSpike: Supervised Learning in Multilayer Spiking Neural Networks. Neur. Comput. DOI: 10.1162/neco_a_01086.
[3]: Kaiser et al. 2018 Synaptic Plasticity Dynamics for Deep Continuous Local Learning (DECOLLE). arXiv 1811.10766.
[4]: Neftçi et al. 2019 Surrogate Gradient Learning in Spiking Neural Networks. arXiv 1901.09948.
[5]: Zimmer et al. 2019 Technical report: supervised training of convolutional spiking neural networks with PyTorch. arXiv 1911.10124.
[6] https://github.com/google/jax
[7] Kingma & Ba 2014 Adam: A Method for Stochastic Optimization. arXiv 1412.6980.