論文地址:https://arxiv.org/abs/2405.14867
項目官網:https://tianweiy.github.io/dmd2/
代碼地址:https://github.com/tianweiy/DMD2
發表時間:2024年5月24日
分布匹配蒸餾(DMD)生成的一步生成器能夠與教師模型在分布上保持一致,即蒸餾過程不會強制要求其采樣軌跡與教師模型形成一一對應關系。然而,為確保實際訓練的穩定性,DMD需要通過大量噪聲-圖像對計算額外的回歸損失。這些噪聲-圖像對由教師模型通過多步驟確定性采樣器生成。這不僅在大規模文本到圖像合成中計算成本高昂,還限制了學生模型的質量,使其與教師模型的原始采樣路徑過于緊密綁定。
- 首先,我們消除了回歸損失和構建昂貴數據集的需求。研究表明,由此產生的不穩定性源于“偽”評價器未能準確估計生成樣本的分布特征,為此我們提出雙時間尺度更新規則作為解決方案。
- 其次,我們將GAN損失整合到蒸餾過程中,用于區分生成樣本與真實圖像。這使得學生模型能在真實數據上進行訓練,從而緩解教師模型“真實”分數估計的不準確性,進而提升生成質量。
- 第三,提出了一種創新的訓練方法,通過在訓練過程中模擬推理階段生成器樣本,實現了學生模型的多步采樣,并有效解決了先前研究中存在的訓練與推理輸入不匹配問題。
DMD2:在ImageNet-64×64數據集上FID分數達到1.28,在零樣本COCO 2014數據集上FID分數為8.35。推理成本降低了500%×,超越了原始教師模型。
此外,通過提煉SDXL方法展示了該方案能生成百萬像素級圖像,其視覺質量在少步長方法中表現卓越,甚至超越了原始教師模型。
1 Introduction
擴散模型在效果上非常好,但是推理成本偏高。現有的少步數推理方法往往導致質量下降(學生模型通過學習教師模型的成對噪聲與圖像映射關系,卻難以完美復現其行為特征)。
DMD方法,其核心目標在于與教師模型在分布層面上達成一致——通過最小化學生模型與教師模型輸出分布之間的Jensen-Shannon(JS)散度或近似Kullback-Leibler(KL)散度,而非需要精確學習從噪聲到圖像的具體路徑。盡管DMD已取得業界領先成果,但相較于基于生成對抗網絡(GAN)的方法[23-29],其研究熱度仍顯不足。究其原因,DMD仍需額外引入回歸損失來確保訓練穩定性。這要求教師模型的采樣生成數百萬組噪聲-圖像配對,這對文本到圖像合成而言成本尤為高昂。此外,回歸損失還削弱了DMD非配對分布匹配目標的核心優勢——由于這種機制的存在,學生模型的質量上限會被教師模型所制約。
本文提出了一種在保持訓練穩定性的同時消除DMD回歸損失的方法。通過將GAN框架整合到DMD中,突破了分布匹配的極限,并開發出名為“逆向模擬”的創新訓練流程實現少步長采樣。綜合來看,我們的研究成果構建了最先進的快速生成模型,僅需四步采樣即可超越原始模型。
DMD2在單步圖像生成領域取得突破性進展:在ImageNet-64×64數據集上FID值達1.28,在零樣本COCO 2014數據集上達到8.35,創下新標桿。我們還通過從SDXL蒸餾生成高質量百萬像素圖像,驗證了該方法的可擴展性,為少步長方法樹立了新標準。
簡而言之,我們的主要貢獻包括:
- DMD2,無需依賴回歸損失即可實現穩定訓練,從而省去昂貴的數據收集環節,使訓練過程更加靈活且可擴展。
- 通過實驗證明,DMD框架[22]中不使用回歸損失導致的訓練不穩定源于偽擴散判別器訓練不足,并提出雙時間尺度更新規則來解決該問題。
- 將生成對抗網絡(GAN)目標整合到DMD框架中,通過訓練判別器區分學生生成器與真實圖像樣本。這種在分布層面施加的額外監督機制,比原始回歸損失更符合DMD的分布匹配理念,有效緩解了教師擴散模型的近似誤差并提升了圖像質量。
- 在原有僅支持單步生成器的DMD基礎上,我們創新性地引入多步生成器支持技術。與以往的多步蒸餾方法不同,通過在訓練過程中模擬推理時的生成器輸入,避免了訓練與推理之間的領域不匹配問題,從而提升了整體性能。
2 Related Work
Diffusion Distillation. 近年來,擴散加速技術主要聚焦于通過蒸餾法提升生成過程的效率[9,10,13-20,22,23,30]。這類方法通常訓練生成器以更少的采樣步驟逼近教師模型的常微分方程(ODE)采樣軌跡。值得注意的是,Luhman等人[16]預先計算了由教師模型使用ODE采樣器生成的噪聲與圖像配對數據集,并利用該數據集訓練學生模型在單次網絡評估中進行映射回歸。后續研究如漸進式蒸餾[10,13]則無需離線預計算這種配對數據集,而是通過迭代訓練一系列學生模型,每個模型的采樣步驟數量都比前序模型減半。互補技術Instaflow [11]通過拉直ODE軌跡,使得單步學生模型更容易逼近。一致性蒸餾[9,12,19,26,31,32]和TRACT [33]則訓練學生模型使其輸出在ODE軌跡的任意時間步都保持自洽性,從而與教師模型保持一致。
GANs 另一項研究采用對抗訓練方法,使生成器與判別器在更廣泛的分布層面上達成對齊。在ADD模型[23]中,生成器初始權重來自擴散模型,通過附加分類器[34]GAN目標函數進行訓練。在此基礎上,LADD模型[24]采用預訓練擴散模型作為判別器,并在潛在空間中運行,從而提升可擴展性并實現更高分辨率的合成。受DiffusionGAN [28,29]啟發,UFOGen模型[25]在判別器的真實與偽造分類前引入噪聲注入機制,通過平滑分布來穩定訓練動態。近期部分研究將對抗目標與蒸餾損失相結合,以保持原始采樣軌跡。例如,SDXL-Lightning模型[27]將DiffusionGAN損失[25]與漸進式蒸餾目標[10,13]整合;而一致性軌跡模型[26]則將生成對抗網絡[35]與改進的一致性蒸餾[9]相結合。
Score Distillation 該方法最初應用于文本到三維合成領域[36-39],通過預訓練的文本到圖像擴散模型作為分布匹配損失函數。這些方法利用預訓練擴散模型預測的分數,將渲染視圖與文本條件下的圖像分布進行對齊,從而優化三維物體。近期研究將分數蒸餾技術[36,37,40-42]拓展為擴散蒸餾[22,43-45]。值得注意的是,DMD [22]通過最小化近似KL散度實現優化,其梯度由兩個分數函數的差異構成:一個是固定且預訓練的,用于目標分布;另一個則是動態訓練的,用于生成器輸出分布。
3 Background: Diffusion and Distribution Matching Distillation
擴散模型通過迭代去噪生成圖像:在正向擴散過程中,噪聲會逐步疊加到樣本x~prealx~p_{real}x~preal?上,使其從數據分布中逐漸轉化為純高斯噪聲,整個過程分為預定的T個步驟。
因此,在每個時間步t,擴散后的樣本遵循分布
,其中
,αt和σt是根據噪聲調度確定的標量[46,47]。擴散模型通過學習逆向推導去噪過程,根據當前噪聲樣本xt和時間步t預測去噪估計值μ(xt,t),最終從數據分布prealp_{real}preal?生成圖像。訓練完成后,該去噪估計值與擴散分布的數據似然函數梯度(即評分函數[47])相關聯:
對圖像進行采樣通常需要幾十到幾百個去噪步驟。
Distribution Matching Distillation (DMD) 通過最小化擴散目標分布prealp_{real}preal?,t與生成器輸出分布pfakep_{fake}pfake?,t之間近似Kullback-Liebler(KL)散度在時間t上的期望值,該方法將多步驟擴散模型簡化為單步生成器G [22]。由于DMD通過梯度下降訓練生成器,僅需計算該損失函數的梯度,而該梯度可通過兩個評分函數的差值來實現:
其中z~N(0,I)是隨機高斯噪聲輸入,θ為生成器參數,F表示前向擴散過程(即噪聲注入),其噪聲水平對應時間步t,sreals_{real}sreal?和sfakes_{fake}sfake?則是基于各自分布訓練的擴散模型μrealμ_{real}μreal?和μfakeμ_{fake}μfake?所近似得到的分數(公式(1))。DMD采用凍結的預訓練擴散模型作為μrealμ_{real}μreal?(教師模型),在訓練生成器G時動態更新μfakeμ_{fake}μfake?,通過使用去噪分數匹配損失函數對一步生成器的樣本(即假數據)進行優化[22,46]。
YIN等人[22]發現,為了對分布匹配梯度(公式(2))進行正則化并獲得高質量的一步模型,需要引入額外的回歸項[16]。為此,他們構建了一個噪聲-圖像配對數據集(z,y),其中圖像y是通過教師擴散模型生成的,并采用確定性采樣器[48,49,52]從噪聲圖z開始生成。當輸入相同的噪聲z時,回歸損失函數會將生成器輸出與教師模型的預測結果進行對比:
其中d表示距離函數,例如LPIPS [53]在其實現中采用的方案。在大規模文本到圖像合成任務或具有復雜條件約束的模型中,這會成為重大瓶頸[54-56]。以SDXL [57]為例,生成一對噪聲-圖像樣本需要約5秒時間,若要覆蓋Yin等人[22]使用的LAION 6.0數據集[58]中的1200萬條提示,累計耗時將達700個A100天。僅數據構建成本就已超過我們總訓練計算量的4倍×(詳見附錄F)。這種正則化目標與DMD匹配師生分布的目標存在矛盾,因為它會促使學習者遵循教師的采樣路徑。
4 Improved Distribution Matching Distillation
我們重新審視了DMD算法[22]中的多個設計選擇,并確定了顯著的改進。
我們的方法將復雜的擴散模型(灰色,右)提煉為單步或多步生成器(紅色,左)。訓練過程包含兩個交替步驟:1.使用隱式分布匹配目標(紅色箭頭)的梯度和GAN損失(綠色)優化生成器;2.訓練評分函數(藍色)來建模生成器產生的“假”樣本分布,并訓練GAN判別器(綠色)以區分假樣本與真實圖像。如圖所示,學生生成器可以是單步或多步模型,并包含中間步驟輸入。
4.1 Removing the regression loss: true distribution matching and easier large-scale training
DMD [22]中使用的回歸損失函數[16]雖然能確保模式覆蓋和訓練穩定性,但設計使得大規模蒸餾過程變得復雜,并且與分布匹配的核心理念相悖,從而從根本上限制了蒸餾生成器的表現水平,使其只能達到教師模型的水平。我們的首個改進方案就是移除這個損失項。
4.2 Stabilizing pure distribution matching with a Two Time-scale Update Rule
若直接從DMD中省略公式(3)所示的回歸目標函數,會導致訓練過程不穩定且質量顯著下降(見表3)。
例如我們發現生成樣本的平均亮度及其他統計指標會出現劇烈波動,始終無法收斂到穩定狀態(詳見附錄C)。我們認為這種不穩定源于偽擴散模型μfakeμ_{fake}μfake?的近似誤差——由于該模型基于生成器非平穩輸出分布進行動態優化,無法準確追蹤偽分數。
這種誤差不僅導致近似偏差,還會產生生成器梯度偏移(如文獻[30]所述)。為此我們采用受Heusel等人[59]啟發的雙時標更新規則:通過不同頻率訓練μfakeμ_{fake}μfake?和生成器G,確保μfakeμ_{fake}μfake?能精準追蹤生成器輸出分布。實驗表明,在每個生成器更新周期內進行5次偽分數更新(不包含回歸損失),既能保持良好穩定性,又能達到與ImageNet上原始DMD相當的質量水平(見表3)。
4.3 Surpassing the teacher model using a GAN loss and real data
DMD2在訓練穩定性與性能表現方面已達到與DMD [22]相當的水平,且無需構建昂貴的數據集(表3)。但蒸餾生成器與教師擴散模型之間仍存在性能差距。我們推測這種差異可能源于DMD所使用的實數評分函數μrealμ_{real}μreal?中存在近似誤差,這些誤差會傳導至生成器并導致次優結果。由于DMD的蒸餾模型從未使用真實數據進行訓練,因此無法從這些誤差中恢復。
為解決這一問題,我們在模型訓練流程中引入了額外的GAN目標函數。通過訓練判別器來區分真實圖像與生成器生成的圖像,經過真實數據訓練的GAN分類器能夠突破教師網絡的局限性,使生成器在樣本質量上超越其性能。我們將GAN分類器整合到深度彌散模型(DMD)時采用了極簡設計:在6層假擴散去噪器瓶頸層之上添加分類分支(見圖3)。
該分類分支與UNet編碼器上游特征通過最大化標準非飽和GAN目標函數進行訓練:
其中D表示判別器,F是第3節定義的前向擴散過程(即噪聲注入),其噪聲強度對應時間步t。生成器G通過最小化該目標函數實現優化。我們的設計靈感來源于先前使用擴散模型作為判別器的研究[24,25,27]。需要指出的是,這種GAN目標函數更符合分布匹配的哲學理念,因為它不需要配對數據,并且獨立于教師的采樣軌跡。
4.4 Multi-step generator
通過本次改進方案,我們在ImageNet和COCO數據集上實現了與教師擴散模型相媲美的性能表現(詳見表1和表5)。但研究發現,像SDXL [57]這類大容量模型仍難以被整合到單步生成器中——這既源于模型容量的限制,也由于從噪聲到高度多樣化且細節豐富的圖像之間存在復雜的優化路徑。這一發現促使我們對DMD算法進行擴展,使其支持多步采樣機制。
我們預先設定了一個包含N個時間步(t1,t2,…tN)的固定時間表,在訓練和推理階段保持一致。在推理過程中,每個步驟都會交替執行去噪與噪聲注入操作,遵循一致性模型[9]以提升樣本質量。具體來說,從高斯噪聲z0~N(0,I)開始,我們交替進行去噪更新x?ti=Gθ(xti,ti)和前向擴散步驟
,直至生成最終圖像x?tN。我們的四步模型采用以下時間表:教師模型經過1000步訓練后,對應的時間步數分別為999、749、499和249。
4.5 Multi-step generator simulation to avoid training/inference mismatch
以往的多步生成器通常被訓練用于去噪含噪真實圖像[23,24,27]。然而在推理過程中,除了從純噪聲開始的第一步外,生成器的輸入都來自前一步生成器的采樣步驟x?ti。這種訓練與推理的不匹配會嚴重影響質量(圖4)。我們通過用當前學生生成器運行若干步驟后產生的含噪合成圖像xtix_{ti}xti?替代訓練時的含噪真實圖像來解決這個問題,其推理流程與第4.4節所述相似。
這種方法具有可處理性,因為與教師擴散模型不同,我們的生成器僅運行少量步驟。隨后生成器對這些模擬圖像進行去噪處理,并通過提出的損失函數對輸出進行監督。使用含噪合成圖像避免了訓練與推理的不匹配問題,從而提升了整體性能。
同期研究Imagine Flash[60]提出了類似技術方案。該團隊的逆向蒸餾算法與我們的思路一致,都希望通過在訓練階段使用學生模型生成的圖像作為后續采樣步驟的輸入,來縮小訓練集與測試集之間的差距。但他們的方法未能徹底解決數據不匹配問題——由于回歸損失函數中的教師模型從未接觸過合成圖像,導致訓練-測試鴻溝持續存在。這種誤差會沿著采樣路徑不斷累積。相比之下,我們提出的分布匹配損失函數完全獨立于學生模型的輸入參數,從而有效緩解了這一缺陷。
4.6 Putting everything together
DMD2突破了DMD [22]對預計算噪聲-圖像配對的嚴苛要求。該方法進一步整合了生成對抗網絡(GAN)的優勢,并支持多步驟生成器的構建。如圖3所示,DMD2以預訓練的擴散模型為起點,交替優化生成器Gθ以最小化原始分布匹配目標和GAN目標,并μfakeμ_{fake}μfake?使用去噪分數匹配目標對假數據進行優化,同時采用GAN分類損失來優化偽分數估計器。為確保在線優化過程中偽分數估計的準確性和穩定性,我們將其更新頻率設置得比生成器更高(5步對比1步)。
5 Experiments
我們通過多個基準測試評估DMD2方法,包括在ImageNet-64×64數據集[61]上進行類別條件圖像生成,以及使用多種教師模型[1,57]在COCO 2014數據集[62]上進行文本到圖像合成。采用Fréchet Inception Distance (FID)[59]衡量圖像質量與多樣性,并用CLIP分數[63]評估文本到圖像的對齊效果。
針對SDXL模型,我們額外報告了補丁FID [27,64]指標——該指標通過299x中心裁剪補丁對圖像進行FID計算,用于評估高分辨率細節表現。最后通過人工評估將本方法與現有前沿技術進行對比。綜合評估結果表明,采用本方法訓練的蒸餾模型不僅超越了先前研究,甚至能與教師模型的性能相媲美。詳細的訓練和評估流程詳見附錄。
5.1 Class-conditional Image Generation
表1展示了我們在ImageNet-64×64數據集上對模型的性能對比。通過單次前向傳播,我們的方法不僅顯著超越了現有的蒸餾技術,甚至在使用ODE采樣器[52]時還超越了教師模型。這一卓越表現主要歸功于兩個關鍵改進:首先移除了DMD的回歸損失(第4.1和4.2節),消除了ODE采樣器帶來的性能上限限制;其次引入了額外的GAN項(第4.3節),有效緩解了教師擴散模型評分近似誤差帶來的負面影響。
5.2 Text-to-Image Synthesis
我們在零樣本COCO 2014數據集[62]上評估了DMD2的文本到圖像生成性能。生成器分別通過蒸餾SDXL [57]和SD v1.5 [1]進行訓練,使用來自LAION-Aesthetics [58]的300萬條提示子集。此外,我們從LAIONAesthetic中收集了50萬張圖像作為GAN判別器的訓練數據。表2總結了SDXL模型的蒸餾結果。
我們的四步生成器能夠產出高質量且多樣化的樣本,FID達到了19.32,CLIP 得分為0.322。在圖像質量與提示一致性方面,我們的模型與教師擴散模型形成競爭。為驗證方法的有效性,我們通過大量用戶研究將模型輸出與教師模型及現有蒸餾方法進行對比。實驗采用PartiPrompts [69]數據集中的128個提示子集,并遵循LADD [24]方法進行評估。
每次對比時,我們隨機選取五位評審員,讓他們分別選出視覺效果更佳的圖像及最符合文本提示的圖像。具體評估細則詳見附錄H。如圖5所示,我們的模型在用戶偏好度上顯著優于基線方法。值得注意的是,在24%的樣本中,我們的模型在圖像質量上超越了教師模型,同時保持了相當的提示一致性,且僅需25×次前向傳播(4次對比100次)。定性對比結果見圖6。SDv1.5的測試數據詳見附錄A表5。同樣地,使用DMD2訓練的一步法模型表現超越所有傳統擴散加速方法,FID分數達到8.35,較原始DMD方法[22]提升3.14分。我們的結果也優于采用50步PNDM采樣器[49]的教師模型。
5.3 Ablation Studies
表3展示了我們在ImageNet數據集上對所提方法不同組件的消融實驗。若直接從原始DMD方法中移除ODE回歸損失,由于訓練不穩定導致FID值下降至3.48。但通過引入我們的雙時間尺度更新規則,這一性能下滑得到有效緩解,在無需額外構建數據集的情況下達到了與DMD基線相當的水平。加入生成對抗網絡(GAN)損失項后,FID值進一步提升了1.1分。綜合方案的表現明顯優于單獨使用GAN(未結合分布匹配目標),而將雙時間尺度更新規則添加到純GAN模型中也未能帶來改善,這充分證明了在統一框架下融合分布匹配與GAN的有效性。
在表4中,我們通過消融實驗驗證了生成對抗網絡(GAN)項(第4.3節)、分布匹配目標函數(公式2)以及反向模擬(第4.4節)對SDXL模型四步生成器的影響。
如圖7所示,當移除GAN損失時,基線模型生成的圖像出現過飽和和平滑過度現象(見圖7第三列)。類似地,若剔除分布匹配目標函數(公式2),我們的方法將退化為純GAN方法,這種純GAN方法在訓練穩定性方面存在明顯缺陷[70,71]。此外,純GAN方法還缺乏整合無分類器引導機制的天然途徑[72],而該機制對于高質量文本到圖像合成至關重要[1,2]。因此,雖然基于生成對抗網絡(GAN)的方法通過精準匹配真實分布獲得了最低的FID值,但在文本對齊和美學質量方面表現明顯遜色(圖7第二列)。同樣地,如退化補丁FID分數所示,省略反向模擬會導致圖像質量下降。
6 Limitations
雖然我們的蒸餾生成器在圖像質量與文本對齊方面表現優異,但相較于教師模型,其圖像多樣性略有不足(詳見附錄B)。此外,我們的生成器仍需經過四個步驟才能達到最大SDXL模型的質量水平。這些局限性雖非本模型獨有,卻凸顯了改進方向。與多數傳統蒸餾方法類似,我們在訓練中采用固定引導尺度,限制了用戶操作的靈活性。引入可變引導尺度[13,31]或將成為未來研究的重要方向。值得注意的是,當前方法主要針對分布匹配進行優化,若能融入人類反饋或其他獎勵函數,性能將有更顯著提升[17,73]。最后需要指出的是,大規模生成模型的訓練過程計算量極大,這使得大多數研究者難以開展相關工作。