手撕经典算法 #1 Attention篇
本文最后更新于:2024年7月8日 中午
本文对常见的几种注意力机制进行了简单的实现和注释,便于理解。包括:
- 缩放点积注意力(Scaled Dot-Product Attention)
- 2014 年《Neural Machine Translation by Jointly Learning to Align and Translate》提出的单头注意力,输入的 Query、Key 和 Value 矩阵都是完整的张量。
- 多头注意力(Multi Head Attention,MHA)
- 2017 年开山之作《Attention is all you need》所提出的一种 Attention 形式,可以说它是当前主流 LLM 的基础工作。每个头有自己单独的 Query、Key 和 Value 矩阵。
- 在自回归 LLM 中通过 Mask 可以实现 Causal Attention,而在 Next Token Prediction 时,新预测的第 \(t+1\) 个 token 不会影响到已经算好的 \(k_{\le t},v_{\le t}\),因此这部分结果我们可以缓存下来供后续生成调用,避免不必要的重复计算,这就是所谓的 KV Cache。
- 多查询注意力(Multi Query Attention,MQA)
- 围绕「如何减少 KV Cache 同时尽可能地保证效果」这个主题发展而来的产物。只有一组 key-value 对,由《Fast Transformer Decoding: One Write-Head is All You Need》在 2019 年提出。
- 与 MHA 不同的是,MQA 让所有的头之间共享同一份 Key 和 Value 矩阵,每个头只单独保留了一份 Query 参数,从而大大减少 Key 和 Value 矩阵的参数量。使用 MQA 的模型包括 PaLM、StarCoder、Gemini 等。
- 分组查询注意力(Grouped Query Attention,GQA)
- 有人担心 MQA 对 KV Cache 的压缩太严重,于是提出了一个折中版本,出自 2023 年论文《GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints》。
- 其思想是将将所有 Head 分为 \(g\) 个组(\(g\) 可以整除 \(h\)),每组共享同一对 Key 和 Value 矩阵。当 \(g=h\) 时就是 MHA, \(g=1\) 时就是 MQA,当 \(1<g<h\) 时,它只将 KV Cache 压缩到 \(g / h\) ,压缩率不如 MQA,但同时也提供了更大的自由度,效果上更有保证。
- GQA 最知名的使用者,大概是 Meta 开源的 Llama-2-70B,以及 Llama-3 全系列。在 Llama-2/3-70B 中,\(g=8\),可以部署到一台机器的 8 张卡上,每张卡负责计算一组 K、V 对应的 Attention Head,减少卡间通信。
- 多头隐注意力(Multi-head Latent Attention,MLA)
- 2024 年在 DeepSeek-V2 技术报告 中提到的新技术,用更一般的线性变换来替代了之前的操作,使得 \(k,v\) 都不需要被完整存储,进一步压缩了 KV Cache。
缩放点积注意力(SDPA)
缩放点积注意力早于 Transformer 被提出,受到的关注并不多,其内部只实现了 \(q,k,v\) 的注意力计算。
- 输入是 query 和 key-value,注意力机制首先计算 query 与每个 key 的关联性
- 每个关联性作为每个 value 的权重 (weight),各个权重与 value 的乘积相加得到输出。
- SDPA 可以被认为是 MHA 的中间步骤!
1 |
|
多头注意力(MHA)
多头注意力机制是 Transformer 模型中的核心组件。在其设计中,「多头」意味着该机制并不只计算一种注意力权重,而是并行计算多种权重,每种权重都从不同的「视角」捕获输入的不同信息。具体步骤如下:
为输入序列中计算 \(Q, K, V\) ,这是通过将输入词向量与三个权重矩阵相乘实现的: \[ \begin{aligned} & Q = X W_q \\ & K = X W_k \\ & V = X W_v \end{aligned} \]
计算 \(Q, K\) 注意力得分,其中, \(d_k\) 是 \(k\) 的维度: \[ \operatorname{score}(Q, K) = \frac{Q \cdot K^T}{\sqrt{d_k}} \]
使用 Softmax 得到注意力权重: \[ \operatorname{Attention}(Q, K) = \operatorname{softmax}(\operatorname{score}(Q, K))=\operatorname{softmax}\left(\frac{Q \cdot K^T}{\sqrt{d_k}}\right) \]
使用注意力权重和 \(V\),计算输出: \[ \text{Output} = \operatorname{Attention}(Q, K) \cdot V = \operatorname{softmax}\left(\frac{Q \cdot K^T}{\sqrt{d_k}}\right) \cdot V \]
拼接多头输出,乘以 \(W_O\),得到最终输出: \[ \text{MultiHeadOutput} = \text{Concat} (\text{Output}^1, \text{Output}^2, \ldots, \text{Output}^H) W_O \]
实现代码如下:
1 |
|
多查询注意力(MQA)
1 |
|
分组查询注意力(GQA)
1 |
|
多头隐注意力(MLA)
1 |
|