近期公司有個項目,需要解決長尾樣本的問題,所以測試了一下paddlepaddle小樣本的能力。
環境::T4? 、ubuntu 、cuda-11.6 、py3.9、? ?paddlepaddle-gpu==2.6.0、pip install opencv-python==4.5.5.64 -i https://pypi.tuna.tsinghua.edu.cn/simple? ? 、?pip install ?numpy==1.23.0
預訓練模型:ppyoloe_crn_s_obj365_pretrained.pdparams
數據集下載地址:五種水果目標檢測數據集coco格式_數據集-飛槳AI Studio星河社區
1、數據集準備五種水果:蕃茄、核桃、桔子、龍眼、青棗。共300張圖像,640*480.COCO格式
2、先正常訓練一波
數據如下:165步0.735的%表現
3、用腳本每個coco類別從原train.json提取10張圖片,代碼:
?
import json
from collections import defaultdict
import argparse
import osdef create_small_sample_coco(original_json, output_json, samples_per_class=10):"""從COCO格式的標注文件中,為每個類別提取指定數量的樣本,并生成新的COCO標注文件參數:original_json (str): 原始COCO標注文件路徑output_json (str): 輸出的小樣本COCO標注文件路徑samples_per_class (int): 每個類別提取的樣本數量"""# 加載原始標注數據with open(original_json, 'r', encoding='utf-8') as f:coco_data = json.load(f)# 確保必要的字段存在,不存在則添加默認值required_fields = {'info': {'description': 'Small sample dataset'},'licenses': [{'id': 0, 'name': 'Unknown'}],'categories': [],'images': [],'annotations': []}for field, default in required_fields.items():if field not in coco_data:print(f"警告: 標注文件缺少 '{field}' 字段,將使用默認值")coco_data[field] = default# 1. 統計每個類別的標注數量category_counts = defaultdict(int)for ann in coco_data['annotations']:cat_id = ann['category_id']category_counts[cat_id] += 1# 檢查是否有類別if not category_counts:print("錯誤: 標注文件中未找到任何類別或標注")return# 2. 為每個類別選擇指定數量的樣本selected_images = set() # 存儲被選中的image_idcategory_samples = defaultdict(int) # 記錄每個類別已選擇的樣本數for ann in coco_data['annotations']:cat_id = ann['category_id']img_id = ann['image_id']# 如果該類別已選樣本數不足,且該圖片尚未被選中if category_samples[cat_id] < samples_per_class and img_id not in selected_images:selected_images.add(img_id)category_samples[cat_id] += 1# 檢查是否所有類別都已選夠樣本if all(count >= samples_per_class for count in category_samples.values()):break# 3. 篩選出被選中的圖片及其標注filtered_images = [img for img in coco_data['images'] if img['id'] in selected_images]filtered_annotations = [ann for ann in coco_data['annotations'] if ann['image_id'] in selected_images]# 4. 構建新的COCO數據集small_coco = {'info': coco_data['info'],'licenses': coco_data['licenses'],'categories': coco_data['categories'],'images': filtered_images,'annotations': filtered_annotations}# 5. 保存新的標注文件with open(output_json, 'w', encoding='utf-8') as f:json.dump(small_coco, f, indent=2)# 打印統計信息print(f"成功創建小樣本數據集!")print(f"原始圖片數量: {len(coco_data['images'])}")print(f"篩選后圖片數量: {len(filtered_images)}")print(f"每個類別樣本數: {samples_per_class}")print(f"保存路徑: {output_json}")# 檢查每個類別的實際樣本數actual_counts = defaultdict(int)for ann in filtered_annotations:actual_counts[ann['category_id']] += 1# 映射類別ID到類別名稱id_to_name = {cat['id']: cat['name'] for cat in coco_data['categories']}print("\n每個類別的實際樣本數:")for cat_id, count in actual_counts.items():cat_name = id_to_name.get(cat_id, f"類別_{cat_id}")print(f" {cat_name} (ID:{cat_id}): {count}個樣本")if __name__ == "__main__":parser = argparse.ArgumentParser(description='從COCO數據集中創建小樣本數據集')parser.add_argument('--input', '-i', required=True, help='原始COCO標注文件路徑')parser.add_argument('--output', '-o', required=True, help='輸出的小樣本COCO標注文件路徑')parser.add_argument('--samples', '-s', type=int, default=10, help='每個類別提取的樣本數,默認為10')args = parser.parse_args()# 檢查輸入文件是否存在if not os.path.exists(args.input):print(f"錯誤: 輸入文件 '{args.input}' 不存在")exit(1)# 檢查輸出目錄是否存在,不存在則創建output_dir = os.path.dirname(args.output)if output_dir and not os.path.exists(output_dir):os.makedirs(output_dir)create_small_sample_coco(args.input, args.output, args.samples)
4、再次訓練
python tools/train.py -c configs/few-shot/ppyoloe_plus_crn_s_80e_contrast_pcb.yml --amp --eval --use_vdl=True --vdl_log_dir=./visdrone/
在39步精度達到0.69%
5、預測一下
python tools/infer.py -c configs/few-shot/ppyoloe_plus_crn_s_80e_contrast_pcb.yml -o weights=output1/best_model.pdparams --infer_img=/home/PaddleDetection/dataset/coco/fruit5_coco/images/106.jpg
6、訓練配置
_BASE_: ['../datasets/coco_detection.yml','../runtime.yml','./_base_/optimizer_80e.yml','./_base_/ppyoloe_plus_crn.yml','./_base_/ppyoloe_plus_reader.yml',
]log_iter: 100
snapshot_epoch: 5
weights: output/ppyoloe_plus_crn_s_80e_contrast_pcb/model_finalpretrain_weights: ./ppyoloe_crn_s_obj365_pretrained.pdparams
depth_mult: 0.33
width_mult: 0.50epoch: 190LearningRate:base_lr: 0.0001schedulers:- !CosineDecaymax_epochs: 596- !LinearWarmupstart_factor: 0.epochs: 5YOLOv3:backbone: CSPResNetneck: CustomCSPPANyolo_head: PPYOLOEContrastHeadpost_process: ~PPYOLOEContrastHead:fpn_strides: [32, 16, 8]grid_cell_scale: 5.0grid_cell_offset: 0.5static_assigner_epoch: 100use_varifocal_loss: Trueloss_weight: {class: 1.0, iou: 2.5, dfl: 0.5, contrast: 0.2}static_assigner:name: ATSSAssignertopk: 9assigner:name: TaskAlignedAssignertopk: 13alpha: 1.0beta: 6.0contrast_loss:name: SupContrasttemperature: 100sample_num: 2048thresh: 0.75nms:name: MultiClassNMSnms_top_k: 1000keep_top_k: 300score_threshold: 0.01nms_threshold: 0.7num_classes: 5
metric: COCO
map_type: integralTrainDataset:!COCODataSetimage_dir: imagesanno_path: /home/PaddleDetection/dataset/small.jsondataset_dir: /home/PaddleDetection/dataset/coco/fruit5_coco/data_fields: ['image', 'gt_bbox', 'gt_class', 'is_crowd']EvalDataset:!COCODataSetimage_dir: imagesanno_path: /home/PaddleDetection/dataset/coco/fruit5_coco/annotations/instance_val.jsondataset_dir: /home/PaddleDetection/dataset/coco/fruit5_coco/TestDataset:!ImageFolderanno_path: /home/PaddleDetection/dataset/coco/fruit5_coco/annotations/instance_val.jsondataset_dir: /home/PaddleDetection/dataset/coco/fruit5_coco/