1st author: Ryutaro Tanno
video: Video from London ML meetup
paper: Adaptive Neural Trees ICML 2019
code: rtanno21609/AdaptiveNeuralTrees: Adaptive Neural Trees
背景
在機器學習領域,神經網絡(NNs)憑借其強大的表示學習能力,在諸多應用中取得了顯著成果,然而其“黑箱”特性和預設架構的局限性也常為人詬病。與之相對,決策樹(DTs)以其良好的可解釋性、自適應結構和輕量級推理等優勢受到青睞,但在特征工程和決策函數簡單性方面存在不足。
本文旨在深入解析 2019 年國際機器學習大會 (ICML) 論文《自適應神經樹 (Adaptive Neural Trees, ANTs)》,該研究提出了一種創新性的方法,旨在融合這兩類模型的優點。ANTs 將神經網絡的表示學習能力深度融入決策樹的結構中,體現在以下幾個方面:
- 邊的特征轉換: 數據在樹的路徑上傳遞過程中,通過神經網絡進行特征轉換。
- 內部節點的路由函數: 神經網絡學習如何將數據路由到合適的子節點。
- 葉子節點的最終預測: 神經網絡負責進行最終的預測。
ANTs 引入了一種基于反向傳播的自適應架構生長算法。這意味著樹的結構并非預先設定,而是能夠根據數據的特性進行動態生長和調整。
這種融合模式帶來的優勢顯而易見:
- 表示學習的深度滲透: 神經網絡在邊緣和路由函數中的應用,使得數據在樹的層層傳遞中,其表示(representation)得以不斷優化和深化。與傳統決策樹在原始特征空間進行劃分不同,ANTs 在學習到的特征空間上進行劃分,顯著提升了模型的表達能力。
- 架構的自適應性: 這是 ANTs 的核心優勢。傳統的神經網絡架構固定,需要人工設計和調參;決策樹雖然結構自適應,但其“硬劃分”的貪婪性質可能導致局部最優。ANTs 的生長算法結合了決策樹的局部決策和神經網絡的端到端優化,使得模型能夠根據數據的規模和復雜性動態調整其深度和寬度,避免在小數據集上過度參數化,同時在大數據集上充分挖掘深層結構。
- 輕量級推理(條件計算): 類似于決策樹,ANTs 在推理時僅激活從根節點到葉子節點路徑上的少量參數。與需要激活所有參數的傳統全連接神經網絡相比,這在資源受限的場景下具有顯著優勢。
ANTs 的核心機制
模型組件
ANTs 的核心抽象是一個二元組 ( T , O ) (T, O) (T,O),其中 T T T 定義了模型的拓撲結構(一棵二叉樹),而 O O O 則是一組作用于這棵樹上的可微分操作。
-
拓撲結構 T = ( N , E ) T = (\mathcal{N}, \mathcal{E}) T=(N,E):
- N \mathcal{N} N: 所有節點的集合,分為內部節點 N i n t \mathcal{N}_{int} Nint? 和葉子節點 N l e a f \mathcal{N}_{leaf} Nleaf?。
- E \mathcal{E} E: 邊的集合。特別地,每條邊上承載著對數據進行變換的“生產線”。
- 每個內部節點 j ∈ N i n t j \in \mathcal{N}_{int} j∈Nint? 有兩個子節點 l e f t ( j ) left(j) left(j) 和 r i g h t ( j ) right(j) right(j)。
-
核心操作 O = ( R , T , S ) O = (\mathcal{R}, \mathcal{T}, \mathcal{S}) O=(R,T,S): 這是 ANTs 最特別的設計,它突破了傳統決策樹的簡化模式,將神經網絡的能力注入到樹的每一個關鍵環節。
-
路由函數 r j θ ∈ R r_j^\theta \in \mathcal{R} rjθ?∈R: 每個內部節點 j ∈ N i n t j \in \mathcal{N}_{int} j∈Nint? 都配備一個路由器。圖 Figure 1 中白色節點。
- 功能: 接收來自父節點的特征表示 x j ∈ X j x_j \in \mathcal{X}_j xj?∈Xj?,輸出一個 [ 0 , 1 ] [0, 1] [0,1] 范圍內的標量,表示樣本流向左子節點的概率。
- 參數: 由 θ \theta θ 參數化。
- 決策方式: 采用隨機路由(Stochastic Routing),即決策是根據伯努利分布 B e r n o u l l i ( r j θ ( x j ) ) Bernoulli(r_j^\theta(x_j)) Bernoulli(rjθ?(xj?)) 采樣得出(1 去左,0 去右)。這保證了路由函數是可微分的,允許梯度流過。
- 實現: 論文中提到可以是小型卷積神經網絡(CNN)或多層感知機(MLP)。這與傳統決策樹中簡單的軸對齊(axis-aligned)劃分函數形成鮮明對比,使得路由決策本身也能學習到復雜的特征表示。
-
變換函數 t ψ ∈ T t^\psi \in \mathcal{T} tψ∈T: 樹的每一條邊 e ∈ E e \in \mathcal{E} e∈E 都帶有一個或一組變換模塊。圖 Figure 1 中邊上的黑色小點。
- 功能: 對流經的特征表示進行非線性變換。例如,一個卷積層加 ReLU 激活函數。
- 參數: 由 ψ \psi ψ 參數化。
- 核心意義: 這是 ANTs 與傳統決策樹(如 SDTs)最顯著的區別之一。傳統決策樹的邊通常是恒等函數,數據在樹中傳遞時其特征表示不變。而 ANTs 的邊能夠“加深”(deepen),學習到更豐富、更抽象的分層表示(Hierarchical Representations)。這意味著每一條從根到葉的路徑,本身就是一條“深度神經網絡流水線”。
-
求解器 s l ? ∈ S s_l^\phi \in \mathcal{S} sl??∈S: 每個葉子節點 l ∈ N l e a f l \in \mathcal{N}_{leaf} l∈Nleaf? 配備一個求解器。圖 Figure 1 中葉子節點。
- 功能: 接收來自父節點的變換后的特征 x l ∈ X l x_l \in \mathcal{X}_l xl?∈Xl?,并輸出對目標變量 y y y 的預測分布 p ( y ∣ x ) p(y|x) p(y∣x)。
- 參數: 由 ? \phi ? 參數化。
- 實現: 對于分類任務,可以是特征空間上的線性分類器。
Figure 1
-
圖 Figure 1 左圖中紅色陰影表示了數據 x x x 經過一系列路由函數,最終到求解器 s 4 ? s_4^\phi s4?? 的路徑。
前向預測
ANTs 將條件分布 p ( y ∣ x ) p(y|x) p(y∣x) 建模為一個分層混合專家模型(Hierarchical Mixture of Experts, HMEs),每個“專家”對應一條從根到葉的路徑。模型的總參數集為 Θ = ( θ , ψ , ? ) \Theta = (\theta, \psi, \phi) Θ=(θ,ψ,?)。
給定輸入 x x x,預測分布為:
p ( y ∣ x , Θ ) = ∑ l = 1 L π l θ , ψ ( x ) p l ? , ψ ( y ) p ( y ∣ x , Θ ) = ∑ l = 1 L p ( z l = 1 ∣ x , θ , ψ ) ? Leaf-assignment?prob.? π l θ , ψ p ( y ∣ x , z l = 1 , ? , ψ ) ? Leaf-specific?prediction.? p l ? , ψ ( 1 ) p(y|x, \Theta) = \sum_{l=1}^{L} \pi_l^{\theta,\psi}(x) p_l^{\phi,\psi}(y)\\ \begin{aligned} &p(\mathbf{y}|\mathbf{x},\Theta)=\sum_{l=1}^L\underbrace{p(z_l=1|\mathbf{x},\boldsymbol{\theta},\boldsymbol{\psi})}_{\text{Leaf-assignment prob. }\pi_l^{\boldsymbol{\theta},\boldsymbol{\psi}}}\underbrace{p(\mathbf{y}|\mathbf{x},z_l=1,\boldsymbol{\phi},\boldsymbol{\psi})}_{\text{Leaf-specific prediction. }p_l^{\boldsymbol{\phi},\boldsymbol{\psi}}} \quad (1) \end{aligned} p(y∣x,Θ)=l=1∑L?πlθ,ψ?(x)pl?,ψ?(y)?p(y∣x,Θ)=l=1∑L?Leaf-assignment?prob.?πlθ,ψ? p(zl?=1∣x,θ,ψ)??Leaf-specific?prediction.?pl?,ψ? p(y∣x,zl?=1,?,ψ)??(1)?
其中:
- L L L 是葉子節點的總數, z = { 0 , 1 } L \mathbf z =\{0,1\}^L z={0,1}L 是 L L L 維的 onehot 向量, z l = 1 z_l=1 zl?=1 表示使用 z l z_l zl? 為葉節點。
- π l θ , ψ ( x ) : = p ( z l = 1 ∣ x , ψ , θ ) \pi_l^{\theta,\psi}(x) := p(z_l=1|x, \psi, \theta) πlθ,ψ?(x):=p(zl?=1∣x,ψ,θ) 是輸入 x x x 被分配到葉子節點 l l l 的路徑概率,由從根到葉 l l l 的唯一路徑 P l \mathcal{P}_l Pl? 上所有路由器的決策概率的乘積給出。
π l ψ , θ ( x ) = ∏ r j θ ∈ P l r j θ ( x j ψ ) I [ l is?left?child?of? j ] ? ( 1 ? r j θ ( x j ψ ) ) 1 ? I [ l is?left?child?of? j ] \pi_l^{\psi,\theta}(x) = \prod_{r_j^{\theta} \in \mathcal{P}_l} r_j^{\theta}(x_j^{\psi})^{ \mathbb{I}[l \text{ is left child of } j]} \cdot (1-r_j^{\theta}(x_j^{\psi}))^{1 - \mathbb{I}[l \text{ is left child of } j]} πlψ,θ?(x)=rjθ?∈Pl?∏?rjθ?(xjψ?)I[l?is?left?child?of?j]?(1?rjθ?(xjψ?))1?I[l?is?left?child?of?j]
這里的 x j ψ x_j^{\psi} xjψ? 是輸入 x x x 經過從根到節點 j j j 的所有變換函數組合后得到的特征表示。如果從根到節點 j j j 的路徑上的變換函數序列是 t e 1 ψ , t e 2 ψ , … , t e n ψ t_{e_1}^\psi, t_{e_2}^\psi, \dots, t_{e_n}^\psi te1?ψ?,te2?ψ?,…,ten?ψ?,那么:
x j ψ : = ( t e n ψ ° ? ° t e 2 ψ ° t e 1 ψ ) ( x ) x_j^{\psi} := (t_{e_n}^{\psi} \circ \dots \circ t_{e_2}^{\psi} \circ t_{e_1}^{\psi})(x) xjψ?:=(ten?ψ?°?°te2?ψ?°te1?ψ?)(x)
° \circ ° 是函數的復合運算。 - p l ? , ψ ( y ) : = p ( y ∣ x , z l = 1 , ? , ψ ) p_l^{\phi,\psi}(y) := p(y|x, z_l=1, \phi, \psi) pl?,ψ?(y):=p(y∣x,zl?=1,?,ψ) 是葉子節點 l l l 的局部預測,由其求解器 s l ? s_l^\phi sl?? 在變換后的輸入特征 x p a r e n t ( l ) ψ x_{parent(l)}^\psi xparent(l)ψ? (或 x l ψ x_l^\psi xlψ?)上計算得出。
推斷策略:
- 多路徑推斷(Multi-path inference): 使用公式 (1) 計算所有葉子節點的加權平均預測。計算成本較高,因為它需要遍歷樹的所有分支。
- 單路徑推斷(Single-path inference): 只根據路由器最高置信度的路徑(貪婪遍歷)選擇一條從根到葉的路徑進行計算和預測,只激活模型參數的一個子集。實驗證明,由于路由器置信度通常接近 0 或 1,單路徑推斷能很好地近似多路徑推斷。
訓練與優化
ANTs 的訓練分為兩個階段,這體現了模型對架構自適應性的追求。
-
生長階段(Growth phase): 學習模型架構 T T T。
- 初始化: 從一個簡單的根節點開始。
- 迭代過程: 采用廣度優先搜索(BFS),對當前所有葉子節點逐一進行評估。
- 三種生長選項: 對于每個葉子節點,模型評估三種可能的局部架構修改,如 Figure 1 右圖所示:
- (1) “Split data”(分裂數據): 添加一個新的路由器和兩個新的葉子節點(左右子節點)。新分支上的變換函數初始化為恒等函數。
- (2) “Deepen transform”(深化變換): 在當前葉子節點對應的傳入邊上添加一個新的變換模塊,并替換舊的求解器為一個新的求解器。
- (3) “Keep”(保持): 不對當前節點進行任何修改。
- 局部優化: 對于選項 (1) 和 (2) 產生的新的模塊,僅對其參數進行局部優化(通過最小化驗證集上的 NLL)。固定已有模塊的參數,減少計算量。
- 選擇標準: 選擇驗證集 NLL 表現最好的選項。如果性能提升,則接受修改并繼續生長;否則,執行 “Keep” 選項。
- 終止條件: 直到沒有更多的“分裂數據”或“深化變換”操作能通過驗證測試。
生長階段是 ANTs 的核心。它賦予了模型在“變得更深”或“變得更寬”(劃分數據)之間進行選擇的自由。局部優化雖然可能導致次優決策,但效率高,尤其適用于大型模型,并可被后續精煉階段修正。這可以看作是一種受約束的神經架構搜索(NAS)過程,只不過搜索空間被限定在樹形結構上,且搜索是增量的。
-
精煉階段(Refinement phase): 調優全局參數 O O O。
- 目標: 一旦樹的拓撲結構 T T T 在生長階段確定,進入精煉階段。
- 優化方式: 對整個 ANT 模型的所有參數 ( θ , ψ , ? ) (\theta, \psi, \phi) (θ,ψ,?) 進行全局優化。同樣使用 NLL (負對數似然) 作為目標函數,通過端到端的反向傳播和梯度下降進行。
- 意義: 修正生長階段中由于局部優化可能導致的次優參數。實驗表明,精煉階段能顯著改善模型的泛化誤差,甚至能“剪枝”掉一些冗余或不必要的路徑,使路由器決策更加集中。
損失函數 (NLL):
? log ? p ( Y ∣ X , Θ ) = ? ∑ n = 1 N log ? ( ∑ l = 1 L π l θ , ψ ( x ( n ) ) p l ? , ψ ( y ( n ) ) ) -\log p(\mathbf{Y}|\mathbf{X},\Theta)=-\sum_{n=1}^N\log\:(\sum_{l=1}^L\pi_l^{\boldsymbol{\theta},\boldsymbol{\psi}}(\mathbf{x}^{(n)})\:p_l^{\boldsymbol{\phi},\boldsymbol{\psi}}(\mathbf{y}^{(n)})) ?logp(Y∣X,Θ)=?n=1∑N?log(l=1∑L?πlθ,ψ?(x(n))pl?,ψ?(y(n)))
由于所有組件(路由器、變換器、求解器)都是可微分的,因此可以使用標準的基于梯度的優化算法。
通過這兩階段的優化,ANTs 不僅能找到適應數據的樹形結構,還能精細地調整結構內各個神經網絡組件的參數,從而在表示學習和架構學習之間實現協同。這種“結構生成”與“參數優化”的解耦再耦合,是其區別于一般混合模型的關鍵。
實驗測評
論文在 SARCOS(多元回歸)、MNIST 和 CIFAR-10(圖像分類)這三個不同類型的數據集上進行了實驗。核心結論清晰而有力:
-
競爭力:ANTs 在 SARCOS 數據集上實現了最低的均方誤差(MSE),即便與最先進的基于樹的模型(如梯度提升樹 GBTs)和各種 MLP 相比,也展現出領先地位。在圖像分類任務上,ANTs 顯著優于傳統隨機森林(RFs)和梯度提升樹(GBTs),并且與一些輕量級、非殘差連接的 CNN 模型性能相當,甚至更優。
-
效率與權衡:
- 單路徑推理:一個關鍵的發現,ANTs 的單路徑推理(僅激活從根到葉的一條路徑)與多路徑推理(聚合所有葉子節點的預測)在準確性上差異極小(分類誤差通常小于 0.1%),但在計算開銷(FLOPS)和激活參數量上則大幅降低。這得益于路由器學習到的高置信度拆分概率(即 r j θ ( x j ) r_j^\theta(x_j) rjθ?(xj?) 趨近于 0 或 1),使得模型在推斷時能果斷地“選擇”一條路徑。
- 參數效率:在某些配置下,ANTs 甚至能以更少的參數量達到甚至超越 LeNet-5 等傳統 CNN 模型在 MNIST 上的性能。這表明樹狀的層級共享和分離機制,能夠有效地增強計算和預測性能。
-
消融實驗:
- “無路由器”(no R):此時 ANTs 退化為某種形式的自適應生長的純神經網絡。在所有數據集上,其性能均顯著低于完整的 ANTs。
- “無變換器”(no T):此時 ANTs 退化為一種帶有可學習路由器的軟決策樹(SDT/HME)。其性能下降更為劇烈,特別是在圖像數據集上,誤差大幅飆升。
總結
ANTs 不僅僅是一個性能優秀的模型,更重要的是,它提供了一種看待深度學習和決策樹的新視角,并引出了許多值得深思的問題。
- 可解釋的特征分離:論文中提及,ANTs 能夠學習到“有意義的層級劃分”,例如將圖像分為“自然物體”和“人造物體”等。這為“黑箱”的神經網絡提供了一扇窺探內部決策邏輯的窗戶。樹的結構本身就具有一定的可解釋性,而學習到的路由函數進一步增強了這種可解釋性。這對于許多對模型透明度有要求的領域(如醫療、金融)具有重要意義。
- 架構自適應性:奧卡姆剃刀的實踐:ANTs 的生長機制使其能夠根據訓練數據的大小和復雜性自適應地構建模型架構。這意味著對于小數據集,它不會過度生長導致過擬合;對于大數據集,它能夠探索更深、更復雜的結構。這本質上是在實踐機器學習的奧卡姆剃刀原理:用最簡單的模型解釋數據。
- “軟”劃分的勝利:傳統決策樹訓練難點在于“硬劃分”導致的損失函數不可微。ANTs 通過將路由器輸出解釋為伯努利分布的概率,并使用混合專家模型(HMEs)的框架,巧妙地避開了這一難題,使得整個樹結構都可微,從而能夠進行端到端的梯度下降優化。這為傳統決策樹的擴展開辟了新的道路。
局限與展望
- 貪婪生長與全局最優:雖然生長階段比預設架構更靈活,但其本質仍是局部貪婪搜索。每次只在當前葉子節點進行局部最優決策。這可能導致模型陷入局部最優,無法發現全局上更優異的架構。未來的工作可以探索更全局的架構搜索策略,例如基于強化學習或進化算法的樹結構搜索,或者像決策森林那樣,訓練多個 ANTs 進行集成。
- 計算成本:生長階段的局部優化雖然效率相對高,但對于非常大的數據集和復雜的模塊,每次評估三個選項并進行局部訓練仍然是計算密集型的。如何在保持自適應性的同時,進一步提高架構搜索的效率,仍是挑戰。
- 模塊的通用性:論文中使用的基礎模塊(如卷積層、MLP)是通用的。未來是否可以設計更適合樹狀結構、更輕量、更具表達力的神經模塊,以進一步提升效率和性能?
- 可解釋性與復雜性的平衡:雖然 ANTs 比純粹的 NNs 更具解釋性,但隨著樹的深度和復雜度的增加,特別是每個節點和邊都包含深層 NN 時,完全理解特定輸入在樹中的決策路徑和特征轉換仍可能面臨挑戰。如何在增加模型能力的同時,維持甚至增強其可解釋性,是一個持續的研究方向。