# Neural Stochastic Differential Equations With Method of Moments

With neural stochastic differential equations, there is once again a helper form neural_dmsde which can be used for the multiplicative noise case (consult the layers API documentation, or this full example using the layer function).

However, since there are far too many possible combinations for the API to support, in many cases you will want to performantly define neural differential equations for non-ODE systems from scratch. For these systems, it is generally best to use TrackerAdjoint with non-mutating (out-of-place) forms. For example, the following defines a neural SDE with neural networks for both the drift and diffusion terms:

dudt(u, p, t) = model(u)
g(u, p, t) = model2(u)
prob = SDEProblem(dudt, g, x, tspan, nothing)

where model and model2 are different neural networks. The same can apply to a neural delay differential equation. Its out-of-place formulation is f(u,h,p,t). Thus for example, if we want to define a neural delay differential equation which uses the history value at p.tau in the past, we can define:

dudt!(u, h, p, t) = model([u; h(t - p.tau)])
prob = DDEProblem(dudt_, u0, h, tspan, nothing)

First let's build training data from the same example as the neural ODE:

using Plots, Statistics
using Flux, Optimization, OptimizationFlux, DiffEqFlux, StochasticDiffEq, SciMLBase.EnsembleAnalysis

u0 = Float32[2.; 0.]
datasize = 30
tspan = (0.0f0, 1.0f0)
tsteps = range(tspan, tspan, length = datasize)
0.0f0:0.03448276f0:1.0f0
function trueSDEfunc(du, u, p, t)
true_A = [-0.1 2.0; -2.0 -0.1]
du .= ((u.^3)'true_A)'
end

mp = Float32[0.2, 0.2]
function true_noise_func(du, u, p, t)
du .= mp.*u
end

prob_truesde = SDEProblem(trueSDEfunc, true_noise_func, u0, tspan)
SDEProblem with uType Vector{Float32} and tType Float32. In-place: true
timespan: (0.0f0, 1.0f0)
u0: 2-element Vector{Float32}:
2.0
0.0

For our dataset we will use DifferentialEquations.jl's parallel ensemble interface to generate data from the average of 10,000 runs of the SDE:

# Take a typical sample from the mean
ensemble_prob = EnsembleProblem(prob_truesde)
ensemble_sol = solve(ensemble_prob, SOSRI(), trajectories = 10000)
ensemble_sum = EnsembleSummary(ensemble_sol)

sde_data, sde_data_vars = Array.(timeseries_point_meanvar(ensemble_sol, tsteps))
(Float32[2.0 1.9204059 … 0.061854772 0.13894309; 0.0 0.5330523 … -0.6823729 -0.63669026], Float32[0.0 0.07600243 … 1.3129256 1.287724; 0.0 0.057330403 … 1.006299 1.0356529])

Now we build a neural SDE. For simplicity we will use the NeuralDSDE neural SDE with diagonal noise layer function:

drift_dudt = Flux.Chain(x -> x.^3,
Flux.Dense(2, 50, tanh),
Flux.Dense(50, 2))
p1, re1 = Flux.destructure(drift_dudt)

diffusion_dudt = Flux.Chain(Flux.Dense(2, 2))
p2, re2 = Flux.destructure(diffusion_dudt)

neuralsde = NeuralDSDE(drift_dudt, diffusion_dudt, tspan, SOSRI(),
saveat = tsteps, reltol = 1e-1, abstol = 1e-1)
NeuralDSDE()

Let's see what that looks like:

# Get the prediction using the correct initial condition
prediction0 = neuralsde(u0)

drift_(u, p, t) = re1(p[1:neuralsde.len])(u)
diffusion_(u, p, t) = re2(p[neuralsde.len+1:end])(u)

prob_neuralsde = SDEProblem(drift_, diffusion_, u0,(0.0f0, 1.2f0), neuralsde.p)

ensemble_nprob = EnsembleProblem(prob_neuralsde)
ensemble_nsol = solve(ensemble_nprob, SOSRI(), trajectories = 100,
saveat = tsteps)
ensemble_nsum = EnsembleSummary(ensemble_nsol)

plt1 = plot(ensemble_nsum, title = "Neural SDE: Before Training")
scatter!(plt1, tsteps, sde_data', lw = 3)

scatter(tsteps, sde_data[1,:], label = "data")
scatter!(tsteps, prediction0[1,:], label = "prediction")

Now just as with the neural ODE we define a loss function that calculates the mean and variance from n runs at each time point and uses the distance from the data values:

function predict_neuralsde(p, u = u0)
return Array(neuralsde(u, p))
end

function loss_neuralsde(p; n = 100)
u = repeat(reshape(u0, :, 1), 1, n)
samples = predict_neuralsde(p, u)
means = mean(samples, dims = 2)
vars = var(samples, dims = 2, mean = means)[:, 1, :]
means = means[:, 1, :]
loss = sum(abs2, sde_data - means) + sum(abs2, sde_data_vars - vars)
return loss, means, vars
end
loss_neuralsde (generic function with 1 method)
list_plots = []
iter = 0

# Callback function to observe training
callback = function (p, loss, means, vars; doplot = false)
global list_plots, iter

if iter == 0
list_plots = []
end
iter += 1

# loss against current data
display(loss)

# plot current prediction against data
plt = Plots.scatter(tsteps, sde_data[1,:], yerror = sde_data_vars[1,:],
ylim = (-4.0, 8.0), label = "data")
Plots.scatter!(plt, tsteps, means[1,:], ribbon = vars[1,:], label = "prediction")
push!(list_plots, plt)

if doplot
display(plt)
end
return false
end
#4 (generic function with 1 method)

Now we train using this loss function. We can pre-train a little bit using a smaller n and then decrease it after it has had some time to adjust towards the right mean behavior:

opt = ADAM(0.025)

# First round of training with n = 10
optf = Optimization.OptimizationFunction((x,p) -> loss_neuralsde(x, n=10), adtype)
optprob = Optimization.OptimizationProblem(optf, neuralsde.p)
result1 = Optimization.solve(optprob, opt,
callback = callback, maxiters = 100)
u: 258-element Vector{Float32}:
0.21147095
-0.2093118
0.71127474
-0.11406444
-0.1572075
-0.19877887
0.3514175
0.6034226
-0.088659346
-0.43707004
⋮
-0.1799361
-0.38913444
-0.02879789
0.8344903
-0.4098586
0.30302274
1.0172085
0.11540552
0.018038237

We resume the training with a larger n. (WARNING - this step is a couple of orders of magnitude longer than the previous one).

optf2 = Optimization.OptimizationFunction((x,p) -> loss_neuralsde(x, n=100), adtype)
optprob2 = Optimization.OptimizationProblem(optf2, result1.u)
result2 = Optimization.solve(optprob2, opt,
callback = callback, maxiters = 100)
u: 258-element Vector{Float32}:
0.06675059
0.02396356
0.7302385
-0.22940089
-0.36332256
-0.45152947
0.33569753
0.5510696
-0.2817531
-0.54131305
⋮
-0.0033338105
-0.50495243
-0.10681707
0.83807236
-1.0403451
0.13566856
1.1598427
0.0014472181
0.3255693

And now we plot the solution to an ensemble of the trained neural SDE:

_, means, vars = loss_neuralsde(result2.u, n = 1000)

plt2 = Plots.scatter(tsteps, sde_data', yerror = sde_data_vars',
label = "data", title = "Neural SDE: After Training",
xlabel = "Time")
plot!(plt2, tsteps, means', lw = 8, ribbon = vars', label = "prediction")

plt = plot(plt1, plt2, layout = (2, 1))
savefig(plt, "NN_sde_combined.png"); nothing # sde Try this with GPUs as well!