R3GAN訓練自己的數據集

簡介

簡介:這篇論文挑戰了"GANs難以訓練"的廣泛觀點,通過提出一個更穩定的損失函數和現代化的網絡架構,構建了一個簡潔而高效的GAN基線模型R3GAN。作者證明了通過合適的理論基礎和架構設計,GANs可以穩定訓練并達到優異性能。

論文題目:The GAN is dead; long live the GAN! A Modern Baseline GAN

會議:NeurIPS 2024

源碼地址:https://www.github.com/brownvc/R3GAN

本文在調試代碼的時候對代碼做了一些修改,如果有遇到報錯的問題可以直接復制我這篇博客修改后的代碼:R3GAN利用配置好的Pytorch訓練自己的數據集-CSDN博客這篇論文挑戰了"GANs難以訓練"的廣泛觀點,通過提出一個更穩定的損失函數和現代化的網絡架構,構建了一個簡潔而高效的GAN基線模型R3GAN。作者證明了通過合適的理論基礎和架構設計,GANs可以穩定訓練并達到優異性能。 https://blog.csdn.net/LJ1147517021/article/details/148315781?fromshare=blogdetail&sharetype=blogdetail&sharerId=148315781&sharerefer=PC&sharesource=LJ1147517021&sharefrom=from_link

摘要:論文反駁了GANs難以訓練的普遍觀點,提出了一個理論有保障的現代GAN基線。首先,推導出一個良好行為的正則化相對論GAN損失函數,解決了模式丟棄和不收斂問題,并數學證明了其局部收斂性。其次,該損失函數允許丟棄所有經驗性技巧,用現代架構替換常見GANs中的過時骨干網絡。以StyleGAN2為例,展示了簡化和現代化的路線圖,產生了新的極簡基線R3GAN。盡管簡單,該方法在FFHQ、ImageNet、CIFAR和Stacked MNIST數據集上超越了StyleGAN2,與最先進的GANs和擴散模型相比表現優異。

模型結構

生成器架構

核心設計原則:

  • 基于現代化ResNet架構,摒棄VGG-like設計
  • 每個分辨率階段包含一個過渡層和兩個殘差塊
  • 采用分組卷積和倒置瓶頸設計

關鍵特性:

  • 無歸一化層:避免批量歸一化等數據相關的歸一化
  • Fix-up初始化:零初始化每個殘差塊的最后一層卷積
  • 雙線性插值:用于上采樣,避免棋盤效應

鑒別器架構

設計特點:

  • 與生成器完全對稱的架構
  • 相同的殘差塊結構和過渡層設計
  • 分類器頭:全局4×4深度卷積 + 線性層

損失函數

相對論配對GAN損失 (RpGAN):

L(θ,ψ) = E[f(D_ψ(G_θ(z)) - D_ψ(x))]

R1正則化:

R1(ψ) = (γ/2) * E[||?_x D_ψ(x)||2] ?(x~p_D)

R2正則化:

R2(θ,ψ) = (γ/2) * E[||?_x D_ψ(x)||2] ?(x~p_θ)

訓練自己的數據集

1. 準備數據集

首先使用 dataset_tool.py 將您的圖像數據轉換為適合訓練的格式:

# 從文件夾創建數據集
python dataset_tool.py --source=path/to/your/images --dest=path/to/output.zip# 如果需要調整分辨率和裁剪
python dataset_tool.py --source=path/to/your/images --dest=path/to/output.zip \--resolution=256x256 --transform=center-crop

數據集要求:

  • 圖像必須是正方形(如256x256, 512x512)
  • 分辨率必須是2的冪次(64, 128, 256, 512, 1024等)
  • 支持RGB或灰度圖像
  • 可以是文件夾或ZIP格式

2. 創建自定義訓練配置

train.py 中添加您自己的預設配置。參考現有預設,在 main() 函數中添加:

if opts.preset == 'YOUR_DATASET':# 網絡架構參數WidthPerStage = [768, 768, 768, 512, 256]  # 每階段寬度BlocksPerStage = [2, 2, 2, 2, 2]           # 每階段塊數CardinalityPerStage = [96, 96, 96, 48, 24] # 每階段基數FP16Stages = [-1, -2, -3, -4]              # FP16優化的階段NoiseDimension = 64                         # 噪聲維度# 如果是條件生成(有類別標簽)if opts.cond:c.G_kwargs.ConditionEmbeddingDimension = NoiseDimensionc.D_kwargs.ConditionEmbeddingDimension = WidthPerStage[0]# 訓練調度參數ema_nimg = 500 * 1000      # EMA開始的圖像數decay_nimg = 2e7           # 總衰減圖像數# 各種調度器c.ema_scheduler = { 'base_value': 0, 'final_value': ema_nimg, 'total_nimg': decay_nimg }c.aug_scheduler = { 'base_value': 0, 'final_value': 0.3, 'total_nimg': decay_nimg }c.lr_scheduler = { 'base_value': 2e-4, 'final_value': 5e-5, 'total_nimg': decay_nimg }c.gamma_scheduler = { 'base_value': 2, 'final_value': 0.2, 'total_nimg': decay_nimg }c.beta2_scheduler = { 'base_value': 0.9, 'final_value': 0.99, 'total_nimg': decay_nimg }

3. 開始訓練

# 無條件生成(如人臉、風景等)
python train.py \--outdir=./training-runs \--data=./datasets/your_dataset.zip \--gpus=4 \--batch=256 \--mirror=1 \--aug=1 \--preset=YOUR_DATASET \--tick=1 \--snap=200# 條件生成(有類別標簽)
python train.py \--outdir=./training-runs \--data=./datasets/your_dataset.zip \--gpus=4 \--batch=256 \--mirror=1 \--aug=1 \--cond=1 \--preset=YOUR_DATASET \--tick=1 \--snap=200

4. 參數說明

  • --gpus: GPU數量
  • --batch: 總批次大小
  • --mirror: 是否啟用水平翻轉增強
  • --aug: 是否啟用數據增強
  • --cond: 是否訓練條件模型(需要標簽)
  • --tick: 多少kimg輸出一次進度
  • --snap: 多少tick保存一次模型

5. 生成圖像

訓練完成后,使用保存的模型生成圖像:

# 生成8張圖像
python gen_images.py \--seeds=0-7 \--outdir=generated_images \--network=training-runs/xxxxx-your_dataset/network-snapshot-xxxxx.pkl# 條件生成(指定類別)
python gen_images.py \--seeds=0-7 \--outdir=generated_images \--class=5 \--network=training-runs/xxxxx-your_dataset/network-snapshot-xxxxx.pkl

6. 評估指標

python calc_metrics.py \--metrics=fid50k_full,kid50k_full \--data=./datasets/your_dataset.zip \--network=training-runs/xxxxx-your_dataset/network-snapshot-xxxxx.pkl

7.報錯指南

1.UnboundLocalError: local variable 'NoiseDimension' referenced before assignment

解決辦法:在 train.py 中,NoiseDimension 只在特定的預設配置塊中定義(如 CIFAR10、FFHQ-64 等)。如果您使用的 --preset 參數不匹配任何現有預設,這個變量就不會被定義,導致使用時出錯。可以使用作者定義好的預先設置。

--preset=CIFAR10
--preset=FFHQ-64  
--preset=FFHQ-256
--preset=ImageNet-32
--preset=ImageNet-64

2.RuntimeError: Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "R3GAN\torch_utils\custom_ops.py".

解決辦法:這個錯誤是因為R3GAN使用了自定義的CUDA操作符,需要C++編譯器來編譯。在Windows系統上缺少MSVC/GCC/CLANG編譯器。

修改 torch_utils/custom_ops.py:找到 get_plugin 函數(大約第84行),在函數開頭添加:

def get_plugin(module_name, sources, headers=None, source_dir=None, **build_kwargs):# 禁用所有自定義插件return Nonedef bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='ref'):# 強制使用 'ref' 實現impl = 'ref'

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/82937.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/82937.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/82937.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

【PhysUnits】15.1 引入P1后的加一特質(add1.rs)

一、源碼 代碼實現了類型系統中的"加一"操作(Add1 trait),用于在編譯期進行數字的增量計算。 //! 加一操作特質實現 / Increment operation trait implementation //! //! 說明: //! 1. Z0、P1,、N1 1&#xff0…

記錄算法筆記(2025.5.29)最小棧

設計一個支持 push ,pop ,top 操作,并能在常數時間內檢索到最小元素的棧。 實現 MinStack 類: MinStack() 初始化堆棧對象。void push(int val) 將元素val推入堆棧。void pop() 刪除堆棧頂部的元素。int top() 獲取堆棧頂部的元素。int get…

Android高級開發第一篇 - JNI(初級入門篇)

文章目錄 Android高級開發JNI開發第一篇(初級入門篇)🧠 一、什么是 JNI?? 為什么要用 JNI? ?? 二、開發環境準備開發工具 🚀 三、創建一個支持 JNI 的 Android 項目第一步:創建新項目項目結構…

PyTorch Image Models (timm) 技術指南

timm PyTorch Image Models (timm) 技術指南功能概述 一、引言二、timm 庫概述三、安裝 timm 庫四、模型加載與推理示例4.1 通用推理流程4.2 具體模型示例4.2.1 ResNeXt50-32x4d4.2.2 EfficientNet-V2 Small 模型4.2.3 DeiT-3 large 模型4.2.4 RepViT-M2 模型4.2.5 ResNet-RS-1…

openEuler安裝MySql8(tar包模式)

操作系統版本: openEuler release 22.03 (LTS-SP4) MySql版本: 下載地址: https://dev.mysql.com/downloads/mysql/ 準備安裝: 關閉防火墻: 停止防火墻 #systemctl stop firewalld.service 關閉防火墻 #systemc…

從零開始的數據結構教程(六) 貪心算法

🍬 標題一:貪心核心思想——發糖果時的最優分配策略 貪心算法 (Greedy Algorithm) 是一種簡單直觀的算法策略。它在每一步選擇中都采取在當前狀態下最好或最優(即最有利)的選擇,從而希望得到一個全局最優解。這就像你…

