手撕经典算法 #3 Transformer篇

本文最后更新于:2024年7月8日 中午

本文在前两章的基础上,对 Transformer 模型进行了简单的实现和注释。包括:

  • Embedding 层
  • Encoder 层
  • Decoder 层
  • 堆叠 Encoder
  • 堆叠 Decoder
  • 完整 Transformer

Embedding 层

Transformer 模型的基础组件之一是嵌入层。Token Embedding 将输入的单词或标记转换为向量表示,Positional Embedding 则为输入的每个位置添加位置信息,以便模型理解序列的顺序。

以下是 Token Embedding 的代码实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
from torch import nn

class TokenEmbedding(nn.Module):
def __init__(self, vocab_size, hidden_size):
super().__init__()
self.embedding = nn.Embedding(vocab_size, hidden_size) # 嵌入层

def forward(self, x):
# x 形状: (batch_size, seq_len)
embedded = self.embedding(x) # 嵌入后的形状: (batch_size, seq_len, hidden_size)
return embedded

def test_token_embedding():
vocab_size = 10000 # 词汇表大小
hidden_size = 512 # 嵌入维度
batch_size = 2
seq_len = 4

# 随机生成输入数据 (batch_size, seq_len)
x = torch.randint(0, vocab_size, (batch_size, seq_len))

# 创建 TokenEmbedding 模块
token_embedding = TokenEmbedding(vocab_size, hidden_size)

# 计算嵌入输出
output = token_embedding(x)

print("Input shape:", x.shape)
print("Output shape:", output.shape)

if __name__ == "__main__":
test_token_embedding()

在 Transformer 模型中,位置编码被添加到输入的嵌入表示中。位置编码的计算通常基于固定函数,例如正弦和余弦函数。这些函数确保不同位置的编码是不同的,同时保持一定的周期性和对称性

具体地,位置编码矩阵 \(\mathbf{PE}\) 的每个元素由以下公式计算:

\[ \mathbf{PE}_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{\frac{2i}{d_{\text{model}}}}}\right) \] \[ \mathbf{PE}_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{\frac{2i}{d_{\text{model}}}}}\right) \]

其中: - \(pos\) 表示位置索引。 - \(i\) 表示维度索引。 - \(d_{\text{model}}\) 表示嵌入维度的大小。

这些公式确保了不同位置的编码是独特的,并且具有不同频率的正弦和余弦成分。以下是 Positional Embedding 的代码实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import math
import torch
from torch import nn

class PositionalEmbedding(nn.Module):
def __init__(self, max_len, hidden_size):
super().__init__()
self.hidden_size = hidden_size

# 创建位置编码表,大小为 (max_len, hidden_size)
# position: (max_len, 1),表示序列中的位置索引,例如 [[0.], [1.], [2.], ...]
position = torch.arange(0, max_len).unsqueeze(1).float()

# div_term: (hidden_size / 2),用于计算位置编码的分母
div_term = torch.exp(torch.arange(0, hidden_size, 2).float() * (-math.log(10000.0) / hidden_size))

# 初始化位置编码矩阵 pe 为零矩阵,大小为 (max_len, hidden_size)
pe = torch.zeros(max_len, hidden_size)

# 计算位置编码矩阵,广播机制将 dive_term 扩展为 (1, hidden_size )
# 偶数索引列使用 sin 函数
pe[:, 0::2] = torch.sin(position * div_term)
# 奇数索引列使用 cos 函数
pe[:, 1::2] = torch.cos(position * div_term)

# 将位置编码矩阵注册为 buffer,模型训练时不会更新它
self.register_buffer('pe', pe)

def forward(self, x):
# x 的形状: (batch_size, seq_len, hidden_size)
seq_len = x.size(1)

# 将位置编码加到输入张量上
# self.pe[:seq_len, :] 的形状为 (seq_len, hidden_size)
# unsqueeze(0) 使其形状变为 (1, seq_len, hidden_size),便于与输入张量相加
x = x + self.pe[:seq_len, :].unsqueeze(0)

