Wednesday, 13 October 2021

More fun with Spiking Neural Networks and JAX

I'm currently exploring Spiking Neural Networks.

SNNs (Spiking Neural Networks) try to model the brain more accurately than most of the Artificial Neural Networks used in Deep Learning.

There are some SNN implementations available in TensorFlow and PyTorch but I'm keen to explore them using pure Python. I find that Python code gives me confidence in my understanding.

But there's a problem.

SNNs need a lot of computing power. Even if I use numpy, large-scale simulations can run slowly.
Spike generation - code below

So I'm using JAX.

JAX code runs in a traditional Python environment. JAX has array processing modules that are closely based on numpy's syntax. It also has a JIT (Just-in-time) compiler that lets you transparently deploy your code to GPUs.

Jitting imposes some minor restrictions on your code but it leads to dramatic speed-ups.

JAX and robotics

As I mentioned in an earlier post, you can run JAX on NVIDIA's Jetson family. You get excellent performance on inexpensive hardware with a low power budget.

If, like me,  you're a robotics experimenter, this is very exciting!

A helpful and knowledgeable JAX community

To get the best from JAX you need to think in JAX. I've a head start in that I've been using APL for over five decades. APL, like JAX and numpy, encourages you to think and code functionally, applying your functions uniformly to tensors.

There a lot of JAX skills I still have to master, and yesterday I hit a problem I couldn't solve. I needed to create the outer product of two arrays using the mod function. I know how to do that using APL or numpy but I couldn't see how to do it using JAX.

I asked for help in JAX's GitHub discussion forum, and quickly got a simple, elegant solution. The solution showed me that I'd missed some fundamental aspects of the way that indexing and broadcasting work in numpy and JAX. 

The advice came from Jake Vanderplas - an Astronomer turned Google Researcher who has given some excellent talks about Python in Science, including tips on performance.

Generating input spikes with JAX

You may be interested to see the SNN problem I was trying to solve and the code I ended up with

I wanted to generate multiple regular spike trains, with each train having its own periodicity.

I've seen Python code to do that but it's been truly awful: hard to read and slow to execute.

I knew there had to be a simple way to do it.

In APL I'd just do a mod outer product of the periods and a vector of time indices. A zero in the result would indicate that the time was a multiple of the period, so that neuron should fire.

Here's the APL code:

      a ← 1 5 10 ⍝ firing period
      b ← 10 ⍝ time steps
      spikes ← 0 = (1+⍳b) ∘.| a

So far so good, but I could not see how to do the same thing in JAX.

Here's the code in JAX

I'm gradually building up a library of Spiking Neural Network code using Jax. If you're interested, let me know on twitter: @rareblog.

No comments:

Post a Comment