昇思訓練營打卡第二十一天(DCGAN生成漫畫頭像)

DCGAN,即深度卷積生成對抗網絡(Deep Convolutional Generative Adversarial Network),是一種深度學習模型,由Ian Goodfellow等人在2014年提出。DCGAN在生成對抗網絡(GAN)的基礎上,引入了深度卷積神經網絡(CNN)的結構,用于生成高質量、高分辨率的圖像。

DCGAN的原理可以概括為兩個主要部分:生成器(Generator)和判別器(Discriminator)。生成器和判別器都是深度卷積神經網絡,它們通過對抗過程互相博弈,最終達到納什均衡。

  1. 生成器(Generator):生成器的輸入是一個隨機的噪聲向量,通過一系列的卷積、反卷積、批標準化(Batch Normalization)和激活函數(如ReLU、Tanh等)操作,生成一個與真實圖像具有相同尺寸的圖像。生成器的目標是生成盡可能逼真的圖像,以欺騙判別器。

  2. 判別器(Discriminator):判別器的輸入是一個圖像,它通過一系列的卷積、批標準化和激活函數操作,判斷輸入圖像是真實圖像還是生成器生成的假圖像。判別器的目標是能夠準確地區分真實圖像和假圖像。

在訓練過程中,生成器和判別器交替進行優化。生成器嘗試生成逼真的圖像,而判別器嘗試更好地識別真實圖像和假圖像。這個過程可以看作是一種博弈,生成器和判別器在不斷的迭代過程中提高自己的性能。最終,當生成器和判別器達到納什均衡時,生成器能夠生成高質量的逼真圖像,判別器無法準確地區分真實圖像和假圖像。

DCGAN在計算機視覺領域有廣泛的應用,如圖像生成、圖像修復、圖像轉換等。通過調整網絡結構和訓練策略,DCGAN還可以應用于其他領域,如自然語言處理、音頻生成等。

數據準備與處理

%%capture captured_output
# 實驗環境已經預裝了mindspore==2.2.14,如需更換mindspore版本,可更改下面mindspore的版本號
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
from download import downloadurl = "https://download.mindspore.cn/dataset/Faces/faces.zip"path = download(url, "./faces", kind="zip", replace=True)
import numpy as np
import mindspore.dataset as ds
import mindspore.dataset.vision as visiondef create_dataset_imagenet(dataset_path):"""數據加載"""dataset = ds.ImageFolderDataset(dataset_path,num_parallel_workers=4,shuffle=True,decode=True)# 數據增強操作transforms = [vision.Resize(image_size),vision.CenterCrop(image_size),vision.HWC2CHW(),lambda x: ((x / 255).astype("float32"))]# 數據映射操作dataset = dataset.project('image')dataset = dataset.map(transforms, 'image')# 批量操作dataset = dataset.batch(batch_size)return datasetdataset = create_dataset_imagenet('./faces')
import matplotlib.pyplot as pltdef plot_data(data):# 可視化部分訓練數據plt.figure(figsize=(10, 3), dpi=140)for i, image in enumerate(data[0][:30], 1):plt.subplot(3, 10, i)plt.axis("off")plt.imshow(image.transpose(1, 2, 0))plt.show()sample_data = next(dataset.create_tuple_iterator(output_numpy=True))
plot_data(sample_data)

構造網絡

生成器

import mindspore as ms
from mindspore import nn, ops
from mindspore.common.initializer import Normalweight_init = Normal(mean=0, sigma=0.02)
gamma_init = Normal(mean=1, sigma=0.02)class Generator(nn.Cell):"""DCGAN網絡生成器"""def __init__(self):super(Generator, self).__init__()self.generator = nn.SequentialCell(nn.Conv2dTranspose(nz, ngf * 8, 4, 1, 'valid', weight_init=weight_init),nn.BatchNorm2d(ngf * 8, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf * 8, ngf * 4, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 4, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf * 4, ngf * 2, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 2, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf * 2, ngf, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf, nc, 4, 2, 'pad', 1, weight_init=weight_init),nn.Tanh())def construct(self, x):return self.generator(x)generator = Generator()

判別器