# 返回加上位置编码后的张量
return x

# 测试 PositionalEmbedding 的函数
def test_positional_embedding():
max_len = 5000 # 最大序列长度
hidden_size = 512 # 嵌入维度
batch_size = 2
seq_len = 4

# 随机生成输入数据,形状为 (batch_size, seq_len, hidden_size)
x = torch.randn(batch_size, seq_len, hidden_size)

# 创建 PositionalEmbedding 模块实例
positional_embedding = PositionalEmbedding(max_len, hidden_size)

# 计算位置嵌入输出
output = positional_embedding(x)

# 打印输入和输出的形状
print("Input shape:", x.shape)
print("Output shape:", output.shape)

# 如果此模块是主模块,则运行测试函数
if __name__ == "__main__":
test_positional_embedding()

Encoder 层

Transformer 的 Encoder 由多个子层组成,包括多头注意力机制(Multi-Head Attention)、前馈神经网络(Feed Forward Neural Network)以及归一化层(Layer Normalization)。

Transformer 结构示意图

代码实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
from torch import nn

class EncoderLayer(nn.Module):
def __init__(self, hidden_size, num_heads, ff_size, dropout_prob=0.1):
super().__init__()
self.multi_head_attention = MultiHeadAttention(hidden_size, num_heads) # 多头注意力层
self.dropout1 = nn.Dropout(dropout_prob) # Dropout 层
self.layer_norm1 = nn.LayerNorm(hidden_size) # LayerNorm 层

self.feed_forward = nn.Sequential(
nn.Linear(hidden_size, ff_size), # 前馈层1
nn.ReLU(), # 激活函数
nn.Linear(ff_size, hidden_size) # 前馈层2
)
self.dropout2 = nn.Dropout(dropout_prob) # Dropout 层
self.layer_norm2 = nn.LayerNorm(hidden_size) # LayerNorm 层

def forward(self, x, attention_mask=None):
# 多头注意力子层
attn_output = self.multi_head_attention(x, attention_mask) # (batch_size, seq_len, hidden_size)
attn_output = self.dropout1(attn_output) # Dropout
out1 = self.layer_norm1(x + attn_output) # 残差连接 + LayerNorm

# 前馈神经网络子层
ff_output = self.feed_forward(out1) # (batch_size, seq_len, hidden_size)
ff_output = self.dropout2(ff_output) # Dropout
out2 = self.layer_norm2(out1 + ff_output) # 残差连接 + LayerNorm

return out2

# 测试 EncoderLayer 的 main 函数
def main():
batch_size = 2
seq_len = 4
hidden_size = 512
num_heads = 8
ff_size = 2048

# 随机生成输入数据 (batch_size, seq_len, hidden_size)
x = torch.randn(batch_size, seq_len, hidden_size)

# 创建 EncoderLayer 模块
encoder_layer = EncoderLayer(hidden_size, num_heads, ff_size)

# 计算 EncoderLayer 输出
output = encoder_layer(x)

print("Input shape:", x.shape)
print("Output shape:", output.shape)

if __name__ == "__main__":
main()

Decoder 层

Transformer 的 Decoder 层和 Encoder 层有类似的结构,但 Decoder 层除了包含多头自注意力机制和前馈神经网络,还增加了一个用于编码器-解码器注意力机制的多头注意力子层。这使得 Decoder 层能够同时关注当前输出序列的上下文信息和输入序列的编码信息。多个 Decoder 层堆叠在一起构成整个 Decoder 模块。

代码实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
from torch import nn

class DecoderLayer(nn.Module):
def __init__(self, hidden_size, num_heads, ff_size, dropout_prob=0.1):
super().__init__()
self.self_attention = MultiHeadAttention(hidden_size, num_heads) # 自注意力层
self.dropout1 = nn.Dropout(dropout_prob) # Dropout 层
self.layer_norm1 = nn.LayerNorm(hidden_size) # LayerNorm 层

