YOLOv3+mAP實現金魚檢測
Git源碼地址:傳送門
準備數據集
- 按幀數讀取視頻保存圖片 video2frame.py
- 使用labelimg標注工具對圖片進行標注
- 統一圖片大小為 416x416,并把標簽等信息寫成
.xml
文件 conver_point.py - 讀取縮放后的標簽圖片,轉為左上角右下角坐標信息 voc2yolo_v3.py
自定義數據集
分析
- 準備數據集
- 使用標注工具(labelimg)給圖片標注標簽,轉換為YOLO可用的格式
- 標簽框信息:類別 + 中心點坐標 + 寬高
- cls, cx, cy, w, h
- 準備錨框
- 自定義錨框,3類檢測目標各3種錨框,共9種錨框
- 獲取錨框寬高 anchor_w, anchor_h,用于tw, th的制作
- 標簽形狀更換
- C H W --> H W C
- 如:使用13x13,四分類
- 13, 13, 27 --> (情況1) 13, 13, 3, 9 (情況2) 3, 13, 13, 9
- 填值(one-hot編碼)
- tx, ty, tw, th, one-hot
- tx = 坐標x偏移量
- ty = 坐標y偏移量
- tw = torch.log(gt_w / anchor_w)
- th = torch.log(gt_h / anchor_h)
實現結構
- dataset.py
- init
- 讀取特征文件,獲得目標寬高所有信息
- 特征文件保存信息格式
- 文件名 類別 中心點坐標 寬高
- img_name, cls, cx, cy, gt_w, gt_h
- len
- 返回文件信息的長度
- gititem
- 根據索引讀取指定行信息 img_name, cls, cx, cy, gt_w, gt_h
- 切割獲得圖片名字 img_name、標簽框信息 cls, cx, cy, gt_w, gt_h
- 圖片轉為張量 img_name --> img_tensor
- 通道變換保存標簽 H W 27 --> H W 3 9
- 標簽框切割計算獲取 cx, tx, cy, ty
- gt_w, gt_h和錨框寬高計算 tw th
- 類別cls和類別數創建one-hot編碼
- 填值 label[cx, cy, feature_idx] = conf tx ty tw th one_hot
- init
完整代碼
dataset.py
import math
import os.pathimport cv2
import numpy as np
import torch
from torch.utils.data import Dataset
from config import cfg
from util import util
import torch.nn.functional as Fclass ODDateset(Dataset):def __init__(self):super().__init__()with open(cfg.BASE_LABEL_PATH, 'r', encoding='utf-8') as f:self.lines = f.readlines()def __len__(self):return len(self.lines)def __getitem__(self, index):""":param index: 索引:return: 三種特征大小標簽值、圖片張量1. 根據索引讀取指定行信息 img_name, cls, cx, cy, gt_w, gt_h2. 切割獲得圖片名字 img_name、標簽框信息 cls, cx, cy, gt_w, gt_h3. 圖片轉為張量 img_name --> img_tensor4. 通道變換保存標簽 H W 27 --> H W 3 95. 標簽框切割計算獲取 cx, tx, cy, ty6. gt_w, gt_h和錨框寬高計算 tw th7. 類別cls和類別數創建one-hot編碼8. 填值 label[cx, cy, feature_idx] = conf tx ty tw th one_hot"""infos = self.lines[index].strip().split()img_name = infos[:1][0]# '1.jpg'img_path = os.path.join(cfg.BASE_IMG_PATH, img_name)img = cv2.imread(img_path)img_tensor = util.t(img)box_info = infos[1:]# ['2', '163', '218', '228', '246', '1', '288', '205', '159', '263']boxes = np.split(np.array(box_info, dtype=np.float_), len(box_info) // 5)# 0 = {ndarray: (5,)} [ 2. 163. 218. 228. 246.]# 1 = {ndarray: (5,)} [ 1. 288. 205. 159. 263.]label_dic = {}for feature, anchors in cfg.ANCHORS_GROUP.items():# H W 3 9label = torch.zeros((feature, feature, 3, 5 + cfg.CLASS_NUM))scale_factor = cfg.IMG_ORI_SIZE / feature# 416 / 13 = 32for box in boxes:cls, cx, cy, gt_w, gt_h = box# [ 2. 163. 218. 228. 246.]offset_x, cx_idx = math.modf(cx / scale_factor)# 0 = {float} 0.09375# 1 = {float} 5.0offset_y, cy_idx = math.modf(cy / scale_factor)for idx, anchor in enumerate(anchors):anchor_w, anchor_h = torch.tensor(anchor)# torch.log 加速收斂速度tw = torch.log(gt_w / anchor_w)th = torch.log(gt_h / anchor_h)one_hot = F.one_hot(torch.tensor(int(cls), dtype=torch.int64), num_classes=cfg.CLASS_NUM)# tensor([0, 0, 1, 0])conf = 1label[int(cy_idx), int(cx_idx), idx] = torch.tensor([conf, offset_x, offset_y, tw, th, *one_hot])# h w clabel_dic[feature] = labelf1, f2, f3 = cfg.ANCHORS_GROUP.keys()# 13 26 52return label_dic[f1], label_dic[f2], label_dic[f3], img_tensorif __name__ == '__main__':dataset = ODDateset()print(dataset[0])pass
構建網絡模型
網絡結構
- 主干網絡 backbone
- 卷積層 CBL
- Conv
- BN
- LeakReLu
- 殘差單元 ResUnit
- 下采樣 DownSample
- 卷積層 CBL
- neck
- 卷積集合 ConvolutionSet
- 卷積層 CBL
- 上采樣 UpSample
- 拼接操作 torch.cat
- head
- 卷積層 CBL
- 全卷積預測
- 類別 x 錨框
- ( 1 + 4 + 4 ) x 3
實現結構
- module.py
- 卷積層 CBL
- 殘差單元 ResUnit
- 下采樣 DownSample
- 上采樣 UpSample
- 卷積集合 ConvolutionSet
- data.yaml
- 保存主干網絡結構的參數:通道數、殘差塊數量
- darknet53.py
- 實現主干網絡結構,輸出out_13x13, out_26x26, out_52x52
- yolov3.py
- 初始化主干網絡,實現neck、head網絡結構,輸出detect_13_out, detect_26_out, detect_52_out
完整代碼
module.py
"""
網絡結構
- backbone- 卷積層 CBL- Conv- BN- LeakReLu- 殘差單元 ResUnit- 下采樣 DownSample
- neck- 卷積集合 ConvolutionSet- 上采樣 UpSample- 拼接操作 torch.cat"""
import torch
from torch import nnclass CBL(nn.Module):# Conv+BN+LeakReLudef __init__(self, c_in, c_out, k, s):super().__init__()self.cnn_layer = nn.Sequential(nn.Conv2d(c_in, c_out, kernel_size=k, stride=s, padding=k // 2, bias=False),nn.BatchNorm2d(c_out),nn.LeakyReLU())def forward(self, x):return self.cnn_layer(x)class ResUnit(nn.Module):# 殘差單元def __init__(self, c_num):super().__init__()self.block = nn.Sequential(CBL(c_num, c_num // 2, 1, 1),CBL(c_num // 2, c_num, 3, 1))def forward(self, x):return self.block(x) + xclass DownSample(nn.Module):# 下采樣def __init__(self, c_in, c_out):super().__init__()self.down_sample = nn.Sequential(CBL(c_in, c_out, 3, 2))def forward(self, x):return self.down_sample(x)class ConvolutionSet(nn.Module):# 卷積集合def __init__(self, c_in, c_out):super().__init__()self.cnn_set = nn.Sequential(CBL(c_in, c_out, 1, 1),CBL(c_out, c_in, 3, 1),CBL(c_in, c_out, 1, 1),CBL(c_out, c_in, 3, 1),CBL(c_in, c_out, 1, 1))def forward(self, x):return self.cnn_set(x)class UpSample(nn.Module):# 上采樣def __init__(self):super().__init__()self.up_sample = nn.Upsample(scale_factor=2, mode='nearest')def forward(self, x):return self.up_sample(x)if __name__ == '__main__':# data = torch.randn(1, 3, 416, 416)# cnn = nn.Sequential(# CBL(3, 32, 3, 1),# DownSample(32, 64)# )# res = ResUnit(64)## cnn_out = cnn(data)# res_out = res(cnn_out)# print(cnn_out.shape)# # torch.Size([1, 64, 208, 208])# print(res_out.shape)# # torch.Size([1, 64, 208, 208])data = torch.randn(1, 1024, 13, 13)con_set = ConvolutionSet(1024, 512)cnn = CBL(512, 256, 1, 1)up_sample = UpSample()P0_out = up_sample(cnn(con_set(data)))print(P0_out.shape)# torch.Size([1, 256, 26, 26])pass
data.yaml
block_nums:
- 1
- 2
- 8
- 8
- 4
channels:
- 32
- 64
- 128
- 256
- 512
- 1024
darknet53.py
import torch
import yaml
from torch import nn
from module import CBL, ResUnit, DownSampleclass DarkNet53(nn.Module):def __init__(self):super().__init__()self.input_layer = nn.Sequential(CBL(3, 32, 3, 1))# # 方式1# self.hidden_layer = nn.Sequential(# DownSample(32, 64),# ResUnit(64),## DownSample(64, 128),# ResUnit(128),# ResUnit(128),## DownSample(128, 256),# ResUnit(256),# ResUnit(256),# ResUnit(256),# ResUnit(256),# ResUnit(256),# ResUnit(256),# ResUnit(256),# ResUnit(256),## DownSample(256, 512),# ResUnit(512),# ResUnit(512),# ResUnit(512),# ResUnit(512),# ResUnit(512),# ResUnit(512),# ResUnit(512),# ResUnit(512),## DownSample(512, 1024),# ResUnit(1024),# ResUnit(1024),# ResUnit(1024),# ResUnit(1024)# )# 方式2layers = []with open('data.yaml', 'r', encoding='utf-8') as file:dic = yaml.safe_load(file)channels = dic['channels']block_nums = dic['block_nums']for idx, block_num in enumerate(block_nums):layers.append(self.make_layer(channels[idx], channels[idx + 1], block_num))self.hidden_layer = nn.Sequential(*layers)def make_layer(self, c_in, c_out, block_num):units = [DownSample(c_in, c_out)]for _ in range(block_num):units.append(ResUnit(c_out))return nn.Sequential(*units)def forward(self, x):x = self.input_layer(x)unit52_out = self.hidden_layer[:3](x)unit26_out = self.hidden_layer[3](unit52_out)unit13_out = self.hidden_layer[4](unit26_out)return unit52_out, unit26_out, unit13_outif __name__ == '__main__':data = torch.randn(1, 3, 416, 416)net = DarkNet53()# out = net(data)# print(out.shape)# # torch.Size([1, 1024, 13, 13])outs = net(data)for out in outs:print(out.shape)# torch.Size([1, 256, 52, 52])# torch.Size([1, 512, 26, 26])# torch.Size([1, 1024, 13, 13])# darknet_hidden_param = {# 'channels': [32, 64, 128, 256, 512, 1024],# 'block_nums': [1, 2, 8, 8, 4]# }# with open('data.yaml', 'r', encoding='utf-8') as file:# # yaml.safe_dump(darknet_hidden_param, file)# dic = yaml.safe_load(file)# channels = dic['channels']# block_nums = dic['block_nums']# print(dic)# # {'block_nums': [1, 2, 8, 8, 4], 'channels': [32, 64, 128, 256, 512, 1024]}# print(channels)# # [32, 64, 128, 256, 512, 1024]# print(block_nums)# # [1, 2, 8, 8, 4]pass
yolov3.py
import torch
from torch import nn
from darknet53 import DarkNet53
from module import ConvolutionSet, CBL, UpSampleclass YoLov3(nn.Module):def __init__(self):super().__init__()self.backbone = DarkNet53()self.conv1 = nn.Sequential(ConvolutionSet(1024, 512))self.detect_13 = nn.Sequential(CBL(512, 256, 3, 1),# 4分類 * 錨框: (1 + 4 + 4) * 3CBL(256, 27, 1, 1))self.neck_hidden1 = nn.Sequential(CBL(512, 256, 1, 1),UpSample())self.conv2 = nn.Sequential(ConvolutionSet(256 + 512, 256))self.detect_26 = nn.Sequential(CBL(256, 128, 3, 1),CBL(128, 27, 1, 1))self.neck_hidden2 = nn.Sequential(CBL(256, 128, 1, 1),UpSample())self.conv3 = nn.Sequential(ConvolutionSet(128 + 256, 128))self.detect_52 = nn.Sequential(CBL(128, 64, 3, 1),CBL(64, 27, 1, 1))def forward(self, x):backbone_unit52_out, backbone_unit26_out, backbone_unit13_out = self.backbone(x)conv1_out = self.conv1(backbone_unit13_out)detect_13_out = self.detect_13(conv1_out)neck_hidden1_out = self.neck_hidden1(conv1_out)route26_out = torch.cat((neck_hidden1_out, backbone_unit26_out), dim=1)conv2_out = self.conv2(route26_out)detect_26_out = self.detect_26(conv2_out)neck_hidden2_out = self.neck_hidden2(conv2_out)route52_out = torch.cat((neck_hidden2_out, backbone_unit52_out), dim=1)conv3_out = self.conv3(route52_out)detect_52_out = self.detect_52(conv3_out)return detect_13_out, detect_26_out, detect_52_outif __name__ == '__main__':data = torch.randn(1, 3, 416, 416)yolov3 = YoLov3()outs = yolov3(data)for out in outs:print(out.shape)# torch.Size([1, 512, 13, 13])# torch.Size([1, 256, 26, 26])# torch.Size([1, 128, 52, 52])# P0 P1 P2: N 27 H W# 27: 類別(1 + 4 + 4) * 錨框3# torch.Size([1, 27, 13, 13])# torch.Size([1, 27, 26, 26])# torch.Size([1, 27, 52, 52])pass
訓練模型
分析
- 準備數據dataset
- 初始化網絡模型
- 損失函數
- 目標檢測
- 正樣本
- 置信度:二分類交叉熵
- 坐標:均方差損失
- 類別:交叉熵損失
- 負樣本
- 置信度:二分類交叉熵
- 正樣本
- 目標檢測
- 優化器
實現結構
- train.py
- init
- 準備數據 dataset
- 初始化自定義數據集
- 數據加載器:批次、打亂次序
- 初始化網絡模型 yolov3
- 切換設備
- 損失函數
- 置信度:二分類交叉熵BCEWithLogitsLoss
- 坐標:均方差損失MSELoss
- 類別:交叉熵損失CrossEntropyLoss
- 優化器 Adam
- 準備數據 dataset
- train
- 網絡模型開啟訓練
- 遍歷數據加載器獲得三種特征大小標簽值、圖片張量,并切換設備
- 圖片張量傳入網絡獲得三種預期輸出
- 三種標簽值、對應預期輸出值和正負樣本因子傳入loss_fn,計算獲得對應損失,并求和獲得模型損失
- 優化器進行梯度清零
- 模型損失反向傳播
- 優化器進行梯度更新
- 累加模型損失計算平均損失
- 保存模型權重
- loss_fn
- 預期輸出值更換通道 N 27 H W --> N H W 27 --> N H W 3 9
- 獲取位置索引值
- 正樣本數據位置 target[…, 0] > 0
- 負樣本數據位置 target[…, 0] == 0
- 計算損失
- 正樣本:置信度 坐標 類別
- 負樣本:置信度
- 索引獲取
- 0 置信度
- 1:5 坐標
- 5: 類別
- 正負樣本乘上對應規模因子的累加和
- run
- 設定迭代次數,循環調用train訓練模型
- init
完整代碼
train.py
"""
分析1. 準備數據dataset
2. 初始化網絡模型
3. 損失函數- 目標檢測- 正樣本- 置信度:二分類交叉熵- 坐標:均方差損失- 類別:交叉熵損失- 負樣本- 置信度:二分類交叉熵4. 優化器
"""
import osimport torch.optim
from torch import nn
from torch.utils.data import DataLoader
from yolov3 import YoLov3
from dataset import ODDateset
from config import cfgdevice = cfg.deviceclass Train:def __init__(self):# 1. 準備數據datasetod_dataset = ODDateset()self.dataloader = DataLoader(od_dataset, batch_size=6, shuffle=True)# 2. 初始化網絡模型self.net = YoLov3().to(device)# 加載參數# if os.path.exists(cfg.WEIGHT_PATH):# self.net.load_state_dict(torch.load(cfg.WEIGHT_PATH))# print('loading weights successfully')# 3. 損失函數# - 置信度:二分類交叉熵BCEWithLogitsLossself.conf_loss_fn = nn.BCEWithLogitsLoss()# - 坐標:均方差損失MSELossself.loc_loss_fn = nn.MSELoss()# - 類別:交叉熵損失CrossEntropyLossself.cls_loss_fn = nn.CrossEntropyLoss()# 4. 優化器self.opt = torch.optim.Adam(self.net.parameters())def train(self, epoch):""":param epoch: 迭代訓練的次數:return: None1. 開啟訓練2. 遍歷數據加載器獲取三種特征大小標簽值、圖片張量,并切換設備3. 圖片張量傳入網絡獲得三種預期輸出4. 三種標簽值、對應預期輸出值和正負樣本因子傳入loss_fn,計算獲得對應損失,并求和獲得模型損失5. 優化器進行梯度清零6. 模型損失反向傳播7. 優化器進行梯度更新8. 累加模型損失計算平均損失9. 保存模型權重"""# 1. 開啟訓練self.net.train()# 累加損失sum_loss = 0for target_13, target_26, target_52, img in self.dataloader:# 2. 獲取三種特征大小標簽值、圖片張量,并切換設備target_13, target_26, target_52 = target_13.to(device), target_26.to(device), target_52.to(device)img = img.to(device)# 3. 圖片張量傳入網絡獲得三種預期輸出pred_out_13, pred_out_26, pred_out_52 = self.net(img)# 4.loss_13 = self.loss_fn(target_13, pred_out_13, scale_factor=cfg.SCALE_FACTOR_BIG)loss_26 = self.loss_fn(target_26, pred_out_26, scale_factor=cfg.SCALE_FACTOR_MID)loss_52 = self.loss_fn(target_52, pred_out_52, scale_factor=cfg.SCALE_FACTOR_SML)loss = loss_13 + loss_26 + loss_52# 5. 梯度清零self.opt.zero_grad()# 6. 反向傳播loss.backward()# 7. 梯度更新self.opt.step()sum_loss += loss.item()avg_loss = sum_loss / len(self.dataloader)print(f'{epoch}\t{avg_loss}')if epoch % 10 == 0:print('save weight')torch.save(self.net.state_dict(), cfg.WEIGHT_PATH)def loss_fn(self, target, pre_out, scale_factor):""":param target: 標簽:param pre_out: 預期輸出:param scale_factor: 正負樣本因子:return: 正負樣本乘上對應規模因子的累加和1. 預期輸出值更換通道 N 27 H W --> N H W 27 --> N H W 3 92. 獲取位置索引值- 正樣本數據位置 target[..., 0] > 0- 負樣本數據位置 target[..., 0] == 03. 計算損失- 正樣本:置信度 坐標 類別- 負樣本:置信度- 索引獲取- 0 置信度- 1:5 坐標- 5: 類別4. 正負樣本乘上對應規模因子的累加和"""# 1. 預期輸出值更換通道 N 27 H W --> N H W 27 --> N H W 3 9pre_out = pre_out.permute((0, 2, 3, 1))n, h, w, _ = pre_out.shapepre_out = torch.reshape(pre_out, (n, h, w, 3, -1))# 2. 獲取位置索引值 正樣本數據位置 target[..., 0] > 0 負樣本數據位置 target[..., 0] == 0mask_obj = target[..., 0] > 0mask_noobj = target[..., 0] == 0# 3. 計算損失# 正樣本:置信度 坐標 類別target_obj = target[mask_obj]output_obj = pre_out[mask_obj]conf_loss = self.conf_loss_fn(output_obj[:, 0], target_obj[:, 0])loc_loss = self.loc_loss_fn(output_obj[:, 1:5], target_obj[:, 1:5])cls_loss = self.cls_loss_fn(output_obj[:, 5:], torch.argmax(target_obj[:, 5:], dim=1))loss_obj = conf_loss + loc_loss + cls_loss# 負樣本:置信度target_noobj = target[mask_noobj]output_noobj = pre_out[mask_noobj]loss_noobj = self.conf_loss_fn(output_noobj[:, 0], target_noobj[:, 0])# 4. 正負樣本乘上對應規模因子的累加和return loss_obj * scale_factor + loss_noobj * (1 - scale_factor)def run(self):for epoch in range(500):self.train(epoch)passif __name__ == '__main__':train = Train()train.run()pass
推理預測
分析
- 網絡初始化,加載權重參數 net
- 輸入數據預處理(歸一化) img_norm
- 前向傳播獲得輸出,輸出數據形狀是 N C H W --> N 3(錨框數量 anchor_num) 9 H W
- 根據給定的閾值 thresh 獲取符合閾值要求目標的索引
- idx = torch.where([:, :, 0, :, :] > thresh
- N: idx[0]
- anchor_num = idx[1]
- H(rows): idx[2]
- W(cols): idx[3]
- 解碼中心點坐標 cx cy
- cx_idx = 2
- cy_idx = 1
- tx = [:, :, 1, :, :]
- ty = [:, :, 2, :, :]
- tw = [:, :, 3, :, :]
- th = [:, :, 4, :, :]
- cx = (cx_idx + tx) * 32
- cy = (cy_idx + ty) * 32
- pred_w = exp(tw) * anchor_w
- pred_h = exp(th) * anchor_h
實現結構
- detect.py
- init
- 初始化網絡
- 網絡開啟驗證
- 網絡加載參數
- forward
- 圖像預處理
- 圖片轉為張量
- 擴張維度,表示批次
- 圖片張量傳給網絡獲得檢測輸出
- 對檢測輸出進行解碼 decode,獲得檢測框信息
- 拼接大中小目標框信息并返回
- 圖像預處理
- decode
- 預期輸出值更換通道 N 27 H W --> N H W 27 --> N H W 3 9
- 獲取檢測框的坐標索引 錨框數量
- 獲取檢測框的標簽信息 [[conf, tx, ty, tw, th, cls], …]
- 方式1:label = pred_out[idx[0], idx[1], idx[2], idx[3], :]
- 方式2:label = pred_out[idx]
- 計算檢測框的中心坐標 寬高
- 規模因子 = 原圖大小 / 特征大小
- 獲取當前特征對應的三種錨框
- 獲取索引對應的錨框的寬高
- 坐標轉換:中心點坐標+寬高 --> 左上角坐標+右下角坐標
- torch.stack 整合坐標 [conf, x_min, y_min, x_max, y_max, cls]
- run
- 傳入圖片進行前向傳播,獲得預測框信息
- 根據不同類別,遍歷框信息,進行NMS,獲得各類別最優框
- 不同類別繪制不同顏色的檢測框,并標注類別名
- 保存框的置信度和坐標信息,以便計算mAP
- init
- util.py
- bbox_iou
- 計算標簽框和輸出框的交并比
- nms
- 模型輸出的框,按置信度排序
- 置信度最高的,作為當前類別最優的框 max_conf_box = detect_boxes[0]
- 剩余的框 detect_boxes[1:] 和當前最優框 max_conf_box 計算IOU 獲取 iou_val
- 和給定閾值(超參數)作比較 iou_idx = iou_val < thresh
- detect_boxes[1:][iou_idx] 則為保留的框
完整代碼
detect.py
"""
分析1. 網絡初始化,加載權重參數 net
2. 輸入數據預處理(歸一化) img_norm
3. 前向傳播獲取預期輸出,計算獲取 cx cy pred_w pred_h- 切片獲取數據 cls, tx, ty, tw, th, one_hot- cx = (索引 + tx) * 特征規模大小- cy = (索引 + ty) * 特征規模大小- pred_w = exp(tw) * anchor_w- pred_h = exp(th) * anchor_h
4. 繪制檢測框"""
import osimport cv2
import torch
from torch import nnfrom yolov3 import YoLov3
from config import cfg
from util import utilclass Detector(nn.Module):def __init__(self):super().__init__()# 1. 網絡初始化net = YoLov3()# 開啟驗證net.eval()# 加載權重參數 netnet.load_state_dict(torch.load(cfg.WEIGHT_PATH))print('loading weights successfully')self.net = netdef normalize(self, frame):frame_tensor = util.t(frame)frame_tensor = torch.unsqueeze(frame_tensor, dim=0)return frame_tensordef decode(self, pred_out, feature, threshold):# 1. 預期輸出值更換通道 N 27 H W --> N H W 27 --> N H W 3 9pred_out = pred_out.permute((0, 2, 3, 1))n, h, w, _ = pred_out.shapepred_out = torch.reshape(pred_out, (n, h, w, 3, -1))# 2. 獲取檢測框的坐標索引 錨框數量idx = torch.where(pred_out[:, :, :, :, 0] > threshold)# N H W 3(錨框數量)# - N: idx[0]# - H(rows): idx[1]# - W(cols): idx[2]# - anchor_num = idx[3]h_idx = idx[1]w_idx = idx[2]anchor_num = idx[3]# 3. 獲取檢測框的標簽信息 [[conf, tx, ty, tw, th, cls], ...]# 方式1# label = pred_out[idx[0], idx[1], idx[2], idx[3], :]# 方式2label = pred_out[idx]# N V# [[conf, tx, ty, tw, th, cls], ...]conf = label[:, 0]tx = label[:, 1]ty = label[:, 2]tw = label[:, 3]th = label[:, 4]cls = torch.argmax(label[:, 5:], dim=1)# 4. 計算檢測框的中心坐標 寬高# 規模因子 = 原圖大小 / 特征大小scale_factor = cfg.IMG_ORI_SIZE / featurecx = (tx + w_idx) * scale_factorcy = (ty + h_idx) * scale_factor# 當前特征對應的三種錨框anchors = cfg.ANCHORS_GROUP[feature]# anchors 類型是list 轉為張量便于高級索引anchors = torch.tensor(anchors)# 獲取索引對應的錨框的寬高anchor_w = anchors[anchor_num][:, 0]anchor_h = anchors[anchor_num][:, 1]pred_w = torch.exp(tw) * anchor_wpred_h = torch.exp(th) * anchor_h# 5. 坐標轉換:中心點坐標+寬高 --> 左上角坐標+右下角坐標x_min = cx - pred_w / 2y_min = cy - pred_h / 2x_max = cx + pred_w / 2y_max = cy + pred_h / 2# torch.stack 整合坐標 [conf, x_min, y_min, x_max, y_max, cls]out = torch.stack((conf, x_min, y_min, x_max, y_max, cls), dim=1)return outdef show_image(self, img, x1, y1, x2, y2, cls):cv2.rectangle(img,(int(x1), int(y1)),(int(x2), int(y2)),color=cfg.COLOR_DIC[int(cls)],thickness=2)cv2.putText(img,text=cfg.CLS_DIC[int(cls)],org=(int(x1) + 5, int(y1) + 10),color=cfg.COLOR_DIC[int(cls)],fontScale=0.5,fontFace=cv2.FONT_ITALIC)cv2.imshow('img', img)cv2.waitKey(25)def forward(self, img, threshold):img_norm = self.normalize(img)pred_out_13, pred_out_26, pred_out_52 = self.net(img_norm)f_big, f_mid, f_sml = cfg.ANCHORS_GROUP.keys()box_13 = self.decode(pred_out_13, f_big, threshold)box_26 = self.decode(pred_out_26, f_mid, threshold)box_52 = self.decode(pred_out_52, f_sml, threshold)boxes = torch.cat((box_13, box_26, box_52), dim=0)return box_52def run(self, img_names):for img_name in img_names:img_path = os.path.join(cfg.BASE_IMG_PATH, img_name)img = cv2.imread(img_path)detect_out = detect(img, cfg.THRESHOLD_BOX)if len(detect_out) == 0:continuefilter_boxes = []for cls in range(4):mask_cls = detect_out[..., -1] == cls_boxes = detect_out[mask_cls]boxes = util.nms(_boxes, cfg.THRESHOLD_NMS)if len(boxes) == 0:continuefilter_boxes.append(boxes)for boxes in filter_boxes:for box in boxes:conf, x1, y1, x2, y2, cls = boxself.show_image(img, x1, y1, x2, y2, cls)# cv2.imwrite(os.path.join(f"./run/imgs/{img_name}"), img)# 保存box信息# file_name = img_name.split('.')[0] + '.txt'# file_path = os.path.join('../data/cal_map/input/detection-results', file_name)# with open(file_path, 'a', encoding='utf-8') as file:# conf_norm = nn.Sigmoid()(conf)# file.write(f"{cfg.CLS_DIC[int(cls)]} {conf_norm} {int(x1)} {int(y1)} {int(x2)} {int(y2)}\n")if __name__ == '__main__':detect = Detector()# frame = cv2.imread('../data/VOC2007/YOLOv3_JPEGImages/2.jpg')# boxes = detect(frame, 1)# # 獲取同種類的框,進行NMS# boxes = util.nms(boxes, 0.1)# for box in boxes:# conf, x1, y1, x2, y2, cls = box.detach().cpu().numpy()# detect.show_image(frame, x1, y1, x2, y2, cls)# # cv2.imshow('frame', frame)# cv2.waitKey(0)# cv2.destroyAllWindows()# 多張圖片img_names = os.listdir(cfg.BASE_IMG_PATH)detect.run(img_names)pass
util.py
import torch
from torchvision import transforms
from config import cfgt = transforms.Compose([# H W C --> C H W 且把值歸一化為 0-1transforms.ToTensor()
])def bbox_iou(box, boxes):box_area = (box[2] - box[0]) * (box[3] - box[1])boxes_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])l_x = torch.maximum(box[0], boxes[:, 0])l_y = torch.maximum(box[1], boxes[:, 1])r_x = torch.minimum(box[2], boxes[:, 2])r_y = torch.minimum(box[3], boxes[:, 3])w = torch.maximum(r_x - l_x, torch.tensor(0))h = torch.maximum(r_y - l_y, torch.tensor(0))inter_area = w * hiou_val = inter_area / (box_area + boxes_area - inter_area)return iou_valdef nms(detect_boxes, threshold=0.5):""":param detect_boxes: 偵測輸出的框的信息 [[conf, tx, ty, tw, th, cls], ...]:param threshold: 閾值:return: 篩選后的偵測框流程分析1. 模型輸出的框,按置信度排序2. 置信度最高的,作為當前類別最優的框 max_conf_box = detect_boxes[0]3. 剩余的框 detect_boxes[1:] 和當前最優框 max_conf_box 計算IOU 獲取 iou_val4. 和給定閾值(超參數)作比較 iou_idx = iou_val < thresh5. detect_boxes[1:][iou_idx] 則為保留的框"""# 保留最優框信息best_boxes = []# 1. 模型輸出的框,按置信度排序idx = torch.argsort(detect_boxes[:, 0], descending=True)detect_boxes = detect_boxes[idx]while detect_boxes.size(0) > 0:# 2. 置信度最高的,作為當前類別最優的框max_conf_box = detect_boxes[0]best_boxes.append(max_conf_box)# 3. 剩余的框 detect_boxes[1:] 和當前最優框 max_conf_box 計算IOUdetect_boxes = detect_boxes[1:]iou_val = bbox_iou(max_conf_box[1:5], detect_boxes[:, 1:5])# 4. 和給定閾值(超參數)作比較保留小于閾值的對應框detect_boxes = detect_boxes[iou_val < threshold]return best_boxes
參數配置文件
cfg.py
import torch'自定義錨框'
ANCHORS_GROUP = {13: [[360, 360], [360, 180], [180, 360]],26: [[180, 180], [180, 90], [90, 180]],52: [[90, 90], [90, 45], [45, 90]]
}'yolo5 coco數據集錨框'
ANCHORS_DIC = {13: [[116, 90], [156, 198], [373, 326]],26: [[30, 61], [62, 45], [59, 119]],52: [[10, 13], [16, 30], [33, 23]]
}device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
CLASS_NUM = 4
IMG_ORI_SIZE = 416
BASE_IMG_PATH = r'E:\pythonProject\yolo3\data\VOC2007\YOLOv3_JPEGImages'
BASE_LABEL_PATH = r'E:\pythonProject\yolo3\data\VOC2007\yolo_annotation.txt'
WEIGHT_PATH = r'E:\pythonProject\yolo3\net\weights\best.pt'SCALE_FACTOR_BIG = 0.9
SCALE_FACTOR_MID = 0.9
SCALE_FACTOR_SML = 0.9
'閾值'
THRESHOLD_BOX = 0.9
THRESHOLD_NMS = 0.1'視頻路徑'
VIDEO_PATH = r'E:\pythonProject\yolo3\data\video\fish_video.mp4'
VIDEO2FRAME_PATH = r'E:\pythonProject\yolo3\data\VOC2007\JPEGImages'
'網絡參數'
DARKNET35_PARAM_PATH = r'E:\pythonProject\yolo3\config\data.yaml'
'檢測類別'
CLS_DIC = {0: 'big_fish',1: 'small_fish'
}
COLOR_DIC = {0: (0, 0, 255), 1: (100, 200, 255), 2: (255, 0, 0), 3: (0, 255, 0)}
計算評價指標
- Github上下載一個mAP源碼(如:
https://github.com/Cartucho/mAP.git
) - 手動創建計算mAP的輸入數據文件夾
- input
- detection-results:模型輸出數據集
- ground-truth:標簽數據集
- images-optional:原圖縮放后的數據集
- input
- data/VOC2007/YOLOv3_JPEGImages數據拷貝到images-optional
- data/VOC2007/Annotations數據拷貝到ground-truth
- 運行convert_gt_xml.py,把
.xml
文件轉為.txt
文件- 其中
.txt
文件保存的是圖片標簽框的類別名+坐標信息cls_name xmin ymin xmax ymax
- 其中
- detector.py取消142-146行代碼的注釋,運行代碼后,detection-results文件夾會保存模型輸出框的
.txt
文件- 其中
.txt
文件保存的是圖片標簽框的類別名+置信度+坐標信息cls_name conf xmin ymin xmax ymax
- 其中
- 運行map.py會自動生成output文件,彈出mAP圖
完整代碼
convert_gt_xml.py
import sys
import os
import glob
import xml.etree.ElementTree as ET# make sure that the cwd() in the beginning is the location of the python script (so that every path makes sense)
os.chdir(os.path.dirname(os.path.abspath(__file__)))# change directory to the one with the files to be changed
parent_path = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
parent_path = os.path.abspath(os.path.join(parent_path, os.pardir))
# GT_PATH = os.path.join(parent_path, 'input','ground-truth')
yolo3_path = os.getcwd().rsplit('\\', 1)[:1][0]
# 'E:\\pythonProject\\yolo3'
GT_PATH = os.path.join(yolo3_path, 'data', 'cal_map', 'input', 'ground-truth')
#print(GT_PATH)
os.chdir(GT_PATH)# old files (xml format) will be moved to a "backup" folder
## create the backup dir if it doesn't exist already
if not os.path.exists("backup"):os.makedirs("backup")# create VOC format files
xml_list = glob.glob('*.xml')
if len(xml_list) == 0:print("Error: no .xml files found in ground-truth")sys.exit()
for tmp_file in xml_list:#print(tmp_file)# 1. create new file (VOC format)with open(tmp_file.replace(".xml", ".txt"), "a") as new_f:root = ET.parse(tmp_file).getroot()for obj in root.findall('object'):obj_name = obj.find('name').textbndbox = obj.find('bndbox')left = bndbox.find('xmin').texttop = bndbox.find('ymin').textright = bndbox.find('xmax').textbottom = bndbox.find('ymax').textnew_f.write("%s %s %s %s %s\n" % (obj_name, left, top, right, bottom))# 2. move old file (xml format) to backupos.rename(tmp_file, os.path.join("backup", tmp_file))
print("Conversion completed!")
voc2yolo_v3.py
import glob
import xml.etree.ElementTree as ET
import osdef xml_to_yolo(xml_path, img_path, save_dir):img_name = os.path.basename(img_path)xml_name_pre = os.path.basename(xml_path).split(".")[0]img_name_pre = os.path.basename(img_path).split(".")[0]if xml_name_pre != img_name_pre:print("xml_name is not equal to img_name")returntree = ET.parse(xml_path)root = tree.getroot()# 拼接格式 圖片地址 類別1 cx cy w h 類別2 cx cy w himg_annotation = img_namefor obj in root.findall('object'):class_name = obj.find('name').textxmin = float(obj.find('bndbox/xmin').text)ymin = float(obj.find('bndbox/ymin').text)xmax = float(obj.find('bndbox/xmax').text)ymax = float(obj.find('bndbox/ymax').text)# Convert to YOLO formatx_center = int((xmin + xmax) / 2)y_center = int((ymin + ymax) / 2)box_width = int((xmax - xmin))box_height = int((ymax - ymin))cls_id = cls_dic[class_name]# cls_id = 0img_annotation += f"\t\t{cls_id}\t{x_center}\t{y_center}\t{box_width}\t{box_height}\t"# Create a text file to save YOLO annotationsfile_name = os.path.splitext(os.path.basename(xml_path))[0] + '.txt'if os.path.isdir(save_directory):save_path = os.path.join(save_dir, file_name)else:save_path = save_dirwith open(save_path, 'a+') as file:file.write(img_annotation + '\n')return save_pathif __name__ == '__main__':# 保存為YOLOv3需要的txt格式save_directory = r'../data/VOC2007/yolo_annotation.txt'# 獲取轉換寬高為416x416之后的標簽,用于進行等比例縮放xml_paths = glob.glob(os.path.join(r'../data/VOC2007/Annotations', "*"))# 轉換之后的原圖像img_paths = glob.glob(os.path.join(r"../data/VOC2007/YOLOv3_JPEGImages", "*"))# cls_dic = {"fish_gray": 0, "fish_red": 1, "fish_black": 2}# cls_dic = {"person": 0, "dog": 1, "cat": 2, "horse": 3}cls_dic = {"big_fish": 0, "small_fish": 1}for idx, xml_path in enumerate(xml_paths):img_path = img_paths[idx]saved_path = xml_to_yolo(xml_path, img_path, save_directory)print(f"YOLO annotations saved to: {saved_path}")
map.py
import glob
import json
import os
import shutil
import operator
import sys
import argparse
import mathimport numpy as npMINOVERLAP = 0.5 # default value (defined in the PASCAL VOC2012 challenge)parser = argparse.ArgumentParser()
parser.add_argument('-na', '--no-animation', help="no animation is shown.", action="store_true")
parser.add_argument('-np', '--no-plot', help="no plot is shown.", action="store_true")
parser.add_argument('-q', '--quiet', help="minimalistic console output.", action="store_true")
# argparse receiving list of classes to be ignored (e.g., python main.py --ignore person book)
parser.add_argument('-i', '--ignore', nargs='+', type=str, help="ignore a list of classes.")
# argparse receiving list of classes with specific IoU (e.g., python main.py --set-class-iou person 0.7)
parser.add_argument('--set-class-iou', nargs='+', type=str, help="set IoU for a specific class.")
args = parser.parse_args()'''0,0 ------> x (width)|| (Left,Top)| *_________| | || |y |_________|(height) *(Right,Bottom)
'''# if there are no classes to ignore then replace None by empty list
if args.ignore is None:args.ignore = []specific_iou_flagged = False
if args.set_class_iou is not None:specific_iou_flagged = True# make sure that the cwd() is the location of the python script (so that every path makes sense)
os.chdir(os.path.dirname(os.path.abspath(__file__)))# GT_PATH = os.path.join(os.getcwd(), 'input', 'ground-truth')
# DR_PATH = os.path.join(os.getcwd(), 'input', 'detection-results')
# # if there are no images then no animation can be shown
# IMG_PATH = os.path.join(os.getcwd(), 'input', 'images-optional')yolo3_path = os.getcwd().rsplit('\\', 1)[:1][0]
# 'E:\\pythonProject\\yolo3'
GT_PATH = os.path.join(yolo3_path, 'data', 'cal_map', 'input', 'ground-truth')
DR_PATH = os.path.join(yolo3_path, 'data', 'cal_map', 'input', 'detection-results')
# if there are no images then no animation can be shown
IMG_PATH = os.path.join(yolo3_path, 'data', 'cal_map', 'input', 'images-optional')
if os.path.exists(IMG_PATH): for dirpath, dirnames, files in os.walk(IMG_PATH):if not files:# no image files foundargs.no_animation = True
else:args.no_animation = True# try to import OpenCV if the user didn't choose the option --no-animation
show_animation = False
if not args.no_animation:try:import cv2show_animation = Trueexcept ImportError:print("\"opencv-python\" not found, please install to visualize the results.")args.no_animation = True# try to import Matplotlib if the user didn't choose the option --no-plot
draw_plot = False
if not args.no_plot:try:import matplotlib.pyplot as pltdraw_plot = Trueexcept ImportError:print("\"matplotlib\" not found, please install it to get the resulting plots.")args.no_plot = Truedef log_average_miss_rate(prec, rec, num_images):"""log-average miss rate:Calculated by averaging miss rates at 9 evenly spaced FPPI pointsbetween 10e-2 and 10e0, in log-space.output:lamr | log-average miss ratemr | miss ratefppi | false positives per imagereferences:[1] Dollar, Piotr, et al. "Pedestrian Detection: An Evaluation of theState of the Art." Pattern Analysis and Machine Intelligence, IEEETransactions on 34.4 (2012): 743 - 761."""# if there were no detections of that classif prec.size == 0:lamr = 0mr = 1fppi = 0return lamr, mr, fppifppi = (1 - prec)mr = (1 - rec)fppi_tmp = np.insert(fppi, 0, -1.0)mr_tmp = np.insert(mr, 0, 1.0)# Use 9 evenly spaced reference points in log-spaceref = np.logspace(-2.0, 0.0, num = 9)for i, ref_i in enumerate(ref):# np.where() will always find at least 1 index, since min(ref) = 0.01 and min(fppi_tmp) = -1.0j = np.where(fppi_tmp <= ref_i)[-1][-1]ref[i] = mr_tmp[j]# log(0) is undefined, so we use the np.maximum(1e-10, ref)lamr = math.exp(np.mean(np.log(np.maximum(1e-10, ref))))return lamr, mr, fppi"""throw error and exit
"""
def error(msg):print(msg)sys.exit(0)"""check if the number is a float between 0.0 and 1.0
"""
def is_float_between_0_and_1(value):try:val = float(value)if val > 0.0 and val < 1.0:return Trueelse:return Falseexcept ValueError:return False"""Calculate the AP given the recall and precision array1st) We compute a version of the measured precision/recall curve withprecision monotonically decreasing2nd) We compute the AP as the area under this curve by numerical integration.
"""
def voc_ap(rec, prec):"""--- Official matlab code VOC2012---mrec=[0 ; rec ; 1];mpre=[0 ; prec ; 0];for i=numel(mpre)-1:-1:1mpre(i)=max(mpre(i),mpre(i+1));endi=find(mrec(2:end)~=mrec(1:end-1))+1;ap=sum((mrec(i)-mrec(i-1)).*mpre(i));"""rec.insert(0, 0.0) # insert 0.0 at begining of listrec.append(1.0) # insert 1.0 at end of listmrec = rec[:]prec.insert(0, 0.0) # insert 0.0 at begining of listprec.append(0.0) # insert 0.0 at end of listmpre = prec[:]"""This part makes the precision monotonically decreasing(goes from the end to the beginning)matlab: for i=numel(mpre)-1:-1:1mpre(i)=max(mpre(i),mpre(i+1));"""# matlab indexes start in 1 but python in 0, so I have to do:# range(start=(len(mpre) - 2), end=0, step=-1)# also the python function range excludes the end, resulting in:# range(start=(len(mpre) - 2), end=-1, step=-1)for i in range(len(mpre)-2, -1, -1):mpre[i] = max(mpre[i], mpre[i+1])"""This part creates a list of indexes where the recall changesmatlab: i=find(mrec(2:end)~=mrec(1:end-1))+1;"""i_list = []for i in range(1, len(mrec)):if mrec[i] != mrec[i-1]:i_list.append(i) # if it was matlab would be i + 1"""The Average Precision (AP) is the area under the curve(numerical integration)matlab: ap=sum((mrec(i)-mrec(i-1)).*mpre(i));"""ap = 0.0for i in i_list:ap += ((mrec[i]-mrec[i-1])*mpre[i])return ap, mrec, mpre"""Convert the lines of a file to a list
"""
def file_lines_to_list(path):# open txt file lines to a listwith open(path) as f:content = f.readlines()# remove whitespace characters like `\n` at the end of each linecontent = [x.strip() for x in content]return content"""Draws text in image
"""
def draw_text_in_image(img, text, pos, color, line_width):font = cv2.FONT_HERSHEY_PLAINfontScale = 1lineType = 1bottomLeftCornerOfText = poscv2.putText(img, text,bottomLeftCornerOfText,font,fontScale,color,lineType)text_width, _ = cv2.getTextSize(text, font, fontScale, lineType)[0]return img, (line_width + text_width)"""Plot - adjust axes
"""
def adjust_axes(r, t, fig, axes):# get text width for re-scalingbb = t.get_window_extent(renderer=r)text_width_inches = bb.width / fig.dpi# get axis width in inchescurrent_fig_width = fig.get_figwidth()new_fig_width = current_fig_width + text_width_inchespropotion = new_fig_width / current_fig_width# get axis limitx_lim = axes.get_xlim()axes.set_xlim([x_lim[0], x_lim[1]*propotion])"""Draw plot using Matplotlib
"""
def draw_plot_func(dictionary, n_classes, window_title, plot_title, x_label, output_path, to_show, plot_color, true_p_bar):# sort the dictionary by decreasing value, into a list of tuplessorted_dic_by_value = sorted(dictionary.items(), key=operator.itemgetter(1))# unpacking the list of tuples into two listssorted_keys, sorted_values = zip(*sorted_dic_by_value)# if true_p_bar != "":"""Special case to draw in:- green -> TP: True Positives (object detected and matches ground-truth)- red -> FP: False Positives (object detected but does not match ground-truth)- pink -> FN: False Negatives (object not detected but present in the ground-truth)"""fp_sorted = []tp_sorted = []for key in sorted_keys:fp_sorted.append(dictionary[key] - true_p_bar[key])tp_sorted.append(true_p_bar[key])plt.barh(range(n_classes), fp_sorted, align='center', color='crimson', label='False Positive')plt.barh(range(n_classes), tp_sorted, align='center', color='forestgreen', label='True Positive', left=fp_sorted)# add legendplt.legend(loc='lower right')"""Write number on side of bar"""fig = plt.gcf() # gcf - get current figureaxes = plt.gca()r = fig.canvas.get_renderer()for i, val in enumerate(sorted_values):fp_val = fp_sorted[i]tp_val = tp_sorted[i]fp_str_val = " " + str(fp_val)tp_str_val = fp_str_val + " " + str(tp_val)# trick to paint multicolor with offset:# first paint everything and then repaint the first numbert = plt.text(val, i, tp_str_val, color='forestgreen', va='center', fontweight='bold')plt.text(val, i, fp_str_val, color='crimson', va='center', fontweight='bold')if i == (len(sorted_values)-1): # largest baradjust_axes(r, t, fig, axes)else:plt.barh(range(n_classes), sorted_values, color=plot_color)"""Write number on side of bar"""fig = plt.gcf() # gcf - get current figureaxes = plt.gca()r = fig.canvas.get_renderer()for i, val in enumerate(sorted_values):str_val = " " + str(val) # add a space beforeif val < 1.0:str_val = " {0:.2f}".format(val)t = plt.text(val, i, str_val, color=plot_color, va='center', fontweight='bold')# re-set axes to show number inside the figureif i == (len(sorted_values)-1): # largest baradjust_axes(r, t, fig, axes)# set window titlefig.canvas.manager.set_window_title(window_title)# write classes in y axistick_font_size = 12plt.yticks(range(n_classes), sorted_keys, fontsize=tick_font_size)"""Re-scale height accordingly"""init_height = fig.get_figheight()# comput the matrix height in points and inchesdpi = fig.dpiheight_pt = n_classes * (tick_font_size * 1.4) # 1.4 (some spacing)height_in = height_pt / dpi# compute the required figure height top_margin = 0.15 # in percentage of the figure heightbottom_margin = 0.05 # in percentage of the figure heightfigure_height = height_in / (1 - top_margin - bottom_margin)# set new heightif figure_height > init_height:fig.set_figheight(figure_height)# set plot titleplt.title(plot_title, fontsize=14)# set axis titles# plt.xlabel('classes')plt.xlabel(x_label, fontsize='large')# adjust size of windowfig.tight_layout()# save the plotfig.savefig(output_path)# show imageif to_show:plt.show()# close the plotplt.close()"""Create a ".temp_files/" and "output/" directory
"""
TEMP_FILES_PATH = ".temp_files"
if not os.path.exists(TEMP_FILES_PATH): # if it doesn't exist alreadyos.makedirs(TEMP_FILES_PATH)
# output_files_path = "output"
output_files_path = r'E:\pythonProject\yolo3\data\cal_map\output'
if os.path.exists(output_files_path): # if it exist already# reset the output directoryshutil.rmtree(output_files_path)os.makedirs(output_files_path)
if draw_plot:os.makedirs(os.path.join(output_files_path, "classes"))
if show_animation:os.makedirs(os.path.join(output_files_path, "images", "detections_one_by_one"))"""ground-truthLoad each of the ground-truth files into a temporary ".json" file.Create a list of all the class names present in the ground-truth (gt_classes).
"""
# get a list with the ground-truth files
ground_truth_files_list = glob.glob(GT_PATH + '/*.txt')
if len(ground_truth_files_list) == 0:error("Error: No ground-truth files found!")
ground_truth_files_list.sort()
# dictionary with counter per class
gt_counter_per_class = {}
counter_images_per_class = {}gt_files = []
for txt_file in ground_truth_files_list:#print(txt_file)file_id = txt_file.split(".txt", 1)[0]file_id = os.path.basename(os.path.normpath(file_id))# check if there is a correspondent detection-results filetemp_path = os.path.join(DR_PATH, (file_id + ".txt"))if not os.path.exists(temp_path):error_msg = "Error. File not found: {}\n".format(temp_path)error_msg += "(You can avoid this error message by running extra/intersect-gt-and-dr.py)"error(error_msg)lines_list = file_lines_to_list(txt_file)# create ground-truth dictionarybounding_boxes = []is_difficult = Falsealready_seen_classes = []for line in lines_list:try:if "difficult" in line:class_name, left, top, right, bottom, _difficult = line.split()is_difficult = Trueelse:class_name, left, top, right, bottom = line.split()except ValueError:error_msg = "Error: File " + txt_file + " in the wrong format.\n"error_msg += " Expected: <class_name> <left> <top> <right> <bottom> ['difficult']\n"error_msg += " Received: " + lineerror_msg += "\n\nIf you have a <class_name> with spaces between words you should remove them\n"error_msg += "by running the script \"remove_space.py\" or \"rename_class.py\" in the \"extra/\" folder."error(error_msg)# check if class is in the ignore list, if yes skipif class_name in args.ignore:continuebbox = left + " " + top + " " + right + " " +bottomif is_difficult:bounding_boxes.append({"class_name":class_name, "bbox":bbox, "used":False, "difficult":True})is_difficult = Falseelse:bounding_boxes.append({"class_name":class_name, "bbox":bbox, "used":False})# count that objectif class_name in gt_counter_per_class:gt_counter_per_class[class_name] += 1else:# if class didn't exist yetgt_counter_per_class[class_name] = 1if class_name not in already_seen_classes:if class_name in counter_images_per_class:counter_images_per_class[class_name] += 1else:# if class didn't exist yetcounter_images_per_class[class_name] = 1already_seen_classes.append(class_name)# dump bounding_boxes into a ".json" filenew_temp_file = TEMP_FILES_PATH + "/" + file_id + "_ground_truth.json"gt_files.append(new_temp_file)with open(new_temp_file, 'w') as outfile:json.dump(bounding_boxes, outfile)gt_classes = list(gt_counter_per_class.keys())
# let's sort the classes alphabetically
gt_classes = sorted(gt_classes)
n_classes = len(gt_classes)
#print(gt_classes)
#print(gt_counter_per_class)"""Check format of the flag --set-class-iou (if used)e.g. check if class exists
"""
if specific_iou_flagged:n_args = len(args.set_class_iou)error_msg = \'\n --set-class-iou [class_1] [IoU_1] [class_2] [IoU_2] [...]'if n_args % 2 != 0:error('Error, missing arguments. Flag usage:' + error_msg)# [class_1] [IoU_1] [class_2] [IoU_2]# specific_iou_classes = ['class_1', 'class_2']specific_iou_classes = args.set_class_iou[::2] # even# iou_list = ['IoU_1', 'IoU_2']iou_list = args.set_class_iou[1::2] # oddif len(specific_iou_classes) != len(iou_list):error('Error, missing arguments. Flag usage:' + error_msg)for tmp_class in specific_iou_classes:if tmp_class not in gt_classes:error('Error, unknown class \"' + tmp_class + '\". Flag usage:' + error_msg)for num in iou_list:if not is_float_between_0_and_1(num):error('Error, IoU must be between 0.0 and 1.0. Flag usage:' + error_msg)"""detection-resultsLoad each of the detection-results files into a temporary ".json" file.
"""
# get a list with the detection-results files
dr_files_list = glob.glob(DR_PATH + '/*.txt')
dr_files_list.sort()for class_index, class_name in enumerate(gt_classes):bounding_boxes = []for txt_file in dr_files_list:#print(txt_file)# the first time it checks if all the corresponding ground-truth files existfile_id = txt_file.split(".txt",1)[0]file_id = os.path.basename(os.path.normpath(file_id))temp_path = os.path.join(GT_PATH, (file_id + ".txt"))if class_index == 0:if not os.path.exists(temp_path):error_msg = "Error. File not found: {}\n".format(temp_path)error_msg += "(You can avoid this error message by running extra/intersect-gt-and-dr.py)"error(error_msg)lines = file_lines_to_list(txt_file)for line in lines:try:tmp_class_name, confidence, left, top, right, bottom = line.split()except ValueError:error_msg = "Error: File " + txt_file + " in the wrong format.\n"error_msg += " Expected: <class_name> <confidence> <left> <top> <right> <bottom>\n"error_msg += " Received: " + lineerror(error_msg)if tmp_class_name == class_name:#print("match")bbox = left + " " + top + " " + right + " " +bottombounding_boxes.append({"confidence":confidence, "file_id":file_id, "bbox":bbox})#print(bounding_boxes)# sort detection-results by decreasing confidencebounding_boxes.sort(key=lambda x:float(x['confidence']), reverse=True)with open(TEMP_FILES_PATH + "/" + class_name + "_dr.json", 'w') as outfile:json.dump(bounding_boxes, outfile)"""Calculate the AP for each class
"""
sum_AP = 0.0
ap_dictionary = {}
lamr_dictionary = {}
# open file to store the output
with open(output_files_path + "/output.txt", 'w') as output_file:output_file.write("# AP and precision/recall per class\n")count_true_positives = {}for class_index, class_name in enumerate(gt_classes):count_true_positives[class_name] = 0"""Load detection-results of that class"""dr_file = TEMP_FILES_PATH + "/" + class_name + "_dr.json"dr_data = json.load(open(dr_file))"""Assign detection-results to ground-truth objects"""nd = len(dr_data)tp = [0] * nd # creates an array of zeros of size ndfp = [0] * ndfor idx, detection in enumerate(dr_data):file_id = detection["file_id"]if show_animation:# find ground truth imageground_truth_img = glob.glob1(IMG_PATH, file_id + ".*")#tifCounter = len(glob.glob1(myPath,"*.tif"))if len(ground_truth_img) == 0:error("Error. Image not found with id: " + file_id)elif len(ground_truth_img) > 1:error("Error. Multiple image with id: " + file_id)else: # found image#print(IMG_PATH + "/" + ground_truth_img[0])# Load imageimg = cv2.imread(IMG_PATH + "/" + ground_truth_img[0])# load image with draws of multiple detectionsimg_cumulative_path = output_files_path + "/images/" + ground_truth_img[0]if os.path.isfile(img_cumulative_path):img_cumulative = cv2.imread(img_cumulative_path)else:img_cumulative = img.copy()# Add bottom border to imagebottom_border = 60BLACK = [0, 0, 0]img = cv2.copyMakeBorder(img, 0, bottom_border, 0, 0, cv2.BORDER_CONSTANT, value=BLACK)# assign detection-results to ground truth object if any# open ground-truth with that file_idgt_file = TEMP_FILES_PATH + "/" + file_id + "_ground_truth.json"ground_truth_data = json.load(open(gt_file))ovmax = -1gt_match = -1# load detected object bounding-boxbb = [ float(x) for x in detection["bbox"].split() ]for obj in ground_truth_data:# look for a class_name matchif obj["class_name"] == class_name:bbgt = [ float(x) for x in obj["bbox"].split() ]bi = [max(bb[0],bbgt[0]), max(bb[1],bbgt[1]), min(bb[2],bbgt[2]), min(bb[3],bbgt[3])]iw = bi[2] - bi[0] + 1ih = bi[3] - bi[1] + 1if iw > 0 and ih > 0:# compute overlap (IoU) = area of intersection / area of unionua = (bb[2] - bb[0] + 1) * (bb[3] - bb[1] + 1) + (bbgt[2] - bbgt[0]+ 1) * (bbgt[3] - bbgt[1] + 1) - iw * ihov = iw * ih / uaif ov > ovmax:ovmax = ovgt_match = obj# assign detection as true positive/don't care/false positiveif show_animation:status = "NO MATCH FOUND!" # status is only used in the animation# set minimum overlapmin_overlap = MINOVERLAPif specific_iou_flagged:if class_name in specific_iou_classes:index = specific_iou_classes.index(class_name)min_overlap = float(iou_list[index])if ovmax >= min_overlap:if "difficult" not in gt_match:if not bool(gt_match["used"]):# true positivetp[idx] = 1gt_match["used"] = Truecount_true_positives[class_name] += 1# update the ".json" filewith open(gt_file, 'w') as f:f.write(json.dumps(ground_truth_data))if show_animation:status = "MATCH!"else:# false positive (multiple detection)fp[idx] = 1if show_animation:status = "REPEATED MATCH!"else:# false positivefp[idx] = 1if ovmax > 0:status = "INSUFFICIENT OVERLAP""""Draw image to show animation"""if show_animation:height, widht = img.shape[:2]# colors (OpenCV works with BGR)white = (255,255,255)light_blue = (255,200,100)green = (0,255,0)light_red = (30,30,255)# 1st linemargin = 10v_pos = int(height - margin - (bottom_border / 2.0))text = "Image: " + ground_truth_img[0] + " "img, line_width = draw_text_in_image(img, text, (margin, v_pos), white, 0)text = "Class [" + str(class_index) + "/" + str(n_classes) + "]: " + class_name + " "img, line_width = draw_text_in_image(img, text, (margin + line_width, v_pos), light_blue, line_width)if ovmax != -1:color = light_redif status == "INSUFFICIENT OVERLAP":text = "IoU: {0:.2f}% ".format(ovmax*100) + "< {0:.2f}% ".format(min_overlap*100)else:text = "IoU: {0:.2f}% ".format(ovmax*100) + ">= {0:.2f}% ".format(min_overlap*100)color = greenimg, _ = draw_text_in_image(img, text, (margin + line_width, v_pos), color, line_width)# 2nd linev_pos += int(bottom_border / 2.0)rank_pos = str(idx+1) # rank position (idx starts at 0)text = "Detection #rank: " + rank_pos + " confidence: {0:.2f}% ".format(float(detection["confidence"])*100)img, line_width = draw_text_in_image(img, text, (margin, v_pos), white, 0)color = light_redif status == "MATCH!":color = greentext = "Result: " + status + " "img, line_width = draw_text_in_image(img, text, (margin + line_width, v_pos), color, line_width)font = cv2.FONT_HERSHEY_SIMPLEXif ovmax > 0: # if there is intersections between the bounding-boxesbbgt = [ int(round(float(x))) for x in gt_match["bbox"].split() ]cv2.rectangle(img,(bbgt[0],bbgt[1]),(bbgt[2],bbgt[3]),light_blue,2)cv2.rectangle(img_cumulative,(bbgt[0],bbgt[1]),(bbgt[2],bbgt[3]),light_blue,2)cv2.putText(img_cumulative, class_name, (bbgt[0],bbgt[1] - 5), font, 0.6, light_blue, 1, cv2.LINE_AA)bb = [int(i) for i in bb]cv2.rectangle(img,(bb[0],bb[1]),(bb[2],bb[3]),color,2)cv2.rectangle(img_cumulative,(bb[0],bb[1]),(bb[2],bb[3]),color,2)cv2.putText(img_cumulative, class_name, (bb[0],bb[1] - 5), font, 0.6, color, 1, cv2.LINE_AA)# show imagecv2.imshow("Animation", img)cv2.waitKey(20) # show for 20 ms# save image to outputoutput_img_path = output_files_path + "/images/detections_one_by_one/" + class_name + "_detection" + str(idx) + ".jpg"cv2.imwrite(output_img_path, img)# save the image with all the objects drawn to itcv2.imwrite(img_cumulative_path, img_cumulative)#print(tp)# compute precision/recallcumsum = 0for idx, val in enumerate(fp):fp[idx] += cumsumcumsum += valcumsum = 0for idx, val in enumerate(tp):tp[idx] += cumsumcumsum += val#print(tp)rec = tp[:]for idx, val in enumerate(tp):rec[idx] = float(tp[idx]) / gt_counter_per_class[class_name]#print(rec)prec = tp[:]for idx, val in enumerate(tp):prec[idx] = float(tp[idx]) / (fp[idx] + tp[idx])#print(prec)ap, mrec, mprec = voc_ap(rec[:], prec[:])sum_AP += aptext = "{0:.2f}%".format(ap*100) + " = " + class_name + " AP " #class_name + " AP = {0:.2f}%".format(ap*100)"""Write to output.txt"""rounded_prec = [ '%.2f' % elem for elem in prec ]rounded_rec = [ '%.2f' % elem for elem in rec ]output_file.write(text + "\n Precision: " + str(rounded_prec) + "\n Recall :" + str(rounded_rec) + "\n\n")if not args.quiet:print(text)ap_dictionary[class_name] = apn_images = counter_images_per_class[class_name]lamr, mr, fppi = log_average_miss_rate(np.array(prec), np.array(rec), n_images)lamr_dictionary[class_name] = lamr"""Draw plot"""if draw_plot:plt.plot(rec, prec, '-o')# add a new penultimate point to the list (mrec[-2], 0.0)# since the last line segment (and respective area) do not affect the AP valuearea_under_curve_x = mrec[:-1] + [mrec[-2]] + [mrec[-1]]area_under_curve_y = mprec[:-1] + [0.0] + [mprec[-1]]plt.fill_between(area_under_curve_x, 0, area_under_curve_y, alpha=0.2, edgecolor='r')# set window titlefig = plt.gcf() # gcf - get current figurefig.canvas.manager.set_window_title('AP ' + class_name)# set plot titleplt.title('class: ' + text)#plt.suptitle('This is a somewhat long figure title', fontsize=16)# set axis titlesplt.xlabel('Recall')plt.ylabel('Precision')# optional - set axesaxes = plt.gca() # gca - get current axesaxes.set_xlim([0.0,1.0])axes.set_ylim([0.0,1.05]) # .05 to give some extra space# Alternative option -> wait for button to be pressed#while not plt.waitforbuttonpress(): pass # wait for key display# Alternative option -> normal display#plt.show()# save the plotfig.savefig(output_files_path + "/classes/" + class_name + ".png")plt.cla() # clear axes for next plotif show_animation:cv2.destroyAllWindows()output_file.write("\n# mAP of all classes\n")mAP = sum_AP / n_classestext = "mAP = {0:.2f}%".format(mAP*100)output_file.write(text + "\n")print(text)"""Draw false negatives
"""
if show_animation:pink = (203,192,255)for tmp_file in gt_files:ground_truth_data = json.load(open(tmp_file))#print(ground_truth_data)# get name of corresponding imagestart = TEMP_FILES_PATH + '/'img_id = tmp_file[tmp_file.find(start)+len(start):tmp_file.rfind('_ground_truth.json')]img_cumulative_path = output_files_path + "/images/" + img_id + ".jpg"img = cv2.imread(img_cumulative_path)if img is None:img_path = IMG_PATH + '/' + img_id + ".jpg"img = cv2.imread(img_path)# draw false negativesfor obj in ground_truth_data:if not obj['used']:bbgt = [ int(round(float(x))) for x in obj["bbox"].split() ]cv2.rectangle(img,(bbgt[0],bbgt[1]),(bbgt[2],bbgt[3]),pink,2)cv2.imwrite(img_cumulative_path, img)# remove the temp_files directory
shutil.rmtree(TEMP_FILES_PATH)"""Count total of detection-results
"""
# iterate through all the files
det_counter_per_class = {}
for txt_file in dr_files_list:# get lines to listlines_list = file_lines_to_list(txt_file)for line in lines_list:class_name = line.split()[0]# check if class is in the ignore list, if yes skipif class_name in args.ignore:continue# count that objectif class_name in det_counter_per_class:det_counter_per_class[class_name] += 1else:# if class didn't exist yetdet_counter_per_class[class_name] = 1
#print(det_counter_per_class)
dr_classes = list(det_counter_per_class.keys())"""Plot the total number of occurences of each class in the ground-truth
"""
if draw_plot:window_title = "ground-truth-info"plot_title = "ground-truth\n"plot_title += "(" + str(len(ground_truth_files_list)) + " files and " + str(n_classes) + " classes)"x_label = "Number of objects per class"output_path = output_files_path + "/ground-truth-info.png"to_show = Falseplot_color = 'forestgreen'draw_plot_func(gt_counter_per_class,n_classes,window_title,plot_title,x_label,output_path,to_show,plot_color,'',)"""Write number of ground-truth objects per class to results.txt
"""
with open(output_files_path + "/output.txt", 'a') as output_file:output_file.write("\n# Number of ground-truth objects per class\n")for class_name in sorted(gt_counter_per_class):output_file.write(class_name + ": " + str(gt_counter_per_class[class_name]) + "\n")"""Finish counting true positives
"""
for class_name in dr_classes:# if class exists in detection-result but not in ground-truth then there are no true positives in that classif class_name not in gt_classes:count_true_positives[class_name] = 0
#print(count_true_positives)"""Plot the total number of occurences of each class in the "detection-results" folder
"""
if draw_plot:window_title = "detection-results-info"# Plot titleplot_title = "detection-results\n"plot_title += "(" + str(len(dr_files_list)) + " files and "count_non_zero_values_in_dictionary = sum(int(x) > 0 for x in list(det_counter_per_class.values()))plot_title += str(count_non_zero_values_in_dictionary) + " detected classes)"# end Plot titlex_label = "Number of objects per class"output_path = output_files_path + "/detection-results-info.png"to_show = Falseplot_color = 'forestgreen'true_p_bar = count_true_positivesdraw_plot_func(det_counter_per_class,len(det_counter_per_class),window_title,plot_title,x_label,output_path,to_show,plot_color,true_p_bar)"""Write number of detected objects per class to output.txt
"""
with open(output_files_path + "/output.txt", 'a') as output_file:output_file.write("\n# Number of detected objects per class\n")for class_name in sorted(dr_classes):n_det = det_counter_per_class[class_name]text = class_name + ": " + str(n_det)text += " (tp:" + str(count_true_positives[class_name]) + ""text += ", fp:" + str(n_det - count_true_positives[class_name]) + ")\n"output_file.write(text)"""Draw log-average miss rate plot (Show lamr of all classes in decreasing order)
"""
if draw_plot:window_title = "lamr"plot_title = "log-average miss rate"x_label = "log-average miss rate"output_path = output_files_path + "/lamr.png"to_show = Falseplot_color = 'royalblue'draw_plot_func(lamr_dictionary,n_classes,window_title,plot_title,x_label,output_path,to_show,plot_color,"")"""Draw mAP plot (Show AP's of all classes in decreasing order)
"""
if draw_plot:window_title = "mAP"plot_title = "mAP = {0:.2f}%".format(mAP*100)x_label = "Average Precision"output_path = output_files_path + "/mAP.png"to_show = Trueplot_color = 'royalblue'draw_plot_func(ap_dictionary,n_classes,window_title,plot_title,x_label,output_path,to_show,plot_color,"")