class Discriminator(nn.Cell):"""DCGAN網絡判別器"""def __init__(self):super(Discriminator, self).__init__()self.discriminator = nn.SequentialCell(nn.Conv2d(nc, ndf, 4, 2, 'pad', 1, weight_init=weight_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf, ndf * 2, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 2, gamma_init=gamma_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 4, gamma_init=gamma_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 8, gamma_init=gamma_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf * 8, 1, 4, 1, 'valid', weight_init=weight_init),)self.adv_layer = nn.Sigmoid()def construct(self, x):out = self.discriminator(x)out = out.reshape(out.shape[0], -1)return self.adv_layer(out)discriminator = Discriminator()

模型訓練

損失函數

# 定義損失函數
adversarial_loss = nn.BCELoss(reduction='mean')

優化器

# 為生成器和判別器設置優化器
optimizer_D = nn.Adam(discriminator.trainable_params(), learning_rate=lr, beta1=beta1)
optimizer_G = nn.Adam(generator.trainable_params(), learning_rate=lr, beta1=beta1)
optimizer_G.update_parameters_name('optim_g.')
optimizer_D.update_parameters_name('optim_d.')

訓練模型

訓練分為兩個主要部分:訓練判別器和訓練生成器。

def generator_forward(real_imgs, valid):# 將噪聲采樣為發生器的輸入z = ops.standard_normal((real_imgs.shape[0], nz, 1, 1))# 生成一批圖像gen_imgs = generator(z)# 損失衡量發生器繞過判別器的能力g_loss = adversarial_loss(discriminator(gen_imgs), valid)return g_loss, gen_imgsdef discriminator_forward(real_imgs, gen_imgs, valid, fake):# 衡量鑒別器從生成的樣本中對真實樣本進行分類的能力real_loss = adversarial_loss(discriminator(real_imgs), valid)fake_loss = adversarial_loss(discriminator(gen_imgs), fake)d_loss = (real_loss + fake_loss) / 2return d_lossgrad_generator_fn = ms.value_and_grad(generator_forward, None,optimizer_G.parameters,has_aux=True)
grad_discriminator_fn = ms.value_and_grad(discriminator_forward, None,optimizer_D.parameters)@ms.jit
def train_step(imgs):valid = ops.ones((imgs.shape[0], 1), mindspore.float32)fake = ops.zeros((imgs.shape[0], 1), mindspore.float32)(g_loss, gen_imgs), g_grads = grad_generator_fn(imgs, valid)optimizer_G(g_grads)d_loss, d_grads = grad_discriminator_fn(imgs, gen_imgs, valid, fake)optimizer_D(d_grads)return g_loss, d_loss, gen_imgs
import mindsporeG_losses = []
D_losses = []
image_list = []total = dataset.get_dataset_size()
for epoch in range(num_epochs):generator.set_train()discriminator.set_train()# 為每輪訓練讀入數據for i, (imgs, ) in enumerate(dataset.create_tuple_iterator()):g_loss, d_loss, gen_imgs = train_step(imgs)if i % 100 == 0 or i == total - 1:# 輸出訓練記錄print('[%2d/%d][%3d/%d]   Loss_D:%7.4f  Loss_G:%7.4f' % (epoch + 1, num_epochs, i + 1, total, d_loss.asnumpy(), g_loss.asnumpy()))D_losses.append(d_loss.asnumpy())G_losses.append(g_loss.asnumpy())# 每個epoch結束后,使用生成器生成一組圖片generator.set_train(False)fixed_noise = ops.standard_normal((batch_size, nz, 1, 1))img = generator(fixed_noise)image_list.append(img.transpose(0, 2, 3, 1).asnumpy())# 保存網絡模型參數為ckpt文件mindspore.save_checkpoint(generator, "./generator.ckpt")mindspore.save_checkpoint(discriminator, "./discriminator.ckpt")
plt.figure(figsize=(10, 5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses, label="G", color='blue')
plt.plot(D_losses, label="D", color='orange')
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()import matplotlib.pyplot as plt
import matplotlib.animation as animationdef showGif(image_list):show_list = []fig = plt.figure(figsize=(8, 3), dpi=120)for epoch in range(len(image_list)):images = []for i in range(3):row = np.concatenate((image_list[epoch][i * 8:(i + 1) * 8]), axis=1)images.append(row)img = np.clip(np.concatenate((images[:]), axis=0), 0, 1)plt.axis("off")show_list.append([plt.imshow(img)])ani = animation.ArtistAnimation(fig, show_list, interval=1000, repeat_delay=1000, blit=True)ani.save('./dcgan.gif', writer='pillow', fps=1)showGif(image_list)

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/44013.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/44013.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/44013.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

【CentOS】Linux命令之docker命令(持續更新)

刪除所有容器 該命令將刪除所有已停止的容器。你還可以使用其他狀態值,例如created、restarting或dead docker container rm $(docker container ls -aqf statusexited)刪除所有鏡像 該命令將刪除所有鏡像,包括被使用的鏡像。請注意,如果某…

【深度學習】第5章——卷積神經網絡(CNN)

一、卷積神經網絡 1.定義 卷積神經網絡(Convolutional Neural Network, CNN)是一種專門用于處理具有網格狀拓撲結構數據的深度學習模型,特別適用于圖像和視頻處理。CNN 通過局部連接和權重共享機制,有效地減少了參數數量&#x…

使用OpencvSharp實現人臉識別

在網上有很多關于這方面的博客,但是都沒有說完整,按照他們的博客做下來代碼都不能跑。所以我就自己寫個博客補充一下 我這使用的.NET框架版本是 .NetFramework4.7.1 使用Nuget安裝這兩個程序包就夠了,不需要其他的配置 一定要安裝OpenCvSha…

大模型日報 2024-07-09

大模型日報 2024-07-09 大模型資訊 大模型最強架構TTT問世!斯坦福UCSD等5年磨一劍,一夜推翻Transformer 斯坦福UCSD等機構研究者提出的TTT方法,直接替代了注意力機制,語言模型方法從此或將徹底改變。這個模型通過對輸入token進行梯…

在亞馬遜云科技AWS上利用SageMaker機器學習模型平臺搭建生成式AI應用(附Llama大模型部署和測試代碼)

項目簡介: 接下來,小李哥將會每天介紹一個基于亞馬遜云科技AWS云計算平臺的全球前沿AI技術解決方案,幫助大家快速了解國際上最熱門的云計算平臺亞馬遜云科技AWS AI最佳實踐,并應用到自己的日常工作里。本次介紹的是如何在Amazon …

802.11漫游流程簡單解析與筆記_Part2_05_wpa_supplicant如何通過nl80211控制內核開始關聯

最近在進行和802.11漫游有關的工作,需要對wpa_supplicant認證流程和漫游過程有更多的了解,所以通過閱讀論文等方式,記錄整理漫游相關知識。Part1將記錄802.11漫游的基本流程、802.11R的基本流程、與認證和漫游都有關的三層秘鑰基礎。Part1將包…

Vue 3與Pinia:下一代狀態管理的探索

引言 隨著Vue 3的推出,Pinia應運而生,成為官方推薦的狀態管理庫,旨在替代Vuex。Pinia與Vuex相比,帶來了以下主要區別和優勢: 更簡潔的API:Pinia的API設計更加直觀和簡潔,易于理解和使用。更好…

220V降5V芯片輸出電壓電流封裝選型WT

220V降5V芯片輸出電壓電流封裝選型WT 220V降5V恒壓推薦:非隔離芯片選型及其應用方案 在考慮220V轉低壓應用方案時,以下非隔離芯片型號及其封裝形式提供了不同的電壓電流輸出能力: 1. WT5101A(SOT23-3封裝)適用于將2…

【實戰場景】大文件解析入庫的方案有哪些?

【實戰場景】大文件解析入庫的方案有哪些? 開篇詞:干貨篇:分塊解析內存映射文件流式處理數據庫集群處理分布式計算框架 總結篇:我是杰叔叔,一名滬漂的碼農,下期再會! 開篇詞: 需求背…

14-57 劍和詩人31 - LLM/SLM 中的高級 RAG

??? 首先確定幾個縮寫的意思 SLM 小模型 LLM 大模型 檢索增強生成 (RAG) 已成為一種增強語言模型能力的強大技術。通過檢索和調整外部知識,RAG 可讓模型生成更準確、更相關、更全面的文本。 RAG 架構主要有三種類型:簡單型、模塊化和高級 RAG&…

性能測試的流程(企業真實流程詳解)(二)

性能測試的流程 1.需求分析以及需求確定(指標值,場景,環境,人員) 一般提出需求的人員有:客戶,產品經理,項目組領導等 2.性能測試計劃和方案制定 基準測試: 負覡測試: 壓力測試: 穩定性測試: 其他:配置測試…

Git安裝使用教程

# 《Git 操作使用教程》 一、Git 簡介 Git 是一個分布式版本控制系統,用于敏捷高效地處理任何或小或大的項目。它讓開發者可以輕松地跟蹤代碼的更改、與團隊成員協作,并管理項目的不同版本。 二、安裝 Git 在 Windows 系統上,可以從 Git 官…

刷題Day47|1143.最長公共子序列、1035.不相交的線、53. 最大子序和、

1143.最長公共子序列 1143. 最長公共子序列 - 力扣(LeetCode) 思路:dp數組含義是以i-1和j-1為結尾的最長公共子序列。當text1[i - 1] text2[i - 1], dp[i][j] dp[i - 1][j - 1] 1; 否則dp[i][j] max(dp[i - 1][j], dp[i][j - 1]); 因為兩…

無法連接Linux遠程服務器的Mysql,解決辦法

問題描述 如果是關閉虛擬機之后,二次打開無法連接Mysql,則可嘗試一下方法進行解決 解決方法 關閉虛擬機的防火墻 1:查看防火墻狀態 systemctl status firewalld 一下顯示說明防火墻是啟動的狀態 2:關閉防火墻 systemctl st…

git提交emoji指南

emoji 指南 emojiemoji 代碼commit 說明🎉 (慶祝)tada初次提交? (火花)sparkles引入新功能🔖 (書簽)bookmark發行/版本標簽🐛 (bug)bug修復 bug🚑 (急救車)ambulance重要補丁🌐 (地球)globe_with_meridians國際化與本…

PTA - 編寫函數計算圓面積

題目描述: 1.要求編寫函數getCircleArea(r)計算給定半徑r的圓面積,函數返回圓的面積。 2.要求編寫函數get_rList(n) 輸入n個值放入列表并將列表返回 函數接口定義: getCircleArea(r); get_rList(n); 傳入的參數r表示圓的半徑&#xff0c…

音視頻解封裝demo:將FLV文件解封裝(demux)得到文件中的H264數據和AAC數據(純手工,不依賴第三方開源庫)

1、README 前言 注意:flv是不支持h.265封裝的。目前解封裝功能正常,所得到的H.264文件與AAC文件均可正常播放。 a. demo使用 $ make clean && make DEBUG1 $ $ $ ./flv_demux_h264_aac Usage: ./flv_demux_h264_aac avfile/test1.flv./flv_d…

壓縮感知1——算法簡介

傳統的數據采集 傳統的數字信號采樣定律就是有名的香農采樣定理,又稱那奎斯特采樣定律定理內容如下:為了不失真地恢復模擬信號,采樣頻率應該不小于模擬信號頻譜中最高頻率的2倍 上述步驟得到的數字信號的數據量比較大,一方面不利…

C語言程序題(一)

一.三個整數從大到小輸出 首先做這個題目需要知道理清排序的思路,通過比較三個整數的值,使之從大到小輸出。解這道題有很多方法我就總結了兩種方法:一是通過中間變量比較和交換,二是可以用冒泡排序法(雖然三個數字排序…

車載聚合路由器應用場景分析

乾元通QYT-X1z車載式1U多卡聚合路由器,支持最多8路聚合,無論是應急救援,還是車載交通,任何寬帶服務商無法覆蓋的區域,聚合路由器可提供現場需要的穩定、流暢、安全的視頻傳輸網絡,聚合路由器可無縫接入應急…