sklearn 交叉驗證迭代器
在 scikit-learn
(sklearn) 中,交叉驗證迭代器(Cross-Validation Iterators)是一組用于生成訓練集和驗證集索引的工具。它們是 model_selection
模塊的核心組件,決定了數據如何被分割,從而支持模型評估、超參數調優等任務。
這些迭代器實現了不同的數據劃分策略,以適應各種數據類型和問題場景。下面詳細介紹 sklearn 中主要的交叉驗證迭代器。
一、核心概念
所有交叉驗證迭代器都遵循相同的接口:
- 輸入:數據集大小
n_samples
。 - 輸出:一個生成器(generator),每次迭代返回一對
(train_indices, test_indices)
的 NumPy 數組。 - 用途:可用于
cross_val_score
,GridSearchCV
等函數的cv
參數。
二、主要交叉驗證迭代器
1. KFold
- 標準 K 折交叉驗證
用途:最基礎的 K 折 CV,適用于類別均衡的分類或回歸問題。
工作方式:
- 將數據集劃分為
k
個大小基本相等的折(folds)。 - 每次使用其中 1 折作為驗證集,其余
k-1
折作為訓練集。 - 重復
k
次,確保每折都恰好被用作一次驗證集。
參數:
n_splits
:折數,默認為 5。shuffle
:是否在劃分前打亂數據順序。建議設為True
,除非數據有時間順序。random_state
:隨機種子,確保結果可復現。
代碼示例:
from sklearn.model_selection import KFold
import numpy as npX = np.array([[1], [2], [3], [4], [5]])
y = np.array([1, 2, 3, 4, 5])kf = KFold(n_splits=3, shuffle=True, random_state=42)
for train_index, test_index in kf.split(X):print("TRAIN:", train_index, "TEST:", test_index)
2. StratifiedKFold
- 分層 K 折交叉驗證
用途:分類任務的首選,尤其當類別分布不均衡時。
工作方式:
與 KFold 類似,但確保每一折中各類別的比例與原始數據集大致相同。
避免某些折中某個類別樣本過少或缺失,導致評估偏差。
為什么重要?
例如:一個二分類數據集中正類占 10%。使用普通 KFold 可能在某折中正類樣本極少,導致模型無法學習或評估失真。
StratifiedKFold 保證每折中正類比例都接近 10%。
代碼示例:
python
深色版本
from sklearn.model_selection import StratifiedKFoldy = np.array([0, 0, 0, 1, 1]) # 不均衡數據skf = StratifiedKFold(n_splits=2, shuffle=True, random_state=42)
for train_index, test_index in skf.split(X, y):print("TRAIN:", train_index, "TEST:", test_index)print("Y_TRAIN:", y[train_index], "Y_TEST:", y[test_index])
3. LeaveOneOut (LOO)
- 留一法交叉驗證
用途:樣本量非常小(如 < 100)時使用。
工作方式:
每次留出一個樣本作為驗證集,其余所有樣本作為訓練集。
重復 n_samples 次。
優缺點:
? 幾乎無偏估計(訓練集最大)。
? 計算成本極高(訓練 n 次),且方差可能很大(單個樣本影響大)。
代碼示例:
python
深色版本
from sklearn.model_selection import LeaveOneOutloo = LeaveOneOut()
for train_index, test_index in loo.split(X):print("TRAIN:", train_index, "TEST:", test_index)
4. LeavePOut
- 留 P 法交叉驗證
用途:比 LOO 更一般化,但計算更昂貴。
工作方式:
每次留出 p 個樣本作為驗證集,其余所有樣本作為訓練集。
所有可能的 p 個樣本組合都會被嘗試,因此總次數為 C(n, p)。
p=1 時退化為 LOO。
注意:當 n 或 p 稍大時,組合數爆炸,極少在實際中使用。
5. ShuffleSplit
- 隨機劃分分割
用途:靈活的隨機抽樣 CV,適合大數據集或需要控制訓練/驗證比例時。
工作方式:
不強制使用所有樣本。
每次迭代從數據中隨機抽取指定比例作為訓練集,其余作為驗證集(可重疊)。
可指定迭代次數 n_splits。
參數:
n_splits:迭代次數。
train_size, test_size:訓練/驗證集比例。
優點:
可獨立控制訓練集大小。
適用于大數據,無需完整 K 折。
代碼示例:
python
深色版本
from sklearn.model_selection import ShuffleSplitss = ShuffleSplit(n_splits=3, test_size=0.25, random_state=0)
for train_index, test_index in ss.split(X):print("TRAIN:", train_index, "TEST:", test_index)
6. StratifiedShuffleSplit
- 分層隨機劃分
用途:ShuffleSplit 的分層版本,用于類別不均衡的分類任務。
工作方式:
在每次隨機劃分時,保持訓練集和驗證集中各類別的比例一致。
適用場景:
大數據集上的分層 CV。
需要固定驗證集大小且保持類別平衡。
7. GroupKFold
- 組 K 折交叉驗證
用途:當數據中存在組結構(如:同一用戶多次記錄、同一病人多個樣本),需確保同一組的數據不同時出現在訓練和驗證集中,防止數據泄露。
工作方式:
根據 groups 數組劃分,確保一個組的所有樣本要么全在訓練集,要么全在驗證集。
參數:
groups:長度為 n_samples 的數組,表示每個樣本所屬的組。
代碼示例:
python
深色版本
from sklearn.model_selection import GroupKFoldX = [0.1, 0.2, 2.2, 2.4, 2.3, 4.5, 5.7, 5.8]
y = [1, 1, 0, 0, 0, 1, 1, 1]
groups = [1, 1, 2, 2, 2, 3, 3, 3] # 3 個組gkf = GroupKFold(n_splits=3)
for train_index, test_index in gkf.split(X, y, groups):print("TRAIN:", train_index, "TEST:", test_index)print("GROUPS:", groups[test_index])
8. TimeSeriesSplit
- 時間序列交叉驗證
用途:處理時間序列數據,確保不使用未來數據預測過去。
工作方式:
按時間順序劃分。
每次迭代,訓練集是過去的數據,驗證集是接下來的一段數據。
訓練集逐漸增長(“前滾”交叉驗證)。
關鍵特性:
不打亂數據。
驗證集始終在訓練集之后。
代碼示例:
python
深色版本
from sklearn.model_selection import TimeSeriesSplittscv = TimeSeriesSplit(n_splits=3)
for train_index, test_index in tscv.split(X):print("TRAIN:", train_index, "TEST:", test_index)
輸出:
深色版本
TRAIN: [0 1 2] TEST: [3]
TRAIN: [0 1 2 3] TEST: [4]
三、如何選擇合適的 CV 迭代器?
場景 推薦迭代器
一般分類(類別均衡) KFold
分類(類別不均衡) ? StratifiedKFold
回歸任務 KFold 或 ShuffleSplit
小樣本數據 LeaveOneOut(謹慎使用)
大數據,靈活劃分 ShuffleSplit, StratifiedShuffleSplit
數據有組結構(避免泄露) GroupKFold, LeaveOneGroupOut
時間序列數據 ? TimeSeriesSplit
需要分層 + 隨機劃分 StratifiedShuffleSplit
四、使用建議
默認選擇:
分類:StratifiedKFold
回歸:KFold
設置 shuffle=True:除非數據有序(如時間序列),否則建議打亂。
固定 random_state:確保實驗可復現。
避免數據泄露:在使用 CV 時,任何數據預處理(如標準化、填充)都應在 CV 循環內部進行(使用 Pipeline)。
python
深色版本
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScalerpipe = Pipeline([('scaler', StandardScaler()),('clf', SVC())
])
在 cross_val_score 中使用 pipe,確保 scaler 只在訓練集上擬合
cross_val_score(pipe, X, y, cv=5)
總結
sklearn 的交叉驗證迭代器提供了豐富且靈活的工具,能夠適應從標準分類到時間序列、組數據等各種復雜場景。選擇合適的迭代器是獲得可靠、無偏模型評估的關鍵第一步。務必根據數據的結構和任務類型,選擇最匹配的 CV 策略。