學習筆記(29):訓練集與測試集劃分詳解:train_test_split 函數深度解析

學習筆記(29):訓練集與測試集劃分詳解:train_test_split 函數深度解析

一、為什么需要劃分訓練集和測試集?

在機器學習中,模型需要經歷兩個核心階段:

  1. 訓練階段:用訓練集數據學習特征與目標值的映射關系(如線性回歸的權重)。
  2. 測試階段:用測試集評估模型在未見過的數據上的表現,避免 “過擬合”(模型只記住訓練數據的噪聲,無法泛化到新數據)。

類比場景:學生通過 “練習題”(訓練集)學習知識,再通過 “考試題”(測試集)檢驗真實水平。

二、train_test_split?函數的核心參數與邏輯
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42
)
1.?輸入參數解析
  • X_scaled:特征矩陣(已標準化的面積、房齡等特征)。
  • y:目標變量(房價)。
  • test_size=0.2:測試集占總數據的比例(20%),也可設為整數(如?test_size=20?表示取 20 個樣本)。
  • random_state=42:隨機種子,確保每次劃分結果一致(與?np.random.seed(42)?作用類似)。
2.?劃分邏輯
  • 隨機抽樣:按?test_size?比例從原始數據中隨機抽取樣本作為測試集,剩余作為訓練集。
  • 數據對齊:確保?X?和?y?的樣本順序一一對應(如第 i 個特征向量對應第 i 個房價標簽)。
三、劃分結果的維度與含義

假設原始數據有 100 個樣本(n_samples=100):

  • 訓練集:80 個樣本(X_train.shape=(80, 2),?y_train.shape=(80,)),用于模型學習。
  • 測試集:20 個樣本(X_test.shape=(20, 2),?y_test.shape=(20,)),用于評估模型泛化能力。
四、關鍵參數深度解析
1.?test_size:平衡訓練與測試的樣本量
  • 取值建議
    • 小數據集(<1000 樣本):常用?test_size=0.2~0.3(20%-30% 作為測試集)。
    • 大數據集(>10000 樣本):可設?test_size=0.1?甚至更低(因少量樣本已足夠評估)。
  • 極端案例:若?test_size=1.0,則所有數據都是測試集,無訓練集;若?test_size=0,則全是訓練集。
2.?random_state:確保可復現的 “隨機” 劃分
  • 作用:固定隨機種子后,每次運行代碼時,訓練集和測試集的樣本索引完全相同。
  • 示例對比
    • 不設置?random_state:每次劃分結果不同,導致模型評估指標波動。
    • 設置?random_state=42:多次運行代碼,劃分結果一致,便于對比不同模型效果。
3.?shuffle=True(默認參數):打亂數據順序
  • 為什么需要打亂?
    若數據按順序排列(如前 50 個是小戶型,后 50 個是大戶型),不打亂會導致訓練集和測試集樣本分布不均(如測試集全是大戶型)。
  • 參數設置train_test_split?默認為?shuffle=True,即先打亂數據再劃分;若數據已隨機排列,可設?shuffle=False
五、進階應用:分層抽樣(Stratified Sampling)

當目標變量是分類變量(如二分類 “是否違約”)時,普通隨機劃分可能導致訓練 / 測試集的類別比例失衡(如測試集全是 “違約” 樣本)。此時需用?StratifiedShuffleSplit?實現分層抽樣:

from sklearn.model_selection import StratifiedShuffleSplit# 4. 使用分層抽樣(確保類別比例平衡)
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
for train_idx, test_idx in sss.split(X_scaled, y_binary):X_train, X_test = X_scaled[train_idx], X_scaled[test_idx]y_train, y_test = y_binary[train_idx], y_binary[test_idx]print("===== 分類模型結果 =====")
print(f"原始數據類別比例:{np.bincount(y_binary)/len(y_binary)}")
print(f"訓練集類別比例:{np.bincount(y_train)/len(y_train)}")
print(f"測試集類別比例:{np.bincount(y_test)/len(y_test)}")
六、實戰誤區與注意事項
  1. 禁止在測試集上訓練:測試集只能用于評估,若根據測試集結果調整模型參數(如調優正則化系數),本質上是 “偷看答案”,會導致評估結果過于樂觀。
  2. 數據標準化的順序
    • 正確流程:先劃分訓練測試集,再對訓練集擬合標準化器(scaler.fit(X_train)),最后用訓練集的標準化參數轉換測試集(scaler.transform(X_test))。
    • 錯誤操作:對全量數據標準化后再劃分,會導致測試集 “偷看到” 全量數據的統計特征,違反 “未知數據” 假設。
  3. 多輪劃分與交叉驗證:當數據量較小時,可使用 K 折交叉驗證(如 10 折),將數據分成 10 份,每次用 9 份訓練、1 份測試,重復 10 次取平均,減少單次劃分的隨機性誤差。
七、總結:劃分訓練測試集的核心原則
  1. 獨立性:測試集數據必須是模型未見過的,模擬真實應用場景。
  2. 代表性:訓練集和測試集的樣本分布應盡可能一致(如特征取值范圍、類別比例)。
  3. 可復現性:通過設置隨機種子,確保實驗結果可重復驗證。

通過合理劃分訓練集與測試集,你可以更準確地評估模型的實際能力,避免被 “過擬合” 的假象誤導 —— 這是機器學習工程化中至關重要的一步!

二分類問題(房價是否高于中位數)-全代碼

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split, StratifiedShuffleSplit
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error, r2_score# 配置中文顯示
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False# 1. 生成模擬數據(假設房價與面積、房齡的關系)
np.random.seed(42)
n_samples = 100
# 面積(平方米),房齡(年)
X = np.random.rand(n_samples, 2) * 100
X[:, 0] = X[:, 0]  # 面積范圍:0-100
X[:, 1] = X[:, 1]  # 房齡范圍:0-100# 真實房價 = 5000*面積 + 1000*房齡 + 隨機噪聲(模擬真實場景)
y = 5000 * X[:, 0] + 1000 * X[:, 1] + np.random.randn(n_samples) * 10000# 2. 將連續的房價y轉換為分類標簽(例如分為低、中、高3個類別)
y_category = pd.qcut(y, q=3, labels=[0, 1, 2])  # 使用pandas的qcut進行分位數切割# 3. 數據預處理:標準化特征
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)# 4. 使用分層抽樣
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
for train_idx, test_idx in sss.split(X_scaled, y_category):X_train, X_test = X_scaled[train_idx], X_scaled[test_idx]y_train, y_test = y[train_idx], y[test_idx]  # 注意:這里仍然使用原始的連續房價作為目標_# 確保訓練集和測試集的類別比例與原始數據一致
print(f"原始數據類別比例:{np.bincount(y_category)/len(y_category)}")
print(f"訓練集類別比例:{np.bincount(y_category[train_idx])/len(y_category[train_idx])}")
print(f"測試集類別比例:{np.bincount(y_category[test_idx])/len(y_category[test_idx])}")# 后續回歸模型訓練和評估代碼保持不變
model = LinearRegression()
model.fit(X_train, y_train)
y_pred = model.predict(X_test)# 評估模型
mse = mean_squared_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)
print(f"均方誤差: {mse:.2f}")
print(f"決定系數R2: {r2:.2f}")

打印:

原始數據類別比例:[0.34 0.32 0.34]
訓練集類別比例:[0.3375 0.325 ?0.3375]
測試集類別比例:[0.35 0.3 ?0.35]
均方誤差: 101112597.45
決定系數R2: 1.00

代碼解析:
核心步驟解析
  1. 數據準備與二分類轉換

    • 生成與方案 1 相同的模擬數據(面積、房齡 → 房價)。
    • 將連續的房價y轉換為二分類標簽:
threshold = np.median(y)  # 使用中位數作為閾值
y_binary = (y > threshold).astype(int)  # 0=低于中位數,1=高于中位數
  1. 這樣做的目的是將 “預測具體房價” 轉化為 “判斷房價高低”。

分層抽樣(Stratified Sampling)

  • 使用StratifiedShuffleSplit確保訓練集和測試集中高低房價的比例與原始數據一致:
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
for train_idx, test_idx in sss.split(X_scaled, y_category):X_train, X_test = X_scaled[train_idx], X_scaled[test_idx]y_train, y_test = y[train_idx], y[test_idx]  # 注意:這里仍然使用原始的連續房價作為目標_

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/87357.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/87357.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/87357.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

