GPU-based MNIST Neural ODE Classifier

Training a classifier for MNIST using a neural ordinary differential equation NeuralODE on GPUs with minibatching.

(Step-by-step description below)

using DiffEqFlux, CUDA, Zygote, NNlib, OrdinaryDiffEq, Lux, Statistics, ComponentArrays,
      Random, Optimization, OptimizationOptimisers, LuxCUDA, MLUtils, OneHotArrays
using MLDatasets: MNIST

CUDA.allowscalar(false)
ENV["DATADEPS_ALWAYS_ACCEPT"] = true

const cdev = cpu_device()
const gdev = gpu_device()

logitcrossentropy = CrossEntropyLoss(; logits = Val(true))

function loadmnist(batchsize)
    # Load MNIST
    dataset = MNIST(; split = :train)[1:2000] # Partial load for demonstration
    imgs = dataset.features
    labels_raw = dataset.targets

    # Process images into (H,W,C,BS) batches
    x_data = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3)))
    y_data = onehotbatch(labels_raw, 0:9)

    return DataLoader(mapobs(gdev, (x_data, y_data)); batchsize, shuffle = true)
end

dataloader = loadmnist(128)

down = Chain(FlattenLayer(), Dense(784, 20, tanh))
nn = Chain(Dense(20, 10, tanh), Dense(10, 10, tanh), Dense(10, 20, tanh))
fc = Dense(20, 10)

nn_ode = NeuralODE(nn, (0.0f0, 1.0f0), Tsit5(); save_everystep = false,
    reltol = 1e-3, abstol = 1e-3, save_start = false)

solution_to_array(sol) = sol.u[end]

# Build our over-all model topology
m = Chain(; down, nn_ode, convert = WrappedFunction(solution_to_array), fc)
ps, st = Lux.setup(Xoshiro(0), m);
ps = ComponentArray(ps) |> gdev;
st = st |> gdev;

# We can also build the model topology without a NN-ODE
m_no_ode = Chain(; down, nn, fc)
ps_no_ode, st_no_ode = Lux.setup(Xoshiro(0), m_no_ode);
ps_no_ode = ComponentArray(ps_no_ode) |> gdev;
st_no_ode = st_no_ode |> gdev;

x_train1, y_train1 = first(dataloader)

# To understand the intermediate NN-ODE layer, we can examine it's dimensionality
x_d = first(down(x_train1, ps.down, st.down))

# We can see that we can compute the forward pass through the NN topology featuring an NNODE layer.
x_m = first(m(x_train1, ps, st))
# Or without the NN-ODE layer.
x_m = first(m_no_ode(x_train1, ps_no_ode, st_no_ode))

classify(x) = argmax.(eachcol(x))

function accuracy(model, data, ps, st; n_batches = 100)
    total_correct = 0
    total = 0
    st = Lux.testmode(st)
    for (x, y) in collect(data)[1:min(n_batches, length(data))]
        target_class = classify(cdev(y))
        predicted_class = classify(cdev(first(model(x, ps, st))))
        total_correct += sum(target_class .== predicted_class)
        total += length(target_class)
    end
    return total_correct / total
end

accuracy(m, ((x_train1, y_train1),), ps, st) # burn in accuracy

function loss_function(ps, data)
    (x, y) = data
    pred, st_ = m(x, ps, st)
    return logitcrossentropy(pred, y)
end

loss_function(ps, (x_train1, y_train1)) # burn in loss

opt = OptimizationOptimisers.Adam(0.05)
iter = 0

opt_func = OptimizationFunction(loss_function, Optimization.AutoZygote())
opt_prob = OptimizationProblem(opt_func, ps, dataloader)

function callback(state, l)
    global iter += 1
    iter % 10 == 0 &&
        @info "[MNIST GPU] Accuracy: $(accuracy(m, dataloader, state.u, st))"
    return false
end

# Train the NN-ODE and monitor the loss and weights.
res = Optimization.solve(opt_prob, opt; callback, epochs = 5)
accuracy(m, dataloader, res.u, st)
0.8925

