Bert源码阅读笔记

Bert 组件

BertEmbeddings

embeddings:

  • word_embeddings: nn.Embedding, $vocab_size \times hidden_size$
  • position_embeddings: nn.Embedding, $max_position_embeddings \times hidden_size$
  • token_type_embeddings: nn.Embedding, $type_vocab_size \times hidden_size$

注意,如果 position_embedding_type 为 absolute 的话,输出 embedding 是3个embedding的加和;如果 position_embedding_type 为 relative 的话,输出 embedding 不包括 position_embeddings; position_embeddings 在 self-attention 里处理。

layernorm & dropout:

  • LayerNorm: nn.LayerNorm, $hidden_size$
  • dropout: nn.Dropout, $hidden_drop_out$

BertSelfAttention

  • num_attention_heads: 多头attention数量
  • attention_head_size: $hidden_size / num_attention_heads$
  • all_head_size: $attention_head_size \times num_attention_heads$

attention 结构:

  • query: nn.Linear, $hidden_size \times all_head_size$
  • key: nn.Linear, $hidden_size \times all_head_size$
  • value: nn.Linear, $hidden_size \times all_head_size$

dropout:

  • dropout: nn.Dropout, $attention_probs_dropout_prob$

distance_embedding:

  • distance_embedding: nn.Embedding, $(2*max_position_embeddings-1) \times attention_head_size$
核心代码逻辑
def transpose_for_scores(self, x):
    new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
    x = x.view(*new_x_shape)
    return x.permute(0, 2, 1, 3)

transpose_for_scores 对 hidden_states 进行变化,hidden_states 的 size 是 $batch_size \times seq_length \times hidden_size$,变化后的 size 是 $batch_size \times num_heads \times seq_length \times attn_size$

def forward(
    self,
    hidden_states,
    attention_mask=None,
    head_mask=None,
    encoder_hidden_states=None,
    encoder_attention_mask=None,
    past_key_value=None,
    output_attentions=False,
):
    mixed_query_layer = self.query(hidden_states)

    # ... 省略若干代码
    else:
        key_layer = self.transpose_for_scores(self.key(hidden_states))
        value_layer = self.transpose_for_scores(self.value(hidden_states))

    query_layer = self.transpose_for_scores(mixed_query_layer)

    # ... 此处省略 decoder 相关代码

    # Take the dot product between "query" and "key" to get the raw attention scores.
    attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

    # ... 此处省略 position 相关代码

    attention_scores = attention_scores / math.sqrt(self.attention_head_size)
    if attention_mask is not None:
        # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
        attention_scores = attention_scores + attention_mask

    # Normalize the attention scores to probabilities.
    attention_probs = nn.Softmax(dim=-1)(attention_scores)

    # This is actually dropping out entire tokens to attend to, which might
    # seem a bit unusual, but is taken from the original Transformer paper.
    attention_probs = self.dropout(attention_probs)

    # Mask heads if we want to
    if head_mask is not None:
        attention_probs = attention_probs * head_mask

    context_layer = torch.matmul(attention_probs, value_layer)

    context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
    new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
    context_layer = context_layer.view(*new_context_layer_shape)

    outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

    if self.is_decoder:
        outputs = outputs + (past_key_value,)
    return outputs

先看一下关于 query/key/value 的计算,以 query 为例

    mixed_query_layer = self.query(hidden_states)
    # ...
    query_layer = self.transpose_for_scores(mixed_query_layer)

输出 mixed_query_layer 的 size 为 $batch_size \times seq_length \times all_head_size$, 再经过 transpose_for_scores 后 query_layer 的 size 为 $batch_size \times num_heads \times seq_length \times attn_size$,key 和 value 的计算是一样的。得到 query/key/value 后,就可以计算 attention_scores了

    attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
    attention_probs = nn.Softmax(dim=-1)(attention_scores)
    attention_probs = self.dropout(attention_probs)
    context_layer = torch.matmul(attention_probs, value_layer)

