BERT 模型微調與傳統機器學習的對比

BERT 微調與傳統機器學習的區別和聯系:

傳統機器學習流程

傳統機器學習處理文本分類通常包含以下步驟:

  1. 特征工程:手動設計特征(如 TF-IDF、詞袋模型)
  2. 模型訓練:使用分類器(如 SVM、隨機森林、邏輯回歸)
  3. 特征和模型調優:反復調整特征和超參數

BERT 微調流程

BERT 微調的典型流程:

  1. 預訓練:使用大規模無標注數據預訓練 BERT 模型
  2. 數據準備:將文本轉換為 BERT 輸入格式(tokenize、添加特殊標記)
  3. 模型微調:凍結大部分 BERT 層,只訓練分類頭(或少量 BERT 層)
  4. 評估與部署:在驗證集上評估,保存模型

兩者的主要區別

對比項傳統機器學習BERT 微調
特征表示手動設計特征(如 TF-IDF)自動學習上下文相關表示
模型復雜度簡單到中等(如 SVM、RF)非常復雜(Transformer 架構)
數據依賴需要大量標注數據可以用較少數據達到好效果
領域適應性遷移到新領域需要重新設計特征可以快速適應新領域(通過微調)
計算資源通常較低需要 GPU/TPU

您代碼中的 BERT 微調關鍵點

  1. 數據預處理

    • 使用BertTokenizer將文本轉換為 token IDs
    • 添加特殊標記([CLS]、[SEP])
    • 填充和截斷到固定長度
  2. 模型架構

    • 基礎模型:BERT 預訓練模型(bert-base-chinese)
    • 分類頭:在 BERT 頂部添加一個全連接層(num_labels=6
    • 微調策略:更新整個模型的權重
  3. 訓練優化

    • 使用AdamW優化器(帶權重衰減的 Adam)
    • 小學習率(2e-5)避免災難性遺忘
    • 批量訓練(batch_size=2)處理長序列
  4. 優勢

    • 利用預訓練模型捕獲的語言知識
    • 自動學習文本的上下文表示
    • 對訓練數據量要求較低
    • 遷移到新領域更容易

何時選擇 BERT 微調而非傳統方法?

  1. 當您有足夠的計算資源時
  2. 當任務數據量有限時
  3. 當需要處理復雜語義理解時
  4. 當需要快速適應新領域時

BERT 微調入門示例,展示了如何將預訓練語言模型應用于特定的分類任務。隨著 Transformer 架構的普及,這種方法已經成為 NLP 任務的主流解決方案。

import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertForSequenceClassification, AdamW
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import pandas as pd# 1. 準備數據
data = {'text': ["我想預訂明天的機票", "查詢今天的天氣", "幫我設置鬧鐘","播放周杰倫的歌曲", "今天有什么新聞", "推薦幾部科幻電影"],'label': [0, 1, 2, 3, 4, 5]  # 0:訂票, 1:天氣, 2:鬧鐘, 3:音樂, 4:新聞, 5:電影
}
df = pd.DataFrame(data)# 2. 數據集劃分
train_df, val_df = train_test_split(df, test_size=0.2, random_state=42)# 3. 創建數據集類
class TextClassificationDataset(Dataset):def __init__(self, texts, labels, tokenizer, max_len=128):self.texts = textsself.labels = labelsself.tokenizer = tokenizerself.max_len = max_lendef __len__(self):return len(self.texts)def __getitem__(self, idx):text = str(self.texts[idx])label = self.labels[idx]encoding = self.tokenizer(text,add_special_tokens=True,max_length=self.max_len,return_token_type_ids=False,padding='max_length',truncation=True,return_attention_mask=True,return_tensors='pt')return {'input_ids': encoding['input_ids'].flatten(),'attention_mask': encoding['attention_mask'].flatten(),'label': torch.tensor(label, dtype=torch.long)}# 4. 初始化tokenizer和模型
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
model = BertForSequenceClassification.from_pretrained('bert-base-chinese',num_labels=6
)# 5. 創建數據加載器
train_dataset = TextClassificationDataset(train_df['text'].values,train_df['label'].values,tokenizer
)
val_dataset = TextClassificationDataset(val_df['text'].values,val_df['label'].values,tokenizer
)train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=2)# 6. 訓練模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)optimizer = AdamW(model.parameters(), lr=2e-5)
epochs = 3for epoch in range(epochs):model.train()train_loss = 0for batch in train_loader:input_ids = batch['input_ids'].to(device)attention_mask = batch['attention_mask'].to(device)labels = batch['label'].to(device)optimizer.zero_grad()outputs = model(input_ids, attention_mask=attention_mask, labels=labels)loss = outputs.losstrain_loss += loss.item()loss.backward()optimizer.step()avg_train_loss = train_loss / len(train_loader)print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_train_loss:.4f}")# 7. 評估模型
model.eval()
predictions = []
true_labels = []with torch.no_grad():for batch in val_loader:input_ids = batch['input_ids'].to(device)attention_mask = batch['attention_mask'].to(device)labels = batch['label'].to(device)outputs = model(input_ids, attention_mask=attention_mask)preds = torch.argmax(outputs.logits, dim=1)predictions.extend(preds.cpu().numpy())true_labels.extend(labels.cpu().numpy())accuracy = accuracy_score(true_labels, predictions)
print(f"Validation Accuracy: {accuracy:.4f}")# 8. 保存模型
model.save_pretrained('./intent_classifier')
tokenizer.save_pretrained('./intent_classifier')

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/909171.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/909171.shtml
英文地址,請注明出處:http://en.pswp.cn/news/909171.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

(12)-Fiddler抓包-Fiddler設置IOS手機抓包

1.簡介 Fiddler不但能截獲各種瀏覽器發出的 HTTP 請求,也可以截獲各種智能手機發出的HTTP/ HTTPS 請求。 Fiddler 能捕獲Android 和 Windows Phone 等設備發出的 HTTP/HTTPS 請求。同理也可以截獲iOS設備發出的請求,比如 iPhone、iPad 和 MacBook 等蘋…

芯科科技Tech Talks技術培訓重磅回歸:賦能物聯網創新,共筑智能互聯未來

聚焦于Matter、藍牙、Wi-Fi、LPWAN、AI/ML五大熱門無線協議與技術 為年度盛會Works With大會賦能先行 隨著物聯網(IoT)和人工智能(AI)技術的飛速發展,越來越多的企業和個人開發者都非常關注最新的無線連接技術和應用…

docker-compose容器單機編排

docker-compose容器單機編排 開篇前言 隨著網站架構的升級,容器的使用也越來越頻繁,應用服務和容器之間的關系也越發的復雜。 這個就要求研發人員能更好的方法去管理數量較多的服務器,而不能手動挨個管理。 例如一個LNMP 架構,就…

LeetCode--29.兩數相除

解題思路: 1.獲取信息: 給定兩個整數,一個除數,一個被除數,要求返回商(商取整數) 限制條件:(1)不能使用乘法,除法和取余運算 (2&#…

中山大學GaussianFusion:首個將高斯表示引入端到端自動駕駛多傳感器融合的新框架

摘要 近年來由于端到端自動駕駛極大簡化了原有傳統自動駕駛模塊化的流程,吸引了來自工業界和學術界的廣泛關注。然而,現有的端到端智駕算法通常采用單一傳感器,使其在處理復雜多樣和具有挑戰性的駕駛場景中受到了限制。而多傳感器融合可以很…

《哈希算法》題集

1、模板題集 滿足差值的數字對 2、課內題集 字符統計 字符串統計 優質數對 3、課后題集 2006 Equations k倍區間 可結合的元素對 滿足差值的數字對 異常頻率 神秘數對 費里的語言 連連看 本題集為作者(英雄哪里出來)在抖音的獨家課程《英雄C入門到精…

Cordova移動應用對云端服務器數據庫的跨域訪問

Cordova移動應用對云端服務器數據庫的跨域訪問 當基于類似 Cordova這樣的跨平臺開發框架進行移動應用的跨平臺開發時,往往需要訪問部署在公網云端服務器上的數據庫,這時就涉及到了跨域數據訪問的問題。 文章目錄 Cordova移動應用對云端服務器數據庫的跨…

mysql知識點3--創建和使用數據庫

mysql知識點3–創建數據庫 創建數據庫 在MySQL中創建數據庫使用CREATE DATABASE語句。語法如下: CREATE DATABASE database_name;其中database_name為自定義的數據庫名稱。例如創建名為test_db的數據庫: CREATE DATABASE test_db;可以添加字符集和排…

林業資源多元監測技術守護綠水青山

在云南高黎貢山的密林中,無人機群正以毫米級精度掃描古樹年輪;福建武夷山保護區,衛星遙感數據實時追蹤著珍稀動植物的棲息地變化;海南熱帶雨林里,AI算法正從億萬條數據中預測下一場山火的風險……這些科幻場景&#xf…

一階/二階Nomoto模型(野本模型)為何“看不到”船速對回轉角速度/角加速度的影響?

提問 圖中的公式反映的是舵角和力矩之間的關系, 其中可以看到力矩(可以理解為角加速度)以及相應導致的回轉角速度和當前的舵速(主要由船速貢獻)有關,那么為什么一階Nomoto模型(一階野本&#xf…

深入剖析 C++ 默認函數:拷貝構造與賦值運算符重載

目錄 1. 簡單認識C 類的默認函數 1.1 默認構造函數 1.2 析構函數 1.3 拷貝構造函數 2. 拷貝構造函數的深入理解 拷貝構造的特點: 實際運用 3. 賦值運算符重載的深入理解 3.1.運算符重載 3.2樣例 1.比較運算符重載 2.算術運算符重載 3.自增和自減運算符重載 4.輸…

板凳-------Mysql cookbook學習 (十--3)

5.16 用短語來進行fulltext查詢 mysql> select count(*) from kjv where match(vtext) against(God); ---------- | count(*) | ---------- | 0 | ---------- 1 row in set (0.00 sec)mysql> select count(*) from kjv where match(vtext) against(sin); -------…

python爬蟲ip封禁應對辦法

目錄 一、背景現象 二、準備工作 三、代碼實現 一、背景現象 最近在做爬蟲項目時,爬取的網站,如果發送請求太頻繁的話,對方網站會先是響應緩慢,最后是封禁一段時間。一直是拒絕連接,導致程序無法正常預期的爬取數據…

【AIGC】Qwen3-Embedding:Embedding與Rerank模型新標桿

Qwen3-Embedding:Embedding與Rerank模型新標桿 一、引言二、技術架構與核心創新1. 模型結構與訓練策略(1)多階段訓練流程(2)高效推理設計(3)多語言與長上下文支持 2. 與經典模型的性能對比 三、…

算法競賽階段二-數據結構(32)數據結構簡單介紹

數據結構的基本概念 數據結構是計算機存儲、組織數據的方式,旨在高效地訪問和修改數據。它是算法設計的基礎,直接影響程序的性能。數據結構可分為線性結構和非線性結構兩大類。 線性數據結構 線性結構中,數據元素按順序排列,每…

Windows桌面圖標修復

新建文本文件,粘入以下代碼,保存為.bat文件,管理員運行這個文件 duecho off taskkill /f /im explorer.exe CD /d %userprofile%\AppData\Local DEL IconCache.db /a start explorer.exe echo 執行完成上面代碼作用是刪除桌面圖標緩存庫&…

13.react與next.js的特性和原理

🟡 一句話總結 React 專注于構建組件,而 Next.js 是基于 React 的全棧框架,提供了頁面路由、服務端渲染和全棧能力,讓你能快速開發現代 Web 應用。 React focuses on building UI components, while Next.js is a full-stack fra…

全棧監控系統架構

全棧監控系統架構 可觀測性從數據層面可分為三類: 指標度量(Metrics):記錄系統的總體運行狀態。事件日志(Logs):記錄系統運行期間發生的離散事件。鏈路追蹤(Tracing):記錄一個請求接入到結束的處理過程,主要用于排查…

云服務運行安全創新標桿:阿里云飛天洛神云網絡子系統“齊天”再次斬獲獎項

引言 為認真落實工信部《工業和信息化部辦公廳關于印發信息通信網絡運行安全管理年實施方案的通知》,2025年5月30日中國信息通信研究院于浙江杭州舉辦了“云服務運行安全高質量發展交流會”,推動正向引導,鞏固云服務安全專項治理成果。會上&a…

刀客doc:WPP走下神壇

一、至暗時刻? 6月11日,快消巨頭瑪氏公司宣布其價值17 億美元,在全球70個市場的廣告業務交給陽獅集團,這其中包括M&Ms、士力架、寶路等知名品牌。 此前,瑪氏公司一直是WPP的大客戶。早在今年3月,WPP就…