簡介
簡介:這篇論文挑戰了"GANs難以訓練"的廣泛觀點,通過提出一個更穩定的損失函數和現代化的網絡架構,構建了一個簡潔而高效的GAN基線模型R3GAN。作者證明了通過合適的理論基礎和架構設計,GANs可以穩定訓練并達到優異性能。
論文題目:The GAN is dead; long live the GAN! A Modern Baseline GAN
會議:NeurIPS 2024
源碼地址:https://www.github.com/brownvc/R3GAN
本文在調試代碼的時候對代碼做了一些修改,如果有遇到報錯的問題可以直接復制我這篇博客修改后的代碼:R3GAN利用配置好的Pytorch訓練自己的數據集-CSDN博客這篇論文挑戰了"GANs難以訓練"的廣泛觀點,通過提出一個更穩定的損失函數和現代化的網絡架構,構建了一個簡潔而高效的GAN基線模型R3GAN。作者證明了通過合適的理論基礎和架構設計,GANs可以穩定訓練并達到優異性能。https://blog.csdn.net/LJ1147517021/article/details/148315781?fromshare=blogdetail&sharetype=blogdetail&sharerId=148315781&sharerefer=PC&sharesource=LJ1147517021&sharefrom=from_link
摘要:論文反駁了GANs難以訓練的普遍觀點,提出了一個理論有保障的現代GAN基線。首先,推導出一個良好行為的正則化相對論GAN損失函數,解決了模式丟棄和不收斂問題,并數學證明了其局部收斂性。其次,該損失函數允許丟棄所有經驗性技巧,用現代架構替換常見GANs中的過時骨干網絡。以StyleGAN2為例,展示了簡化和現代化的路線圖,產生了新的極簡基線R3GAN。盡管簡單,該方法在FFHQ、ImageNet、CIFAR和Stacked MNIST數據集上超越了StyleGAN2,與最先進的GANs和擴散模型相比表現優異。
模型結構
生成器架構
核心設計原則:
- 基于現代化ResNet架構,摒棄VGG-like設計
- 每個分辨率階段包含一個過渡層和兩個殘差塊
- 采用分組卷積和倒置瓶頸設計
關鍵特性:
- 無歸一化層:避免批量歸一化等數據相關的歸一化
- Fix-up初始化:零初始化每個殘差塊的最后一層卷積
- 雙線性插值:用于上采樣,避免棋盤效應
鑒別器架構
設計特點:
- 與生成器完全對稱的架構
- 相同的殘差塊結構和過渡層設計
- 分類器頭:全局4×4深度卷積 + 線性層
損失函數
相對論配對GAN損失 (RpGAN):
L(θ,ψ) = E[f(D_ψ(G_θ(z)) - D_ψ(x))]
R1正則化:
R1(ψ) = (γ/2) * E[||?_x D_ψ(x)||2] ?(x~p_D)
R2正則化:
R2(θ,ψ) = (γ/2) * E[||?_x D_ψ(x)||2] ?(x~p_θ)
訓練自己的數據集
1. 準備數據集
首先使用 dataset_tool.py
將您的圖像數據轉換為適合訓練的格式:
# 從文件夾創建數據集
python dataset_tool.py --source=path/to/your/images --dest=path/to/output.zip# 如果需要調整分辨率和裁剪
python dataset_tool.py --source=path/to/your/images --dest=path/to/output.zip \--resolution=256x256 --transform=center-crop
數據集要求:
- 圖像必須是正方形(如256x256, 512x512)
- 分辨率必須是2的冪次(64, 128, 256, 512, 1024等)
- 支持RGB或灰度圖像
- 可以是文件夾或ZIP格式
2. 創建自定義訓練配置
在 train.py
中添加您自己的預設配置。參考現有預設,在 main()
函數中添加:
if opts.preset == 'YOUR_DATASET':# 網絡架構參數WidthPerStage = [768, 768, 768, 512, 256] # 每階段寬度BlocksPerStage = [2, 2, 2, 2, 2] # 每階段塊數CardinalityPerStage = [96, 96, 96, 48, 24] # 每階段基數FP16Stages = [-1, -2, -3, -4] # FP16優化的階段NoiseDimension = 64 # 噪聲維度# 如果是條件生成(有類別標簽)if opts.cond:c.G_kwargs.ConditionEmbeddingDimension = NoiseDimensionc.D_kwargs.ConditionEmbeddingDimension = WidthPerStage[0]# 訓練調度參數ema_nimg = 500 * 1000 # EMA開始的圖像數decay_nimg = 2e7 # 總衰減圖像數# 各種調度器c.ema_scheduler = { 'base_value': 0, 'final_value': ema_nimg, 'total_nimg': decay_nimg }c.aug_scheduler = { 'base_value': 0, 'final_value': 0.3, 'total_nimg': decay_nimg }c.lr_scheduler = { 'base_value': 2e-4, 'final_value': 5e-5, 'total_nimg': decay_nimg }c.gamma_scheduler = { 'base_value': 2, 'final_value': 0.2, 'total_nimg': decay_nimg }c.beta2_scheduler = { 'base_value': 0.9, 'final_value': 0.99, 'total_nimg': decay_nimg }
3. 開始訓練
# 無條件生成(如人臉、風景等)
python train.py \--outdir=./training-runs \--data=./datasets/your_dataset.zip \--gpus=4 \--batch=256 \--mirror=1 \--aug=1 \--preset=YOUR_DATASET \--tick=1 \--snap=200# 條件生成(有類別標簽)
python train.py \--outdir=./training-runs \--data=./datasets/your_dataset.zip \--gpus=4 \--batch=256 \--mirror=1 \--aug=1 \--cond=1 \--preset=YOUR_DATASET \--tick=1 \--snap=200
4. 參數說明
--gpus
: GPU數量--batch
: 總批次大小--mirror
: 是否啟用水平翻轉增強--aug
: 是否啟用數據增強--cond
: 是否訓練條件模型(需要標簽)--tick
: 多少kimg輸出一次進度--snap
: 多少tick保存一次模型
5. 生成圖像
訓練完成后,使用保存的模型生成圖像:
# 生成8張圖像
python gen_images.py \--seeds=0-7 \--outdir=generated_images \--network=training-runs/xxxxx-your_dataset/network-snapshot-xxxxx.pkl# 條件生成(指定類別)
python gen_images.py \--seeds=0-7 \--outdir=generated_images \--class=5 \--network=training-runs/xxxxx-your_dataset/network-snapshot-xxxxx.pkl
6. 評估指標
python calc_metrics.py \--metrics=fid50k_full,kid50k_full \--data=./datasets/your_dataset.zip \--network=training-runs/xxxxx-your_dataset/network-snapshot-xxxxx.pkl
7.報錯指南
1.UnboundLocalError: local variable 'NoiseDimension' referenced before assignment
解決辦法:在 train.py
中,NoiseDimension
只在特定的預設配置塊中定義(如 CIFAR10、FFHQ-64 等)。如果您使用的 --preset
參數不匹配任何現有預設,這個變量就不會被定義,導致使用時出錯。可以使用作者定義好的預先設置。
--preset=CIFAR10
--preset=FFHQ-64
--preset=FFHQ-256
--preset=ImageNet-32
--preset=ImageNet-64
2.RuntimeError: Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "R3GAN\torch_utils\custom_ops.py".
解決辦法:這個錯誤是因為R3GAN使用了自定義的CUDA操作符,需要C++編譯器來編譯。在Windows系統上缺少MSVC/GCC/CLANG編譯器。
修改 torch_utils/custom_ops.py
:找到 get_plugin
函數(大約第84行),在函數開頭添加:
def get_plugin(module_name, sources, headers=None, source_dir=None, **build_kwargs):# 禁用所有自定義插件return Nonedef bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='ref'):# 強制使用 'ref' 實現impl = 'ref'