PyTorch生成式人工智能——ACGAN詳解與實現

PyTorch生成式人工智能——ACGAN詳解與實現

    • 0. 前言
    • 1. ACGAN 簡介
      • 1.1 ACGAN 技術原理
      • 1.2 ACGAN 核心思想
      • 1.3 損失函數
    • 2. 模型訓練流程
    • 3. 使用 PyTorch 構建 ACGAN
      • 3.1 數據處理
      • 3.2 模型構建
      • 3.3 模型訓練
      • 3.4 模型測試
    • 相關鏈接

0. 前言

在生成對抗網絡 (Generative Adversarial Network, GAN) 的眾多變體中,ACGAN (Auxiliary Classifier GAN) 是一個非常經典且實用的條件生成模型。它的核心思想是:在判別器中除了保留“真假判別”這一任務外,額外加入一個輔助分類器,讓判別器同時預測輸入樣本的類別。這樣,生成器在訓練時不僅需要“欺騙判別器”,還必須生成能夠被正確分類的樣本,從而在圖像語義和類別可控性上得到顯著提升。
這一改進讓 ACGAN 能夠在條件圖像生成中表現出色,在復雜數據集上實現按類別生成的能力。相比于傳統條件生成對抗網絡 (Conditional GAN, cGAN) 簡單地把標簽拼接到輸入,ACGAN 通過 “輔助分類監督” 提供了更細粒度的學習信號,使得生成器得到的梯度更加穩定和有意義。在本節中,將詳細介紹 ACGAN 原理,并使用 PyTorch 構建 ACGAN 模型。

1. ACGAN 簡介

1.1 ACGAN 技術原理

生成對抗網絡 (Generative Adversarial Network, GAN) 的眾多變體中,ACGAN (Auxiliary Classifier GAN) 能夠從隨機噪聲中生成逼真的圖像、文本甚至音樂。然而,傳統的 GAN 有一個顯著的局限性:缺乏對生成過程的精確控制。我們無法指定要生成“數字7”的圖片還是一只“戴墨鏡的貓”。
為了解決這個問題,條件生成對抗網絡 (Conditional GAN, cGAN) 應運而生。它通過將類別標簽信息同時注入生成器 (Generator) 和判別器 (Discriminator),實現了條件生成。但這仍然不夠完美,cGAN 的判別器最終只輸出一個“真/假”的概率,它并沒有顯式地告訴生成器它生成的圖片在類別上是否正確。
ACGAN (Auxiliary Classifier GAN) 正是在 CGAN 的基礎上,對判別器的任務進行了至關重要的擴展。它不僅判斷真偽,還同時擔任一個“分類器”的角色。這個簡單的改變,極大地提升了生成圖像的質量和多樣性,尤其是在生成特定類別的圖像時。

1.2 ACGAN 核心思想

ACGAN 的核心思想非常直觀:為判別器增加一個輔助任務——對輸入圖像進行分類。其中,生成器的輸入包括隨機噪聲向量 zzz 和目標類別標簽 ccc;判別器的輸出包括:

  • 一個源 (Source) 輸出:一個標量概率,表示圖像是來自真實數據分布的概率
  • 一個輔助類別 (Class) 輸出:類別概率分布

通過引入這個輔助的分類任務,ACGAN 迫使判別器不僅要學習“什么樣的圖像是真實的”,還要學習“真實圖像屬于什么類別”。反過來,生成器為了欺騙這個更強大的判別器,也必須生成既逼真又類別分明的圖像。

1.3 損失函數

