Chapter 3: Neural Ordinary Differential Equations

If we want to build a continuous-time or continuous-depth model, differential equation solvers are a useful tool. But how exactly can we treat odeint as a layer for building deep models? The previous chapter showed how to compute its gradients, so the only thing missing is to give it some parameters. This chapter will show how and why to do so.

In this chapter we won’t be using any deep learning frameworks. Instead, we’ll build everything from scratch using differentiable Numpy commands available through JAX.

Preliminaries: Training a residual network

As a warm-up, we can define a simple deep neural network in only a few lines:

import jax.numpy as jnp

def mlp(params, inputs):
  # A multi-layer perceptron, i.e. a fully-connected neural network.
  for w, b in params:
    outputs = jnp.dot(inputs, w) + b  # Linear transform
    inputs = jnp.tanh(outputs)        # Nonlinearity
  return outputs

mlp is simply a composition of linear and nonlinear layers. Its parameters params are a list of weight matrices and bias vectors.

To make larger models, we can always chain together or compose layers. As a standard example, chaining together some smaller neural networks, such as mlp layers, adding each one’s input to its output, is called a residual network:

def resnet(params, inputs, depth):
  for i in range(depth):
    outputs = mlp(params, inputs) + inputs
  return outputs

To fit this model to data, we also need a loss, an initializer, and an optimizer:

import numpy.random as npr
from jax.api import jit, grad

resnet_depth = 3
def resnet_squared_loss(params, inputs, targets):
  preds = resnet(params, inputs, resnet_depth)
  return jnp.mean(jnp.sum((preds - targets)**2, axis=1))

def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)):
  return [(scale * rng.randn(m, n), scale * rng.randn(n))
          for m, n, in zip(layer_sizes[:-1], layer_sizes[1:])]

# A simple gradient-descent optimizer.
@jit
def resnet_update(params, inputs, targets):
  grads = grad(resnet_squared_loss)(params, inputs, targets)
  return [(w - step_size * dw, b - step_size * db)
          for (w, b), (dw, db) in zip(params, grads)]

As a sanity check, let’s fit our resnet to a toy 1D dataset (green circles) and plot the predictions of the trained model (blue curve):

# Toy 1D dataset.
inputs = jnp.reshape(jnp.linspace(-2.0, 2.0, 10), (10, 1))
targets = inputs**3 + 0.1 * inputs

# Hyperparameters.
layer_sizes = [1, 20, 1]
param_scale = 1.0
step_size = 0.01
train_iters = 1000

# Initialize and train.
resnet_params = init_random_params(param_scale, layer_sizes)
for i in range(train_iters):
  resnet_params = resnet_update(resnet_params, inputs, targets)

# Plot results.
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(6, 4), dpi=150)
ax = fig.gca()
ax.scatter(inputs, targets, lw=0.5, color='green')
fine_inputs = jnp.reshape(jnp.linspace(-3.0, 3.0, 100), (100, 1))
ax.plot(fine_inputs, resnet(resnet_params, fine_inputs, resnet_depth), lw=0.5, color='blue')
ax.set_xlabel('input')
ax.set_ylabel('output')

Building a neural ODE

Similar to a residual network, a neural ODE (or ODE-Net) takes a simple layer as a building block, and chains many copies of it together to buld a bigger model. In particular, our “base layer” is going to specify the dynamics of an ODE, and we’re going to chain the output of these base layers together according to the logic on an ODE solver.

Specifying the dynamics layer

What kind of layer do we need to specify the dynamics of an ODE? Recall that an ODE initial value problem has the form:

\[\dot y(t) = f(y(t), t, \theta), \qquad y(0) = y_0,\]

where the initial value $y_0 \in \mathbb{R}^n$.

We’ve added parameters $\theta$ to the dynamics, so the dynamics function has the dimensions $f : \mathbb{R}^{n} \times \mathbb{R} \times \mathbb{R}^{|\theta|} \to \mathbb{R}^n$, where $|\theta|$ is the number of parameters we’ve added to $f$.

In plain English, we need the dynamics function to take in the current state $y(t)$ of the ODE, the current time $t$, and some parameters $\theta$, and output $\frac{\partial y(t)}{\partial t}$, which has the same shape as $y(t)$.

We can easily build such a function by simply concatenating the state and current time, and sending that as the input to mlp:

def nn_dynamics(state, time, params):
  state_and_time = jnp.hstack([state, jnp.array(time)])
  return mlp(params, state_and_time)

The remaining part of our model that we need to specify is how to combine evaluations of this dynamics layer. We could use any solver. JAX’s odeint function implements the standard adapative-step Dormand-Price solver.

from jax.experimental.ode import odeint

def odenet(params, input):
  start_and_end_times = jnp.array([0.0, 1.0])
  init_state, final_state = odeint(nn_dynamics, input, start_and_end_times, params)
  return final_state

Without loss of generality, we can make the integration time go from 0 to 1.

That’s it! We’ve defined an ODE net. Below, we’ll talk a bit more about what’s happening inside of odeint, but for now, let’s hook it up to an optimizer and see if we can fit it to data!

Batching an ODE Net

