PaperNotes(6)-GAN/DCGAN/WGAN/WGAN-GP/WGAN-SN-網絡結構/實驗效果

GAN模型網絡結構+實驗效果演化

  • 1.GAN
    • 1.1網絡結構
    • 1.2實驗結果
  • 2.DCGAN
    • 2.1網絡結構
    • 2.2實驗結果
  • 3.WGAN
    • 3.1網絡結構
    • 3.2實驗結果
  • 4.WGAN-GP
    • 4.1網絡結構
    • 4.2實驗結果
  • 5.WGAN-SN
    • 5.1網絡結構
    • 5.2實驗結果
  • 小結

1.GAN

文章: https://arxiv.org/pdf/1406.2661.pdf
代碼: Pylearn2, theano, https://github.com/goodfeli/adversarial

1.1網絡結構

多層感知機器(沒有在文章中找到)
G: ReLU, sigmoid
D:maxout, dropout

1.2實驗結果

1.數據集:MNIST,the Toronto Face Database (TFD) , CIFAR-10

2.Gaussian Parzen window 擬合樣本,輸出對應的log-likelihood.

3.直接展示了在三個圖像集合上的效果,最右遍一列顯示的是與第二列最相似的訓練樣本(具體如何衡量相近,需要查論文)

在這里插入圖片描述
a) MNIST,b) TFD, c) CIFAR-10 (fully connected model), d) CIFAR-10 (convolutional discriminatorand “deconvolutional” generator)

訓練次數呢?
這時候的cifar數據集基本不能看·

2.DCGAN

文章:https://arxiv.org/pdf/1511.06434.pdf
代碼:https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html,pytorch 官網DCAGAN教程,示例是人臉圖像生成

2.1網絡結構

P3:網絡結構表
1.去除所有的poling層
2.D,G中都使用batchnorm
3.移除全聯接結構
4.G激活函數:ReLU+Tanh(最后一層)
5.D激活函數:LeakyReLU(所有層)

2.2實驗結果

Lsun–視覺效果,300萬張圖像
Cifar10-分類實驗
人臉加減法實驗

3.WGAN

文章:https://arxiv.org/pdf/1701.07875.pdf
代碼:https://github.com/martinarjovsky/WassersteinGAN,作者github 提供的代碼,pytorch

3.1網絡結構

p9:以DCGAN為baseline, baseline 損失使用-logD 技巧
lipschitz約束實現:clip D網絡參數

3.2實驗結果

Lsun-bedromm 穩定性視覺實驗

WGAN本身是為了提高GAN模型訓練的穩定性而生的。文章強調的兩個優點啊:有意義的loss+穩定訓練過程。同一作者的后續文章(improved Training of Wasserstein GANs) 圖3,展示了clip 版本WGAN IS指標確實比不上DCGAN。
在這里插入圖片描述

4.WGAN-GP

文章:https://arxiv.org/pdf/1704.00028.pdf
代碼:https://github.com/igul222/improved_wgan_training,作者github 提供的代碼,tensorflow(明明是同一個作者寫的平臺還不一樣)
自己復現代碼時使用的是:https://github.com/eriklindernoren/PyTorch-GAN/tree/master/implementations WGAN_GP中寫的GP方法。

4.1網絡結構

1.G網絡帶BN,D網絡不實用Batch normalization, 轉而使用 layer Normalization
2.clip 會使優化變得困難,懲罰D網絡的梯度,使其不至于太大
3.在cifar-10 數據集合上D和G都使用resnet 結構

4.2實驗結果

在這里插入圖片描述
WGAN-GP能看出來一個輪廓,算是比較好的一個視覺效果了。