CPP中CAS std::chrono 信號量與Any類的手動實現

前言 CAS(Compare and Swap) 是一種用于多線程同步的原子指令。它通過比較和交換操作來確保數據的一致性和線程安全性。CAS操作涉及三個操作數:內存位置V、預期值E和新值U。當且僅當內存位置V的值與預期值E相等時,CAS才會將內存位…

Axure設計案例——科技感對比柱狀圖

想讓數據對比展示擺脫平淡無奇,瞬間抓住觀眾的眼球嗎?那就來看看這個Axure設計的科技感對比柱狀圖案例!科技感設計風格運用獨特元素打破傳統對比柱狀圖的常規,營造出一種極具沖擊力的視覺氛圍。每一組柱狀體都仿佛是科技戰場上的士…

怒更一波免費聲音克隆和AI配音功能

寶子們! 最近咱軟件TransDuck的免費聲音克隆和AI配音功能被大家用爆啦!感謝各位自來水瘋狂安利!! DD這里也是收到好多用戶提的寶貴建議!所以,連夜肝了波更新! 這次重點更新使用克隆音色進行A…

UDP協議原理與Java編程實戰:無連接通信的奧秘

1.UDP協議核心原理 1. 無連接特性:快速通信的基石 UDP(User Datagram Protocol,用戶數據報協議)是TCP/IP協議族中無連接的輕量級傳輸層協議。與TCP的“三次握手”建立連接不同,UDP通信無需提前建立鏈路,發送…

vue-seamless-scroll 結束從頭開始,加延時后滾動

今天遇到一個大屏需求: 1??初始進入頁面停留5秒,然后開始滾動 2??最后一條數據出現在最后一行時候暫停5秒,然后返回1?? 依次循環,發現vue-seamless-scroll的方法 ScrollEnd是監測最后一條數據消失在第一行才回調&#xff…

[Protobuf] 快速上手:安全高效的序列化指南

標題:[Protobuf] (1)快速上手 水墨不寫bug 文章目錄 一、什么是protobuf?二、protobuf的特點三、使用protobuf的過程?1、定義消息格式(.proto文件)(1)指定語法版本(2)package 聲明符 2、使用protoc編譯器生成代碼&…

uniapp調用java接口 跨域問題

前言 之前在Windows10本地 調試一個舊項目,手機移動端用的是Uni-app,vue的版本是v2。后端是java spring-boot。運行手機移動端的首頁請求后臺接口老是提示錯誤信息。 錯誤信息如下: Access to XMLHttpRequest at http://localhost:8080/api/…

[ Qt ] | Qlabel使用

目錄 屬性 setTextFormat 插入圖片 設置圖片根據窗口大小實時變化 邊框和對其方式 ?編輯 設置縮進 設置伙伴 Qlabel可以用來顯式圖片和文字 屬性 text textFormat Qlabel獨有的機制:buddy setTextFormat 插入圖片 設置圖片根據窗口大小實時變化 Qt中表…

Springboot 項目一啟動就獲取HttpSession

在 Spring Boot 項目中,HttpSession 是有狀態的,通常只有在用戶發起 HTTP 請求并建立會話后才會創建。因此,在項目啟動時(即應用剛啟動還未處理任何請求)是無法獲取到 HttpSession 的。 方法一:使用 HttpS…

Step9—Ambari Web UI 初始化安裝 (Ambari3.0.0)

Ambari Web UI 安裝 如果還不會系統性的部署,或者前置內容不熟悉,建議從Step1 開始閱讀。不通版本針對于不同操作系統可能存在差異!這里我也整理好了 https://doc.janettr.com/install/manual/ 1. 進入 Ambari Web UI 并登錄 在瀏覽器中訪…

熱門大型語言模型(LLM)應用開發框架

我們來深入探索這些強大的大型語言模型(LLM)應用開發框架,并且我會嘗試用文本形式描述一些核心的流程圖,幫助您更好地理解它們的工作機制。由于我無法直接生成圖片,我會用文字清晰地描述流程圖的各個步驟和連接。 Lang…

機器學習數據降維方法

1.數據類型 2.如何選擇降維方法進行數據降維 3.線性降維:主成分分析(PCA)、線性判別分析(LDA) 4.非線性降維 5.基于特征選擇的降維 6.基于神經網絡的降維 數據降維是將高維數據轉換為低維表示的過程,旨在保…

太陽系運行模擬程序-html動畫

太陽系運行模擬程序-html動畫 by AI: <!DOCTYPE html> <html lang"zh"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>交互式太陽系…

2025年全國青少年信息素養大賽 scratch圖形化編程挑戰賽 小低組初賽 內部集訓模擬題解析

2025年信息素養大賽初賽scratch模擬題解析 博主推薦 所有考級比賽學習相關資料合集【推薦收藏】 scratch資料 Scratch3.0系列視頻課程資料零基礎學習scratch3.0【入門教學 免費】零基礎學習scratch3.0【視頻教程 114節 免費】 歷屆藍橋杯scratch國賽真題解析歷屆藍橋杯scr…