Wednesday, 13 October 2021

More fun with Spiking Neural Networks and JAX

I'm currently exploring Spiking Neural Networks.

SNNs (Spiking Neural Networks) try to model the brain more accurately than most of the Artificial Neural Networks used in Deep Learning.

There are some SNN implementations available in TensorFlow and PyTorch but I'm keen to explore them using pure Python. I find that Python code gives me confidence in my understanding.

But there's a problem.

SNNs need a lot of computing power. Even if I use numpy, large-scale simulations can run slowly.
Spike generation - code below

So I'm using JAX.

JAX code runs in a traditional Python environment. JAX has array processing modules that are closely based on numpy's syntax. It also has a JIT (Just-in-time) compiler that lets you transparently deploy your code to GPUs.

Jitting imposes some minor restrictions on your code but it leads to dramatic speed-ups.

JAX and robotics

As I mentioned in an earlier post, you can run JAX on NVIDIA's Jetson family. You get excellent performance on inexpensive hardware with a low power budget.

If, like me,  you're a robotics experimenter, this is very exciting!

A helpful and knowledgeable JAX community

To get the best from JAX you need to think in JAX. I've a head start in that I've been using APL for over five decades. APL, like JAX and numpy, encourages you to think and code functionally, applying your functions uniformly to tensors.

There a lot of JAX skills I still have to master, and yesterday I hit a problem I couldn't solve. I needed to create the outer product of two arrays using the mod function. I know how to do that using APL or numpy but I couldn't see how to do it using JAX.

I asked for help in JAX's GitHub discussion forum, and quickly got a simple, elegant solution. The solution showed me that I'd missed some fundamental aspects of the way that indexing and broadcasting work in numpy and JAX. 

The advice came from Jake Vanderplas - an Astronomer turned Google Researcher who has given some excellent talks about Python in Science, including tips on performance.

Generating input spikes with JAX

You may be interested to see the SNN problem I was trying to solve and the code I ended up with

I wanted to generate multiple regular spike trains, with each train having its own periodicity.

I've seen Python code to do that but it's been truly awful: hard to read and slow to execute.

I knew there had to be a simple way to do it.

In APL I'd just do a mod outer product of the periods and a vector of time indices. A zero in the result would indicate that the time was a multiple of the period, so that neuron should fire.

Here's the APL code:

      a ← 1 5 10 ⍝ firing period
      b ← 10 ⍝ time steps
      spikes ← 0 = (1+⍳b) ∘.| a

So far so good, but I could not see how to do the same thing in JAX.

Here's the code in JAX

I'm gradually building up a library of Spiking Neural Network code using Jax. If you're interested, let me know on twitter: @rareblog.

Thursday, 23 September 2021

Installing Jax on the Jetson Nano

Yesterday I got Jax running on a Jetson Nano. There's been some online interest and I promised to describe the installation. I'll cover the process below but first I'll briefly explain

What's Jax, and why do I want it?

tldr: The cool kids have moved from Tensorflow to Jax.
If you're interested in AI and Artificial Neural Networks (ANNs) you'll have heard of Google's TensorFlow.

Tensorflow with Keras makes it easy to
  • build an ANNs from standard software components
  • train the network
  • test it and
  • deploy it into production
Even better: TensorFlow can take advantage of a GPU or TPU if they are available, which allows a dramatic speedup of the training process. That's a big deal because training a complex network from scratch might take weeks of computer time.

However, I find it hard to get TensorFlow and its competitors to use novel network architectures.  I am currently exploring types of  network that I can't implement easily in TensorFlow.

I could code them in Python and Numpy but I want to be able to run them fast, and that means using a GPU.

Hence Jax.

You can read about why DeepMind uses Jax here

Jax implements many of Numpy's primitives on a GPU and it can also convert functional, side-effect-free Python code.

It also supports automatic differentiation. I'm not going to be using back propagation in the work I'm doing, so that's not a big deal for me. If you are, you can see online examples that build and train networks from scratch in a few lines of Python code.

I'll give some numbers for the speed-ups you get later in the article. For now, I'll just say that Jax code runs faster than Numpy on the Nano, and faster still when compared with Numpy on the Raspberry Pi.

Jax on the Jetson Nano

Jetson Nano with SSD

There are several ways you can run Jax-based applications; it's not hard to install on a Linux system, and you can run it on Google's wonderful, free Colab service. My interest lies in using neural networks with mobile robots, so I want to run it on an Edge Computing device with GPU hardware built-in.

The Jetson Nano is ideal for that.

I recently came across an article on NVIDIA's developer website that described how to install Jax on the Nano. It included a warning that the installation was slow and required a large (10GB) swap file.

