BlackJAX: Composable Bayesian inference in JAX
Alberto Cabezas,Adrien Corenflos,Junpeng Lao,Rémi Louf,Antoine Carnec,Kaustubh Chaudhari,Reuben Cohn-Gordon,Jeremie Coullon,Wei Deng,Sam Duffield,Gerardo Durán-Martín,Marcin Elantkowski,Dan Foreman-Mackey,Michele Gregori,Carlos Iguaran,Ravin Kumar,Martin Lysy,Kevin Murphy,Juan Camilo Orduz,Karm Patel,Xi Wang,Rob Zinkov
2024-02-22
Abstract:BlackJAX is a library implementing sampling and variational inference algorithms commonly used in Bayesian computation. It is designed for ease of use, speed, and modularity by taking a functional approach to the algorithms' implementation. BlackJAX is written in Python, using JAX to compile and run NumpPy-like samplers and variational methods on CPUs, GPUs, and TPUs. The library integrates well with probabilistic programming languages by working directly with the (un-normalized) target log density function. BlackJAX is intended as a collection of low-level, composable implementations of basic statistical 'atoms' that can be combined to perform well-defined Bayesian inference, but also provides high-level routines for ease of use. It is designed for users who need cutting-edge methods, researchers who want to create complex sampling methods, and people who want to learn how these work.
Mathematical Software,Machine Learning,Computation