dataset.py
ultralytics\data\dataset.py
目錄
dataset.py
1.所需的庫和模塊
2.class YOLODataset(BaseDataset):?
3.class YOLOMultiModalDataset(YOLODataset):?
4.class GroundingDataset(YOLODataset):?
5.class YOLOConcatDataset(ConcatDataset):?
6.class SemanticDataset(BaseDataset):?
7.class ClassificationDataset:?
1.所需的庫和模塊
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/licenseimport json
from collections import defaultdict
from itertools import repeat
from multiprocessing.pool import ThreadPool
from pathlib import Pathimport cv2
import numpy as np
import torch
from PIL import Image
from torch.utils.data import ConcatDatasetfrom ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM, colorstr
from ultralytics.utils.ops import resample_segments
from ultralytics.utils.torch_utils import TORCHVISION_0_18from .augment import (Compose,Format,Instances,LetterBox,RandomLoadText,classify_augmentations,classify_transforms,v8_transforms,
)
from .base import BaseDataset
from .utils import (HELP_URL,LOGGER,get_hash,img2label_paths,load_dataset_cache_file,save_dataset_cache_file,verify_image,verify_image_label,
)# Ultralytics dataset *.cache version, >= 1.0.0 for YOLOv8 Ultralytics 數據集 *.cache 版本,對于 YOLOv8,版本 >= 1.0.0 。
DATASET_CACHE_VERSION = "1.0.3"
2.class YOLODataset(BaseDataset):?
# 這段代碼定義了一個名為 YOLODataset 的類,繼承自 BaseDataset ,用于處理 YOLO 模型的數據加載和預處理。它支持多種任務(目標檢測、分割、姿態估計等),并提供了緩存標簽、數據增強和數據格式化等功能。
# 定義了 YOLODataset 類,繼承自 BaseDataset ,用于處理 YOLO 模型的數據加載和預處理。
class YOLODataset(BaseDataset):# 用于以 YOLO 格式加載對象檢測和/或分割標簽的數據集類。"""Dataset class for loading object detection and/or segmentation labels in YOLO format.Args:data (dict, optional): A dataset YAML dictionary. Defaults to None.task (str): An explicit arg to point current task, Defaults to 'detect'.Returns:(torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model."""# 這段代碼是 YOLODataset 類的構造函數 __init__ ,用于初始化數據集類的實例。# 定義了 YOLODataset 類的構造函數。 接受以下參數 :# 1.*args 和 4.**kwargs :傳遞給父類 BaseDataset 的參數。# 2.data :數據集的配置信息(例如類別名稱、關鍵點信息等),默認為 None 。# 3.task :指定任務類型,默認為 "detect" ,支持以下任務 :# "detect" :目標檢測。# "segment" :分割任務。# "pose" :姿態估計任務。# "obb" :定向邊界框任務。def __init__(self, *args, data=None, task="detect", **kwargs):# 使用可選的片段和關鍵點配置初始化 YOLODataset。"""Initializes the YOLODataset with optional configurations for segments and keypoints."""# 根據任務類型 task 初始化布爾標志。# 如果任務是 "segment" ,則為 True ,表示啟用分割任務。self.use_segments = task == "segment"# 如果任務是 "pose" ,則為 True ,表示啟用姿態估計任務。self.use_keypoints = task == "pose"# 如果任務是 "obb" ,則為 True ,表示啟用定向邊界框任務。self.use_obb = task == "obb"# 將 數據集配置 存儲在實例變量 self.data 中,供后續方法使用。self.data = data# 檢查是否同時啟用了 分割任務 和 姿態估計 任務。 如果同時啟用,拋出 AssertionError ,因為這兩種任務不兼容。assert not (self.use_segments and self.use_keypoints), "Can not use both segments and keypoints." # 不能同時使用段和關鍵點。# 調用父類 BaseDataset 的構造函數,將 *args 和 **kwargs 傳遞給父類。 這一步確保父類的初始化邏輯被執行,例如設置數據集路徑、圖像大小等。super().__init__(*args, **kwargs)# 這段代碼的功能是。初始化任務標志:根據任務類型 task 設置布爾標志,決定是否啟用分割、姿態估計或定向邊界框任務。存儲數據集配置:將數據集的配置信息存儲在實例變量中。檢查任務沖突:確保不會同時啟用分割和姿態估計任務,因為這兩種任務不兼容。調用父類構造函數:執行父類的初始化邏輯,確保繼承的屬性和方法被正確初始化。通過這種方式, YOLODataset 類能夠根據任務類型動態調整其行為,并為后續的數據加載和預處理提供必要的配置信息。# 這段代碼定義了 YOLODataset 類中的 cache_labels 方法,用于驗證和緩存數據集的標簽信息。它通過多線程并行處理圖像和標簽文件,統計標簽的有效性、缺失、空標簽和損壞情況,并將結果保存到緩存文件中。# 定義了 cache_labels 方法,它接受一個參數。# 1.path :保存緩存文件的路徑。默認緩存路徑為 ./labels.cache 。def cache_labels(self, path=Path("./labels.cache")):# 緩存數據集標簽,檢查圖像并讀取形狀。"""Cache dataset labels, check images and read shapes.Args:path (Path): Path where to save the cache file. Default is Path("./labels.cache").Returns:(dict): labels."""# 初始化一個字典 x ,用于 存儲緩存信息 ,其中 "labels" 鍵存儲標簽數據。x = {"labels": []}# 初始化統計變量。# nm :缺失標簽的數量。# nf :找到的標簽數量。# ne :空標簽的數量。# nc :損壞的標簽數量。# msgs :驗證過程中生成的消息列表。nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages# 初始化描述信息。# desc 描述信息,用于進度條顯示,說明正在掃描的路徑。desc = f"{self.prefix}Scanning {path.parent / path.stem}..." # {self.prefix} 正在掃描 {path.parent / path.stem}...# 圖像文件的總數 ,用于進度條的總進度。total = len(self.im_files)# 從數據集配置 self.data 中獲取關鍵點信息。# nkpt :每個對象的 關鍵點數量 。# ndim :每個關鍵點的 維度 (通常是 2 或 3)。nkpt, ndim = self.data.get("kpt_shape", (0, 0))# 如果啟用了關鍵點任務( self.use_keypoints 為 True ),檢查關鍵點配置是否正確 : nkpt 必須大于 0。 ndim 必須為 2 或 3。if self.use_keypoints and (nkpt <= 0 or ndim not in {2, 3}):# 如果配置錯誤,拋出 ValueError 并附帶錯誤信息。raise ValueError("'kpt_shape' in data.yaml missing or incorrect. Should be a list with [number of " # data.yaml 中的“kpt_shape”缺失或不正確。"keypoints, number of dims (2 for x,y or 3 for x,y,visible)], i.e. 'kpt_shape: [17, 3]'" # 應為包含 [關鍵點數量、維度數量(x、y 為 2,x、y、visible 為 3)] 的列表,即“kpt_shape:[17, 3]”。)# 使用 ThreadPool 創建多線程池,用于并行驗證圖像和標簽文件。# NUM_THREADS -> 計算YOLO多進程線程數,最多8個,最少1個,通常是CPU核心數減1。with ThreadPool(NUM_THREADS) as pool:results = pool.imap(# verify_image_label 是驗證函數,逐個處理圖像和標簽文件。func=verify_image_label,# 輸入參數包括 :iterable=zip(# 圖像文件路徑。self.im_files,# 標簽文件路徑。self.label_files,# 日志前綴。repeat(self.prefix),# 是否處理關鍵點。repeat(self.use_keypoints),# 類別數量。repeat(len(self.data["names"])),# 關鍵點數量和維度。repeat(nkpt),repeat(ndim),),)# 這段代碼是 cache_labels 方法的核心部分,用于處理多線程驗證的結果,并實時更新進度條信息。# 使用 TQDM 創建一個進度條,用于顯示驗證過程的進度。# results 是多線程池 ThreadPool 返回的迭代器,包含 每個圖像和標簽文件的驗證結果 。# desc 是 進度條的描述信息 ,顯示當前正在處理的任務。# total 是 進度條的總進度 ,等于圖像文件的總數。pbar = TQDM(results, desc=desc, total=total)# 遍歷多線程驗證的結果,每次迭代返回以下內容 :# im_file :圖像文件路徑。# lb :標簽數據(NumPy 數組)。# shape :圖像尺寸( (height, width) )。# segments :多邊形分割數據(如果有)。# keypoint :關鍵點數據(如果有)。# nm_f 、 nf_f 、 ne_f 、 nc_f :分別表示當前圖像的 缺失 、 找到 、 空 和 損壞 的標簽數量。# msg :驗證過程中生成的消息。for im_file, lb, shape, segments, keypoint, nm_f, nf_f, ne_f, nc_f, msg in pbar:# 將當前圖像的統計結果累加到全局統計變量中。# 缺失標簽總數。nm += nm_f# 找到的標簽總數。nf += nf_f# 空標簽總數。ne += ne_f# 損壞的標簽總數。nc += nc_f# 如果圖像文件有效( im_file 不為 None ),將 標簽信息存 儲到緩存字典 x["labels"] 中。if im_file:x["labels"].append(# 存儲的內容包括 :{# 圖像文件路徑。"im_file": im_file,# 圖像尺寸。"shape": shape,# 類別編號(從標簽數組的第 0 列提取)。"cls": lb[:, 0:1], # n, 1# 邊界框坐標(從標簽數組的第 1 列到最后一列提取)。"bboxes": lb[:, 1:], # n, 4# 多邊形分割數據。"segments": segments,# 關鍵點數據。"keypoints": keypoint,# 標記坐標是否歸一化。"normalized": True,# 邊界框格式( "xywh" )。"bbox_format": "xywh",})# 如果驗證過程中生成了消息( msg 不為空),將其添加到消息列表 msgs 中。if msg:msgs.append(msg)# 動態更新進度條的描述信息,實時顯示當前的驗證結果。# nf :已找到的圖像數量。# nm + ne :缺失或空標簽的圖像數量。# nc :損壞的圖像數量。pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt" # {desc} {nf} 圖像,{nm + ne} 背景,{nc} 損壞。# 在驗證完成后關閉進度條。pbar.close()# 這段代碼的功能是。創建進度條:使用 TQDM 顯示驗證進度。遍歷驗證結果:逐個處理多線程驗證的結果。更新統計變量:累加每個圖像的驗證統計結果。存儲有效的標簽信息:將有效的圖像和標簽信息存儲到緩存字典中。記錄驗證消息:將驗證過程中生成的消息存儲到列表中。動態更新進度條:實時顯示驗證進度和結果。關閉進度條:在驗證完成后關閉進度條。通過這種方式,代碼能夠高效地處理多線程驗證的結果,并實時反饋驗證進度和統計信息,便于用戶了解數據集的質量和驗證狀態。# 這段代碼是 cache_labels 方法的最后部分,負責處理驗證過程中的日志輸出、統計結果的保存以及緩存文件的寫入。# 如果驗證過程中生成了任何消息(存儲在 msgs 列表中)。if msgs:# 將這些消息合并為一個字符串,并通過日志記錄器 LOGGER 輸出為信息。這些消息通常是關于圖像或標簽文件的警告信息,例如修復損壞的 JPEG 文件或移除重復標簽。LOGGER.info("\n".join(msgs))# 如果在整個驗證過程中沒有找到任何有效的標簽文件( nf == 0 ),則通過日志記錄器 LOGGER 輸出警告信息。 警告信息中包含前綴 self.prefix ,路徑 path ,以及一個幫助鏈接 HELP_URL ,以便用戶查找解決問題的方法。if nf == 0:LOGGER.warning(f"{self.prefix}WARNING ?? No labels found in {path}. {HELP_URL}") # {self.prefix}警告 ?? 在 {path} 中未找到標簽。{HELP_URL} 。# 調用 get_hash 函數,計算所有標簽文件和圖像文件路徑的哈希值。 將哈希值存儲在緩存字典 x 中,鍵為 "hash" 。 這個哈希值用于后續驗證數據集的一致性,確保數據未被修改。x["hash"] = get_hash(self.label_files + self.im_files)# 將 驗證過程中的統計結果 存儲在緩存字典 x 中,鍵為 "results" 。 統計結果包括 :# nf :找到的標簽文件數量。# nm :缺失的標簽文件數量。# ne :空的標簽文件數量。# nc :損壞的標簽文件數量。# len(self.im_files) :圖像文件的總數。x["results"] = nf, nm, ne, nc, len(self.im_files)# 將驗證過程中生成的所有消息(警告信息)存儲在緩存字典 x 中,鍵為 "msgs" 。 這些消息后續可以用于調試或用戶提示。x["msgs"] = msgs # warnings# 調用 save_dataset_cache_file 函數,將緩存字典 x 保存到指定路徑 path 。 緩存文件中包含標簽信息、統計結果、哈希值和驗證消息。 緩存文件的版本號為 DATASET_CACHE_VERSION ,用于確保緩存文件的兼容性。# def save_dataset_cache_file(prefix, path, x, version): -> 它將一個字典 x 保存為一個以 .cache 結尾的文件,并將其存儲到指定路徑 path 。save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)# 返回緩存字典 x ,包含 標簽信息 、 統計結果 、 哈希值 和 驗證消息 。 這個字典后續可以用于快速加載數據集信息,避免重復驗證。return x# 這段代碼的功能是.處理驗證消息:將驗證過程中生成的消息輸出到日志。檢查標簽文件:如果未找到任何標簽文件,發出警告。計算哈希值:為數據集生成哈希值,用于后續驗證數據一致性。保存統計結果:將驗證過程中的統計結果存儲到緩存字典中。保存緩存文件:將緩存字典保存到文件中,便于后續快速加載。返回緩存結果:返回包含標簽信息和統計結果的緩存字典。通過這種方式, cache_labels 方法不僅驗證了數據集的完整性和一致性,還通過緩存機制提高了數據加載的效率,同時為用戶提供了詳細的驗證反饋。# 這段代碼的功能是。驗證標簽文件:通過多線程并行處理圖像和標簽文件,檢查標簽文件是否存在、格式是否正確、坐標是否歸一化等。統計標簽信息:統計找到的標簽數量、缺失的標簽數量、空標簽數量和損壞的標簽數量。保存緩存文件:將驗證后的標簽信息和統計結果保存到緩存文件中,便于后續快速加載。記錄驗證消息:記錄驗證過程中生成的消息(例如警告或修復信息),便于調試和排查問題。通過這種方式, cache_labels 方法能夠高效地驗證和緩存數據集的標簽信息,確保數據集在訓練前的質量和一致性。# 這段代碼定義了 YOLODataset 類中的 get_labels 方法,用于加載和驗證數據集的標簽信息。它通過檢查緩存文件來決定是否重新生成緩存,并根據緩存內容更新數據集的標簽信息。# 定義了 get_labels 方法,用于加載和驗證數據集的標簽信息。def get_labels(self):# 返回 YOLO 訓練的標簽詞典。"""Returns dictionary of labels for YOLO training."""# 這段代碼是 get_labels 方法的核心部分,負責加載或生成數據集的緩存文件,并驗證其完整性和一致性。# 使用 img2label_paths 函數將 圖像文件路徑列表 self.im_files 轉換為 對應的標簽文件路徑列表 self.label_files 。 img2label_paths 函數通常會將路徑中的 /images/ 替換為 /labels/ ,并將文件擴展名從圖像格式(如 .jpg )替換為 .txt ,以匹配標簽文件的命名規則。# def img2label_paths(img_paths): -> 它接收一個包含圖像路徑的列表 img_paths ,并返回一個對應的標簽路徑列表。返回一個列表,其中包含 與輸入圖像路徑對應的標簽路徑 。 -> return [sb.join(x.rsplit(sa, 1)).rsplit(".", 1)[0] + ".txt" for x in img_paths]self.label_files = img2label_paths(self.im_files)# 通過 Path 對象操作,確定 緩存文件的路徑 。 self.label_files[0] 是第一個標簽文件的路徑。 .parent 獲取該路徑的父目錄。 .with_suffix(".cache") 將文件擴展名替換為 .cache ,生成緩存文件的完整路徑。cache_path = Path(self.label_files[0]).parent.with_suffix(".cache")# 嘗試加載緩存文件。try:# 調用 load_dataset_cache_file(cache_path) 函數加載緩存文件,返回一個 包含緩存信息的字典 cache 。# 設置 exists 為 True ,表示 緩存文件存在 。# def load_dataset_cache_file(path): -> 它從指定路徑加載一個以 .cache 結尾的文件,并將其內容解析為一個字典。返回加載并解析后的字典對象,即 .cache 文件的內容。 -> return cachecache, exists = load_dataset_cache_file(cache_path), True # attempt to load a *.cache file# 驗證緩存文件的 完整性 和 一致性 。# 檢查 緩存文件的版本號 是否與 當前版本一致 ( DATASET_CACHE_VERSION )。assert cache["version"] == DATASET_CACHE_VERSION # matches current version# 檢查 緩存文件的哈希值 是否與 當前數據集的哈希值 一致(通過 get_hash 函數計算)。哈希值是基于所有標簽文件和圖像文件路徑生成的,用于 確保數據集未被修改 。assert cache["hash"] == get_hash(self.label_files + self.im_files) # identical hash# 處理緩存文件不存在或驗證失敗的情況。# 如果在加載或驗證緩存文件時發生以下異常 : FileNotFoundError 緩存文件不存在。 AssertionError 版本號或哈希值不匹配。 AttributeError 緩存文件格式錯誤或缺少關鍵字段。except (FileNotFoundError, AssertionError, AttributeError):# 調用 self.cache_labels(cache_path) 方法重新生成緩存文件。 設置 exists 為 False ,表示緩存文件是新生成的。cache, exists = self.cache_labels(cache_path), False # run cache ops# 這段代碼的功能是。初始化標簽文件路徑:將圖像文件路徑轉換為對應的標簽文件路徑。確定緩存路徑:根據標簽文件路徑確定緩存文件的存儲位置。嘗試加載緩存文件:加載緩存文件并驗證其版本和哈希值是否與當前數據集一致。處理緩存文件不存在或驗證失敗的情況:如果緩存文件不存在或驗證失敗,重新生成緩存文件。通過這種方式,代碼能夠高效地加載或生成緩存文件,確保數據集的一致性和完整性,同時避免重復驗證已處理的數據集。# 這段代碼的功能是展示緩存文件的內容,包括驗證過程中統計的結果和生成的消息。# Display cache# 從緩存字典 cache 中提取統計結果,鍵為 "results" 。 統計結果包含以下內容 :# nf :找到的標簽文件數量。# nm :缺失的標簽文件數量。# ne :空的標簽文件數量。# nc :損壞的標簽文件數量。# n :圖像文件的總數。# 使用 cache.pop("results") 提取并移除該鍵,避免后續重復處理。nf, nm, ne, nc, n = cache.pop("results") # found, missing, empty, corrupt, total# 判斷 是否顯示緩存信息 。# exists :表示緩存文件是否存在(如果為 True ,則緩存文件已成功加載)。# LOCAL_RANK :用于分布式訓練中標識當前進程的本地排名。 -1 表示單進程運行, 0 表示多進程中的主進程。# 只有當緩存文件存在且當前進程是主進程時,才顯示緩存信息。if exists and LOCAL_RANK in {-1, 0}:# 構造描述信息 d ,用于展示緩存文件的內容。# cache_path :緩存文件的路徑。# {nf} images :表示找到的標簽文件數量。# {nm + ne} backgrounds :表示缺失或空的標簽文件數量(被視為背景圖像)。# {nc} corrupt :表示損壞的標簽文件數量。d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt" # 掃描 {cache_path}...{nf} 幅圖像、{nm + ne} 幅背景、{nc} 幅損壞圖像。# 使用 TQDM 創建一個進度條,用于顯示緩存信息。# None :不傳遞迭代對象,僅用于顯示靜態信息。# desc :進度條的描述信息,包含前綴 self.prefix 和動態描述 d 。# total 和 initial :設置進度條的 總進度 和 初始進度為 n (圖像文件的總數),使進度條顯示為已完成狀態。# 這種方式用于靜態展示緩存信息,而不是用于動態進度跟蹤。TQDM(None, desc=self.prefix + d, total=n, initial=n) # display results# 檢查緩存字典中是否存在驗證消息( cache["msgs"] )。if cache["msgs"]:# 如果存在消息,將這些消息合并為一個字符串,并通過日志記錄器 LOGGER 輸出為信息。 這些消息通常是驗證過程中生成的警告信息,例如修復損壞的 JPEG 文件或移除重復標簽。LOGGER.info("\n".join(cache["msgs"])) # display warnings# 這段代碼的功能是。提取緩存中的統計結果:從緩存字典中提取驗證過程中的統計信息。判斷是否顯示緩存信息:根據緩存文件是否存在以及當前進程是否為主進程,決定是否顯示緩存信息。構造描述信息:動態生成描述緩存內容的字符串。使用 TQDM 顯示結果:通過進度條靜態展示緩存信息。顯示驗證消息:將驗證過程中生成的消息輸出到日志。通過這種方式,代碼能夠清晰地展示緩存文件的內容和驗證結果,同時將重要的警告信息記錄到日志中,便于用戶了解數據集的狀態和潛在問題。# 這段代碼的功能是讀取緩存文件中的標簽信息,并根據緩存內容更新數據集的狀態。# Read cache# 使用列表推導式從緩存字典 cache 中移除以下鍵 :# "hash" :數據集的哈希值。# "version" :緩存文件的版本號。# "msgs" :驗證過程中生成的消息列表。# 這些鍵在后續處理中不再需要,因此移除以簡化緩存字典。[cache.pop(k) for k in ("hash", "version", "msgs")] # remove items# 從緩存字典中提取 標簽信息 ,存儲在變量 labels 中。 labels 是一個列表,每個元素是一個字典,包含單個圖像的標簽信息(例如圖像路徑、類別、邊界框、分割數據等)。labels = cache["labels"]# 檢查 labels 是否為空。if not labels:# 如果為空,說明緩存中沒有有效的標簽信息。 使用 LOGGER.warning 輸出警告信息,提示用戶數據集可能為空,訓練可能無法正常進行。 提供幫助鏈接 HELP_URL ,以便用戶查找解決方案。LOGGER.warning(f"WARNING ?? No images found in {cache_path}, training may not work correctly. {HELP_URL}") # 警告 ?? 在 {cache_path} 中未找到圖像,訓練可能無法正常工作。{HELP_URL}。# 使用列表推導式 從每個標簽字典中 提取 圖像文件路徑 lb["im_file"] 。 更新 self.im_files ,確保其包含所有有效圖像的路徑。 這一步確保后續數據加載和訓練時使用的是經過驗證的圖像文件。self.im_files = [lb["im_file"] for lb in labels] # update im_files# 這段代碼的功能是。移除緩存中的冗余信息:從緩存字典中移除不再需要的鍵( "hash" 、 "version" 、 "msgs" )。提取標簽信息:從緩存字典中提取標簽列表 labels 。檢查是否有有效的標簽信息:如果 labels 為空,發出警告,提示數據集可能為空。更新圖像文件路徑列表:根據標簽信息更新 self.im_files ,確保后續處理使用的是有效的圖像路徑。通過這種方式,代碼能夠確保數據集的完整性和一致性,并為后續的數據加載和訓練提供準確的圖像路徑和標簽信息。# 這段代碼的功能是檢查數據集是否包含邊界框(boxes)或分割掩碼(segments),并確保數據集的類型一致。如果發現不一致,代碼會發出警告并調整數據集以避免潛在問題。# Check if the dataset is all boxes or all segments# 檢查數據集類型。# 生成一個元組列表,每個元組包含每個標簽的 類別數量 、 邊界框數量 和 分割掩碼數量 。lengths = ((len(lb["cls"]), len(lb["bboxes"]), len(lb["segments"])) for lb in labels)# zip(*lengths) :將元組列表解包,分別對 類別數量 、 邊界框數量 和 分割掩碼數量 進行匯總。# len_cls, len_boxes, len_segments :分別計算 總類別數量 、 總邊界框數量 和 總分割掩碼數量 。len_cls, len_boxes, len_segments = (sum(x) for x in zip(*lengths))# 檢查邊界框和分割掩碼數量是否一致。# 如果數據集中包含分割掩碼( len_segments > 0 )且邊界框數量與分割掩碼數量不一致( len_boxes != len_segments ),說明數據集是混合類型(檢測和分割)。if len_segments and len_boxes != len_segments:# 使用 LOGGER.warning 發出警告,說明邊界框數量和分割掩碼數量不一致,并提示用戶應提供單一類型的數據集(檢測或分割)。LOGGER.warning(f"WARNING ?? Box and segment counts should be equal, but got len(segments) = {len_segments}, " # 警告 ?? 框和段數應該相等,但得到的 len(segments) = {len_segments},len(boxes) = {len_boxes}。f"len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. " # 要解決此問題,將僅使用框,并將刪除所有段。"To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset." # 為避免這種情況,請提供檢測或段數據集,而不是檢測-段混合數據集。)# 調整數據集。為了避免問題,代碼 將所有標簽的分割掩碼清空 ( lb["segments"] = [] ), 僅保留邊界框 。for lb in labels:lb["segments"] = []# 如果 總類別數量為 0( len_cls == 0 ),說明數據集中沒有有效的標簽。if len_cls == 0:# 使用 LOGGER.warning 發出警告,提示用戶數據集中沒有標簽,訓練可能無法正常進行,并提供幫助鏈接 HELP_URL 。LOGGER.warning(f"WARNING ?? No labels found in {cache_path}, training may not work correctly. {HELP_URL}") # 警告 ?? 在 {cache_path} 中未找到標簽,訓練可能無法正常工作。{HELP_URL}。# 返回 處理后的標簽信息列表 labels ,供后續數據加載和訓練使用。return labels# 這段代碼的功能是。檢查數據集類型:統計數據集中邊界框和分割掩碼的數量。確保數據集類型一致:如果發現邊界框和分割掩碼數量不一致,發出警告并清空所有分割掩碼,僅保留邊界框。檢查是否有標簽:如果數據集中沒有標簽,發出警告并提示用戶。返回標簽信息:返回處理后的標簽列表,確保數據集的一致性和完整性。通過這種方式,代碼能夠確保數據集的類型一致,避免因混合類型數據集導致的潛在問題,同時為后續的數據加載和訓練提供準確的標簽信息。# 這段代碼的功能是。初始化標簽文件路徑:將圖像文件路徑轉換為對應的標簽文件路徑。嘗試加載緩存文件:檢查緩存文件是否存在,版本和哈希值是否匹配。如果不匹配,重新生成緩存。顯示緩存信息:顯示緩存內容,包括統計結果和驗證消息。讀取緩存內容:提取緩存中的標簽信息,并移除不必要的鍵。檢查標簽信息:驗證標簽信息是否為空,更新圖像文件路徑列表。檢查數據一致性:確保數據集類型一致(檢測或分割),移除不一致的數據。返回標簽信息:返回處理后的標簽信息,供后續使用。通過這種方式, get_labels 方法能夠高效地加載和驗證標簽信息,確保數據集的一致性和完整性,同時為用戶提供詳細的反饋。# 這段代碼定義了 YOLODataset 類中的 build_transforms 方法,用于構建數據增強和格式化流程。# 定義了 build_transforms 方法,用于構建數據增強和格式化流程。該方法接受一個參數。# 1.hyp :表示超參數配置(例如數據增強的參數)。def build_transforms(self, hyp=None):# 構建轉換并將其附加到列表。"""Builds and appends transforms to the list."""# 如果 啟用了數據增強 ( self.augment 為 True )。if self.augment:# 調整 hyp 中的 mosaic 和 mixup 參數。# 如果啟用了數據增強但未啟用矩形訓練( self.rect 為 False ),則保留 mosaic 和 mixup 的值。# 否則,將 mosaic 和 mixup 設置為 0.0 ,禁用這些增強方法。hyp.mosaic = hyp.mosaic if self.augment and not self.rect else 0.0hyp.mixup = hyp.mixup if self.augment and not self.rect else 0.0# 調用 v8_transforms 函數,生成 數據增強流程 transforms ,并傳遞 當前實例 、 圖像尺寸 self.imgsz 和 超參數 hyp 。# def v8_transforms(dataset, imgsz, hyp, stretch=False):# -> 用于構建一個綜合的數據增強流程,適用于目標檢測和分割任務。該函數根據傳入的參數和配置,組合了多種增強操作,以提高模型的魯棒性和泛化能力。構建并返回一個 綜合的數據增強流程 。# -> return Compose([pre_transform, MixUp(dataset, pre_transform=pre_transform, p=hyp.mixup), Albumentations(p=1.0), RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v), RandomFlip(direction="vertical", p=hyp.flipud), RandomFlip(direction="horizontal", p=hyp.fliplr, flip_idx=flip_idx),]) # transformstransforms = v8_transforms(self, self.imgsz, hyp)# 如果未啟用數據增強( self.augment 為 False )。else:# 使用 Compose 創建一個簡單的數據處理流程,僅包含 LetterBox 操作。 LetterBox 用于調整圖像大小,使其適應指定的尺寸( self.imgsz, self.imgsz ),同時禁用上采樣( scaleup=False )。# class Compose:# -> 用于將多個圖像變換(或數據處理)操作組合在一起,并按順序應用到輸入數據上。這種設計模式在數據預處理、數據增強以及機器學習任務中非常常見。# -> def __init__(self, transforms):# class LetterBox:# -> 用于對圖像進行縮放和填充操作,以適應指定的目標尺寸。這種操作通常用于深度學習中的圖像預處理,尤其是在目標檢測和分割任務中。 LetterBox 的核心功能是將圖像縮放到指定大小,同時保持原始圖像的寬高比,并通過填充(通常是灰色)來補充剩余部分。# -> def __init__(self, new_shape=(640, 640), auto=False, scaleFill=False, scaleup=True, center=True, stride=32):transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), scaleup=False)])# 向 數據增強或處理流程 中添加 Format 操作,用于格式化數據。transforms.append(# class Format:# -> 用于對圖像及其標注信息(如邊界框、分割掩碼、關鍵點等)進行格式化處理。該類的主要功能是將標注信息轉換為模型訓練所需的格式,并支持多種選項,如歸一化、掩碼生成、關鍵點處理等。# -> def __init__(self, bbox_format="xywh", normalize=True, return_mask=False, return_keypoint=False, return_obb=False, mask_ratio=4, mask_overlap=True, batch_idx=True, bgr=0.0,):Format(# 指定邊界框格式為 (x, y, w, h) 。bbox_format="xywh",# 將坐標歸一化到 [0, 1] 范圍內。normalize=True,# 根據是否啟用分割任務,決定是否返回分割掩碼。return_mask=self.use_segments,# 根據是否啟用關鍵點任務,決定是否返回關鍵點數據。return_keypoint=self.use_keypoints,# 根據是否啟用定向邊界框任務,決定是否返回定向邊界框數據。return_obb=self.use_obb,# 為每個樣本添加批量索引。batch_idx=True,# mask_ratio=hyp.mask_ratio 和 mask_overlap=hyp.overlap_mask :控制分割掩碼的生成參數。mask_ratio=hyp.mask_ratio,mask_overlap=hyp.overlap_mask,# 在訓練時根據超參數 hyp.bgr 決定是否啟用 BGR 轉換;在非增強模式下禁用。bgr=hyp.bgr if self.augment else 0.0, # only affect training.))# 返回構建好的 數據處理流程 transforms ,供后續數據加載和訓練使用。return transforms# 這段代碼的功能是。配置數據增強:根據是否啟用數據增強和矩形訓練,調整增強參數(如 mosaic 和 mixup )。構建增強流程:如果啟用增強,調用 v8_transforms 生成增強流程;否則,僅使用 LetterBox 調整圖像大小。添加格式化操作:在處理流程中添加 Format 操作,用于格式化邊界框、分割掩碼、關鍵點和定向邊界框數據。返回處理流程:返回完整的數據處理流程,供數據加載器使用。通過這種方式, build_transforms 方法能夠靈活地配置數據增強和格式化流程,適應不同的任務需求(檢測、分割、關鍵點估計等),并確保數據在訓練前被正確處理。# 這段代碼定義了 YOLODataset 類中的 close_mosaic 方法,用于關閉 Mosaic 數據增強功能,并調整其他相關增強參數。# 定義了 close_mosaic 方法,用于關閉 Mosaic 數據增強功能。該方法接受一個參數。# 1.hyp :表示超參數配置。def close_mosaic(self, hyp):# 將馬賽克、復制粘貼和混合選項設置為 0.0 并構建轉換。"""Sets mosaic, copy_paste and mixup options to 0.0 and builds transformations."""# 將超參數 hyp 中的 mosaic 參數設置為 0.0 ,表示關閉 Mosaic 數據增強功能。 Mosaic 是一種數據增強技術,通過將多個圖像拼接在一起,增強模型對不同背景和目標組合的泛化能力。hyp.mosaic = 0.0 # set mosaic ratio=0.0# 將超參數 hyp 中的 copy_paste 參數設置為 0.0 ,表示關閉 Copy-Paste 數據增強功能。 Copy-Paste 是一種增強技術,通過將一個圖像中的對象復制并粘貼到另一個圖像中,增加數據多樣性。hyp.copy_paste = 0.0 # keep the same behavior as previous v8 close-mosaic# 將超參數 hyp 中的 mixup 參數設置為 0.0 ,表示關閉 Mixup 數據增強功能。 Mixup 是一種增強技術,通過將兩個圖像及其標簽進行線性組合,生成新的訓練樣本。hyp.mixup = 0.0 # keep the same behavior as previous v8 close-mosaic# 調用 build_transforms 方法,根據更新后的超參數 hyp 重新構建數據處理流程。 更新后的 self.transforms 不再包含 Mosaic、Copy-Paste 和 Mixup 數據增強,確保數據增強行為與關閉 Mosaic 時一致。self.transforms = self.build_transforms(hyp)# 這段代碼的功能是。關閉 Mosaic 數據增強:通過將 hyp.mosaic 設置為 0.0 ,禁用 Mosaic 數據增強。關閉 Copy-Paste 和 Mixup 數據增強:通過將 hyp.copy_paste 和 hyp.mixup 設置為 0.0 ,禁用這兩種增強技術。重新構建數據處理流程:調用 build_transforms 方法,根據更新后的超參數重新生成數據處理流程。通過這種方式, close_mosaic 方法能夠快速調整數據增強行為,確保在需要時關閉 Mosaic 及其他相關增強功能,同時保持數據處理流程的一致性。# 這段代碼定義了 YOLODataset 類中的 update_labels_info 方法,用于更新和格式化標簽信息,特別是處理邊界框、分割掩碼和關鍵點數據。# 定義了 update_labels_info 方法,用于更新和格式化單個標簽信息。該方法接受一個參數。# 1.label :一個標簽字典,包含邊界框、分割掩碼和關鍵點等信息。def update_labels_info(self, label):# 在此處自定義您的標簽格式。# 注意:# cls 現在不包含 bboxes,分類和語義分割需要獨立的 cls 標簽# 還可以通過添加或刪除字典鍵來支持分類和語義分割。"""Custom your label format here.Note:cls is not with bboxes now, classification and semantic segmentation need an independent cls labelCan also support classification and semantic segmentation by adding or removing dict keys there."""# 從 label 字典中提取以下信息。# 邊界框坐標。bboxes = label.pop("bboxes")# 分割掩碼(默認為空列表,如果不存在)。segments = label.pop("segments", [])# 關鍵點數據(默認為 None ,如果不存在)。keypoints = label.pop("keypoints", None)# 邊界框格式(例如 "xywh" )。bbox_format = label.pop("bbox_format")# 坐標是否歸一化。normalized = label.pop("normalized")# NOTE: do NOT resample oriented boxes# 如果啟用了定向邊界框( self.use_obb 為 True ),將分割數據的 重采樣數量 設置為 100 。 否則,設置為 1000 。 這個值決定了分割掩碼的點數,用于后續插值。segment_resamples = 100 if self.use_obb else 1000# 這段代碼的功能是對分割掩碼( segments )進行處理,確保所有分割掩碼的長度一致,并將其轉換為統一的 NumPy 數組格式。# 檢查 segments 是否為空。如果 segments 是一個非空列表,說明存在分割掩碼數據。if len(segments) > 0:# make sure segments interpolate correctly if original length is greater than segment_resamples# 遍歷 segments 列表,計算每個分割掩碼的長度(即每個分割掩碼的點數)。 使用 max() 函數找到所有分割掩碼中的最大長度 max_len 。max_len = max(len(s) for s in segments)# 如果當前的重采樣數量 segment_resamples 小于最大長度 max_len ,則將其調整為 max_len + 1 。 這是為了確保在插值過程中不會丟失信息。 否則,保持 segment_resamples 不變。segment_resamples = (max_len + 1) if segment_resamples < max_len else segment_resamples# list[np.array(segment_resamples, 2)] * num_samples# 調用 resample_segments 函數,對每個分割掩碼進行重采樣,使其長度一致為 segment_resamples 。# 使用 np.stack 將重采樣后的分割掩碼堆疊為一個 NumPy 數組,形狀為 (num_samples, segment_resamples, 2) ,其中 :# num_samples 是分割掩碼的數量。# segment_resamples 是每個分割掩碼的點數。# 2 表示每個點的 (x, y) 坐標。# def resample_segments(segments, n=1000): -> 用于對輸入的二維線段數據進行重采樣,使其每個線段的點數統一為指定的數量 n 。返回處理后的線段數據列表。 -> return segmentssegments = np.stack(resample_segments(segments, n=segment_resamples), axis=0)# 如果 segments 為空(即沒有分割掩碼數據)。else:# 創建一個形狀為 (0, segment_resamples, 2) 的空 NumPy 數組。 數據類型為 np.float32 ,表示坐標值為浮點數。segments = np.zeros((0, segment_resamples, 2), dtype=np.float32)# 這段代碼的功能是。檢查是否存在分割掩碼:通過檢查 segments 是否為空。確保分割掩碼長度一致:計算所有分割掩碼的最大長度,并調整重采樣數量。對分割掩碼進行重采樣:使用 resample_segments 函數對分割掩碼進行插值,使其長度一致。處理空的分割掩碼:如果不存在分割掩碼,生成一個空的 NumPy 數組。通過這種方式,代碼能夠標準化分割掩碼的格式,確保所有分割掩碼的長度一致,便于后續處理和訓練。# 使用提取的邊界框、分割掩碼和關鍵點數據,創建一個 Instances 對象,并將其存儲在 label 字典中,鍵為 "instances" 。 Instances 是一個封裝類,用于統一管理目標檢測、分割和關鍵點任務的實例信息。 傳遞的參數包括 :# bboxes :邊界框坐標。# segments :分割掩碼。# keypoints :關鍵點數據。# bbox_format :邊界框格式。# normalized :坐標是否歸一化。# class Instances:# -> 用于封裝和處理目標檢測中的邊界框、分割掩碼和關鍵點信息。# -> def __init__(self, bboxes, segments=None, keypoints=None, bbox_format="xywh", normalized=True) -> None:label["instances"] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized)# 返回更新后的標簽字典 label ,其中包含格式化后的 Instances 對象。return label# 這段代碼的功能是。提取和移除標簽信息:從輸入的標簽字典中提取邊界框、分割掩碼、關鍵點等信息。配置分割數據的重采樣數量:根據是否啟用定向邊界框任務,設置分割數據的重采樣數量。處理分割數據:對分割掩碼進行重采樣,確保所有分割掩碼的長度一致。創建 Instances 對象:將邊界框、分割掩碼和關鍵點封裝為一個統一的 Instances 對象。返回更新后的標簽信息:返回包含格式化后實例信息的標簽字典。通過這種方式, update_labels_info 方法能夠標準化標簽信息的格式,確保數據在后續處理和訓練中的一致性和兼容性。# 這段代碼定義了 YOLODataset 類中的 collate_fn 靜態方法,用于將一個批次(batch)的數據合并為一個統一的張量格式,以便用于訓練。@staticmethod# 定義了一個靜態方法 collate_fn ,它接受一個參數。# 1.batch :一個列表,其中每個元素是一個字典,表示單個樣本的數據(例如圖像、標簽、分割掩碼等)def collate_fn(batch):# 將數據樣本整理成批次。"""Collates data samples into batches."""# 初始化一個空字典 new_batch ,用于 存儲合并后的批次數據 。new_batch = {}# 提取第一個樣本的鍵( keys ),假設所有樣本的鍵是相同的。keys = batch[0].keys()# 使用列表推導式和 zip 函數,將每個樣本的值按鍵分組,形成一個列表 values 。 values[i] 包含所有樣本的第 i 個鍵對應的值。# 這行代碼的作用是將一個批次( batch )中的所有樣本數據 按鍵值分組 ,以便后續對每個鍵對應的值進行批量處理。# batch 的結構 :# batch 是一個列表,其中每個元素是一個字典,表示單個樣本的數據。例如 :# batch = [# {"img": img1, "bboxes": bboxes1, "cls": cls1},# {"img": img2, "bboxes": bboxes2, "cls": cls2},# ...# ]# 每個字典包含相同的鍵(如 "img" 、 "bboxes" 、 "cls" 等),但值是不同的樣本數據。# 列表推導式 :# [list(b.values()) for b in batch]# 遍歷 batch 中的每個樣本 b 。 使用 b.values() 提取每個樣本字典的值(按鍵的順序)。 將這些值轉換為列表,形成一個新的列表。例如 :# [# [img1, bboxes1, cls1],# [img2, bboxes2, cls2],# ...# ]# 解包和 zip :# zip(*[list(b.values()) for b in batch])# 使用 * 解包操作符,將上述列表中的每個子列表解包為獨立的參數傳遞給 zip 。 zip 函數會將這些子列表按列分組,即 :# 第一列 :所有樣本的 "img" 數據。# 第二列 :所有樣本的 "bboxes" 數據。# 第三列 :所有樣本的 "cls" 數據。# 例如 :# [# (img1, img2, ...),# (bboxes1, bboxes2, ...),# (cls1, cls2, ...)# ]# 轉換為列表 :# list(zip(*[list(b.values()) for b in batch]))# 使用 list 將 zip 的結果轉換為一個列表,確保可以多次迭代。# 最終結果是一個列表,其中每個元素是一個元組,包含所有樣本對應鍵的值。values = list(zip(*[list(b.values()) for b in batch]))# 遍歷每個鍵 k 和對應的值 value 。for i, k in enumerate(keys):value = values[i]# 如果鍵是 "img" (圖像數據),使用 torch.stack 將圖像張量堆疊為一個批次張量。if k == "img":value = torch.stack(value, 0)# 如果鍵是 "masks" 、 "keypoints" 、 "bboxes" 、 "cls" 、 "segments" 或 "obb" (標簽數據),使用 torch.cat 將這些張量拼接為一個批次張量。if k in {"masks", "keypoints", "bboxes", "cls", "segments", "obb"}:value = torch.cat(value, 0)# 將合并后的值存儲到 new_batch 中,鍵為 k 。new_batch[k] = value# 將 new_batch["batch_idx"] 轉換為列表。new_batch["batch_idx"] = list(new_batch["batch_idx"])# 這兩行代碼的作用是為每個目標(例如邊界框或分割掩碼)添加一個唯一的批量索引( batch_idx ),以便在后續處理中區分不同樣本的目標。這種操作通常用于目標檢測或分割任務中,特別是在構建目標張量時。# 示例場景 :# 假設有一個批次( batch )的數據,包含多個樣本(圖像)。每個樣本包含多個目標(例如邊界框或分割掩碼)。需要為每個目標分配一個唯一的索引,以便在后續處理中區分它們。# 假設 new_batch["batch_idx"] 是一個列表,其中每個元素是一個目標的批量索引。初始時,這些索引可能都是相同的(例如,所有目標的索引都是 0 ),需要為每個目標添加一個唯一的索引,以區分不同樣本的目標。# 示例執行 :# 假設 new_batch["batch_idx"] 的初始值為 :# new_batch["batch_idx"] = [[0, 0, 0], [0, 0], [0, 0, 0, 0]]# 執行這兩行代碼后 :# 第一個樣本( i = 0 ) : new_batch["batch_idx"][0] += 0# 結果 : [0, 0, 0]# 第二個樣本( i = 1 ) : new_batch["batch_idx"][1] += 1# 結果 : [1, 1]# 第三個樣本( i = 2 ) : new_batch["batch_idx"][2] += 2# 結果 : [2, 2, 2, 2]# 最終, new_batch["batch_idx"] 的值變為 :# [[0, 0, 0], [1, 1], [2, 2, 2, 2]]# 總結 :# 這兩行代碼的作用是 :# 遍歷每個樣本 :通過 for i in range(len(new_batch["batch_idx"])) 遍歷批次中的每個樣本。# 為每個目標添加唯一的索引 :將當前樣本的索引 i 加到每個目標的索引上,從而為每個目標分配一個唯一的索引。# 這種操作確保了在后續處理中(例如構建目標張量時),每個目標可以通過其唯一的索引被正確區分。# 遍歷每個 批量索引 。for i in range(len(new_batch["batch_idx"])):# 將其值加上 當前樣本的索引 i 。這一步是為了在后續處理中區分不同樣本的目標索引。new_batch["batch_idx"][i] += i # add target image index for build_targets()# 使用 torch.cat 將批量索引合并為一個張量。new_batch["batch_idx"] = torch.cat(new_batch["batch_idx"], 0)# 返回 合并后的批次數據 new_batch ,其中每個鍵對應的值是一個批次張量。return new_batch# 這段代碼的功能是。初始化新批次:提取樣本的鍵,并按鍵分組值。合并圖像和標簽數據:使用 torch.stack 合并圖像張量。使用 torch.cat 合并標簽數據(如邊界框、分割掩碼、關鍵點等)。處理批量索引:為每個樣本的目標索引加上當前樣本的索引,確保目標索引唯一。返回合并后的批次:返回一個包含批次數據的字典,供后續訓練使用。通過這種方式, collate_fn 方法能夠將一個批次的樣本數據合并為統一的張量格式,確保數據在訓練過程中的一致性和兼容性。
# YOLODataset 類是一個用于目標檢測、分割和關鍵點估計任務的數據集類,繼承自 BaseDataset 。它通過靈活的配置支持多種任務類型(如檢測、分割、姿態估計和定向邊界框檢測),并提供了高效的數據加載、預處理和增強功能。該類的核心功能包括。任務配置:根據任務類型(如檢測、分割或姿態估計)動態調整行為,支持多任務數據集。數據增強:通過 build_transforms 方法配置多種增強技術(如 Mosaic、Mixup 和 Copy-Paste),并可通過 close_mosaic 方法關閉增強。標簽處理:在 get_labels 方法中加載和驗證標簽文件,支持緩存機制以提高效率,并通過 update_labels_info 方法標準化標簽格式。數據格式化:在 collate_fn 方法中將批次數據合并為統一的張量格式,便于訓練。數據一致性檢查:通過 cache_labels 方法驗證數據集的完整性和一致性,并提供詳細的日志信息。通過這些功能, YOLODataset 類能夠高效地處理大規模數據集,確保數據在訓練前的質量和一致性,同時為 YOLO 模型的訓練提供了強大的數據支持。
3.class YOLOMultiModalDataset(YOLODataset):?
# 這段代碼定義了 YOLOMultiModalDataset 類,它是 YOLODataset 的一個擴展,用于處理多模態數據(例如結合圖像和文本信息)。
# 定義了 YOLOMultiModalDataset 類,繼承自 YOLODataset ,用于處理多模態數據集(例如同時包含圖像和文本信息)。
class YOLOMultiModalDataset(YOLODataset):# 用于以 YOLO 格式加載對象檢測和/或分割標簽的數據集類。"""Dataset class for loading object detection and/or segmentation labels in YOLO format.Args:data (dict, optional): A dataset YAML dictionary. Defaults to None.task (str): An explicit arg to point current task, Defaults to 'detect'.Returns:(torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model."""# 定義了構造函數,接受以下參數 :# 1.*args 和 4.**kwargs :傳遞給父類 YOLODataset 的參數。# 2.data :數據集的配置信息(例如類別名稱、關鍵點信息等)。# 3.task :任務類型,默認為 "detect" 。def __init__(self, *args, data=None, task="detect", **kwargs):# 使用可選規范初始化用于對象檢測任務的數據集對象。"""Initializes a dataset object for object detection tasks with optional specifications."""# 調用父類 YOLODataset 的構造函數,初始化父類的屬性和方法。這一步確保了 YOLOMultiModalDataset 繼承了 YOLODataset 的所有功能。super().__init__(*args, data=data, task=task, **kwargs)# 重寫了 update_labels_info 方法,用于更新和格式化標簽信息。def update_labels_info(self, label):# 添加用于多模態模型訓練的文本信息。"""Add texts information for multi-modal model training."""# 調用父類 YOLODataset 的 update_labels_info 方法,獲取 基礎的標簽信息 。labels = super().update_labels_info(label)# NOTE: some categories are concatenated with its synonyms by `/`.# 為每個類別 添加文本信息 。假設類別名稱中包含多個同義詞,通過 / 分隔。 使用列表推導式將每個類別的名稱按 / 分割,生成一個包含同義詞的列表。 將這些文本信息存儲在 labels 字典中,鍵為 "texts" 。labels["texts"] = [v.split("/") for _, v in self.data["names"].items()]# 返回 更新后的標簽信息 ,包含圖像標簽和文本信息。return labels# 重寫了 build_transforms 方法,用于構建數據增強和格式化流程。def build_transforms(self, hyp=None):# 通過可選的文本增強功能增強數據轉換,以實現多模式訓練。"""Enhances data transformations with optional text augmentation for multi-modal training."""# 調用父類 YOLODataset 的 build_transforms 方法,獲取基礎的數據增強流程。transforms = super().build_transforms(hyp)# 如果啟用了數據增強( self.augment 為 True )。if self.augment:# NOTE: hard-coded the args for now. 注意:目前參數是硬編碼的。# 在數據增強流程中插入一個 RandomLoadText 操作。 RandomLoadText 是一個自定義的數據增強操作,用于隨機加載文本數據。 max_samples 限制了最大樣本數量,取類別數量 self.data["nc"] 和 80 的較小值。 padding=True 表示對文本數據進行填充。transforms.insert(-1, RandomLoadText(max_samples=min(self.data["nc"], 80), padding=True))# 返回構建好的數據增強和格式化流程。return transforms
# 這段代碼的功能是。定義多模態數據集類: YOLOMultiModalDataset 繼承自 YOLODataset ,用于處理多模態數據(圖像和文本)。重寫標簽更新方法:在 update_labels_info 中,為每個類別添加文本信息,支持類別名稱中包含多個同義詞。重寫數據增強方法:在 build_transforms 中,添加了文本加載增強操作,用于隨機加載文本數據。擴展功能:通過繼承和重寫方法, YOLOMultiModalDataset 在 YOLODataset 的基礎上增加了對文本信息的支持,適用于多模態任務。通過這種方式, YOLOMultiModalDataset 類能夠處理包含圖像和文本的多模態數據集,為多模態任務提供了靈活的數據加載和增強功能。
4.class GroundingDataset(YOLODataset):?
# “基于標注文件的目標檢測任務”是指使用預先定義好的標注文件來指導目標檢測模型的訓練和驗證。標注文件通常包含了圖像中目標對象的位置、類別以及其他相關信息,這些信息被用來監督模型的學習過程,使其能夠準確地識別和定位圖像中的目標。
# 標注文件的作用標注文件是目標檢測任務中的關鍵組成部分,它提供了以下信息 :
# 目標位置 :標注文件中通常會包含目標對象在圖像中的位置信息,例如邊界框(Bounding Box)的坐標(通常是左上角和右下角的坐標,或者中心點坐標加上寬度和高度)。
# 目標類別 :標注文件會指定每個目標對象的類別,例如“人”、“汽車”、“貓”等。
# 其他信息 :標注文件可能還會包含其他信息,如目標的屬性(如“戴帽子的人”)、目標之間的關系(如“人騎自行車”)等。
# 標注文件的格式可以是多種多樣的,常見的格式包括 :
# JSON 格式 :例如 COCO 數據集的標注文件就是 JSON 格式,它以鍵值對的形式存儲了圖像信息、目標信息等。
# XML 格式 :例如 Pascal VOC 數據集的標注文件是 XML 格式,它以標簽和屬性的形式存儲了標注信息。
# TXT 格式 :例如 YOLO 數據集的標注文件是 TXT 格式,它以簡單的文本形式存儲了邊界框和類別信息。
# 目標檢測任務的流程 :
# 基于標注文件的目標檢測任務通常包括以下幾個步驟 :
# 數據準備 :收集圖像數據和對應的標注文件。確保圖像和標注文件之間是一一對應的。
# 數據預處理 :讀取圖像和標注文件。將標注信息轉換為模型需要的格式,例如將邊界框坐標歸一化到 [0, 1] 范圍內。對圖像進行預處理,例如調整大小、歸一化等。
# 模型訓練 :使用標注信息作為監督信號,訓練目標檢測模型。模型學習如何根據圖像內容預測目標的位置和類別。
# 模型評估 :使用驗證集或測試集評估模型的性能。評估指標通常包括準確率、召回率、mAP(Mean Average Precision)等。
# 模型應用 :將訓練好的模型應用于實際的圖像,檢測其中的目標對象。
# 基于標注文件的目標檢測任務的優勢 :
# 數據驅動 :模型的學習過程完全依賴于標注數據,能夠自動學習到目標的特征和模式。
# 可擴展性 :可以通過增加標注數據來提高模型的性能和泛化能力。
# 靈活性 :可以處理多種類型的目標檢測任務,如單目標檢測、多目標檢測、實例分割等。
# 基于標注文件的目標檢測任務的挑戰 :
# 標注成本 :標注數據需要大量的人力和時間,尤其是對于復雜的標注任務(如實例分割)。
# 標注質量 :標注數據的質量直接影響模型的性能,錯誤的標注可能導致模型學習到錯誤的模式。
# 數據不平衡 :某些類別可能有大量的標注數據,而某些類別可能只有少量標注數據,這可能導致模型對某些類別有偏見。
# 總結 :“基于標注文件的目標檢測任務”是一種常見的計算機視覺任務,它依賴于標注文件來指導模型的學習過程。標注文件提供了目標的位置、類別等信息,模型通過學習這些標注信息來識別和定位圖像中的目標。這種任務在實際應用中非常廣泛,例如自動駕駛、安防監控、醫療影像等領域。# 這段代碼定義了 GroundingDataset 類,它是 YOLODataset 的一個擴展,專門用于處理基于標注文件(如 COCO 格式的 JSON 文件)的目標檢測任務。
# 定義了 GroundingDataset 類,繼承自 YOLODataset ,用于處理基于標注文件的目標檢測任務。
class GroundingDataset(YOLODataset):# 通過從指定的 JSON 文件加載注釋來處理對象檢測任務,支持 YOLO 格式。"""Handles object detection tasks by loading annotations from a specified JSON file, supporting YOLO format."""# 定義了構造函數,接受以下參數 :# 1.*args 和 4.**kwargs :傳遞給父類 YOLODataset 的參數。# 2.task :任務類型,默認為 "detect" 。目前僅支持目標檢測任務。# 3.json_file :標注文件的路徑,通常是一個 JSON 文件。def __init__(self, *args, task="detect", json_file, **kwargs):# 初始化 GroundingDataset 以進行對象檢測,從指定的 JSON 文件加載注釋。"""Initializes a GroundingDataset for object detection, loading annotations from a specified JSON file."""# 使用 assert 確保任務類型為 "detect" ,因為目前該類僅支持目標檢測任務。 如果任務類型不是 "detect" ,拋出異常。assert task == "detect", "`GroundingDataset` only support `detect` task for now!" # `GroundingDataset` 目前僅支持 `detect` 任務!# 將標注文件路徑存儲在實例變量 self.json_file 中。self.json_file = json_file# 調用父類 YOLODataset 的構造函數,初始化父類的屬性和方法。 傳遞一個空字典 data={} 作為 數據集配置 ,因為標注信息將從 JSON 文件中加載。super().__init__(*args, task=task, data={}, **kwargs)# 定義了一個空的 get_img_files 方法,用于獲取圖像文件路徑。def get_img_files(self, img_path):# 圖像文件將在“get_labels”函數中讀取,在此處返回空列表。"""The image files would be read in `get_labels` function, return empty list here."""# 目前該方法返回一個空列表,可能需要根據具體需求實現。return []# 這段代碼定義了 GroundingDataset 類中的 get_labels 方法,用于從標注文件(JSON 格式)中加載和處理目標檢測任務的標簽信息。# 定義了 get_labels 方法,用于從標注文件中加載和格式化標簽信息。def get_labels(self):# 從 JSON 文件加載注釋,過濾并規范化每個圖像的邊界框。"""Loads annotations from a JSON file, filters, and normalizes bounding boxes for each image."""# 初始化一個空列表 labels ,用于 存儲處理后的標簽信息 。labels = []# 使用 LOGGER.info 輸出加載標注文件的信息。LOGGER.info("Loading annotation file...") # 正在加載注釋文件...# 打開標注文件( self.json_file ),并使用 json.load 加載其內容。with open(self.json_file) as f:annotations = json.load(f)# 遍歷 標注文件 中的 "images" 部分,將每張圖像的信息存儲在一個字典中,鍵為圖像的 ID,值為圖像的詳細信息。images = {f"{x['id']:d}": x for x in annotations["images"]}# 使用 defaultdict 創建一個默認值為空列表的字典 img_to_anns 。img_to_anns = defaultdict(list)# 遍歷 標注文件 中的 "annotations" 部分,將每個標注( ann )添加到對應圖像 ID 的列表中。for ann in annotations["annotations"]:img_to_anns[ann["image_id"]].append(ann)# 這段代碼是 get_labels 方法的核心部分,用于處理每個圖像的標注信息。它從標注文件中提取邊界框和類別信息,并進行必要的預處理。# 使用 TQDM 創建一個進度條,顯示處理標注文件的進度。 遍歷 img_to_anns 字典中的每個條目。 img_id 是圖像的 ID。 anns 是該圖像的所有標注信息(一個列表)。for img_id, anns in TQDM(img_to_anns.items(), desc=f"Reading annotations {self.json_file}"):# 提取圖像信息。 從 images 字典中提取 當前圖像的信息 。img = images[f"{img_id:d}"]# h :圖像的高度。# w :圖像的寬度。# f :圖像的文件名。h, w, f = img["height"], img["width"], img["file_name"]# 使用 Path 對象構造 圖像文件的完整路徑 。 self.img_path 是圖像文件所在的目錄。 f 是圖像文件名。im_file = Path(self.img_path) / f# 檢查圖像文件是否存在。如果文件不存在,跳過當前圖像的處理。if not im_file.exists():continue# 將 圖像文件路徑 添加到 self.im_files 列表中,供后續加載圖像使用。self.im_files.append(str(im_file))# 初始化以下列表和字典。# 用于存儲邊界框信息。bboxes = []# 用于存儲類別名稱到類別 ID 的映射。cat2id = {}# 用于存儲文本信息(類別名稱)。texts = []# 遍歷當前圖像的所有標注信息 anns 。for ann in anns:# 如果標注標記為 "iscrowd" (表示人群標注),跳過該標注。if ann["iscrowd"]:continue# 提取標注中的邊界框信息。# ann["bbox"] 是一個列表 [x, y, w, h] ,表示邊界框的左上角坐標和寬度、高度。box = np.array(ann["bbox"], dtype=np.float32)# 將邊界框從 [x, y, w, h] 格式轉換為 [cx, cy, w, h] 格式(中心點坐標)。 box[:2] += box[2:] / 2 :將左上角坐標 (x, y) 轉換為中心點坐標 (cx, cy) 。box[:2] += box[2:] / 2# 將邊界框坐標歸一化到 [0, 1] 范圍內。# 將 x 和 w 除以圖像寬度。box[[0, 2]] /= float(w)# 將 y 和 h 除以圖像高度。box[[1, 3]] /= float(h)# 檢查邊界框的寬度和高度是否為正數。如果寬度或高度為零或負數,跳過該標注。if box[2] <= 0 or box[3] <= 0:continue# 這段代碼的功能是。遍歷圖像標注:從 img_to_anns 中提取每個圖像的標注信息。提取圖像信息:獲取圖像的高度、寬度和文件名。檢查圖像文件是否存在:跳過不存在的圖像文件。初始化邊界框和類別信息:為當前圖像準備存儲邊界框和類別信息的結構。遍歷標注:處理每個標注,跳過人群標注。提取邊界框信息:將邊界框從 [x, y, w, h] 格式轉換為 [cx, cy, w, h] 格式,并歸一化坐標。檢查邊界框的有效性:跳過寬度或高度為零的邊界框。通過這些步驟,代碼能夠高效地處理每個圖像的標注信息,為后續的目標檢測任務準備數據。# 這段代碼的功能是從標注信息中提取文本描述(caption)和類別名稱,并將它們與邊界框信息關聯起來,最終構建用于目標檢測的標簽數據。# 從圖像信息字典 img 中提取 描述文本 ( caption ),通常是一個字符串,描述圖像的內容。caption = img["caption"]# 遍歷標注 ann 中的 tokens_positive ,這是一個列表,包含描述文本中與當前標注相關的文本片段的起始和結束索引。 使用列表推導式從 caption 中提取這些片段,并將它們拼接為一個完整的類別名稱 cat_name 。cat_name = " ".join([caption[t[0] : t[1]] for t in ann["tokens_positive"]])# 如果類別名稱 cat_name 不在 cat2id 字典中。if cat_name not in cat2id:# 為其分配一個唯一的類別 ID(當前 cat2id 的長度)。cat2id[cat_name] = len(cat2id)# 將 類別名稱 添加到 texts 列表中,用于后續存儲文本信息。texts.append([cat_name])# 從 cat2id 字典中獲取 當前類別名稱對應的類別 ID 。cls = cat2id[cat_name] # class# 將類別 ID 添加到邊界框信息的開頭,形成一個完整的邊界框標簽 [cls, cx, cy, w, h] 。box = [cls] + box.tolist()# 如果 當前邊界框信息 尚未存在于 bboxes 列表中,將其添加進去。 這一步確保每個邊界框是唯一的,避免重復。if box not in bboxes:bboxes.append(box)# 如果 bboxes 列表不為空,將其轉換為一個 NumPy 數組 lb ,數據類型為 float32 。 如果 bboxes 為空,創建一個形狀為 (0, 5) 的空數組,表示沒有邊界框。lb = np.array(bboxes, dtype=np.float32) if len(bboxes) else np.zeros((0, 5), dtype=np.float32)# 構建一個標簽字典,包含以下信息 :labels.append({# 圖像文件路徑。"im_file": im_file,# 圖像的尺寸 (h, w) 。"shape": (h, w),# 類別 ID 列表(從 lb 的第一列提取)。"cls": lb[:, 0:1], # n, 1# 邊界框坐標列表(從 lb 的第二列到第五列提取)。"bboxes": lb[:, 1:], # n, 4# 標記邊界框坐標是否歸一化。"normalized": True,# 邊界框的格式( "xywh" ,表示中心點坐標和寬度、高度)。"bbox_format": "xywh",# 文本信息列表,包含類別名稱。"texts": texts,})# 這段代碼的功能是。提取圖像描述和類別名稱:從標注信息中提取與當前標注相關的文本片段,并拼接為類別名稱。構建類別到 ID 的映射:為每個類別名稱分配一個唯一的 ID,并存儲文本信息。構建邊界框信息:將類別 ID 和邊界框坐標組合成一個完整的標簽。添加邊界框到列表:確保每個邊界框是唯一的,避免重復。構建標簽數組:將邊界框信息轉換為 NumPy 數組。構建標簽字典:將處理后的信息存儲為一個字典,供后續數據加載和訓練使用。通過這種方式,代碼能夠高效地從標注文件中提取和格式化目標檢測任務的標簽信息,同時保留與類別相關的文本描述,適用于多模態任務(結合圖像和文本信息)。# 返回格式化后的標簽列表 labels 。return labels# 這段代碼的功能是。加載標注文件:從 JSON 文件中加載標注信息。構建圖像到標注的映射:將每個圖像的標注信息組織起來。處理每個圖像的標注:提取圖像信息(高度、寬度、文件名)。遍歷標注,提取邊界框和文本信息。將邊界框坐標歸一化,并構建類別到 ID 的映射。構建標簽字典:將處理后的信息存儲為一個字典,包含圖像路徑、圖像尺寸、類別、邊界框和文本信息。返回標簽列表:返回處理后的標簽列表,供后續數據加載和訓練使用。通過這種方式, get_labels 方法能夠高效地從標注文件中加載和格式化目標檢測任務的標簽信息,確保數據在訓練前的一致性和完整性。# 這段代碼定義了 GroundingDataset 類中的 build_transforms 方法,用于構建數據增強和格式化流程。# 定義了 build_transforms 方法,該方法接受一個參數。# 1.hyp :表示超參數配置(例如數據增強的參數)。def build_transforms(self, hyp=None):# 配置用于訓練的增強功能,并帶有可選的文本加載;`hyp` 調整增強強度。"""Configures augmentations for training with optional text loading; `hyp` adjusts augmentation intensity."""# 調用父類 YOLODataset 的 build_transforms 方法,獲取基礎的數據增強流程。 父類方法會根據 hyp 參數和數據集的配置生成一個數據處理流程( transforms ),可能包括圖像調整大小、歸一化等操作。transforms = super().build_transforms(hyp)# 檢查是否啟用了數據增強( self.augment 為 True )。如果啟用,將繼續執行后續增強操作。if self.augment:# NOTE: hard-coded the args for now. 注意:目前參數是硬編碼的。# 在數據增強流程中插入一個 RandomLoadText 操作。# max_samples=80 :限制最多加載 80 個文本樣本。# padding=True :對文本數據進行填充,確保所有文本長度一致。# 使用 insert(-1, ...) 將該操作插入到數據增強流程的倒數第二位。這通常是為了確保在圖像增強操作之后加載文本數據。transforms.insert(-1, RandomLoadText(max_samples=80, padding=True))# 返回 構建好的數據增強和格式化流程 transforms ,供后續數據加載和訓練使用。return transforms# 這段代碼的功能是。調用父類方法:獲取基礎的數據增強流程。檢查是否啟用數據增強:如果啟用,繼續添加額外的增強操作。添加文本加載增強:在數據增強流程中插入 RandomLoadText 操作,用于加載和處理文本數據。返回增強流程:返回完整的數據增強和格式化流程。通過這種方式, build_transforms 方法能夠靈活地擴展父類的數據增強功能,支持多模態任務(結合圖像和文本信息),為后續訓練提供了更豐富的數據處理能力。
# GroundingDataset 類是 YOLODataset 的擴展,專門用于處理基于標注文件(如 COCO 格式的 JSON 文件)的目標檢測任務,同時支持多模態數據(結合圖像和文本信息)。該類通過重寫 get_labels 方法,從標注文件中加載和格式化標簽信息,支持從描述文本(caption)中提取類別名稱,并將其與邊界框信息關聯。此外, GroundingDataset 在 build_transforms 方法中擴展了數據增強流程,加入了文本加載增強操作,以支持多模態任務。通過這些功能, GroundingDataset 提供了從標注文件加載數據、處理多模態信息以及應用數據增強的完整解決方案,適用于需要結合圖像和文本的目標檢測任務。
5.class YOLOConcatDataset(ConcatDataset):?
# 這段代碼定義了 YOLOConcatDataset 類,它是 ConcatDataset 的一個擴展,用于將多個數據集合并為一個數據集,同時繼承了 YOLODataset 的數據處理功能。
# 定義了 YOLOConcatDataset 類,繼承自 ConcatDataset 。 ConcatDataset 是 PyTorch 提供的一個類,用于將多個數據集合并為一個數據集。 YOLOConcatDataset 的目的是將多個 YOLODataset 實例合并,并統一處理數據。
class YOLOConcatDataset(ConcatDataset):# 數據集為多個數據集的串聯。# 此類可用于組裝不同的現有數據集。"""Dataset as a concatenation of multiple datasets.This class is useful to assemble different existing datasets."""# 定義了一個靜態方法 collate_fn ,用于將一個批次( batch )的數據合并為一個統一的張量格式。@staticmethod# 1.batch :一個列表,其中每個元素是一個樣本的數據(通常是一個字典)。def collate_fn(batch):# 將數據樣本整理成批次。"""Collates data samples into batches."""# 調用 YOLODataset 類中的 collate_fn 方法,將批次數據合并為統一的張量格式。 YOLODataset.collate_fn 方法負責將圖像、標簽等數據合并為批次張量,并處理批量索引等信息。return YOLODataset.collate_fn(batch)
# 這段代碼的功能是。定義 YOLOConcatDataset 類:繼承自 ConcatDataset ,用于合并多個數據集。實現 collate_fn 方法:通過調用 YOLODataset.collate_fn ,將批次數據合并為統一的張量格式。統一數據處理:確保合并后的數據集在數據加載和處理時保持一致的格式。通過這種方式, YOLOConcatDataset 類能夠將多個數據集合并為一個數據集,同時繼承 YOLODataset 的數據處理邏輯,適用于需要合并多個數據集進行訓練的場景。
6.class SemanticDataset(BaseDataset):?
# TODO: support semantic segmentation TODO:支持語義分割。
# 這段代碼定義了一個名為 SemanticDataset 的類,它繼承自 BaseDataset ,用于處理語義分割任務的數據集
# 定義了 SemanticDataset 類,繼承自 BaseDataset 。 BaseDataset 是一個基礎類,通常包含數據加載和預處理的通用方法。 SemanticDataset 的目的是為語義分割任務提供特定的數據處理邏輯。
class SemanticDataset(BaseDataset):# 語義分割數據集。# 此類負責處理用于語義分割任務的數據集。它從 BaseDataset 類繼承功能。# 注意:# 此類當前為占位符,需要填充方法和屬性以支持語義分割任務。"""Semantic Segmentation Dataset.This class is responsible for handling datasets used for semantic segmentation tasks. It inherits functionalitiesfrom the BaseDataset class.Note:This class is currently a placeholder and needs to be populated with methods and attributes for supportingsemantic segmentation tasks."""# 定義了 SemanticDataset 類的構造函數,用于初始化數據集類的實例。def __init__(self):# 初始化 SemanticDataset 對象。"""Initialize a SemanticDataset object."""# 調用父類 BaseDataset 的構造函數,初始化父類的屬性和方法。這一步確保 SemanticDataset 繼承了 BaseDataset 的所有功能,例如數據路徑的設置、圖像加載等。super().__init__()
# 這段代碼的功能是。定義 SemanticDataset 類:繼承自 BaseDataset ,用于處理語義分割任務的數據集。初始化父類:通過調用父類的構造函數,確保繼承了基礎的數據加載和預處理功能。通過這種方式, SemanticDataset 類能夠利用 BaseDataset 提供的基礎功能,同時可以在此基礎上擴展語義分割任務特定的邏輯,例如加載分割掩碼、處理類別映射等。
7.class ClassificationDataset:?
# 這段代碼定義了 ClassificationDataset 類,用于處理圖像分類任務的數據集。它支持數據增強、內存緩存和磁盤緩存功能,并通過緩存機制提高數據加載效率。
# 定義了 ClassificationDataset 類,用于處理圖像分類任務的數據集。
class ClassificationDataset:# 擴展 torchvision ImageFolder 以支持 YOLO 分類任務,提供圖像增強、緩存和驗證等功能。它旨在高效處理用于訓練深度學習模型的大型數據集,并具有可選的圖像轉換和緩存機制以加快訓練速度。# 此類允許使用 torchvision 和 Albumentations 庫進行增強,并支持在 RAM 或磁盤上緩存圖像以減少訓練期間的 IO 開銷。此外,它還實現了強大的驗證過程以確保數據的完整性和一致性。"""Extends torchvision ImageFolder to support YOLO classification tasks, offering functionalities like imageaugmentation, caching, and verification. It's designed to efficiently handle large datasets for training deeplearning models, with optional image transformations and caching mechanisms to speed up training.This class allows for augmentations using both torchvision and Albumentations libraries, and supports caching imagesin RAM or on disk to reduce IO overhead during training. Additionally, it implements a robust verification processto ensure data integrity and consistency.Attributes:cache_ram (bool): Indicates if caching in RAM is enabled.cache_disk (bool): Indicates if caching on disk is enabled.samples (list): A list of tuples, each containing the path to an image, its class index, path to its .npy cachefile (if caching on disk), and optionally the loaded image array (if caching in RAM).torch_transforms (callable): PyTorch transforms to be applied to the images."""# 這段代碼定義了 ClassificationDataset 類的構造函數 __init__ ,用于初始化圖像分類任務的數據集。# 定義了構造函數,接受以下參數 :# 1.root :數據集的根目錄路徑。# 2.args :包含數據集配置和訓練參數的對象。# 3.augment :布爾值,表示是否啟用數據增強,默認為 False 。# 4.prefix :日志信息的前綴,默認為空字符串。def __init__(self, root, args, augment=False, prefix=""):# 使用 root、圖像大小、增強和緩存設置初始化 YOLO 對象。"""Initialize YOLO object with root, image size, augmentations, and cache settings.Args:root (str): Path to the dataset directory where images are stored in a class-specific folder structure.args (Namespace): Configuration containing dataset-related settings such as image size, augmentationparameters, and cache settings. It includes attributes like `imgsz` (image size), `fraction` (fractionof data to use), `scale`, `fliplr`, `flipud`, `cache` (disk or RAM caching for faster training),`auto_augment`, `hsv_h`, `hsv_s`, `hsv_v`, and `crop_fraction`.augment (bool, optional): Whether to apply augmentations to the dataset. Default is False.prefix (str, optional): Prefix for logging and cache filenames, aiding in dataset identification anddebugging. Default is an empty string."""# 導入 torchvision 模塊,用于加載和處理圖像數據集。這里通過局部導入的方式,避免在全局范圍內加載 torchvision ,從而加快 ultralytics 模塊的導入速度。import torchvision # scope for faster 'import ultralytics'# Base class assigned as attribute rather than used as base class to allow for scoping slow torchvision import 將基類指定為屬性,而不是用作基類,以允許對緩慢的 torchvision 導入進行范圍界定。# 檢查 torchvision 的版本是否為 0.18 或更高。if TORCHVISION_0_18: # 'allow_empty' argument first introduced in torchvision 0.18# torchvision.datasets.ImageFolder(root, transform=None, target_transform=None, loader=None, is_valid_file=None)# torchvision.datasets.ImageFolder 是 PyTorch 的一個類,它提供了一種方便的方式來加載結構化存儲的圖像數據集。這種結構化存儲意味著圖像被組織在不同的文件夾中,每個文件夾的名稱對應一個類別。# 參數 :# root :數據集的根目錄路徑,其中包含所有類別的子文件夾。# transform :一個可選的函數或可調用對象,用于對圖像進行預處理或數據增強。它在圖像加載后、返回前應用于圖像。# target_transform :一個可選的函數或可調用對象,用于對標簽進行預處理。它在標簽加載后、返回前應用于標簽。# loader :一個函數,用于加載圖像文件。默認情況下,使用 PIL 庫加載圖像。# is_valid_file :一個函數,用于檢查文件名是否有效。如果提供,它將被用于過濾文件。# 返回值 :# 返回一個 ImageFolder 實例,該實例包含圖像數據集的加載和預處理邏輯。# ImageFolder 類是 PyTorch 中處理圖像分類任務時常用的工具之一,它簡化了數據加載和預處理的過程,使得用戶可以專注于模型的訓練和評估。# torchvision.datasets.ImageFolder 類的實例通常包含以下常見的屬性 :# root : 字符串,表示數據集的根目錄路徑。# samples : 列表,包含數據集中所有圖像的元組信息,通常每個元組包含圖像的路徑和對應的標簽索引。# 如果是,使用 allow_empty=True 參數初始化 ImageFolder , 允許加載空文件夾 。self.base = torchvision.datasets.ImageFolder(root=root, allow_empty=True)# 否則,使用默認參數初始化 ImageFolder 。else:self.base = torchvision.datasets.ImageFolder(root=root)# 初始化樣本列表和根目錄。# 從基礎數據集中獲取 樣本列表 ,每個樣本是一個元組 (路徑, 類別索引) 。self.samples = self.base.samples# 數據集的 根目錄路徑 。self.root = self.base.root# Initialize attributes# 如果啟用了數據增強且 args.fraction 小于 1.0,減少訓練數據的比例。 args.fraction 表示訓練數據的比例,例如 0.5 表示使用一半的數據。if augment and args.fraction < 1.0: # reduce training fractionself.samples = self.samples[: round(len(self.samples) * args.fraction)]# 初始化日志前綴。如果提供了 prefix ,將其格式化為彩色字符串并存儲在 self.prefix 中。 如果未提供 prefix ,則為空字符串。self.prefix = colorstr(f"{prefix}: ") if prefix else ""# 檢查是否啟用 內存緩存 ( cache_ram )。如果 args.cache 是 True 或字符串 "ram" ,啟用內存緩存。self.cache_ram = args.cache is True or str(args.cache).lower() == "ram" # cache images into RAM# 如果啟用了內存緩存,發出警告,說明存在已知的內存泄漏問題。if self.cache_ram:LOGGER.warning("WARNING ?? Classification `cache_ram` training has known memory leak in " # 警告??分類`cache_ram`訓練在https://github.com/ultralytics/ultralytics/issues/9824中存在已知內存泄漏,設置`cache_ram=False`。"https://github.com/ultralytics/ultralytics/issues/9824, setting `cache_ram=False`.")# 并將 cache_ram 設置為 False 。self.cache_ram = False# 檢查是否啟用 磁盤緩存 ( cache_disk )。如果 args.cache 是字符串 "disk" ,啟用磁盤緩存。self.cache_disk = str(args.cache).lower() == "disk" # cache images on hard drive as uncompressed *.npy files# 調用 verify_images 方法,過濾掉損壞的圖像。self.samples = self.verify_images() # filter out bad images# 為每個樣本添加 兩個額外的字段 。# Path(x[0]).with_suffix(".npy") :圖像對應的 .npy 文件路徑,用于緩存。# None :用于存儲緩存的圖像數據。self.samples = [list(x) + [Path(x[0]).with_suffix(".npy"), None] for x in self.samples] # file, index, npy, im# 定義了 圖像縮放的范圍 scale ,表示圖像在預處理時可以被縮放的比例。 args.scale 是一個參數,表示圖像縮放的最小比例(例如 0.08 表示圖像可以縮小到原始尺寸的 8%)。 scale 的范圍是從 1.0 - args.scale 到 1.0 ,即圖像可以縮小到最小比例,但不會放大。scale = (1.0 - args.scale, 1.0) # (0.08, 1.0)# 根據是否啟用數據增強( augment 參數),選擇合適的數據處理流程。self.torch_transforms = (# 如果啟用數據增強( augment=True ),調用 classify_augmentations 函數。# def classify_augmentations(size=224, mean=DEFAULT_MEAN, std=DEFAULT_STD, scale=None, ratio=None, hflip=0.5, vflip=0.0, auto_augment=None, hsv_h=0.015, hsv_s=0.4, hsv_v=0.4, force_color_jitter=False, erasing=0.0, interpolation="BILINEAR",):# -> 用于生成圖像分類任務中的數據增強(Data Augmentation)變換序列。將 主變換列表 primary_tfl 、輔助變換列表 secondary_tfl 和最終變換列表 final_tfl 合并為一個完整的變換序列。 使用 T.Compose 將所有變換組合在一起,返回一個可以應用于圖像的變換對象。# -> return T.Compose(primary_tfl + secondary_tfl + final_tfl)classify_augmentations(# 指定圖像的最終尺寸。size=args.imgsz,# 指定圖像縮放的范圍。scale=scale,# 是否啟用水平翻轉增強。hflip=args.fliplr,# 是否啟用垂直翻轉增強。vflip=args.flipud,# 是否啟用隨機擦除增強。erasing=args.erasing,# 是否啟用自動增強策略(如 AutoAugment 或 RandAugment)。auto_augment=args.auto_augment,# HSV 色彩空間中色調(H)的變化范圍。hsv_h=args.hsv_h,# HSV 色彩空間中飽和度(S)的變化范圍。hsv_s=args.hsv_s,# HSV 色彩空間中亮度(V)的變化范圍。hsv_v=args.hsv_v,)if augment# 如果不啟用數據增強( augment=False ),調用 classify_transforms 函數。# size=args.imgsz :指定圖像的最終尺寸。# crop_fraction=args.crop_fraction :指定裁剪比例,用于中心裁剪或隨機裁剪。# def classify_transforms(size=224, mean=DEFAULT_MEAN, std=DEFAULT_STD, interpolation="BILINEAR", crop_fraction: float = DEFAULT_CROP_FRACTION,):# -> 用于生成圖像分類任務中常用的圖像預處理流程。使用 torchvision.transforms.Compose 將所有變換組合成一個完整的變換序列。 返回的 T.Compose 對象可以作為 PyTorch 數據預處理管道的一部分,對輸入圖像依次應用所有定義的變換。# -> return T.Compose(tfl)else classify_transforms(size=args.imgsz, crop_fraction=args.crop_fraction))# 這段代碼的功能是。初始化基礎數據集:使用 torchvision.datasets.ImageFolder 加載數據集。初始化樣本列表和根目錄:從基礎數據集中獲取樣本列表和根目錄路徑。減少訓練數據比例:根據 args.fraction 減少訓練數據的比例。初始化日志前綴:設置日志信息的前綴。初始化內存緩存:檢查是否啟用內存緩存,并發出警告(如果存在內存泄漏問題)。初始化磁盤緩存:檢查是否啟用磁盤緩存。驗證圖像:過濾掉損壞的圖像。初始化樣本列表:為每個樣本添加緩存路徑和緩存數據字段。初始化數據增強:根據是否啟用數據增強,選擇合適的預處理流程。通過這種方式, ClassificationDataset 類能夠高效地加載和處理圖像分類任務的數據集,支持數據增強和緩存功能,提高數據加載效率。# 這段代碼定義了 ClassificationDataset 類中的 __getitem__ 方法,用于獲取單個樣本的數據。# 定義了 __getitem__ 方法,用于根據索引 1.i 獲取單個樣本的數據。這是 PyTorch 數據集類的標準方法,用于在數據加載器中迭代數據。def __getitem__(self, i):# 返回與給定索引相對應的數據子集和目標。"""Returns subset of data and targets corresponding to given indices."""# 從 self.samples 列表中提取第 i 個樣本的信息。# f :圖像文件路徑。# j :類別索引。# fn :緩存的 .npy 文件路徑。# im :緩存的圖像數據(如果啟用了緩存)。f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image# 如果啟用了 內存緩存 ( self.cache_ram 為 True )。if self.cache_ram:# 檢查 im 是否為 None 。如果是,則加載圖像并緩存到 self.samples[i][3] 中。if im is None: # Warning: two separate if statements required here, do not combine this with previous line 注意:這里需要兩個獨立的 if 語句,不能與前面的提取樣本信息合并,否則會導致邏輯錯誤。im = self.samples[i][3] = cv2.imread(f)# 如果啟用了磁盤緩存( self.cache_disk 為 True )。elif self.cache_disk:# 檢查緩存的 .npy 文件是否存在。如果不存在,則加載圖像并保存為 .npy 文件。if not fn.exists(): # load npynp.save(fn.as_posix(), cv2.imread(f), allow_pickle=False)# 加載緩存的 .npy 文件。im = np.load(fn)# 如果未啟用緩存,直接讀取圖像文件。else: # read imageim = cv2.imread(f) # BGR# Convert NumPy array to PIL image# 將圖像從 BGR 格式轉換為 RGB 格式。 將 NumPy 數組轉換為 PIL 圖像格式,以便后續處理。im = Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))# 使用 self.torch_transforms 對圖像進行數據增強或預處理。 self.torch_transforms 是在構造函數中初始化的,包含一系列 PyTorch 圖像處理操作。sample = self.torch_transforms(im)# 返回一個字典,包含 處理后的圖像數據 和 類別索引 。# "img" :處理后的圖像張量。# "cls" :類別索引。return {"img": sample, "cls": j}# 這段代碼的功能是。提取樣本信息:從 self.samples 中提取圖像路徑、類別索引、緩存路徑和緩存圖像。內存緩存邏輯:如果啟用了內存緩存,加載或緩存圖像到內存。磁盤緩存邏輯:如果啟用了磁盤緩存,加載或緩存圖像到磁盤。讀取圖像:如果未啟用緩存,直接讀取圖像。轉換圖像格式:將圖像從 BGR 格式轉換為 RGB 格式,并轉換為 PIL 圖像。應用數據增強或預處理:使用 self.torch_transforms 對圖像進行處理。返回樣本數據:返回處理后的圖像和類別索引。通過這種方式, __getitem__ 方法能夠高效地加載和處理單個樣本的數據,支持內存緩存和磁盤緩存功能,提高數據加載效率。# 這段代碼定義了 ClassificationDataset 類中的 __len__ 方法,用于返回數據集的樣本數量。# 定義了 __len__ 方法,這是 Python 中的一個特殊方法,用于返回對象的長度。在這里,它返回數據集的樣本數量。返回值類型為 int ,表示樣本數量是一個整數。def __len__(self) -> int:# 返回數據集中的樣本總數。"""Return the total number of samples in the dataset."""# 使用 len() 函數獲取 self.samples 列表的長度,即數據集中樣本的數量。 self.samples 是一個列表,每個元素表示一個樣本的信息(例如圖像路徑和類別索引)。return len(self.samples)# 這段代碼的功能是。定義 __len__ 方法:返回數據集的樣本數量。返回樣本數量:通過 len(self.samples) 獲取樣本數量。通過這種方式, __len__ 方法能夠提供數據集的大小信息,這在 PyTorch 數據加載器中非常有用,例如在訓練循環中確定數據集的迭代次數。# 這段代碼定義了 ClassificationDataset 類中的 verify_images 方法,用于驗證數據集中的圖像文件是否有效,并生成或加載緩存文件以提高驗證效率。# 定義了 verify_images 方法,用于驗證數據集中的圖像文件是否有效,并過濾掉損壞的圖像。def verify_images(self):# 驗證數據集中的所有圖像。"""Verify all images in dataset."""# 初始化描述信息和緩存路徑。# 描述信息,用于進度條顯示,說明正在掃描的數據集路徑。desc = f"{self.prefix}Scanning {self.root}..." # {self.prefix}正在掃描 {self.root}...# 緩存文件的路徑,文件名為 .cache ,存儲在數據集根目錄下。path = Path(self.root).with_suffix(".cache") # *.cache file path# 這段代碼的功能是嘗試加載緩存文件,并驗證其內容是否與當前數據集一致。如果緩存文件有效,它會返回緩存中的樣本列表。# 嘗試加載指定路徑 path 的緩存文件。try:# load_dataset_cache_file 是一個函數,用于加載緩存文件并返回其內容。# def load_dataset_cache_file(path): -> 它從指定路徑加載一個以 .cache 結尾的文件,并將其內容解析為一個字典。返回加載并解析后的字典對象,即 .cache 文件的內容。 -> return cachecache = load_dataset_cache_file(path) # attempt to load a *.cache file# 檢查緩存文件的版本號是否與當前數據集的版本號一致。 DATASET_CACHE_VERSION 是一個常量,表示當前數據集的版本。 如果版本號不匹配,拋出 AssertionError 。assert cache["version"] == DATASET_CACHE_VERSION # matches current version# 檢查緩存文件的哈希值是否與當前數據集的哈希值一致。 get_hash 是一個函數,用于計算數據集的哈希值,通常基于圖像路徑列表。 如果哈希值不匹配,拋出 AssertionError 。assert cache["hash"] == get_hash([x[0] for x in self.samples]) # identical hash# 從緩存字典中提取統計結果。# nf :有效圖像數量。# nc :損壞圖像數量。# n :總圖像數量。# samples :有效樣本列表。nf, nc, n, samples = cache.pop("results") # found, missing, empty, corrupt, total# 如果當前進程是主進程( LOCAL_RANK 為 -1 或 0 )。if LOCAL_RANK in {-1, 0}:d = f"{desc} {nf} images, {nc} corrupt" # {desc} {nf} 圖像,{nc} 損壞。# 使用 TQDM 顯示緩存信息。TQDM(None, desc=d, total=n, initial=n)# 如果緩存中有消息( cache["msgs"] ),將這些消息輸出到日志。if cache["msgs"]:LOGGER.info("\n".join(cache["msgs"])) # display warnings# 返回 緩存中的有效樣本列表 。return samples# 這段代碼的功能是。嘗試加載緩存文件:加載指定路徑的緩存文件。驗證緩存文件的版本和哈希值:確保緩存文件與當前數據集一致。提取緩存結果:從緩存文件中提取統計結果和樣本列表。顯示緩存信息:在主進程中顯示緩存信息,并輸出驗證消息。返回有效樣本列表:返回緩存中的有效樣本列表,供后續使用。通過這種方式,代碼能夠高效地利用緩存文件,避免重復驗證數據集,提高數據加載效率。# 這段代碼的功能是處理緩存文件加載失敗的情況,并重新運行圖像驗證流程,生成新的緩存文件。# 捕獲三種可能的異常。# FileNotFoundError :緩存文件不存在。# AssertionError :緩存文件的版本或哈希值與當前數據集不匹配。# AttributeError :緩存文件格式錯誤或缺少關鍵字段。except (FileNotFoundError, AssertionError, AttributeError):# Run scan if *.cache retrieval failed# 初始化驗證統計變量。# nf :有效圖像數量。# nc :損壞圖像數量。# msgs :驗證消息列表。# samples :有效樣本列表。# x :緩存字典,用于存儲驗證結果。nf, nc, msgs, samples, x = 0, 0, [], [], {}# 使用 ThreadPool 創建多線程池,驗證每個樣本。with ThreadPool(NUM_THREADS) as pool:# verify_image 是驗證函數,檢查圖像是否損壞。# zip(self.samples, repeat(self.prefix)) :將 樣本信息 和 日志前綴 傳遞給驗證函數。results = pool.imap(func=verify_image, iterable=zip(self.samples, repeat(self.prefix)))# 使用 TQDM 創建進度條,顯示驗證進度。pbar = TQDM(results, desc=desc, total=len(self.samples))# 遍歷進度條 pbar 中的 驗證結果 。每個結果是一個元組,包含以下內容 :# sample :當前樣本的信息(例如圖像路徑和類別索引)。# nf_f :布爾值,表示當前樣本是否有效( True 表示有效)。# nc_f :布爾值,表示當前樣本是否損壞( True 表示損壞)。# msg :驗證過程中生成的消息(例如警告信息)。for sample, nf_f, nc_f, msg in pbar:# 如果當前樣本有效( nf_f 為 True ),將其添加到 samples 列表中。 samples 列表 用于存儲所有有效的樣本信息 。if nf_f:samples.append(sample)# 如果驗證過程中生成了消息( msg 不為空),將其添加到 msgs 列表中。 msgs 列表用于 存儲所有驗證消息 ,通常包含警告或錯誤信息。if msg:msgs.append(msg)# 更新統計變量。# 有效圖像數量 。如果 nf_f 為 True , nf 加 1。nf += nf_f# 損壞圖像數量 。如果 nc_f 為 True , nc 加 1。nc += nc_f# 動態更新進度條的描述信息,顯示當前驗證的進度。# desc :初始描述信息,例如 "Scanning /path/to/dataset..." 。# nf :當前有效圖像的數量。# nc :當前損壞圖像的數量。pbar.desc = f"{desc} {nf} images, {nc} corrupt" # {desc} {nf} 圖像,{nc} 損壞。# 關閉進度條。pbar.close()# 如果有驗證消息,將這些消息輸出到日志。if msgs:LOGGER.info("\n".join(msgs))# 計算當前數據集的哈希值。 將驗證結果存儲到緩存字典 x 中 :# 數據集的哈希值。x["hash"] = get_hash([x[0] for x in self.samples])# 包含統計結果( nf 、 nc 、總樣本數、有效樣本列表)。x["results"] = nf, nc, len(samples), samples# 驗證消息列表。x["msgs"] = msgs # warnings# 調用 save_dataset_cache_file 函數,將緩存文件保存到指定路徑。# # def save_dataset_cache_file(prefix, path, x, version): -> 它將一個字典 x 保存為一個以 .cache 結尾的文件,并將其存儲到指定路徑 path 。save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)# 返回驗證后的有效樣本列表。return samples# 這段代碼的功能是。捕獲異常:處理緩存文件加載失敗的情況。初始化驗證變量:為驗證流程準備統計變量和緩存字典。使用多線程驗證圖像:并行驗證每個樣本,檢查圖像是否損壞。處理驗證結果:統計有效和損壞的圖像數量,并收集驗證消息。保存驗證結果到緩存文件:將驗證結果和消息保存到緩存文件中,便于后續快速加載。返回有效樣本列表:返回驗證后的有效樣本列表,供后續數據加載和訓練使用。通過這種方式,代碼能夠高效地驗證數據集中的圖像文件,過濾掉損壞的圖像,并利用緩存機制提高驗證效率。# 這段代碼的功能是。加載緩存文件:嘗試加載緩存文件,驗證其版本和哈希值是否與當前數據集一致。顯示緩存信息:如果緩存文件有效,顯示緩存信息并返回緩存中的樣本列表。運行驗證:如果緩存文件無效或不存在,使用多線程驗證每個圖像文件是否損壞。保存緩存文件:將驗證結果保存到緩存文件中,便于后續快速加載。返回有效樣本列表:返回驗證后的有效樣本列表,供后續數據加載和訓練使用。通過這種方式, verify_images 方法能夠高效地驗證數據集中的圖像文件,過濾掉損壞的圖像,并利用緩存機制提高驗證效率。
# ClassificationDataset 類是一個用于圖像分類任務的數據集類,繼承自 torchvision.datasets.ImageFolder 。它提供了高效的數據加載和預處理功能,支持數據增強、內存緩存和磁盤緩存。通過緩存機制,該類能夠快速驗證圖像文件的有效性,并過濾掉損壞的圖像。此外,它還支持多種數據增強策略,如隨機翻轉、擦除、自動增強和 HSV 調整,以提高模型的泛化能力。通過靈活的配置和優化的數據處理流程, ClassificationDataset 類適用于大規模圖像分類任務的訓練和驗證。