TensorFlow深度學習實戰——DCGAN詳解與實現

TensorFlow深度學習實戰——DCGAN詳解與實現

    • 0. 前言
    • 1. DCGAN 架構
    • 2. 構建 DCGAN 生成手寫數字圖像
      • 2.1 生成器與判別器架構
      • 2.2 構建 DCGAN
    • 相關鏈接

0. 前言

深度卷積生成對抗網絡 (Deep Convolutional Generative Adversarial Network, DCGAN) 是一種基于生成對抗網絡 (Generative Adversarial Network, GAN) 的深度學習模型,主要用于生成圖像。它結合了卷積神經網絡 (Convolutional Neural Network,CNN) 和生成對抗網絡的優勢,以更高效地生成質量更高的圖像。

1. DCGAN 架構

深度卷積生成對抗網絡 (Deep Convolutional Generative Adversarial Network, DCGAN) 引入了卷積神經網絡 (Convolutional Neural Network,CNN) 的結構,主要設計思想是使用卷積層而不使用池化層或分類層。使用卷積的步幅參數和轉置卷積執行下采樣(維度減少)和上采樣(維度增加)。

相比于原始生成對抗網絡 (Generative Adversarial Network, GAN),DCGAN 的主要變化包括:

  • 網絡完全由卷積層組成。池化層替換為步幅卷積(即,在使用卷積層時,將步幅從 1 增加為 2 )用于判別器,而生成器使用轉置卷積
  • 移除卷積后的全連接分類層
  • 為了提高訓練的穩定性,在每個卷積層后使用批歸一化

DCGAN 的基本思想與原始 GAN 相同,生成器接受 100 維的噪聲輸入,經過全連接層后重塑形狀后,通過卷積層處理,生成器架構如下:

生成器架構

判別器接收圖像(可以是生成器生成的圖像或來自真實數據集的圖像),圖像經過卷積處理和批歸一化處理。在每一步卷積中通過步幅參數進行下采樣。卷積層的最終輸出展平后,輸入到一個具有單個神經元的分類層:

判別器

生成器和判別器組合在一起形成 DCGAN。訓練過程與原始 GAN 相同,首先在一個批數據上訓練判別器,然后凍結判別器,訓練生成器,并重復以上過程。實踐證明,使用學習率為 0.002Adam 優化器能得到更穩定的結果。接下來,使用 Tensorflow 實現一個用于生成 MNIST 手寫數字圖像的 DCGAN

2. 構建 DCGAN 生成手寫數字圖像

在本節中,構建一個用于生成 MNIST 手寫數字圖像的 DCGAN

2.1 生成器與判別器架構

生成器通過順序添加網絡層構建。第一層是一個全連接層,接受 100 維的噪聲作為輸入,全連接層將 100 維的輸入擴展為一個大小為 128 × 7 × 7 的一維向量。這樣做的目的是為了最終得到大小為 28 × 28 的輸出,也就是 MNIST 手寫數字圖像的標準大小。該向量重塑為一個大小為 7 × 7 × 128 的張量,然后使用 TensorFlowUpSampling2D 層進行上采樣。需要注意的是,該層只是通過將行和列翻倍來放大圖像,并沒有可訓練權重,因此計算開銷較小。
Upsampling2D 層將 7 × 7 × 128 (行 × 列 × 通道)的圖像的行和列翻倍,得到大小 14 × 14 × 128 的輸出。上采樣后的圖像傳遞給一個卷積層,卷積層學習填充上采樣圖像中的細節,卷積的輸出傳遞到批歸一化層。批歸一化后的輸出經過 ReLU 激活。重復以上結構,即:上采樣-卷積-批歸一化-ReLU。在生成器中,具有兩個這樣的結構,第一個卷積層中使用 128 個卷積核,第二個使用 64 個卷積核。最終輸出使用一個卷積層,使用尺寸為 3 x 3 的單個卷積核和 tanh 激活函數,生成 28 × 28 × 1 的圖像:

    def build_generator(self):model = Sequential()model.add(Dense(128 * 7 * 7, activation="relu", input_dim=self.latent_dim))model.add(Reshape((7, 7, 128)))model.add(UpSampling2D())model.add(Conv2D(128, kernel_size=3, padding="same"))model.add(BatchNormalization(momentum=0.8))model.add(Activation("relu"))model.add(UpSampling2D())model.add(Conv2D(64, kernel_size=3, padding="same"))model.add(BatchNormalization(momentum=0.8))model.add(Activation("relu"))model.add(Conv2D(self.channels, kernel_size=3, padding="same"))model.add(Activation("tanh"))model.summary()noise = Input(shape=(self.latent_dim,))img = model(noise)return Model(noise, img)

