學習記錄:初次學習使用transformers進行大模型微調

初次使用transformers進行大模型微調

環境:

電腦配置:
筆記本電腦:I5(6核12線程) + 16G + RTX3070(8G顯存)
需要自行解決科學上網

Python環境:
python版本:3.8.8
大模型:microsoft/DialoGPT-medium(微軟的對話大模型,模型小,筆記本也能學習微調)
數據集:daily_dialog (日常對話數據集)

其他:
模型及數據集:使用來源于抱抱臉

微調大模型

準備工作:

下載模型:

找到自己想要的模型:

  1. 打開抱抱臉官網——點擊Model:
    在這里插入圖片描述

  2. 輸入要搜索的模型(這里以DialoGPT-medium為例):
    在這里插入圖片描述

  3. 復制名稱到代碼中替換要下載的模型名稱:

在這里插入圖片描述
模型下載:

import os
from transformers import AutoModel, AutoTokenizer# 因為使用了科學上網,需要進行處理
os.environ["HTTP_PROXY"] = "http://127.0.0.1:xxxx"
os.environ["HTTPS_PROXY"] = "http://127.0.0.1:xxxx"if __name__ == '__main__':# model_name = 'google-t5/t5-small'  # 要下載的模型名稱model_name = 'microsoft/DialoGPT-medium'  # 要下載的模型名稱 需要到抱抱臉進行復制cache_dir = r'xxxx'  # 模型保存位置# 加載模型時指定下載路徑model = AutoModel.from_pretrained(model_name, cache_dir=cache_dir)
下載數據集:

找到自己想要的模型:

  1. 打開抱抱臉官網——點擊Datasets:List item
  2. 輸入要搜索的內容,點擊對應數據集進入:
    在這里插入圖片描述
  3. 找到適合用的模型后,點擊復制
    在這里插入圖片描述

開始微調訓練

代碼示例:

# 系統模塊
import os# 第三方庫
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
from datasets import load_dataset# 設置代理(注意:可能需要根據實際網絡環境調整或移除)
os.environ["HTTP_PROXY"] = "http://127.0.0.1:xxxx"  # HTTP代理設置
os.environ["HTTPS_PROXY"] = "http://127.0.0.1:xxxx"  # HTTPS代理設置if __name__ == '__main__':# 數據準備階段 --------------------------------------------------------------# 加載完整數據集(daily_dialog包含日常對話數據集)full_dataset = load_dataset("daily_dialog", trust_remote_code=True)# 創建子數據集(僅使用訓練集前500條樣本,用于快速實驗)dataset = {"train": full_dataset["train"].select(range(500))  # select保持數據集結構}# 模型加載階段 --------------------------------------------------------------# 模型配置參數model_name = "microsoft/DialoGPT-medium"  # 使用微軟的對話生成預訓練模型cache_dir = r'xxx'  # 本地模型緩存路徑# 加載分詞器(重要:設置填充token與EOS token一致)tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)tokenizer.pad_token = tokenizer.eos_token  # 將填充token設置為與EOS相同# 加載預訓練模型(使用因果語言模型結構)model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir)# 數據預處理階段 ------------------------------------------------------------def tokenize_function(examples):"""將對話數據轉換為模型輸入格式的預處理函數"""# 將多輪對話用EOS token連接,并在結尾添加EOSdialogues = [tokenizer.eos_token.join(dialog) + tokenizer.eos_tokenfor dialog in examples["dialog"]]# 對文本進行分詞處理tokenized = tokenizer(dialogues,truncation=True,  # 啟用截斷max_length=512,  # 最大序列長度padding="max_length"  # 填充到最大長度(靜態填充))# 創建標簽(對于因果語言模型,標簽與輸入相同)tokenized["labels"] = tokenized["input_ids"].copy()return tokenized# 應用預處理(保留數據集結構)tokenized_dataset = {"train": dataset["train"].map(tokenize_function,batched=True,  # 批量處理提升效率batch_size=50,  # 每批處理50個樣本remove_columns=["dialog", "act", "emotion"]  # 移除原始文本列)}# 數據驗證(檢查預處理結果)print("Sample keys:", tokenized_dataset["train"][0].keys())  # 應包含input_ids, attention_mask, labelsprint("Input IDs:", tokenized_dataset["train"][0]["input_ids"][:5])  # 檢查前5個token# 訓練配置階段 --------------------------------------------------------------training_args = TrainingArguments(output_dir="./dialo_finetuned",  # 輸出目錄per_device_train_batch_size=2,  # 每個設備的批次大小(根據顯存調整)gradient_accumulation_steps=8,  # 梯度累積步數(模擬更大batch size)learning_rate=1e-5,  # 初始學習率(可調超參數)num_train_epochs=3,  # 訓練輪次(根據需求調整)fp16=True,  # 啟用混合精度訓練(需要GPU支持)logging_steps=10,  # 每10步記錄日志# 可添加的優化參數:# evaluation_strategy="steps",    # 添加驗證策略# save_strategy="epoch",          # 保存策略# warmup_steps=100,               # 學習率預熱步數)# 創建訓練器trainer = Trainer(model=model,args=training_args,train_dataset=tokenized_dataset["train"],  # 訓練數據集# 可擴展功能:# eval_dataset=tokenized_dataset["validation"],  # 添加驗證集# data_collator=...,             # 自定義數據整理器# compute_metrics=...,           # 添加評估指標)# 訓練執行階段 --------------------------------------------------------------trainer.train()  # 啟動訓練# 模型保存階段 --------------------------------------------------------------model.save_pretrained("./dialo_finetuned")  # 保存模型權重tokenizer.save_pretrained("./dialo_finetuned")  # 保存分詞器# 推薦使用以下方式統一保存:trainer.save_model("./dialo_finetuned")       # 官方推薦保存方式

