昇思25天學習打卡營第16天 | DCGAN生成漫畫頭像

這兩天把minspore配置到我的電腦上了,然后運行就沒什么問題了?😊

今天學這個DCGAN生成漫畫頭像,我超級感興趣的嘞🦄🥰

GAN基礎原理

這部分原理介紹參考GAN圖像生成。

DCGAN原理

DCGAN(深度卷積對抗生成網絡,Deep Convolutional Generative Adversarial Networks)是GAN的直接擴展。不同之處在于,DCGAN會分別在判別器和生成器中使用卷積和轉置卷積層。

它最早由Radford等人在論文Unsupervised Representation Learning With Deep Convolutional Generative Adversarial Networks中進行描述。判別器由分層的卷積層、BatchNorm層和LeakyReLU激活層組成。輸入是3x64x64的圖像,輸出是該圖像為真圖像的概率。生成器則是由轉置卷積層、BatchNorm層和ReLU激活層組成。輸入是標準正態分布中提取出的隱向量𝑧𝑧,輸出是3x64x64的RGB圖像。

本教程將使用動漫頭像數據集來訓練一個生成式對抗網絡,接著使用該網絡生成動漫頭像圖片。

數據準備與處理

from download import downloadurl = "https://download.mindspore.cn/dataset/Faces/faces.zip"path = download(url, "./faces", kind="zip", replace=True)

數據處理

batch_size = 128          # 批量大小
image_size = 64           # 訓練圖像空間大小
nc = 3                    # 圖像彩色通道數
nz = 100                  # 隱向量的長度
ngf = 64                  # 特征圖在生成器中的大小
ndf = 64                  # 特征圖在判別器中的大小
num_epochs = 3           # 訓練周期數
lr = 0.0002               # 學習率
beta1 = 0.5               # Adam優化器的beta1超參數import numpy as np
import mindspore.dataset as ds
import mindspore.dataset.vision as visiondef create_dataset_imagenet(dataset_path):"""數據加載"""dataset = ds.ImageFolderDataset(dataset_path,num_parallel_workers=4,shuffle=True,decode=True)# 數據增強操作transforms = [vision.Resize(image_size),vision.CenterCrop(image_size),vision.HWC2CHW(),lambda x: ((x / 255).astype("float32"))]# 數據映射操作dataset = dataset.project('image')dataset = dataset.map(transforms, 'image')# 批量操作dataset = dataset.batch(batch_size)return datasetdataset = create_dataset_imagenet('./faces')import matplotlib.pyplot as pltdef plot_data(data):# 可視化部分訓練數據plt.figure(figsize=(10, 3), dpi=140)for i, image in enumerate(data[0][:30], 1):plt.subplot(3, 10, i)plt.axis("off")plt.imshow(image.transpose(1, 2, 0))plt.show()sample_data = next(dataset.create_tuple_iterator(output_numpy=True))
plot_data(sample_data)

構造網絡

當處理完數據后,就可以來進行網絡的搭建了。按照DCGAN論文中的描述,所有模型權重均應從mean為0,sigma為0.02的正態分布中隨機初始化。

生成器

生成器G的功能是將隱向量z映射到數據空間。由于數據是圖像,這一過程也會創建與真實圖像大小相同的 RGB 圖像。在實踐場景中,該功能是通過一系列Conv2dTranspose轉置卷積層來完成的,每個層都與BatchNorm2d層和ReLu激活層配對,輸出數據會經過tanh函數,使其返回[-1,1]的數據范圍內。

DCGAN論文生成圖像如下所示:

dcgangenerator

我們通過輸入部分中設置的nzngfnc來影響代碼中的生成器結構。nz是隱向量z的長度,ngf與通過生成器傳播的特征圖的大小有關,nc是輸出圖像中的通道數。

以下是生成器的代碼實現:

import mindspore as ms
from mindspore import nn, ops
from mindspore.common.initializer import Normalweight_init = Normal(mean=0, sigma=0.02)
gamma_init = Normal(mean=1, sigma=0.02)class Generator(nn.Cell):"""DCGAN網絡生成器"""def __init__(self):super(Generator, self).__init__()self.generator = nn.SequentialCell(nn.Conv2dTranspose(nz, ngf * 8, 4, 1, 'valid', weight_init=weight_init),nn.BatchNorm2d(ngf * 8, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf * 8, ngf * 4, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 4, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf * 4, ngf * 2, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 2, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf * 2, ngf, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf, nc, 4, 2, 'pad', 1, weight_init=weight_init),nn.Tanh())def construct(self, x):return self.generator(x)generator = Generator()

判別器

如前所述,判別器D是一個二分類網絡模型,輸出判定該圖像為真實圖的概率。通過一系列的Conv2dBatchNorm2dLeakyReLU層對其進行處理,最后通過Sigmoid激活函數得到最終概率。

DCGAN論文提到,使用卷積而不是通過池化來進行下采樣是一個好方法,因為它可以讓網絡學習自己的池化特征。

判別器的代碼實現如下:

class Discriminator(nn.Cell):"""DCGAN網絡判別器"""def __init__(self):super(Discriminator, self).__init__()self.discriminator = nn.SequentialCell(nn.Conv2d(nc, ndf, 4, 2, 'pad', 1, weight_init=weight_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf, ndf * 2, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 2, gamma_init=gamma_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 4, gamma_init=gamma_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 8, gamma_init=gamma_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf * 8, 1, 4, 1, 'valid', weight_init=weight_init),)self.adv_layer = nn.Sigmoid()def construct(self, x):out = self.discriminator(x)out = out.reshape(out.shape[0], -1)return self.adv_layer(out)discriminator = Discriminator()

模型訓練

損失函數

當定義了DG后,接下來將使用MindSpore中定義的二進制交叉熵損失函數BCELoss。

優化器

這里設置了兩個單獨的優化器,一個用于D,另一個用于G。這兩個都是lr = 0.0002beta1 = 0.5的Adam優化器。

訓練模型

