上回说到了 self-attention 的四种写法,这次更新一下 multi-head attention 的写法。multi-head attentionself-attention 最主要的不同是,我们输入的 X (Batch_size, seq_size, emb_size) 在 self-attention 的 QKV 均为 (Batch_size, seq_size, emb_size), 而在 multi-head attention 中,我们需要将 emb_size 拆分为 num_headhead_dim

20250223173238

多头注意力通过并行地运行多个独立的注意力机制来获取输入序列的不同子空间的注意力分布,从而更全面地捕获序列中潜在的多种语义关联。

代码如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
import torch.nn as nn
import math
class MultiHeadAttention(nn.Module):
def __init__(self, hidden_dim, dropout_rate, head_num, ):
super().__init__()
self.hidden_dim = hidden_dim
self.dropout_rate = dropout_rate
self.head_num = head_num
self.head_dim = self.hidden_dim // self.head_num
self.query_proj = nn.Linear(self.hidden_dim, self.head_num * self.head_dim)
self.key_proj = nn.Linear(self.hidden_dim, self.head_num * self.head_dim)
self.value_proj = nn.Linear(self.hidden_dim, self.head_num * self.head_dim)
self.output_proj = nn.Linear(self.head_num * self.head_dim, self.head_num * self.head_dim)


def forward(self, X, masked = None):

Q = self.query_proj(X)
K = self.query_proj(X)
V = self.query_proj(X)

batch_size, seq_size, _ = X.size()
Q = Q.view(batch_size, seq_size, self.head_num, self.head_dim).permute(0, 2, 1, 3)
K = K.view(batch_size, seq_size, self.head_num, self.head_dim).permute(0, 2, 1, 3)
V = V.view(batch_size, seq_size, self.head_num, self.head_dim).permute(0, 2, 1, 3)




attention_weight = torch.matmul(
Q, K.permute(0, 1, 3, 2)
)/ math.sqrt(self.head_dim) # K batch_size, self.head_num, self.head_dim, seq_size,
print(attention_weight.size())
if masked is not None:
attention_weight = attention_weight.masked_fill(
masked == 0,
float('-inf')
)
attention_weight = torch.softmax(attention_weight, -1)
attention_value = torch.matmul(attention_weight, V) # batch_size, self.head_num, seq_size, self.head_dim
print(attention_value.size())
attention_value = attention_value.permute(0, 2, 1, 3).contiguous()
attention_value = attention_value.view(batch_size, seq_size, self.head_num * self.head_dim)
attention_value = self.output_proj()
return attention_value
# test
n = torch.rand(2,3,4)
s = MultiHeadAttention(4, dropout_rate = 0, head_num = 2)

mask = torch.tensor(
[
[1, 1, 1],
[1, 0, 1],
[1, 0, 1],
]
)
mask = mask.unsqueeze(0).unsqueeze(0).expand(2, 2, 3, 3)
print(mask.size()) # 输出: torch.Size([2, 2, 3, 3])
s(n,mask)