Decoder only inference a step by step deep dive

January 10, 2025
In this deep dive video, we explore the step-by-step process of transformer inference for text generation, with a focus on decoder-only architectures like those used in GPT models. The step-by-step breakdown covers self-attention, KV caching, and multi-head attention (MHA), culminating in an in-depth look at the advanced Multi-Head Latent Attention (MLA). You'll learn how MLA improves efficiency by reducing memory usage and accelerating inference without compromising accuracy. Finally, you'll learn how attention outputs are used for token generation. ⭐️⭐️⭐️ Don't forget to subscribe to be notified of future videos. You can become a channel member and enjoy exclusive perks: details at https://www.youtube.com/channel/UCVonoXm3SI_Q0ZNHd5JPawA/join You can also follow me on Medium at https://julsimon.medium.com or Substack at https://julsimon.substack.com. ⭐️⭐️⭐️ In this deep dive video, we explore the step-by-step process of transformer inference for text generation, with a focus on decoder-only architectures like those used in GPT models. We delve into the mechanics behind their operation, starting with an analysis of the self-attention mechanism, which serves as the foundational building block for these models. The video begins by explaining how self-attention is computed, including the role of queries, keys, and values in capturing contextual relationships within a sequence of tokens. We then examine the significance of the KV cache in optimizing performance by avoiding redundant computations during token generation. The discussion progresses to multi-head attention (MHA), a key innovation in transformers that enables the model to capture diverse patterns in data through parallel attention heads. We address the memory bottlenecks associated with MHA and the techniques employed to mitigate them. We also introduce multi-head latent attention (MLA), a cutting-edge alternative to traditional MHA. MLA significantly reduces memory usage by caching a low-rank representation of key and value matrices, enabling faster and more efficient inference. This breakthrough is explained in detail, alongside comparisons to MHA in terms of performance and accuracy. Finally, the video walks through the process of translating attention outputs into coherent text generation. This includes the role of projection layers, softmax normalization, and decoding strategies like greedy search and top-k/top-p sampling. This comprehensive exploration provides a detailed understanding of the inference process, emphasizing practical challenges and the state-of-the-art solutions that address them. Whether you're a researcher, engineer, or AI enthusiast, this video offers valuable insights into the mechanics of generative language models. * Slides: https://fr.slideshare.net/slideshow/deep_dive_multihead_latent_attention-pdf/274778590 * "Deep dive: better Attention Layers": https://youtu.be/2TT384U4vQg 00:00 Introduction 01:20 The architecture of decoder-only transformers 05:10 The self-attention formula 05:51 Computing self-attention step-by-step 14:50 The role of the KV cache 18:25 Multi-head attention (MHA) 20:40 Computing multi-head attention step-by-step 23:20 The memory bottleneck in multi-head attention 25:15 Multi-head latent attention (MLA) 28:30 Computing multi-head latent attention step-by-step 35:00 From attention outputs to text generation 41:00 Conclusion

Transcript

