第40周——GAN入門

目錄

目錄

目錄

前言

一、定義超參數

二、下載數據

三、配置數據

四、定義鑒別器

五、訓練模型并保存

總結


前言

  • 🍨?本文為🔗365天深度學習訓練營中的學習記錄博客
  • 🍖?原作者:K同學啊

一、定義超參數

import argparse
import os
import numpy as np
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
import torch.nn as nn
import torch## 創建文件夾
os.makedirs("./images/", exist_ok=True)         # 記錄訓練過程的圖片效果
os.makedirs("./save/", exist_ok=True)           # 訓練完成時模型保存的位置
os.makedirs("./datasets/mnist", exist_ok=True)  # 下載數據集存放的位置## 超參數配置
n_epochs  = 50
batch_size= 64
lr        = 0.0002
b1        = 0.5
b2        = 0.999
n_cpu     = 2
latent_dim= 100
img_size  = 28
channels  = 1
sample_interval=500# 圖像的尺寸:(1, 28, 28),  和圖像的像素面積:(784)
img_shape = (channels, img_size, img_size)
img_area = np.prod(img_shape)# 設置cuda:(cuda:0)
cuda = True if torch.cuda.is_available() else False
print(cuda)

二、下載數據

# mnist數據集下載
mnist = datasets.MNIST(root='./datasets/', train=True, download=True, transform=transforms.Compose([transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]), 
)


三、配置數據

# 配置數據到加載器
dataloader = DataLoader(mnist,batch_size=batch_size,shuffle=True,
)

四、定義鑒別器

# 將圖片28x28展開成784,然后通過多層感知器,中間經過斜率設置為0.2的LeakyReLU激活函數,
# 最后接sigmoid激活函數得到一個0到1之間的概率進行二分類
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.model = nn.Sequential(nn.Linear(img_area, 512),         # 輸入特征數為784,輸出為512nn.LeakyReLU(0.2, inplace=True),  # 進行非線性映射nn.Linear(512, 256),              # 輸入特征數為512,輸出為256nn.LeakyReLU(0.2, inplace=True),  # 進行非線性映射nn.Linear(256, 1),                # 輸入特征數為256,輸出為1nn.Sigmoid(),                     # sigmoid是一個激活函數,二分類問題中可將實數映射到[0, 1],作為概率值, 多分類用softmax函數)def forward(self, img):img_flat = img.view(img.size(0), -1) # 鑒別器輸入是一個被view展開的(784)的一維圖像:(64, 784)validity = self.model(img_flat)      # 通過鑒別器網絡return validity       

五、訓練模型并保存

## 創建生成器,判別器對象
generator = Generator()
discriminator = Discriminator()## 首先需要定義loss的度量方式  (二分類的交叉熵)
criterion = torch.nn.BCELoss()## 其次定義 優化函數,優化函數的學習率為0.0003
## betas:用于計算梯度以及梯度平方的運行平均值的系數
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))## 如果有顯卡,都在cuda模式中運行
if torch.cuda.is_available():generator     = generator.cuda()discriminator = discriminator.cuda()criterion     = criterion.cuda()## 進行多個epoch的訓練
for epoch in range(n_epochs):                   # epoch:50for i, (imgs, _) in enumerate(dataloader):  # imgs:(64, 1, 28, 28)     _:label(64)## =============================訓練判別器==================## view(): 相當于numpy中的reshape,重新定義矩陣的形狀, 相當于reshape(128,784)  原來是(128, 1, 28, 28)imgs = imgs.view(imgs.size(0), -1)    # 將圖片展開為28*28=784  imgs:(64, 784)real_img = Variable(imgs).cuda()      # 將tensor變成Variable放入計算圖中,tensor變成variable之后才能進行反向傳播求梯度real_label = Variable(torch.ones(imgs.size(0), 1)).cuda()      ## 定義真實的圖片label為1fake_label = Variable(torch.zeros(imgs.size(0), 1)).cuda()     ## 定義假的圖片的label為0## ---------------------##  Train Discriminator## 分為兩部分:1、真的圖像判別為真;2、假的圖像判別為假## ---------------------## 計算真實圖片的損失real_out = discriminator(real_img)            # 將真實圖片放入判別器中loss_real_D = criterion(real_out, real_label) # 得到真實圖片的lossreal_scores = real_out                        # 得到真實圖片的判別值,輸出的值越接近1越好## 計算假的圖片的損失## detach(): 從當前計算圖中分離下來避免梯度傳到G,因為G不用更新z = Variable(torch.randn(imgs.size(0), latent_dim)).cuda()      ## 隨機生成一些噪聲, 大小為(128, 100)fake_img    = generator(z).detach()                                    ## 隨機噪聲放入生成網絡中,生成一張假的圖片。 fake_out    = discriminator(fake_img)                                  ## 判別器判斷假的圖片loss_fake_D = criterion(fake_out, fake_label)                       ## 得到假的圖片的lossfake_scores = fake_out                                              ## 得到假圖片的判別值,對于判別器來說,假圖片的損失越接近0越好## 損失函數和優化loss_D = loss_real_D + loss_fake_D  # 損失包括判真損失和判假損失optimizer_D.zero_grad()             # 在反向傳播之前,先將梯度歸0loss_D.backward()                   # 將誤差反向傳播optimizer_D.step()                  # 更新參數## -----------------##  Train Generator## 原理:目的是希望生成的假的圖片被判別器判斷為真的圖片,## 在此過程中,將判別器固定,將假的圖片傳入判別器的結果與真實的label對應,## 反向傳播更新的參數是生成網絡里面的參數,## 這樣可以通過更新生成網絡里面的參數,來訓練網絡,使得生成的圖片讓判別器以為是真的, 這樣就達到了對抗的目的## -----------------z = Variable(torch.randn(imgs.size(0), latent_dim)).cuda()      ## 得到隨機噪聲fake_img = generator(z)                                             ## 隨機噪聲輸入到生成器中,得到一副假的圖片output = discriminator(fake_img)                                    ## 經過判別器得到的結果## 損失函數和優化loss_G = criterion(output, real_label)                              ## 得到的假的圖片與真實的圖片的label的lossoptimizer_G.zero_grad()                                             ## 梯度歸0loss_G.backward()                                                   ## 進行反向傳播optimizer_G.step()                                                  ## step()一般用在反向傳播后面,用于更新生成網絡的參數## 打印訓練過程中的日志## item():取出單元素張量的元素值并返回該值,保持原元素類型不變if (i + 1) % 300 == 0:print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] [D real: %f] [D fake: %f]"% (epoch, n_epochs, i, len(dataloader), loss_D.item(), loss_G.item(), real_scores.data.mean(), fake_scores.data.mean()))## 保存訓練過程中的圖像batches_done = epoch * len(dataloader) + iif batches_done % sample_interval == 0:save_image(fake_img.data[:25], "./images/%d.png" % batches_done, nrow=5, normalize=True)## 保存模型
torch.save(generator.state_dict(), './save/generator.pth')
torch.save(discriminator.state_dict(), './save/discriminator.pth')


總結:

本項目中,我們實現了一個基于 MNIST 數據集的生成對抗網絡(GAN),主要流程從參數配置、數據準備,到模型構建與訓練,最后再到結果保存,形成了一個完整的生成式模型訓練管線。

首先,在超參數設置上,我們采用了經典 GAN 論文中推薦的組合:學習率設為 0.0002,Adam 優化器的 \beta_1 和 \beta_2 分別為 0.5 和 0.999,訓練 50 個周期,批量大小為 64。這些參數能在保證穩定訓練的同時,加快收斂速度。

數據部分選用了MNIST 手寫數字集,先將像素歸一化到 [-1, 1],再通過 DataLoader 按批次讀取并打亂順序。這一處理不僅保證了輸入分布的穩定性,也提升了訓練效率。

在模型結構方面,判別器(D)是一個多層全連接網絡,將 28×28 圖像展平為 784 維向量后輸入,激活函數使用 LeakyReLU,最后通過 Sigmoid 得到真假概率;生成器(G)則以 100 維隨機噪聲為輸入,經過多層全連接與 ReLU/Tanh 激活,輸出與真實圖像同尺寸的 28×28 結果。這樣的結構簡單直觀,適合入門實驗。

訓練時,判別器與生成器交替優化:

  • 判別器的目標是最大化對真實圖像的判真概率、對生成圖像的判假概率;

  • 生成器的目標則是讓判別器將其生成的圖像判為真。

    損失函數統一采用二分類交叉熵(BCELoss),并為 D、G 分別設置優化器以避免梯度更新沖突。訓練過程中會定期輸出損失值并保存生成樣本,方便對生成效果進行直觀評估。

GAN 的核心思想,是讓生成器與判別器在對抗博弈中共同提升能力:G 學會捕捉數據分布特征,D 學會分辨真實與偽造,兩者在動態平衡中逼近真實分布。這種機制使得 GAN 特別適合用于圖像生成、數據增強和風格遷移等任務。

從實驗體驗來看,這套代碼的優點在于結構清晰、可視化直觀、參數穩定,非常適合作為學習 GAN 的起點。同時,經過適當修改,還能擴展到更復雜的生成任務,為后續的研究打下基礎。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/93250.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/93250.shtml
英文地址,請注明出處:http://en.pswp.cn/web/93250.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Nginx性能優化與安全配置:打造高性能Web服務器

系列文章索引: 第一篇:《Nginx入門與安裝詳解:從零開始搭建高性能Web服務器》第二篇:《Nginx基礎配置詳解:nginx.conf核心配置與虛擬主機實戰》第三篇:《Nginx代理配置詳解:正向代理與反向代理…

二分算法(模板)

例題1: 704. 二分查找 - 力扣(LeetCode) 算法原理:(二分) 通過遍歷也可以通過,但是二分更優且數據量越大越能體現。 二分思路: 1.mid1 (left right)/2 與 mid2 right (right …

VUE3 學習筆記2 computed、watch、生命周期、hooks、其他組合式API

computed 計算屬性在vue3中,雖然也能寫vue2的computed,但還是更推薦使用vue3語法的computed。在Vue3中,計算屬性是組合式API,要想使用computed,需要先對computed進行引入:import { computed } from vuecomp…

【java面試day13】mysql-定位慢查詢

文章目錄問題💬 Question 1相關知識問題 💬 Question 1 Q:這條sql語句執行很慢,你如何分析呢? A:當一條 SQL 執行較慢時,可以先使用 EXPLAIN 查看執行計劃,通過 key 和 key_len 判…

3分鐘解鎖網頁“硬盤“能力:離線運行VSCode的新一代Web存儲技術

Hi,我是前端人類學(之前叫布蘭妮甜)! “這不是瀏覽器,這是裝了個硬盤。” —— 用戶對現代Web應用能力的驚嘆 隨著Origin Private File System和IndexedDB Stream等新技術的出現,Web應用現在可以在用戶的設…

LT6911GXD,HD-DVI2.1/DP1.4a/Type-C 轉 Dual-port MIPI/LVDS with Audio 帶音頻

簡介LT6911GXD是一款高性能HD-DVI2.1/DP1.4a/Type-c轉Dual-port MIPI/LVDS芯片,兼容 HDMI2.1、HDMI2.0b、HDMI1.4、DVI1.0、DisplayPort 1.4a、eDP1.4b 等多種視頻接口標準。支持4K(38402160)60Hz的DSC直通。應用場景AR/VR設備LT6911GXD 支持高達 4K(384…

【100頁PPT】數字化轉型某著名企業集團信息化頂層規劃方案(附下載方式)

篇幅所限,本文只提供部分資料內容,完整資料請看下面鏈接 https://download.csdn.net/download/2501_92808811/91662628 資料解讀:數字化轉型某著名企業集團信息化頂層規劃方案 詳細資料請看本解讀文章的最后內容 作為企業數字化轉型領域的…

高精度標準鋼卷尺優質廠家、選購建議

高精度標準鋼卷尺的優質廠家通常具備精湛工藝與權威精度認證等特征,能為產品質量提供保障。其選購需兼顧精度標識、使用場景、結構細節等多方面,具體介紹如下:一、高精度標準鋼卷尺優質廠家**1、河南普天同創:**PTTC-C5標準鋼卷尺…

38 C++ STL模板庫7-迭代器

C STL模板庫7-迭代器 文章目錄C STL模板庫7-迭代器一、迭代器的核心作用二、迭代器的五大分類與操作三、關鍵用法與代碼示例1. 迭代器的原理2. 迭代器用法與示例3. 迭代工具用法示例4. 使用技巧迭代器是C中連接容器與算法的通用接口,提供了一種訪問容器元素的統一方…

【0基礎3ds Max】學習計劃

3ds Max 作為一款功能強大的專業 3D 計算機圖形軟件,在影視動畫、游戲開發、建筑可視化、產品設計和工業設計等眾多領域有著廣泛的應用。 目錄前言一、第一階段:基礎認知(第 1 - 2 周)?二、第二階段:建模技術學習&…

用 Enigma Virtual Box 將 Qt 程序打包成單 exe

上一篇介紹了用windeployqt生成可運行的多文件程序,但一堆文件分發起來不夠方便。有沒有辦法將所有文件合并成一個 exe? 答案是肯定的 用Enigma Virtual Box工具就能實現。本文就來講解如何用它將 Qt 多文件程序打包為單一 exe,讓分發更輕松。 其中的 一定要選 第二個 一…

【LeetCode 熱題 100】45. 跳躍游戲 II

Problem: 45. 跳躍游戲 II 給定一個長度為 n 的 0 索引整數數組 nums。初始位置為 nums[0]。 每個元素 nums[i] 表示從索引 i 向后跳轉的最大長度。換句話說&#xff0c;如果你在索引 i 處&#xff0c;你可以跳轉到任意 (i j) 處&#xff1a; 0 < j < nums[i] 且 i j &…

池式管理之線程池

1.初識線程池問&#xff1a;線程池是什么&#xff1f;答&#xff1a;維持管理一定數量的線程的池式結構。&#xff08;維持&#xff1a;線程復用 。 管理&#xff1a;沒有收到任務的線程處于阻塞休眠狀態不參與cpu調度 。一定數量&#xff1a;數量太多的線程會給操作系統帶來線…

嬰兒 3D 安睡系統專利拆解:搭扣與智能系帶的鎖定機制及松緊調節原理

凌晨2點&#xff0c;你盯著嬰兒床里的小肉團直嘆氣。剛用襁褓裹成小粽子才哄睡的寶寶&#xff0c;才半小時就蹬開了裹布&#xff0c;小胳膊支棱得像只小考拉&#xff1b;你手忙腳亂想重新裹緊&#xff0c;結果越裹越松&#xff0c;裹布滑到脖子邊&#xff0c;寶寶突然一個翻身&…

pandas中df.to _dict(orient=‘records‘)方法的作用和場景說明

df.to _dict(orientrecords) 是 Pandas DataFrame 的一個方法&#xff0c;用于將數據轉換為字典列表格式。以下是詳細解釋及實例說明&#xff1a; 一、核心含義作用 將 DataFrame 的每一行轉換為一個字典&#xff0c;所有字典組成一個列表。 每個字典的鍵&#xff08;key&#…

阿里云Anolis OS 8.6的公有云倉庫源配置步驟

文章目錄一、備份現有倉庫配置&#xff08;防止誤操作&#xff09;二、配置阿里云鏡像源2.1 修改 BaseOS 倉庫2.2 修改 AppStream 倉庫三、清理并重建緩存四、驗證配置4.1 ?檢查倉庫狀態?&#xff1a;五、常見問題解決5.1 ?HTTP 404 錯誤5.2 ?網絡連接問題附&#xff1a;其…

回歸預測 | Matlab實現CNN-BiLSTM-self-Attention多變量回歸預測

回歸預測 | Matlab實現CNN-BiLSTM-self-Attention多變量回歸預測 目錄回歸預測 | Matlab實現CNN-BiLSTM-self-Attention多變量回歸預測預測效果基本介紹程序設計參考資料預測效果 基本介紹 1.Matlab實現CNN-BiLSTM融合自注意力機制多變量回歸預測&#xff0c;CNN-BiLSTM-self-…

103、【OS】【Nuttx】【周邊】文檔構建渲染:Sphinx 配置文件

【聲明】本博客所有內容均為個人業余時間創作&#xff0c;所述技術案例均來自公開開源項目&#xff08;如Github&#xff0c;Apache基金會&#xff09;&#xff0c;不涉及任何企業機密或未公開技術&#xff0c;如有侵權請聯系刪除 背景 接之前 blog 【OS】【Nuttx】【周邊】文…

轉換一個python項目到moonbit,碰到報錯輸出:編譯器對workflow.mbt文件中的類方法要求不一致的類型注解,導致無法正常編譯

先上結論&#xff1a;現在是moon test的時候有很多報錯&#xff0c;消不掉。問題在Trae中用GLM-4.5模型&#xff0c;轉換一個python項目到moonbit&#xff0c;碰到報錯輸出&#xff1a;報錯輸出經過多次嘗試修復&#xff0c;我發現這是一個MoonBit編譯器的bug。編譯器對workflo…

【C#補全計劃】事件

一、事件的概念1. 事件是基于委托的存在&#xff0c;是委托的安全包裹&#xff0c;讓委托的使用更具有安全性2. 事件是一種特殊的變量類型二、事件的使用1. 語法&#xff1a;event 委托類型 事件名;2. 使用&#xff1a;&#xff08;1&#xff09;事件是作為成員變量存在與類中&…