圖解BERT

圖解 Bert

大家可以訪問 圖解Bert 獲取更加優質的閱讀體驗。
在這里插入圖片描述
圖解BERT一文還在持續更新中。


環境搭建

按序執行以下命令完成環境搭建:

git clone https://github.com/DA-southampton/Read_Bert_Code.git
cd Read_Bert_Code
conda create -n Read_Bert_Code python=3.9.22
conda activate Read_Bert_Code

本文使用的是谷歌的中文預訓練模型:chinese_L-12_H-768_A-12.zip,模型有點大,我就不上傳了,如果本地不存在,就點擊這里直接下載,或者直接命令行運行

wget https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip

預訓練模型下載下來之后,進行解壓,然后將tf模型轉為對應的pytorch版本即可。對應代碼如下:

export BERT_BASE_DIR=/Users/zhandaohong/Read_Bert_Code/chinese_L-12_H-768_A-12python convert_tf_checkpoint_to_pytorch.py \--tf_checkpoint_path $BERT_BASE_DIR/bert_model.ckpt \--bert_config_file $BERT_BASE_DIR/bert_config.json \--pytorch_dump_path $BERT_BASE_DIR/pytorch_model.bin

轉化成功之后,將模型放入到倉庫對應位置:

Read_Bert_Code/bert_read_step_to_step/prev_trained_model/

并重新命名為:

 bert-base-chinese

其次是準備訓練數據,這里我準備做一個文本分類任務,使用的是Tnews數據集,這個數據集來源是這里,分為訓練,測試和開發集,我已經上傳到了倉庫中,具體位置在

Read_Bert_Code/bert_read_step_to_step/chineseGLUEdatasets/tnews

需要注意的一點是,因為我只是為了了解內部代碼情況,所以準確度不是在我的考慮范圍之內,所以我只是取其中的一部分數據,其中訓練數據使用1k,測試數據使用1k,開發數據1k。

準備就緒,使用pycharm導入項目,準備調試,我的調試文件是run_classifier.py文件,對應的參數為

--model_type=bert --model_name_or_path=prev_trained_model/bert-base-chinese --task_name="tnews" --do_train --do_eval --do_lower_case --data_dir=./chineseGLUEdatasets/tnews --max_seq_length=128 --per_gpu_train_batch_size=16 --per_gpu_eval_batch_size=16 --learning_rate=2e-5 --num_train_epochs=4.0 --logging_steps=100 --save_steps=100 --output_dir=./outputs/tnews_output/ --overwrite_output_dir

然后啟動 run_classifier.py 文件進行調試即可 , 所參考源倉庫未提供requirements.txt文件,因此需要大家自行完成運行時缺失依賴包的安裝。

數據預處理

  1. 輸入數據格式
{"guid": "train-0","label": "104",              // 文本分類任務: 文本對應的標簽"text_a": "股票中的突破形態", "text_b": null               // NSP任務: 用于判斷給出的兩個句子是否連續
}

NSP (Next Sentence Prediction)

  1. 文本分詞 & 借助字典映射為word id
"股票中的突破形態" --> ['股', '票', '中', '的', '突', '破', '形', '態'] --> [5500, 4873, 704, 4638, 4960, 4788, 2501, 2578]

對于字典中不存在的詞 , 用 [UNK] 表示, 對應的id為 100

  1. 過長截斷策略 (待補充)

  2. 添加特殊Token標記

在這里插入圖片描述

[101, 5500, 4873, 704, 4638, 4960, 4788, 2501, 2578, 102]

BertTokenizer中的特殊token id:

  • [CLS]: 101
  • [SEP]: 102
  • [MASK]: 103
  • [UNK]: 100
  • [PAD]: 0
    # BertTokenizerdef build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):if token_ids_1 is None:return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]cls = [self.cls_token_id]sep = [self.sep_token_id]return cls + token_ids_0 + sep + token_ids_1 + sep
  1. 創建句子辨識列表,用以區分不同的句子

在這里插入圖片描述

     # BertTokenizerdef create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None):"""Creates a mask from the two sequences passed to be used in a sequence-pair classification task.A BERT sequence pair mask has the following format:0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1| first sequence    | second sequenceif token_ids_1 is None, only returns the first portion of the mask (0's)."""sep = [self.sep_token_id]cls = [self.cls_token_id]if token_ids_1 is None:return len(cls + token_ids_0 + sep) * [0]return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
  1. 創建用以區分special tokens部分的mask列表

在這里插入圖片描述

    # BertTokenizerdef get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False):if token_ids_1 is not None:return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]return [1] + ([0] * len(token_ids_0)) + [1]
  1. 超長截斷
       # PreTrainedTokenizerif max_length and len(encoded_inputs["input_ids"]) > max_length:encoded_inputs["input_ids"] = encoded_inputs["input_ids"][:max_length]encoded_inputs["token_type_ids"] = encoded_inputs["token_type_ids"][:max_length]encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"][:max_length]
  1. 生成padding部分的mask列表

