Mixture of Experts

Paul Bruffett
6 min readFeb 15, 2022

Outrageously Large Neural Networks in PyTorch

Scaling neural networks has proven very challenging with significant bottlenecks being introduced as models are parallelized across GPUs and machines. Shuffling data and synchronizing parameters across dense networks explodes in costs and latency as the model is scaled.

Mixture of experts architecture introduces sparse connections between the models, dramatically reducing the parameters to be synchronized across instances.

Mixture of Experts consists of

  • A number of experts (feed forward neural networks)
  • Trainable gating network used to select a few experts per input

The experts are, in this implementation, identical networks which are replicated multiple times, each instance representing an expert. There is no reason the architectures of experts couldn’t diverge.

Noisy Gating

The gating network is the novel part of the architecture, it serves as the router for the experts, feeding the data to selected model, collecting the outputs and summing and weighting them.

As implemented here, the gate selects the top-k results from the pool of experts and discards the rest of the gate values (setting them equal to 0). The sparsity saves computation and also includes a noise term that helps with load balancing.

Noisy gating attempts to counteract Mixture of Expert’s tendency toward over reliance on a subset of models, functionally leading to many models being dead weights and rarely or never selected by the top-k mechanism. Noisy gating adds a noise element to the selection of an expert during training time. A loss coefficient can also be captured in the gate selection process and used to penalize over reliance on a few experts.

MoE in PyTorch

Successfully building a MoE model requires both building logic for the layers but also an efficient distributed training framework. Microsoft has built a library that supports PyTorch and effectively handles these challenges; Deepspeed. I’ll be using Deepspeed to train a Mixture of Expert vision recognition problem for the CIFAR10 dataset.

I’m using AzureML because it was easy for me to get a 4 GPU machine, the STANDARD_NC_24 machine type.

After provisioning an Azure ML Workspace in the portal I created a compute instance and used the notebooks in this GitHub Repo.

Baseline CIFAR

First, let’s get a working model using Deepspeed and a MoE expert running with a generic CNN architecture to demonstrate the library and functionality.

This implementation is designed to be used in the Azure ML Workspace, but if you’re not using that, simply remove all of the azureml.core references and use your own MLFlow for tracking, or remove that too.

Data preparation

Nothing too out of the ordinary, except that deepspeed.init_distributed() is what will allow us to get_rank in subsequent calls. This is done so we can avoid spinning multiple redundant threads in parallel and, later, not log repeated or duplicate messages.

Parameters

Next, we set a number of parameters. Deepspeed splits these into two;

  1. A config JSON file or object (in this case ds_config), which contains deepspeed specific configuration related to optimizers and CUDA parameters, in our case floating point 16 configuration.
  2. Parameters or flags that are passed when invoking the training script. Here, because we’re using a notebook, they’re coded arguments that are parsed in a function. These are general model configuration and settings which are not Deepspeed specific.

Basic Network

This neural network isn’t remarkable except in that it uses a deepspeed.moe.layer. This is our expert and is duplicated the number of times specified in our “ — num-experts” configuration (8 for this notebook example).

The hidden_size isn’t actually what governs the number of parameters each expert has, rather it is the contract for the number of inputs and outputs each expert receives and produces. In this case, only fc3 is duplicated 8 times, the rest of the network architecture prepares the data and collects the output and maps them to classes.

The forward function is also straightforward except that fc3 produces several outputs, one is the result, which is fed into another layer, but the other two outputs are the gate loss value, the other being the expert count.

We’re capturing and logging the gate loss along with other run metrics and losses.

Deepspeed initialization

Initializing deepspeed involves setting up a model_engine, optimizer and trainloader, this is where our distinct configuration options are used and we get objects that can be used in a distributed training environment with a normal training loop (we’ll see that below).

Training

The training loop is standard but is using the deepspeed components that we initialized earlier which will allow us to perform distributed training across machines or local GPUs.

The optimizer and LR schedule are bundled in the model engine object, hence it being used for producing inferences and being called for optimization or training.

Finally, we assess accuracy, which for this model is not very good. Now that we have a basic implementation that trains fast and isn’t very unorthodox or verbose, we can apply some of this to a more sophisticated architecture and see the commonalities and differences with a Resnet model for CIFAR.

MoE ResNet

In this notebook we’ll take the same implementation but applying the ResNet architecture to the problem, we can see much of the code holds steady, but we’re going to modify the model architecture.

There are two notes; the MoE layer we’re using manages the input and output sizes of our network. I am going to experiment with the experts receiving minimal pre-processing so we need to flatten our images and set the hidden_size (input and output dimensions) of our MoE layer to the image feature length (channels * height * width);

This is a quick and dirty way for us to not hard code the image attributes.

Now we’ll implement our ResNet, reusing the custom block we defined. This network is a

We’ll use the residual block defined earlier in the script to build a ResNet34 implementation. This model is mostly unremarkable except that it reshapes the image at the beginning of the forward pass, ensuring it is not a flattened tensor;

x = torch.reshape(x, (-1, self.channels, self.height, self.height))

This layer is reshaping because this module is a module within our larger network;

Here the resnet layers act as parameters for the moe, receiving a flattened view of the images and outputting a flattened view of the same size from the final dense layer in the network. That’s connected to a fully connected layer that outputs our 10 classes.

We can see the similarities (most of the notebook) with the simpler implementation, here most of the parameters for the model are duplicated in experts, whereas only about a layer was previously.

So this network is instantiated repeatedly based on how many experts we’ve flagged in our configuation.

Parallel Networks

Here, we’re making the ResNet architecture into a script we can launch from the command line.

So far we haven’t taken much advantage of the Mixture of Experts architecture; all the training has taken place on a single GPU in the machine. If we really scale our Experts, to 40 say, we’ll run out of memory on the GPU.

Now we’ll use Deepspeed to launch the training and use a script which is very similar to our ResNet notebook. If using an NC24 it has 4 GPUs; we can scale our Expert network to 40 Experts with 10 running on each GPU.

This is governed by 3 parameters, the dist_world_size, which is discovered by deepspeed runtime, the “ — ep-world-size” which governs how many experts run on a given GPU, and “ — num-experts”, which governs how many experts we have in total.

For this run we have 4 GPUs with 10 parallel experts and 40 experts configured.

To launch a distributed run we need to activate the kernel we’ve been using;

source activate azureml_py38_PT_TF

then;

deepspeed cifar_distributed_resnet.py

This will kick off our run which will log to an AzureML experiment, allowing us to track metrics across the ranks.

This network converges and takes several hours to train;

The accuracy is still useless and terrible but I suppose we’ve demonstrated the approach.

--

--

Paul Bruffett

Enterprise Architect specializing in data and analytics.