Neural ODE from scratch and revisit backward propagation

Fei Cheung
10 min readSep 16, 2020

--

Neural Network is a powerful tool as a universal functional approximator and plays a vital role in Deep Learning. To boost the performance on various task, researchers add prior knowledge to the network on different tasks, for example, using Convolutional Neural Neural network on image recognition and Recurrent Network on sequence model. In NIPS 2018, researchers from University of Toronto propose a new Neural network architecture:

Neural Ordinary Differential Equations

and received the best paper award.

I spent some time trying to understand the nitty-gritty details of Neural ODE and reproduce the results. I did it as an exercise and decided to write a post about it (full code here). This blog intends to go through the details, intuition and tricks to implement it with PyTorch. Most of the code reference from Mikhail Surtsukov excellent blog post and the original paper.

Prerequisites

I strongly advise reader go through part 1, part 2, appendix A, B, C in the original paper first, which this blog focus on to elaborate. Nevertheless, you will need the following knowledge to understand the material.

  1. Basic understanding of calculus and differential equations.
  2. Basic understanding of Neural Network and backward propagation
  3. Basic understanding of Python and Pytorch

Why ODE (Ordinary Differential Equations)?

Many dynamical systems are more natural to described using ODE. For example, population, epidemic (like COVID-19), and especially physics and engineering. It is natural to learn the dynamics instead of the observable, which could be hard to understand and describe. (especially for physics and control theory). For example, to describe a simple harmonic motion:

Source: wiki

We can write down the Equation of Motion (EOM) in terms of ODE.

Equation of motion (EOM) of Simple Harmonic Oscillator
The solution to the EOM, notice the simplicity of ODE

We can see the differential equation is just linear equation(s) and the solution is in a more complex form. It is very natural to model the interested system in terms of differential equations.

General Idea of training the Neural ODE

Fig.1 Visualisation of the training process

The idea is simple and just like ordinary supervised learning

  1. Solve the Neural ODE with some initial value
  2. compare the path with the true label and compute the loss
  3. minimise the loss

Let’s start by solving the Neural ODE

We can solve the following ODE

notice here z is a vector, θ is the parameters of the function

by integrating along the path

start from t_0 and follow the slope to generate z from later time step

In practice, we don’t want to calculate the analytic form of the ODE (can we ?) so instead of integrating the function, we can get z(t) with numerical approximation.

There are lots of numerical methods to solve the differential equation, and we use the simplest one here, the Euler Method. The idea is simple. Instead of doing it in infinitesimal steps, we approximate the solutions with small discrete time steps.

Figure 2. Source: wiki

Now we can create our first python function to solve the ODE given the initial value, the ODE function, and find the trajectory in later time step.

def ode_solve(z0, t0, t1, f):
"""
Simple Euler ODE initial Value solver
:param z0: initial value
:param t0: initial time stamp
:param t1: target time stamp
:param f: ODE function
:return: z1
"""
h_max = 0.05
n_steps = math.ceil((abs(t1 - t0) / h_max).max().item())

h = (t1 - t0) / n_steps
t = t0
z = z0

for i_step in range(n_steps):
z = z + h * f(z, t)
t = t + h
return z

Next we need to compute the loss.

Compute the loss function

Computing the loss is easy. By looking at figure 1, we can calculate the loss by just calculating the distance between the outputs and the labels. It can be done using mean square error. For example, in PyTorch we can use the following function.

loss = F.mse_loss(outputs, labels.detach())

After that, we can use standard optimisation technique to minimise the loss. The standard one for differentiable functions is gradient descent, and we can compute the gradient with backward propagation.

How do we do backward propagation?

For an ordinary neural network, we know how to do backward propagation to each neuron and parameters.

Figure 3. Source: 3b1b

We can certainly do it on the numerical integrator, but it will introduce high memory cost and numerical error. On the other hand, we can assume that we are doing it on the “Correct path” with infinitesimal step, and compute the gradient irrespective of the integration we use.

But how do we do backpropagation on a continuous path? The loss w.r.t. to parameters will also be continuous, and how do we compute it?

Figure 4. The black arrow represents forward path, green arrow show every point on the curve depends on the parameters, the blue arrow represents backward propagation path

