昇思Mindspore25天學習打卡Day20:DCGAN生成漫畫頭像
- 1 GAN基礎原理
- 2 DCGAN原理
- 3 數據準備與處理
- 數據處理
- 4 構造網絡
- 4.1 生成器
- 4.2 判別器
- 5 模型訓練
- 損失函數
- 優化器
- 訓練模型
- 6 結果展示
- 7 訓練結束打上標簽和時間
在下面的教程中,我們將通過示例代碼說明DCGAN網絡如何設置網絡、優化器、如何計算損失函數以及如何初始化模型權重。在本教程中,使用的動漫頭像數據集共有70,171張動漫頭像圖片,圖片大小均為96*96.
1 GAN基礎原理
這部分原理介紹參考Link:GAN圖像生成
2 DCGAN原理
DCGAN(深度卷積對抗生成網絡,Deep Convolutional Generative Adversarial Networks)是GAN的直接擴展。不同之處在于,DCGAN會分別在判別器和生成器中使用卷積和轉置卷積層。
它最早由Radford等人在論文Link:Unsupervised Representation Learning With Deep Convolutional Generative Adversarial Networks中進行描述。判別器由分層的卷積層、BatchNorm層和LeakReLU激活層組成。輸入是3x64x64的圖像,輸出是該圖像為真圖像的概率。生成器則是由轉置卷積層、BatchNorm層和ReLU激活層組成。輸入是標準正態分布中提取出的隱向量z,輸出是3x64x64的RGB圖像。
本教程將使用動漫頭像數據集來訓練一個生成式對抗網絡,接著使用該網絡生成動漫頭像圖片。
3 數據準備與處理
首先我們將數據集下載到指定目錄下并解壓。示例代碼如下:
數據處理
定義 create_dataset_imagenet 函數對數據進行處理和增強操作.
通過 create dict iterator 函數將數據轉換成字典迭代器,然后使用 matplotlib 模塊可視化部分訓練數據。
4 構造網絡
當處理完數據后,就可以來進行網絡的搭建了。按照DCGAN論文中的描述,所有模型權重均應從 mean為0,sigma為0.02的正態分布中隨機初始化。
4.1 生成器
生成器 G的功能是將隱向量z 映射到數據空間。由于數據是圖像,這一過程也會創建與真實圖像大小相同的 RG8 圖像。在實踐場景中,該功能是通過一系列
Conv2dTranspose 轉置卷積層來完成的,每個層都與 BatchNorm2d 層和 ReLu 激活層配對,輸出數據會經過 tanh 函數,使其返回[-1,1]的數據范圍內。
DCGAN論文生成圖像如下所示:
- 圖片來源: Liuk:Unsupervised Representation Learning With Deep Convolutional Generative Adversarial Networks
我們通過輸入部分中設置的 nz、ngf 和 nc 來影響代碼中的生成器結構。 nz 是隱向量 z的長度,
ngf 與通過生成器傳播的特征圖的大小有關,nc 是輸出圖像中的通道數。
以下是生成器的代碼實現:
4.2 判別器
如前所述,判別器D是一個二分類網絡模型,輸出判定該圖像為真實圖的概率。通過一系列的 conv2d、
BatchNorm2d 和 LeakyReLu 層對其進行處理,最后通過 Sigmoid 激活函數得到最終概率。
DCGAN論文提到,使用卷積而不是通過池化來進行下采樣是一個好方法,因為它可以讓網絡學習自己的池化特征。
判別器的代碼實現如下:
5 模型訓練
損失函數
當定義了 D和G后,接下來將使用Mindspore中定義的二進制交叉熵損失函數Link:BCELoss
優化器
這里設置了兩個單獨的優化器,一個用于D,另一個用于G。這兩個都是 1r=0.8002 和 beta1 = 0.5 的Adam優化器。
訓練模型
訓練分為兩個主要部分:訓練判別器和訓練生成器.
- 訓練判別器
訓練判別器的目的是最大程度地提高判別圖像真偽的概率。按照Goodfelow的方法,是希望通過提高其隨機梯度來更新判別器,所以我們要最大化
l o g D ( x ) + l o g ( 1 ? D ( G ( z ) ) logD(x)+log(1-D(G(z)) logD(x)+log(1?D(G(z))的值, - 訓練生成器
如DCGAN論文所述,我們希望通過最小化 l o g ( 1 ? D ( G ( z ) ) ) log(1- D(G(z))) log(1?D(G(z)))來訓練生成器,以產生更好的虛假圖像,。
在這兩個部分中,分別獲取訓練過程中的損失,并在每個周期結束時進行統計,將 fixed_noise 批量推送到生成器中,以直觀地跟蹤 G的訓練進度.
下面實現模型訓練正向邏輯:
循環訓練網絡,每經過50次迭代,就收集生成器和判別器的損失,以便于后面繪制訓練過程中損失函數的圖像。
6 結果展示
運行下面代碼,描繪 D和G損失與訓練迭代的關系圖
可視化訓練過程中通過隱向量 fixed noise 生成的圖像。
從上面的圖像可以看出,隨著訓練次數的增多,圖像質量也越來越好。如果增大訓練周期數,當 num_epochs 達到50以上時,生成的動漫頭像圖片與數據集中的較為相似,下面我們通過加載生成器網絡模型參數文件來生成圖像,代碼如下: