Bert源码阅读笔记
Bert 组件
BertEmbeddings
embeddings:
- word_embeddings: nn.Embedding, $vocab_size \times hidden_size$
- position_embeddings: nn.Embedding, $max_position_embeddings \times hidden_size$
- token_type_embeddings: nn.Embedding, $type_vocab_size \times hidden_size$
注意,如果 position_embedding_type 为 absolute 的话,输出 embedding 是3个embedding的加和;如果 position_embedding_type 为 relative 的话,输出 embedding 不包括 position_embeddings; position_embeddings 在 self-attention 里处理。
layernorm & dropout:
- LayerNorm: nn.LayerNorm, $hidden_size$
- dropout: nn.Dropout, $hidden_drop_out$
BertSelfAttention
- num_attention_heads: 多头attention数量
- attention_head_size: $hidden_size / num_attention_heads$
- all_head_size: $attention_head_size \times num_attention_heads$
attention 结构:
- query: nn.Linear, $hidden_size \times all_head_size$
- key: nn.Linear, $hidden_size \times all_head_size$
- value: nn.Linear, $hidden_size \times all_head_size$
dropout:
- dropout: nn.Dropout, $attention_probs_dropout_prob$
distance_embedding:
- distance_embedding: nn.Embedding, $(2*max_position_embeddings-1) \times attention_head_size$
核心代码逻辑
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
transpose_for_scores 对 hidden_states 进行变化,hidden_states 的 size 是 $batch_size \times seq_length \times hidden_size$,变化后的 size 是 $batch_size \times num_heads \times seq_length \times attn_size$
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
):
mixed_query_layer = self.query(hidden_states)
# ... 省略若干代码
else:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
# ... 此处省略 decoder 相关代码
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
# ... 此处省略 position 相关代码
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = nn.Softmax(dim=-1)(attention_scores)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
if self.is_decoder:
outputs = outputs + (past_key_value,)
return outputs
先看一下关于 query/key/value 的计算,以 query 为例
mixed_query_layer = self.query(hidden_states)
# ...
query_layer = self.transpose_for_scores(mixed_query_layer)
输出 mixed_query_layer 的 size 为 $batch_size \times seq_length \times all_head_size$, 再经过 transpose_for_scores 后 query_layer 的 size 为 $batch_size \times num_heads \times seq_length \times attn_size$,key 和 value 的计算是一样的。得到 query/key/value 后,就可以计算 attention_scores了
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_probs = nn.Softmax(dim=-1)(attention_scores)
attention_probs = self.dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
attention_scores 计算每一个 token 和序列里所有 token 的注意力的权重,size 为 $batch_size \times num_heads \times seq_length \times seq_length$,然后做归一化得到 attention_probs,attention_probs 再接一个 dropout 层,最后乘以每个 token 的 value,即得到注意力 context_layer, size 为 $batch_size \times num_heads \times seq_length \times all_head_size$。为了和之前的 hidden_states 保持一致,这里先通过一个 permute 把形状变化为 $batch_size \times seq_length \times num_heads \times all_head_size$,然后再通过 view 方法把形状重置为 $batch_size \times seq_length \times all_head_size$
在代码里还有一段关于 position 的代码
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
seq_length = hidden_states.size()[1]
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
distance = position_ids_l - position_ids_r
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
if self.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores
elif self.position_embedding_type == "relative_key_query":
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
当 position_embedding_type 为 relative_key 或 relative_key_query 时,会在 attention_scores 里加上 key 或 query 相关的 position embedding distance 计算任意两个 token 的位置差,size 为 $seq_length \times seq_length$,经过 embedding 后 size 为 $seq_length \times seq_length \times attn_head_size$,再通过 esinsum 把 embedding 加到 attention_scores 里
BertSelfAttention 的输出是 (context_layer, [optional] attention_probs, [optional] past_key_value); 其中 past_key_value 也是一个 tuple, 其内容为 (key_layer, value_layer), 这一部分的值是在 is_decoder 为 True 时设置的
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)
BertSelfOutput
- dense: nn.Linear, $hidden_size \times hidden_size$
- dropout: nn.Dropout, $hidden_dropout_prob$
- LayerNorm: nn.LayerNorm, $hidden_size$
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
LayerNorm层的输入是 dropout层的输出 + input_tensor,即先进行残差连接,再做LayerNorm
BertAttention
- self: BertSelfAttention
- output: BertSelfOutput
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
):
self_outputs = self.self(
hidden_states,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs
这个组件比较简单,主要就是作用BertSelfOutput对BertSelfAttention输出中的 context_layer 做了一个变换,最终输出是 (attention_output, [optional] attention_probs, [optional] past_key_value)
BertIntermediate
- dense: nn.Linear, $hidden_size \times hidden_size$
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
BertOutput
- dense: nn.Linear, $hidden_size \times hidden_size$
- dropout: nn.Dropout, $hidden_dropout_prob$
- LayerNorm: nn.LayerNorm, $hidden_size$
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
和BertSelfOutput一样,LayerNorm也是先进行残差连接,再做LayerNorm
BertLayer
- attention: BertAttention
- intermediate: BertIntermediate
- output: BertOutput
如果只看 encoder 的部分,BertLayer 还是比较简单的
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
):
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention(
hidden_states,
attention_mask,
head_mask,
output_attentions=output_attentions,
past_key_value=self_attn_past_key_value,
)
attention_output = self_attention_outputs[0]
# ... 省略 decoder 部分相关代码
layer_output = apply_chunking_to_forward(
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
)
outputs = (layer_output,) + outputs
# ... 省略 decoder 部分相关代码
return outputs
直接获取 attention 的输出,并使用 feed_forward_chunk 方法变换
def feed_forward_chunk(self, attention_output):
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output
BertLayer 层的输出是 (layout_output, [optional] attention_probs, [opitonal] present_key_value)
BertEncoder
- layer: nn.ModuleList, [BertLayer] * num_hidden_layers
BertEncoder 由 num_hidden_layers 层 BertLayer 组成。对于每一层 layer,计算输出并保存结果,代码比较简单,就不贴了
BertPooler
- dense: nn.Linear, $hidden_size \times hidden_size$
- activation: nn.Tanh
BertPooler即第一个token [CLS] 的输出
def forward(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
Bert Head
Head 部分主要是在输出部分针对一些任务做的封装,代码比较简单
BertPredictionHeadTransform
- dense: nn.Linear, $hidden_size \times hidden_size$
- transform_act_fn
- LayerNorm: nn.LayerNorm, $hidden_size$
BertLMPredictionHead
- transform: BertPredictionHeadTransform
- decoder: nn.Linear, $hidden_size \times vocab_size$
- bias: nn.Parameter, $vocab_size$
BertOnlyMLMHead
- predictions: BertLMPredictionHead
BertPreTrainingHeads
- predictions: BertLMPredictionHead
- seq_relationship: nn.Linear, $hidden_size \times 2$
BertOnlyNSPHead
- seq_relationship: nn.Linear, $hidden_size \times hidden_size$
Bert
Model 部分的代码针对不同的任务进行了封装,其中 BertModel 是一个基础的类,其它任务类型中都会有一个 BertModel
BertModel
- embeddings: BertEmbeddings
- encoder: BertEncoder
- pooler: BertPooler
BertForPretraining:
- bert: BertModel
- cls: BertPretrainingHeads
BertLMHeadModel
- bert: BertModel
- cls: BertOnlyMLMHead
BertForMaskedLM
- bert: BertModel
- cls: BertOnlyMLMHead
BertForNextSentencePrediction
- bert: BertModel
- cls: BertOnlyNSPHead
BertForNextSentencePrediction
- bert: BertModel
- dropout: nn.Dropout, $hidden_dropout_prob$
- classifier: nn.Linear, $hidden_size \times num_labels$
BertForMultipleChoice
- bert: BertModel
- dropout: nn.Dropout, $hidden_dropout_prob$
- classifier: nn.Linear, $hidden_size \times 1$
BertForTokenClassification
- bert: BertModel
- dropout: nn.Dropout, $hidden_dropout_prob$
- classifier: nn.Linear, $hidden_size \times num_labels$
BertForQuestionAnswering
- bert: BertModel
- qa_outputs: nn.Linear, $hidden_size \times num_labels$
