【圖像分類實用腳本】數據可視化以及高數量類別截斷

圖像分類時,如果某個類別或者某些類別的數量遠大于其他類別的話,模型在計算的時候,更傾向于擬合數量更多的類別;因此,觀察類別數量以及對數據量多的類別進行截斷是很有必要的。

1.準備數據

數據的格式為圖像分類數據集格式,根目錄下分為train和val文件夾,每個文件夾下以類別名命名的子文件夾:

.
├── ./datasets
│ ├── ./datasets/train/A
│ │ ├── ./datasets/train/A/1.jpg
│ │ ├── ./datasets/train/A/2.jpg
│ │ ├── ./datasets/train/A/3.jpg
│ │ ├── …
│ ├── ./datasets/train/B
│ │ ├── ./datasets/train/B/1.jpg
│ │ ├── ./datasets/train/B/1.jpg
│ │ ├── ./datasets/train/B/1.jpg
│ │ ├── …
│ ├── ./datasets/val/A
│ │ ├── ./datasets/val/A/1.jpg
│ │ ├── ./datasets/val/A/2.jpg
│ │ ├── ./datasets/val/A/3.jpg
│ │ ├── …
│ ├── ./datasets/val/B
│ │ ├── ./datasets/val/B/1.jpg
│ │ ├── ./datasets/val/B/1.jpg
│ │ ├── ./datasets/val/B/1.jpg
│ │ ├── …

2.查看數據分布