self.encoder_decoder_attention = MultiHeadAttention(hidden_size, num_heads) # 编码器-解码器注意力层
self.dropout2 = nn.Dropout(dropout_prob) # Dropout 层
self.layer_norm2 = nn.LayerNorm(hidden_size) # LayerNorm 层

self.feed_forward = nn.Sequential(
nn.Linear(hidden_size, ff_size), # 前馈层1
nn.ReLU(), # 激活函数
nn.Linear(ff_size, hidden_size) # 前馈层2
)
self.dropout3 = nn.Dropout(dropout_prob) # Dropout 层
self.layer_norm3 = nn.LayerNorm(hidden_size) # LayerNorm 层

def forward(self, x, encoder_output, self_attention_mask=None, encoder_attention_mask=None):
# 自注意力子层
self_attn_output = self.self_attention(x, self_attention_mask) # (batch_size, seq_len, hidden_size)
self_attn_output = self.dropout1(self_attn_output) # Dropout
out1 = self.layer_norm1(x + self_attn_output) # 残差连接 + LayerNorm

# 编码器-解码器注意力子层
enc_dec_attn_output = self.encoder_decoder_attention(out1, encoder_output, encoder_attention_mask) # (batch_size, seq_len, hidden_size)
enc_dec_attn_output = self.dropout2(enc_dec_attn_output) # Dropout
out2 = self.layer_norm2(out1 + enc_dec_attn_output) # 残差连接 + LayerNorm

# 前馈神经网络子层
ff_output = self.feed_forward(out2) # (batch_size, seq_len, hidden_size)
ff_output = self.dropout3(ff_output) # Dropout
out3 = self.layer_norm3(out2 + ff_output) # 残差连接 + LayerNorm

return out3

# 测试 DecoderLayer 的 main 函数
def main():
batch_size = 2
seq_len = 4
hidden_size = 512
num_heads = 8
ff_size = 2048

# 随机生成输入数据 (batch_size, seq_len, hidden_size)
x = torch.randn(batch_size, seq_len, hidden_size)
encoder_output = torch.randn(batch_size, seq_len, hidden_size)

# 创建 DecoderLayer 模块
decoder_layer = DecoderLayer(hidden_size, num_heads, ff_size)

# 计算 DecoderLayer 输出
output = decoder_layer(x, encoder_output)

print("Input shape:", x.shape)
print("Encoder output shape:", encoder_output.shape)
print("Output shape:", output.shape)

if __name__ == "__main__":
main()

堆叠 Encoder

多个 Encoder 层堆叠在一起构成整个 Encoder 模块。

代码实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class Encoder(nn.Module):
def __init__(self, hidden_size, num_heads, ff_size, num_layers, dropout_prob=0.1):
super().__init__()
self.layers = nn.ModuleList([
EncoderLayer(hidden_size, num_heads, ff_size, dropout_prob)
for _ in range(num_layers)
]) # 堆叠多个 EncoderLayer

def forward(self, x, attention_mask=None):
for layer in self.layers:
x = layer(x, attention_mask) # 逐层传递输入

return x

# 测试 Encoder 的 main 函数
def main():
batch_size = 2
seq_len = 4
hidden_size = 512
num_heads = 8
ff_size = 2048
num_layers = 6

# 随机生成输入数据 (batch_size, seq_len, hidden_size)
x = torch.randn(batch_size, seq_len, hidden_size)

# 创建 Encoder 模块
encoder = Encoder(hidden_size, num_heads, ff_size, num_layers)

# 计算 Encoder 输出
output = encoder(x)

print("Input shape:", x.shape)
print("Output shape:", output.shape)

if __name__ == "__main__":
main()

堆叠 Decoder

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
class Decoder(nn.Module):
def __init__(self, hidden_size, num_heads, ff_size, num_layers, dropout_prob=0.1):
super().__init__()
self.layers = nn.ModuleList([
DecoderLayer(hidden_size, num_heads, ff_size, dropout_prob)
for _ in range(num_layers)
]) # 堆叠多个 DecoderLayer

