昇思25天學習打卡營第8天|DCGAN生成漫畫頭像

文章目錄

      • 昇思MindSpore應用實踐
        • 基于MindSpore的DCGAN生成漫畫頭像
          • 1、DCGAN 概述
            • 零和博弈 vs 極大極小博弈
            • GAN的生成對抗損失
            • DCGAN原理
          • 2、數據預處理
          • 3、DCGAN模型構建
            • 生成器部分
            • 判別器部分
          • 4、模型訓練
      • Reference

昇思MindSpore應用實踐

本系列文章主要用于記錄昇思25天學習打卡營的學習心得。

基于MindSpore的DCGAN生成漫畫頭像
1、DCGAN 概述

這部分原理介紹參考昇思官方文檔GAN圖像生成和昇思25天學習打卡營第5天_GAN圖像生成

生成對抗網絡簡介:

零和博弈 vs 極大極小博弈

生成對抗網絡Generative adversarial networks (GANs)主要包括生成器網絡(Generator)和判別器網絡(Discriminator)
這兩個網絡在GAN的訓練過程中相互競爭,形成了一種博弈論中的極大極小博弈(MinMax game)

零和博弈(Zero-sum game)是博弈論中的一個重要概念,指的是參與者的利益完全相反,即一方的利益的增加意味著另一方的利益的減少,總利益為零。在零和博弈中,參與者之間的利益是完全對立的,因此一個參與者的利益的增加必然導致其他參與者的利益減少。在非合作博弈中,納什均衡是一種重要的解,納什均衡代表每個玩家選擇的策略都是其在對方策略給定的情況下的最優策略。在零和博弈中,尋找納什均衡通常涉及找到使每個玩家的預期收益最大化的策略組合。

極大極小博弈(MinMax game)是一種博弈論中的解決方法,用于確定參與者的最佳決策策略,此外為人所熟知用于決策的方法還有強化學習。在極大極小博弈中,每個參與者都試圖最大化自己的最小收益。也就是說,每個參與者都采取行動,以確保在對手選擇其最優策略時自己的收益最大化。

假設GAN網絡訓練達到了納什平衡狀態,那么判別器無法準確地判斷出輸入樣本是真樣本還是假樣本,此時判別器失效,生成器達到了巔峰狀態,我們就無需使用判別器并終止訓練了,得到的生成器就是我們用來生成數據的預訓練模型。

在這里插入圖片描述
從理論上講,此博弈游戲的平衡點是 p G ( x ; θ ) = p d a t a ( x ) p_{G}(x;\theta) = p_{data}(x) pG?(x;θ)=pdata?(x),此時判別器會隨機猜測輸入是真圖像還是假圖像。下面我們簡要說明生成器和判別器的博弈過程:

  1. 在訓練剛開始的時候,生成器和判別器的質量都比較差,生成器會隨機生成一個數據分布;
  2. 判別器通過求取梯度和損失函數對網絡進行優化,將接近真實數據分布的數據判定為1 D ( x ) = 1 D(x)=1 D(x)=1),將接近生成器生成數據分布數據判定為0(( G ( z ) = 0 G(z)=0 G(z)=0)),即希望 min ? G max ? D V ( G , D ) \underset{G}{\min} \underset{D}{\max}V(G, D) Gmin?Dmax?V(G,D)
  3. 生成器通過優化,生成出更加貼近真實數據分布的數據;
  4. 生成器所生成的數據和真實數據達到相同的分布,此時判別器的輸出為1/2,如上圖中的(d)所示。
GAN的生成對抗損失

min ? G max ? D V ( G , D ) = E x ~ p data ( x ) [ log ? D ( x ) ] + E z ~ p z ( z ) [ log ? ( 1 ? D ( G ( z ) ) ) ] \underset{G}{\min} \underset{D}{\max}V(G, D) = \mathbb{E}_{x \sim p{\text{data}}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))] Gmin?Dmax?V(G,D)=Expdata(x)?[logD(x)]+Ezpz?(z)?[log(1?D(G(z)))]

GAN網絡本身就是在訓練一個能達到平衡狀態的損失函數,生成對抗損失是GANs中最基本的損失函數。