import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pddef count_images(directory, image_extensions):"""統計每個子文件夾中的圖像數量。:param directory: 主目錄路徑(train或val):param image_extensions: 允許的圖像文件擴展名元組:return: 一個字典,鍵為類別名,值為圖像數量"""counts = {}if not os.path.exists(directory):print(f"目錄不存在: {directory}")return countsfor class_name in os.listdir(directory):class_path = os.path.join(directory, class_name)if os.path.isdir(class_path):# 統計符合擴展名的文件數量image_count = sum(1 for file in os.listdir(class_path)if file.lower().endswith(image_extensions))counts[class_name] = image_countreturn countsdef count_images_in_single_directory(directory, image_extensions):"""統計單個目錄下每個類別的圖像數量。:param directory: 主目錄路徑:param image_extensions: 允許的圖像文件擴展名元組:return: 一個字典,鍵為類別名,值為圖像數量"""counts = {}if not os.path.exists(directory):print(f"目錄不存在: {directory}")return countsfor class_name in os.listdir(directory):class_path = os.path.join(directory, class_name)if os.path.isdir(class_path):image_count = sum(1 for file in os.listdir(class_path)if file.lower().endswith(image_extensions))counts[class_name] = image_countreturn countsdef autolabel(ax, rects):"""在每個柱狀圖上方添加數值標簽。:param ax: Matplotlib 的軸對象:param rects: 柱狀圖對象"""for rect in rects:height = rect.get_height()ax.annotate(f'{height}',xy=(rect.get_x() + rect.get_width() / 2, height),xytext=(0, 3),  # 3 points vertical offsettextcoords="offset points",ha='center', va='bottom')def plot_distribution(all_classes, train_values, val_values, output_path, has_val=False):"""繪制并保存訓練集和驗證集中每個類別的圖像數量分布柱狀圖。如果沒有驗證集數據,則只繪制訓練集數據。:param all_classes: 所有類別名稱列表:param train_values: 訓練集中每個類別的圖像數量列表:param val_values: 驗證集中每個類別的圖像數量列表(如果有的話):param output_path: 保存圖表的文件路徑:param has_val: 是否包含驗證集數據"""x = np.arange(len(all_classes))  # 類別位置width = 0.35  # 柱狀圖的寬度fig, ax = plt.subplots(figsize=(12, 8))if has_val:rects1 = ax.bar(x - width/2, train_values, width, label='Train')rects2 = ax.bar(x + width/2, val_values, width, label='Validation')else:rects1 = ax.bar(x, train_values, width, label='Count')# 添加一些文本標簽ax.set_xlabel('Category')ax.set_ylabel('Number of Images')title = 'Number of Images in Each Category for Train and Validation' if has_val else 'Number of Images in Each Category'ax.set_title(title)ax.set_xticks(x)ax.set_xticklabels(all_classes, rotation=45, ha='right')ax.legend() if has_val else ax.legend(['Count'])# 自動標注柱狀圖上的數值autolabel(ax, rects1)if has_val:autolabel(ax, rects2)fig.tight_layout()# 保存圖表為圖片文件plt.savefig(output_path, dpi=300, bbox_inches='tight')print(f"圖表已保存到 {output_path}")def compute_and_display_statistics(counts_dict, dataset_name, save_csv=False):"""計算并展示統計數據,包括總圖像數量、類別數量、平均每個類別的圖像數量和類別占比。:param counts_dict: 類別名稱與圖像數量的字典:param dataset_name: 數據集名稱(例如 'Train', 'Validation', 'Dataset'):param save_csv: 是否保存統計結果為 CSV 文件"""total_images = sum(counts_dict.values())num_classes = len(counts_dict)avg_per_class = total_images / num_classes if num_classes > 0 else 0# 計算每個類別的占比category_proportions = {cls: (count / total_images * 100) if total_images > 0 else 0 for cls, count in counts_dict.items()}# 創建 DataFramedf = pd.DataFrame({'類別名稱': list(counts_dict.keys()),'圖像數量': list(counts_dict.values()),'占比 (%)': [f"{prop:.2f}" for prop in category_proportions.values()]})# 排序 DataFrame 按圖像數量降序df = df.sort_values(by='圖像數量', ascending=False)print(f"\n===== {dataset_name} 數據統計 =====")print(df.to_string(index=False))print(f"總圖像數量: {total_images}")print(f"類別數量: {num_classes}")print(f"平均每個類別的圖像數量: {avg_per_class:.2f}")# 根據 save_csv 參數決定是否保存為 CSV 文件if save_csv:# 將數據集名稱轉換為小寫并去除空格,以作為文件名的一部分sanitized_name = dataset_name.lower().replace(" ", "_").replace("(", "").replace(")", "")csv_filename = f"{sanitized_name}_statistics.csv"df.to_csv(csv_filename, index=False, encoding='utf-8-sig')print(f"統計表已保存為 {csv_filename}\n")def main():# ================== 配置參數 ==================# 設置數據集的根目錄路徑dataset_root = 'datasets/device_cls_merge_manual_with_21w_1218'  # 替換為你的數據集路徑# 定義train和val目錄train_dir = os.path.join(dataset_root, 'train')val_dir = os.path.join(dataset_root, 'val')# 定義允許的圖像文件擴展名image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.gif')# 輸出圖表的路徑output_path = 'dataset_distribution.png'  # 你可以更改為你想要的文件名和路徑# 是否保存統計結果為 CSV 文件(默認不保存)SAVE_CSV = False  # 設置為 True 以啟用保存 CSV# ================== 統計圖像數量 ==================has_train = os.path.exists(train_dir) and os.path.isdir(train_dir)has_val = os.path.exists(val_dir) and os.path.isdir(val_dir)if has_train and has_val:print("檢測到 'train' 和 'val' 目錄。統計訓練集和驗證集中的圖像數量...")train_counts = count_images(train_dir, image_extensions)val_counts = count_images(val_dir, image_extensions)# 獲取所有類別的名稱(確保train和val中的類別一致)all_classes = sorted(list(set(train_counts.keys()) | set(val_counts.keys())))# 準備繪圖數據train_values = [train_counts.get(cls, 0) for cls in all_classes]val_values = [val_counts.get(cls, 0) for cls in all_classes]# ================== 計算并展示統計數據 ==================compute_and_display_statistics(train_counts, '訓練集 (Train)', save_csv=SAVE_CSV)compute_and_display_statistics(val_counts, '驗證集 (Validation)', save_csv=SAVE_CSV)# ================== 繪制并保存圖表 ==================print("繪制并保存訓練集和驗證集的圖表...")plot_distribution(all_classes, train_values, val_values, output_path, has_val=True)else:print("未檢測到 'train' 和 'val' 目錄。將統計主目錄下的圖像數量...")# 如果沒有train和val目錄,則統計主目錄下的圖像分布main_counts = count_images_in_single_directory(dataset_root, image_extensions)# 獲取所有類別的名稱all_classes = sorted(main_counts.keys())# 準備繪圖數據main_values = [main_counts.get(cls, 0) for cls in all_classes]# 定義輸出圖表路徑(可以區分不同的輸出文件名)output_path_single = 'dataset_distribution_single.png'  # 或者使用與train_val相同的output_path# ================== 計算并展示統計數據 ==================compute_and_display_statistics(main_counts, '數據集 (Dataset)', save_csv=SAVE_CSV)# ================== 繪制并保存圖表 ==================print("繪制并保存主目錄的圖表...")plot_distribution(all_classes, main_values, [], output_path_single, has_val=False)if __name__ == "__main__":main()

