Pytorch深度學習框架60天進階學習計劃 - 第41天:生成對抗網絡進階(一)

Pytorch深度學習框架60天進階學習計劃 - 第41天:生成對抗網絡進階(一)

今天我們將深入探討生成對抗網絡(GAN)的進階內容,特別是Wasserstein GAN(WGAN)的梯度懲罰機制,以及條件生成與無監督生成在模式坍塌方面的差異。

生成對抗網絡是近年來深度學習領域最激動人心的進展之一,它由Ian Goodfellow于2014年提出,通過生成器和判別器的博弈來學習生成真實數據分布的樣本。隨著研究的深入,GAN的改進版本層出不窮,其中WGAN及其梯度懲罰版本(WGAN-GP)解決了原始GAN訓練不穩定的問題,成為了GAN研究的重要里程碑。

今天我們將從理論到實踐,系統地學習這些進階概念,并通過PyTorch實現相關模型,探索其工作原理。

1. GAN基礎回顧

在深入WGAN之前,讓我們簡要回顧GAN的基本原理:

1.1 GAN的基本架構

GAN由兩部分組成:

  • 生成器(Generator): 學習從隨機噪聲生成看起來真實的數據
  • 判別器(Discriminator): 學習區分真實數據和生成器生成的假數據

這兩個網絡通過對抗訓練相互提高:生成器嘗試生成越來越逼真的樣本以欺騙判別器,而判別器則努力提高其區分真假樣本的能力。

1.2 原始GAN的問題

雖然GAN的思想非常優雅,但原始GAN在訓練過程中存在一些問題:

  1. 訓練不穩定:很難找到生成器和判別器之間的平衡點
  2. 梯度消失:當判別器表現過好時,生成器梯度接近于零
  3. 模式坍塌:生成器只生成有限種類的樣本,無法覆蓋真實數據的全部分布
  4. 難以量化訓練進度:缺乏有效的指標來衡量生成樣本的質量

這些問題促使研究者尋找GAN的改進版本,其中WGAN是最重要的改進之一。

2. Wasserstein GAN詳解

2.1 從JS散度到Wasserstein距離

原始GAN隱式地最小化生成分布與真實分布之間的Jensen-Shannon(JS)散度,這在兩個分布沒有顯著重疊時會導致梯度問題。

Wasserstein距離(也稱Earth Mover’s Distance,簡稱EMD)提供了一種更平滑的度量方式,即使兩個分布沒有重疊或重疊很少,也能提供有意義的梯度。

Wasserstein距離的直觀解釋:想象將一個分布的概率質量移動到另一個分布所需的最小"工作量",其中工作量定義為概率質量乘以移動距離。

2.2 WGAN的核心改進

WGAN相比原始GAN有以下關鍵改進:

  1. 目標函數改變:使用Wasserstein距離而非JS散度
  2. 判別器(現稱為評論家/Critic)輸出不再是概率:移除了最后的sigmoid激活函數
  3. 權重裁剪:限制評論家的參數在一定范圍內,滿足Lipschitz約束
  4. 避免使用基于動量的優化器:建議使用RMSProp或Adam優化器(學習率較小)

2.3 WGAN的目標函數

WGAN的目標函數如下:

min ? G max ? D ∈ D E x ~ P r [ D ( x ) ] ? E z ~ P z [ D ( G ( z ) ) ] \min_G \max_{D \in \mathcal{D}} \mathbb{E}_{x \sim P_r}[D(x)] - \mathbb{E}_{z \sim P_z}[D(G(z))] Gmin?DDmax?ExPr??[D(x)]?EzPz??[D(G(z))]

其中 D \mathcal{D} D是滿足1-Lipschitz約束的函數集合。

2.4 Lipschitz約束與權重裁剪

為了滿足Wasserstein距離計算中的Lipschitz約束,WGAN對評論家的參數進行了權重裁剪:將權重限制在 [ ? c , c ] [-c, c] [?c,c]的范圍內,其中 c c c是一個小常數(如0.01)。

然而,權重裁剪是一種粗糙的方法,會導致優化問題和容量浪費。這就引出了WGAN的進一步改進:梯度懲罰機制。

3. WGAN的梯度懲罰機制

3.1 權重裁剪的局限性

WGAN中的權重裁剪雖然簡單有效,但存在以下問題:

  1. 容量浪費:強制權重接近0或c,導致模型傾向于使用更簡單的函數
  2. 優化困難:可能導致梯度爆炸或消失
  3. 對架構敏感:不同網絡架構可能需要不同的裁剪范圍

