生成對抗網絡(Generative Adversarial Networks ,GAN)

生成對抗網絡是深度學習領域最具革命性的生成模型之一。

一 GAN框架

1.1組成

構造生成器(G)與判別器(D)進行動態對抗,實現數據的無監督生成。

G(造假者):接收噪聲 $z \sim p_z$?,生成數據 $G(z)$?。?

D(鑒定家):接收真實數據 $x \sim p_{data}$ 和生成數據 $G(z)$,輸出概率 $D(x)$?或 $D(G(z))$

1.2核心原理

對抗目標:

$\min_G \max_D V(D,G) = \mathbb{E}_{x \sim p_{data}}[\log D(x)] + \mathbb{E}_{z \sim p_z}[\log(1 - D(G(z)))]$

該公式為極大極小博弈,D和G互為對手,在動態博弈中驅動模型逐步提升性能。

其中:

第一項(真實性強化)?$\mathbb{E}_{x \sim p_{data}}[\log D(x)]$?的目標為:

讓判別器 D 將真實數據?x?識別為“真”(即讓 $D(x) \to 1$),最大化這一項可使D對真實數據的判斷置信度更高。

第二項(生成性對抗):$\mathbb{E}_{z \sim p_z}[\log(1 - D(G(z)))]$的目標為:

生成器G希望生成的假數據$G(z)$被判別器D判為“真”(即讓 $D(G(z)) \to 1$)從而最小化$\log(1 - D(G(z)))$,判別器D則希望判假數據為“假”(即讓$D(G(z)) \to 0$),從而最大化$\log(1 - D(G(z)))$

生成器和判別器在此項上存在直接對抗。

數學原理

(1)最優判別器理論

固定生成器G時,最大化?$V(D,G)$?得到最優判別器?$D^*$

$D^*(x) = \frac{p_{data}(x)}{p_{data}(x) + p_g(x)}$

其中,?$p_g(x)$?是生成數據的分布。當生成數據完美匹配真實分布時?$p_{g} = p_{data}$,判別器無法區分真假(輸出$D(x)=0.5 \quad \forall x$)。

推導一下:

$V(D, G) = \mathbb{E}_{x \sim p_{data}}[\log D(x)] + \mathbb{E}_{z \sim p_z}[\log(1 - D(G(z)))]$? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? (1)

轉換為積分形式,設真實數據分布為?$p_{data}(x)$,生成數據分布為$p_g(x)$,則(1)可以改寫為:

$V(D, G) = \int_{x} p_{data}(x) \log D(x) dx + \int_{x} p_g(x) \log(1 - D(x)) dx$? ? ? ? ? ? ? ? ? ? ? ? ? ? ?(2)?

逐點最大化對于每個樣本 x,單獨最大化以下函數:

$f(D(x)) = p_{data}(x) \log D(x) + p_g(x) \log(1 - D(x))$? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? (3)

求導并解方程:?

$f'(D(x)) = \frac{p_{data}(x)}{D(x)} - \frac{p_g(x)}{1 - D(x)} = 0$? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? (4)

求得:

$D^*(x) = \frac{p_{data}(x)}{p_{data}(x) + p_g(x)}$? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?? ? ? ? (5)?

(2) 目標化簡:JS散度(Jensen-Shannon Divergence)?

將最優判別器?$D^*$?代入原目標函數,可得:

$V(G,D^*) = 2 \cdot JSD(p_{data} \| p_g) - 2\log 2$

最小化目標即等價于最小化?$p_{data}$$p_g$ 的JS散度。

JS散度特性:對稱、非負,衡量兩個分布的相似性。

1.3訓練過程解釋

每個訓練步驟包含兩階段:

(1)判別器更新(固定G,最大化 $V(D,G)$

$\nabla_D \left[ \log D(x) + \log(1 - D(G(z))) \right]$

通過梯度上升優化D的參數,提升判別能力。

(2)生成器更新(固定D,最小化 $V(D,G)$

$\nabla_G \left[ \log(1 - D(G(z))) \right]$

實際訓練中常用 $\nabla_G \left[ -\log D(G(z)) \right]$代替以增強梯度穩定性。

訓練中出現的問題

(1)JS散度飽和導致梯度消失

(2)參數空間的非凸優化(存在無數個局部極值,優化算法極易陷入次優解,而非全局最優解)使訓練難以收斂

二 經典GAN架構

DCGAN(GAN+卷積)

特性原始GANDCGAN
網絡結構全連接層(MLPs)卷積生成器 + 卷積判別器
穩定性容易梯度爆炸/消失,難以收斂通過BN和特定激活函數穩定訓練
生成圖像分辨率低分辨率(如32x32)支持64x64及以上分辨率的清晰圖像生成
圖像質量輪廓模糊,缺乏細節細粒度紋理(如毛發、磚紋)
計算效率參數量大,訓練速度慢卷積結構參數共享,效率提升

(1)生成器架構(反卷積)

class Generator(nn.Module):def __init__(self, noise_dim=100, output_channels=3):super().__init__()self.main = nn.Sequential(# 輸入:100維噪聲,輸出:1024x4x4nn.ConvTranspose2d(noisel_dim, 512, 4, 1, 0, bias=False),nn.BatchNorm2d(512),nn.ReLU(True),# 上采樣至8x8nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),nn.BatchNorm2d(256),nn.ReLU(True),# 輸出層:3通道RGB圖像nn.ConvTranspose2d(64, output_channels, 4, 2, 1, bias=False),nn.Tanh()  # 將輸出壓縮到[-1,1])

(2) 判別器架構(卷積)

class Discriminator(nn.Module):def __init__(self, input_channels=3):super().__init__()self.main = nn.Sequential(# 輸入:3x64x64nn.Conv2d(input_channels, 64, 4, 2, 1, bias=False),nn.LeakyReLU(0.2, inplace=True),# 下采樣至32x32nn.Conv2d(64, 128, 4, 2, 1, bias=False),nn.BatchNorm2d(128),nn.LeakyReLU(0.2, inplace=True),# 輸出層:二分類概率nn.Conv2d(512, 1, 4, 1, 0, bias=False),nn.Sigmoid())

(3)雙重優化問題

保持生成器和判別器動態平衡的核心機制

for epoch in range(num_epochs):# 更新判別器optimizer_D.zero_grad()real_loss = adversarial_loss(D(real_imgs), valid)fake_loss = adversarial_loss(D(gen_imgs.detach()), fake)d_loss = (real_loss + fake_loss) / 2d_loss.backward()optimizer_D.step()# 更新生成器optimizer_G.zero_grad()g_loss = adversarial_loss(D(gen_imgs), valid)  # 欺詐判別器g_loss.backward()optimizer_G.step()

三 應用場景?

圖像合成引擎(語義圖到照片)、醫學影像增強、語音與音頻合成。

GAN作為生成式AI的基石模型,其核心價值不僅在于數據生成能力,更在于構建了一種全新的深度學習范式——通過對抗博弈驅動模型持續進化。

四 一個完整DCGAN代碼示例

MNIST數據集

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np# 參數設置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 128
image_size = 64
num_epochs = 50
latent_dim = 100
lr = 0.0002
beta1 = 0.5# 數據準備
transform = transforms.Compose([transforms.Resize(image_size),transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))  # MNIST是單通道
])dataset = datasets.MNIST(root="./data", train=True, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)# 可視化輔助函數
def show_images(images):plt.figure(figsize=(8,8))images = images.permute(1,2,0).cpu().numpy()plt.imshow((images * 0.5) + 0.5)  # 反歸一化plt.axis('off')plt.show()# 權重初始化
def weights_init(m):classname = m.__class__.__name__if classname.find('Conv') != -1:nn.init.normal_(m.weight.data, 0.0, 0.02)elif classname.find('BatchNorm') != -1:nn.init.normal_(m.weight.data, 1.0, 0.02)nn.init.constant_(m.bias.data, 0)# 生成器定義
class Generator(nn.Module):def __init__(self, latent_dim):super(Generator, self).__init__()self.main = nn.Sequential(# 輸入:latent_dim x 1 x 1nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False),nn.BatchNorm2d(512),nn.ReLU(True),# 輸出:512 x 4 x 4nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),nn.BatchNorm2d(256),nn.ReLU(True),# 輸出:256 x 8 x 8nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),nn.BatchNorm2d(128),nn.ReLU(True),# 輸出:128 x 16x16nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),nn.BatchNorm2d(64),nn.ReLU(True),# 輸出:64 x 32x32nn.ConvTranspose2d(64, 1, 4, 2, 1, bias=False),nn.Tanh()  # 輸出范圍[-1,1]# 最終輸出:1 x 64x64)def forward(self, input):return self.main(input)# 判別器定義
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.main = nn.Sequential(# 輸入:1 x 64x64nn.Conv2d(1, 64, 4, 2, 1, bias=False),nn.LeakyReLU(0.2, inplace=True),# 輸出:64 x32x32nn.Conv2d(64, 128, 4, 2, 1, bias=False),nn.BatchNorm2d(128),nn.LeakyReLU(0.2, inplace=True),# 輸出:128x16x16nn.Conv2d(128, 256, 4, 2, 1, bias=False),nn.BatchNorm2d(256),nn.LeakyReLU(0.2, inplace=True),# 輸出:256x8x8nn.Conv2d(256, 512, 4, 2, 1, bias=False),nn.BatchNorm2d(512),nn.LeakyReLU(0.2, inplace=True),# 輸出:512x4x4nn.Conv2d(512, 1, 4, 1, 0, bias=False),nn.Sigmoid())def forward(self, input):return self.main(input).view(-1, 1).squeeze(1)# 初始化模型
generator = Generator(latent_dim).to(device)
discriminator = Discriminator().to(device)# 應用權重初始化
generator.apply(weights_init)
discriminator.apply(weights_init)# 損失函數和優化器
criterion = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))# 訓練過程可視化準備
fixed_noise = torch.randn(64, latent_dim, 1, 1, device=device)# 訓練循環
for epoch in range(num_epochs):for i, (real_images, _) in enumerate(dataloader):# 準備數據real_images = real_images.to(device)batch_size = real_images.size(0)# 真實標簽和虛假標簽real_labels = torch.full((batch_size,), 0.9, device=device)  # label smoothingfake_labels = torch.full((batch_size,), 0.0, device=device)# ========== 訓練判別器 ==========optimizer_D.zero_grad()# 真實圖片的判別結果outputs_real = discriminator(real_images)loss_real = criterion(outputs_real, real_labels)# 生成假圖片noise = torch.randn(batch_size, latent_dim, 1, 1, device=device)fake_images = generator(noise)# 假圖片的判別結果outputs_fake = discriminator(fake_images.detach())loss_fake = criterion(outputs_fake, fake_labels)# 合并損失并反向傳播loss_D = loss_real + loss_fakeloss_D.backward()optimizer_D.step()# ========== 訓練生成器 ==========optimizer_G.zero_grad()# 更新生成器時的判別結果outputs = discriminator(fake_images)loss_G = criterion(outputs, real_labels)  # 欺騙判別器# 反向傳播loss_G.backward()optimizer_G.step()# 打印訓練狀態if i % 100 == 0:print(f"[Epoch {epoch}/{num_epochs}] [Batch {i}/{len(dataloader)}] "f"Loss_D: {loss_D.item():.4f} Loss_G: {loss_G.item():.4f}")# 每個epoch結束時保存生成結果with torch.no_grad():test_images = generator(fixed_noise)grid = torchvision.utils.make_grid(test_images, nrow=8, normalize=True)show_images(grid)# 保存模型檢查點if (epoch+1) % 5 == 0:torch.save(generator.state_dict(), f'generator_epoch_{epoch+1}.pth')torch.save(discriminator.state_dict(), f'discriminator_epoch_{epoch+1}.pth')print("訓練完成!")

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/81198.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/81198.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/81198.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

httpclient請求出現403

問題 httpclient請求對方服務器報403,用postman是可以的 解決方案: request.setHeader( “User-Agent” ,“Mozilla/5.0 (Windows NT 6.1; Win64; x64; rv:50.0) Gecko/20100101 Firefox/50.0” ); // 設置請求頭 原因: 因為沒有設置為瀏覽器形式&#…

嵌入式硬件篇---IIC

文章目錄 前言1. IC協議基礎1.1 物理層特性兩根信號線SCLSDA支持多主多從 標準模式電平 1.2 通信流程起始條件(Start Condition)從機地址(Slave Address)應答(ACK/NACK)數據傳輸:停止條件&#…

深入探討 Java 注解:從基礎到高級應用

Java 注解自 Java 5 引入以來,已成為現代 Java 開發中不可或缺的一部分。它們通過為代碼添加元數據,簡化了配置、增強了代碼可讀性,并支持了從編譯時驗證到運行時動態行為的多種功能。本文將全面探討 Java 注解的使用、定義和處理方式,并通過一個實際的插件系統示例展示其強…

力扣-105.從前序與中序遍歷序列構造二叉樹

題目描述 給定兩個整數數組 preorder 和 inorder &#xff0c;其中 preorder 是二叉樹的先序遍歷&#xff0c; inorder 是同一棵樹的中序遍歷&#xff0c;請構造二叉樹并返回其根節點。 class Solution { public:TreeNode* buildTree(vector<int>& preorder, vecto…

NoSQL數據庫技術與應用復習總結【看到最后】

第1章 初識NoSQL 1.1 大數據時代對數據存儲的挑戰 1.高并發讀寫需求 2.高效率存儲與訪問需求 3.高擴展性 1.2 認識NoSQL NoSQL--非關系型、分布式、不提供ACID的數據庫設計模式 NoSQL特點 1.易擴展 2.高性能 3.靈活的數據模型 4.高可用 NoSQL擁有一個共同的特點&am…

【ios越獄包安裝失敗?uniapp導出ipa文件如何安裝到蘋果手機】蘋果IOS直接安裝IPA文件

問題場景&#xff1a; 提示&#xff1a;ipa是用于蘋果設備安裝的軟件包資源 設備&#xff1a;iphone 13(未越獄) 安裝包類型&#xff1a;ipa包 調試工具&#xff1a;hbuilderx 問題描述 提要&#xff1a;ios包無法安裝 uniapp導出ios包無法安裝 相信有小伙伴跟我一樣&…

php數據導出pdf,然后pdf轉圖片,再推送釘釘群

public function takePdf($data_plan, $data_act, $file_name, $type){$pdf new \TCPDF(L); // L - 橫向 P-豎向// 設置文檔信息//$file_name 外協批價單;$pdf->SetCreator($file_name);$pdf->SetAuthor($file_name);$pdf->SetTitle($file_name);$pdf->SetSubjec…

每日算法-250513

每日算法 - 2024-05-13 記錄今天學習的算法題解。 2335. 裝滿杯子需要的最短總時長 題目 思路 貪心 這道題的關鍵在于每次操作盡可能多地減少杯子的數量。我們每次操作可以裝一杯或兩杯&#xff08;不同類型&#xff09;。為了最小化總時間&#xff0c;應該優先選擇裝兩杯不同…

城市生命線綜合管控系統解決方案-守護城市生命線安全

一、政策背景 國務院辦公廳《城市安全風險綜合監測預警平臺建設指南》?要求&#xff1a;將燃氣、供水、排水、橋梁、熱力、綜合管廊等納入城市生命線監測體系&#xff0c;建立"能監測、會預警、快處置"的智慧化防控機制。住建部?《"十四五"全國城市基礎…

分布式AI推理的成功之道

隨著AI模型逐漸成為企業運營的核心支柱&#xff0c;實時推理已成為推動這一轉型的關鍵引擎。市場對即時、可決策的AI洞察需求激增&#xff0c;而AI代理——正迅速成為推理技術的前沿——即將迎來爆發式普及。德勤預測&#xff0c;到2027年&#xff0c;超半數采用生成式AI的企業…

auto.js面試題及答案

以下是常見的 Auto.js 面試題及參考答案&#xff0c;涵蓋基礎知識、腳本編寫、運行機制、權限、安全等方面&#xff0c;適合開發崗位的技術面試準備&#xff1a; 一、基礎類問題 什么是 Auto.js&#xff1f;它的主要用途是什么&#xff1f; 答案&#xff1a; Auto.js 是一個…

C語言中的指定初始化器

什么是指定初始化器? C99標準引入了一種更靈活、直觀的初始化語法——指定初始化器(designated initializer), 可以在初始化列表中直接引用結構體或聯合體成員名稱的語法。通過這種方式,我們可以跳過某些不需要初始化的成員,并且可以以任意順序對特定成員進行初始化。這…

高德地圖在Vue3中的使用方法

1.地圖初始化 容器創建&#xff1a;通過 <div> 標簽定義地圖掛載點。 <div id"container" style"height: 300px; width: 100%; margin-top: 10px;"></div> 密鑰配置&#xff1a;綁定高德地圖安全密鑰&#xff0c;確保 API 合法調用。 參…

RabbitMQ發布訂閱模式深度解析與實踐指南

目錄 RabbitMQ發布訂閱模式深度解析與實踐指南1. 發布訂閱模式核心原理1.1 消息分發模型1.2 核心組件對比 2. 交換機類型詳解2.1 交換機類型矩陣2.2 消息生命周期 3. 案例分析與實現案例1&#xff1a;基礎廣播消息系統案例2&#xff1a;分級日志處理系統案例3&#xff1a;分布式…

中小型培訓機構都用什么教務管理系統?

在教育培訓行業快速發展的今天&#xff0c;中小型培訓機構面臨著學員管理復雜、課程體系多樣化、教學效果難以量化等挑戰。一個高效的教務管理系統已成為機構運營的核心支撐。本文將深入分析當前市場上適用于中小型培訓機構的教務管理系統&#xff0c;重點介紹愛耕云這一專業解…

C++虛函數食用筆記

虛函數定義與作用&#xff1a; virtual關鍵字聲明虛函數&#xff0c;虛函數可被派生類override(保證返回類型與參數列表&#xff0c;名字均相同&#xff09;&#xff0c;從而通過基類指針調用時&#xff0c;實現多態的功能 virtual關鍵字: 將函數聲明為虛函數 override關鍵…

運算放大器相關的電路

1運算放大器介紹 解釋&#xff1a;運算放大器本質就是一個放大倍數很大的元件&#xff0c;就如上圖公式所示 Vp和Vn相差很小但是放大后輸出還是會很大。 運算放大器不止上面的三個引腳&#xff0c;他需要獨立供電&#xff1b; 如圖比較器&#xff1a; 解釋&#xff1a;Vp&…

華為OD機試真題——通信系統策略調度(用戶調度問題)(2025B卷:100分)Java/python/JavaScript/C/C++/GO最佳實現

2025 B卷 100分 題型 本專欄內全部題目均提供Java、python、JavaScript、C、C++、GO六種語言的最佳實現方式; 并且每種語言均涵蓋詳細的問題分析、解題思路、代碼實現、代碼詳解、3個測試用例以及綜合分析; 本文收錄于專欄:《2025華為OD真題目錄+全流程解析+備考攻略+經驗分…

Ubuntu 系統默認已安裝 python,此處只需添加一個超鏈接即可

步驟 1&#xff1a;確認 Python 3 的安裝路徑 查看當前 Python 3 的路徑&#xff1a; which python3 輸出類似&#xff1a; /usr/bin/python3 步驟 2&#xff1a;創建符號鏈接 使用 ln -s 創建符號鏈接&#xff0c;將 python 指向 python3&#xff1a; sudo ln -s /usr/b…

深度學習-分布式訓練機制

1、分布式訓練時&#xff0c;包括train.py的全部的代碼都會在每個gpu上運行嗎&#xff1f; 在分布式訓練&#xff08;如使用 PyTorch 的 DistributedDataParallel&#xff0c;DDP&#xff09;時&#xff0c;每個 GPU 上運行的進程會執行 train.py 的全部代碼&#xff0c;但通過…