手撕经典算法 #1 Attention篇

本文最后更新于:2024年7月8日 中午

本文对常见的几种注意力机制进行了简单的实现和注释,便于理解。包括:

  • 缩放点积注意力(Scaled Dot-Product Attention)
  • 多头注意力(Multi Head Attention,MHA)
    • 2017 年开山之作《Attention is all you need》所提出的一种 Attention 形式,可以说它是当前主流 LLM 的基础工作。每个头有自己单独的 Query、Key 和 Value 矩阵。
    • 在自回归 LLM 中通过 Mask 可以实现 Causal Attention,而在 Next Token Prediction 时,新预测的第 \(t+1\) 个 token 不会影响到已经算好的 \(k_{\le t},v_{\le t}\),因此这部分结果我们可以缓存下来供后续生成调用,避免不必要的重复计算,这就是所谓的 KV Cache
  • 多查询注意力(Multi Query Attention,MQA)
    • 围绕「如何减少 KV Cache 同时尽可能地保证效果」这个主题发展而来的产物。只有一组 key-value 对,由《Fast Transformer Decoding: One Write-Head is All You Need》在 2019 年提出。
    • 与 MHA 不同的是,MQA 让所有的头之间共享同一份 Key 和 Value 矩阵,每个头只单独保留了一份 Query 参数,从而大大减少 Key 和 Value 矩阵的参数量。使用 MQA 的模型包括 PaLM、StarCoder、Gemini 等。
  • 分组查询注意力(Grouped Query Attention,GQA)
    • 有人担心 MQA 对 KV Cache 的压缩太严重,于是提出了一个折中版本,出自 2023 年论文《GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints》
    • 其思想是将将所有 Head 分为 \(g\) 个组(\(g\) 可以整除 \(h\)),每组共享同一对 Key 和 Value 矩阵。当 \(g=h\) 时就是 MHA, \(g=1\) 时就是 MQA,当 \(1<g<h\) 时,它只将 KV Cache 压缩到 \(g / h\)压缩率不如 MQA,但同时也提供了更大的自由度,效果上更有保证。
    • GQA 最知名的使用者,大概是 Meta 开源的 Llama-2-70B,以及 Llama-3 全系列。在 Llama-2/3-70B 中,\(g=8\),可以部署到一台机器的 8 张卡上,每张卡负责计算一组 K、V 对应的 Attention Head,减少卡间通信。
  • 多头隐注意力(Multi-head Latent Attention,MLA)
    • 2024 年在 DeepSeek-V2 技术报告 中提到的新技术,用更一般的线性变换来替代了之前的操作,使得 \(k,v\) 都不需要被完整存储,进一步压缩了 KV Cache。

缩放点积注意力(SDPA)

缩放点积注意力早于 Transformer 被提出,受到的关注并不多,其内部只实现了 \(q,k,v\) 的注意力计算。

  • 输入是 query 和 key-value,注意力机制首先计算 query 与每个 key 的关联性
  • 每个关联性作为每个 value 的权重 (weight),各个权重与 value 的乘积相加得到输出。
  • SDPA 可以被认为是 MHA 的中间步骤
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch
from torch import nn

class ScaledDotProductAttention(nn.Module):
def __init__(self):
super().__init__()

def forward(self, query, key, value, attention_mask=None):
# query, key, value 形状: (batch_size, seq_len, hidden_size)

# 计算注意力分数
# key.transpose(-1, -2) 将最后两个维度进行转置,以进行点积
# attention_scores 形状: (batch_size, seq_len, seq_len)
d_k = query.size(-1) # 获取 hidden_size
attention_scores = torch.matmul(query, key.transpose(-1, -2)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))

# 添加注意力掩码(seq_len, seq_len),掩码位置(1)的值为负无穷
if attention_mask is not None:
attention_scores += attention_mask * -1e9

# 对注意力分数进行归一化,得到注意力概率
attention_probs = torch.softmax(attention_scores, dim=-1) # (batch_size, num_heads, seq_len, seq_len)

# 计算注意力输出,通过注意力概率加权值
attention_output = torch.matmul(attention_probs, value) # (batch_size, num_heads, seq_len, hidden_size)

return attention_output

def test_attn():
batch_size = 128
seq_len = 512
hidden_size = 1024

query = torch.randn(batch_size, seq_len, hidden_size) # (batch_size, seq_len, hidden_size)
key = torch.randn(batch_size, seq_len, hidden_size) # (batch_size, seq_len, hidden_size)
value = torch.randn(batch_size, seq_len, hidden_size) # (batch_size, seq_len, hidden_size)

sdpa = ScaledDotProductAttention()
output = sdpa(query, key, value)

print("Query shape:", query.shape)
print("Key shape:", key.shape)
print("Value shape:", value.shape)
print("Output shape:", output.shape)

if __name__ == "__main__":
test_attn()

多头注意力(MHA)