訓練分為兩個主要部分:訓練判別器和訓練生成器。

  • 訓練判別器

    訓練判別器的目的是最大程度地提高判別圖像真偽的概率。按照Goodfellow的方法,是希望通過提高其隨機梯度來更新判別器,所以我們要最大化𝑙𝑜𝑔𝐷(𝑥)+𝑙𝑜𝑔(1?𝐷(𝐺(𝑧))𝑙𝑜𝑔𝐷(𝑥)+𝑙𝑜𝑔(1?𝐷(𝐺(𝑧))的值。

  • 訓練生成器

    如DCGAN論文所述,我們希望通過最小化𝑙𝑜𝑔(1?𝐷(𝐺(𝑧)))𝑙𝑜𝑔(1?𝐷(𝐺(𝑧)))來訓練生成器,以產生更好的虛假圖像。

在這兩個部分中,分別獲取訓練過程中的損失,并在每個周期結束時進行統計,將fixed_noise批量推送到生成器中,以直觀地跟蹤G的訓練進度。

下面實現模型訓練正向邏輯:

# 定義損失函數
adversarial_loss = nn.BCELoss(reduction='mean')# 為生成器和判別器設置優化器
optimizer_D = nn.Adam(discriminator.trainable_params(), learning_rate=lr, beta1=beta1)
optimizer_G = nn.Adam(generator.trainable_params(), learning_rate=lr, beta1=beta1)
optimizer_G.update_parameters_name('optim_g.')
optimizer_D.update_parameters_name('optim_d.')def generator_forward(real_imgs, valid):# 將噪聲采樣為發生器的輸入z = ops.standard_normal((real_imgs.shape[0], nz, 1, 1))# 生成一批圖像gen_imgs = generator(z)# 損失衡量發生器繞過判別器的能力g_loss = adversarial_loss(discriminator(gen_imgs), valid)return g_loss, gen_imgsdef discriminator_forward(real_imgs, gen_imgs, valid, fake):# 衡量鑒別器從生成的樣本中對真實樣本進行分類的能力real_loss = adversarial_loss(discriminator(real_imgs), valid)fake_loss = adversarial_loss(discriminator(gen_imgs), fake)d_loss = (real_loss + fake_loss) / 2return d_lossgrad_generator_fn = ms.value_and_grad(generator_forward, None,optimizer_G.parameters,has_aux=True)
grad_discriminator_fn = ms.value_and_grad(discriminator_forward, None,optimizer_D.parameters)@ms.jit
def train_step(imgs):valid = ops.ones((imgs.shape[0], 1), mindspore.float32)fake = ops.zeros((imgs.shape[0], 1), mindspore.float32)(g_loss, gen_imgs), g_grads = grad_generator_fn(imgs, valid)optimizer_G(g_grads)d_loss, d_grads = grad_discriminator_fn(imgs, gen_imgs, valid, fake)optimizer_D(d_grads)return g_loss, d_loss, gen_imgsimport mindsporeG_losses = []
D_losses = []
image_list = []total = dataset.get_dataset_size()
for epoch in range(num_epochs):generator.set_train()discriminator.set_train()# 為每輪訓練讀入數據for i, (imgs, ) in enumerate(dataset.create_tuple_iterator()):g_loss, d_loss, gen_imgs = train_step(imgs)if i % 100 == 0 or i == total - 1:# 輸出訓練記錄print('[%2d/%d][%3d/%d]   Loss_D:%7.4f  Loss_G:%7.4f' % (epoch + 1, num_epochs, i + 1, total, d_loss.asnumpy(), g_loss.asnumpy()))D_losses.append(d_loss.asnumpy())G_losses.append(g_loss.asnumpy())# 每個epoch結束后,使用生成器生成一組圖片generator.set_train(False)fixed_noise = ops.standard_normal((batch_size, nz, 1, 1))img = generator(fixed_noise)image_list.append(img.transpose(0, 2, 3, 1).asnumpy())# 保存網絡模型參數為ckpt文件mindspore.save_checkpoint(generator, "./generator.ckpt")mindspore.save_checkpoint(discriminator, "./discriminator.ckpt")

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/41133.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/41133.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/41133.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Python中的lambda函數是什么以及它有哪些用途和限制

Python中的lambda函數 定義 Python中的lambda函數是一種簡潔定義小函數的方式,也被稱為匿名函數。它允許用戶快速定義一個小的、一次性的函數對象,而無需正式地命名一個函數。lambda函數的基本語法為:lambda arguments: expression&#xf…

港三新二是那幾所大學?有哪些知名校友?中英雙語介紹

中文版 港三新二指的是香港和新加坡的五所著名大學,分別是香港大學(HKU)、香港中文大學(CUHK)、香港科技大學(HKUST)、新加坡國立大學(NUS)和南洋理工大學(N…

秒驗—手機號碼置換接口

功能說明 提交客戶端獲取到的token、opToken等數據,驗證后返回手機號碼 服務端務必不要緩存DNS,否則可能影響服務高可用性 調用地址 POST https://identify-verify.dutils.com/auth/auth/sdkClientFreeLogin 請求頭 Content-Type :appli…

圖書商城系統java項目ssm項目jsp項目java課程設計java畢業設計

文章目錄 圖書商城系統一、項目演示二、項目介紹三、部分功能截圖四、部分代碼展示五、底部獲取項目源碼(9.9¥帶走) 圖書商城系統 一、項目演示 圖書商城系統 二、項目介紹 語言: Java 數據庫:MySQL 技術棧:SpringS…

SaaS行業的AI化征程:穿越“大模型焦慮”,擁抱“AI自信”

隨著大模型技術的風起云涌,SaaS行業正站在一個充滿機遇與挑戰的十字路口。本文旨在深入剖析SaaS廠商在AI化升級過程中所遭遇的“大模型焦慮”,并探索通過戰略性的AI應用策略,如何重拾信心,實現產品與服務的華麗轉身,為…

關于虛擬機上不了網的解決辦法

先ping出ip地址 或者查詢ifconfig得到目前網絡信息 繼續輸入命令Ifconfig -a查詢是否能找到ip地址 明顯ens33是沒有打開的,所以找不到分配的ip地址,需要打開,自動隨機分配ip 輸入命令: sudo dhclient ens33 現在就可以開始上網…

公司“領導”們竟如此討論工作!小伙:此事有蹊蹺;|國家漏洞庫CNNVD:關于OpenSSH安全漏洞的通報;

公司“領導”們竟如此討論工作!小伙:此事有蹊蹺 “當時我正在等驗證碼 還好你們快了一步 不然公司的93萬余元就沒了” 一談到這件事 杜先生仍然心有余悸 近日 正在處理公司財務工作的杜先生 突然被拉進了一個QQ群聊 從頭像、昵稱上看 群聊里的竟…

累積分布函數的一些性質證明

性質1: E [ X ] ∫ 0 ∞ ( 1 ? F ( x ) ) d x ? ∫ ? ∞ 0 F ( x ) d x ( 1 ) E[X]\int_0^{\infty}(1-F(x))dx - \int_{-\infty}^0F(x)dx\quad (1) E[X]∫0∞?(1?F(x))dx?∫?∞0?F(x)dx(1) 證明: E [ X ] ∫ ? ∞ ∞ x p ( x ) d x E[X] …

SpringBoot | 大新聞項目后端(redis優化登錄)

該項目的前篇內容的使用jwt令牌實現登錄認證,使用Md5加密實現注冊,在上一篇:http://t.csdnimg.cn/vn3rB 該篇主要內容:redis優化登錄和ThreadLocal提供線程局部變量,以及該大新聞項目的主要代碼。 redis優化登錄 其實…

macOS版ChatGPT更新:修復AI對話純文本存儲問題

貓頭虎 🐯 建聯貓頭虎,商務合作,產品評測,產品推廣,個人自媒體創作,超級個體,漲粉秘籍,一起探索編程世界的無限可能! macOS版ChatGPT更新:修復AI對話純文本…

HOW - React Router v6.x Feature 實踐(react-router-dom)

目錄 基本特性ranked routes matchingactive linksNavLinkuseMatch relative links1. 相對路徑的使用2. 嵌套路由的增強行為3. 優勢和注意事項4. . 和 ..5. 總結 data loadingloading or changing data and redirectpending navigation uiskeleton ui with suspensedata mutati…

JAVA高級進階11多線程

第十一天、多線程 線程安全問題 線程安全問題 多線程給我們帶來了很大性能上的提升,但是也可能引發線程安全問題 線程安全問題指的是當個多線程同時操作同一個共享資源的時候,可能會出現的操作結果不符預期問題 線程同步方案 認識線程同步 線程同步 線程同步就是讓多個線…

內網滲透學習-殺入內網

1、靶機上線cs 我們已經拿到了win7的shell,執行whoami,發現win7是administrator權限,且在域中 執行ipconfig發現了win7存在內網網段192.168.52.0/24 kali開啟cs服務端 客戶端啟動cs 先在cs中創建一個監聽器 接著用cs生成后門,記…

Mysql 的第二次作業

一、數據庫 1、登陸數據庫 2、創建數據庫zoo 3、修改數據庫zoo字符集為gbk 4、選擇當前數據庫為zoo 5、查看創建數據庫zoo信息 6、刪除數據庫zoo 1)登陸數據庫。 打開命令行,輸入登陸用戶名和密碼。 mysql -uroot -p123456 ? 2)切換數據庫…

菜雞的原地踏步史(???)

leetcode啟動!(╯‵□′)╯︵┻━┻ 嘗試改掉想到哪寫哪的代碼壞習慣 鏈表 相交鏈表 public class Solution {/**ac(公共長度)b所以 鏈表A的長度 a c,鏈表B的長度b ca b c b c a只要指針a從headA開始走,走完再…

利用pg_rman進行備份與恢復操作

文章目錄 pg_rman簡介一、安裝配置pg_rman二、創建表與用戶三、備份與恢復 pg_rman簡介 pg_rman 是 PostgreSQL 的在線備份和恢復工具。類似oracle 的 rman pg_rman 項目的目標是提供一種與 pg_dump 一樣簡單的在線備份和 PITR 方法。此外,它還為每個數據庫集群維護…

抖音使矛,美團用盾

有市場,就有競爭。抖音全力進軍本地生活市場欲取代美團,已不是新聞。 互聯網行業進入存量時代,本地生活市場是為數不多存在較大增長空間的賽道。艾媒咨詢數據顯示,預計2025年在線餐飲外賣市場規模達到17469億元,生鮮電…

Day05-01-jenkins進階

Day05-01-jenkins進階 10. 案例07: 理解 案例06基于ans實現10.1 整體流程10.2 把shell改為Ansible劇本10.3 jk調用ansible全流程10.4 書寫劇本 11. Jenkins進階11.1 jenkins分布式1)概述2)案例08:拆分docker功能3)創建任務并綁定到…

安裝 ClamAV 并進行病毒掃描

安裝 ClamAV 并進行病毒掃描 以下是安裝 ClamAV 并使用它進行病毒掃描的步驟: 1. 安裝 ClamAV 在 Debian/Ubuntu 系統上: sudo apt update sudo apt install clamav clamav-daemon在 RHEL/CentOS 系統上: sudo yum install epel-release…

開發指南040-swagger加header

swagger可以在線生成接口文檔,便于前后端溝通,而且還可以在線調用接口,方便后臺調試。但是接口需要經過登錄校驗,部分接口還需要得到登錄token,使用token識別用戶身份進行后續操作。這種情況下,都需要接口增…