attention_scores 计算每一个 token 和序列里所有 token 的注意力的权重,size 为 $batch_size \times num_heads \times seq_length \times seq_length$,然后做归一化得到 attention_probs,attention_probs 再接一个 dropout 层,最后乘以每个 token 的 value,即得到注意力 context_layer, size 为 $batch_size \times num_heads \times seq_length \times all_head_size$。为了和之前的 hidden_states 保持一致,这里先通过一个 permute 把形状变化为 $batch_size \times seq_length \times num_heads \times all_head_size$,然后再通过 view 方法把形状重置为 $batch_size \times seq_length \times all_head_size$

在代码里还有一段关于 position 的代码

    if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
        seq_length = hidden_states.size()[1]
        position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
        position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
        distance = position_ids_l - position_ids_r
        positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
        positional_embedding = positional_embedding.to(dtype=query_layer.dtype)  # fp16 compatibility

        if self.position_embedding_type == "relative_key":
            relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
            attention_scores = attention_scores + relative_position_scores
        elif self.position_embedding_type == "relative_key_query":
            relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
            relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
            attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key

当 position_embedding_type 为 relative_key 或 relative_key_query 时,会在 attention_scores 里加上 key 或 query 相关的 position embedding distance 计算任意两个 token 的位置差,size 为 $seq_length \times seq_length$,经过 embedding 后 size 为 $seq_length \times seq_length \times attn_head_size$,再通过 esinsum 把 embedding 加到 attention_scores 里

BertSelfAttention 的输出是 (context_layer, [optional] attention_probs, [optional] past_key_value); 其中 past_key_value 也是一个 tuple, 其内容为 (key_layer, value_layer), 这一部分的值是在 is_decoder 为 True 时设置的

    if self.is_decoder:
        # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
        # Further calls to cross_attention layer can then reuse all cross-attention
        # key/value_states (first "if" case)
        # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
        # all previous decoder key/value_states. Further calls to uni-directional self-attention
        # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
        # if encoder bi-directional self-attention `past_key_value` is always `None`
        past_key_value = (key_layer, value_layer)

BertSelfOutput

  • dense: nn.Linear, $hidden_size \times hidden_size$
  • dropout: nn.Dropout, $hidden_dropout_prob$
  • LayerNorm: nn.LayerNorm, $hidden_size$
def forward(self, hidden_states, input_tensor):
    hidden_states = self.dense(hidden_states)
    hidden_states = self.dropout(hidden_states)
    hidden_states = self.LayerNorm(hidden_states + input_tensor)
    return hidden_states

LayerNorm层的输入是 dropout层的输出 + input_tensor,即先进行残差连接,再做LayerNorm

BertAttention

  • self: BertSelfAttention
  • output: BertSelfOutput
def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_value=None,
        output_attentions=False,
    ):
    self_outputs = self.self(
        hidden_states,
        attention_mask,
        head_mask,
        encoder_hidden_states,
        encoder_attention_mask,
        past_key_value,
        output_attentions,
    )
    attention_output = self.output(self_outputs[0], hidden_states)
    outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them
    return outputs

这个组件比较简单,主要就是作用BertSelfOutput对BertSelfAttention输出中的 context_layer 做了一个变换,最终输出是 (attention_output, [optional] attention_probs, [optional] past_key_value)

BertIntermediate

  • dense: nn.Linear, $hidden_size \times hidden_size$
def forward(self, hidden_states):
    hidden_states = self.dense(hidden_states)
    hidden_states = self.intermediate_act_fn(hidden_states)
    return hidden_states

BertOutput

  • dense: nn.Linear, $hidden_size \times hidden_size$
  • dropout: nn.Dropout, $hidden_dropout_prob$
  • LayerNorm: nn.LayerNorm, $hidden_size$
def forward(self, hidden_states, input_tensor):
    hidden_states = self.dense(hidden_states)
    hidden_states = self.dropout(hidden_states)
    hidden_states = self.LayerNorm(hidden_states + input_tensor)
    return hidden_states

