More fun with Spiking Neural Networks and JAX
I'm currently exploring Spiking Neural Networks. SNNs (Spiking Neural Networks) try to model the brain more accurately than most of the Artificial Neural Networks used in Deep Learning. There are some SNN implementations available in TensorFlow and PyTorch but I'm keen to explore them using pure Python. I find that Python code gives me confidence in my understanding. But there's a problem. SNNs need a lot of computing power. Even if I use numpy, large-scale simulations can run slowly. Spike generation - code below So I'm using JAX. JAX code runs in a traditional Python environment. JAX has array processing modules that are closely based on numpy's syntax. It also has a JIT (Just-in-time) compiler that lets you transparently deploy your code to GPUs. Jitting imposes some minor restrictions on your code but it leads to dramatic speed-ups. JAX and robotics As I mentioned in an earlier post, you can run JAX on NVIDIA's Jetson family. You get excellent performance on...