圖像分類時,如果某個類別或者某些類別的數量遠大于其他類別的話,模型在計算的時候,更傾向于擬合數量更多的類別;因此,觀察類別數量以及對數據量多的類別進行截斷是很有必要的。
1.準備數據
數據的格式為圖像分類數據集格式,根目錄下分為train和val文件夾,每個文件夾下以類別名命名的子文件夾:
.
├── ./datasets
│ ├── ./datasets/train/A
│ │ ├── ./datasets/train/A/1.jpg
│ │ ├── ./datasets/train/A/2.jpg
│ │ ├── ./datasets/train/A/3.jpg
│ │ ├── …
│ ├── ./datasets/train/B
│ │ ├── ./datasets/train/B/1.jpg
│ │ ├── ./datasets/train/B/1.jpg
│ │ ├── ./datasets/train/B/1.jpg
│ │ ├── …
│ ├── ./datasets/val/A
│ │ ├── ./datasets/val/A/1.jpg
│ │ ├── ./datasets/val/A/2.jpg
│ │ ├── ./datasets/val/A/3.jpg
│ │ ├── …
│ ├── ./datasets/val/B
│ │ ├── ./datasets/val/B/1.jpg
│ │ ├── ./datasets/val/B/1.jpg
│ │ ├── ./datasets/val/B/1.jpg
│ │ ├── …
2.查看數據分布
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pddef count_images(directory, image_extensions):"""統計每個子文件夾中的圖像數量。:param directory: 主目錄路徑(train或val):param image_extensions: 允許的圖像文件擴展名元組:return: 一個字典,鍵為類別名,值為圖像數量"""counts = {}if not os.path.exists(directory):print(f"目錄不存在: {directory}")return countsfor class_name in os.listdir(directory):class_path = os.path.join(directory, class_name)if os.path.isdir(class_path):# 統計符合擴展名的文件數量image_count = sum(1 for file in os.listdir(class_path)if file.lower().endswith(image_extensions))counts[class_name] = image_countreturn countsdef count_images_in_single_directory(directory, image_extensions):"""統計單個目錄下每個類別的圖像數量。:param directory: 主目錄路徑:param image_extensions: 允許的圖像文件擴展名元組:return: 一個字典,鍵為類別名,值為圖像數量"""counts = {}if not os.path.exists(directory):print(f"目錄不存在: {directory}")return countsfor class_name in os.listdir(directory):class_path = os.path.join(directory, class_name)if os.path.isdir(class_path):image_count = sum(1 for file in os.listdir(class_path)if file.lower().endswith(image_extensions))counts[class_name] = image_countreturn countsdef autolabel(ax, rects):"""在每個柱狀圖上方添加數值標簽。:param ax: Matplotlib 的軸對象:param rects: 柱狀圖對象"""for rect in rects:height = rect.get_height()ax.annotate(f'{height}',xy=(rect.get_x() + rect.get_width() / 2, height),xytext=(0, 3), # 3 points vertical offsettextcoords="offset points",ha='center', va='bottom')def plot_distribution(all_classes, train_values, val_values, output_path, has_val=False):"""繪制并保存訓練集和驗證集中每個類別的圖像數量分布柱狀圖。如果沒有驗證集數據,則只繪制訓練集數據。:param all_classes: 所有類別名稱列表:param train_values: 訓練集中每個類別的圖像數量列表:param val_values: 驗證集中每個類別的圖像數量列表(如果有的話):param output_path: 保存圖表的文件路徑:param has_val: 是否包含驗證集數據"""x = np.arange(len(all_classes)) # 類別位置width = 0.35 # 柱狀圖的寬度fig, ax = plt.subplots(figsize=(12, 8))if has_val:rects1 = ax.bar(x - width/2, train_values, width, label='Train')rects2 = ax.bar(x + width/2, val_values, width, label='Validation')else:rects1 = ax.bar(x, train_values, width, label='Count')# 添加一些文本標簽ax.set_xlabel('Category')ax.set_ylabel('Number of Images')title = 'Number of Images in Each Category for Train and Validation' if has_val else 'Number of Images in Each Category'ax.set_title(title)ax.set_xticks(x)ax.set_xticklabels(all_classes, rotation=45, ha='right')ax.legend() if has_val else ax.legend(['Count'])# 自動標注柱狀圖上的數值autolabel(ax, rects1)if has_val:autolabel(ax, rects2)fig.tight_layout()# 保存圖表為圖片文件plt.savefig(output_path, dpi=300, bbox_inches='tight')print(f"圖表已保存到 {output_path}")def compute_and_display_statistics(counts_dict, dataset_name, save_csv=False):"""計算并展示統計數據,包括總圖像數量、類別數量、平均每個類別的圖像數量和類別占比。:param counts_dict: 類別名稱與圖像數量的字典:param dataset_name: 數據集名稱(例如 'Train', 'Validation', 'Dataset'):param save_csv: 是否保存統計結果為 CSV 文件"""total_images = sum(counts_dict.values())num_classes = len(counts_dict)avg_per_class = total_images / num_classes if num_classes > 0 else 0# 計算每個類別的占比category_proportions = {cls: (count / total_images * 100) if total_images > 0 else 0 for cls, count in counts_dict.items()}# 創建 DataFramedf = pd.DataFrame({'類別名稱': list(counts_dict.keys()),'圖像數量': list(counts_dict.values()),'占比 (%)': [f"{prop:.2f}" for prop in category_proportions.values()]})# 排序 DataFrame 按圖像數量降序df = df.sort_values(by='圖像數量', ascending=False)print(f"\n===== {dataset_name} 數據統計 =====")print(df.to_string(index=False))print(f"總圖像數量: {total_images}")print(f"類別數量: {num_classes}")print(f"平均每個類別的圖像數量: {avg_per_class:.2f}")# 根據 save_csv 參數決定是否保存為 CSV 文件if save_csv:# 將數據集名稱轉換為小寫并去除空格,以作為文件名的一部分sanitized_name = dataset_name.lower().replace(" ", "_").replace("(", "").replace(")", "")csv_filename = f"{sanitized_name}_statistics.csv"df.to_csv(csv_filename, index=False, encoding='utf-8-sig')print(f"統計表已保存為 {csv_filename}\n")def main():# ================== 配置參數 ==================# 設置數據集的根目錄路徑dataset_root = 'datasets/device_cls_merge_manual_with_21w_1218' # 替換為你的數據集路徑# 定義train和val目錄train_dir = os.path.join(dataset_root, 'train')val_dir = os.path.join(dataset_root, 'val')# 定義允許的圖像文件擴展名image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.gif')# 輸出圖表的路徑output_path = 'dataset_distribution.png' # 你可以更改為你想要的文件名和路徑# 是否保存統計結果為 CSV 文件(默認不保存)SAVE_CSV = False # 設置為 True 以啟用保存 CSV# ================== 統計圖像數量 ==================has_train = os.path.exists(train_dir) and os.path.isdir(train_dir)has_val = os.path.exists(val_dir) and os.path.isdir(val_dir)if has_train and has_val:print("檢測到 'train' 和 'val' 目錄。統計訓練集和驗證集中的圖像數量...")train_counts = count_images(train_dir, image_extensions)val_counts = count_images(val_dir, image_extensions)# 獲取所有類別的名稱(確保train和val中的類別一致)all_classes = sorted(list(set(train_counts.keys()) | set(val_counts.keys())))# 準備繪圖數據train_values = [train_counts.get(cls, 0) for cls in all_classes]val_values = [val_counts.get(cls, 0) for cls in all_classes]# ================== 計算并展示統計數據 ==================compute_and_display_statistics(train_counts, '訓練集 (Train)', save_csv=SAVE_CSV)compute_and_display_statistics(val_counts, '驗證集 (Validation)', save_csv=SAVE_CSV)# ================== 繪制并保存圖表 ==================print("繪制并保存訓練集和驗證集的圖表...")plot_distribution(all_classes, train_values, val_values, output_path, has_val=True)else:print("未檢測到 'train' 和 'val' 目錄。將統計主目錄下的圖像數量...")# 如果沒有train和val目錄,則統計主目錄下的圖像分布main_counts = count_images_in_single_directory(dataset_root, image_extensions)# 獲取所有類別的名稱all_classes = sorted(main_counts.keys())# 準備繪圖數據main_values = [main_counts.get(cls, 0) for cls in all_classes]# 定義輸出圖表路徑(可以區分不同的輸出文件名)output_path_single = 'dataset_distribution_single.png' # 或者使用與train_val相同的output_path# ================== 計算并展示統計數據 ==================compute_and_display_statistics(main_counts, '數據集 (Dataset)', save_csv=SAVE_CSV)# ================== 繪制并保存圖表 ==================print("繪制并保存主目錄的圖表...")plot_distribution(all_classes, main_values, [], output_path_single, has_val=False)if __name__ == "__main__":main()
下圖為原始數據集運行結果,可以看到數據存在嚴重不均衡問題
3.數據截斷
import os
import shutil
import randomdef count_images(directory, image_extensions):"""統計每個子文件夾中的圖像文件路徑列表。:param directory: 主目錄路徑(train或val):param image_extensions: 允許的圖像文件擴展名列表:return: 一個字典,鍵為類別名,值為圖像文件路徑列表"""counts = {}if not os.path.exists(directory):print(f"目錄不存在: {directory}")return countsfor class_name in os.listdir(directory):class_path = os.path.join(directory, class_name)if os.path.isdir(class_path):# 獲取符合擴展名的文件列表images = [file for file in os.listdir(class_path)if file.lower().endswith(tuple(image_extensions))]image_paths = [os.path.join(class_path, img) for img in images]counts[class_name] = image_pathsreturn countsdef truncate_dataset(class_images, threshold, seed=42):"""對每個類別的圖像進行截斷,如果超過閾值則隨機選擇一定數量的圖像。:param class_images: 一個字典,鍵為類別名,值為圖像文件路徑列表:param threshold: 每個類別的圖像數量閾值:param seed: 隨機種子:return: 截斷后的類別圖像字典"""truncated = {}random.seed(seed)for class_name, images in class_images.items():if len(images) > threshold:truncated_images = random.sample(images, threshold)truncated[class_name] = truncated_imagesprint(f"類別 '{class_name}' 超過閾值 {threshold},已隨機選擇 {threshold} 張圖像。")else:truncated[class_name] = imagesprint(f"類別 '{class_name}' 不超過閾值 {threshold},保留所有 {len(images)} 張圖像。")return truncateddef copy_images(truncated_data, subset, output_root):"""將截斷后的圖像復制到輸出目錄,保持原有的目錄結構。:param truncated_data: 截斷后的類別圖像字典:param subset: 'train' 或 'val':param output_root: 輸出根目錄路徑"""for class_name, images in truncated_data.items():dest_dir = os.path.join(output_root, subset, class_name)os.makedirs(dest_dir, exist_ok=True)for img_path in images:img_name = os.path.basename(img_path)dest_path = os.path.join(dest_dir, img_name)shutil.copy2(img_path, dest_path)print(f"'{subset}' 子集已復制到 {output_root}")def main():"""主函數,執行數據集截斷和復制操作。"""# ================== 配置參數 ==================# 原始數據集根目錄路徑input_dir = 'datasets/device_cls_merge_manual_with_21w_1218_train_val_224' # 替換為你的原始數據集路徑# 截斷后數據集的輸出根目錄路徑output_dir = 'datasets/device_cls_merge_manual_with_21w_1218_train_val_224_truncate' # 替換為你希望保存截斷后數據集的路徑# 訓練集每個類別的圖像數量閾值train_threshold = 2000 # 設置為你需要的訓練集閾值# 驗證集每個類別的圖像數量閾值val_threshold = 400 # 設置為你需要的驗證集閾值# 允許的圖像文件擴展名image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff']# 隨機種子以確保可重復性random_seed = 42# ================== 腳本實現 ==================# 設置隨機種子random.seed(random_seed)# 定義train和val目錄路徑train_input_dir = os.path.join(input_dir, 'train')val_input_dir = os.path.join(input_dir, 'val')# 統計train和val中的圖像print("統計訓練集中的圖像數量...")train_counts = count_images(train_input_dir, image_extensions)print("統計驗證集中的圖像數量...")val_counts = count_images(val_input_dir, image_extensions)# 截斷train和val中的圖像print("\n截斷訓練集中的圖像...")truncated_train = truncate_dataset(train_counts, train_threshold, random_seed)print("\n截斷驗證集中的圖像...")truncated_val = truncate_dataset(val_counts, val_threshold, random_seed)# 復制截斷后的圖像到輸出目錄print("\n復制截斷后的訓練集圖像...")copy_images(truncated_train, 'train', output_dir)print("復制截斷后的驗證集圖像...")copy_images(truncated_val, 'val', output_dir)print("\n數據集截斷完成。")if __name__ == "__main__":main()
再次查看已經符合截斷后的數據分布了