本人用pytroch 復現WGAN_GP ,參考了作者梯度懲罰的源碼(https://github.com/igul222/improved_wgan_training.)主體代碼是在WGAN的基礎上(https://github.com/martinarjovsky/WassersteinGAN),注釋了CLIP部分的代碼,在D損失函數的計算上增加了梯度懲罰項目(計算方式參考了網上的實現博文)。雖然生成的圖像視覺指標輪廓不錯,但是IS曲線(與baseline-WGAN 相比)并沒特別的優勢。

現在WGAN實現的時候,D網絡的更新次數在100/5之間切換,直接換成5 試一試

倉庫:https://github.com/caogang/wgan-gp (1000star)是pytorch復現的WGAN_GP具體效果沒有考察。

// 梯度懲罰的計算法函數
def compute_gradient_penalty(D, real_samples, fake_samples):"""Calculates the gradient penalty loss for WGAN GP"""# Random weight term for interpolation between real and fake samplesalpha = torch.Tensor(np.random.random((real_samples.size(0), 1, 1, 1))).cuda()# Get random interpolation between real and fake samples# print(real_samples.size(),fake_samples.size())# interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)interpolates = (alpha * fake_samples + ((1 - alpha) * real_samples)).requires_grad_(True)d_interpolates = D(interpolates)# d_interpolates = d_interpolates.resize(d_interpolates.size()[0],1)fake = Variable(torch.Tensor(real_samples.shape[0], 1).fill_(1.0), requires_grad=False).cuda()# Get gradient w.r.t. interpolatesgradients = torch.autograd.grad(outputs=d_interpolates,inputs=interpolates,grad_outputs=fake,create_graph=True,retain_graph=True,only_inputs=True,)[0]gradients = gradients.view(gradients.size(0), -1)gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()return gradient_penalty
...
// 判別器的損失函數的計算
gradient_penalty = utils.compute_gradient_penalty(netD,inputv_real,inputv_fake)
gradient_penalty *= lambda_gp
gradient_penalty.backward()
errD = errD_real - errD_fake - gradient_penalty
optimizerD.step()
d_iterations += 1

5.WGAN-SN

文章:https://arxiv.org/pdf/1802.05957.pdf
代碼:https://pytorch.org/docs/stable/generated/torch.nn.utils.spectral_norm.html?highlight=nn%20utils%20spectra#torch.nn.utils.spectral_norm,pytorch 官網上實現了D網絡參數譜正則化的代碼,直接在定義層的時候調用就可以了。

5.1網絡結構

D卷積結構,沒有BN
G卷積結構+BN
(沒有LN的情況)

5.2實驗結果

在這里插入圖片描述
cifar-10 上的結構,雖然也只能是看一個大致輪廓,但是,效果還是比較好的。

小結

1.BN 在mini-batch 較小或者RNN等動態網絡里效果不好,因為少量樣本的均值和方差無法反應整體的情況。BN強調了mini-batch 樣本之間的聯系。D網絡本身是將一個輸入映射到一個得分輸出,不應該考慮樣本之間的聯系,所以不應該使用BN,在WGAN-GP中轉而使用layer-normalization,對同一個樣本的各個通道做歸一化。

2.網絡越深,其生成能力越強,WGAN-GP論文中cifar-10 IS可以達到7左右,WGAN-Sn中也可以達到6.41,都是因為網絡結構不同,所以在淺層只有卷積的G(DCGAN,WGAN)想要達到那么高的IS一般是不可能的。

3.嘗試人臉生成數據集合the Toronto Face Database (TFD)

4.整理一下各個實驗的G訓練次數。

5.stack gan 的網絡結構基本還行。在做GP實驗的時候,至少得吧BN該成LN,再看看SN中是如何做的。


涉及WGAN的論文總共三篇:

WGAN前作:Towards Principled Methods for Training Generative Adversarial Networks
論文鏈接:https://arxiv.org/abs/1701.04862

WGAN:Wasserstein GAN
論文鏈接:https://arxiv.org/abs/1701.07875

WGAN后作:Improved Training of Wasserstein GANs
論文鏈接:https://arxiv.org/abs/1704.00028v3
都是神人Ishaan Gulrajani 寫的,連GAN之父Ian Goodfellow都十分驚嘆WGAN的改進內容。

神員各種類型GAN代碼實現(TensorFlow框架):https://github.com/LynnHo/AttGAN-Tensorflow

這三篇論文理論性都比較強,尤其是第一篇,涉及到比較多的理論公式推導。知乎鄭華濱的兩個論述,Wasserstein GAN最新進展:從weight clipping到gradient penalty,更加先進的Lipschitz限制手法在理論方面已經做了一個很好的介紹。不過對于很多數學不太好的同學(包括我自己),看著還是不太好理解,所以這里盡量站在做工程的角度,理一下這三篇文章的思路,這樣可以對作者的思路有一個比較清晰的理解。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/444918.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/444918.shtml
英文地址,請注明出處:http://en.pswp.cn/news/444918.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Spring Security使用

Spring Security 在web應用開發中,安全無疑是十分重要的,選擇Spring Security來保護web應用是一個非常好的選擇。 Spring Security 是spring項目之中的一個安全模塊,可以非常方便與spring項目無縫集成。特別是在spring boot項目中加入sprin…

nginx python webpy 配置安裝

安裝webpy$ wget http://webpy.org/static/web.py-0.34.tar.gz$ tar xvzf web.py-0.34.tar.gz$ cd web.py-0.34$ sudo python setup.py install安裝 Fluphttp://www.saddi.com/software/flup/dist/flup-1.0.2.tar.gz$ wget http://www.saddi.com/software/flup/dist/flup-1.0.2…

PaperNotes(7)-GANs模式坍塌/訓練不穩定

GANs-模式坍塌-訓練不穩定1.訓練不穩定問題相關文章1.1 DCGAN1.2Big-GAN1.3WGAN 、WGAN-GP、SN-WGAN1.4其他工作2.模式坍塌問題相關文章2.1 MAD-GAN2.2 Unrolled GAN2.3 DRAGAN2.4 D2GAN2.5 InfoGAN2.6 Deligan2.7 EBGAN2.8 Maximum Entropy Generators for Energy-Based Model…

thinkphp框架起步認識

先看看thinkphp的文檔吧:這是我在網上找的一個不錯的鏈接地址,對自己有用,同時相信對讀者也有用吧。 http://doc.thinkphp.cn/manual/class.html ThinkPHP 跨模塊調用操作方法(A方法與R方法) 跨模塊調用操作方法 前面說…

leetcode403 青蛙過河

一只青蛙想要過河。 假定河流被等分為 x 個單元格,并且在每一個單元格內都有可能放有一石子(也有可能沒有)。 青蛙可以跳上石頭,但是不可以跳入水中。 給定石子的位置列表(用單元格序號升序表示)&#xff…

PaperNotes(8)-Stein Variational Gradient Descent A General Purpose Bayesian Inference Algorithm

通用貝葉斯推理算法-Stein Variational Gradient DescentAbstract1 Introduction2 Background3 Variational Inference Using Smooth Transforms3.1 Stein Operator as the Derivative of KL Divergence定理3.1引理3.23.2 Stein Variational Gradient Descent4 Related Works5 …

thinkphp的增刪改查

ThinkPHP 添加數據 add 方法 ThinkPHP 內置的 add 方法用于向數據表添加數據,相當于 SQL 中的 INSERT INTO 行為。ThinkPHP Insert 添加數據添加數據 add 方法是 CURD(Create,Update,Read,Delete / 創建,修改,讀取,刪除)中的 Create 的實現&a…

leetcode115 不同的子序列

給定一個字符串 S 和一個字符串 T,計算在 S 的子序列中 T 出現的個數。 一個字符串的一個子序列是指,通過刪除一些(也可以不刪除)字符且不干擾剩余字符相對位置所組成的新字符串。(例如,"ACE" 是…

ThinkPHP 模板循環輸出 Volist 標簽

volist 標簽用于在模板中循環輸出數據集或者多維數組。volist 標簽在模塊操作中&#xff0c;select() 方法返回的是一個二維數組&#xff0c;可以用 volist 直接輸出&#xff1a;<volist name"list" id"vo"> 用 戶 名&#xff1a;{$vo[username]}&l…

MachineLearning(9)-最大似然、最小KL散度、交叉熵損失函數三者的關系

最大似然-最小KL散度-最小化交叉熵損失-三者的關系問題緣起&#xff1a;給定一組數據(x1,x2,...,xm)(x^1,x^2,...,x^m)(x1,x2,...,xm),希望找到這組數據服從的分布。此種情況下&#xff0c;分布規律用概率密度p(x)表征。 問題歸處&#xff1a;如果能夠建模/近似建模p(x)&#…

ThinkPHP redirect 頁面重定向使用詳解與實例

ThinkPHP redirect 方法ThinkPHP redirect 方法可以實現頁面的重定向&#xff08;跳轉&#xff09;功能。redirect 方法語法如下&#xff1a;$this->redirect(string url, array params, int delay, string msg) 參數說明&#xff1a;url 必須&#xff0c;重定向的 URL 表達…

PaperNotes(9)-Learning deep energy model: contrastive divergence vs. Amortized MLE

Learning deep energy model: contrastive divergence vs. Amortized MLEabstract1 Introduction2 Background2.1 stein variational gradient descent2.2 learning energy model**contrastive Divergence**abstract 受SVGD算法的啟發,本文提出兩個算法用于從數據中學習深度能…

windows下的gvim配置

首要任務是下載安裝Gvim7.3 。 安裝完后&#xff0c;gvim菜單中文出現亂碼&#xff0c;在_vimrcset文件中增加&#xff1a; " 配置多語言環境,解決中文亂碼問題 if has("multi_byte") " UTF-8 編碼 set encodingutf-8 set termencodingutf…

leetcode104 二叉樹的最大深度

給定一個二叉樹&#xff0c;找出其最大深度。 二叉樹的深度為根節點到最遠葉子節點的最長路徑上的節點數。 說明: 葉子節點是指沒有子節點的節點。 示例&#xff1a; 給定二叉樹 [3,9,20,null,null,15,7]&#xff0c; 3 / \ 9 20 / \ 15 7 返回它的最大深度…

C++的安全類型轉換的討論

關于強制類型轉換的問題,很多書都討論過,寫的最詳細的是C++ 之父的《C++的設計和演化》。最好的解決方法就是不要使用C風格的強制類型轉換,而是使用標準C++的類型轉換符:static_cast, dynamic_cast。標準C++中有四個類型轉換符:static_cast、dynamic_cast、reinterpret_ca…

PaperNotes(10)-Maximum Entropy Generators for Energy-Based Models

Maximum Entropy Generators for Energy-Based ModelsAbstract1 Introduction2 Background3 Maximum Entropy Generators for Energy-Based Models4 Experiments5 Related Work6 Conclusion7 AcknowledgementsAbstract 由于對數似然梯度的難以計算&#xff0c;能量模型的最大似…

leetcode105 前序中序遍歷序列構造二叉樹

根據一棵樹的前序遍歷與中序遍歷構造二叉樹。 注意: 你可以假設樹中沒有重復的元素。 例如&#xff0c;給出 前序遍歷 preorder [3,9,20,15,7] 中序遍歷 inorder [9,3,15,20,7] 返回如下的二叉樹&#xff1a; 3 / \ 9 20 / \ 15 7 思路&#xff1a; 1、…

c++的虛擬繼承 的一些思考吧

虛擬繼承是多重繼承中特有的概念。虛擬基類是為解決多重繼承而出現的。如:類D繼承自類B1、B2,而類B1、B2都繼承自類A,因此在類D中兩次出現類A中的變量和函數。為了節省內存空間,可以將B1、B2對A的繼承定義為虛擬繼承,而A就成了虛擬基類。實現的代碼如下: class A class …

對于linux socket與epoll配合相關的一些心得記錄

對于linux socket與epoll配合相關的一些心得記錄 沒有多少高深的東西&#xff0c;全當記錄&#xff0c;雖然簡單&#xff0c;但是沒有做過測試還是挺容易讓人糊涂的int nRecvBuf32*1024;//設置為32Ksetsockopt(s,SOL_SOCKET,SO_RCVBUF,(const char*)&nRecvBuf,sizeof(int))…

leetcode144 二叉樹的前序遍歷

給定一個二叉樹&#xff0c;返回它的 前序 遍歷。 示例: 輸入: [1,null,2,3] 1 \ 2 / 3 輸出: [1,2,3] 進階: 遞歸算法很簡單&#xff0c;你可以通過迭代算法完成嗎&#xff1f; 思路&#xff1a;模仿遞歸的思路壓棧即可。 /*** Definition for a bi…