Controlling Choices of Adjoints
DiffEqFlux is capable of training neural networks embedded inside of differential equations with many different techniques. For all of the details, see the DifferentialEquations.jl local sensitivity analysis documentation. Here we will summarize these methodologies in the context of neural differential equations and scientific machine learning.
Choosing a sensealg in a Nutshell
By default, a stable adjoint with an auto-adapting vjp choice is used. In many cases, a user can optimize the choice to compute more than an order of magnitude faster than the default. However, given the vast space to explore, use the following decision tree to help guide the choice:
- If you have 100 parameters or less, consider using forward-mode sensititivites. If the
f
function is not ForwardDiff-compatible, useForwardSensitivty
, otherwise useForwardDiffSensitivty
as its more efficient. - For larger equations, give
BacksolveAdjoint
andInterpolatingAdjoint
a try. If the gradient ofBacksolveAdjoint
is correct, many times it's the faster choice so choose that (but it's not always faster!). If your equation is stiff or a DAE, skip this step asBacksolveAdjoint
is almost certainly unstable. - If your equation does not use much memory and you're using a stiff solver, consider using
QuadratureAdjoint
as it is asymtopically more computationally efficient by trading off memory cost. - If the other methods are all unstable (check the gradients against each other!), then
ReverseDiffAdjoint
is a good fallback on CPU, whileTrackerAdjoint
is a good fallback on GPUs. - After choosing a general sensealg, if the choice is
InterpolatingAdjoint
,QuadratureAdjoint
, orBacksolveAdjoint
, then optimize the choice of vjp calculation next:- If your function has no branching (no if statements), use
ReverseDiffVJP(true)
. - If you're on the CPU and your function is very scalarized in operations but has branches, choose
ReverseDiffVJP()
. - If your on the CPU or GPU and your function is very vectorized, choose
ZygoteVJP()
. - Else fallback to
TrackerVJP()
if Zygote does not support the function.
- If your function has no branching (no if statements), use
Additional Details
A sensitivity analysis method can be passed to a solver via the sensealg
keyword argument. For example:
solve(prob,Tsit5(),sensealg=BacksolveAdjoint(autojacvec=ZygoteVJP()))
sets the adjoint sensitivity analysis so that, when this call is encountered in the gradient calculation of any of the Julia reverse-mode AD frameworks, the differentiation will be replaced with the BacksolveAdjoint
method where internal vector-Jacobian products are performed using Zygote.jl. From the DifferentialEquations.jl local sensitivity analysis page, we note that the following choices for sensealg
exist:
BacksolveAdjoint
InterpolatingAdjoint
(with checkpoints)QuadratureAdjoint
TrackerAdjoint
ReverseDiffAdjoint
(currently requiresusing DistributionsAD
)ZygoteAdjoint
(currently limited to special solvers)
Additionally, there are methodologies for forward sensitivity analysis:
ForwardSensitivty
ForwardDiffSensitivty
These methods have very low overhead compared to adjoint methods but have poor scaling with respect to increased numbers of parameters. Our benchmarks demonstrate a cutoff of around 100 parameters, where for models with less than 100 parameters these techniques are more efficient, but when there are more than 100 parameters (like in neural ODEs) these methods are less efficient than the adjoint methods.
Choices of Vector-Jacobian Products (autojacvec)
With each of these solvers, autojacvec
can be utilized to choose how the internal vector-Jacobian products of the f
function are computed. The choices are:
ReverseDiffVJP(compile::Bool)
: Usually the fastest when scalarized operations exist in thef
function (like in scientific machine learning applications like Universal Differential Equations) and the booleancompile
(i.e.ReverseDiffVJP(true)
) is the absolute fastest but requires that thef
function of the ODE/DAE/SDE/DDE has no branching. Does not support GPUs.TrackerVJP
: Not as efficient asReverseDiffVJP
, but supports GPUs.ZygoteVJP
: Tends to be the fastest VJP method if the ODE/DAE/SDE/DDE is written with mostly vectorized functions (like neural networks and other layers from Flux.jl). Bear in mind that Zygote does not allow mutation, making the solve more memory expensive and therefore slow.nothing
: Default choice given characteristics of the types in your model.true
: Forward-mode AD Jacobian-vector products. Should only be used on sufficiently small equationsfalse
: Numerical Jacobian-vector products. Should only be used if thef
function is not differentiable (i.e. is a Fortran code).
As other vector-Jacobian product systems become available in Julia they will be added to this system so that no user code changes are required to interface with these methodologies.
Manual VJPs
Note that when defining your differential equation the vjp can be manually overwritten by providing a vjp(u,p,t)
that returns a tuple f(u,p,t),v->J*v
in the form of ChainRules.jl. When this is done, the choice of ZygoteVJP
will utilize your VJP function during the internal steps of the adjoint. This is useful for models where automatic differentiation may have trouble producing optimal code. This can be paired with ModelingToolkit.jl for producing hyper-optimized, sparse, and parallel VJP functions utilizing the automated symbolic conversions.
Optimize-then-Discretize
The original neural ODE paper popularized optimize-then-discretize with O(1) adjoints via backsolve. This is the methodology BacksolveAdjoint
When training non-stiff neural ODEs, BacksolveAdjoint
with ZygoteVJP
is generally the fastest method. Additionally, this method does not require storing the values of any intermediate points and is thus the most memory efficient. However, BacksolveAdjoint
is prone to instabilities whenever the Lipschitz constant is sufficiently large, like in stiff equations, PDE discretizations, and many other contexts, so it is not used by default. When training a neural ODE for machine learning applications, the user should try BacksolveAdjoint
and see if it is sufficiently accurate on their problem.
Note that DiffEqFlux's implementation of BacksolveAdjoint
includes an extra feature BacksolveAdjoint(checkpointing=true)
which mixes checkpointing with BacksolveAdjoint
. What this method does is that, at saveat
points, values from the forward pass are saved. Since the reverse solve should numerically be the same as the forward pass, issues with divergence of the reverse pass are mitigated by restarting the reverse pass at the saveat
value from the forward pass. This reduces the divergence and can lead to better gradients at the cost of higher memory usage due to having to save some values of the forward pass. This can stabilize the adjoint in some applications, but for highly stiff applications the divergence can be too fast for this to work in practice.
To avoid the issues of backwards solving the ODE, InterpolatingAdjoint
and QuadratureAdjoint
utilize information from the forward pass. By default these methods utilize the continuous solution provided by DifferentialEquations.jl in the calculations of the adjoint pass. QuadratureAdjoint
uses this to build a continuous function for the solution of adjoint equation and then performs an adaptive quadrature via Quadrature.jl, while InterpolatingAdjoint
appends the integrand to the ODE so it's computed simultaneously to the Lagrange multiplier. When memory is not an issue, we find that the QuadratureAdjoint
approach tends to be the most efficient as it has a significantly smaller adjoint differential equation and the quadrature converges very fast, but this form requires holding the full continuous solution of the adjoint which can be a significant burden for large parameter problems. The InterpolatingAdjoint
is thus a compromise between memory efficiency and compute efficiency, and is in the same spirit as CVODES.
However, if the memory cost of the InterpolatingAdjoint
is too high, checkpointing can be used via InterpolatingAdjoint(checkpointing=true)
. When this is used, the checkpoints default to sol.t
of the forward pass (i.e. the saved timepoints usually set by saveat
). Then in the adjoint, intervals of sol.t[i-1]
to sol.t[i]
are re-solved in order to obtain a short interpolation which can be utilized in the adjoints. This at most results in two full solves of the forward pass, but dramatically reduces the computational cost while being a low-memory format. This is the preferred method for highly stiff equations when memory is an issue, i.e. stiff PDEs or large neural DAEs.
For forward-mode, the ForwardSensitivty
is the version that performs the optimize-then-discretize approach. In this case, autojacvec
corresponds to the method for computing J*v
within the forward sensitivity equations, which is either true
or false
for whether to use Jacobian-free forward-mode AD (via ForwardDiff.jl) or Jacobian-free numerical differentiation.
Discretize-then-Optimize
In this approach the discretization is done first and then optimization is done on the discretized system. While traditionally this can be done discrete sensitivity analysis, this is can be equivalently done by automatic differentiation on the solver itself. ReverseDiffAdjoint
performs reverse-mode automatic differentiation on the solver via ReverseDiff.jl, ZygoteAdjoint
performs reverse-mode automatic differentiation on the solver via Zygote.jl, and TrackerAdjoint
performs reverse-mode automatic differentiation on the solver via Tracker.jl. In addition, ForwardDiffSensitivty
performs forward-mode automatic differentiation on the solver via ForwardDiff.jl.
We note that many studies have suggested that this approach produces more accurate gradients than the optimize-than-discretize approach
Special Notes on Equation Types
While all of the choices are compatible with ordinary differential equations, specific notices apply to other forms:
Differential-Algebraic Equations
We note that while all 3 are compatible with index-1 DAEs via the derivation in the universal differential equations paper (note the reinitialization), we do not recommend BacksolveAdjoint
one DAEs because the stiffness inherent in these problems tends to cause major difficulties with the accuracy of the backwards solution due to reinitialization of the algebraic variables.
Stochastic Differential Equations
We note that all of the adjoints except QuadratureAdjoint
are applicable to stochastic differential equations.
Delay Differential Equations
We note that only the discretize-then-optimize methods are applicable to delay differential equations. Constant lag and variable lag delay differential equation parameters can be estimated, but the lag times themselves are unable to be estimated through these automatic differentiation techniques.