核心思想與定義
擴散模型的核心思想是:學習一個去噪過程,以逆轉一個固定的加噪過程。
-
前向過程(固定): 定義一個馬爾可夫鏈,逐步向數據 x0~q(x0)\mathbf{x}_0 \sim q(\mathbf{x}_0)x0?~q(x0?) 添加高斯噪聲,產生一系列噪聲逐漸增大的隱變量 x1,...,xT\mathbf{x}_1, ..., \mathbf{x}_Tx1?,...,xT?。最終 xT\mathbf{x}_TxT? 近似為一個標準高斯分布。
q(x1:T∣x0)=∏t=1Tq(xt∣xt?1),其中q(xt∣xt?1)=N(xt;1?βtxt?1,βtI) q(\mathbf{x}_{1:T} | \mathbf{x}_0) = \prod_{t=1}^T q(\mathbf{x}_t | \mathbf{x}_{t-1}), \quad \text{其中} \quad q(\mathbf{x}_t | \mathbf{x}_{t-1}) = \mathcal{N}(\mathbf{x}_t; \sqrt{1 - \beta_t} \mathbf{x}_{t-1}, \beta_t \mathbf{I}) q(x1:T?∣x0?)=t=1∏T?q(xt?∣xt?1?),其中q(xt?∣xt?1?)=N(xt?;1?βt??xt?1?,βt?I)
這里 {βt}t=1T\{\beta_t\}_{t=1}^T{βt?}t=1T? 是預先定義好的方差調度表。 -
反向過程(可學習): 我們想要學習一個參數化的反向馬爾可夫鏈 pθp_\thetapθ?,從噪聲 xT~N(0,I)\mathbf{x}_T \sim \mathcal{N}(\mathbf{0}, \mathbf{I})xT?~N(0,I) 開始,逐步去噪以生成數據。
pθ(x0:T)=p(xT)∏t=1Tpθ(xt?1∣xt),其中pθ(xt?1∣xt)=N(xt?1;μθ(xt,t),Σθ(xt,t)) p_\theta(\mathbf{x}_{0:T}) = p(\mathbf{x}_T) \prod_{t=1}^T p_\theta(\mathbf{x}_{t-1} | \mathbf{x}_t), \quad \text{其中} \quad p_\theta(\mathbf{x}_{t-1} | \mathbf{x}_t) = \mathcal{N}(\mathbf{x}_{t-1}; \mathbf{\mu}_\theta(\mathbf{x}_t, t), \mathbf{\Sigma}_\theta(\mathbf{x}_t, t)) pθ?(x0:T?)=p(xT?)t=1∏T?pθ?(xt?1?∣xt?),其中pθ?(xt?1?∣xt?)=N(xt?1?;μθ?(xt?,t),Σθ?(xt?,t))
我們的目標是讓 pθ(x0)p_\theta(\mathbf{x}_0)pθ?(x0?) 盡可能接近真實數據分布 q(x0)q(\mathbf{x}_0)q(x0?)。 -
前向過程的閉式解: 得益于高斯分布的可加性,我們可以直接從 x0\mathbf{x}_0x0? 采樣任意時刻 ttt 的 xt\mathbf{x}_txt?:
q(xt∣x0)=N(xt;αˉtx0,(1?αˉt)I) q(\mathbf{x}_t | \mathbf{x}_0) = \mathcal{N}(\mathbf{x}_t; \sqrt{\bar{\alpha}_t} \mathbf{x}_0, (1 - \bar{\alpha}_t)\mathbf{I}) q(xt?∣x0?)=N(xt?;αˉt??x0?,(1?αˉt?)I)
其中 αt=1?βt\alpha_t = 1 - \beta_tαt?=1?βt?, αˉt=∏i=1tαi\bar{\alpha}_t = \prod_{i=1}^t \alpha_iαˉt?=∏i=1t?αi?。使用重參數化技巧,可以寫為:
xt=αˉtx0+1?αˉt?,其中?~N(0,I) \mathbf{x}_t = \sqrt{\bar{\alpha}_t} \mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_t} \mathbf{\epsilon}, \quad \text{其中} \quad \mathbf{\epsilon} \sim \mathcal{N}(\mathbf{0}, \mathbf{I}) xt?=αˉt??x0?+1?αˉt???,其中?~N(0,I)
這個公式至關重要,它允許我們隨機采樣時間步 ttt 并高效地計算訓練損失。
優化目標:變分下界 (VLB/ELBO)
我們的目標是最大化模型生成真實數據的對數似然 log?pθ(x0)\log p_\theta(\mathbf{x}_0)logpθ?(x0?)。由于其難以直接計算,我們轉而最大化其變分下界(VLB),也稱為證據下界(ELBO)。
log?pθ(x0)≥Eq(x1:T∣x0)[log?pθ(x0:T)q(x1:T∣x0)]=Eq[log?p(xT)∏t=1Tpθ(xt?1∣xt)∏t=1Tq(xt∣xt?1)]??LVLB
\begin{aligned}
\log p_\theta(\mathbf{x}_0)
&\geq \mathbb{E}_{q(\mathbf{x}_{1:T} | \mathbf{x}_0)} \left[ \log \frac{p_\theta(\mathbf{x}_{0:T})}{q(\mathbf{x}_{1:T} | \mathbf{x}_0)} \right] \\
&= \mathbb{E}_{q} \left[ \log \frac{ p(\mathbf{x}_T) \prod_{t=1}^T p_\theta(\mathbf{x}_{t-1} | \mathbf{x}_t) }{ \prod_{t=1}^T q(\mathbf{x}_t | \mathbf{x}_{t-1}) } \right] \\
&\triangleq -L_{\text{VLB}}
\end{aligned}
logpθ?(x0?)?≥Eq(x1:T?∣x0?)?[logq(x1:T?∣x0?)pθ?(x0:T?)?]=Eq?[log∏t=1T?q(xt?∣xt?1?)p(xT?)∏t=1T?pθ?(xt?1?∣xt?)?]??LVLB??
因此,我們最小化 LVLBL_{\text{VLB}}LVLB?。
通過對 LVLBL_{\text{VLB}}LVLB? 進行推導(利用馬爾可夫性和貝葉斯定理),可以將其分解為以下幾項:
LVLB=Eq[DKL(q(xT∣x0)∥p(xT))?LT?log?pθ(x0∣x1)?L0+∑t=2TDKL(q(xt?1∣xt,x0)∥pθ(xt?1∣xt))?Lt?1] L_{\text{VLB}} = \mathbb{E}_q [\underbrace{D_{\text{KL}}(q(\mathbf{x}_T | \mathbf{x}_0) \parallel p(\mathbf{x}_T))}_{L_T} - \underbrace{\log p_\theta(\mathbf{x}_0 | \mathbf{x}_1)}_{L_0} + \sum_{t=2}^T \underbrace{D_{\text{KL}}(q(\mathbf{x}_{t-1} | \mathbf{x}_t, \mathbf{x}_0) \parallel p_\theta(\mathbf{x}_{t-1} | \mathbf{x}_t))}_{L_{t-1}} ] LVLB?=Eq?[LT?DKL?(q(xT?∣x0?)∥p(xT?))???L0?logpθ?(x0?∣x1?)??+t=2∑T?Lt?1?DKL?(q(xt?1?∣xt?,x0?)∥pθ?(xt?1?∣xt?))??]
- LTL_TLT?: 衡量最終噪聲分布與先驗分布 N(0,I)\mathcal{N}(\mathbf{0}, \mathbf{I})N(0,I) 的差異。此項沒有可學習參數,接近于0,可以忽略。
- L0L_0L0?: 重建項,衡量最后一步生成圖像與真實圖像的差異。此項在原始DDPM中通過一個離散化decoder處理,實踐中發現其影響較小。
- Lt?1L_{t-1}Lt?1? (1≤t≤T1 \le t \le T1≤t≤T): 這是最關鍵的一項。它衡量的是對于每一個去噪步,真實的去噪分布 q(xt?1∣xt,x0)q(\mathbf{x}_{t-1} | \mathbf{x}_t, \mathbf{x}_0)q(xt?1?∣xt?,x0?) 和 學習的去噪分布 pθ(xt?1∣xt)p_\theta(\mathbf{x}_{t-1} | \mathbf{x}_t)pθ?(xt?1?∣xt?) 之間的KL散度。
核心推導:真實的后驗分布 q(xt?1∣xt,x0)q(\mathbf{x}_{t-1} | \mathbf{x}_t, \mathbf{x}_0)q(xt?1?∣xt?,x0?)
根據貝葉斯定理和馬爾可夫性,我們可以推導出這個真實的后驗分布。它也是一個高斯分布,這意味著我們可以用另一個高斯分布 pθp_\thetapθ? 去匹配它。
q(xt?1∣xt,x0)=q(xt∣xt?1,x0)q(xt?1∣x0)q(xt∣x0)∝N(xt;αtxt?1,(1?αt)I)?N(xt?1;αˉt?1x0,(1?αˉt?1)I) \begin{aligned} q(\mathbf{x}_{t-1} | \mathbf{x}_t, \mathbf{x}_0) &= \frac{q(\mathbf{x}_t | \mathbf{x}_{t-1}, \mathbf{x}_0) q(\mathbf{x}_{t-1} | \mathbf{x}_0)}{q(\mathbf{x}_t | \mathbf{x}_0)} \\ &\propto \mathcal{N}(\mathbf{x}_t; \sqrt{\alpha_t} \mathbf{x}_{t-1}, (1 - \alpha_t)\mathbf{I}) \cdot \mathcal{N}(\mathbf{x}_{t-1}; \sqrt{\bar{\alpha}_{t-1}} \mathbf{x}_0, (1 - \bar{\alpha}_{t-1})\mathbf{I}) \end{aligned} q(xt?1?∣xt?,x0?)?=q(xt?∣x0?)q(xt?∣xt?1?,x0?)q(xt?1?∣x0?)?∝N(xt?;αt??xt?1?,(1?αt?)I)?N(xt?1?;αˉt?1??x0?,(1?αˉt?1?)I)?
經過一系列高斯分布密度函數的乘積和配方,可以得出其均值和方差為:
q(xt?1∣xt,x0)=N(xt?1;μ~t(xt,x0),β~tI) q(\mathbf{x}_{t-1} | \mathbf{x}_t, \mathbf{x}_0) = \mathcal{N}(\mathbf{x}_{t-1}; \mathbf{\tilde{\mu}}_t(\mathbf{x}_t, \mathbf{x}_0), \tilde{\beta}_t \mathbf{I}) q(xt?1?∣xt?,x0?)=N(xt?1?;μ~?t?(xt?,x0?),β~?t?I)
其中μ~t(xt,x0)=1αt(xt?βt1?αˉt?),β~t=1?αˉt?11?αˉtβt \text{其中} \quad \mathbf{\tilde{\mu}}_t(\mathbf{x}_t, \mathbf{x}_0) = \frac{1}{\sqrt{\alpha_t}} \left( \mathbf{x}_t - \frac{\beta_t}{\sqrt{1 - \bar{\alpha}_t}} \mathbf{\epsilon} \right), \quad \tilde{\beta}_t = \frac{1 - \bar{\alpha}_{t-1}}{1 - \bar{\alpha}_t} \beta_t 其中μ~?t?(xt?,x0?)=αt??1?(xt??1?αˉt??βt???),β~?t?=1?αˉt?1?αˉt?1??βt?
注意:這里 ?\mathbf{\epsilon}? 是前向過程中添加到 x0\mathbf{x}_0x0? 上生成 xt\mathbf{x}_txt? 的噪聲。這個 μ~t\mathbf{\tilde{\mu}}_tμ~?t? 的表達式非常關鍵!
簡化損失函數:從均值預測到噪聲預測
現在我們來看要最小化的 Lt?1L_{t-1}Lt?1?,它是兩個高斯分布的KL散度。高斯分布的KL散度主要由其均值的差異主導(假設方差固定)。
Lt?1=Eq[DKL(q(xt?1∣xt,x0)∥pθ(xt?1∣xt))]=Eq[12σt2∥μ~t(xt,x0)?μθ(xt,t)∥2]+C \begin{aligned} L_{t-1} &= \mathbb{E}_q \left[ D_{\text{KL}}(q(\mathbf{x}_{t-1} | \mathbf{x}_t, \mathbf{x}_0) \parallel p_\theta(\mathbf{x}_{t-1} | \mathbf{x}_t)) \right] \\ &= \mathbb{E}_q \left[ \frac{1}{2\sigma_t^2} \| \mathbf{\tilde{\mu}}_t(\mathbf{x}_t, \mathbf{x}_0) - \mathbf{\mu}_\theta(\mathbf{x}_t, t) \|^2 \right] + C \end{aligned} Lt?1??=Eq?[DKL?(q(xt?1?∣xt?,x0?)∥pθ?(xt?1?∣xt?))]=Eq?[2σt2?1?∥μ~?t?(xt?,x0?)?μθ?(xt?,t)∥2]+C?
現在我們有兩個選擇:
- 讓網絡 μθ\mathbf{\mu}_\thetaμθ? 直接預測均值 μ~t\mathbf{\tilde{\mu}}_tμ~?t?。
- 根據 μ~t\mathbf{\tilde{\mu}}_tμ~?t? 的表達式,重新參數化模型。
DDPM選擇了第二種方式,因為它效果更好。我們將 μ~t\mathbf{\tilde{\mu}}_tμ~?t? 的表達式代入:
μθ(xt,t)=1αt(xt?βt1?αˉt?θ(xt,t)) \mathbf{\mu}_\theta(\mathbf{x}_t, t) = \frac{1}{\sqrt{\alpha_t}} \left( \mathbf{x}_t - \frac{\beta_t}{\sqrt{1 - \bar{\alpha}_t}} \mathbf{\epsilon}_\theta(\mathbf{x}_t, t) \right) μθ?(xt?,t)=αt??1?(xt??1?αˉt??βt???θ?(xt?,t))
這里,我們不再讓網絡預測均值,而是讓它預測噪聲 ?\mathbf{\epsilon}?,即 ?θ(xt,t)\mathbf{\epsilon}_\theta(\mathbf{x}_t, t)?θ?(xt?,t)。將這個表達式代入上面的損失函數,經過簡化(忽略權重系數),我們得到最終極其簡潔的損失函數:
Lsimple=Ex0,t,?~N(0,I)[∥???θ(αˉtx0+1?αˉt?,t)∥2] L_{\text{simple}} = \mathbb{E}_{\mathbf{x}_0, t, \mathbf{\epsilon} \sim \mathcal{N}(\mathbf{0}, \mathbf{I})} \left[ \| \mathbf{\epsilon} - \mathbf{\epsilon}_\theta( \sqrt{\bar{\alpha}_t} \mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_t} \mathbf{\epsilon}, t ) \|^2 \right] Lsimple?=Ex0?,t,?~N(0,I)?[∥???θ?(αˉt??x0?+1?αˉt???,t)∥2]
這個損失函數的直觀解釋是:對于一張真實圖像 x0\mathbf{x}_0x0?,隨機選擇一個時間步 ttt,隨機采樣一個噪聲 ?\mathbf{\epsilon}?,構造出噪聲圖像 xt\mathbf{x}_txt?。然后,我們訓練一個網絡 ?θ\mathbf{\epsilon}_\theta?θ?,讓它根據 xt\mathbf{x}_txt? 和 ttt 來預測出我們添加的噪聲 ?\mathbf{\epsilon}?。損失就是預測噪聲和真實噪聲之間的均方誤差。
總結:優化流程
- 輸入:從訓練集中采樣一張真實圖像 x0\mathbf{x}_0x0?。
- 加噪:
- 均勻采樣一個時間步 t~Uniform(1,...,T)t \sim \text{Uniform}(1, ..., T)t~Uniform(1,...,T)。
- 從標準高斯分布采樣噪聲 ?~N(0,I)\mathbf{\epsilon} \sim \mathcal{N}(\mathbf{0}, \mathbf{I})?~N(0,I)。
- 計算 xt=αˉtx0+1?αˉt?\mathbf{x}_t = \sqrt{\bar{\alpha}_t} \mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_t} \mathbf{\epsilon}xt?=αˉt??x0?+1?αˉt???。
- 預測:將 xt\mathbf{x}_txt? 和 ttt 輸入神經網絡 ?θ\mathbf{\epsilon}_\theta?θ?,得到其對噪聲的預測 ?θ(xt,t)\mathbf{\epsilon}_\theta(\mathbf{x}_t, t)?θ?(xt?,t)。
- 優化:計算損失 L=∥???θ∥2L = \| \mathbf{\epsilon} - \mathbf{\epsilon}_\theta \|^2L=∥???θ?∥2,并通過梯度下降更新網絡參數 θ\thetaθ。
- 重復:重復步驟1-4直至收斂。
這個框架的巧妙之處在于,它將一個復雜的生成問題,分解為了 TTT 個相對簡單的去噪問題。網絡 ?θ\mathbf{\epsilon}_\theta?θ? 不需要一步生成完美圖像,只需要在每一步完成一個更簡單的任務:預測噪聲。這使得訓練非常穩定,也是擴散模型成功的核心原因。