The original authors use the adjoint sensitive method which solves the problem and provides proofs in the appendix. I will explain a little bit more on the details and proof following the author’s approach.

Continuous Backward Propagation

First, imagine we treat two near points as two separate nodes with share parameters:

Figure 5: Red arrow shows forward path and blue shows the backward path

Now we can do backward propagation using chain rules:

for ease of notation: we write the partial derivatives as follow

Notice that a(t) and z(t) have same dimension, and dT/dz is Jacobian matrix

Instead of computing the partial derivatives directly, we find the “Change in partial derivatives” by taking limit from discrete case to continuous case.

following proof from the original paper:

Original proof from paper from appendix B

Notice by taking the limit to 0, we can use the definition of derivative to get another ODE to describe how errors are being backpropagated w.r.t. z(t). It is defined as the adjoint differential equation.

Important implementation detail

The above equation describes the “error dynamic” for between two points on the curve. To accommodate for all pairs, we need to adjust the error by calculating the total change. (notice why there’s 0+ in the limit?) If this is not obvious to the reader (which happens to me), we can compare the computational graph in the discrete case.

Figure 5: Notice the corresponding path for backward propagation. The discrete jump represents by adding the partial derivative w.r.t. to the loss function directly on that point.

Using this method, we can compute the partial derivative of loss w.r.t every point on the curve.

Now we need the gradient of the parameters

It puzzles me for a long time since every point of the curve still contribute to derivative w.r.t. the parameters.

Figure 4.

My first intuition is to find:

which is hard to compute without the explicit form of z(t).

The original paper proposed a brilliant method to tackle the problem. Instead of treating the parameters as constant to the dynamics / ODE. We treat it as part of the dynamic! We construct an augmented ODE:

augmented dynamic from the original paper

which treat

I know it is confusing, I will explain more about it in the later session

Now by following the above proof again, replacing

We have

Parts of the Jacobian are known since we know the “dynamic” of θ and t.

notice the augmented function for θ and t are constant hence the corresponding partial derivatives is 0

now we can obtain the adjoint differential equation for all variable of the differential equations.

final “error” dynamic

Now all we need is the initial points, and we can use any ODE solver to solve

So what does these equation means?

I found it fuzzy to understand the following partial derivatives

I spent some time digesting the meaning, and I found the following way is easier to understand and picture the equation.

Figure 6. backward propagation in the discrete case

Unlike figure 4, which every z(t) depends on θ, we treat the parameters as a constant “function” in time.

following the same argument, we can get the same adjoint state for the parameters

Now observe from figure 6 (blue arrow) that

mean the total derivatives at time t. When we back propagation through time, we obtain the the “total error” from the “final time step” to current time t. If we go back all the way to the starting point, we obtain the total derivatives.

In the continuous case, we only need to find

the first time step of the ODEs.

and we will obtain the total partial derivative of the loss function w.r.t. the parameters.

How do we find the initial value?

To find the initial value for the adjoint differential equation, we can use chain rule starting from the loss function.

I cannot find a very satisfying way to set the initial value to zero for the adjoint differential equation for the parameters, nevertheless here’s how I picture it.

  1. θ(t_1) can only contribute an infinitesimal amount to the loss so we can set it to zero
  2. by solving the initial value problem, the initial value contributes to a constant shift to the final output only. We can don’t need to find the

If anyone has a better explanation, please let me know!

Backward Propagation summary:

adjoint differential equation
initial value
solve the IVP

By solving the adjoint differential equation using any ODE solver, we can perform continuous backward propagation.

Important methods in PyTorch

I am not going into every detail in the code. Interested audience can find it here. There are a few important parts that I will cover here which relate to the theory described above.

First, How to compute the adjoint differential equation at time t

Since the adjoint differential equation requires partial derivative to the ODE function, the easiest way is to use autograd to find the derivatives.

final “error” dynamic
out = self.forward(z, t)
# direction for autograd
a = grad_outputs

adfdz, adfdt, *adfdp = torch.autograd.grad(
(out,)
, (z, t) + tuple(self.parameters())
, grad_outputs=(a)
, allow_unused=True
, retain_graph=True
)

Now notice in torch.autograd.grad() we have an option to choose the output direction. If we use the adjoint state as direction, PyTorch automatically gives us the augmented dynamic we want.

