Augmented Neural Ordinary Differential Equations

Copy-Pasteable Code

using DiffEqFlux, OrdinaryDiffEq, Statistics, LinearAlgebra, Plots, LuxCUDA, Random
using MLUtils, ComponentArrays
using Optimization, OptimizationOptimisers, IterTools

const cdev = cpu_device()
const gdev = gpu_device()

function random_point_in_sphere(dim, min_radius, max_radius)
    distance = (max_radius - min_radius) .* (rand(Float32, 1) .^ (1.0f0 / dim)) .+
               min_radius
    direction = randn(Float32, dim)
    unit_direction = direction ./ norm(direction)
    return distance .* unit_direction
end

function concentric_sphere(dim, inner_radius_range, outer_radius_range,
        num_samples_inner, num_samples_outer; batch_size = 64)
    data = []
    labels = []
    for _ in 1:num_samples_inner
        push!(data, reshape(random_point_in_sphere(dim, inner_radius_range...), :, 1))
        push!(labels, ones(1, 1))
    end
    for _ in 1:num_samples_outer
        push!(data, reshape(random_point_in_sphere(dim, outer_radius_range...), :, 1))
        push!(labels, -ones(1, 1))
    end
    data = cat(data...; dims = 2)
    labels = cat(labels...; dims = 2)
    return DataLoader((data |> gdev, labels |> gdev); batchsize = batch_size,
        shuffle = true, partial = false)
end

diffeqarray_to_array(x) = gdev(x.u[1])

function construct_model(out_dim, input_dim, hidden_dim, augment_dim)
    input_dim = input_dim + augment_dim
    node = NeuralODE(
        Chain(Dense(input_dim, hidden_dim, relu),
            Dense(hidden_dim, hidden_dim, relu), Dense(hidden_dim, input_dim)),
        (0.0f0, 1.0f0),
        Tsit5();
        save_everystep = false,
        reltol = 1.0f-3,
        abstol = 1.0f-3,
        save_start = false)
    node = augment_dim == 0 ? node : AugmentedNDELayer(node, augment_dim)
    model = Chain(node, diffeqarray_to_array, Dense(input_dim, out_dim))
    ps, st = Lux.setup(Xoshiro(0), model)
    return model, ps |> gdev, st |> gdev
end

function plot_contour(model, ps, st, npoints = 300)
    grid_points = zeros(Float32, 2, npoints^2)
    idx = 1
    x = range(-4.0f0, 4.0f0; length = npoints)
    y = range(-4.0f0, 4.0f0; length = npoints)
    for x1 in x, x2 in y
        grid_points[:, idx] .= [x1, x2]
        idx += 1
    end
    sol = reshape(model(grid_points |> gdev, ps, st)[1], npoints, npoints) |> cdev

    return contour(x, y, sol; fill = true, linewidth = 0.0)
end

loss_node(model, data, ps, st) = mean((first(model(data[1], ps, st)) .- data[2]) .^ 2)

dataloader = concentric_sphere(
    2, (0.0f0, 2.0f0), (3.0f0, 4.0f0), 2000, 2000; batch_size = 256)

iter = 0
cb = function (state, l)
    global iter
    iter += 1
    if iter % 10 == 0
        @info "Augmented Neural ODE" iter=iter loss=l
    end
    return false
end

model, ps, st = construct_model(1, 2, 64, 0)
opt = OptimizationOptimisers.Adam(0.005)

loss_node(model, (dataloader.data[1], dataloader.data[2]), ps, st)

println("Training Neural ODE")

optfunc = OptimizationFunction(
    (x, data) -> loss_node(model, data, x, st),
    Optimization.AutoZygote())
optprob = OptimizationProblem(optfunc, ComponentArray(ps |> cdev) |> gdev, dataloader)
res = solve(optprob, opt; callback = cb, epochs = 100)

plt_node = plot_contour(model, res.u, st)

