LLM随着最近deepseek的热潮被推上了新的高峰,楼主在学习之中对于attention的理解也有着不同的见解,借此机会重新试着手写一下attention。

见山是山,Attention

Atttetion 的公式如下所示,是大家都很熟悉的在Attention is all you need文章中的

Atttention(Q,K,V)=softmax(QKTdk)VAtttention(Q, K, V) = softmax(\frac{QK^T}{d_k})V

对于此的复现也是较为容易,我们可以构建一个Attention的类来根据公式进行复现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import torch
import torch.nn as nn
import math
class AttentionClass(nn.Module):
def __init__(self, hidden_dim: int = 748) -> None:
super().__init__()
# 初始化 Q K V linear
self.hidden_dim = hidden_dim
# X batch_size, seq, hidden_dim
self.query_proj = nn.Linear(self.hidden_dim, self.hidden_dim)
self.key_proj = nn.Linear(self.hidden_dim, self.hidden_dim)
self.value_proj = nn.Linear(self.hidden_dim, self.hidden_dim)
def forward(self, X):
# X.shape is batch_size, seq_len, hidden_dim
Q = self.query_proj(X)
K = self.key_proj(X)
V = self.value_proj(X)

attention_value = torch.matmul(
# batch_size, seq_len, seq_len
Q, K.permute(0,2,1),
)
attention_weight = torch.softmax(
attention_value / math.sqrt(self.hidden_dim),
2
)

attention = torch.matmul(
# batch_size, seq_len, hidden_dim
attention_weight, V
)
return attention
# test
n = torch.rand(2,3,4)
s = AttentionClass(4) # hidden_dim 保持一致
s(n)
'''
tensor([[[ 0.1558, -0.5328, 0.5522, -0.0879],
[ 0.1612, -0.5310, 0.5441, -0.0869],
[ 0.1572, -0.5307, 0.5501, -0.0876]],

[[ 0.1985, -0.6046, 0.5730, 0.0433],
[ 0.1986, -0.6053, 0.5728, 0.0442],
[ 0.1987, -0.6045, 0.5731, 0.0433]]], grad_fn=<UnsafeViewBackward0>)
'''

第一重境界就是按照公式进行代码实现就行了,需要注意的点是attention_weight处需要对 hidden_dim开根号,是为了防止 attention_weight 的梯度过大或是梯度消失。

山山而川,效率优化

做效率优化主要是为了让代码更简洁,也减少内部的运算,可以注意到,在上述的函数中,QKV是单独使用 Linear 计算出来的,这部分是可以的进行简化的。这部分我们可以使用 torch.split()进行实现。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch
import math
import torch.nn as nn
class SelfAttention_V2(nn.Module):
def __init__(self, emb_dim):
super().__init__()
self.emb_dim = emb_dim
self.QKV = nn.Linear(self.emb_dim, self.emb_dim * 3)

def forward(self, X):
QKV = self.QKV(X)
Q, K, V = torch.split(QKV, self.emb_dim, -1)
attention_weight = torch.softmax(
torch.matmul(Q, K.permute(0,2,1)) / math.sqrt(self.emb_dim),
-1
)
attention = torch.matmul(attention_weight, V)
return attention
# test
n = torch.rand(2,3,4)
s = SelfAttention_V2(4)
s(n)

精雕细琢

以上两重我们对self-attention进行了复现和简化,但是仍然有很多细节,并为雕琢上去。
1.dropout。self-attention是一个极强的拟合器,所以十分容易出现过拟合的问题,故,dropout是十分重要的一环,可以增强整个模型的泛化能力。
2.mask。mask是自注意力机制中非常重要的一环,主要作用是用于屏蔽不需要被关注的部分。
3.矩阵映射。指进行了self-attention计算之后对attention-result加一层线性函数进行矩阵映射。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch
import torch.nn as nn
import math

class SelfAttentionV3(nn.Module):
def __init__(self, dim, dropout_rate = 0.1, output_proj = None, **kwargs) -> None:
super().__init__()
self.dim = dim

self.proj = nn.Linear(self.dim, self.dim * 3)
if dropout_rate is not None:
self.dropout = nn.Dropout(dropout_rate)

if output_proj is not None:
self.output_proj = nn.Linear(dim, dim)

def forward(self, X, masked):
QKV = self.proj(X)
Q,K,V = torch.split(QKV, self.dim, -1)

attention_weight = torch.softmax( # batch seq seq
torch.matmul(Q, K.permute(0,2,1))/math.sqrt(self.dim),
-1
)
# bat seq
if masked is not None:
attention_weight = attention_weight.masked_fill(
masked == 0,
1e-10
)
if self.dropout:
attention_weight = self.dropout(attention_weight)

attention_result = torch.matmul(attention_weight, V)
if self.output_proj:
attention_result = self.output_proj(attention_result)
return attention_result
# test
n = torch.rand(2,3,4)
s = SelfAttentionV3(4, output_proj = True)
mask = torch.tensor(
[
[01, 1],
[001],
]
)
mask = mask.unsqueeze(1).repeat(1, 3, 1)

s(n,mask)

面试还需八股文

面试的时候让手写的 self-attention 一般如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch
import torch.nn as nn

class SelfAttentionInterview(nn.Module):
def __init__(self, emb_dim, dropout_rate = None, **kwargs):
super().__init__()
self.emb_dim = emb_dim
self.query_proj = nn.Linear(self.emb_dim, self.emb_dim)
self.key_proj = nn.Linear(self.emb_dim, self.emb_dim)
self.key_proj = nn.Linear(self.emb_dim, self.emb_dim)
self.dropout_rate = dropout_rate
if self.dropout_rate is not None:

self.dropout = nn.Dropout(self.dropout_rate)

def forward(self, X, masked = None) -> None:

Q = self.query_proj(X)
K = self.query_proj(X)
V = self.query_proj(X)

attention_weight = Q @ K.permute(0, 2, 1) / torch.sqrt(torch.tensor(self.emb_dim))

if masked is not None:
attention_weight = attention_weight.masked_fill(
masked == 0,
float('-inf')
)
attention_weight = torch.softmax(attention_weight, -1)

if self.dropout_rate:
attention_weight = self.dropout(attention_weight)

attention_weight = attention_weight @ V
return attention_weight
# test
n = torch.rand(2,3,4)
s = SelfAttentionInterview(4, output_proj = True)
mask = torch.tensor(
[
[01, 1],
[001],
]
)
mask = mask.unsqueeze(1).repeat(1, 3, 1)

s(n,mask)