Gan論文閱讀筆記

GAN論文閱讀筆記

2014年老論文了,主要記錄一些重要的東西。論文鏈接如下:

Generative Adversarial Nets (neurips.cc)

文章目錄

  • GAN論文閱讀筆記
    • 出發點
    • 創新點
    • 設計
    • 訓練代碼
    • 網絡結構代碼
    • 測試代碼

出發點

Deep generative models have had less of an impact, due to the difficulty of approximating many intractable probabilistic computations that arise in maximum likelihood estimation and related strategies, and due to difficulty of leveraging the benefits of piecewise linear units in the generative context.

? 當時的生成模型效果不佳在于近似許多棘手的概率計算十分困難,如最大似然估計等。除此之外,把利用分段線性單元運用到生成場景中也有困難。于是作者提出新的生成模型:GAN。

? 我的理解是,當時的生成模型都是去學習模型生成數據的分布,比如確定方差,確定均值之類的參數,然而這種方法十分難以學習,而且計算量大而復雜,作者考慮到這一點,對生成模型采用端到端的學習策略,不去學習生成數據的分布,而是直接學習模型,只要這個模型的生成結果能夠逼近Ground-Truth,那么就可以直接用這個模型代替分布去生成數據。這是典型的黑箱思想。

創新點

adiscriminative model that learns to determine whether a sample is from the model distribution or the data distribution. The generative model can be thought of as analogous to a team of counterfeiters, trying to produce fake currency and use it without detection, while the discriminative model is analogous to the police, trying to detect the counterfeit currency. Competition in this game drives both teams to improve their methods until the counterfeits are indistiguishable from the genuine articles.

創新點1:提出對抗學習策略:提出兩個model之間相互對抗,相互抑制的策略。一個model名為生成器Generator,一個model名為判別器Discriminator,生成器盡可能生成接近真實的數據,判別器盡可能識別出生成器數據是Fake。

In this article, we explore the special case when the generative model generates samples by passing random noise through a multilayer perceptron, and the discriminative model is also a multilayer perceptron.

創新點2:當兩個model都使用神經網絡時,可以運用反向傳播和Dropout等算法進行學習,這樣就可以避免使用馬爾科夫鏈。

設計

To learn the generator’s distribution pgover data x, we define a prior on input noise variables pz(z), then represent a mapping to data space as G(z; θg), where G is a differentiable function represented by a multilayer perceptron with parameters θg. We also define a second multilayer perceptron D(x; θd) that outputs a single scalar. D(x) represents the probability that x came from the data rather than pg.

1.輸入:為了讓生成器G生成的數據分布pg與真實數據分布x接近,策略是給G輸入一個噪音變量z,然后學習參數θg,這個θg是G網絡權重。因此,G可以被寫作:G(z;θg)。
m i n G m a x D V ( D , G ) = E x ~ p d a t a ( x ) [ l o g D ( x ) ] + E z ~ p z ( z ) [ l o g ( 1 ? D ( G ( z ) ) ) ] \underset{G}{min}\underset{D}{max}V(D, G) =\mathbb{E}_{x \sim p_{data}(x)}\left[ logD(x)\right] + \mathbb{E}_{z \sim p_z(z)}\left[log(1 - D(G(z)))\right] Gmin?Dmax?V(D,G)=Expdata?(x)?[logD(x)]+Ezpz?(z)?[log(1?D(G(z)))]
2.對抗性損失函數:從代碼可知,對抗性損失是兩個BCELoss的和,V盡可能使D(x)更大,在此基礎上盡可能使G(z)更小。這是有先后順序的,在后面會做說明。

在代碼中可知,先人為生成兩個標簽,第一個標簽是用torch.ones生成的全為1的矩陣,形狀為(batch,1)。其中batch是輸入噪聲的batch,第二維度只是一個數字——1,這個標簽用于判別器D的BCELoss中,代入BCELoss即可得到上面對抗性損失中左側的期望。第二個標簽是用torch.zeors生成的全為0的矩陣,形狀同理為(batch,1),運用于生成器G的BCELoss中,代入即可得到對抗性損失的右側期望。