model, ps, st = construct_model(1, 2, 64, 1)
opt = OptimizationOptimisers.Adam(0.005)

println()
println("Training Augmented Neural ODE")

optfunc = OptimizationFunction(
    (x, data) -> loss_node(model, data, x, st),
    Optimization.AutoZygote())
optprob = OptimizationProblem(optfunc, ComponentArray(ps |> cdev) |> gdev, dataloader)
res = solve(optprob, opt; callback = cb, epochs = 100)

plot_contour(model, res.u, st)
Example block output

Step-by-Step Explanation

Loading required packages

using DiffEqFlux, OrdinaryDiffEq, Statistics, LinearAlgebra, Plots, LuxCUDA, Random
using MLUtils, ComponentArrays
using Optimization, OptimizationOptimisers, IterTools

const cdev = cpu_device()
const gdev = gpu_device()
(::CUDADevice{Nothing}) (generic function with 1 method)

Generating a toy dataset

In this example, we will be using data sampled uniformly in two concentric circles and then train our Neural ODEs to do regression on that values. We assign 1 to any point which lies inside the inner circle, and -1 to any point which lies between the inner and outer circle. Our first function random_point_in_sphere samples points uniformly between 2 concentric circles/spheres of radii min_radius and max_radius respectively.

function random_point_in_sphere(dim, min_radius, max_radius)
    distance = (max_radius - min_radius) .* (rand(Float32, 1) .^ (1.0f0 / dim)) .+
               min_radius
    direction = randn(Float32, dim)
    unit_direction = direction ./ norm(direction)
    return distance .* unit_direction
end
random_point_in_sphere (generic function with 1 method)

Next, we will construct a dataset of these points and use Flux's DataLoader to automatically minibatch and shuffle the data.

function concentric_sphere(dim, inner_radius_range, outer_radius_range,
        num_samples_inner, num_samples_outer; batch_size = 64)
    data = []
    labels = []
    for _ in 1:num_samples_inner
        push!(data, reshape(random_point_in_sphere(dim, inner_radius_range...), :, 1))
        push!(labels, ones(1, 1))
    end
    for _ in 1:num_samples_outer
        push!(data, reshape(random_point_in_sphere(dim, outer_radius_range...), :, 1))
        push!(labels, -ones(1, 1))
    end
    data = cat(data...; dims = 2)
    labels = cat(labels...; dims = 2)
    return DataLoader((data |> gdev, labels |> gdev); batchsize = batch_size,
        shuffle = true, partial = false)
end
concentric_sphere (generic function with 1 method)

Models

We consider 2 models in this tutorial. The first is a simple Neural ODE which is described in detail in this tutorial. The other one is an Augmented Neural ODE [1]. The idea behind this layer is very simple. It augments the input to the Neural DE Layer by appending zeros. So in order to use any arbitrary DE Layer in combination with this layer, simply assume that the input to the DE Layer is of size size(x, 1) + augment_dim instead of size(x, 1) and construct that layer accordingly.

In order to run the models on Flux.gpu, we need to manually transfer the models to Flux.gpu. First one is the network predicting the derivatives inside the Neural ODE and the other one is the last layer in the Chain.

diffeqarray_to_array(x) = gdev(x.u[1])

function construct_model(out_dim, input_dim, hidden_dim, augment_dim)
    input_dim = input_dim + augment_dim
    node = NeuralODE(
        Chain(Dense(input_dim, hidden_dim, relu),
            Dense(hidden_dim, hidden_dim, relu), Dense(hidden_dim, input_dim)),
        (0.0f0, 1.0f0),
        Tsit5();
        save_everystep = false,
        reltol = 1.0f-3,
        abstol = 1.0f-3,
        save_start = false)
    node = augment_dim == 0 ? node : AugmentedNDELayer(node, augment_dim)
    model = Chain(node, diffeqarray_to_array, Dense(input_dim, out_dim))
    ps, st = Lux.setup(Xoshiro(0), model)
    return model, ps |> gdev, st |> gdev
