一、NMS是什么?
NMS(non maximum suppression)即非極大值抑制,廣泛應用于傳統的特征提取和深度學習的目標檢測算法中。
NMS原理是通過篩選出局部極大值得到最優解。
在2維邊緣提取中體現在提取邊緣輪廓后將一些梯度方向變化率較小的點篩選掉,避免造成干擾。
在三維關鍵點檢測中也起到重要作用,篩選掉特征中非局部極值。
在目標檢測方面如Yolo和RCNN等模型中均有使用,可以將較小分數的輸出框過濾掉,同樣,在三維基于點云的目標檢測模型中亦有使用。
二、示例
1.opencv示例
查看opencv源碼,可以知道canny算子中使用了nms,即對sobel等梯度計算方法生成的梯度矩陣中的點求取局部極大值。
其計算方法是比較中心點與其鄰域的梯度值,如果為最大值,則保留,不是的話為0。
源碼可見:
Canny算法解析,opencv源碼實現及實例
//讀取圖片Mat img = imread("true.jpg");Mat Grayimg;resize(img, img, Size(400, 600), 0, 0, INTER_LINEAR);cvtColor(img, Grayimg, COLOR_RGB2GRAY); //轉為灰度圖Canny(Grayimg, Grayimg, 100, 300, 3);imshow("picture0", img);imshow("picture", Grayimg);waitKey(0);return 0;
2.PCL示例
點云關鍵點特征提取算法經常會使用nms提取極大值點。
如3D SIFT關鍵點檢測中需要計算尺度空間中像素點的26鄰域的極值點。
算法原理參考:
PCL 3D-SIFT關鍵點檢測(Z方向梯度約束)
pcl::SIFTKeypoint<pcl::PointXYZ, pcl::PointWithScale> sift;
pcl::PointCloud<pcl::PointWithScale> result;
sift.setInputCloud(cloud_xyz);
pcl::search::KdTree<pcl::PointXYZ>::Ptr tree(new pcl::search::KdTree<pcl::PointXYZ>());
sift.setSearchMethod(tree);
sift.setScales(0.01f, 7, 20);
sift.setMinimumContrast(0.001f);
sift.compute(result);
3.目標檢測中nms示例
nms在深度學習領域常用于對box的得分進行極大值篩選,在rcnn,yolo, pointnet等模型中廣泛使用。
其算法流程大致為:
1:計算所有box的得分。
2:排序,依次與得分高的box的IOU進行對比,如果大于設定的閾值,就刪除該框。
在yolo源代碼detect.py可見:
pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
conf_thres:置信度即得分score的閾值,yolo為0.25。
iou_thres:重疊度閾值,為0.45
classes:類別數,可以設置保留哪一類的box
agnostic_nms:是否去除不同類別之間的框,默認false
max_det:一張圖片中最大識別種類的個數,默認300
def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,labels=(), max_det=300):"""Runs Non-Maximum Suppression (NMS) on inference resultsReturns:list of detections, on (n,6) tensor per image [xyxy, conf, cls]"""nc = prediction.shape[2] - 5 # number of classesxc = prediction[..., 4] > conf_thres # candidates# Checksassert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'# Settingsmin_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and heightmax_nms = 30000 # maximum number of boxes into torchvision.ops.nms()time_limit = 10.0 # seconds to quit afterredundant = True # require redundant detectionsmulti_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)merge = False # use merge-NMSt = time.time()output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]for xi, x in enumerate(prediction): # image index, image inference# Apply constraints# x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-heightx = x[xc[xi]] # confidence# Cat apriori labels if autolabellingif labels and len(labels[xi]):l = labels[xi]v = torch.zeros((len(l), nc + 5), device=x.device)v[:, :4] = l[:, 1:5] # boxv[:, 4] = 1.0 # confv[range(len(l)), l[:, 0].long() + 5] = 1.0 # clsx = torch.cat((x, v), 0)# If none remain process next imageif not x.shape[0]:continue# Compute confx[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf# Box (center x, center y, width, height) to (x1, y1, x2, y2)box = xywh2xyxy(x[:, :4])# Detections matrix nx6 (xyxy, conf, cls)if multi_label:i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).Tx = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)else: # best class onlyconf, j = x[:, 5:].max(1, keepdim=True)x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]# Filter by classif classes is not None:x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]# Apply finite constraint# if not torch.isfinite(x).all():# x = x[torch.isfinite(x).all(1)]# Check shapen = x.shape[0] # number of boxesif not n: # no boxescontinueelif n > max_nms: # excess boxesx = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence# Batched NMSc = x[:, 5:6] * (0 if agnostic else max_wh) # classesboxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scoresi = torchvision.ops.nms(boxes, scores, iou_thres) # NMSif i.shape[0] > max_det: # limit detectionsi = i[:max_det]if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)# update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)iou = box_iou(boxes[i], boxes) > iou_thres # iou matrixweights = iou * scores[None] # box weightsx[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxesif redundant:i = i[iou.sum(1) > 1] # require redundancyoutput[xi] = x[i]if (time.time() - t) > time_limit:print(f'WARNING: NMS time limit {time_limit}s exceeded')break # time limit exceededreturn output