InternLM 論文分類微調實踐(XTuner 版)

1.環境安裝

我創建開發機選擇鏡像為Cuda12.2-conda,選擇GPU為100%A100的資源配置

Conda 管理環境

conda create -n xtuner_101?python=3.10 -y
conda activate xtuner_101
pip install torch==2.4.0+cu121 torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu121
pip install xtuner timm flash_attn datasets==2.21.0 deepspeed==0.16.1
conda install mpi4py -y
#為了兼容模型,降級transformers版本
pip uninstall transformers -y
pip install transformers==4.48.0 --no-cache-dir -i https://pypi.tuna.tsinghua.edu.cn/simple

檢驗環境安裝

xtuner list-cfg

2.數據獲取

數據為sftdata.jsonl,已上傳。

3.訓練

?鏈接模型位置命令

ln -s /root/share/new_models/Shanghai_AI_Laboratory/internlm2_5-7b-chat ./

3.1 微調腳本

# Copyright (c) OpenMMLab. All rights reserved.
import torch
from datasets import load_dataset
from mmengine.dataset import DefaultSampler
from mmengine.hooks import (CheckpointHook,DistSamplerSeedHook,IterTimerHook,LoggerHook,ParamSchedulerHook,
)
from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
from peft import LoraConfig
from torch.optim import AdamW
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfigfrom xtuner.dataset import process_hf_dataset
from xtuner.dataset.collate_fns import default_collate_fn
from xtuner.dataset.map_fns import alpaca_map_fn, template_map_fn_factory
from xtuner.engine.hooks import (DatasetInfoHook,EvaluateChatHook,VarlenAttnArgsToMessageHubHook,
)
from xtuner.engine.runner import TrainLoop
from xtuner.model import SupervisedFinetune
from xtuner.parallel.sequence import SequenceParallelSampler
from xtuner.utils import PROMPT_TEMPLATE, SYSTEM_TEMPLATE#######################################################################
#                          PART 1  Settings                           #
#######################################################################
# Model
pretrained_model_name_or_path = "./internlm2_5-7b-chat"
use_varlen_attn = False# Data
alpaca_en_path = "/root/xtuner/datasets/train/sftdata.jsonl"#換成自己的數據路徑
prompt_template = PROMPT_TEMPLATE.internlm2_chat
max_length = 2048
pack_to_max_length = True# parallel
sequence_parallel_size = 1# Scheduler & Optimizer
batch_size = 1  # per_device
accumulative_counts = 1
accumulative_counts *= sequence_parallel_size
dataloader_num_workers = 0
max_epochs = 3
optim_type = AdamW
lr = 2e-4
betas = (0.9, 0.999)
weight_decay = 0
max_norm = 1  # grad clip
warmup_ratio = 0.03# Save
save_steps = 500
save_total_limit = 2  # Maximum checkpoints to keep (-1 means unlimited)# Evaluate the generation performance during the training
evaluation_freq = 500
SYSTEM = SYSTEM_TEMPLATE.alpaca
evaluation_inputs = ["請給我介紹五個上海的景點", "Please tell me five scenic spots in Shanghai"]#######################################################################
#                      PART 2  Model & Tokenizer                      #
#######################################################################
tokenizer = dict(type=AutoTokenizer.from_pretrained,pretrained_model_name_or_path=pretrained_model_name_or_path,trust_remote_code=True,padding_side="right",
)model = dict(type=SupervisedFinetune,use_varlen_attn=use_varlen_attn,llm=dict(type=AutoModelForCausalLM.from_pretrained,pretrained_model_name_or_path=pretrained_model_name_or_path,trust_remote_code=True,torch_dtype=torch.float16,quantization_config=dict(type=BitsAndBytesConfig,load_in_4bit=True,load_in_8bit=False,llm_int8_threshold=6.0,llm_int8_has_fp16_weight=False,bnb_4bit_compute_dtype=torch.float16,bnb_4bit_use_double_quant=True,bnb_4bit_quant_type="nf4",),),lora=dict(type=LoraConfig,r=64,lora_alpha=16,lora_dropout=0.1,bias="none",task_type="CAUSAL_LM",),
)#######################################################################
#                      PART 3  Dataset & Dataloader                   #
#######################################################################
alpaca_en = dict(type=process_hf_dataset,dataset=dict(type=load_dataset, path='json', data_files=alpaca_en_path),tokenizer=tokenizer,max_length=max_length,dataset_map_fn=alpaca_map_fn,template_map_fn=dict(type=template_map_fn_factory, template=prompt_template),remove_unused_columns=True,shuffle_before_pack=True,pack_to_max_length=pack_to_max_length,use_varlen_attn=use_varlen_attn,
)sampler = SequenceParallelSampler if sequence_parallel_size > 1 else DefaultSampler
train_dataloader = dict(batch_size=batch_size,num_workers=dataloader_num_workers,dataset=alpaca_en,sampler=dict(type=sampler, shuffle=True),collate_fn=dict(type=default_collate_fn, use_varlen_attn=use_varlen_attn),
)#######################################################################
#                    PART 4  Scheduler & Optimizer                    #
#######################################################################
# optimizer
optim_wrapper = dict(type=AmpOptimWrapper,optimizer=dict(type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),accumulative_counts=accumulative_counts,loss_scale="dynamic",dtype="float16",
)# learning policy
# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md  # noqa: E501
param_scheduler = [dict(type=LinearLR,start_factor=1e-5,by_epoch=True,begin=0,end=warmup_ratio * max_epochs,convert_to_iter_based=True,),dict(type=CosineAnnealingLR,eta_min=0.0,by_epoch=True,begin=warmup_ratio * max_epochs,end=max_epochs,convert_to_iter_based=True,),
]# train, val, test setting
train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)#######################################################################
#                           PART 5  Runtime                           #
#######################################################################
# Log the dialogue periodically during the training process, optional
custom_hooks = [dict(type=DatasetInfoHook, tokenizer=tokenizer),dict(type=EvaluateChatHook,tokenizer=tokenizer,every_n_iters=evaluation_freq,evaluation_inputs=evaluation_inputs,system=SYSTEM,prompt_template=prompt_template,),
]if use_varlen_attn:custom_hooks += [dict(type=VarlenAttnArgsToMessageHubHook)]# configure default hooks
default_hooks = dict(# record the time of every iteration.timer=dict(type=IterTimerHook),# print log every 10 iterations.logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),# enable the parameter scheduler.param_scheduler=dict(type=ParamSchedulerHook),# save checkpoint per `save_steps`.checkpoint=dict(type=CheckpointHook,by_epoch=False,interval=save_steps,max_keep_ckpts=save_total_limit,),# set sampler seed in distributed evrionment.sampler_seed=dict(type=DistSamplerSeedHook),
)# configure environment
env_cfg = dict(# whether to enable cudnn benchmarkcudnn_benchmark=False,# set multi process parametersmp_cfg=dict(mp_start_method="fork", opencv_num_threads=0),# set distributed parametersdist_cfg=dict(backend="nccl"),
)# set visualizer
visualizer = None# set log level
log_level = "INFO"# load from which checkpoint
load_from = None# whether to resume training from the loaded checkpoint
resume = False# Defaults to use random seed and disable `deterministic`
randomness = dict(seed=None, deterministic=False)# set log processor
log_processor = dict(by_epoch=False)

