文章目錄
- 昇思MindSpore應用實踐
- 基于MindSpore的DCGAN生成漫畫頭像
- 1、DCGAN 概述
- 零和博弈 vs 極大極小博弈
- GAN的生成對抗損失
- DCGAN原理
- 2、數據預處理
- 3、DCGAN模型構建
- 生成器部分
- 判別器部分
- 4、模型訓練
- Reference
昇思MindSpore應用實踐
本系列文章主要用于記錄昇思25天學習打卡營的學習心得。
基于MindSpore的DCGAN生成漫畫頭像
1、DCGAN 概述
這部分原理介紹參考昇思官方文檔GAN圖像生成和昇思25天學習打卡營第5天_GAN圖像生成
生成對抗網絡簡介:
零和博弈 vs 極大極小博弈
生成對抗網絡Generative adversarial networks (GANs)主要包括生成器網絡(Generator)和判別器網絡(Discriminator)
這兩個網絡在GAN的訓練過程中相互競爭,形成了一種博弈論中的極大極小博弈(MinMax game)
零和博弈(Zero-sum game)是博弈論中的一個重要概念,指的是參與者的利益完全相反,即一方的利益的增加意味著另一方的利益的減少,總利益為零。在零和博弈中,參與者之間的利益是完全對立的,因此一個參與者的利益的增加必然導致其他參與者的利益減少。在非合作博弈中,納什均衡是一種重要的解,納什均衡代表每個玩家選擇的策略都是其在對方策略給定的情況下的最優策略。在零和博弈中,尋找納什均衡通常涉及找到使每個玩家的預期收益最大化的策略組合。
極大極小博弈(MinMax game)是一種博弈論中的解決方法,用于確定參與者的最佳決策策略,此外為人所熟知用于決策的方法還有強化學習。在極大極小博弈中,每個參與者都試圖最大化自己的最小收益。也就是說,每個參與者都采取行動,以確保在對手選擇其最優策略時自己的收益最大化。
假設GAN網絡訓練達到了納什平衡狀態,那么判別器無法準確地判斷出輸入樣本是真樣本還是假樣本,此時判別器失效,生成器達到了巔峰狀態,我們就無需使用判別器并終止訓練了,得到的生成器就是我們用來生成數據的預訓練模型。
從理論上講,此博弈游戲的平衡點是 p G ( x ; θ ) = p d a t a ( x ) p_{G}(x;\theta) = p_{data}(x) pG?(x;θ)=pdata?(x),此時判別器會隨機猜測輸入是真圖像還是假圖像。下面我們簡要說明生成器和判別器的博弈過程:
- 在訓練剛開始的時候,生成器和判別器的質量都比較差,生成器會隨機生成一個數據分布;
- 判別器通過求取梯度和損失函數對網絡進行優化,將接近真實數據分布的數據判定為1( D ( x ) = 1 D(x)=1 D(x)=1),將接近生成器生成數據分布的數據判定為0(( G ( z ) = 0 G(z)=0 G(z)=0)),即希望 min ? G max ? D V ( G , D ) \underset{G}{\min} \underset{D}{\max}V(G, D) Gmin?Dmax?V(G,D);
- 生成器通過優化,生成出更加貼近真實數據分布的數據;
- 生成器所生成的數據和真實數據達到相同的分布,此時判別器的輸出為1/2,如上圖中的(d)所示。
GAN的生成對抗損失
min ? G max ? D V ( G , D ) = E x ~ p data ( x ) [ log ? D ( x ) ] + E z ~ p z ( z ) [ log ? ( 1 ? D ( G ( z ) ) ) ] \underset{G}{\min} \underset{D}{\max}V(G, D) = \mathbb{E}_{x \sim p{\text{data}}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))] Gmin?Dmax?V(G,D)=Ex~pdata(x)?[logD(x)]+Ez~pz?(z)?[log(1?D(G(z)))]
GAN網絡本身就是在訓練一個能達到平衡狀態的損失函數,生成對抗損失是GANs中最基本的損失函數。
當生成對抗損失達到納什均衡時,判別器對真假數據的判別概率都是0.5,即 D ( x ) = 1 ? G ( z ) = 0.5 D(x)=1-G(z)=0.5 D(x)=1?G(z)=0.5,
即 l o g ( D ( x ) ) = l o g ( 1 ? G ( z ) ) ≈ 0.693 log(D(x))=log(1-G(z))\approx0.693 log(D(x))=log(1?G(z))≈0.693
由于數據x和G(z)不僅是一張圖片,再分別取兩者的均值 E \mathbb{E} E,相加,就得到了生成對抗損失。
近十年來著名的GAN網絡結構:
DCGAN原理
如上圖所示,DCGAN(深度卷積對抗生成網絡,Deep Convolutional Generative Adversarial Networks)是GAN的直接擴展。
不同之處在于,DCGAN會分別在判別器和生成器中使用卷積和轉置卷積層。
它最早由Radford等人在論文Unsupervised Representation Learning With Deep Convolutional Generative Adversarial Networks中進行描述。判別器由分層的卷積層、BatchNorm層和LeakyReLU激活層組成。輸入是3x64x64的圖像,輸出是該圖像為真圖像的概率。生成器則是由轉置卷積層、BatchNorm層和ReLU激活層組成。輸入是標準正態分布中提取出的隱向量 z z z,輸出是3x64x64的RGB圖像。
本教程將使用動漫頭像數據集來訓練一個生成式對抗網絡,接著使用該網絡生成動漫頭像圖片。
2、數據預處理
import numpy as np
import mindspore.dataset as ds
import mindspore.dataset.vision as visiondef create_dataset_imagenet(dataset_path):"""數據加載"""dataset = ds.ImageFolderDataset(dataset_path,num_parallel_workers=4,shuffle=True,decode=True)# 數據增強操作transforms = [vision.Resize(image_size),vision.CenterCrop(image_size),vision.HWC2CHW(),lambda x: ((x / 255).astype("float32"))]# 數據映射操作dataset = dataset.project('image')dataset = dataset.map(transforms, 'image')# 批量操作dataset = dataset.batch(batch_size)return datasetdataset = create_dataset_imagenet('./faces')# 通過create_dict_iterator函數將數據轉換成字典迭代器,然后使用matplotlib模塊可視化部分訓練數據。import matplotlib.pyplot as pltdef plot_data(data):# 可視化部分訓練數據plt.figure(figsize=(10, 3), dpi=140)for i, image in enumerate(data[0][:30], 1):plt.subplot(3, 10, i)plt.axis("off")plt.imshow(image.transpose(1, 2, 0))plt.show()sample_data = next(dataset.create_tuple_iterator(output_numpy=True))
plot_data(sample_data)
3、DCGAN模型構建
生成器部分
生成器G
的功能是將隱向量z
映射到數據空間。由于數據是圖像,這一過程也會創建與真實圖像大小相同的 RGB 圖像。在實踐場景中,該功能是通過一系列Conv2dTranspose
轉置卷積層來完成的,每個層都與BatchNorm2d
層和ReLu
激活層配對,輸出數據會經過tanh
函數,使其返回[-1,1]
的數據范圍內。
DCGAN生成器生成圖像的大致流程如下:
1、將一個1x100的高斯潛在噪聲向量投影變換為一個4x4x1024的特征圖;
2、在經過CONV1卷積輸出為8x8x512的特征圖;
3、逐步增大分辨率,縮小通道數,經過CONV2卷積輸出為16x16x256的特征圖;
4、經過CONV3卷積輸出為32x32x128的特征圖;
5、最后經過CONV4卷積輸出為64x64x3的生成圖像,與真實圖像一起送入判別器進行鑒定;
6、在訓練過程中盡可能地生成逼近真實圖像分布的效果從而欺騙判別器,令其失效,這樣生成對抗就達到了平衡狀態,生成器的訓練過程完畢,拿去用作模型推理。
import mindspore as ms
from mindspore import nn, ops
from mindspore.common.initializer import Normalweight_init = Normal(mean=0, sigma=0.02)
gamma_init = Normal(mean=1, sigma=0.02)class Generator(nn.Cell):"""DCGAN網絡生成器"""def __init__(self):super(Generator, self).__init__()self.generator = nn.SequentialCell(nn.Conv2dTranspose(nz, ngf * 8, 4, 1, 'valid', weight_init=weight_init),nn.BatchNorm2d(ngf * 8, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf * 8, ngf * 4, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 4, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf * 4, ngf * 2, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 2, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf * 2, ngf, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf, nc, 4, 2, 'pad', 1, weight_init=weight_init),nn.Tanh())def construct(self, x):return self.generator(x)generator = Generator()
判別器部分
class Discriminator(nn.Cell):"""DCGAN網絡判別器"""def __init__(self):super(Discriminator, self).__init__()self.discriminator = nn.SequentialCell(nn.Conv2d(nc, ndf, 4, 2, 'pad', 1, weight_init=weight_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf, ndf * 2, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 2, gamma_init=gamma_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 4, gamma_init=gamma_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 8, gamma_init=gamma_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf * 8, 1, 4, 1, 'valid', weight_init=weight_init),)self.adv_layer = nn.Sigmoid()def construct(self, x):out = self.discriminator(x)out = out.reshape(out.shape[0], -1)return self.adv_layer(out)discriminator = Discriminator()
4、模型訓練
# 定義損失函數
adversarial_loss = nn.BCELoss(reduction='mean')# 為生成器和判別器設置優化器
optimizer_D = nn.Adam(discriminator.trainable_params(), learning_rate=lr, beta1=beta1)
optimizer_G = nn.Adam(generator.trainable_params(), learning_rate=lr, beta1=beta1)
optimizer_G.update_parameters_name('optim_g.')
optimizer_D.update_parameters_name('optim_d.')# 定義訓練時要用到的功能函數
def generator_forward(real_imgs, valid):# 將噪聲采樣為發生器的輸入z = ops.standard_normal((real_imgs.shape[0], nz, 1, 1))# 生成一批圖像gen_imgs = generator(z)# 損失衡量發生器繞過判別器的能力g_loss = adversarial_loss(discriminator(gen_imgs), valid)return g_loss, gen_imgsdef discriminator_forward(real_imgs, gen_imgs, valid, fake):# 衡量鑒別器從生成的樣本中對真實樣本進行分類的能力real_loss = adversarial_loss(discriminator(real_imgs), valid)fake_loss = adversarial_loss(discriminator(gen_imgs), fake)d_loss = (real_loss + fake_loss) / 2return d_lossgrad_generator_fn = ms.value_and_grad(generator_forward, None,optimizer_G.parameters,has_aux=True)
grad_discriminator_fn = ms.value_and_grad(discriminator_forward, None,optimizer_D.parameters)@ms.jit
def train_step(imgs):valid = ops.ones((imgs.shape[0], 1), mindspore.float32)fake = ops.zeros((imgs.shape[0], 1), mindspore.float32)(g_loss, gen_imgs), g_grads = grad_generator_fn(imgs, valid)optimizer_G(g_grads)d_loss, d_grads = grad_discriminator_fn(imgs, gen_imgs, valid, fake)optimizer_D(d_grads)return g_loss, d_loss, gen_imgsimport mindsporeG_losses = []
D_losses = []
image_list = []total = dataset.get_dataset_size()
for epoch in range(num_epochs):generator.set_train()discriminator.set_train()# 為每輪訓練讀入數據for i, (imgs, ) in enumerate(dataset.create_tuple_iterator()):g_loss, d_loss, gen_imgs = train_step(imgs)if i % 100 == 0 or i == total - 1:# 輸出訓練記錄print('[%2d/%d][%3d/%d] Loss_D:%7.4f Loss_G:%7.4f' % (epoch + 1, num_epochs, i + 1, total, d_loss.asnumpy(), g_loss.asnumpy()))D_losses.append(d_loss.asnumpy())G_losses.append(g_loss.asnumpy())# 每個epoch結束后,使用生成器生成一組圖片generator.set_train(False)fixed_noise = ops.standard_normal((batch_size, nz, 1, 1))img = generator(fixed_noise)image_list.append(img.transpose(0, 2, 3, 1).asnumpy())# 保存網絡模型參數為ckpt文件mindspore.save_checkpoint(generator, "./generator.ckpt")mindspore.save_checkpoint(discriminator, "./discriminator.ckpt")
cpu訓練5個epoch的訓練效果:
可以明顯看出Loss_D和Loss_G的分數并沒有達到0.5:0.5的納什平衡狀態,生成圖像自然是很可怕的抽象二次元漫畫頭像,這里忘了截圖了就不放效果了。
申請了Ascend910 NPU的算力,訓練50輪效果:
910太快了啊,吃頓飯回來就跑完了,不過結果還是蚌埠住了…
還是很糊,練崩了,今天先到這里了,先打次卡,有時間再調整一下網絡結構試試,DCGAN可能對Anime數據集來說還是太簡單了,不太好控制的樣子。
兩個網絡訓練的log:
Reference
昇思大模型平臺
什么是GAN生成對抗網絡,使用DCGAN生成動漫頭像