end
construct_model (generic function with 1 method)

Plotting the Results

Here, we define a utility to plot our model regression results as a heatmap.

function plot_contour(model, ps, st, npoints = 300)
    grid_points = zeros(Float32, 2, npoints^2)
    idx = 1
    x = range(-4.0f0, 4.0f0; length = npoints)
    y = range(-4.0f0, 4.0f0; length = npoints)
    for x1 in x, x2 in y
        grid_points[:, idx] .= [x1, x2]
        idx += 1
    end
    sol = reshape(model(grid_points |> gdev, ps, st)[1], npoints, npoints) |> cdev

    return contour(x, y, sol; fill = true, linewidth = 0.0)
end
plot_contour (generic function with 2 methods)

Training Parameters

Loss Functions

We use the L2 distance between the model prediction model(x) and the actual prediction y as the optimization objective.

loss_node(model, data, ps, st) = mean((first(model(data[1], ps, st)) .- data[2]) .^ 2)
loss_node (generic function with 1 method)

Dataset

Next, we generate the dataset. We restrict ourselves to 2 dimensions as it is easy to visualize. We sample a total of 4000 data points.

dataloader = concentric_sphere(
    2, (0.0f0, 2.0f0), (3.0f0, 4.0f0), 2000, 2000; batch_size = 256)
15-element DataLoader(::Tuple{CUDA.CuArray{Float32, 2, CUDA.DeviceMemory}, CUDA.CuArray{Float32, 2, CUDA.DeviceMemory}}, shuffle=true, batchsize=256, partial=false)
  with first element:
  (2×256 CUDA.CuArray{Float32, 2, CUDA.DeviceMemory}, 1×256 CUDA.CuArray{Float32, 2, CUDA.DeviceMemory},)

Callback Function

Additionally, we define a callback function which displays the total loss at specific intervals.

iter = 0
cb = function (state, l)
    global iter
    iter += 1
    if iter % 10 == 0
        @info "Augmented Neural ODE" iter=iter loss=l
    end
    return false
end
#2 (generic function with 1 method)

Optimizer

We use Adam as the optimizer with a learning rate of 0.005

opt = OptimizationOptimisers.Adam(5.0f-3)
Adam(0.005, (0.9, 0.999), 1.0e-8)

Training the Neural ODE

To train our neural ode model, we need to pass the appropriate learnable parameters, parameters which are returned by the construct_models function. It is simply the node.p vector. We then train our model for 20 epochs.

model, ps, st = construct_model(1, 2, 64, 0)

optfunc = OptimizationFunction(
    (x, data) -> loss_node(model, data, x, st),
    Optimization.AutoZygote())
optprob = OptimizationProblem(optfunc, ComponentArray(ps |> cdev) |> gdev, dataloader)
res = solve(optprob, opt; callback = cb, epochs = 100)

plot_contour(model, res.u, st)
Example block output

Here is what the contour plot should look for Neural ODE. Notice that the regression is not perfect due to the thin artifact which connects the circles.

Training the Augmented Neural ODE

Our training configuration will be the same as that of Neural ODE. Only in this case, we have augmented the input with a single zero. This makes the problem 3-dimensional, and as such it is possible to find a function which can be expressed by the neural ode. For more details and proofs, please refer to [1].

model, ps, st = construct_model(1, 2, 64, 1)

optfunc = OptimizationFunction(
    (x, data) -> loss_node(model, data, x, st),
    Optimization.AutoZygote())
optprob = OptimizationProblem(optfunc, ComponentArray(ps |> cdev) |> gdev, dataloader)
res = solve(optprob, opt; callback = cb, epochs = 100)

plot_contour(model, res.u, st)
Example block output

References

[1] Dupont, Emilien, Arnaud Doucet, and Yee Whye Teh. "Augmented neural ODEs." In Proceedings of the 33rd International Conference on Neural Information Processing Systems, pp. 3140-3150. 2019.