In Multiple Shooting, the training data is split into overlapping intervals. The solver is then trained on individual intervals. If the end conditions of any interval co-incide with the initial conditions of the next immediate interval, then the joined/combined solution is same as solving on the whole dataset (without splitting).
To ensure that the overlapping part of two consecutive intervals co-incide, we add a penalizing term,
continuity_strength * absolute_value_of( prediction of last point of some group, i - prediction of first point of group i+1 ), to the loss.
Note that the
continuity_strength should have a large positive value to add high penalities in case the solver predicts discontinuous values.
The following is a working demo, using Multiple Shooting
using DiffEqFlux, OrdinaryDiffEq, Flux, Optim, Plots # Define initial conditions and timesteps datasize = 30 u0 = Float32[2.0, 0.0] tspan = (0.0f0, 5.0f0) tsteps = range(tspan, tspan, length = datasize) # Get the data function trueODEfunc(du, u, p, t) true_A = [-0.1 2.0; -2.0 -0.1] du .= ((u.^3)'true_A)' end prob_trueode = ODEProblem(trueODEfunc, u0, tspan) ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps)) # Define the Neural Network dudt2 = FastChain((x, p) -> x.^3, FastDense(2, 16, tanh), FastDense(16, 2)) prob_neuralode = NeuralODE(dudt2, (0.0,5.0), Tsit5(), saveat = tsteps) function plot_function_for_multiple_shoot(plt, pred, grp_size) step = 1 if(grp_size != 1) step = grp_size-1 end if(grp_size == datasize) scatter!(plt, tsteps, pred[1,:], label = "pred") else for i in 1:step:datasize-grp_size # The term `trunc(Integer,(i-1)/(grp_size-1)+1)` goes from 1, 2, ... , N where N is the total number of groups that can be formed from `ode_data` (In other words, N = trunc(Integer, (datasize-1)/(grp_size-1))) scatter!(plt, tsteps[i:i+step], pred[trunc(Integer,(i-1)/step+1)][1,:], label = "grp"*string(trunc(Integer,(i-1)/step+1))) end end end callback = function (p, l, pred, predictions; doplot = true) display(l) if doplot # plot the original data plt = scatter(tsteps[1:size(pred,2)], ode_data[1,1:size(pred,2)], label = "data") # plot the different predictions for individual shoot plot_function_for_multiple_shoot(plt, predictions, grp_size_param) # plot a single shooting performance of our multiple shooting training (this is what the solver predicts after the training is done) # scatter!(plt, tsteps[1:size(pred,2)], pred[1,:], label = "pred") display(plot(plt)) end return false end # Define parameters for Multiple Shooting grp_size_param = 1 loss_multiplier_param = 100 neural_ode_f(u,p,t) = dudt2(u,p) prob_param = ODEProblem(neural_ode_f, u0, tspan, initial_params(dudt2)) function loss_function_param(ode_data, pred):: Float32 return sum(abs2, (ode_data .- pred))^2 end function loss_neuralode(p) return multiple_shoot(p, ode_data, tsteps, prob_param, loss_function_param, Tsit5(), grp_size_param, loss_multiplier_param) end result_neuralode = DiffEqFlux.sciml_train(loss_neuralode, prob_neuralode.p, ADAM(0.05), cb = callback, maxiters = 300) callback(result_neuralode.minimizer,loss_neuralode(result_neuralode.minimizer)...;doplot=true) result_neuralode_2 = DiffEqFlux.sciml_train(loss_neuralode, result_neuralode.minimizer, BFGS(), cb = callback, maxiters = 100, allow_f_increases=true) callback(result_neuralode_2.minimizer,loss_neuralode(result_neuralode_2.minimizer)...;doplot=true)
Here's the plots that we get from above
The picture on the left shows how well our Neural Network does on a single shoot after training it through
multiple_shoot. The picture on the right shows the predictions of each group (Notice that there are overlapping points as well. These are the points we are trying to co-incide.)
Here is an output with
grp_size = 30 (which is same as solving on the whole interval without splitting also called single shooting)
It is clear from the above picture, a single shoot doesn't perform very well with the ODE Problem we have and gets stuck in a local minima.