在嘗試復現極客專欄《PyTorch 深度學習實戰|24 | 文本分類:如何使用BERT構建文本分類模型?》時候,構建模型這一步驟專欄老師一筆帶過,對于新手有些不友好,經過一陣摸索,終于調通了,現在總結一下整體流程。
1. 獲取必要腳本文件
首先,我們需要從 Transformers 的 GitHub 倉庫中找到相關文件:
# 克隆 Transformers 倉庫
git clone https://github.com/huggingface/transformers.git
cd transformers
在倉庫中,我們需要找到以下關鍵文件:
src/transformers/models/bert/convert_bert_original_tf_checkpoint_to_pytorch.py
(用于 TF1.x 模型)src/transformers/models/bert/convert_bert_original_tf2_checkpoint_to_pytorch.py
(用于 TF2.x 模型)src/transformers/models/bert/modeling_bert.py
(BERT 的 PyTorch 實現)
2. 下載預訓練模型
接下來,我們需要下載 Google 提供的預訓練 BERT 模型。根據你的需求,我們選擇"BERT-Base, Multilingual Cased"版本,它支持104種語言。
訪問 Google 的 BERT GitHub 頁面:https://github.com/google-research/bert
在該頁面中找到"BERT-Base, Multilingual Cased"的下載鏈接,或直接使用以下命令下載:
mkdir bert-base-multilingual-cased
cd bert-base-multilingual-cased# 下載模型文件
wget https://storage.googleapis.com/bert_models/2018_11_23/multi_cased_L-12_H-768_A-12.zip
unzip multi_cased_L-12_H-768_A-12.zip
解壓后,你會得到以下文件:
bert_model.ckpt.data-00000-of-00001
bert_model.ckpt.index
bert_model.ckpt.meta
bert_config.json
vocab.txt
3. 模型轉換
現在,我們使用之前找到的轉換腳本將 TensorFlow 模型轉換為 PyTorch 格式:
# 回到 transformers 目錄
cd ../transformers# 執行轉換腳本(針對 TF2.x 模型)
python src/transformers/models/bert/convert_bert_original_tf2_checkpoint_to_pytorch.py \--tf_checkpoint_path ../bert-base-multilingual-cased/bert_model.ckpt \--bert_config_file ../bert-base-multilingual-cased/bert_config.json \--pytorch_dump_path ../bert-base-multilingual-cased/pytorch_model.bin
如果你下載的是 TF1.x 模型,則使用:
python src/transformers/models/bert/convert_bert_original_tf_checkpoint_to_pytorch.py \--tf_checkpoint_path ../bert-base-multilingual-cased/multi_cased_L-12_H-768_A-12/bert_model.ckpt \--bert_config_file ../bert-base-multilingual-cased/multi_cased_L-12_H-768_A-12/bert_config.json \--pytorch_dump_path ../bert-base-multilingual-cased/multi_cased_L-12_H-768_A-12/pytorch_model.bin
注意,此處需要安裝tensorflow。
4. 準備完整的 PyTorch 模型目錄
轉換完成后,我們需要確保模型目錄包含所有必要文件:
cd ../bert-base-multilingual-cased# 復制 bert_config.json 為 config.json(Transformers 庫需要)
cp bert_config.json config.json
現在,你的模型目錄應該包含以下三個關鍵文件:
config.json
:模型配置文件,包含了所有用于訓練的參數設置pytorch_model.bin
:轉換后的 PyTorch 模型權重文件vocab.txt
:詞表文件,用于識別模型支持的各種語言的字符
5. 驗證模型轉換成功
為了驗證模型轉換是否成功,我們可以編寫一個簡單的腳本來加載模型并進行測試:
from transformers import BertTokenizer, BertModel# 加載模型和分詞器
model_path = "path/to/bert-base-multilingual-cased"
tokenizer = BertTokenizer.from_pretrained(model_path)
model = BertModel.from_pretrained(model_path)# 測試多語言能力
texts = ["Hello, how are you?", # 英語"你好,最近怎么樣?", # 中文"Hola, ?cómo estás?" # 西班牙語
]for text in texts:inputs = tokenizer(text, return_tensors="pt")outputs = model(**inputs)print(f"Text: {text}")print(f"Shape of last hidden states: {outputs.last_hidden_state.shape}")print("---")
6. 使用模型進行下游任務
現在你可以使用這個轉換好的模型進行各種下游任務,如文本分類、命名實體識別等:
from transformers import BertTokenizer, BertForSequenceClassification
import torch# 加載模型和分詞器
model_path = "path/to/bert-base-multilingual-cased"
tokenizer = BertTokenizer.from_pretrained(model_path)# 初始化分類模型(假設有2個類別)
model = BertForSequenceClassification.from_pretrained(model_path, num_labels=2)# 準備輸入
text = "這是一個測試文本"
inputs = tokenizer(text, return_tensors="pt")# 前向傳播
outputs = model(**inputs)
logits = outputs.logits# 獲取預測結果
predicted_class = torch.argmax(logits, dim=1).item()
print(f"預測類別: {predicted_class}")
注意事項
-
模型文件大小:BERT-Base 模型文件通常較大(約400MB+),請確保有足夠的磁盤空間和內存。
-
路徑問題:在執行轉換腳本時,確保正確指定了所有文件的路徑。
-
命名約定:Transformers 庫期望配置文件名為
config.json
,而不是bert_config.json
,所以需要進行復制或重命名。 -
TensorFlow 版本:根據你下載的模型版本(TF1.x 或 TF2.x),選擇正確的轉換腳本。
-
checkpoint 文件:轉換腳本中的
--tf_checkpoint_path
參數應該指向不帶后綴的 checkpoint 文件名(如bert_model.ckpt
),而不是具體的.index
或.data
文件。
通過以上步驟,你就可以成功地將 Google 預訓練的 BERT 模型轉換為 PyTorch 格式,并在你的項目中使用它了。這個多語言版本的 BERT 模型支持 104 種語言,非常適合多語言自然語言處理任務。