生成對抗網絡詳解與實現

生成對抗網絡詳解與實現

    • 0. 前言
    • 1. GAN 原理
    • 2. GAN 架構
    • 3. 損失函數
      • 3.1 判別器損失
      • 3.2 生成器損失
      • 3.4 VANILLA GAN
    • 4. GAN 訓練步驟

0. 前言

生成對抗網絡 (Generative Adversarial Network, GAN) 是圖像和視頻生成中的主要方法之一。在本節中,我們將了解 GAN 的架構、訓練步驟等,并實現原始 GAN

1. GAN 原理

生成模型的目的是學習數據分布并從中進行采樣以生成新數據。PixelCNN 和變分自編碼器 (Variational Autoencoder, VAE),它們的生成部分將著眼于訓練過程中的圖像分布。因此,稱為顯式密度模型 (explicit density models)。相比之下,GAN 中的生成部分不會直接查看圖像。因此,GAN 被歸類為隱式密度模型 (implicit density models)。
我們可以使用一個類比來比較顯式模型和隱式模型。假設一位藝術系學生 G 獲得了畢加索的畫作收藏,并被要求學習繪制假畢加索畫作。學生可以在學習繪畫時查看收藏,因此這是一個顯式模型。在另一種情況下,我們要求學生 G 偽造畢加索的畫,但我們沒有給他們看任何畫,他們也不知道畢加索的畫是什么樣。他們學習的唯一方法是學生 D 的反饋,后者正在學習判別假畢加索的畫作。反饋很簡單——這幅畫是假的還是真實的。這就是我們的隱式密度 GAN 模型。
也許有一天,G 偶然地畫了一張扭曲的臉,并從反饋中得知它看起來像一幅真正的畢加索畫,然后他們開始以這種方式來欺騙學生 D。學生 GDGAN 中的兩個網絡,稱為生成器和判別器。與其他生成模型相比,這是網絡體系結構的最大區別。我們將從了解 GAN 構建塊開始,然后介紹損失函數。然后,我們將為 GAN 創建自定義的訓練步驟。

2. GAN 架構

生成對抗網絡中的對抗一詞是指包含對立或異議。有兩個相互競爭的網絡,稱為生成器和判別器。顧名思義,生成器生成偽造的圖像。而辨別器將查看生成的圖像,以確定它們是真實的還是偽造的。每個網絡都試圖贏得這場比賽,判別器要正確識別每個真實和偽造的圖像,而生成器則要愚弄判別器以使其所產生的虛假圖像被判別器判定是真實的。下圖顯示了 GAN 的體系結構:

GAN

GAN 架構與 VAE 有一些相似之處。如果 VAE 由兩個獨立的網絡組成,我們可以想到:

  • GAN 的生成器作為 VAE 的解碼器
  • GAN 的判別器作為 VAE 的編碼器

生成器將低維和簡單分布轉換為具有復雜分布的高維圖像,就像解碼器一樣。生成器的輸入通常是來自正態分布的樣本,也有些樣本使用均勻分布。
我們將不同批次的圖像發送給判別器。真實圖像是來自數據集的圖像,而偽造圖像則是由生成器生成的。判別器輸出輸入圖片是真還是假的單值概率。它是一個二進制分類器,可以使用 CNN 來實現它。從技術上講,判別器的作用與編碼器不同,但它們都減小了輸入的維數。
實際上,原始的 GAN 僅使用了多層感知器,該感知器由一些基本的全連接層組成。

3. 損失函數

損失函數體現了 GAN 的工作原理。公式如下:
minGmaxDV(D,G)=EX~Pdata(x)[logD(x)]+EZ~Pz(z)[log(1?D(G(z)))]min_Gmax_DV(D,G)=E_{X\sim P_data(x)}[logD(x)]+E_{Z\sim P_z(z)}[log(1-D(G(z)))] minG?maxD?V(D,G)=EXPd?ata(x)?[logD(x)]+EZPz?(z)?[log(1?D(G(z)))]

