Hi everybody, this is Julien from Arcee and in this video I would like to introduce model parallelism in Amazon SageMaker. In the previous video, we discussed data parallelism, where we split the dataset across a cluster of training instances to speed up training. Model parallelism is different. Here, we're going to split the model itself and distribute the computation, the training job, on a cluster of training instances.
Why would we want to split the model? Simply because some models are so large and complex that they don't fit in GPU memory. Think about very large natural language processing models like T5 and its variants or very complex computer vision models that work on 3D images or 3D videos, especially at high resolution. These models are huge, with many parameters, and they can't fit inside the memory of a single GPU. Model parallelism will split the model into different parts. The first layers of the model will run on a specific GPU, let's say GPU 0, and then the next layers will run on GPU 1, the next layers on GPU 2, and so on.
Now, you may wonder, how do these GPUs collaborate? When we're forward propagating, we need to go through all layers, and the same goes for backward propagation. This is exactly what model parallelism in SageMaker does: it efficiently partitions models and keeps all GPUs busy at all times.
Imagine we have two GPUs, and we're partitioning the model across those two GPUs. We will have partition one, which is the first half of the model, and partition two, which is the second half. When the training data is sent to the first GPU, the batch will be forward propagated through the first partition and then propagated to the second half of the model on GPU 2.
However, if this batch is being propagated on GPU 1, what's happening on GPU 2? It would be inefficient if GPU 2 wasn't doing anything. So, we split the training data into micro-batches. We further split the mini-batches in the training set and send different micro-batches to the GPUs. For example, batch 1 starts on GPU 1 and is forward propagated to GPU 2. While GPU 2 is busy forward propagating batch 1, GPU 1 can start forward propagating batch 2. This way, you can keep all your GPUs busy during forward propagation.
What about backward propagation? If GPU 2 is done forward propagating micro batch 1, it would then need to back propagate batch 1, and the same would need to happen on GPU 1. However, GPU 1 could still be forward propagating the next batch, creating a conflict. To solve this, we replicate the partitions. On GPU 1, we have two copies of partition 1, and on GPU 2, we have two copies of partition 2. This is a parameter we can set, and we would need at least two copies.
Now, partition one's first copy is always forward propagating, and partition two's first copy is always forward propagating, while the other copies are always backward propagating. This creates a two-way pipeline, allowing micro-batches to move through the system efficiently. They can start from the training set, forward propagate through the first partition, move to the second partition, run the back propagation for that same batch, and then send it back. This pipeline of partitions keeps all GPUs busy on a sequence of micro-batches.
In this example, micro batch n is almost done, having gone through the three other partitions and is being backward propagated. The next one has gone through the first two partitions and is on the third, and the next batch is already through the first, and so on. This is how you can keep all GPUs busy at all times thanks to micro-batches and this pipeline architecture, called interleaving.
Now, let's take a look at an example. The examples are in this repo under `training/distributed training`, and you have TensorFlow and PyTorch examples. The documentation is available, and I will put all the URLs in the description. We're running a simple example based on TensorFlow. Make sure you have the latest SageMaker SDK, version 2.19. Import the necessary packages, and this notebook also uses SageMaker experiments to track the training job, though this is not essential for the example.
Here is our TensorFlow script. You need to import the model parallel package and initialize it. We load the dataset, build our TensorFlow dataset object, and build the model using the traditional Keras API. The only difference is that we extend the distributed model object. Everything else is vanilla Keras. We instantiate the model, select a loss function, define metrics, and handle checkpointing, which is good practice for long-running jobs.
The key part is the SMP step function, which you annotate to forward propagate and return gradients. We compute the loss, gradients, and predictions, and return these tensors. This code is standard, and if you already have it, you just need to return the tensors and annotate with the SMP.step annotation.
We also have a standard training step function to accumulate and apply gradients, and a function for evaluation to compute predictions, loss, and accuracy. These functions need to be annotated with the SMP.step annotation. The training loop is standard, with a barrier call to synchronize all GPUs at the end of the loop.
Summing up, import the package, initialize it, ensure your model extends the distributed model class, handle checkpointing, define the SMP step function for forward and backward propagation, and the evaluation function for validation. This is a simple example, and the changes to your code are minimal. The same philosophy applies to PyTorch, with different APIs.
Once we've done this, we can create our estimator, passing our training script, TensorFlow 2.3.1, Python 3.7, and the distribution configuration, which enables model parallelism. We set parameters for two micro-batches, two copies of each partition, and two partitions, using the interleaved pipeline mode. We can also use data parallelism with Horovod, but I didn't cover that today. The key is to understand these settings.
Finally, we call `fit` and optionally pass the `SMExperimentConfig`. The rest is a normal training job. The model is automatically profiled, split according to the number of partitions, and training begins. You just have to wait for completion. Everything else is a standard training job, and you can see the training log, etc.
That's a short intro to model parallelism. I'm sure you'll work on bigger and more complex examples, so feel free to get in touch if you have questions or feedback. I'm always happy to help, and until then, keep rocking!