生成器模型架構如下:

生成器架構

也可以使用轉置卷積層,轉置卷積層不僅對輸入圖像進行上采樣,而且在訓練過程中學習如何填充細節。因此,可以用一個轉置卷積層來替代上采樣和卷積層,轉置卷積層執行的是反卷積操作。
接下來,構建判別器。判別器類似于標準卷積神經網絡,但區別在于,使用步幅為 2 的卷積層來代替最大池化層。還添加了 dropout 層以避免過擬合,并使用批歸一化以提高準確性和加快收斂速度,激活函數使用 leaky ReLU。在判別器中,使用了三個卷積層,分別具有 3264128 個卷積核。最后一個卷積層的輸出展平后傳遞給一個具有單個單元的全連接層。輸出用于將圖像分類為真實圖像或偽造圖像:

    def build_discriminator(self):model = Sequential()model.add(Conv2D(32, kernel_size=3, strides=2, input_shape=self.img_shape, padding="same"))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.25))model.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))model.add(ZeroPadding2D(padding=((0,1),(0,1))))model.add(BatchNormalization(momentum=0.8))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.25))model.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))model.add(BatchNormalization(momentum=0.8))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.25))model.add(Conv2D(256, kernel_size=3, strides=1, padding="same"))model.add(BatchNormalization(momentum=0.8))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.25))model.add(Flatten())model.add(Dense(1, activation='sigmoid'))model.summary()img = Input(shape=self.img_shape)validity = model(img)return Model(img, validity)

判別器模型架構如下:

判別器架構

2.2 構建 DCGAN

通過將生成器和判別器組合在一起得到完整的 GAN

from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout
from tensorflow.keras.layers import BatchNormalization, Activation, ZeroPadding2D
from tensorflow.keras.layers import LeakyReLU
from tensorflow.keras.layers import UpSampling2D, Conv2D
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adamimport matplotlib.pyplot as plt
import sys
import numpy as npclass DCGAN():def __init__(self, rows, cols, channels, z = 10):# Input shapeself.img_rows = rowsself.img_cols = colsself.channels = channelsself.img_shape = (self.img_rows, self.img_cols, self.channels)self.latent_dim = zoptimizer_1 = Adam(0.0002, 0.5)optimizer_2 = Adam(0.0002, 0.5)# Build and compile the discriminatorself.discriminator = self.build_discriminator()self.discriminator.compile(loss='binary_crossentropy',optimizer=optimizer_1,metrics=['accuracy'])# Build the generatorself.generator = self.build_generator()# The generator takes noise as input and generates imgsz = Input(shape=(self.latent_dim,))img = self.generator(z)# For the combined model we will only train the generatorself.discriminator.trainable = False# The discriminator takes generated images as input and determines validityvalid = self.discriminator(img)# The combined model  (stacked generator and discriminator)# Trains the generator to fool the discriminatorself.combined = Model(z, valid)self.combined.compile(loss='binary_crossentropy', optimizer=optimizer_2)def build_generator(self):model = Sequential()model.add(Dense(128 * 7 * 7, activation="relu", input_dim=self.latent_dim))model.add(Reshape((7, 7, 128)))model.add(UpSampling2D())model.add(Conv2D(128, kernel_size=3, padding="same"))model.add(BatchNormalization(momentum=0.8))model.add(Activation("relu"))model.add(UpSampling2D())model.add(Conv2D(64, kernel_size=3, padding="same"))model.add(BatchNormalization(momentum=0.8))model.add(Activation("relu"))model.add(Conv2D(self.channels, kernel_size=3, padding="same"))model.add(Activation("tanh"))model.summary()noise = Input(shape=(self.latent_dim,))img = model(noise)return Model(noise, img)def build_discriminator(self):model = Sequential()model.add(Conv2D(32, kernel_size=3, strides=2, input_shape=self.img_shape, padding="same"))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.25))model.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))model.add(ZeroPadding2D(padding=((0,1),(0,1))))model.add(BatchNormalization(momentum=0.8))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.25))model.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))model.add(BatchNormalization(momentum=0.8))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.25))model.add(Conv2D(256, kernel_size=3, strides=1, padding="same"))model.add(BatchNormalization(momentum=0.8))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.25))model.add(Flatten())model.add(Dense(1, activation='sigmoid'))model.summary()img = Input(shape=self.img_shape)validity = model(img)return Model(img, validity)

