BERT 模型準備與轉換詳細操作流程

在嘗試復現極客專欄《PyTorch 深度學習實戰|24 | 文本分類:如何使用BERT構建文本分類模型?》時候,構建模型這一步驟專欄老師一筆帶過,對于新手有些不友好,經過一陣摸索,終于調通了,現在總結一下整體流程。

1. 獲取必要腳本文件

首先,我們需要從 Transformers 的 GitHub 倉庫中找到相關文件:

# 克隆 Transformers 倉庫
git clone https://github.com/huggingface/transformers.git
cd transformers

在倉庫中,我們需要找到以下關鍵文件:

  • src/transformers/models/bert/convert_bert_original_tf_checkpoint_to_pytorch.py(用于 TF1.x 模型)
  • src/transformers/models/bert/convert_bert_original_tf2_checkpoint_to_pytorch.py(用于 TF2.x 模型)
  • src/transformers/models/bert/modeling_bert.py(BERT 的 PyTorch 實現)

2. 下載預訓練模型

接下來,我們需要下載 Google 提供的預訓練 BERT 模型。根據你的需求,我們選擇"BERT-Base, Multilingual Cased"版本,它支持104種語言。

訪問 Google 的 BERT GitHub 頁面:https://github.com/google-research/bert

在該頁面中找到"BERT-Base, Multilingual Cased"的下載鏈接,或直接使用以下命令下載:

mkdir bert-base-multilingual-cased
cd bert-base-multilingual-cased# 下載模型文件
wget https://storage.googleapis.com/bert_models/2018_11_23/multi_cased_L-12_H-768_A-12.zip
unzip multi_cased_L-12_H-768_A-12.zip

解壓后,你會得到以下文件:

  • bert_model.ckpt.data-00000-of-00001
  • bert_model.ckpt.index
  • bert_model.ckpt.meta
  • bert_config.json
  • vocab.txt

3. 模型轉換

現在,我們使用之前找到的轉換腳本將 TensorFlow 模型轉換為 PyTorch 格式:

# 回到 transformers 目錄
cd ../transformers# 執行轉換腳本(針對 TF2.x 模型)
python src/transformers/models/bert/convert_bert_original_tf2_checkpoint_to_pytorch.py \--tf_checkpoint_path ../bert-base-multilingual-cased/bert_model.ckpt \--bert_config_file ../bert-base-multilingual-cased/bert_config.json \--pytorch_dump_path ../bert-base-multilingual-cased/pytorch_model.bin

如果你下載的是 TF1.x 模型,則使用:

python src/transformers/models/bert/convert_bert_original_tf_checkpoint_to_pytorch.py \--tf_checkpoint_path ../bert-base-multilingual-cased/multi_cased_L-12_H-768_A-12/bert_model.ckpt \--bert_config_file ../bert-base-multilingual-cased/multi_cased_L-12_H-768_A-12/bert_config.json \--pytorch_dump_path ../bert-base-multilingual-cased/multi_cased_L-12_H-768_A-12/pytorch_model.bin

注意,此處需要安裝tensorflow。

4. 準備完整的 PyTorch 模型目錄

轉換完成后,我們需要確保模型目錄包含所有必要文件:

cd ../bert-base-multilingual-cased# 復制 bert_config.json 為 config.json(Transformers 庫需要)
cp bert_config.json config.json

現在,你的模型目錄應該包含以下三個關鍵文件:

  1. config.json:模型配置文件,包含了所有用于訓練的參數設置
  2. pytorch_model.bin:轉換后的 PyTorch 模型權重文件
  3. vocab.txt:詞表文件,用于識別模型支持的各種語言的字符

5. 驗證模型轉換成功

為了驗證模型轉換是否成功,我們可以編寫一個簡單的腳本來加載模型并進行測試:

from transformers import BertTokenizer, BertModel# 加載模型和分詞器
model_path = "path/to/bert-base-multilingual-cased"
tokenizer = BertTokenizer.from_pretrained(model_path)
model = BertModel.from_pretrained(model_path)# 測試多語言能力
texts = ["Hello, how are you?",  # 英語"你好,最近怎么樣?",    # 中文"Hola, ?cómo estás?"   # 西班牙語
]for text in texts:inputs = tokenizer(text, return_tensors="pt")outputs = model(**inputs)print(f"Text: {text}")print(f"Shape of last hidden states: {outputs.last_hidden_state.shape}")print("---")

6. 使用模型進行下游任務

現在你可以使用這個轉換好的模型進行各種下游任務,如文本分類、命名實體識別等:

from transformers import BertTokenizer, BertForSequenceClassification
import torch# 加載模型和分詞器
model_path = "path/to/bert-base-multilingual-cased"
tokenizer = BertTokenizer.from_pretrained(model_path)# 初始化分類模型(假設有2個類別)
model = BertForSequenceClassification.from_pretrained(model_path, num_labels=2)# 準備輸入
text = "這是一個測試文本"
inputs = tokenizer(text, return_tensors="pt")# 前向傳播
outputs = model(**inputs)
logits = outputs.logits# 獲取預測結果
predicted_class = torch.argmax(logits, dim=1).item()
print(f"預測類別: {predicted_class}")

注意事項

  1. 模型文件大小:BERT-Base 模型文件通常較大(約400MB+),請確保有足夠的磁盤空間和內存。

  2. 路徑問題:在執行轉換腳本時,確保正確指定了所有文件的路徑。

  3. 命名約定:Transformers 庫期望配置文件名為 config.json,而不是 bert_config.json,所以需要進行復制或重命名。

  4. TensorFlow 版本:根據你下載的模型版本(TF1.x 或 TF2.x),選擇正確的轉換腳本。

  5. checkpoint 文件:轉換腳本中的 --tf_checkpoint_path 參數應該指向不帶后綴的 checkpoint 文件名(如 bert_model.ckpt),而不是具體的 .index.data 文件。

通過以上步驟,你就可以成功地將 Google 預訓練的 BERT 模型轉換為 PyTorch 格式,并在你的項目中使用它了。這個多語言版本的 BERT 模型支持 104 種語言,非常適合多語言自然語言處理任務。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/86195.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/86195.shtml
英文地址,請注明出處:http://en.pswp.cn/web/86195.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

doris 和StarRocks 導入導出數據配置

一、StarRocks 導數據到hdfs EXPORT TABLE database.table TO “hdfs://namenode/tmp/demo/table” WITH BROKER ( “username”“username”, “password”“password” ); 二、StarRocks 導數據到oss EXPORT TABLE database.table TO “oss://broke/aa/” WITH BROKER ( “…

【HTTP】取消已發送的請求

場景 在頁面中,可能會因為某些操作多次觸發某個請求,如多次點擊某按鈕觸發請求,實際上我們只需要最后一次請求的返回值,但是由于請求的耗時不一,請求未必會按發送的順序返回,導致我們最終獲取到的值 ≠ 最后…

JSON框架轉化isSuccess()為sucess字段

在您的描述中,BankInfoVO子類返回的JSON中出現了"success": true字段,但類本身沒有定義這個字段。這通常是由以下原因之一造成的: 原因分析及解決方案 序列化框架的Getter自動推導 Java序列化框架(如Jackson/Gson&…

Ragflow 源碼:task_executor.py

目錄 介紹主要功能核心組件 流程圖核心代碼解釋1. 系統架構與核心組件2. 核心處理流程3. 高級處理能力4. 關鍵創新點5. 容錯與監控機制6. 性能優化技巧 介紹 task_executor.py 是RAGFlow系統中的任務執行器(Task Executor)核心部分,主要負責文檔的解析、分塊(chunk…

創客匠人聯盟生態:重構家庭教育知識變現的底層邏輯

在《家庭教育促進法》推動行業剛需化的背景下,單一個體 IP 的增長天花板日益明顯。創客匠人提出的 “聯盟生態思維”,正推動家庭教育行業從 “單打獨斗” 轉向 “矩陣作戰”,其核心在于通過工具整合資源,將 “同行競爭” 轉化為 “…

【Docker基礎】Docker容器管理:docker stop詳解

目錄 1 Docker容器生命周期概述 2 docker stop命令深度解析 2.1 命令基本語法 2.2 命令執行流程 2.3 stop與kill的區別 3 docker stop的工作原理 3.1 工作流程 3.2 詳細工作流程 3.3 信號處理機制 4 docker stop的使用場景與最佳實踐 4.1 典型使用場景 場景1&#…

rules寫成動態

拖拽排序和必填校驗聯動(rules寫到computed里) computed: {rules() {const rules {};this.form.feedList.forEach((item, idx) > {rules[feedList.${idx}] [{ required: true, message: 路線評價動態${idx 1}待填寫,請填寫完畢提交, trigger: change }];});re…

The Open Group開放流程自動化? 論壇(OPAF)發布組織最新進展報告

除埃克森美孚(ExxonMobil)的成就外,開放流程自動化? 論壇(OPAF)的最新論壇報告顯示,該組織其他成員也在多個領域取得進展。 “我們祝賀埃克森美孚,因為他們證明了在前線、創收的工藝操作中部署…

線程的基本控制

線程終止 exit是危險的 如果進程中的任意一個線程調用了exit,那么整個進程終止。 不終止進程的退出方式 普通單個線程的退出方法,以下方法退出不會導致進程終止: (1)從啟動例程中返回,返回值是線程的退出…

DeepSeek+WinForm串口通訊實戰

前言 在現代軟件開發中,串口通訊仍然是工業自動化、物聯網設備和嵌入式系統的重要通信方式。隨著.NET技術的發展,特別是.NET 5/.NET 6的跨平臺能力,傳統的WinForm應用現在可以通過現代UI框架實現真正的跨平臺串口通訊。本文將深入探討三種主…

針對數據倉庫方向的大數據算法工程師面試經驗總結

?? 一、技術核心考察點 數據建模能力 星型 vs 雪花模型:面試官常要求對比兩種模型。星型模型(事實表冗余維度表)查詢性能高但存儲冗余;雪花模型(規范化維度表)減少冗余但增加JOIN復雜度。需結合場景選擇&…

Nuxt3 Cannot read properties of undefined (reading ‘createElement‘)

你遇到的 TypeError: Cannot read properties of undefined (reading createElement) 這個報錯,通常是由于在 Nuxt3 或 Vue3 項目中,某些地方嘗試訪問 document.createElement 或類似 DOM API,但此時 document 還未定義(比如在服務…

正則表達式匹配實現

直接上代碼 using Microsoft.AspNetCore.Mvc; using System.Text.RegularExpressions;namespace SaaS.OfficialWebSite.Web.Controllers {public class RegController : Controller{public IActionResult Index(){return View();}[HttpPost]public IActionResult TestRegex([F…

API測試工具Parasoft SOAtest:應對API變化,優化測試執行

API頻繁變更給測試工作帶來諸多挑戰,如手動排查變更影響耗時費力、測試用例維護繁瑣易出錯等。Parasoft SOAtest作為一款企業級API測試工具,通過自動掃描API接口、智能分析變更影響、優化測試,執行以及支持測試用例共享與版本控制等功能&…

mysql 數據庫連接 -h localhost 和 -h 127.0.0.1 區別是什么

對于 mysql 數據庫, 在 my.conf 中指定的client 端口是 3358,實際的mysql server 的端口監聽在 3306, mysql -h localhost 可以居然可以連接成功; mysql -h 127.0.0.1 連接失敗提示Can’t connect to MySQL server on 127.0.0.1&a…

Educational Codeforces Round 180 (Rated for Div. 2) A-D

A.Race 題目大意 給你兩個x,y,終點會在二點之間隨機出現,alice在點a,假設alice和bob有相同的速度(距離更短者用時更少),問對于bob是否存在一點,無論終點是x還是y,他都能比alice更快到達 思路 如果alice在…

python requests post請求

在Python中,使用requests庫進行POST請求是一種常見的操作,用于向服務器發送數據。下面是如何使用requests庫進行POST請求的步驟: 安裝requests庫 如果你還沒有安裝requests庫,可以通過pip安裝: pip install requests…

Postman中設置定時自動運行接口測試

?創建測試集合? 將需每日運行的接口組織到Collection中,并配置好測試腳本和斷言。 ?配置定時運行? 打開目標Collection → 點擊 ?Run? 按鈕在Collection Runner頁面底部選擇 ?Schedule runs?關鍵配置: Frequency: Daily // 選擇每日執行 Time…

multiprocessing.pool和multiprocessing.Process

在CPU密集型任務中,Python的multiprocessing模塊是突破GIL限制的關鍵工具。multiprocessing.Pool(進程池)和multiprocessing.Process(獨立進程)是最常用的兩種并行化方案,但其設計思想和適用場景截然不同。…

容器技術技術入門與 Docker 環境部署

目錄 一:Docker概述 1、 Docker的優勢: (1)環境一致性 (2)隔離性 (3)資源高效 (4)便捷性和可擴展性 2、Docker容器與傳統虛擬機的區別 3、Docker的應用…