【深度學習踩坑實錄】從 Checkpoint 報錯到 TrainingArguments 精通:QNLI 任務微調全流程復盤

作為一名深度學習初學者,最近在基于 Hugging Face Transformers 微調 BERT 模型做 QNLI 任務時,被Checkpoint 保存TrainingArguments 配置這兩個知識點卡了整整兩天。從磁盤爆滿、權重文件加載報錯,到不知道如何控制 Checkpoint 數量,每一個問題都讓我一度想放棄。好在最終逐一解決,特此整理成博客,希望能幫到同樣踩坑的朋友。

一、核心背景:我在做什么?

本次任務是基于 GLUE 數據集的 QNLI(Question Natural Language Inference,問題自然語言推理)任務,用 Hugging Face 的run_glue.py腳本微調bert-base-cased模型。核心需求很簡單:

  1. 順利完成模型微調,避免中途中斷;
  2. 控制 Checkpoint(模型快照)的保存數量,防止磁盤爆滿;
  3. 后續能正常加載 Checkpoint,用于后續的 TRAK 貢獻度分析。

但實際操作中,光是 “Checkpoint 保存” 這一個環節,就暴露出我對TrainingArguments(訓練參數配置類)的認知盲區。

二、先搞懂:TrainingArguments 是什么?為什么它很重要?

在解決問題前,必須先理清TrainingArguments的核心作用 —— 它是 Hugging Face Transformers 庫中控制訓練全流程的 “總開關”,幾乎所有與訓練相關的配置(如批次大小、學習率、Checkpoint 保存策略)都通過它定義。

1. TrainingArguments 的本質

TrainingArguments是一個數據類(dataclass)?,它將訓練過程中需要的所有參數(從優化器設置到日志保存)封裝成結構化對象,再傳遞給Trainer(訓練器)實例。無需手動編寫訓練循環,只需配置好TrainingArgumentsTrainer就能自動完成訓練、驗證、Checkpoint 保存等操作。

2. 常用核心參數(按功能分類)

我整理了本次任務中最常用的參數,按 “訓練基礎配置”“Checkpoint 控制”“日志與驗證” 三類劃分,新手直接套用即可:

類別參數名作用說明常用值示例
訓練基礎配置output_dir訓練結果(Checkpoint、日志、指標)的保存根路徑/root/autodl-tmp/bert_qnli
per_device_train_batch_size單設備訓練批次大小(GPU 內存不足就調小)8/16/32
learning_rate學習率(BERT 類模型微調常用 5e-5/2e-5)5e-5
num_train_epochs訓練總輪次(QNLI 任務 3-5 輪足夠)3.0
fp16是否啟用混合精度訓練(GPU 支持時可加速,減少顯存占用)true
Checkpoint 控制save_strategyCheckpoint 保存時機(核心!)"epoch"(按輪次)/"steps"(按步數)
save_steps按步數保存時,每多少步保存一次(需配合save_strategy="steps"2000/5000
save_total_limit最多保存多少個 Checkpoint(超過自動刪除最舊的,防磁盤爆滿)2/3
save_only_model是否只保存模型權重(不保存優化器、調度器狀態,減小文件體積)true
overwrite_output_dir是否覆蓋已存在的output_dir(避免 “目錄非空” 報錯)true
日志與驗證do_eval是否在訓練中執行驗證(判斷模型性能)true
eval_strategy驗證時機(建議與save_strategy一致)"epoch"/"steps"
logging_dir日志保存路徑(TensorBoard 可視化用)/root/autodl-tmp/bert_qnli/logs
logging_steps每多少步記錄一次日志(查看訓練進度)100/200

3. TrainingArguments 的配置方式

TrainingArguments不支持在代碼中硬編碼(除非修改腳本),常用兩種配置方式,新手推薦第二種:

方式 1:命令行參數(快速調試)

運行run_glue.py時,通過--參數名 參數值的格式傳遞,示例:

python run_glue.py \--model_name_or_path bert-base-cased \--task_name qnli \--output_dir /root/autodl-tmp/bert_qnli \--do_train \--do_eval \--per_device_train_batch_size 8 \--learning_rate 5e-5 \--num_train_epochs 3 \--save_strategy epoch \--save_total_limit 3 \--overwrite_output_dir
方式 2:JSON 配置文件(固定復用)

將所有參數寫入 JSON 文件(如qnli_train_config.json),運行時直接指定文件,適合參數較多或多任務復用:

{"model_name_or_path": "bert-base-cased","task_name": "qnli","do_train": true,"do_eval": true,"max_seq_length": 128,"per_device_train_batch_size": 8,"learning_rate": 5e-5,"num_train_epochs": 3.0,"output_dir": "/root/autodl-tmp/bert_qnli_new","save_strategy": "epoch","save_total_limit": 3,"overwrite_output_dir": true,"logging_dir": "/root/autodl-tmp/bert_qnli_new/logs","logging_steps": 100
}
python run_glue.py qnli_train_config.json

三、我的踩坑實錄:3 個經典問題與解決方案

接下來重點復盤我遇到的 3 個核心問題,每個問題都附 “報錯現象→原因分析→解決步驟”,新手可直接對號入座。

問題 1:訓練中途磁盤爆滿,被迫中斷

報錯現象

訓練到約 3 萬步時,服務器提示 “磁盤空間不足”,查看output_dir發現有 10 多個 Checkpoint 文件夾,每個文件夾占用數百 MB,累計占用超過 20GB。

原因分析

默認情況下,TrainingArgumentssave_strategy"steps"(每 500 步保存一次),且save_total_limit未設置(不限制保存數量)。QNLI 任務 1 個 epoch 約 1.3 萬步,3 個 epoch 會生成 6-8 個 Checkpoint,加上優化器狀態文件(optimizer.pt),很容易撐爆磁盤。

解決步驟
  1. TrainingArguments中添加save_total_limit: 3(最多保存 3 個 Checkpoint,超過自動刪除最舊的);
  2. 選擇合適的save_strategy:若追求穩定,用"epoch"(每輪保存一次,3 個 epoch 僅 3 個 Checkpoint);若需中途恢復,用"steps"并設置較大的save_steps(如 5000 步);
  3. 可選添加save_only_model: true(只保存模型權重,不保存優化器狀態,每個 Checkpoint 體積從 500MB 縮減到 300MB 左右)。

問題 2:加載 Checkpoint 時提示 “_pickle.UnpicklingError: invalid load key, '\xe0'”

報錯現象

訓練中斷后,嘗試加載已保存的 Checkpoint(路徑/root/autodl-tmp/bert_qnli/checkpoint-31000),運行代碼:

model.load_state_dict(torch.load(os.path.join(checkpoint, "model.safetensors"), map_location=DEVICE))

報錯:_pickle.UnpicklingError: invalid load key, '\xe0'

原因分析
  • model.safetensorsSafetensors 格式的權重文件(更安全,但需專用方法加載);
  • torch.load()是 PyTorch 原生加載函數,更適合加載pytorch_model.bin(PyTorch 二進制格式),用它加載 Safetensors 格式會因 “格式不兼容” 報錯。
解決步驟

有兩種方案,按需選擇:

方案 1:改用 Safetensors 專用加載函數(需先安裝safetensors庫)

pip install safetensors
from safetensors.torch import load_file
# 用load_file()替代torch.load()
model.load_state_dict(load_file(os.path.join(checkpoint, "model.safetensors"), device=DEVICE))

方案 2:讓 Checkpoint 默認保存為pytorch_model.bin格式
TrainingArguments中添加save_safetensors: false,后續生成的 Checkpoint 會默認保存為pytorch_model.bin,直接用torch.load()加載即可:

model.load_state_dict(torch.load(os.path.join(checkpoint, "pytorch_model.bin"), map_location=DEVICE))

問題 3:配置 TrainingArguments 后,Checkpoint 遲遲不生成

報錯現象

設置save_strategy: "epoch"后,訓練到 4000 步仍未生成任何 Checkpoint,懷疑配置未生效。

原因分析

save_strategy: "epoch"表示每輪訓練結束后才保存 Checkpoint,而 QNLI 任務 1 個 epoch 約 1.3 萬步(訓練集約 10 萬樣本,batch_size=8時:100000÷8=12500 步)。4000 步僅完成第一個 epoch 的 1/3,未到保存時機,屬于正常現象。

解決步驟
  1. 若想快速驗證配置是否生效,臨時改用save_strategy: "steps"并設置較小的save_steps(如 2000 步),訓練到 2000 步時會自動生成checkpoint-2000
  2. 若堅持按 epoch 保存,耐心等待第一個 epoch 結束(約 1.3 萬步),日志會打印Saving model checkpoint to xxx,此時output_dir下會出現第一個 Checkpoint;
  3. 查看日志確認配置:搜索save_strategysave_total_limit,確認日志中顯示的參數與 JSON 配置一致(避免 JSON 文件未被正確讀取)。

四、總結:TrainingArguments 配置 “避坑指南”

經過這次踩坑,我總結出 3 條新手必看的 “避坑原則”,幫你少走彎路:

  1. 優先用 JSON 配置文件:命令行參數容易遺漏,JSON 文件可固化配置,后續復用或修改時更清晰;
  2. Checkpoint 配置 “三要素”:每次訓練前必確認save_strategy(保存時機)、save_total_limit(保存數量)、output_dir(保存路徑),這三個參數直接決定是否會出現磁盤爆滿或 Checkpoint 丟失;
  3. 加載 Checkpoint 前先看格式:先查看 Checkpoint 文件夾中的權重文件名(是model.safetensors還是pytorch_model.bin),再選擇對應的加載函數,避免格式不兼容報錯。

最后想說,深度學習中的 “環境配置” 和 “參數調試” 雖然繁瑣,但每一次踩坑都是對知識點的深化。這次從 “完全不懂 TrainingArguments” 到 “能靈活控制 Checkpoint”,雖然花了兩天時間,但后續再做其他 GLUE 任務(如 SST-2、MRPC)時,直接復用配置就能快速上手 —— 這大概就是踩坑的價值吧。

如果你也在做類似任務,歡迎在評論區交流更多踩坑經驗~

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/96951.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/96951.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/96951.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Java面試小冊(3)

21【Q】: 什么是Java的SPI機制?【A】:SPI 是一種插件機制,用于在運行時動態加載服務的實現。它通過定義接口(服務接口)并提供一種可擴展的方式來讓服務的提供著(實現類)在運行時注入&#xff0c…

P1150 Peter 的煙

記錄20#include <bits/stdc.h> using namespace std; int main(){int n,k;cin>>n>>k;int cnt0;while(n>k){cntk;nn-k1;}cntn;cout<<cnt;return 0; }突破口每吸完一根煙就把煙蒂保存起來&#xff0c;k&#xff08;k>1&#xff09;個煙蒂可以換一個…

Cursor和Hbuilder用5分鐘開發微信小程序

分享一個5分鐘搞定微信小程序開發的技能&#xff0c;需要用到兩個工具&#xff1a;Cursor和Hbuilder。 第1步、下載HBuilder。Hbuilder可以實現一套代碼直接生成安卓、蘋果、鴻蒙各個平臺APP。訪問Hbuilder的官方網站&#xff0c;HBuilderX-高效極客技巧&#xff0c;選擇適合…

k8s的dashboard

找一個裝有docker的機器&#xff0c;在一個rocky linux的虛擬機里弄拉取一個rancher鏡像建立一個目錄&#xff0c;目的&#xff1a;和里面數據做持久化關聯后臺運行&#xff0c;讓他有權限&#xff0c;8080端口和容器80端口映射&#xff0c;443和443做映射查看一下刪掉&#xf…

橋接模式,打造靈活可擴展的日志系統C++

一、為什么用橋接模式在企業開發中&#xff0c;日志系統幾乎是標配。常見需求&#xff1a;日志有多種類型&#xff08;Info、Warning、Error 等&#xff09;&#xff1b;日志需要支持多種輸出方式&#xff08;控制臺輸出、寫文件、遠程上傳、數據庫存儲等&#xff09;。如果把這…

kafka--基礎知識點--5.3--producer事務

1 事務簡介 Kafka事務是Apache Kafka在流處理場景中實現Exactly-Once語義的核心機制。它允許生產者在跨多個分區和主題的操作中&#xff0c;以原子性&#xff08;Atomicity&#xff09;的方式提交或回滾消息&#xff0c;確保數據處理的最終一致性。例如&#xff0c;在流處理中…

利用DeepSeek實現服務器客戶端模式的DuckDB原型

在網上看到韓國公司開發的一款GooseDB&#xff0c;DuckDB? 的功能擴展分支&#xff0c;具有服務器/客戶端、多會話和并發寫入支持&#xff0c;使用 PostgreSQL 有線協議&#xff0c;但它是Freeware而不是開源&#xff0c;所以讓DeepSeek實現之。 首先把readme頁面發給他翻譯&a…

麥當勞APP逆向

版本 V 7.0.17.0反調試 梆梆企業加固 frida反調試部分代碼 headers {"biz_scenario": "500","biz_from": "1004","User-Agent": "mcdonald_Android/7.0.17.0 (Android)","ct": "102","…

大數據畢業設計選題推薦-基于大數據的結核病數據可視化分析系統-Hadoop-Spark-數據可視化-BigData

?作者主頁&#xff1a;IT畢設夢工廠? 個人簡介&#xff1a;曾從事計算機專業培訓教學&#xff0c;擅長Java、Python、PHP、.NET、Node.js、GO、微信小程序、安卓Android等項目實戰。接項目定制開發、代碼講解、答辯教學、文檔編寫、降重等。 ?文末獲取源碼? 精彩專欄推薦?…

Vue3 視頻播放器完整指南 – @videojs-player/vue 從入門到精通

前言 在 Vue 3 生態中&#xff0c;視頻播放功能是許多應用的核心需求。videojs-player/vue 是一個專門為 Vue 3 設計的視頻播放器組件&#xff0c;基于成熟的 Video.js 庫構建&#xff0c;提供了簡單而強大的視頻播放解決方案。 主要特性 Vue 3 組件化&#xff1a;原生 Vue …

【靶場練習】--DVWA第一關Brute Force(暴力破解)全難度分析

注意&#xff0c;這一關必須要使用Burpsuite來抓包 目錄Low1.抓包2.發送到爆破模塊3.選擇爆破模式爆破模式介紹4.添加載荷5.添加字典6.爆破查看查看源碼Medium查看源碼High1.抓包2.在bp的extensions中找到CSRF Token Tracker&#xff0c;并安裝3.構造字典4.成功爆破查看源碼Imp…

Java語言——排序算法

一、基本概念排序&#xff1a;將n個數字按一定順序排列&#xff08;比如&#xff1a;升序&#xff0c;或者降序&#xff09; ^內部排序 &#xff1a;若整個排序過程不需要訪問外存便能完成&#xff0c;則稱此類排序問題為內部排序 ^外部排序&#xff1a;若參加排序的記錄數量很…

【Linux】人事檔案——用戶及組管理

目錄 1 用戶及組管理 2?用戶及用戶組管理命令 2.1 useradd&#xff1a;建立用戶 useradd命令用于建立用戶&#xff0c;該 2.2 passwd&#xff1a;更改用戶密碼 2.3 usermod&#xff1a;更改用戶信息 2.4 groupadd&#xff1a;建立用戶組 2.5 finger&#xff1a;查找并顯…

給定一個有序的正數數組arr和一個正數range,如果可以自由選擇arr中的數字,想累加得 到 1~range 范圍上所有的數,返回arr最少還缺幾個數。

給定一個有序的正數數組arr和一個正數range&#xff0c;如果可以自由選擇arr中的數字&#xff0c;想累加得 到 1~range 范圍上所有的數&#xff0c;返回arr最少還缺幾個數。 #include <iostream> #include <vector>using namespace std;void func1(std::vector<…

BigemapPro快速添加歷史影像(Arcgis衛星地圖歷史地圖)

這是Esri(Arcgis)官方提供的歷史影像數據&#xff0c;可放心使用。https://livingatlas.arcgis.com/wayback如何快速添加到Bigemap Pro軟件里&#xff0c;詳細步驟如下&#xff1a;復制下面的文本保存為 配置.bmmap,然后拖入軟件就可以了{"BmLayerVersion":"1.0…

[免費]基于Python的Django醫院管理系統【論文+源碼+SQL腳本】

大家好&#xff0c;我是python222_小鋒老師&#xff0c;看到一個不錯的基于Python的Django醫院管理系統&#xff0c;分享下哈。 項目視頻演示 https://www.bilibili.com/video/BV1iPH8zmEut/ 項目介紹 隨著人民生活水平日益增長&#xff0c;科技日益發達的今天&#xff0c;…

MyBatis 從入門到精通(第三篇)—— 動態 SQL、關聯查詢與查詢緩存

在前兩篇博客中&#xff0c;我們掌握了 MyBatis 的基礎搭建、核心架構與 Mapper 代理開發&#xff0c;能應對簡單的單表 CRUD 場景。但實際項目中&#xff0c;業務往往更復雜 —— 比如 “多條件動態查詢”“員工與部門的關聯查詢”“高頻查詢的性能優化” 等。本篇將聚焦 MyBa…

Linux內核中IPv4的BEET模式封裝機制解析

引言 在Linux網絡棧中,IPSec提供了網絡層的數據加密和認證服務。傳統的IPSec支持兩種模式:傳輸模式(Transport Mode)和隧道模式(Tunnel Mode)。然而,這兩種模式各有優缺點:傳輸模式開銷小但無法隱藏原始IP頭;隧道模式提供完全封裝但增加了開銷。 BEET(Bound End-to…

設計模式——創建型模式

什么是設計模式&#xff1f;設計模式是軟件工程中解決常見問題的經典方案&#xff0c;它們代表了最佳實踐和經驗總結。通過使用設計模式&#xff0c;開發者可以創建更加靈活、可維護和可擴展的代碼結構。設計模式不是具體的代碼實現&#xff0c;而是針對特定問題的通用解決方案…

我愛學算法之—— 位運算(上)

常見位運算 對于位運算&#xff1a; &&#xff1a;按位與&#xff0c;有0則0。 |&#xff1a;按位或&#xff0c;有1則1。 ^&#xff1a;按位異或&#xff0c;相同為0、不同為1。&#xff08;無進位相加&#xff09; ~&#xff1a;二進制位按位取反。 對于位運算的常見使用…