和BertSelfOutput一样,LayerNorm也是先进行残差连接,再做LayerNorm

BertLayer

  • attention: BertAttention
  • intermediate: BertIntermediate
  • output: BertOutput

如果只看 encoder 的部分,BertLayer 还是比较简单的

def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_value=None,
        output_attentions=False,
    ):
    # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
    self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
    self_attention_outputs = self.attention(
        hidden_states,
        attention_mask,
        head_mask,
        output_attentions=output_attentions,
        past_key_value=self_attn_past_key_value,
    )
    attention_output = self_attention_outputs[0]

    # ... 省略 decoder 部分相关代码

    layer_output = apply_chunking_to_forward(
        self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
    )
    outputs = (layer_output,) + outputs

    # ... 省略 decoder 部分相关代码

    return outputs

直接获取 attention 的输出,并使用 feed_forward_chunk 方法变换

def feed_forward_chunk(self, attention_output):
    intermediate_output = self.intermediate(attention_output)
    layer_output = self.output(intermediate_output, attention_output)
    return layer_output

BertLayer 层的输出是 (layout_output, [optional] attention_probs, [opitonal] present_key_value)

BertEncoder

  • layer: nn.ModuleList, [BertLayer] * num_hidden_layers

BertEncoder 由 num_hidden_layers 层 BertLayer 组成。对于每一层 layer,计算输出并保存结果,代码比较简单,就不贴了

BertPooler

  • dense: nn.Linear, $hidden_size \times hidden_size$
  • activation: nn.Tanh

BertPooler即第一个token [CLS] 的输出

def forward(self, hidden_states):
    # We "pool" the model by simply taking the hidden state corresponding
    # to the first token.
    first_token_tensor = hidden_states[:, 0]
    pooled_output = self.dense(first_token_tensor)
    pooled_output = self.activation(pooled_output)
    return pooled_output

Bert Head

Head 部分主要是在输出部分针对一些任务做的封装,代码比较简单

BertPredictionHeadTransform

  • dense: nn.Linear, $hidden_size \times hidden_size$
  • transform_act_fn
  • LayerNorm: nn.LayerNorm, $hidden_size$

BertLMPredictionHead

  • transform: BertPredictionHeadTransform
  • decoder: nn.Linear, $hidden_size \times vocab_size$
  • bias: nn.Parameter, $vocab_size$

BertOnlyMLMHead

  • predictions: BertLMPredictionHead

BertPreTrainingHeads

  • predictions: BertLMPredictionHead
  • seq_relationship: nn.Linear, $hidden_size \times 2$

BertOnlyNSPHead

  • seq_relationship: nn.Linear, $hidden_size \times hidden_size$

Bert

Model 部分的代码针对不同的任务进行了封装,其中 BertModel 是一个基础的类,其它任务类型中都会有一个 BertModel

BertModel

  • embeddings: BertEmbeddings
  • encoder: BertEncoder
  • pooler: BertPooler

BertForPretraining:

  • bert: BertModel
  • cls: BertPretrainingHeads

BertLMHeadModel

  • bert: BertModel
  • cls: BertOnlyMLMHead

BertForMaskedLM

  • bert: BertModel
  • cls: BertOnlyMLMHead

BertForNextSentencePrediction

  • bert: BertModel
  • cls: BertOnlyNSPHead

BertForNextSentencePrediction

  • bert: BertModel
  • dropout: nn.Dropout, $hidden_dropout_prob$
  • classifier: nn.Linear, $hidden_size \times num_labels$

BertForMultipleChoice

  • bert: BertModel
  • dropout: nn.Dropout, $hidden_dropout_prob$
  • classifier: nn.Linear, $hidden_size \times 1$

BertForTokenClassification

  • bert: BertModel
  • dropout: nn.Dropout, $hidden_dropout_prob$
  • classifier: nn.Linear, $hidden_size \times num_labels$

BertForQuestionAnswering

  • bert: BertModel
  • qa_outputs: nn.Linear, $hidden_size \times num_labels$
Written on April 22, 2021
View: