第G1周:生成對抗網絡(GAN)入門

🍨 本文為[🔗365天深度學習訓練營]內部限免文章(版權歸 *K同學啊* 所有)
🍖 作者:[K同學啊]

一、理論基礎
生成對抗網絡(Generative Adversarial Networks, GAN)是近年來深度學習領域的一個熱點方向。GAN并不指代某一個具體的神經網絡,而是指一類基于博弈思想而設計的神經網絡。GAN由兩個分別被稱為生成器(Generator)和判別器(Discriminator)的神經網絡組成。其中,生成器從某種噪聲分布中隨機采樣作為輸入,輸出與訓練集中真實樣本非常相似的人工樣本;判別器的輸入則為真實樣本或人工樣本,其目的是將人工樣本與真實樣本盡可能地區分出來。生成器和判別器交替運行,相互博弈,各自的能力都得到升。理想情況下,經過足夠次數的博弈之后,判別器無法判斷給定樣本的真實性,即對于所有樣本都輸出50%真,50%假的判斷。此時,生成器輸出的人工樣本已經逼真到使判別器無法分辨真假,停止博弈。這樣就可以得到一個具有“偽造”真實樣本能力的生成器。
1. 生成器

GANs中,生成器 G 選取隨機噪聲 z 作為輸入,通過生成器的不斷擬合,最終輸出一個和真實樣本尺寸相同,分布相似的偽造樣本G(z)。生成器的本質是一個使用生成式方法的模型,它對數據的分布假設和分布參數進行學習,然后根據學習到的模型重新采樣出新的樣本。
從數學上來說,生成式方法對于給定的真實數據,首先需要對數據的顯式變量或隱含變量做分布假設;然后再將真實數據輸入到模型中對變量、參數進行訓練;最后得到一個學習后的近似分布,這個分布可以用來生成新的數據。從機器學習的角度來說,模型不會去做分布假設,而是通過不斷地學習真實數據,對模型進行修正,最后也可以得到一個學習后的模型來做樣本生成任務。這種方法不同于數學方法,學習的過程對人類理解較不直觀。

2. 判別器
GANs中,判別器 D 對于輸入的樣本 x,輸出一個[0,1]之間的概率數值D(x)。x 可能是來自于原始數據集中的真實樣本 x,也可能是來自于生成器 G 的人工樣本G(z)。通常約定,概率值D(x)越接近于1就代表此樣本為真實樣本的可能性更大;反之概率值越小則此樣本為偽造樣本的可能性越大。也就是說,這里的判別器是一個二分類的神經網絡分類器,目的不是判定輸入數據的原始類別,而是區分輸入樣本的真偽。可以注意到,不管在生成器還是判別器中,樣本的類別信息都沒有用到,也表明 GAN 是一個無監督的學習過程。

3. 基本原理
GAN是博弈論和機器學習相結合的產物,于2014年Ian Goodfellow的論文中問世,一經問世即火爆足以看出人們對于這種算法的認可和狂熱的研究熱忱。想要更詳細的了解GAN,就要知道它是怎么來的,以及這種算法出現的意義是什么。研究者最初想要通過計算機完成自動生成數據的功能,例如通過訓練某種算法模型,讓某模型學習過一些蘋果的圖片后能自動生成蘋果的圖片,具備些功能的算法即認為具有生成功能。但是GAN不是第一個生成算法,而是以往的生成算法在衡量生成圖片和真實圖片的差距時采用均方誤差作為損失函數,但是研究者發現有時均方誤差一樣的兩張生成圖片效果卻截然不同,鑒于此不足Ian Goodfellow提出了GAN。

image.png

那么GAN是如何完成生成圖片這項功能的呢,如圖1所示,GAN是由兩個模型組成的:生成模型G和判別模型D。首先第一代生成模型1G的輸入是隨機噪聲z,然后生成模型會生成一張初級照片,訓練一代判別模型1D另其進行二分類操作,將生成的圖片判別為0,而真實圖片判別為1;為了欺瞞一代鑒別器,于是一代生成模型開始優化,然后它進階成了二代,當它生成的數據成功欺瞞1D時,鑒別模型也會優化更新,進而升級為2D,按照同樣的過程也會不斷更新出N代的G和D。

二、前期準備工作

1. 定義超參數