在這里插入圖片描述

        # 生成注意力掩碼,真實token對應1,填充token對應0attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
  1. 所有序列都填充到max_length長度,不足長度用padding填充

在這里插入圖片描述

        # 記錄輸入長度input_len = len(input_ids)# 計算需要填充的長度 --- 所有輸入序列等長,都等于max_lengthpadding_length = max_length - len(input_ids)# 右填充input_ids = input_ids + ([pad_token] * padding_length)attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length)
  1. 數據集中每一個樣本最終都會解析得到一個InputFeatures

在這里插入圖片描述

features.append(InputFeatures(input_ids=input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids,label=label,input_len=input_len))

label 是當前文本對應的類別標簽
input_len 是序列實際長度(含special tokens)

  1. 數據集預處理完后,將InputFeatures List列表組裝起來得到需要的DataSet
dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_lens,all_labels)

模型架構

DataLoader

    train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size,collate_fn=collate_fn)

DataLoader 設置的回調方法cllote_fn負責對返回的一個batch,在返回前進行預處理:

def collate_fn(batch):all_input_ids, all_attention_mask, all_token_type_ids, all_lens, all_labels = map(torch.stack, zip(*batch))max_len = max(all_lens).item() # 計算當前批次中所有序列的實際最大長度all_input_ids = all_input_ids[:, :max_len] # 按照本批次序列中最大長度進行截斷: max_length --> max_lenall_attention_mask = all_attention_mask[:, :max_len]all_token_type_ids = all_token_type_ids[:, :max_len]return all_input_ids, all_attention_mask, all_token_type_ids, all_labels

BertEmbeddings

在這里插入圖片描述

class BertEmbeddings(nn.Module):def __init__(self, config):super(BertEmbeddings, self).__init__()self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)self.dropout = nn.Dropout(config.hidden_dropout_prob)def forward(self, input_ids, token_type_ids=None, position_ids=None):seq_length = input_ids.size(1)if position_ids is None: # 為當前批次中的每個序列樣本生成一個位置序列: (1,2,3,4,5,...) , 構成一個位置序列矩陣position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)position_ids = position_ids.unsqueeze(0).expand_as(input_ids)if token_type_ids is None:token_type_ids = torch.zeros_like(input_ids)words_embeddings = self.word_embeddings(input_ids)position_embeddings = self.position_embeddings(position_ids) # 位置編碼為可學習的矩陣token_type_embeddings = self.token_type_embeddings(token_type_ids) # 讓模型自己學會區分不同的句子embeddings = words_embeddings + position_embeddings + token_type_embeddingsembeddings = self.LayerNorm(embeddings)embeddings = self.dropout(embeddings)return embeddings

在這里插入圖片描述

BertEncoder

BertLayer

在這里插入圖片描述

class BertIntermediate(nn.Module):def __init__(self, config):super(BertIntermediate, self).__init__()self.dense = nn.Linear(config.hidden_size, config.intermediate_size) # (768,3072)# 激活函數 - GLEUif isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):self.intermediate_act_fn = ACT2FN[config.hidden_act]else:self.intermediate_act_fn = config.hidden_actdef forward(self, hidden_states):hidden_states = self.dense(hidden_states)hidden_states = self.intermediate_act_fn(hidden_states)  # 激活函數 - GLEUreturn hidden_statesclass BertOutput(nn.Module):def __init__(self, config):super(BertOutput, self).__init__()self.dense = nn.Linear(config.intermediate_size, config.hidden_size) # (3072,768)self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)self.dropout = nn.Dropout(config.hidden_dropout_prob)def forward(self, hidden_states, input_tensor):hidden_states = self.dense(hidden_states)hidden_states = self.dropout(hidden_states)hidden_states = self.LayerNorm(hidden_states + input_tensor)return hidden_statesclass BertLayer(nn.Module):def __init__(self, config):super(BertLayer, self).__init__()self.attention = BertAttention(config)self.intermediate = BertIntermediate(config)self.output = BertOutput(config)def forward(self, hidden_states, attention_mask=None):attention_output = self.attention(hidden_states, attention_mask)intermediate_output = self.intermediate(attention_output)layer_output = self.output(intermediate_output, attention_output)return layer_output
BertEncoder

在這里插入圖片描述

class BertEncoder(nn.Module):def __init__(self, config):super(BertEncoder, self).__init__()self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])def forward(self, hidden_states, attention_mask=None, head_mask=None):for i, layer_module in enumerate(self.layer):hidden_states = layer_module(hidden_states, attention_mask, head_mask[i])return hidden_states