def forward(self, x, encoder_output, self_attention_mask=None, encoder_attention_mask=None):
for layer in self.layers:
x = layer(x, encoder_output, self_attention_mask, encoder_attention_mask) # 逐层传递输入

return x

# 测试 Decoder 的 main 函数
def main():
batch_size = 2
seq_len = 4
hidden_size = 512
num_heads = 8
ff_size = 2048
num_layers = 6

# 随机生成输入数据 (batch_size, seq_len, hidden_size)
x = torch.randn(batch_size, seq_len, hidden_size)
encoder_output = torch.randn(batch_size, seq_len, hidden_size)

# 创建 Decoder 模块
decoder = Decoder(hidden_size, num_heads, ff_size, num_layers)

# 计算 Decoder 输出
output = decoder(x, encoder_output)

print("Input shape:", x.shape)
print("Encoder output shape:", encoder_output.shape)
print("Output shape:", output.shape)

if __name__ == "__main__":
main()

Transformer

Transformer 由三个主要部分组成:输入嵌入(Token Embedding 和 Positional Embedding)、Encoder 堆叠、Decoder 堆叠。下面将这些部分组合在一起,实现一个完整的 Transformer 类。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch
from torch import nn

class Transformer(nn.Module):
def __init__(self, vocab_size, hidden_size, num_heads, ff_size, num_layers, max_seq_len, dropout_prob=0.1):
super().__init__()
self.token_embedding = TokenEmbedding(vocab_size, hidden_size) # Token Embedding 层
self.positional_embedding = PositionalEmbedding(hidden_size, max_seq_len) # Positional Embedding 层

self.encoder = Encoder(hidden_size, num_heads, ff_size, num_layers, dropout_prob) # Encoder 堆叠
self.decoder = Decoder(hidden_size, num_heads, ff_size, num_layers, dropout_prob) # Decoder 堆叠

self.output_linear = nn.Linear(hidden_size, vocab_size) # 输出线性层

def forward(self, src, tgt, src_mask=None, tgt_mask=None, src_tgt_mask=None):
# 获取输入序列的嵌入表示
src_emb = self.token_embedding(src) + self.positional_embedding(src) # (batch_size, src_seq_len, hidden_size)
tgt_emb = self.token_embedding(tgt) + self.positional_embedding(tgt) # (batch_size, tgt_seq_len, hidden_size)

# 通过 Encoder 获取编码器输出
encoder_output = self.encoder(src_emb, src_mask) # (batch_size, src_seq_len, hidden_size)

# 通过 Decoder 获取解码器输出
decoder_output = self.decoder(tgt_emb, encoder_output, tgt_mask, src_tgt_mask) # (batch_size, tgt_seq_len, hidden_size)

# 线性层映射到词汇表大小
output = self.output_linear(decoder_output) # (batch_size, tgt_seq_len, vocab_size)

return output

# 测试 Transformer 的 main 函数
def main():
vocab_size = 10000
hidden_size = 512
num_heads = 8
ff_size = 2048
num_layers = 6
max_seq_len = 100
dropout_prob = 0.1

batch_size = 2
src_seq_len = 10
tgt_seq_len = 10

# 随机生成源序列和目标序列 (batch_size, seq_len)
src = torch.randint(0, vocab_size, (batch_size, src_seq_len))
tgt = torch.randint(0, vocab_size, (batch_size, tgt_seq_len))

# 创建 Transformer 模块
transformer = Transformer(vocab_size, hidden_size, num_heads, ff_size, num_layers, max_seq_len, dropout_prob)

# 计算 Transformer 输出
output = transformer(src, tgt)

print("Source shape:", src.shape)
print("Target shape:", tgt.shape)
print("Output shape:", output.shape)

if __name__ == "__main__":
main()

手撕经典算法 #3 Transformer篇
https://hwcoder.top/Manual-Coding-3
作者
Wei He
发布于
2024年7月8日
许可协议