import argparse
import os
import numpy as np
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
import torch.nn as nn
import torch## 創建文件夾
os.makedirs("./images/", exist_ok=True)         ## 記錄訓練過程的圖片效果
os.makedirs("./save/", exist_ok=True)           ## 訓練完成時模型保存的位置
os.makedirs("./datasets/mnist", exist_ok=True)      ## 下載數據集存放的位置## 超參數配置
n_epochs=50
batch_size=512
lr=0.0002
b1=0.5
b2=0.999
n_cpu=2
latent_dim=100
img_size=28
channels=1
sample_interval=500## 圖像的尺寸:(1, 28, 28),  和圖像的像素面積:(784)
img_shape = (channels, img_size, img_size)
img_area = np.prod(img_shape)## 設置cuda:(cuda:0)
cuda = True if torch.cuda.is_available() else False
print(cuda)
## mnist數據集下載
mnist = datasets.MNIST(root='./datasets/', train=True, download=True, transform=transforms.Compose([transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]),
)
## 配置數據到加載器
dataloader = DataLoader(mnist,batch_size=batch_size,shuffle=True,
)
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.model = nn.Sequential(nn.Linear(img_area, 512),         # 輸入特征數為784,輸出為512nn.LeakyReLU(0.2, inplace=True),  # 進行非線性映射nn.Linear(512, 256),              # 輸入特征數為512,輸出為256nn.LeakyReLU(0.2, inplace=True),  # 進行非線性映射nn.Linear(256, 1),                # 輸入特征數為256,輸出為1nn.Sigmoid(),                     # sigmoid是一個激活函數,二分類問題中可將實數映射到[0, 1],作為概率值, 多分類用softmax函數)def forward(self, img):img_flat = img.view(img.size(0), -1) # 鑒別器輸入是一個被view展開的(784)的一維圖像:(64, 784)validity = self.model(img_flat)      # 通過鑒別器網絡return validity                      # 鑒別器返回的是一個[0, 1]間的概率
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()## 模型中間塊兒def block(in_feat, out_feat, normalize=True):        # block(in, out )layers = [nn.Linear(in_feat, out_feat)]          # 線性變換將輸入映射到out維if normalize:layers.append(nn.BatchNorm1d(out_feat, 0.8)) # 正則化layers.append(nn.LeakyReLU(0.2, inplace=True))   # 非線性激活函數return layers## prod():返回給定軸上的數組元素的乘積:1*28*28=784self.model = nn.Sequential(*block(latent_dim, 128, normalize=False), # 線性變化將輸入映射 100 to 128, 正則化, LeakyReLU*block(128, 256),                         # 線性變化將輸入映射 128 to 256, 正則化, LeakyReLU*block(256, 512),                         # 線性變化將輸入映射 256 to 512, 正則化, LeakyReLU*block(512, 1024),                        # 線性變化將輸入映射 512 to 1024, 正則化, LeakyReLUnn.Linear(1024, img_area),                # 線性變化將輸入映射 1024 to 784nn.Tanh()                                 # 將(784)的數據每一個都映射到[-1, 1]之間)## view():相當于numpy中的reshape,重新定義矩陣的形狀:這里是reshape(64, 1, 28, 28)def forward(self, z):                           # 輸入的是(64, 100)的噪聲數據imgs = self.model(z)                        # 噪聲數據通過生成器模型imgs = imgs.view(imgs.size(0), *img_shape)  # reshape成(64, 1, 28, 28)return imgs                                 # 輸出為64張大小為(1, 28, 28)的圖像
## 創建生成器,判別器對象
generator = Generator()
discriminator = Discriminator()## 首先需要定義loss的度量方式  (二分類的交叉熵)
criterion = torch.nn.BCELoss()## 其次定義 優化函數,優化函數的學習率為0.0003
## betas:用于計算梯度以及梯度平方的運行平均值的系數
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))## 如果有顯卡,都在cuda模式中運行
if torch.cuda.is_available():generator     = generator.cuda()discriminator = discriminator.cuda()criterion     = criterion.cuda()
for epoch in range(n_epochs):                   # epoch:50for i, (imgs, _) in enumerate(dataloader):  # imgs:(64, 1, 28, 28)     _:label(64)imgs = imgs.view(imgs.size(0), -1)    # 將圖片展開為28*28=784  imgs:(64, 784)real_img = Variable(imgs).cuda()      # 將tensor變成Variable放入計算圖中,tensor變成variable之后才能進行反向傳播求梯度real_label = Variable(torch.ones(imgs.size(0), 1)).cuda()      ## 定義真實的圖片label為1fake_label = Variable(torch.zeros(imgs.size(0), 1)).cuda()     ## 定義假的圖片的label為0real_out = discriminator(real_img)            # 將真實圖片放入判別器中loss_real_D = criterion(real_out, real_label) # 得到真實圖片的lossreal_scores = real_out                        # 得到真實圖片的判別值,輸出的值越接近1越好## 計算假的圖片的損失## detach(): 從當前計算圖中分離下來避免梯度傳到G,因為G不用更新z = Variable(torch.randn(imgs.size(0), latent_dim)).cuda()      ## 隨機生成一些噪聲, 大小為(128, 100)fake_img    = generator(z).detach()                                    ## 隨機噪聲放入生成網絡中,生成一張假的圖片。fake_out    = discriminator(fake_img)                                  ## 判別器判斷假的圖片loss_fake_D = criterion(fake_out, fake_label)                       ## 得到假的圖片的lossfake_scores = fake_out## 損失函數和優化loss_D = loss_real_D + loss_fake_D  # 損失包括判真損失和判假損失optimizer_D.zero_grad()             # 在反向傳播之前,先將梯度歸0loss_D.backward()                   # 將誤差反向傳播optimizer_D.step()                  # 更新參數z = Variable(torch.randn(imgs.size(0), latent_dim)).cuda()      ## 得到隨機噪聲fake_img = generator(z)                                             ## 隨機噪聲輸入到生成器中,得到一副假的圖片output = discriminator(fake_img)                                    ## 經過判別器得到的結果## 損失函數和優化loss_G = criterion(output, real_label)                              ## 得到的假的圖片與真實的圖片的label的lossoptimizer_G.zero_grad()                                             ## 梯度歸0loss_G.backward()                                                   ## 進行反向傳播optimizer_G.step()                                                  ## step()一般用在反向傳播后面,用于更新生成網絡的參數## 打印訓練過程中的日志## item():取出單元素張量的元素值并返回該值,保持原元素類型不變if ( i + 1 ) % 100 == 0:print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] [D real: %f] [D fake: %f]"% (epoch, n_epochs, i, len(dataloader), loss_D.item(), loss_G.item(), real_scores.data.mean(), fake_scores.data.mean()))## 保存訓練過程中的圖像batches_done = epoch * len(dataloader) + iif batches_done % sample_interval == 0:save_image(fake_img.data[:25], "./images/%d.png" % batches_done, nrow=5, normalize=True)
torch.save(generator.state_dict(), './generator.pth')
torch.save(discriminator.state_dict(), './discriminator.pth')