【全網唯一】自動化編輯器 Windows版純本地離線文字識別插件

目的 自動化編輯器超輕量級RPA工具&#xff0c;零代碼制作RPA自動化任務&#xff0c;解放雙手&#xff0c;釋放雙眼&#xff0c;輕松玩游戲&#xff0c;刷任務。本篇文章主要講解下自動化編輯器的TomatoOCR純本地離線文字識別Windows版插件如何使用和集成。 準備工作 1、下載自…

GitHub 2FA綁定

GitHub 2FA綁定 作為全球最大的代碼托管平臺&#xff0c;GitHub對賬號安全的重視程度不斷提升——自2023年3月起&#xff0c;GitHub已要求所有在GitHub.com上貢獻代碼的用戶必須啟用雙因素身份驗證&#xff08;2FA&#xff09;。如果你是符合條件的用戶&#xff0c;會收到一封…

pytest fixture基礎大全詳解

一、介紹 作用 fixture主要有兩個作用&#xff1a; 復用測試數據和環境&#xff0c;可以減少重復的代碼&#xff1b;可以在測試用例運行前和運行后設置和清理資源&#xff0c;避免對測試結果產生影響&#xff0c;同時也可以提高測試用例的運行效率。 優勢 pytest框架的fix…

Unity知識點-Renderer常用材質變量

本篇總結了Unity中renderer的3種常用的材質相關的變量&#xff1a;renderer.material,renderer.sharedMaterial,renderer.MaterialPropertyBlock。以及三者對SRPBatcher的影響。 一.介紹及對比 1.概念介紹 1.material 定義&#xff1a;material 是Render組件&#xff08;如…

【算法】??如何判斷時間復雜度?

文章目錄 1. 什么是時間復雜度&#xff1f;為什么需要時間復雜度&#xff1f; 2. 常見時間復雜度對比3. 如何分析時間復雜度&#xff1f;&#xff08;Java版&#xff09;&#x1f539; 步驟1&#xff1a;找出基本操作&#x1f539; 步驟2&#xff1a;分析循環結構&#xff08;1…

MySQL使用C語言連接

文章目錄 版本查看以及編譯mysql接口介紹初始化鏈接數據庫下發mysql命令mysql_query獲取執行結果mysql_store_result獲取結果行數mysql_num_rows獲取結果列數mysql_num_fields獲取列名mysql_fetch_fields獲取結果內容mysql_fetch_row關閉mysql鏈接mysql_closeC語言操作mysql查看…

堅持每日Codeforces三題挑戰:Day 7 - 題目詳解(2025-06-11,難度:1200,1300,1500)

每天堅持寫三道題第七天&#xff1a; Problem - A - Codeforces 1200 Problem - B - Codeforces 1300 Problem - A - Codeforces 1500 目錄 題目一: 題目大意: 解題思路: 代碼(C): 題目二: 題目大意: 解題思路: 代碼(C): 題目三: 題目大意: 解題思路: 代碼(C): …

洛谷 P4305:[JLOI2011] 不重復數字 ← unordered_set

【題目來源】 https://www.luogu.com.cn/problem/P4305 【題目描述】 給定 n 個數&#xff0c;要求把其中重復的去掉&#xff0c;只保留第一次出現的數。 【輸入格式】 第一行一個整數 T&#xff0c;表示數據組數。 對于每組數據&#xff0c;第一行一個整數 n。第二行 n 個數…

STM32固件升級設計——SPIFLASH模擬U盤升級固件

目錄 概述 一、功能描述 1、BootLoader部分&#xff1a; 2、APP部分&#xff1a; 二、BootLoader程序制作 1、分區定義 2、 主函數 3、配置USB 4、配置fatfs文件系統 5、程序跳轉 三、APP程序制作 四、工程配置&#xff08;默認KEIL5&#xff09; 五、運行測試 六…

解鎖阿里云日志服務SLS:云時代的日志管理利器

引言&#xff1a;開啟日志管理新篇 在云計算時代&#xff0c;數據如同企業的血液&#xff0c;源源不斷地產生并流動。從用戶的每一次點擊&#xff0c;到系統后臺的每一個操作&#xff0c;數據都在記錄著企業運營的軌跡。而在這些海量的數據中&#xff0c;日志數據占據著至關重…

Keye-VL-8B-Preview:由快手 Kwai Keye 團隊精心打造的尖端多模態大語言模型

&#x1f525; News 2025.06.26 &#x1f31f; 我們非常自豪地推出Kwai Keye-VL&#xff0c;這是快手Kwai Keye團隊精心打造的前沿多模態大語言模型。作為快手先進技術生態中的核心AI產品&#xff0c;Keye在視頻理解、視覺感知和推理任務方面表現卓越&#xff0c;樹立了新的性…

Web前端之JavaScript實現圖片圓環、圓環元素根據角度指向圓心、translate、rotate

MENU 前言效果HtmlStyleJavaScript 前言 代碼段創建了一個由6個WiFi圖標組成的圓形排列&#xff0c;每個圖標均勻分布在圓周上。 效果 Html 代碼 <div class"ring"><div class"item"><img class"img" src"../image/icon/W…

1 Studying《Computer Vision: Algorithms and Applications 2nd Edition》11-15

目錄 Chapter 11 Structure from motion and SLAM 11.1 幾何內稟校準 11.2 姿態估計 11.3 從運動中獲得的雙幀結構 11.4 從運動中提取多幀結構 11.5 同步定位與建圖&#xff08;SLAM&#xff09; 11.6 額外閱讀 Chapter 12 Depth estimation 12.1 極點幾何 12.2 稀疏…

phpstudy 可以按照mysql 數據庫

phpstudy 可以按照mysql 數據庫 PHPStudy&#xff08;小皮面板&#xff09;是一款專為開發者設計的集成環境工具&#xff0c;涵蓋服務器配置、開發環境搭建、網站部署等多項功能。以下是其核心用途及優勢的詳細解析&#xff1a; 一、開發環境快速搭建 一站式集成環境集成Apa…

Python搭建HTTP服務,如何用內網穿透快速遠程訪問?

Python的內置HTTP服務模塊是開發者工具箱中的瑞士軍刀&#xff0c;只需一行命令即可啟動一個功能完備的Web服務器。無論是前端工程師調試頁面、數據科學家共享Jupyter Notebook&#xff0c;還是后端開發者快速驗證API原型&#xff0c;Python HTTP服務都能以零配置的方式滿足需求…

撥號音識別系統的設計與實現

撥號音識別系統的設計與實現 摘要 本文設計并實現了一個完整的撥號音識別系統&#xff0c;該系統能夠自動識別電話號碼中的數字。系統基于雙音多頻(DTMF)技術原理&#xff0c;使用MATLAB開發&#xff0c;包含GUI界面展示處理過程和結果。系統支持從麥克風實時錄音或加載音頻文…

數據結構-樹詳解

樹簡介 樹存儲和組織具有層級結構的數據&#xff08;例&#xff1a;公司職級&#xff09;&#xff0c;就是一顆倒立生長的樹。 屬性&#xff1a; 遞歸n個節點有n-1個連接節點x的深度&#xff1a;節點x到根節點的最長路徑節點x的高度&#xff1a;節點x到葉子節點的最長路徑 …

【安卓Sensor框架-2】應用注冊Sensor 流程

注冊傳感器的核心流程為如下&#xff1a;應用層調用 SensorManager注冊傳感器&#xff0c;framework層創建SensorEventQueue對象&#xff08;事件隊列&#xff09;&#xff0c;通過JNI調用Native方法nativeEnableSensor()&#xff1b;SensorService服務端createEventQueue()創建…

新版本沒有docker-desktop-data分發 | docker desktop 鏡像遷移

在新版本的docker desktop中&#xff08;如4.42版本&#xff09;&#xff0c;鏡像遷移只需要更改路徑即可。如下&#xff1a; 打開docker desktop的設置&#xff08;圖1&#xff09;&#xff0c;將圖2的原來的地址C:\Users\用戶\AppData\Local\Docker\wsl修改為你想要的空文件…

EtherCAT SOEM源碼分析 - ec_init

ec_init SOEM主站一切開始的地方始于ec_init, 它是EtherCAT主站初始化的入口。初始化SOEM 主站&#xff0c;并綁定到socket到ifname。 /** Initialise lib in single NIC mode* param[in] ifname Dev name, f.e. "eth0"* return >0 if OK* see ecx_init*/ in…