1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
| import torch import torch.nn as nn import math class MultiHeadAttention(nn.Module): def __init__(self, hidden_dim, dropout_rate, head_num, ): super().__init__() self.hidden_dim = hidden_dim self.dropout_rate = dropout_rate self.head_num = head_num self.head_dim = self.hidden_dim // self.head_num self.query_proj = nn.Linear(self.hidden_dim, self.head_num * self.head_dim) self.key_proj = nn.Linear(self.hidden_dim, self.head_num * self.head_dim) self.value_proj = nn.Linear(self.hidden_dim, self.head_num * self.head_dim) self.output_proj = nn.Linear(self.head_num * self.head_dim, self.head_num * self.head_dim)
def forward(self, X, masked = None):
Q = self.query_proj(X) K = self.query_proj(X) V = self.query_proj(X)
batch_size, seq_size, _ = X.size() Q = Q.view(batch_size, seq_size, self.head_num, self.head_dim).permute(0, 2, 1, 3) K = K.view(batch_size, seq_size, self.head_num, self.head_dim).permute(0, 2, 1, 3) V = V.view(batch_size, seq_size, self.head_num, self.head_dim).permute(0, 2, 1, 3)
attention_weight = torch.matmul( Q, K.permute(0, 1, 3, 2) )/ math.sqrt(self.head_dim) print(attention_weight.size()) if masked is not None: attention_weight = attention_weight.masked_fill( masked == 0, float('-inf') ) attention_weight = torch.softmax(attention_weight, -1) attention_value = torch.matmul(attention_weight, V) print(attention_value.size()) attention_value = attention_value.permute(0, 2, 1, 3).contiguous() attention_value = attention_value.view(batch_size, seq_size, self.head_num * self.head_dim) attention_value = self.output_proj() return attention_value
n = torch.rand(2,3,4) s = MultiHeadAttention(4, dropout_rate = 0, head_num = 2)
mask = torch.tensor( [ [1, 1, 1], [1, 0, 1], [1, 0, 1], ] ) mask = mask.unsqueeze(0).unsqueeze(0).expand(2, 2, 3, 3) print(mask.size()) s(n,mask)
|