Tabnet介紹(Decision Manifolds)和PyTorch TabNet之TabNetRegressor

Tabnet介紹(Decision Manifolds)和PyTorch TabNet之TabNetRegressor

  • Decision Manifolds
  • TabNet
    • 1.核心思想
    • 2. 架構組成
    • 3. 工作流程
    • 4. 優點
  • PyTorch TabNet
    • TabNetRegressor參數
      • 1. 模型相關參數
        • `n_d`
        • `n_a`
        • `n_steps`
        • `gamma`
        • `cat_idxs`
        • `cat_dims`
        • `cat_emb_dim`
      • 2. 訓練相關參數
        • `optimizer_fn`
        • `optimizer_params`
        • `scheduler_fn`
        • `scheduler_params`
        • `mask_type`
      • 3. 其他參數
        • `seed`
        • `verbose`
        • `device_name`
    • TabNetRegressor.fit 參數詳解
      • 1. 核心訓練數據參數
        • `X_train`
        • `y_train`
      • 2. 驗證數據參數
        • `eval_set`
        • `eval_name`
        • `eval_metric`
      • 3. 訓練控制參數
        • `max_epochs`
        • `patience`
        • `batch_size`
        • `virtual_batch_size`
        • `num_workers`
      • 4. 回調與日志參數
        • `drop_last`
        • `callbacks`
        • `from_unsupervised`
        • `loss_fn `
    • TabNetRegressor.predict 參數
      • 1. 核心參數
        • `X`
        • `batch_size`
        • `num_workers`
        • `from_unsupervised`
        • `return_proba`
        • `verbose`
      • 2. 返回值
  • 參考

Decision Manifolds

指在決策樹模型中,數據點通過一系列超平面的分割形成的決策邊界。具體來說:

  • 在決策樹模型中:決策流形由一系列垂直于特征軸的超平面組成,這些超平面將數據空間劃分為多個區域,每個區域代表一個決策區域。例如,一個簡單的決策樹可能通過比較特征值與某個閾值來決定數據點的分類或回歸結果。
  • 適用于表格數據:由于表格數據通常具有結構化特征,決策流形的這種分割方式能夠有效地捕捉數據中的線性關系,尤其是在特征維度較低的情況下,能夠實現較好的分類或回歸性能。
  • 可解釋性:決策流形的直觀分割使得模型的決策過程易于理解,每個分割超平面都對應一個特定的特征閾值,便于人類解釋和理解模型的決策依據。
  • 對比神經網絡:與依賴于高維非線性映射的神經網絡不同,決策流形提供了一種更直接、更簡單的決策方式,這在某些情況下使得決策樹模型在表格數據上表現更佳。

此外,決策流形的概念也與模型的歸納偏差相關,即模型在學習過程中傾向于生成符合某種先驗知識或規則的解。對于表格數據,決策樹模型的決策流形天生具備線性分割的歸納偏差,這有助于它在沒有過多參數調整的情況下,仍然能夠有效地學習到數據的結構。

TabNet

一種專門為結構化數據(表格數據)設計的深度學習模型,由 Google 提出。它通過注意力機制和可解釋性設計,解決了傳統神經網絡在處理表格數據時透明度不足的問題。以下是 TabNet 的詳細解析:

1.核心思想

  • 稀疏注意力機制:TabNet 使用稀疏注意力機制來選擇輸入特征的子集進行處理,從而減少計算量并提高模型的可解釋性。
  • 逐步特征選擇:模型逐步選擇重要的特征,并忽略不相關的特征,這使得 TabNet 能夠專注于對任務最重要的特征。

2. 架構組成

TabNet 的架構主要由以下幾個部分組成:

  • Feature Transformer:負責對輸入特征進行非線性變換。
  • Attention Mechanism:通過注意力機制選擇重要的特征子集。
  • Masking Mechanism:生成掩碼,決定哪些特征被選中參與下一步計算。
  • Decoder:用于預測任務(如分類或回歸)。

3. 工作流程

  • 輸入層:將表格數據輸入到模型中。
  • 特征變換:通過 Feature Transformer 對特征進行非線性變換。
  • 注意力選擇:使用注意力機制選擇重要的特征子集。
  • 掩碼生成:生成掩碼以決定哪些特征參與下一步計算。
  • 輸出層:通過 Decoder 輸出預測結果。