當生成對抗損失達到納什均衡時,判別器對真假數據的判別概率都是0.5,即 D ( x ) = 1 ? G ( z ) = 0.5 D(x)=1-G(z)=0.5 D(x)=1?G(z)=0.5

l o g ( D ( x ) ) = l o g ( 1 ? G ( z ) ) ≈ 0.693 log(D(x))=log(1-G(z))\approx0.693 log(D(x))=log(1?G(z))0.693

由于數據x和G(z)不僅是一張圖片,再分別取兩者的均值 E \mathbb{E} E,相加,就得到了生成對抗損失。

近十年來著名的GAN網絡結構:
在這里插入圖片描述

DCGAN原理

如上圖所示,DCGAN(深度卷積對抗生成網絡,Deep Convolutional Generative Adversarial Networks)是GAN的直接擴展。
不同之處在于,DCGAN會分別在判別器和生成器中使用卷積和轉置卷積層

它最早由Radford等人在論文Unsupervised Representation Learning With Deep Convolutional Generative Adversarial Networks中進行描述。判別器由分層的卷積層、BatchNorm層和LeakyReLU激活層組成。輸入是3x64x64的圖像,輸出是該圖像為真圖像的概率。生成器則是由轉置卷積層、BatchNorm層和ReLU激活層組成。輸入是標準正態分布中提取出的隱向量 z z z,輸出是3x64x64的RGB圖像。

本教程將使用動漫頭像數據集來訓練一個生成式對抗網絡,接著使用該網絡生成動漫頭像圖片。

2、數據預處理
import numpy as np
import mindspore.dataset as ds
import mindspore.dataset.vision as visiondef create_dataset_imagenet(dataset_path):"""數據加載"""dataset = ds.ImageFolderDataset(dataset_path,num_parallel_workers=4,shuffle=True,decode=True)# 數據增強操作transforms = [vision.Resize(image_size),vision.CenterCrop(image_size),vision.HWC2CHW(),lambda x: ((x / 255).astype("float32"))]# 數據映射操作dataset = dataset.project('image')dataset = dataset.map(transforms, 'image')# 批量操作dataset = dataset.batch(batch_size)return datasetdataset = create_dataset_imagenet('./faces')# 通過create_dict_iterator函數將數據轉換成字典迭代器,然后使用matplotlib模塊可視化部分訓練數據。import matplotlib.pyplot as pltdef plot_data(data):# 可視化部分訓練數據plt.figure(figsize=(10, 3), dpi=140)for i, image in enumerate(data[0][:30], 1):plt.subplot(3, 10, i)plt.axis("off")plt.imshow(image.transpose(1, 2, 0))plt.show()sample_data = next(dataset.create_tuple_iterator(output_numpy=True))
plot_data(sample_data)

在這里插入圖片描述

3、DCGAN模型構建
生成器部分

生成器G的功能是將隱向量z映射到數據空間。由于數據是圖像,這一過程也會創建與真實圖像大小相同的 RGB 圖像。在實踐場景中,該功能是通過一系列Conv2dTranspose轉置卷積層來完成的,每個層都與BatchNorm2d層和ReLu激活層配對,輸出數據會經過tanh函數,使其返回[-1,1]的數據范圍內。

DCGAN生成器生成圖像的大致流程如下:

1、將一個1x100的高斯潛在噪聲向量投影變換為一個4x4x1024的特征圖;
2、在經過CONV1卷積輸出為8x8x512的特征圖;
3、逐步增大分辨率,縮小通道數,經過CONV2卷積輸出為16x16x256的特征圖;
4、經過CONV3卷積輸出為32x32x128的特征圖;
5、最后經過CONV4卷積輸出為64x64x3的生成圖像,與真實圖像一起送入判別器進行鑒定;
6、在訓練過程中盡可能地生成逼近真實圖像分布的效果從而欺騙判別器,令其失效,這樣生成對抗就達到了平衡狀態,生成器的訓練過程完畢,拿去用作模型推理。
在這里插入圖片描述

import mindspore as ms
from mindspore import nn, ops
from mindspore.common.initializer import Normalweight_init = Normal(mean=0, sigma=0.02)
gamma_init = Normal(mean=1, sigma=0.02)class Generator(nn.Cell):"""DCGAN網絡生成器"""def __init__(self):super(Generator, self).__init__()self.generator = nn.SequentialCell(nn.Conv2dTranspose(nz, ngf * 8, 4, 1, 'valid', weight_init=weight_init),nn.BatchNorm2d(ngf * 8, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf * 8, ngf * 4, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 4, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf * 4, ngf * 2, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 2, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf * 2, ngf, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf, nc, 4, 2, 'pad', 1, weight_init=weight_init),nn.Tanh())def construct(self, x):return self.generator(x)generator = Generator()
判別器部分

在這里插入圖片描述

class Discriminator(nn.Cell):"""DCGAN網絡判別器"""def __init__(self):super(Discriminator, self).__init__()self.discriminator = nn.SequentialCell(nn.Conv2d(nc, ndf, 4, 2, 'pad', 1, weight_init=weight_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf, ndf * 2, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 2, gamma_init=gamma_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 4, gamma_init=gamma_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 8, gamma_init=gamma_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf * 8, 1, 4, 1, 'valid', weight_init=weight_init),)self.adv_layer = nn.Sigmoid()def construct(self, x):out = self.discriminator(x)out = out.reshape(out.shape[0], -1)return self.adv_layer(out)discriminator = Discriminator()
4、模型訓練
# 定義損失函數
adversarial_loss = nn.BCELoss(reduction='mean')# 為生成器和判別器設置優化器
optimizer_D = nn.Adam(discriminator.trainable_params(), learning_rate=lr, beta1=beta1)
optimizer_G = nn.Adam(generator.trainable_params(), learning_rate=lr, beta1=beta1)
optimizer_G.update_parameters_name('optim_g.')
optimizer_D.update_parameters_name('optim_d.')# 定義訓練時要用到的功能函數
def generator_forward(real_imgs, valid):# 將噪聲采樣為發生器的輸入z = ops.standard_normal((real_imgs.shape[0], nz, 1, 1))# 生成一批圖像gen_imgs = generator(z)# 損失衡量發生器繞過判別器的能力g_loss = adversarial_loss(discriminator(gen_imgs), valid)return g_loss, gen_imgsdef discriminator_forward(real_imgs, gen_imgs, valid, fake):# 衡量鑒別器從生成的樣本中對真實樣本進行分類的能力real_loss = adversarial_loss(discriminator(real_imgs), valid)fake_loss = adversarial_loss(discriminator(gen_imgs), fake)d_loss = (real_loss + fake_loss) / 2return d_lossgrad_generator_fn = ms.value_and_grad(generator_forward, None,optimizer_G.parameters,has_aux=True)
grad_discriminator_fn = ms.value_and_grad(discriminator_forward, None,optimizer_D.parameters)@ms.jit
def train_step(imgs):valid = ops.ones((imgs.shape[0], 1), mindspore.float32)fake = ops.zeros((imgs.shape[0], 1), mindspore.float32)(g_loss, gen_imgs), g_grads = grad_generator_fn(imgs, valid)optimizer_G(g_grads)d_loss, d_grads = grad_discriminator_fn(imgs, gen_imgs, valid, fake)optimizer_D(d_grads)return g_loss, d_loss, gen_imgsimport mindsporeG_losses = []
D_losses = []
image_list = []total = dataset.get_dataset_size()
for epoch in range(num_epochs):generator.set_train()discriminator.set_train()# 為每輪訓練讀入數據for i, (imgs, ) in enumerate(dataset.create_tuple_iterator()):g_loss, d_loss, gen_imgs = train_step(imgs)if i % 100 == 0 or i == total - 1:# 輸出訓練記錄print('[%2d/%d][%3d/%d]   Loss_D:%7.4f  Loss_G:%7.4f' % (epoch + 1, num_epochs, i + 1, total, d_loss.asnumpy(), g_loss.asnumpy()))D_losses.append(d_loss.asnumpy())G_losses.append(g_loss.asnumpy())# 每個epoch結束后,使用生成器生成一組圖片generator.set_train(False)fixed_noise = ops.standard_normal((batch_size, nz, 1, 1))img = generator(fixed_noise)image_list.append(img.transpose(0, 2, 3, 1).asnumpy())# 保存網絡模型參數為ckpt文件mindspore.save_checkpoint(generator, "./generator.ckpt")mindspore.save_checkpoint(discriminator, "./discriminator.ckpt")

