使用PEFT庫進行ChatGLM3-6B模型的LORA高效微調

PEFT庫進行ChatGLM3-6B模型LORA高效微調

  • LORA微調ChatGLM3-6B模型
    • 安裝相關庫
    • 使用ChatGLM3-6B
    • 模型GPU顯存占用
    • 準備數據集
    • 加載模型
    • 加載數據集
    • 數據處理
    • 數據集處理
    • 配置LoRA
    • 配置訓練超參數
    • 開始訓練
    • 保存LoRA模型
    • 模型推理
    • 從新加載
    • 合并模型
    • 使用微調后的模型

LORA微調ChatGLM3-6B模型

本文基于transformers、peft等框架,對ChatGLM3-6B模型進行Lora微調。

LORA(Low-Rank Adaptation)是一種高效的模型微調技術,它可以通過在預訓練模型上添加額外的低秩權重矩陣來微調模型,從而僅需更新很少的參數即可獲得良好的微調性能。這相比于全量微調大幅減少了訓練時間和計算資源的消耗。

安裝相關庫

pip install ransformers==4.37.2 peft==0.8.0 accelerate==0.27.0 bitsandbytes

使用ChatGLM3-6B

直接調用ChatGLM3-6B模型來生成對話

from transformers import AutoTokenizer, AutoModelmodel_id = "/root/work/chatglm3-6b"
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
#model = AutoModel.from_pretrained(model_id, trust_remote_code=True).half().cuda()
model = AutoModel.from_pretrained(model_id, trust_remote_code=True, device='cuda')model = model.eval()
response, history = model.chat(tokenizer, "你好", history=history)
print(response)

在這里插入圖片描述

模型GPU顯存占用

默認情況下,模型以半精度(float16)加載,模型權重需要大概 13GB顯存。

獲取當前模型占用的GPU顯存

memory_bytes = model.get_memory_footprint()
# 轉換為GB
memory_gb = memory_footprint_bytes / (1024 ** 3)  
print(f"{memory_gb :.2f}GB")

注意:與實際進程占用有差異,差值為預留給PyTorch的顯存

準備數據集

準備數據集其實就是指令集構建,LLM的微調一般指指令微調過程。所謂指令微調,就是使用的微調數據格式、形式。

訓練目標是讓模型具有理解并遵循用戶指令的能力。因此在指令集構建時,應該針對目標任務,針對性的構建任務指令集。

這里使用alpaca格式的數據集,格式形式如下:

[{"instruction": "用戶指令(必填)","input": "用戶輸入(選填)","output": "模型回答(必填)",},"system": "系統提示詞(選填)","history": [["第一輪指令(選填)", "第一輪回答(選填)"],["第二輪指令(選填)", "第二輪回答(選填)"]]
]
instruction:用戶指令,要求AI執行的任務或問題input:用戶輸入,是完成用戶指令所必須的輸入內容,就是執行指令所需的具體信息或上下文output:模型回答,根據給定的指令和輸入生成答案

這里根據企業私有文檔數據,生成相關格式的訓練數據集,大概格式如下:

[{"instruction": "內退條件是什么?","input": "","output": "內退條件包括與公司簽訂正式勞動合同并連續工作滿20年及以上,以及距離法定退休年齡不足5年。特殊工種符合國家相關規定可提前退休的也可在退休前5年內提出內退申請。"},
]

加載模型

from transformers import AutoModel, AutoTokenizermodel_id = "/root/work/chatglm3-6b"
model = AutoModel.from_pretrained(model_id, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)

加載數據集

from datasets import load_datasetdata_id="/root/work/jupyterlab/zd.json"
dataset = load_dataset("json", data_files=data_id)
print(dataset["train"])

在這里插入圖片描述

數據處理

Lora訓練數據是需要經過tokenize編碼處理,然后后再輸入模型進行訓練。一般需要將輸入文本編碼為input_ids,將輸出文本編碼為labels,編碼之后的結果都是多維的向量。

需要定義一個預處理函數,這個函數用于對每一個樣本,編碼其輸入、輸出文本并返回一個編碼后的字典。