微調后使用

代碼:

from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from transformers import TextStreamer
from collections import deque
import torchdef optimized_generation(text, tokenizer, model):inputs = tokenizer(text, return_tensors="pt").to(model.device)outputs = model.generate(**inputs,max_new_tokens=150,temperature=0.9,  # 越高越有創意 (0-1)top_k=50,  # 限制候選詞數量top_p=0.95,  # 核采樣閾值repetition_penalty=1.2,  # 抑制重復num_beams=3,  # 束搜索寬度early_stopping=True,do_sample=True)return tokenizer.decode(outputs[0], skip_special_tokens=True)# 單輪對話
def simple_chat(model_path, text, max_length=100):"""單輪對話:param text::param max_length::return:"""# 加載模型和分詞器tokenizer = AutoTokenizer.from_pretrained(model_path)model = AutoModelForCausalLM.from_pretrained(model_path)# 確保pad_token設置正確tokenizer.pad_token = tokenizer.eos_token# inputs = tokenizer(text + tokenizer.eos_token, return_tensors="pt")# outputs = model.generate(#     inputs.input_ids,#     max_length=max_length,#     pad_token_id=tokenizer.eos_token_id,#     temperature=0.7,#     do_sample=True# )# response = tokenizer.decode(outputs[0], skip_special_tokens=True)response = optimized_generation(text + tokenizer.eos_token, tokenizer, model)return response[len(text):]  # 去除輸入文本# 多輪對話
class DialogueBot:def __init__(self, model_path, max_history=3):self.tokenizer = AutoTokenizer.from_pretrained(model_path)self.model = AutoModelForCausalLM.from_pretrained(model_path).to("cuda")self.max_history = max_historyself.history = deque(maxlen=max_history * 2)  # 每輪包含用戶和機器人各一條# 確保pad_token設置if self.tokenizer.pad_token is None:self.tokenizer.pad_token = self.tokenizer.eos_tokendef generate_response(self, user_input):# 添加用戶輸入(帶EOS)self.history.append(f"User: {user_input}{self.tokenizer.eos_token}")# 構建prompt并編碼prompt = self._build_prompt()inputs = self.tokenizer(prompt,return_tensors="pt",max_length=512,truncation=True).to(self.model.device)# 流式輸出# streamer = TextStreamer(self.tokenizer)# 生成回復outputs = self.model.generate(inputs.input_ids,attention_mask=inputs.attention_mask,max_new_tokens=150,temperature=0.85,top_p=0.95,eos_token_id=self.tokenizer.eos_token_id,pad_token_id=self.tokenizer.eos_token_id,do_sample=True,# streamer=streamer,early_stopping=True)# 解碼并處理回復full_response = self.tokenizer.decode(outputs[0][inputs.input_ids.shape[-1]:],skip_special_tokens=True)# 清理無效內容(按第一個EOS截斷)clean_response = full_response.split(self.tokenizer.eos_token)[0].strip()# 添加機器人回復到歷史(帶EOS)self.history.append(f"Bot: {clean_response}{self.tokenizer.eos_token}")return clean_responsedef _build_prompt(self):return "".join(self.history)if __name__ == '__main__':# 指定模型路徑model_path = "./dialo_finetuned"# 測試單輪對話print(simple_chat(model_path, "Hello, how are you?"))# 使用示例 多輪對話# bot = DialogueBot(model_path)# while True:#     user_input = input("You: ")#     if user_input.lower() == "exit":#         break#     print("Bot:", bot.generate_response(user_input))

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/70834.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/70834.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/70834.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

【Java學習】Object類與接口

面向對象系列五 一、引用 1.自調傳自與this類型 2.類變量引用 3.重寫時的發生 二、Object類 1.toString 2.equals 3.hashCode 4.clone 三、排序規則接口 1.Comparable 2.Comparator 一、引用 1.自調傳自與this類型 似復刻變量調用里面的非靜態方法時,都…

OpenEuler學習筆記(三十五):搭建代碼托管服務器

以下是主流的代碼托管軟件分類及推薦,涵蓋自托管和云端方案,您可根據團隊規模、功能需求及資源情況選擇: 一、自托管代碼托管平臺(可私有部署) 1. GitLab 簡介: 功能全面的 DevOps 平臺,支持代碼托管、C…

Vscode無法加載文件,因為在此系統上禁止運行腳本

1.在 vscode 終端執行 get-ExecutionPolicy 如果返回是Restricted,說明是禁止狀態。 2.在 vscode 終端執行set-ExecutionPolicy RemoteSigned 爆紅說明沒有設置成功 3.在 vscode 終端執行Set-ExecutionPolicy -Scope CurrentUser RemoteSigned 然后成功后你再在終…

Transformer 架構 理解

大家讀完覺得有幫助記得關注和點贊!!! Transformer 架構:encoder/decoder 內部細節。 的介紹,說明 Transformer 架構相比當時主流的 RNN/CNN 架構的創新之處: 在 transformer 之前,最先進的架構…

事務的4個特性和4個隔離級別

事務的4個特性和4個隔離級別 1. 什么是事務2. 事務的ACID特性2.1 原子性2.2 一致性2.3 持久性2.4 隔離性 3. 事務的創建4. 事務并發時出現的問題4.1 DIRTY READ 臟讀4.2 NON - REPEATABLR READ 不可重復讀4.3 PHANTOM READ 幻讀 5. 事務的隔離級別5.1 READ UNCOMMITTED 讀未提交…

LeetCode熱題100- 字符串解碼【JavaScript講解】

古語有云:“事以密成,語以泄敗”! 關于字符串解碼: 題目:題解:js代碼:代碼中遇到的方法:repeat方法:為什么這里不用this.strstack.push(result)? 題目&#x…

水利工程安全包括哪幾個方面

水利工程安全培訓的內容主要包括以下幾個方面: 基礎知識和技能培訓 : 法律法規 :學習水利工程相關的安全生產法律法規,了解安全生產標準及規范。 事故案例 :通過分析事故案例,了解事故原因和教訓&#x…

淺談新能源汽車充電樁建設問題分析及解決方案

摘要: 在全球倡導低碳減排的大背景下,新能源成為熱門行業在全球范圍內得以開展。汽車尾氣排放會在一定程度上加重溫室效應,并且化石能源的日漸緊缺也迫切對新能源汽車發展提出新要求。現階段的新能源汽車以電力汽車為主,與燃油汽…

05-1基于vs2022的c語言筆記——運算符

目錄 前言 5.運算符和表達式 5-1-1 加減乘除運算符 1.把變量進行加減乘除運算 2.把常量進行加減乘除運算 3.對于比較大的數(往數軸正方向或者負方向),要注意占位符的選取 4.浮點數的加減乘除 5-1-2取余/取模運算符 1.基本規則 2.c語…

ubuntu:換源安裝docker-ce和docker-compose

更新apt源 apt換源:ubuntu:更新阿里云apt源-CSDN博客 安裝docker-ce 1、更新軟件源 sudo apt update2、安裝基本軟件 sudo apt-get install apt-transport-https ca-certificates curl software-properties-common lrzsz -y3、指定使用阿里云鏡像 su…

0—QT ui界面一覽

2025.2.26,感謝gpt4 1.控件盒子 1. Layouts(布局) 布局控件用于組織界面上的控件,確保它們的位置和排列方式合理。 Vertical Layout(垂直布局) :將控件按垂直方向排列。 建議:適…

Apache Doris 索引的全面剖析與使用指南

搞大數據開發的都知道,想要在海量數據里快速查數據,就像在星圖里找一顆特定的星星,賊費勁。不過別慌,數據庫索引就是咱們的 “定位神器”,能讓查詢效率直接起飛!就拿 Apache Doris 這個超火的分析型數據庫來…

docker file中ADD命令的介紹

在 Docker 的世界里,Dockerfile 是一個用于定義鏡像內容和行為的腳本文件。其中,ADD 指令是 Dockerfile 中一個非常重要的命令,用于將文件或目錄從主機文件系統復制到容器的文件系統中。本文將詳細介紹 ADD 指令的作用、使用方式以及一些最佳…

從零到一:如何用阿里云百煉和火山引擎搭建專屬 AI 助手(DeepSeek)?

本文首發:從零到一:如何用阿里云百煉和火山引擎搭建專屬 AI 助手(DeepSeek)? 阿里云百煉和火山引擎都推出了免費的 DeepSeek 模型體驗額度,今天我和大家一起搭建一個本地的專屬 AI 助手。  阿里云百煉為 …

cpp中的繼承

一、繼承概念 在cpp中,封裝、繼承、多態是面向對象的三大特性。這里的繼承就是允許已經存在的類(也就是基類)的基礎上創建新類(派生類或者子類),從而實現代碼的復用。 如上圖所示,Person是基類&…

【QT】QLinearGradient 線性漸變類簡單使用教程

目錄 0.簡介 1)qtDesigner中 2)實際執行 1.功能詳述 3.舉一反三的樣式 0.簡介 QLinearGradient 是 Qt 框架中的一個類,用于定義線性漸變效果(通過樣式表設置)。它可以用來填充形狀、背景或其他圖形元素&#xff0…