BertPooler

在這里插入圖片描述

class BertPooler(nn.Module):def __init__(self, config):super(BertPooler, self).__init__()self.dense = nn.Linear(config.hidden_size, config.hidden_size)self.activation = nn.Tanh()def forward(self, hidden_states):# We "pool" the model by simply taking the hidden state corresponding# to the first token.first_token_tensor = hidden_states[:, 0] # CLS Token Context Embeddingspooled_output = self.dense(first_token_tensor)pooled_output = self.activation(pooled_output)return pooled_output

BertModel

在這里插入圖片描述

class BertModel(BertPreTrainedModel):def __init__(self, config):super(BertModel, self).__init__(config)self.embeddings = BertEmbeddings(config)self.encoder = BertEncoder(config)self.pooler = BertPooler(config)self.init_weights()def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibilityextended_attention_mask = (1.0 - extended_attention_mask) * -10000.0embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids)sequence_output = self.encoder(embedding_output,extended_attention_mask, # padding mask)pooled_output = self.pooler(sequence_output)outputs = (sequence_output, pooled_output,)return outputs

BertForSequenceClassification

在這里插入圖片描述

class BertForSequenceClassification(BertPreTrainedModel):def __init__(self, config):super(BertForSequenceClassification, self).__init__(config)self.num_labels = config.num_labelsself.bert = BertModel(config)self.dropout = nn.Dropout(config.hidden_dropout_prob)self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)self.init_weights()def forward(self, input_ids, attention_mask=None, token_type_ids=None,position_ids=None, head_mask=None, labels=None):outputs = self.bert(input_ids,attention_mask=attention_mask, # padding masktoken_type_ids=token_type_ids,position_ids=position_ids, head_mask=head_mask) # None ?pooled_output = outputs[1] # 對于分類任務來說,只需要去除CLS Token用于分類任務即可pooled_output = self.dropout(pooled_output)logits = self.classifier(pooled_output)outputs = (logits,) + outputs[2:]  # add hidden states and attention if they are hereif labels is not None:if self.num_labels == 1:#  We are doing regressionloss_fct = MSELoss()loss = loss_fct(logits.view(-1), labels.view(-1))else:loss_fct = CrossEntropyLoss()loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))outputs = (loss,) + outputsreturn outputs  # (loss), logits, (hidden_states), (attentions)

BertAttention

BertSelfAttention

在這里插入圖片描述

class BertSelfAttention(nn.Module):def __init__(self, config):super(BertSelfAttention, self).__init__()self.output_attentions = config.output_attentionsself.num_attention_heads = config.num_attention_headsself.attention_head_size = int(config.hidden_size / config.num_attention_heads)self.all_head_size = self.num_attention_heads * self.attention_head_sizeself.query = nn.Linear(config.hidden_size, self.all_head_size)self.key = nn.Linear(config.hidden_size, self.all_head_size)self.value = nn.Linear(config.hidden_size, self.all_head_size)self.dropout = nn.Dropout(config.attention_probs_dropout_prob)def transpose_for_scores(self, x):new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)x = x.view(*new_x_shape)return x.permute(0, 2, 1, 3)def forward(self, hidden_states, attention_mask=None, head_mask=None):mixed_query_layer = self.query(hidden_states)mixed_key_layer = self.key(hidden_states)mixed_value_layer = self.value(hidden_states)# view 成多頭格式: (batch,heads,seq_len,d_k)query_layer = self.transpose_for_scores(mixed_query_layer)key_layer = self.transpose_for_scores(mixed_key_layer)value_layer = self.transpose_for_scores(mixed_value_layer)# Take the dot product between "query" and "key" to get the raw attention scores.attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) # (batch,heads,d_k,seq_len)attention_scores = attention_scores / math.sqrt(self.attention_head_size)if attention_mask is not None:# Apply the attention mask is (precomputed for all layers in BertModel forward() function)attention_scores = attention_scores + attention_mask# Normalize the attention scores to probabilities.attention_probs = nn.Softmax(dim=-1)(attention_scores)# This is actually dropping out entire tokens to attend to, which might# seem a bit unusual, but is taken from the original Transformer paper.attention_probs = self.dropout(attention_probs)context_layer = torch.matmul(attention_probs, value_layer)context_layer = context_layer.permute(0, 2, 1, 3).contiguous()new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)context_layer = context_layer.view(*new_context_layer_shape) # 合并頭結果return context_layer
BertSelfOutput

在這里插入圖片描述

class BertSelfOutput(nn.Module):def __init__(self, config):super(BertSelfOutput, self).__init__()self.dense = nn.Linear(config.hidden_size, config.hidden_size)self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)self.dropout = nn.Dropout(config.hidden_dropout_prob)# 殘差鏈接 + 層歸一化def forward(self, hidden_states, input_tensor):hidden_states = self.dense(hidden_states)hidden_states = self.dropout(hidden_states)hidden_states = self.LayerNorm(hidden_states + input_tensor)return hidden_states
BertAttention

