初學機器學習:直觀解讀KL散度的數學概念
轉自:初學機器學習:直觀解讀KL散度的數學概念
譯自:https://towardsdatascience.com/light-on-math-machine-learning-intuitive-guide-to-understanding-kl-divergence-2b382ca2b2a8
解讀自:https://www.countbayesie.com/blog/2017/5/9/kullback-leibler-divergence-explained
代碼:https://github.com/thushv89/exercises_thushv_dot_com/blob/master/kl_divergence.ipynb
作者:Thushan Ganegedara,機器之心編譯。
本文修正了一些錯誤,優化了排版。
基礎概念
首先讓我們確立一些基本規則。我們將會定義一些我們需要了解的概念。
分布(distribution)
分布可能指代不同的東西,比如數據分布或概率分布。我們這里所涉及的是概率分布。假設你在一張紙上畫了兩根軸(即 XXX 和 YYY),我可以將一個分布想成是落在這兩根軸之間的一條線。其中 XXX 表示你有興趣獲取概率的不同值。YYY 表示觀察 XXX 軸上的值時所得到的概率。即 y=p(x)y=p(x)y=p(x)。下圖即是某個分布的可視化。
這是一個連續概率分布。比如,我們可以將 XXX 軸看作是人的身高,YYY 軸是找到對應身高的人的概率。
如果你想得到離散的概率分布,你可以將這條線分成固定長度的片段并以某種方式將這些片段水平化。然后就能根據這條線的每個片段創建邊緣互相連接的矩形。這就能得到一個離散概率分布。
事件(event)
對于離散概率分布而言,事件是指觀察到 XXX 取某個值(比如 X=1X=1X=1)的情況。我們將事件 X=1X=1X=1 的概率記為 P(X=1)P(X=1)P(X=1)。在連續空間中,你可以將其看作是一個取值范圍(比如 0.95<X<1.050.95<X<1.050.95<X<1.05)。注意,事件的定義并不局限于在 XXX 軸上取值。但是我們后面只會考慮這種情況。
回到 KL 散度
從這里開始,我將使用來自這篇博文的示例:https://www.countbayesie.com/blog/2017/5/9/kullback-leibler-divergence-explained。這是一篇很好的 KL 散度介紹文章,但我覺得其中某些復雜的解釋可以更詳細的闡述。好了,讓我們繼續吧。
我們想要解決的問題
上述博文中所解決的核心問題是這樣的:假設我們是一組正在廣袤無垠的太空中進行研究的科學家。我們發現了一些太空蠕蟲,這些太空蠕蟲的牙齒數量各不相同。現在我們需要將這些信息發回地球。但從太空向地球發送信息的成本很高,所以我們需要用盡量少的數據表達這些信息。我們有個好方法:我們不發送單個數值,而是繪制一張圖表,其中 XXX 軸表示所觀察到的不同牙齒數量 (0,1,2…)(0,1,2…)(0,1,2…),YYY 軸是看到的太空蠕蟲具有 xxx 顆牙齒的概率(即具有 xxx 顆牙齒的蠕蟲數量/蠕蟲總數量)。這樣,我們就將觀察結果轉換成了分布。
發送分布比發送每只蠕蟲的信息更高效。但我們還能進一步壓縮數據大小。我們可以用一個已知的分布來表示這個分布(比如均勻分布、二項分布、正態分布)。舉個例子,假如我們用均勻分布來表示真實分布,我們只需要發送兩段數據就能恢復真實數據;均勻概率和蠕蟲數量。但我們怎樣才能知道哪種分布能更好地解釋真實分布呢?這就是 KL 散度的用武之地。
直觀解釋:KL 散度是一種衡量兩個分布(比如兩條線)之間的匹配程度的方法。
讓我們對示例進行一點修改
為了能夠檢查數值的正確性,讓我們將概率值修改成對人類更友好的值(相比于上述博文中的值)。我們進行如下假設:假設有 100 只蠕蟲,各種牙齒數的蠕蟲的數量統計結果如下。
牙齒顆數 iii | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 |
---|---|---|---|---|---|---|---|---|---|---|---|
蠕蟲數 | 2 | 3 | 5 | 14 | 16 | 15 | 12 | 8 | 10 | 8 | 7 |
概率 pip_ipi? | 0.02 | 0.03 | 0.05 | 0.14 | 0.16 | 0.15 | 0.12 | 0.08 | 0.10 | 0.08 | 0.07 |
快速做一次完整性檢查!確保蠕蟲總數為 100,且概率總和為 1.0.
- 蠕蟲總數 = 2+3+5+14+16+15+12+8+10+8+7 = 100
- 概率總和 = 0.02+0.03+0.05+0.14+0.16+0.15+0.12+0.08+0.1+0.08+0.07 = 1.0
可視化結果為:
嘗試 1:使用均勻分布建模
我們首先使用均勻分布來建模該分布。均勻分布只有一個參數:均勻概率;即給定事件發生的概率。
puniform=1totalevents=111=0.0909p_{uniform}=\frac{1}{total\ events}=\frac{1}{11}=0.0909 puniform?=total?events1?=111?=0.0909
均勻分布和我們的真實分布對比:
先不討論這個結果,我們再用另一種分布來建模真實分布。
嘗試 2:使用二項分布建模
你可能計算過拋硬幣正面或背面向上的概率,這就是一種二項分布概率。我們可以將同樣的概念延展到我們的問題上。對于有兩個可能輸出的硬幣,我們假設硬幣正面向上的概率為 ppp,并且進行了 nnn 次嘗試,那么其中成功 kkk 次的概率為:
P(X=k)=(nk)pk(1?p)n?kP(X=k)=\begin{pmatrix} n \\ k \end{pmatrix} p^k(1-p)^{n-k} P(X=k)=(nk?)pk(1?p)n?k
公式解讀
這里說明一下二項分布中每一項的含義。第一項是 pkp^kpk。我們想成功 kkk 次,其中單次成功的概率為 ppp;那么成功 kkk 次的概率為 pkp^kpk。另外要記得我們進行了 nnn 次嘗試。因此,其中失敗的次數為 n?kn-kn?k,對應失敗的概率為 (1?p)(1-p)(1?p)。所以成功 k 次的概率即為聯合概率 pk(1?p)n?kp^k(1-p)^{n-k}pk(1?p)n?k。到此還未結束。在 nnn 次嘗試中,kkk 次成功會有不同的排列方式。在數量為 nnn 的空間中 kkk 個元素的不同排列數量為
(nk)=n!k!(n?k)!\begin{pmatrix} n \\ k \end{pmatrix} =\frac{n!}{k!(n-k)!} (nk?)=k!(n?k)!n!?
將所有這些項相乘就得到了成功 kkk 次的二項概率。
二項分布的均值和方差
我們還可以定義二項分布的均值和方差:mean=npmean=npmean=np,var=np(1?p)var=np(1-p)var=np(1?p)。
均值是什么意思?均值是指你進行 nnn 次嘗試時的期望(平均)成功次數。如果每次嘗試成功的概率為 ppp,那么可以說 nnn 次嘗試的成功次數為 npnpnp。
方差又是什么意思?它表示真實的成功嘗試次數偏離均值的程度。為了理解方差,讓我們假設 n=1n=1n=1,那么等式就成了「方差= p(1?p)p(1-p)p(1?p)」。那么當 p=0.5p=0.5p=0.5 時(正面和背面向上的概率一樣),方差最大;當 p=1p=1p=1 或 p=0p=0p=0 時(只能得到正面或背面中的一種),方差最小。
回來繼續建模
現在我們已經理解了二項分布,接下來回到我們之前的問題。首先讓我們計算蠕蟲的牙齒的期望值:
∑i=011=0×p0+1×p1+…10×p10=5.44\sum_{i=0}^{11}=0\times p_0+1\times p_1+\dots 10\times p_{10}=5.44 i=0∑11?=0×p0?+1×p1?+…10×p10?=5.44
有了均值,我們可以計算 ppp 的值:
mean=np5.44=10pp=0.544mean=np\\ 5.44=10p\\ p=0.544 mean=np5.44=10pp=0.544
注意,這里的 nnn 是指在蠕蟲中觀察到的最大牙齒數。你可能會問我們為什么不把蠕蟲總數(即 100)或總事件數(即 11)設為 nnn。我們很快就將看到原因。有了這些數據,我們可以按如下方式定義任意牙齒數的概率。
鑒于牙齒數的取值最大為 10,那么看見 kkk 顆牙齒的概率是多少(這里看見一顆牙齒即為一次成功嘗試)?
從拋硬幣的角度看,這就類似于:
假設我拋 10 次硬幣,觀察到 kkk 次正面向上的概率是多少?
從形式上講,我們可以計算所有不同 kkk 值的概率 pkbip_k^{bi}pkbi?。其中 kkk 是我們希望觀察到的牙齒數量,pkbip_k^{bi}pkbi? 是第 k 個牙齒數量位置(即 0 顆牙齒、1 顆牙齒……)的二項概率。所以,計算結果如下:
p0bi=(10!/(0!10!))0.5440(1–0.544)10=0.0004p1bi=(10!/(1!9!))0.5441(1–0.544)9=0.0046p2bi=(10!/(2!8!))0.5442(1–0.544)8=0.0249…p9bi=(10!/(9!1!))0.5449(1–0.544)1=0.0190p10bi=(10!/(10!0!))0.54410(1–0.544)0=0.0023p0^{bi} = (10!/(0!10!)) 0.544? (1–0.544)^{10} = 0.0004\\ p1^{bi} = (10!/(1!9!)) 0.5441 (1–0.544)? = 0.0046\\ p2^{bi} = (10!/(2!8!)) 0.5442 (1–0.544)? = 0.0249\\ …\\ p9^{bi} = (10!/(9!1!)) 0.544? (1–0.544)1 = 0.0190\\ p10^{bi} = (10!/(10!0!)) 0.544^{10} (1–0.544)? = 0.0023\\ p0bi=(10!/(0!10!))0.5440(1–0.544)10=0.0004p1bi=(10!/(1!9!))0.5441(1–0.544)9=0.0046p2bi=(10!/(2!8!))0.5442(1–0.544)8=0.0249…p9bi=(10!/(9!1!))0.5449(1–0.544)1=0.0190p10bi=(10!/(10!0!))0.54410(1–0.544)0=0.0023
我們的真實分布和二項分布的比較如下:
總結已有情況
現在回頭看看我們已經完成的工作。首先,我們理解了我們想要解決的問題。我們的問題是將特定類型的太空蠕蟲的牙齒數據統計用盡量小的數據量發回地球。為此,我們想到用某個已知分布來表示真實的蠕蟲統計數據,這樣我們就可以只發送該分布的參數,而無需發送真實統計數據。我們檢查了兩種類型的分布,得到了以下結果。
- 均勻分布——概率為 0.0909
- 二項分布——n=10n=10n=10、p=0.544p=0.544p=0.544,kkk 取值在 0 到 10 之間。
讓我們在同一個地方可視化這三個分布:
我們如何定量地確定哪個分布更好?
經過這些計算之后,我們需要一種衡量每個近似分布與真實分布之間匹配程度的方法。這很重要,這樣當我們發送信息時,我們才無需擔憂「我是否選擇對了?」畢竟太空蠕蟲關乎我們每個人的生命。
這就是 KL 散度的用武之地。KL 散度在形式上定義如下:
DKL(p∣∣q)=∑i=1Np(xi)log?p(xi)q(xi)D_{KL}(p||q)=\sum_{i=1}^Np(x_i)\log\frac{p(x_i)}{q(x_i)} DKL?(p∣∣q)=i=1∑N?p(xi?)logq(xi?)p(xi?)?
其中 q(x)q(x)q(x) 是近似分布,p(x)p(x)p(x) 是我們想要用 q(x)q(x)q(x) 匹配的真實分布。直觀地說,這衡量的是給定任意分布偏離真實分布的程度。如果兩個分布完全匹配,那么DKL(p∣∣q)=0D_{KL}(p||q)=0DKL?(p∣∣q)=0 ,否則它的取值應該是在 0 到 ∞\infty∞ 之間。KL 散度越小,真實分布與近似分布之間的匹配就越好。
KL 散度的直觀解釋
讓我們看看 KL 散度各個部分的含義。首先看看
logp(xi)q(xi)log\frac{p(x_i)}{q(x_i)} logq(xi?)p(xi?)?
項。如果 q(xi)q(x_i)q(xi?) 大于 p(xi)p(x_i)p(xi?) 會怎樣呢?此時這個項的值為負,因為小于 1 的值的對數為負。另一方面,如果 q(xi)q(x_i)q(xi?) 總是小于 p(xi)p(x_i)p(xi?),那么該項的值為正。如果 p(xi)=q(xi)p(x_i)=q(x_i)p(xi?)=q(xi?) 則該項的值為 0。然后,為了使這個值為期望值,你要用 p(xi)p(x_i)p(xi?) 來給這個對數項加權。也就是說,p(xi)p(x_i)p(xi?) 有更高概率的匹配區域比低 p(xi)p(x_i)p(xi?) 概率的匹配區域更加重要。
直觀而言,優先正確匹配近似分布中真正高可能性的事件是有實際價值的。從數學上講,這能讓你自動忽略落在真實分布的支集(支集(support)是指分布使用的 XXX 軸的全長度)之外的分布區域。另外,這還能避免計算 log(0)log(0)log(0) 的情況——如果你試圖計算落在真實分布的支集之外的任意區域的這個對數項,就可能出現這種情況。
計算 KL 散度
我們計算一下上面兩個近似分布與真實分布之間的 KL 散度。首先來看均勻分布:
D(True∣∣Uniform)=0.02log?(0.02/0.0909)+?+0.07log?(0.07/0.0909)=0.136D(True||Uniform)=0.02\log(0.02/0.0909)+\dots+0.07\log(0.07/0.0909)=0.136 D(True∣∣Uniform)=0.02log(0.02/0.0909)+?+0.07log(0.07/0.0909)=0.136
再看看二項分布:
D(True∣∣Binomila)=0.02log?(0.02/0.0004)+?+0.07log?(0.07/0.0023)=0.427D(True||Binomila)=0.02\log(0.02/0.0004)+\dots+0.07\log(0.07/0.0023)=0.427 D(True∣∣Binomila)=0.02log(0.02/0.0004)+?+0.07log(0.07/0.0023)=0.427
玩一玩 KL 散度
現在,我們來玩一玩 KL 散度。首先我們會先看看當二元分布的成功概率變化時 KL 散度的變化情況。不幸的是,我們不能使用均勻分布做同樣的事,因為 nnn 固定時均勻分布的概率不會變化。
可以看到,當我們遠離我們的選擇(紅點)時,KL 散度會快速增大。實際上,如果你顯示輸出我們的選擇周圍小 Δ\DeltaΔ 數量的 KL 散度值,你會看到我們選擇的成功概率的 KL 散度最小。
現在讓我們看看 DKL(P∣∣Q)D_{KL}(P||Q)DKL?(P∣∣Q) 和 DKL(Q∣∣P)D_{KL}(Q||P)DKL?(Q∣∣P) 的行為方式。如下圖所示:
看起來有一個區域中的 DKL(P∣∣Q)D_{KL}(P||Q)DKL?(P∣∣Q) 和 DKL(Q∣∣P)D_{KL}(Q||P)DKL?(Q∣∣P) 之間有最小的距離。讓我們繪出兩條線之間的差異(虛線),并且放大我們的概率選擇所在的區域。
有最低差異的區域(但并不是最低差異的區域)。但這仍然是一個很有意思的發現。我不確定出現這種情況的原因是什么。如果有人知道,歡迎討論。
結論
現在我們有些可靠的結果了。盡管均勻分布看起來很簡單且信息不多而二項分布帶有更有差別的信息,但實際上均勻分布與真實分布之間的匹配程度比二項分布的匹配程度更高。說老實話,這個結果實際上讓我有點驚訝。因為我之前預計二項分布能更好地建模這個真實分布。因此,這個實驗也能告訴我們:不要只相信自己的直覺!