PyTorch生成式人工智能——ACGAN詳解與實現
- 0. 前言
- 1. ACGAN 簡介
- 1.1 ACGAN 技術原理
- 1.2 ACGAN 核心思想
- 1.3 損失函數
- 2. 模型訓練流程
- 3. 使用 PyTorch 構建 ACGAN
- 3.1 數據處理
- 3.2 模型構建
- 3.3 模型訓練
- 3.4 模型測試
- 相關鏈接
0. 前言
在生成對抗網絡 (Generative Adversarial Network, GAN) 的眾多變體中,ACGAN
(Auxiliary Classifier GAN
) 是一個非常經典且實用的條件生成模型。它的核心思想是:在判別器中除了保留“真假判別”這一任務外,額外加入一個輔助分類器,讓判別器同時預測輸入樣本的類別。這樣,生成器在訓練時不僅需要“欺騙判別器”,還必須生成能夠被正確分類的樣本,從而在圖像語義和類別可控性上得到顯著提升。
這一改進讓 ACGAN
能夠在條件圖像生成中表現出色,在復雜數據集上實現按類別生成的能力。相比于傳統條件生成對抗網絡 (Conditional GAN, cGAN) 簡單地把標簽拼接到輸入,ACGAN
通過 “輔助分類監督” 提供了更細粒度的學習信號,使得生成器得到的梯度更加穩定和有意義。在本節中,將詳細介紹 ACGAN
原理,并使用 PyTorch
構建 ACGAN
模型。
1. ACGAN 簡介
1.1 ACGAN 技術原理
生成對抗網絡 (Generative Adversarial Network, GAN) 的眾多變體中,ACGAN
(Auxiliary Classifier GAN
) 能夠從隨機噪聲中生成逼真的圖像、文本甚至音樂。然而,傳統的 GAN
有一個顯著的局限性:缺乏對生成過程的精確控制。我們無法指定要生成“數字7”的圖片還是一只“戴墨鏡的貓”。
為了解決這個問題,條件生成對抗網絡 (Conditional GAN, cGAN) 應運而生。它通過將類別標簽信息同時注入生成器 (Generator
) 和判別器 (Discriminator
),實現了條件生成。但這仍然不夠完美,cGAN
的判別器最終只輸出一個“真/假”的概率,它并沒有顯式地告訴生成器它生成的圖片在類別上是否正確。
ACGAN
(Auxiliary Classifier GAN
) 正是在 CGAN
的基礎上,對判別器的任務進行了至關重要的擴展。它不僅判斷真偽,還同時擔任一個“分類器”的角色。這個簡單的改變,極大地提升了生成圖像的質量和多樣性,尤其是在生成特定類別的圖像時。
1.2 ACGAN 核心思想
ACGAN
的核心思想非常直觀:為判別器增加一個輔助任務——對輸入圖像進行分類。其中,生成器的輸入包括隨機噪聲向量 zzz 和目標類別標簽 ccc;判別器的輸出包括:
- 一個源 (
Source
) 輸出:一個標量概率,表示圖像是來自真實數據分布的概率 - 一個輔助類別 (
Class
) 輸出:類別概率分布
通過引入這個輔助的分類任務,ACGAN
迫使判別器不僅要學習“什么樣的圖像是真實的”,還要學習“真實圖像屬于什么類別”。反過來,生成器為了欺騙這個更強大的判別器,也必須生成既逼真又類別分明的圖像。
1.3 損失函數
損失函數包含兩部分:
- 源判別損失 (
source loss
),用來訓練真假判別,通常使用二元交叉熵 - 類別判別損失 (
auxiliary classification loss
),使用多元交叉熵(真實圖像的類別為真實標簽,生成圖像的類別為生成器的條件標簽
訓練目標:
- 判別器
D
,最小化源判別損失(正確區分真實/虛假圖像)并最小化類別判別損失(正確預測類別) - 生成器
G
:生成圖像以最大化判別器認為是“真實”的概率,并最小化判別器給出的類別預測與條件類的一致性
2. 模型訓練流程
模型訓練流程如下:
- 從真實數據中取一批數據 x,c{x,c}x,c
- 判別器更新:
- 計算真實樣本的源判別損失與類別判別損失
- 用噪聲和隨機標簽生成虛假樣本 x~=G(z,c)\tilde x=G(z,c)x~=G(z,c),計算虛假樣本的源判別損失與類別判別損失(可選)
- 把這些損失加權后更新
D
- 生成器更新:
- 用一批噪聲與條件標簽生成樣本 x~\tilde xx~
- 通過
D
計算源輸出與輔助類別輸出 - 生成器的損失是希望源輸出為“真實”,并希望輔助類別輸出為生成時的條件標簽
- 更新
G
3. 使用 PyTorch 構建 ACGAN
接下來,使用 PyTorch
實現 ACGAN
,并在 MNIST
數據集上進行訓練生成手寫數字。
3.1 數據處理
(1) 首先,導入所需庫并設置超參數:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.utils import save_image
import numpy as np
import matplotlib.pyplot as plt
import os# 設置超參數
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
latent_dim = 100
num_classes = 10
batch_size = 64
lr = 0.0002
num_epochs = 100
sample_interval = 400# 創建輸出目錄
os.makedirs("images", exist_ok=True)
os.makedirs("models", exist_ok=True)
(2) 加載 MNIST
數據集,將圖像轉換為張量,并將像素值從 [0,1]
歸一化到 [-1,1]
范圍:
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5], [0.5])
])
dataset = torchvision.datasets.MNIST(root="./data",train=True,download=True,transform=transform
)
(3) 構建數據加載器:
dataloader = torch.utils.data.DataLoader(dataset,batch_size=batch_size,shuffle=True
)
3.2 模型構建
(1) 首先,定義權重初始化函數:
def weights_init_normal(m):classname = m.__class__.__name__if classname.find("Conv") != -1:torch.nn.init.normal_(m.weight.data, 0.0, 0.02)elif classname.find("BatchNorm") != -1:torch.nn.init.normal_(m.weight.data, 1.0, 0.02)torch.nn.init.constant_(m.bias.data, 0.0)
(2) 定義生成器。生成器接收隨機噪聲和類別標簽作為輸入,通過嵌入層將標簽轉換為與噪聲相同維度的向量,然后將二者相乘融合,之后通過全連接層和轉置卷積層逐步上采樣,最終生成 28 x 28
的圖像:
class Generator(nn.Module):def __init__(self, latent_dim, num_classes):super(Generator, self).__init__()# 將類別標簽轉換為嵌入向量self.label_emb = nn.Embedding(num_classes, latent_dim)self.init_size = 7 # 初始特征圖大小self.l1 = nn.Sequential(nn.Linear(latent_dim, 128 * self.init_size ** 2))self.conv_blocks = nn.Sequential(nn.BatchNorm2d(128),nn.Upsample(scale_factor=2),nn.Conv2d(128, 128, 3, stride=1, padding=1),nn.BatchNorm2d(128, 0.8),nn.LeakyReLU(0.2, inplace=True),nn.Upsample(scale_factor=2),nn.Conv2d(128, 64, 3, stride=1, padding=1),nn.BatchNorm2d(64, 0.8),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(64, 1, 3, stride=1, padding=1),nn.Tanh())def forward(self, noise, labels):# 將噪聲和標簽嵌入相乘gen_input = torch.mul(self.label_emb(labels), noise)out = self.l1(gen_input)out = out.view(out.shape[0], 128, self.init_size, self.init_size)img = self.conv_blocks(out)return img
除了將類別標簽轉換為嵌入向量進行融合外,也可以直接使用標簽的獨熱編碼與噪聲向量進行拼接。
(3) 定義判別器。判別器使用卷積層逐步提取特征,最后通過兩個全連接層分別輸出樣本真偽的概率(源判別輸出)和類別概率(類別判別輸出):
class Discriminator(nn.Module):def __init__(self, num_classes):super(Discriminator, self).__init__()# 卷積層提取特征self.features = nn.Sequential(# 輸入: 1x28x28nn.Conv2d(1, 16, 3, stride=2, padding=1), # 16x14x14nn.LeakyReLU(0.2, inplace=True),nn.Dropout2d(0.25),nn.Conv2d(16, 32, 3, stride=2, padding=1), # 32x7x7nn.LeakyReLU(0.2, inplace=True),nn.Dropout2d(0.25),nn.BatchNorm2d(32, 0.8),nn.Conv2d(32, 64, 3, stride=2, padding=1), # 64x4x4nn.LeakyReLU(0.2, inplace=True),nn.Dropout2d(0.25),nn.BatchNorm2d(64, 0.8),nn.Conv2d(64, 128, 3, stride=2, padding=1), # 128x2x2nn.LeakyReLU(0.2, inplace=True),nn.Dropout2d(0.25),nn.BatchNorm2d(128, 0.8),)# 計算特征圖大小: 128 * 2 * 2 = 512self.feature_size = 128 * 2 * 2# 輸出真實/虛假的概率self.adv_layer = nn.Sequential(nn.Linear(self.feature_size, 1), nn.Sigmoid())# 輸出類別概率self.aux_layer = nn.Sequential(nn.Linear(self.feature_size, num_classes), nn.Softmax(dim=1))def forward(self, img):# 提取特征features = self.features(img)features = features.view(features.size(0), -1) # 展平# 預測真偽和類別validity = self.adv_layer(features)label = self.aux_layer(features)return validity, label
(4) 初始化生成器和判別器,并打印模型結構:
generator = Generator(latent_dim, num_classes).to(device)
discriminator = Discriminator(num_classes).to(device)# 初始化權重
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)# 打印模型結構
print("Generator structure:")
print(generator)
print("\nDiscriminator structure:")
print(discriminator)
輸出模型結構如下所示:
3.3 模型訓練
(1) 初始化損失函數和優化器:
# 定義損失函數
adversarial_loss = nn.BCELoss()
auxiliary_loss = nn.CrossEntropyLoss()# 定義優化器
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
(2) 定義變量記錄訓練過程的損失變化:
G_losses = []
D_losses = []
(3) 實現訓練循環。訓練過程分為兩個部分,先訓練判別器,使其能正確區分真實和生成樣本,并正確分類;然后訓練生成器,使其能生成被判別器判定為真實且分類正確的樣本:
# 訓練循環
for epoch in range(num_epochs):for i, (imgs, labels) in enumerate(dataloader):batch_size = imgs.shape[0]# 準備真實/虛假標簽valid = torch.ones(batch_size, 1).to(device)fake = torch.zeros(batch_size, 1).to(device)# 真實圖像和標簽real_imgs = imgs.to(device)real_labels = labels.to(device)# 訓練判別器optimizer_D.zero_grad()# 真實樣本的損失real_pred, real_aux = discriminator(real_imgs)d_real_loss = (adversarial_loss(real_pred, valid) + auxiliary_loss(real_aux, real_labels)) / 2# 生成虛假樣本z = torch.randn(batch_size, latent_dim).to(device)gen_labels = torch.randint(0, num_classes, (batch_size,)).to(device)gen_imgs = generator(z, gen_labels)# 虛假樣本的損失fake_pred, fake_aux = discriminator(gen_imgs.detach())d_fake_loss = (adversarial_loss(fake_pred, fake) + auxiliary_loss(fake_aux, gen_labels)) / 2# 總判別器損失d_loss = (d_real_loss + d_fake_loss) / 2# 計算判別器準確率pred = np.concatenate([real_aux.data.cpu().numpy(), fake_aux.data.cpu().numpy()], axis=0)gt = np.concatenate([real_labels.data.cpu().numpy(), gen_labels.data.cpu().numpy()], axis=0)d_acc = np.mean(np.argmax(pred, axis=1) == gt)d_loss.backward()optimizer_D.step()# 訓練生成器optimizer_G.zero_grad()# 生成器希望判別器將虛假樣本判斷為真實validity, pred_label = discriminator(gen_imgs)g_loss = (adversarial_loss(validity, valid) + auxiliary_loss(pred_label, gen_labels)) / 2g_loss.backward()optimizer_G.step()# 記錄損失G_losses.append(g_loss.item())D_losses.append(d_loss.item())# 打印訓練狀態if i % 100 == 0:print(f"[Epoch {epoch}/{num_epochs}] [Batch {i}/{len(dataloader)}] "f"[D loss: {d_loss.item():.4f}, acc: {100*d_acc:.2f}%] "f"[G loss: {g_loss.item():.4f}]")# 定期保存生成的圖像樣本batches_done = epoch * len(dataloader) + iif batches_done % sample_interval == 0:# 保存生成器生成的圖像save_image(gen_imgs.data[:25], f"images/{batches_done}.png", nrow=5, normalize=True)
(4) 訓練完成后,保存模型權重:
torch.save(generator.state_dict(), "models/generator_final.pth")
torch.save(discriminator.state_dict(), "models/discriminator_final.pth")
(5) 繪制訓練過程中的損失曲線
plt.figure(figsize=(10, 5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses, label="G")
plt.plot(D_losses, label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.savefig("loss_curve.png")
plt.show()
3.4 模型測試
在 images
文件夾中可以看到訓練過程中生成的樣本,隨著訓練進行,生成的數字越來越清晰:
使用訓練完成的模型,生成制定類別的數字:
fig, axes = plt.subplots(2, 5, figsize=(10, 4))
for i in range(10):img = generate_digit(generator, i)ax = axes[i//5, i%5]ax.imshow(img.cpu().squeeze(), cmap='gray')ax.set_title(f"Digit: {i}")ax.axis('off')
plt.tight_layout()
plt.savefig("generated_digits.png")
plt.show()
生成數字 1
:
fig, axes = plt.subplots(2, 5, figsize=(10, 4))
for i in range(10):img = generate_digit(generator, 1)ax = axes[i//5, i%5]ax.imshow(img.cpu().squeeze(), cmap='gray')ax.set_title(f"Digit: 1")ax.axis('off')
plt.tight_layout()
plt.savefig("generated_digits.png")
plt.show()
相關鏈接
PyTorch生成式人工智能實戰:從零打造創意引擎
PyTorch生成式人工智能(1)——神經網絡與模型訓練過程詳解
PyTorch生成式人工智能(2)——PyTorch基礎
PyTorch生成式人工智能(3)——使用PyTorch構建神經網絡
PyTorch生成式人工智能(4)——卷積神經網絡詳解
PyTorch生成式人工智能(5)——分類任務詳解
PyTorch生成式人工智能(6)——生成模型(Generative Model)詳解
PyTorch生成式人工智能(7)——生成對抗網絡實踐詳解
PyTorch生成式人工智能(8)——深度卷積生成對抗網絡
PyTorch生成式人工智能(9)——Pix2Pix詳解與實現
PyTorch生成式人工智能(10)——CyclelGAN詳解與實現
PyTorch生成式人工智能(11)——神經風格遷移
PyTorch生成式人工智能(12)——StyleGAN詳解與實現
PyTorch生成式人工智能(13)——WGAN詳解與實現
PyTorch生成式人工智能(14)——條件生成對抗網絡(conditional GAN,cGAN)
PyTorch生成式人工智能(15)——自注意力生成對抗網絡(Self-Attention GAN, SAGAN)
PyTorch生成式人工智能(16)——自編碼器(AutoEncoder)詳解
PyTorch生成式人工智能(17)——變分自編碼器詳解與實現
PyTorch生成式人工智能(18)——循環神經網絡詳解與實現
PyTorch生成式人工智能(19)——自回歸模型詳解與實現
PyTorch生成式人工智能(20)——像素卷積神經網絡(PixelCNN)
PyTorch生成式人工智能(21)——歸一化流模型(Normalizing Flow Model)
PyTorch生成式人工智能(27)——從零開始訓練GPT模型
PyTorch生成式人工智能(28)——MuseGAN詳解與實現