Bayesian polynomial neural networks and polynomial neural ordinary differential equations

Colby Fronk,Jaewoong Yun,Prashant Singh,Linda Petzold
DOI: https://doi.org/10.1371/journal.pcbi.1012414
2024-10-11
PLoS Computational Biology
Abstract:Symbolic regression with polynomial neural networks and polynomial neural ordinary differential equations (ODEs) are two recent and powerful approaches for equation recovery of many science and engineering problems. However, these methods provide point estimates for the model parameters and are currently unable to accommodate noisy data. We address this challenge by developing and validating the following Bayesian inference methods: the Laplace approximation, Markov Chain Monte Carlo (MCMC) sampling methods, and variational inference. We have found the Laplace approximation to be the best method for this class of problems. Our work can be easily extended to the broader class of symbolic neural networks to which the polynomial neural network belongs. Polynomial neural ordinary differential equations (ODEs) are a recent approach for symbolic regression of dynamical systems governed by polynomials. However, they are limited in that they provide maximum likelihood point estimates of the model parameters. The domain expert using system identification often desires a specified level of confidence or range of parameter values that best fit the data. In this work, we use Bayesian inference to provide posterior probability distributions of the parameters in polynomial neural ODEs. To date, there are no studies that attempt to identify the best Bayesian inference method for neural ODEs and symbolic neural ODEs. To address this need, we explore and compare three different approaches for estimating the posterior distributions of weights and biases of the polynomial neural network: the Laplace approximation, Markov Chain Monte Carlo (MCMC) sampling, and variational inference. We have found the Laplace approximation to be the best method for this class of problems. We have also developed lightweight JAX code to estimate posterior probability distributions using the Laplace approximation.
biochemical research methods,mathematical & computational biology
What problem does this paper attempt to address?