使用 binary_crossentropy 損失函數定義生成器和判別器的損失。生成器和判別器的優化器在初始化方法中定義。最后,定義了一個 TensorFlow 檢查點,用于在模型訓練過程中保存生成器和判別器模型。
DCGAN 的訓練過程與原始 GAN 相同,在每一步中,首先將隨機噪聲輸入到生成器中。生成器的輸出與真實圖像用于訓練判別器,然后訓練生成器,使其生成能夠欺騙判別器的圖像。GAN 的訓練通常需要幾百到數千個訓練 epoch

    def train(self, epochs, batch_size=256, save_interval=50):# Load the dataset(X_train, _), (_, _) = mnist.load_data()# Rescale -1 to 1X_train = X_train / 127.5 - 1.X_train = np.expand_dims(X_train, axis=3)# Adversarial ground truthsvalid = np.ones((batch_size, 1))fake = np.zeros((batch_size, 1))for epoch in range(epochs):# ---------------------#  Train Discriminator# ---------------------# Select a random half of imagesidx = np.random.randint(0, X_train.shape[0], batch_size)imgs = X_train[idx]# Sample noise and generate a batch of new imagesnoise = np.random.normal(0, 1, (batch_size, self.latent_dim))gen_imgs = self.generator.predict(noise)# Train the discriminator (real classified as ones and generated as zeros)d_loss_real = self.discriminator.train_on_batch(imgs, valid)d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)# ---------------------#  Train Generator# ---------------------# Train the generator (wants discriminator to mistake images as real)g_loss = self.combined.train_on_batch(noise, valid)# Plot the progressprint ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))# If at save interval => save generated image samplesif epoch % save_interval == 0:self.save_imgs(epoch)

最后,定義輔助函數保存圖像:

    def save_imgs(self, epoch):r, c = 5, 5noise = np.random.normal(0, 1, (r * c, self.latent_dim))gen_imgs = self.generator.predict(noise)# Rescale images 0 - 1gen_imgs = 0.5 * gen_imgs + 0.5fig, axs = plt.subplots(r, c)cnt = 0for i in range(r):for j in range(c):axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')axs[i,j].axis('off')cnt += 1fig.savefig("images/dcgan_mnist_%d.png" % epoch)plt.close()

訓練 DCGAN 模型:

dcgan = DCGAN(28,28,1)
dcgan.train(epochs=5000, batch_size=128, save_interval=50)

隨著訓練的進行,GAN 學習生成手寫數字的能力逐漸增強:

訓練監控

在第 50 個訓練 epoch,生成的手寫數字圖像質量有了顯著提升:

結果圖像

下圖是將 DCGAN 應用到名人圖像數據集中的一些生成結果:

生成結果

相關鏈接

TensorFlow深度學習實戰(1)——神經網絡與模型訓練過程詳解
TensorFlow深度學習實戰(2)——使用TensorFlow構建神經網絡
TensorFlow深度學習實戰(3)——深度學習中常用激活函數詳解
TensorFlow深度學習實戰(4)——正則化技術詳解
TensorFlow深度學習實戰(5)——神經網絡性能優化技術詳解
TensorFlow深度學習實戰(6)——回歸分析詳解
TensorFlow深度學習實戰(7)——分類任務詳解
TensorFlow深度學習實戰(8)——卷積神經網絡
TensorFlow深度學習實戰(9)——構建VGG模型實現圖像分類
TensorFlow深度學習實戰(10)——遷移學習詳解
TensorFlow深度學習實戰(11)——風格遷移詳解
TensorFlow深度學習實戰(14)——循環神經網絡詳解
TensorFlow深度學習實戰(15)——編碼器-解碼器架構
TensorFlow深度學習實戰(16)——注意力機制詳解
TensorFlow深度學習實戰(23)——自編碼器詳解與實現
TensorFlow深度學習實戰(24)——卷積自編碼器詳解與實現
TensorFlow深度學習實戰(25)——變分自編碼器詳解與實現
TensorFlow深度學習實戰(26)——生成對抗網絡詳解與實現

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/89303.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/89303.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/89303.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

SpringBoot 使用MyBatisPlus

引入依賴<dependency><groupId>com.github.xiaoymin</groupId><artifactId>knife4j-openapi3-jakarta-spring-boot-starter</artifactId><version>4.3.0</version> </dependency>寫一個interface 繼承basemapMapper public in…

Git 中如何查看提交歷史?常用命令有哪些?

回答重點在 Git 中&#xff0c;我們可以使用 git log 命令來查看提交歷史。這個命令會列出所有的提交記錄&#xff0c;顯示每個提交的哈希值、作者信息、提交時間和提交信息。常用的 git log 命令及其選項有&#xff1a;1&#xff09; git log &#xff1a;顯示完整的提交歷史。…

Flink數據流高效寫入MySQL實戰

這段代碼展示了如何使用 Apache Flink 將數據流寫入 MySQL 數據庫&#xff0c;并使用了 JdbcSink 來實現自定義的 Sink 邏輯。以下是對代碼的詳細解析和說明&#xff1a;代碼結構包聲明&#xff1a;package sink定義了代碼所在的包。導入依賴&#xff1a;導入了必要的 Flink 和…

MATLAB下載安裝教程(附安裝包)2025最新版(MATLAB R2024b)

文章目錄前言一、MATLAB R2024b下載二、MATLAB下載安裝教程前言 MATLAB R2024b 的推出&#xff0c;進一步提升了其在工程實踐中的實用性和專業性。它不僅提供了更多針對特定工程領域的解決方案&#xff0c;還在性能和兼容性方面進行了顯著改進。 本教程將一步一步引導完成 MA…

Linux 基礎命令學習,立即上手Linux操作

Linux?基礎命令學習本文挑選最常用、最容易上手的 Linux 命令。每條都附帶一句話說明 真實示例&#xff0c;直接復制即可練習&#xff0c;零基礎也能跟得上。1? 先掌握 目錄導航&#xff1a;pwd?/?ls?/?cdpwd – 顯示當前所在目錄 pwd # 輸出示例 /home/yournamels??a…

Android構建流程與Transform任務

1. 完整構建流程概覽 1.1 主要構建階段 預構建階段 → 代碼生成階段 → 資源處理階段 → 編譯階段 → Transform階段 → 打包階段1.2 詳細任務執行順序 ┌─────────────────────────────────────────────────────────…

CKS認證 | Day6 監控、審計和運行時安全 sysdig、falco、審計日志

一、分析容器系統調用&#xff1a;Sysdig Sysdig&#xff1a;定位是系統監控、分析和排障的工具&#xff0c;在 linux 平臺上&#xff0c;已有很多這方面的工具 如tcpdump、htop、iftop、lsof、netstat&#xff0c;它們都能用來分析 linux 系統的運行情況&#xff0c;而且還有…

Redis:持久化配置深度解析與實踐指南

&#x1f9e0; 1、簡述 Redis 是一款基于內存的高性能鍵值數據庫&#xff0c;為了防止數據丟失&#xff0c;Redis 提供了兩種主要的持久化機制&#xff1a;RDB&#xff08;快照&#xff09;和 AOF&#xff08;追加日志&#xff09;。本文將從原理到配置&#xff0c;再到實際項目…

共創 Rust 十年輝煌時刻:RustChinaConf 2025 贊助與演講征集正式啟動

&#x1f680; 共創 Rust 十年輝煌時刻&#xff1a;RustChinaConf 2025 贊助與演講征集正式啟動2025年&#xff0c;是 Rust 編程語言誕生十周年的里程碑時刻。在這個具有歷史意義的節點&#xff0c;RustChinaConf 2025 攜手 RustGlobal 首次登陸中國&#xff0c;聯合 GOSIM HAN…

EMS4100芯祥科技USB3.1高速模擬開關芯片規格介紹

