【目標檢測】模型驗證:K-Fold 交叉驗證

K-Fold 交叉驗證

  • 1、引言
    • 1.1 K 折交叉驗證概述
  • 2、配置
    • 2.1 數據集
    • 2.2 安裝包
  • 3、 實戰
    • 3.1 生成物體檢測數據集的特征向量
    • 3.2 K 折數據集拆分
    • 3.3 保存記錄
    • 3.4 使用 K 折數據分割訓練YOLO
  • 4、總結

1、引言

我們將利用YOLO 檢測格式和關鍵的Python 庫(如 sklearn、pandas 和 PyYaml),完成必要的設置、生成特征向量的過程以及 K-Fold 數據集拆分的執行。

1.1 K 折交叉驗證概述

無論你的項目涉及水果檢測數據集還是自定義數據源,都可以使用 K 折交叉驗證,
以提高項目的可靠性和穩健性。

書說簡短,閑言少敘,咱進入正題
在這里插入圖片描述

2、配置

2.1 數據集

該數據集共包含 8479 幅圖像。
它包括 6 個類別標簽,每個標簽的實例總數如下:

類別計數
蘋果7049
葡萄7202
菠蘿1613
橙色15549
香蕉3536
西瓜1976

2.2 安裝包

必要的Python 軟件包包括

  • ultralytics
  • sklearn
  • pandas
  • pyyaml

這次實例中,我們使用 k=5 折疊次數

3、 實戰

3.1 生成物體檢測數據集的特征向量

具體步驟如下:

  • 1、首先創建一個新的 demo.py Python 文件來執行下面的步驟。

  • 2、繼續檢索數據集的所有標簽文件。

from pathlib import Pathdataset_path = Path("./Fruit-detection")  # replace with 'path/to/dataset' for your custom data
labels = sorted(dataset_path.rglob("*labels/*.txt"))  # all data in 'labels'
  • 3、現在,讀取數據集 YAML 文件的內容并提取類標簽的索引。
yaml_file = "path/to/data.yaml"  # your data YAML with data directories and names dictionary
with open(yaml_file, "r", encoding="utf8") as y:classes = yaml.safe_load(y)["names"]
cls_idx = sorted(classes.keys())
  • 4、初始化一個空的 pandas DataFrame.
import pandas as pdindex = [label.stem for label in labels]  # uses base filename as ID (no extension)
labels_df = pd.DataFrame([], columns=cls_idx, index=index)
  • 5、計算注釋文件中每個類別標簽的實例數。
from collections import Counterfor label in labels:lbl_counter = Counter()with open(label, "r") as lf:lines = lf.readlines()for line in lines:# classes for YOLO label uses integer at first position of each linelbl_counter[int(line.split(" ")[0])] += 1labels_df.loc[label.stem] = lbl_counterlabels_df = labels_df.fillna(0.0)  # replace `nan` values with `0.0`
  • 6、以下是已填充 DataFrame 的示例視圖:
                                                       0    1    2    3    4    5
'0000a16e4b057580_jpg.rf.00ab48988370f64f5ca8ea4...'  0.0  0.0  0.0  0.0  0.0  7.0
'0000a16e4b057580_jpg.rf.7e6dce029fb67f01eb19aa7...'  0.0  0.0  0.0  0.0  0.0  7.0
'0000a16e4b057580_jpg.rf.bc4d31cdcbe229dd022957a...'  0.0  0.0  0.0  0.0  0.0  7.0
'00020ebf74c4881c_jpg.rf.508192a0a97aa6c4a3b6882...'  0.0  0.0  0.0  1.0  0.0  0.0
'00020ebf74c4881c_jpg.rf.5af192a2254c8ecc4188a25...'  0.0  0.0  0.0  1.0  0.0  0.0...                                                  ...  ...  ...  ...  ...  ...
'ff4cd45896de38be_jpg.rf.c4b5e967ca10c7ced3b9e97...'  0.0  0.0  0.0  0.0  0.0  2.0
'ff4cd45896de38be_jpg.rf.ea4c1d37d2884b3e3cbce08...'  0.0  0.0  0.0  0.0  0.0  2.0
'ff5fd9c3c624b7dc_jpg.rf.bb519feaa36fc4bf630a033...'  1.0  0.0  0.0  0.0  0.0  0.0
'ff5fd9c3c624b7dc_jpg.rf.f0751c9c3aa4519ea3c9d6a...'  1.0  0.0  0.0  0.0  0.0  0.0
'fffe28b31f2a70d4_jpg.rf.7ea16bd637ba0711c53b540...'  0.0  6.0  0.0  0.0  0.0  0.0

