網格搜索(Grid Search)詳細教學
1. 什么是網格搜索?
在機器學習模型中,算法的**超參數(Hyperparameters)**對模型的表現起著決定性作用。比如:
KNN 的鄰居數量
n_neighbors
SVM 的懲罰系數
C
和核函數參數gamma
隨機森林的決策樹數量
n_estimators
這些超參數不會在訓練過程中自動學習得到,而是需要我們人為設定。網格搜索(Grid Search)是一種最常見的超參數優化方法:
它通過遍歷給定參數網格中的所有組合,使用交叉驗證來評估每組參數的效果,最終選出表現最優的一組。
通俗理解:
👉 網格搜索 = 窮舉法找最佳參數。
2. 網格搜索的核心思想
定義參數范圍(網格):例如
C=[0.1, 1, 10]
,gamma=[0.01, 0.1, 1]
。訓練所有組合:即
(C=0.1, gamma=0.01)
、(C=0.1, gamma=0.1)
...直到(C=10, gamma=1)
。交叉驗證評估:每組參數都會在 k 折交叉驗證下計算平均性能指標(如準確率、F1 分數)。
選擇最佳參數:選出指標最優的一組參數作為最終模型配置。
3. 為什么要用網格搜索?
超參數選擇自動化:不用憑感覺拍腦袋。
保證找到最優解:只要網格覆蓋范圍足夠大,就不會遺漏最佳參數組合。
結合交叉驗證:結果更加穩健,避免過擬合或欠擬合。
但缺點也明顯:
計算開銷大:參數范圍和組合越多,訓練越耗時。
不適合大規模搜索:參數維度高時可能出現“維度災難”。
4. Scikit-Learn 中的網格搜索工具
sklearn.model_selection.GridSearchCV
是最常用的網格搜索實現。
4.1 函數原型
GridSearchCV(estimator, # 基礎模型,如SVC()、RandomForestClassifier()param_grid, # 參數字典或列表,定義搜索空間scoring=None, # 評估指標(accuracy、f1、roc_auc等)n_jobs=None, # 并行任務數,-1表示使用所有CPUcv=None, # 交叉驗證折數,如cv=5verbose=0, # 日志等級,1=簡單進度條,2=詳細refit=True, # 是否在找到最優參數后重新訓練整個模型return_train_score=False # 是否返回訓練集得分
)
GridSearchCV
常用參數表:
分類 | 參數 | 類型 | 說明 | 常用取值 |
---|---|---|---|---|
核心 | estimator | estimator 對象 | 基礎模型,必須實現 fit / predict | SVC() 、RandomForestClassifier() |
param_grid | dict / list | 要搜索的參數空間,鍵=參數名,值=候選值列表 | {'C':[0.1,1,10], 'gamma':[0.01,0.1,1]} | |
評估 | scoring | str / callable | 模型評估指標 | accuracy 、f1_macro 、roc_auc 、neg_mean_squared_error |
cv | int / 生成器 | 交叉驗證方式 | 5 (5折交叉驗證)、KFold(10) | |
refit | bool / str | 用最佳參數在全訓練集上重新訓練 | True (默認)、'f1_macro' (多指標時指定) | |
效率 | n_jobs | int | 并行任務數,-1=使用所有CPU | -1 、4 |
pre_dispatch | int / str | 并行調度策略 | '2*n_jobs' (默認) | |
日志 | verbose | int | 輸出日志等級 | 0 =無輸出,1 =進度,2 =詳細 |
錯誤處理 | error_score | str / numeric | 參數報錯時的分數 | np.nan (默認)、0 |
調試 | return_train_score | bool | 是否返回訓練集得分(用于過擬合分析) | False (默認)、True |
5. 網格搜索實戰案例
5.1 示例數據集
以鳶尾花(Iris)分類為例,使用 SVM 模型。
import numpy as np
import pandas as pd
from sklearn.datasets import load_iris
from sklearn.svm import SVC
from sklearn.model_selection import GridSearchCV, train_test_split# 加載數據
iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 定義模型
svc = SVC()
5.2 設置參數網格
param_grid = {'C': [0.1, 1, 10, 100], # 懲罰系數'gamma': [1, 0.1, 0.01, 0.001], # 核函數參數'kernel': ['rbf', 'linear'] # 核函數類型
}
5.3 執行網格搜索
grid = GridSearchCV(estimator=svc,param_grid=param_grid,scoring='accuracy',cv=5,verbose=2,n_jobs=-1
)
grid.fit(X_train, y_train)
5.4 輸出結果
print("最佳參數:", grid.best_params_)
print("最佳得分:", grid.best_score_)
print("測試集準確率:", grid.best_estimator_.score(X_test, y_test))
結果示例:
6. 網格搜索的可視化
我們可以把不同參數組合的表現繪制出來,直觀查看最優解在哪個區域:
import matplotlib.pyplot as pltresults = pd.DataFrame(grid.cv_results_)# 只繪制 C 與 gamma 的得分熱力圖(kernel=rbf)
scores = results[results.param_kernel == 'rbf'].pivot(index='param_gamma',columns='param_C',values='mean_test_score'
)plt.imshow(scores, interpolation='nearest', cmap=plt.cm.hot)
plt.xlabel('C')
plt.ylabel('gamma')
plt.colorbar()
plt.xticks(np.arange(len(scores.columns)), scores.columns)
plt.yticks(np.arange(len(scores.index)), scores.index)
plt.title('Grid Search Accuracy Heatmap')
plt.show()
7. 網格搜索的進階技巧
縮小搜索范圍:先用較粗粒度搜索,再在最優附近細化搜索。
并行計算:
n_jobs=-1
可利用多核 CPU。隨機搜索(RandomizedSearchCV):當參數空間太大時,可考慮隨機抽樣搜索,更高效。
貝葉斯優化:如
Optuna
、Hyperopt
,比網格搜索更智能。
8. 注意事項
參數空間不要過大,否則計算量爆炸。
交叉驗證的折數
cv
不宜過大,通常 5 或 10。選擇合適的評分指標
scoring
,分類問題常用accuracy
、f1_macro
,回歸問題用neg_mean_squared_error
等。最終模型建議用
grid.best_estimator_
,而不是手動再初始化。
9. 總結
**網格搜索(Grid Search)**是一種系統化的超參數優化方法,通過遍歷參數網格+交叉驗證,找到表現最優的參數組合。
在
sklearn
中,GridSearchCV
是核心工具。它簡單易用,但計算成本高,不適合大規模問題。
實際應用中常結合粗到細搜索、隨機搜索、貝葉斯優化來提升效率。