3.2 梯度懲罰的原理

WGAN-GP(帶梯度懲罰的WGAN)提出了一種更優雅的方式來滿足Lipschitz約束。其核心思想是:

對于一個1-Lipschitz函數,其梯度范數在任何地方都不應超過1。因此,我們可以通過懲罰評論家函數梯度范數偏離1的行為來滿足這一約束。

具體來說,WGAN-GP在真實數據和生成數據之間的隨機插值點上施加梯度懲罰:

L G P = E x ^ ~ P x ^ [ ( ∣ ∣ ? x ^ D ( x ^ ) ∣ ∣ 2 ? 1 ) 2 ] \mathcal{L}_{GP} = \mathbb{E}_{\hat{x} \sim P_{\hat{x}}}[(||\nabla_{\hat{x}}D(\hat{x})||_2 - 1)^2] LGP?=Ex^Px^??[(∣∣?x^?D(x^)2??1)2]

其中 x ^ \hat{x} x^是在真實樣本 x x x和生成樣本 G ( z ) G(z) G(z)之間的隨機插值:

x ^ = ? x + ( 1 ? ? ) G ( z ) \hat{x} = \epsilon x + (1-\epsilon)G(z) x^=?x+(1??)G(z)

? \epsilon ?是一個在 [ 0 , 1 ] [0,1] [0,1]之間均勻采樣的隨機數。

3.3 WGAN-GP的完整目標函數

將梯度懲罰添加到WGAN的目標函數中,我們得到WGAN-GP的目標函數:

L = E z ~ p ( z ) [ D ( G ( z ) ) ] ? E x ~ p d a t a [ D ( x ) ] + λ E x ^ ~ P x ^ [ ( ∣ ∣ ? x ^ D ( x ^ ) ∣ ∣ 2 ? 1 ) 2 ] \mathcal{L} = \mathbb{E}_{z \sim p(z)}[D(G(z))] - \mathbb{E}_{x \sim p_{data}}[D(x)] + \lambda \mathbb{E}_{\hat{x} \sim P_{\hat{x}}}[(||\nabla_{\hat{x}}D(\hat{x})||_2 - 1)^2] L=Ezp(z)?[D(G(z))]?Expdata??[D(x)]+λEx^Px^??[(∣∣?x^?D(x^)2??1)2]

其中 λ \lambda λ是梯度懲罰的權重,通常設為10。

3.4 WGAN-GP的優勢

WGAN-GP相比WGAN有以下優勢:

  1. 更好的穩定性:避免了權重裁剪帶來的問題
  2. 更快的收斂:通常需要更少的迭代次數
  3. 更好的生成質量:能生成更多樣、更高質量的樣本
  4. 架構靈活性:適用于各種GAN架構,包括深度卷積網絡

4. PyTorch實現WGAN-GP

下面我們使用PyTorch實現一個簡單的WGAN-GP模型,用于生成MNIST手寫數字。

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt# 設置隨機種子,確保結果可復現
torch.manual_seed(42)
np.random.seed(42)# 設置設備
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")# 超參數
batch_size = 64
lr = 0.0002
n_epochs = 50
latent_dim = 100
img_shape = (1, 28, 28)
lambda_gp = 10  # 梯度懲罰權重# 數據加載
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5], [0.5])  # 歸一化到[-1, 1]
])mnist_dataset = torchvision.datasets.MNIST(root='./data',train=True,transform=transform,download=True
)dataloader = DataLoader(mnist_dataset,batch_size=batch_size,shuffle=True,num_workers=2
)# 生成器網絡
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()def block(in_features, out_features, normalize=True):layers = [nn.Linear(in_features, out_features)]if normalize:layers.append(nn.BatchNorm1d(out_features, 0.8))layers.append(nn.LeakyReLU(0.2, inplace=True))return layersself.model = nn.Sequential(*block(latent_dim, 128, normalize=False),*block(128, 256),*block(256, 512),*block(512, 1024),nn.Linear(1024, int(np.prod(img_shape))),nn.Tanh()  # 輸出歸一化到[-1, 1])def forward(self, z):img = self.model(z)img = img.view(img.size(0), *img_shape)return img# 判別器網絡(評論家)
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.model = nn.Sequential(nn.Linear(int(np.prod(img_shape)), 512),nn.LeakyReLU(0.2, inplace=True),nn.Linear(512, 256),nn.LeakyReLU(0.2, inplace=True),nn.Linear(256, 1)# 注意:沒有sigmoid激活函數)def forward(self, img):img_flat = img.view(img.size(0), -1)validity = self.model(img_flat)return validity# 初始化網絡
generator = Generator().to(device)
discriminator = Discriminator().to(device)# 初始化優化器
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.9))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.9))# 計算梯度懲罰
def compute_gradient_penalty(D, real_samples, fake_samples):"""計算WGAN-GP中的梯度懲罰"""# 在真實樣本和生成樣本之間隨機插值alpha = torch.rand(real_samples.size(0), 1, 1, 1, device=device)interpolates = (alpha * real_samples + (1 - alpha) * fake_samples).requires_grad_(True)# 計算插值點的判別器輸出d_interpolates = D(interpolates)# 計算梯度fake = torch.ones(d_interpolates.size(), device=device, requires_grad=False)gradients = torch.autograd.grad(outputs=d_interpolates,inputs=interpolates,grad_outputs=fake,create_graph=True,retain_graph=True,only_inputs=True)[0]# 計算梯度范數gradients = gradients.view(gradients.size(0), -1)gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()return gradient_penalty# 訓練函數
def train_wgan_gp():# 用于記錄損失d_losses = []g_losses = []for epoch in range(n_epochs):for i, (real_imgs, _) in enumerate(dataloader):real_imgs = real_imgs.to(device)batch_size = real_imgs.shape[0]# ---------------------#  訓練判別器# ---------------------optimizer_D.zero_grad()# 生成隨機噪聲z = torch.randn(batch_size, latent_dim, device=device)# 生成一批假圖像fake_imgs = generator(z)# 判別器前向傳播real_validity = discriminator(real_imgs)fake_validity = discriminator(fake_imgs.detach())# 計算梯度懲罰gradient_penalty = compute_gradient_penalty(discriminator, real_imgs.data, fake_imgs.data)# WGAN-GP 判別器損失d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gradient_penaltyd_loss.backward()optimizer_D.step()# 每n_critic次迭代訓練一次生成器n_critic = 5if i % n_critic == 0:# ---------------------#  訓練生成器# ---------------------optimizer_G.zero_grad()# 生成一批新的假圖像gen_imgs = generator(z)# 判別器評估假圖像fake_validity = discriminator(gen_imgs)# WGAN 生成器損失g_loss = -torch.mean(fake_validity)g_loss.backward()optimizer_G.step()if i % 50 == 0:print(f"[Epoch {epoch}/{n_epochs}] [Batch {i}/{len(dataloader)}] "f"[D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]")d_losses.append(d_loss.item())g_losses.append(g_loss.item())# 每個epoch結束后保存生成的圖像樣本if (epoch + 1) % 10 == 0:save_sample_images(epoch)# 繪制損失曲線plt.figure(figsize=(10, 5))plt.plot(d_losses, label='Discriminator Loss')plt.plot(g_losses, label='Generator Loss')plt.xlabel('Iterations (x50)')plt.ylabel('Loss')plt.legend()plt.savefig('wgan_gp_loss.png')plt.close()# 保存樣本圖像
def save_sample_images(epoch):# 生成并保存樣本圖像z = torch.randn(25, latent_dim, device=device)gen_imgs = generator(z).detach().cpu()# 將圖像像素值從[-1, 1]轉換為[0, 1]gen_imgs = 0.5 * gen_imgs + 0.5# 創建圖像網格fig, axs = plt.subplots(5, 5, figsize=(10, 10))for i in range(5):for j in range(5):axs[i, j].imshow(gen_imgs[i*5+j, 0, :, :], cmap='gray')axs[i, j].axis('off')# 保存圖像plt.savefig(f'wgan_gp_epoch_{epoch+1}.png')plt.close()# 運行訓練
if __name__ == "__main__":train_wgan_gp()

這段代碼實現了一個基本的WGAN-GP模型,用于生成MNIST數字圖像。下面我們來解析代碼的關鍵部分:

  1. 梯度懲罰計算compute_gradient_penalty函數實現了WGAN-GP的核心——在真實樣本和生成樣本之間的插值點上計算梯度懲罰。
  2. 判別器損失:包括真實數據的評論家值、生成數據的評論家值,以及梯度懲罰項。
  3. 生成器損失:僅包含生成數據的評論家值的負期望。
  4. 優化器設置:使用Adam優化器,但β1參數設為0.5,這是GAN訓練的常見設置。
  5. 訓練循環:判別器和生成器交替訓練,但判別器通常訓練多次(n_critic=5)后才訓練一次生成器。

5. WGAN-GP訓練流程圖

以下是WGAN-GP的訓練流程圖,幫助理解整個訓練過程:

┌────────────────────┐
│  初始化網絡和優化器  │
└──────────┬─────────┘│▼
┌────────────────────┐
│    開始訓練循環     │
└──────────┬─────────┘│▼
┌────────────────────┐
│  從數據集加載真實樣本 │
└──────────┬─────────┘│▼
┌────────────────────┐
│  生成隨機噪聲并產生  │
│     假樣本         │
└──────────┬─────────┘│▼
┌────────────────────┐
│  計算判別器對真實   │
│   和假樣本的輸出    │
└──────────┬─────────┘│▼
┌────────────────────┐
│  在樣本插值點上計算  │
│    梯度懲罰        │
└──────────┬─────────┘│▼
┌────────────────────┐
│   計算判別器損失    │
│   并更新判別器參數   │
└──────────┬─────────┘│▼┌────┴─────┐│ i % n_critic ││   == 0?   │└────┬─────┘No        │       Yes┌─────────┘       └──────────┐│                            ▼│              ┌────────────────────┐│              │   重新生成假樣本    ││              └──────────┬─────────┘│                         ││                         ▼│              ┌────────────────────┐│              │   計算生成器損失    ││              │   并更新生成器參數   ││              └──────────┬─────────┘│                         │└─────────────────────────┘│▼
┌────────────────────┐
│ 是否達到預定訓練輪數? │
└──────────┬─────────┘No   │   Yes┌────┘       └──────────┐│                       ▼│            ┌────────────────────┐└──────?     │      結束訓練      │└────────────────────┘

這個流程圖展示了WGAN-GP的訓練過程,包括梯度懲罰的計算和判別器多次訓練的機制。與普通GAN相比,WGAN-GP的關鍵區別在于梯度懲罰的引入和目標函數的改變。

6. 條件生成與無監督生成的對比

接下來,我們將探討條件生成與無監督生成在模式坍塌方面的差異。

6.1 無監督生成與模式坍塌

無監督生成是指生成器僅從隨機噪聲生成樣本,沒有額外的條件輸入。

模式坍塌(Mode Collapse)是GAN訓練中的常見問題,指生成器只學會生成數據分布中的少數幾種模式,而忽略了其他模式。例如,在MNIST數據集上,模型可能只生成數字"1"而不生成其他數字。

導致模式坍塌的原因:

  1. 判別器更新不足:判別器無法有效區分真假樣本
  2. 梯度消失:當判別器表現過好時,生成器梯度接近零
  3. 目標函數設計問題:JS散度在兩個分布不重疊時提供有限的梯度信息

6.2 條件生成對模式坍塌的緩解

條件生成是指生成器不僅接收隨機噪聲,還接收額外的條件信息(如類別標簽)作為輸入。

條件GAN(CGAN)通過以下方式緩解模式坍塌:

  1. 強制生成器覆蓋所有類別:通過提供不同的類別條件,迫使生成器學習生成不同類別的樣本
  2. 簡化學習任務:條件信息使生成器只需要學習條件分布,而非整個聯合分布
  3. 提供更多監督信號:條件信息為生成器提供了額外的指導

6.3 條件生成與無監督生成的模式坍塌差異表

特性無監督生成條件生成
輸入僅隨機噪聲隨機噪聲 + 條件信息
模式覆蓋容易忽略部分模式被條件強制覆蓋更多模式
生成樣本多樣性較低,傾向于生成相似樣本較高,不同條件生成不同樣本
訓練穩定性較差,易發生模式坍塌較好,條件信息提供穩定指導
應用靈活性生成過程不可控可控制生成特定類別/屬性的樣本
實現復雜度相對簡單需要額外的條件嵌入機制

7. 實現條件WGAN-GP

下面我們將實現一個條件版本的WGAN-GP,以比較其與無監督版本在模式坍塌方面的差異。

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt# 設置隨機種子
torch.manual_seed(42)
np.random.seed(42)# 設置設備
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")# 超參數
batch_size = 64
lr = 0.0002
n_epochs = 50
latent_dim = 100
img_shape = (1, 28, 28)
n_classes = 10  # MNIST有10個類別
lambda_gp = 10  # 梯度懲罰權重# 數據加載
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5], [0.5])  # 歸一化到[-1, 1]
])mnist_dataset = torchvision.datasets.MNIST(root='./data',train=True,transform=transform,download=True
)dataloader = DataLoader(mnist_dataset,batch_size=batch_size,shuffle=True,num_workers=2
)# 條件生成器網絡
class ConditionalGenerator(nn.Module):def __init__(self):super(ConditionalGenerator, self).__init__()# 嵌入層將類別標簽轉換為嵌入向量self.label_embedding = nn.Embedding(n_classes, n_classes)# 輸入層處理噪聲和類別嵌入self.input_layer = nn.Linear(latent_dim + n_classes, 128)# 主要模型self.model = nn.Sequential(nn.BatchNorm1d(128, 0.8),nn.LeakyReLU(0.2, inplace=True),nn.Linear(128, 256),nn.BatchNorm1d(256, 0.8),nn.LeakyReLU(0.2, inplace=True),nn.Linear(256, 512),nn.BatchNorm1d(512, 0.8),nn.LeakyReLU(0.2, inplace=True),nn.Linear(512, 1024),nn.BatchNorm1d(1024, 0.8),nn.LeakyReLU(0.2, inplace=True),nn.Linear(1024, int(np.prod(img_shape))),nn.Tanh())def forward(self, noise, labels):# 將標簽嵌入向量與噪聲拼接label_embedding = self.label_embedding(labels)x = torch.cat([noise, label_embedding], dim=1)# 通過輸入層x = self.input_layer(x)# 通過主模型x = self.model(x)# 重塑為圖像格式img = x.view(x.size(0), *img_shape)return img# 條件判別器網絡
class ConditionalDiscriminator(nn.Module):def __init__(self):super(ConditionalDiscriminator, self).__init__()# 嵌入層將類別標簽轉換為嵌入向量self.label_embedding = nn.Embedding(n_classes, n_classes)# 處理圖像和標簽self.model = nn.Sequential(nn.Linear(int(np.prod(img_shape)) + n_classes, 512),nn.LeakyReLU(0.2, inplace=True),nn.Linear(512, 256),nn.LeakyReLU(0.2, inplace=True),nn.Linear(256, 1))def forward(self, img, labels):# 將圖像展平img_flat = img.view(img.size(0), -1)# 獲取標簽嵌入label_embedding = self.label_embedding(labels)# 拼接圖像特征和標簽嵌入x = torch.cat([img_flat, label_embedding], dim=1)# 通過判別器網絡validity = self.model(x)return validity# 初始化網絡
generator = ConditionalGenerator().to(device)
discriminator = ConditionalDiscriminator().to(device)# 初始化優化器
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.9))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.9))# 計算梯度懲罰(條件版本)
def compute_gradient_penalty(D, real_samples, fake_samples, labels):"""計算條件WGAN-GP的梯度懲罰"""# 在真實樣本和生成樣本之間隨機插值alpha = torch.rand(real_samples.size(0), 1, 1, 1, device=device)interpolates = (alpha * real_samples + (1 - alpha) * fake_samples).requires_grad_(True)# 計算插值點的判別器輸出(帶條件)d_interpolates = D(interpolates, labels)# 計算梯度fake = torch.ones(d_interpolates.size(), device=device, requires_grad=False)gradients = torch.autograd.grad(outputs=d_interpolates,inputs=interpolates,grad_outputs=fake,create_graph=True,retain_graph=True,only_inputs=True)[0]# 計算梯度范數gradients = gradients.view(gradients.size(0), -1)gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()return gradient_penalty# 訓練條件WGAN-GP
def train_conditional_wgan_gp():# 用于記錄損失d_losses = []g_losses = []# 用于記錄生成樣本的多樣性(通過類別分布)class_distributions = []for epoch in range(n_epochs):for i, (real_imgs, labels) in enumerate(dataloader):real_imgs = real_imgs.to(device)labels = labels.to(device)batch_size = real_imgs.shape[0]# ---------------------#  訓練判別器# ---------------------optimizer_D.zero_grad()# 生成隨機噪聲z = torch.randn(batch_size, latent_dim, device=device)# 為生成器生成隨機標簽gen_labels = torch.randint(0, n_classes, (batch_size,), device=device)# 生成一批假圖像fake_imgs = generator(z, gen_labels)# 判別器前向傳播real_validity = discriminator(real_imgs, labels)fake_validity = discriminator(fake_imgs.detach(), gen_labels)# 計算梯度懲罰gradient_penalty = compute_gradient_penalty(discriminator, real_imgs.data, fake_imgs.data, labels)# WGAN-GP 判別器損失d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gradient_penaltyd_loss.backward()optimizer_D.step()# 每n_critic次迭代訓練一次生成器n_critic = 5if i % n_critic == 0:# ---------------------#  訓練生成器# ---------------------optimizer_G.zero_grad()# 為生成器生成新的隨機標簽gen_labels = torch.randint(0, n_classes, (batch_size,), device=device)# 生成一批新的假圖像gen_imgs = generator(z, gen_labels)# 判別器評估假圖像fake_validity = discriminator(gen_imgs, gen_labels)# WGAN 生成器損失g_loss = -torch.mean(fake_validity)g_loss.backward()optimizer_G.step()if i % 50 == 0:print(f"[Epoch {epoch}/{n_epochs}] [Batch {i}/{len(dataloader)}] "f"[D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]")d_losses.append(d_loss.item())g_losses.append(g_loss.item())# 每個epoch結束后,評估生成樣本的類別分布if (epoch + 1) % 10 == 0:class_dist = evaluate_class_distribution()class_distributions.append(class_dist)# 保存生成的圖像樣本save_sample_images(epoch)# 繪制損失曲線plt.figure(figsize=(10, 5))plt.plot(d_losses, label='Discriminator Loss')plt.plot(g_losses, label='Generator Loss')plt.xlabel('Iterations (x50)')plt.ylabel('Loss')plt.legend()plt.savefig('cond_wgan_gp_loss.png')plt.close()# 繪制類別分布變化plot_class_distributions(class_distributions)# 評估生成樣本的類別分布
def evaluate_class_distribution():"""評估生成樣本在各類別上的分布情況"""# 創建一個預訓練的分類器classifier = torchvision.models.resnet18(pretrained=True)# 修改第一個卷積層以適應灰度圖classifier.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)# 修改最后的全連接層以適應MNIST的10個類別classifier.fc = nn.Linear(classifier.fc.in_features, 10)# 加載預先訓練好的MNIST分類器權重(這里假設我們有一個預訓練的模型)# classifier.load_state_dict(torch.load('mnist_classifier.pth'))# 簡化起見,這里我們使用一個簡單的CNN分類器classifier = nn.Sequential(nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Flatten(),nn.Linear(64 * 7 * 7, 128),nn.ReLU(),nn.Linear(128, 10)).to(device)# 這里假設這個簡單分類器已經在MNIST上訓練好了# 實際應用中,應該加載一個預先訓練好的模型# 生成1000個樣本z = torch.randn(1000, latent_dim, device=device)# 均勻采樣所有類別gen_labels = torch.tensor([i % 10 for i in range(1000)], device=device)gen_imgs = generator(z, gen_labels)# 使用分類器預測類別with torch.no_grad():classifier.eval()preds = torch.softmax(classifier(gen_imgs), dim=1)pred_labels = torch.argmax(preds, dim=1)# 計算每個類別的樣本數量class_counts = torch.zeros(10)for i in range(10):class_counts[i] = (pred_labels == i).sum().item() / 1000return class_counts.numpy()# 繪制類別分布變化
def plot_class_distributions(class_distributions):"""繪制生成樣本類別分布的變化"""epochs = [10, 20, 30, 40, 50]  # 假設每10個epoch評估一次plt.figure(figsize=(12, 8))for i, dist in enumerate(class_distributions):plt.subplot(len(class_distributions), 1, i+1)plt.bar(np.arange(10), dist)plt.ylabel(f'Epoch {epochs[i]}')plt.ylim(0, 0.3)  # 限制y軸范圍,便于比較if i == len(class_distributions) - 1:plt.xlabel('Digit Class')plt.tight_layout()plt.savefig('class_distribution.png')plt.close()# 保存樣本圖像(條件版本)
def save_sample_images(epoch):"""保存按類別排列的樣本圖像"""# 為每個類別生成樣本n_row = 10  # 每個類別一行n_col = 10  # 每個類別10個樣本fig, axs = plt.subplots(n_row, n_col, figsize=(12, 12))for i in range(n_row):# 固定類別fixed_class = torch.tensor([i] * n_col, device=device)# 隨機噪聲z = torch.randn(n_col, latent_dim, device=device)# 生成圖像gen_imgs = generator(z, fixed_class).detach().cpu()# 轉換到[0, 1]范圍gen_imgs = 0.5 * gen_imgs + 0.5# 顯示圖像for j in range(n_col):axs[i, j].imshow(gen_imgs[j, 0, :, :], cmap='gray')axs[i, j].axis('off')plt.savefig(f'cond_wgan_gp_epoch_{epoch+1}.png')plt.close()# 運行條件WGAN-GP訓練
if __name__ == "__main__":train_conditional_wgan_gp()

清華大學全五版的《DeepSeek教程》完整的文檔需要的朋友,關注我私信:deepseek 即可獲得。

怎么樣今天的內容還滿意嗎?再次感謝朋友們的觀看,關注GZH:凡人的AI工具箱,回復666,送您價值199的AI大禮包。最后,祝您早日實現財務自由,還請給個贊,謝謝!

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/76751.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/76751.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/76751.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

大模型到底是怎么產生的?一文了解大模型誕生全過程

前言 大模型到底是怎么產生的呢? 本文將從最基礎的概念開始,逐步深入,用通俗易懂的語言為大家揭開大模型的神秘面紗。 大家好,我是大 F,深耕AI算法十余年,互聯網大廠核心技術崗。 知行合一,不寫水文,喜歡可關注,分享AI算法干貨、技術心得。 【專欄介紹】: 歡迎關注《…

五子棋(測試報告)

文章目錄 一、項目介紹二、測試用例三、自動化測試用例的部分展示注冊登錄游戲大廳游戲匹配 總結 一、項目介紹 本項目是一款基于Spring、SpringMVC、MyBatis、WebSocket的雙人實時對戰五子棋游戲,游戲操作便捷,功能清晰明了。 二、測試用例 三、自動化測試用例的…

idea開發工具多賬號使用拉取代碼報錯問題

