前言
生成對抗網絡(GAN)是lan J. Goodfellow團隊在2014年提出的生成架構, 該架構自誕生起,就產生了很多的話題,更是被稱為生成對抗網絡是“新世紀以來機器學習領域內最有趣的想法”。如今,基于生成對抗網絡思想的架構在圖像處理、生成方面的能力越來越強大,已經成為視覺領域中不可忽視的存在,可以這樣說,GAN的誕生,讓這個世界的物理表象變得具有欺騙性了。
一、GAN的原理
最初GAN的目標就是用來生成圖像,首先GAN的目標就是利用生成器 G 根據真實數據生成的假數據,并希望生成的假數據能夠以假亂真,從而騙過判別器 D;同時,又希望判別器 D 能夠區分真假數據。如下所示:
整個網絡由生成器 G 和判別器 D 構成,隨機初始化噪聲數據,然后輸入生成器生成假數據,判別器判斷生成的數據和真實數據哪個才是真的。生成器沒有標簽,是無監督網絡,判別器是監督網絡,標簽是“真或假”(0/1)。原始論文規定判別器輸出當前數據為真的概率(標簽為1的概率),當概率大于0.5,判別器認為樣本是真實數據,小于0.5,判別器認為樣本是由生成器生成的假數據。
其核心思想就是:
- 通過兩個神經網絡(生成器 G 與判別器 D)之間的對抗博弈,讓生成器學會產生以假亂真的數據
二、GAN的損失函數
其實GAN的整過訓練過程就是一個零和博弈。在訓練過程中,生成器和判別器的目標是相互競爭的:生成器的任務是盡可能生成以假亂真的數據,讓判別器判斷不出來,其目的就是讓判別器的準確性降低;箱單,判別器的目的是盡量判斷出真偽,讓自己判斷的準確性越來越高。
當生成器生成的數據越來越真時,判別器為保證自己的準確性,就會朝著判斷能力強的方向迭代。當判斷器判斷能力越來越強大時,生成器為了保證自己生成的真實性,就會朝著生成能力強的方向迭代。在整個的關系中,判別器的準確性由論文中定義的交叉熵 來衡量,判別器和生成器共同影響著
。
1. 交叉熵 V
在生成器和判別器的特殊關系中,GAN的目標損失 ?為:
由于期望代表的是數據均值,因此,式子可以改為:
?其中,?為真實數據,
?為與真實數據結構相似的隨機噪音,
?為生成器生成的假數據,
?為判別器在真實數據?
?上的判別的結果,
?為判別器在假數據(
)上判別的結果,其中?
?與?
?都是樣本為真(標簽為1)的概率。
由于?
?與?
?都是概率,所以值在 (0,1] 之間,因此,取對數后的值域為
,所以損失?
的值域也在
。并且
在判別器的能力最好時達到最大值,說明判別器越準確,
反而越大,這顯然與普通的二分類損失相反。但是,如果分別從判別器和生成器的角度看,又是合理的。
2. 對于判別器損失
在 V 的表達式中,對數都與判別器有關,因此,對于判別器來說, 即判別器的損失:
我解釋一下,判別器的目的是盡量使自己作出正確的判斷,且判別器輸出的是標簽為真的概率,因此,判別器的最佳表現是:對于所有在真實數據上的判別器的輸出??都接近1,所有假數據上判別器?
?都接近0,因此,對于判別器的最佳損失就是:
- 因此,判別器希望?
?越大越好,即
,判別能力越強,值越大,理想情況下值最大為0。
3. 對于生成器損失
在 的表達式中,生成器只會影響?
,因此只有
的后半部分的表達式與生成器有關,即:
去掉常數項:
生成器的目標是盡可能使生成的假數據讓判別器判斷為真,即??越接近1越好,因此,對于生成器的最佳損失為:
- 可以看出,生成器希望?
?越小越好,即
無限接近負無窮,因此生成器的本質就是追求
?無限接近 1。對生成器而言,
更像是一個損失,即模型表現越好,該指標的值越低。
- 從整個GAN的角度看,我們的目標就是與生成器的目標一致,因此,對于我們而言,
就被當做損失,并且越低越好。
4. 求最優解
上面我已經推導了GAN損失函數??的由來,那么如何求最優解呢?下面我們來推導一下:
第一步:固定生成器 G,求最優判別器 D:
當我們固定 G 時,要求最優判別器 D,此時 G 是確定的,因此??生成了一個偽樣本?
,這意味這我們可以把:
,
?服從分布
,即
因此,第二項期望就可以寫作:
所以,原式就變為了:
我們將期望轉化為積分:
將兩個分布??和?
?放在同一個積分中:
這是一個關于??的泛化優化問題,對每個點?
,我們要最大化,令:
對??求導數,令導數為 0,得到最優?
:
解得:
第二步:將??代入原損失函數,優化生成器的目標??
:
我們對其化簡:
這里,我們介紹一個概念,Jensen-Shannon散度,是衡量兩個分布差異的對稱度量,它與KL散度有如下關系:
因此,最終式子變為:
若要使得??取得最小,那么KL散度應當為0,當KL散度為0時,就相當于:
由此可知,??逼近
?時,目標函數取得最優值,并且當
判別器就無法判斷出樣本是來自假數據樣本?,還是來自真實數據樣本
?了,此時生成器的生成效果便達到了最好。
三、GAN的訓練流程
接下來,我將使用MINST數據集來實現一下GAN的訓練流程:
訓練過程:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.utils import save_image, make_grid
from torch.utils.tensorboard import SummaryWriter
import os# 定義生成器與判別器
class Generator(nn.Module):def __init__(self, latent_dim):super().__init__()self.net = nn.Sequential(nn.ConvTranspose2d(latent_dim, 256, 4, 1, 0), # 1x1 -> 4x4nn.BatchNorm2d(256),nn.ReLU(True),nn.ConvTranspose2d(256, 128, 4, 2, 1), # 4x4 -> 8x8nn.BatchNorm2d(128),nn.ReLU(True),nn.ConvTranspose2d(128, 64, 4, 2, 1), # 8x8 -> 16x16nn.BatchNorm2d(64),nn.ReLU(True),nn.ConvTranspose2d(64, 1, 4, 4, 0), # 16x16 -> 64x64nn.Tanh())def forward(self, z):return self.net(z)class Discriminator(nn.Module):def __init__(self):super().__init__()self.net = nn.Sequential(nn.Conv2d(1, 64, 4, 4, 0), # 64x64 -> 16x16nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(64, 128, 4, 2, 1), # 16x16 -> 8x8nn.BatchNorm2d(128),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(128, 256, 4, 2, 1), # 8x8 -> 4x4nn.BatchNorm2d(256),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(256, 1, 4, 1, 0), # 4x4 -> 1x1nn.Sigmoid())def forward(self, x):return self.net(x).view(-1, 1).squeeze(1)
# 設置超參數
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
latent_dim = 100
batch_size = 128
epochs = 50
image_size = 64
lr = 2e-4
log_dir = "./GAN/runs/log"
sample_dir = "./GAN/samples"
os.makedirs(sample_dir, exist_ok=True)# 加載數據
transform = transforms.Compose([transforms.Resize(image_size),transforms.ToTensor(),transforms.Normalize([0.5], [0.5])
])
dataset = datasets.MNIST(root='./GAN', train=True, transform=transform, download=False)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)G = Generator(latent_dim).to(device)
D = Discriminator().to(device)
criterion = nn.BCELoss()
optimizer_G = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))writer = SummaryWriter(log_dir)
fixed_noise = torch.randn(64, latent_dim, 1, 1, device=device)
best_d_loss = float('inf')
# 訓練
for epoch in range(epochs):for i, (real_imgs, _) in enumerate(dataloader):real_imgs = real_imgs.to(device)valid = torch.ones(real_imgs.size(0), device=device)fake = torch.zeros(real_imgs.size(0), device=device)# -------------------# 訓練生成器 G# -------------------optimizer_G.zero_grad()z = torch.randn(real_imgs.size(0), latent_dim, 1, 1, device=device)gen_imgs = G(z)g_loss = criterion(D(gen_imgs), valid)g_loss.backward()optimizer_G.step()# -----------------------# 訓練判別器 D# -----------------------optimizer_D.zero_grad()real_loss = criterion(D(real_imgs), valid)fake_loss = criterion(D(gen_imgs.detach()), fake)d_loss = real_loss + fake_lossd_loss.backward()optimizer_D.step()# 日志記錄batches_done = epoch * len(dataloader) + iwriter.add_scalar("Loss/Generator", g_loss.item(), batches_done)writer.add_scalar("Loss/Discriminator", d_loss.item(), batches_done)print(f"Epoch {epoch+1}/{epochs} | G_loss: {g_loss.item():.4f} | D_loss: {d_loss.item():.4f}")# 保存最優模型if d_loss.item() < best_d_loss:best_d_loss = d_loss.item()torch.save(G.state_dict(), "best_generator.pth")torch.save(D.state_dict(), "best_discriminator.pth")
生成:
# 生成新圖像
加載保存訓練好的生成器
G.load_state_dict(torch.load("best_generator.pth", map_location=device))
G.eval()
noise = torch.randn(64, latent_dim, 1, 1, device=device)
with torch.no_grad():fake_imgs = G(noise).detach().cpu()
os.makedirs("final_samples", exist_ok=True)
grid = make_grid(fake_imgs, nrow=8, normalize=True)
save_image(grid, "final_samples/generated_grid.png")
在這里,我簡單的訓練了一個GAN網絡,并生成64張數字圖片,生成效果如下:
可以看到,簡單訓練的GAN能夠生成不錯的效果。
總結
?以上就是本文對GAN原理的全部介紹,相信小伙伴們在看完之后對GAN的原理會有更深刻的理解。總的來說,GAN為我們帶來了新的視角,它讓我們不再試圖去擬合復雜的數據分布,而是建立一個“博弈系統”,通過競爭機制驅動模型學習。從14年到現在,基于GAN架構的生成模型已經發展了很多,比如StyleGAN、CycleGAN等模型,但是它們的核心仍然是GAN的架構。
如果小伙伴們覺得本文對各位有幫助,歡迎:👍點贊 |?? 收藏 | ?🔔 關注。我將持續在專欄《人工智能》中更新人工智能知識,幫助各位小伙伴們打好扎實的理論與操作基礎,歡迎🔔訂閱本專欄,向AI工程師進階!