Step-by-Step Description

Load Packages

using DiffEqFlux, CUDA, Zygote, NNlib, OrdinaryDiffEq, Lux, Statistics, ComponentArrays,
      Random, Optimization, OptimizationOptimisers, LuxCUDA, MLUtils, OneHotArrays
using MLDatasets: MNIST

GPU

A good trick used here:

CUDA.allowscalar(false)
ENV["DATADEPS_ALWAYS_ACCEPT"] = true

const cdev = cpu_device()
const gdev = gpu_device()
(::CUDADevice{Nothing}) (generic function with 5 methods)

ensures that only optimized kernels are called when using the GPU. Additionally, the gpu_device function is shown as a way to translate models and data over to the GPU. Note that this function is CPU-safe, so if the GPU is disabled or unavailable, this code will fall back to the CPU.

Load MNIST Dataset into Minibatches

The MNIST dataset is split into 60,000 train and 10,000 test images, ensuring a balanced ratio of labels.

The preprocessing is done in loadmnist where the raw MNIST data is split into features x and labels y. Features are reshaped into format [Height, Width, Color, Samples], in case of the train set [28, 28, 1, 60000]. Using OneHotArrays's onehotbatch function, the labels (numbers 0 to 9) are one-hot encoded, resulting in a a [10, 60000]OneHotMatrix.

Features and labels are then passed to MLUtils's DataLoader. This automatically minibatches both the images and labels using the specified batchsize, meaning that every minibatch will contain 128 images with a single color channel of 28x28 pixels.

logitcrossentropy = CrossEntropyLoss(; logits = Val(true))

function loadmnist(batchsize)
    # Load MNIST
    dataset = MNIST(; split = :train)[1:2000] # Partial load for demonstration
    imgs = dataset.features
    labels_raw = dataset.targets

    # Process images into (H,W,C,BS) batches
    x_data = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3)))
    y_data = onehotbatch(labels_raw, 0:9)

    return DataLoader(mapobs(gdev, (x_data, y_data)); batchsize, shuffle = true)
end
loadmnist (generic function with 1 method)

and then loaded from main:

dataloader = loadmnist(128)
16-element DataLoader(::MLUtils.MappedData{:auto, CUDADevice{Nothing}, Tuple{Array{Float32, 4}, OneHotArrays.OneHotMatrix{UInt32, Vector{UInt32}}}}, shuffle=true, batchsize=128)
  with first element:
  (28×28×1×128 CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}, 10×128 OneHotMatrix(::CUDA.CuArray{UInt32, 1, CUDA.DeviceMemory}) with eltype Bool,)

Layers

The Neural Network requires passing inputs sequentially through multiple layers. We use Chain which allows inputs to functions to come from the previous layer and sends the outputs to the next. Four different sets of layers are used here:

down = Chain(FlattenLayer(), Dense(784, 20, tanh))
nn = Chain(Dense(20, 10, tanh), Dense(10, 10, tanh), Dense(10, 20, tanh))
fc = Dense(20, 10)
Dense(20 => 10)     # 210 parameters

down: This layer downsamples our images into a 20 dimensional feature vector. It takes a 28 x 28 image, flattens it, and then passes it through a fully connected layer with tanh activation

nn: A 3 layers Deep Neural Network Chain with tanh activation which is used to model our differential equation

nn_ode: ODE solver layer

fc: The final fully connected layer which maps our learned feature vector to the probability of the feature vector of belonging to a particular class

Array Conversion

When using NeuralODE, this function converts the ODESolution's DiffEqArray to a Matrix (CuArray), and reduces the matrix from 3 to 2 dimensions for use in the next layer.

nn_ode = NeuralODE(nn, (0.0f0, 1.0f0), Tsit5(); save_everystep = false,
    reltol = 1e-3, abstol = 1e-3, save_start = false)

solution_to_array(sol) = sol.u[end]
solution_to_array (generic function with 1 method)

