# Multiple Shooting

In Multiple Shooting, the training data is split into overlapping intervals. The solver is then trained on individual intervals. If the end conditions of any interval co-incide with the initial conditions of the next immediate interval, then the joined/combined solution is same as solving on the whole dataset (without splitting).

To ensure that the overlapping part of two consecutive intervals co-incide, we add a penalizing term, continuity_strength * absolute_value_of( prediction of last point of some group, i - prediction of first point of group i+1 ), to the loss.

Note that the continuity_strength should have a large positive value to add high penalities in case the solver predicts discontinuous values.

The following is a working demo, using Multiple Shooting

using DiffEqFlux, OrdinaryDiffEq, Flux, Optim, Plots

# Define initial conditions and timesteps
datasize = 30
u0 = Float32[2.0, 0.0]
tspan = (0.0f0, 5.0f0)
tsteps = range(tspan[1], tspan[2], length = datasize)

# Get the data
function trueODEfunc(du, u, p, t)
true_A = [-0.1 2.0; -2.0 -0.1]
du .= ((u.^3)'true_A)'
end
prob_trueode = ODEProblem(trueODEfunc, u0, tspan)
ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps))

# Define the Neural Network
dudt2 = FastChain((x, p) -> x.^3,
FastDense(2, 16, tanh),
FastDense(16, 2))

prob_neuralode = NeuralODE(dudt2, (0.0,5.0), Tsit5(), saveat = tsteps)

function plot_function_for_multiple_shoot(plt, pred, grp_size)
step = 1
if(grp_size != 1)
step = grp_size-1
end
if(grp_size == datasize)
scatter!(plt, tsteps, pred[1][1,:], label = "pred")
else
for i in 1:step:datasize-grp_size
# The term trunc(Integer,(i-1)/(grp_size-1)+1) goes from 1, 2, ... , N where N is the total number of groups that can be formed from ode_data (In other words, N = trunc(Integer, (datasize-1)/(grp_size-1)))
scatter!(plt, tsteps[i:i+step], pred[trunc(Integer,(i-1)/step+1)][1,:], label = "grp"*string(trunc(Integer,(i-1)/step+1)))
end
end
end

callback = function (p, l, pred, predictions; doplot = true)
display(l)
if doplot
# plot the original data
plt = scatter(tsteps[1:size(pred,2)], ode_data[1,1:size(pred,2)], label = "data")

# plot the different predictions for individual shoot
plot_function_for_multiple_shoot(plt, predictions, grp_size_param)

# plot a single shooting performance of our multiple shooting training (this is what the solver predicts after the training is done)
# scatter!(plt, tsteps[1:size(pred,2)], pred[1,:], label = "pred")

display(plot(plt))
end
return false
end

# Define parameters for Multiple Shooting
grp_size_param = 1
loss_multiplier_param = 100

neural_ode_f(u,p,t) = dudt2(u,p)
prob_param = ODEProblem(neural_ode_f, u0, tspan, initial_params(dudt2))

function loss_function_param(ode_data, pred):: Float32
return sum(abs2, (ode_data .- pred))^2
end

function loss_neuralode(p)
return multiple_shoot(p, ode_data, tsteps, prob_param, loss_function_param, Tsit5(), grp_size_param, loss_multiplier_param)
end

result_neuralode = DiffEqFlux.sciml_train(loss_neuralode, prob_neuralode.p,
maxiters = 300)
callback(result_neuralode.minimizer,loss_neuralode(result_neuralode.minimizer)...;doplot=true)

result_neuralode_2 = DiffEqFlux.sciml_train(loss_neuralode, result_neuralode.minimizer,
BFGS(), cb = callback,
maxiters = 100, allow_f_increases=true)
callback(result_neuralode_2.minimizer,loss_neuralode(result_neuralode_2.minimizer)...;doplot=true)


Here's the plots that we get from above

The picture on the left shows how well our Neural Network does on a single shoot after training it through multiple_shoot. The picture on the right shows the predictions of each group (Notice that there are overlapping points as well. These are the points we are trying to co-incide.)

Here is an output with grp_size = 30 (which is same as solving on the whole interval without splitting also called single shooting)

It is clear from the above picture, a single shoot doesn't perform very well with the ODE Problem we have and gets stuck in a local minima.