部分運行截圖:

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/37733.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/37733.shtml
英文地址,請注明出處:http://en.pswp.cn/news/37733.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Windows安裝Go開發環境

Windows安裝Go開發環境 一、Go語言下載地址 https://golang.google.cn/dl/ 二、設置工作空間GOPATH目錄(Go語言開發的項目路徑) 首先進入我的C盤(你放到其他盤也行),新建一個文件夾,名字叫做mygo(這個就是你的工作目…

ArcGIS Maps SDK for JavaScript系列之一:在Vue3中加載ArcGIS地圖

目錄 ArcGIS Maps SDK for JavaScript簡介ArcGIS Maps SDK for JavaScript 4.x 的主要特點和功能AMD modules 和 ES modules兩種方式比較Vue3中使用ArcGIS Maps SDK for JavaScript的步驟創建 Vue 3 項目安裝 ArcGIS Maps SDK for JavaScript創建地圖組件 ArcGIS Maps SDK for …

“深入理解JVM:探索Java虛擬機的內部工作原理“

標題:深入理解JVM:探索Java虛擬機的內部工作原理 摘要:本文將深入探索Java虛擬機(JVM)的內部工作原理,包括JVM的架構、類加載、內存管理、垃圾回收機制等方面。通過理解JVM的內部工作原理,我們…

華為開源自研AI框架昇思MindSpore應用案例:基于MindSpore框架的UNet-2D案例實現

目錄 一、環境準備1.進入ModelArts官網2.使用CodeLab體驗Notebook實例 二、環境準備與數據讀取三、模型解析Transformer基本原理Attention模塊 Transformer EncoderViT模型的輸入整體構建ViT 四、模型訓練與推理模型訓練模型驗證模型推理 近些年,隨著基于自注意&…

改造舊項目-長安分局人事費用管理系統

一、系統環境搭建 1、搭建前臺環境 vue3vite構建項目復制“銀稅系統”頁面結構,包括:路由、vuex存儲、菜單、登錄(復制一個干凈的空架子) 2、搭建后臺環境 新三大框架 SSMP聚合工程:common、admin,新的…

JAVA冒泡排序

package com.hzh.javase.day03;public class maopao {public static void main(String[] args) {int[] arr {2, 11,4,7,5,22,15,37,12,1};int zjvalue 0;//中間值boolean boofalse; //冒泡比較相鄰元素將小的提前打的放后 // 外層循環時用來控制輪數 // 內存循…

2023國賽數學建模E題思路分析

文章目錄 0 賽題思路1 競賽信息2 競賽時間3 建模常見問題類型3.1 分類問題3.2 優化問題3.3 預測問題3.4 評價問題 4 建模資料 0 賽題思路 (賽題出來以后第一時間在CSDN分享) https://blog.csdn.net/dc_sinor?typeblog 1 競賽信息 全國大學生數學建模…

Linux服務器上配置HTTP和HTTPS代理

本文將向你分享如何在Linux服務器上配置HTTP和HTTPS代理的方法,解決可能遇到的問題,讓你的爬蟲項目順利運行,暢爬互聯網! 配置HTTP代理的步驟 1. 了解HTTP代理的類型:常見的有正向代理和反向代理兩種類型。根據實際需求…

涉及近300個業務場景,重慶銀行數字員工平臺建設解析

隨著數字化轉型戰略規劃的逐步落地,重慶銀行于2022年6月成功建設了數字員工平臺,該平臺已成為行內數字化轉型的標桿應用。數字員工平臺以RPA(機器人流程自動化)為基礎,AI(人工智能)技術為抓手&a…

PHP最簡單自定義自己的框架view使用引入smarty(8)--自定義的框架完成

1、實現效果。引入smarty, 實現assign和 display 2、下載smarty,創建緩存目錄cache和擴展extend 點擊下面查看具體下載使用,下載改名后放到extend PHP之Smarty使用以及框架display和assign原理_PHP隔壁老王鄰居的博客-CSDN博客 3、當前控…

leetcode 力扣刷題 旋轉矩陣(循環過程邊界控制)

力扣刷題 旋轉矩陣 二維矩陣按圈遍歷(順時針 or 逆時針)遍歷59. 旋轉矩陣Ⅱ54. 旋轉矩陣劍指 Offer 29. 順時針打印矩陣 二維矩陣按圈遍歷(順時針 or 逆時針)遍歷 下面的題目的主要考察點都是,二維數組從左上角開始順…

輸出無重復的3位數和計算無人機飛行坐標

編程題總結 題目一:輸出無重復的3位數 題目描述 從{1,2,3,4,5,6,7,8,9}中隨機挑選不重復的5個數字作為輸入數組‘selectedDigits’,能組成多少個互不相同且無重復數字的3位數?請編寫程》序,從小到大順序,以數組形式輸出這些3位…

C# Linq源碼分析之Take (一)

概要 在.Net 6 中引入的Take的另一個重載方法,一個基于Range的重載方法。因為該方法中涉及了很多新的概念,所以在分析源碼之前,先將這些概念搞清楚。 Take方法基本介紹 public static System.Collections.Generic.IEnumerable Take (this …

【LeetCode: 2811. 判斷是否能拆分數組】

🚀 算法題 🚀 🌲 算法刷題專欄 | 面試必備算法 | 面試高頻算法 🍀 🌲 越難的東西,越要努力堅持,因為它具有很高的價值,算法就是這樣? 🌲 作者簡介:碩風和煒,…

NavMeshPlus 2D尋路插件

插件地址:h8man/NavMeshPlus: Unity NavMesh 2D Pathfinding (github.com) 我對Unity官方是深惡痛覺,一個2D尋路至今都沒想解決,這破引擎早點倒閉算了. 這插件是githun的開源項目,我本身是有寫jps尋路的,但是無法解決多個單位互相阻擋的問題(可以解決但是有性能問…

vue3+ts使用antv/x6 + 自定義節點

使用 2.x 版本 x6.antv 新官網: 安裝 npm install antv/x6 //"antv/x6": "^2.1.6",項目結構 1、初始化畫布 index.vue <template><div id"container"></div> </template><script setup langts> import { onM…

Python爬蟲——scrapy_基本使用

安裝scrapy pip install scrapy創建scrapy項目&#xff0c;需要在終端里創建 注意&#xff1a;項目的名字開頭不能是數字&#xff0c;也不能包含中文 scrapy startproject 項目名稱 示例&#xff1a; scrapy startproject scra_baidu_36創建好后的文件 3. 創建爬蟲文件&…

MySQL表的操作

文章目錄 MySQL表的操作1. 創建表2. 查看表2.1 查看數據庫中存在的表2.2 查看表的屬性2.3 查看創建時表的詳細信息 3. 修改表3.1 向表中添加記錄3.2 添加列3.3 修改列的數據類型3.4 刪除列3.5 表的重命名3.6 修改列名 4. 刪除表 MySQL表的操作 1. 創建表 CREATE TABLE table_…

博客系統【架構】

用戶管理 實現用戶的注冊、登錄、注銷等功能 使用Redis做緩存處理、阿里云短信服務 確保用戶身份驗證和安全性 使用Jwt來鑒權 userId (主鍵&#xff0c;自增長) username (唯一&#xff0c;用戶名)【用于普通登錄】email (唯一&#xff0c;用戶的電子郵件地址) password (存儲…

zabbix監控tomcat

一、zabbix監控Tomcat1.1 zbx-agent配置1.1.1 關閉防火墻&#xff0c;將安裝 Tomcat 所需軟件包傳到/opt目錄下1.1.2 安裝JDK1.1.3 設置JDK環境變量1.1.4 安裝啟動Tomcat1.1.5 配置 JMX 1.2 zbx-server配置1.2.1 安裝zabbix&#xff08;省略&#xff0c;可看上一篇博客&#xf…