PyTorch訓練深度卷積生成對抗網絡DCGAN

文章目錄

    • DCGAN介紹
    • 代碼
    • 結果
    • 參考

DCGAN介紹

將CNN和GAN結合起來,把監督學習和無監督學習結合起來。具體解釋可以參見 深度卷積對抗生成網絡(DCGAN)

DCGAN的生成器結構:
在這里插入圖片描述
圖片來源:https://arxiv.org/abs/1511.06434

代碼

model.py

import torch
import torch.nn as nnclass Discriminator(nn.Module):def __init__(self, channels_img, features_d):super(Discriminator, self).__init__()self.disc = nn.Sequential(# Input: N x channels_img x 64 x 64nn.Conv2d(channels_img, features_d, kernel_size=4, stride=2, padding=1), # 32 x 32nn.LeakyReLU(0.2),self._block(features_d, features_d*2, 4, 2, 1), # 16 x 16self._block(features_d*2, features_d*4, 4, 2, 1), # 8 x 8self._block(features_d*4, features_d*8, 4, 2, 1), # 4 x 4nn.Conv2d(features_d*8, 1, kernel_size=4, stride=2, padding=0), # 1 x 1nn.Sigmoid(),)def _block(self, in_channels, out_channels, kernel_size, stride, padding):return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),nn.BatchNorm2d(out_channels),nn.LeakyReLU(0.2),)def forward(self, x):return self.disc(x)class Generator(nn.Module):def __init__(self, z_dim, channels_img, features_g):super(Generator, self).__init__()self.gen = nn.Sequential(# Input: N x z_dim x 1 x 1self._block(z_dim, features_g*16, 4, 1, 0), # N x f_g*16 x 4 x 4self._block(features_g*16, features_g*8, 4, 2, 1), # 8x8self._block(features_g*8, features_g*4, 4, 2, 1), # 16x16self._block(features_g*4, features_g*2, 4, 2, 1), # 32x32nn.ConvTranspose2d(features_g*2, channels_img, kernel_size=4, stride=2, padding=1,),nn.Tanh(),)def _block(self, in_channels, out_channels, kernel_size, stride, padding):return nn.Sequential(nn.ConvTranspose2d(in_channels,out_channels,kernel_size,stride,padding,bias=False,),nn.BatchNorm2d(out_channels),nn.ReLU(),)def forward(self, x):return self.gen(x)def initialize_weights(model):for m in model.modules():if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):nn.init.normal_(m.weight.data, 0.0, 0.02)def test():N, in_channels, H, W = 8, 3, 64, 64z_dim = 100x = torch.randn((N, in_channels, H, W))disc = Discriminator(in_channels, 8)initialize_weights(disc)assert disc(x).shape == (N, 1, 1, 1)gen = Generator(z_dim, in_channels, 8)initialize_weights(gen)z = torch.randn((N, z_dim, 1, 1))assert gen(z).shape == (N, in_channels, H, W)print("success")if __name__ == "__main__":test()

訓練使用的數據集:CelebA dataset (Images Only) 總共1.3GB的圖片,使用方法,將其解壓到當前目錄

圖片如下圖所示:
在這里插入圖片描述

train.py

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from model import Discriminator, Generator, initialize_weights# Hyperparameters etc.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LEARNING_RATE = 2e-4  # could also use two lrs, one for gen and one for disc
BATCH_SIZE = 128
IMAGE_SIZE = 64
CHANNELS_IMG = 3 # 1 if MNIST dataset; 3 if celeb dataset
NOISE_DIM = 100
NUM_EPOCHS = 5
FEATURES_DISC = 64
FEATURES_GEN = 64transforms = transforms.Compose([transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),transforms.ToTensor(),transforms.Normalize([0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]),]
)# If you train on MNIST, remember to set channels_img to 1
# dataset = datasets.MNIST(
#     root="dataset/", train=True, transform=transforms, download=True
# )# comment mnist above and uncomment below if train on CelebA# If you train on celeb dataset, remember to set channels_img to 3
dataset = datasets.ImageFolder(root="celeb_dataset", transform=transforms)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
gen = Generator(NOISE_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
disc = Discriminator(CHANNELS_IMG, FEATURES_DISC).to(device)
initialize_weights(gen)
initialize_weights(disc)opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
criterion = nn.BCELoss()fixed_noise = torch.randn(32, NOISE_DIM, 1, 1).to(device)
writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
step = 0gen.train()
disc.train()for epoch in range(NUM_EPOCHS):# Target labels not needed! <3 unsupervisedfor batch_idx, (real, _) in enumerate(dataloader):real = real.to(device)noise = torch.randn(BATCH_SIZE, NOISE_DIM, 1, 1).to(device)fake = gen(noise)### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))disc_real = disc(real).reshape(-1)loss_disc_real = criterion(disc_real, torch.ones_like(disc_real))disc_fake = disc(fake.detach()).reshape(-1)loss_disc_fake = criterion(disc_fake, torch.zeros_like(disc_fake))loss_disc = (loss_disc_real + loss_disc_fake) / 2disc.zero_grad()loss_disc.backward()opt_disc.step()### Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))output = disc(fake).reshape(-1)loss_gen = criterion(output, torch.ones_like(output))gen.zero_grad()loss_gen.backward()opt_gen.step()# Print losses occasionally and print to tensorboardif batch_idx % 100 == 0:print(f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(dataloader)} \Loss D: {loss_disc:.4f}, loss G: {loss_gen:.4f}")with torch.no_grad():fake = gen(fixed_noise)# take out (up to) 32 examplesimg_grid_real = torchvision.utils.make_grid(real[:32], normalize=True)img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)writer_real.add_image("Real", img_grid_real, global_step=step)writer_fake.add_image("Fake", img_grid_fake, global_step=step)step += 1

結果

訓練5個epoch,部分結果如下:

Epoch [3/5] Batch 1500/1583                   Loss D: 0.4996, loss G: 1.1738
Epoch [4/5] Batch 0/1583                   Loss D: 0.4268, loss G: 1.6633
Epoch [4/5] Batch 100/1583                   Loss D: 0.4841, loss G: 1.7475
Epoch [4/5] Batch 200/1583                   Loss D: 0.5094, loss G: 1.2376
Epoch [4/5] Batch 300/1583                   Loss D: 0.4376, loss G: 2.1271
Epoch [4/5] Batch 400/1583                   Loss D: 0.4173, loss G: 1.4380
Epoch [4/5] Batch 500/1583                   Loss D: 0.5213, loss G: 2.1665
Epoch [4/5] Batch 600/1583                   Loss D: 0.5036, loss G: 2.1079
Epoch [4/5] Batch 700/1583                   Loss D: 0.5158, loss G: 1.0579
Epoch [4/5] Batch 800/1583                   Loss D: 0.5426, loss G: 1.9427
Epoch [4/5] Batch 900/1583                   Loss D: 0.4721, loss G: 1.2659
Epoch [4/5] Batch 1000/1583                   Loss D: 0.5662, loss G: 2.4537
Epoch [4/5] Batch 1100/1583                   Loss D: 0.5604, loss G: 0.8978
Epoch [4/5] Batch 1200/1583                   Loss D: 0.4085, loss G: 2.0747
Epoch [4/5] Batch 1300/1583                   Loss D: 1.1894, loss G: 0.1825
Epoch [4/5] Batch 1400/1583                   Loss D: 0.4518, loss G: 2.1509
Epoch [4/5] Batch 1500/1583                   Loss D: 0.3814, loss G: 1.9391

使用

tensorboard --logdir=logs

打開tensorboard

在這里插入圖片描述

參考

[1] DCGAN implementation from scratch
[2] https://arxiv.org/abs/1511.06434

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/43068.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/43068.shtml
英文地址,請注明出處:http://en.pswp.cn/news/43068.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

VSCode 使用總結

快捷鍵 在 Visual Studio Code (VSCode) 中&#xff0c;有許多常用的快捷鍵可以提高編程效率。以下是一些常見的 VSCode 編程項目快捷鍵&#xff1a; 編輯器操作&#xff1a; 撤銷&#xff1a;Ctrl Z重做&#xff1a;Ctrl Shift Z復制&#xff1a;Ctrl C剪切&#xff1a;C…

Electron入門,項目啟動。

electron 簡單介紹&#xff1a; 實現&#xff1a;HTML/CSS/JS桌面程序&#xff0c;搭建跨平臺桌面應用。 electron 官方文檔&#xff1a; [https://electronjs.org/docs] 本文是基于以下2篇文章且自行實踐過的&#xff0c;可行性真實有效。 文章1&#xff1a; https://www.cnbl…

題解 | #1005.List Reshape# 2023杭電暑期多校9

1005.List Reshape 簽到題 題目大意 按一定格式給定一個純數字一維數組&#xff0c;按給定格式輸出成二維數組。 解題思路 讀入初始數組字符串&#xff0c;將每個數字分離&#xff0c;按要求輸出即可 參考代碼 參考代碼為已AC代碼主干&#xff0c;其中部分功能需讀者自行…

學習Vue:組件通信

組件化開發在現代前端開發中是一種關鍵的方法&#xff0c;它能夠將復雜的應用程序拆分為更小、更可管理的獨立組件。在Vue.js中&#xff0c;父子組件通信是組件化開發中的重要概念&#xff0c;同時我們還會討論其他組件間通信的方式。 父子組件通信&#xff1a;Props 和 Events…

TDSQL赤兔管理臺無管理員用戶密碼解決方案

解決方案 問題描述&#xff1a; tdsql使用過程中&#xff0c;可能會遇到控制臺用戶密碼忘記的情況&#xff0c;用戶登錄次數過多被鎖的情況&#xff0c;沒有管理員的用戶密碼又急需某些權限的情況。 解決過程&#xff1a; 獲取配置庫信息&#xff1a; 在瀏覽器上打開如下命…

基于Javaweb的攝影作品網站/攝影網站

摘 要 隨著信息化時代的到來&#xff0c;系統管理都趨向于智能化、系統化&#xff0c;攝影作品網站也不例外&#xff0c;但目前國內的有些網站仍然都使用人工管理&#xff0c;瀏覽網站人數越來越多&#xff0c;同時信息量也越來越龐大&#xff0c;人工管理顯然已無法應對時代的…

AMD fTPM RNG的BUG使得Linus Torvalds不滿

導讀因為在 Ryzen 系統上對內核造成了困擾&#xff0c;Linus Torvalds 最近在郵件列表中表達了對 AMD fTPM 硬件隨機數生成器的不滿&#xff0c;并提出了禁用該功能的建議。 因為在 Ryzen 系統上對內核造成了困擾&#xff0c;Linus Torvalds 最近在郵件列表中表達了對 AMD fTPM…

【【Verilog典型電路設計之FIFO設計】】

典型電路設計之FIFO設計 FIFO (First In First Out&#xff09;是一種先進先出的數據緩存器&#xff0c;通常用于接口電路的數據緩存。與普通存儲器的區別是沒有外部讀寫地址線&#xff0c;可以使用兩個時鐘分別進行寫和讀操作。FIFO只能順序寫入數據和順序讀出數據&#xff0…

ThinkPHP中實現IP地址定位

在網站開發中&#xff0c;我們經常需要獲取用戶的地理位置信息以提供個性化的服務。一種常見的方法是通過IP地址定位。在本文中&#xff0c;我們將介紹如何在ThinkPHP框架中實現IP地址定位。 一、IP地址定位的基本原理 IP地址是Internet上的設備在網絡中的標識符。每個設備都有…

【從0開始學架構筆記】01 基礎架構

文章目錄 一、架構的定義1. 系統與子系統2. 模塊與組件3. 框架與架構4. 重新定義架構 二、架構設計的目的三、復雜度來源&#xff1a;高性能1. 單機復雜度2. 集群復雜度2.1 任務分配2.2 任務分解&#xff08;微服務&#xff09; 四、復雜度來源&#xff1a;高可用1. 計算高可用…

GitKraken保姆級圖文使用指南

前言 寫這篇文章的原因是組內的產品和美術同學&#xff0c;開始參與到git工作流中&#xff0c;但是網上又沒有找到一個比較詳細的使用教程&#xff0c;所以干脆就自己寫了一個[doge]。文章的內容比較基礎&#xff0c;介紹了Git內的一些基礎概念和基本操作&#xff0c;適合零基…

合并多個文本文件

使用 wxPython 模塊合并多個文本文件的博客。以下是一篇示例博客&#xff1a; C:\pythoncode\blog\txtmerge.py 在 Python 編程中&#xff0c;我們經常需要處理文本文件。有時候&#xff0c;我們可能需要將多個文本文件合并成一個文件&#xff0c;以便進行進一步的處理或分析。…

QT報表Limereport v1.5.35編譯及使用

1、編譯說明 下載后QT CREATER中打開limereport.pro然后直接編譯就可以了。編譯后結果如下圖&#xff1a; 一次編譯可以得到庫文件和DEMO執行程序。 2、使用說明 拷貝如下圖編譯后的lib目錄到自己的工程目錄中。 release版本的重新命名為librelease. PRO文件中配置 QT …

openpose姿態估計【學習筆記】

文章目錄 1、人體需要檢測的關鍵點2、Top-down方法3、Openpose3.1 姿態估計的步驟3.2 PAF&#xff08;Part Affinity Fields&#xff09;部分親和場3.3 制作PAF標簽3.4 PAF權值計算3.5 匹配方法 4、CPM&#xff08;Convolutional Pose Machines&#xff09;模型5、Openpose5.1 …

怎么修改圖片的分辨率?

怎么修改圖片的分辨率&#xff1f;很多人還不知道分辨率是什么意思&#xff0c;以為代表了圖片的清晰度&#xff0c;然而并不是這樣的&#xff0c;其實圖片的分辨率就是圖片尺寸大小的意思。修改圖片的分辨率即改變圖片的尺寸&#xff0c;通常以像素為單位表示。分辨率決定了圖…

【linux基礎(四)】對Linux權限的理解

&#x1f493;博主CSDN主頁:杭電碼農-NEO&#x1f493; ? ?專欄分類:Linux從入門到開通? ? &#x1f69a;代碼倉庫:NEO的學習日記&#x1f69a; ? &#x1f339;關注我&#x1faf5;帶你學更多操作系統知識 ? &#x1f51d;&#x1f51d; Linux權限 1. 前言2. shell命…

八、Linux下,grep/wc/管道符/echo/重定向符/tail如何使用?

1、grep命令 &#xff08;1&#xff09;主要用于文件 &#xff08;2&#xff09;主要作用是“通過關鍵字&#xff0c;過濾文件行” &#xff08;3&#xff09;示例&#xff1a; 2、wc命令 &#xff08;1&#xff09;統計文件的行數、單詞數等 &#xff08;2&#xff09;示例…

react之路由的安裝與使用

一、路由安裝 路由官網2021.11月初&#xff0c;react-router 更新到 v6 版本。使用最廣泛的 v5 版本的使用 npm i react-router-dom5.3.0二、路由使用 2.1 路由的簡單使用 第一步 在根目錄下 創建 views 文件夾 ,用于放置路由頁面 films.js示例代碼 export default functio…

一文預覽 | 8 月 16 日 NVIDIA 在 WAVE SUMMIT深度學習開發者大會 2023精彩亮點搶先看!

由深度學習技術及應用國家工程研究中心主辦&#xff0c;百度飛槳和文心大模型承辦的 WAVE SUMMIT深度學習開發者大會2023&#xff0c;將于 8 月 16 日在北京與大家見面。NVIDIA 作為技術合作伙伴&#xff0c;將攜手百度飛槳參與這場技術盛會。 在這次大會中&#xff0c;NVIDIA…

Java 項目日志實例基礎:Log4j

點擊下方關注我&#xff0c;然后右上角點擊...“設為星標”&#xff0c;就能第一時間收到更新推送啦~~~ 介紹幾個日志使用方面的基礎知識。 1 Log4j 1、Log4j 介紹 Log4j&#xff08;log for java&#xff09;是 Apache 的一個開源項目&#xff0c;通過使用 Log4j&#xff0c;我…