表格數據是按行和列組織的電子表格形式,在從生物醫學、粒子物理到經濟學和氣候科學等各個科學領域中無處不在 。基于表格其余列來填充標簽列缺失值的基本預測任務,對于生物醫學風險模型、藥物研發和材料科學等各種應用至關重要。盡管深度學習徹底改變了從原始數據中學習的方式,并帶來了許多備受矚目的成功案例,但在過去20年里,梯度提升決策樹在表格數據處理方面占據主導地位。在此,作者提出表格先驗數據擬合網絡(Tabular Prior-data Fitted Network,TabPFN),這是一種表格基礎模型,在樣本數量高達10,000的數據集上,它的表現遠超以往所有方法,且訓練時間大幅縮短。在分類場景中,TabPFN能在2.8秒內超越經過4小時調優的最強基線模型。作為一種基于生成式Transformer的基礎模型,該模型還支持微調、數據生成、密度估計以及學習可復用的嵌入。TabPFN是一種通過在數百萬個合成數據集上進行學習得到的學習算法,展示了這種方法在算法開發方面的強大能力。通過提升不同領域的建模能力,TabPFN有潛力加速科學發現,并在各個領域改進重要決策。
來自:Accurate predictions on small data with a tabular foundation model, Nature, 2025
工程地址:https://github.com/PriorLabs/tabpfn
目錄
- 背景概述
- 上下文學習
- 為表格設計的架構
- 基于因果模型的合成數據
- 在線使用TabPFN
- Reference
背景概述
在人工智能的發展歷程中,人工創建的算法組件逐漸被性能更優的端到端學習組件所取代。計算機視覺中手工設計的特征,如尺度不變特征變換(SIFT)和方向梯度直方圖(HOG),已被學習得到的卷積所替代;自然語言處理中基于語法的方法已被學習得到的Transformer取代;游戲領域中定制的開局和殘局庫設計,也已被端到端學習策略所超越。在此,作者將這種端到端學習擴展到廣泛存在的表格數據領域。
表格數據的多樣性使其有別于文本和圖像等未經處理的數據。例如,在語言建模中,一個單詞的含義在不同文檔中是一致的,但在表格數據集中,相同的值可能有著截然不同的含義。比如,一個藥物研發數據集可能記錄化學性質,而材料科學領域的另一個數據集可能記錄熱學和電學性質。這種專業性導致了大量較小的獨立數據集和相關模型的涌現。舉例來說,在廣受歡迎的表格數據基準測試網站openml.org上,截至撰寫本文時,76% 的數據集包含的行數少于10,000行。
傳統上,深度學習方法在處理表格數據時困難重重,這是由于數據集之間的異質性以及原始數據本身的異質性:表格包含的列(也稱為特征)具有不同的尺度和類型(布爾型、類別型、有序型、整型、浮點型),還存在數據不平衡、缺失數據、無關特征、異常值等問題。這使得非深度學習方法,如基于樹的模型,成為目前為止最有力的競爭者。
然而,這些傳統的機器學習模型存在一些缺點。如果不進行大量修改,它們在分布外預測方面表現不佳,并且很難將一個數據集的知識遷移到另一個數據集 。最后,它們很難與神經網絡結合,因為它們無法傳播梯度。
作為一種解決方案,TabPFN是一種針對中小規模表格數據的基礎模型。這種新的有監督表格學習方法可應用于任何中小規模的數據集,在樣本數高達10,000、特征數達500的數據集上表現卓越。在基準測試中,TabPFN僅需一次前向傳遞,就能顯著超越包括梯度提升決策樹在內的最先進基線模型,即便這些基線模型經過4小時的調優,TabPFN在分類任務中的速度提升可達5140倍,在回歸任務中可達3000倍。最后,作者展示了TabPFN的多種基礎模型特性,包括微調、生成能力和密度估計。
上下文學習
TabPFN利用上下文學習(ICL)機制(正是這一機制使得大語言模型展現出驚人的性能),生成了一種完全通過學習得到的強大表格預測算法。雖然ICL最早在大語言模型中被發現,但最近的研究表明,Transformer可以通過ICL學習諸如邏輯回歸這樣的簡單算法。先驗數據擬合網絡(PFNs)已經證明,即使是像高斯過程和貝葉斯神經網絡這樣的復雜算法,也可以通過ICL進行近似 [1]。ICL使我們能夠探索更廣泛的可能算法空間,包括那些不存在封閉形式解的情況。
作者在TabPFN的初步版本 [2] 基礎上進行開發。該初步版本原則上證明了上下文學習在表格數據中的適用性,但存在許多局限性,導致其在大多數情況下無法應用。經過一系列改進,新的TabPFN能夠處理規模大50倍的數據集;支持回歸任務、分類數據和缺失值處理;并且對無關特征和異常值具有魯棒性。
TabPFN背后的核心思想是生成大量合成表格數據集,然后訓練一個基于Transformer的神經網絡,使其學會解決這些合成預測任務。傳統方法在處理諸如缺失值等數據挑戰時需要手動設計解決方案,而TabPFN通過解決包含這些挑戰的合成任務,自主學習有效的策略。通過生成展示期望行為的多樣化合成數據集來設計所需的算法行為,然后訓練模型對滿足該行為的算法進行編碼。這將算法設計過程從編寫明確指令轉變為定義輸入-輸出示例,為在各個領域創建算法開辟了可能性。在此,將這種方法應用于具有高影響力的表格學習領域,生成了一種強大的表格預測算法。
上下文學習(ICL)方法與標準的有監督深度學習有著根本區別。通常情況下,模型是針對每個數據集進行訓練的,根據手工設計的權重更新算法(如Adam算法)在單個樣本或批次上更新模型參數。在推理時,將訓練好的模型應用于測試樣本。相比之下,作者的方法是跨數據集進行訓練的,并且在推理時應用于整個數據集,而不是單個樣本。在應用于真實世界的數據集之前,模型會在代表不同預測任務的數百萬個合成數據集上進行一次預訓練。在推理時,模型接收一個包含有標記訓練樣本和無標記測試樣本的未知數據集,并通過一次神經網絡前向傳遞在這個數據集上進行訓練和預測。
圖1和圖2概述了作者的方法:
- 數據生成:作者定義了一個生成過程(稱為 prior)來合成各種表格數據集,這些數據集的特征與target之間具有不同的關系,旨在涵蓋模型可能遇到的各種潛在場景。作者從這個生成過程中采樣數百萬個數據集。對于每個數據集,會將一部分樣本的目標值進行mask處理,模擬有監督預測問題。作者在先驗設計上的更多細節在 “基于因果模型的合成數據” 部分展示(Synthetic data based on casual models)。
- 預訓練:作者訓練一個基于Transformer的模型,即先驗數據擬合網絡(PFN),以輸入特征和未掩碼樣本為上下文,預測所有合成數據集的掩碼target值 。這一步在模型開發過程中只進行一次,目的是學習一種通用的學習算法,可用于預測任何數據集。
- 真實世界預測:訓練好的模型現在可以應用于任意未知的真實世界數據集。訓練樣本作為上下文提供給模型,模型通過上下文學習來預測這些未知數據集的標簽。
正如 [1] 中所述,該方法可以被視為對由合成數據集定義的先驗進行貝葉斯預測的近似。經過訓練的先驗數據擬合網絡(PFN)將逼近后驗預測分布 p ( y ^ t e s t ∣ X t e s t , X t r a i n , y t r a i n ) p(\hat{y}_{test } | X_{test }, X_{train }, y_{train }) p(y^?test?∣Xtest?,Xtrain?,ytrain?),從而針對PFN預訓練期間使用的人工數據集的指定分布返回貝葉斯預測結果。
- 圖1:a. TabPFN預訓練和使用的概述。b. TabPFN架構。作者訓練一個模型來解決超過1億個合成任務。作者的架構是對標準Transformer編碼器的改編,以適應表格中遇到的二維數據。
- 圖2:prior的示意圖。a. 對于每個數據集,作者首先對高級超參數進行采樣。b. 基于這些超參數,作者構建一個結構因果模型,該模型對合成數據集的計算函數進行編碼。每個節點持有一個向量,計算圖中的每條邊根據其中一種連接類型來實現一個函數。在步驟1中,作者使用隨機噪聲變量生成初始化數據,將其輸入到圖的根節點,并為每個待生成的樣本通過計算圖進行傳播。在步驟2中,作者在圖中隨機采樣特征節點和目標節點的位置,分別標記為F和T。在步驟3中,作者在采樣得到的特征節點和目標節點位置提取中間數據表示。在步驟4中,作者對提取的數據進行后處理。c. 獲取最終的數據集。繪制特征對之間的相互作用圖,節點顏色代表樣本的類別。
為表格設計的架構
Transformer架構目前是靈活深度學習和基礎模型青睞的架構。Transformer模型作用于序列,通過注意力機制整合序列元素間的信息,使其能夠有效捕捉長距離依賴關系并學習數據中的復雜關系。盡管基于Transformer的模型可以應用于表格數據,但TabPFN解決了它們固有的兩個關鍵局限性。第一,由于Transformer是為處理序列而設計的,它將輸入數據視為單個序列,未利用表格結構。第二,機器學習模型常采用擬合-預測模式,即模型在訓練集上擬合一次,然后在多個測試數據集上重復使用。然而,基于Transformer的上下文學習(ICL)算法在單次傳遞中接收訓練數據和測試數據,因此會同時進行訓練和預測。這樣一來,當重復使用擬合好的模型時,就必須重新對訓練集進行計算。
為了更好地利用表格結構,作者提出了一種為表格中的每個單元格分配單獨表示的架構。如圖1b所示,架構采用雙向注意力機制,每個單元格先關注其所在行中的其他特征(即當前樣本),然后關注其所在列中的相同特征(即所有其他樣本)。這種設計使架構對樣本和特征的順序都具有不變性,并且在樣本數量和特征數量方面,相比訓練時遇到的表格,能夠更高效地訓練并外推到更大的表格。
為了減少擬合-預測設置中每個測試樣本對訓練集的重復計算,模型可以將訓練樣本和測試樣本的推理分開。這使我們能夠在訓練集上進行一次ICL,保存結果狀態,并將其重新用于多個測試集的推理。對于有10,000個訓練樣本和10個特征的數據集,優化后的訓練狀態緩存使CPU上的推理速度提高了約300倍(從32秒提升到0.1秒),GPU上提高了6倍。當特征數量增加10倍(即100個特征)時,CPU上的速度提升至800倍,GPU上提升至30倍。這些測量僅關注核心推理過程,不包括“推理細節”部分詳細介紹的預處理和集成步驟。GPU上速度提升較低是由于其大規模并行架構未得到充分利用。
通過使用半精度計算層歸一化、FlashAttention、激活檢查點和狀態的順序計算,進一步優化了架構的內存和計算需求。優化將內存需求降低為原來的四分之一,每個單元格的內存占用不到1000字節。這使得在單個H100 GPU上能夠對多達5000萬個單元格的數據集(例如,500萬行×10個特征)進行預測。
對于回歸任務,使用分段常數輸出分布,這使模型能夠預測目標值的概率分布,而不是單個值,例如可以處理雙峰分布。
基于因果模型的合成數據
TabPFN的性能依賴于生成合適的合成訓練數據集,這些數據集能夠捕捉現實世界表格數據的特征和挑戰。為了生成這樣的數據集,作者開發了一種基于結構因果模型(SCMs)的方法 。SCMs提供了一個正式的框架,用于表示數據背后的因果關系和生成過程。通過使用合成數據,而非大量公開的表格數據,避免了基礎模型常見的問題,如隱私和版權侵犯、訓練數據被測試數據污染,或是數據可用性受限等問題。
如圖2所示,生成流程首先對高級超參數進行采樣,例如數據集大小、特征數量和難度級別,以此來控制每個合成數據集的整體屬性。在這些超參數的指導下,構建一個有向無環圖,該圖明確了數據集背后的因果結構。
為了生成數據集中的每個樣本,作者將隨機生成的噪聲(稱為初始化數據)通過因果圖的根節點進行傳播。這種初始化數據是從隨機正態分布或均勻分布中采樣得到的,樣本之間具有不同程度的非獨立性。當這些數據在計算圖的邊中傳播時,作者應用一系列不同的計算映射:具有線性或非線性激活函數(例如,sigmoid、ReLU(修正線性單元)、取模、正弦)的小型神經網絡,用于生成分類特征的離散化機制,以及用于編碼局部基于規則的依賴關系的決策樹結構。在每條邊上,添加高斯噪聲,為生成的數據引入不確定性。作者將每個節點的中間數據表示保存下來,以便后續檢索。
在遍歷因果圖之后,在采樣得到的特征節點和目標節點處提取中間表示,從而得到一個由特征值和相關目標值組成的樣本。
通過將各種數據挑戰和復雜性納入合成數據集,作者創建了一個訓練環境,使TabPFN能夠開發出處理現實世界數據集中類似問題的策略。例如,考慮表格數據中常見的缺失值問題。在合成數據生成過程中,讓TabPFN接觸具有不同缺失值模式和比例的合成數據集,模型就能學習到處理缺失值的有效方法,這些方法可以推廣到現實世界的數據集中。作者應用后處理技術進一步增強數據的真實性,并檢驗所學預測算法的穩健性。這包括使用Kumaraswamy分布進行變換,引入復雜的非線性扭曲,以及模擬離散特征的量化處理。
通過這個生成過程,在每次模型訓練時創建了大約1億個合成數據集,每個數據集都具有獨特的因果結構、特征類型和功能特性。
在線使用TabPFN
在線使用地址:https://ux.priorlabs.ai/predict,需要注冊后才能使用。
TabPFN適合處理表格數據,本質是填補指定target列中的缺失值,從而作為預測結果。以下面表格為例,target列被指定了readmitted,其他列被指定為feature columns,任務被設置為分類,模型將通過填補表中的target列缺失值實現預測:
Reference
[1] Transformers can do Bayesian inference, ICLR, 2022
[2] TabPFN: a transformer that solves small tabular classification problems in a second, ICLR, 2023