-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmodule.py
96 lines (76 loc) · 3.98 KB
/
module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch.nn as nn
import torch
class PositionwiseFeedForward(nn.Module):
def __init__(self, d_model, d_inner, dropout):
super(PositionwiseFeedForward, self).__init__()
self.w_1 = nn.Linear(d_model, d_inner)
self.w_2 = nn.Linear(d_inner, d_model)
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
self.dropout = nn.Dropout(dropout)
self.activate = nn.ReLU()
def forward(self, x):
residual = x
x = self.dropout(self.w_2(self.activate(self.w_1(x))))
return self.layer_norm(residual + x)
class SelfAttention(nn.Module):
def __init__(self, temperature, dropout):
super().__init__()
self.temperature = temperature
self.dropout = nn.Dropout(dropout)
self.softmax = nn.Softmax(dim=-1)
def forward(self, query, key, value, mask):
attn = torch.matmul(query, key.transpose(-2, -1)) / self.temperature
attn = attn + mask
p_attn = self.dropout(self.softmax(attn))
return torch.matmul(p_attn, value), p_attn
class MultiHeadedAttention(nn.Module):
def __init__(self, n_heads, d_model, dropout):
super().__init__()
assert d_model % n_heads == 0
self.d_model = d_model
self.d_k = d_model // n_heads
self.n_heads = n_heads
self.d_v = self.d_k
self.w_Q = nn.Linear(d_model, n_heads * self.d_k, bias=False)
self.w_K = nn.Linear(d_model, n_heads * self.d_k, bias=False)
self.w_V = nn.Linear(d_model, n_heads * self.d_v, bias=False)
self.fc = nn.Linear(n_heads * self.d_v, d_model, bias=False)
self.self_attention = SelfAttention(temperature=self.d_k ** 0.5, dropout=dropout)
self.dropout = nn.Dropout(p=dropout)
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
def forward(self, query, key, value, mask):
sz_b, len_q, len_k, len_v = query.size(0), query.size(1), key.size(1), value.size(1)
residual = query
q = self.w_Q(query).view(sz_b, len_q, self.n_heads, self.d_k).transpose(1, 2)
k = self.w_K(key).view(sz_b, len_k, self.n_heads, self.d_k).transpose(1, 2)
v = self.w_V(value).view(sz_b, len_v, self.n_heads, self.d_v).transpose(1, 2)
x, attn = self.self_attention(q, k, v, mask=mask)
x = x.transpose(1, 2).contiguous().view(sz_b, len_q, self.d_model)
x = self.dropout(self.fc(x))
return self.layer_norm(residual + x)
class TransformerBlock(nn.Module):
def __init__(self, d_model, n_heads, d_inner, dropout):
super().__init__()
self.multi_head_attention = MultiHeadedAttention(n_heads=n_heads, d_model=d_model, dropout=dropout)
self.feed_forward = PositionwiseFeedForward(d_model=d_model, d_inner=d_inner, dropout=dropout)
def forward(self, block_input, mask):
output = self.multi_head_attention(block_input, block_input, block_input, mask)
return self.feed_forward(output)
class TransformerEncoder(torch.nn.Module):
def __init__(self, n_vocab, n_position, d_model, n_heads, dropout, n_layers):
super(TransformerEncoder, self).__init__()
# self.word_embedding = nn.Embedding(n_vocab + 1, d_model, padding_idx=0)
self.position_embedding = nn.Embedding(n_position, d_model)
self.dropout = nn.Dropout(p=dropout)
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
self.transformer_blocks = nn.ModuleList(
[TransformerBlock(d_model=d_model, n_heads=n_heads, d_inner=d_model * 4, dropout=dropout
) for _ in range(n_layers)])
def forward(self, input_embs, log_mask, att_mask):
position_ids = torch.arange(log_mask.size(1), dtype=torch.long, device=log_mask.device)
position_ids = position_ids.unsqueeze(0).expand_as(log_mask)
output = self.layer_norm(input_embs + self.position_embedding(position_ids))
output = self.dropout(output)
for transformer in self.transformer_blocks:
output = transformer.forward(output, att_mask)
return output