損失函數包含兩部分:

  • 源判別損失 (source loss),用來訓練真假判別,通常使用二元交叉熵
  • 類別判別損失 (auxiliary classification loss),使用多元交叉熵(真實圖像的類別為真實標簽,生成圖像的類別為生成器的條件標簽

訓練目標:

  • 判別器 D,最小化源判別損失(正確區分真實/虛假圖像)并最小化類別判別損失(正確預測類別)
  • 生成器 G:生成圖像以最大化判別器認為是“真實”的概率,并最小化判別器給出的類別預測與條件類的一致性

2. 模型訓練流程

模型訓練流程如下:

  1. 從真實數據中取一批數據 x,c{x,c}x,c
  2. 判別器更新:
    • 計算真實樣本的源判別損失與類別判別損失
    • 用噪聲和隨機標簽生成虛假樣本 x~=G(z,c)\tilde x=G(z,c)x~=G(z,c),計算虛假樣本的源判別損失與類別判別損失(可選)
    • 把這些損失加權后更新 D
  3. 生成器更新:
    • 用一批噪聲與條件標簽生成樣本 x~\tilde xx~
    • 通過 D 計算源輸出與輔助類別輸出
    • 生成器的損失是希望源輸出為“真實”,并希望輔助類別輸出為生成時的條件標簽
    • 更新 G

3. 使用 PyTorch 構建 ACGAN

接下來,使用 PyTorch 實現 ACGAN,并在 MNIST 數據集上進行訓練生成手寫數字。

3.1 數據處理

(1) 首先,導入所需庫并設置超參數:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.utils import save_image
import numpy as np
import matplotlib.pyplot as plt
import os# 設置超參數
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
latent_dim = 100
num_classes = 10
batch_size = 64
lr = 0.0002
num_epochs = 100
sample_interval = 400# 創建輸出目錄
os.makedirs("images", exist_ok=True)
os.makedirs("models", exist_ok=True)

(2) 加載 MNIST 數據集,將圖像轉換為張量,并將像素值從 [0,1] 歸一化到 [-1,1] 范圍:

transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5], [0.5])
])
dataset = torchvision.datasets.MNIST(root="./data",train=True,download=True,transform=transform
)

(3) 構建數據加載器:

dataloader = torch.utils.data.DataLoader(dataset,batch_size=batch_size,shuffle=True
)

3.2 模型構建

(1) 首先,定義權重初始化函數:

def weights_init_normal(m):classname = m.__class__.__name__if classname.find("Conv") != -1:torch.nn.init.normal_(m.weight.data, 0.0, 0.02)elif classname.find("BatchNorm") != -1:torch.nn.init.normal_(m.weight.data, 1.0, 0.02)torch.nn.init.constant_(m.bias.data, 0.0)

(2) 定義生成器。生成器接收隨機噪聲和類別標簽作為輸入,通過嵌入層將標簽轉換為與噪聲相同維度的向量,然后將二者相乘融合,之后通過全連接層和轉置卷積層逐步上采樣,最終生成 28 x 28 的圖像:

class Generator(nn.Module):def __init__(self, latent_dim, num_classes):super(Generator, self).__init__()# 將類別標簽轉換為嵌入向量self.label_emb = nn.Embedding(num_classes, latent_dim)self.init_size = 7  # 初始特征圖大小self.l1 = nn.Sequential(nn.Linear(latent_dim, 128 * self.init_size ** 2))self.conv_blocks = nn.Sequential(nn.BatchNorm2d(128),nn.Upsample(scale_factor=2),nn.Conv2d(128, 128, 3, stride=1, padding=1),nn.BatchNorm2d(128, 0.8),nn.LeakyReLU(0.2, inplace=True),nn.Upsample(scale_factor=2),nn.Conv2d(128, 64, 3, stride=1, padding=1),nn.BatchNorm2d(64, 0.8),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(64, 1, 3, stride=1, padding=1),nn.Tanh())def forward(self, noise, labels):# 將噪聲和標簽嵌入相乘gen_input = torch.mul(self.label_emb(labels), noise)out = self.l1(gen_input)out = out.view(out.shape[0], 128, self.init_size, self.init_size)img = self.conv_blocks(out)return img

除了將類別標簽轉換為嵌入向量進行融合外,也可以直接使用標簽的獨熱編碼與噪聲向量進行拼接。

(3) 定義判別器。判別器使用卷積層逐步提取特征,最后通過兩個全連接層分別輸出樣本真偽的概率(源判別輸出)和類別概率(類別判別輸出):