# tokenize_func 函數
def tokenize_func(example, tokenizer, ignore_label_id=-100):"""對單個數據樣本進行tokenize處理。參數:example (dict): 包含'content'和'summary'鍵的字典,代表訓練數據的一個樣本。tokenizer (transformers.PreTrainedTokenizer): 用于tokenize文本的tokenizer。ignore_label_id (int, optional): 在label中用于填充的忽略ID,默認為-100。返回:dict: 包含'tokenized_input_ids'和'labels'的字典,用于模型訓練。"""prompt_text = ''                          # 所有數據前的指令文本max_input_length = 512                    # 輸入的最大長度max_output_length = 1536                  # 輸出的最大長度# 構建問題文本question = prompt_text + example['instruction']if example.get('input', None) and example['input'].strip():question += f'\n{example["input"]}'# 構建答案文本answer = example['output']# 對問題和答案文本進行tokenize處理q_ids = tokenizer.encode(text=question, add_special_tokens=False)a_ids = tokenizer.encode(text=answer, add_special_tokens=False)# 如果tokenize后的長度超過最大長度限制,則進行截斷if len(q_ids) > max_input_length - 2:  # 保留空間給gmask和bos標記q_ids = q_ids[:max_input_length - 2]if len(a_ids) > max_output_length - 1:  # 保留空間給eos標記a_ids = a_ids[:max_output_length - 1]# 構建模型的輸入格式input_ids = tokenizer.build_inputs_with_special_tokens(q_ids, a_ids)question_length = len(q_ids) + 2  # 加上gmask和bos標記# 構建標簽,對于問題部分的輸入使用ignore_label_id進行填充labels = [ignore_label_id] * question_length + input_ids[question_length:]return {'input_ids': input_ids, 'labels': labels}

進行數據映射處理,同時刪除特定列

# 獲取 'train' 部分的列名
column_names = dataset['train'].column_names  # 使用lambda函數調用tokenize_func函數,并傳入example和tokenizer作為參數
tokenized_dataset = dataset['train'].map(lambda example: tokenize_func(example, tokenizer),batched=False,  # 不按批次處理remove_columns=column_names  # 移除特定列(column_names中指定的列)
)

執行print(tokenized_dataset[0]),打印tokenize處理結果
在這里插入圖片描述

數據集處理

還需要使用一個數據收集器,可以使用transformers 中的DataCollatorForSeq2Seq數據收集器

from transformers import DataCollatorForSeq2Seqdata_collator = DataCollatorForSeq2Seq(tokenizer,model=model,label_pad_token_id=-100,pad_to_multiple_of=None,padding=True
)

或者自定義實現一個數據收集器

import torch
from typing import List, Dict, Optional# DataCollatorForChatGLM 類
class DataCollatorForChatGLM:"""用于處理批量數據的DataCollator,尤其是在使用 ChatGLM 模型時。該類負責將多個數據樣本(tokenized input)合并為一個批量,并在必要時進行填充(padding)。屬性:pad_token_id (int): 用于填充(padding)的token ID。max_length (int): 單個批量數據的最大長度限制。ignore_label_id (int): 在標簽中用于填充的ID。"""def __init__(self, pad_token_id: int, max_length: int = 2048, ignore_label_id: int = -100):"""初始化DataCollator。參數:pad_token_id (int): 用于填充(padding)的token ID。max_length (int): 單個批量數據的最大長度限制。ignore_label_id (int): 在標簽中用于填充的ID,默認為-100。"""self.pad_token_id = pad_token_idself.ignore_label_id = ignore_label_idself.max_length = max_lengthdef __call__(self, batch_data: List[Dict[str, List]]) -> Dict[str, torch.Tensor]:"""處理批量數據。參數:batch_data (List[Dict[str, List]]): 包含多個樣本的字典列表。返回:Dict[str, torch.Tensor]: 包含處理后的批量數據的字典。"""# 計算批量中每個樣本的長度len_list = [len(d['input_ids']) for d in batch_data]batch_max_len = max(len_list)  # 找到最長的樣本長度input_ids, labels = [], []for len_of_d, d in sorted(zip(len_list, batch_data), key=lambda x: -x[0]):pad_len = batch_max_len - len_of_d  # 計算需要填充的長度# 添加填充,并確保數據長度不超過最大長度限制ids = d['input_ids'] + [self.pad_token_id] * pad_lenlabel = d['labels'] + [self.ignore_label_id] * pad_lenif batch_max_len > self.max_length:ids = ids[:self.max_length]label = label[:self.max_length]input_ids.append(torch.LongTensor(ids))labels.append(torch.LongTensor(label))# 將處理后的數據堆疊成一個tensorinput_ids = torch.stack(input_ids)labels = torch.stack(labels)return {'input_ids': input_ids, 'labels': labels}
data_collator = DataCollatorForChatGLM(pad_token_id=tokenizer.pad_token_id)