cpu訓練5個epoch的訓練效果:
在這里插入圖片描述
可以明顯看出Loss_D和Loss_G的分數并沒有達到0.5:0.5的納什平衡狀態,生成圖像自然是很可怕的抽象二次元漫畫頭像,這里忘了截圖了就不放效果了。

申請了Ascend910 NPU的算力,訓練50輪效果:

910太快了啊,吃頓飯回來就跑完了,不過結果還是蚌埠住了…
在這里插入圖片描述
還是很糊,練崩了,今天先到這里了,先打次卡,有時間再調整一下網絡結構試試,DCGAN可能對Anime數據集來說還是太簡單了,不太好控制的樣子。
在這里插入圖片描述
兩個網絡訓練的log:

在這里插入圖片描述

Reference

昇思大模型平臺
什么是GAN生成對抗網絡,使用DCGAN生成動漫頭像

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/38894.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/38894.shtml
英文地址,請注明出處:http://en.pswp.cn/web/38894.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

機器學習基礎概念

1.機器學習定義 2.機器學習工作流程 (1)數據集 ①一行數據:一個樣本 ②一列數據:一個特征 ③目標值(標簽值):有些數據集有目標值,有些數據集沒有。因此數據類型由特征值目標值構成或…

Java實現圖書管理系統

一、框架 1. 創建類 用戶:管理員AdminUser 普通用戶NormalUser 繼承抽象類User 書:書Book 書架BookList 操作對象:書Book 2. 知識點 主要涉及的知識點:數據類型 變量 if for 數組 方法 類和對象 封裝繼承多態 抽象類和接口 …

Linux運維之需掌握的基本Linux命令

前言:本博客僅作記錄學習使用,部分圖片出自網絡,如有侵犯您的權益,請聯系刪除 目錄 一、SHELL 二、執行命令 三、常用系統工作命令 四、系統狀態檢測命令 五、查找定位文件命令 六、文本文件編輯命令 七、文件目錄管理命令…

【JavaWeb】登錄校驗-會話技術(一)Cookie與Session

登錄校驗 實現登陸后才能訪問后端系統頁面,不登陸則跳轉登陸頁面進行登陸。 首先我們在宏觀上先有一個認知: HTTP協議是無狀態協議。即每一次請求都是獨立的,下一次請求并不會攜帶上一次請求的數據。 因此當我們通過瀏覽器訪問登錄后&#…

go語言怎么獲取文件的大小并且轉化為kb為單位呢?

在Go語言中,你可以使用os包中的IsExist和Stat函數來獲取文件的信息,包括文件的大小。文件的大小通常是以字節為單位的,但你可以很容易地將其轉換為KB(千字節)。 下面是一個簡單的Go程序示例,該程序打開指定…

Simulink 模型生成 C 代碼(一):使用 Embedded Coder 快速向導生成代碼

以matlab自帶的示例模型RollAxisAutopilot為例進行講解。RollAxisAutopilot為飛機自動駕駛控制系統模型。 使用快速向導工具生成代碼 通過鍵入以下命令打開模型 RollAxisAutopilot: openExample(RollAxisAutopilot); 如果 C 代碼選項卡尚未打開,請在 …

【C++】宏定義

嚴格來說,這個題目起名為C是不合適的,因為宏定義是C語言的遺留特性。CleanCode并不推薦C中使用宏定義。我當時還在公司做過宏定義為什么應該被取代的報告。但是適當使用宏定義對代碼是有好處的。壞處也有一些。 無參宏定義 最常見的一種宏定義&#xf…

makefile總結

1,Makefile規則介紹 一個簡單的 Makefile 描述規則組成: TARGET... : PREREQUISITES... COMMAND 注意: 每一個命令行必須以[Tab]字符開始, [Tab]字符告訴 make 此行是一個命令行。 make 按照命令完成相應的動作。這也是書寫 Makefile 中容易產生,而且比較隱蔽的錯…

油煙凈化器:餐飲業健康環保的守護者

我最近分析了餐飲市場的油煙凈化器等產品報告,解決了餐飲業廚房油膩的難題,更加方便了在餐飲業和商業場所有需求的小伙伴們。 在現代餐飲業,油煙凈化器已經成為不可或缺的重要設備。它不僅是保障餐飲環境清潔的利器,更是守護健康…

