Vector Faculty Member David Duvenaud and collaborators have published a new paper, “Scalable Gradients for Stochastic Differential Equations,” in Artificial Intelligence and Statistics. The paper uses backpropagation to fit stochastic continuous-time models and offers the potential for building more complex prediction models in fields like physics, finance, and human genetics.
Stochastic differential equations (SDEs) — differential equations that account for uncertainty due to unseen interactions — have a long history in fields like finance where they help forecast how stock prices might evolve over time. However, they’ve been limited in the number of parameters (generally 10 or 20) that can be fit at any one time and, generally speaking, were not scalable to the large neural networks with millions of parameters that are used to fit data in other domains.
“Scalable Gradients for Stochastic Differential Equations” specifies the dynamics of these models with neural networks, and trains them with gradient-based optimization.
Duvenaud and his co-authors Ricky T. Chen, Ting-Kam Leonard Wong, and Xuechen Li combine these processes with deep neural networks. Duvenaud had previously worked on the idea of switching from discrete time — data sampled at regular intervals — to continuous time — data sampled at any point in the flow, in the paper “Neural Ordinary Differential Equations,” which won a Best Paper Award at NeurIPS 2018. In this new paper, they generalized the math allowing for neural ODEs to be trained for SDEs, a much larger family of models.
Continuous time backpropagation already existed for neural ODEs, but no such reverse mode method existed for SDEs. The algorithm ended up being a straightforward extension of the ODE method with fixed noise, a sort of continuous-time reparameterization trick.
These continuous time models offer a more fine-grained and flexible way to incorporate or sample data time series data. They have the potential to help better model medical data, predict prices across the stock market, or track the evolution of populations over long periods of time.