Controlling Choices of Adjoints

DiffEqFlux is capable of training neural networks embedded inside of differential equations with many different techniques. For all of the details, see the DifferentialEquations.jl local sensitivity analysis documentation. Here we will summarize these methodologies in the context of neural differential equations and scientific machine learning.

Choosing a sensealg in a Nutshell

By default, a stable adjoint with an auto-adapting vjp choice is used. In many cases, a user can optimize the choice to compute more than an order of magnitude faster than the default. However, given the vast space to explore, use the following decision tree to help guide the choice:

  • If you have 100 parameters or less, consider using forward-mode sensititivites. If the f function is not ForwardDiff-compatible, use ForwardSensitivty, otherwise use ForwardDiffSensitivty as its more efficient.
  • For larger equations, give BacksolveAdjoint and InterpolatingAdjoint a try. If the gradient of BacksolveAdjoint is correct, many times it's the faster choice so choose that (but it's not always faster!). If your equation is stiff or a DAE, skip this step as BacksolveAdjoint is almost certainly unstable.
  • If your equation does not use much memory and you're using a stiff solver, consider using QuadratureAdjoint as it is asymtopically more computationally efficient by trading off memory cost.
  • If the other methods are all unstable (check the gradients against each other!), then ReverseDiffAdjoint is a good fallback on CPU, while TrackerAdjoint is a good fallback on GPUs.
  • After choosing a general sensealg, if the choice is InterpolatingAdjoint, QuadratureAdjoint, or BacksolveAdjoint, then optimize the choice of vjp calculation next:
    • If your function has no branching (no if statements), use ReverseDiffVJP(true).
    • If you're on the CPU and your function is very scalarized in operations but has branches, choose ReverseDiffVJP().
    • If your on the CPU or GPU and your function is very vectorized, choose ZygoteVJP().
    • Else fallback to TrackerVJP() if Zygote does not support the function.

Additional Details

A sensitivity analysis method can be passed to a solver via the sensealg keyword argument. For example:


sets the adjoint sensitivity analysis so that, when this call is encountered in the gradient calculation of any of the Julia reverse-mode AD frameworks, the differentiation will be replaced with the BacksolveAdjoint method where internal vector-Jacobian products are performed using Zygote.jl. From the DifferentialEquations.jl local sensitivity analysis page, we note that the following choices for sensealg exist:

  • BacksolveAdjoint
  • InterpolatingAdjoint (with checkpoints)
  • QuadratureAdjoint
  • TrackerAdjoint
  • ReverseDiffAdjoint (currently requires using DistributionsAD)
  • ZygoteAdjoint (currently limited to special solvers)

Additionally, there are methodologies for forward sensitivity analysis:

  • ForwardSensitivty
  • ForwardDiffSensitivty

These methods have very low overhead compared to adjoint methods but have poor scaling with respect to increased numbers of parameters. Our benchmarks demonstrate a cutoff of around 100 parameters, where for models with less than 100 parameters these techniques are more efficient, but when there are more than 100 parameters (like in neural ODEs) these methods are less efficient than the adjoint methods.

Choices of Vector-Jacobian Products (autojacvec)

With each of these solvers, autojacvec can be utilized to choose how the internal vector-Jacobian products of the f function are computed. The choices are:

  • ReverseDiffVJP(compile::Bool)
  • TrackerVJP
  • ZygoteVJP
  • nothing (a default choice given characteristics of the types in your model)
  • true (Forward-mode AD Jacobian-vector products)
  • false (Numerical Jacobian-vector products)

As a quick summary of the VJP choices, ReverseDiffVJP is usually the fastest when scalarized operations exist in the f function (like in scientific machine learning applications like Universal Differential Equations) and the boolean compile (i.e. ReverseDiffVJP(true)) is the absolute fastest but requires that the f function of the ODE/DAE/SDE/DDE has no branching. However, if the ODE/DAE/SDE/DDE is written with mostly vectorized functions (like neural networks and other layers from Flux.jl), then ZygoteVJP tends to be the fastest VJP method. Note that ReverseDiffVJP does not support GPUs, while TrackerAdjoint is not as efficient but supports GPUs. As other vector-Jacobian product systems become available in Julia they will be added to this system so that no user code changes are required to interface with these methodologies. Forward-mode AD should only be used on sufficiently small equations, and numerical autojacvecs should only be used if the f function is not differentiable (i.e. is a Fortran code).

Manual VJPs

Note that when defining your differential equation the vjp can be manually overwritten by providing a vjp(u,p,t) that returns a tuple f(u,p,t),v->J*v in the form of ChainRules.jl. When this is done, the choice of ZygoteVJP will utilize your VJP function during the internal steps of the adjoint. This is useful for models where automatic differentiation may have trouble producing optimal code. This can be paired with ModelingToolkit.jl for producing hyper-optimized, sparse, and parallel VJP functions utilizing the automated symbolic conversions.


