一起學Hugging Face Transformers(13)- 模型微調之自定義訓練循環

文章目錄

  • 前言
  • 一、什么是訓練循環
    • 1. 訓練循環的關鍵步驟
    • 2. 示例
    • 3. 訓練循環的重要性
  • 二、使用 Hugging Face Transformers 庫實現自定義訓練循環
    • 1. 前期準備
      • 1)安裝依賴
      • 2)導入必要的庫
    • 2. 加載數據和模型
      • 1) 加載數據集
      • 2) 加載預訓練模型和分詞器
      • 3) 預處理數據
      • 4) 創建數據加載器
    • 3. 自定義訓練循環
      • 1) 定義優化器和學習率調度器
      • 2) 定義訓練和評估函數
      • 3) 運行訓練和評估
  • 總結


前言

Hugging Face Transformers 庫為 NLP 模型的預訓練和微調提供了豐富的工具和簡便的方法。雖然 Trainer API 簡化了許多常見任務,但有時我們需要更多的控制權和靈活性,這時可以實現自定義訓練循環。本文將介紹什么是訓練循環以及如何使用 Hugging Face Transformers 庫實現自定義訓練循環。


一、什么是訓練循環

在模型微調過程中,訓練循環是指模型訓練的核心過程,通過多次迭代數據集來調整模型的參數,使其在特定任務上表現更好。訓練循環包含以下幾個關鍵步驟:

1. 訓練循環的關鍵步驟

1) 前向傳播(Forward Pass)

  • 模型接收輸入數據并通過網絡進行計算,生成預測輸出。這一步是將輸入數據通過模型的各層逐步傳遞,計算出最終的預測結果。

2) 計算損失(Compute Loss)

  • 將模型的預測輸出與真實標簽進行比較,計算損失函數的值。損失函數是一個衡量預測結果與真實值之間差距的指標,常用的損失函數有交叉熵損失(用于分類任務)和均方誤差(用于回歸任務)。

3) 反向傳播(Backward Pass)

  • 根據損失函數的值,計算每個參數對損失的貢獻,得到梯度。反向傳播使用鏈式法則,將損失對每個參數的梯度計算出來。

4) 參數更新(Parameter Update)

  • 使用優化算法(如梯度下降、Adam 等)根據計算出的梯度調整模型的參數。優化算法會更新每個參數,使損失函數的值逐步減小,模型的預測性能逐步提高。

5) 重復以上步驟

  • 以上過程在整個數據集上進行多次(多個epoch),每次遍歷數據集被稱為一個epoch。隨著訓練的進行,模型的性能會不斷提升。

2. 示例

假設你在微調一個BERT模型用于情感分析任務,訓練循環的步驟如下:

1) 前向傳播

  • 輸入一條文本評論,模型通過各層網絡計算,生成預測的情感標簽(如正面或負面)。

2) 計算損失

  • 將模型的預測標簽與實際標簽進行比較,計算交叉熵損失。

3) 反向傳播

  • 計算損失對每個模型參數的梯度,確定每個參數需要調整的方向和幅度。

4) 參數更新

  • 使用Adam優化器,根據計算出的梯度調整模型的參數。

5) 重復以上步驟

  • 在整個訓練數據集上進行多次迭代,不斷調整參數,使模型的預測精度逐步提高。

3. 訓練循環的重要性

訓練循環是模型微調的核心,通過多次迭代和參數更新,使模型能夠從數據中學習,逐步提高在特定任務上的性能。理解訓練循環的各個步驟和原理,有助于更好地調試和優化模型,獲得更好的結果。

在實際應用中,訓練循環可能會包含一些額外的步驟和技術,例如:

  • 批量訓練(Mini-Batch Training):將數據集分成小批量,每次訓練一個批量,降低計算資源的需求。
  • 學習率調度(Learning Rate Scheduling):動態調整學習率,以提高訓練效率和模型性能。
  • 正則化技術(Regularization Techniques):如Dropout、權重衰減等,防止模型過擬合。

這些技術和方法結合使用,可以進一步提升模型微調的效果和性能。

二、使用 Hugging Face Transformers 庫實現自定義訓練循環

1. 前期準備

1)安裝依賴

首先,確保已經安裝了必要的庫:

pip install transformers datasets torch

2)導入必要的庫

import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModelForSequenceClassification, get_scheduler
from datasets import load_dataset
from tqdm.auto import tqdm

2. 加載數據和模型

1) 加載數據集

這里我們以 IMDb 電影評論數據集為例:

dataset = load_dataset("imdb")

2) 加載預訓練模型和分詞器

我們將使用 distilbert-base-uncased 作為基礎模型:

model_name = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)

3) 預處理數據

定義一個預處理函數,并將其應用到數據集:

def preprocess_function(examples):return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=512)encoded_dataset = dataset.map(preprocess_function, batched=True)
encoded_dataset = encoded_dataset.rename_column("label", "labels")
encoded_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])

4) 創建數據加載器

train_dataloader = DataLoader(encoded_dataset["train"], batch_size=8, shuffle=True)
eval_dataloader = DataLoader(encoded_dataset["test"], batch_size=8)

3. 自定義訓練循環

1) 定義優化器和學習率調度器

optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
num_epochs = 3
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
)

2) 定義訓練和評估函數

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)def train_loop():model.train()for batch in tqdm(train_dataloader):batch = {k: v.to(device) for k, v in batch.items()}outputs = model(**batch)loss = outputs.lossloss.backward()optimizer.step()lr_scheduler.step()optimizer.zero_grad()def eval_loop():model.eval()total_loss = 0correct_predictions = 0with torch.no_grad():for batch in tqdm(eval_dataloader):batch = {k: v.to(device) for k, v in batch.items()}outputs = model(**batch)loss = outputs.losslogits = outputs.logitstotal_loss += loss.item()predictions = torch.argmax(logits, dim=-1)correct_predictions += (predictions == batch["labels"]).sum().item()avg_loss = total_loss / len(eval_dataloader)accuracy = correct_predictions / len(eval_dataloader.dataset)return avg_loss, accuracy

3) 運行訓練和評估

for epoch in range(num_epochs):print(f"Epoch {epoch + 1}/{num_epochs}")train_loop()avg_loss, accuracy = eval_loop()print(f"Validation Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")

總結

通過上述步驟,我們實現了使用 Hugging Face Transformers 庫的自定義訓練循環。這種方法提供了更大的靈活性,可以根據具體需求調整訓練過程。無論是優化器、學習率調度器,還是其他訓練策略,都可以根據需要進行定制。希望這篇文章能幫助你更好地理解和實現自定義訓練循環,為你的 NLP 項目提供更強大的支持。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/41666.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/41666.shtml
英文地址,請注明出處:http://en.pswp.cn/web/41666.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

玉石風能否接棒黏土風?一探AI繪畫新風尚

在數字藝術的浪潮中,AI繪畫平臺以其獨特的創造力和便捷性,正在逐步改變我們對藝術的傳統認知。從黏土風的溫暖質感到琉璃玉石的細膩光澤,每一次風格的轉變都引領著新的潮流。今天,我們將聚焦玉石風,探討它是否能成為下一個流行的藝術濾鏡,并提供一種在線體驗的方式,讓你…

Python | Leetcode Python題解之第221題最大正方形

題目: 題解: class Solution:def maximalSquare(self, matrix: List[List[str]]) -> int:if len(matrix) 0 or len(matrix[0]) 0:return 0maxSide 0rows, columns len(matrix), len(matrix[0])dp [[0] * columns for _ in range(rows)]for i in…

使用Python實現深度學習模型:模型監控與性能優化

在深度學習模型的實際應用中,模型的性能監控與優化是確保其穩定性和高效性的關鍵步驟。本文將介紹如何使用Python實現深度學習模型的監控與性能優化,涵蓋數據準備、模型訓練、監控工具和優化策略等內容。 目錄 引言模型監控概述性能優化概述實現步驟數據準備模型訓練模型監控…

梧桐數據庫:語法分析模塊概述

語法分析模塊是數據庫系統的重要組成部分,它負責將用戶輸入的 SQL 語句轉換為內部表示形式,以便后續的處理和執行。在數據庫系統中,語法分析模塊是連接用戶與數據庫的橋梁。它的主要任務是將用戶輸入的 SQL 語句進行解析,檢查語法…

Kafka(一)基礎介紹

一,Kafka集群 一個典型的 Kafka 體系架構包括若Producer、Broker、Consumer,以及一個ZooKeeper集群,如圖所示。 ZooKeeper:Kafka負責集群元數據的管理、控制器的選舉等操作的; Producer:將消息發送到Broker…

隨著云計算和容器技術的廣泛應用,如何在這些環境中有效地運用 Shell 進行自動化部署和管理?

在云計算和容器技術的環境中,Shell 腳本可以被用于自動化部署和管理任務。下面是一些在這些環境中有效使用 Shell 進行自動化部署和管理的方法: 在云環境中,使用云服務提供商的 API 進行自動化管理。例如,使用命令行工具或 SDK 來…

14 - Python網絡應用開發

網絡應用開發 發送電子郵件 在即時通信軟件如此發達的今天,電子郵件仍然是互聯網上使用最為廣泛的應用之一,公司向應聘者發出錄用通知、網站向用戶發送一個激活賬號的鏈接、銀行向客戶推廣它們的理財產品等幾乎都是通過電子郵件來完成的,而…

[AI 大模型] OpenAI ChatGPT

文章目錄 ChatGPT 簡介ChatGPT 的模型架構ChatGPT的發展歷史節點爆發元年AI倫理和安全 ChatGPT 新技術1. 技術進步2. 應用領域3. 代碼示例4. 對話示例 ChatGPT 簡介 ChatGPT 是由 OpenAI 開發的一個大型語言模型,基于GPT-4架構。它能夠理解和生成自然語言文本&…

學習筆記——動態路由——OSPF(特殊區域)

十、OSPF特殊區域 1、技術背景 早期路由器靠CPU計算轉發,由于硬件技術限制問題,因此資源不是特別充足,因此是要節省資源使用,規劃是非常必要的。 OSPF路由器需要同時維護域內路由、域間路由、外部路由信息數據庫。當網絡規模不…

電腦會議錄音轉文字工具哪個好?5個轉文字工具簡化工作流程

在如今忙碌的生活中,我們常常需要記錄和回顧重要的對話和討論。手寫筆記可能跟不上速度,而錄音則以其便捷性成為了捕捉信息的有力工具。但錄音文件的后續處理,往往讓人頭疼不已。想象一下,如果能夠瞬間將這些聲音轉化為文字&#…

spring-16

Spring 對 DAO 的支持 Spring 對 DAO 的支持是通過 Spring 框架的 JDBC 模塊實現的,它提供了一系列的工具和類來簡化數據訪問對象(DAO)的開發和管理。 首先,我們需要在 Spring 配置文件中配置數據源和事務管理器: &l…

Java筆試|面試 —— 子類對象實例化全過程 (熟悉)

子類對象實例化全過程 (熟悉) (1)從結果的角度來看:體現為繼承性 當創建子類對象后,子類對象就獲取了其父類中聲明的所有的屬性和方法,在權限允許的情況下,可以直接調用。 (2)從過…

iptables實現端口轉發ssh

iptables實現端口轉發 實現使用防火墻9898端口訪問內網front主機的22端口(ssh連接) 1. 防火墻配置(lb01) # 配置iptables # 這條命令的作用是將所有目的地為192.168.100.155且目標端口為19898的TCP數據包的目標IP地址改為10.0.0.148,并將目標…

Java策略模式在動態數據驗證中的應用

在軟件開發中,數據驗證是一項至關重要的任務,它確保了數據的完整性和準確性,為后續的業務邏輯處理奠定了堅實的基礎。然而,不同的數據來源往往需要不同的驗證規則,如何在不破壞代碼的整潔性和可維護性的同時&#xff0…

無向圖中尋找指定路徑:深度優先遍歷算法

刷題記錄 1. 節點依賴 背景: 類似于無向圖中, 尋找從 起始節點 --> 目標節點 的 線路. 需求: 現在需要從 起始節點 A, 找到所有到 終點 H 的所有路徑 A – B : 路徑由一個對象構成 public class NodeAssociation {private String leftNodeName;private Stri…

數據編碼的藝術:sklearn中的數據轉換秘籍

數據編碼的藝術:sklearn中的數據轉換秘籍 在機器學習中,數據預處理是一個至關重要的步驟,它直接影響到模型的性能和結果的準確性。數據編碼轉換是數據預處理的一部分,它涉及將原始數據轉換成適合模型訓練的格式。scikit-learn&am…

Python 爬蟲 tiktok關鍵詞搜索用戶數據信息 api接口

Tiktok APP API接口 Python 爬蟲采集Tiktok數據 采集結果頁面如下圖: https://www.tiktok.com/search?qwwe&t1706679918408 請求API http://api.xxx.com/tt/search/user?keywordwwe&count10&offset0&tokentest 請求參數 返回示例 聯系我們&…

178 折線圖-柱形圖-餅狀圖

1.折線圖 1、QChart 類繼承自 QGraphicsWidget,用于管理圖表、圖例和軸。2、QValueAxis 類專門用來自定義圖表中 X 和 Y 坐標軸。3、QLineSeries 類專門用于折線圖(曲線)的形式展示數據 //.pro QT core gui charts#ifndef WIDGET_H #defi…

探索鄰近奧秘:SKlearn中K-近鄰(KNN)算法的應用

探索鄰近奧秘:SKlearn中K-近鄰(KNN)算法的應用 在機器學習的世界里,K-近鄰(K-Nearest Neighbors,簡稱KNN)算法以其簡單直觀而著稱。KNN是一種基本的分類和回歸方法,它的工作原理非常…

Error in onLoad hook: “SyntaxError: Unexpected token u in JSON at position 0“

1.接收頁面報錯 Error in onLoad hook: "SyntaxError: Unexpected token u in JSON at position 0" Unexpected token u in JSON at position 0 at JSON.parse (<anonymous>) 2.發送頁面 &#xff0c;JSON.stringify(item) &#xff0c;將對象轉換為 JSO…