Deep dive Better Attention layers for Transformer models
February 12, 2024
The self-attention mechanism is at the core of transformer models. As amazing as it is, it requires a significant amount of computing and memory bandwidth, leading to scalability issues as models get more complex and context length increases.
In this video, we'll quickly review the computation involved in the self-attention mechanism and its multi-head variant. Then, we'll discuss newer attention implementations focused on compute and memory optimizations, namely Multi-Query Attention, Group-Query Attention, Sliding Window Attention, Flash Attention v1 and v2, and Paged Attention.
Slides: https://fr.slideshare.net/slideshow/julien-simon-deep-dive-accelerating-models-with-better-attention-layers/270921899
⭐️⭐️⭐️ Don't forget to subscribe to be notified of future videos. Follow me on Medium at https://julsimon.medium.com or Substack at https://julsimon.substack.com. ⭐️⭐️⭐️
00:00 Introduction
03:00 Self-attention
07:20 Multi-Head Attention (MHA)
12:32 Multi-Query Attention (MQA)
18:45 Group-Query Attention (GQA)
22:47 Sliding Window Attention (SWA)
26:17 Flash Attention
31:28 Flash Attention v2
34:36 Paged Attention
39:00 The Hugging Face LLM performance leaderboard
Transcript
Hi everybody, this is Julien from Arcee. As we all know, the self-attention mechanism is at the core of transformer models. But as amazing as it is, it's fairly heavy in terms of compute requirements and memory requirements. Over the last year, a number of really amazing optimizations have been invented to solve those problems. They've been implemented in state-of-the-art models available on Hugging Face. So in this video, we're going to start with a quick review of the self-attention mechanism and then look at two different ways this attention mechanism can be optimized. First, to reduce the memory bandwidth requirement with techniques like multi-query attention, group query attention, etc. And then by inventing and implementing faster, smarter attention layers. And of course, we'll discuss flash attention, flash attention v2, page attention, etc. Hopefully, I will put everything in plain English. I've read all those papers and tried to summarize them in the simplest possible way. Hopefully, this is your chance to understand all those concepts in plain English.
If you find this video useful, please give it a thumbs up. Consider subscribing to my channel and don't forget to enable notifications so you won't miss anything in the future. Also, why not share the video on your social networks or with your colleagues? If you find it useful, others may find it useful too. Thank you very much.
Before we dive into the attention layers, I want to take a few steps back. There are really many ways to accelerate models. So today, we'll look at what I call new attention layers and faster attention layers. These are technology improvements on the model side of things. But obviously, there are other ways. In quite a few videos, I certainly discussed hardware acceleration and using hardware features to accelerate models. We won't go into that today, but there will be more content later on this. There's another side of the discussion, which is framework features, model compilation, PyTorch 2.0, etc., quantization, and so on. Again, we won't go into that stuff today, and that's coming later too. But just to give you the big picture on the different techniques available to accelerate models. So today, we'll just focus on the attention layer.
Let's start with what I call new attention layers, which are really evolutions of the self-attention mechanism. Why don't we start by reviewing this? As mentioned, the self-attention mechanism is at the core of transformer models. It's what makes them great. It's the major building block for these models. And as we all know by now, this was revealed to the world in a famous paper called "Attention is All You Need," published in mid-2017. I would encourage you to read it, even if you're not a big fan of research papers. It's actually not that ugly. Yes, there's a bit of math, but the authors do a really good job explaining the attention mechanism and the key elements that make it powerful. Don't be afraid; don't censor yourself. Even if you don't have a strong background in math and computer science, why don't you dive into it? You'll see; you can figure it out. There are a ton of really good detailed videos on YouTube on the attention mechanism and coding it from scratch, etc. I would highly recommend that you go through that stuff as well.
Today, we want to understand what's a little bit wrong with this attention or self-attention mechanism. From a compute perspective, we are really multiplying large matrices that have pretty large dimensions because there's the sequence length involved. That could be hundreds if not thousands of tokens, and there's the embedding dimension, which could be hundreds or more. So, without focusing too much on that equation, the main problem here is that we are multiplying very, very large matrices every time we encode or decode an input sequence. Whether we're training or running inference, we are multiplying those matrices. The problem is that we end up with quadratic complexity for compute and memory access. Quadratic is just a fancy way to say that complexity grows to the square of the sequence length. So if you want to double the sequence length, compute complexity and memory access complexity increase 4x. If you multiply the sequence length by 4, it increases 16x, etc.
The original transformers had a reasonably short sequence length. BERT was 512, etc. But now we're seeing large language models with thousands, 2k, 4k, 8k, sequence lengths. Particularly for inference, this becomes a problem. As the context grows bigger, especially with the popularity of retrieval-augmented generation, inference becomes very expensive. That's the base problem we want to solve: reducing the amount of compute and memory accesses required to compute those self-attention scores. There are really two ways to do this. One focuses on reducing memory access, and the other reduces algorithmic complexity and the number of operations involved. We'll cover both.
Let's start by looking at how self-attention is implemented in our favorite models. The paper describes self-attention, but models like BERT actually implement what is called multi-head attention. What happens here is we split the attention calculation across the dimensions of embeddings and across a number of attention layers. We don't have a single self-attention operation; we have multiple. Each head looks at a fraction of the embedding space and can learn different relationships than the other layers. This is not a compute or memory optimization; it's just how attention is implemented in the original transformers because we want each head to focus on a fraction of the embedding space.
If you have, let's say, 1024 embedding dimensions and eight heads, each head will look at 128 dimensions in the embedding space. This is not a compute optimization; it's just how it's implemented. To understand what's wrong with this, we need to understand how it's implemented. The standard implementation involves matrices: the keys (K), the values (V), and the queries (Q). We load them from memory, multiply them, and each head does its part of the calculation. Then we write the results back to memory. The problem is that GPUs don't have a lot of onboard memory. They have what is called HBM (High Bandwidth Memory), which is off-chip. HBM is not the fastest memory the GPU can access. When we load huge matrices with larger dimensions due to sequence length and embedding size, we have to load them from memory that doesn't sit on the GPU. Even though HBM is very fast, this still takes quite a bit of time.
This became a major problem when transformer models started scaling, especially with large language models (LLMs) scaling sequence length and context length. The cost of loading those matrices for every inference or training step became too much. Training time and, more importantly, inference time became a problem. The number of tokens per second dropped, and LLMs became way too slow. Memory became a bottleneck, and that's the problem that needs to be solved. If you want to look at the implementation for multi-head attention, I referenced the file in the transformers library, the code for BERT. Believe it or not, this is rather straightforward to read. You can see the algorithm implemented in the code, with K, Q, and V being multiplied and in action. If you're curious about how this is implemented, I encourage you to check out the BERT implementation.
Now that we understand how attention is problematic in terms of memory, let's look at the first optimization that was invented. On the left, we see multi-head attention, and on the right, we see multi-query attention (MQA). The difference is in the orange box of multi-head attention, where VI and KI are the values and keys used by each head. Each head has its own set of values and keys, which are different and need to be loaded. MQA uses the same set of values and keys across heads. For example, if we have 32 heads, instead of having 32 VI and KI matrices, we have just one V matrix and one K matrix shared across all heads. The benefit is significant: if we have 32 heads, we don't have to load 32 value tensors or 32 key tensors. We can just load one value and one key. This is implemented in Falcon 7 billion. You can see and compare the code to BERT, which I found really interesting. We end up storing much less data, loading less data from HBM, and caching less intermediate results during decoding. We use less memory on the GPU, increase decoding speed, and MQA is reported to be up to 12x faster than MHA.
The cons are a small accuracy drop because we use fewer parameters. Each head has a single K and V instead of dedicated keys and values, so we have fewer parameters and can learn a little less. It's a compromise between accuracy and speed. Another issue is that you have to train models with MQA. You can't take a model trained with multi-head attention and run it with multi-query attention. If you have a model you like, you have to retrain it with MQA. Additionally, if you want to use techniques like tensor parallelism, you can't split K and V across GPUs because they are unique and need to be present on each node of the distributed cluster. However, MQA is a good step forward in reducing cache size, optimizing memory usage, and accelerating inference.
Next, we have group query attention (GQA). On the top graph, multi-head attention is on the left, with one keys and one values tensor per head. On the right, we have multi-query attention, with a single keys and values tensor for all heads. GQA is a middle ground. Instead of one per head or one for all heads, we can group them. For example, we could have one values and one key tensor for two, four, or eight heads. This becomes a hyperparameter you can set. The paper ran experiments on T5XXL and found that GQA could get almost the best of both worlds: almost the same performance as multi-head and almost as fast as multi-query. The sweet spot is between four and eight heads, which are almost as fast as multi-query but with a very nice optimization in terms of memory. This is implemented in LAMA2 and Mistral. Models can be uptrained from multi-head attention to GQA, and tensor parallelism is a better fit because you can split multiple value and key tensors across GPUs.
Sliding window attention works differently. In traditional vanilla attention, we compute attention scores for all token pairs. At inference time, we mask future tokens. We have a triangle-shaped attention mask. Sliding window attention limits the self-attention computation to a fixed window, which for Mistral is 4 kilobytes. This means we can't see more than 4k tokens from the previous layer. For very short sequences, this makes no difference. As the sequence length scales, the window starts applying. The maximum context size is the window size multiplied by the number of layers, which for Mistral is 131k. This reduces attention complexity from quadratic to linear, speeding up inference. We're not changing the number of queries and keys; we're shortening the attention span and propagating it across layers.
The next group of techniques focuses on rewriting the attention algorithm to make it faster. The first one is flash attention. The main problem is that high bandwidth memory is still too slow compared to on-GPU memory. Flash attention runs the self-attention computation on the GPU itself with minimal back and forth to HBM. It loads the matrices once and applies a clever tiling algorithm to compute the full matrix operations incrementally in static RAM, which is on the GPU and much faster. Once the final result is computed, it writes it back to HBM. This limits HBM memory accesses to a minimum and parallelizes everything over batch size and number of heads, leveraging GPU cores for significant speedup.
For example, if n is the sequence length, d is the embedding length, and m is the size of the SRAM cache, flash attention requires O(N^2 * D^2 / M) HBM memory accesses. If the SRAM cache is equal to the sequence length, we cancel one of the Ns, and the memory complexity becomes linear with respect to sequence length. Flash attention is 2 to 4x faster in terms of inference and saves 10 to 20x memory, allowing for increased batch size. It optimizes both the forward and backward passes, accelerating training. Flash attention is available in our text generation inference server.
Flash attention 2 is another round of optimizing flash attention. It minimizes operations that are not matrix multiplications, reducing scalar or vector operations. It's rewritten to reduce the number of operations that can't be fully parallelized and accelerated by the GPU. It also optimizes for MQA and GQA, reducing the amount of keys and queries that need to be processed. Flash attention 2 is 2x faster than the previous version and up to 9x faster than standard attention. It's available in TGI.
Page attention takes a different approach. The KV cache, which stores intermediate keys and values, grows and shrinks dynamically, leading to memory fragmentation. Page attention solves this by chunking memory into fixed-size blocks called pages, reducing external and internal memory fragmentation. This allows for efficient memory allocation and deallocation, similar to virtual memory systems in operating systems. This was introduced in the VLLM project and is available in Hugging Face TGI. Managing memory well on a GPU is critical for increasing batch size and accelerating operations.
One more thing: we have an amazing LLM performance leaderboard on the Hugging Face hub. They added useful graphs showing the time to generate tokens versus model performance, memory requirements, latency, throughput, and generation quality. Studying this will help you see the impact of the techniques discussed, from optimizing model architectures to managing memory with page attention.
That's it for today. If you enjoyed the video, please give it a thumbs up, subscribe, enable notifications, share the video with your colleagues, and I'll see you soon with more content. Thank you and keep rocking.