Hi everybody, this is Julien from Arcee. Attention layers are at the core of transformer models, but unfortunately, the attention mechanism is pretty computationally heavy, which slows down transformer operations and particularly inference. Over time, a number of techniques have been invented to try to solve that problem, and the latest one is called multi-head latent attention, and that's the one we're going to study today. First, we're going to look at vanilla attention. We're going to run a step-by-step example showing you every operation with the tensor dimensions. Then we'll move on to multi-head attention, again going step-by-step and understanding the differences. We'll talk about the KV cache and why it's so important. And then we'll talk about multi-head latent attention. We'll put all the pieces together, and you'll see why this is such an important development in transformer architectures. As a bonus, we'll go all the way to text generation, showing you how we turn attention outputs into generated tokens. I promise there's not going to be any heavy math. The only thing you need to know is matrix multiplication. So I think you'll be fine. Okay, let's get started. Before we dive into attention layers, let's take a quick look at the architecture of text generation models and where attention layers live. As you probably know, these models are decoder-only models. So compared to reference transformers visible in the original paper, they don't have an encoder and obviously no connection between the encoder and the decoder. So just the right-hand part here, just the decoding part. Decoding works in two steps. We have the input processing step, or the prompt processing step, and we have the text generation step. So let's look at the input processing first. We start from the input sequence, which has been tokenized. We retrieve the embeddings for those tokens. We encode them with their positions. And then we run the attention calculation. Along the way, we compute those so important K and V, keys and values, which are central to the attention scores. This is really just a series of large matrix multiplications. And it's a good thing because we have thousands of cores on AI accelerators, so we can easily parallelize and scale that. In fact, compute is not the problem here. So when we're trying to optimize the attention, at this point, we're not really trying to save compute, although obviously, the less we use, the better. But generally, we have enough. The problem we have is memory bandwidth between the cores, right, between, let's say, the GPU cores and the external memory where keys and values and generally everything else is stored because K and V, as we'll see, are very large. Although GPUs have a ton of bandwidth, it's never enough. And this process is really memory bound. So it's really limited by how fast you can read and write those values into the GPU cores. That's one thing we're going to try and optimize later. OK, keep that in mind. Once we've processed the prompts, we can move on to the second step of the process, which is generating new tokens. Unlike input processing, output generation is a sequential process. So we generate one token at a time, meaning if you're generating up to, let's say, 512 tokens, you may run that text generation loop up to 512 times. Newer models are able to generate multiple tokens at a time, but I would say generally most models still do one token at a time. So that's the one I'm going to stick to today. So we'll see how that process works. When we have a new token, it becomes part of the input. In fact, the token we just generated is going to be used as the input to generate the next one. So what's the most likely token that could follow the one that has been generated? We do this again and again until we get an end-of-sequence token or until we hit the maximum number of tokens we said we would generate, maybe five. So that's what we're going to look at. We're going to try and crack those boxes open and understand what's going on in there. So here's that famous attention equation that you can see in the paper. Multiplying Q by K transpose, scaling them, applying softmax, multiplying by values, right? You've seen this one a million times, but do you really understand what's going on in there? It took me a while to figure that out. And I'm guessing if you're watching this video, you want to understand what is really happening here and turn that equation into actual steps that the model performs. And that's exactly what we're going to do now. So let's break down that attention calculation step by step, looking at input tensors and their dimensions, giving you a numerical example, and explaining why we're doing every one of those steps. First, we need an input sequence. So my input sequence is "the quick brown fox jumps over the lazy dog." That's nine words, and I want to generate the next tokens. That's the problem I want to solve today. The first step is to turn those words into tokens. Because models don't care for words; they need numbers, and that's what the tokenization process is. It will turn each one of those words into one or potentially several tokens and assign to each token a numerical ID, which is the unique ID of that token in the model vocabulary. These are the actual token IDs for the BERT base models. You may wonder why we end up with 11 tokens because we started from nine words. In this example, the first token is the start of sentence token, and the last token is an end of sentence token. Again, some complex words or verbs can be broken down into several tokens, but here that's not happening; we're just adding a start and end token. So that's our input dimension: 11. Next, we need to turn each token into its embedding. Embeddings are those high-dimensional vectors that encapsulate concepts, relationships, and generally knowledge, and of course, they are learned at model training time. For example, let's say our model works with 512 embeddings. This means that each token in the input sequence corresponds to a 512-dimensional vector. That's the purpose of the embedding layers: to map a token ID to one of those vectors. This is also called the hidden dimension, and it can change from one model to the next, 768 or 1536, etc. It's one of the design parameters of the model. I skipped the positional encoding mechanism because it doesn't really contribute to this discussion. So now we're almost ready to start computing attention scores. For this, we're going to need three weight matrices in the model. Again, these have been learned during the training process. They all have the same dimensions, which is hidden by hidden, so 512 by 512. In my example, we have WQ, WK, WV, and we'll see where those fit. Just keep in mind they have the same sizes. And later on, we also use W0, which again has the same dimension. These are model weights that we learned. What are they good for? They're used to compute Q, K, and V, so queries, keys, and values, which are the ones we saw in the previous slide. With those matrices, we can start computing our attention scores. Q, queries, K, keys, V, values. The math is not a problem here. We're just multiplying our inputs by WQ, WK, WV, right? And we get three matrices with the same dimensions, 11 by 512. What's a little more complicated is understanding the purpose of Q, K, and V. And why are those things called queries, keys, and values? So attention is trying to find relationships between words or tokens in the input sequence. One way to understand this is to realize that Q is a query. Each token needs to know something from other tokens. Key K is the opposite; K is what each token brings in terms of knowledge to all the tokens. So it's like, "Okay, these are keywords. These are things. These are concepts. These are relationships that I bring to the table." And V is what the token actually shares. The actual knowledge, the value that the token shares when it's attended to. I know that sounds a little theoretical, but let's look at the calculation, and I'm sure you'll understand that. The first thing we do is multiply Q by K transpose. The result is a matrix N by N, so 11 by 11. What it tells us is how well each term in the input sequence matches what each other token in the input sequence is looking for. Each token wants to find knowledge in other tokens, and each token is bringing some knowledge to the sequence. Multiplying Q by K is giving us the result of that. We divide by some number for numerical stability, and we take softmax, skipping those steps, which are just numerical steps. Then we multiply by V. What we do here is, based on those similarity scores, based on how well the knowledge of each token matches what each token is looking for, we add the appropriate amount of value. So if two tokens have a very high similarity score, we'll inject a lot of value. If the similarity score is very low, we'll inject just a tiny bit of V, and that won't matter much. The result of this calculation is called attention weights. Feel free to pause the video and read that stuff again. As you can see, the math here is not the problem. What is a bit tricky is understanding what Q, K, and V are all about. So once again, for each token, we're really trying to figure out what tokens in the input sequence give me what I need, that match my search query with their keywords. If they match well, the dot product will be fairly high, and so I want to bring a lot of the value for that token into my attention score. At the end, what we have is this N by hidden matrix, so 11 by 512. The last step is to multiply those attention weights by W0, which is a fit form weight matrix. The purpose here is just to give the model a chance to capture additional interactions across the sequence outside of attention calculations. That's what this formula is all about. Let's move on to the KV cache. We just saw that each round of text generation adds a new token to the input sequence. Meaning that if we remember the previous slide, we would have to run the same multiplications again on K and V, just with a longer input sequence, plus one, then plus one, and then plus one. The problem is, if we do that in a naive way, we keep recomputing the same values again and again. For example, if you have 10 tokens in the input sequence, you've computed K and V for those 10 tokens. Now you have an 11th token. What you really want is how this 11th token relates to the 10 previous ones. But the relationship between the first and the second, the second and the third, etc., they haven't changed. So why would we compute those scores again and again? The matrix keeps growing, but most of the cells in the matrix are going to be the same. It's just a waste of compute. We can see that in this example: if we have two tokens and a third one, we've already computed the relationship between the first and the second, so we don't need to do that again. We need to compute the relationship between the first and the third and the second and the third. So we're just expanding the KV values but don't want to recompute everything again, and that's the purpose of the KV cache. A lot of the K matrix is already there. A lot of the V matrix is already there. We just need to compute the dot products for the new token with respect to all the previous ones. That's a huge optimization here. We need to cache that stuff, meaning we need to put it somewhere in memory where we can retrieve it instead of computing it every single time. Obviously, this doesn't work for the first token that we generate, which is why it takes longer to generate and that's why you see that time to first token measure in model benchmarks because it's a measure of how fast you can pre-fill the KV cache. The KV cache is super important but is obviously pretty large because it grows linearly with the number of layers, it grows linearly with the size of the embeddings, and it grows linearly with the input size. Very quickly, you know, it's going to be gigabytes and gigabytes of accelerator RAM that you're going to be using here. As models get bigger and as sequence lengths get bigger, context sizes become a problem, and this is precisely what multi-head latent attention is trying to fix. Before we look at multi-head latent attention, we need to look at multi-head attention. It works pretty much the same as the attention mechanism we just saw, except we split the attention computation across several layers, which we call heads. Instead of doing that calculation in one big step, like we saw before, we're going to parallelize it across several layers. All heads get the full input sequence, but they only work with a subset of the embedding dimensions. That subset is driven by how many heads we have. So we take the size of the embeddings, divide by the number of heads, and that tells us how many dimensions each head is going to see. Each head has its own set of query, key, and value weight matrices. They all work in parallel. The end result is pretty much the same, but now instead of doing one big matrix operation, we're going to do several of them on smaller matrices in parallel. The main benefit is not speeding things up. You could think, "Well, I'm parallelizing here, right? So I can leverage more cores or whatnot." But not really. It's not the point. The point is that each head can work on a subset of the embedding space. So it can, in a way, focus on a specific part of the embedding space and learn specific patterns, specific relationships in that subspace. When we train the model, we're not training the model on 512 dimensions. We're just telling each head to look at relationships in its subspace, which may be 64 dimensions. And apparently, this yields better models. Let's go and look at the computation again. The steps are exactly the same, except we work in parallel on smaller subspaces. It starts the same: input sequence, tokenized, embeddings. None of that stuff is different. Where it changes is we have different heads, we have different potential layers. So let's say here we have eight. Instead of working on 512 dimensions, each head is going to work on 64 dimensions. So I'll have eight attention heads, each one of them looking at a slice of the embedding space, only 64 dimensions, and each one of them having their own weight matrices for Q, K, V. Computation is the same. We compute QI, KI, VI, again, on the subset of the dimensions. We store them in the KV cache, and all that stuff runs in parallel. So in this case, we have eight rounds of matrix multiplications instead of one big round as before. Attention score computation is the same. QI multiplied by KI transpose. All that stuff in parallel. Attention weights, same story. All run in parallel, and at the end, we just concatenate the results. So we take attention weights coming from each of the heads and just concatenate them. And that's really what there is to it. Then, just like before, we multiply attention weights by W0 to capture potential interactions that happen across the sequence. So exactly the same, except we run many times in parallel smaller matrix multiplications on a smaller piece of the embedding space. Now we've done multi-head attention. Now let's talk about that KV cache again. By now, you understand we need to read those Q and K and Vs and write them, and all that stuff happens off GPU. This is usually called high bandwidth memory (HBM). That's where we read and write all those intermediate results. The problem is that these are large, and the number of HBM accesses we need to do is quadratic with respect to sequence length. So if sequence length doubles, we need to do 4x the number of HBM accesses. As context sizes increase and as we like to do RAG and all that good stuff, we tend to work with longer and longer sequences, and we end up reading and writing a lot in HBM. Even though HBM is blazing fast, if you're doing too much of that, it takes time, and that's the limitation. That's really the bottleneck for inference speed. Inference is not compute-bound; it's memory-bound. Lots of techniques have been invented to try and reduce the amount of KV data we have to read and write. I actually did a previous video on that, looking at multi-query attention, group query attention, flash attention, all those good techniques. The latest one is called multi-head latent attention, and it's quite different. Let's take a look at the high-level idea of multi-head latent attention. And then once again, we'll run those step-by-step computations to see how things work. Multi-head latent attention was introduced a few months ago in the DeepSeek V2 model. It's also part of the DeepSeek V3 model that just came out. This one is quite different. Again, I won't cover GQA and MQA in detail again, but the idea of GQA and MQA was just to use fewer keys and values matrices and share them across queries so that we can save on the KV cache. Certainly, there was a memory improvement, but unfortunately, there was an accuracy degradation. Again, go on and look at that video for all the details on that. What multi-head latent attention does is avoid caching full-size K and V altogether. Instead, we're going to cache smaller versions of K and V, and to go from the original large K and V to the much smaller K and V, we're going to learn during training a projection matrix that helps us shrink K and V to smaller matrices and then another matrix to restore the small K and V matrices to their original sizes. If you're familiar with LoRA, for example, this is exactly the same idea. It's really shrinking to a lower rank, so smaller size, to save memory in this particular case and optimize memory usage. Sounds a little bit theoretical, but we're going to look at the calculation. The result is that we can save over 90% of the KV cache. So that's huge savings. It's divided by 10, maybe even more. The paper reports 96% savings, so dividing by over 20. If we're talking gigabytes and gigabytes, that's huge, huge savings. And that's freeing a ton of memory on the GPU, allowing us to maybe use larger batch sizes or just load larger models. Because we're working with smaller matrices too, there is a very nice inference speedup. There's just less stuff to multiply and add. And believe it or not, you also get higher accuracy than the vanilla multi-head attention. If you want to look at all the details, I recommend reading appendix D1 in the DeepSeek V2 paper, where they show how for small models and large models, the numbers are indeed higher accuracy than vanilla multi-head attention. So what's not to like? Now let's look at the computation in detail. It starts the same: we have input and embeddings, and we have more multi-head query, key, and value weight matrices. I'm still assuming eight heads, so each head is going to work on 64 dimensions. No change. Now we introduce those projection matrices. So we have a down-projection matrix which will help us shrink from, let's say, 512 dimensions to something much smaller. And I have an up-projection matrix which will, in the end, help me resize my output to the original 512 dimensions. How small is small? If we want to save a lot of KV cache memory, we need to shrink that 512 dimension here, that's the one that's hurting us. Let's try to down-project from 512 to 32. That's huge savings; it's 16 times less. Very significant savings. But because we are doing multi-head attention, we will run down-projection across the heads as well. So we need to divide that latent dimension 32 by the number of heads. Still assuming I've got eight heads, it's going to be four. And that's why my down-projection matrix per head is 64 by 4. We'll talk about the up-projection matrix in the end. So for now, just keep in mind it starts the same, and per head, we have a small matrix that helps us shrink the computation to much smaller dimensions. In fact, it's what we do here. Q isn't changed; we don't touch Q because it's not cached. We apply the down-projection matrix to K and V. So it starts the same: X multiplied by WKE, and then we apply W A down. With that down-projection matrix, we can compute Ki and Vi, and there's just an extra step. We multiply X by W Ki, and then by the down-projection matrix, and same for Vi. So now, because we have applied that down-projection, Ki and Vi are 11 by 4, not 11 by 64, as they were in our multi-head attention example. So this means we have shrunk Ki and Vi by a factor of 16, which is huge, right? Huge memory savings. There is a tiny trade-off because now we're doing two multiplications instead of one. We're multiplying X by WKI and then by WI down and same for VI. So we save a ton of memory at the expense of a little bit of compute. But remember, we said we're not really compute-bound here; we certainly have enough computing power to absorb this small multiplication. What we're really after is saving a ton of memory. To keep things understandable, I've used the same down-projection matrix for K and V. In fact, if you read the DeepSeek paper, you will see they use different down-projection matrices. So they actually have WI down K and WI down V. It's not so important. What's important is understanding how we shrink those matrices. Next, we can compute the attention scores. Here's a bit of a trick because we didn't shrink QI, so the dimensions don't match. Meaning we have to multiply QI by the down-projection matrix and then by Ki transpose. Again, it's possible that we use a dedicated matrix for Q, so maybe we have WI down Q as well for this specific calculation. Keep in mind we're just doing this to match dimensions; we are not shrinking Q per se. It's just for the attention scores, and at the end, we still get our 11 by 11 matrix. We're not redoing Q by K transpose; we're doing QI by WI down by Ki transpose. So the formula here is really those cores on the previous lines. Finally, we concatenate the outputs, and because we do want to have the original dimensions back at some point, and you'll see why when we talk about the text generation step, we need to multiply those small dimension outputs by the up-projection matrix to take it back to the initial dimensions. So we started with 11 by 512, we end up with 11 by 512. So it looks the same, but in the middle, we actually worked with much smaller matrices and only cached those much smaller Ki and Vi matrices. That's multi-head latent attention. It's pretty clever. Now that we understand how those attention outputs are computed, let's see how we use them for text generation. Our starting point are the outputs that we computed. It doesn't matter if they come out of multi-head attention or multi-head latent attention; regardless of that, we have those values. The prompt has been processed, we filled the KV cache, etc. Remember what we said earlier: we generate one token at a time, and the starting point to generate a token is, in fact, the attention output for the last token in the sequence. So assuming we are generating the first token in the output, we will start from the attention output for the last token in the prompt. That linear layer, also called the projection layer, is what we use to compute the scores for all the tokens in the vocabulary. So assuming my model has 100,000 words, 100,000 tokens in the vocabulary, this is a rather large layer, as you can see. What we do is very simply take the attention output for the last token we processed and multiply it by the transpose of that linear layer. So it would be 512 by 100,000, and the result is going to be a vector of 100,000 row scores. What those scores represent are the score for each token in the vocabulary. The highest score, the more likely it is that we should use that token as our next generated token. We call those scores the logits. And because we prefer to work with probabilities instead of raw numerical values, we apply the softmax function to that large vector. So now we have a vector of 100,000 probabilities, all between 0 and 1 and all adding up to 1. That's what the softmax function does. So now we know which token has the highest probability. How do we select the token that we're going to output and add to our sequence? That's the final decoding step, which really happens outside the model. The model has done its job; it's generated the probabilities. And so now we can use different strategies to evaluate those probabilities and select the appropriate token. The simplest strategy is called greedy decoding. We just pick the token with the highest probability. Whatever has the highest probability in that 100,000 probabilities is the one we select. Another strategy is called top-k sampling. That explains what that top-k inference parameter is. Top-k sampling means we select randomly a token from the k most likely tokens. So let's say k is equal to 10. We take the 10 tokens with the highest probabilities and pick one randomly. Obviously, that's a way to introduce some randomness into the generation process and maybe help the model generate different words, get more creative. Another strategy is called top-P decoding, and that explains what that top-P parameter is. Top-P works this way: we select the smallest subset of tokens for which the cumulative probabilities add up to P. So if P is, let's say, 0.5, we're going to select the most likely tokens out of all their probabilities until they get to 0.5, and we'll pick a token from those. Again, this is a way to maybe expand how the model generates and select tokens, get the model to be a little more creative and generate different answers. So that's what those top-k and top-P parameters are at inference time. And so if you tweak them, this is what's happening under the hood. Regardless of the strategy you used, you have a new token. You can output that to the user, and this new token becomes the next input. And we repeat the process until we meet a stopping condition. Again, it could be the end-of-sentence token or it could be the maximum number of output tokens. So that's text generation. You can see actually that text generation is very easy; there's no difficulty there. All the magic comes from those attention outputs. And we spent a fair amount of time explaining them. All right, my friends, we did dive pretty deep into attention and went quite far, breaking down that attention formula into small steps. We looked at vanilla attention, what multi-head attention is and why it's useful, focusing on smaller parts of the embedding space. We discussed the memory bottleneck problem that attention has and how the KV cache can help with that and also how we can shrink the KV cache using fancy techniques like multi-head latent attention. And again, regardless of the attention method we use for computation, we saw how we could use those attention scores to actually generate text. So that's it for today. I'm pretty sure we'll keep talking about attention. And until next time, you know what to do. Keep rocking.

Tags

Transformer ModelsAttention MechanismMulti-Head Latent Attention