attention的四重境界
|Word Count:1k|Reading Time:5mins|Post Views:
LLM随着最近deepseek的热潮被推上了新的高峰,楼主在学习之中对于attention的理解也有着不同的见解,借此机会重新试着手写一下attention。
见山是山,Attention
Atttetion 的公式如下所示,是大家都很熟悉的在Attention is all you need
文章中的
Atttention(Q,K,V)=softmax(dkQKT)V
对于此的复现也是较为容易,我们可以构建一个Attention的类来根据公式进行复现
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
| import torch import torch.nn as nn import math class AttentionClass(nn.Module): def __init__(self, hidden_dim: int = 748) -> None: super().__init__() self.hidden_dim = hidden_dim self.query_proj = nn.Linear(self.hidden_dim, self.hidden_dim) self.key_proj = nn.Linear(self.hidden_dim, self.hidden_dim) self.value_proj = nn.Linear(self.hidden_dim, self.hidden_dim) def forward(self, X): Q = self.query_proj(X) K = self.key_proj(X) V = self.value_proj(X) attention_value = torch.matmul( Q, K.permute(0,2,1), ) attention_weight = torch.softmax( attention_value / math.sqrt(self.hidden_dim), 2 )
attention = torch.matmul( attention_weight, V ) return attention
n = torch.rand(2,3,4) s = AttentionClass(4) s(n) ''' tensor([[[ 0.1558, -0.5328, 0.5522, -0.0879], [ 0.1612, -0.5310, 0.5441, -0.0869], [ 0.1572, -0.5307, 0.5501, -0.0876]],
[[ 0.1985, -0.6046, 0.5730, 0.0433], [ 0.1986, -0.6053, 0.5728, 0.0442], [ 0.1987, -0.6045, 0.5731, 0.0433]]], grad_fn=<UnsafeViewBackward0>) '''
|
第一重境界就是按照公式进行代码实现就行了,需要注意的点是attention_weight
处需要对 hidden_dim
开根号,是为了防止 attention_weight
的梯度过大或是梯度消失。
山山而川,效率优化
做效率优化主要是为了让代码更简洁,也减少内部的运算,可以注意到,在上述的函数中,QKV是单独使用 Linear
计算出来的,这部分是可以的进行简化的。这部分我们可以使用 torch.split()
进行实现。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
| import torch import math import torch.nn as nn class SelfAttention_V2(nn.Module): def __init__(self, emb_dim): super().__init__() self.emb_dim = emb_dim self.QKV = nn.Linear(self.emb_dim, self.emb_dim * 3)
def forward(self, X): QKV = self.QKV(X) Q, K, V = torch.split(QKV, self.emb_dim, -1) attention_weight = torch.softmax( torch.matmul(Q, K.permute(0,2,1)) / math.sqrt(self.emb_dim), -1 ) attention = torch.matmul(attention_weight, V) return attention
n = torch.rand(2,3,4) s = SelfAttention_V2(4) s(n)
|
精雕细琢
以上两重我们对self-attention
进行了复现和简化,但是仍然有很多细节,并为雕琢上去。
1.dropout。self-attention
是一个极强的拟合器,所以十分容易出现过拟合的问题,故,dropout是十分重要的一环,可以增强整个模型的泛化能力。
2.mask。mask是自注意力机制中非常重要的一环,主要作用是用于屏蔽不需要被关注的部分。
3.矩阵映射。指进行了self-attention
计算之后对attention-result
加一层线性函数进行矩阵映射。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
| import torch import torch.nn as nn import math
class SelfAttentionV3(nn.Module): def __init__(self, dim, dropout_rate = 0.1, output_proj = None, **kwargs) -> None: super().__init__() self.dim = dim
self.proj = nn.Linear(self.dim, self.dim * 3) if dropout_rate is not None: self.dropout = nn.Dropout(dropout_rate) if output_proj is not None: self.output_proj = nn.Linear(dim, dim)
def forward(self, X, masked): QKV = self.proj(X) Q,K,V = torch.split(QKV, self.dim, -1)
attention_weight = torch.softmax( torch.matmul(Q, K.permute(0,2,1))/math.sqrt(self.dim), -1 ) if masked is not None: attention_weight = attention_weight.masked_fill( masked == 0, 1e-10 ) if self.dropout: attention_weight = self.dropout(attention_weight) attention_result = torch.matmul(attention_weight, V) if self.output_proj: attention_result = self.output_proj(attention_result) return attention_result
n = torch.rand(2,3,4) s = SelfAttentionV3(4, output_proj = True) mask = torch.tensor( [ [0,1, 1], [0,0,1], ] ) mask = mask.unsqueeze(1).repeat(1, 3, 1)
s(n,mask)
|
面试还需八股文
面试的时候让手写的 self-attention
一般如下所示:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
| import torch import torch.nn as nn
class SelfAttentionInterview(nn.Module): def __init__(self, emb_dim, dropout_rate = None, **kwargs): super().__init__() self.emb_dim = emb_dim self.query_proj = nn.Linear(self.emb_dim, self.emb_dim) self.key_proj = nn.Linear(self.emb_dim, self.emb_dim) self.key_proj = nn.Linear(self.emb_dim, self.emb_dim) self.dropout_rate = dropout_rate if self.dropout_rate is not None: self.dropout = nn.Dropout(self.dropout_rate) def forward(self, X, masked = None) -> None: Q = self.query_proj(X) K = self.query_proj(X) V = self.query_proj(X) attention_weight = Q @ K.permute(0, 2, 1) / torch.sqrt(torch.tensor(self.emb_dim))
if masked is not None: attention_weight = attention_weight.masked_fill( masked == 0, float('-inf') ) attention_weight = torch.softmax(attention_weight, -1)
if self.dropout_rate: attention_weight = self.dropout(attention_weight) attention_weight = attention_weight @ V return attention_weight
n = torch.rand(2,3,4) s = SelfAttentionInterview(4, output_proj = True) mask = torch.tensor( [ [0,1, 1], [0,0,1], ] ) mask = mask.unsqueeze(1).repeat(1, 3, 1)
s(n,mask)
|