class Discriminator(nn.Module):def __init__(self, num_classes):super(Discriminator, self).__init__()# 卷積層提取特征self.features = nn.Sequential(# 輸入: 1x28x28nn.Conv2d(1, 16, 3, stride=2, padding=1),  # 16x14x14nn.LeakyReLU(0.2, inplace=True),nn.Dropout2d(0.25),nn.Conv2d(16, 32, 3, stride=2, padding=1),  # 32x7x7nn.LeakyReLU(0.2, inplace=True),nn.Dropout2d(0.25),nn.BatchNorm2d(32, 0.8),nn.Conv2d(32, 64, 3, stride=2, padding=1),  # 64x4x4nn.LeakyReLU(0.2, inplace=True),nn.Dropout2d(0.25),nn.BatchNorm2d(64, 0.8),nn.Conv2d(64, 128, 3, stride=2, padding=1),  # 128x2x2nn.LeakyReLU(0.2, inplace=True),nn.Dropout2d(0.25),nn.BatchNorm2d(128, 0.8),)# 計算特征圖大小: 128 * 2 * 2 = 512self.feature_size = 128 * 2 * 2# 輸出真實/虛假的概率self.adv_layer = nn.Sequential(nn.Linear(self.feature_size, 1), nn.Sigmoid())# 輸出類別概率self.aux_layer = nn.Sequential(nn.Linear(self.feature_size, num_classes), nn.Softmax(dim=1))def forward(self, img):# 提取特征features = self.features(img)features = features.view(features.size(0), -1)  # 展平# 預測真偽和類別validity = self.adv_layer(features)label = self.aux_layer(features)return validity, label

(4) 初始化生成器和判別器,并打印模型結構:

generator = Generator(latent_dim, num_classes).to(device)
discriminator = Discriminator(num_classes).to(device)# 初始化權重
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)# 打印模型結構
print("Generator structure:")
print(generator)
print("\nDiscriminator structure:")
print(discriminator)

輸出模型結構如下所示:

模型結構

3.3 模型訓練

(1) 初始化損失函數和優化器:

# 定義損失函數
adversarial_loss = nn.BCELoss()
auxiliary_loss = nn.CrossEntropyLoss()# 定義優化器
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

(2) 定義變量記錄訓練過程的損失變化:

G_losses = []
D_losses = []

(3) 實現訓練循環。訓練過程分為兩個部分,先訓練判別器,使其能正確區分真實和生成樣本,并正確分類;然后訓練生成器,使其能生成被判別器判定為真實且分類正確的樣本:

# 訓練循環
for epoch in range(num_epochs):for i, (imgs, labels) in enumerate(dataloader):batch_size = imgs.shape[0]# 準備真實/虛假標簽valid = torch.ones(batch_size, 1).to(device)fake = torch.zeros(batch_size, 1).to(device)# 真實圖像和標簽real_imgs = imgs.to(device)real_labels = labels.to(device)#  訓練判別器optimizer_D.zero_grad()# 真實樣本的損失real_pred, real_aux = discriminator(real_imgs)d_real_loss = (adversarial_loss(real_pred, valid) + auxiliary_loss(real_aux, real_labels)) / 2# 生成虛假樣本z = torch.randn(batch_size, latent_dim).to(device)gen_labels = torch.randint(0, num_classes, (batch_size,)).to(device)gen_imgs = generator(z, gen_labels)# 虛假樣本的損失fake_pred, fake_aux = discriminator(gen_imgs.detach())d_fake_loss = (adversarial_loss(fake_pred, fake) + auxiliary_loss(fake_aux, gen_labels)) / 2# 總判別器損失d_loss = (d_real_loss + d_fake_loss) / 2# 計算判別器準確率pred = np.concatenate([real_aux.data.cpu().numpy(), fake_aux.data.cpu().numpy()], axis=0)gt = np.concatenate([real_labels.data.cpu().numpy(), gen_labels.data.cpu().numpy()], axis=0)d_acc = np.mean(np.argmax(pred, axis=1) == gt)d_loss.backward()optimizer_D.step()#  訓練生成器optimizer_G.zero_grad()# 生成器希望判別器將虛假樣本判斷為真實validity, pred_label = discriminator(gen_imgs)g_loss = (adversarial_loss(validity, valid) + auxiliary_loss(pred_label, gen_labels)) / 2g_loss.backward()optimizer_G.step()# 記錄損失G_losses.append(g_loss.item())D_losses.append(d_loss.item())# 打印訓練狀態if i % 100 == 0:print(f"[Epoch {epoch}/{num_epochs}] [Batch {i}/{len(dataloader)}] "f"[D loss: {d_loss.item():.4f}, acc: {100*d_acc:.2f}%] "f"[G loss: {g_loss.item():.4f}]")# 定期保存生成的圖像樣本batches_done = epoch * len(dataloader) + iif batches_done % sample_interval == 0:# 保存生成器生成的圖像save_image(gen_imgs.data[:25], f"images/{batches_done}.png", nrow=5, normalize=True)

(4) 訓練完成后,保存模型權重:

torch.save(generator.state_dict(), "models/generator_final.pth")
torch.save(discriminator.state_dict(), "models/discriminator_final.pth")

(5) 繪制訓練過程中的損失曲線

plt.figure(figsize=(10, 5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses, label="G")
plt.plot(D_losses, label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.savefig("loss_curve.png")
plt.show()

訓練過程監測

3.4 模型測試

images 文件夾中可以看到訓練過程中生成的樣本,隨著訓練進行,生成的數字越來越清晰:

模型訓練過程

使用訓練完成的模型,生成制定類別的數字:

fig, axes = plt.subplots(2, 5, figsize=(10, 4))
for i in range(10):img = generate_digit(generator, i)ax = axes[i//5, i%5]ax.imshow(img.cpu().squeeze(), cmap='gray')ax.set_title(f"Digit: {i}")ax.axis('off')
plt.tight_layout()
plt.savefig("generated_digits.png")
plt.show()

生成結果

生成數字 1

fig, axes = plt.subplots(2, 5, figsize=(10, 4))
for i in range(10):img = generate_digit(generator, 1)ax = axes[i//5, i%5]ax.imshow(img.cpu().squeeze(), cmap='gray')ax.set_title(f"Digit: 1")ax.axis('off')
plt.tight_layout()
plt.savefig("generated_digits.png")
plt.show()

生成結果

相關鏈接

PyTorch生成式人工智能實戰:從零打造創意引擎
PyTorch生成式人工智能(1)——神經網絡與模型訓練過程詳解
PyTorch生成式人工智能(2)——PyTorch基礎
PyTorch生成式人工智能(3)——使用PyTorch構建神經網絡
PyTorch生成式人工智能(4)——卷積神經網絡詳解
PyTorch生成式人工智能(5)——分類任務詳解
PyTorch生成式人工智能(6)——生成模型(Generative Model)詳解
PyTorch生成式人工智能(7)——生成對抗網絡實踐詳解
PyTorch生成式人工智能(8)——深度卷積生成對抗網絡
PyTorch生成式人工智能(9)——Pix2Pix詳解與實現
PyTorch生成式人工智能(10)——CyclelGAN詳解與實現
PyTorch生成式人工智能(11)——神經風格遷移
PyTorch生成式人工智能(12)——StyleGAN詳解與實現
PyTorch生成式人工智能(13)——WGAN詳解與實現
PyTorch生成式人工智能(14)——條件生成對抗網絡(conditional GAN,cGAN)
PyTorch生成式人工智能(15)——自注意力生成對抗網絡(Self-Attention GAN, SAGAN)
PyTorch生成式人工智能(16)——自編碼器(AutoEncoder)詳解
PyTorch生成式人工智能(17)——變分自編碼器詳解與實現
PyTorch生成式人工智能(18)——循環神經網絡詳解與實現
PyTorch生成式人工智能(19)——自回歸模型詳解與實現
PyTorch生成式人工智能(20)——像素卷積神經網絡(PixelCNN)
PyTorch生成式人工智能(21)——歸一化流模型(Normalizing Flow Model)
PyTorch生成式人工智能(27)——從零開始訓練GPT模型
PyTorch生成式人工智能(28)——MuseGAN詳解與實現

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/96333.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/96333.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/96333.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Python + 淘寶 API 開發:自動化采集商品數據的完整流程?

在電商數據分析、競品監控和市場調研等場景中,高效采集淘寶商品數據是關鍵環節。本文將詳細介紹如何利用 Python 結合 API,構建一套自動化的商品數據采集系統,涵蓋從 API 申請到數據存儲的完整流程,并提供可直接運行的代碼實現。?…

2025.8.21總結

工作一年多了,在這期間,確實也有不少壓力,但每當工作有壓力的時候,最后面都會解決。好像每次遇到解決不了的事情,都有同事給我兜底。這種壓力,確實會加速一個人的成長。這種狼性文化,這種環境&a…

VS2022 - C#程序簡單打包操作

文章目錄VS2022 - C#程序簡單打包操作概述筆記實驗過程新建工程讓依賴的運行時程序安裝包在安裝時運行(如果發現運行時不能每次都安裝程序,就不要做這步)關于”運行時安裝程序無法每次都安裝成功“的應對知識點嘗試打包舊工程bug修復從需求屬性中,可以原…

在JAVA中如何給Main方法傳參?

一、在IDEA中進行傳參:先創建一個類:MainTestimport java.util.Arrays;public class MainTest {public static void main(String[] args) {System.out.println(args.length);System.out.println(Arrays.toString(args));} }1.IDEA ---> 在運行的按鈕上…

ORACLE中如何批量重置序列

背景:數據庫所有序列都重置為1了,所以要將所有的序列都更新為對應的表主鍵(這里是id)的最大值1。我這里序列的規則是SEQ_表名。BEGINENHANCED_SYNC_SEQUENCES(WJ_CPP); -- 替換為你的模式名 END; / CREATE OR REPLACE PROCEDURE E…

公號文章排版教程:圖文雙排、添加圖片超鏈接、往期推薦、推文采集(2025-08-21)

文章目錄 排版的基本原則 I 圖片超鏈接 方式1: 利用公號原生編輯器 方式2:在CSDN平臺使用markdown編輯器, 利用標簽實現圖片鏈接。 II 排版小技巧 自定義頁面模版教程 使用壹伴進行文章素材的采集 美編助手的往期推薦還不錯 利用365編輯器創建圖文雙排效果 排版的基本原則 親…

計算兩幅圖像在特定交點位置的置信度評分。置信度評分反映了該位置特征匹配的可靠性,通常用于圖像處理任務(如特征匹配、立體視覺等)

這段代碼定義了一個名為compute_confidence的函數,用于計算兩幅圖像在特定交點位置的置信度評分。置信度評分反映了該位置特征匹配的可靠性,通常用于圖像處理任務(如特征匹配、立體視覺等)。以下是逐部分解析: 3. 結果…

計算機視覺第一課opencv(三)保姆級教學

簡介 計算機視覺第一課opencv(一)保姆級教學 計算機視覺第一課opencv(二)保姆級教學 今天繼續學習opencv。 一、 圖像形態學 什么是形態學:圖像形態學是一種處理圖像形狀特征的圖像處理技術,主要用于描…

24.早期目標檢測

早期目標檢測 第一步,計算機圖形學做初步大量候選框,把物體圈出來 第二步,依次將所有的候選框圖片,輸入到分類模型進行判斷 選擇性搜索 選擇搜索算法(Selective Search),是一種熟知的計算機圖像…

Java基礎知識點匯總(三)

一、面向對象的特征有哪些方面 Java中面向對象的特征主要包括以下四個核心方面:封裝(Encapsulation) 封裝是指將對象的屬性(數據)和方法(操作)捆綁在一起,隱藏對象的內部實現細節&am…

GEO優化專家孟慶濤:讓AI“聰明”選擇,為企業“精準”生長

在生成式AI席卷全球的今天,企業最常遇到的困惑或許是:“為什么我的AI生成內容總像‘模板套娃’?”“用戶明明想要A,AI卻拼命輸出B?”當生成式AI從“能用”邁向“好用”的關鍵階段,如何讓AI真正理解用戶需求…

【交易系統系列04】交易所版《速度與激情》:如何為狂飆的BTC交易引擎上演“空中加油”?

交易所版《速度與激情》:如何為狂飆的BTC交易引擎上演“空中加油”? 想象一下這個場景:你正端著一杯熱氣騰騰的咖啡,看著窗外我家那只貪睡的橘貓趴在陽光下打著呼嚕。突然,手機上的警報開始尖叫,交易系統監…

windows下jdk環境切換為jdk17后,臨時需要jdk1.8的處理

近段時間,終于決定把開發環境全面轉向jdk17,這不就遇到了問題。 windows主環境已經設置為jdk17了。 修改的JAVA_HOME D:\java\jdk-17CLASSPATH設置 .;D:\java\jdk-17\lib\dt.jar;D:\java\jdk-17\lib\tools.jar;PATH中增加 D:\java\jdk-17\bin但是有些程序…

Android URC 介紹及源碼案例參考

1. URC 含義 URC 是 Unsolicited Result Code(非請求結果碼)的縮寫。 它是 modem(基帶)在不需要 AP 主動請求的情況下向上層主動上報的消息。 典型例子:短信到達提示、網絡狀態變更、來電通知、信號質量變化等。 URC 一般以 AT 命令擴展的形式從 modem 發到 AP,例如串口…

VB.NET發送郵件給OUTLOOK.COM的用戶,用OUTLOOK.COM郵箱賬號登錄給別人發郵件

在VB.NET中通過代碼發送郵件時,確實會遇到郵箱服務的身份認證(Authentication)要求。特別是微軟Outlook/Hotmail等服務,已經逐步禁用傳統的“基本身份驗證”(Basic Authentication),轉而強制要求…

【網絡運維】Shell:變量進階知識

Shell 變量進階知識 Shell 中的特殊變量 位置參數變量 Shell 腳本中常用的位置參數變量如下: $0:獲取當前執行的 Shell 腳本文件名(包含路徑時包括路徑)$n:獲取第 n 個參數值(n>9 時需使用 ${n}&#xf…

部署Qwen2.5-VL-7B-Instruct-GPTQ-Int3

模型下載 from modelscope import snapshot_download model_dir snapshot_download(ChineseAlpacaGroup/Qwen2.5-VL-7B-Instruct-GPTQ-Int3)相關包導入 import os import numpy as np import pandas as pd from tqdm import tqdm from datetime import datetime,timedelta fro…

sourcetree 拉取代碼

提示:文章旨在于教授大家 sourcetree 拉取代碼的方式,關于代碼的提交合并等操作后續會補充。 文章目錄前言一、sourcetree 安裝二、http 與 ssh 拉取代碼1.http 方式(1)生成 token(2)拼接項目的 url&#x…

epoll模型網絡編程知識要領

1、程序初始化創建監聽socket調用bind函數綁定ip地址、port端口號調用listen函數監聽調用epoll_create函數創建epollfd調用epoll_ctrl函數將listenfd綁定到epollfd上,監測listenfd的讀事件在一個無限循環中,調用epoll_wait函數等待事件發生2、處理客戶端…

15-day12LLM結構變化、位置編碼和投機采樣

多頭機制transformer結構歸一化層選擇 歸一化層位置歸一化層類型激活函數Llama2結構MoE架構 混合專家模型DeepSeek MLA為何需要位置編碼目前的主流位置編碼正余弦位置編碼可學習位置編碼ROPE旋轉位置編碼推導參考: https://spaces.ac.cn/archives/8265 https://zhua…