Open-source LLMs are great for conversational applications, but they can be difficult to scale in production and deliver latency and throughput that are incompatible with your cost-performance objectives.
In this video, we zoom in on optimizing LLM inference, and study key mechanisms that help reduce latency and increase throughput: the KV cache, continuous batching, and speculative decoding, including the state-of-the-art Medusa approach.
Slides: https://fr.slideshare.net/slideshow/julien-simon-deep-dive-optimizing-llm-inference-69d3/270921961
⭐️⭐️⭐️ Don't forget to subscribe to be notified of future videos. Follow me on Medium at https://julsimon.medium.com or Substack at https://julsimon.substack.com. ⭐️⭐️⭐️
00:00 Introduction
01:15 Decoder-only inference
06:05 The KV cache
11:15 Continuous batching
16:17 Speculative decoding
25:28 Speculative decoding: small off-the-shelf model
26:40 Speculative decoding: n-grams
30:25 Speculative decoding: Medusa
Transcript
Hi everybody, this is Julien from Hugging Face. In the last few months, we've all experimented with great models like LAMA2, Mistral, and many others, and we've seen how amazing they were for chatbot applications. However, when the time comes to deploy those models to production, it's not very easy to get the low latency and high throughput that we expect. In previous videos, we've discussed different techniques on optimizing those models, like better retention layers, quantization, compilation, hardware acceleration, etc. In this video, we're going to dive into the actual inference process for those decoder-only models. We're going to look at how techniques like the KV cache, continuous batching, and speculative decoder can help increase performance. Alright? Sounds good? Let's get started. If you enjoyed this video, please give it a thumbs up, consider subscribing to my YouTube channel, and don't forget to enable notifications so that you won't miss anything in the future. Also, why not share the video on your social networks or with your colleagues, because if you enjoyed this, it's quite likely someone else will. Thank you very much for your support.
Before we optimize inference, we need to understand how it works. So let's start with how decoder-only inference works. Decoder-only is the architecture of those GPT-like models like LAMA, Vicuna, Mistral, etc. Because they are decoder-only, they are a little bit different from the traditional model architecture for transformers. If we look at the reference architecture from "Attention Is All You Need," we won't need an encoder because here the input will be basically just a prompt. We're not doing sequence-to-sequence, which we would do for translation or Q&A. So we don't need to train a model on an input sequence, encode it, and match it with the ground truth output sequence. There is no such thing. We don't need an encoder, decoder, or multi-head attention because we don't have an encoder in the first place. Our inputs are really the tokenized prompt, and then we embed it, encode it, and run it through the decoder to generate tokens. That's the basic architecture for GPT-like decoder-only models. A bit simpler, you could say.
So, inputs are processed in the following way. When we say inputs, we mean the tokenized prompt. They are embedded, encoded positionally, and then we run multi-head attention to compute the keys and values for each of the input tokens. These KV values will be used to actually generate the next token. This is highly parallel and a large matrix multiplication, which is what AI accelerators have been built for. We can do this very efficiently, and we see pretty high usage on the hardware accelerator. While there's always room for optimization, generally this part works well out of the box and isn't the number one problem we want to solve. The real problem is, of course, once we've done that and want to generate the output, we do it one token at a time. Based on the input prompt and the KV values we computed, we generate the next token, append it to the previous input, and repeat the process. This is a highly sequential process, and we repeat it until we've generated the maximum number of allowed tokens or until we generate an end-of-sentence token. If we're going to generate 500 tokens, we're going to do this 500 times: embed, encode, KV, generate, append, repeat. The problem is this is sequential, meaning we can't parallelize much, leading to low usage of the hardware accelerator and low throughput overall. Unless we start optimizing, we won't get great performance and cost-efficiency out of that LLM.
The first obvious thing we can do is avoid recomputing KV values again and again for the same input tokens. Remember, we do it for the prompt. We compute KV, generate a token, add a token, and now we have a sequence length plus one. We do KV again, but all the original values will be the same because they are dot products. We only need to generate K and V for the token we appended, the one we generated. Everything else will be unchanged, and it's a massive waste of time. While it's highly parallel, it's still a waste of time. We can speed up this part using the KVCache. The KVCache will store the keys and values for all the tokens we've already processed. For example, if the first token is "I," the second token is "love," and the third token is "tranium," we only need to compute the keys and values for "tranium." The rest we already have. We've already seen "I" and "love," and we've already generated and computed them. We don't need to compute all those things again. If you want to see the actual matrix operation, this is what it looks like. When we add a new token, we extend the sequence length. We have a new key and a new V, and this is the only thing we need to compute. The gray areas, Kprev and Vprev, are cached. The purpose of the KV cache is to save us from computing those operations again and again. The longer the sequence, the more impact this will have. The first token takes longer to generate because everything will be empty, which is why the first token takes longer to generate. In benchmarks, you see the time to the first token, which measures how efficiently you manage K and V, and the average time for the next tokens, which measures the generation process itself.
We need to understand how big the KV cache is. This will be in accelerator RAM, and there's never enough of that, so we need to be careful. For FP16 models, the cache size grows linearly with respect to sequence length, the number of attention layers, the embedding length, and batch size. We need to multiply by 2 because we have K and V, and again by 2 because we have 16 bits, so 2 bytes for each. For a 7 billion parameter model, this can be over 2 gigabytes. The bigger the model, the more layers, the larger the embeddings, the more memory you'll need. This could come at the expense of growing batch size, because if the KV cache is huge, you don't have much space left to increase the batch size and parallelize things further. More recent attention layers, which I covered in a previous video, work at shrinking the KV cache in different ways. You can learn about multi-query attention, group query attention, etc., which are used in newer models like LAMA to Mistral, and see why they make a big difference in shrinking cache size so that we can increase the batch size again. It's a never-ending battle.
Let's move on and see how we can improve the parallelism of inference with continuous batching. We're all familiar with batching, processing mini-batches on accelerators to increase throughput and leverage the thousands of cores we have. We could do the same with decoder-only models, but they are harder to batch than traditional transformers like translation or Q&A. The reason is the input and output length of your conversations can change a lot. One question might have a very short prompt, another a very long prompt, a yes or no answer, or a full page. In terms of processing time and generation times, you will see very different things. This variability is the problem. Let's look at an example. Here are different requests, and we see the time steps T1, T2, T3, etc. The yellow squares are input tokens, the blue squares are generated tokens, and the red squares are end-of-sentence tokens. Initially, we load four inference requests on the accelerator. It starts generating at T3, T4, T5, all good. A little while later, the third request was pretty short, ending at T5. The first one ended at T6, and the second one was much longer, ending at T8. Because it drags on, we have to wait until it completes to run another batch. Immediately, you see the problem: there's only one inference request being processed on the last few time steps, and everything else is completed, wasting hardware resources. Traditional batching, processing, say, 4x4, 8x8, and waiting for all of them to complete, is not a good idea here. Instead, we want to do continuous batching. It starts the same, but as soon as an inference request has completed, we evict it and start another one. For example, the first request ends at T6, and immediately at T7, we have something else running. We're not waiting for the second request to complete at T8. The same for the third and fourth requests. As soon as a request is over, we feed another one to the GPU. That's why it's called continuous batching. We never stop, trying to keep the hardware as busy as possible.
Obviously, we're only looking at the generation process, but what about pre-fill? Remember, we need to embed, encode, and compute or retrieve KV. Can we pause the generation process to run that input processing, the pre-fill process, for the new queries we're going to run in the future? Yes, we need to do that. Otherwise, we'll run out of queries to process and stall the whole thing. Depending on the inference server you use, you may have a parameter to control the ratio of waiting input queries versus running generation queries. In our TGI library, where this is implemented, there's a parameter called waiting_served_ratio. You can control how often you want to pause the generation process to run pre-fill for the next ones. Continuous batching is a very good technique, making a very large difference.
Now, let's talk about the third technique I wanted to discuss today, which is more and more popular: speculative decoding. If you're completely new to this, there's a really cool blog post by my colleagues. I highly encourage you to read it. We said the generation process only outputs one token at a time, not making good use of the highly parallel hardware on AI accelerators. The whole process is not compute-bound; a lot of hardware is just doing nothing. It's memory-bound because of the KV cache we need to load again and again. We have compute resources available, so the basic idea is to use that idle compute to do something smart. For example, could we use a smaller model, which would be a decent approximation of the large model, to predict several completions, multi-token completions, in parallel? In a way, it would tell us, "Hey, I know you're generating a token right now, but I think the next four, five, six tokens could look something like this." Here are five, six, seven different ways to do it. Looking ahead with a smaller model that is more nimble and can leverage the idle compute we have. At some point, we need to look at those potential completions and see which one is the best, which one has the largest number of correct tokens, and pick that one. We ask the large model to evaluate those potential completions and pick the one that is closest to what the large model would have generated sequentially. At each iteration, we still get one valid token generated by the large model in the usual way. If the small model did its job well, we could have way more. One of those completions with, let's say, five, six, seven tokens could be great, and then maybe we have six additional tokens in one go, saving six runs through the large model. That's where the speedup comes from. Let's look at an example. We start from "the quick brown." We ask the small model to come up with a completion. Let's say one of them is "fox jumps into the." We take the full sequence, send it to the large model, and ask, "Is this what you would have generated?" It says, "Well, 'fox jumps' is fine. 'Into the' is not what I would have generated. The next token would have been 'over.'" We discard everything after "jumps" and append "over." We do that again. We feed that to the small model, generate a completion, and ask the large language model to validate those tokens, invalidate some of them, and add the one it generated after that. That's pretty cool and works very well.
In the bottom left corner, you see the relationship between how well the small model approximates the large model, a parameter alpha. If alpha is close to one, the small model is a really good approximation of the large model. If it's lower, it makes a lot of mistakes. On the y-axis, you see the number of expected tokens we can validate. If alpha is 0.9 or higher, the number of expected tokens will be very close to the number of tokens the small model generates. You can see the different colors for different numbers. If you move left, the small model is not a great approximation, and the expected number of tokens is quite low compared to what the small model generates because some of them will be discarded at the validation step. We can see what that means in terms of speedup. For example, if we have alpha at 0.8 and gamma at 5, a decent approximation, and generate five additional tokens for each completion, the speedup can be up to 3x. If we have alpha up to 0.9, we can get even better, especially if we increase the number of sequences. But that means you have a really great small model, which might not always be the case. In the paper, they test this with T5XXL, which is 11 billion parameters, and try two tasks: ENDE (English to German translation) and CNNDM (summarization). With seven predicted tokens, T5 small gives a 3.4x speedup. T5 small is a much smaller model, and the paper recommends going to models that are at least one order of magnitude smaller, maybe even two orders of magnitude. It's amazing to see those small models doing really well and delivering real-life speedups of 2x, 2.5x, and sometimes over 3x.
Now, the big question is how do we build that small model? With the T5 example, it's simple because we have many different sizes of T5, but that might not always be the case. Let's look at different options. The first one is using different sizes of the same model. Try the smaller ones and figure it out. Another way is to use an n-gram approach. This is a clever trick where we use tokens found in the prompt. We'll look at that. We can also fine-tune a small model based on the large model, which stays frozen, or fine-tune the small model and the large model together for maximum performance. Let's quickly look at those four options. The first one, the small off-the-shelf model, is the T5 scenario. We can see the Pythia model here, 1.4 billion and 1.6 million, so about a 10x ratio. We load both, and when we generate with the large model, we just pass the small model as a parameter. The whole process of generating potential completions and validating them happens under the hood. If you have models in different sizes, this is a cool thing to try because you could get 2x, 3x speedup out of the box without training anything. The restriction is that the two models need to have the same tokenizer, so you would look at two models from the same family in different sizes. It's a simple, elegant solution.
Option two is n-grams. This comes from the observation that for some tasks like summarization, Q&A, short discussions, and code editing, there's a strong relationship between the tokens in the input and the tokens in the output. Words found in the prompt are very likely to be found in the generated answer. It's more than words; it's n-grams, short sequences of related words. For input-grounded tasks, where the output is strongly related to the input, we can use the strings, words, or word pairs in the input to accelerate the completion of the output. This n-gram speculative decoding works very well for those tasks, with speedups of 2x to 4x, no model modification, and no impact on output quality. It's implemented in Transformers and works with any model. It wouldn't work well if you had a very short question and asked for 5,000 tokens because you have so few input tokens to pick from to generate a long answer. But for shorter tasks, it works well. Here's an example. We have a prompt about a French football player, extract n-grams, ask the large language model to validate, and it will validate some, discard some, and add the token it generated. The exact same scenario as before, except here there is no model predicting; it's using n-grams to find and concatenate them. This is how it works in the Transformers library. You just need to add the parameter `prompt_lookup_num_tokens`, and that's it. It's a great solution worth trying and benchmarking to see if you can accelerate your transformers and generation process.
The last speculative decoding technique I wanted to cover today is called Medusa. In our four options, we saw we could fine-tune a model either on top of or together with a large language model. Medusa is about adding decoding heads to the LLM to predict multiple outputs in parallel. We start from, let's say, your LAMA, Vicuna, and add additional decoding heads so that instead of generating one output, it can generate multiple outputs in parallel, and we verify them at each decoding step, selecting the one that works best. On the left, you see the original model, the decoder-only architecture. We plug in Medusa heads, each generating a new potential completion, and at the decoding stage, we pick the completion that works best, keeping the validated tokens, discarding the invalid ones, and adding the one the original model generated. The whole process is based on a technique called tree-based attention, which is nicely explained in the paper. We won't go into the details, but there's a specific algorithm for verification and selection at the decoding stage.
The next question is how do we train those Medusa heads? There are two different ways. MEDUSA1, where we leave the original model alone and only fine-tune the Medusa heads. This is efficient because they use parameter-efficient fine-tuning with QLoRA, so LoRA plus 8-bit quantization. The paper mentions Vicuna 7B and 60K samples, which took only five hours on a single A100 GPU. The second technique is joint fine-tuning, where we fine-tune the heads and the model together. If you were going to fine-tune your LLM anyway, this is the preferred option because the heads will be tightly related to the model, and you can expect better performance. If you have your LLM and don't want to fine-tune it, just want to accelerate it, then MEDUSA1 is a good way to do this. Here's a cool animation from the repo. You can see this is clearly 2x faster. The numbers show Medusa1 is over 2x faster, and Medusa2 is almost 3x faster. This is because if we fine-tune the heads and the model together, they perform better, generate better potential completions, and fewer tokens are discarded, saving more iterations through the larger model and speeding things up more.
As you can see, decoder-only inference is a very interesting topic, and there are many ways to accelerate it, from the KV cache to stop computing those KV values again and again, to continuous batching to keep pumping requests to the accelerator as soon as it has availability, to speculative decoding to save us from generating every single token through that large, slow LLM and use smaller, hopefully accurate approximations to save time. This is a very active field, and I'm sure we'll see more techniques. Feel free to try them in TGI, as all of this is implemented there, making it a good sandbox for you to experiment. If you enjoyed this video, please give it a thumbs up; it helps me with the YouTube algorithm. I appreciate your support, and until next time, keep rocking.