在這里插入圖片描述

class BertAttention(nn.Module):def __init__(self, config):super(BertAttention, self).__init__()self.self = BertSelfAttention(config)self.output = BertSelfOutput(config)def forward(self, input_tensor, attention_mask=None):self_outputs = self.self(input_tensor, attention_mask) # 多頭自注意力機制attention_output = self.output(self_outputs, input_tensor)return attention_output

預訓練

在這里插入圖片描述

BertPredictionHeadTransform

在這里插入圖片描述

class BertPredictionHeadTransform(nn.Module):def __init__(self, config):super(BertPredictionHeadTransform, self).__init__()self.dense = nn.Linear(config.hidden_size, config.hidden_size)if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):self.transform_act_fn = ACT2FN[config.hidden_act]else:self.transform_act_fn = config.hidden_actself.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)def forward(self, hidden_states):hidden_states = self.dense(hidden_states)hidden_states = self.transform_act_fn(hidden_states)hidden_states = self.LayerNorm(hidden_states)return hidden_states

BertLMPredictionHead

在這里插入圖片描述

class BertLMPredictionHead(nn.Module):def __init__(self, config):super(BertLMPredictionHead, self).__init__()self.transform = BertPredictionHeadTransform(config)# The output weights are the same as the input embeddings, but there is# an output-only bias for each token.self.decoder = nn.Linear(config.hidden_size,config.vocab_size,bias=False)self.bias = nn.Parameter(torch.zeros(config.vocab_size))def forward(self, hidden_states):hidden_states = self.transform(hidden_states)hidden_states = self.decoder(hidden_states) + self.biasreturn hidden_states

BertPreTrainingHeads

在這里插入圖片描述

class BertPreTrainingHeads(nn.Module):def __init__(self, config):super(BertPreTrainingHeads, self).__init__()self.predictions = BertLMPredictionHead(config)self.seq_relationship = nn.Linear(config.hidden_size, 2)def forward(self, sequence_output, pooled_output):prediction_scores = self.predictions(sequence_output) #seq_relationship_score = self.seq_relationship(pooled_output) # 兩個句子是否為上下句關系return prediction_scores, seq_relationship_score

BertForPreTraining

在這里插入圖片描述

class BertForPreTraining(BertPreTrainedModel):def __init__(self, config):super(BertForPreTraining, self).__init__(config)self.bert = BertModel(config)self.cls = BertPreTrainingHeads(config)def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,masked_lm_labels=None, next_sentence_label=None):outputs = self.bert(input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids,position_ids=position_ids, head_mask=head_mask)sequence_output, pooled_output = outputs[:2] # 隱藏層輸出,CLS Token Embeddingsprediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)outputs = (prediction_scores, seq_relationship_score,)# 計算掩碼語言損失 和 下一個句子預測損失if masked_lm_labels is not None and next_sentence_label is not None:loss_fct = CrossEntropyLoss(ignore_index=-1)masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))total_loss = masked_lm_loss + next_sentence_lossoutputs = (total_loss,) + outputsreturn outputs  # (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions)

其他下游任務

在這里插入圖片描述

問答任務

在 BERT 的問答任務中,典型的輸入是一個包含 問題(Question)上下文(Context) 的文本對。例如:

問題: “誰寫了《哈姆雷特》?”
上下文: “莎士比亞是英國文學史上最偉大的作家之一,他寫了包括《哈姆雷特》、《麥克白》等著名悲劇。”

  1. 輸入格式(Tokenization 后的形式),在使用 BertTokenizer 編碼后,輸入會變成如下結構:
[CLS] 問題 tokens [SEP] 上下文 tokens [SEP]
  1. BERT 的輸出(Outputs),通過調用 self.bert(...),你將得到一個包含多個元素的 tuple 輸出:
outputs = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)

返回值形如:

(sequence_output,          # (batch_size, seq_length, hidden_size)pooled_output,            # (batch_size, hidden_size)
)

主要輸出項解釋:

? sequence_output: 最終每個 token 的表示

  • 形狀:(batch_size, seq_length, hidden_size)
  • 是模型最后一層所有 token(包括問題和上下文)的隱藏狀態。
  • 在問答任務中,我們主要使用它來預測答案的起始和結束位置。

? pooled_output: 句子級別表示(不常用)

  • 形狀:(batch_size, hidden_size)
  • [CLS] token 經過一層全連接后的輸出。
  • 在分類任務中更有用,在問答任務中一般不會使用這個輸出。
  1. 如何利用 BERT 輸出做問答預測?

BertForQuestionAnswering 中,使用了如下邏輯:

logits = self.qa_outputs(sequence_output)  # (batch_size, seq_length, 2)
start_logits, end_logits = logits.split(1, dim=-1)  # split into start and end
start_logits = start_logits.squeeze(-1)  # (batch_size, seq_length)
end_logits = end_logits.squeeze(-1)

qa_outputs 層的作用:

  • 是一個線性層:nn.Linear(config.hidden_size, 2)
  • 將每個 token 的 hidden_size 向量映射成兩個分數:一個是該 token 作為答案開始的可能性,另一個是作為答案結束的可能性。

輸出解釋:

  • start_logits: 每個 token 是答案起點的得分(未歸一化)。
  • end_logits: 每個 token 是答案終點的得分。

比如對于一個長度為 128 的序列,每個 token 都有一個對應的 start/end 分數:

start_scores = torch.softmax(start_logits, dim=-1)  # softmax 得到概率
end_scores = torch.softmax(end_logits, dim=-1)# 找出最可能是 start 和 end 的位置
start_index = torch.argmax(start_scores)
end_index = torch.argmax(end_scores)

如果 start_index <= end_index,那么可以組合這兩個索引得到答案 span。

代碼實現
class BertForQuestionAnswering(BertPreTrainedModel):def __init__(self, config):super(BertForQuestionAnswering, self).__init__(config)self.num_labels = config.num_labels # 通常是 2,即 start 和 endself.bert = BertModel(config)self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,start_positions=None, end_positions=None):outputs = self.bert(input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids,position_ids=position_ids)sequence_output = outputs[0]# (batch,seq_len,hidden_size) ---> (batch,seq_len,2)logits = self.qa_outputs(sequence_output)start_logits, end_logits = logits.split(1, dim=-1)start_logits = start_logits.squeeze(-1) # (batch,seq_len)end_logits = end_logits.squeeze(-1)outputs = (start_logits, end_logits,)# 計算交叉熵損失if start_positions is not None and end_positions is not None:# sometimes the start/end positions are outside our model inputs, we ignore these terms# ignored_index = seq_lenignored_index = start_logits.size(1)# clamp_ 是 PyTorch 中的一個方法,用于將張量中的值限制在指定的范圍內。# 它的語法是 tensor.clamp_(min, max) ,表示將張量中的值限制在 min 和 max 之間。# 如果值小于 min ,則將其設置為 min ;如果值大于 max ,則將其設置為 max 。start_positions.clamp_(0, ignored_index)end_positions.clamp_(0, ignored_index)# ignore_index: 用于指定在計算損失時忽略的標簽索引。 loss_fct = CrossEntropyLoss(ignore_index=ignored_index)# 分別計算答案起始下標和結束下標預測得到的交叉熵損失start_loss = loss_fct(start_logits, start_positions)end_loss = loss_fct(end_logits, end_positions)total_loss = (start_loss + end_loss) / 2outputs = (total_loss,) + outputsreturn outputs  # (loss), start_logits, end_logits
易混淆

BERT 是一個 基于上下文編碼(Contextual Encoder) 的模型,不是自回歸生成器。它不會“生成”新的文本,而是對輸入文本中每個 token 的角色進行分類(如判斷哪個是答案的開始、結束)。所以最終的答案只能來自原始輸入文本中的某一段子串。

📚 詳細解釋

  1. ? BERT 是一個 Encoder-only 模型
  • BERT 只包含 Transformer 的 encoder 部分。

  • 它的作用是給定一個完整的句子(或兩個句子),對每個 token 生成一個上下文相關的表示(contextualized representation)。

  • 不具有生成能力,不能像 GPT 這樣的 decoder-only 模型那樣逐詞生成新內容。


  1. 🔍 QA 任務的本質:定位答案 span 而非生成答案

在 SQuAD 這類抽取式問答任務中:

  • 答案必須是原文中的連續片段(span)。

  • 所以模型的任務是:

    • 給出問題和上下文;

    • 在上下文中找到最可能的答案起始位置和結束位置;

    • 最終答案就是上下文中這兩個位置之間的字符串。

BERT 做的就是這個定位任務,而不是重新生成一個新的答案。


  1. 🧩 輸入與輸出的關系
answer_tokens = input_ids[0][start_index : end_index + 1]
answer = tokenizer.decode(answer_tokens, skip_special_tokens=True)

這段代碼的意思是:

  • start_indexend_index 是模型預測出的答案的起始和結束位置。

  • 我們從原始輸入的 input_ids 中取出對應的 token ID 子序列。

  • 使用 tokenizer 把這些 token ID 解碼成自然語言文本。

  • 得到的就是答案。

這其實就是在說:

“根據你的理解,答案應該在這段文字中的第 X 到第 Y 個詞之間,請把這部分原文告訴我。”


  1. 🧪 舉個例子

假設原始上下文是:

The capital of France is Paris.

經過 Tokenizer 編碼后可能是:

[CLS] the capital of france is paris [SEP]

如果模型預測 start_index=5,end_index=5,那么對應的就是單詞 "paris",這就是答案。


?? 注意事項

  1. 不能超出上下文范圍

    • start/end positions 必須落在上下文部分(即 token_type_id == 1 的區域)。
    • 否則答案可能不合理(比如取到了問題部分的內容)。
  2. 特殊 token 不計入答案

    • [CLS], [SEP] 等會被 skip_special_tokens=True 自動跳過。
  3. 無法處理不在原文中的答案

    • 如果正確答案沒有出現在上下文中,BERT 無法“編造”出來。
    • 這是抽取式問答模型的局限性。

💡 對比:生成式 vs 抽取式問答

類型模型代表是否能生成新文本答案是否必須在原文中示例
抽取式BERT??答案是原文中的一段
生成式T5 / BART / GPT??答案可以是任意文本

如果你希望模型能“自己寫答案”,那就需要使用生成式模型。


? 總結

問題回答
為什么答案來自 input_ids因為 BERT 是編碼器模型,只做抽取式問答,答案必須是原文中的一段文本。
BERT 能不能自己生成答案?不能,BERT 不具備生成能力,只能對輸入文本中的 token 做分類。
如何獲取答案?根據預測的 start/end index,從 input_ids 中提取 token,并用 tokenizer 解碼成自然語言。

Token分類任務

Token 分類任務是指對輸入文本中的每個 token 進行分類,常見的應用場景包括:

  • 命名實體識別 (NER)
  • 詞性標注 (POS)
  • 語義角色標注 (SRL)
class BertForTokenClassification(BertPreTrainedModel):def __init__(self, config):super(BertForTokenClassification, self).__init__(config)self.num_labels = config.num_labelsself.bert = BertModel(config)self.dropout = nn.Dropout(config.hidden_dropout_prob)self.classifier = nn.Linear(config.hidden_size, config.num_labels)def forward(self, input_ids, attention_mask=None, token_type_ids=None,position_ids=None, head_mask=None, labels=None):outputs = self.bert(input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids,position_ids=position_ids, head_mask=head_mask)sequence_output = outputs[0] # (batch,seq_len,hidden_size)sequence_output = self.dropout(sequence_output)logits = self.classifier(sequence_output) # (batch,seq_len,num_labels)outputs = (logits,)if labels is not None:loss_fct = CrossEntropyLoss()# Only keep active parts of the lossif attention_mask is not None:active_loss = attention_mask.view(-1) == 1active_logits = logits.view(-1, self.num_labels)[active_loss]active_labels = labels.view(-1)[active_loss]loss = loss_fct(active_logits, active_labels)else:loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))outputs = (loss,) + outputsreturn outputs  # (loss), scores

多項選擇任務

多項選擇任務是指給定一個問題和多個候選答案,模型需要從中選擇最合適的答案。常見的應用場景包括:

  • 閱讀理解任務

  • 問答系統中的候選答案選擇

  • 對話系統中的候選回復選擇

在 多項選擇題(Multiple Choice) 任務中,BERT 的輸入組織形式與普通分類或問答任務略有不同。你需要為每個選項分別構造一個完整的 BERT 輸入序列,并將它們組合成一個批次進行處理。

? 假設你有一個問題 + 4 個選項:

問題:誰寫了《哈姆雷特》?
A. 雨果
B. 歌德
C. 莎士比亞
D. 托爾斯泰

對于這樣的多選問題,BERT 的輸入方式是:

對每一個選項,都單獨構造一個 [CLS] + 問題 + [SEP] + 選項內容 + [SEP] 的輸入序列。

也就是說,模型會對每個選項分別編碼 ,然后從中選出最合適的那個。

class BertForMultipleChoice(BertPreTrainedModel):def __init__(self, config):super(BertForMultipleChoice, self).__init__(config)self.bert = BertModel(config)self.dropout = nn.Dropout(config.hidden_dropout_prob)self.classifier = nn.Linear(config.hidden_size, 1)def forward(self, input_ids, attention_mask=None, token_type_ids=None,position_ids=None, head_mask=None, labels=None):# 獲取選項個數        num_choices = input_ids.shape[1] # (batch_size, num_choices, seq_length)# 將選項展平,以便一起處理: (batch_size * num_choices, seq_length)input_ids = input_ids.view(-1, input_ids.size(-1))attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else Nonetoken_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else Noneposition_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else Noneoutputs = self.bert(input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids,position_ids=position_ids,head_mask=head_mask)pooled_output = outputs[1] # (batch_size * num_choices, hidden_size)pooled_output = self.dropout(pooled_output)logits = self.classifier(pooled_output) # (batch_size * num_choices, 1)reshaped_logits = logits.view(-1, num_choices) # (batch_size , num_choices, 1)outputs = (reshaped_logits,)if labels is not None:loss_fct = CrossEntropyLoss()loss = loss_fct(reshaped_logits, labels)outputs = (loss,) + outputsreturn outputs  # (loss), reshaped_logits, (hidden_states), (attentions)

在前向傳播中,會將這些輸入展平,變成:

input_ids.view(-1, seq_length)  # (batch_size * num_choices, seq_length)

這樣就能讓 BERT 對每個選項分別進行編碼。

BERT 輸出后,再對每個選項做分類打分,最后重新 reshape 成 (batch_size, num_choices) 形式,用于計算交叉熵損失。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/82377.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/82377.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/82377.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

【HarmonyOS 5】鴻蒙中的UIAbility詳解(一)

【HarmonyOS 5】鴻蒙中的UIAbility詳解&#xff08;一&#xff09; 一、UIAbility是什么&#xff1f; Stage模型中的組件類型名&#xff0c;即UIAbility組件&#xff0c;包含UI&#xff0c;提供展示UI的能力&#xff0c;主要用于和用戶交互。 UIAbility類似于傳統移動開發An…

Transformer預訓練模型微調技術全解析

引言:Transformer預訓練模型與微調的浪潮 近年來,人工智能領域取得了令人矚目的成就,特別是在自然語言處理(NLP)方面。引領這場變革的核心技術之一便是Transformer架構。自2017年 Vaswani 等人在論文 "Attention Is All You Need" 中提出以來,Transformer憑借…

《算法筆記》12.2小節——字符串專題->KMP算法 問題 C: 剪花布條

題目描述 一塊花布條&#xff0c;里面有些圖案&#xff0c;另有一塊直接可用的小飾條&#xff0c;里面也有一些圖案。對于給定的花布條和小飾條&#xff0c;計算一下能從花布條中盡可能剪出幾塊小飾條來呢&#xff1f; 輸入 輸入中含有一些數據&#xff0c;分別是成對出現的…

實現一個前端動態模塊組件(Vite+原生JS)

1. 引言 在前面的文章《使用Vite創建一個動態網頁的前端項目》中我們實現了一個動態網頁。不過這個動態網頁的實用價值并不高&#xff0c;在真正實際的項目中我們希望的是能實現一個動態的模塊組件。具體來說&#xff0c;就是有一個頁面控件同時在多個頁面中使用&#xff0c;那…

NTFS0x90屬性和0xa0屬性和0xb0屬性的一一對應關系是index_entry中的index_node中VCN和runlist和bitmap

第一部分&#xff1a; 0: kd> dt _FILE_RECORD_SEGMENT_HEADER 0xc1241400 Ntfs!_FILE_RECORD_SEGMENT_HEADER 0x000 MultiSectorHeader : _MULTI_SECTOR_HEADER 0x008 Lsn : _LARGE_INTEGER 0x80e74aa 0x010 SequenceNumber : 5 0x012 Referen…

PCB 通孔是電容性的,但不一定是電容器

哼&#xff1f;……這是什么意思&#xff1f;…… 多年來&#xff0c;流行的觀點是 PCB 通孔本質上是電容性的&#xff0c;因此可以用集總電容器進行建模。雖然當信號的上升時間大于或等于過孔不連續性延遲的 3 倍時&#xff0c;這可能是正確的&#xff0c;但我將向您展示為什…

Flutter 3.32 新特性

2天前&#xff0c;Flutter發布了最新版本3.32&#xff0c;我們來一起看下29到32有哪些變化。 簡介 歡迎來到Flutter 3.32&#xff01;此版本包含了旨在加速開發和增強應用程序的功能。準備好在網絡上進行熱加載&#xff0c;令人驚嘆的原生保真Cupertino&#xff0c;以及與Fir…

漢諾塔超級計算機數據區結構和源代碼詳細設計

### 數據區結構與源代碼詳細設計 基于"滿秩二叉樹"存儲模型的設計理念&#xff0c;我設計了以下數據區結構和實現方案&#xff1a; #### 1. 滿秩二叉樹存儲模型 **數據結構設計**&#xff1a; python class TreeNode: """二叉樹節點結構&#xff0c…

GitHub Copilot 現已支持 AI Coding Agent

VS Code 開始越來越像 Cursor 和 WindSurf 了。 這周,GitHub 發布了一個新的編程代理,直接嵌入到 GitHub 中。當你將 GitHub 問題分配給 Copilot 或在 VS Code 中提示它時,該代理會啟動一個由 GitHub Actions 驅動的安全且完全可定制的開發環境。 這一公告來自微軟首席執行…