下圖為原始數據集運行結果,可以看到數據存在嚴重不均衡問題
在這里插入圖片描述

3.數據截斷

import os
import shutil
import randomdef count_images(directory, image_extensions):"""統計每個子文件夾中的圖像文件路徑列表。:param directory: 主目錄路徑(train或val):param image_extensions: 允許的圖像文件擴展名列表:return: 一個字典,鍵為類別名,值為圖像文件路徑列表"""counts = {}if not os.path.exists(directory):print(f"目錄不存在: {directory}")return countsfor class_name in os.listdir(directory):class_path = os.path.join(directory, class_name)if os.path.isdir(class_path):# 獲取符合擴展名的文件列表images = [file for file in os.listdir(class_path)if file.lower().endswith(tuple(image_extensions))]image_paths = [os.path.join(class_path, img) for img in images]counts[class_name] = image_pathsreturn countsdef truncate_dataset(class_images, threshold, seed=42):"""對每個類別的圖像進行截斷,如果超過閾值則隨機選擇一定數量的圖像。:param class_images: 一個字典,鍵為類別名,值為圖像文件路徑列表:param threshold: 每個類別的圖像數量閾值:param seed: 隨機種子:return: 截斷后的類別圖像字典"""truncated = {}random.seed(seed)for class_name, images in class_images.items():if len(images) > threshold:truncated_images = random.sample(images, threshold)truncated[class_name] = truncated_imagesprint(f"類別 '{class_name}' 超過閾值 {threshold},已隨機選擇 {threshold} 張圖像。")else:truncated[class_name] = imagesprint(f"類別 '{class_name}' 不超過閾值 {threshold},保留所有 {len(images)} 張圖像。")return truncateddef copy_images(truncated_data, subset, output_root):"""將截斷后的圖像復制到輸出目錄,保持原有的目錄結構。:param truncated_data: 截斷后的類別圖像字典:param subset: 'train' 或 'val':param output_root: 輸出根目錄路徑"""for class_name, images in truncated_data.items():dest_dir = os.path.join(output_root, subset, class_name)os.makedirs(dest_dir, exist_ok=True)for img_path in images:img_name = os.path.basename(img_path)dest_path = os.path.join(dest_dir, img_name)shutil.copy2(img_path, dest_path)print(f"'{subset}' 子集已復制到 {output_root}")def main():"""主函數,執行數據集截斷和復制操作。"""# ================== 配置參數 ==================# 原始數據集根目錄路徑input_dir = 'datasets/device_cls_merge_manual_with_21w_1218_train_val_224'  # 替換為你的原始數據集路徑# 截斷后數據集的輸出根目錄路徑output_dir = 'datasets/device_cls_merge_manual_with_21w_1218_train_val_224_truncate'  # 替換為你希望保存截斷后數據集的路徑# 訓練集每個類別的圖像數量閾值train_threshold = 2000  # 設置為你需要的訓練集閾值# 驗證集每個類別的圖像數量閾值val_threshold = 400  # 設置為你需要的驗證集閾值# 允許的圖像文件擴展名image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff']# 隨機種子以確保可重復性random_seed = 42# ================== 腳本實現 ==================# 設置隨機種子random.seed(random_seed)# 定義train和val目錄路徑train_input_dir = os.path.join(input_dir, 'train')val_input_dir = os.path.join(input_dir, 'val')# 統計train和val中的圖像print("統計訓練集中的圖像數量...")train_counts = count_images(train_input_dir, image_extensions)print("統計驗證集中的圖像數量...")val_counts = count_images(val_input_dir, image_extensions)# 截斷train和val中的圖像print("\n截斷訓練集中的圖像...")truncated_train = truncate_dataset(train_counts, train_threshold, random_seed)print("\n截斷驗證集中的圖像...")truncated_val = truncate_dataset(val_counts, val_threshold, random_seed)# 復制截斷后的圖像到輸出目錄print("\n復制截斷后的訓練集圖像...")copy_images(truncated_train, 'train', output_dir)print("復制截斷后的驗證集圖像...")copy_images(truncated_val, 'val', output_dir)print("\n數據集截斷完成。")if __name__ == "__main__":main()

再次查看已經符合截斷后的數據分布了
在這里插入圖片描述

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/63621.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/63621.shtml
英文地址,請注明出處:http://en.pswp.cn/web/63621.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

【Leetcode 每日一題】2545. 根據第 K 場考試的分數排序

問題背景 班里有 m m m 位學生,共計劃組織 n n n 場考試。給你一個下標從 0 0 0 開始、大小為 m n m \times n mn 的整數矩陣 s c o r e score score,其中每一行對應一位學生,而 s c o r e [ i ] [ j ] score[i][j] score[i][j] 表示…

React系列(八)——React進階知識點拓展

前言 在之前的學習中,我們已經知道了React組件的定義和使用,路由配置,組件通信等其他方法的React知識點,那么本篇文章將針對React的一些進階知識點以及React16.8之后的一些新特性進行講解。希望對各位有所幫助。 一、setState &am…

PCIe_Host驅動分析_地址映射

往期內容 本文章相關專欄往期內容,PCI/PCIe子系統專欄: 嵌入式系統的內存訪問和總線通信機制解析、PCI/PCIe引入 深入解析非橋PCI設備的訪問和配置方法 PCI橋設備的訪問方法、軟件角度講解PCIe設備的硬件結構 深入解析PCIe設備事務層與配置過程 PCIe的三…

【閱讀記錄-章節6】Build a Large Language Model (From Scratch)

文章目錄 6. Fine-tuning for classification6.1 Different categories of fine-tuning6.2 Preparing the dataset第一步:下載并解壓數據集第二步:檢查類別標簽分布第三步:創建平衡數據集第四步:數據集拆分 6.3 Creating data loa…

ip_output函數