EMS4100一款適用于USB Type-C應用的二通道差分2:1/1:2 USB 3.1高速雙向被動開關。該器件支持USB 3.1 Gen 1和Gen 2數據速率,具有高帶寬、低串擾、寬供電電壓范圍等特點。EMS4100芯片內部框架&#xff1a;EMS4100主要特性&#xff1a;2-獨立頻道1&#xff1a;2/2&#xff1a;1 M…

HTML 常用語義標簽與常見搭配詳解

一、什么是語義標簽&#xff1f; 語義標簽是 HTML5 引入的一組具有特定含義的標簽&#xff0c;用于描述頁面中不同部分的內容類型&#xff0c;如頁眉、導航欄、主內容區域、側邊欄、頁腳等。相比傳統的 <div> 和 <span>&#xff0c;語義標簽更具表達力和結構化。 …

遷移學習的概念和案例

遷移學習概念 預訓練模型 定義: 簡單來說別人訓練好的模型。一般預訓練模型具備復雜的網絡模型結構&#xff1b;一般是在大量的語料下訓練完成的。 預訓練語言模型的類別&#xff1a; 現在我們接觸到的預訓練語言模型&#xff0c;基本上都是基于transformer這個模型迭代而來…

DAOS系統架構-RDB

1. 概述 基于Raft共識算法和強大的領導地位策略&#xff0c;pool service和container service可以通過復制其內部的元數據來實現高可用。通過這種方法實現具有副本能力的服務可以容忍少數副本中的任何一個出現故障。通過將每個服務的副本分布在容災域中&#xff0c;pool servic…

深入GPU硬件架構及運行機制

轉自深入GPU硬件架構及運行機制 - 0向往0 - 博客園&#xff0c;基本上是其理解。 一、GPU概述 1.1 GPU是什么&#xff1f; GPU全稱是Graphics Processing Unit&#xff0c;圖形處理單元。它的功能最初與名字一致&#xff0c;是專門用于繪制圖像和處理圖元數據的特定芯片&…

數值計算庫:Eigen與Boost.Multiprecision全方位解析

在科學計算、工程模擬、機器學習等領域&#xff0c;高效的數值計算能力是構建高性能應用的基石。C作為性能優先的編程語言&#xff0c;擁有眾多優秀的數值計算庫&#xff0c;其中Eigen和Boost.Multiprecision是兩個極具代表性的工具。本文將深入探討這兩個庫的核心特性、使用場…

第十八節:第三部分:java高級:反射-獲取構造器對象并使用

Class提供的獲取類構造器的方法以及獲取類構造器的作用代碼&#xff1a;掌握獲取類的構造器&#xff0c;并對其進行操作 Cat類 package com.itheima.day9_reflect;public class Cat {private String name;private int age;private Cat(String name, int age) {this.name name;…

集中打印和轉換Office 批量打印精靈:Word/Excel/PDF 全兼容,效率翻倍

各位辦公小能手們&#xff01;你們平時辦公的時候&#xff0c;是不是經常要打印一堆文件&#xff0c;煩得要命&#xff1f;別慌&#xff0c;今天我給大家介紹一款超厲害的神器——Office批量打印精靈&#xff01; 軟件下載地址安裝包 這玩意兒啊&#xff0c;是專門為高效辦公設…

docker的搭建

一、安裝docker使用以下命令進行安裝dockerapt-get install docker.io docker-compose使用以下命令進行查看docker是否開啟systemctl status docker由此可見&#xff0c;docker沒有打開&#xff0c;進行使用命令打開。systemctl start docker再次查看是否開啟。肉眼可見&#x…

數據庫管理-第349期 Oracle DB 23.9新特性一覽(20250717)

數據庫管理349期 2025-07-17數據庫管理-第349期 Oracle DB 23.9新特性一覽&#xff08;20250717&#xff09;1 JavaScript過程和函數的編譯時語法檢查2 不再需要JAVASCRIPT上的EXECUTE權限3 GROUP BY ALL4 使用SQL創建并測試UUID5 IVF索引在線重組6 JSON到二元性遷移器&#xf…

將CSDN文章導出為PDF

作者&#xff1a;翟天保Steven 版權聲明&#xff1a;著作權歸作者所有&#xff0c;商業轉載請聯系作者獲得授權&#xff0c;非商業轉載請注明出處前言在日常學習和技術積累過程中&#xff0c;我們經常會在 CSDN 等技術博客平臺上閱讀高質量的技術文章。然而&#xff0c;網頁閱讀…