多头注意力机制是 Transformer 模型中的核心组件。在其设计中,「多头」意味着该机制并不只计算一种注意力权重,而是并行计算多种权重,每种权重都从不同的「视角」捕获输入的不同信息。具体步骤如下:

  1. 为输入序列中计算 \(Q, K, V\) ,这是通过将输入词向量与三个权重矩阵相乘实现的: \[ \begin{aligned} & Q = X W_q \\ & K = X W_k \\ & V = X W_v \end{aligned} \]

  2. 计算 \(Q, K\) 注意力得分,其中, \(d_k\)\(k\) 的维度: \[ \operatorname{score}(Q, K) = \frac{Q \cdot K^T}{\sqrt{d_k}} \]

  3. 使用 Softmax 得到注意力权重: \[ \operatorname{Attention}(Q, K) = \operatorname{softmax}(\operatorname{score}(Q, K))=\operatorname{softmax}\left(\frac{Q \cdot K^T}{\sqrt{d_k}}\right) \]

  4. 使用注意力权重和 \(V\),计算输出: \[ \text{Output} = \operatorname{Attention}(Q, K) \cdot V = \operatorname{softmax}\left(\frac{Q \cdot K^T}{\sqrt{d_k}}\right) \cdot V \]

  5. 拼接多头输出,乘以 \(W_O\),得到最终输出: \[ \text{MultiHeadOutput} = \text{Concat} (\text{Output}^1, \text{Output}^2, \ldots, \text{Output}^H) W_O \]

实现代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import torch
from torch import nn

class MultiHeadAttention(torch.nn.Module):
def __init__(self, hidden_size, num_heads):
super().__init__()
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads # 每个头的维度,二者必须整除

# 初始化 Q、K、V 的投影矩阵,将输入词向量线性变换为 Q、K、V,维度保持一致
self.q_linear = nn.Linear(hidden_size, hidden_size)
self.k_linear = nn.Linear(hidden_size, hidden_size)
self.v_linear = nn.Linear(hidden_size, hidden_size)

# 输出线性层,将拼接后的多头注意力输出变换为所需的输出维度,这里维度保持一致
self.o_linear = nn.Linear(hidden_size, hidden_size)

def forward(self, hidden_state, attention_mask=None):
# hidden_state 形状: (batch_size, seq_len, hidden_size)
batch_size = hidden_state.size(0) # 获取批量大小

# 计算 Q、K、V,线性变换
query = self.q_linear(hidden_state) # (batch_size, seq_len, hidden_size)
key = self.k_linear(hidden_state) # (batch_size, seq_len, hidden_size)
value = self.v_linear(hidden_state) # (batch_size, seq_len, hidden_size)

# 分割多头,将每个头的维度拆分出来
query = self.split_head(query) # (batch_size, num_heads, seq_len, head_dim)
key = self.split_head(key) # (batch_size, num_heads, seq_len, head_dim)
value = self.split_head(value) # (batch_size, num_heads, seq_len, head_dim)

# 计算注意力分数,使用缩放点积注意力机制
# attention_scores 形状: (batch_size, num_heads, seq_len, seq_len)
attention_scores = torch.matmul(query, key.transpose(-1, -2)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))

# 添加注意力掩码(seq_len, seq_len),掩码位置(1)的值为负无穷
if attention_mask is not None:
attention_scores += attention_mask * -1e9

# 对注意力分数进行归一化,得到注意力概率
attention_probs = torch.softmax(attention_scores, dim=-1) # (batch_size, num_heads, seq_len, seq_len)

# 计算注意力输出,通过注意力概率加权值
output = torch.matmul(attention_probs, value) # (batch_size, num_heads, seq_len, head_dim)

# 对多头注意力输出进行拼接
# output.transpose(1, 2) 将 num_heads 和 seq_len 维度转置
# 将形状调整为 (batch_size, seq_len, hidden_size)
output = output.transpose(1, 2).reshape(batch_size, -1, self.head_dim * self.num_heads)

# 通过线性层将拼接后的输出变换为所需的输出维度
output = self.o_linear(output) # (batch_size, seq_len, hidden_size)

return output

def split_head(self, x):
batch_size = x.size(0) # 获取批量大小
# x 形状: (batch_size, seq_len, hidden_size)
# 将 hidden_size 分割为 num_heads 和 head_dim
return x.reshape(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
# 返回形状: (batch_size, num_heads, seq_len, head_dim)

def test_MHA():
batch_size = 128
seq_len = 512
hidden_size = 1024
num_heads = 8

# 随机生成输入数据
hidden_state = torch.randn(batch_size, seq_len, hidden_size) # (batch_size, seq_len, hidden_size)

# 创建多头注意力模块
mha = MultiHeadAttention(hidden_size, num_heads)

# 计算多头注意力输出
output = mha(hidden_state)

print("Input shape:", hidden_state.shape)
print("Output shape:", output.shape)

if __name__ == "__main__":
test_MHA()

多查询注意力(MQA)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch
from torch import nn

class MultiQueryAttention(torch.nn.Module):
def __init__(self, hidden_size, num_heads):
super().__init__()
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads

# 初始化 Q、K、V 投影矩阵,注意这里的 K V 比原来更小
self.q_linear = nn.Linear(hidden_size, hidden_size)
self.k_linear = nn.Linear(hidden_size, self.head_dim)
self.v_linear = nn.Linear(hidden_size, self.head_dim)

self.o_linear = nn.Linear(hidden_size, hidden_size)

def forward(self, hidden_state, attention_mask=None):
batch_size = hidden_state.size(0)

query = self.q_linear(hidden_state) # (batch_size, seq_len, hidden_size)
key = self.k_linear(hidden_state) # (batch_size, seq_len, head_dim)
value = self.v_linear(hidden_state) # (batch_size, seq_len, head_dim)

# 分割头部,K V 矩阵也要加上一个维度
query = self.split_head(query) # (batch_size, num_heads, seq_len, head_dim)
key = self.split_head(key, 1) # (batch_size, 1, seq_len, head_dim)
value = self.split_head(value, 1) # (batch_size, 1, seq_len, head_dim)

# 计算注意力分数,自动广播,(batch_size, num_heads, seq_len, seq_len)
attention_scores = torch.matmul(query, key.transpose(-1, -2)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))

if attention_mask is not None:
attention_scores += attention_mask * -1e9

attention_probs = torch.softmax(attention_scores, dim=-1) # (batch_size, num_heads, seq_len, seq_len)

output = torch.matmul(attention_probs, value) # (batch_size, num_heads, seq_len, head_dim)

# 对注意力输出进行拼接,(batch_size, seq_len, hidden_size)
output = output.transpose(1, 2).reshape(batch_size, -1, self.head_dim * self.num_heads)

output = self.o_linear(output) # (batch_size, seq_len, hidden_size)

return output

def split_head(self, x, head_num=None):
batch_size = x.size(0) # 获取批量大小
if head_num is None:
head_num = self.num_heads # 默认使用类中的 num_heads

# 返回形状: (batch_size, head_num, seq_len, head_dim)
return x.reshape(batch_size, -1, head_num, self.head_dim).transpose(1, 2)

分组查询注意力(GQA)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import torch
from torch import nn

class GroupQueryAttention(torch.nn.Module):
def __init__(self, hidden_size, num_heads, group_num):
super().__init__()
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.group_num = group_num # 组的数量

# 初始化 Q、K、V 投影矩阵,注意这里的 K V 做了折衷
self.q_linear = nn.Linear(hidden_size, hidden_size) # (hidden_size, hidden_size)
self.k_linear = nn.Linear(hidden_size, self.group_num * self.head_dim) # (hidden_size, group_num * head_dim)
self.v_linear = nn.Linear(hidden_size, self.group_num * self.head_dim) # (hidden_size, group_num * head_dim)

self.o_linear = nn.Linear(hidden_size, hidden_size) # (hidden_size, hidden_size)

def forward(self, hidden_state, attention_mask=None):
batch_size = hidden_state.size(0)

query = self.q_linear(hidden_state) # (batch_size, seq_len, hidden_size)
key = self.k_linear(hidden_state) # (batch_size, seq_len, group_num * head_dim)
value = self.v_linear(hidden_state) # (batch_size, seq_len, group_num * head_dim)

# 分割头部,将每个头的维度拆分出来
query = self.split_head(query) # (batch_size, num_heads, seq_len, head_dim)
key = self.split_head(key, self.group_num) # (batch_size, num_heads, seq_len, head_dim)
value = self.split_head(value, self.group_num) # (batch_size, num_heads, seq_len, head_dim)

# 计算注意力分数,自动广播,(batch_size, num_heads, seq_len, seq_len)
attention_scores = torch.matmul(query, key.transpose(-1, -2)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))

if attention_mask is not None:
attention_scores += attention_mask * -1e9

attention_probs = torch.softmax(attention_scores, dim=-1) # (batch_size, num_heads, seq_len, seq_len)

output = torch.matmul(attention_probs, value) # (batch_size, num_heads, seq_len, head_dim)

# 对注意力输出进行拼接,形状: (batch_size, seq_len, hidden_size)
output = output.transpose(1, 2).reshape(batch_size, -1, self.head_dim * self.num_heads)

# 通过线性层将拼接后的输出变换为所需的输出维度
output = self.o_linear(output) # (batch_size, seq_len, hidden_size)

return output

def split_head(self, x, group_num=None):
batch_size, seq_len = x.size()[:2] # 获取批量大小和序列长度

if group_num is None:
return x.reshape(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
else:
# 将 hidden_size 分割为 group_num 和 head_dim
x = x.reshape(batch_size, -1, group_num, self.head_dim).transpose(1, 2)
# 再将其手动 expand 到相同大小
x = x[:, :, None, :, :].expand(batch_size, group_num, self.num_heads // group_num, seq_len, self.head_dim).reshape(batch_size, self.num_heads, seq_len, self.head_dim)
return x # 形状: (batch_size, num_heads, seq_len, head_dim)

多头隐注意力(MLA)

1
TODO

手撕经典算法 #1 Attention篇
https://hwcoder.top/Manual-Coding-1
作者
Wei He
发布于
2024年7月6日
许可协议