【機器學習深度學習】模型微調的基本概念與流程

目錄

前言

一、什么是模型微調(Fine-tuning)?

二、預訓練 vs 微調:什么關系?

三、微調的基本流程(以BERT為例)

1?? 準備數據

2?? 加載預訓練模型和分詞器

3?? 數據編碼與加載

4?? 定義優化器

5?? 開始訓練

6?? 評估與保存模型

四、是否要凍結 BERT 層?

?五、完整訓練示例代碼

5.1 環境依賴

5.2 執行代碼

總結:微調的優勢


前言

在自然語言處理(NLP)快速發展的今天,預訓練模型如 BERT 成為了眾多任務的基礎。但光有預訓練模型并不能解決所有問題,模型微調 的技術應運而生。它讓通用模型具備了“專才”的能力,使其能更好地服務于特定任務,如情感分析、問答系統、命名實體識別等。

本文將帶你快速理解——什么是模型微調,它的基本流程又是怎樣的?


一、什么是模型微調(Fine-tuning)?

? 概念通俗解釋:

微調,就是在別人學得很好的“通用知識”上,加入你自己的“專業訓練”。

具體來說,像 BERT 這樣的預訓練語言模型已經通過大規模語料學習了大量語言規律,比如語法結構、詞語搭配等。我們不需要從頭訓練它,而是在此基礎上繼續用小規模、特定領域的數據進行訓練,讓模型更好地完成某個具體任務。


二、預訓練 vs 微調:什么關系?

階段目標數據類型舉例
預訓練(Pre-training)學習通用語言知識大規模通用語料維基百科、圖書館語料
微調(Fine-tuning)適應特定任務少量任務特定數據產品評論、醫療文本、法律文書

一句話總結:
🔁 預訓練是“打基礎”,微調是“練專業”。


三、微調的基本流程(以BERT為例)

讓我們以“使用 BERT 進行情感分析”為例,梳理整個微調流程:

1?? 準備數據

我們需要將文本和標簽準備好,通常是一個 CSV 文件,比如:

評論內容情感標簽
這部電影太好看了!正面
爛片,浪費時間。負面

我們會將“正面”轉換為 1,負面為 0,方便模型學習。


2?? 加載預訓練模型和分詞器

from transformers import BertTokenizer, BertForSequenceClassificationtokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
model = BertForSequenceClassification.from_pretrained('bert-base-chinese', num_labels=2)

?此時模型的主體結構已經包含了 BERT 和一個分類頭(Classification Head)。


3?? 數據編碼與加載

使用分詞器將文本轉為模型輸入格式:

tokens = tokenizer("這部電影太好看了!", padding='max_length', truncation=True, return_tensors="pt")

你還需要構建自定義數據集類(Dataset),并使用 DataLoader 加載:

from torch.utils.data import DataLoadertrain_loader = DataLoader(my_dataset, batch_size=16, shuffle=True)

4?? 定義優化器

from transformers import AdamWoptimizer = AdamW(model.parameters(), lr=5e-5)

?▲優化器的作用是:根據損失函數的值,自動調整模型的參數,使模型表現越來越好。

▲通俗理解

優化器就像你走路的“策略”:
它告訴你“往哪邊走,走多快,怎么避開障礙”,最終盡可能走到山底。

?▲優化器做了什么?

神經網絡訓練時,每一輪都會:

  1. 計算當前模型的預測誤差(損失函數 loss)

  2. 反向傳播得到每個參數的梯度(方向)

  3. 👉 優化器根據梯度,更新參數的值

就像你爬山時,不斷踩點 → 看地形 → 決定下一個落腳點。

組件類比作用
損失函數地圖高度告訴你你離目標有多遠
梯度當前坡度告訴你往哪里走
優化器走路策略告訴你怎么調整步伐走得更快更穩

5?? 開始訓練

model.train()
for batch in train_loader:outputs = model(**batch)loss = outputs.lossloss.backward()optimizer.step()optimizer.zero_grad()

通常我們會訓練幾個 epoch,讓模型逐漸學會如何從文本中識別情感。


6?? 評估與保存模型

訓練完成后,我們可以在驗證集上評估準確率,并保存模型:

torch.save(model.state_dict(), "bert_sentiment.pth")