The original neural ODE paper popularized optimize-then-discretize with O(1) adjoints via backsolve. This is the methodology BacksolveAdjoint When training non-stiff neural ODEs, BacksolveAdjoint with ZygoteVJP is generally the fastest method. Additionally, this method does not require storing the values of any intermediate points and is thus the most memory efficient. However, BacksolveAdjoint is prone to instabilities whenever the Lipschitz constant is sufficiently large, like in stiff equations, PDE discretizations, and many other contexts, so it is not used by default. When training a neural ODE for machine learning applications, the user should try BacksolveAdjoint and see if it is sufficiently accurate on their problem.

Note that DiffEqFlux's implementation of BacksolveAdjoint includes an extra feature BacksolveAdjoint(checkpointing=true) which mixes checkpointing with BacksolveAdjoint. What this method does is that, at saveat points, values from the forward pass are saved. Since the reverse solve should numerically be the same as the forward pass, issues with divergence of the reverse pass are mitigated by restarting the reverse pass at the saveat value from the forward pass. This reduces the divergence and can lead to better gradients at the cost of higher memory usage due to having to save some values of the forward pass. This can stabilize the adjoint in some applications, but for highly stiff applications the divergence can be too fast for this to work in practice.

To avoid the issues of backwards solving the ODE, InterpolatingAdjoint and QuadratureAdjoint utilize information from the forward pass. By default these methods utilize the continuous solution provided by DifferentialEquations.jl in the calculations of the adjoint pass. QuadratureAdjoint uses this to build a continuous function for the solution of adjoint equation and then performs an adaptive quadrature via Quadrature.jl, while InterpolatingAdjoint appends the integrand to the ODE so it's computed simultaneously to the Lagrange multiplier. When memory is not an issue, we find that the QuadratureAdjoint approach tends to be the most efficient as it has a significantly smaller adjoint differential equation and the quadrature converges very fast, but this form requires holding the full continuous solution of the adjoint which can be a significant burden for large parameter problems. The InterpolatingAdjoint is thus a compromise between memory efficiency and compute efficiency, and is in the same spirit as CVODES.

However, if the memory cost of the InterpolatingAdjoint is too high, checkpointing can be used via InterpolatingAdjoint(checkpointing=true). When this is used, the checkpoints default to sol.t of the forward pass (i.e. the saved timepoints usually set by saveat). Then in the adjoint, intervals of sol.t[i-1] to sol.t[i] are re-solved in order to obtain a short interpolation which can be utilized in the adjoints. This at most results in two full solves of the forward pass, but dramatically reduces the computational cost while being a low-memory format. This is the preferred method for highly stiff equations when memory is an issue, i.e. stiff PDEs or large neural DAEs.

For forward-mode, the ForwardSensitivty is the version that performs the optimize-then-discretize approach. In this case, autojacvec corresponds to the method for computing J*v within the forward sensitivity equations, which is either true or false for whether to use Jacobian-free forward-mode AD (via ForwardDiff.jl) or Jacobian-free numerical differentiation.


In this approach the discretization is done first and then optimization is done on the discretized system. While traditionally this can be done discrete sensitivity analysis, this is can be equivalently done by automatic differentiation on the solver itself. ReverseDiffAdjoint performs reverse-mode automatic differentiation on the solver via ReverseDiff.jl, ZygoteAdjoint performs reverse-mode automatic differentiation on the solver via Zygote.jl, and TrackerAdjoint performs reverse-mode automatic differentiation on the solver via Tracker.jl. In addition, ForwardDiffSensitivty performs forward-mode automatic differentiation on the solver via ForwardDiff.jl.

We note that many studies have suggested that this approach produces more accurate gradients than the optimize-than-discretize approach

Special Notes on Equation Types

While all of the choices are compatible with ordinary differential equations, specific notices apply to other forms:

Differential-Algebraic Equations

We note that while all 3 are compatible with index-1 DAEs via the derivation in the universal differential equations paper (note the reinitialization), we do not recommend BacksolveAdjoint one DAEs because the stiffness inherent in these problems tends to cause major difficulties with the accuracy of the backwards solution due to reinitialization of the algebraic variables.

Stochastic Differential Equations

We note that all of the adjoints except QuadratureAdjoint are applicable to stochastic differential equations.

Delay Differential Equations

We note that only the discretize-then-optimize methods are applicable to delay differential equations. Constant lag and variable lag delay differential equation parameters can be estimated, but the lag times themselves are unable to be estimated through these automatic differentiation techniques.