I decided to re-configure one of my Nanos from scratch and set it up to boot from a USB SSD. I did the whole process in three stages using links which I have listed below, along with the gotcha I hit and my work-around.

  1. I set up the Nano in headless mode using an SD card.
  2. I configured the Nano to boot from a USB SSD.
  3. I installed Jax from source.

Set up the Nano in headless mode using an SD card

NVIDIA have really simplified the setup process since I started exploring the Nano a couple of years ago.

The NVIDIA Getting Started guide now gives two ways to set up your Nano.

One requires an HDMI screen, a mouse and keyboard. 

The other needs an Ethernet connection, a suitable power supply and a jumper to configure the Nano to power itself from its barrel connector rather than from USB.

I want to run my Nano in fast, full-power mode and I have very limited desk space. I decided to go for a Headless Setup. I followed these three steps in the NVIDIA guide:

  1. Prepare for Setup*
  2. Write Image to the microSD Card
  3. Setup and First Boot**
* There are alternative power supplies to the Adafruit supply that NVIDIA mentioned. I purchased this one in the UK.

** Make sure you scroll down to the section Initial Setup Headless Mode

Configure the Nano to boot from USB SDD

I followed this excellent guide from JetsonHacks.

I found that the Nano did not boot from the SSD first time but it's booted reliably ever since.

Installing Jax from source

I followed the process given in this post on the NVIDIA developer website. I hit one snag, which was easy to solve: After I had set up a Python 3.9 virtual environment, I needed to run

sudo apt install python3.9-distutils

before I could run the stages that used pip.

First time through I failed to realise that the heading How to Add Swap Space on Ubuntu 18.04 is actually a link to a web page that tells you how to do that. I also set the swap file size to 16 GB rather than 10GB as recommended.

The installation guide says that the compilation process takes 12 hours. Mine took about 6 hours, possibly because I was running the Nano in Max Power mode.

Once the compilation process has finished, the remaining steps take a minute or less!

First Timing Tests

It looks as if Jax is six to seven times faster than numpy on the Nano when dot-multiplying a 4000x4000 element matrix with itself, and 25 times faster than a Pi 4 running numpy. That's worth the effort of installation :)


Monday, 15 March 2021

MicroPlot on the PyPortal - progress and frustration

Update: I got a fix within minutes of posing the problem on the Adafruit discord channel!

Here's the correct bitmap:

The problem is now solved,

MicroPlot now runs well on the Adafruit PyPortal and  Clue as well as the Pimoroni Pico Explorer base, but I've been tearing my hair out trying to solve a problem saving bitmaps.

The bitmap problem

Here's a screenshot of a display on the PyPortal together with the bitmap file which should show what's on the screen. I could not work out what's going wrong.

The code creates the plot and then uses the screenshot code that Adafruit provides.

import math from plotter import Plotter from plots import LinePlot import board import digitalio import busio import adafruit_sdcard import storage from adafruit_bitmapsaver import save_pixels def plot(): sines = list(math.sin(math.radians(x)) for x in range(0, 361, 4)) lineplot = LinePlot([sines],'MicroPlot line') plotter = Plotter() lineplot.plot(plotter) def save(): spi = busio.SPI(board.SCK, MOSI=board.MOSI, MISO=board.MISO) cs = digitalio.DigitalInOut(board.SD_CS) sdcard = adafruit_sdcard.SDCard(spi, cs) vfs = storage.VfsFat(sdcard) storage.mount(vfs, "/sd") save_pixels("/sd/screenshot.bmp") plot() save() print('done')

What works, what doesn't?

When I first wrote about MicroPlot it ran on the Pico. The code worked but I knew it would need refactoring as I extended it to cover other plot types and other devices.

Past experience has taught me the value of automated tests for embedded projects. I wondered if I could capture MicroPlot's output in bitmap files and use them in tests. I didn't know much about the format of Bitmap files so I started to look for code I could borrow and adapt.

I soon discovered that bitmap files could use different ways of encoding colour, and the simplest, smallest format created monochrome bitmaps.  I soon got a Bitmap file-saver working reliably on the Pico. As you can see, it works, and you can find the code in the MicroPlot project on GitHub.

Soon after that I discovered some Adafruit colour bitmap saver code and adapted it to run on the Pico. As you can see, it too worked well.


When I ran the unmodified Adafruit code on the PyPortal, it scrambled the image, as shown earlier.

I've checked that I am using the latest production versions of the Adafruit code. Can anyone suggest what I'm doing wrong? The code (mine and Adafruit's) looks sensible, but it seems to corrupt every screenshot  bitmap that I try to take.

If you can spot the problem, let me know in the comments, tweet to @rareblog or respond to my cry for help on Discord.