生成式人工智能實戰 | 生成對抗網絡
- 0. 前言
- 1. 生成對抗網絡
- 2. 模型構建
- 2.1 生成器
- 2.2 判別器
- 3. 模型訓練
- 3.1 數據加載
- 3.2 訓練流程
0. 前言
生成對抗網絡 (Generative Adversarial Networks
, GAN
) 是一種由兩個相互競爭的神經網絡組成的深度學習模型,它由一個生成網絡和一個判別網絡組成,通過彼此之間的博弈來提高生成網絡的性能。生成對抗網絡使用神經網絡生成與原始圖像集非常相似的新圖像,它在圖像生成中應用廣泛,且 GAN
的相關研究正在迅速發展,以生成與真實圖像難以區分的逼真圖像。在本節中,我們將學習 GAN
網絡的原理并使用 PyTorch
實現 GAN
。
1. 生成對抗網絡
生成對抗網絡 (Generative Adversarial Networks
, GAN
) 包含兩個網絡:生成網絡( Generator
,也稱生成器)和判別網絡( discriminator
,也稱判別器)。在 GAN
網絡訓練過程中,需要有一個合理的圖像樣本數據集,生成網絡從圖像樣本中學習圖像表示,然后生成與圖像樣本相似的圖像。判別網絡接收(由生成網絡)生成的圖像和原始圖像樣本作為輸入,并將圖像分類為原始(真實)圖像或生成(偽造)圖像:
- 生成器 G ( z ; θ G ) G(z;θ_G) G(z;θG?) 接受噪聲 z ~ p z z~p_z z~pz??,學習映射到數據空間,以“欺騙”判別器
- 判別器 D ( x ; θ D ) D(x;θ_D) D(x;θD?) 輸出樣本 x x x 屬于真實數據的概率,旨在區分真實與生成數據
兩者通過以下最小–最大化 (minimax
) 目標函數進行博弈:
m i n G m a x D ? V ( D , G ) = E x ~ p d a t a [ l o g ? D ( x ) ] + E z ~ p z [ l o g ? ( 1 ? D ( G ( z ) ) ) ] \underset G {min} \underset D {max} ?V(D,G)=\mathbb E_{x~p_{data}}[log?D(x)]+\mathbb E_{z~p_z}[log?(1?D(G(z)))] Gmin?Dmax??V(D,G)=Ex~pdata??[log?D(x)]+Ez~pz??[log?(1?D(G(z)))]
生成網絡的目標是生成逼真的偽造圖像騙過判別網絡,判別網絡的目標是將生成的圖像分類為偽造圖像,將原始圖像樣本分類為真實圖像。本質上,GAN
中的對抗表示兩個網絡的相反性質,生成網絡生成圖像來欺騙判別網絡,判別網絡通過判別圖像是生成圖像還是原始圖像來對輸入圖像進行分類:
在上圖中,生成網絡根據輸入隨機噪聲生成圖像,判別網絡接收生成網絡生成的圖像,并將它們與真實圖像樣本進行比較,以判斷生成的圖像是真實的還是偽造的。生成網絡嘗試生成盡可能逼真的圖像,而判別網絡嘗試判定生成網絡生成圖像的真實性,從而學習生成盡可能逼真的圖像。
GAN
的關鍵思想是生成網絡和判別網絡之間的競爭和動態平衡,通過不斷的訓練和迭代,生成網絡和判別網絡會逐漸提高性能,生成網絡能夠生成更加逼真的樣本,而判別網絡則能夠更準確地區分真實和偽造的樣本。
通常,生成網絡和判別網絡交替訓練,將生成網絡和判別網絡視為博弈雙方,并通過兩者之間的對抗來推動模型性能的提升,直到生成網絡生成的樣本能夠以假亂真,判別網絡無法分辨真實樣本和生成樣本之間的差異:
- 生成網絡的訓練過程:凍結判別網絡權重,生成網絡以噪聲
z
作為輸入,通過最小化生成網絡與真實數據之間的差異來學習如何生成更好的樣本,以便判別網絡將圖像分類為真實圖像 - 判別網絡的訓練過程:凍結生成網絡權重,判別網絡通過最小化真實樣本和假樣本之間的分類誤差來更新判別網絡,區分真實樣本和生成樣本,將生成網絡生成的圖像分類為偽造圖像
重復訓練生成網絡與判別網絡,直到達到平衡,當判別網絡能夠很好地檢測到生成的圖像時,生成網絡對應的損失比判別網絡對應的損失要高得多。通過不斷訓練生成網絡和判別網絡,直到生成網絡可以生成逼真圖像,而判別網絡無法區分真實圖像和生成圖像。
2. 模型構建
2.1 生成器
生成器由若干全連接層與 LeakyReLU
激活構成,最后用 Tanh
將輸出映射至 [?1,1]
范圍內:
# 定義生成器 G
import torch.nn as nnclass Generator(nn.Module):def __init__(self, z_dim=10):super().__init__()self.net = nn.Sequential(nn.Linear(z_dim, 256),nn.LeakyReLU(0.2),nn.Linear(256, 512),nn.LeakyReLU(0.2),nn.Linear(512, 1024),nn.LeakyReLU(0.2),nn.Linear(1024, 28*28),nn.Tanh() # 輸出像素映射到 [-1,1])def forward(self, z):return self.net(z).view(-1,1,28,28)
2.2 判別器
判別器使用全連接層與 LeakyReLU
激活,末端使用 Sigmoid
激活函數,輸出一個標量真值估計:
# 定義判別器 D,對輸入圖片輸出真偽概率
class Discriminator(nn.Module):def __init__(self, img_dim=28*28):super().__init__()self.model = nn.Sequential(nn.Flatten(),nn.Linear(img_dim, 1024),nn.LeakyReLU(0.2),nn.Linear(1024, 512),nn.LeakyReLU(0.2),nn.Linear(512, 256),nn.LeakyReLU(0.2),nn.Linear(256, 1),nn.Sigmoid())def forward(self, x):return self.model(x)
3. 模型訓練
接下來,使用 MNIST
數據集訓練 GAN
模型。
3.1 數據加載
將 MNIST
像素值歸一化到生成器 Tanh
輸出所需的 [?1,1]
區間:
# 加載并歸一化 MNIST 數據集
from torchvision import datasets, transforms
from torch.utils.data import DataLoadertransform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,)) # 映射到 [-1,1]
])
train_ds = datasets.MNIST('./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_ds, batch_size=128, shuffle=True)
3.2 訓練流程
首先,初始化模型、優化器與損失函數:
import torch
import torch.optim as optimdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
z_dim = 20
G = Generator(z_dim).to(device)
D = Discriminator().to(device)opt_G = optim.Adam(G.parameters(), lr=1e-4, betas=(0.5, 0.999))
opt_D = optim.Adam(D.parameters(), lr=1e-4, betas=(0.5, 0.999))
loss_fn = nn.BCELoss()
訓練模型 50
個 epoch
:
epochs = 50for epoch in range(epochs):for real, _ in train_loader:real = real.to(device)batch_size = real.size(0)real_labels = torch.ones(batch_size, 1, device=device)fake_labels = torch.zeros(batch_size, 1, device=device)# 訓練判別器# 在真實樣本上進行訓練D_real = D(real)loss_D_real = loss_fn(D_real, real_labels)opt_D.zero_grad()loss_D_real.backward()opt_D.step()# 在虛假樣本上進行訓練z = torch.randn(batch_size, z_dim, device=device)fake = G(z)D_fake = D(fake.detach())loss_D_fake = loss_fn(D_fake, fake_labels)opt_D.zero_grad()loss_D_fake.backward()opt_D.step()d_loss = (loss_D_real + loss_D_fake) / 2# 訓練生成器z = torch.randn(batch_size, z_dim, device=device)fake = G(z)D_fake = D(fake)loss_G = loss_fn(D_fake, real_labels) # 生成器希望D認為它生成的是真的opt_G.zero_grad()loss_G.backward()opt_G.step()print(f"Epoch [{epoch+1}/{epochs}] Loss_D: {d_loss.item():.4f} Loss_G: {loss_G.item():.4f}")
使用訓練后的模型生成偽造數據:
# 采樣生成圖片并顯示
import matplotlib.pyplot as pltG.eval()
with torch.no_grad():z = torch.randn(16, z_dim, device=device)fake_images = G(z).cpu()fig, axes = plt.subplots(4, 4, figsize=(6, 6))
for i, ax in enumerate(axes.flatten()):ax.imshow(fake_images[i].squeeze().reshape(28, 28), cmap='gray')ax.axis('off')
plt.tight_layout()
plt.show()