受柯爾莫哥洛夫-阿諾德表示定理的啟發,作者提出柯爾莫哥洛夫-阿諾德網絡(KAN)作為多層感知器(MLP)有前途的替代品。MLP 在節點(“神經元”)上具有固定的激活函數,而 KAN 在邊(“權重”)上具有可學習的激活函數。KAN 沒有線性權重----每個權重參數都被參數化為spline的單變量函數所取代。作者證明,這種看似簡單的改變使得 KAN 在準確性和可解釋性方面優于 MLP。就準確性而言,在數據擬合和 PDE 求解中,較小的 KAN 可以比較大的 MLP 獲得可比或更好的準確性。從理論上和經驗上來說,KAN 比 MLP 擁有更快的神經尺度法則(neural scaling laws:隨著模型參數的增加,測試損失減小)。對于可解釋性,KAN 可以直觀地可視化,并且可以輕松地與人類交互。通過數學和物理領域的兩個例子,KAN 被證明是有用的“合作者”,幫助科學家(重新)發現數學和物理定律。總之,KAN 是 MLP 的有前途的替代品,為進一步改進當今嚴重依賴 MLP 的深度學習模型提供了機會。
來自:KAN: Kolmogorov–Arnold Networks
目錄
- 引言
- KAN
- Kolmogorov-Arnold表示理論
- KAN架構
- 實現細節
- 化簡KAN使其可解釋
- 思考
引言
多層感知器 MLP,也稱為全連接前饋神經網絡,是當今深度學習模型的基礎構建塊。MLP 的重要性怎么強調都不為過,因為它們是機器學習中用于逼近非線性函數的默認模型,MLP的表達能力由通用逼近定理保證。然而,MLP 是我們可以構建的最好的非線性回歸器嗎?盡管 MLP 被廣泛使用,但它們也有明顯的缺點。 例如,在 Transformer 中,MLP 占用幾乎所有參數,并且在沒有后期分析工具的情況下難以解釋。
作者提出了一種有希望的MLP替代方案,稱為Kolmogorov-Arnold網絡 KANs。MLP受到通用近似定理的啟發,而KAN則受到Kolmogorov-Arnold表示定理的啟發。與MLP一樣,KAN具有全連接的結構。然而,MLP將固定的激活函數放在節點(“神經元”)上,而KAN將可學習的激活函數放在邊(“權重”)上,如圖0.1所示。因此,KAN根本沒有線性權重矩陣:取而代之的是,每個權重參數都被一個可學習的一維函數參數化為spline所取代。KAN的節點只是簡單地對輸入信號求和,而不應用任何非線性。
- 圖0.1:MLP與KAN的比較。
有人可能會擔心,由于每個MLP的權重參數都變成了KAN的spline,因此KAN的成本非常昂貴。幸運的是,KAN擁有比MLP更小的規模。比如:對于PDE solving,2層,每層寬度10的KAN比4層,每層寬度100的MLP精確100倍( 1 0 ? 7 10^{?7} 10?7 vs 1 0 ? 5 10^{?5} 10?5 MSE),參數效率高100倍( 1 0 2 10^{2} 102 vs 1 0 4 10^{4} 104參數)。
其實之前已經有很多研究使用Kolmogorov-Arnold表示定理構建神經網絡。然而,大多數工作都堅持使用原始的depth-2 width-(2 n n n+1),并且沒有利用更現代的技術,例如反向傳播來訓練網絡。KAN的貢獻在于將原始的Kolmogorov-Arnold表示推廣到任意寬度和深度,并在今天的深度學習世界中重新激活它,以及使用廣泛的實驗來突出其作為AI+Science基礎模型的優勢,因為它的準確性和可解釋性。
KAN是spline和MLP的組合,利用各自的優勢并避免各自的弱點。spline對于低維函數是精確的,易于局部調整,并且能夠在不同的分辨率之間切換。然而,spline曲線由于無法利用組合結構而存在嚴重的COD(維數詛咒)。另一方面,MLP由于其特征學習而較少受到COD的影響,但由于其無法優化單變量函數,因此在低維情況下不如spline曲線準確。
為了準確地學習一個函數,一個模型不僅要學習組成結構(外部自由度),而且要很好地近似單變量函數(內部自由度)。KAN是這樣的模型,因為它們在外部有MLP,在內部有spline。因此,KAN不僅可以學習特征(由于它們與MLP的外部相似性),而且還可以以很高的精度優化這些學習到的特征(由于它們與spline的內部相似性)。比如,給定一個高維函數: f ( x 1 , . . . , x N ) = e x p ( 1 N ∑ i = 1 N s i n 2 ( x i ) ) f(x_{1},...,x_{N})=exp(\frac{1}{N}\sum_{i=1}^{N}sin^{2}(x_{i})) f(x1?,...,xN?)=exp(N1?i=1∑N?sin2(xi?))當 N N N較大時,由于COD的影響,spline失效;MLP可以潛在地學習廣義加性結構,但它們對于用ReLU激活來近似指數函數和正弦函數是非常低效的。相比之下,KANs可以很好地學習組合結構和單變量函數,因此在很大程度上優于MLP(見圖3.1)。
- 圖3.1:用5個toy例子來比較KAN和MLP。KAN 幾乎可以實現理論預測的最快縮放定律,而 MLP 的縮放定律很緩慢。
KAN
MLP的靈感來自于通用近似定理。本節組織形式為:
- 回顧Kolmogorov-Arnold定理
- Kolmogorov-Arnold網絡的設計
- 提出簡化技術以使KANs可解釋
Kolmogorov-Arnold表示理論
Vladimir Arnold和Andrey Kolmogorov證明了如果 f f f是一個有界域上的多元連續函數,那么 f f f可以寫成單變量連續函數的有限復合,更具體的,對于函數 f : [ 0 , 1 ] n → R f:[0,1]^{n}\rightarrow\mathbb{R} f:[0,1]n→R: f ( x ) = f ( x 1 , . . . , x n ) = ∑ q = 1 2 n + 1 Φ q ( ∑ p = 1 n ? q , p ( x p ) ) (2.1) f(\textbf{x})=f(x_{1},...,x_{n})=\sum_{q=1}^{2n+1}\Phi_{q}(\sum_{p=1}^{n}\phi_{q,p}(x_{p}))\tag{2.1} f(x)=f(x1?,...,xn?)=q=1∑2n+1?Φq?(p=1∑n??q,p?(xp?))(2.1)其中, ? q , p : [ 0 , 1 ] → R \phi_{q,p}:[0,1]\rightarrow\mathbb{R} ?q,p?:[0,1]→R, Φ q : R → R \Phi_{q}:\mathbb{R}\rightarrow\mathbb{R} Φq?:R→R。從某種意義上說,他們證明了唯一真正的多元函數是加法,因為所有其他函數都可以用一元函數與求和來表示。有人可能會天真地認為這對機器學習來說是個好消息:學習一個高維函數可以歸結為學習一個多項式數量的一維函數。然而,這些一維函數可能是非平滑的,甚至是分形的,因此在實踐中可能無法學習。由于這種病態的行為,Kolmogorov-Arnold表示定理在機器學習中基本上被判了死刑,被認為理論上是正確的,但實際上毫無用處。
然而,作者對Kolmogorov-Arnold定理在機器學習中的有用性更為樂觀:
- 首先,不需要拘謹于原始的Eq.(2.1),它只有兩層非線性和隱藏層中少量的項(2 n n n+1):作者將網絡推廣到任意的寬度和深度。
- 其次,科學和日常生活中的大多數函數通常是光滑的,并且具有稀疏的組成結構,這有助于光滑的Kolmogorov-Arnold表示。這里的哲學接近于物理學家的心態,他們通常更關心典型的情況,而不是最壞的情況。
KAN架構
假設我們有一個由輸入輸出對 { x i , y i } \left\{\textbf{x}_{i},y_{i}\right\} {xi?,yi?}組成的監督學習任務,我們想要找到一個 f f f對所有data point有 y i ≈ f ( x i ) y_{i}\approx f(\textbf{x}_{i}) yi?≈f(xi?)。Eq.(2.1)暗示,如果我們能找到合適的單變量函數 ? q , p \phi_{q,p} ?q,p?和 Φ q \Phi_{q} Φq?,我們就完成了。這啟發作者設計一個神經網絡,顯式參數化Eq.(2.1)。由于所有需要學習的函數都是單變量函數,我們可以將每個一維函數參數化為一條B-spline,以及具有局部B-spline基函數的可學習系數(見圖2.2右)。
- 圖2.2:左----流經網絡的激活符號(從下往上前向計算)。右----激活函數被參數化為B-spline,它允許在粗粒度和細粒度網格之間切換。
現在我們有了一個KAN的原型,其計算圖由Eq.(2.1)精確指定,如圖0.1 b所示(輸入維數 n n n=2),表現為一個兩層神經網絡,激活函數放置在邊緣而不是節點上(對節點進行簡單求和),中間層寬度為 2 n n n + 1。
如前所述,這樣的網絡被認為太簡單,無法在實踐中任意地用光滑spline近似任何函數!因此,作者希望將KAN擴展到更寬和更深。目前還不清楚如何使KAN更深,因為Kolmogorov-Arnold表示對應于兩層的KAN。然而,目前還沒有一個“一般化”版本的定理對應于更深層次的KAN。
作者注意到MLP和KAN之間的類比時,突破出現了。在MLP中,一旦定義了一個層(它由線性變換和非線性組成),我們就可以堆疊更多的層來使網絡更深。為了建立更深的KAN,我們應該首先回答:什么是KAN layer?事實證明,具有 n i n n_{in} nin?-dim輸入和 n o u t n_{out} nout?-dim輸出的KAN層可以定義為一維函數的矩陣: Φ = { ? q , p } , p = 1 , 2 , . . . , n i n , q = 1 , 2 , . . . , n o u t (2.2) \Phi=\left\{\phi_{q,p}\right\},\thinspace\thinspace p=1,2,...,n_{in},\thinspace\thinspace q=1,2,...,n_{out}\tag{2.2} Φ={?q,p?},p=1,2,...,nin?,q=1,2,...,nout?(2.2)其中 ? q , p \phi_{q,p} ?q,p?有可學習參數,在Kolmogov-Arnold定理中,內部函數形成一個 n i n = n , n o u t = 2 n + 1 n_{in}=n,n_{out}=2n+1 nin?=n,nout?=2n+1的KAN層,外部函數形成一個 n i n = 2 n + 1 , n o u t = n n_{in}=2n+1,n_{out}=n nin?=2n+1,nout?=n的KAN層。因此,Eq.(2.1)中的Kolmogov-Arnold表示只是兩個KAN層的簡單組合。現在很清楚擁有更深的Kolmogorov-Arnold表示意味著什么:簡單地堆疊更多的KAN層!
作者引入一些符號。這段話有點技術性,但我們可以參考圖2.2-左來獲得具體的例子和直觀的理解。KAN的形狀由整數數組表示: [ n 0 , n 1 , . . . , n L ] [n_{0},n_{1},...,n_{L}] [n0?,n1?,...,nL?]其中, n i n_{i} ni?是計算圖第 i i i層的節點數量。用 ( l , i ) (l,i) (l,i)表示第 l l l層的第 i i i個神經元,用 x l , i x_{l,i} xl,i?表示 ( l , i ) (l,i) (l,i)神經元的激活值。在第 l l l層和第 l + 1 l+1 l+1層之間,有 n l n l + 1 n_{l}n_{l+1} nl?nl+1?個激活函數:連接 ( l , j ) (l, j) (l,j)和 ( l + 1 , i ) (l +1, i) (l+1,i)的激活函數表示為 ? l , i , j , l = 0 , . . . , L ? 1 , i = 1 , . . . , n l + 1 , j = 1 , . . . , n l \phi_{l,i,j},\thinspace\thinspace l=0,...,L-1,\thinspace\thinspace i=1,...,n_{l+1},\thinspace\thinspace j=1,...,n_{l} ?l,i,j?,l=0,...,L?1,i=1,...,nl+1?,j=1,...,nl?其中, ? l , i , j \phi_{l,i,j} ?l,i,j?的激活之前是 x l , i x_{l,i} xl,i?, ? l , i , j \phi_{l,i,j} ?l,i,j?的激活之后被記為 x ~ l , i , j = ? l , i , j ( x l , i ) \widetilde{x}_{l,i,j}=\phi_{l,i,j}(x_{l,i}) x l,i,j?=?l,i,j?(xl,i?)。 ( l + 1 , j ) (l + 1, j) (l+1,j)神經元的激活值就是所有傳入的后激活值之和: x l + 1 , j = ∑ i = 1 n l x ~ l , i , j = ∑ i = 1 n l ? l , i , j ( x l , i ) , j = 1 , . . . , n l + 1 (2.5) x_{l+1,j}=\sum_{i=1}^{n_{l}}\widetilde{x}_{l,i,j}=\sum_{i=1}^{n_{l}}\phi_{l,i,j}(x_{l,i}),\thinspace\thinspace j=1,...,n_{l+1}\tag{2.5} xl+1,j?=i=1∑nl??x l,i,j?=i=1∑nl???l,i,j?(xl,i?),j=1,...,nl+1?(2.5)在矩陣形式中,有:
其中, Φ l \Phi_{l} Φl?是第 l l l層KAN對應的函數矩陣。一般的KAN網絡是 L L L層的組合:給定一個輸入向量 x 0 ∈ R n 0 \textbf{x}_{0}\in\mathbb{R}^{n_{0}} x0?∈Rn0?,KAN的輸出為: K A N ( x ) = ( Φ L ? 1 ° Φ L ? 2 ° ? ? ? ° Φ 1 ° Φ 0 ) x KAN(\textbf{x})=(\Phi_{L-1}\circ\Phi_{L-2}\circ\cdot\cdot\cdot\circ\Phi_{1}\circ\Phi_{0})\textbf{x} KAN(x)=(ΦL?1?°ΦL?2?°???°Φ1?°Φ0?)x也可以重寫上述方程,假設輸出維數 n L = 1 n_{L}=1 nL?=1,則有 f ( x ) = K A N ( x ) f(\textbf{x})=KAN(\textbf{x}) f(x)=KAN(x):
原始Kolmogorov-Arnold表示Eq.(2.1)對應于形狀為 [ n , 2 n + 1 , 1 ] [n, 2n + 1,1] [n,2n+1,1]的2層KAN。所有的運算都是可微的,所以我們可以用反向傳播來訓練KAN。為了比較,一個MLP可以寫成仿射變換 W \textbf{W} W和非線性 σ σ σ的交織: M L P ( x ) = ( W L ? 1 ° σ ° W L ? 2 ° σ ° ? ? ? ° W 1 ° σ ° W 0 ) x MLP(\textbf{x})=(\textbf{W}_{L-1}\circ\sigma\circ\textbf{W}_{L-2}\circ\sigma\circ\cdot\cdot\cdot\circ\textbf{W}_{1}\circ\sigma\circ\textbf{W}_{0})\textbf{x} MLP(x)=(WL?1?°σ°WL?2?°σ°???°W1?°σ°W0?)x很明顯,MLP將線性變換和非線性分別處理為 W \textbf{W} W和 σ σ σ,而KAN在 Φ Φ Φ中將它們一起處理。在圖0.1 ?和(d)中,可視化了三層MLP和三層KAN,以澄清它們的區別。
實現細節
1.殘差激活函數
我們包含一個基函數 b ( x ) b(x) b(x)(類似于殘差連接),使得激活函數 ? ( x ) \phi(x) ?(x)是基函數 b ( x ) b(x) b(x)和spline函數的和: ? ( x ) = w ( b ( x ) + s p l i n e ( x ) ) \phi(x)=w(b(x)+spline(x)) ?(x)=w(b(x)+spline(x))設置 b ( x ) = s i l u ( x ) = x 1 + e ? x b(x)=silu(x)=\frac{x}{1+e^{-x}} b(x)=silu(x)=1+e?xx?在大部分情況下,spline函數被參數化為B-spline的線性組合: s p l i n e ( x ) = ∑ i c i B i ( x ) spline(x)=\sum_{i}c_{i}B_{i}(x) spline(x)=i∑?ci?Bi?(x)其中, c i c_{i} ci?是可學習的。原則上 w w w是多余的,因為它可以被吸收到 b ( x ) b(x) b(x)和 s p l i n e ( x ) spline(x) spline(x)中。然而,我們仍然包括這個 w w w因子,以更好地控制激活函數的總體大小。
2.初始化scales
每個激活函數初始化為 s p l i n e ( x ) ≈ 0 spline(x)\approx 0 spline(x)≈0
3.更新spline grids
作者根據輸入激活動態更新每個網格。
化簡KAN使其可解釋
這個想法是從一個足夠大的KAN開始,用稀疏性正則化訓練它,然后進行修剪。作者將證明這些經過修剪的KAN比未經過修剪的KAN更易于解釋。
1.稀疏化
對于MLP,線性權重的L1正則化用于支持稀疏性。KAN可以適應這個高層次的想法,但需要兩個修改:
- 在KAN中沒有線性的“權重”。線性權重被可學習的激活函數所取代,因此我們應該定義這些激活函數的L1范數。
- 實驗發現L1不足以使KAN稀疏化;相反,一個額外的熵正則化是必要的。
定義激活函數 ? \phi ?的L1范數為在 N p N_{p} Np?輸入上的平均幅度: ∣ ? ∣ 1 = 1 N p ∑ s = 1 N p ∣ ? ( x ( s ) ) ∣ |\phi|_{1}=\frac{1}{N_{p}}\sum_{s=1}^{N_{p}}|\phi(x^{(s)})| ∣?∣1?=Np?1?s=1∑Np??∣?(x(s))∣然后對于輸入 n i n n_{in} nin?和輸出 n o u t n_{out} nout?的KAN layer Φ \Phi Φ,其L1范數為: ∣ Φ ∣ 1 = ∑ i = 1 n i n ∑ j = 1 n o u t ∣ ? i , j ∣ 1 |\Phi|_{1}=\sum_{i=1}^{n_{in}}\sum_{j=1}^{n_{out}}|\phi_{i,j}|_{1} ∣Φ∣1?=i=1∑nin??j=1∑nout??∣?i,j?∣1?因此,訓練目標為: l t o t a l = l p r e d + λ ∑ l = 0 L ? 1 ∣ Φ l ∣ 1 l_{total}=l_{pred}+\lambda\sum_{l=0}^{L-1}|\Phi_{l}|_{1} ltotal?=lpred?+λl=0∑L?1?∣Φl?∣1?
2.可視化
當可視化KAN時,將激活函數的透明度設置為與 t a n h ( β A l , i , j ) tanh(β A_{l,i,j}) tanh(βAl,i,j?)成比例,其中 β = 3 β = 3 β=3。因此,貢獻較小的函數會逐漸消失,讓我們專注于重要的函數。
3.修剪
在使用稀疏化懲罰進行訓練后,我們可能還希望將網絡修剪成更小的子網。作者在節點級別(而不是在邊緣級別)對KANs進行稀疏化。對于每個節點(假設是第1層的第 i i i個神經元),作者將其傳入和傳出的分數定義為: I l , i = m a x k ( ∣ ? l ? 1 , k , i ∣ 1 ) , O l , i = m a x j ( ∣ ? l + 1 , j , i ∣ 1 ) I_{l,i}=max_{k}(|\phi_{l-1,k,i}|_{1}),O_{l,i}=max_{j}(|\phi_{l+1,j,i}|_{1}) Il,i?=maxk?(∣?l?1,k,i?∣1?),Ol,i?=maxj?(∣?l+1,j,i?∣1?)如果傳入和傳出的分數都大于閾值超參數 θ = 1 0 ? 2 \theta=10^{-2} θ=10?2,則認為節點是重要的。不重要的節點被修剪。
4.符號化
在我們懷疑某些激活函數實際上是符號的情況下,例如,cos或log,作者提供了一個接口來將它們設置為指定的符號形式。
思考
KAN的設計看起來更像是為了求解物理公式,尤其是可解釋性符號化部分的內容。在修剪的時候,我們可以發現保留的節點是隨著輸入動態改變的。盡管我一開始總在想KAN相比MLP有什么缺陷,但是現在我真的覺得KAN是代替MLP的一個未來方案。