機器學習——網格搜索(GridSearchCV)超參數優化


網格搜索(Grid Search)詳細教學

1. 什么是網格搜索?

在機器學習模型中,算法的**超參數(Hyperparameters)**對模型的表現起著決定性作用。比如:

  • KNN 的鄰居數量 n_neighbors

  • SVM 的懲罰系數 C 和核函數參數 gamma

  • 隨機森林的決策樹數量 n_estimators

這些超參數不會在訓練過程中自動學習得到,而是需要我們人為設定。網格搜索(Grid Search)是一種最常見的超參數優化方法:
它通過
遍歷給定參數網格中的所有組合
,使用交叉驗證來評估每組參數的效果,最終選出表現最優的一組。

通俗理解:
👉 網格搜索 = 窮舉法找最佳參數。


2. 網格搜索的核心思想

  1. 定義參數范圍(網格):例如 C=[0.1, 1, 10]gamma=[0.01, 0.1, 1]

  2. 訓練所有組合:即 (C=0.1, gamma=0.01)(C=0.1, gamma=0.1)...直到 (C=10, gamma=1)

  3. 交叉驗證評估:每組參數都會在 k 折交叉驗證下計算平均性能指標(如準確率、F1 分數)。

  4. 選擇最佳參數:選出指標最優的一組參數作為最終模型配置。


3. 為什么要用網格搜索?

  • 超參數選擇自動化:不用憑感覺拍腦袋。

  • 保證找到最優解:只要網格覆蓋范圍足夠大,就不會遺漏最佳參數組合。

  • 結合交叉驗證:結果更加穩健,避免過擬合或欠擬合。

但缺點也明顯:

  • 計算開銷大:參數范圍和組合越多,訓練越耗時。

  • 不適合大規模搜索:參數維度高時可能出現“維度災難”。


4. Scikit-Learn 中的網格搜索工具

sklearn.model_selection.GridSearchCV 是最常用的網格搜索實現。

4.1 函數原型

GridSearchCV(estimator,          # 基礎模型,如SVC()、RandomForestClassifier()param_grid,         # 參數字典或列表,定義搜索空間scoring=None,       # 評估指標(accuracy、f1、roc_auc等)n_jobs=None,        # 并行任務數,-1表示使用所有CPUcv=None,            # 交叉驗證折數,如cv=5verbose=0,          # 日志等級,1=簡單進度條,2=詳細refit=True,         # 是否在找到最優參數后重新訓練整個模型return_train_score=False  # 是否返回訓練集得分
)

GridSearchCV 常用參數表:

分類參數類型說明常用取值
核心estimatorestimator 對象基礎模型,必須實現 fit / predictSVC()RandomForestClassifier()
param_griddict / list要搜索的參數空間,鍵=參數名,值=候選值列表{'C':[0.1,1,10], 'gamma':[0.01,0.1,1]}
評估scoringstr / callable模型評估指標accuracyf1_macroroc_aucneg_mean_squared_error
cvint / 生成器交叉驗證方式5(5折交叉驗證)、KFold(10)
refitbool / str用最佳參數在全訓練集上重新訓練True(默認)、'f1_macro'(多指標時指定)
效率n_jobsint并行任務數,-1=使用所有CPU-14
pre_dispatchint / str并行調度策略'2*n_jobs'(默認)
日志verboseint輸出日志等級0=無輸出,1=進度,2=詳細
錯誤處理error_scorestr / numeric參數報錯時的分數np.nan(默認)、0
調試return_train_scorebool是否返回訓練集得分(用于過擬合分析)False(默認)、True


5. 網格搜索實戰案例

5.1 示例數據集

以鳶尾花(Iris)分類為例,使用 SVM 模型。

import numpy as np
import pandas as pd
from sklearn.datasets import load_iris
from sklearn.svm import SVC
from sklearn.model_selection import GridSearchCV, train_test_split# 加載數據
iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 定義模型
svc = SVC()

5.2 設置參數網格

param_grid = {'C': [0.1, 1, 10, 100],          # 懲罰系數'gamma': [1, 0.1, 0.01, 0.001],  # 核函數參數'kernel': ['rbf', 'linear']      # 核函數類型
}

5.3 執行網格搜索

grid = GridSearchCV(estimator=svc,param_grid=param_grid,scoring='accuracy',cv=5,verbose=2,n_jobs=-1
)
grid.fit(X_train, y_train)

5.4 輸出結果

print("最佳參數:", grid.best_params_)
print("最佳得分:", grid.best_score_)
print("測試集準確率:", grid.best_estimator_.score(X_test, y_test))

結果示例


6. 網格搜索的可視化

我們可以把不同參數組合的表現繪制出來,直觀查看最優解在哪個區域:

