In this deep dive video, we zoom in on model distillation, an advanced technique to build high-performance small language models at a reasonable cost. First, we explain what a model distillation is. Then, we introduce two popular strategies for distillation, logits distillation and hidden states distillation. We study in detail how they work and how they're implemented in the Arcee DistillKit open-source library. Finally, we look at two Arcee models built with distillation, Arcee SuperNova 70B and Arcee SuperNova Medius 14B.
Note: my calculation at 18:45 is wrong. It's 2.3 Tera tokens, not 2.3 Peta tokens. Sorry about that 🤡
If you’d like to understand how Arcee AI can help your organization build scalable and cost-efficient AI solutions, please get in touch at sales@arcee.ai or by booking a demo at https://www.arcee.ai/book-a-demo.
⭐️⭐️⭐️ Don't forget to subscribe to be notified of future videos. You can also follow me on Medium at https://julsimon.medium.com or Substack at https://julsimon.substack.com. ⭐️⭐️⭐️
* Slides: https://fr.slideshare.net/slideshow/deep-dive-model-distillation-with-distillkit/274619548
* DistillKit: https://github.com/arcee-ai/DistillKit00:00 Introduction
00:30 What is model distillation?
04:55 Model distillation with DistillKit
11:20 Logits distillation
20:10 Logits distillation with DistillKit
26:10 Hidden states distillation
31:35 Hidden states distillation with DistillKit
36:00 Pros and cons
40:32 Distillation example: Arcee SuperNova 70B
42:50 Distillation example: Arcee SuperNova Medius 14B
44:40 Conclusion
Transcript
Hi everybody, this is Julien from Arcee. In this video, we're going to dive deep into model distillation, an important technique to build very high-quality small language models starting from much larger pre-trained models. We're going to look at the different types of model distillation and then we'll see how they're implemented in our very own library for model distillation called DistillKit. So be ready to learn. Let's get started.
Let's start with defining model distillation. Model distillation is a technique where we train a small student model to mimic the inference behavior of a much larger pre-trained teacher model. The intuition here is that we have huge models, with hundreds of billions of parameters, and they have very high performance, but they're bulky and impractical for most use cases. Can we train a smaller model, maybe 10x smaller, or even more, to predict as closely as possible to the large model? In a way, we leverage the investment that someone else made in building the very large model. We try to shrink it down while keeping as much of the original goodness as possible. The goal is to get as much performance as we can from a much smaller model, which should give us lower compute costs and good performance. This should increase the ROI for our AI application.
The trick is really how we teach that smaller model to mimic the behavior of the large model, and that's what we're going to cover today. The benefits of distillation are resource efficiency because distilling a large model into a small model, although still a fairly heavy task, is not as heavy as training a small model completely from scratch. We reuse an existing model and save on training time and compute costs. Performance, if done right, should be close to that of the large model. In many examples, the performance will actually be higher than models in the same size range. I'll show you a benchmark in our technical report, so distillation is also a way to get more accuracy for the same number of parameters. Knowledge transfer is really what distillation is all about. The large model, the teacher model, has already been trained, so we're trying to keep as much of that knowledge as we can using clever techniques, and we don't train from scratch all over again. Hopefully, we can keep the knowledge from the big model into the smaller model, helping it handle a variety of use cases that the large model could already handle.
Scalability is another obvious benefit. Working with smaller models is always desirable. Small is beautiful, small is fast, small is cheaper, and it's easier to scale. You don't need fancy, expensive GPU instances to run them. There are a lot of benefits in working with small models, especially if they can inherit the qualities of much larger models. Working with smaller models is also more energy efficient, and distilling models is less compute-intensive than training from scratch, which is an added bonus.
Now, let's talk a little bit about DistillKit. DistillKit is an open-source library by Arcee. You can find it on GitHub, and my colleagues also wrote a nice tech report to tell you about the library and some experiments they ran. I strongly encourage you to read this. Here's one of the examples they share in the report. They distilled Qwen 2, a 7 billion parameter model, into a Qwen 2 1.5 billion parameter model, which is 5x smaller. They used logit distillation, which we'll cover, and then they ran benchmarks on the original model (green bars), the distilled model (purple bars), and the original Qwen 2 1.5b instruct model released by the Qwen team (yellow bars).
Immediately, we can see that the distilled model strongly outperforms the original 1.5b model. Distilled models tend to outperform models in the same range, and the scores are impressive across the board. The reason for this is that instead of training a 1.5B model from scratch, we distill it from a larger version that was trained on more data, for longer, and with more parameters. There's just more goodness in the larger model, and we managed to keep more of it through distillation. The comparison between the student model and the teacher model (purple and green bars) shows that the student model cannot be as performant as the teacher, and on some benchmarks like BBH, Big Bench Hard, there is still a huge difference. However, on MUSR, they're pretty close. If your business problem looks more like the MUSR benchmark, you could consider using the 1.5B student model, which is 5x smaller, faster, and easier to deploy and scale.
Benchmarks are benchmarks, and we could talk about them for days and never agree, but they still give you a reference point. It's important to understand how your business problem relates to those benchmarks so you can focus on the right ones and ignore those that don't relate to your problem. Please read the tech report to learn more about DistillKit.
DistillKit implements two distillation techniques: Logit distillation and Hidden States distillation. Logits are the raw output scores of a model before the final activation function. At the very last steps of inference, before generating probabilities by running the logits through the activation function (usually softmax), the logits are the raw scores. When we do logits distillation, we train the student to match the teacher's output. We don't care how the teacher predicts or what's happening inside the teacher model; we just look at the output scores. This is a black-box technique.
The other technique implemented in DistillKit is Hidden States distillation. Unlike Logit distillation, with Hidden States distillation, we train the student to mimic the internal state of the teacher model. We look at how each layer in the teacher model predicts and ask the student to do the same. It's a layer-wise training, a white-box technique. We need to understand what's happening inside the teacher and ask the student to do the same. This is a finer-grained technique and more complicated, as we'll see.
Now, let's look at these two techniques in more detail, starting with Logit distillation. It's very easy to understand. We have a teacher model. We start inferencing with the teacher model on a distillation dataset, which begs the question: How do I pick the distillation dataset? My advice would be to pick a dataset that looks like your business problem. You want to see how the model performs on Q&A pairs that are close to yours. If you're interested in healthcare, don't take cooking recipes. Predict this dataset with the teacher model and get the logits. These logits become a training set, and we train the student model to mimic those logits. We also use the distillation dataset and inject some of the ground truth present in it.
You can run this in two ways: offline distillation, where we predict the distillation dataset completely with the teacher, store all the logits, and then train the student model on this dataset. The problem is that logits are bulky, and this could be petabytes of storage. There's also online distillation, where we run the teacher and student models in parallel. We predict a batch, keep the logits, and use that batch to train the student. Storage isn't a problem, but we need more compute because we run both models in parallel. The student model is usually pre-trained, so we don't start from a blank model. The models need to share the same vocabulary because the token IDs need to be identical.
Let's double-click on what those logits are. The teacher model predicts the distillation dataset, and we get the logits. The output logits for the teacher are a matrix where each row represents the raw score for each token in the model vocabulary. Each row represents the scores for a different position in the output sequence. The input sequence is a series of tokens. We run inference and get logits, with as many columns as we have tokens in the vocabulary. If we're generating M new tokens, we'll have as many rows as we have positions in the generated output sequence.
The first row represents the scores for all the tokens in the vocabulary for the first generated token. One of those values is the highest, and that's the one we pick. The next row has the scores for all tokens in the vocabulary for position two in the output sequence, and so on. Each row represents the scores for one token in the vocabulary at each position. This is what we train the student on.
Let's look at the storage requirements. If we have 50,000 prompts in the distillation dataset, 100,000 words or tokens in the vocabulary, and we're generating up to 512 tokens in the output sequence, we multiply 50,000 by 100,000 by 512, and we get over two peta tokens. This is an insane number of petabytes. Storing this in S3 would be expensive, so we can reduce this by keeping only the top K most likely tokens or by doing online distillation.
Now, let's look at how this works with DistillKit. The purpose is for the student to mimic the teacher. The logits matrix needs to be as close as possible between the two models. We predict with the student and the teacher, get logits for both, and measure the difference using the Kullback-Leibler (KL) divergence to score the two probability distributions. We also factor in some cross-entropy loss from the labels. The KL divergence measures the distance between the token score distributions of the student and the teacher. We compute the KL divergence for each position, sum the values, and divide by the batch size.
Let's look at the code. We take a batch, forward it through the student model, get the outputs, forward it through the teacher, and get the outputs. We compute the actual loss by padding the logits to have the same length, scaling them with temperature, and applying softmax. We compute the KL divergence between the student and teacher logits, and return a weighted average between the KL loss and the cross-entropy loss.
Now, let's explain hidden states distillation. Logit distillation focuses on mimicking the output, while hidden states distillation focuses on mimicking internal states, the output values at each layer in the teacher model. The high-level process is the same: predict with the teacher, save some stats, and feed that to the student. The difference is in the nature of the data coming out of the teacher. We can do offline distillation, which takes more storage, or online distillation, which is more compute-intensive but doesn't require storage.
The student model is much smaller than the teacher, which creates problems. The student model will have fewer layers and narrower layers. For example, Qwen 2 1.5B has 28 layers, each with 1.5K parameters, while the larger model has 48 layers with 5K parameters each. We need to match the data coming out of the teacher and the student. We select which teacher layers to use for training, often by picking one in N layers. We also need to extend the student layers to match the size of the teacher layers by adding trainable parameters.
In DistillKit, we predict a batch with the teacher and the student, including the adaptation layers. We compute the loss between the hidden states of the teacher and the student for each mapped layer. We use KL divergence and add the losses for all mapped layers, returning a weighted average between the KL loss and the cross-entropy loss.
The code starts the same: predict a batch with the student and the teacher, grab the hidden states, and run the actual distillation loss. We iterate over each teacher layer, grab the hidden states, and compute the KL divergence. We add the layer loss to the total loss, scale it, and enrich it. This is more compute-intensive, but the high-level process is the same.
Now, let's compare the two techniques. Logit distillation is simple to implement and very effective for classification problems because we're just looking for the most likely token output. We can work with different architectures as long as they share the same vocabulary. The cons are that it's a shallow way to transfer knowledge, sensitive to the distribution of input data, and the distilled model could be less adaptable.
Hidden states distillation offers richer knowledge transfer, generalizes better, and is usually a better option for downstream tasks and further fine-tuning. It's more complex to implement, more compute-intensive, and has a risk of overfitting if the dataset is too small.
To close, let me show you two models we built with distillation. The first is Supernova, a 70 billion parameter model. The end model is a 317 billion parameter model, and an important step was distilling the 314 billion parameter model. It took five days on 32 H100s, but it was much more cost-effective than training a 317 billion parameter model from scratch. We merged the distilled model with two additional 7 billion parameter models: one trained on synthetic data using Spectrum and another aligned with DPO on our own datasets. The result was a 7 billion parameter model that outperformed the 314 billion parameter model on some benchmarks and outperformed some of the biggest closed models, like CLO35 and GPT-40, at the time.
The other model is Supernova Medius, a 14 billion parameter model available on Hugging Face. It's based on the Qwen 2.5 14 billion architecture and is a merge of three models: one distilled from the large Llama model, another from a 72 billion parameter Qwen model, and another Qwen 2.5 14 billion model trained internally. We used a tool called MerchKit Surgeon to replace the vocabulary in the Qwen 2.5 model with the Llama vocabulary, enabling cross-architecture distillation. When we launched it, it was the best 14 billion parameter model available, with performance close to 70 billion parameter models. We received great feedback from the community, which is always a good sign.
That's what I wanted to tell you. This was a bit hardcore, but hopefully, it helps you understand the technology and magic behind model distillation. That's it for today. Go and relax, and I'll see you soon with more content. Until next time, keep rocking.