【摘要】由于圖神經網絡 (GNN) 通常會隨著分布變化而出現性能下降,因此分布外 (OOD) 泛化在圖學習中引起了越來越多的關注。挑戰在于,圖上的分布變化涉及節點之間錯綜復雜的互連,并且數據中通常不存在環境標簽。在本文中,我們采用自下而上的數據生成視角,并通過因果分析揭示了一個關鍵觀察結果:GNN 在 OOD 泛化中失敗的關鍵在于來自環境的潛在混雜偏差。后者誤導模型利用自我圖特征與目標節點標簽之間的環境敏感相關性,導致在新的未見節點上出現不良的泛化。基于這一分析,我們引入了一種概念上簡單但原則性的方法,用于在節點級分布變化下訓練穩健的 GNN,而無需事先了解環境標簽。我們的方法采用了一種源自因果推理的新學習目標,該目標協調了環境估計器和專家混合 GNN 預測器。新方法可以抵消訓練數據中的混雜偏差,并促進學習可推廣的預測關系。大量實驗表明,我們的模型可以有效地增強各種分布偏移下的泛化能力,并且在圖 OOD 泛化基準上比最先進的方法提高高達 27.4% 的準確率。
原文:Graph Out-of-Distribution Generalization via Causal Intervention
地址:https://arxiv.org/abs/2402.11494
代碼:https://github.com/fannie1208/CaNet
出版:www 24
機構: 上海交通大學寫的這么辛苦,麻煩關注微信公眾號“碼農的科研筆記”!
1 研究問題
本文研究的核心問題是: 如何設計一個圖神經網絡模型,使其能夠在結點屬性分布發生變化時,仍然保持良好的泛化性能。
假設一個社交網絡中,用戶的愛好與其朋友的年齡分布密切相關。在大學生群體中,朋友都比較年輕的用戶往往更喜歡籃球運動。但是這種相關性可能只在大學生群體中成立,對于職場社交網絡LinkedIn,用戶的年齡與其愛好的相關性可能就很弱。如果我們基于大學生的社交網絡訓練了一個用于預測用戶愛好的圖神經網絡,那么將其直接應用于LinkedIn,可能會遇到泛化失敗的問題。
本文研究問題的特點和現有方法面臨的挑戰主要體現在以下幾個方面:
-
圖數據中的分布變化往往涉及到結點之間的復雜交互與關聯,需要模型能夠充分考慮不同結點的結構化特征。
-
在圖學習問題中,每個結點所處的環境信息通常是隱含的,難以直接獲取。這為模型從觀測數據中推斷有用的環境信息,以指導學習過程,帶來了障礙。
針對這些挑戰,本文提出了一種基于因果干預的"因果網絡(CaNet)"方法:
CaNet巧妙地借鑒了因果推理中的 do-calculus 思想,通過顯式地對環境變量建模,消除了隱含的混淆偏差。具體來說,它引入了一個環境估計器,負責基于輸入的局部子圖推斷可能的環境信息。同時,圖神經網絡的每一層都配備了一組 mixture-of-expert 的傳播單元,可以動態地根據推斷的環境選擇不同的傳播方式。通過環境估計器和圖神經網絡的協同優化,CaNet 可以自動地發現觀測數據中的穩定關系,同時避免捕獲那些容易受環境變化影響的虛假相關性。這一設計理念猶如為圖神經網絡裝上了一副"透視鏡",讓它對不可見的因果機制具備了感知和適應的能力。
2 研究方法
2.1 因果分析
論文首先從因果的角度來分析GNN面臨分布外泛化問題的根本原因。如圖2(a)所示,論文使用有向無環圖建模了節點的ego-graph特征、節點標簽和環境因素之間的因果依賴關系。可以看到,作為未觀測的混淆因素,會影響和的生成過程。當使用最大似然估計來訓練GNN時,由于忽略了的影響,模型會錯誤地學習到由某些特定的引起的和之間的強相關性(如在大學生群體中,"朋友年輕"和"喜歡打籃球"往往同時成立)。然而,這類相關性是不穩定的,一旦測試環境發生變化(如職場人士的社交網絡),先前學到的相關性便不再成立。這導致了GNN在分布外數據上的泛化性能顯著下降。
2.2 因果干預
為了消除環境因素的混淆偏差,進而提升GNN的分布外泛化性能,論文提出了一種基于因果干預的方法。借助后門調整公式,論文指出,優化干預分布而非觀測分布,可以有效避免環境因素的混淆。然而,的求解需要窮舉所有可能的環境,這在實際中是不可行的。為此,論文進一步引入變分推斷,得到的一個變分下界,如式(5)所示:
其中,是根據節點的ego-graph特征來推斷環境因素的估計器,是給定ego-graph特征和推斷的環境標簽來預測節點標簽的GNN。通過最大化該變分下界,可以得到協同優化的算法:環境估計器盡可能準確地推斷環境因素,同時要求推斷結果與ego-graph特征保持獨立;而GNN預測器則根據ego-graph特征和推斷的環境因素來預測節點標簽。通過協同學習,模型可以學習到與具體環境無關的穩定預測模式,進而提升分布外泛化性能。
2.3 模型實例化
在CaNet中,環境估計器將環境因素表示為一系列偽環境標簽向量,其中表示GNN的層數。如式(6-7)所示(太難打了,公式見原文),對于每個節點的第層表示,環境估計器首先計算該節點屬于每個偽環境的概率,然后通過Gumbel-Softmax技巧對重參數化,得到。
GNN預測器的核心是一個層的混合專家傳播網絡,每一層包含個專家分支。如式(8-9)所示,每個分支采用獨立的參數,并由推斷的偽環境標簽進行選擇。不同分支學習ego-graph特征的不同組合模式,賦予模型更強的表征能力。GNN預測器的最后一層輸出節點表示,并通過全連接層映射為節點標簽的預測值。
算法1總結了CaNet的前向計算和訓練優化流程。其中,環境估計器和GNN預測器通過梯度下降法交替優化,協同學習對分布外泛化有利的預測模式。
5 實驗
5.1 實驗場景介紹
該論文提出了一個處理圖神經網絡節點級別分布差異的因果干預方法CaNet。實驗主要在節點屬性預測任務中,驗證CaNet相比其他模型在訓練集和測試集節點分布不一致情況下的泛化優勢。同時通過消融實驗、超參數分析等進一步探究模型內部機制。
5.2 實驗設置
-
Datasets:使用Cora、Citeseer、Pubmed、Twitch、Arxiv、Elliptic等6個不同規模和屬性的節點預測數據集,通過時間屬性、子圖、動態快照等不同方式構建訓練集和測試集的分布差異
-
Baseline:ERM, IRM, DeepCoral, DANN, GroupDRO, Mixup等通用OOD方法;SR-GNN, EERM等圖數據OOD方法;均使用GCN和GAT作為編碼器骨干
-
Implementation details:基于PyTorch 1.13和PyG 2.1, Adam優化器,訓練500輪,網格搜索超參數
-
metric:Accuracy, ROC-AUC, macro F1
5.3 實驗結果
5.3.1 實驗一、不同數據集上的性能對比
目的:在多個數據集上驗證CaNet相比其他模型處理分布差異節點的優勢
涉及圖表:表1,表3,圖4
實驗細節概述:在Cora、Citeseer、Pubmed上測試合成特征和結構導致的分布差異;在Arxiv上測試不同時間的論文節點;在Twitch上測試不同子圖的節點;在Elliptic上測試不同時間快照的節點
結果:
-
CaNet在所有OOD測試集上顯著優于對應的基線,在ID測試集上也有競爭力的表現
-
在Cora和Citeseer上,CaNet在OOD數據的絕對性能接近ID數據
-
在Arxiv的跨時間差異最大的測試集上,CaNet超出次優baseline 14.1%和27.4%
-
在Elliptic的動態圖快照測試集上,CaNet平均超出次優baseline 12.16%
5.3.2 實驗二、消融實驗
目的:驗證正則化損失、層級環境推斷等關鍵組件的有效性
涉及圖表:圖5
實驗細節概述:去除正則化損失、使用復雜先驗分布、采用全局環境表示、使用非參數環境估計器等簡化變體
結果:
-
正則化損失和層級環境推斷能有效提升OOD性能
-
簡單先驗分布優于復雜先驗,更利于泛化
5.3.3 實驗三、超參數分析
目的:探究偽環境數K和溫度τ對模型性能的影響
涉及圖表:圖6
實驗細節概述:在Arxiv和Twitch上分別評估不同K和τ下模型在各OOD測試集的表現
結果:
-
性能對K不太敏感,過大或過小的K在Arxiv的OOD 2/3上可能降低性能
-
適中的τ(如1)效果最佳,過大的τ會導致性能下降
5.3.4 實驗四、可視化分析
目的:直觀展現不同分支學習到的權重模式差異
涉及圖表:圖7,8,9,10
實驗細節概述:可視化K=3時模型在Arxiv和Twitch上第一層和最后一層不同分支的權重矩陣
結果:不同分支權重有明顯差異,說明mixture-of-expert結構能學習到區分不同偽環境的表達模式,利于泛化
4 總結后記
本論文針對圖神經網絡在面對分布偏移時泛化能力較差的問題,從因果分析的角度揭示了其根源在于未觀測到的環境混淆因素。基于此分析,提出了一種通過因果干預改進圖神經網絡泛化性的方法CaNet。該方法引入了環境估計器和混合專家傳播網絡,可以在沒有先驗環境標簽的情況下,通過優化一個新的學習目標來捕獲對環境不敏感的預測關系,從而提高模型的分布外泛化能力。實驗結果表明,該模型在多個具有不同類型分布偏移的數據集上,相比現有方法可以顯著提升泛化性能,泛化準確率提升高達27.4%。
疑惑和想法:
-
除了節點層面的分布偏移,該方法是否可以推廣到處理圖層面或子圖層面的分布偏移?
-
環境估計器推斷出的偽環境標簽對應著什么物理含義?能否賦予它們可解釋性?
-
除了混合專家傳播網絡,是否可以設計其他形式的條件傳播機制來建模環境因素對節點表示的影響?
可借鑒的方法點:
-
從因果分析角度揭示模型泛化不足的根源,為診斷和改進其他類型的圖學習任務提供了新的思路。
-
通過優化包含對抗正則化項的目標函數來消除混淆偏差,可以推廣到其他需要增強魯棒性的機器學習場景。
-
環境估計器和條件傳播網絡的思想可以借鑒到圖預訓練等其他圖表示學習任務中,以建模不同圖之間的差異。