To support batching (evaluating the ODE-Net on more than one training example) we can simply use Jax’s vmap function, which automatically adds batching dimensions. This transformation is non-trivial, since odeint contains while loops and control flow, but JAX can do it automatically. The vmapped odeint creates independent parallel solvers running in parallel on each batch element, waiting for the last one to finish before returning all the final states together. But it still combines the calls to the dynamics function into one efficient vectorized call shared across all batch elements.

In enviroments that don’t have vmap, typically what’s done is to create one giant ODE that combines the dynamics of every example in the batch, solve it all in one call to odeint, and then split up the results across the batch.

from jax import vmap
batched_odenet = vmap(odenet, in_axes=(None, 0))

What remains is simply to initialize the parameters, hook up the model to the loss function, and train the ODE-Net:

# We need to change the input dimension to 2, to allow time-dependent dynamics.
odenet_layer_sizes = [2, 20, 1]

def odenet_loss(params, inputs, targets):
  preds = batched_odenet(params, inputs)
  return jnp.mean(jnp.sum((preds - targets)**2, axis=1))

@jit
def odenet_update(params, inputs, targets):
  grads = grad(odenet_loss)(params, inputs, targets)
  return [(w - step_size * dw, b - step_size * db)
          for (w, b), (dw, db) in zip(params, grads)]

# Initialize and train ODE-Net.
odenet_params = init_random_params(param_scale, odenet_layer_sizes)

for i in range(train_iters):
  odenet_params = odenet_update(odenet_params, inputs, targets)

# Plot resulting model.
fig = plt.figure(figsize=(6, 4), dpi=150)
ax = fig.gca()
ax.scatter(inputs, targets, lw=0.5, color='green')
fine_inputs = jnp.reshape(jnp.linspace(-3.0, 3.0, 100), (100, 1))
ax.plot(fine_inputs, resnet(resnet_params, fine_inputs, resnet_depth), lw=0.5, color='blue')
ax.plot(fine_inputs, batched_odenet(odenet_params, fine_inputs), lw=0.5, color='red')
ax.set_xlabel('input')
ax.set_ylabel('output')
plt.legend(('Resnet predictions', 'ODE Net predictions'))

The two regression methods both match the data, but extrapolate slightly differently.

Activation trajectories

In a deep residual network, we can examine the activations between each block. In an ODE-Net, we can instead examing the activation trajectories as a function of depth:

fig = plt.figure(figsize=(6, 4), dpi=150)
ax = fig.gca()

@jit
def odenet_times(params, input, times):
  def dynamics_func(state, time, params):
    return mlp(params, jnp.hstack([state, jnp.array(time)]))
  return odeint(dynamics_func, input, times, params)

times = jnp.linspace(0.0, 1.0, 200)

for i in fine_inputs:
  ax.plot(odenet_times(odenet_params, i, times), times, lw=0.5)

ax.set_xlabel('input / output')
ax.set_ylabel('time / depth')

In this toy setting where there is only one hidden unit, the trajectories can never cross each other, limiting the classes of functions that can be learned. However, this limitation can be overcome (if desired) by adding auxiliary dimensions to the network’s input that are discarded at the network output.

What form can the dynamics take?

There are a few restrictions to make the ODE solution well-defined and unique which we will discuss later. But in general it can be almost any tractable, differentiable, parametric function. In other words, odeint is a layer that takes in another layer to specify its dynamics function. This layer can be a fully-connected net, a convnet, U-net, or even some kinds of transformer!

Where can we use odeint layers?

The short answer is: anywhere you can use a residual net, you can use an ODENet. Both require the size of the input to be the same size as the output.

Computational advantages of neural ODEs

Why would we want to introduce all this extra complexity into our network architecture? Like Deep Equilibrium Models, there are some computational advantages to implictly defining the output of our model, and leaving it up to an adaptive solver to approximate it:

We would like our models to spend their compute resources wisely, and only think hard about difficult problems whose answer is important. Adaptive ODE solvers, developed over the last 120 years or so, achieve this in a limited way.

The standard approach to building adaptive ODE solvers is to monitor the difference in the predicted trajectory made by two different extrapolation methods. If this difference grows large, it suggests that at least one of the extrapolations methods is making bad predictions. These methods then attempt to recover by starting over and making predictions less far ahead (i.e. taking smaller steps through time).

Different solvers can handle different sorts of dynamics more or less easily, but generally speaking, the simpler the dynamics, the fewer steps an adaptive solver will need to approximate the answer to a given accuracy.

Most adaptive ODE solvers require the user to specify an error tolerance (both relative and absolute) that the solver will try to meet. For most real systems, we can’t guarantee that any particular error target will be met. But even in that case, the error tolerance is a way to trade off compute time against the precision of the answer. This is a more flexible approach in some ways than weight pruning or quantization, since tolerances can be adjusted throughout training, or even after the model has been deployed.

Modeling advantages of neural ODEs

Besides having different computational tradeoff from fixed-depth networks, exact solutions of ODE-based networks are also different model class than standard neural nets, with a few different properties:

Modeling disadvantages of neural ODEs

Computational disadvantages of neural ODEs

Software

Here are a few of the more comprehensive toolkits that let one fit neural ODES:

Stochastic and Partial Differential Equations

Besides ordinary differential equations, there are many other variants of differential equations that can be fit by gradients, and developing new model classes based on differential equations is an active research area. The solution of almost any type of differential equation can be seen as a layer!

Here are some pointers to recent work in these areas: