基于Stable Diffusion XL模型進行文本生成圖像的訓練
flyfish
export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
export VAE_NAME="madebyollin/sdxl-vae-fp16-fix"
export DATASET_NAME="lambdalabs/naruto-blip-captions"accelerate launch train_text_to_image_sdxl.py \--pretrained_model_name_or_path=$MODEL_NAME \--pretrained_vae_model_name_or_path=$VAE_NAME \--dataset_name=$DATASET_NAME \--enable_xformers_memory_efficient_attention \--resolution=512 --center_crop --random_flip \--proportion_empty_prompts=0.2 \--train_batch_size=1 \--gradient_accumulation_steps=4 --gradient_checkpointing \--max_train_steps=10000 \--use_8bit_adam \--learning_rate=1e-06 --lr_scheduler="constant" --lr_warmup_steps=0 \--mixed_precision="fp16" \--report_to="wandb" \--validation_prompt="a cute Sundar Pichai creature" --validation_epochs 5 \--checkpointing_steps=5000 \--output_dir="sdxl-naruto-model" \--push_to_hub
環境變量部分
export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
export VAE_NAME="madebyollin/sdxl-vae-fp16-fix"
export DATASET_NAME="lambdalabs/naruto-blip-captions"
MODEL_NAME
:指定預訓練模型的名稱或路徑。這里使用的是stabilityai/stable-diffusion-xl-base-1.0
,也就是Stable Diffusion XL的基礎版本1.0。VAE_NAME
:指定變分自編碼器(VAE)的名稱或路徑。madebyollin/sdxl-vae-fp16-fix
是針對Stable Diffusion XL的一個經過修復的VAE模型,適用于半精度(FP16)計算。DATASET_NAME
:指定訓練所使用的數據集名稱或路徑。這里使用的是lambdalabs/naruto-blip-captions
,是一個包含火影忍者相關圖像及其描述的數據集。
accelerate launch
命令參數部分
accelerate launch train_text_to_image_sdxl.py \
這行代碼使用 accelerate
工具來啟動 train_text_to_image_sdxl.py
腳本,accelerate
可以幫助我們在多GPU、TPU等環境下進行分布式訓練。
腳本參數部分
--pretrained_model_name_or_path=$MODEL_NAME
:指定預訓練模型的名稱或路徑,這里使用前面定義的MODEL_NAME
環境變量。--pretrained_vae_model_name_or_path=$VAE_NAME
:指定預訓練VAE模型的名稱或路徑,使用前面定義的VAE_NAME
環境變量。--dataset_name=$DATASET_NAME
:指定訓練數據集的名稱或路徑,使用前面定義的DATASET_NAME
環境變量。--enable_xformers_memory_efficient_attention
:啟用xformers
庫的內存高效注意力機制,能減少訓練過程中的內存占用。--resolution=512 --center_crop --random_flip
:--resolution=512
:將輸入圖像的分辨率統一調整為512x512像素。--center_crop
:對圖像進行中心裁剪,使其達到指定的分辨率。--random_flip
:在訓練過程中隨機對圖像進行水平翻轉,以增加數據的多樣性。
--proportion_empty_prompts=0.2
:設置空提示(沒有文本描述)的樣本在訓練數據中的比例為20%。--train_batch_size=1
:每個訓練批次包含的樣本數量為1。--gradient_accumulation_steps=4 --gradient_checkpointing
:--gradient_accumulation_steps=4
:梯度累積步數為4,即每4個批次的梯度進行一次更新,這樣可以在有限的內存下模擬更大的批次大小。--gradient_checkpointing
:啟用梯度檢查點機制,通過減少內存使用來支持更大的模型和批次大小。
--max_train_steps=10000
:最大訓練步數為10000步。--use_8bit_adam
:使用8位Adam優化器,能減少內存占用。--learning_rate=1e-06 --lr_scheduler="constant" --lr_warmup_steps=0
:--learning_rate=1e-06
:學習率設置為1e-6。--lr_scheduler="constant"
:學習率調度器設置為常數,即訓練過程中學習率保持不變。--lr_warmup_steps=0
:學習率預熱步數為0,即不進行學習率預熱。
--mixed_precision="fp16"
:使用半精度(FP16)混合精度訓練,能減少內存使用并加快訓練速度。--report_to="wandb"
:將訓練過程中的指標報告到Weights & Biases(WandB)平臺,方便進行可視化和監控。--validation_prompt="a cute Sundar Pichai creature" --validation_epochs 5
:--validation_prompt="a cute Sundar Pichai creature"
:指定驗證時使用的文本提示,這里是“一個可愛的桑達爾·皮查伊形象”。--validation_epochs 5
:每5個訓練輪次進行一次驗證。
--checkpointing_steps=5000
:每5000步保存一次模型的檢查點。--output_dir="sdxl-naruto-model"
:指定訓練好的模型的輸出目錄為sdxl-naruto-model
。--push_to_hub
:將訓練好的模型推送到Hugging Face模型庫。
離線環境運行
# 假設已經把模型、VAE和數據集下載到本地了
# 這里假設模型在當前目錄下的 sdxl-base-1.0 文件夾
# VAE 在 sdxl-vae-fp16-fix 文件夾
# 數據集在 naruto-blip-captions 文件夾# 定義本地路徑
MODEL_NAME="./sdxl-base-1.0"
VAE_NAME="./sdxl-vae-fp16-fix"
DATASET_NAME="./naruto-blip-captions"# 移除需要外網連接的參數
accelerate launch train_text_to_image_sdxl.py \--pretrained_model_name_or_path=$MODEL_NAME \--pretrained_vae_model_name_or_path=$VAE_NAME \--dataset_name=$DATASET_NAME \--enable_xformers_memory_efficient_attention \--resolution=512 --center_crop --random_flip \--proportion_empty_prompts=0.2 \--train_batch_size=1 \--gradient_accumulation_steps=4 --gradient_checkpointing \--max_train_steps=10000 \--use_8bit_adam \--learning_rate=1e-06 --lr_scheduler="constant" --lr_warmup_steps=0 \--mixed_precision="fp16" \--validation_prompt="a cute Sundar Pichai creature" --validation_epochs 5 \--checkpointing_steps=5000 \--output_dir="sdxl-naruto-model"
移除需要外網連接的參數:去掉 --report_to="wandb"
和 --push_to_hub
參數,因為 wandb
需要外網連接來上傳訓練指標,--push_to_hub
則需要外網連接把模型推送到Hugging Face模型庫。
推理
from diffusers import DiffusionPipeline
import torchmodel_path = "you-model-id-goes-here" # <-- 替換為你的模型路徑
pipe = DiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16)
pipe.to("cuda")prompt = "A naruto with green eyes and red legs."
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
image.save("naruto.png")
訓練后的文件夾結構
.
├── checkpoint-10000
│ ├── optimizer.bin
│ ├── random_states_0.pkl
│ ├── scaler.pt
│ ├── scheduler.bin
│ └── unet
│ ├── config.json
│ ├── diffusion_pytorch_model-00001-of-00002.safetensors
│ ├── diffusion_pytorch_model-00002-of-00002.safetensors
│ └── diffusion_pytorch_model.safetensors.index.json
├── checkpoint-5000
│ ├── optimizer.bin
│ ├── random_states_0.pkl
│ ├── scaler.pt
│ ├── scheduler.bin
│ └── unet
│ ├── config.json
│ ├── diffusion_pytorch_model-00001-of-00002.safetensors
│ ├── diffusion_pytorch_model-00002-of-00002.safetensors
│ └── diffusion_pytorch_model.safetensors.index.json
├── model_index.json
├── scheduler
│ └── scheduler_config.json
├── text_encoder
│ ├── config.json
│ └── model.safetensors
├── text_encoder_2
│ ├── config.json
│ └── model.safetensors
├── tokenizer
│ ├── merges.txt
│ ├── special_tokens_map.json
│ ├── tokenizer_config.json
│ └── vocab.json
├── tokenizer_2
│ ├── merges.txt
│ ├── special_tokens_map.json
│ ├── tokenizer_config.json
│ └── vocab.json
├── unet
│ ├── config.json
│ ├── diffusion_pytorch_model-00001-of-00002.safetensors
│ ├── diffusion_pytorch_model-00002-of-00002.safetensors
│ └── diffusion_pytorch_model.safetensors.index.json
└── vae├── config.json└── diffusion_pytorch_model.safetensors
LoRA訓練
accelerate launch train_text_to_image_lora_sdxl.py \--pretrained_model_name_or_path=$MODEL_NAME \--pretrained_vae_model_name_or_path=$VAE_NAME \--dataset_name=$DATASET_NAME --caption_column="text" \--resolution=1024 --random_flip \--train_batch_size=1 \--num_train_epochs=2 --checkpointing_steps=500 \--learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \--mixed_precision="fp16" \--seed=42 \--output_dir="sd-naruto-model-lora-sdxl" \--validation_prompt="cute dragon creature"
推理
from diffusers import DiffusionPipeline
import torchsdxl_model_path="/media/models/AI-ModelScope/stable-diffusion-xl-base-1___0/"
lora_model_path = "/media/text_to_image/sd-naruto-model-lora-sdxl/"pipe = DiffusionPipeline.from_pretrained(sdxl_model_path, torch_dtype=torch.float16)
pipe.to("cuda")
pipe.load_lora_weights(lora_model_path)prompt = "A naruto with green eyes and red legs."
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
image.save("naruto.png")
LoRA訓練后的文件夾結構
├── checkpoint-1000
│ ├── optimizer.bin
│ ├── pytorch_lora_weights.safetensors
│ ├── random_states_0.pkl
│ ├── scaler.pt
│ └── scheduler.bin
├── checkpoint-1500
│ ├── optimizer.bin
│ ├── pytorch_lora_weights.safetensors
│ ├── random_states_0.pkl
│ ├── scaler.pt
│ └── scheduler.bin
├── checkpoint-2000
│ ├── optimizer.bin
│ ├── pytorch_lora_weights.safetensors
│ ├── random_states_0.pkl
│ ├── scaler.pt
│ └── scheduler.bin
├── checkpoint-500
│ ├── optimizer.bin
│ ├── pytorch_lora_weights.safetensors
│ ├── random_states_0.pkl
│ ├── scaler.pt
│ └── scheduler.bin
└── pytorch_lora_weights.safetensors