?將模型和地址改為自己的路徑

3.2 啟動微調

cd /root/101
conda activate xtuner_101
xtuner train internlm2_5_chat_7b_qlora_alpaca_e3_copy.py --deepspeed deepspeed_zero1

3.3 合并

3.3.1??將PTH格式轉換為HuggingFace格式

xtuner convert pth_to_hf internlm2_5_chat_7b_qlora_alpaca_e3_copy.py ./work_dirs/internlm2_5_chat_7b_qlora_alpaca_e3_copy/iter_195.pth ./work_dirs/hf

3.3.2??合并adapter和基礎模型

xtuner convert merge \
/root/internlm2_5-7b-chat \
./work_dirs/hf \
./work_dirs/merged \
--max-shard-size 2GB \

完成這兩個步驟后,合并好的模型將保存在./work_dirs/merged目錄下,你可以直接使用這個模型進行推理了。?

3.4 推理

from transformers import AutoModelForCausalLM, AutoTokenizer
import time# 加載模型和分詞器
# model_path = "./lora_output/merged"
model_path = "./internlm2_5-7b-chat"
print(f"加載模型:{model_path}")start_time = time.time()tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype="auto", device_map="auto"
)def classify_paper(title, authors, abstract, additional_info=""):# 構建輸入,包含多選選項prompt = f"Based on the title '{title}', authors '{authors}', and abstract '{abstract}', please determine the scientific category of this paper. {additional_info}\n\nA. astro-ph\nB. cond-mat.mes-hall\nC. cond-mat.mtrl-sci\nD. cs.CL\nE. cs.CV\nF. cs.LG\nG. gr-qc\nH. hep-ph\nI. hep-th\nJ. quant-ph"# 設置系統信息messages = [{"role": "system", "content": "你是個優秀的論文分類師"},{"role": "user", "content": prompt},]# 應用聊天模板input_text = tokenizer.apply_chat_template(messages, tokenize=False)# 生成回答inputs = tokenizer(input_text, return_tensors="pt").to(model.device)outputs = model.generate(**inputs,max_new_tokens=10,  # 減少生成長度,只需要簡短答案temperature=0.1,  # 降低溫度提高確定性top_p=0.95,repetition_penalty=1.0,)# 解碼輸出response = tokenizer.decode(outputs[0][inputs.input_ids.shape[1] :], skip_special_tokens=True).strip()# 如果回答中包含選項標識符,只返回該標識符for option in ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J"]:if option in response:return option# 如果回答不包含選項,返回完整回答return response# 示例使用
if __name__ == "__main__":title = "Outilex, plate-forme logicielle de traitement de textes 'ecrits"authors = "Olivier Blanc (IGM-LabInfo), Matthieu Constant (IGM-LabInfo), Eric Laporte (IGM-LabInfo)"abstract = "The Outilex software platform, which will be made available to research, development and industry, comprises software components implementing all the fundamental operations of written text processing: processing without lexicons, exploitation of lexicons and grammars, language resource management. All data are structured in XML formats, and also in more compact formats, either readable or binary, whenever necessary; the required format converters are included in the platform; the grammar formats allow for combining statistical approaches with resource-based approaches. Manually constructed lexicons for French and English, originating from the LADL, and of substantial coverage, will be distributed with the platform under LGPL-LR license."result = classify_paper(title, authors, abstract)print(result)# 計算并打印總耗時end_time = time.time()total_time = end_time - start_timeprint(f"程序總耗時:{total_time:.2f}秒")