解析

  • 行是標簽文件的索引,每個標簽文件對應數據集中的一幅圖像,列則對應類標簽索引。
  • 每一行代表一個偽特征向量,其中包含數據集中每個類標簽的計數。
  • 這種數據結構可以將 K 折交叉驗證應用于對象檢測數據集。

3.2 K 折數據集拆分

  • 1、使用 KFold 從 sklearn.model_selection 以產生 k 對數據集進行分割。

    • 敲黑板:
      • 設置 shuffle=True 確保了分班中班級的隨機分布。
      • 通過設置 random_state=M 其中 M 是一個選定的整數,這樣就可以得到可重復的結果。
from sklearn.model_selection import KFoldksplit = 5
kf = KFold(n_splits=ksplit, shuffle=True, random_state=20)  # setting random_state for repeatable resultskfolds = list(kf.split(labels_df))
  • 2、數據集現已分為 k 折疊,每個折疊都有一個 train 和 val 指數。我們將構建一個 DataFrame 來更清晰地顯示這些結果。
folds = [f"split_{n}" for n in range(1, ksplit + 1)]
folds_df = pd.DataFrame(index=index, columns=folds)for i, (train, val) in enumerate(kfolds, start=1):folds_df[f"split_{i}"].loc[labels_df.iloc[train].index] = "train"folds_df[f"split_{i}"].loc[labels_df.iloc[val].index] = "val"
  • 3、將計算每個褶皺的類別標簽分布,并將其作為褶皺中出現的類別的比率。
fold_lbl_distrb = pd.DataFrame(index=folds, columns=cls_idx)for n, (train_indices, val_indices) in enumerate(kfolds, start=1):train_totals = labels_df.iloc[train_indices].sum()val_totals = labels_df.iloc[val_indices].sum()# To avoid division by zero, we add a small value (1E-7) to the denominatorratio = val_totals / (train_totals + 1e-7)fold_lbl_distrb.loc[f"split_{n}"] = ratio
最理想的情況是,每次分割和不同類別的所有類別比率都相當相似。不過,這取決于數據集的具體情況。
  • 4、為每個分割創建目錄和數據集 YAML 文件。
import datetimesupported_extensions = [".jpg", ".jpeg", ".png"]# Initialize an empty list to store image file paths
images = []# Loop through supported extensions and gather image files
for ext in supported_extensions:images.extend(sorted((dataset_path / "images").rglob(f"*{ext}")))# Create the necessary directories and dataset YAML files (unchanged)
save_path = Path(dataset_path / f"{datetime.date.today().isoformat()}_{ksplit}-Fold_Cross-val")
save_path.mkdir(parents=True, exist_ok=True)
ds_yamls = []for split in folds_df.columns:# Create directoriessplit_dir = save_path / splitsplit_dir.mkdir(parents=True, exist_ok=True)(split_dir / "train" / "images").mkdir(parents=True, exist_ok=True)(split_dir / "train" / "labels").mkdir(parents=True, exist_ok=True)(split_dir / "val" / "images").mkdir(parents=True, exist_ok=True)(split_dir / "val" / "labels").mkdir(parents=True, exist_ok=True)# Create dataset YAML filesdataset_yaml = split_dir / f"{split}_dataset.yaml"ds_yamls.append(dataset_yaml)with open(dataset_yaml, "w") as ds_y:yaml.safe_dump({"path": split_dir.as_posix(),"train": "train","val": "val","names": classes,},ds_y,)
  • 5、最后,將圖像和標簽復制到每個分割的相應目錄("train "或 “val”)中。
import shutilfor image, label in zip(images, labels):for split, k_split in folds_df.loc[image.stem].items():# Destination directoryimg_to_path = save_path / split / k_split / "images"lbl_to_path = save_path / split / k_split / "labels"# Copy image and label files to new directory (SamefileError if file already exists)shutil.copy(image, img_to_path / image.name)shutil.copy(label, lbl_to_path / label.name)

3.3 保存記錄

將 K 折分割和標簽分布數據框的記錄保存為 CSV 文件。

folds_df.to_csv(save_path / "kfold_datasplit.csv")
fold_lbl_distrb.to_csv(save_path / "kfold_label_distribution.csv")

3.4 使用 K 折數據分割訓練YOLO

  • 首先,加載YOLO 模型。
from ultralytics import YOLOweights_path = "path/to/weights.pt"
model = YOLO(weights_path, task="detect")
  • 其次,遍歷數據集 YAML 文件以運行訓練。結果將保存到由 project 和 name 參數。默認情況下,該目錄為 “exp/runs#”,其中 # 為整數索引。