For CPU: If this function does not automatically fall back to CPU when no GPU is present, we can change gdev(x) to Array(x).

Build Topology

Next, we connect all layers together in a single chain:

# Build our over-all model topology
m = Chain(; down, nn_ode, convert = WrappedFunction(solution_to_array), fc)
ps, st = Lux.setup(Xoshiro(0), m);
ps = ComponentArray(ps) |> gdev;
st = st |> gdev;
(down = (layer_1 = NamedTuple(), layer_2 = NamedTuple()), nn_ode = (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()), convert = NamedTuple(), fc = NamedTuple())
# We can also build the model topology without a NN-ODE
m_no_ode = Chain(; down, nn, fc)
ps_no_ode, st_no_ode = Lux.setup(Xoshiro(0), m_no_ode);
ps_no_ode = ComponentArray(ps_no_ode) |> gdev;
st_no_ode = st_no_ode |> gdev;

x_train1, y_train1 = first(dataloader)

# To understand the intermediate NN-ODE layer, we can examine it's dimensionality
x_d = first(down(x_train1, ps.down, st.down));
@show size(x_d)

# We can see that we can compute the forward pass through the NN topology featuring an NNODE layer.
x_m = first(m(x_train1, ps, st));
@show size(x_m)

# Or without the NN-ODE layer.
x_m = first(m_no_ode(x_train1, ps_no_ode, st_no_ode));
@show size(x_m)
size(x_d) = (20, 128)
size(x_m) = (10, 128)
size(x_m) = (10, 128)

Prediction

To convert the classification back into readable numbers, we use classify which returns the prediction by taking the arg max of the output for each column of the minibatch:

classify(x) = argmax.(eachcol(x))
classify (generic function with 1 method)

Accuracy

We then evaluate the accuracy on n_batches at a time through the entire network:

function accuracy(model, data, ps, st; n_batches = 100)
    total_correct = 0
    total = 0
    st = Lux.testmode(st)
    for (x, y) in collect(data)[1:min(n_batches, length(data))]
        target_class = classify(cdev(y))
        predicted_class = classify(cdev(first(model(x, ps, st))))
        total_correct += sum(target_class .== predicted_class)
        total += length(target_class)
    end
    return total_correct / total
end

accuracy(m, ((x_train1, y_train1),), ps, st) # burn in accuracy
0.1171875

Training Parameters

Once we have our model, we can train our neural network by backpropagation using Lux.train!. This function requires Loss, Optimizer and Callback functions.

Loss

Cross Entropy is the loss function computed here, which applies a Softmax operation on the final output of our model. logitcrossentropy takes in the prediction from our model model(x) and compares it to actual output y:

function loss_function(ps, data)
    (x, y) = data
    pred, st_ = m(x, ps, st)
    return logitcrossentropy(pred, y)
end

loss_function(ps, (x_train1, y_train1)) # burn in loss
2.5208788f0

Optimizer

Adam is specified here as our optimizer with a learning rate of 0.05:

opt = OptimizationOptimisers.Adam(0.05)
Adam(0.05, (0.9, 0.999), 1.0e-8)

Callback

This callback function is used to print both the training and testing accuracy after 10 training iterations:

iter = 0

opt_func = OptimizationFunction(loss_function, Optimization.AutoZygote())
opt_prob = OptimizationProblem(opt_func, ps, dataloader)

function callback(state, l)
    global iter += 1
    iter % 10 == 0 &&
        @info "[MNIST GPU] Accuracy: $(accuracy(m, dataloader, state.u, st))"
    return false
end
callback (generic function with 1 method)

Train

To train our model, we select the appropriate trainable parameters of our network with params. In our case, backpropagation is required for down, nn_ode and fc. Notice that the parameters for Neural ODE is given by nn_ode.p:

# Train the NN-ODE and monitor the loss and weights.
res = Optimization.solve(opt_prob, opt; callback, epochs = 5)
accuracy(m, dataloader, res.u, st)
0.806