SNNAX -- Spiking Neural Networks in JAX

Jamie Lohoff,Jan Finkbeiner,Emre Neftci
2024-09-05
Abstract:Spiking Neural Networks (SNNs) simulators are essential tools to prototype biologically inspired models and neuromorphic hardware architectures and predict their performance. For such a tool, ease of use and flexibility are critical, but so is simulation speed especially given the complexity inherent to simulating SNN. Here, we present SNNAX, a JAX-based framework for simulating and training such models with PyTorch-like intuitiveness and JAX-like execution speed. SNNAX models are easily extended and customized to fit the desired model specifications and target neuromorphic hardware. Additionally, SNNAX offers key features for optimizing the training and deployment of SNNs such as flexible automatic differentiation and just-in-time compilation. We evaluate and compare SNNAX to other commonly used machine learning (ML) frameworks used for programming SNNs. We provide key performance metrics, best practices, documented examples for simulating SNNs in SNNAX, and implement several benchmarks used in the literature.
Neural and Evolutionary Computing,Machine Learning
What problem does this paper attempt to address?
The main goal of this paper is to introduce a new library—SNNAX (Spiking Neural Networks in JAX), which aims to address the current shortcomings of spiking neural network (SNN) simulators in terms of flexibility, ease of use, and execution speed. Specifically: 1. **Improving Simulation and Training Efficiency**: By leveraging the advantages of the JAX framework, SNNAX can achieve efficient simulation and training of SNN models without sacrificing flexibility. It combines the intuitive interface of PyTorch with the high-performance execution capabilities provided by JAX. 2. **Enhancing Algorithm Exploration**: SNNAX supports flexible automatic differentiation and just-in-time compilation, making it easier for researchers to develop and test new learning algorithms, especially those based on synaptic plasticity. 3. **Compatibility with Multiple Hardware Accelerators**: The library can run not only on modern hardware accelerators such as GPUs and TPUs but also supports mapping trained models to specific neuromorphic hardware, thereby enhancing cross-platform applicability. 4. **Simplifying User Interface**: To improve user experience, SNNAX adopts a concise API design similar to PyTorch and fully utilizes JAX's function transformation features to optimize the model training process. This allows users to focus more on model design and experimentation rather than cumbersome technical details. In summary, SNNAX aims to be an easy-to-use, highly scalable, and high-performance tool that helps researchers conduct deeper research and innovation in the field of spiking neural networks.