機器學習入門之KNN算法和交叉驗證與超參數搜索(三)

機器學習入門之KNN算法和交叉驗證與超參數搜索(三)

文章目錄

  • 機器學習入門之KNN算法和交叉驗證與超參數搜索(三)
  • 一、KNN算法-分類
    • 1. 樣本距離判斷
      • 明可夫斯基距離
    • 2. KNN 算法原理
    • 3. KNN 的缺點
    • 4. KNN 的 API
    • 5. 使用 sklearn 實現 KNN 示例
    • 6. 模型保存與加載
  • 二、模型選擇與調優:交叉驗證與超參數搜索
    • 1. 交叉驗證
      • (1) 保留交叉驗證(HoldOut)
      • (2) K-折交叉驗證(K-fold)
      • (3) 分層 K-折交叉驗證(Stratified K-fold)
      • (4) 其他驗證方法
      • (5) API 示例
    • 2. 超參數搜索(網格搜索,Grid Search)
    • 3. sklearn API
    • 4. 示例:鳶尾花分類


一、KNN算法-分類

1. 樣本距離判斷

KNN 算法中,樣本之間的距離是判斷相似性的關鍵。常見的距離度量方式包括:

明可夫斯基距離

  • 歐式距離:明可夫斯基距離的特殊情況,公式為 (\sqrt{\sum_{i=1}^{n}(x_i - y_i)^2})。
  • 曼哈頓距離:明可夫斯基距離的另一種特殊情況,公式為 (\sum_{i=1}^{n}|x_i - y_i|)。

2. KNN 算法原理

K-近鄰算法(K-Nearest Neighbors,簡稱 KNN)是一種基于實例的學習方法。其核心思想是:如果一個樣本在特征空間中的 k 個最相似(最鄰近)樣本中的大多數屬于某個類別,則該樣本也屬于這個類別。例如,假設我們有 10000 個樣本,選擇距離樣本 A 最近的 7 個樣本,其中類別 1 有 2 個,類別 2 有 3 個,類別 3 有 2 個,則樣本 A 被認為屬于類別 2。

3. KNN 的缺點

  • 計算量大:對于大規模數據集,需要計算測試樣本與所有訓練樣本的距離。
  • 維度災難:在高維數據中,距離度量可能變得不那么有意義。
  • 參數選擇:需要選擇合適的 k 值和距離度量方式,這可能需要多次實驗和調整。

4. KNN 的 API

class sklearn.neighbors.KNeighborsClassifier(n_neighbors=5, algorithm='auto')
  • 參數
    • n_neighbors:用于 kneighbors 查詢的近鄰數,默認為 5。
    • algorithm:找到近鄰的方式,可選值為 {'auto', 'ball_tree', 'kd_tree', 'brute'},默認為 'auto'
  • 方法
    • fit(X, y):使用 X 作為訓練數據和 y 作為目標數據。
    • predict(X):預測提供的數據,返回預測結果。

5. 使用 sklearn 實現 KNN 示例

以下是一個使用 KNN 算法對鳶尾花進行分類的完整代碼示例:

# 用 KNN 算法對鳶尾花進行分類
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier# 1)獲取數據
iris = load_iris()
print(iris.data.shape)  # (150, 4)
print(iris.feature_names)  # ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']
print(iris.target.shape)  # (150,)
print(iris.target)  # [0 0 0 ... 2 2 2]
print(iris.target_names)  # ['setosa' 'versicolor' 'virginica']# 2)劃分數據集
x_train, x_test, y_train, y_test = train_test_split(iris.data, iris.target, random_state=22)# 3)特征工程:標準化
transfer = StandardScaler()
x_train = transfer.fit_transform(x_train)
x_test = transfer.transform(x_test)# 4)KNN 算法預估器
estimator = KNeighborsClassifier(n_neighbors=7)
estimator.fit(x_train, y_train)# 5)模型評估
y_predict = estimator.predict(x_test)
print("y_predict:\n", y_predict)
print("直接比對真實值和預測值:\n", y_test == y_predict)
score = estimator.score(x_test, y_test)
print("準確率為:\n", score)  # 0.9473684210526315

