生成對抗網絡是深度學習領域最具革命性的生成模型之一。
一 GAN框架
1.1組成
構造生成器(G)與判別器(D)進行動態對抗,實現數據的無監督生成。
G(造假者):接收噪聲 ?,生成數據
?。?
D(鑒定家):接收真實數據 和生成數據
,輸出概率
?或
1.2核心原理
對抗目標:
該公式為極大極小博弈,D和G互為對手,在動態博弈中驅動模型逐步提升性能。
其中:
第一項(真實性強化)??的目標為:
讓判別器 D 將真實數據?x?識別為“真”(即讓 ),最大化這一項可使D對真實數據的判斷置信度更高。
第二項(生成性對抗):的目標為:
生成器G希望生成的假數據被判別器D判為“真”(即讓
)從而最小化
,判別器D則希望判假數據為“假”(即讓
),從而最大化
。
生成器和判別器在此項上存在直接對抗。
數學原理:
(1)最優判別器理論
固定生成器G時,最大化??得到最優判別器?
。
其中,??是生成數據的分布。當生成數據完美匹配真實分布時?
,判別器無法區分真假(輸出
)。
推導一下:
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? (1)
轉換為積分形式,設真實數據分布為?,生成數據分布為
,則(1)可以改寫為:
? ? ? ? ? ? ? ? ? ? ? ? ? ? ?(2)?
逐點最大化,對于每個樣本 x,單獨最大化以下函數:
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? (3)
求導并解方程:?
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? (4)
求得:
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?? ? ? ? (5)?
(2) 目標化簡:JS散度(Jensen-Shannon Divergence)?
將最優判別器??代入原目標函數,可得:
最小化目標即等價于最小化? 與
的JS散度。
JS散度特性:對稱、非負,衡量兩個分布的相似性。
1.3訓練過程解釋
每個訓練步驟包含兩階段:
(1)判別器更新(固定G,最大化 ):
通過梯度上升優化D的參數,提升判別能力。
(2)生成器更新(固定D,最小化 ):
實際訓練中常用 代替以增強梯度穩定性。
訓練中出現的問題:
(1)JS散度飽和導致梯度消失
(2)參數空間的非凸優化(存在無數個局部極值,優化算法極易陷入次優解,而非全局最優解)使訓練難以收斂
二 經典GAN架構
DCGAN(GAN+卷積)
特性 | 原始GAN | DCGAN |
---|---|---|
網絡結構 | 全連接層(MLPs) | 卷積生成器 + 卷積判別器 |
穩定性 | 容易梯度爆炸/消失,難以收斂 | 通過BN和特定激活函數穩定訓練 |
生成圖像分辨率 | 低分辨率(如32x32) | 支持64x64及以上分辨率的清晰圖像生成 |
圖像質量 | 輪廓模糊,缺乏細節 | 細粒度紋理(如毛發、磚紋) |
計算效率 | 參數量大,訓練速度慢 | 卷積結構參數共享,效率提升 |
(1)生成器架構(反卷積)
class Generator(nn.Module):def __init__(self, noise_dim=100, output_channels=3):super().__init__()self.main = nn.Sequential(# 輸入:100維噪聲,輸出:1024x4x4nn.ConvTranspose2d(noisel_dim, 512, 4, 1, 0, bias=False),nn.BatchNorm2d(512),nn.ReLU(True),# 上采樣至8x8nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),nn.BatchNorm2d(256),nn.ReLU(True),# 輸出層:3通道RGB圖像nn.ConvTranspose2d(64, output_channels, 4, 2, 1, bias=False),nn.Tanh() # 將輸出壓縮到[-1,1])
(2) 判別器架構(卷積)
class Discriminator(nn.Module):def __init__(self, input_channels=3):super().__init__()self.main = nn.Sequential(# 輸入:3x64x64nn.Conv2d(input_channels, 64, 4, 2, 1, bias=False),nn.LeakyReLU(0.2, inplace=True),# 下采樣至32x32nn.Conv2d(64, 128, 4, 2, 1, bias=False),nn.BatchNorm2d(128),nn.LeakyReLU(0.2, inplace=True),# 輸出層:二分類概率nn.Conv2d(512, 1, 4, 1, 0, bias=False),nn.Sigmoid())
(3)雙重優化問題
保持生成器和判別器動態平衡的核心機制。
for epoch in range(num_epochs):# 更新判別器optimizer_D.zero_grad()real_loss = adversarial_loss(D(real_imgs), valid)fake_loss = adversarial_loss(D(gen_imgs.detach()), fake)d_loss = (real_loss + fake_loss) / 2d_loss.backward()optimizer_D.step()# 更新生成器optimizer_G.zero_grad()g_loss = adversarial_loss(D(gen_imgs), valid) # 欺詐判別器g_loss.backward()optimizer_G.step()
三 應用場景?
圖像合成引擎(語義圖到照片)、醫學影像增強、語音與音頻合成。
GAN作為生成式AI的基石模型,其核心價值不僅在于數據生成能力,更在于構建了一種全新的深度學習范式——通過對抗博弈驅動模型持續進化。
四 一個完整DCGAN代碼示例
MNIST數據集
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np# 參數設置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 128
image_size = 64
num_epochs = 50
latent_dim = 100
lr = 0.0002
beta1 = 0.5# 數據準備
transform = transforms.Compose([transforms.Resize(image_size),transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,)) # MNIST是單通道
])dataset = datasets.MNIST(root="./data", train=True, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)# 可視化輔助函數
def show_images(images):plt.figure(figsize=(8,8))images = images.permute(1,2,0).cpu().numpy()plt.imshow((images * 0.5) + 0.5) # 反歸一化plt.axis('off')plt.show()# 權重初始化
def weights_init(m):classname = m.__class__.__name__if classname.find('Conv') != -1:nn.init.normal_(m.weight.data, 0.0, 0.02)elif classname.find('BatchNorm') != -1:nn.init.normal_(m.weight.data, 1.0, 0.02)nn.init.constant_(m.bias.data, 0)# 生成器定義
class Generator(nn.Module):def __init__(self, latent_dim):super(Generator, self).__init__()self.main = nn.Sequential(# 輸入:latent_dim x 1 x 1nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False),nn.BatchNorm2d(512),nn.ReLU(True),# 輸出:512 x 4 x 4nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),nn.BatchNorm2d(256),nn.ReLU(True),# 輸出:256 x 8 x 8nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),nn.BatchNorm2d(128),nn.ReLU(True),# 輸出:128 x 16x16nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),nn.BatchNorm2d(64),nn.ReLU(True),# 輸出:64 x 32x32nn.ConvTranspose2d(64, 1, 4, 2, 1, bias=False),nn.Tanh() # 輸出范圍[-1,1]# 最終輸出:1 x 64x64)def forward(self, input):return self.main(input)# 判別器定義
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.main = nn.Sequential(# 輸入:1 x 64x64nn.Conv2d(1, 64, 4, 2, 1, bias=False),nn.LeakyReLU(0.2, inplace=True),# 輸出:64 x32x32nn.Conv2d(64, 128, 4, 2, 1, bias=False),nn.BatchNorm2d(128),nn.LeakyReLU(0.2, inplace=True),# 輸出:128x16x16nn.Conv2d(128, 256, 4, 2, 1, bias=False),nn.BatchNorm2d(256),nn.LeakyReLU(0.2, inplace=True),# 輸出:256x8x8nn.Conv2d(256, 512, 4, 2, 1, bias=False),nn.BatchNorm2d(512),nn.LeakyReLU(0.2, inplace=True),# 輸出:512x4x4nn.Conv2d(512, 1, 4, 1, 0, bias=False),nn.Sigmoid())def forward(self, input):return self.main(input).view(-1, 1).squeeze(1)# 初始化模型
generator = Generator(latent_dim).to(device)
discriminator = Discriminator().to(device)# 應用權重初始化
generator.apply(weights_init)
discriminator.apply(weights_init)# 損失函數和優化器
criterion = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))# 訓練過程可視化準備
fixed_noise = torch.randn(64, latent_dim, 1, 1, device=device)# 訓練循環
for epoch in range(num_epochs):for i, (real_images, _) in enumerate(dataloader):# 準備數據real_images = real_images.to(device)batch_size = real_images.size(0)# 真實標簽和虛假標簽real_labels = torch.full((batch_size,), 0.9, device=device) # label smoothingfake_labels = torch.full((batch_size,), 0.0, device=device)# ========== 訓練判別器 ==========optimizer_D.zero_grad()# 真實圖片的判別結果outputs_real = discriminator(real_images)loss_real = criterion(outputs_real, real_labels)# 生成假圖片noise = torch.randn(batch_size, latent_dim, 1, 1, device=device)fake_images = generator(noise)# 假圖片的判別結果outputs_fake = discriminator(fake_images.detach())loss_fake = criterion(outputs_fake, fake_labels)# 合并損失并反向傳播loss_D = loss_real + loss_fakeloss_D.backward()optimizer_D.step()# ========== 訓練生成器 ==========optimizer_G.zero_grad()# 更新生成器時的判別結果outputs = discriminator(fake_images)loss_G = criterion(outputs, real_labels) # 欺騙判別器# 反向傳播loss_G.backward()optimizer_G.step()# 打印訓練狀態if i % 100 == 0:print(f"[Epoch {epoch}/{num_epochs}] [Batch {i}/{len(dataloader)}] "f"Loss_D: {loss_D.item():.4f} Loss_G: {loss_G.item():.4f}")# 每個epoch結束時保存生成結果with torch.no_grad():test_images = generator(fixed_noise)grid = torchvision.utils.make_grid(test_images, nrow=8, normalize=True)show_images(grid)# 保存模型檢查點if (epoch+1) % 5 == 0:torch.save(generator.state_dict(), f'generator_epoch_{epoch+1}.pth')torch.save(discriminator.state_dict(), f'discriminator_epoch_{epoch+1}.pth')print("訓練完成!")