其中:DDD 表示判別器,GGG 表示生成器,xxx 表示輸入數據,zzz 表示潛變量。
了解 GAN 的損失函數之后,代碼實現將變得更加容易。此外,有關 GAN 改進的許多討論都圍繞損失函數進行。GAN 損失函數也稱為對抗損失。接下來,我們將對其進行分解,并逐步向展示如何將其轉換為我們可以實現的簡單損失函數。

3.1 判別器損失

GAN 損失函數的等式右側的第一項是用于正確分類真實圖像的值。從等式左邊的項來看,我們知道判別器想要將其最大化。期望是一個數學術語,是隨機變量每個樣本的加權平均值之和。在此等式中,權重是數據的概率,而變量是判別器輸出的對數,如下所示:
EX[logD(x)]=∑i=1Np(x)logD(x)=1N∑i=1NlogD(x)E_X[logD(x)]=\sum_{i=1}^Np(x)logD(x)=\frac 1N\sum_{i=1}^NlogD(x) EX?[logD(x)]=i=1N?p(x)logD(x)=N1?i=1N?logD(x)
在大小為 NNN 的小批次中,p(x)p(x)p(x)1N\frac 1 NN1?。這是因為 xxx 是單個圖像。不必嘗試使它最大化,我們可以將符號更改為減號并嘗試使其最小化。這可以借助以下方程來完成,該方程稱為對數損失:
minDV(D)=?1N∑i=1NlogD(x)=?1N∑i=1Nyilogp(yi)min_DV(D)=-\frac 1N\sum_{i=1}^NlogD(x)=-\frac 1N\sum_{i=1}^Ny_ilogp(y_i) minD?V(D)=?N1?i=1N?logD(x)=?N1?i=1N?yi?logp(yi?)
其中:yiy_iyi? 是標簽,對于真實圖像為 1p(yi)p(y_i)p(yi?) 是樣本為真的概率。
GAN 損失函數的等式右側的第二項是關于偽造圖像的。zzz 是隨機噪聲,并且 G(z)G(z)G(z) 是生成圖像。D(G(z))D(G(z))D(G(z)) 是判別器對圖像真實可能性的置信度得分。如果我們將標簽 0 用于偽造圖像,則可以使用相同的方法將其轉換為以下等式:
?EZ~Pz(z)[log(1?D(G(z))]=?1N∑i=1N(1?yi)log(1?p(yi))-E_{Z\sim P_z(z)}[log(1-D(G(z))]=-\frac 1N\sum_{i=1}^N(1-y_i)log(1-p(y_i)) ?EZPz?(z)?[log(1?D(G(z))]=?N1?i=1N?(1?yi?)log(1?p(yi?))
現在,將所有內容放在一起,我們有了判別器損失函數,即二進制交叉熵損失:
minDV(D)=?1N∑i=1Nyilogp(yi)+(1?yi)log(1?p(yi))min_DV(D)=-\frac 1N\sum_{i=1}^Ny_ilogp(y_i)+(1-y_i)log(1-p(y_i)) minD?V(D)=?N1?i=1N?yi?logp(yi?)+(1?yi?)log(1?p(yi?))
使用以下代碼實現判別器損失:

def discriminator_loss(pred_fake, pred_real):real_loss = bce(tf.ones_like(pred_real), pred_real)fake_loss = bce(tf.zeros_like(pred_fake), pred_fake)d_loss = 0.5 *(real_loss + fake_loss)return d_loss

在我們的訓練中,我們使用相同的批大小分別對真實和偽造圖像進行前向傳遞。因此,我們分別為它們計算二進制交叉熵損失,并取平均值作為損失。

3.2 生成器損失

僅當模型判別偽造圖像時才涉及生成器,因此我們只需要查看 GAN 損失函數的等式右側第二項并將其簡化為:
minGV(G)=EZ~Pz(z)[log(1?D(G(z))]min_GV(G)=E_{Z\sim P_z(z)}[log(1-D(G(z))] minG?V(G)=EZPz?(z)?[log(1?D(G(z))]
在訓練開始時,生成器并不擅長生成圖像,因此判別器始終有信心將其歸類為 0,使 D(G(z))D(G(z))D(G(z)) 始終為 0log(1–0)log(1 – 0)log(1–0) 也是如此。當模型輸出中的誤差始終為 0 時,則沒有反向傳播的梯度。結果,生成器的權重未更新,并且生成器未學習。由于判別器的 sigmoid 輸出幾乎沒有梯度,因此這種現象稱為梯度飽和 (saturating gradient)。為避免此問題,將等式從最小化 1?D(G(z))1-D(G(z))1?D(G(z)) 到最大化 D(G(z))D(G(z))D(G(z)) 進行如下轉換:
maxGV(G)=EZ~Pz(z)[logD(G(z))]max_GV(G)=E_{Z\sim P_z(z)}[logD(G(z))] maxG?V(G)=EZPz?(z)?[logD(G(z))]
使用此函數的 GAN 也稱為非飽和 GAN (Non-Saturating GANs, NS-GAN)。實際上,Vanilla GAN 的實現都使用此損失函數而不是原始的 GAN 損失函數。

3.4 VANILLA GAN

GAN 誕生后,研究人員對 GAN 的興趣激增,提出了一系列改進模型。Vanilla GAN 是泛指基本 GANVanilla GAN 通常使用具有兩個或三個隱藏的全連接層來實現。
我們可以對判別器使用相同的數學步驟來推導生成器損失,最終將得到相同的判別器損失函數,只是將標簽 1 用于偽造圖像。為什么要對偽造圖片使用標簽 1,我們也可以這樣理解它——因為我們想欺騙判別器以假定那些生成的圖像是真實的,因此我們使用標簽 1

def generator_loss(pred_fake):g_loss = bce(tf.ones_like(pred_fake), pred_fake)return g_loss

4. GAN 訓練步驟

為了在 TensorFlow 中訓練神經網絡,我們需要指定模型,損失函數,優化器,然后調用 model.fit()TensorFlow 將為我們完成所有工作,我們等待損失減少。
在研究 GAN 問題之前,我們首先回顧神經網絡在進行單個訓練步驟時代碼執行的情況:

  • 執行前向傳播以計算損失
  • 使用損失相對于權重的梯度向后傳播
  • 然后,這是更新權重。優化器將縮放梯度并將其添加到權重中,從而完成一個訓練步驟

這些是深度神經網絡中的通用訓練步驟。各種優化器的不同之處僅在于它們計算縮放因子的方式。
現在回到 GAN,查看梯度流。當我們訓練真實圖像時,只涉及判別器–網絡輸入是真實圖像,輸出是 1 的標簽。當我們使用偽造圖像并且梯度通過判別器反向傳播到生成器時,就會出現問題。讓我們將偽造圖像的生成器損失和判別器損失并排放置:

g_loss = bce(tf.ones_like(pred_fake), pred_fake)
fake_loss = bce(tf.zeros_like(pred_fake), pred_fake)

可以發現它們之間的差異,它們的標簽是相反!這意味著,使用生成器損失來訓練整個模型將使判別器朝相反的方向移動,而不會學會進行判別。這適得其反,我們不想有一個未經訓練的判別器,這會阻止生成器學習。因此,我們必須分別訓練生成器和判別器。訓練生成器時,我們將凍結判別器權重。
有多種方法可以設計 GAN 訓練流程。一種是使用高級 Keras 模型,該模型需要較少的代碼,因此看起來更優雅。我們只需要定義一次模型,然后調用 train_on_batch() 即可執行所有步驟,包括前向計算,反向傳播和權重更新。但是,在實現更復雜的損失函數時,靈活性較差。
另一種方法是使用低級函數,以便控制每個步驟。在本節中,GAN 將使用自定義訓練步驟:

def train_step(g_input, real_input):with tf.GradientTape() as g_tape,\tf.GradientTape() as d_tape:# Forward passfake_input = G(g_input)pred_fake = D(fake_input)pred_real = D(real_input)   # Calculate lossesd_loss = discriminator_loss(pred_fake, pred_real)g_loss = generator_loss(pred_fake)

tf.GradientTape() 用于記錄單次通過的梯度。另一個具有類似功能的 APItf.Gradient(),但后者在 TensorFlow Eager 執行中不起作用。我們將看到如何在 train_step() 中實現前面提到的三個過程步驟。前面的代碼段顯示了執行前向傳遞以計算損失的第一步。
第二步是使用 tape 梯度從它們各自的損失計算生成器和判別器的梯度:

        gradient_g = g_tape.gradient(g_loss, G.trainable_variables)gradient_d = d_tape.gradient(d_loss, D.trainable_variables)

第三步也是最后一步是使用優化器將梯度應用于模型權重:

        G_optimizer.apply_gradients(zip(gradient_g, self.G.trainable_variables))D_optimizer.apply_gradients(zip(gradient_d, self.D.trainable_variables))

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/96575.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/96575.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/96575.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

FPGA硬件開發-XPE工具的使用

目錄 XPE 工具概述? XPE 使用步驟詳解? 1. 工具獲取與初始化? 2. 器件選擇與配置? 3. 電源電壓設置? 4. 資源使用量配置? 5. 時鐘與開關活動配置? 6. 功耗計算與報告生成? 報告解讀與電源設計優化? 常見問題與最佳實踐? 與實際功耗的差異處理? 工具版本…

CentOS 7.9 RAID 10 實驗報告

文章目錄CentOS 7.9 RAID 10 實驗報告一、實驗概述1.1 實驗目的1.2 實驗環境1.3 實驗拓撲二、實驗準備2.1 磁盤準備2.2 安裝必要軟件三、RAID 10陣列創建3.1 創建RAID 10陣列3.2 創建文件系統并掛載3.3 保存RAID配置四、性能基準測試4.1 初始性能測試4.2 創建測試數據集五、故障…

機器人逆運動學進階:李代數、矩陣指數與旋轉流形計算

做機器人逆運動學(IK)的時候,你遲早會遇到矩陣指數和對數這些東西。為什么呢?因為計算三維旋轉的誤差,不能簡單地用歐氏距離那一套,那只對位置有效。旋轉得用另一套方法——你需要算兩個旋轉矩陣之間的差異…

計算機視覺(opencv)實戰十八——圖像透視轉換

圖像透視變換詳解與實戰在圖像處理中,透視變換(Perspective Transform) 是一種常見的幾何變換,用來將圖像中某個四邊形區域拉伸或壓縮,映射到一個矩形區域。常見應用場景包括:糾正拍照時的傾斜(…

【飛書多維表格插件】

coze中添加飛書多維表格記錄插件 添加單條記錄 [{"fields":{"任務詳情":"選項1","是否完成":"未完成"}}]添加多條記錄 [{"fields":{"任務詳情":"選項1","是否完成":"已完…

Java基礎 9.14

1.Collection接口遍歷對象方式2-for循環增強增強for循環,可以代替iterator選代器,特點:增強for就是簡化版的iterator本質一樣 只能用于遍歷集合或數組package com.logic.collection_;import java.util.ArrayList; import java.util.Collectio…

數據結構(C語言篇):(十三)堆的應用

目錄 前言 一、堆排序 1.1 版本一:基于已有數組建堆、取棧頂元素完成排序 1.1.1 實現邏輯 1.1.2 底層原理 1.1.3 應用示例 1.1.4 執行流程 1.2 版本二:原地排序 —— 標準堆排序 1.2.1 實現邏輯 1.2.2 底層原理 1.2.3 時間復雜度計算…

4步OpenCV-----掃秒身份證號

這段代碼用 OpenCV 做了一份“數字模板字典”,然后在銀行卡/身份證照片里自動找到身份證號那一行,把每個數字切出來跟模板比對,最終輸出并高亮顯示出完整的身份證號碼,下面是代碼解釋:模塊 1 工具箱(通用函…

馮諾依曼體系:現代計算機的基石與未來展望

馮諾依曼體系:現代計算機的基石與未來展望 引人入勝的開篇 當你用手機刷視頻、用電腦辦公時,是否想過這些設備背后共享的底層邏輯?從指尖輕滑切換APP,到電腦秒開文檔,這種「無縫銜接」的體驗,其實藏著一個改…

前端基礎 —— C / JavaScript基礎語法

以下是對《3.JavaScript(基礎語法).pdf》的內容大綱總結:---📘 一、JavaScript 簡介 - 定義:腳本語言,最初用于表單驗證,現為通用編程語言。 - 應用:網頁開發、游戲、服務器(Node.js&#xff09…

springboot 二手物品交易系統設計與實現

springboot 二手物品交易系統設計與實現 目錄 【SpringBoot二手交易系統全解析】從0到1搭建你的專屬平臺! 🔍 需求確認:溝通對接 🗣 📊 系統功能結構:附思維導圖 ☆開發技術: &#x1f6e…

【Android】可折疊式標題欄

在 Android 應用開發中,精美的用戶界面可以顯著提升應用品質和用戶體驗。Material Design 組件中的 CollapsingToolbarLayout 能夠為應用添加動態、流暢的折疊效果,讓標題欄不再是靜態的元素。本文將深入探討如何使用 CollapsingToolbarLayout 創建令人驚…

Debian13下使用 Vim + Vimspector + ST-LINK v2.1 調試 STM32F103 指南

1. 硬件準備與連接 1.1 所需硬件 STM32F103C8T6 最小系統板ST-LINK v2.1 調試器連接線(杜邦線) 1.2 硬件連接 ST-LINK v2.1 ? STM32F103C8T6 連接方式:ST-LINK v2.1 引腳STM32F103C8T6 引腳功能說明SWDIOPA13數據線SWCLKPA14時鐘線GNDGND共地…

第21課:成本優化與資源管理

第21課:成本優化與資源管理 課程目標 掌握計算資源優化 學習成本控制策略 了解資源調度算法 實踐實現成本優化系統 課程內容 21.1 成本分析框架 成本分析系統 class CostAnalysisFramework {constructor(config) {this.config

SAP HANA Scale-out 04:CalculationView優化

CV執行過程計算視圖激活時,生成Stored ModelSELECT查詢時:首先將Stored Model實例化為runtime Model 計算引擎執行優化,將runtime Model轉換為Optimized Runtime ModelOptimized Runtime Model通過SQL Optimizer進行優化計算引擎優化特性說明…

鴻蒙審核問題——Scroll中嵌套了List/Grid時滑動問題

文章目錄背景原因解決辦法1、借鑒Flutter中的解決方式,如下圖2、鴻蒙Next中對應的解決方式,如下圖3、官方文檔回訪背景 來源一次審核被拒的情況。也是出于粗心導致的。之前在flutter項目中也是遇到過這種問題的。其實就是滾動視圖內嵌滾動視圖造成的&am…

測試電商購物車功能,設計測試case

在電商場景中,購物車是連接商品瀏覽與下單支付的關鍵環節,需要從功能、性能、兼容性、安全性等多維度進行測試。以下是購物車功能的測試用例設計: 一、功能測試 1. 商品添加到購物車 - 未登錄狀態下,添加商品到購物車(…

Linux --- 常見的基本指令

一. 前言本篇博客使用的 Linux 操作系統是 centos ,用來學習Linux 的 Linux 系統的內核版本和系統架構信息版本如下所示:上圖的主要結構為:主版本號-次版本號 修正次數,3.10.0 是操作系統的主版本號;當我們在維護一段L…

微信小程序 -開發郵箱注冊驗證功能

一、前端驗證:正則表達式與插件結合正則表達式設計 使用通用郵箱格式校驗正則,并允許中文域名(如.中國): const emailReg /^[a-zA-Z0-9._%-][a-zA-Z0-9-](?:\.[a-zA-Z0-9-])*\.[a-zA-Z]{2,}(?:\.[a-zA-Z]{2})?$/i;…

docker 部署 code-server

docker 部署 code-servercode-serverError response from daemon: Get "https://registry-1.docker.io/v2/": net/http: request canceled while waiting for connection (Client.Timeout exceeded while awaiting headersdocker 配置正確步驟 阿里云源permission de…