4. 優點

  • 可解釋性:由于稀疏注意力機制,TabNet 可以明確指出哪些特征對預測結果有貢獻。
  • 高效性:通過選擇重要特征,減少了不必要的計算。
  • 靈活性:適用于多種任務,包括分類和回歸。

PyTorch TabNet

pytorch_tabnet 是基于 PyTorch 實現的 TabNet 模型庫,專為結構化數據(表格數據)設計。它提供了高效的特征選擇和可解釋性功能,適用于分類和回歸任務。

TabNetRegressor參數

TabNetRegressor 是一種基于 TabNet 架構的回歸模型,適用于結構化數據的回歸任務。以下是其主要參數的詳細說明:


1. 模型相關參數

n_d
  • 類型: int
  • 默認值: 8
  • 描述: 表示決策路徑中每個步驟的維度大小。較大的值會增加模型的表達能力,但也可能導致過擬合。
n_a
  • 類型: int
  • 默認值: 8
  • 描述: 表示注意力機制的維度大小。與 n_d 類似,控制模型的復雜度。
n_steps
  • 類型: int
  • 默認值: 3
  • 描述: 表示 TabNet 模型中的步數(steps),即模型在每輪迭代中選擇特征的次數。更大的值可以捕獲更多的特征組合。
gamma
  • 類型: float
  • 默認值: 1.3
  • 描述: 控制特征稀疏性的超參數。較大的值會導致更少的特征被選中。
cat_idxs
  • 類型: list[int]
  • 默認值: []
  • 描述: 指定分類特征的索引列表。如果數據集中包含分類變量,需要通過此參數指定。
cat_dims
  • 類型: list[int]
  • 默認值: []
  • 描述: 指定分類特征的類別數量。與 cat_idxs 配合使用,用于定義分類變量的嵌入維度。
cat_emb_dim
  • 類型: int 或 list[int]
  • 默認值: 1
  • 描述: 分類特征的嵌入維度。如果為整數,則所有分類特征共享相同的嵌入維度;如果為列表,則每個分類特征可以有不同的嵌入維度。

2. 訓練相關參數

optimizer_fn
  • 類型: function
  • 默認值: Adam
  • 描述: 優化器函數,默認使用 PyTorch 的 Adam 優化器。
optimizer_params
  • 類型: dict
  • 默認值: {‘lr’: 0.02}
  • 描述: 傳遞給優化器的參數字典,例如學習率(lr)。
scheduler_fn
  • 類型: function
  • 默認值: None
  • 描述: 學習率調度器函數。如果需要動態調整學習率,可以通過此參數指定。
scheduler_params
  • 類型: dict
  • 默認值: None
  • 描述: 傳遞給學習率調度器的參數字典。
mask_type
  • 類型: str
  • 默認值: “sparsemax”
  • 描述: 特征選擇掩碼的類型,可選值為 "sparsemax""entmax""sparsemax" 更加常用。

mask_type 參數用于指定特征選擇掩碼的類型,控制模型在每個決策步驟中選擇特征的方式。


3. 其他參數

seed
  • 類型: int
  • 默認值: 0
  • 描述: 隨機種子,用于確保結果的可重復性。
verbose
  • 類型: int
  • 默認值: 1
  • 描述: 控制輸出日志的詳細程度。0 表示靜默模式,1 表示普通模式,2 表示調試模式。
device_name
  • 類型: str
  • 默認值: “auto”
  • 描述: 指定計算設備。"auto" 會自動檢測是否有 GPU 可用。

TabNetRegressor.fit 參數詳解

1. 核心訓練數據參數

X_train
  • 必須為 numpy.ndarray 格式,不支持直接傳入 pandas.DataFrame 或 pandas.Series。
y_train
  • 必須為 numpy.ndarray 格式,且形狀需調整為 (n_samples, 1)。

2. 驗證數據參數

eval_set
  • 類型: list[tuple]
  • 默認值: None
  • 描述: 驗證集數據列表,格式為 [(X_valid, y_valid)]。支持多個驗證集。
eval_name
  • 類型: list[str]
  • 默認值: None
  • 描述: 每個驗證集的名稱,便于在日志中區分不同驗證集。
eval_metric
  • 類型: list[str] 或 callable
  • 默認值: [‘rmse’]
  • 描述: 評估指標,可選值包括 'rmse''mse' 等。也可以傳入自定義的評估函數。

