海南城乡住房建设厅网站,海南省建设人力资源网站,wordpress首页弹窗你,学seo优化本文将基于PyTorch源码重新审视MultiheadAttention与Transformer。事实上#xff0c;早在一年前博主就已经分别介绍了两者#xff1a;各种注意力机制的PyTorch实现、从零开始手写一个Transformer#xff0c;但当时的实现大部分是基于d2l教程的#xff0c;这次将基于PyTorch…本文将基于PyTorch源码重新审视MultiheadAttention与Transformer。事实上早在一年前博主就已经分别介绍了两者各种注意力机制的PyTorch实现、从零开始手写一个Transformer但当时的实现大部分是基于d2l教程的这次将基于PyTorch源码重新实现一遍。 目录 1. MultiheadAttention1.1 思路1.2 源码1.3 极简版MHA面试用 2. Transformer3. QA1. MHA的参数量时间复杂度FLOPs  1. MultiheadAttention 
1.1 思路 
回顾多头注意力其公式如下 MHA ( Q , K , V )  Concat ( head 1 , ⋯ , head h ) W O head i  Attn ( Q W i Q , K W i K , V W i V ) \text{MHA}(Q,K,V)\text{Concat}(\text{head}_1,\cdots,\text{head}_h)W^O \\ \text{head}_i\text{Attn}(QW_i^Q,KW_i^K,VW_i^V) MHA(Q,K,V)Concat(head1,⋯,headh)WOheadiAttn(QWiQ,KWiK,VWiV) 
其中  W i Q ∈ R d m o d e l × d k W_i^Q\in \mathbb{R}^{d_{model}\times d_k} WiQ∈Rdmodel×dk W i K ∈ R d m o d e l × d k W_i^K\in \mathbb{R}^{d_{model}\times d_k} WiK∈Rdmodel×dk W i V ∈ R d m o d e l × d v W_i^V\in \mathbb{R}^{d_{model}\times d_v} WiV∈Rdmodel×dv W O ∈ R h d v × d m o d e l W^O\in \mathbb{R}^{hd_v\times d_{model}} WO∈Rhdv×dmodel且  d k  d v  d m o d e l / h d_kd_vd_{model}/h dkdvdmodel/h。 
如果记  d h e a d  d m o d e l / h d_{head}d_{model}/h dheaddmodel/h则  W i Q , W i K , W i V W_i^Q,W_i^K,W_i^V WiQ,WiK,WiV 的形状均为  ( d m o d e l , d h e a d ) (d_{model},d_{head}) (dmodel,dhead) W O W^O WO 的形状为  ( d m o d e l , d m o d e l ) (d_{model},d_{model}) (dmodel,dmodel)。 
先不考虑batch和mask的情形在只有一个头的情况下 h  1 h1 h1MHA的计算方式为 
class MHA(nn.Module):def __init__(self, d_model):super().__init__()self.w_q  nn.Parameter(torch.empty(d_model, d_model))self.w_k  nn.Parameter(torch.empty(d_model, d_model))self.w_v  nn.Parameter(torch.empty(d_model, d_model))self.w_o  nn.Parameter(torch.empty(d_model, d_model))self._reset_parameters()def _reset_parameters(self):for p in self.parameters():if p.dim()  1:nn.init.xavier_uniform_(p)def forward(self, query, key, value):Args:query: (n, d_model)n是query的个数m是key-value的个数key: (m, d_model)value: (m, d_model)q  query  self.w_qk  key  self.w_kv  value  self.w_vattn_logits  q  k.transpose(0, 1) / math.sqrt(q.size(1))  # attn_logits: (n, m)attn_probs  F.softmax(attn_logits, dim-1)attn_output  attn_probs  v  # attn_output: (n, d_model)return attn_output, attn_probs现在考虑  h  2 h2 h2 的情形此时一共需要  3 ⋅ 2  1  7 3\cdot217 3⋅217 个参数矩阵 
class MHA(nn.Module):def __init__(self, d_model):super().__init__()self.w_q_1  nn.Parameter(torch.empty(d_model, d_model // 2))self.w_k_1  nn.Parameter(torch.empty(d_model, d_model // 2))self.w_v_1  nn.Parameter(torch.empty(d_model, d_model // 2))self.w_q_2  nn.Parameter(torch.empty(d_model, d_model // 2))self.w_k_2  nn.Parameter(torch.empty(d_model, d_model // 2))self.w_v_2  nn.Parameter(torch.empty(d_model, d_model // 2))self.w_o  nn.Parameter(torch.empty(d_model, d_model))self._reset_parameters()def _reset_parameters(self):for p in self.parameters():if p.dim()  1:nn.init.xavier_uniform_(p)def forward(self, query, key, value):Args:query: (n, d_model)n是query的个数m是key-value的个数key: (m, d_model)value: (m, d_model)q_1  query  self.w_q_1k_1  key  self.w_k_1v_1  value  self.w_v_1q_2  query  self.w_q_2k_2  key  self.w_k_2v_2  value  self.w_v_2attn_logits_1  q_1  k_1.transpose(0, 1) / math.sqrt(q_1.size(1))attn_probs_1  F.softmax(attn_logits_1, dim-1)attn_output_1  attn_probs_1  v_1attn_logits_2  q_2  k_2.transpose(0, 1) / math.sqrt(q_2.size(1))attn_probs_2  F.softmax(attn_logits_2, dim-1)attn_output_2  attn_probs_2  v_2attn_output  torch.cat([attn_output_1, attn_output_2], dim-1)  self.w_o  # attn_output: (n, d_model)attn_probs  torch.stack([attn_probs_1, attn_probs_2], dim0)  # attn_probs: (2, n, m)其中2是头数return attn_output, attn_probs可以看到代码量已经增加了不少如果扩展到  h h h 个头的情形则需要  3 h  1 3h1 3h1 个参数矩阵。手动去一个个声明显然不现实因为  h h h 是动态变化的而用for循环创建又略显笨拙有没有更简便的方法呢 
在上面的代码中我们用小写  q q q 来代表查询  Q Q Q 经过投影后的结果 k , v k,v k,v 同理即 q i  Q W i Q , i  1 , 2 , ⋯ , h q_iQW_i^Q,\quad i 1,2,\cdots,h qiQWiQ,i1,2,⋯,h 
其中  Q Q Q 的形状为  ( n , d m o d e l ) (n,d_{model}) (n,dmodel) q i q_i qi 的形状为  ( n , d h e a d ) (n,d_{head}) (n,dhead)且有 h e a d i  softmax ( q i k i T d h e a d ) v i head_i\text{softmax}\left(\frac{q_ik_i^{T}}{\sqrt{d_{head}}}\right)v_i headisoftmax(dhead   qikiT)vi 
注意到 [ q 1 , q 2 , ⋯ , q h ]  Q [ W 1 Q , W 2 Q , ⋯ , W h Q ] (1) [q_1,q_2,\cdots,q_h]Q[W_1^Q,W_2^Q,\cdots,W_h^Q]\tag{1} [q1,q2,⋯,qh]Q[W1Q,W2Q,⋯,WhQ](1) 
如果记  q ≜ [ q 1 , q 2 , ⋯ , q h ] q\triangleq [q_1,q_2,\cdots,q_h] q≜[q1,q2,⋯,qh] W Q ≜ [ W 1 Q , W 2 Q , ⋯ , W h Q ] W^Q\triangleq [W_1^Q,W_2^Q,\cdots,W_h^Q] WQ≜[W1Q,W2Q,⋯,WhQ]则  W Q W^Q WQ 的形状为  ( d m o d e l , d m o d e l ) (d_{model},d_{model}) (dmodel,dmodel)与  h h h 无关 q q q 的形状为  ( n , d m o d e l ) (n,d_{model}) (n,dmodel)。这样一来我们就不需要一个个声明  W i Q W_i^Q WiQ 了并且可以一次性存储所有的  q i q_i qi。 
要计算  h e a d 1 head_1 head1我们需要能够从  q q q 中取出  q 1 q_1 q1 k , v k,v k,v 同理所以我们期望  q q q 的形状是  ( h , n , d h e a d ) (h,n,d_{head}) (h,n,dhead)从而  q [ 1 ] q[1] q[1] 就是  q 1 q_1 q1这里下标从  1 1 1 开始。 当然也可以是  ( n , h , d h e a d ) (n,h,d_{head}) (n,h,dhead) 等形状但必须要确保形状里含且只含这三个数字。之所以把  h h h 放在第一个维度是为了方便索引和后续计算。 同理可知  k , v k,v k,v 的形状均为  ( h , m , d h e a d ) (h,m,d_{head}) (h,m,dhead)。我们可以视  h h h 所在的维度为批量维从而可以执行批量乘法 torch.bmm 来一次性算出  h h h 个头的结果。 
q  torch.randn(h, n, d_head)
k  torch.randn(h, m, d_head)
v  torch.randn(h, m, d_head)# 和torch.bmm的效果相同但写法更简洁
attn_logits  q  k.transpose(1, 2) / math.sqrt(q.size(2))
attn_probs  F.softmax(attn_logits, dim-1)
attn_output  attn_probs  v  # attn_output: (h, n, d_head)h h h 个头的结果存储在形状为  ( h , n , d h e a d ) (h,n,d_{head}) (h,n,dhead) 的张量中那我们如何把这  h h h 个结果concat在一起呢注意到我们实际上是将  h h h 个形状为  ( n , d h e a d ) (n,d_{head}) (n,dhead) 的张量横向concat为一个形状为  ( n , d m o d e l ) (n,d_{model}) (n,dmodel) 的张量因此只需执行如下的形状变换 ( h , n , d h e a d ) → ( n , h , d h e a d ) → ( n , h ⋅ d h e a d )  ( n , d m o d e l ) (2) (h,n,d_{head})\to(n,h,d_{head})\to(n,h\cdot d_{head})(n,d_{model}) \tag{2} (h,n,dhead)→(n,h,dhead)→(n,h⋅dhead)(n,dmodel)(2) 
n  attn_output.size(1)
attn_output  attn_output.transpose(0, 1).reshape(n, -1)⚠️ 注意切勿直接将  ( h , n , d h e a d ) (h,n,d_{head}) (h,n,dhead) reshape成  ( n , d m o d e l ) (n,d_{model}) (n,dmodel)。 之前我们只讨论了  q q q 的形状应当是  ( h , n , d h e a d ) (h,n,d_{head}) (h,n,dhead)但并没有讨论它是如何变换得来的。这是因为 Q Q Q 在经过投影后得到的  q q q 只具有  ( n , d m o d e l ) (n,d_{model}) (n,dmodel) 的形状要进行形状变换一种做法是对  q q q 沿纵向切  h h h 刀再堆叠起来这样从直观上来看也比较符合公式  ( 1 ) (1) (1) 
q  torch.randn(n, d_model)
q  torch.stack(torch.split(q, d_head, dim-1), dim0)但由于  W Q W^Q WQ 初始时是随机的所以我们不需要严格按照公式  ( 1 ) (1) (1) 那样操作直接执行  ( 2 ) (2) (2) 的逆变换即可 ( n , d m o d e l )  ( n , h ⋅ d h e a d ) → ( n , h , d h e a d ) → ( h , n , d h e a d ) (n,d_{model})(n,h\cdot d_{head})\to(n,h,d_{head})\to(h,n,d_{head}) (n,dmodel)(n,h⋅dhead)→(n,h,dhead)→(h,n,dhead) 
现考虑有batch的情形设批量大小为  b b b则  Q Q Q 的形状为  ( b , n , d m o d e l ) (b,n,d_{model}) (b,n,dmodel) 或  ( n , b , d m o d e l ) (n,b,d_{model}) (n,b,dmodel)具体是哪一个要看 batch_first 是否为 True。接下来均假设 batch_first  False。 
在以上的假设下 q q q 的形状也为  ( n , b , d m o d e l ) (n,b,d_{model}) (n,b,dmodel)我们将  b b b 和  h h h 看成同一维度都是批量维从而  ( 2 ) (2) (2) 式改写为 ( n , b , d m o d e l ) → ( n , b , h , d h e a d ) → ( n , b ⋅ h , d h e a d ) → ( b ⋅ h , n , d h e a d ) (n,b,d_{model})\to(n,b,h,d_{head})\to(n,b\cdot h,d_{head})\to(b\cdot h,n,d_{head}) (n,b,dmodel)→(n,b,h,dhead)→(n,b⋅h,dhead)→(b⋅h,n,dhead) 
关于 key_padding_mask 和 attn_mask 这里不再介绍如有需要可阅读博主之前的文章这里主要讲解如何合并两种mask。 
前者的形状为  ( b , m ) (b,m) (b,m)用来mask掉key中的 [PAD]防止query注意到它。而后者的形状可以是  ( n , m ) (n,m) (n,m) 也可以是  ( b ⋅ h , n , m ) (b\cdot h,n,m) (b⋅h,n,m)。在实际合并两种mask的时候我们均需要按照  ( b ⋅ h , n , m ) (b\cdot h,n,m) (b⋅h,n,m) 这个形状去计算。也就是说如果是 key_padding_mask我们需要进行形状变换  ( b , m ) → ( b , 1 , 1 , m ) → ( b , h , 1 , m ) → ( b ⋅ h , 1 , m ) (b,m)\to(b,1,1,m)\to(b,h,1,m)\to(b\cdot h,1,m) (b,m)→(b,1,1,m)→(b,h,1,m)→(b⋅h,1,m)如果是 attn_mask我们需要进行形状变换  ( n , m ) → ( 1 , n , m ) (n,m)\to(1,n,m) (n,m)→(1,n,m)。 
1.2 源码 
本节将遵循以下记号 
记号说明 b b bbatch size h h hnum heads d d dhead dim n n nnum queries m m mnum key-value pairs 
首先实现一个MHA的基类 
class MultiheadAttentionBase_(nn.Module):def __init__(self, embed_dim, num_heads, dropout0., biasTrue):super().__init__()self.embed_dim  embed_dimself.num_heads  num_headsself.dropout  dropoutself.head_dim  embed_dim // num_headsassert self.head_dim * num_heads  embed_dimself.in_proj_weight  nn.Parameter(torch.empty(3 * embed_dim, embed_dim))if bias:self.in_proj_bias  nn.Parameter(torch.empty(3 * embed_dim))else:self.register_parameter(in_proj_bias, None)self.out_proj  nn.Linear(embed_dim, embed_dim, biasbias)self._reset_parameters()def _reset_parameters(self):nn.init.xavier_uniform_(self.in_proj_weight)if self.in_proj_bias is not None:nn.init.constant_(self.in_proj_bias, 0.)nn.init.constant_(self.out_proj.bias, 0.)def forward(self,query,key,value,key_padding_mask,attn_mask,need_weightsTrue,):Args:query: (n, b, h * d)key: (m, b, h * d)value: (m, b, h * d)key_padding_mask: (b, m), bool typeattn_mask: (n, m) or (b * h, n, m), bool typeReturns:attn_output: (n, b, h * d)attn_weights: (b, h, n, m)w_q, w_k, w_v  self.in_proj_weight.chunk(3)if self.in_proj_bias is not None:b_q, b_k, b_v  self.in_proj_bias.chunk(3)else:b_q  b_k  b_v  Noneq  F.linear(query, w_q, b_q)k  F.linear(key, w_k, b_k)v  F.linear(value, w_v, b_v)b, h, d  q.size(1), self.num_heads, self.head_dimq, k, v  map(lambda x: x.reshape(-1, b, h, d), [q, k, v])attn_mask  self.merge_masks(key_padding_mask, attn_mask, q)attn_output, attn_weights  self.attention(q, k, v, attn_mask, out_projself.out_proj, dropoutself.dropout, trainingself.training)if not need_weights:attn_weights  Nonereturn attn_output, attn_weightsdef merge_masks(self, key_padding_mask, attn_mask, q):Args:key_padding_mask: (b, m), bool typeattn_mask: (n, m) or (b * h, n, m), bool typeq: only used to confirm the dtype of attn_maskReturns:attn_mask: (b * h, n, m), float typeassert key_padding_mask is not None and key_padding_mask.dtype  torch.boolb, m  key_padding_mask.size()key_padding_mask  key_padding_mask.view(b, 1, 1, m).expand(-1, self.num_heads, -1, -1).reshape(b * self.num_heads, 1, m)if attn_mask is not None:assert attn_mask.dtype  torch.boolif attn_mask.dim()  2:attn_mask  attn_mask.unsqueeze(0)attn_mask  attn_mask.logical_or(key_padding_mask)else:attn_mask  key_padding_maskattn_mask  torch.zeros_like(attn_mask, dtypeq.dtype).masked_fill_(attn_mask, -1e28)return attn_maskdef attention(self, q, k, v, attn_mask, out_proj, dropout, training):Args:q: (n, b, h, d)k: (m, b, h, d)v: (m, b, h, d)attn_mask: (b * h, n, m), float typeout_proj: nn.Linear(h * d, h * d)Returns:attn_output: (n, b, h * d), is the result of concating h heads.attn_weights: (b, h, n, m)raise NotImplementedError接下来只需要重写 attention 方法就可以实现普通版的MHA了 
class MultiheadAttention(MultiheadAttentionBase_):def attention(self, q, k, v, attn_mask, out_proj, dropout, training):if not training:dropout  0n, b, h, d  q.size()q, k, v  map(lambda x: x.reshape(-1, b * h, d).transpose(0, 1), [q, k, v])attn_logits  q  k.transpose(-2, -1) / math.sqrt(d)  attn_maskattn_probs  F.softmax(attn_logits, dim-1)attn_weights  F.dropout(attn_probs, pdropout)attn_output  attn_weights  vattn_output  attn_output.transpose(0, 1).reshape(n, b, h * d)attn_output  out_proj(attn_output)return attn_output, attn_weights1.3 极简版MHA面试用 
不少面试会让现场手写MHA这里提供了一份模版略去了很多细节。 
相比原版极简版做了如下改动 
略去了参数初始化。去掉了mask 
class MultiheadAttention(nn.Module):def __init__(self, embed_dim, num_heads, dropout0., biasTrue):super().__init__()self.embed_dim  embed_dimself.num_heads  num_headsself.dropout  nn.Dropout(dropout)self.head_dim  embed_dim // num_headsassert self.head_dim * num_heads  embed_dimself.in_proj_weight  nn.Parameter(torch.empty(3 * embed_dim, embed_dim))if bias:self.in_proj_bias  nn.Parameter(torch.empty(3 * embed_dim))else:self.register_parameter(in_proj_bias, None)self.out_proj  nn.Linear(embed_dim, embed_dim, biasbias)def forward(self, query, key, value):Args:query: (n, b, h * d)key: (m, b, h * d)value: (m, b, h * d)w_q, w_k, w_v  self.in_proj_weight.chunk(3)if self.in_proj_bias is not None:b_q, b_k, b_v  self.in_proj_bias.chunk(3)else:b_q  b_k  b_v  Noneq, k, v  F.linear(query, w_q, b_q), F.linear(key, w_k, b_k), F.linear(value, w_v, b_v)b, h, d  q.size(1), self.num_heads, self.head_dimq, k, v  map(lambda x: x.reshape(-1, b * h, d).transpose(0, 1), [q, k, v])attn_logits  q  k.transpose(-2, -1) / math.sqrt(d)attn_probs  F.softmax(attn_logits, dim-1)attn_weights  self.dropout(attn_probs)attn_output  attn_weights  vattn_output  attn_output.transpose(0, 1).reshape(-1, b, h * d)attn_output  self.out_proj(attn_output)return attn_output, attn_weights注意如果尝试直接输出的话会得到一堆 nan这是因为没有xavier初始化需要 _reset_parameters() 一下。 
具体需要哪种mask可根据面试官的要求去实现。 
2. Transformer 
接下来基于PyTorch官方的MHA来实现Transformer。 
首先需要实现一个基础函数它可以用来复制一个 Module N次。 
def _get_clones(module, n):return nn.ModuleList([copy.deepcopy(module) for _ in range(n)])EncoderLayer的实现 
class TransformerEncoderLayer(nn.Module):def __init__(self,d_model,n_head,d_ffn,dropout0.1,activationF.relu,norm_firstFalse,):super().__init__()self.self_attn  nn.MultiheadAttention(embed_dimd_model, num_headsn_head, dropoutdropout)self.dropout1  nn.Dropout(dropout)self.linear1  nn.Linear(d_model, d_ffn)self.activation  activationself.dropout2  nn.Dropout(dropout)self.linear2  nn.Linear(d_ffn, d_model)self.dropout3  nn.Dropout(dropout)self.norm1  nn.LayerNorm(d_model)self.norm2  nn.LayerNorm(d_model)self.norm_first  norm_firstdef forward(self, src, src_mask, src_key_padding_mask):x  srcif self.norm_first:x  x  self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)x  x  self._ff_block(self.norm2(x))else:x  self.norm1(x  self._sa_block(x, src_mask, src_key_padding_mask))x  self.norm2(x  self._ff_block(x))return xdef _sa_block(self, x, attn_mask, key_padding_mask):x  self.self_attn(x, x, x, attn_maskattn_mask, key_padding_maskkey_padding_mask, need_weightsFalse)[0]return self.dropout1(x)def _ff_block(self, x):x  self.linear2(self.dropout2(self.activation(self.linear1(x))))return self.dropout3(x)这里的 norm_first 用来决定是Pre-LN还是Post-LN如下图所示 DecoderLayer的实现 
class TransformerDecoderLayer(nn.Module):def __init__(self,d_model,n_head,d_ffn,dropout0.1,activationF.relu,norm_firstFalse,):super().__init__()self.self_attn  nn.MultiheadAttention(embed_dimd_model, num_headsn_head, dropoutdropout)self.dropout1  nn.Dropout(dropout)self.cross_attn  nn.MultiheadAttention(embed_dimd_model, num_headsn_head, dropoutdropout)self.dropout2  nn.Dropout(dropout)self.linear1  nn.Linear(d_model, d_ffn)self.activation  activationself.dropout3  nn.Dropout(dropout)self.linear2  nn.Linear(d_ffn, d_model)self.dropout4  nn.Dropout(dropout)self.norm1  nn.LayerNorm(d_model)self.norm2  nn.LayerNorm(d_model)self.norm3  nn.LayerNorm(d_model)self.norm_first  norm_firstdef forward(self, tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask):x  tgtif self.norm_first:x  x  self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask)x  x  self._ca_block(self.norm2(x), memory, memory_mask, memory_key_padding_mask)x  x  self._ff_block(self.norm3(x))else:x  self.norm1(x  self._sa_block(x, tgt_mask, tgt_key_padding_mask))x  self.norm2(x  self._ca_block(x, memory, memory_mask, memory_key_padding_mask))x  self.norm3(x  self._ff_block(x))return xdef _sa_block(self, x, attn_mask, key_padding_mask):x  self.self_attn(x, x, x, attn_maskattn_mask, key_padding_maskkey_padding_mask, need_weightsFalse)[0]return self.dropout1(x)def _ca_block(self, x, mem, attn_mask, key_padding_mask):x  self.cross_attn(x, mem, mem, attn_maskattn_mask, key_padding_maskkey_padding_mask, need_weightsFalse)[0]return self.dropout2(x)def _ff_block(self, x):x  self.linear2(self.dropout3(self.activation(self.linear1(x))))return self.dropout4(x)根据EncoderLayer搭建Encoder。需要注意的是PyTorch源码中还提供了 encoder_norm 这一参数即决定是否在Encoder最后放一个LN。 
class TransformerEncoder(nn.Module):def __init__(self, encoder_layer, num_layers, encoder_normNone):super().__init__()self.layers  _get_clones(encoder_layer, num_layers)self.num_layers  num_layersself.encoder_norm  encoder_normdef forward(self, src, src_mask, src_key_padding_mask):output  srcfor mod in self.layers:output  mod(output, src_mask, src_key_padding_mask)if self.encoder_norm is not None:output  self.encoder_norm(output)return outputDecoderLayer同理 
class TransformerDecoder(nn.Module):def __init__(self, decoder_layer, num_layers, decoder_normNone):super().__init__()self.layers  _get_clones(decoder_layer, num_layers)self.num_layers  num_layersself.decoder_norm  decoder_normdef forward(self, tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask):output  tgtfor mod in self.layers:output  mod(output, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask)if self.decoder_norm is not None:output  self.decoder_norm(output)return outputPyTorch官方的Transformer默认添加 encoder_norm 和 decoder_norm然而这对于Post-LN的情形无疑是多余的所以这里我们做个简单修改即如果是Post-LN情形就不在最后添加LN了。 
class Transformer(nn.Module):def __init__(self,d_model512,n_head8,num_encoder_layers6,num_decoder_layers6,d_ffn2048,dropout0.1,activationF.relu,norm_firstFalse,):super().__init__()if norm_first:encoder_norm, decoder_norm  nn.LayerNorm(d_model), nn.LayerNorm(d_model)else:encoder_norm  decoder_norm  Noneencoder_layer  TransformerEncoderLayer(d_model, n_head, d_ffn, dropout, activation, norm_first)self.encoder  TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)decoder_layer  TransformerDecoderLayer(d_model, n_head, d_ffn, dropout, activation, norm_first)self.decoder  TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm)self._reset_parameters()def _reset_parameters(self):for p in self.parameters():if p.dim()  1:nn.init.xavier_uniform_(p)def forward(self,src,tgt,src_maskNone,tgt_maskNone,memory_maskNone,src_key_padding_maskNone,tgt_key_padding_maskNone,memory_key_padding_maskNone,):memory  self.encoder(src, src_mask, src_key_padding_mask)output  self.decoder(tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask)return output截止到目前我们实现的Transfomer并不是完整的还缺少embedding层和Decoder后面的Linear层这里只介绍前者因为后者仅仅是简单的 nn.Linear(d_model, tgt_vocab_size)。 
Transformer的embedding层分为token embedding和Positional Encoding前者是可学习的 nn.Embedding后者是固定的Sinusoidal编码。 
PE的公式为 P [ i , 2 j ]  sin  ( i 1000 0 2 j / d m o d e l ) P [ i , 2 j  1 ]  cos  ( i 1000 0 2 j / d m o d e l ) 0 ≤ i  m a x _ l e n , 0 ≤ j  d m o d e l P[i,2j]\sin\left(\frac{i}{10000^{2j/d_{model}}}\right)\\ P[i,2j1]\cos\left(\frac{i}{10000^{2j/d_{model}}}\right) \\ 0\leq i  max\_len,\;0\leq jd_{model} P[i,2j]sin(100002j/dmodeli)P[i,2j1]cos(100002j/dmodeli)0≤imax_len,0≤jdmodel 
class PositionalEncoding(nn.Module):def __init__(self, d_model, dropout0.1, max_len5000):super().__init__()self.dropout  nn.Dropout(dropout)position  torch.arange(max_len).unsqueeze(1)div_term  torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))pe  torch.zeros(max_len, 1, d_model)  # 1是batch size维度pe[:, 0, 0::2]  torch.sin(position * div_term)pe[:, 0, 1::2]  torch.cos(position * div_term)self.register_buffer(pe, pe)def forward(self, x):x  x  self.pe[:x.size(0)]return self.dropout(x)3. QA 
1. MHA的参数量时间复杂度FLOPs 
只考虑自注意力情形。为简便起见令  h ≜ d m o d e l h\triangleq d_{model} h≜dmodel。 
MHA模块一共包含四个参数矩阵 W Q , W K , W V , W O W^Q,W^K,W^V,W^O WQ,WK,WV,WO形状均为  ( h , h ) (h,h) (h,h)因此weight部分的参数量是  4 ⋅ h 2 4\cdot h^2 4⋅h2。每个参数矩阵都会带有一个长度为  h h h 的bias因此总共的参数量为  4 h 2  4 h 4h^24h 4h24h。 注意FLOPs和FLOPS的含义不同。前者是floating point operations指浮点运算数可以理解为计算量用来衡量模型/算法的复杂度后者是floating point operations per second指每秒浮点运算次数可以理解为计算速度用来衡量衡量硬件的性能。 在计算形状为  ( m , n ) (m,n) (m,n) 和  ( n , k ) (n,k) (n,k) 矩阵的乘积时每计算一次内积都要执行  n n n 次乘法和  n n n 次加法而最终输出矩阵的形状为  ( m , k ) (m,k) (m,k)所以总共的浮点运算次数为  ( n  n ) ⋅ m ⋅ k  2 m n k (nn)\cdot m\cdot k2mnk (nn)⋅m⋅k2mnk。 
回到MHA只考虑矩阵乘法 
首先会对形状为  ( l , b , h ) (l,b,h) (l,b,h) 的embedding进行投影执行的矩阵乘法为  ( l , b , h ) × ( h , h ) → ( l , b , h ) (l,b,h)\times (h, h)\to(l,b,h) (l,b,h)×(h,h)→(l,b,h)这一步的计算量为  2 l b h 2 2lbh^2 2lbh2。由于会分别投影到  Q , K , V Q,K,V Q,K,V 三个矩阵因此这一步的总计算量为  6 l b h 2 6lbh^2 6lbh2。接下来是  Q K T QK^T QKT 相乘执行的矩阵乘法为  ( b ⋅ n h , l , h d ) × ( b ⋅ n h , h d , l ) → ( b ⋅ n h , l , l ) (b\cdot nh,l,hd)\times(b\cdot nh,hd,l)\to(b\cdot nh,l,l) (b⋅nh,l,hd)×(b⋅nh,hd,l)→(b⋅nh,l,l)其中  n h nh nh 代表 num_heads h d hd hd 代表 head_dim。计算量为  2 l 2 b h 2l^2bh 2l2bh。然后是对  V V V 进行加权执行的矩阵乘法为  ( b ⋅ n h , l , l ) × ( b ⋅ n h , l , h d ) → ( b ⋅ n h , l , h d ) (b\cdot nh,l,l)\times(b\cdot nh,l,hd)\to(b\cdot nh,l,hd) (b⋅nh,l,l)×(b⋅nh,l,hd)→(b⋅nh,l,hd)计算量为  2 l 2 b h 2l^2bh 2l2bh。最后的投影中执行的矩阵乘法为  ( l , b , h ) × ( h , h ) → ( l , b , h ) (l,b,h)\times(h,h)\to(l,b,h) (l,b,h)×(h,h)→(l,b,h)计算量为  2 l b h 2 2lbh^2 2lbh2。 
由上述步骤可知MHA的FLOPs约为  6 l b h 2  2 l 2 b h  2 l 2 b h  2 l b h 2  4 l b h ( 2 h  l ) 6lbh^22l^2bh2l^2bh2lbh^24lbh(2hl) 6lbh22l2bh2l2bh2lbh24lbh(2hl)。 
再来看MHA的复杂度依然只考虑矩阵乘法。在计算形状为  ( m , n ) (m,n) (m,n) 和  ( n , k ) (n,k) (n,k) 矩阵的乘积时计算内积的时间复杂度为  O ( n ) O(n) O(n)而输出矩阵的形状为  ( m , k ) (m,k) (m,k)填满这个矩阵所需要的时间为  O ( m k ) O(mk) O(mk)所以总时间复杂度为  O ( m n k ) O(mnk) O(mnk)。 
可以发现一个不严谨的等式仅针对矩阵乘法场景 时间复杂度  O ( FLOPs 2 ) 时间复杂度O\left(\frac{\text{FLOPs}}{2}\right) 时间复杂度O(2FLOPs) 
由此可得到MHA的时间复杂度为  O ( 2 l b h ( 2 h  l ) )  O ( l b h 2  l 2 b h ) O(2lbh(2hl))O(lbh^2l^2bh) O(2lbh(2hl))O(lbh2l2bh)。特别地当  b  1 b1 b1 且  h ≪ l h\ll l h≪l 时MHA的复杂度退化为  O ( l 2 h ) O(l^2h) O(l2h)这就是Transformer那篇论文里提到的复杂度。