設置git不使用憑證管理 把 use credential helper 取消勾選 然后重新pull代碼,并勾選remember 這樣就可以使用多賬號來連接管理代碼了

【OpenCV】【XTerminal】talk程序運用和linux進程之間通信程序編寫,opencv圖像庫編程聯系

目錄 一、talk程序的運用&Linux進程間通信程序的編寫 1.1使用talk程序和其他用戶交流 1.2用c語言寫一個linux進程之間通信(聊天)的簡單程序 1.服務器端程序socket_server.c編寫 2.客戶端程序socket_client.c編寫 3.程序編譯與使用 二、編寫一個…

【軟考系統架構設計師】信息系統基礎知識點

1、 信息的特點:客觀性(真偽性)、動態性、層次性、傳遞性、滯后性、擴壓性、分享性 2、 信息化:是指從工業社會到信息社會的演進與變革 3、 信息系統是由計算機硬件、網絡和通信設備、計算機軟件、信息資源、信息用戶和規章制度…

一種基于學習的多尺度方法及其在非彈性碰撞問題中的應用

A learning-based multiscale method and its application to inelastic impact problems 摘要: 我們在工程應用中觀察和利用的材料宏觀特性,源于電子、原子、缺陷、域等多尺度物理機制間復雜的相互作用。多尺度建模旨在通過利用固有的層次化結構來理解…

基于PyQt5的Jupyter Notebook轉Python工具

一、項目背景與核心價值 在數據科學領域,Jupyter Notebook因其交互特性廣受歡迎,但在生產環境中通常需要將其轉換為標準Python文件。本文介紹一款基于PyQt5開發的桌面級轉換工具,具有以下核心價值: 可視化操作:提供友好的GUI界面,告別命令行操作 批量處理:支持目錄遞歸…

圖論之并查集——含例題

目錄 介紹 秩是什么 例子——快速入門 例題 使用路徑壓縮,不使用秩合并 使用路徑壓縮和秩合并 無向圖和有向圖 介紹 并查集是一種用于 處理不相交集合的合并與查詢問題的數據結構。它主要涉及以下基本概念和操作: 基本概念: 集合&…

【數學建模】(智能優化算法)天牛須算法(Beetle Antennae Search, BAS)詳解與Python實現

天牛須算法(Beetle Antennae Search, BAS)詳解與Python實現 文章目錄 天牛須算法(Beetle Antennae Search, BAS)詳解與Python實現1. 引言2. 算法原理2.1 基本思想2.2 數學模型 3. Python實現4.實測效果測試1. Michalewicz函數的最小化測試2. Goldstein-Price函數的約束最小化 5…