ip_output函數是Linux內核(特別是網絡子系統)中用于發送IPv4數據包的核心函數。以下是一個示例實現,并附上詳細的中文講解: int ip_output(struct net *net, struct sock *sk, struct sk_buff *skb) {struct iphdr *iph; /* 構建IP頭部 */iph = ip_hdr(skb);/* 設置服務…

梳理你的思路(從OOP到架構設計)_簡介設計模式

目錄 1、 模式(Pattern) 是較大的結構?編輯 2、 結構形式愈大 通用性愈小?編輯 3、 從EIT造形 組合出設計模式 1、 模式(Pattern) 是較大的結構 組合與創新 達芬奇說:簡單是複雜的終極形式 (Simplicity is the ultimate form of sophistication) —Leonardo d…

用SparkSQL和PySpark完成按時間字段順序將字符串字段中的值組合在一起分組顯示

用SparkSQL和PySpark完成以下數據轉換。 源數據: userid,page_name,visit_time 1,A,2021-2-1 2,B,2024-1-1 1,C,2020-5-4 2,D,2028-9-1 目的數據: user_id,page_name_path 1,C->A 2,B->D PySpark: from pyspark.sql import SparkSes…

【libuv】Fargo信令2:【深入】client為什么收不到服務端響應的ack消息

客戶端處理server的ack回復,判斷鏈接連接建立 【Fargo】28:字節序列【libuv】Fargo信令1:client發connect消息給到server客戶端啟動后理解監聽read消息 但是,這個代碼似乎沒有觸發ack消息的接收: // 客戶端初始化 void start_client(uv_loop_t

硬盤dma讀寫過程

pci初始化時,遍歷pci上的設置,如果BaseClassCode1,則為大容量存儲控制器,包括硬盤控制器、固態硬盤控制器、光盤驅動控制器、RAID控制器等。 BaseAdder4為DMA控制器基地址,包含兩個控制器,主控制器&#x…

Python-基于Pygame的小游戲(貪吃蛇)(一)

前言:貪吃蛇是一款經典的電子游戲,最早可以追溯到1976年的街機游戲Blockade。隨著諾基亞手機的普及,貪吃蛇游戲在1990年代變得廣為人知。它是一款休閑益智類游戲,適合所有年齡段的玩家,其最初為單機模式,后來隨著技術發…

使用k6進行MongoDB負載測試

1.安裝環境 安裝xk6-mongo擴展 ./xk6 build --with github.com/itsparser/xk6-mongo 2.安裝MongoDB 參考Docker安裝MongoDB服務-CSDN博客 連接成功后新建test數據庫和sample集合 3.編寫腳本 test_mongo.js import xk6_mongo from k6/x/mongo;const client xk6_mongo.new…

solon 集成 activemq-client (sdk)

原始狀態的 activemq-client sdk 集成非常方便&#xff0c;也更適合定制。就是有些同學&#xff0c;可能對原始接口會比較陌生&#xff0c;會希望有個具體的示例。 <dependency><groupId>org.apache.activemq</groupId><artifactId>activemq-client&l…

2024 年最新前端ES-Module模塊化、webpack打包工具詳細教程(更新中)

模塊化概述 什么是模塊&#xff1f;模塊是一個封裝了特定功能的代碼塊&#xff0c;可以獨立開發、測試和維護。模塊通過導出&#xff08;export&#xff09;和導入&#xff08;import&#xff09;與其他模塊通信&#xff0c;保持內部細節的封裝。 前端 JavaScript 模塊化是指…

uni-app商品搜索頁面

目錄 一:功能概述 二:功能實現 一:功能概述 商品搜索頁面,可以根據商品品牌,商品分類,商品價格等信息實現商品搜索和列表展示。 二:功能實現 1:商品搜索數據 <view class="search-map padding-main bg-base"> <view class…

最小堆及添加元素操作

【小白從小學Python、C、Java】 【考研初試復試畢業設計】 【Python基礎AI數據分析】 最小堆及添加元素操作 [太陽]選擇題 以下代碼執行的結果為&#xff1f; import heapq heap [] heapq.heappush(heap, 5) heapq.heappush(heap, 3) heapq.heappush(heap, 2) heapq.…

10. 考勤信息

題目描述 公司用一個字符串來表示員工的出勤信息 absent:缺勤late: 遲到leaveearly: 早退present: 正常上班 現需根據員工出勤信息&#xff0c;判斷本次是否能獲得出勤獎&#xff0c;能獲得出勤獎的條件如下: 缺勤不超過一次&#xff0c;沒有連續的遲到/早退:任意連續7次考勤&a…

【計算機網絡】期末考試預習復習|中

作業講解 轉發器、網橋、路由器和網關(4-6) 作為中間設備&#xff0c;轉發器、網橋、路由器和網關有何區別&#xff1f; (1) 物理層使用的中間設備叫做轉發器(repeater)。 (2) 數據鏈路層使用的中間設備叫做網橋或橋接器(bridge)。 (3) 網絡層使用的中間設備叫做路…

前端工程化-Vue腳手架安裝

在現代前端開發中&#xff0c;Vue.js已成為一個流行的框架&#xff0c;而Vue CLI&#xff08;腳手架&#xff09;則為開發者提供了一個方便的工具&#xff0c;用于快速創建和管理Vue項目。本文將詳細介紹如何安裝Vue腳手架&#xff0c;創建新項目以及常見問題的解決方法。 什么…

利用爬蟲獲取的數據能否用于商業分析?

在數字化時代&#xff0c;數據已成為企業獲取競爭優勢的關鍵資源。網絡爬蟲作為一種數據收集工具&#xff0c;能夠從互聯網上抓取大量數據&#xff0c;這些數據在商業分析中扮演著重要角色。然而&#xff0c;使用爬蟲技術獲取的數據是否合法、能否用于商業分析&#xff0c;是許…

羅德與施瓦茨ZN-Z129E網絡分析儀校準套件具體參數

羅德與施瓦茨ZN-Z129E網絡校準件ZN-Z129E網絡分析儀校準套件 1&#xff0c;頻率范圍從9kHz到4GHz&#xff08;ZNB4&#xff09;,8.5GHz(ZNB8)&#xff0c;20GHz(ZNB20)&#xff0c;40GHz(ZNB40) 2&#xff0c;動態范圍寬&#xff0c;高達140 dB 3&#xff0c;掃描時間短達4ms…