6. 模型保存與加載

使用 joblib 可以方便地保存和加載模型:

import joblib# 保存模型
joblib.dump(estimator, "my_knn.pkl")# 加載模型
estimator = joblib.load("my_knn.pkl")# 使用模型預測
y_test = estimator.predict([[0.4, 0.2, 0.4, 0.7]])
print(y_test)

以下是整理后的 Markdown 格式內容:


二、模型選擇與調優:交叉驗證與超參數搜索

1. 交叉驗證

交叉驗證是評估模型性能的重要方法,常見的交叉驗證技術包括:

(1) 保留交叉驗證(HoldOut)

  • 原理:將數據集隨機劃分為訓練集和驗證集,通常比例為 70% 訓練集和 30% 驗證集。
  • 優點:簡單易行。
  • 缺點
    • 不適用于不平衡數據集。
    • 一部分數據未參與訓練,可能導致模型性能不佳。

(2) K-折交叉驗證(K-fold)

  • 原理:將數據集劃分為 K 個大小相同的部分,每次使用一個部分作為驗證集,其余部分作為訓練集,重復 K 次。
  • 優點:充分利用數據,模型性能更穩定。
  • 缺點:計算量較大。

(3) 分層 K-折交叉驗證(Stratified K-fold)

  • 原理:在每一折中保持原始數據中各個類別的比例關系,確保每個折疊的類別分布與整體數據一致。
  • 優點:適用于不平衡數據集,驗證結果更可信。

(4) 其他驗證方法

  • 留一交叉驗證:每次只留一個樣本作為驗證集。
  • 蒙特卡羅交叉驗證:隨機劃分訓練集和測試集,多次重復。
  • 時間序列交叉驗證:適用于時間序列數據。

(5) API 示例

from sklearn.model_selection import StratifiedKFoldstrat_k_fold = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
indexs = strat_k_fold.split(X, y)for train_index, test_index in indexs:X_train, X_test = X[train_index], X[test_index]y_train, y_test = y[train_index], y[test_index]

2. 超參數搜索(網格搜索,Grid Search)

網格搜索是一種自動尋找最佳超參數的方法,通過遍歷所有可能的參數組合來找到最優解。

3. sklearn API

from sklearn.model_selection import GridSearchCVGridSearchCV(estimator, param_grid, cv=5)
  • 參數
    • estimator:模型實例。
    • param_grid:超參數的網格,例如 {"n_neighbors": [1, 3, 5, 7, 9, 11]}
    • cv:交叉驗證的折數,默認為 5。
  • 屬性
    • best_params_:最佳參數。
    • best_score_:最佳模型的交叉驗證分數。
    • best_estimator_:最佳模型實例。
    • cv_results_:交叉驗證結果。

4. 示例:鳶尾花分類

以下是一個使用 KNN 算法對鳶尾花進行分類的完整代碼示例,結合了分層 K-折交叉驗證和網格搜索:

from sklearn.datasets import load_iris
from sklearn.model_selection import StratifiedKFold, GridSearchCV
from sklearn.neighbors import KNeighborsClassifier
from sklearn.preprocessing import StandardScaler# 加載數據
iris = load_iris()
X = iris.data
y = iris.target# 初始化分層 K-折交叉驗證器
strat_k_fold = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)# 創建 KNN 分類器實例
knn = KNeighborsClassifier()# 網格搜索與交叉驗證
param_grid = {"n_neighbors": [1, 3, 5, 7, 9, 11]}
grid_search = GridSearchCV(knn, param_grid, cv=strat_k_fold)
grid_search.fit(X, y)# 輸出結果
print("最佳參數:", grid_search.best_params_)  # {'n_neighbors': 3}
print("最佳準確率:", grid_search.best_score_)  # 0.9553030303030303
print("最佳模型:", grid_search.best_estimator_)  # KNeighborsClassifier(n_neighbors=3)

通過分層 K-折交叉驗證和網格搜索,我們可以找到最優的超參數,從而提高模型的性能。


本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/81128.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/81128.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/81128.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

小剛說C語言刷題—1700請輸出所有的2位數中,含有數字2的整數

1.題目描述 請輸出所有的 2 位數中,含有數字 2 的整數有哪些,每行 1個,按照由小到大輸出。 比如: 12、20、21、22、23… 都是含有數字 2的整數。 輸入 無 輸出 按題意要求由小到大輸出符合條件的整數,每行 1 個。…

在MYSQL中導入cookbook.sql文件

參考資料: GitHub 項目:svetasmirnova/mysqlcookbook CSDN 博客:https://blog.csdn.net/u011868279/category_11645577.html 建庫: mysql> use mysql Reading table information for completion of table and column names …

Scrapy框架下地圖爬蟲的進度監控與優化策略

1. 引言 在互聯網數據采集領域,地圖數據爬取是一項常見但具有挑戰性的任務。由于地圖數據通常具有復雜的結構(如POI點、路徑信息、動態加載等),使用傳統的爬蟲技術可能會遇到效率低下、反爬策略限制、任務進度難以監控等問題。 …

【Win32 API】 lstrcmpA()

作用 比較兩個字符字符串(比較區分大小寫)。 lstrcmp 函數通過從第一個字符開始檢查,若相等,則檢查下一個,直到找到不相等或到達字符串的末尾。 函數 int lstrcmpA(LPCSTR lpString1, LPCSTR lpString2); 參數 lpStr…

代碼隨想錄60期day38

2維背包 #include<bits/stdc.h> using namespace std;int main(){int n,bagweight;cin>>n>>bagweight;vector<int>weight(n,0);vector<int>value(n,0);for(int i 0 ; i <n;i){cin>>weight[i];}for(int j 0;j<n;j){cin>>val…

[模型部署] 1. 模型導出

&#x1f44b; 你好&#xff01;這里有實用干貨與深度分享?? 若有幫助&#xff0c;歡迎&#xff1a;? &#x1f44d; 點贊 | ? 收藏 | &#x1f4ac; 評論 | ? 關注 &#xff0c;解鎖更多精彩&#xff01;? &#x1f4c1; 收藏專欄即可第一時間獲取最新推送&#x1f514;…

mac的Cli為什么輸入python3才有用python --version顯示無效,pyenv入門筆記,如何查看mac自帶的標準庫模塊

根據你的終端輸出&#xff0c;可以得出以下結論&#xff1a; 1. 你的 Mac 當前只有一個 Python 版本 系統默認的 Python 3 位于 /usr/bin/python3&#xff08;這是 macOS 自帶的 Python&#xff09;通過 which python3 確認當前使用的就是系統自帶的 Pythonbrew list python …

Java注解詳解:從入門到實戰應用篇

1. 引言 Java注解&#xff08;Annotation&#xff09;是JDK 5.0引入的一種元數據機制&#xff0c;用于為代碼提供附加信息。它廣泛應用于框架開發、代碼生成、編譯檢查等領域。本文將從基礎到實戰&#xff0c;全面解析Java注解的核心概念和使用場景。 2. 注解基礎概念 2.1 什…

前端方法的總結及記錄

個人簡介 &#x1f468;?&#x1f4bb;?個人主頁&#xff1a; 魔術師 &#x1f4d6;學習方向&#xff1a; 主攻前端方向&#xff0c;正逐漸往全棧發展 &#x1f6b4;個人狀態&#xff1a; 研發工程師&#xff0c;現效力于政務服務網事業 &#x1f1e8;&#x1f1f3;人生格言&…

組件導航 (HMRouter)+flutter項目搭建-混合開發+分欄效果

組件導航 (Navigation)flutter項目搭建 接上一章flutter項目的環境變量配置并運行flutter 1.flutter創建項目并運行 flutter create fluter_hmrouter 進入ohos目錄打開編輯器先自動簽名 編譯項目-生成簽名包 flutter build hap --debug 運行項目 HMRouter搭建安裝 1.安…

城市排水管網流量監測系統解決方案

一、方案背景 隨著工業的不斷發展和城市人口的急劇增加&#xff0c;工業廢水和城市污水的排放量也大量增加。目前&#xff0c;我國已成為世界上污水排放量大、增加速度快的國家之一。然而&#xff0c;總體而言污水處理能力較低&#xff0c;有相當部分未經處理的污水直接或間接排…