import matplotlib.pyplot as pltresults = pd.DataFrame(grid.cv_results_)# 只繪制 C 與 gamma 的得分熱力圖(kernel=rbf)
scores = results[results.param_kernel == 'rbf'].pivot(index='param_gamma',columns='param_C',values='mean_test_score'
)plt.imshow(scores, interpolation='nearest', cmap=plt.cm.hot)
plt.xlabel('C')
plt.ylabel('gamma')
plt.colorbar()
plt.xticks(np.arange(len(scores.columns)), scores.columns)
plt.yticks(np.arange(len(scores.index)), scores.index)
plt.title('Grid Search Accuracy Heatmap')
plt.show()

7. 網格搜索的進階技巧

  1. 縮小搜索范圍:先用較粗粒度搜索,再在最優附近細化搜索。

  2. 并行計算n_jobs=-1 可利用多核 CPU。

  3. 隨機搜索(RandomizedSearchCV):當參數空間太大時,可考慮隨機抽樣搜索,更高效。

  4. 貝葉斯優化:如 OptunaHyperopt,比網格搜索更智能。


8. 注意事項

  • 參數空間不要過大,否則計算量爆炸。

  • 交叉驗證的折數 cv 不宜過大,通常 5 或 10。

  • 選擇合適的評分指標 scoring,分類問題常用 accuracyf1_macro,回歸問題用 neg_mean_squared_error 等。

  • 最終模型建議用 grid.best_estimator_,而不是手動再初始化。


9. 總結

  • **網格搜索(Grid Search)**是一種系統化的超參數優化方法,通過遍歷參數網格+交叉驗證,找到表現最優的參數組合。

  • sklearn 中,GridSearchCV 是核心工具。

  • 它簡單易用,但計算成本高,不適合大規模問題。

  • 實際應用中常結合粗到細搜索、隨機搜索、貝葉斯優化來提升效率。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/919603.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/919603.shtml
英文地址,請注明出處:http://en.pswp.cn/news/919603.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

【LeetCode】18. 四數之和

文章目錄18. 四數之和題目描述示例 1:示例 2:提示:解題思路算法一:排序 雙指針(推薦)算法二:通用 kSum(含 2Sum 雙指針)復雜度關鍵細節代碼實現要點完整題解代碼18. 四數…

Go語言入門(10)-數組