前端項目配置 Nginx 全攻略

在前端開發中,項目開發完成后,如何高效、穩定地將其部署到生產環境是至關重要的一步。Nginx 作為一款輕量級、高性能的 Web 服務器和反向代理服務器,憑借其出色的性能和豐富的功能,成為了前端項目部署的首選方案。本文將詳細介紹在…

網絡安全學習-常見web漏洞的滲xxx透以及防護方法

滲XX透測試 弱口令漏洞 漏洞描述 目標網站管理入口(或數據庫等組件的外部連接)使用了容易被猜測的簡單字符口令、或者是默認系統賬號口令。 滲XX透測試 如果不存在驗證碼,則直接使用相對應的弱口令字典使用burpsuite 進行爆破如果存在驗證…

網絡安全 機器學習算法 計算機網絡安全機制

(一)網絡操作系統 安全 網絡操作系統安全是整個網絡系統安全的基礎。操作系統安全機制主要包括訪問控制和隔離控制。 訪問控制系統一般包括主體、客體和安全訪問政策 訪問控制類型: 自主訪問控制強制訪問控制 訪問控制措施: 入…

2025網絡安全等級測評報告,信息安全風險評估報告(Word模板)

一、概述 1.1工作方法 1.2評估依據 1.3評估范圍 1.4評估方法 1.5基本信息 二、資產分析 2.1 信息資產識別概述 2.2 信息資產識別 三、評估說明 3.1無線網絡安全檢查項目評估 3.2無線網絡與系統安全評估 3.3 ip管理與補丁管理 3.4防火墻 四、威脅細類分析 4.1威脅…