生成對抗網絡詳解與實現
- 0. 前言
- 1. GAN 原理
- 2. GAN 架構
- 3. 損失函數
- 3.1 判別器損失
- 3.2 生成器損失
- 3.4 VANILLA GAN
- 4. GAN 訓練步驟
0. 前言
生成對抗網絡 (Generative Adversarial Network
, GAN
) 是圖像和視頻生成中的主要方法之一。在本節中,我們將了解 GAN
的架構、訓練步驟等,并實現原始 GAN
。
1. GAN 原理
生成模型的目的是學習數據分布并從中進行采樣以生成新數據。PixelCNN 和變分自編碼器 (Variational Autoencoder, VAE),它們的生成部分將著眼于訓練過程中的圖像分布。因此,稱為顯式密度模型 (explicit density models
)。相比之下,GAN
中的生成部分不會直接查看圖像。因此,GAN
被歸類為隱式密度模型 (implicit density models
)。
我們可以使用一個類比來比較顯式模型和隱式模型。假設一位藝術系學生 G
獲得了畢加索的畫作收藏,并被要求學習繪制假畢加索畫作。學生可以在學習繪畫時查看收藏,因此這是一個顯式模型。在另一種情況下,我們要求學生 G
偽造畢加索的畫,但我們沒有給他們看任何畫,他們也不知道畢加索的畫是什么樣。他們學習的唯一方法是學生 D
的反饋,后者正在學習判別假畢加索的畫作。反饋很簡單——這幅畫是假的還是真實的。這就是我們的隱式密度 GAN
模型。
也許有一天,G
偶然地畫了一張扭曲的臉,并從反饋中得知它看起來像一幅真正的畢加索畫,然后他們開始以這種方式來欺騙學生 D
。學生 G
和 D
是 GAN
中的兩個網絡,稱為生成器和判別器。與其他生成模型相比,這是網絡體系結構的最大區別。我們將從了解 GAN
構建塊開始,然后介紹損失函數。然后,我們將為 GAN
創建自定義的訓練步驟。
2. GAN 架構
生成對抗網絡中的對抗一詞是指包含對立或異議。有兩個相互競爭的網絡,稱為生成器和判別器。顧名思義,生成器生成偽造的圖像。而辨別器將查看生成的圖像,以確定它們是真實的還是偽造的。每個網絡都試圖贏得這場比賽,判別器要正確識別每個真實和偽造的圖像,而生成器則要愚弄判別器以使其所產生的虛假圖像被判別器判定是真實的。下圖顯示了 GAN
的體系結構:
GAN
架構與 VAE
有一些相似之處。如果 VAE
由兩個獨立的網絡組成,我們可以想到:
GAN
的生成器作為VAE
的解碼器GAN
的判別器作為VAE
的編碼器
生成器將低維和簡單分布轉換為具有復雜分布的高維圖像,就像解碼器一樣。生成器的輸入通常是來自正態分布的樣本,也有些樣本使用均勻分布。
我們將不同批次的圖像發送給判別器。真實圖像是來自數據集的圖像,而偽造圖像則是由生成器生成的。判別器輸出輸入圖片是真還是假的單值概率。它是一個二進制分類器,可以使用 CNN
來實現它。從技術上講,判別器的作用與編碼器不同,但它們都減小了輸入的維數。
實際上,原始的 GAN
僅使用了多層感知器,該感知器由一些基本的全連接層組成。
3. 損失函數
損失函數體現了 GAN
的工作原理。公式如下:
minGmaxDV(D,G)=EX~Pdata(x)[logD(x)]+EZ~Pz(z)[log(1?D(G(z)))]min_Gmax_DV(D,G)=E_{X\sim P_data(x)}[logD(x)]+E_{Z\sim P_z(z)}[log(1-D(G(z)))] minG?maxD?V(D,G)=EX~Pd?ata(x)?[logD(x)]+EZ~Pz?(z)?[log(1?D(G(z)))]
其中:DDD 表示判別器,GGG 表示生成器,xxx 表示輸入數據,zzz 表示潛變量。
了解 GAN
的損失函數之后,代碼實現將變得更加容易。此外,有關 GAN
改進的許多討論都圍繞損失函數進行。GAN
損失函數也稱為對抗損失。接下來,我們將對其進行分解,并逐步向展示如何將其轉換為我們可以實現的簡單損失函數。
3.1 判別器損失
GAN
損失函數的等式右側的第一項是用于正確分類真實圖像的值。從等式左邊的項來看,我們知道判別器想要將其最大化。期望是一個數學術語,是隨機變量每個樣本的加權平均值之和。在此等式中,權重是數據的概率,而變量是判別器輸出的對數,如下所示:
EX[logD(x)]=∑i=1Np(x)logD(x)=1N∑i=1NlogD(x)E_X[logD(x)]=\sum_{i=1}^Np(x)logD(x)=\frac 1N\sum_{i=1}^NlogD(x) EX?[logD(x)]=i=1∑N?p(x)logD(x)=N1?i=1∑N?logD(x)
在大小為 NNN 的小批次中,p(x)p(x)p(x) 為 1N\frac 1 NN1?。這是因為 xxx 是單個圖像。不必嘗試使它最大化,我們可以將符號更改為減號并嘗試使其最小化。這可以借助以下方程來完成,該方程稱為對數損失:
minDV(D)=?1N∑i=1NlogD(x)=?1N∑i=1Nyilogp(yi)min_DV(D)=-\frac 1N\sum_{i=1}^NlogD(x)=-\frac 1N\sum_{i=1}^Ny_ilogp(y_i) minD?V(D)=?N1?i=1∑N?logD(x)=?N1?i=1∑N?yi?logp(yi?)
其中:yiy_iyi? 是標簽,對于真實圖像為 1
。p(yi)p(y_i)p(yi?) 是樣本為真的概率。
GAN
損失函數的等式右側的第二項是關于偽造圖像的。zzz 是隨機噪聲,并且 G(z)G(z)G(z) 是生成圖像。D(G(z))D(G(z))D(G(z)) 是判別器對圖像真實可能性的置信度得分。如果我們將標簽 0
用于偽造圖像,則可以使用相同的方法將其轉換為以下等式:
?EZ~Pz(z)[log(1?D(G(z))]=?1N∑i=1N(1?yi)log(1?p(yi))-E_{Z\sim P_z(z)}[log(1-D(G(z))]=-\frac 1N\sum_{i=1}^N(1-y_i)log(1-p(y_i)) ?EZ~Pz?(z)?[log(1?D(G(z))]=?N1?i=1∑N?(1?yi?)log(1?p(yi?))
現在,將所有內容放在一起,我們有了判別器損失函數,即二進制交叉熵損失:
minDV(D)=?1N∑i=1Nyilogp(yi)+(1?yi)log(1?p(yi))min_DV(D)=-\frac 1N\sum_{i=1}^Ny_ilogp(y_i)+(1-y_i)log(1-p(y_i)) minD?V(D)=?N1?i=1∑N?yi?logp(yi?)+(1?yi?)log(1?p(yi?))
使用以下代碼實現判別器損失:
def discriminator_loss(pred_fake, pred_real):real_loss = bce(tf.ones_like(pred_real), pred_real)fake_loss = bce(tf.zeros_like(pred_fake), pred_fake)d_loss = 0.5 *(real_loss + fake_loss)return d_loss
在我們的訓練中,我們使用相同的批大小分別對真實和偽造圖像進行前向傳遞。因此,我們分別為它們計算二進制交叉熵損失,并取平均值作為損失。
3.2 生成器損失
僅當模型判別偽造圖像時才涉及生成器,因此我們只需要查看 GAN
損失函數的等式右側第二項并將其簡化為:
minGV(G)=EZ~Pz(z)[log(1?D(G(z))]min_GV(G)=E_{Z\sim P_z(z)}[log(1-D(G(z))] minG?V(G)=EZ~Pz?(z)?[log(1?D(G(z))]
在訓練開始時,生成器并不擅長生成圖像,因此判別器始終有信心將其歸類為 0
,使 D(G(z))D(G(z))D(G(z)) 始終為 0
,log(1–0)log(1 – 0)log(1–0) 也是如此。當模型輸出中的誤差始終為 0
時,則沒有反向傳播的梯度。結果,生成器的權重未更新,并且生成器未學習。由于判別器的 sigmoid
輸出幾乎沒有梯度,因此這種現象稱為梯度飽和 (saturating gradient
)。為避免此問題,將等式從最小化 1?D(G(z))1-D(G(z))1?D(G(z)) 到最大化 D(G(z))D(G(z))D(G(z)) 進行如下轉換:
maxGV(G)=EZ~Pz(z)[logD(G(z))]max_GV(G)=E_{Z\sim P_z(z)}[logD(G(z))] maxG?V(G)=EZ~Pz?(z)?[logD(G(z))]
使用此函數的 GAN
也稱為非飽和 GAN
(Non-Saturating GANs
, NS-GAN
)。實際上,Vanilla GAN
的實現都使用此損失函數而不是原始的 GAN
損失函數。
3.4 VANILLA GAN
GAN
誕生后,研究人員對 GAN
的興趣激增,提出了一系列改進模型。Vanilla GAN
是泛指基本 GAN
,Vanilla GAN
通常使用具有兩個或三個隱藏的全連接層來實現。
我們可以對判別器使用相同的數學步驟來推導生成器損失,最終將得到相同的判別器損失函數,只是將標簽 1
用于偽造圖像。為什么要對偽造圖片使用標簽 1
,我們也可以這樣理解它——因為我們想欺騙判別器以假定那些生成的圖像是真實的,因此我們使用標簽 1
:
def generator_loss(pred_fake):g_loss = bce(tf.ones_like(pred_fake), pred_fake)return g_loss
4. GAN 訓練步驟
為了在 TensorFlow
中訓練神經網絡,我們需要指定模型,損失函數,優化器,然后調用 model.fit()
,TensorFlow
將為我們完成所有工作,我們等待損失減少。
在研究 GAN
問題之前,我們首先回顧神經網絡在進行單個訓練步驟時代碼執行的情況:
- 執行前向傳播以計算損失
- 使用損失相對于權重的梯度向后傳播
- 然后,這是更新權重。優化器將縮放梯度并將其添加到權重中,從而完成一個訓練步驟
這些是深度神經網絡中的通用訓練步驟。各種優化器的不同之處僅在于它們計算縮放因子的方式。
現在回到 GAN
,查看梯度流。當我們訓練真實圖像時,只涉及判別器–網絡輸入是真實圖像,輸出是 1
的標簽。當我們使用偽造圖像并且梯度通過判別器反向傳播到生成器時,就會出現問題。讓我們將偽造圖像的生成器損失和判別器損失并排放置:
g_loss = bce(tf.ones_like(pred_fake), pred_fake)
fake_loss = bce(tf.zeros_like(pred_fake), pred_fake)
可以發現它們之間的差異,它們的標簽是相反!這意味著,使用生成器損失來訓練整個模型將使判別器朝相反的方向移動,而不會學會進行判別。這適得其反,我們不想有一個未經訓練的判別器,這會阻止生成器學習。因此,我們必須分別訓練生成器和判別器。訓練生成器時,我們將凍結判別器權重。
有多種方法可以設計 GAN
訓練流程。一種是使用高級 Keras
模型,該模型需要較少的代碼,因此看起來更優雅。我們只需要定義一次模型,然后調用 train_on_batch()
即可執行所有步驟,包括前向計算,反向傳播和權重更新。但是,在實現更復雜的損失函數時,靈活性較差。
另一種方法是使用低級函數,以便控制每個步驟。在本節中,GAN
將使用自定義訓練步驟:
def train_step(g_input, real_input):with tf.GradientTape() as g_tape,\tf.GradientTape() as d_tape:# Forward passfake_input = G(g_input)pred_fake = D(fake_input)pred_real = D(real_input) # Calculate lossesd_loss = discriminator_loss(pred_fake, pred_real)g_loss = generator_loss(pred_fake)
tf.GradientTape()
用于記錄單次通過的梯度。另一個具有類似功能的 API
為 tf.Gradient()
,但后者在 TensorFlow Eager
執行中不起作用。我們將看到如何在 train_step()
中實現前面提到的三個過程步驟。前面的代碼段顯示了執行前向傳遞以計算損失的第一步。
第二步是使用 tape
梯度從它們各自的損失計算生成器和判別器的梯度:
gradient_g = g_tape.gradient(g_loss, G.trainable_variables)gradient_d = d_tape.gradient(d_loss, D.trainable_variables)
第三步也是最后一步是使用優化器將梯度應用于模型權重:
G_optimizer.apply_gradients(zip(gradient_g, self.G.trainable_variables))D_optimizer.apply_gradients(zip(gradient_d, self.D.trainable_variables))