GPPT: Graph Pre-training and Prompt Tuning to Generalize Graph Neural Networks
KDD22
推薦指數:#paper/??#?
動機
本文探討了圖神經網絡(GNN)在遷移學習中“預訓練-微調”框架的局限性及改進方向。現有方法通過預訓練(如邊預測、對比學習)學習可遷移的圖結構知識,在微調時將其應用于下游任務(如節點分類)。然而,預訓練目標與下游任務之間的差異(如二元邊預測與多類節點分類)導致知識傳遞低效甚至負遷移——微調效果可能遜于從頭訓練。傳統改進方案依賴為每個下游任務定制預訓練目標(目標工程),但需大量領域知識與試錯成本。
受自然語言處理(NLP)中提示(Prompt)技術的啟發,作者提出“預訓練-提示-微調”新范式,旨在通過任務重表述縮小預訓練與下游任務差異。例如,NLP通過添加語義模板將分類任務轉化為與預訓練一致的填空任務(如情感分類轉為預測掩碼詞)。然而,圖數據面臨兩大挑戰:
- 符號化圖數據適配難題:節點為抽象符號,無法直接套用基于文本模板的語義改寫。
- 提示設計的有效性:需結合圖結構(如節點鄰域信息)設計高效的提示函數,以提升分類等任務精度。
因此,本文核心研究問題聚焦于如何設計圖感知提示函數,以橋接預訓練與下游任務,從而高效激發預訓練模型的知識。該方向有望通過任務形式統一化提升預訓練模型的泛用性,減少對定制化目標工程的依賴,推動少樣本圖分析的進一步發展。
??
圖提示框架
Pre-train, Prompt, Fine-tune
Graph prompting function(圖提示函數)
v i ′ = f p r o m p t ( v i ) v_{i}^{\prime}=f_{\mathrm{prompt}}(v_{i}) vi′?=fprompt?(vi?), v i ′ v_i' vi′?和映射頭有相似的輸入形狀
Pairwise prompting function(成對提示函數)
v i ′ = f p r o m p t ( v i ) = [ T t a s k ( y ) , T s r t ( v i ) ] v_{i}^{\prime}=f_{\mathrm{prompt}}(v_{i})=[T_{\mathbf{task}}(y),T_{\mathbf{srt}}( v_{i})] vi′?=fprompt?(vi?)=[Ttask?(y),Tsrt?(vi?)]
T t a s k T_{task} Ttask?是下有任務的token, T s r c T_{src} Tsrc?是目標節點結構的token。前者由待分類節點的標簽得到,后者由目標節點周圍子圖表示,以提供更多的結構信息。很自然,可以利用函數來捕獲他們兩個的聯系
Prompt addition
[ y 1 , ? , y C ] [y_1,\cdots,y_C] [y1?,?,yC?]為C個類的prompt。自然可以構造token對: [ T t a s k ( y c ) , T s r t ( v i ) ] , f o r c = 1 , ? , C [T_{\mathrm{task}}(y_{c}),T_{\mathrm{srt}}(v_{i})],\mathrm{for~}c=1,\cdots,C [Ttask?(yc?),Tsrt?(vi?)],for?c=1,?,C
Prompt answer
對于每個token對,我們可以拼接,并將其放入預訓練的映射頭,如果目標節點 v i v_i vi? 與某類得到最高的鏈接概率,我們就將其歸為一類。
prompt tuning:
min ? θ , ? ∑ ( v i , y c ) L p r e ( p ? p r e ( T t a s k ( y c ) , T s r t ( v i ) ) ; g ( y c , v i ) ) . \min_{\theta,\phi}\sum_{(v_i,y_c)}\mathcal{L}^{\mathrm{pre}}(p_\phi^{\mathrm{pre}}(T_{\mathrm{task}}(y_c),T_{\mathrm{srt}}(v_i));g(y_c,v_i)). minθ,??∑(vi?,yc?)?Lpre(p?pre?(Ttask?(yc?),Tsrt?(vi?));g(yc?,vi?)).其中,g為真實的標簽函數
圖形提示功能設計
任務token的生成:
e c = T t a s k ( y c ) ∈ R d e_c=T_\mathrm{task}(y_c)\in\mathbb{R}^d ec?=Ttask?(yc?)∈Rd
E = [ e 1 , ? , e C ] ? ∈ R C × d E=[e_{1},\cdots,e_{C}]^{\top}\in\mathbb{R}^{C\times d} E=[e1?,?,eC?]?∈RC×d,C是類別數。
很自然,每個節點的token可以通過查詢如上的任務token得到自己的類別。很自然的是, T t a s k ( y c ) T_{\mathbf{task}}(y_c) Ttask?(yc?)最優應該是類 y c y_c yc?的中心。因此,我們通過聚類,來獲得初始的tasktoken:
- 利用可擴展聚類(比如metis)獲得M個類: { G 1 , ? , G M } \{\mathcal{G}_1,\cdots,\mathcal{G}_M\} {G1?,?,GM?},M是類別超參。
- 對于每個類,我們得到相應的task token: E m = [ e 1 m , ? , e C m ] ? ∈ R C × d E^m=[e_1^m,\cdots,e_C^m]^\top\in\mathbb{R}^{C\times d} Em=[e1m?,?,eCm?]?∈RC×d(怎么感覺有問題這一行表述)
- 給定集群 處節點 v i v_i vi? 的任務令牌 T t a s k ( y c ) T_{task}(y_c) Ttask?(yc?) ,它使用向量嵌入 e c m e_c^m ecm? 表示。
Structure Token Generation.(結構token的升成)
如果直接用節點v用于下游分類,會失去結構信息。因此我們使用 T s t r ( v i ) T_{\mathrm{str}}(v_i) Tstr?(vi?)來表示子圖結構,來涵蓋結構信息。在本文中,作者使用一階子圖來表示。
e v i = a i ? h i + ∑ v j ∈ N ( v i ) a j ? h j . e_{v_i}=a_i*h_i+\sum_{v_j\in\mathcal{N}(v_i)}a_j*h_j. evi??=ai??hi?+∑vj?∈N(vi?)?aj??hj?.
a通過注意力機制得到
Prompt 初始化以及正交約束:
直接使用隨機初始化肯定不太好,因此我們使用預訓練的GNN來初始化 E m = [ e 1 m , ? , e C m ] ? E^{m}=[e_{1}^{m},\cdots,e_{C}^{m}]^{\top} Em=[e1m?,?,eCm?]?。
因此,我們通過節點表示來初始化標記嵌入 e c m e^m_c ecm?,節點表示由集群 m 處 y c y_c yc?類的訓練節點給出。
不同類的中心的距離應該盡可能的打,因此有: L o = ∑ m ∥ E m ( E m ) ? ? I ∥ F 2 . \mathcal{L}_o=\sum_m\|E^m(E^m)^\top-I\|_F^2. Lo?=∑m?∥Em(Em)??I∥F2?.
損失:
min ? θ , ? , E 1 , ? , E M ∑ ( v i , y c ) L p r e ( p ? p r e ( e c m , e v i ) ; g ( y c , v i ) ) + λ L o , s . t . θ i n i t = θ p r e , ? i n i t = ? p r e . \begin{aligned}\min_{\theta,\phi,E^{1},\cdots,E^{M}}&\sum_{(v_{i},y_{c})}\mathcal{L}^{\mathrm{pre}}(p_{\phi}^{\mathrm{pre}}(e_{c}^{m},e_{v_{i}});g(y_{c},v_{i}))+\lambda\mathcal{L}_{o},\\\mathrm{s.t.}&\theta^{\mathrm{init}}=\theta^{\mathrm{pre}},\phi^{\mathrm{init}}=\phi^{\mathrm{pre}}.\end{aligned} θ,?,E1,?,EMmin?s.t.?(vi?,yc?)∑?Lpre(p?pre?(ecm?,evi??);g(yc?,vi?))+λLo?,θinit=θpre,?init=?pre.?
結果:
??
??
?