@(TOC)[CycleGAN圖像風格遷移呼喚]
模型介紹
模型簡介
CycleGAN(Cycle Generative Adversaial Network)即循環對抗生成網絡,來自論文Link:Unpaired lmage-to-mage Translation using Cycle-Consistent AdvesairalNetworks該模型實現了—種在沒有配對示例的情況下學習將圖像從源域×轉換到目標域Y的方法。
該模型一個重要應用領城是域遷移(Dom in Adaptation),可以通俗地理解為圖像風格遷移。其實在CycieGAV之前,就已經有了域遷移模型,比以D Pi2Pk,但是Pi2Fik要求訓練數據必須是成對的,而現實生活中,要找到兩個城(畫風)中成對出現的圖片是相當困難的,因此 CyclCGAN誕生了,它只需要兩種域的數據,而不需要他們有嚴格對應關系,是一種新的無監督的圖像遷移網絡。
模型結構
CycleGAN網絡本質上是由兩個鏡像對稱的GAN網絡組成,其結構如下圖所示(圖片來源于原論文)∶
為了方便理解,這里以蘋果和橘子為例介紹。上圖中 X X X可以理解為蘋果, Y Y Y為橘子; G G G為將蘋果生成橘子風格的生成器, F F F為將橘子生成的蘋果風格的生成器, D x D_x Dx?和 D x D_x Dx?為其相應判別器,具體生成器和判別器的結構可見下文代碼。模型最終能夠輸出兩個模型的權重,分別將兩種圖像的風格進行彼此遷移,生成新的圖像。
該模型一個很重要的部分就是損失函數,在所有損失里面循環一致損失(Cycle ConsistencyLoss)是最重要的。循環損失的計算過程如下圖所示(圖片來源于原論文)︰
圖中蘋果圖片 x x x經過生成器 G G G得到偽橘子 Y ? \^Y Y?,然后將偽橘子 Y ? \^Y Y?結果送進生成器 F F F又產生蘋果風格的結果 x ? \^x x?,最后將生成的蘋果風格結果 x ? \^x x?與原蘋果圖片 x x x一起計算出循環一致損失,反之亦然。循環損失捕捉了這樣的直覺,即如果我們從一個域轉換到另一個域,然后再轉換回來,我們應該到達我們開始的地方。詳細的訓練過程見下文代碼。
1 數據集
本案例使用的數據集里面的圖片來源于Link:ImageNet,該數據集共有17個數據包,本文只使用了其中的蘋果橘子部分。圖像被統─縮放為256×256像素大小,其中用于訓練的蘋果圖片996張、橘子圖片1020張,用于測試的蘋果圖片266張、橘子圖片248張。
這里對數據進行了隨機裁剪、水平隨機翻轉和歸—化的預處理,為了將重點聚焦到模型,此處將數據預處理后的結果轉換為MindRecord格式的數據,以省略大部分數據預處理的代碼。
1.1 數據集下載
使用download 接口下載數據集,并將下載后的數據集自動解壓到當前目錄下。數據下載之前需要使用pip install download安裝download 包。
1.2 數據集加載
使用MindSpore的MindDataset接讀取和解析數據集。
1.3 可視化
通過create_dict_iterator函數將數據轉換成字典迭代器,然后使用matplotlib 模塊可視化部分訓練數據。
2 構建生成器
本案例生成器的模型結構參考的ResNet模型的結構,參考原論文,對于128×128大小的輸入圖片采用6個殘差塊相連,圖片大小為256×256以上的需要采用9個殘差塊相連,所以本文網絡有9個殘差塊相連,超參數n_layers參數控制殘差塊數。
生成器的結構如下所示:
具體的模型結構請參照下文代碼:
3 構建判別器
判別器其實是一個二分類網絡模型,輸出判定該圖像為真實圖的概率。網絡模型使用的是Patch大小為70x70的PatchGANs模型。通過一系列的Conv2d 、 BatchNorm2d和LeakyReLu層對其進行處理,最后通過Sigmoid 激活函數得到最終概率。
4 優化器和損失函數
根據不同模型需要單獨的設置優化器,這是訓練過程決定的。
對生成器 G G G及其判別器 D y Dy Dy ,目標損失函數定義為:
其中 G G G試圖生成看起來與 Y Y Y中的圖像相似的圖像 G ( x ) G(x) G(x),而 D y D_y Dy?的目標是區分翻譯樣本 G ( x ) G(x) G(x)和真實樣本 y y y,生成器的目標是最小化這個損失函數以此來對抗判別器。即
單獨的對抗損失不能保證所學函數可以將單個輸入映射到期望的輸出,為了進一步減少可能的映射函數的空間,學習到的映射函數應該是周期一致的,例如對于X的每個圖像x,圖像轉換周期應能夠將x帶回原始圖像,可以稱之為正向循環—致性,即
對于 Y Y Y,類似的
可以理解采用了一個循環一致性損失來激勵這種行為。
循環一致損失函數定義如下:
5 前向計算
搭建模型前向計算損失的過程,過程如下代碼。
為了減少模型振蕩[1],這里遵循Shrivastava等人的策略[2],使用生成器生成圖像的歷史數據而不是生成器生成的最新圖像數據來更新鑒別器。這里創建image_pool 函數,保留了一個圖像緩沖區,用于存儲生成器生成前的50個圖像。
6 計算梯度和反向傳播
其中梯度計算也是分開不同的模型來進行的,詳情見如下代碼:
7 模型訓練
8 模型推理
下面我們通過加載生成器網絡模型參數文件來對原圖進行風格遷移,結果中第—行為原圖,第二行為對應生成的結果圖。
9 參考
[1] I.Goodfellow.NIPS 2016 tutorial: Generative ad-versarial networks. arXiv preprint arXiv:1701.00160,2016.2,4,5
[2]A.Shwivastava T.Pister,O. Tuzel, J.Susskind W.Wang, R.Webb.Learning from simulated and unsupervised images through adversarial training. In CVPR,2017.3,5,6,7