model.classifier:分類頭
分類頭(model.classifier)含義
在基于Transformer架構的模型(如BERT、GPT等 )用于分類任務時,“分類頭(model.classifier)” 是模型的一個重要組成部分。以Hugging Face的Transformers庫為例,許多預訓練模型在完成通用的預訓練任務(如語言建模 )后,為適配具體的分類任務(如情感分析、主題分類 ),會在模型的基礎上添加一個全連接層,這個全連接層就被稱為分類頭。
具體來說,在分類任務中,Transformer模型首先通過編碼器(如BERT的多層雙向Transformer編碼器 )對輸入文本進行特征提取,將輸入序列編碼為特征向量。然后,這些特征向量會被輸入到分類頭(model.classifier ),分類頭再將這些特征映射到不同類別的概率上。比如在二分類的情感分析任務中,分類頭會輸出文本屬于“正面”和“負面”情感的概率,通過比較這兩個概率來確定文本的情感傾向。
舉個例子,使用BERT進行影評情感分析,代碼如下: