模型選擇與調優

一、模型選擇與調優

  • 在機器學習中,模型的選擇和調優是一個重要的步驟,它直接影響到最終模型的性能

1、交叉驗證

  • 在任何有監督機器學習項目的模型構建階段,我們訓練模型的目的是從標記的示例中學習所有權重和偏差的最佳值

  • 如果我們使用相同的標記示例來測試我們的模型,那么這將是一個方法論錯誤,因為一個只會重復剛剛看到的樣本標簽的模型將獲得完美的結果,但無法預測數據,這種情況稱為過擬合,為了克服過度擬合的問題,我們使用交叉驗證

  • 交叉驗證(Cross-validation)是一種統計學上的方法,用于評估機器學習模型的性能,并幫助避免過擬合。它的主要思想是在有限的數據集上劃分出一部分數據用于測試模型的泛化能力

1.1 保留交叉驗證HoldOut

  • 保留交叉驗證(Holdout Cross Validation)是最簡單的一種交叉驗證方法。在這種方法中,數據集被一次性劃分為兩個互斥的部分:一個較大的部分作為訓練集(training set),用于訓練模型;另一個較小的部分作為驗證集(validation set)或測試集(test set),用于評估模型的泛化能力

  • Holdout 方法的步驟:

    • 數據劃分:從原始數據集中隨機抽取一部分數據作為測試集,剩余的數據作為訓練集。通常的比例為70%的數據作為訓練集,30%的數據作為測試集

    • 訓練模型:使用訓練集數據來訓練模型

    • 評估模型:使用測試集數據來評估模型的性能,如準確率、召回率等指標

  • Holdout 方法的優點:

    • 簡單易行:只需要一次劃分即可完成訓練和測試

    • 計算效率高:相較于K折交叉驗證等方法,Holdout方法的計算開銷較低

  • Holdout 方法的缺點:

    • 不適用于不平衡的數據集:假設我們有一個不平衡的數據集,有 0 類和 1 類。假設80%的數據屬于 0 類,其余 20% 的數據屬于 1 類。這種情況下,訓練集的大小為 80%,測試數據的大小為數據集的 20%。可能發生的情況是,所有 80% 的 0 類數據都在訓練集中,而所有 1 類數據都在測試集中。因此,我們的模型將不能很好地概括我們的測試數據,因為它之前沒有見過 1 類的數據

    • 大塊數據被剝奪了訓練模型的機會:在小數據集的情況下,有一部分數據將被保留下來用于測試模型,這些數據可能具有重要的特征,而我們的模型可能會因為沒有在這些數據上進行訓練而錯過

1.2 K-折交叉驗證K-Fold

  • K-折交叉驗證(K-Fold Cross Validation)是一種評估機器學習模型性能的方法

  • 在這種 K 折交叉驗證技術中,整個數據集被劃分為 K 個相等大小的部分。每個分區稱為一個折疊。因為有 K 個部分,所以我們稱之為 K-Fold 。一個 Fold 折用作驗證集,其余 K-1 個 Fold 用作訓練集

  • 圖示:

  • 執行步驟:

    • 數據劃分:

      • 首先,將整個數據集隨機分成 K 個子集或者 Folds,盡量保證每個子集的大小相同

    • 模型訓練與測試:

      • 對于每一個子集,將其保留作為測試集,而其他 K-1個子集合并作為訓練集。這樣,就可以訓練一次模型并評估其性能

      • 這個過程會被重復 K 次,每次選擇不同的子集作為測試集

    • 性能評估:

      • 在所有 K 次迭代結束后,計算每次測試的結果(如準確率、召回率、F1分數等),然后求這些結果的平均值,以此作為模型性能的估計

  • cross_val_score 是 scikit-learn 中的一個函數,用于執行交叉驗證并返回模型在不同折疊上的得分,cross_val_score 函數的主要參數如下:

    • estimator:一個 scikit-learn 的估算器(estimator),通常是模型類的實例,如 LogisticRegressionSVC

    • X:特征數據,可以是 NumPy 數組、Pandas DataFrame 或其他支持索引的數據結構

    • y:目標數據,通常是一個一維數組,表示每個樣本的標簽

    • cv:交叉驗證的折疊數或交叉驗證生成器。可以是一個整數(表示 K 折交叉驗證的 K 值),也可以是一個交叉驗證生成器(如 StratifiedKFold

  • 優點:

    • 減少過擬合:由于模型是整個數據集既用作訓練集又用作驗證集,這種方法可以幫助減少模型的過擬合傾向

    • 提高模型穩定性:通過多次迭代,K-折交叉驗證能夠提供更穩定的模型性能評估,因為它考慮了不同數據劃分對模型的影響

    • 利用有限數據:當可用的數據量較小的時候,這種方法允許更有效地利用數據來評估模型性能

  • 缺點:

    • 計算成本較高:需要訓練和驗證K次模型,因此需要更多的計算資源和時間

1.3 分層K-折交叉驗證Stratified k-fold

  • 分層 K 折交叉驗證(Stratified K-Fold Cross Validation)是一種改進的交叉驗證方法,它特別適用于類別不平衡的數據集。這種技術確保在每個折疊(fold)中,不同類別的樣本比例保持一致,從而使得每次訓練和測試集的類別分布盡可能與整體數據集的類別分布相同,比如說:原始數據有 3 類,比例為 1:2:1,采用 3 折分層交叉驗證,那么劃分的 3 折中,每一折中的數據類別保持著 1:2:1 的比例

  • StratifiedKFoldscikit-learn庫中的一個類,用于實現分層 K 折交叉驗證(Stratified K-Fold Cross Validation),構造函數 StratifiedKFold 的參數如下:

    • n_splits (int, default=5):定義將數據集分割成多少個折疊(folds)。默認值為 5,意味著將數據集分成 5 個子集

    • shuffle (bool, default=False):如果設置為 True,則在劃分數據之前會對數據進行隨機化。這有助于提高結果的隨機性和公平性

    • random_state (int, RandomState instance or None, default=None):如果 shuffle=True,則可以設置 random_state 來確保每次分割的結果可復現。它可以是一個整數(作為隨機種子)、一個 RandomState 實例,或者 None(表示不使用固定的隨機種子)

  • StratifiedKFold返回的對象中,可以調用skf.split(X, y)方法,skf.split(X, y) 方法是 StratifiedKFold 類的一個重要方法,用于生成訓練集和測試集的索引。該方法返回一個迭代器,每次迭代返回一個元組 (train_index, test_index),分別對應訓練集和測試集的索引,參數和返回值如下:

    • X:特征數據,可以是任何支持索引訪問的數據結構,如 NumPy 數組、列表或 Pandas DataFrame

    • y:標簽數據,通常是一個一維數組,表示每個樣本的類別標簽

    • split 方法返回一個迭代器,每次迭代返回一個元組 (train_index, test_index),其中:

      • train_index:一個 NumPy 數組,包含訓練集樣本的索引

      • test_index:一個 NumPy 數組,包含測試集樣本的索引

案例:

from sklearn.metrics import accuracy_score
from sklearn.model_selection import StratifiedKFold
from sklearn.datasets import load_iris
from sklearn.neighbors import KNeighborsClassifier
from sklearn.preprocessing import StandardScaler
?
def stratified_k_fold():iris = load_iris()X = iris.datay = iris.target# 模型對象model = KNeighborsClassifier()# K-折交叉驗證skf = StratifiedKFold(n_splits=5,shuffle=True,random_state=42)# 存儲每次驗證的結果scores = []for train_index,test_index in skf.split(X,y):X_train,X_test = X[train_index],X[test_index]y_train,y_test = y[train_index],y[test_index]# 特征工程:標準化standardScaler = StandardScaler()X_train = standardScaler.fit_transform(X_train)X_test = standardScaler.transform(X_test)# 訓練模型model.fit(X_train,y_train)# 在測試集上評估模型y_predict = model.predict(X_test)score = accuracy_score(y_test, y_predict)scores.append(score)print(scores)

StratifiedKFold 的工作原理

  • 目標:在劃分數據時,保持每一折的 y_trainy_test 中各類別的比例與原始數據 y 相同。

  • 實現方式

    • 每個類別單獨計算分位數,確保每個類別的樣本均勻分布在每一折中。

    • 例如,如果原始數據中類別 A 占 40%,B 占 60%,則每一折的 y_trainy_test 也會保持 4:6 的比例。

2、超參數搜索

  • 超參數是在建立模型時用于控制算法行為的參數。這些參數不能從常規訓練過程中獲得。在對模型進行訓練之前,需要對它們進行賦

  • 超參數的選擇對模型的最終性能有著至關重要的影響。不同的超參數組合可能導致模型在訓練集上過擬合或者欠擬合。因此,尋找合適的超參數組合是一項關鍵任務

  • 常見的方式:

    • 手工調參

    • 網格搜索

    • 隨機搜索

    • 貝葉斯搜索

2.1 網格搜索

  • 網格搜索是一種基本的超參數調優技術。它為網格中指定的所有給定超參數值的每個排列構建模型,評估并選擇最佳模型

  • GridSearchCVscikit-learn 庫提供的一個用于執行網格搜索的類,它可以在給定的超參數網格中自動尋找最佳的超參數組合。GridSearchCV 通過交叉驗證來評估每種超參數組合的效果,并最終返回性能最好的模型

  • GridSearchCV的構造函數的參數如下:

    • estimator:scikit-learn估計器實例

    • param_grid:以參數名稱(str)作為鍵,將參數設置列表嘗試作為值的字典

    • cv:確定交叉驗證切分策略,None 默認5折,integer 設置多少折

  • GridSearchCV的對象中,有如下幾個重要的屬性:

    • best_params_ 最佳參數

    • best_score_ 在訓練集中的準確率

    • best_estimator_ 最佳估計器

    • cv_results_ 交叉驗證結果

案例:

from sklearn.datasets import load_iris
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import GridSearchCV
?
?
# 用KNN算法對鳶尾花進行分類,添加網格搜索和交叉驗證
def knn_iris_gridSearchCV():iris = load_iris()X = iris.datay = iris.targetX_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)# 特征工程:標準化transfer = StandardScaler()X_train = transfer.fit_transform(X_train)X_test = transfer.transform(X_test)# 創建KNN模型, 不設置n_neighbors的值,后期讓GridSearchCV來設置knn = KNeighborsClassifier()# 加入網格搜索與交叉驗證, GridSearchCV會讓k分別等于1,2,5,7,9,11進行網格搜索model = GridSearchCV(knn, param_grid={"n_neighbors": [1, 2, 5, 7, 9, 11]})# 訓練模型model.fit(X_train, y_train)# 模型評估# 比對真實值和預測值y_predict = model.predict(X_test)print("y_predict:\n", y_predict)print("真實值和預測值的準確率:\n", y_test == y_predict)# 計算預測值和真實值的準確率print("預測值和真實值的準確率:\n", accuracy_score(y_test, y_predict))# 計算準確率score = model.score(X_test, y_test)print("在測試集中的準確率為:\n", score)# 最佳參數:best_params_print("最佳參數:\n", model.best_params_)# 最佳結果:best_score_print("在訓練集中的準確率:\n", model.best_score_)# 最佳估計器:best_estimator_print("最佳估計器:\n", model.best_estimator_)# 交叉驗證的平均得分:cv_results_['mean_test_score']print("交叉驗證的平均得分:\n", model.cv_results_['mean_test_score'])# 超參數的所有組合:model.cv_results_['params']print("超參數的所有組合:\n", model.cv_results_['params'])# 交叉驗證結果:cv_results_print("交叉驗證結果:\n", model.cv_results_)

3、模型保存與加載

  • joblib 是一個 Python 庫,主要用于并行計算和持久化存儲(即保存和加載)大型 NumPy 數組和模型。joblib 提供了 dumpload 兩個函數,用于保存和恢復 Python 對象,特別是機器學習模型

函數名參數說明
joblib.dump()obj:要保存的對象 filename:保存文件的路徑將 Python 對象保存到磁盤文件中。它可以有效地壓縮數據,并且支持并行寫入,適合用于保存大型數據集或模型
joblib.load()filename:保存對象的文件路徑從磁盤文件中恢復之前保存的對象

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/93246.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/93246.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/93246.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

vue+Django農產品推薦與價格預測系統、雙推薦+機器學習預測+知識圖譜

vueflask農產品推薦與價格預測系統、雙推薦機器學習價格預測知識圖譜文章結尾部分有CSDN官方提供的學長 聯系方式名片 文章結尾部分有CSDN官方提供的學長 聯系方式名片 關注B站,有好處!編號: D010 技術架構: vueflaskmysqlneo4j 核心技術: 基…

數據分析小白訓練營:基于python編程語言的Numpy庫介紹(第三方庫)(下篇)

銜接上篇文章:數據分析小白訓練營:基于python編程語言的Numpy庫介紹(第三方庫)(上篇)(十一)數組的組合核心功能:一、生成基數組np.arange().reshape() 基礎運算功能&…

負載因子(Load Factor) :哈希表(Hash Table)中的一個關鍵性能指標

負載因子(Load Factor) 是哈希表(Hash Table)中的一個關鍵性能指標,用于衡量哈希表的空間利用率和發生哈希沖突的可能性。一:定義負載因子(通常用希臘字母 λ 表示)的計算公式為&…

監控插件SkyWalking(一)原理

一、介紹 1、簡介 SkyWalking 是一個 開源的 APM(Application Performance Monitoring,應用性能監控)和分布式追蹤系統,主要用于監控、追蹤、分析分布式系統中的調用鏈路、性能指標和日志。 它由 Apache 基金會托管,…

【接口自動化測試】---自動化框架pytest

目錄 1、用例運行規則 2、pytest命令參數 3、pytest配置文件 4、前后置 5、斷言 6、參數化---對函數的參數(重要) 7、fixture 7.1、基本用法 7.2、fixture嵌套: 7.3、請求多個fixture: 7.4、yield fixture 7.5、帶參數…

Flink Stream API 源碼走讀 - socketTextStream

概述 本文深入分析了 Flink 中 socketTextStream() 方法的源碼實現,從用戶API調用到最終返回 DataStream 的完整流程。 核心知識點 1. socketTextStream 方法重載鏈 // 用戶調用入口 env.socketTextStream("hostname", 9999)↓ 補充分隔符參數 env.socket…

待辦事項小程序開發

1. 項目規劃功能需求:添加待辦事項標記完成/未完成刪除待辦事項分類或標簽管理(可選)數據持久化(本地存儲)2. 實現功能添加待辦事項:監聽輸入框和按鈕事件,將輸入內容添加到列表。 標記完成/未完…

【C#】Region、Exclude的用法

在 C# 中,Region 和 Exclude 是與圖形編程相關的概念,通常在使用 System.Drawing 命名空間進行 GDI 繪圖時出現。它們主要用于定義和操作二維空間中的區域(幾何區域),常用于窗體裁剪、控件重繪、圖形繪制優化等場景。 …

機器學習 - Kaggle項目實踐(3)Digit Recognizer 手寫數字識別

Digit Recognizer | Kaggle 題面 Digit Recognizer-CNN | Kaggle 下面代碼的kaggle版本 使用CNN進行手寫數字識別 學習到了網絡搭建手法學習率退火數據增廣 提高訓練效果。 使用混淆矩陣 以及對分類出錯概率最大的例子單獨拎出來分析。 最終以99.546%正確率 排在 86/1035 …

新手如何高效運營亞馬遜跨境電商:從傳統SP廣告到DeepBI智能策略

"為什么我的廣告點擊量很高但訂單轉化率卻很低?""如何避免新品期廣告預算被大詞消耗殆盡?""為什么手動調整關鍵詞和出價總是慢市場半拍?""競品ASIN投放到底該怎么做才有效?""有沒有…

【論文閱讀 | CVPR 2024 | UniRGB-IR:通過適配器調優實現可見光-紅外語義任務的統一框架】

論文閱讀 | CVPR 2024 | UniRGB-IR:通過適配器調優實現可見光-紅外語義任務的統一框架?1&&2. 摘要&&引言3.方法3.1 整體架構3.2 多模態特征池3.3 補充特征注入器3.4 適配器調優范式4 實驗4.1 RGB-IR 目標檢測4.2 RGB-IR 語義分割4.3 RGB-IR 顯著目…

Hyperf 百度翻譯接口實現方案

保留 HTML/XML 標簽結構,僅翻譯文本內容,避免破壞富文本格式。采用「HTML 解析 → 文本提取 → 批量翻譯 → 回填」的流程。百度翻譯集成方案:富文本內容翻譯系統 HTML 解析 百度翻譯 API 集成 文件結構 app/ ├── Controller/ │ └──…

字節跳動 VeOmni 框架開源:統一多模態訓練效率飛躍!

資料來源:火山引擎-開發者社區 多模態時代的訓練痛點,終于有了“特效藥” 當大模型從單一語言向文本 圖像 視頻的多模態進化時,算法工程師們的訓練流程卻陷入了 “碎片化困境”: 當業務要同時迭代 DiT、LLM 與 VLM時&#xff0…

配置docker pull走http代理

之前寫了一篇自建Docker鏡像加速器服務的博客,需要用到境外服務器作為代理,但是一般可能沒有境外服務器,只有http代理,所以如果本地使用想走代理可以用以下方式 臨時生效(只對當前終端有效) 設置環境變量…

OpenAI 開源模型 gpt-oss 本地部署詳細教程

OpenAI 最近發布了其首個開源的開放權重模型gpt-oss,這在AI圈引起了巨大的轟動。對于廣大開發者和AI愛好者來說,這意味著我們終于可以在自己的機器上,完全本地化地運行和探索這款強大的模型了。 本教程將一步一步指導你如何在Windows和Linux…

力扣-5.最長回文子串

題目鏈接 5.最長回文子串 class Solution {public String longestPalindrome(String s) {boolean[][] dp new boolean[s.length()][s.length()];int maxLen 0;String str s.substring(0, 1);for (int i 0; i < s.length(); i) {dp[i][i] true;}for (int len 2; len …

Apache Ignite超時管理核心組件解析

這是一個非常關鍵且設計精巧的 定時任務與超時管理組件 —— GridTimeoutProcessor&#xff0c;它是 Apache Ignite 內核中負責 統一調度和處理所有異步超時事件的核心模塊。&#x1f3af; 一、核心職責統一管理所有需要“在某個時間點觸發”的任務或超時邏輯。它相當于 Ignite…

DAY 42 Grad-CAM與Hook函數

知識點回顧回調函數lambda函數hook函數的模塊鉤子和張量鉤子Grad-CAM的示例# 定義一個存儲梯度的列表 conv_gradients []# 定義反向鉤子函數 def backward_hook(module, grad_input, grad_output):# 模塊&#xff1a;當前應用鉤子的模塊# grad_input&#xff1a;模塊輸入的梯度…

基于 NVIDIA 生態的 Dynamo 風格分布式 LLM 推理架構

網羅開發&#xff08;小紅書、快手、視頻號同名&#xff09;大家好&#xff0c;我是 展菲&#xff0c;目前在上市企業從事人工智能項目研發管理工作&#xff0c;平時熱衷于分享各種編程領域的軟硬技能知識以及前沿技術&#xff0c;包括iOS、前端、Harmony OS、Java、Python等方…

《吃透 C++ 類和對象(中):拷貝構造函數與賦值運算符重載深度解析》

&#x1f525;個人主頁&#xff1a;草莓熊Lotso &#x1f3ac;作者簡介&#xff1a;C研發方向學習者 &#x1f4d6;個人專欄&#xff1a; 《C語言》 《數據結構與算法》《C語言刷題集》《Leetcode刷題指南》 ??人生格言&#xff1a;生活是默默的堅持&#xff0c;毅力是永久的…