?推理結果如下:

微調前模型推理

微調后模型推理

3.5 部署

pip install lmdeploy
python -m lmdeploy.pytorch.chat ./work_dirs/merged \
--max_new_tokens 256 \   
--temperture 0.8 \   
--top_p 0.95 \   
--seed 0

4.評測(跳過)

5.上傳模型到魔搭

pip install modelscope

?使用腳本

from modelscope.hub.api import HubApi
YOUR_ACCESS_TOKEN='自己的令牌'
api=HubApi()
api.login(YOUR_ACCESS_TOKEN)from modelscope.hub.constants import Licenses, ModelVisibility
owner_name='Raven10086'
model_name='InternLM-gmz-camp5'
model_id=f"{owner_name}/{model_name}"
api.create_model(model_id,visibility=ModelVisibility.PUBLIC,license=Licenses.APACHE_V2,chinese_name="gmz文本分類微調端側小模型"
)
api.upload_folder(repo_id=f"{owner_name}/{model_name}",folder_path='/root/swift_output/InternLM3-8B-SFT-Lora/v5-20250517-164316/checkpoint-62-merged',commit_message='fast commit',)

?上傳成功截圖

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/906562.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/906562.shtml
英文地址,請注明出處:http://en.pswp.cn/news/906562.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

軟考中級軟件設計師——設計模式篇

一、設計模式核心分類 設計模式分為 3 大類,共 23 種模式(考試常考約 10-15 種): 分類核心模式考試重點創建型模式工廠方法、抽象工廠、單例、生成器、原型單例模式的實現(懶漢、餓漢)、工廠模式的應用場…

小米2025年校招筆試真題手撕(一)

一、題目 小A每天都要吃a,b兩種面包各一個。而他有n個不同的面包機,不同面包機制作面包的時間各不相同。第i臺面包機制作a面包 需要花費ai的時間,制作b面包則需要花費bi的時間。 為能盡快吃到這兩種面包,小A可以選擇兩個不同的面包機x&…

【微信小程序 + 高德地圖API 】鍵入關鍵字搜索地址,獲取經緯度等

前言 又到熟悉的前言,接到個需求,要引入高德地圖api,我就記錄一下,要是有幫助記得點贊、收藏、關注😁。 后續有時間會慢慢完善一些文章:(畫餅時間) map組件自定義氣泡、mark標記點…

uni-app(2):頁面

1 頁面簡介 uni-app項目中,一個頁面就是一個符合Vue SFC規范的 vue 文件。 在 uni-app js 引擎版中,后綴名是.vue文件或.nvue文件。 這些頁面均全平臺支持,差異在于當 uni-app 發行到App平臺時,.vue文件會使用webview進行渲染&…

Axure實戰:智慧水務管理系統原型設計速覽

本原型通過Axure構建覆蓋生產到服務的全流程交互模型,聚焦"數據驅動智能決策"核心價值,助力水務企業實現管理效率提升與運營成本優化。 系統采用"13N"架構: 1個統一入口:集成單點登錄與角色動態權限&#xff…

十二、Linux實現截屏小工具

系列文章目錄 本系列文章記錄在Linux操作系統下,如何在不依賴QT、GTK等開源GUI庫的情況下,基于x11窗口系統(xlib)圖形界面應用程序開發。之所以使用x11進行窗口開發,是在開發一個基于duilib跨平臺的界面庫項目&#x…

藍橋杯分享經驗

系列文章目錄 提示:小白先看系列 第一章 藍橋杯的錢白給嗎 文章目錄 系列文章目錄前言一、自我介紹二、經驗講解:1.基礎知識2.進階知識3.個人觀點 三、總結四、后續 前言 第十六屆藍橋杯已經省賽已經結束了,相信很多小伙伴也已經得到自己的成績了。接下…

XC3588H搭載國產麒麟系統可用于政務/社保一體機嗎?

答案是肯定的。 向成電子XC3588H搭載的國產銀河麒麟系統和國產星光麒麟系統已完成適配,適用于政務服務、社保服務一體機的所有外設,運行穩定流暢。 在數字化政務快速發展的今天,政務服務終端的穩定性、安全性與高效性成為提升群眾辦事體驗的關…

如何排查服務器 CPU 溫度過高的問題并解決?

服務器CPU溫度過高是一個常見的問題,可能導致服務器性能下降、系統穩定性問題甚至硬件損壞。有效排查和解決服務器CPU溫度過高的問題對于確保服務器正常運行和延長硬件壽命至關重要。本文將介紹如何排查服務器CPU溫度過高的問題,并提供解決方法&#xff…

物聯網、云計算技術加持,助推樓宇自控系統實現智能高效管理

在建筑智能化發展的進程中,樓宇自控系統作為實現建筑高效管理的核心載體,正面臨著數據海量復雜、設備協同困難、管理響應遲緩等挑戰。而物聯網與云計算技術的深度融合,為樓宇自控系統的升級提供了全新的解決方案,賦予其智能感知、…

uni-app使用大集

1、手動修改頁面標題 uni.setNavigationBarTitle({title: 修改標題 }); 2、單選 不止有 radio-group&#xff0c;還有 uni-data-checkbox 數據選擇器 <!-- html部分 --> <uni-data-checkbox v-model"sex" :localdata"checkboxList"></u…

(6)python爬蟲--selenium

文章目錄 前言一、初識selenium二、安裝selenium2.1 查看chrome版本并禁止chrome自動更新2.1.1 查看chrome版本2.1.2 禁止chrome更新自動更新 2.2 安裝對應版本的驅動程序2.3安裝selenium包 三、selenium關于瀏覽器的使用3.1 創建瀏覽器、設置、打開3.2 打開/關閉網頁及瀏覽器3…

基于OpenCV的人臉微笑檢測實現

文章目錄 引言一、技術原理二、代碼實現2.1 關鍵代碼解析2.1.1 模型加載2.1.2 圖像翻轉2.1.3 人臉檢測 微笑檢測 2.2 顯示效果 三、參數調優建議四、總結 引言 在計算機視覺領域&#xff0c;人臉檢測和表情識別一直是熱門的研究方向。今天我將分享一個使用Python和OpenCV實現…

Java 大視界 -- 基于 Java 的大數據分布式存儲在視頻會議系統海量視頻數據存儲與回放中的應用(263)

&#x1f496;親愛的朋友們&#xff0c;熱烈歡迎來到 青云交的博客&#xff01;能與諸位在此相逢&#xff0c;我倍感榮幸。在這飛速更迭的時代&#xff0c;我們都渴望一方心靈凈土&#xff0c;而 我的博客 正是這樣溫暖的所在。這里為你呈上趣味與實用兼具的知識&#xff0c;也…

Kotlin 極簡小抄 P9 - 數組(數組的創建、數組元素的訪問與修改、數組遍歷、數組操作、多維數組、數組與可變參數)

Kotlin 概述 Kotlin 由 JetBrains 開發&#xff0c;是一種在 JVM&#xff08;Java 虛擬機&#xff09;上運行的靜態類型編程語言 Kotlin 旨在提高開發者的編碼效率和安全性&#xff0c;同時保持與 Java 的高度互操作性 Kotlin 是 Android 應用開發的首選語言&#xff0c;也可…

gitlab+portainer 實現Ruoyi Vue前端CI/CD

1. 場景 最近整了一個Ruoyi Vue 項目&#xff0c;需要實現CICD&#xff0c;經過一番坎坷&#xff0c;最終達成&#xff0c;現將技術要點和踩坑呈現。 具體操作流程和后端大同小異&#xff0c;后端操作參考連接如下&#xff1a; https://blog.csdn.net/leinminna/article/detai…

RNN神經網絡

RNN神經網絡 1-核心知識 1-解釋RNN神經網絡2-RNN和傳統的神經網絡有什么區別&#xff1f;3-RNN和LSTM有什么區別&#xff1f;4-transformer的歸一化有哪幾種實現方式 2-知識問答 1-解釋RNN神經網絡 Why&#xff1a;與我何干&#xff1f; 在我們的生活中&#xff0c;很多事情…

javaweb-html

1.交互流程&#xff1a; 瀏覽器向服務器發送http請求&#xff0c;服務器對瀏覽器進行回應&#xff0c;并發送字符串&#xff0c;瀏覽器能對這些字符串&#xff08;html代碼&#xff09;進行解釋&#xff1b; 三大web語言&#xff1a;&#xff08;1&#xff09;html&#xff1a…

從混亂到高效:我們是如何重構 iOS 上架流程的(含 Appuploader實踐)

從混亂到高效&#xff1a;我們是如何重構 iOS 上架流程的 在開發團隊中&#xff0c;有一類看不見卻至關重要的問題&#xff1a;環境依賴。 特別是 iOS App 的發布流程&#xff0c;往往牢牢綁死在一臺特定的 Mac 上。每次需要發版本&#xff0c;都要找到“那臺 Mac”&#xff…

FPGA:CLB資源以及Verilog編碼面積優化技巧

本文將先介紹Kintex-7系列器件的CLB&#xff08;可配置邏輯塊&#xff09;資源&#xff0c;然后分享在Verilog編碼時節省CLB資源的技巧。以下內容基于Kintex-7系列的架構特點&#xff0c;并結合實際設計經驗進行闡述。 一、Kintex-7系列器件的CLB資源介紹 Kintex-7系列是Xilin…