本系列博文為深度學習/計算機視覺論文筆記,轉載請注明出處
標題:Conditional Generative Adversarial Nets
鏈接:[1411.1784] Conditional Generative Adversarial Nets (arxiv.org)
摘要
生成對抗網絡(Generative Adversarial Nets)[8] 最近被引入為訓練生成模型的一種新穎方法。
在這項工作中,我們介紹了生成對抗網絡的條件版本,通過簡單地將我們希望依賴的數據 y y y同時提供給生成器和判別器,就可以構建它。我們展示了這個模型可以生成依據類標簽條件化的MNIST數字。
我們還說明了如何使用這個模型學習一個多模態模型(multi-modal model),并提供了一個初步的圖像標記應用示例,在其中我們展示了如何使用這種方法生成并不是訓練標簽部分的描述性標簽。
1 引言
生成對抗網絡最近被引入為訓練生成模型的一種替代框架,以便繞過許多難以處理的概率計算的困難。
對抗網絡具有以下優勢:
-
從不需要馬爾可夫鏈,僅使用反向傳播來獲取梯度
-
學習過程中不需要推理,而且
-
各種因素和相互作用都可以輕松納入模型中
此外,正如[8]中所展示的,它可以產生最先進的對數似然估計和逼真的樣本。
在一個無條件的生成模型中,生成數據的模式沒有控制。
然而,通過在模型上附加額外的信息進行條件化,就可以引導數據生成過程。這種條件化可能基于類標簽,像[5]那樣基于部分數據進行修補,甚至基于不同的模態數據。
在這項工作中,我們展示了如何構建條件生成對抗網絡。至于實證結果,我們展示了兩組實驗。一組是基于類標簽的MNIST數字數據集,另一組是用于多模態學習的MIR Flickr 25,000數據集[10]。
2 相關工作
2.1 針對圖像標記的多模態學習
盡管監督神經網絡(特別是卷積網絡)[13, 17]近來取得了許多成功,但將這些模型擴展以容納極大數量的預測輸出類別仍然具有挑戰性。第二個問題是迄今為止的大部分工作都集中在學習輸入到輸出的一對一映射。然而,許多有趣的問題更自然地被認為是概率性的一對多映射。例如,在圖像標記的情況下,可能有許多不同的標簽可以適當地應用于給定的圖像,不同的(人類)注釋者可能使用不同的(但通常是同義或相關的)術語來描述同一圖像。
-
解決第一個問題的一種方法
- 是利用其他模態的附加信息:例如,使用自然語言語料庫學習標簽的向量表示,其中幾何關系在語義上有意義。
- 在這樣的空間中進行預測時,我們從事實中受益,即當預測錯誤時,我們仍然通常接近真相(例如,預測“桌子”而不是“椅子”),并且也從我們可以自然地對訓練期間未見過的標簽進行預測概括的事實中受益。
- 諸如[3]的作品已經表明,即使是從圖像特征空間到單詞表示空間的簡單線性映射也可以提高分類性能。
-
解決第二個問題的一種方法
- 是使用條件概率生成模型,輸入被視為條件變量,一對多映射被實例化為條件預測分布。
- [16]對此問題采取了類似的方法,并在MIR Flickr 25,000數據集上訓練了一種多模態深度玻爾茲曼機,就像我們在這項工作中所做的那樣。
此外,在[12]中,作者展示了如何訓練一種受監督的多模態神經語言模型,并且他們能夠為圖像生成描述性句子。
3 條件生成對抗網絡
3.1 生成對抗網絡
生成對抗網絡最近被引入作為訓練生成模型的一種新穎方法。
它們由兩個“對抗”的模型組成:一個生成模型G,用于捕獲數據分布;和一個判別模型D,用于估計樣本來自訓練數據還是G的概率。G和D都可以是非線性映射函數,例如多層感知器。
為了學習生成器分布 p g p_g pg?在數據 x x x上的分布,生成器從先驗噪聲分布 p z ( z ) p_z(z) pz?(z)構建到數據空間的映射函數 G ( z ; θ g ) G(z; \theta_g) G(z;θg?)。而判別器 D ( x ; θ d ) D(x; \theta_d) D(x;θd?)輸出一個標量,表示 x x x來自訓練數據而不是 p g p_g pg?的概率。
G和D都同時進行訓練:我們調整G的參數以使 log ? ( 1 ? D ( G ( z ) ) \log(1 - D(G(z)) log(1?D(G(z))最小化,并調整D的參數以使 log ? D ( X ) \log D(X) logD(X)最小化,就好像它們在遵循具有值函數 V ( G , D ) V(G, D) V(G,D)的兩玩家極小極大博弈:
min ? G max ? D V ( D , G ) = E x ~ p data ( x ) [ log ? D ( x ) ] + E z ~ p z ( z ) [ log ? ( 1 ? D ( G ( z ) ) ) ] 。 (1) \min_G \max_D V(D, G) = \mathbb{E}_{x \sim p_{\text{data}}(x)} [\log D(x)] + \mathbb{E}_{z \sim p_z(z)} [\log(1 - D(G(z)))]。 \tag{1} Gmin?Dmax?V(D,G)=Ex~pdata?(x)?[logD(x)]+Ez~pz?(z)?[log(1?D(G(z)))]。(1)
3.2 條件生成對抗網絡
生成對抗網絡可以擴展到條件模型,如果生成器和判別器都基于一些額外的信息 y y y進行條件化。 y y y可以是任何類型的輔助信息,例如類標簽或來自其他模態的數據。我們可以通過將 y y y作為附加的輸入層輸入到判別器和生成器中來執行條件化。
在生成器中,先驗輸入噪聲 p z ( z ) p_z(z) pz?(z)和 y y y結合在聯合隱藏表示中,而對抗訓練框架允許在組成這個隱藏表示方面具有相當大的靈活性。1
在判別器中, x x x和 y y y被呈現為輸入,并輸入到判別函數(在這種情況下再次由MLP體現)。
兩個玩家極小極大博弈的目標函數將與等式2相同
min ? G max ? D V ( D , G ) = E x ~ p data ( x ) [ log ? D ( x ∣ y ) ] + E z ~ p z ( z ) [ log ? ( 1 ? D ( G ( z ∣ y ) ) ) ] 。 (2) \min_G \max_D V(D, G) = \mathbb{E}_{x \sim p_{\text{data}}(x)} [\log D(x|y)] + \mathbb{E}_{z \sim p_z(z)} [\log(1 - D(G(z|y)))]。 \tag{2} Gmin?Dmax?V(D,G)=Ex~pdata?(x)?[logD(x∣y)]+Ez~pz?(z)?[log(1?D(G(z∣y)))]。(2)
圖1說明了一個簡單條件對抗網絡的結構。
圖1:條件生成對抗網絡
4 實驗結果
4.1 單模態
我們在MNIST圖像上訓練了一個條件生成對抗網絡,并基于它們的類標簽進行條件化,以one-hot向量進行編碼。
在生成器網絡中,從單位超立方體中均勻分布抽取了一個具有100個維度的噪聲先驗 z z z。 z z z和 y y y都被映射到具有整流線性單元(ReLu)激活[4, 11]的隱藏層,層大小分別為200和1000,然后再被映射到維度為1200的第二個組合隱藏ReLu層。然后我們有一個最終的Sigmoid單元層作為生成784維MNIST樣本的輸出。
判別器將 x x x映射到具有240個單元和5個部分的maxout [6]層,并將 y y y映射到具有50個單元和5個部分的maxout層。在被送入Sigmoid層之前,兩個隱藏層都映射到一個具有240個單元和4個部分的聯合maxout層。(判別器的確切架構并不關鍵,只要它具有足夠的能力;我們發現maxout單元通常適合這項任務。)
該模型使用具有大小為100的小批量和初始學習率0.1的隨機梯度下降進行訓練,該學習率以1.00004的衰減因子呈指數遞減至0.000001。初始動量為0.5,增加到0.7。Dropout [9]概率為0.5應用于生成器和判別器。并以驗證集上的對數似然的最佳估計作為停止點。
表1顯示了用于MNIST數據集測試數據的高斯Parzen窗口對數似然估計。從每個10個類別中抽取1000個樣本,并對這些樣本擬合高斯Parzen窗口。然后我們使用Parzen窗口分布估計測試集的對數似然。(有關如何構建此估計的更多詳細信息,請參見[8]。)
表1:基于Parzen窗口的MNIST對數似然估計。我們遵循了與[8]相同的程序來計算這些值。
我們展示的條件生成對抗網絡結果與其他一些基于網絡的結果相當,但被其他幾種方法所超過,包括非條件生成對抗網絡。我們更多地將這些結果作為概念驗證而非有效性展示,并相信通過進一步探索超參數空間和架構,條件模型應當匹配或超過非條件結果。
圖2顯示了一些生成的樣本。每一行都以一個標簽為條件,每一列都是一個不同的生成樣本。
圖2:生成的MNIST數字,每一行都是基于一個標簽
4.2 多模態
像Flickr這樣的照片網站是圖像及其關聯的用戶生成元數據(UGM)形式的豐富標簽數據源,特別是用戶標簽。用戶生成的元數據與更“規范”的圖像標簽方案不同,因為它們通常更具描述性,并且在語義上更接近人們如何用自然語言描述圖像,而不僅僅是識別圖像中存在的對象。UGM的另一個方面是同義詞普遍存在,不同的用戶可能會使用不同的詞匯來描述相同的概念,因此,有效地標準化這些標簽變得重要。概念詞嵌入[14]在這里可能非常有用,因為相關概念最終會被表示為相似的向量。
在本節中,我們演示了使用條件對抗網絡生成圖像的自動標簽(可能是多模態的)標簽向量分布的多標簽預測。
對于圖像特征,我們使用與[13]類似的卷積模型預訓練了具有21,000個標簽[15]的完整ImageNet數據集。我們使用最后一個完全連接的層的輸出,該層具有4096個單元作為圖像表示。
對于世界表示,我們首先從YFCC100M2數據集元數據的用戶標簽、標題和描述的串聯中收集文本。在文本的預處理和清理之后,我們用單詞向量大小為200訓練了一個跳過的gram模型[14]。我們省略了在詞匯表中出現少于200次的任何單詞,從而最終得到大小為247465的字典。
在對抗網絡的訓練期間,我們保持卷積模型和語言模型固定。并將通過這些模型的反向傳播留作未來工作。
對于我們的實驗,我們使用MIR Flickr 25,000數據集[10],并使用我們上述描述的卷積模型和語言模型提取圖像和標簽特征。未加任何標簽的圖像被省略,注釋被視為額外標簽。前150,000個示例用作訓練集。具有多個標簽的圖像在訓練集內重復,每個關聯標簽重復一次。
對于評估,我們為每個圖像生成100個樣本,并使用詞匯表中單詞的向量表示與每個樣本的余弦相似度找到最接近的前20個單詞。然后我們選擇所有100個樣本中最常見的前10個單詞。表4.2顯示了用戶分配的標簽和注釋以及生成的標簽的一些樣本。
最佳工作模型的生成器接收大小為100的高斯噪聲作為噪聲先驗,并將其映射到500維ReLu層。并將4096維圖像特征向量映射到2000維ReLu隱藏層。這兩層都映射到200維線性層,該層將輸出生成的單詞向量。
鑒別器由單詞向量和圖像特征分別為500和1200維的ReLu隱藏層組成,并且具有1000個單位和3個部分的最大層作為連接層,最終輸入到一個單一的S形單元。
該模型使用隨機梯度下降進行訓練,批量大小為100,并且初始學習速率為0.1,這個速率呈指數下降至.000001,衰減因子為1.00004。還使用了初始值為.5的動量,該動量增加到0.7。在生成器和鑒別器上均應用了概率為0.5的丟棄。
通過交叉驗證和隨機網格搜索與手動選擇的混合(盡管在有限的搜索空間內)獲得了超參數和架構選擇。
5 未來工作
本文所示的結果非常初步,但它們展示了條件對抗網絡的潛力,并對有趣和有用的應用展示了希望。
在現在和工作坊之間的未來探索中,我們期望呈現更復雜的模型,以及對它們的性能和特性進行更詳細和徹底的分析。
表格2:生成標簽樣本
此外,在當前的實驗中,我們只單獨使用每個標簽。但是,通過同時使用多個標簽(有效地將生成問題提出為“集合生成”問題),我們希望能夠取得更好的結果。
未來工作的另一個明顯方向是構建聯合訓練方案以學習語言模型。例如[12]的工作表明,我們可以為特定任務學習適合的語言模型。
致謝
本項目是在Pylearn2 [7] 框架中開發的,我們想要感謝Pylearn2的開發者們。我們還要感謝Ian Goodfellow在蒙特利爾大學任職期間的有益討論。作者衷心感謝Flickr的視覺與機器學習團隊以及生產工程團隊的支持(按字母順序:Andrew Stadlen, Arel Cordero, Clayton Mellina, Cyprien Noel, Frank Liu, Gerry Pesavento, Huy Nguyen, Jack Culpepper, John Ko, Pierre Garrigues, Rob Hess, Stacey Svetlichnaya, Tobi Baumgartner, 和 Ye Lu)。
參考文獻
- Bengio, Y., Mesnil, G., Dauphin, Y.和Rifai, S.(2013)。通過深度表示實現更好的混合。在ICML’2013上。
- Bengio, Y., Thibodeau-Laufer, E., Alain, G.和Yosinski, J.(2014)。可以通過反向傳播進行訓練的深度生成隨機網絡。在第30屆國際機器學習大會(ICML’14)論文集中。
- Frome, A., Corrado, G. S., Shlens, J., Bengio, S., Dean, J., Mikolov, T.等(2013)。Devise:一種深度視覺語義嵌入模型。在神經信息處理系統的進展中,頁碼:2121–2129。
- Glorot, X., Bordes, A.和Bengio, Y.(2011)。深度稀疏整流器神經網絡。在人工智能與統計學國際會議上,頁碼:315–323。
- Goodfellow, I.,Mirza, M.,Courville, A.和Bengio, Y.(2013a)。多預測深度Boltzmann機。在神經信息處理系統的進展中,頁碼:548–556。
- Goodfellow, I. J.,Warde-Farley, D.,Mirza, M.,Courville, A.和Bengio, Y.(2013b)。最大輸出網絡。在ICML’2013上。
- Goodfellow, I. J., Warde-Farley, D., Lamblin, P., Dumoulin, V., Mirza, M., Pascanu, R., Bergstra, J., Bastien, F.和Bengio, Y.(2013c)。Pylearn2:一個機器學習研究庫。arXiv預印本arXiv:1308.4214。
- Goodfellow, I. J., Pouget-Abadie, J., Mirza, M., Xu, B., Warde-Farley, D., Ozair, S., Courville, A.和Bengio, Y.(2014)。生成對抗網絡。在NIPS’2014上。
- Hinton, G. E.,Srivastava, N.,Krizhevsky, A.,Sutskever, I.和Salakhutdinov, R.(2012)。通過防止特征檢測器的共適應來改善神經網絡。技術報告,編號:arXiv:1207.0580。
- Huiskes, M. J.和Lew, M. S.(2008)。mir flickr檢索評估。在MIR’08:2008年ACM國際多媒體信息檢索大會上,紐約,美國。ACM。
- Jarrett, K.,Kavukcuoglu, K.,Ranzato, M.和LeCun, Y.(2009)。用于對象識別的最佳多級架構是什么?在ICCV’09上。
- Kiros, R.,Zemel, R.和Salakhutdinov, R.(2013)。多模態神經語言模型。在NIPS深度學習研討會的論文集中。
- Krizhevsky, A.,Sutskever, I.和Hinton, G.(2012)。使用深度卷積神經網絡的ImageNet分類。在神經信息處理系統25的進展(NIPS’2012)中。
- Mikolov, T.,Chen, K.,Corrado, G.和Dean, J.(2013)。在向量空間中有效估計單詞表示。在學習表示國際會議:研討會跟蹤上。
- Russakovsky, O.和Fei-Fei, L.(2010)。大規模數據集中的屬性學習。在歐洲計算機視覺大會(ECCV),希臘克里特島的部分和屬性國際研討會上。
- Srivastava, N.和Salakhutdinov, R.(2012)。用深度Boltzmann機進行多模態學習。在NIPS’2012上。
- Szegedy, C.,Liu, W.,Jia, Y.,Sermanet, P.,Reed, S.,Anguelov, D.,Erhan, D.,Vanhoucke, V.和Rabiovich, A.(2014)。用卷積深入探究。arXiv預印本arXiv:1409.4842。
References
- Bengio, Y., Mesnil, G., Dauphin, Y., and Rifai, S. (2013). Better mixing via deep representations. In ICML’2013.
- Bengio, Y., Thibodeau-Laufer, E., Alain, G., and Yosinski, J. (2014). Deep generative stochastic networks trainable by backprop. In Proceedings of the 30th International Conference on Machine Learning (ICML’14).
- Frome, A., Corrado, G. S., Shlens, J., Bengio, S., Dean, J., Mikolov, T., et al. (2013). Devise: A deep visual-semantic embedding model. In Advances in Neural Information Processing Systems, pages 2121–2129.
- Glorot, X., Bordes, A., and Bengio, Y. (2011). Deep sparse rectifier neural networks. In International Conference on Artificial Intelligence and Statistics, pages 315–323.
- Goodfellow, I., Mirza, M., Courville, A., and Bengio, Y. (2013a). Multi-prediction deep Boltzmann machines. In Advances in Neural Information Processing Systems, pages 548–556.
- Goodfellow, I. J., Warde-Farley, D., Mirza, M., Courville, A., and Bengio, Y. (2013b). Maxout networks. In ICML’2013.
- Goodfellow, I. J., Warde-Farley, D., Lamblin, P., Dumoulin, V., Mirza, M., Pascanu, R., Bergstra, J., Bastien, F., and Bengio, Y. (2013c). Pylearn2: a machine learning research library. arXiv preprint arXiv:1308.4214.
- Goodfellow, I. J., Pouget-Abadie, J., Mirza, M., Xu, B., Warde-Farley, D., Ozair, S., Courville, A., and Bengio, Y. (2014). Generative adversarial nets. In NIPS’2014.
- Hinton, G. E., Srivastava, N., Krizhevsky, A., Sutskever, I., and Salakhutdinov, R. (2012). Improving neural networks by preventing co-adaptation of feature detectors. Technical report, arXiv:1207.0580.
- Huiskes, M. J. and Lew, M. S. (2008). The mir flickr retrieval evaluation. In MIR ’08: Proceedings of the 2008 ACM International Conference on Multimedia Information Retrieval, New York, NY, USA. ACM.
- Jarrett, K., Kavukcuoglu, K., Ranzato, M., and LeCun, Y. (2009). What is the best multi-stage architecture for object recognition? In ICCV’09.
- Kiros, R., Zemel, R., and Salakhutdinov, R. (2013). Multimodal neural language models. In Proc. NIPS Deep Learning Workshop.
- Krizhevsky, A., Sutskever, I., and Hinton, G. (2012). ImageNet classification with deep convolutional neural networks. In Advances in Neural Information Processing Systems 25 (NIPS’2012).
- Mikolov, T., Chen, K., Corrado, G., and Dean, J. (2013). Efficient estimation of word representations in vector space. In International Conference on Learning Representations: Workshops Track.
- Russakovsky, O. and Fei-Fei, L. (2010). Attribute learning in large-scale datasets. In European Conference of Computer Vision (ECCV), International Workshop on Parts and Attributes, Crete, Greece.
- Srivastava, N. and Salakhutdinov, R. (2012). Multimodal learning with deep Boltzmann machines. In NIPS’2012.
Vision (ECCV), International Workshop on Parts and Attributes, Crete, Greece. - Srivastava, N. and Salakhutdinov, R. (2012). Multimodal learning with deep Boltzmann machines. In NIPS’2012.
- Szegedy, C., Liu, W., Jia, Y., Sermanet, P., Reed, S., Anguelov, D., Erhan, D., Vanhoucke, V., and Rabiovich, A. (2014). Going deeper with convolutions. arXiv preprint arXiv:1409.4842.
目前,我們簡單地將條件輸入和先驗噪聲作為MLP的單個隱藏層的輸入,但人們可以想象使用更高階的交互作用,允許復雜的生成機制,這在傳統的生成框架中將非常難以處理。 ??
Yahoo Flickr Creative Common 100M 數據集:http://webscope.sandbox.yahoo.com/catalog.php?datatype=i&did=67。 ??