配置LoRA

在peft中使用LoRA非常簡單。借助PeftModel抽象,可以快速將低秩適配器(LoRA)應用到任意模型中。

在初始化相應的微調配置類(LoraConfig)時,需要顯式指定在哪些層新增適配器(Adapter),并將其設置正確。

ChatGLM3-6B模型通過以下方式獲取需要訓練的模型層的名字

from peft.utils import TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPINGtarget_modules = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING['chatglm']

在PEFT庫的 constants.py 文件中定義了不同的 PEFT 方法,在各類大模型上的微調適配模塊。

在這里插入圖片描述
主要是配置LoraConfig類,其中可以設置很多參數,但主要參數只有幾個

# 從peft庫導入LoraConfig和get_peft_model函數
from peft import LoraConfig, get_peft_model, TaskType# 創建一個LoraConfig對象,用于設置LoRA(Low-Rank Adaptation)的配置參數
config = LoraConfig(r=8,  # LoRA的秩,影響LoRA矩陣的大小lora_alpha=32,  # LoRA適應的比例因子# 指定需要訓練的模型層的名字,不同模型對應層的名字不同# target_modules=["query_key_value"],target_modules=target_modules,lora_dropout=0.05,  # 在LoRA模塊中使用的dropout率bias="none",  # 設置bias的使用方式,這里沒有使用bias# task_type="CAUSAL_LM"  # 任務類型,這里設置為因果(自回歸)語言模型task_type=TaskType.CAUSAL_LM
)# 使用get_peft_model函數和給定的配置來獲取一個PEFT模型
model = get_peft_model(model, config)# 打印出模型中可訓練的參數
model.print_trainable_parameters()

在這里插入圖片描述

配置訓練超參數

配置訓練超參數使用TrainingArguments類,可配置參數同樣有很多,但主要參數也是只有幾個

from transformers import TrainingArguments, Trainertraining_args = TrainingArguments(output_dir=f"{model_id}-lora",  # 指定模型輸出和保存的目錄per_device_train_batch_size=4,  # 每個設備上的訓練批量大小learning_rate=2e-4,  # 學習率fp16=True,  # 啟用混合精度訓練,可以提高訓練速度,同時減少內存使用logging_steps=20,  # 指定日志記錄的步長,用于跟蹤訓練進度save_strategy="steps",   # 模型保存策略save_steps=50,   # 模型保存步數# max_steps=50, # 最大訓練步長num_train_epochs=1  # 訓練的總輪數)

查看添加LoRA模塊后的模型

print(model)

開始訓練

配置model、參數、數據集后就可以進行訓練了

trainer = Trainer(model=model,  # 指定訓練時使用的模型train_dataset=tokenized_dataset,  # 指定訓練數據集args=training_args,data_collator=data_collator,
)model.use_cache = False
# trainer.train() 
with torch.autocast("cuda"): trainer.train()

在這里插入圖片描述

注意:

執行trainer.train() 時出現異常,參考:bitsandbytes的issues

保存LoRA模型

lora_model_path = "lora/chatglm3-6b-int8"
trainer.model.save_pretrained(lora_model_path )
#model.save_pretrained(lora_model_path )

在這里插入圖片描述

模型推理

使用LoRA模型,進行模型推理

lora_model = trainer.model

1.文本補全

text = "人力資源部根據各部門人員"inputs = tokenizer(text, return_tensors="pt").to(0)out = lora_model.generate(**inputs, max_new_tokens=500)
print(tokenizer.decode(out[0], skip_special_tokens=True))

在這里插入圖片描述
2.問答對話

from peft import PeftModelinput_text = '公司的招聘需求是如何提出的?'
model.eval()
response, history = lora_model.chat(tokenizer=tokenizer, query=input_text)
print(f'ChatGLM3-6B 微調后回答: \n{response}')

在這里插入圖片描述

從新加載

加載源model與tokenizer,使用PeftModel合并源model與PEFT微調后的參數,然后進行推理測試。

from peft import PeftModel
from transformers import AutoModel, AutoTokenizermodel_path="/root/work/chatglm3-6b"
peft_model_checkpoint_path="./chatglm3-6b-lora/checkpoint-50"model = AutoModel.from_pretrained(model_path, trust_remote_code=True, low_cpu_mem_usage=True)
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, trust_remote_code=True)# 將訓練所得的LoRa權重加載起來
p_model = PeftModel.from_pretrained(model, model_id=peft_model_checkpoint_path) p_model = p_model.cuda()
response, history = p_model.chat(tokenizer, "內退條件是什么?", history=[])
print(response)