3. 訓練控制參數

max_epochs
  • 類型: int
  • 默認值: 100
  • 描述: 最大訓練輪數。如果提前收斂,則可能在達到最大輪數之前停止。
patience
  • 類型: int
  • 默認值: 10
  • 描述: 早停機制的耐心值。如果驗證集性能在連續 patience 輪內沒有提升,則停止訓練。
batch_size
  • 類型: int
  • 默認值: 1024
  • 描述: 每次迭代的批量大小。較大的批量可能會加速訓練,但需要更多的內存。
virtual_batch_size
  • 類型: int
  • 默認值: 128
  • 描述: 虛擬批量大小,用于模擬小批量梯度下降,減少內存占用。
num_workers
  • 類型: int
  • 默認值: 0
  • 描述: 數據加載器中的工作線程數。設置為 0 表示使用主進程加載數據。

4. 回調與日志參數

drop_last
  • 類型: bool
  • 默認值: False
  • 描述: 是否丟棄最后一個不完整的批量數據。
callbacks
  • 類型: list[Callback]
  • 默認值: None
  • 描述: 自定義回調函數列表,例如學習率調度器、早停等。
from_unsupervised
  • 類型: TabNetPretrainer
  • 默認值: None
  • 描述: 如果提供了預訓練的 TabNetPretrainer 模型,則從無監督預訓練階段繼續訓練。

loss_fn
  • 類型: Callable(可調用對象,例如函數或類方法)
  • 默認值: 默認使用均方誤差(MSE)損失函數。
  • 描述: 允許用戶自定義訓練過程中使用的損失函數。

TabNetRegressor.predict 參數

1. 核心參數

X
  • 數據格式應與訓練時使用的 X_train 一致。
batch_size
  • 類型: int
  • 默認值: 1024
  • 描述: 每次預測的批量大小。較大的批量可能會加速預測過程,但需要更多的內存。
num_workers
  • 類型: int
  • 默認值: 0
  • 描述: 數據加載器中的工作線程數。設置為 0 表示使用主進程加載數據。
from_unsupervised
  • 類型: TabNetPretrainer
  • 默認值: None
  • 描述: 如果提供了預訓練的 TabNetPretrainer 模型,則從無監督預訓練階段繼續預測。
return_proba
  • 類型: bool
  • 默認值: False
  • 描述: 是否返回預測的概率分布(僅適用于分類任務)。對于回歸任務,此參數無效。
verbose
  • 類型: int
  • 默認值: 1
  • 描述: 控制輸出日志的詳細程度。0 表示靜默模式,1 表示普通模式,2 表示調試模式。

2. 返回值

  • 當 TabNetRegressor.predict 返回的預測結果是一個二維數組(例如 (n_samples, 1))時,可以使用 flatten 方法將其轉換為一維數組 (n_samples,)。

skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=2025)
for trn_idx, val_idx in skf.split(train_feats[feature_names], train_feats['V_bins']):X_train = train_feats.loc[trn_idx][feature_names].valuesY_train = train_feats.loc[trn_idx]['V'].values.reshape(-1, 1)X_val = train_feats.loc[val_idx][feature_names].valuesY_val = train_feats.loc[val_idx]['V'].values.reshape(-1, 1)print("Train Num: ", len(Y_train))print("Val Num: ", len(Y_val))model = tab_model.TabNetRegressor(n_d = 8,n_a = 8,n_steps = 1,gamma = 1.6,lambda_sparse = 6e-5,n_independent = 4,n_shared = 2,optimizer_fn = torch.optim.AdamW,optimizer_params = dict(lr=0.025),scheduler_fn = torch.optim.lr_scheduler.ReduceLROnPlateau,scheduler_params = dict(mode='min', factor=0.6, patience=3),mask_type = 'entmax',seed=2025,device_name = 'cuda',verbose = 1,)model.fit(X_train=X_train,y_train=Y_train,eval_set=[(X_val, Y_val)],eval_name=['val'],eval_metric=['rmse'],patience=10,max_epochs=100,batch_size=512, virtual_batch_size=128, num_workers=1, drop_last=False,)pred = model.predict()pred = pred.flatten()

參考

1.https://github.com/dreamquark-ai/tabnet
2.https://github.com/google-research/google-research/tree/master/tabnet
3.https://arxiv.org/abs/1908.07442

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/900978.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/900978.shtml
英文地址,請注明出處:http://en.pswp.cn/news/900978.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