【辰輝創聚生物】FGF信號通路相關蛋白:解碼生命調控的關鍵樞紐

在生命科學的探索旅程中&#xff0c;成纖維細胞生長因子&#xff08;Fibroblast Growth Factor&#xff0c;FGF&#xff09;信號通路猶如精密儀器中的核心齒輪&#xff0c;驅動著眾多生命活動的有序進行。FGF 信號通路相關蛋白作為該通路的重要組成部分&#xff0c;其結構與功能…

算法的學習筆記— 構建乘積數組(牛客JZ66)

構建乘積數組 1. 問題背景與描述 1.1 題目來源與鏈接 本題來源于NowCoder在線編程平臺&#xff0c;是劍指Offer系列面試題中的經典問題。題目鏈接為&#xff1a;NowCoder。該問題在算法面試中出現頻率較高&#xff0c;主要考察數組操作和數學思維。 1.2 問題描述與要求 給…

SpringBoot+ELK 搭建日志監控平臺

ELK 簡介 ELK&#xff08;Elasticsearch, Logstash, Kibana&#xff09;是一個目前主流的開源日志監控平臺。由三個主要組件組成的&#xff1a; Elasticsearch&#xff1a; 是一個開源的分布式搜索和分析引擎&#xff0c;可以用于全文檢索、結構化檢索和分析&#xff0c;它構建…

python36

仔細回顧一下神經網絡到目前的內容&#xff0c;沒跟上進度的同學補一下進度。 作業&#xff1a;對之前的信貸項目&#xff0c;利用神經網絡訓練下&#xff0c;嘗試用到目前的知識點讓代碼更加規范和美觀。 # 先運行之前預處理好的代碼 import pandas as pd import pandas as pd…

SGlang 推理模型優化(PD架構分離)

一、技術背景 隨著大型語言模型&#xff08;LLM&#xff09;廣泛應用于搜索、內容生成、AI助手等領域&#xff0c;對模型推理服務的并發能力、響應延遲和資源利用效率提出了前所未有的高要求。與模型訓練相比&#xff0c;推理是一個持續進行、資源消耗巨大的任務&#xff0c;尤…

模型實戰(28)之 yolov5分類模型 訓練自己的數據集

模型實戰(28)之 yolov5分類模型 訓練自己的數據集 本文以手寫數字數據集為例總結YOLO分類模型如何訓練自己的數據集,關于數據集的預處理可以看這篇:https://blog.csdn.net/yohnyang/article/details/148209978?spm=1001.2014.3001.5502 yolov5曾是在 2021-2023 年十分流行…

醫學寫作人才管理策略

1. 人才選擇:精準定位核心能力 1.1 人才篩選標準 1.1.1 硬性要求 初創生物制藥公司醫學寫作崗位對專業背景要求嚴格,候選人需具備醫學、藥學或生物學碩士及以上學歷,博士優先。同時,熟悉ICH、FDA/EMA等法規指南是必備條件,且至少有1-3年醫學寫作經驗,或相關領域如臨床研…

Axure酒店管理系統原型

酒店管理系統通常被設計為包含多個模塊或界面&#xff0c;以支持酒店運營的不同方面和參與者。其中&#xff0c;管理端和商戶端是兩個核心組成部分&#xff0c;它們各自承擔著不同的職責和功能。 軟件版本&#xff1a;Axure RP 9 預覽地址&#xff1a;https://556i1e.axshare.…

云原生安全之HTTP協議:從基礎到實戰的安全指南

&#x1f525;「炎碼工坊」技術彈藥已裝填&#xff01; 點擊關注 → 解鎖工業級干貨【工具實測|項目避坑|源碼燃燒指南】 一、基礎概念&#xff1a;HTTP協議的核心要素 HTTP&#xff08;HyperText Transfer Protocol&#xff09;是云原生應用中客戶端與服務器通信的基礎協議&a…

怎樣解決photoshop閃退問題

檢查系統資源&#xff1a;在啟動 Photoshop 之前&#xff0c;打開任務管理器檢查 CPU 和內存的使用情況。如果發現資源占用過高&#xff0c;嘗試關閉不必要的程序或重啟計算機以釋放資源。更新 Photoshop 版本&#xff1a;確保 Photoshop 是最新版本。Adobe 經常發布更新以修復…

修復ubuntu server筆記本合蓋導致的無線網卡故障

下班回到家發現走時還好的局域網 ubuntu server 24 連不上了&#xff0c;趕緊打開筆記本查看下原因&#xff0c;發現控制臺出了一堆看不懂的內容&#xff1a; 根據搜索結果&#xff0c;筆記本合蓋導致無線網卡故障可能與電源管理設置和系統休眠策略有關&#xff0c;以下是具體…