we alternate between k steps of optimizing D and one step of optimizing G.

This results in D being maintained near its optimal solution, so long as G changes slowly enough.

3.D與G的訓練有先后順序:判別器D先于生成器G訓練,而且要求先對D訓練k步,再為G訓練1步,這就保證G的訓練比D足夠慢。

如果生成器G足夠強大,那么判別器無法再監測生成器,也就沒有對抗的必要了。相反,如果判別器D太過于強大,那么生成器也訓練地十分緩慢。

在這里插入圖片描述

4.算法圖如上。

訓練代碼

import torch
import torch.nn as nn
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from Model import generator
from Model import discriminatorimport osif not os.path.exists('gan_train.py'):  # 報錯中間結果os.mkdir('gan_train.py')def to_img(x):  # 將結果的-0.5~0.5變為0~1保存圖片out = 0.5 * (x + 1)out = out.clamp(0, 1)out = out.view(-1, 1, 28, 28)return outbatch_size = 96
num_epoch = 200
z_dimension = 100# 數據預處理
img_transform = transforms.Compose([transforms.ToTensor(),  # 圖像數據轉換成了張量,并且歸一化到了[0,1]。transforms.Normalize([0.5], [0.5])  # 這一句的實際結果是將[0,1]的張量歸一化到[-1, 1]上。前面的(0.5)均值, 后面(0.5)標準差,
])
# MNIST數據集
mnist = datasets.MNIST(root='./data', train=True, transform=img_transform, download=True)
# 數據集加載器
dataloader = torch.utils.data.DataLoader(dataset=mnist, batch_size=batch_size, shuffle=True)D = discriminator()  # 創建生成器
G = generator()  # 創建判別器
if torch.cuda.is_available():  # 放入GPUD = D.cuda()G = G.cuda()criterion = nn.BCELoss()  # BCELoss 因為可以當成是一個分類任務,如果后面不加Sigmod就用BCEWithLogitsLoss
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0003)  # 優化器
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0003)  # 優化器# 開始訓練
for epoch in range(num_epoch):for i, (img, _) in enumerate(dataloader):  # img[96,1,28,28]G.train()num_img = img.size(0)  # num_img=batchsize# =================train discriminatorimg = img.view(num_img, -1)  # 把圖片拉平,為了輸入判別器 [96,784]real_img = img.cuda()  # 裝進cuda,真實圖片real_label = torch.ones(num_img).reshape(num_img, 1).cuda()  # 希望判別器對real_img輸出為1 [96,1]fake_label = torch.zeros(num_img).reshape(num_img, 1).cuda()  # 希望判別器對fake_img輸出為0  [96,1]# 先訓練鑒別器# 計算真實圖片的lossreal_out = D(real_img)  # 將真實圖片輸入鑒別器 [96,1]d_loss_real = criterion(real_out, real_label)  # 希望real_out越接近1越好 [1]real_scores = real_out  # 后面print用的# 計算生成圖片的lossz = torch.randn(num_img, z_dimension).cuda()  # 創建一個100維度的隨機噪聲作為生成器的輸入 [96,1]#   這個z維度和生成器第一個Linear第一個參數一致# 避免計算G的梯度fake_img = G(z).detach()  # 生成偽造圖片 [96,748]fake_out = D(fake_img)  # 給判別器判斷生成的好不好 [96,1]d_loss_fake = criterion(fake_out, fake_label)  # 希望判別器給fake_out越接近0越好 [1]fake_scores = fake_out  # 后面print用的d_loss = d_loss_real + d_loss_faked_optimizer.zero_grad()d_loss.backward()d_optimizer.step()# 訓練生成器# 計算生成圖片的lossz = torch.randn(num_img, z_dimension).cuda()  # 生成隨機噪聲 [96,100]fake_img = G(z)  # 生成器偽造圖像 [96,784]output = D(fake_img)  # 將偽造圖像給判別器判斷真偽 [96,1]g_loss = criterion(output, real_label)  # 生成器希望判別器給的值越接近1越好 [1]# 更新生成器g_optimizer.zero_grad()g_loss.backward()g_optimizer.step()if (i + 1) % 100 == 0:print(f'Epoch [{epoch}/{num_epoch}], d_loss: {d_loss.cpu().detach():.6f}, g_loss: {g_loss.cpu().detach():.6f}',f'D real: {real_scores.cpu().detach().mean():.6f}, D fake: {fake_scores.cpu().detach().mean():.6f}')if epoch == 0:  # 保存圖片real_images = to_img(real_img.detach().cpu())save_image(real_images, './img_gan/real_images.png')fake_images = to_img(fake_img.detach().cpu())save_image(fake_images, f'./img_gan/fake_images-{epoch + 1}.png')G.eval()with torch.no_grad():new_z = torch.randn(batch_size, 100).cuda()test_img = G(new_z)print(test_img.shape)test_img = to_img(test_img.detach().cpu())test_path = f'./test_result/the_{epoch}.png'save_image(test_img, test_path)# 保存模型
torch.save(G.state_dict(), './generator.pth')
torch.save(D.state_dict(), './discriminator.pth')

網絡結構代碼

import torch
from torch import nn# 判別器 判別圖片是不是來自MNIST數據集
class discriminator(nn.Module):def __init__(self):super(discriminator, self).__init__()self.dis = nn.Sequential(nn.Linear(784, 256),  # 784=28*28nn.LeakyReLU(0.2),nn.Linear(256, 256),nn.LeakyReLU(0.2),nn.Linear(256, 1),nn.Sigmoid()#   sigmoid輸出這個生成器是或不是原圖片,是二分類)def forward(self, x):x = self.dis(x)return x# 生成器 生成偽造的MNIST數據集
class generator(nn.Module):def __init__(self):super(generator, self).__init__()self.gen = nn.Sequential(nn.Linear(100, 256),  # 輸入為100維的隨機噪聲nn.ReLU(),nn.Linear(256, 256),nn.ReLU(),nn.Linear(256, 784),#   生成器輸出的特征維和正常圖片一樣,這是一個可參考的點nn.Tanh())def forward(self, x):x = self.gen(x)return xclass FinetuneModel(nn.Module):def __init__(self, weights):super(FinetuneModel, self).__init__()self.G = generator()base_weights = torch.load(weights)model_parameters = dict(self.G.named_parameters())#   不是對model進行named_parameters,而是對model里面的具體網絡進行named_parameters取出參數,否則取出的是model冗余的參數去測試pretrained_weights = {k: v for k, v in base_weights.items() if k in model_parameters}new_state_dict = {k: pretrained_weights[k] for k in model_parameters.keys()}self.G.load_state_dict(new_state_dict)def forward(self, input):output = self.G(input)return output

測試代碼

import os
import sys
import numpy as np
import torch
import argparse
import torch.utils.data
from PIL import Image
from Model import FinetuneModel
from Model import generator
from torchvision.utils import save_imageparser = argparse.ArgumentParser("GAN")
parser.add_argument('--save_path', type=str, default='./test_result')
parser.add_argument('--gpu', type=int, default=0)
parser.add_argument('--seed', type=int, default=2)
parser.add_argument('--model', type=str, default='generator.pth')args = parser.parse_args()
save_path = args.save_path
os.makedirs(save_path, exist_ok=True)def to_img(x):  # 將結果的-0.5~0.5變為0~1保存圖片out = 0.5 * (x + 1)out = out.clamp(0, 1)out = out.view(-1, 1, 28, 28)return outdef main():if not torch.cuda.is_available():print("no gpu device available")sys.exit(1)model = FinetuneModel(args.model)model = model.to(device=args.gpu)model.eval()z_dimension = 100with torch.no_grad():for i in range(100):z = torch.randn(96, z_dimension).cuda()  # 創建一個100維度的隨機噪聲作為生成器的輸入 [96,100]output = model(z)print(output.shape)u_name = f'the_{i}.png'print(f'processing {u_name}')u_path = save_path + '/' + u_nameoutput = to_img(output.cpu().detach())save_image(output, u_path)if __name__ == '__main__':main()

本文畢

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/208299.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/208299.shtml
英文地址,請注明出處:http://en.pswp.cn/news/208299.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

軟件壓力測試的重要性與用途

在當今數字化的時代,軟件已經成為幾乎所有行業不可或缺的一部分。隨著軟件應用規模的增加和用戶數量的上升,軟件的性能變得尤為關鍵。為了確保軟件在面對高并發和大負載時仍然能夠保持穩定性和可靠性,軟件壓力測試變得至關重要。下面是軟件壓…

提醒事項日歷同步怎么設置?可實時同步日歷的提醒事項工具

隨著生活節奏的加快,我們每天都需要處理許多瑣碎的事務。為了不忘記重要的事情,很多人選擇使用提醒事項工具來幫助自己。然而,市場上的提醒事項工具五花八門,有些并不具備日歷月視圖功能,也無法與手機日歷同步&#xf…

JavaScript 復雜的<三元運算符和比較操作>的組合--案例(一)

在逆向的時候,碰上有些復雜的js代碼,邏輯弄得人有點混; 因此本帖用來記錄一些棘手的代碼,方便自己記憶,也讓大家拓展認識~ ----前言 內容: function(e, t, n) {try {1 (e "{" e[0] ? JSON.parse(e) : JSON.parse(webInstace.shell(e))).Status || 200 e.Code…

Linux學習筆記7-IIC的應用和AP3216C

接下來進入其他兩種串行通信方式:SPI和I2C的學習,因為以后的項目中會用到這些通信方式,而且正點原子的開發板里面也有用I2C和SPI通信的傳感器來做實例,分別是一個距離傳感器和六軸陀螺儀,這樣就可以很好的通過實例來學…

GRE與順豐圓通快遞盒子

1. DNS污染 隨想: 在輸入一串網址后,會發生如下變化如果你在系統中配置了 Hosts 文件,那么電腦會先查詢 Hosts 文件如果 Hosts 里面沒有這個別名,就通過域名服務器查詢域名服務器回應了,那么你的電腦就可以根據域名服…

第六屆“強網”擬態防御國際精英挑戰賽——入圍戰隊篇

第六屆“強網”擬態防御國際精英挑戰賽即將于2023年12月6日在南京盛大開賽!本屆挑戰賽再次為全球頂尖戰隊提供實戰機會,向多類擬態防御設備系統發起挑戰,在眾測實戰中持續檢驗中國制造內生安全數字產品所具有的中國力量。 本屆挑戰賽參賽戰隊…

【LeetCode:1466. 重新規劃路線 | DFS + 圖 + 樹】

🚀 算法題 🚀 🌲 算法刷題專欄 | 面試必備算法 | 面試高頻算法 🍀 🌲 越難的東西,越要努力堅持,因為它具有很高的價值,算法就是這樣? 🌲 作者簡介:碩風和煒,…

Vue 子路由頁面發消息給主路由頁面 ,實現主頁面顯示子頁面的信息

需求 子頁面進入后,能在主頁面顯示子頁的相關信息,比如說主頁面的菜單激活的是哪個子頁面的菜單項 如上圖,當刷新瀏覽器頁面時,讓菜單的激活項仍保持在【最近瀏覽】。 實現方式: 在子頁面的create事件中增加&#xff…

Java File類詳解(下)練習一

練習 第一題 需求:在當前模塊下的aaa文件夾中創建一個a.txt文件 import java.io.File; import java.io.IOException;public class FileExer01 {public static void main(String[] args) throws IOException {File f1 new File("AllInOne\\aaa");f1.mk…

docker-compose腳本編寫關鍵詞詳解

docker-compose腳本編寫高頻關鍵詞(一) 此處關鍵詞應該必須能靈活運用 關鍵詞 解釋 例子 version 定義使用的docker-compose文件版本。較新的版本支持更豐富的功能和選項。 version: 3.8 services 定義應用程序的各個服務及其配置。每個服務通常…

Vue:繪制圖例

本文記錄使用Vue框架繪制圖例的代碼片段。 可以嵌入到cesium視圖中,也可以直接繪制到自己的原生系統中。 一、繪制圖例Vue組件 <div v-for="(color, index) in colors" :key="index" class="legend-item"><div class="color-…

深度學習還可以從如下方面進行創新!!

文章目錄 一、我認為可以從如下5個方向進行創新總結 一、我認為可以從如下5個方向進行創新 新的模型結構&#xff1a;盡管現在的深度學習模型已經非常強大&#xff0c;但是還有很多未被探索的模型結構。探索新的模型結構可以帶來更好的性能和更低的計算成本。 新的優化算法&a…

JavaScript數組面試題

JavaScript數組面試題 創建一個包含多個元素的數組&#xff0c;并打印輸出數組的內容。 const array ["apple", "banana", "orange"]; console.log(array);如何訪問數組中的特定元素&#xff1f; const array ["apple", "banan…

JS判斷數組中是否包含某個值

方法一&#xff1a;array.indexOf 此方法判斷數組中是否存在某個值&#xff0c;如果存在&#xff0c;則返回數組元素的下標&#xff0c;否則返回-1。 var arr[1,2,3,4]; var indexarr.indexOf(3); console.log(index);方法二&#xff1a;array.includes(searcElement[,fromIn…

一個簡單的postman設置斷言,為何會難住一個工作5年的測試?

postman設置斷言 作為一款接口測試工 具&#xff0c;postman需要對發送請求后返回的結果是否正確做驗證&#xff0c;在postman中通過 tests頁簽做請求的驗證&#xff0c;也稱為斷言。 postman設置斷言的流程 1、在tests頁簽截取要對比的實際響應信息&#xff08;響應頭、響應…

眼花繚亂的ADN/ADX/DSP/DMP/SSP和他們的關系鏈

做過互聯網廣告尤其是程序化廣告的同學都遇到過以下這些名詞&#xff0c;或許正被他們折磨的焦頭爛額&#xff0c;這篇文章&#xff0c;我們就來說說這些概念的含義及他們之間的關系鏈。 ADN&#xff1a;AD Network——廣告網絡或廣告聯盟。連接廣告主和媒體的中間商。 ADX&…

stm32串口編程實例-實現數據的收發功能

大家好&#xff0c;今天給大家介紹stm32串口編程實例&#xff0c;文章末尾附有分享大家一個資料包&#xff0c;差不多150多G。里面學習內容、面經、項目都比較新也比較全&#xff01;可進群免費領取。 串口是USART(通用同步/異步收發器)的俗稱。 實際上&#xff0c;串行總線并不…

2023年8月8日 Go生態洞察:Go 1.21 版本發布探索

&#x1f337;&#x1f341; 博主貓頭虎&#xff08;&#x1f405;&#x1f43e;&#xff09;帶您 Go to New World?&#x1f341; &#x1f984; 博客首頁——&#x1f405;&#x1f43e;貓頭虎的博客&#x1f390; &#x1f433; 《面試題大全專欄》 &#x1f995; 文章圖文…

中小企業都在用哪些開源項目管理工具?分享15款

推薦15個優秀的開源項目管理工具&#xff0c;比如&#xff1a;ProjectLibre、OpenProject、ERPNext、Redmine、禪道、Tuleap、Restyaboard等。 項目經理面臨各種復雜任務&#xff0c;包括追蹤任務的進度、評估交付風險和管理整體工作量。為了順利達成目標&#xff0c;一款靠譜的…

ALLEGRO PCB 如何設置增加的過孔

Allegro添加過孔 1、首先建立焊盤&#xff08;熱風焊盤&#xff09; Via20x10mil(tr30x45x12mil_45) 2、設置過孔的焊盤 Setup-->Constraints&#xff08;約束&#xff09;-->Physical 彈出以下對話框Allegro Constraint Manager 可以通過右鍵點擊PC S&#xff08;…