TCP/IP 知識體系

TCP/IP 知識體系 一、TCP/IP 定義 全稱&#xff1a;Transmission Control Protocol/Internet Protocol&#xff08;傳輸控制協議/網際協議&#xff09;核心概念&#xff1a; 跨網絡實現信息傳輸的協議簇&#xff08;包含 TCP、IP、FTP、SMTP、UDP 等協議&#xff09;因 TCP 和…

5G行業專網部署費用詳解:投資回報如何最大化?

隨著數字化轉型的加速&#xff0c;5G行業專網作為企業提升生產效率、保障業務安全和實現智能化管理的重要基礎設施&#xff0c;正受到越來越多行業客戶的關注。部署5G專網雖然前期投入較大&#xff0c;但通過合理規劃和技術選擇&#xff0c;能夠實現投資回報的最大化。 在5G行…

網頁工具-OTU/ASV表格物種分類匯總工具

AI輔助下開發了個工具&#xff0c;功能如下&#xff0c;分享給大家&#xff1a; 基于Shiny開發的用戶友好型網頁應用&#xff0c;專為微生物組數據分析設計。該工具能夠自動處理OTU/ASV_taxa表格&#xff08;支持XLS/XLSX/TSV/CSV格式&#xff09;&#xff0c;通過調用QIIME1&a…

【超分辨率專題】一種考量視頻編碼比特率優化能力的超分辨率基準

這是一個Benchmark&#xff0c;超分辨率視頻編碼&#xff08;2024&#xff09; 專題介紹一、研究背景二、相關工作2.1 SR的發展2.2 SR benchmark的發展 三、Benchmark細節3.1 數據集制作3.2 模型選擇3.3 編解碼器和壓縮標準選擇3.4 Benchmark pipeline3.5 質量評估和主觀評價研…

保姆教程-----安裝MySQL全過程

1.電腦從未安裝過mysql的&#xff0c;先找到mysql官網&#xff1a;MySQL :: Download MySQL Community Server 然后下載完成后&#xff0c;找到文件&#xff0c;然后雙擊打開 2. 選擇安裝的產品和功能 依次點開“MySQL Servers”、“MySQL Servers”、“MySQL Servers 5.7”、…

【React中函數組件和類組件區別】

在 React 中,函數組件和類組件是兩種構建組件的方式,它們在多個方面存在區別,以下詳細介紹: 1. 語法和定義 類組件:使用 ES6 的類(class)語法定義,繼承自 React.Component。需要通過 this.props 來訪問傳遞給組件的屬性(props),并且通常要實現 render 方法返回 JSX…

[基礎] HPOP、SGP4與SDP4軌道傳播模型深度解析與對比

HPOP、SGP4與SDP4軌道傳播模型深度解析與對比 文章目錄 HPOP、SGP4與SDP4軌道傳播模型深度解析與對比第一章 引言第二章 模型基礎理論2.1 歷史演進脈絡2.2 動力學方程統一框架 第三章 數學推導與攝動機制3.1 SGP4核心推導3.1.1 J?攝動解析解3.1.2 大氣阻力建模改進 3.2 SDP4深…

搭建運行若依微服務版本ruoyi-cloud最新教程

搭建運行若依微服務版本ruoyi-cloud 一、環境準備 JDK > 1.8MySQL > 5.7Maven > 3.0Node > 12Redis > 3 二、后端 2.1數據庫準備 在navicat上創建數據庫ry-seata、ry-config、ry-cloud運行SQL文件ry_20250425.sql、ry_config_20250224.sql、ry_seata_2021012…

Google I/O 2025 觀看攻略一鍵收藏,開啟技術探索之旅!

AIGC開放社區https://lerhk.xetlk.com/sl/1SAwVJ創業邦https://weibo.com/1649252577/PrNjioJ7XCSDNhttps://live.csdn.net/room/csdnnews/OOFSCy2g/channel/collectiondetail?sid2941619DONEWShttps://www.donews.com/live/detail/958.html鳳凰科技https://flive.ifeng.com/l…