【家政平臺開發(42)】筑牢家政平臺安全防線:安全測試與漏洞修復指南

本【家政平臺開發】專欄聚焦家政平臺從 0 到 1 的全流程打造。從前期需求分析,剖析家政行業現狀、挖掘用戶需求與梳理功能要點,到系統設計階段的架構選型、數據庫構建,再到開發階段各模塊逐一實現。涵蓋移動與 PC 端設計、接口開發及性能優化,測試階段多維度保障平臺質量,…

學習筆記八——內存管理相關

📘 目錄 內存結構基礎:棧、堆、數據段Rust 的內存管理機制(對比 C/C、Java)Drop:Rust 的自動清理機制Deref:為什么 *x 能訪問結構體內部值Rc:多個變量“共享一個資源”怎么辦?Weak&…

ReliefF 的原理

🌟 ReliefF 是什么? ReliefF 是一種“基于鄰居差異”的特征選擇方法,用來評估每個特征對分類任務的貢獻大小。 它的核心問題是: “我怎么知道某個特征是不是重要?是不是有能力把不同類別的數據區分開?” 而…

?asm匯編源代碼之-漢字點陣字庫顯示程序源代碼下載?

漢字點陣字庫顯示程序 源代碼下載 文本模式下顯示16x16點陣漢字庫內容的程序(標準16x16字庫需要使用CHGHZK轉換過后才能使用本程序正常顯示) 本程序需要調用file.asm和string.asm中的子程序,所以連接時需要把它們連接進來,如下 C:\> tlink showhzk file string 調用參…

【已更新完畢】2025泰迪杯數據挖掘競賽B題數學建模思路代碼文章教學:基于穿戴裝備的身體活動監測

基于穿戴裝備的身體活動監測 摘要 本研究基于加速度計采集的活動數據,旨在分析和統計100名志愿者在不同身體活動類別下的時長分布。通過對加速度數據的處理,活動被劃分為睡眠、靜態活動、低強度、中等強度和高強度五類,進而計算每個志愿者在…

Ubuntu24.04裝機安裝指南

文章目錄 Ubuntu24.04裝機安裝指南一、分區說明二、基礎軟件三、使用fcitx5配置中文輸入法四、安裝搜狗輸入法【**不推薦**】1. 安裝fcitx2. 安裝輸入法 五、禁用/home目錄下自動生成文件夾六、更新軟件源1. 針對**新配置方式**的清華源替換方法2. 針對**老配置方式**的清華源替…

互聯網三高-數據庫高并發之分庫分表ShardingJDBC

1 ShardingJDBC介紹 1.1 常見概念術語 ① 數據節點Node:數據分片的最小單元,由數據源名稱和數據表組成 如:ds0.product_order_0 ② 真實表:再分片的數據庫中真實存在的物理表 如:product_order_0 ③ 邏輯表&#xff1a…

BM25、BGE以及text2vec-base-chinese的區別

BM25、BGE以及text2vec-base-chinese的區別 BM25 原理:BM25(Best Matching 25)是一種基于概率檢索模型的算法,它通過考慮查詢詞與文檔之間的匹配程度、文檔的長度等因素,來計算文檔對于查詢的相關性得分。具體來說,它會給包含查詢詞次數較多、文檔長度適中的文檔更高的分…

Python中try用法、內置異常類型與自定義異常類型拓展

目錄 try介紹與語法格式try具體使用案例except的異常類型簡介案例內置的常見異常類型自定義異常類型繼承關系用途 注意事項 try介紹與語法格式 在 Python 里,try 語句主要用于異常處理,其作用是捕獲并處理代碼運行期間可能出現的異常,避免程…

【第41節】windows的中斷與異常及異常處理方式

目錄 一、中斷與異常處理 1.1 中斷與異常 1.2 IDT 1.3 異常的概念 1.4 異常分類 二、windows異常處理方式 2.1 概述 2.2 結構化異常處理 2.3 向量化異常處理之VEH 2.4 向量化異常處理之VCH 2.5 默認的異常處理函數 2.6 如何手動安裝 SEH 節點 2.7 異常處理的優先級…

分布式日志治理:Log4j2自定義Appender寫日志到RocketMQ

🧑 博主簡介:CSDN博客專家,歷代文學網(PC端可以訪問:https://literature.sinhy.com/#/?__c1000,移動端可微信小程序搜索“歷代文學”)總架構師,15年工作經驗,精通Java編…