新聲創新20年:無線技術給助聽器插上“娛樂”的翅膀

聽力損失并非現代人的專利,古代人也會有聽力損失。助聽器距今發展已經有二百多年了,從當初單純的聲音放大器到如今的全數字時代助聽器,助聽器發生了翻天覆地的變化,現代助聽器除了助聽功能,還具有看電視,聽…

【LeetCode】368. 最大整除子集

雖然這題挺難寫的,但是仍然提醒了我:解題要注意方法。在明確分析當一條道路走不通的時候,就不要再猶豫了,就要果斷的換方法,嘗試用其它方法解決。否則一味的消耗時間,得不償失。換方法的前提是明確的分析&a…

C++ 和C#的差別

首先把眼睛瞪大,然后憋住一口氣,讀下去: 1、CPP 就是C plus plus的縮寫,中國大陸的程序員圈子中通常被讀做"C加加",而西方的程序員通常讀做"C plus plus",它是一種使用非常廣泛的計算…

Maya崩潰閃退常見原因及解決方案

Autodesk Maya 是一款功能強大的 3D 計算機圖形程序,被電影、游戲和建筑等各個領域的設計師廣泛使用。然而,Maya 就像任何其他軟件一樣可能會發生崩潰問題。在前文中,小編給大家介紹了3ds Max使用V-Ray渲染時的崩潰閃退解決方案: …

Neo4j 圖數據庫 高級操作

Neo4j 圖數據庫 高級操作 文章目錄 Neo4j 圖數據庫 高級操作1 批量添加節點、關系1.1 直接使用 UNWIND 批量創建關系1.2 使用 CSV 文件批量創建關系1.3 選擇方法 2 索引2.1 創建單一屬性索引2.2 創建組合屬性索引2.3 創建全文索引2.4 列出所有索引2.5 刪除索引2.6 注意事項 3 清…

后端之路第三站(Mybatis)——JDBC跟Mybatis、lombok

一、什么是JDBC JDBC就是sun公司研發的一套通過java來操控數據庫的工具,對應不同的數據庫系統有不同的JDBC,而他們統稱【驅動】,這就是上一篇我們提到創建Mybatis項目時要引入的依賴、以及連接數據庫四要素里的第一要素。 JDBC有自己一套原始…

SerialportToTCP② 全

效果補全(代碼): namespace SerialportToTCP {public partial class Form1 : Form{IniHelper Ini;string[] botelvs new string[] { "1200", "4800", "9600", "13200" };public Form1(){Initializ…

Elasticsearch:Painless scripting 語言(一)

Painless 是一種高性能、安全的腳本語言,專為 Elasticsearch 設計。你可以使用 Painless 在 Elasticsearch 支持腳本的任何地方安全地編寫內聯和存儲腳本。 Painless 提供眾多功能,這些功能圍繞以下核心原則: 安全性:確保集群的…

安卓gdb 建立鏈接

adbshell gdbserver :1234 testdcam --sensor 0 --workmode 0 --args preview-size1024x600,picture-size640x480, --time 10 adb forwardtcp:1234 tcp:1234 //設置adb的轉發 ./prebuilts/gcc/linux-x86/arm/arm-linux-androideabi-4.7/bin/arm-linux-androideabi-gdb out/tar…

近紅外光譜腦功能成像(fNIRS):1.光學原理、變量選取與預處理

一、朗伯-比爾定律與修正的朗伯-比爾定律 朗伯-比爾定律 是一個描述光通過溶液時被吸收的規律。想象你有一杯有色液體,比如一杯紅茶。當你用一束光照射這杯液體時,光的一部分會被液體吸收,導致透過液體的光變弱。朗伯-比爾定律告訴我們&#…

mmdetection3D指定版本安裝指南

1. 下載指定版本號 選擇指定版本號下載mmdetection3d的源碼,如這里選擇的是0.17.2版本 git clone https://github.com/open-mmlab/mmdetection3d.git -b v0.17.22. 安裝 cd mmdetection3d安裝依賴庫 pip install -r requirment.txt編譯安裝 pip install -v e .…