基于Stable Diffusion XL模型進行文本生成圖像的訓練

基于Stable Diffusion XL模型進行文本生成圖像的訓練

flyfish

export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
export VAE_NAME="madebyollin/sdxl-vae-fp16-fix"
export DATASET_NAME="lambdalabs/naruto-blip-captions"accelerate launch train_text_to_image_sdxl.py \--pretrained_model_name_or_path=$MODEL_NAME \--pretrained_vae_model_name_or_path=$VAE_NAME \--dataset_name=$DATASET_NAME \--enable_xformers_memory_efficient_attention \--resolution=512 --center_crop --random_flip \--proportion_empty_prompts=0.2 \--train_batch_size=1 \--gradient_accumulation_steps=4 --gradient_checkpointing \--max_train_steps=10000 \--use_8bit_adam \--learning_rate=1e-06 --lr_scheduler="constant" --lr_warmup_steps=0 \--mixed_precision="fp16" \--report_to="wandb" \--validation_prompt="a cute Sundar Pichai creature" --validation_epochs 5 \--checkpointing_steps=5000 \--output_dir="sdxl-naruto-model" \--push_to_hub

環境變量部分

export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
export VAE_NAME="madebyollin/sdxl-vae-fp16-fix"
export DATASET_NAME="lambdalabs/naruto-blip-captions"
  • MODEL_NAME:指定預訓練模型的名稱或路徑。這里使用的是 stabilityai/stable-diffusion-xl-base-1.0,也就是Stable Diffusion XL的基礎版本1.0。
  • VAE_NAME:指定變分自編碼器(VAE)的名稱或路徑。madebyollin/sdxl-vae-fp16-fix 是針對Stable Diffusion XL的一個經過修復的VAE模型,適用于半精度(FP16)計算。
  • DATASET_NAME:指定訓練所使用的數據集名稱或路徑。這里使用的是 lambdalabs/naruto-blip-captions,是一個包含火影忍者相關圖像及其描述的數據集。

accelerate launch 命令參數部分

accelerate launch train_text_to_image_sdxl.py \

這行代碼使用 accelerate 工具來啟動 train_text_to_image_sdxl.py 腳本,accelerate 可以幫助我們在多GPU、TPU等環境下進行分布式訓練。

腳本參數部分

  • --pretrained_model_name_or_path=$MODEL_NAME:指定預訓練模型的名稱或路徑,這里使用前面定義的 MODEL_NAME 環境變量。
  • --pretrained_vae_model_name_or_path=$VAE_NAME:指定預訓練VAE模型的名稱或路徑,使用前面定義的 VAE_NAME 環境變量。
  • --dataset_name=$DATASET_NAME:指定訓練數據集的名稱或路徑,使用前面定義的 DATASET_NAME 環境變量。
  • --enable_xformers_memory_efficient_attention:啟用 xformers 庫的內存高效注意力機制,能減少訓練過程中的內存占用。
  • --resolution=512 --center_crop --random_flip
    • --resolution=512:將輸入圖像的分辨率統一調整為512x512像素。
    • --center_crop:對圖像進行中心裁剪,使其達到指定的分辨率。
    • --random_flip:在訓練過程中隨機對圖像進行水平翻轉,以增加數據的多樣性。
  • --proportion_empty_prompts=0.2:設置空提示(沒有文本描述)的樣本在訓練數據中的比例為20%。
  • --train_batch_size=1:每個訓練批次包含的樣本數量為1。
  • --gradient_accumulation_steps=4 --gradient_checkpointing
    • --gradient_accumulation_steps=4:梯度累積步數為4,即每4個批次的梯度進行一次更新,這樣可以在有限的內存下模擬更大的批次大小。
    • --gradient_checkpointing:啟用梯度檢查點機制,通過減少內存使用來支持更大的模型和批次大小。
  • --max_train_steps=10000:最大訓練步數為10000步。
  • --use_8bit_adam:使用8位Adam優化器,能減少內存占用。
  • --learning_rate=1e-06 --lr_scheduler="constant" --lr_warmup_steps=0
    • --learning_rate=1e-06:學習率設置為1e-6。
    • --lr_scheduler="constant":學習率調度器設置為常數,即訓練過程中學習率保持不變。
    • --lr_warmup_steps=0:學習率預熱步數為0,即不進行學習率預熱。
  • --mixed_precision="fp16":使用半精度(FP16)混合精度訓練,能減少內存使用并加快訓練速度。
  • --report_to="wandb":將訓練過程中的指標報告到Weights & Biases(WandB)平臺,方便進行可視化和監控。
  • --validation_prompt="a cute Sundar Pichai creature" --validation_epochs 5
    • --validation_prompt="a cute Sundar Pichai creature":指定驗證時使用的文本提示,這里是“一個可愛的桑達爾·皮查伊形象”。
    • --validation_epochs 5:每5個訓練輪次進行一次驗證。
  • --checkpointing_steps=5000:每5000步保存一次模型的檢查點。
  • --output_dir="sdxl-naruto-model":指定訓練好的模型的輸出目錄為 sdxl-naruto-model
  • --push_to_hub:將訓練好的模型推送到Hugging Face模型庫。

離線環境運行

# 假設已經把模型、VAE和數據集下載到本地了
# 這里假設模型在當前目錄下的 sdxl-base-1.0 文件夾
# VAE 在 sdxl-vae-fp16-fix 文件夾
# 數據集在 naruto-blip-captions 文件夾# 定義本地路徑
MODEL_NAME="./sdxl-base-1.0"
VAE_NAME="./sdxl-vae-fp16-fix"
DATASET_NAME="./naruto-blip-captions"# 移除需要外網連接的參數
accelerate launch train_text_to_image_sdxl.py \--pretrained_model_name_or_path=$MODEL_NAME \--pretrained_vae_model_name_or_path=$VAE_NAME \--dataset_name=$DATASET_NAME \--enable_xformers_memory_efficient_attention \--resolution=512 --center_crop --random_flip \--proportion_empty_prompts=0.2 \--train_batch_size=1 \--gradient_accumulation_steps=4 --gradient_checkpointing \--max_train_steps=10000 \--use_8bit_adam \--learning_rate=1e-06 --lr_scheduler="constant" --lr_warmup_steps=0 \--mixed_precision="fp16" \--validation_prompt="a cute Sundar Pichai creature" --validation_epochs 5 \--checkpointing_steps=5000 \--output_dir="sdxl-naruto-model"

移除需要外網連接的參數:去掉 --report_to="wandb"--push_to_hub 參數,因為 wandb 需要外網連接來上傳訓練指標,--push_to_hub 則需要外網連接把模型推送到Hugging Face模型庫。

推理

from diffusers import DiffusionPipeline
import torchmodel_path = "you-model-id-goes-here" # <-- 替換為你的模型路徑
pipe = DiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16)
pipe.to("cuda")prompt = "A naruto with green eyes and red legs."
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
image.save("naruto.png")

訓練后的文件夾結構

.
├── checkpoint-10000
│   ├── optimizer.bin
│   ├── random_states_0.pkl
│   ├── scaler.pt
│   ├── scheduler.bin
│   └── unet
│       ├── config.json
│       ├── diffusion_pytorch_model-00001-of-00002.safetensors
│       ├── diffusion_pytorch_model-00002-of-00002.safetensors
│       └── diffusion_pytorch_model.safetensors.index.json
├── checkpoint-5000
│   ├── optimizer.bin
│   ├── random_states_0.pkl
│   ├── scaler.pt
│   ├── scheduler.bin
│   └── unet
│       ├── config.json
│       ├── diffusion_pytorch_model-00001-of-00002.safetensors
│       ├── diffusion_pytorch_model-00002-of-00002.safetensors
│       └── diffusion_pytorch_model.safetensors.index.json
├── model_index.json
├── scheduler
│   └── scheduler_config.json
├── text_encoder
│   ├── config.json
│   └── model.safetensors
├── text_encoder_2
│   ├── config.json
│   └── model.safetensors
├── tokenizer
│   ├── merges.txt
│   ├── special_tokens_map.json
│   ├── tokenizer_config.json
│   └── vocab.json
├── tokenizer_2
│   ├── merges.txt
│   ├── special_tokens_map.json
│   ├── tokenizer_config.json
│   └── vocab.json
├── unet
│   ├── config.json
│   ├── diffusion_pytorch_model-00001-of-00002.safetensors
│   ├── diffusion_pytorch_model-00002-of-00002.safetensors
│   └── diffusion_pytorch_model.safetensors.index.json
└── vae├── config.json└── diffusion_pytorch_model.safetensors

LoRA訓練

accelerate launch train_text_to_image_lora_sdxl.py \--pretrained_model_name_or_path=$MODEL_NAME \--pretrained_vae_model_name_or_path=$VAE_NAME \--dataset_name=$DATASET_NAME --caption_column="text" \--resolution=1024 --random_flip \--train_batch_size=1 \--num_train_epochs=2 --checkpointing_steps=500 \--learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \--mixed_precision="fp16" \--seed=42 \--output_dir="sd-naruto-model-lora-sdxl" \--validation_prompt="cute dragon creature"

推理

from diffusers import DiffusionPipeline
import torchsdxl_model_path="/media/models/AI-ModelScope/stable-diffusion-xl-base-1___0/"
lora_model_path = "/media/text_to_image/sd-naruto-model-lora-sdxl/"pipe = DiffusionPipeline.from_pretrained(sdxl_model_path, torch_dtype=torch.float16)
pipe.to("cuda")
pipe.load_lora_weights(lora_model_path)prompt = "A naruto with green eyes and red legs."
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
image.save("naruto.png")

LoRA訓練后的文件夾結構

├── checkpoint-1000
│   ├── optimizer.bin
│   ├── pytorch_lora_weights.safetensors
│   ├── random_states_0.pkl
│   ├── scaler.pt
│   └── scheduler.bin
├── checkpoint-1500
│   ├── optimizer.bin
│   ├── pytorch_lora_weights.safetensors
│   ├── random_states_0.pkl
│   ├── scaler.pt
│   └── scheduler.bin
├── checkpoint-2000
│   ├── optimizer.bin
│   ├── pytorch_lora_weights.safetensors
│   ├── random_states_0.pkl
│   ├── scaler.pt
│   └── scheduler.bin
├── checkpoint-500
│   ├── optimizer.bin
│   ├── pytorch_lora_weights.safetensors
│   ├── random_states_0.pkl
│   ├── scaler.pt
│   └── scheduler.bin
└── pytorch_lora_weights.safetensors

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/80763.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/80763.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/80763.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

基于React的高德地圖api教程001:初始化地圖

文章目錄 1、初始化地圖1.1 創建react項目1.2 可視化地圖1.3 設置衛星地圖1.4 添加開關開啟3D地圖1.5 代碼下載1、初始化地圖 1.1 創建react項目 創建geodeapi項目: npx create-react-app gaodeapi安裝高德地圖包: npm install @amap/amap-jsapi-loader1.2 可視化地圖 在…

uniapp使用npm下載

uniapp的項目在使用HBuilder X創建時是不會有node_modules文件夾的&#xff0c;如下圖所示&#xff1a; 但是uni-app不管基于哪個框架&#xff0c;它內部一定是有node.js的&#xff0c;否則沒有辦法去實現框架層面的一些東西&#xff0c;只是說它略微有點差異。具體差異表現在…

輕量在線工具箱系統源碼 附教程

源碼介紹 輕量在線工具箱系統源碼,直接扔服務器 修改config/config.php文件里面的數據庫 后臺賬號admin 密碼admin123 本工具是AI寫的 所以工具均是第三方接口直接寫的。 需要加工具直接自己找接口寫好扔到goju目錄 后臺自動讀取 效果預覽 源碼獲取 輕量在線工具箱系統源…

圖解gpt之Seq2Seq架構與序列到序列模型

今天深入探討如何構建更強大的序列到序列模型&#xff0c;特別是Seq2Seq架構。序列到序列模型&#xff0c;顧名思義&#xff0c;它的核心任務就是將一個序列映射到另一個序列。這個序列可以是文本&#xff0c;也可以是其他符號序列。最早&#xff0c;人們嘗試用一個單一的RNN來…

mac M2能安裝的虛擬機和linux系統系統

能適配MAC M2芯片的虛擬機下Linux系統的搭建全是深坑&#xff0c;目前網上的資料能搜到的都是錯誤的&#xff0c;自己整理并分享給坑友們~ 網上搜索到的推薦安裝的改造過的centos7也無法進行yum操作&#xff0c;我這邊建議安裝centos8 VMware Fusion下載地址&#xff1a; htt…

「國產嵌入式仿真平臺:高精度虛實融合如何終結Proteus時代?」——從教學實驗到低空經濟,揭秘新一代AI賦能的產業級教學工具

引言&#xff1a;從Proteus到國產平臺的范式革新 在高校嵌入式實驗教學中&#xff0c;仿真工具的選擇直接影響學生的工程能力培養與創新思維發展。長期以來&#xff0c;Proteus作為經典工具占據主導地位&#xff0c;但其設計理念已難以滿足現代復雜系統教學與國產化技術需求。…

【Linux】在Arm服務器源碼編譯onnxruntime-gpu的whl

服務器信息&#xff1a; aarch64架構 ubuntu20.04 nvidia T4卡 編譯onnxruntime-gpu前置條件&#xff1a; 已經安裝合適的cuda已經安裝合適的cudnn已經安裝合適的cmake 源碼編譯onnxruntime-gpu的步驟 1. 下載源碼 git clone --recursive https://github.com/microsoft/o…

前端上傳el-upload、原生input本地文件pdf格式(純前端預覽本地文件不走后端接口)

前端實現本地文件上傳與預覽&#xff08;PDF格式展示&#xff09;不走后端接口 實現步驟 第一步&#xff1a;文件選擇 使用前端原生input上傳本地文件&#xff0c;或者是el-upload組件實現文件選擇功能&#xff0c;核心在于文件渲染處理。&#xff08;input只不過可以自定義樣…

Python 數據分析與可視化:開啟數據洞察之旅(5/10)

一、Python 數據分析與可視化簡介 在當今數字化時代&#xff0c;數據就像一座蘊藏無限價值的寶藏&#xff0c;等待著我們去挖掘和探索。而 Python&#xff0c;作為數據科學領域的明星語言&#xff0c;憑借其豐富的庫和強大的功能&#xff0c;成為了開啟這座寶藏的關鍵鑰匙&…

C語言學習記錄——深入理解指針(4)

OK&#xff0c;這一篇主要是講我學習的3種指針類型。 正文開始&#xff1a; 一.字符指針 所謂字符指針&#xff0c;顧名思義就是指向字符的指針。一般寫作 " char* " 直接來說說它的使用方法吧&#xff1a; &#xff08;1&#xff09;一般使用情況&#xff1a; i…

springboot3+vue3融合項目實戰-大事件文章管理系統獲取用戶詳細信息-ThreadLocal優化

一句話本質 為每個線程創建獨立的變量副本&#xff0c;實現多線程環境下數據的安全隔離&#xff08;線程操作自己的副本&#xff0c;互不影響&#xff09;。 關鍵解讀&#xff1a; 核心機制 ? 同一個 ThreadLocal 對象&#xff08;如示意圖中的紅色區域 tl&#xff09;被多個線…

Nacos源碼—8.Nacos升級gRPC分析六

大綱 7.服務端對服務實例進行健康檢查 8.服務下線如何注銷注冊表和客戶端等信息 9.事件驅動架構源碼分析 一.處理ClientChangedEvent事件 也就是同步數據到集群節點&#xff1a; public class DistroClientDataProcessor extends SmartSubscriber implements DistroDataSt…

設計雜談-工廠模式

“工廠”模式在各種框架中非常常見&#xff0c;包括 MyBatis&#xff0c;它是一種創建對象的設計模式。使用工廠模式有很多好處&#xff0c;尤其是在復雜的框架中&#xff0c;它可以帶來更好的靈活性、可維護性和可配置性。 讓我們以 MyBatis 為例&#xff0c;來理解工廠模式及…

AI與IoT攜手,精準農業未來已來

AIoT&#xff1a;農業領域的變革先鋒 在科技飛速發展的當下&#xff0c;人工智能&#xff08;AI&#xff09;與物聯網&#xff08;IoT&#xff09;的融合 ——AIoT&#xff0c;正逐漸成為推動各行業變革的關鍵力量&#xff0c;農業領域也不例外。AIoT 技術通過將 AI 的智能分析…

排錯-harbor-db容器異常重啟

排錯-harbor-db容器異常重啟 環境&#xff1a; docker 19.03 , harbor-db(postgresql) goharbor/harbor-db:v2.5.6 現象&#xff1a; harbor-db 容器一直restart&#xff0c;查看日志發現 報錯 initdb: error: directory "/var/lib/postgresql/data/pg13" exists…

Docker容器啟動失敗?無法啟動?

Docker容器無法啟動的疑難雜癥解析與解決方案 一、問題現象 Docker容器無法啟動是開發者在容器化部署中最常見的故障之一。盡管Docker提供了豐富的調試工具&#xff0c;但問題的根源往往隱藏在復雜的配置、環境依賴或資源限制中。本文將從環境變量配置錯誤這一細節問題入手&am…

查看購物車

一.查看購物車 查看購物車使用get請求。我們要查看當前用戶的購物車&#xff0c;就要獲取當前用戶的userId字段進行條件查詢。因為在用戶登錄時就已經將userId封裝在token中了&#xff0c;因此我們只需要解析token獲取userId即可&#xff0c;不需要前端再傳入參數了。 Control…

C/C++ 內存管理深度解析:從內存分布到實踐應用(malloc和new,free和delete的對比與使用,定位 new )

一、引言&#xff1a;理解內存管理的核心價值 在系統級編程領域&#xff0c;內存管理是決定程序性能、穩定性和安全性的關鍵因素。C/C 作為底層開發的主流語言&#xff0c;賦予開發者直接操作內存的能力&#xff0c;卻也要求開發者深入理解內存布局與生命周期管理。本文將從內…

使用Stable Diffusion(SD)中CFG參數指的是什么?該怎么用!

1.定義&#xff1a; CFG參數控制模型在生成圖像時&#xff0c;對提示詞&#xff08;Prompt&#xff09;的“服從程度”。 它衡量模型在“完全根據提示詞生成圖像”和“自由生成圖像”&#xff08;不參考提示詞&#xff09;之間的權衡程度。 數值范圍&#xff1a;常見范圍是 1 …

【GESP】C++三級練習 luogu-B2156 最長單詞 2

GESP三級練習&#xff0c;字符串練習&#xff08;C三級大綱中6號知識點&#xff0c;字符串&#xff09;&#xff0c;難度★★☆☆☆。 題目題解詳見&#xff1a;https://www.coderli.com/gesp-3-luogu-b2156/ 【GESP】C三級練習 luogu-B2156 最長單詞 2 | OneCoderGESP三級練…