Installing Jax on the Jetson Nano
Yesterday I got Jax running on a Jetson Nano. There's been some online interest and I promised to describe the installation. I'll cover the process below but first I'll briefly explain
What's Jax, and why do I want it?
tldr: The cool kids have moved from Tensorflow to Jax.
If you're interested in AI and Artificial Neural Networks (ANNs) you'll have heard of Google's TensorFlow.
Tensorflow with Keras makes it easy to
- build an ANNs from standard software components
- train the network
- test it and
- deploy it into production
Even better: TensorFlow can take advantage of a GPU or TPU if they are available, which allows a dramatic speedup of the training process. That's a big deal because training a complex network from scratch might take weeks of computer time.
However, I find it hard to get TensorFlow and its competitors to use novel network architectures. I am currently exploring types of network that I can't implement easily in TensorFlow.
I could code them in Python and Numpy but I want to be able to run them fast, and that means using a GPU.
You can read about why DeepMind uses Jax here.
Jax implements many of Numpy's primitives on a GPU and it can also convert functional, side-effect-free Python code.
It also supports automatic differentiation. I'm not going to be using back propagation in the work I'm doing, so that's not a big deal for me. If you are, you can see online examples that build and train networks from scratch in a few lines of Python code.
I'll give some numbers for the speed-ups you get later in the article. For now, I'll just say that Jax code runs faster than Numpy on the Nano, and faster still when compared with Numpy on the Raspberry Pi.
Jax on the Jetson Nano
|Jetson Nano with SSD
There are several ways you can run Jax-based applications; it's not hard to install on a Linux system, and you can run it on Google's wonderful, free Colab service. My interest lies in using neural networks with mobile robots, so I want to run it on an Edge Computing device with GPU hardware built-in.
The Jetson Nano is ideal for that.
I recently came across an article on NVIDIA's developer website that described how to install Jax on the Nano. It included a warning that the installation was slow and required a large (10GB) swap file.
I decided to re-configure one of my Nanos from scratch and set it up to boot from a USB SSD. I did the whole process in three stages using links which I have listed below, along with the gotcha I hit and my work-around.
- I set up the Nano in headless mode using an SD card.
- I configured the Nano to boot from a USB SSD.
- I installed Jax from source.
Set up the Nano in headless mode using an SD card
NVIDIA have really simplified the setup process since I started exploring the Nano a couple of years ago.
The NVIDIA Getting Started guide now gives two ways to set up your Nano.
One requires an HDMI screen, a mouse and keyboard.
The other needs an Ethernet connection, a suitable power supply and a jumper to configure the Nano to power itself from its barrel connector rather than from USB.
I want to run my Nano in fast, full-power mode and I have very limited desk space. I decided to go for a Headless Setup. I followed these three steps in the NVIDIA guide:
* There are alternative power supplies to the Adafruit supply that NVIDIA mentioned. I purchased this one in the UK.
** Make sure you scroll down to the section Initial Setup Headless Mode
Configure the Nano to boot from USB SDD
I followed this excellent guide from JetsonHacks.
I found that the Nano did not boot from the SSD first time but it's booted reliably ever since.
Installing Jax from source
I followed the process given in this post on the NVIDIA developer website. I hit one snag, which was easy to solve: After I had set up a Python 3.9 virtual environment, I needed to run
sudo apt install python3.9-distutils
before I could run the stages that used pip.
First time through I failed to realise that the heading How to Add Swap Space on Ubuntu 18.04 is actually a link to a web page that tells you how to do that. I also set the swap file size to 16 GB rather than 10GB as recommended.
The installation guide says that the compilation process takes 12 hours. Mine took about 6 hours, possibly because I was running the Nano in Max Power mode.
Once the compilation process has finished, the remaining steps take a minute or less!
First Timing Tests
It looks as if Jax is six to seven times faster than numpy on the Nano when dot-multiplying a 4000x4000 element matrix with itself, and 25 times faster than a Pi 4 running numpy. That's worth the effort of installation :)