Custom PyTorch function to perform forward and backward propagation

For autograd to work its magic by just calling

loss.backward()

We need to have a way to automate the forward pass and backward pass irrespective of the ODE solver.

We can define function with custom forward and backward using

torch.autograd.Function

which takes in the Neural ODE model

func

In forward pass, it simply means solving the Neural ODE function IVP.

@staticmethod
def forward(ctx, z0, t, flat_parameters, func):
bs, *z_shape = z0.size()
time_len = t.size(0)

with torch.no_grad():
z = torch.zeros(time_len, bs, *z_shape).to(z0)
z[0] = z0
for i_t in range(time_len - 1):
z0 = ode_solve(z0, t[i_t], t[i_t+1], func)
z[i_t+1] = z0

ctx.func = func
ctx.save_for_backward(t, z.clone(), flat_parameters)

return z

We can see there’s for loop to generate the trajectory using ode_solve with the Nerual ODE (func). ctx is an object that save useful information for backward pass.

Now for the custom backward function:

@staticmethod
def backward(ctx, dLdz):

func = ctx.func
t, z, flat_parameters = ctx.saved_tensors
time_len, bs, *z_shape = z.size()
n_dim = np.prod(z_shape)
n_params = flat_parameters.size(0)

def augmented_dynamics(aug_z_i, t_i):
...
return torch.cat((func_eval, -adfdz, -adfdp, -adfdt), dim=1)


dLdz = dLdz.view(time_len, bs, n_dim)

with torch.no_grad():
adj_z = torch.zeros(bs, n_dim).to(dLdz)
adj_p = torch.zeros(bs, n_params).to(dLdz)
adj_t = torch.zeros(time_len, bs, 1).to(dLdz)

for i_t in range(time_len-1, 0, -1):
z_i = z[i_t]
t_i = t[i_t]
f_i = func(z_i, t_i).view(bs, n_dim)

dLdz_i = dLdz[i_t]
a_t = torch.transpose(dLdz_i.unsqueeze(-1), 1, 2)
dLdt_i = -torch.bmm(a_t, f_i.unsqueeze(-1))[:, 0]

adj_z += dLdz_i
adj_t[i_t] = adj_t[i_t] + dLdt_i


aug_z = torch.cat((z_i.view(bs, n_dim), adj_z, torch.zeros(bs, n_params).to(z), adj_t[i_t]), dim=-1)

aug_ans = ode_solve(aug_z, t_i, t[i_t-1], augmented_dynamics)

adj_z[:] = aug_ans[:, n_dim:2*n_dim]
adj_p[:] += aug_ans[:, 2*n_dim:2*n_dim + n_params]
adj_t[i_t-1] = aug_ans[:, 2*n_dim + n_params:]

del aug_z, aug_ans

## Adjust 0 time adjoint with direct gradients
# Compute direct gradients
dLdz_0 = dLdz[0]
dLdt_0 = -torch.bmm(torch.transpose(dLdz_0.unsqueeze(-1), 1, 2), f_i.unsqueeze(-1))[:, 0]

# Adjust adjoints
adj_z += dLdz_0ri
adj_t[0] = adj_t[0] + dLdt_0

# forward: (z0, t, parameters, func)
return adj_z.view(bs, *z_shape), adj_t, adj_p, None

I strongly advice interested reader go through the code by themselves since the dimension of each vector are not shown in the code. It is essential to go through the details for it to work. Things that I want to highlight in the code.

  1. Remember to adjust the gradient at each data point. The gradient contributes from both the path and the loss refer to the above figures.
  2. The augmented dynamic, and adjoint state computation follows the algorithm from appendix C.

Training Example on Toy Model

Figure 7. Training the network on a Spiral ODE

Conclusion

I started this exercise since I want to understand the details of Neural ODE and the best way to do it is code it out! This post focuses more on the derivation, implementation details and miss a lot of things from the original paper. For example, how to use it as a residual block in solving MNIST, generative model, and various experiment results. I strongly advise checking the original paper and presentation from the authors for more details.

Reference

Mikhail Surtsukov excellent blog post

the original paper

--

--

Fei Cheung

Hongkonger, Maker, Teacher. Interested in all kind of stuff, from physics to Machine Learning. Lead Engineer in 2019 CES innovation award honoree.