訪問數組元素:數組中的每個元素都可以通過“[]”和一個從0開始的索引進行訪問數組的長度可由內置函數len來確定。在聲明數組時,未被賦值元素的值是對應類型的零值。下面看一個例子package mainfunc main(){var planets [8]stringplanets[0] "Mercu…

為什么經過IPSec隧道后HTTPS會訪問不通?一次隧道環境下的實戰分析

在運維圈子里,大家可能都遇到過這種奇怪的問題:瀏覽器能打開 HTTP 網站,但一換成 HTTPS,頁面就死活打不開。前段時間,我們就碰到這么一個典型案例。故障現象某公司系統在 VPN 隧道里訪問 HTTPS 服務,結果就…

【Linux系統】進程信號:信號的產生和保存

上篇文章我們介紹了Syetem V IPC的消息隊列和信號量,那么信號量和我們下面要介紹的信號有什么關系嗎?其實沒有關系,就相當于我們日常生活中常說的老婆和老婆餅,二者并沒有關系1. 認識信號1.1 生活角度的信號解釋(快遞比…

WEB服務器(靜態/動態網站搭建)

簡介 名詞:HTML(超文本標記語言),網站(多個網頁組成一臺網站),主頁,網頁,URL(統一資源定位符) 網站架構:LAMP(linux(系統)+apache(服務器程序)+mysql(數據庫管理軟件)+php(中間軟件)) 靜態站點 Apache基礎 Apache官網:www.apache.org 軟件包名稱:…

開發避坑指南(29):微信昵稱特殊字符存儲異常修復方案

異常信息 Cause: java.sql.SQLException: Incorrect string value: \xF0\x9F\x8D\x8B\xE5\xBB... for column nick_name at row 1異常背景 抽獎大轉盤,抽獎后需要保存用戶抽獎記錄,用戶再次進入游戲時根據抽獎記錄判斷剩余抽獎機會。保存抽獎記錄時需要…

leetcode-python-242有效的字母異位詞

題目&#xff1a; 給定兩個字符串 s 和 t &#xff0c;編寫一個函數來判斷 t 是否是 s 的 字母異位詞。 示例 1: 輸入: s “anagram”, t “nagaram” 輸出: true 示例 2: 輸入: s “rat”, t “car” 輸出: false 提示: 1 < s.length, t.length < 5 * 104 s 和 t 僅…

【ARM】Keil MDK如何指定單文件的優化等級

1、 文檔目標解決在MDK中如何對于單個源文件去設置優化等級。2、 問題場景在正常的項目開發中&#xff0c;我們通常都是針對整個工程去做優化&#xff0c;相當于整個工程都是使用一個編譯器優化等級去進行的工程構建。那么在一些特定的情況下&#xff0c;工程師需要保證我的部分…

零基礎學Java第二十二講---異常(2)

續接上一講 目錄 一、異常的處理&#xff08;續&#xff09; 1、異常的捕獲-try-catch捕獲并處理異常 1.1關于異常的處理方式 2、finally 3、異常的處理流程 二、自定義異常類 1、實現自定義異常類 一、異常的處理&#xff08;續&#xff09; 1、異常的捕獲-try-catch捕…

自建開發工具IDE(一)之拖找排版—仙盟創夢IDE

自建拖拽布局排版在 IDE 中的優勢及初學者開發指南在軟件開發領域&#xff0c;用戶界面&#xff08;UI&#xff09;的設計至關重要。自建拖拽布局排版功能為集成開發環境&#xff08;IDE&#xff09;帶來了諸多便利&#xff0c;尤其對于初學者而言&#xff0c;是踏入開發領域的…

GitHub Copilot - GitHub 推出的AI編程助手

本文轉載自&#xff1a;GitHub Copilot - GitHub 推出的AI編程助手 - Hello123工具導航。 ** 一、GitHub Copilot 核心定位 GitHub Copilot 是由 GitHub 與 OpenAI 聯合開發的 AI 編程助手&#xff0c;基于先進大語言模型實現代碼實時補全、錯誤檢測及文檔生成&#xff0c;顯…

基于截止至 2025 年 6 月 4 日,在 App Store 上進行交易的設備數據統計,iOS/iPadOS 各版本在所有設備中所占比例詳情

iOS 和 iPadOS 使用情況 基于截止至 2025 年 6 月 4 日&#xff0c;在 App Store 上進行交易的設備數據統計。 iPhone 在過去四年推出的設備中&#xff0c;iOS 18 的普及率達 88。 88% iOS 188% iOS 174% 較早版本 所有的設備中&#xff0c;iOS 18 的普及率達 82。 82% iOS 189…

云計算-k8s實戰指南:從 ServiceMesh 服務網格、流量管理、limitrange管理、親和性、環境變量到RBAC管理全流程

介紹 本文是一份 Kubernetes 與 ServiceMesh 實戰操作指南,涵蓋多個核心功能配置場景。從 Bookinfo 應用部署入手,詳細演示了通過 Istio 創建 Ingress Gateway 實現外部訪問,以及基于用戶身份、請求路徑的服務網格路由規則配置,同時為應用微服務設置了默認目標規則。 還包…

Vue 3項目中的路由管理和狀態管理系統

核心概念理解 1. 整體架構關系 這兩個文件構成了Vue應用的導航系統和狀態管理系統&#xff1a; Router&#xff08;路由&#xff09;&#xff1a;控制頁面跳轉和URL變化Store&#xff08;狀態&#xff09;&#xff1a;管理全局數據和用戶狀態兩者協同工作實現權限控制 2. 數據流…

Linux Capability 解析

文章目錄1. 權限模型演進背景2. Capability核心原理2.1 能力單元分類2.2 進程三集合2.3 文件系統屬性3. 完整能力單元表4. 高級應用場景4.1 能力邊界控制4.2 編程控制4.3 容器安全5. 安全實踐建議6. 潛在風險提示 1. 權限模型演進背景 在傳統UNIX權限模型中&#xff0c;采用二進…

vue 監聽 sessionStorage 值的變化

<template><div class"specific-storage-watcher"><h3>僅監聽 userId 變化</h3><p>當前 userId: {{ currentUserId }}</p><p v-if"changeRecord">最近變化: {{ changeRecord }}</p><button click"…

IDEA:控制臺中文亂碼

目錄一、設置字符編碼為 UTF-8一、設置字符編碼為 UTF-8 點擊菜單 File -> settings -> Eitor -> File Encodings , 將字符全局編碼、項目編碼、配置文件編碼統一設置為UTF-8, 然后點擊 Apply 應用設置&#xff0c;點擊 OK 關閉對話框:

[Sql Server]特殊數值計算

任務一&#xff1a;求下方的Num列的中值:參考代碼:use Test go SELECT DISTINCTPERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY Num) over()AS MedianSalary FROM MedianTest;任務二: 下方表中,每個選手有多個評委打分&#xff0c;求每個選手的評委打分中值。參考代碼:use Tes…

01-Docker概述

Docker 的主要目標是:Build, Ship and Run Any App, Anywhere,也就是通過對應用組件的封裝、分發、部署、運行等生命周期的管理,使用戶的 APP 及其運行環境能做到一次鏡像,處處運行。 Docker 運行速度快的原因: 由于 Docker 不需要 Hypervisor(虛擬機)實現硬件資源虛擬化…

Laravel中如何使用php-casbin

一、&#x1f680; 安裝和配置 1. 安裝包 composer require casbin/laravel-authz2. 發布配置文件 php artisan vendor:publish這會生成兩個重要文件&#xff1a; config/lauthz.php - 主配置文件config/lauthz-rbac-model.conf - RBAC 模型配置文件 3. 運行數據庫遷移 php…