Pytorch深度學習框架60天進階學習計劃 - 第41天:生成對抗網絡進階(一)
今天我們將深入探討生成對抗網絡(GAN)的進階內容,特別是Wasserstein GAN(WGAN)的梯度懲罰機制,以及條件生成與無監督生成在模式坍塌方面的差異。
生成對抗網絡是近年來深度學習領域最激動人心的進展之一,它由Ian Goodfellow于2014年提出,通過生成器和判別器的博弈來學習生成真實數據分布的樣本。隨著研究的深入,GAN的改進版本層出不窮,其中WGAN及其梯度懲罰版本(WGAN-GP)解決了原始GAN訓練不穩定的問題,成為了GAN研究的重要里程碑。
今天我們將從理論到實踐,系統地學習這些進階概念,并通過PyTorch實現相關模型,探索其工作原理。
1. GAN基礎回顧
在深入WGAN之前,讓我們簡要回顧GAN的基本原理:
1.1 GAN的基本架構
GAN由兩部分組成:
- 生成器(Generator): 學習從隨機噪聲生成看起來真實的數據
- 判別器(Discriminator): 學習區分真實數據和生成器生成的假數據
這兩個網絡通過對抗訓練相互提高:生成器嘗試生成越來越逼真的樣本以欺騙判別器,而判別器則努力提高其區分真假樣本的能力。
1.2 原始GAN的問題
雖然GAN的思想非常優雅,但原始GAN在訓練過程中存在一些問題:
- 訓練不穩定:很難找到生成器和判別器之間的平衡點
- 梯度消失:當判別器表現過好時,生成器梯度接近于零
- 模式坍塌:生成器只生成有限種類的樣本,無法覆蓋真實數據的全部分布
- 難以量化訓練進度:缺乏有效的指標來衡量生成樣本的質量
這些問題促使研究者尋找GAN的改進版本,其中WGAN是最重要的改進之一。
2. Wasserstein GAN詳解
2.1 從JS散度到Wasserstein距離
原始GAN隱式地最小化生成分布與真實分布之間的Jensen-Shannon(JS)散度,這在兩個分布沒有顯著重疊時會導致梯度問題。
Wasserstein距離(也稱Earth Mover’s Distance,簡稱EMD)提供了一種更平滑的度量方式,即使兩個分布沒有重疊或重疊很少,也能提供有意義的梯度。
Wasserstein距離的直觀解釋:想象將一個分布的概率質量移動到另一個分布所需的最小"工作量",其中工作量定義為概率質量乘以移動距離。
2.2 WGAN的核心改進
WGAN相比原始GAN有以下關鍵改進:
- 目標函數改變:使用Wasserstein距離而非JS散度
- 判別器(現稱為評論家/Critic)輸出不再是概率:移除了最后的sigmoid激活函數
- 權重裁剪:限制評論家的參數在一定范圍內,滿足Lipschitz約束
- 避免使用基于動量的優化器:建議使用RMSProp或Adam優化器(學習率較小)
2.3 WGAN的目標函數
WGAN的目標函數如下:
min ? G max ? D ∈ D E x ~ P r [ D ( x ) ] ? E z ~ P z [ D ( G ( z ) ) ] \min_G \max_{D \in \mathcal{D}} \mathbb{E}_{x \sim P_r}[D(x)] - \mathbb{E}_{z \sim P_z}[D(G(z))] Gmin?D∈Dmax?Ex~Pr??[D(x)]?Ez~Pz??[D(G(z))]
其中 D \mathcal{D} D是滿足1-Lipschitz約束的函數集合。
2.4 Lipschitz約束與權重裁剪
為了滿足Wasserstein距離計算中的Lipschitz約束,WGAN對評論家的參數進行了權重裁剪:將權重限制在 [ ? c , c ] [-c, c] [?c,c]的范圍內,其中 c c c是一個小常數(如0.01)。
然而,權重裁剪是一種粗糙的方法,會導致優化問題和容量浪費。這就引出了WGAN的進一步改進:梯度懲罰機制。
3. WGAN的梯度懲罰機制
3.1 權重裁剪的局限性
WGAN中的權重裁剪雖然簡單有效,但存在以下問題:
- 容量浪費:強制權重接近0或c,導致模型傾向于使用更簡單的函數
- 優化困難:可能導致梯度爆炸或消失
- 對架構敏感:不同網絡架構可能需要不同的裁剪范圍
3.2 梯度懲罰的原理
WGAN-GP(帶梯度懲罰的WGAN)提出了一種更優雅的方式來滿足Lipschitz約束。其核心思想是:
對于一個1-Lipschitz函數,其梯度范數在任何地方都不應超過1。因此,我們可以通過懲罰評論家函數梯度范數偏離1的行為來滿足這一約束。
具體來說,WGAN-GP在真實數據和生成數據之間的隨機插值點上施加梯度懲罰:
L G P = E x ^ ~ P x ^ [ ( ∣ ∣ ? x ^ D ( x ^ ) ∣ ∣ 2 ? 1 ) 2 ] \mathcal{L}_{GP} = \mathbb{E}_{\hat{x} \sim P_{\hat{x}}}[(||\nabla_{\hat{x}}D(\hat{x})||_2 - 1)^2] LGP?=Ex^~Px^??[(∣∣?x^?D(x^)∣∣2??1)2]
其中 x ^ \hat{x} x^是在真實樣本 x x x和生成樣本 G ( z ) G(z) G(z)之間的隨機插值:
x ^ = ? x + ( 1 ? ? ) G ( z ) \hat{x} = \epsilon x + (1-\epsilon)G(z) x^=?x+(1??)G(z)
? \epsilon ?是一個在 [ 0 , 1 ] [0,1] [0,1]之間均勻采樣的隨機數。
3.3 WGAN-GP的完整目標函數
將梯度懲罰添加到WGAN的目標函數中,我們得到WGAN-GP的目標函數:
L = E z ~ p ( z ) [ D ( G ( z ) ) ] ? E x ~ p d a t a [ D ( x ) ] + λ E x ^ ~ P x ^ [ ( ∣ ∣ ? x ^ D ( x ^ ) ∣ ∣ 2 ? 1 ) 2 ] \mathcal{L} = \mathbb{E}_{z \sim p(z)}[D(G(z))] - \mathbb{E}_{x \sim p_{data}}[D(x)] + \lambda \mathbb{E}_{\hat{x} \sim P_{\hat{x}}}[(||\nabla_{\hat{x}}D(\hat{x})||_2 - 1)^2] L=Ez~p(z)?[D(G(z))]?Ex~pdata??[D(x)]+λEx^~Px^??[(∣∣?x^?D(x^)∣∣2??1)2]
其中 λ \lambda λ是梯度懲罰的權重,通常設為10。
3.4 WGAN-GP的優勢
WGAN-GP相比WGAN有以下優勢:
- 更好的穩定性:避免了權重裁剪帶來的問題
- 更快的收斂:通常需要更少的迭代次數
- 更好的生成質量:能生成更多樣、更高質量的樣本
- 架構靈活性:適用于各種GAN架構,包括深度卷積網絡
4. PyTorch實現WGAN-GP
下面我們使用PyTorch實現一個簡單的WGAN-GP模型,用于生成MNIST手寫數字。
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt# 設置隨機種子,確保結果可復現
torch.manual_seed(42)
np.random.seed(42)# 設置設備
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")# 超參數
batch_size = 64
lr = 0.0002
n_epochs = 50
latent_dim = 100
img_shape = (1, 28, 28)
lambda_gp = 10 # 梯度懲罰權重# 數據加載
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5], [0.5]) # 歸一化到[-1, 1]
])mnist_dataset = torchvision.datasets.MNIST(root='./data',train=True,transform=transform,download=True
)dataloader = DataLoader(mnist_dataset,batch_size=batch_size,shuffle=True,num_workers=2
)# 生成器網絡
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()def block(in_features, out_features, normalize=True):layers = [nn.Linear(in_features, out_features)]if normalize:layers.append(nn.BatchNorm1d(out_features, 0.8))layers.append(nn.LeakyReLU(0.2, inplace=True))return layersself.model = nn.Sequential(*block(latent_dim, 128, normalize=False),*block(128, 256),*block(256, 512),*block(512, 1024),nn.Linear(1024, int(np.prod(img_shape))),nn.Tanh() # 輸出歸一化到[-1, 1])def forward(self, z):img = self.model(z)img = img.view(img.size(0), *img_shape)return img# 判別器網絡(評論家)
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.model = nn.Sequential(nn.Linear(int(np.prod(img_shape)), 512),nn.LeakyReLU(0.2, inplace=True),nn.Linear(512, 256),nn.LeakyReLU(0.2, inplace=True),nn.Linear(256, 1)# 注意:沒有sigmoid激活函數)def forward(self, img):img_flat = img.view(img.size(0), -1)validity = self.model(img_flat)return validity# 初始化網絡
generator = Generator().to(device)
discriminator = Discriminator().to(device)# 初始化優化器
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.9))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.9))# 計算梯度懲罰
def compute_gradient_penalty(D, real_samples, fake_samples):"""計算WGAN-GP中的梯度懲罰"""# 在真實樣本和生成樣本之間隨機插值alpha = torch.rand(real_samples.size(0), 1, 1, 1, device=device)interpolates = (alpha * real_samples + (1 - alpha) * fake_samples).requires_grad_(True)# 計算插值點的判別器輸出d_interpolates = D(interpolates)# 計算梯度fake = torch.ones(d_interpolates.size(), device=device, requires_grad=False)gradients = torch.autograd.grad(outputs=d_interpolates,inputs=interpolates,grad_outputs=fake,create_graph=True,retain_graph=True,only_inputs=True)[0]# 計算梯度范數gradients = gradients.view(gradients.size(0), -1)gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()return gradient_penalty# 訓練函數
def train_wgan_gp():# 用于記錄損失d_losses = []g_losses = []for epoch in range(n_epochs):for i, (real_imgs, _) in enumerate(dataloader):real_imgs = real_imgs.to(device)batch_size = real_imgs.shape[0]# ---------------------# 訓練判別器# ---------------------optimizer_D.zero_grad()# 生成隨機噪聲z = torch.randn(batch_size, latent_dim, device=device)# 生成一批假圖像fake_imgs = generator(z)# 判別器前向傳播real_validity = discriminator(real_imgs)fake_validity = discriminator(fake_imgs.detach())# 計算梯度懲罰gradient_penalty = compute_gradient_penalty(discriminator, real_imgs.data, fake_imgs.data)# WGAN-GP 判別器損失d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gradient_penaltyd_loss.backward()optimizer_D.step()# 每n_critic次迭代訓練一次生成器n_critic = 5if i % n_critic == 0:# ---------------------# 訓練生成器# ---------------------optimizer_G.zero_grad()# 生成一批新的假圖像gen_imgs = generator(z)# 判別器評估假圖像fake_validity = discriminator(gen_imgs)# WGAN 生成器損失g_loss = -torch.mean(fake_validity)g_loss.backward()optimizer_G.step()if i % 50 == 0:print(f"[Epoch {epoch}/{n_epochs}] [Batch {i}/{len(dataloader)}] "f"[D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]")d_losses.append(d_loss.item())g_losses.append(g_loss.item())# 每個epoch結束后保存生成的圖像樣本if (epoch + 1) % 10 == 0:save_sample_images(epoch)# 繪制損失曲線plt.figure(figsize=(10, 5))plt.plot(d_losses, label='Discriminator Loss')plt.plot(g_losses, label='Generator Loss')plt.xlabel('Iterations (x50)')plt.ylabel('Loss')plt.legend()plt.savefig('wgan_gp_loss.png')plt.close()# 保存樣本圖像
def save_sample_images(epoch):# 生成并保存樣本圖像z = torch.randn(25, latent_dim, device=device)gen_imgs = generator(z).detach().cpu()# 將圖像像素值從[-1, 1]轉換為[0, 1]gen_imgs = 0.5 * gen_imgs + 0.5# 創建圖像網格fig, axs = plt.subplots(5, 5, figsize=(10, 10))for i in range(5):for j in range(5):axs[i, j].imshow(gen_imgs[i*5+j, 0, :, :], cmap='gray')axs[i, j].axis('off')# 保存圖像plt.savefig(f'wgan_gp_epoch_{epoch+1}.png')plt.close()# 運行訓練
if __name__ == "__main__":train_wgan_gp()
這段代碼實現了一個基本的WGAN-GP模型,用于生成MNIST數字圖像。下面我們來解析代碼的關鍵部分:
- 梯度懲罰計算:
compute_gradient_penalty
函數實現了WGAN-GP的核心——在真實樣本和生成樣本之間的插值點上計算梯度懲罰。 - 判別器損失:包括真實數據的評論家值、生成數據的評論家值,以及梯度懲罰項。
- 生成器損失:僅包含生成數據的評論家值的負期望。
- 優化器設置:使用Adam優化器,但β1參數設為0.5,這是GAN訓練的常見設置。
- 訓練循環:判別器和生成器交替訓練,但判別器通常訓練多次(n_critic=5)后才訓練一次生成器。
5. WGAN-GP訓練流程圖
以下是WGAN-GP的訓練流程圖,幫助理解整個訓練過程:
┌────────────────────┐
│ 初始化網絡和優化器 │
└──────────┬─────────┘│▼
┌────────────────────┐
│ 開始訓練循環 │
└──────────┬─────────┘│▼
┌────────────────────┐
│ 從數據集加載真實樣本 │
└──────────┬─────────┘│▼
┌────────────────────┐
│ 生成隨機噪聲并產生 │
│ 假樣本 │
└──────────┬─────────┘│▼
┌────────────────────┐
│ 計算判別器對真實 │
│ 和假樣本的輸出 │
└──────────┬─────────┘│▼
┌────────────────────┐
│ 在樣本插值點上計算 │
│ 梯度懲罰 │
└──────────┬─────────┘│▼
┌────────────────────┐
│ 計算判別器損失 │
│ 并更新判別器參數 │
└──────────┬─────────┘│▼┌────┴─────┐│ i % n_critic ││ == 0? │└────┬─────┘No │ Yes┌─────────┘ └──────────┐│ ▼│ ┌────────────────────┐│ │ 重新生成假樣本 ││ └──────────┬─────────┘│ ││ ▼│ ┌────────────────────┐│ │ 計算生成器損失 ││ │ 并更新生成器參數 ││ └──────────┬─────────┘│ │└─────────────────────────┘│▼
┌────────────────────┐
│ 是否達到預定訓練輪數? │
└──────────┬─────────┘No │ Yes┌────┘ └──────────┐│ ▼│ ┌────────────────────┐└──────? │ 結束訓練 │└────────────────────┘
這個流程圖展示了WGAN-GP的訓練過程,包括梯度懲罰的計算和判別器多次訓練的機制。與普通GAN相比,WGAN-GP的關鍵區別在于梯度懲罰的引入和目標函數的改變。
6. 條件生成與無監督生成的對比
接下來,我們將探討條件生成與無監督生成在模式坍塌方面的差異。
6.1 無監督生成與模式坍塌
無監督生成是指生成器僅從隨機噪聲生成樣本,沒有額外的條件輸入。
模式坍塌(Mode Collapse)是GAN訓練中的常見問題,指生成器只學會生成數據分布中的少數幾種模式,而忽略了其他模式。例如,在MNIST數據集上,模型可能只生成數字"1"而不生成其他數字。
導致模式坍塌的原因:
- 判別器更新不足:判別器無法有效區分真假樣本
- 梯度消失:當判別器表現過好時,生成器梯度接近零
- 目標函數設計問題:JS散度在兩個分布不重疊時提供有限的梯度信息
6.2 條件生成對模式坍塌的緩解
條件生成是指生成器不僅接收隨機噪聲,還接收額外的條件信息(如類別標簽)作為輸入。
條件GAN(CGAN)通過以下方式緩解模式坍塌:
- 強制生成器覆蓋所有類別:通過提供不同的類別條件,迫使生成器學習生成不同類別的樣本
- 簡化學習任務:條件信息使生成器只需要學習條件分布,而非整個聯合分布
- 提供更多監督信號:條件信息為生成器提供了額外的指導
6.3 條件生成與無監督生成的模式坍塌差異表
特性 | 無監督生成 | 條件生成 |
---|---|---|
輸入 | 僅隨機噪聲 | 隨機噪聲 + 條件信息 |
模式覆蓋 | 容易忽略部分模式 | 被條件強制覆蓋更多模式 |
生成樣本多樣性 | 較低,傾向于生成相似樣本 | 較高,不同條件生成不同樣本 |
訓練穩定性 | 較差,易發生模式坍塌 | 較好,條件信息提供穩定指導 |
應用靈活性 | 生成過程不可控 | 可控制生成特定類別/屬性的樣本 |
實現復雜度 | 相對簡單 | 需要額外的條件嵌入機制 |
7. 實現條件WGAN-GP
下面我們將實現一個條件版本的WGAN-GP,以比較其與無監督版本在模式坍塌方面的差異。
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt# 設置隨機種子
torch.manual_seed(42)
np.random.seed(42)# 設置設備
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")# 超參數
batch_size = 64
lr = 0.0002
n_epochs = 50
latent_dim = 100
img_shape = (1, 28, 28)
n_classes = 10 # MNIST有10個類別
lambda_gp = 10 # 梯度懲罰權重# 數據加載
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5], [0.5]) # 歸一化到[-1, 1]
])mnist_dataset = torchvision.datasets.MNIST(root='./data',train=True,transform=transform,download=True
)dataloader = DataLoader(mnist_dataset,batch_size=batch_size,shuffle=True,num_workers=2
)# 條件生成器網絡
class ConditionalGenerator(nn.Module):def __init__(self):super(ConditionalGenerator, self).__init__()# 嵌入層將類別標簽轉換為嵌入向量self.label_embedding = nn.Embedding(n_classes, n_classes)# 輸入層處理噪聲和類別嵌入self.input_layer = nn.Linear(latent_dim + n_classes, 128)# 主要模型self.model = nn.Sequential(nn.BatchNorm1d(128, 0.8),nn.LeakyReLU(0.2, inplace=True),nn.Linear(128, 256),nn.BatchNorm1d(256, 0.8),nn.LeakyReLU(0.2, inplace=True),nn.Linear(256, 512),nn.BatchNorm1d(512, 0.8),nn.LeakyReLU(0.2, inplace=True),nn.Linear(512, 1024),nn.BatchNorm1d(1024, 0.8),nn.LeakyReLU(0.2, inplace=True),nn.Linear(1024, int(np.prod(img_shape))),nn.Tanh())def forward(self, noise, labels):# 將標簽嵌入向量與噪聲拼接label_embedding = self.label_embedding(labels)x = torch.cat([noise, label_embedding], dim=1)# 通過輸入層x = self.input_layer(x)# 通過主模型x = self.model(x)# 重塑為圖像格式img = x.view(x.size(0), *img_shape)return img# 條件判別器網絡
class ConditionalDiscriminator(nn.Module):def __init__(self):super(ConditionalDiscriminator, self).__init__()# 嵌入層將類別標簽轉換為嵌入向量self.label_embedding = nn.Embedding(n_classes, n_classes)# 處理圖像和標簽self.model = nn.Sequential(nn.Linear(int(np.prod(img_shape)) + n_classes, 512),nn.LeakyReLU(0.2, inplace=True),nn.Linear(512, 256),nn.LeakyReLU(0.2, inplace=True),nn.Linear(256, 1))def forward(self, img, labels):# 將圖像展平img_flat = img.view(img.size(0), -1)# 獲取標簽嵌入label_embedding = self.label_embedding(labels)# 拼接圖像特征和標簽嵌入x = torch.cat([img_flat, label_embedding], dim=1)# 通過判別器網絡validity = self.model(x)return validity# 初始化網絡
generator = ConditionalGenerator().to(device)
discriminator = ConditionalDiscriminator().to(device)# 初始化優化器
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.9))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.9))# 計算梯度懲罰(條件版本)
def compute_gradient_penalty(D, real_samples, fake_samples, labels):"""計算條件WGAN-GP的梯度懲罰"""# 在真實樣本和生成樣本之間隨機插值alpha = torch.rand(real_samples.size(0), 1, 1, 1, device=device)interpolates = (alpha * real_samples + (1 - alpha) * fake_samples).requires_grad_(True)# 計算插值點的判別器輸出(帶條件)d_interpolates = D(interpolates, labels)# 計算梯度fake = torch.ones(d_interpolates.size(), device=device, requires_grad=False)gradients = torch.autograd.grad(outputs=d_interpolates,inputs=interpolates,grad_outputs=fake,create_graph=True,retain_graph=True,only_inputs=True)[0]# 計算梯度范數gradients = gradients.view(gradients.size(0), -1)gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()return gradient_penalty# 訓練條件WGAN-GP
def train_conditional_wgan_gp():# 用于記錄損失d_losses = []g_losses = []# 用于記錄生成樣本的多樣性(通過類別分布)class_distributions = []for epoch in range(n_epochs):for i, (real_imgs, labels) in enumerate(dataloader):real_imgs = real_imgs.to(device)labels = labels.to(device)batch_size = real_imgs.shape[0]# ---------------------# 訓練判別器# ---------------------optimizer_D.zero_grad()# 生成隨機噪聲z = torch.randn(batch_size, latent_dim, device=device)# 為生成器生成隨機標簽gen_labels = torch.randint(0, n_classes, (batch_size,), device=device)# 生成一批假圖像fake_imgs = generator(z, gen_labels)# 判別器前向傳播real_validity = discriminator(real_imgs, labels)fake_validity = discriminator(fake_imgs.detach(), gen_labels)# 計算梯度懲罰gradient_penalty = compute_gradient_penalty(discriminator, real_imgs.data, fake_imgs.data, labels)# WGAN-GP 判別器損失d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gradient_penaltyd_loss.backward()optimizer_D.step()# 每n_critic次迭代訓練一次生成器n_critic = 5if i % n_critic == 0:# ---------------------# 訓練生成器# ---------------------optimizer_G.zero_grad()# 為生成器生成新的隨機標簽gen_labels = torch.randint(0, n_classes, (batch_size,), device=device)# 生成一批新的假圖像gen_imgs = generator(z, gen_labels)# 判別器評估假圖像fake_validity = discriminator(gen_imgs, gen_labels)# WGAN 生成器損失g_loss = -torch.mean(fake_validity)g_loss.backward()optimizer_G.step()if i % 50 == 0:print(f"[Epoch {epoch}/{n_epochs}] [Batch {i}/{len(dataloader)}] "f"[D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]")d_losses.append(d_loss.item())g_losses.append(g_loss.item())# 每個epoch結束后,評估生成樣本的類別分布if (epoch + 1) % 10 == 0:class_dist = evaluate_class_distribution()class_distributions.append(class_dist)# 保存生成的圖像樣本save_sample_images(epoch)# 繪制損失曲線plt.figure(figsize=(10, 5))plt.plot(d_losses, label='Discriminator Loss')plt.plot(g_losses, label='Generator Loss')plt.xlabel('Iterations (x50)')plt.ylabel('Loss')plt.legend()plt.savefig('cond_wgan_gp_loss.png')plt.close()# 繪制類別分布變化plot_class_distributions(class_distributions)# 評估生成樣本的類別分布
def evaluate_class_distribution():"""評估生成樣本在各類別上的分布情況"""# 創建一個預訓練的分類器classifier = torchvision.models.resnet18(pretrained=True)# 修改第一個卷積層以適應灰度圖classifier.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)# 修改最后的全連接層以適應MNIST的10個類別classifier.fc = nn.Linear(classifier.fc.in_features, 10)# 加載預先訓練好的MNIST分類器權重(這里假設我們有一個預訓練的模型)# classifier.load_state_dict(torch.load('mnist_classifier.pth'))# 簡化起見,這里我們使用一個簡單的CNN分類器classifier = nn.Sequential(nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Flatten(),nn.Linear(64 * 7 * 7, 128),nn.ReLU(),nn.Linear(128, 10)).to(device)# 這里假設這個簡單分類器已經在MNIST上訓練好了# 實際應用中,應該加載一個預先訓練好的模型# 生成1000個樣本z = torch.randn(1000, latent_dim, device=device)# 均勻采樣所有類別gen_labels = torch.tensor([i % 10 for i in range(1000)], device=device)gen_imgs = generator(z, gen_labels)# 使用分類器預測類別with torch.no_grad():classifier.eval()preds = torch.softmax(classifier(gen_imgs), dim=1)pred_labels = torch.argmax(preds, dim=1)# 計算每個類別的樣本數量class_counts = torch.zeros(10)for i in range(10):class_counts[i] = (pred_labels == i).sum().item() / 1000return class_counts.numpy()# 繪制類別分布變化
def plot_class_distributions(class_distributions):"""繪制生成樣本類別分布的變化"""epochs = [10, 20, 30, 40, 50] # 假設每10個epoch評估一次plt.figure(figsize=(12, 8))for i, dist in enumerate(class_distributions):plt.subplot(len(class_distributions), 1, i+1)plt.bar(np.arange(10), dist)plt.ylabel(f'Epoch {epochs[i]}')plt.ylim(0, 0.3) # 限制y軸范圍,便于比較if i == len(class_distributions) - 1:plt.xlabel('Digit Class')plt.tight_layout()plt.savefig('class_distribution.png')plt.close()# 保存樣本圖像(條件版本)
def save_sample_images(epoch):"""保存按類別排列的樣本圖像"""# 為每個類別生成樣本n_row = 10 # 每個類別一行n_col = 10 # 每個類別10個樣本fig, axs = plt.subplots(n_row, n_col, figsize=(12, 12))for i in range(n_row):# 固定類別fixed_class = torch.tensor([i] * n_col, device=device)# 隨機噪聲z = torch.randn(n_col, latent_dim, device=device)# 生成圖像gen_imgs = generator(z, fixed_class).detach().cpu()# 轉換到[0, 1]范圍gen_imgs = 0.5 * gen_imgs + 0.5# 顯示圖像for j in range(n_col):axs[i, j].imshow(gen_imgs[j, 0, :, :], cmap='gray')axs[i, j].axis('off')plt.savefig(f'cond_wgan_gp_epoch_{epoch+1}.png')plt.close()# 運行條件WGAN-GP訓練
if __name__ == "__main__":train_conditional_wgan_gp()
清華大學全五版的《DeepSeek教程》完整的文檔需要的朋友,關注我私信:deepseek 即可獲得。
怎么樣今天的內容還滿意嗎?再次感謝朋友們的觀看,關注GZH:凡人的AI工具箱,回復666,送您價值199的AI大禮包。最后,祝您早日實現財務自由,還請給個贊,謝謝!