Neural Ordinary Differential Equations with Flux.train!
The following is the same neural ODE example as before, but now using Flux.jl directly with Flux.train!
. Notice that the only difference is that we have to make the neural network be a Chain
and use Flux.jl's Flux.params
implicit parameter system.
using DiffEqFlux, DifferentialEquations, Plots
u0 = Float32[2.; 0.]
datasize = 30
tspan = (0.0f0,1.5f0)
function trueODEfunc(du,u,p,t)
true_A = [-0.1 2.0; -2.0 -0.1]
du .= ((u.^3)'true_A)'
end
t = range(tspan[1],tspan[2],length=datasize)
prob = ODEProblem(trueODEfunc,u0,tspan)
ode_data = Array(solve(prob,Tsit5(),saveat=t))
dudt2 = Chain(x -> x.^3,
Dense(2,50,tanh),
Dense(50,2))
p,re = Flux.destructure(dudt2) # use this p as the initial condition!
dudt(u,p,t) = re(p)(u) # need to restructure for backprop!
prob = ODEProblem(dudt,u0,tspan)
function predict_n_ode()
Array(solve(prob,Tsit5(),u0=u0,p=p,saveat=t))
end
function loss_n_ode()
pred = predict_n_ode()
loss = sum(abs2,ode_data .- pred)
loss
end
loss_n_ode() # n_ode.p stores the initial parameters of the neural ODE
cb = function (;doplot=false) # callback function to observe training
pred = predict_n_ode()
display(sum(abs2,ode_data .- pred))
# plot current prediction against data
pl = scatter(t,ode_data[1,:],label="data")
scatter!(pl,t,pred[1,:],label="prediction")
display(plot(pl))
return false
end
# Display the ODE with the initial parameter values.
cb()
data = Iterators.repeated((), 1000)
Flux.train!(loss_n_ode, Flux.params(u0,p), data, ADAM(0.05), cb = cb)