introduce
什么樣的 latent 空間更適合用于擴散模型?作者發現:相比傳統的 VAE,結構良好、判別性強的 latent 空間才是 diffusion 成功的關鍵。
研究動機:什么才是“好的 latent 表征”?
背景:
- Diffusion Models最初在像素空間操作,但效率低;
- 后續工作(如 Latent Diffusion Models)引入tokenizer,將圖像壓縮成 latent token,再在 latent 空間進行生成,提高效率;
- VAE 是常見的 tokenizer,要求 latent 遵循高斯分布(通過 KL regularization)。
問題:
- VAE 的 KL 限制損害了圖像重建質量;
- 普通 AE 雖然重建質量高,但 latent 表征結構性較差,對擴散模型訓練不友好;
那么問題來了:什么樣的 latent 才最適合用于 diffusion?VAE 真有必要嗎?
關鍵發現:結構良好的 latent space 才是關鍵,而非 VAE 的正則。擁有更少 GMM 模式(即更清晰結構、更聚類)的 latent 表征 → 擴散模型訓練損失更小 → 生成效果更好
具體來說:
- 給不同類型的 tokenizer(AE / VAE / 表征對齊 VAE / MAETok)提取 latent;
- 擬合 Gaussian Mixture Model(GMM),觀察模式數量(mode 數);
- 對應的擴散模型的訓練損失越小、生成越好,說明 latent 更利于建模。
結論: 判別性強、結構清晰(mode 少)的 latent 比“高斯先驗 + 正則”更有價值
核心方法:MAETok——用 Masked AE 做 tokenizer
總體設計: 用 MAE(Masked AutoEncoder)訓練 AE,而非 VAE,使其 latent: 語義豐富、 判別性強(discriminative)、可恢復像素。
Encoder:
- transformer-based encoder;
- 隨機 mask 掉輸入 patch(如 50%),強迫模型從部分觀察中學習全局語義;
- 得到的 latent 表征具有更高判別能力和更強結構性(類似 DINO、SimCLR)。
Decoder:
- 兩個 decoder:
- Pixel decoder:恢復輸入圖像;
- Auxiliary decoder:恢復 DINOv2 / HOG / CLIP 特征等;
- 這兩個目標并行訓練,增強表征語義的泛化能力;
- 在推理時只保留 pixel decoder,幾乎不增加開銷。
解耦機制:
- 訓練階段:高 mask ratio(如 60%)讓 encoder 學語義;
- 微調階段:freeze encoder,fine-tune decoder,讓它學會精確恢復像素;
避免語義學習與像素精度之間的沖突。
為什么判別性強、mode 少的 latent 更適合 diffusion?
從 diffusion loss 的角度推導:
- 擴散模型學習的是如何逐步去噪 latent 表征;
- 若 latent 本身是聚合性好的結構(mode 少、類內差小),就更容易建模。
- 理論上證明: GMM mode 越少 → 模型預測誤差(loss)越小 → 更好的 sample quality
On the Latent Space and Diffusion Models
Empirical Analysis
目標: 探索不同 tokenizer(AE、VAE、VAVAE)生成的 latent space 結構復雜度,以及這種結構如何影響 diffusion 模型的訓練和生成質量。
實驗設置:
- 用同樣結構和訓練配置分別訓練 AE、VAE、VAVAE,
- 把它們當作 tokenizer,對 ImageNet 圖像進行編碼得到 latent;
- 用 latent 訓練 DDPM 擴散模型;
- 用 GMM(高斯混合模型) 來衡量 latent 空間的復雜度:
- 模式數(mode K)越多 → 表示 latent 越復雜、結構越混亂;
- 模式數(mode K)越少 → latent 越聚合、語義更清晰,越利于建模;
圖2a:GMM 擬合對比(負對數似然 NLL) ,對 AE、VAE、VAVAE 的 latent 分別進行 GMM 擬合。比較不同模式數量下的 負對數似然(NLL),即擬合誤差。發現:
模型 | 所需 mode 數 | 擬合誤差(NLL) |
---|---|---|
AE | 多 | 高 |
VAE | 中 | 中 |
VAVAE | 少 | 低? |
進一步用這些 latent 分別訓練擴散模型,發現擴散模型訓練 loss 與 GMM mode 數量 幾乎對應:
- 模式越多 → 擴散學習更難 → loss 更高;
- 模式越少 → latent 更有語義結構 → 學習更輕松,loss 更小。
實驗驗證:模式少的 latent 空間能顯著降低擴散模型訓練難度,提高生成質量
Theoretical Analysis
目標: 從理論上解釋為何“mode 少” → “訓練更容易”,即模式數越多,訓練樣本復雜度越高。
理論設定:假設 latent 空間分布為 K 個等權高斯的混合(GMM):
擴散模型訓練目標采用 score matching loss:
Theorem 2.1
為了讓生成分布接近真實分布(KL誤差小于 O(Tε2)),所需樣本數量滿足:
K = 模式數(mode 數); d = latent 維度; B = 均值向量范數的上界(大致相同); ε = 目標誤差精度。
模式數越多(K ↑),樣本復雜度呈 K? 增長。
說明: mode 越多,越難建模,需要越多訓練樣本才能達到同樣生成質量。在訓練樣本有限的現實中,mode 少(如 VAVAE / MAETok)的 latent 更利于 diffusion 學習。
Method
那么核心問題: 如何訓練一個結構性更好、語義更豐富的 latent 空間,讓擴散模型更高效、更強大?
答案是:通過帶 Mask 的 AE(MAETok)結構 + 多目標訓練 + 解耦優化?策略,構造少mode、可判別的 latent,從而提升擴散模型學習效率與生成質量。
Architecture
?如圖,架構組件:
1. 編碼器(Encoder)
2. 解碼器(Decoder)
3.?位置編碼策略(RoPE)
- 對于 image patch tokens 使用 2D Rotary Position Embedding(RoPE) 保留圖像結構;
- 對于 latent tokens 使用 1D 絕對位置編碼,表示抽象語義;
Mask Modeling
MAETok 結構的關鍵設計之一:
- 對圖像 patch token 施加 40%~60% 的隨機掩碼;
- 將被 mask 的 patch 替換為 learnable mask token;
- 讓 latent tokens 學會從剩余部分恢復被遮擋部分信息 → 增強其判別能力;
- 同時,mask 的 patch 特征通過 shallow decoder 去恢復多種語義目標;
高 mask 比例訓練迫使 encoder 抓住圖像的全局、穩定特征,從而提升 latent 表征的“結構性”。
Auxiliary Shallow Decoders
多目標特征預測:進一步強化 latent 語義。
- 使用多個淺層解碼器 D?,預測如: HOG(邊緣特征); DINOv2; CLIP; 文本 token(如 BPE index)等;
- 每個淺層解碼器結構與主 pixel decoder 類似,但層數更少;
- 訓練 loss:只在被 mask 的位置上監督,強化 latent token 對多種語義結構的恢復能力
Pixel Decoder Fine-Tuning
解碼器解耦微調。由于 mask 訓練主要優化 encoder,可能損失了重建精度,因此:
- 最后階段凍結 encoder;
- 微調 pixel decoder 若干輪,僅優化重建質量;
- 不再使用 mask 或輔助解碼器。loss 采用標準組合:
這一步讓 encoder 保持判別性結構,同時恢復 decoder 的高保真圖像輸出能力。
Experiments
Setup
Tokenizer 訓練設置
- 基于 XQ-GAN 框架訓練;
- 編碼器和主 pixel 解碼器均為 ViT-Base(176M 參數);
- 設置 latent token 數量 L=128,維度 H=32;
- 三種數據集/尺寸設置: ImageNet-256 ImageNet-512 LAION-COCO-512 子集(預測圖文 BPE token)
多目標重建:
- mask 比例 40~60%;
- 三個淺層解碼器用于 HOG、DINO-v2、SigCLIP; LAION 加一個 BPE 文本目標;
- decoder 深度 = 3(通過消融得出);
- 損失系數:λ? = 1.0,λ? = 0.4;
- pixel 解碼器微調階段:mask 從 60% 線性下降到 0%。
Diffusion 模型訓練設置:
- 用 SiT(Simple Tokenizer) 與 LightningDiT;
- patch size=1,1D Positional Embedding;
- SiT-L(458M)用于消融,SiT-XL(675M)訓練 4M 步;
- LightningDiT 訓練 400K 步;
- 分辨率:256×256 與 512×512;
評估指標:
- Tokenizer 評估:
- 重建質量:rFID、PSNR、SSIM
- 語義評估:Linear Probing Accuracy(LP)
- 生成評估:
- gFID(生成 FID)、IS(Inception Score)
- Precision/Recall(附錄中)
- CFG 與否兩種條件下(classifier-free guidance)
Design Choices of MAETok
- Mask Modeling AE 中加入 mask modeling:
- gFID 明顯下降(→更好生成);
- rFID 稍升(重建質量下降),可通過 decoder 微調恢復;
- VAE 加 mask 效果小,因為 KL 抑制了 latent 學習。
結論:mask modeling 是提高 AE 表征能力、簡化擴散學習的關鍵。
重建目標 | 特點 | 效果 |
---|---|---|
原始像素 + HOG | 低級視覺特征 | 可學好 latent,但提升有限 |
DINO-v2, CLIP | 語義特征 | gFID 顯著下降(→更好生成) |
組合使用 | 同時兼顧結構和語義 | 最佳 trade-off |
結論:語義教師(CLIP/DINO)能教 AE 學習出更判別的 latent。
Mask 比例(Mask Ratio)
- 太低 → latent 太“忠實”,不判別;
- 太高 → 重建能力差;
- 40%~60% 是最優折中(參考 MAE 系列);
Auxiliary Decoder 深度
- 太淺 → 無法處理高低語義混合目標;
- 太深 → 容易記憶任務,反而不學好的 latent;
- 最優為:中等深度(3 層),效果最佳。
Latent Space Analysis
Latent 可視化(UMAP)
- AE / VAE 的 latent 分布混疊嚴重(類間重疊);
- MAETok latent 分布:類間分明,聚類清晰 → 判別性強;
圖 4(UMAP 圖)直觀支持這個發現。
LP Accuracy 與 gFID 的相關性(圖 5a)
- LP Acc 越高(latent 更判別)→ gFID 越低(生成越好);
- 提示 latent 表征與生成性能緊密相關。
收斂速度(圖 5b)
- MAETok latent 訓練更快;
- SiT-L 在使用 MAETok latent 時,gFID 下降更迅速、值更低。
生成任務對比(表 2/3)
- MAETok + SiT-XL(128 tokens)不使用 CFG,gFID=2.79(512),擊敗 REPA;
- 使用 CFG 后:超越 2B USiT 模型,達到 SOTA: gFID = 1.69(SiT) gFID = 1.65(LightningDiT)
- 使用更強 CFG(如 Autoguidance): gFID 進一步降到 1.54 或 1.51
結論:結構化 latent > 更大模型/更多 token。
重建能力(表 4)
- 256 分辨率,僅用 128 token,rFID=0.48,SSIM=0.763;
- 超越 SoftVQ 和 TexTok(后者 token 數翻倍);
- MS-COCO 上未訓練,仍具泛化能力;
- 在 512 resolution 下依舊保持優勢。
模型 | Token 數 | GFlops | 推理速度(A100) |
---|---|---|---|
原始 SiT-XL | 1024 | 373.3 | 0.1 img/sec |
MAETok | 128 | 48.5 | 3.12 img/sec |
?
Theoretical Analysis
- Step 1:從 latent 的 GMM 模式數 K 推導訓練誤差上限
- Step 2:從訓練誤差推導采樣誤差(KL/采樣分布和真實分布差異)
核心目標是推導:
- 生成誤差 ∝ 模式數 K? → 模式多訓練難度大
- MAETok 的 latent 空間更“判別”(K 少),所以訓練快、生成質量高
Preliminaries
輸入數據建模為 GMM 分布:?latent 空間數據是一個等權重、單位協方差的高斯混合模型
DDPM 的目標函數(Score Matching):
在 GMM 下的解析 score:
即 GMM 分布的 score 函數是“softmax 加權的類中心差值”。
模擬網絡預測的 score: 訓練模型 sθ(x) 采用相同結構假設:?
推論 A.4:數據二階矩上界為:
Step 1:從模式數到訓練誤差(估計 score 的誤差)
Theorem A.5:DDPM 的收斂誤差界
結論:K 越大 → 所需樣本數量越多 → 難訓練
推導 Score Estimation Error :用真實 score 和模型輸出之間的距離展開:
Step 2:從訓練誤差到采樣誤差(生成質量)
Theorem A.6(Early Stopping)
最終結論 Theorem A.7:完整誤差界
訓練 DDPM 時:
- 結論 1:模式數 K↑ → 樣本數 n↑↑↑ → 難訓練
- 結論 2:KL 越小 → 分布越相似 → FID 越低(在高斯假設下)
推理鏈條 | 對 MAETok 的意義 |
---|---|
高模式數 K → 訓練樣本要求高 | AE latent 太 entangled → 訓練慢 |
判別性強 latent(K 小) → 更快收斂 | MAETok 顯著加快 gFID 下降(圖5b) |
分布判別性高 → gFID 更低 | LP Acc ↑ → gFID ↓(圖5a) |
Score loss 越小 → KL 越小 → FID 越低 | MAETok 結構性 latent 直接提升生成質量 |
Experiments Setup
B.1. Training Details of AEs(自編碼器訓練細節)
MAETok 和其他 AE 對比模型(如 AE、KL-VAE、VAVAE)在完全相同的設置下訓練
B.2. Training Details of Diffusion Models(擴散模型訓練細節)
用兩個 backbone:
- SiT-XL(強表征能力)
- LightningDiT(輕量加速)
訓練設置遵循各自原始論文的配置,見 Tables 8、9;
與 AE 模塊解耦,主要對比 latent 空間設計對擴散模型訓練效果的影響。
B.3. Training Details of GMM Models(高斯混合建模的細節)
對應于 Fig. 2a 中對 latent 分布的可分性度量:
實驗流程:
- Flatten Latents :把原始 AE 輸出的 latent 表示 (N,H,C) reshape 為 (N,H×C)
- Dimensionality Reduction(PCA降維) :降維到維度 K,保留>90%方差,保證所有模型輸出 latent 都變為統一維度 (N,K)?,避免“維度詛咒”
- Normalization(標準化):保證不同模型輸出分布一致,避免尺度差異
- GMM Fitting + NLL 評估:擬合 GMM,輸出 NLL loss 衡量 latent 空間是否“結構清晰”(mode 少/可分性強)
訓練配置:
- 所有模型在 ImageNet 全量數據上訓練
- GMM 模型數量:50、100、200,對應訓練時間約為 3/8/11 小時
- 使用單卡 NVIDIA A8000(分布式訓練可提速)
Experiments Results
C.1. More Quantitative Generation Results
在 256×256 和 512×512 分辨率上提供了 Precision / Recall 的補充評估(Table 10, 11);
與 gFID 等指標互補,更全面評估生成質量與多樣性。
C.2. Classifier-free Guidance Tuning Results(CFG 調參結果)
CFG 是無條件擴散模型的關鍵組件,但:
- 即使是微小的 CFG scale 變化,gFID 也會明顯變化;
- 即使用 “CFG Interval” 技術(如 [0, 0.75])跳過高步數時間段,也很難穩定控制;
- 根本原因在于 unconditional class 的語義空間不穩定
實際使用的 CFG 設置:
分辨率 | 模型 | CFG Scale | Interval |
---|---|---|---|
256×256 | SiT-XL | 1.9 | [0, 0.75] |
256×256 | LightningDiT | 1.8 | [0, 0.75] |
512×512 | SiT-XL | 1.5 | [0, 0.7] |
512×512 | LightningDiT | 1.6 | [0, 0.65] |
結論與未來方向:
- 當前線性 CFG 無法有效控制 MAETok 的強語義 latent;
- 可嘗試采用更高級的 CFG 設計
C.3. Latent Space Visualization(可視化結果)
圖 9 展示了 MAETok 及其變體在不同重建目標下的 latent 分布,顯示出明顯的 分布清晰、聚類可分、mode 少 的特點; 理論分析中的 GMM 模型假設與實驗中圖像結果高度一致。
C.4. More Ablation Results
見 Table 13,主要關注兩個因素:
Token Type | 效果 |
---|---|
圖像 patch tokens | 表現普通 |
可學習 latent tokens | 效果顯著更好 |
結論:使用 learnable latent token 更高效,128 個就能達成與 256 個相當效果
2. 2D RoPE(二維相對位置編碼):
- 幫助模型在 混合分辨率訓練場景中泛化更好;
- 對比無位置編碼或1D編碼的模型有更強的空間建模能力。
模塊 | 要點 | 啟發 |
---|---|---|
AE 訓練 | 使用統一設置進行公平比較 | 可復現、可對比 |
GMM 分析 | PCA 降維+標準化+NLL度量 | 量化 latent 可分性(mode 越少越好) |
CFG 調參 | 變化劇烈,調優困難 | MAETok latent 空間語義穩定但不適于線性 CFG |
可視化 | 顯示 clear clustering | 理論假設與實際分布一致 |
Ablation | 128 latent token+2D RoPE 最優 | 更高效、分辨率穩健泛化 |
問題
K(模式數)和樣本量(components)指什么?
作者用 GMM(Gaussian Mixture Model) 去擬合 autoencoder 的 latent space:
- 每一個 mode K,就是一個高斯分布中心(Gaussian Component),代表 latent 中聚集的一群數據。
- K 越大,說明 latent 空間越“離散化、碎片化”,分布不集中。
- GMM 會估計出每個分布的均值 μ 和權重 w,用于刻畫 latent 的整體形狀。
核心直覺:一個“好”的 latent 空間,應該是幾個“集中的簇”,而不是碎片化、重疊、高維擴散。
分析得出: 為了學習一個含 K 個模式的 GMM,score-based 模型訓練所需的樣本量為:
這意味著:K 越大 → 模型越難訓練、樣本需求呈指數增長。
為什么要最小化score matching loss ?
?DDPM 訓練函數:
目標:讓模型輸出的去噪方向盡量接近真實的概率梯度方向,從而逐步反擴散、重建圖像。