圖像變換方式區別對比(Opencv)

1. 變換示例 import cv2 import matplotlib.pyplot as plotimg cv2.imread(url) img_cut img[100:200, 200:300] img_rsize cv2.resize(img, (50, 50)) (hight,width) img.shape[:2] rotate_matrix cv2.getRotationMatrix2D((hight//2, width//2), 50, 1) img_wa cv2.wa…

Navicat分組、查詢分享

1、分組 有些項目業務表比較多,多達幾百張,如果通過人眼看,很容易頭暈。這時候可以通過Navicat表分組來進行分類。 使用場景 按版本分組按業務功能分組 創建分組 示例:按版本分組,可以將1.0版本的表放到1.0中。 分組…

大模型在初治CLL成人患者診療全流程風險預測與方案制定中的應用研究

目錄 一、緒論 1.1 研究背景與意義 1.2 國內外研究現狀 1.3 研究目的與內容 二、大模型技術與慢性淋巴細胞白血病相關知識 2.1 大模型技術原理與特點 2.2 慢性淋巴細胞白血病的病理生理與診療現狀 三、術前風險預測與手術方案制定 3.1 術前數據收集與預處理 3.2 大模…

for循環的優化方式、循環的種類、使用及平替方案。

本篇文章主要圍繞for循環,來講解循環處理數據中常見的六種方式及其特點,性能。通過本篇文章你可以快速了解循環的概念,以及循環在實際使用過程中的調優方案。 作者:任聰聰 日期:2025年4月11日 一、循環的種類 1.1 默認有以下類型 原始 for 循環 for(i = 0;i<10;i++){…

穿透三層內網VPC1

網絡拓撲: 打開入口web服務 信息收集發現漏洞CVE-2024-4577 PHP CGI Windows平臺遠程代碼執行漏洞&#xff08;CVE-2024-4577&#xff09;復現_cve-2024-4577漏洞復現-CSDN博客 利用POC&#xff1a; 執行成功&#xff0c;那么直接上傳馬子&#xff0c;注意&#xff0c;這里要…

【計算機網絡】同步操作 vs 異步操作:核心區別與實戰場景解析

&#x1f4cc; 引言 在網絡通信和分布式系統中&#xff0c;**同步&#xff08;Synchronous&#xff09;和異步&#xff08;Asynchronous&#xff09;**是兩種基礎卻易混淆的操作模式。本文將通過代碼示例、生活類比和對比表格&#xff0c;幫你徹底理解它們的區別與應用場景。 1…

TensorFlow充分并行化使用CPU

關鍵字&#xff1a;TensorFlow 并行化、TensorFlow CPU多線程 場景&#xff1a;在沒有GPU或者GPU性能一般、環境不可用的機器上&#xff0c;對于多核CPU&#xff0c;有時TensorFlow或上層的Keras默認并沒有完全利用機器的計算能力&#xff08;CPU占用沒有接近100%&#xff09;…

Kubernetes容器編排與云原生實踐

第一部分&#xff1a;Kubernetes基礎架構與核心原理 第1章 容器技術的演進與Kubernetes的誕生 1.1 虛擬化技術的三次革命 物理機時代&#xff1a;資源浪費嚴重&#xff0c;利用率不足15% 虛擬機突破&#xff1a;VMware與Hyper-V實現硬件虛擬化&#xff0c;利用率提升至50% …

Windows 錄音格式為什么是 M4A?M4A 怎樣轉為 MP3 格式

M4A 格式憑借其高效的壓縮技術和卓越的音質表現脫穎而出&#xff0c;成為了包括 Windows 在內的眾多操作系統默認的錄音格式選擇。然而&#xff0c;盡管 M4A 格式擁有諸多優點&#xff0c;不同的應用場景有時需要將這些文件轉換為其他格式以滿足特定需求。 本文將探討 M4A 格式…

Qt之OpenGL使用Qt封裝好的著色器和編譯器

代碼 #include "sunopengl.h"sunOpengl::sunOpengl(QWidget *parent) {}unsigned int VBO,VAO; float vertices[]{0.5f,0.5f,0.0f,0.5f,-0.5f,0.0f,-0.5f,-0.5f,0.0f,-0.5f,0.5f,0.0f };unsigned int indices[]{0,1,3,1,2,3, }; unsigned int EBO; sunOpengl::~sunO…

HCIP-17 BGP基礎2

HCIP-17 BGP基礎2 一、bgp的路由黑洞問題 1.bgp的同步功能 ipv4-family unicast IPV4的地址簇 undo synchronization 關閉BGP同步功能 bgp的同步功能原理 當邊界路由器從ibgp鄰居收到一條路由后&#xff0c;會使用該路由和igp路由表進行比較。 如果在igp路由表中存在…

leetcode_15. 三數之和_java

15. 三數之和https://leetcode.cn/problems/3sum/ 1、題目 給你一個整數數組 nums &#xff0c;判斷是否存在三元組 [nums[i], nums[j], nums[k]] 滿足 i ! j、i ! k 且 j ! k &#xff0c;同時還滿足 nums[i] nums[j] nums[k] 0 。請你返回所有和為 0 且不重復的三元組。…

Open Interpreter:重新定義人機交互的開源革命

引言 在人工智能技術蓬勃發展的今天&#xff0c;人機交互的方式正經歷著前所未有的變革。Open Interpreter&#xff0c;作為一個開源項目&#xff0c;正在重新定義我們與計算機的互動方式。它允許大型語言模型&#xff08;LLMs&#xff09;在本地運行代碼&#xff0c;通過自然…

【JavaScript】錯誤處理與調試

個人主頁&#xff1a;Guiat 歸屬專欄&#xff1a;HTML CSS JavaScript 文章目錄 1. JavaScript 錯誤處理基礎1.1 錯誤類型1.2 try...catch 語句 2. 錯誤拋出與自定義錯誤2.1 throw 語句2.2 自定義錯誤類型 3. 異步錯誤處理3.1 Promise 錯誤處理3.2 async/await 錯誤處理 4. 調試…

算法基礎模板

高精度加法 #include <bits/stdc.h> using namespace std; const int N10005; int A[N],B[N],C[N],al,bl,cl; void add(int A[],int B[],int C[]) {for(int icl-1;~i;i--){C[cl]A[i]B[i];C[cl1]C[cl]/10;C[cl]%10;}if(C[cl])cl; } int main() {string a,b;cin>>a&…

自行搭建一個Git倉庫托管平臺

1.安裝Git sudo apt install git 2.Git本地倉庫創建&#xff08;自己選擇一個文件夾&#xff09; git init 這里我在 /home/test 下面初始化了代碼倉庫 1. 首先在倉庫中新建一個txt文件&#xff0c;并輸入一些內容 2. 將文件添加到倉庫 git add test.txt 執行之后沒有任何輸…

[MySQL]數據庫與表創建

歡迎來到啾啾的博客&#x1f431;。 這是一個致力于構建完善 Java 程序員知識體系的博客&#x1f4da;。 它記錄學習點滴&#xff0c;分享工作思考和實用技巧&#xff0c;偶爾也分享一些雜談&#x1f4ac;。 歡迎評論交流&#xff0c;感謝您的閱讀&#x1f604;。 本篇簡單記錄…

相機回調函數為靜態函數原因

在注冊相機SDK的回調函數時&#xff0c;是否需要設置為靜態函數取決于具體SDK的設計要求&#xff0c;但通常需要遵循以下原則&#xff1a; 1. 必須使用靜態函數的情況 當相機SDK是C語言接口或要求普通函數指針時&#xff0c;回調必須聲明為靜態成員函數或全局函數&#xff1a;…

《Vue Router實戰教程》4.路由的匹配語法

歡迎觀看《Vue Router 實戰&#xff08;第4版&#xff09;》視頻課程 路由的匹配語法 大多數應用都會使用 /about 這樣的靜態路由和 /users/:userId 這樣的動態路由&#xff0c;就像我們剛才在動態路由匹配中看到的那樣&#xff0c;但是 Vue Router 可以提供更多的方式&#…

Debezium報錯處理系列之第128篇:增量快照報錯java.lang.OutOfMemoryError: Java heap space

Debezium報錯處理系列之第128篇:增量快照報錯java.lang.OutOfMemoryError: Java heap space 一、完整報錯二、錯誤原因三、解決方法Debezium從入門到精通系列之:研究Debezium技術遇到的各種錯誤解決方法匯總: Debezium從入門到精通系列之:百篇系列文章匯總之研究Debezium技…