四、是否要凍結 BERT 層?

微調過程中,有兩種策略:

  • 全模型微調(默認): 所有 BERT 層和分類頭都參與訓練。效果通常更好,但對顯存要求高。

  • 凍結 BERT,僅訓練分類頭: 保持 BERT 權重不變,只訓練新加的分類層。適合數據量小或設備受限的場景。

凍結代碼示例:

for param in model.bert.parameters():param.requires_grad = False

?【兩種微調策略對比】

策略是否凍結BERT層?訓練內容優點缺點
? 全部微調? 不凍結BERT + 分類層一起訓練效果最好,能深入適配任務訓練慢,顯存占用大
🚫 只微調分類層? 凍結 BERT只訓練分類層快速,適合小數據、低配置表現可能略遜一籌

【舉個通俗類比】

想象你雇了一個精通語文的老師(BERT),但你只想讓他教學生寫作文(分類任務):

  • 全模型微調:老師重新備課、重新學習學生情況,全面參與教學(耗時但有效)。

  • 只調分類層:老師照搬舊知識,只教作文技巧,不深入了解學生(快速但效果一般)。


【什么時候選擇“凍結 BERT 層”?】

  • 數據量很小(如只有幾百個樣本)

  • 硬件資源有限(顯存小、設備性能弱)

  • 快速原型驗證,先試試看效果


【小結一句話】

凍結 BERT 層就是讓預訓練好的 BERT 不再學習,只訓練新增的部分;
不凍結則是讓整個 BERT 跟著任務數據一起再“進修”一輪,效果更強,但代價更高。

你也可以先凍結一部分,再逐步解凍,稱為 層級微調(layer-wise unfreezing),也是一個進階策略。


?五、完整訓練示例代碼

5.1 環境依賴

?1、安裝Pytorch

pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126

注意:

▲安裝pytorch前先確定自己電腦是否有GPU,沒有請安裝cpu版本的;

pip3 install torch torchvision torchaudio

▲確保CUDA 12.6版本可以兼容

確定是否兼容可參考該文章對應內容:【CUDA&cuDNN安裝】深度學習基礎環境搭建_cudnn安裝教程-CSDN博客


2、安裝transformers

pip install transformers

3、安裝scikit-learn

pip install scikit-learn

scikit-learn 是一個專注于傳統機器學習的工具箱,涵蓋從模型訓練、評估到數據處理的一整套流程。


5.2 執行代碼

import torch
from transformers import BertTokenizer, BertForSequenceClassification, AdamW
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score
import pandas as pd# 1. 自定義數據集類
class SentimentDataset(Dataset):def __init__(self, texts, labels, tokenizer, max_len=128):self.texts = textsself.labels = labelsself.tokenizer = tokenizerself.max_len = max_lendef __len__(self):return len(self.texts)def __getitem__(self, idx):inputs = self.tokenizer(self.texts[idx],truncation=True,padding='max_length',max_length=self.max_len,return_tensors='pt')return {'input_ids': inputs['input_ids'].squeeze(0),'attention_mask': inputs['attention_mask'].squeeze(0),'labels': torch.tensor(self.labels[idx], dtype=torch.long)}# 2. 加載數據
df = pd.read_csv("data.csv")  # 假設 CSV 有 'text' 和 'label' 列
train_texts, train_labels = df['text'].tolist(), df['label'].tolist()# 3. 初始化 tokenizer 和 model
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
model = BertForSequenceClassification.from_pretrained('bert-base-chinese', num_labels=2)# 4. 構建 DataLoader
train_dataset = SentimentDataset(train_texts, train_labels, tokenizer)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)# 5. 設置優化器
optimizer = AdamW(model.parameters(), lr=5e-5)# 6. 訓練過程
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.train()for epoch in range(3):  # 可改為你想要的輪數total_loss = 0preds, targets = [], []for batch in train_loader:input_ids = batch['input_ids'].to(device)attention_mask = batch['attention_mask'].to(device)labels = batch['labels'].to(device)outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)loss = outputs.losslogits = outputs.logitsoptimizer.zero_grad()loss.backward()optimizer.step()total_loss += loss.item()preds += torch.argmax(logits, dim=1).tolist()targets += labels.tolist()acc = accuracy_score(targets, preds)print(f"Epoch {epoch+1} | Loss: {total_loss:.4f} | Accuracy: {acc:.4f}")# 7. 保存模型
torch.save(model.state_dict(), "bert_finetuned.pth")

?


總結:微調的優勢

? 少量數據就能訓練出效果不錯的模型
? 遷移學習加速開發,節省計算資源
? 靈活應對不同領域任務,如醫學、法律、金融等

模型微調是現代 AI 應用的關鍵技能之一。如果說預訓練模型是“萬能工具箱”,那么微調就是“選對合適的工具并精修”。掌握這項技術,你就能迅速把通用模型打造成特定任務的專家。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/87515.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/87515.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/87515.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

大語言模型預訓練數據——數據采樣方法介紹以GPT3為例

大語言模型預訓練數據——數據采樣方法介紹以GPT3為例一、數據采樣核心邏輯二、各列數據含義一、數據采樣核心邏輯 這是 GPT - 3 訓練時的數據集配置,核心是非等比例采樣——不按數據集原始大小分配訓練占比,而是人工設定不同數據集在訓練中被抽取的概率…

針對同一臺電腦,為使用不同 SSH Key 的不同用戶分別設置 Git 遠程倉庫憑據的操作指南

一、準備工作 生成多對 SSH Key 為每個用戶(如“個人”、“公司”)生成一對獨立的 SSH Key。 示例(在 Git Bash 或 Linux 終端中執行): # 個人 ssh-keygen -t rsa -b 4096 -C "personalexample.com" -f ~/.…

【V5.0 - 視覺篇】AI的“火眼金睛”:用OpenCV量化“第一眼緣”,并用SHAP驗證它的“審美”

系列回顧: 在上一篇 《給AI裝上“寫輪眼”:用SHAP看穿模型決策的每一個細節》 中,我們成功地為AI裝上了“透視眼鏡”,看穿了它基于數字決策的內心世界。 但一個巨大的問題暴露了:它的世界里,還只有數字。 它…

Open3D 基于最大團(MAC)的點云粗配準

MAC 一、算法原理1、原理概述2、實現流程3、總結二、代碼實現三、結果展示博客長期更新,本文最新更新時間為:2025年7月1日。 一、算法原理 1、原理概述 最大團(Maximal Cliques, MAC)法在點云配準中的應用,是近年來解決高離群值(outlier)和低重疊場景下配準問題的重要…

Science Robotics發表 | 20m/s自主飛行+避開2.5mm電線的微型無人機!

從山火搜救到災后勘察,時間常常意味著生命。分秒必爭的任務要求無人機在陌生狹窄環境中既要飛得快、又要飛得穩。香港大學機械工程系張富教授團隊在Science Robotics(2025)發表論文“Safety-assured High-speed Navigation for MAVs”提出了微型無人機的安全高速導航…

【數據分析】如何在PyCharm中高效配置和使用SQL

PyCharm 作為 Python 開發者的首選 IDE,其 Professional 版本提供了強大的數據庫集成功能,讓開發者無需切換工具即可完成數據庫操作。本文將手把手教你配置和使用 PyCharm 的 SQL 功能。 一、安裝和配置 PyCharm 老生常談,第一步自然是安裝并…

OpenShift AI - 使用 NVIDIA Triton Runtime 運行模型

《OpenShift / RHEL / DevSecOps 匯總目錄》 說明:本文已經在 OpenShift 4.18 OpenShift AI 2.19 的環境中驗證 文章目錄 準備 Triton Runtime 環境添加 Triton Serving Runtime運行基于 Triton Runtime 的 Model Server 在 Triton Runtime 中運行模型準備模型運行…

物聯網數據安全區塊鏈服務

物聯網數據安全區塊鏈服務 下面是一個專為物聯網數據安全設計的區塊鏈服務實現,使用Python編寫并封裝為RESTful API。該服務確保物聯網設備數據的不可篡改性、可追溯性和安全性。 import hashlib import json import time from datetime import datetime from uui…

數據集-目標檢測系列- 卡車 數據集 truck >> DataBall

數據集-目標檢測系列- 卡車 數據集 truck >> DataBall貴在堅持!* 相關項目1)數據集可視化項目:gitcode: https://gitcode.com/DataBall/DataBall-detections-100s/overview2)數據集訓練、推理相關項目&…

vue/微信小程序/h5 實現react的boundary

ErrorBoundary react的boundary實現核心邏輯無法處理的情況包含函數詳細介紹getDerivedStateFromError和componentDidCatch作用為什么分開調用 代碼實現(補充其他異常捕捉)函數組件與useErrorBoundary(需自定義Hook) vue的boundar…

Day113 切換Node.js版本、多數據源配置

切換Node.js版本 1.nvm簡介nvm(Node Version Manager),在Windows上管理Node.js版本,可以在同一臺電腦上輕松管理和切換多個Node.js版本 nvm下載地址:https://github.com/coreybutler/nvm-windows/2.配置nvm安裝之后檢查nvm是否已經安裝好了&a…

應急響應靶機-linux2-知攻善防實驗室

題目: 1.提交攻擊者IP2.提交攻擊者修改的管理員密碼(明文)3.提交第一次Webshell的連接URL(http://xxx.xxx.xxx.xx/abcdefg?abcdefg只需要提交abcdefg?abcdefg)4.提交Webshell連接密碼5.提交數據包的flag16.提交攻擊者使用的后續上傳的木馬文件名稱7.提交攻擊者隱藏…

新手前端使用Git(常用命令和規范)

發一篇文章來說一下前端在開發項目的時候常用的一些git命令 注:這篇文章只說最常用的,最下面有全面的 一:從git倉庫拉取項目到本地 1:新建文件夾存放項目代碼 2:在git上復制一下項目路徑(看那個順眼復制…

【面試題】常用Git命令

【面試題】常用Git命令1. 常用Git命令1. 常用Git命令 1.git clone git clone https://gitee.com/Blue_Pepsi_Cola/straw.git 2.使用-v選項,可以參看遠程主機的網址 git remote -v origin https://ccc.ddd.com/1-java/a-admin-api.git (fetch) origin https://ccc.…

Webpack構建工具

構建工具系列 Gulp構建工具Grunt構建工具Webpack構建工具Vite構建工具 Webpack構建工具 構建工具系列前言一、安裝打包配置webpack安裝樣式加載器devtoolwebpack devtool 配置詳解常見 devtool 值及適用場景選擇建議性能影響注意事項 module處理流程module.rulesmodule.usemod…

重學前端002 --響應式網頁設計 CSS

文章目錄 css 樣式特殊說明 根據在這里 Freecodecamp 實踐,調整順序后做的總結。 css 樣式 body {background-color: red; # 跟background-image 不同時使用background-image: url(https://cdn.freecodecamp.org/curriculum/css-cafe/beans.jpg);font-family: san…

RabbitMQ簡單消息監聽和確認

如何監聽RabbitMQ隊列 簡單代碼實現RabbitMQ消息監聽 需要的依賴 <!--rabbitmq--><dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-amqp</artifactId><version>x.x.x</version>&l…

Docker學習筆記:Docker網絡

本文是自己的學習筆記 1、Linux中的namespace1.1、創建namespace1.2、兩個namespace互相通信2、Docker中的namespace2.1 容器中的默認Bridge3、容器的三種網絡模式1、Linux中的namespace Docker中使用了虛擬網絡技術&#xff0c;讓各個容器的網絡隔離。好像每個容器從網卡到端…

用自定義注解解決excel動態表頭導出的問題

導入的excel有固定表頭動態表頭如何解決 自定義注解&#xff1a; import java.lang.annotation.*;/*** 自定義注解&#xff0c;用于動態生成excel表頭*/ Target(ElementType.FIELD) Retention(RetentionPolicy.RUNTIME) public interface FieldLabel {// 字段中文String label(…

Android-EDLA 解決 GtsMediaRouterTestCases 存在 fail

問題描述&#xff1a;[原因]R10套件新增模塊&#xff0c;getRemoteDevice獲取遠程藍牙設備時&#xff0c;藍牙MAC為空 [對策]實際藍牙MAC非空;測試時繞過處理 1.release/ebsw_skg/skg/frameworks/base/packages/SettingsLib/src/com/android/settingslib/media/InfoMediaManage…