一、概述
在視覺任務中,圖像分割任務是一個很廣泛的領域,應用于交互式分割,邊緣檢測,超像素化,感興趣目標生成,前景分割,語義分割,實例分割,泛視分割等。
交互式分割,這種分割任務,它允許用戶手動細化掩碼來分割任意類型的對象。然而,這種方法需要用戶的不斷參與和指導,類似于ps里面的摳圖快速選擇工具。
實例分割任務是它能夠自動分割特定類別的對象,例如行人,狗,電視或椅子,但需要大量的手動標注數據,標注樣本要以上萬個樣本,然后要經過大量的計算資源和代碼算法知識來訓練模型。這種方式應用最廣泛應該是人像自動摳圖:
為了 解決這些分割任務的局限性,Meta 推出了「分割一切」AI 算法Segment Anything,為分割任務提供一種通用的、全自動的分割解決方案。
二、Segment Anything 萬物分割
1.算法摘要
作者介紹了Segment Anything (SA) 項目,這是一個旨在進行圖像分割的新任務,同時提供了相應的模型和數據集。在該項目中,作者采用了一種高效的模型來進行數據收集,以構建迄今為止最大的分割數據集。他們在超過1,100萬張公開的圖像上進行了標注,生成了超過10億個掩碼。
這個模型(SAM)訓練完成之后,以使其具備"promptable"(可提示)的性質,因此意味著它可以零樣本(zero-shot)地適應新的數據集和任務,而無需先對數據進行標注和訓練。作者對該模型進行了廣泛的評估,發現它在許多任務上的零樣本表現通常與完全監督的性能相媲美,甚至更好。作者公開了他們的模型(SAM),還發布了相應的圖像數據集(SA-1B)。
2. 算法介紹
LLM的出現,讓研究人員感受到,使用互聯網規模的數據集上預訓練的大型語言模型已經改變了自然語言處理(NLP)領域,因為它們表現出強大的零樣本和少樣本泛化能力,可以應對未在訓練中出現的任務和數據分布。這種泛化通常通過提示工程(prompt engineering)來實現,其中手工制作的文本提示可以引導語言模型生成有效的文本響應。這些基礎模型在使用豐富的互聯網文本語料庫進行預訓練時,表現出令人驚訝的零樣本和少樣本性能,有時甚至可以與經過精細調整(fine-tune)的模型相媲美。研究經驗表明,這種零樣本和少樣本性能會隨著模型規模、數據集大小和總訓練計算量的增加而改善。
在計算機視覺領域也在探索基礎模型的應用,例如,CLIP和ALIGN使用對比學習來訓練文本和圖像編碼器,經過訓練后,這些編碼器可以用于零樣本泛化到新的視覺概念和數據分布。這些編碼器還可以有效地與其他模塊結合,用于解決下游任務,比如圖像生成。然而,計算機視覺領域涉及的問題遠不止這些,而且許多問題缺乏豐富的訓練數據。
在這項研究工作中,SAM作者的目標是建立一個圖像分割的基礎模型,也就是一個可提示的模型,它可以在廣泛的數據集上進行預訓練以實現強大的泛化能力。一旦有了這個模型,作者進一步探索如何通過快速流程來解決各種新的數據分布上的下游分割問題。
這個計劃的成功取決于三個關鍵要素:任務、模型和數據。作者需要解決以下關于圖像分割的問題:
-
什么樣的視覺分割任務可以實現零樣本泛化?
-
為了實現這一個分割任務,對應的模型架構應該是什么樣的?
-
哪些數據可以支持這個任務和模型的預訓練?
作者首先定義了一個可提示的分割任務,這個任務足夠通用,可以作為強大的預訓練目標,同時也可以支持廣泛的下游應用。這個任務要求一個支持多種提示的模型,并能夠實時生成分割掩碼,以支持交互式使用。然而,互聯網上目前尚沒有足夠大規模的分割數據集來滿足這個任務的需求。作者提出了“數據引擎”來應對這個問題,即通過模型輔助數據收集和不斷迭代來改進數據,以填補數據的不足。這個方法可以在模型訓練和數據收集之間進行交互,以實現更好的性能。
- 分割任務
在自然語言處理和計算機視覺領域,基礎模型具有很大的前景,因為它們可以用于執行零樣本學習和少樣本學習,通過利用提示來適應新的數據集和任務。受到這種思路的啟發,本文提出了一個稱為"可提示分割任務"的新領域,其主要目標是在給定分割提示的情況下生成有效的分割掩碼(如圖1a所示)。
這些分割提示可以簡單地指定圖像中要分割的對象,例如,提示可以包括對象的位置信息或文本描述。
這里的"有效輸出掩碼"意味著,即使提示信息模糊不清,可能指向多個不同對象(例如,在圖像上一個點可能表示襯衫或穿襯衫的人),生成的分割掩碼也應該合理,至少應該包括這些對象中的一個。
在這項研究中,作者將可提示分割任務作為預訓練目標,然后使用提示工程方法來解決各種不同的下游分割任務。這種方法有望為計算機視覺領域帶來一種強大的學習范式,可以在面對新任務時從有限的提示信息中進行學習,而不需要大量的標記數據。這對于處理多樣化和復雜的視覺任務可能具有很大的潛力。
- 模型選擇
可提示分割任務對模型的架構提出了一些嚴格的要求,這包括對提示的支持靈活性、實時計算的需求,以便允許交互使用,以及能夠處理歧義。作者提出了一個簡單的模型設計,可以滿足這些要求,被稱為"Segment Anything"模型,簡稱SAM(見圖1b)。SAM的架構包括以下組成部分: - 圖像編碼器:這是一個強大的模型,負責將輸入圖像轉化為圖像嵌入(image embedding),以捕捉圖像的特征信息。
- 提示編碼器:這是一個用于嵌入提示信息的模型,它將提示信息轉化為提示嵌入,以使模型能夠理解提示中的內容。
- 控碼解碼器:這是一個輕量級的模型,負責將圖像嵌入和提示嵌入結合,然后預測分割掩碼。這一部分的設計使得SAM可以實現對相同圖像嵌入的不同提示信息的分配,從而使模型能夠處理多樣性的提示。
SAM的設計還允許它在不超過50毫秒的時間內從提示符中預測掩碼,實現了實時性能,這對于實際應用和交互式任務非常重要。
作者的主要關注點包括邊界框、關鍵點和分割掩碼提示。為了解決歧義問題,SAM被設計成能夠預測多個掩碼,即使給定相同的提示。這使得SAM可以自然地處理提示中的歧義,比如前文提到的襯衫和穿襯衫的人之間的歧義示例。這個能力對于處理復雜的圖像場景和多義性提示非常有幫助。
- 數據引擎
為了使SAM能夠在新的數據分布上實現強大的泛化能力,需要在一個大型數據集上進行訓練,該數據集應該覆蓋各種不同的分割任務和場景。然而,典型的訓練方法通常依賴于在線獲取數據,而掩碼標注信息通常相對稀缺,因此需要采用替代策略。作者提出的解決方案是構建一個稱為"數據引擎"的系統,這個引擎包括三個主要階段:輔助手動、半自動和全自動。
-
輔助手動階段:在這個階段,SAM與人工注釋人員協作,類似于傳統的交互式分割設置。人工注釋人員手動為圖像中的對象生成掩碼,同時SAM提供輔助信息,例如提示信息,以幫助人工注釋人員完成掩碼的生成。這一階段有助于收集一些基本的分割標注。
-
半自動階段:在這個階段,SAM能夠自動為圖像中的對象的某些子區域生成掩碼。它會根據已有的掩碼和提示信息,自動預測可能的對象位置,并生成相應的掩碼。這減輕了人工注釋人員的工作負擔,因為他們可以專注于注釋剩余的對象,從而提高了標注的多樣性。
-
全自動階段:在最后一個階段,作者采用一種規則網格提示SAM,用于生成大量高質量掩碼。這個提示方法能夠為每張圖像平均產生約100個掩碼,以增加數據的多樣性和覆蓋不同的情況。
通過這種數據引擎的階段性設計,作者能夠有效地利用協作注釋和自動化方法,以構建一個大規模的數據集,為SAM的訓練提供了足夠豐富和多樣的標注數據,從而使其在新的數據分布上實現強大的泛化能力。這種方法有助于克服標注數據稀缺性的問題,尤其是對于復雜的分割任務。
3.數據集
作者最終的數據集SA-1B,包括來自1100萬張經許可和隱私保護圖像的超過10億個掩碼(見圖2)。SA-1B使用作者的數據引擎的最后階段完全自動收集,比現有的最大分割數據集擁有400多倍的掩碼,并且作者廣泛驗證,掩碼具有高質量和多樣性。作者希望SA-1B能夠成為一種有價值的資源,用于建立新的基礎模型。
4.實驗
作者廣泛地評估SAM。首先,在23個分割數據集上的測試,作者發現SAM從單個前景點生成了高質量的掩碼,通常僅略低于手動注釋的真實值。其次,作者在使用提示工程的零樣本傳輸協議(zero-shot transfer protocol)下的各種下游任務上發現了持續強大的定量和定性結果,包括邊緣檢測、感興趣目標生成、實例分割和文本到掩碼預測。這些結果表明,SAM可以在即時工程中開箱即用,解決涉及SAM訓練數據之外的圖像分布的各種任務。
三、模型C++推理
1.實現代碼
#include "include/segment_anything.h"
namespace sam{
SegmentAnything::~SegmentAnything()
{image_encoder_net_.clear();mask_decoder_net_.clear();
}static inline float intersection_area(const sam_result_t& a, const sam_result_t& b)
{cv::Rect_<float> inter = a.box & b.box;return inter.area();
}static void qsort_descent_inplace(std::vector<sam_result_t>& faceobjects, int left, int right)
{int i = left;int j = right;float p = faceobjects[(left + right) / 2].iou_pred;while (i <= j){while (faceobjects[i].iou_pred > p)i++;while (faceobjects[j].iou_pred < p)j--;if (i <= j){// swapstd::swap(faceobjects[i], faceobjects[j]);i++;j--;}}#pragma omp parallel sections{#pragma omp section{if (left < j) qsort_descent_inplace(faceobjects, left, j);}#pragma omp section{if (i < right) qsort_descent_inplace(faceobjects, i, right);}}
}static void qsort_descent_inplace(std::vector<sam_result_t>& faceobjects)
{if (faceobjects.empty())return;qsort_descent_inplace(faceobjects, 0, faceobjects.size() - 1);
}static void nms_sorted_bboxes(const cv::Mat& bgr,const std::vector<sam_result_t>& faceobjects, std::vector<int>& picked, float nms_threshold)
{picked.clear();const int n = faceobjects.size();std::vector<float> areas(n);for (int i = 0; i < n; i++){areas[i] = faceobjects[i].box.area();}cv::Mat img = bgr.clone();for (int i = 0; i < n; i++){const sam_result_t& a = faceobjects[i];int keep = 1;for (int j = 0; j < (int)picked.size(); j++){const sam_result_t& b = faceobjects[picked[j]];// intersection over unionfloat inter_area = intersection_area(a, b);float union_area = areas[i] + areas[picked[j]] - inter_area;// float IoU = inter_area / union_areaif (inter_area / union_area > nms_threshold){keep = 0;}}if (keep)picked.push_back(i);}
}
int SegmentAnything::NMS(const cv::Mat& bgr, std::vector<sam_result_t>& proposals, std::vector<int>& picked, float nms_threshold)
{qsort_descent_inplace(proposals);nms_sorted_bboxes(bgr, proposals, picked, nms_threshold);return 0;
}int SegmentAnything::Load(const std::string& image_encoder_param, const std::string& image_encoder_bin, const std::string& mask_decoder_param, const std::string& mask_decoder_bin)
{int ret = 0;ret = image_encoder_net_.load_param(image_encoder_param.c_str());if (ret < 0)return -1;ret = image_encoder_net_.load_model(image_encoder_bin.c_str());if (ret < 0)return -1;ret = mask_decoder_net_.load_param(mask_decoder_param.c_str());if (ret < 0)return -1;ret = mask_decoder_net_.load_model(mask_decoder_bin.c_str());if (ret < 0)return -1;return 0;
}
int SegmentAnything::ImageEncoder(const cv::Mat& bgr, ncnn::Mat& image_embeddings, image_info_t& image_info)
{const int target_size = 1024;int img_w = bgr.cols;int img_h = bgr.rows;int w = img_w;int h = img_h;float scale = 1.f;if (w > h){scale = (float)target_size / w;w = target_size;h = h * scale;}else{scale = (float)target_size / h;h = target_size;w = w * scale;}ncnn::Mat in = ncnn::Mat::from_pixels_resize(bgr.data, ncnn::Mat::PIXEL_BGR2RGB, img_w, img_h, w, h);int wpad = target_size - w;int hpad = target_size - h;ncnn::Mat in_pad;ncnn::copy_make_border(in, in_pad, 0, hpad, 0, wpad, ncnn::BORDER_CONSTANT, 0.f);in_pad.substract_mean_normalize(means_, norms_);ncnn::Extractor image_encoder_ex = image_encoder_net_.create_extractor();image_encoder_ex.input("image", in_pad);image_encoder_ex.extract("image_embeddings", image_embeddings);image_info.img_h = img_h;image_info.img_w = img_w;image_info.pad_h = h;image_info.pad_w = w;image_info.scale = scale;return 0;
}int SegmentAnything::embed_masks(const prompt_info_t& prompt_info, ncnn::Mat& mask_input, ncnn::Mat& has_mask)
{mask_input = ncnn::Mat(256, 256, 1);mask_input.fill(0.f);has_mask = ncnn::Mat(1);has_mask.fill(0.f);return 0;
}
int SegmentAnything::transform_coords(const image_info_t& image_info, ncnn::Mat& point_coords)
{for(int h = 0; h < point_coords.h; ++h){float* ptr = point_coords.row(h);ptr[0] *= image_info.scale;ptr[1] *= image_info.scale;}return 0;
}
int SegmentAnything::embed_points(const prompt_info_t& prompt_info, std::vector<ncnn::Mat>& point_labels, ncnn::Mat& point_coords)
{int num_points = prompt_info.points.size() / 2;point_coords = ncnn::Mat(num_points * 2, (void*)prompt_info.points.data()).reshape(2, num_points).clone();ncnn::Mat point_labels1 = ncnn::Mat(256, num_points);ncnn::Mat point_labels2 = ncnn::Mat(256, num_points);ncnn::Mat point_labels3 = ncnn::Mat(256, num_points);ncnn::Mat point_labels4 = ncnn::Mat(256, num_points);ncnn::Mat point_labels5 = ncnn::Mat(256, num_points);ncnn::Mat point_labels6 = ncnn::Mat(256, num_points);point_labels1.row_range(0, num_points - 1).fill(1.f);point_labels1.row_range(num_points - 1, 1).fill(0.f);for (int i = 0; i < num_points - 1; ++i) {if (prompt_info.labels[i] == -1)point_labels2.row_range(i, 1).fill(1.f);elsepoint_labels2.row_range(i, 1).fill(0.f);}point_labels2.row_range(num_points - 1, 1).fill(1.f);for (int i = 0; i < num_points - 1; ++i) {if (prompt_info.labels[i] == 0)point_labels3.row_range(i, 1).fill(1.f);elsepoint_labels3.row_range(i, 1).fill(0.f);}point_labels3.row_range(num_points - 1, 1).fill(0.f);for (int i = 0; i < num_points - 1; ++i) {if (prompt_info.labels[i] == 1)point_labels4.row_range(i, 1).fill(1.f);elsepoint_labels4.row_range(i, 1).fill(0.f);}point_labels4.row_range(num_points - 1, 1).fill(0.f);for (int i = 0; i < num_points - 1; ++i) {if (prompt_info.labels[i] == 2)point_labels5.row_range(i, 1).fill(1.f);elsepoint_labels5.row_range(i, 1).fill(0.f);}point_labels5.row_range(num_points - 1, 1).fill(0.f);for (int i = 0; i < num_points - 1; ++i) {if (prompt_info.labels[i] == 3)point_labels6.row_range(i, 1).fill(1.f);elsepoint_labels6.row_range(i, 1).fill(0.f);}point_labels6.row_range(num_points - 1, 1).fill(0.f);point_labels.push_back(point_labels1);point_labels.push_back(point_labels2);point_labels.push_back(point_labels3);point_labels.push_back(point_labels4);point_labels.push_back(point_labels5);point_labels.push_back(point_labels6);return 0;
}
int SegmentAnything::MaskDecoder(const ncnn::Mat& image_embeddings, image_info_t& image_info, const prompt_info_t& prompt_info, std::vector<sam_result_t>& sam_results, float pred_iou_thresh, float stability_score_thresh)
{std::vector<ncnn::Mat> point_labels;ncnn::Mat point_coords;embed_points(prompt_info, point_labels, point_coords);transform_coords(image_info, point_coords);ncnn::Mat mask_input, has_mask;embed_masks(prompt_info, mask_input, has_mask);ncnn::Extractor mask_decoder_ex = mask_decoder_net_.create_extractor();mask_decoder_ex.input("mask_input", mask_input);mask_decoder_ex.input("point_coords", point_coords);mask_decoder_ex.input("point_labels1", point_labels[0]);mask_decoder_ex.input("point_labels2", point_labels[1]);mask_decoder_ex.input("point_labels3", point_labels[2]);mask_decoder_ex.input("point_labels4", point_labels[3]);mask_decoder_ex.input("point_labels5", point_labels[4]);mask_decoder_ex.input("point_labels6", point_labels[5]);mask_decoder_ex.input("image_embeddings", image_embeddings);mask_decoder_ex.input("has_mask_input", has_mask);ncnn::Mat scores;mask_decoder_ex.extract("scores", scores);ncnn::Mat masks;mask_decoder_ex.extract("masks", masks);//postprocessstd::vector<std::pair<float, int>> scores_vec;for (int i = 1; i < scores.w; ++i) {scores_vec.push_back(std::pair<float, int>(scores[i], i));}std::sort(scores_vec.begin(), scores_vec.end(), std::greater<std::pair<float, int>>());if (scores_vec[0].first > pred_iou_thresh) {sam_result_t sam_result;ncnn::Mat mask = masks.channel(scores_vec[0].second);cv::Mat cv_mask_32f = cv::Mat::zeros(cv::Size(mask.w, mask.h), CV_32F);std::copy((float*)mask.data, (float*)mask.data + mask.w * mask.h, (float*)cv_mask_32f.data);cv::Mat single_mask_32f;cv::resize(cv_mask_32f(cv::Rect(0, 0, image_info.pad_w, image_info.pad_h)), single_mask_32f, cv::Size(image_info.img_w,image_info.img_h), 0, 0, 1);float stable_score = calculate_stability_score(single_mask_32f);if (stable_score < stability_score_thresh)return -1;single_mask_32f = single_mask_32f > 0;single_mask_32f.convertTo(sam_result.mask, CV_8UC1, 1, 0);if (postprocess_mask(sam_result.mask, sam_result.box) < 0)return -1;sam_results.push_back(sam_result);}else {return -1;}return 0;
}
int SegmentAnything::postprocess_mask(cv::Mat& mask, cv::Rect& box)
{std::vector<std::vector<cv::Point>> contours;std::vector<cv::Vec4i> hierarchy;cv::findContours(mask.clone(), contours, hierarchy, cv::RETR_EXTERNAL, cv::CHAIN_APPROX_SIMPLE);if(contours.size() == 0)return -1;if (contours.size() > 1) {float max_area = 0;int max_idx = 0;std::vector<std::pair<float,int>> areas;for (size_t i = 0; i < contours.size(); ++i) {float area = cv::contourArea(contours[i]);if (area > max_area) {max_idx = i;max_area = area;}areas.push_back(std::pair<float,int>(area,i));}for (size_t i = 0; i < areas.size(); ++i) {//if (i == max_idx)// continue;//else {// cv::drawContours(mask, contours, i, cv::Scalar(0), -1);//}if(areas[i].first < max_area * 0.3){cv::drawContours(mask, contours, i, cv::Scalar(0), -1);}else{box = box | cv::boundingRect(contours[i]);}}}else {box = cv::boundingRect(contours[0]);}return 0;
}
float SegmentAnything::calculate_stability_score(cv::Mat& mask, float mask_threshold, float stable_score_offset)
{float intersections = (float)cv::countNonZero(mask > (mask_threshold + stable_score_offset));float unions = (float)cv::countNonZero(mask > (mask_threshold - stable_score_offset));return intersections / unions;
}
}
2. 交互方法
分割交互方式中有好四種,開放式點,可以多個點組合, 矩形框, 分割一切,還有文字提示這幾種方式。但文字提示效果不太穩定,C++代碼沒有實現這一部分。
開放式點
點擊要分割目標的中間,分割包含該點的物體,會按最小分割的結果展示出來,如果想分割的物體大于展示的結果,可以在物體的其他部分也點擊下:
選擇矩形框
使用鼠標拖動在目標選擇,分割目標:
分割一切
將圖片中所有物體的分割都展示出來: