GPT-2 語言模型 - 模型訓練

本節代碼是一個完整的機器學習工作流程,用于訓練一個基于GPT-2的語言模型。下面是對這段代碼的詳細解釋:

文件目錄如下

1. 初始化和數據準備

  • 設置隨機種子

    random.seed(1002)

    確保結果的可重復性。

  • 定義參數

    test_rate = 0.2
    context_length = 128
    • test_rate:測試集占總數據集的比例。

    • context_length:模型處理的文本長度。

  • 獲取數據文件

    all_files = glob(pathname=os.path.join("data","*"))

    使用 glob 獲取 data 目錄下的所有文件。

  • 劃分數據集

    test_file_list = random.sample(all_files, int(len(all_files) * test_rate))
    train_file_list = [i for i in all_files if i not in test_file_list]

    將數據集隨機劃分為訓練集和測試集。

  • 加載數據集

    raw_datasets = load_dataset("csv", data_files={"train": train_file_list, "vaild": test_file_list}, cache_dir="cache_data")

    使用 datasets 庫加載 CSV 格式的數據集,并緩存到 cache_data 目錄。

2. 數據預處理

  • 初始化分詞器

    tokenizer = BertTokenizerFast.from_pretrained("D:/bert-base-chinese")
    tokenizer.add_special_tokens({"bos_token":"[begin]","eos_token":"[end]"})

    從本地路徑加載預訓練的 BERT 分詞器,并添加自定義的開始和結束標記。

  • 數據預處理

    tokenize_datasets = raw_datasets.map(tokenize, batched=True, remove_columns=raw_datasets["train"].column_names)

    使用 map 方法對數據集進行預處理,將文本轉換為模型可接受的格式。

    • tokenize 函數對文本進行分詞和截斷。

    • batched=True 表示批量處理數據。

    • remove_columns 刪除原始數據集中的列。

3. 模型配置和初始化

  • 模型配置

    config = GPT2Config.from_pretrained("config",vocab_size=len(tokenizer),n_ctx=context_length,bos_token_id=tokenizer.bos_token_id,eos_token_id=tokenizer.eos_token_id,)

    加載預訓練的 GPT-2 配置,并根據分詞器的詞匯表大小和上下文長度進行調整。

  • 初始化模型

    model = GPT2LMHeadModel(config)
    model_size = sum([t.numel() for t in model.parameters()])
    print(f"model_size: {model_size/1000/1000} M")

    根據配置初始化 GPT-2 語言模型,并計算模型參數的總數,打印模型大小(以兆字節為單位)。

4. 訓練設置

  • 數據整理器

    data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

    使用 DataCollatorForLanguageModeling 整理訓練數據,設置 mlm=False 表示不使用掩碼語言模型。

  • 訓練參數

    args = TrainingArguments(learning_rate=1e-5,num_train_epochs=100,per_device_train_batch_size=10,per_device_eval_batch_size=10,eval_steps=2000,logging_steps=2000,gradient_accumulation_steps=5,weight_decay=0.1,warmup_steps=1000,lr_scheduler_type="cosine",save_steps=100,output_dir="model_output",fp16=True,
    )

    配置訓練參數,包括學習率、訓練輪數、批大小、評估間隔等。

  • 初始化訓練器

    trianer = Trainer(model=model,args=args,tokenizer=tokenizer,data_collator=data_collator,train_dataset=tokenize_datasets["train"],eval_dataset=tokenize_datasets["vaild"]
    )
  • 啟動訓練

    trianer.train()

    使用 Trainer 類啟動模型訓練。

需復現完整代碼

from glob import glob
import os
from torch.utils.data import Dataset
from datasets import load_dataset
import random
from transformers import BertTokenizerFast
from transformers import GPT2Config
from transformers import GPT2LMHeadModel
from transformers import DataCollatorForLanguageModeling
from transformers import Trainer,TrainingArgumentsdef tokenize(element):outputs = tokenizer(element["content"],truncation=True,max_length=context_length,return_overflowing_tokens=True,return_length=True)input_batch = []for length,input_ids in zip(outputs["length"],outputs["input_ids"]):if length == context_length:input_batch.append(input_ids)return {"input_ids":input_batch}if __name__ == "__main__":random.seed(1002)test_rate = 0.2context_length = 128all_files = glob(pathname=os.path.join("data","*"))test_file_list = random.sample(all_files,int(len(all_files)*test_rate))train_file_list = [i for i in all_files if i not in test_file_list]raw_datasets = load_dataset("csv",data_files={"train":train_file_list,"vaild":test_file_list},cache_dir="cache_data")tokenizer = BertTokenizerFast.from_pretrained("D:/bert-base-chinese")tokenizer.add_special_tokens({"bos_token":"[begin]","eos_token":"[end]"})tokenize_datasets = raw_datasets.map(tokenize,batched=True,remove_columns=raw_datasets["train"].column_names)config = GPT2Config.from_pretrained("config",vocab_size=len(tokenizer),n_ctx=context_length,bos_token_id = tokenizer.bos_token_id,eos_token_id = tokenizer.eos_token_id,)model = GPT2LMHeadModel(config)model_size = sum([ t.numel() for t in model.parameters()])print(f"model_size: {model_size/1000/1000} M")data_collator = DataCollatorForLanguageModeling(tokenizer,mlm=False)args = TrainingArguments(learning_rate=1e-5,num_train_epochs=100,per_device_train_batch_size=10,per_device_eval_batch_size=10,eval_steps=2000,logging_steps=2000,gradient_accumulation_steps=5,weight_decay=0.1,warmup_steps=1000,lr_scheduler_type="cosine",save_steps=100,output_dir="model_output",fp16=True,)trianer = Trainer(model=model,args=args,tokenizer=tokenizer,data_collator=data_collator,train_dataset=tokenize_datasets["train"],eval_dataset=tokenize_datasets["vaild"])trianer.train()

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/75718.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/75718.shtml
英文地址,請注明出處:http://en.pswp.cn/web/75718.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

架構師面試(二十九):TCP Socket 編程

問題 今天考察網絡編程的基礎知識。 在基于 TCP 協議的網絡 【socket 編程】中可能會遇到很多異常,在下面的相關描述中說法正確的有哪幾項呢? A. 在建立連接被拒絕時,有可能是因為網絡不通或地址錯誤或 server 端對應端口未被監聽&#x…

HTTP實現心跳模塊

HTTP實現心跳模塊 使用輕量級的cHTTP庫cpp-httplib重現實現HTTP心跳模塊 頭文件HttplibHeartbeat.h #ifndef HTTPLIB_HEARTBEAT_H #define HTTPLIB_HEARTBEAT_H#include <string> #include <thread> #include <atomic> #include <chrono> #include …

openharmony—release—4.1開發環境搭建(踩坑記錄)

環境開發需要分別在window以及ubuntu下進行相應設置 一、window 1.安裝DevEco Device Tool OpenAtom OpenHarmony 二、ubuntu 1.將Ubuntu Shell環境修改為bash ls -l /bin/sh 2.打開終端工具&#xff0c;執行如下命令&#xff0c;輸入密碼&#xff0c;然后選擇No&#xff0…

Go學習系列文章聲明

本次學習是基于B站的視頻&#xff0c;【Udemy高分熱門付費課程】Golang&#xff1a;完整開發者指南&#xff08;基礎知識和高級特性&#xff09;中英文字幕_嗶哩嗶哩_bilibili 本人會嘗試輸出視頻中的內容&#xff0c;如有錯誤歡迎指出 next page: Go installation process

error: RPC failed; HTTP 408 curl 22 The requested URL returned error: 408

在git push時報錯&#xff1a;error: RPC failed; HTTP 408 curl 22 The requested URL returned error: 408 原因&#xff1a;可能是推送的文件太大&#xff0c;要么是緩存不夠&#xff0c;要么是網絡不行。 解決方法&#xff1a; 將本地 http.postBuffer 數值調整到500MB&…

Android.bp中添加條件判斷編譯方式

背景&#xff1a; 馬哥學員朋友以前在vip群里&#xff0c;有問道如何在Android.bp中添加條件判斷&#xff0c;在工作中經常需要一套代碼兼容發貨目標版本&#xff0c;即代碼都是公共的一套&#xff0c;但是需要用這一套代碼集成到各個產品設備上 但是這個產品設備可能面臨比…

swift ui基礎

一個樸實無華的目錄 今日學習內容&#xff1a;1.三種布局&#xff08;可以相互包裹&#xff09;1.1 vstack&#xff08;豎直&#xff09;&#xff1a;先寫的在上面1.1 hstack&#xff08;水平&#xff09;&#xff1a;先寫的在左邊1.1 zstack&#xff08;前后&#xff09;&…

第16屆藍橋杯單片機模擬試題Ⅲ

試題 代碼 sys.h #ifndef __SYS_H__ #define __SYS_H__#include <STC15F2K60S2.H> //sys.c extern unsigned char UI; //界面標志(0濕度界面、1參數界面、2時間界面) extern unsigned char time; //時間間隔(1s~10S) extern bit ssflag; //啟動/停止標志…

Node.js中URL模塊詳解

Node.js 中 URL 模塊全部 API 詳解 1. URL 類 const { URL } require(url);// 1. 創建 URL 對象 const url new URL(https://www.example.com:8080/path?queryvalue#hash);// 2. URL 屬性 console.log(協議:, url.protocol); // https: console.log(主機名:, url.hos…

Java接口性能優化面試問題集錦:高頻考點與深度解析

1. 如何定位接口性能瓶頸&#xff1f;常用哪些工具&#xff1f; 考察點&#xff1a;性能分析工具的使用與問題定位能力。 核心答案&#xff1a; 工具&#xff1a;Arthas&#xff08;在線診斷&#xff09;、JProfiler&#xff08;內存與CPU分析&#xff09;、VisualVM、Prometh…

WheatA小麥芽:農業氣象大數據下載器

今天為大家介紹的軟件是WheatA小麥芽&#xff1a;專業純凈的農業氣象大數據系統。下面&#xff0c;我們將從軟件的主要功能、支持的系統、軟件官網等方面對其進行簡單的介紹。 主要內容來源于軟件官網&#xff1a;WheatA小麥芽的官方網站是http://www.wheata.cn/ &#xff0c;…

Python10天突擊--Day 2: 實現觀察者模式

以下是 Python 實現觀察者模式的完整方案&#xff0c;包含同步/異步支持、類型注解、線程安全等特性&#xff1a; 1. 經典觀察者模式實現 from abc import ABC, abstractmethod from typing import List, Anyclass Observer(ABC):"""觀察者抽象基類""…

CST1019.基于Spring Boot+Vue智能洗車管理系統

計算機/JAVA畢業設計 【CST1019.基于Spring BootVue智能洗車管理系統】 【項目介紹】 智能洗車管理系統&#xff0c;基于 Spring Boot Vue 實現&#xff0c;功能豐富、界面精美 【業務模塊】 系統共有三類用戶&#xff0c;分別是&#xff1a;管理員用戶、普通用戶、工人用戶&…

Windows上使用Qt搭建ARM開發環境

在 Windows 上使用 Qt 和 g++-arm-linux-gnueabihf 進行 ARM Linux 交叉編譯(例如針對樹莓派或嵌入式設備),需要配置 交叉編譯工具鏈 和 Qt for ARM Linux。以下是詳細步驟: 1. 安裝工具鏈 方法 1:使用 MSYS2(推薦) MSYS2 提供 mingw-w64 的 ARM Linux 交叉編譯工具鏈…

Python爬蟲教程011:scrapy爬取當當網數據開啟多條管道下載及下載多頁數據

文章目錄 3.6.4 開啟多條管道下載3.6.5 下載多頁數據3.6.6 完整項目下載3.6.4 開啟多條管道下載 在pipelines.py中新建管道類(用來下載圖書封面圖片): # 多條管道開啟 # 要在settings.py中開啟管道 class DangdangDownloadPipeline:def process_item(self, item, spider):…

Mysql -- 基礎

SQL SQL通用語法&#xff1a; SQL分類&#xff1a; DDL: 數據庫操作 查詢&#xff1a; SHOW DATABASES&#xff1b; 創建&#xff1a; CREATE DATABASE[IF NOT EXISTS] 數據庫名 [DEFAULT CHARSET字符集] [COLLATE 排序規則]&#xff1b; 刪除&#xff1a; DROP DATABA…

實操(環境變量)Linux

環境變量概念 我們用語言寫的文件編好后變成了程序&#xff0c;./ 運行的時候他就會變成一個進程被操作系統調度并運行&#xff0c;運行完畢進程相關資源被釋放&#xff0c;因為它是一個bash的子進程&#xff0c;所以它退出之后進入僵尸狀態&#xff0c;bash回收他的退出結果&…

torch.cat和torch.stack的區別

torch.cat 和 torch.stack 是 PyTorch 中用于組合張量的兩個常用函數&#xff0c;它們的核心區別在于輸入張量的維度和輸出張量的維度變化。以下是詳細對比&#xff1a; 1. torch.cat (Concatenate) 作用&#xff1a;沿現有維度拼接多個張量&#xff0c;不創建新維度 輸入要求…

深入解析@Validated注解:Spring 驗證機制的核心工具

一、注解出處與核心定位 1. 注解來源 ? 所屬框架&#xff1a;Validated 是 Spring Framework 提供的注解&#xff08;org.springframework.validation.annotation 包下&#xff09;。 ? 核心定位&#xff1a; 作為 Spring 對 JSR-380&#xff08;Bean Validation 2.0&#…

2025年認證杯數學建模競賽A題完整分析論文(含模型、可運行代碼)(共32頁)

2025年認證杯數學建模競賽A題完整分析論文 目錄 摘要 一、問題分析 二、問題重述 三、模型假設 四、 模型建立與求解 4.1問題1 4.1.1問題1解析 4.1.2問題1模型建立 4.1.3問題1樣例代碼&#xff08;僅供參考&#xff09; 4.1.4問題1求解結果分析&#xff08…