合并模型

將lora權重合并到大模型中,將模型參數加載為16位浮點數

from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch model_path="/root/work/chatglm3-6b"
peft_model_path="./lora/chatglm3-6b-int8"
save_path = "chatglm3-6b-lora"tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, low_cpu_mem_usage=True, torch_dtype=torch.float16, device_map="auto")
model = PeftModel.from_pretrained(model, peft_model_path)
model = model.merge_and_unload()tokenizer.save_pretrained(save_path)
model.save_pretrained(save_path)

查看合并文件
在這里插入圖片描述

使用微調后的模型

from transformers import AutoTokenizer, AutoModeltokenizer = AutoTokenizer.from_pretrained("chatglm3-6b-lora", trust_remote_code=True)
model = AutoModel.from_pretrained("chatglm3-6b-lora", trust_remote_code=True, device='cuda')model = model.eval()
response, history = model.chat(tokenizer, "內退條件是什么?", history=[])
print(response)

在這里插入圖片描述

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/35910.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/35910.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/35910.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

6 序列數據和文本的深度學習

6.1 使用文本數據 文本是常用的序列化數據類型之一。文本數據可以看作是一個字符序列或詞的序列。對大多數問題,我們都將文本看作詞序列。深度學習序列模型(如RNN及其變體)能夠從文本數據中學習重要的模式。這些模式可以解決類似以下領域中的問題: 自然…

JVM專題十一:JVM 中的收集器一

上一篇JVM專題十:JVM中的垃圾回收機制專題中,我們主要介紹了Java的垃圾機制,包括垃圾回收基本概念,重點介紹了垃圾回收機制中自動內存管理與垃圾收集算法。如果說收集算法是內存回收的方法論,那么垃圾收集器就是內存回…

【開發者推薦】告別繁瑣:一鍵解鎖國產ETL新貴,Kettle的終結者

在數字化轉型的今天,數據集成的重要性不言而喻。ETL工具作為數據管理的核心,對企業決策和運營至關重要。盡管Kettle廣受歡迎,但國產ETL工具 TASKCTL 以其創新特性和卓越性能,為市場提供了新的選擇。 TASKCTL概述 TASKCTL 是一款免…

wget之Win11中安裝及使用

wget之Win11中安裝及使用 文章目錄 wget之Win11中安裝及使用1. 下載2. 安裝3. 配置環境變量4. 查看及使用1. 查看版本2. 幫助命令3. 基本使用 1. 下載 下載地址:https://eternallybored.org/misc/wget 選擇對應的版本進行下載即可 2. 安裝 將下載后的wget-1.21.4-w…

中醫實訓室:在傳統針灸教學中的應用與創新

中醫實訓室是中醫教育體系中的重要組成部分,尤其在傳統針灸教學中,它扮演著無可替代的角色。這里是理論與實踐的交匯點,是傳統技藝與現代教育理念的碰撞之地。本文將探討中醫實訓室在傳統針灸教學中的應用與創新實踐。 首先,實訓室…

ResultSet的作用和類型

ResultSet的作用: ResultSet在Java中主要用于處理和操作數據庫查詢結果。它是一個接口,提供了一系列方法來訪問和操作數據庫查詢得到的結果集。具體來說,ResultSet的作用包括: 獲取查詢結果:通過ResultSet可以獲取數…

C++中指針的使用方法

基本概念 指針:一個變量,它存儲另一個變量的內存地址。地址運算符 &:用于獲取變量的內存地址。間接運算符 *:用于訪問指針所指向的變量的值。 聲明和初始化 int a 10; // 定義一個整數變量 int *p &a; // 定…

算法導論 總結索引 | 第四部分 第十六章:貪心算法

1、求解最優化問題的算法 通常需要經過一系列的步驟,在每個步驟都面臨多種選擇。對于許多最優化問題,使用動態規劃算法求最優解有些殺雞用牛刀了,可以使用更簡單、更高效的算法 貪心算法(greedy algorithm)就是這樣的算…

Git 學習筆記(超詳細注釋,從0到1)

Git學習筆記 1.1 關鍵詞 Fork、pull requests、pull、fetch、push、diff、merge、commit、add、checkout 1.2 原理(看圖學習) 1.3 Fork別人倉庫到自己倉庫中 記住2個地址 1)上游地址(upstream地址):http…

Nuxt 應用的三種運行模式(五)

Nuxt.js 提供了三種運行模式,分別是: SPA(單頁面應用) Universal(服務端渲染) Static(靜態生成) 每種模式都適用于不同的場景和需求,下面將詳細解析這三種模式的區別&…

【Qt】Qt多線程編程指南:提升應用性能與用戶體驗

文章目錄 前言1. Qt 多線程概述2. QThread 常用 API3. 使用線程4. 多線的使用場景5. 線程安全問題5.1. 加鎖5.2. QReadWriteLocker、QReadLocker、QWriteLocker 6. 條件變量 與 信號量6.1. 條件變量6.2 信號量 總結 前言 在現代軟件開發中,多線程編程已成為一個不可…

C語言類型轉換理解不同的基本類型為什么能夠進行運算

類型轉換 1.類型轉換1.1隱式轉換1.2常用算術轉換1.2強制類型轉換 1.類型轉換 在執行算數運算時,計算機比C語言的限制更多。為了讓計算機執行算術運算,通常要求操作數用相同的大小(即為的數量相同),但是C語言卻允許混合…

Java基礎:常用類(四)

Java基礎:常用類(四) 文章目錄 Java基礎:常用類(四)1. String字符串類1.1 簡介1.2 創建方式1.3 構造方法1.4 連接操作符1.5 常用方法 2. StringBuffer和StringBuilder類2.1 StringBuffer類2.1.1 簡介2.1.2 …

智能電能表如何助力智慧農業

智能電能表作為智能電網數據采集的基本設備之一,不僅具備傳統電能表基本用電量的計量功能,還具備雙向多種費率計量功能、用戶端控制功能、多種數據傳輸模式的雙向數據通信功能以及防竊電功能等智能化的功能。這些功能使得智能電能表在農業領域的應用具有…

基于深度學習的圖像去霧

基于深度學習的圖像去霧 圖像去霧是指從有霧的圖像中恢復清晰圖像的過程。傳統的圖像去霧方法(如暗原色先驗、圖像分層法等)在某些情況下表現良好,但在復雜場景下效果有限。深度學習方法利用大量的數據和強大的模型能力,在圖像去…

【滲透測試】小程序反編譯

前言 在滲透測試時,除了常規的Web滲透,小程序也是我們需要重點關注的地方,微信小程序反編譯后,可以借助微信小程序開發者工具進行調試,搜索敏感關鍵字,或許能夠發現泄露的AccessKey等敏感信息及數據 工具…

【PHP小課堂】PHP中PRGE正則函數的學習

PHP中PRGE正則函數的學習 正則表達式的作用想必不用我多說了,大家在日常的開發中或多或少都會接觸到。特別是對于一些登錄(郵箱、手機號)以及網頁爬蟲來說,正則表達式就是神器一般的存在。在 PHP 中,有兩種處理正則表達…

ChatGPT在用戶交互過程中如何實現自我學習和優化?

ChatGPT的自我學習和優化:深度解析與未來展望 在人工智能領域,ChatGPT的出現標志著自然語言處理技術的一大飛躍。作為一個先進的語言模型,ChatGPT不僅能夠與用戶進行流暢的對話,還能夠通過自我學習和優化來不斷提升其性能。本文將…

【SkiaSharp繪圖11】SKCanvas屬性詳解

文章目錄 SKCanvas構造SKCanvas構造光柵 Surface構造GPU Surface構造PDF文檔構造XPS文檔構造SVG文檔SKNoDrawCanvas 變換剪裁和狀態構造函數相關屬性DeviceClipBounds獲取裁切邊界(設備坐標系)ClipRect修改裁切區域IsClipEmpty當前裁切區域是否為空IsClipRect裁切區域是否為矩形…

JFreeChart 生成Word圖表

文章目錄 1 思路1.1 概述1.2 支持的圖表類型1.3 特性 2 準備模板3 導入依賴4 圖表生成工具類 ChartWithChineseExample步驟 1: 準備字體文件步驟 2: 注冊字體到FontFactory步驟 3: 設置圖表具體位置的字體柱狀圖:餅圖:折線圖:完整代碼&#x…