results = {}# Define your additional arguments here
batch = 16
project = "kfold_demo"
epochs = 100for k in range(ksplit):dataset_yaml = ds_yamls[k]model = YOLO(weights_path, task="detect")model.train(data=dataset_yaml, epochs=epochs, batch=batch, project=project)  # include any train argumentsresults[k] = model.metrics  # save output metrics for further analysis

4、總結

這篇小魚使用了 K 折交叉驗證來訓練YOLO 物體檢測模型的過程。

還創建報告 DataFrames 的程序,以可視化數據拆分和標簽在這些拆分中的分布,清楚地了解訓練集和驗證集的結構。

此外,還保存了記錄,這在大型項目或排除模型性能故障時尤為有用。

最后,在一個循環中使用每個拆分來執行實際的模型訓練,保存訓練結果,以便進一步分析和比較。

這種 K 折交叉驗證技術是充分利用可用數據的一種穩健方法,有助于確保模型在不同數據子集中的性能是可靠和一致的。這將產生一個更具通用性和可靠性的模型,從而減少對特定數據模式的過度擬合。

我是小魚

  • CSDN 博客專家
  • 阿里云 專家博主
  • 51CTO博客專家
  • 企業認證金牌面試官
  • 多個名企認證&特邀講師等
  • 名企簽約職場面試培訓、職場規劃師
  • 多個國內主流技術社區的認證專家博主
  • 多款主流產品(阿里云等)評測一等獎獲得者

關注小魚,學習【人工智能&大模型】/【深度學習&機器學習】領域最新最全的知識。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/68102.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/68102.shtml
英文地址,請注明出處:http://en.pswp.cn/web/68102.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Android studio ternimal 中gradle 指令失效(gradle環境變量未配置)

默認gradle路徑:C:\Users\ylwj.gradle\wrapper\dists\gradle-8.10.2-bin\a04bxjujx95o3nb99gddekhwo\gradle-8.10.2\bin 環境變量-系統環境變量-雙擊path-配置上即可-注意重啟studio才會生效

Axure大屏可視化動態交互設計:解鎖數據魅力,引領決策新風尚

可視化組件/模板預覽:https://8dge09.axshare.com 一、大屏可視化技術概覽 在數據驅動決策的時代,大屏可視化技術憑借直觀、動態的展示方式,已成為眾多行業提升管理效率和優化決策過程的關鍵工具。它能夠將復雜的數據轉化為易于理解的圖形和…

Resnet 改進:嘗試在不同位置加入Transform模塊

目錄 1. TransformerBlock 2. resnet 3. 替換部分卷積層 4. 在特定位置插入Transformer模塊 5. 使用Transformer全局特征提取器 6. 其他 Tips:融入模塊后的網絡經過測試,可以直接使用,設置好輸入和輸出的圖片維度即可 1. TransformerBlock TransformerBlock是Transfo…

PromptSource和LangChain哪個更好

目錄 1. 設計目標與定位 PromptSource LangChain 2. 功能對比 3. 優缺點分析 PromptSource LangChain 4. 如何選擇? 5. 總結 PromptSource 和 LangChain 是兩個在自然語言處理(NLP)領域非常有用的工具,但它們的設計目標和…

MySQL調優02 - SQL語句的優化

SQL語句的優化 文章目錄 SQL語句的優化一:SQL優化的小技巧1:編寫SQL時的注意點1.1:查詢時盡量不要使用*1.2:連表查詢時盡量不要關聯太多表1.3:多表查詢時一定要以小驅大1.4:like不要使用左模糊或者全模糊1.…

langchain教程-12.Agent/工具定義/Agent調用工具/Agentic RAG

前言 該系列教程的代碼: https://github.com/shar-pen/Langchain-MiniTutorial 我主要參考 langchain 官方教程, 有選擇性的記錄了一下學習內容 這是教程清單 1.初試langchain2.prompt3.OutputParser/輸出解析4.model/vllm模型部署和langchain調用5.DocumentLoader/多種文檔…

如何實現網頁不用刷新也能更新

要實現用戶在網頁上不用刷新也能到下一題,可以使用 前端和后端交互的技術,比如 AJAX(Asynchronous JavaScript and XML)、Fetch API 或 WebSocket 來實現局部頁面更新。以下是一個實現思路: 1. 使用前端 AJAX 或 Fetch…

在ubuntu22.04上先部署docker,再編譯安裝kamailio,附詳細操作流程及docker和makailio的版本號

以下是在Ubuntu 22.04上部署Docker并編譯安裝Kamailio的詳細操作流程,包含版本號信息: 一、部署Docker(版本:24.0.7) 更新系統包 sudo apt update && sudo apt upgrade -y安裝依賴工具 sudo apt install -y ap…

大模型中提到的超參數是什么

在大模型中提到的超參數是指在模型訓練之前需要手動設置的參數,這些參數決定了模型的訓練過程和最終性能。超參數與模型內部通過訓練獲得的參數(如權重和偏置)不同,它們通常不會通過訓練自動學習,而是需要開發者根據任…

位運算及常用技巧

涉及位運算的運算符如下表所示: 位運算的運算律: 負數的位運算 首先,我們要知道,在計算機中,運算是使用的二進制補碼,而正數的補碼是它本身,負數的補碼則是符號位不變,其余按位取反…

組合總和III(力扣216)

這道題在回溯的基礎上加入了剪枝操作。回溯方面我就不過多贅述,與組合(力扣77)-CSDN博客 大差不差,主要講解一下剪枝(下面的代碼也有回溯操作的詳細注釋)。我們可以發現,如果我們遞歸到后面,可能集合過小,無法滿足題目…

hot100(8)

71.10. 正則表達式匹配 - 力扣(LeetCode) 動態規劃 題解:10. 正則表達式匹配題解 - 力扣(LeetCode) 72.5. 最長回文子串 - 力扣(LeetCode) 動態規劃 1.dp數組及下標含義 dp[i][j] : 下標i到…

二進制/源碼編譯安裝httpd 2.4,提供系統服務管理腳本并測試

方法一:使用 systemd 服務文件 安裝所需依賴 yum install gcc make apr-devel apr-util-devel pcre-devel 1.下載源碼包 wget http://archive.apache.org/dist/httpd/httpd-2.4.62.tar.gz 2.解壓源碼 tar -xf httpd-2.4.62.tar.gz cd httpd-2.4.62 3.編譯安裝 指定…

Java 中 LinkedList 的底層源碼

在 Java 的集合框架中,LinkedList是一個獨特且常用的成員。它基于雙向鏈表實現,與數組結構的集合類如ArrayList有著顯著差異。深入探究LinkedList的底層源碼,有助于我們更好地理解其工作原理和性能特點,以便在實際開發中做出更合適…

Level2逐筆成交逐筆委托數據分享下載:20250127

Level2逐筆成交逐筆委托數據分享下載 采用Level2逐筆成交與逐筆委托的毫秒級數據,可以揭露眾多有用信息,如莊家策略、偽裝交易,讓所有交易行為透明化。這對于交易高手的策略分析極為有用,對人工智能領域的機器學習也極為合適&…

金蝶云星空k3cloud webapi報“java.lang.Class cannot be cast to java.lang.String”的錯誤

最近在對接金蝶云星空k3cloud webapi時,報一個莫名其妙的轉換異常,具體如下: 同步部門異常! ERP接口登錄異常:java.lang.Class cannot be cast to java.lang.String at com.jkwms.k3cloudSyn.service.basics.DeptK3CloudService.…

【Android】jni開發之導入opencv和libyuv來進行圖像處理

做視頻圖像處理時需要對其進行水印的添加,放在應用層調用工具性能方面不太滿意,于是當下采用opencvlibyuv方法進行處理。 對于Android的jni開發不是很懂,我的需求是導入opencv方便在cpp中調用,但目前找到的教程都是把opencv作為模…

【MySQL】centos 7 忘記數據庫密碼

vim /etc/my.cnf文件; 在[mysqld]后添加skip-grant-tables(登錄時跳過權限檢查) 重啟MySQL服務:sudo systemctl restart mysqld 登錄mysql,輸入mysql –uroot –p;直接回車(Enter) 輸…

國產編輯器EverEdit - 自定義標記使用詳解

1 自定義標記使用詳解 1.1 應用場景 當閱讀日志等文件,用于調試或者檢查問題時,往往日志中會有很多關鍵性的單詞,比如:ERROR, FATAL等,但由于文本模式對這些關鍵詞并沒有突出顯示,造成檢查問題時&#xff…

Golang 并發機制-6:掌握優雅的錯誤處理藝術

并發編程可能是提高軟件系統效率和響應能力的一種強有力的技術。它允許多個工作負載同時運行,充分利用現代多核cpu。然而,巨大的能力帶來巨大的責任,良好的錯誤管理是并發編程的主要任務之一。 并發代碼的復雜性 并發編程增加了順序程序所不…