"近期,二維到三維感知技術的進步顯著提升了對二維圖像中三維場景的理解能力。然而,現有方法面臨諸多關鍵挑戰,包括跨場景泛化能力有限、感知精度欠佳以及重建速度緩慢。為克服這些局限,我們提出了感知高效三維重建框架(PE3R),旨在同時提升準確性與效率。PE3R采用前饋架構,實現了快速的三維語義場重建。該框架在多樣化的場景與對象上展現出強大的零樣本泛化能力,并顯著提高了重建速度。在二維到三維開放詞匯分割及三維重建上的大量實驗驗證了PE3R的有效性與多功能性。" PE3R的作者這樣寫
代碼開源在?
GitHub - hujiecpp/PE3R: PE3R: Perception-Efficient 3D Reconstruction. Take 2 - 3 photos with your phone, upload them, wait a few minutes, and then start exploring your 3D world via text!PE3R: Perception-Efficient 3D Reconstruction. Take 2 - 3 photos with your phone, upload them, wait a few minutes, and then start exploring your 3D world via text! - hujiecpp/PE3Rhttps://github.com/hujiecpp/PE3R
論文地址
https://arxiv.org/abs/2503.07507https://t.co/ec3NSH0KoN
簡單的梳理下論文背景和成果,后面會從代碼分析模型結構
背景
PE3R 誕生的背景是現有方法如NeRF和3DGS依賴于場景特定的訓練和語義提取,計算開銷大,限制了實際應用的可擴展性。
研究空白
- 現有方法在多場景泛化、感知精度和重建速度方面表現不佳。
- 缺乏一種能夠在不依賴3D數據的情況下高效進行3D語義重建的框架。
核心貢獻
- 提出了PE3R(Perception-Efficient 3D Reconstruction)框架,用于高效且準確的3D語義重建。
- 通過僅使用2D圖像實現3D場景重建,無需額外的3D數據(如相機參數或深度信息)。
技術架構
- PE3R框架包含三個關鍵模塊:像素嵌入消歧、語義場重建和全局視角感知。
- 通過前饋機制實現快速的3D語義重建。
實現細節
- 像素嵌入消歧模塊通過跨視角、多層次的語義信息解決像素級別的語義歧義。
- 語義場重建模塊將語義信息直接嵌入到重建過程中,提升重建精度。
- 全局視角感知模塊通過全局語義對齊,減少單視角引入的噪聲。
一個意想不到的細節
除了基本的 3D 重建,PE3R 還支持基于文本的查詢功能,允許用戶通過描述選擇特定的 3D 對象,這在傳統 3D 重建系統中并不常見。
結論
- PE3R框架通過高效的3D語義重建,顯著提升了2D到3D感知的速度和精度。
- 該框架在不依賴場景特定訓練或預校準3D數據的情況下,實現了零樣本泛化,具有廣泛的實際應用潛力。
下載代碼安裝必要依賴后,運行pe3r
上傳官方測試用的 4 張測試圖片
點擊‘reconstruct’? 后日志輸出
渲染glb
查詢Chair
嘗試本地圖片3D構建
構建結構
查找花盆
代碼分析
.\PE3R\modules\pe3r\models.py 展示的模型結構
sys.path.append(os.path.abspath('./modules/ultralytics'))from transformers import AutoTokenizer, AutoModel, AutoProcessor, SamModel
from modules.mast3r.model import AsymmetricMASt3R# from modules.sam2.build_sam import build_sam2_video_predictor
from modules.mobilesamv2.promt_mobilesamv2 import ObjectAwareModel
from modules.mobilesamv2 import sam_model_registryfrom sam2.sam2_video_predictor import SAM2VideoPredictorclass Models:def __init__(self, device):# -- mast3r --# MAST3R_CKP = './checkpoints/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth'MAST3R_CKP = 'naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric'self.mast3r = AsymmetricMASt3R.from_pretrained(MAST3R_CKP).to(device)# -- sam2 --self.sam2 = SAM2VideoPredictor.from_pretrained('facebook/sam2.1-hiera-large', device=device)# -- mobilesamv2 & sam1 --SAM1_DECODER_CKP = './checkpoints/Prompt_guided_Mask_Decoder.pt'self.mobilesamv2 = sam_model_registry['sam_vit_h'](None)# image_encoder=sam_model_registry['sam_vit_h_encoder'](SAM1_ENCODER_CKP)sam1 = SamModel.from_pretrained('facebook/sam-vit-huge')image_encoder = sam1.vision_encoderprompt_encoder, mask_decoder = sam_model_registry['prompt_guided_decoder'](SAM1_DECODER_CKP)self.mobilesamv2.prompt_encoder = prompt_encoderself.mobilesamv2.mask_decoder = mask_decoderself.mobilesamv2.image_encoder=image_encoderself.mobilesamv2.to(device=device)self.mobilesamv2.eval()# -- yolov8 --YOLO8_CKP='./checkpoints/ObjectAwareModel.pt'self.yolov8 = ObjectAwareModel(YOLO8_CKP)# -- siglip --self.siglip = AutoModel.from_pretrained("google/siglip-large-patch16-256", device_map=device)self.siglip_tokenizer = AutoTokenizer.from_pretrained("google/siglip-large-patch16-256")self.siglip_processor = AutoProcessor.from_pretrained("google/siglip-large-patch16-256")
模型結構中的 Models 類初始化了多個組件,如 MASt3R、SAM2、MobileSAMv2、SAM1、YOLOv8 和 Siglip,共同處理復雜的場景。MASt3R是從checkpoint加載的非對稱模型,用于3D重建或匹配。YOLOv8用于對象檢測,Siglip則用于圖像和文本特征提取。看來這個類整合了多個尖端模型,分別處理不同任務。
Models 類初始化了以下關鍵組件:
- MASt3R:用于多視圖立體視覺,估計圖像對之間的深度和姿態。
- 加載checkpoint(如 naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric),是一個多視圖立體視覺模型。
- 功能:估計圖像對之間的深度和姿態,核心用于 3D 重建。
- SAM2:視頻分割模型,用于在圖像序列中傳播分割掩碼。
- 從預訓練模型 facebook/sam2.1-hiera-large 加載,是視頻分割預測器。
- 功能:在圖像序列中傳播分割掩碼,確保對象在不同幀中的一致性。
- MobileSAMv2 和 SAM1:用于基于對象檢測的精確圖像分割。
- MobileSAMv2 使用自定義掩碼解碼器(如 Prompt_guided_Mask_Decoder.pt)初始化,結合 SAM1 的視覺編碼器。
- 功能:基于 YOLOv8 的對象檢測結果,進行精確的圖像分割。
- YOLOv8:對象檢測模型,識別圖像中的潛在對象。
- checkpoint?ObjectAwareModel.pt 加載,是 Ultralytics 的對象檢測模型。
- 功能:識別圖像中的潛在對象,提供邊界框供后續分割使用。
- Siglip:從圖像片段提取特征,支持基于文本的查詢。
- 從 google/siglip-large-patch16-256 加載,支持圖像和文本特征提取。
- 功能:從分割后的圖像片段提取特征,支持基于文本的查詢。
對比表:組件功能與作用
組件 | 主要功能 | 在 3D 重建中的作用 |
---|---|---|
MASt3R | 多視圖立體視覺 | 估計深度和姿態,核心重建步驟 |
SAM2 | 視頻分割傳播 | 確保跨視圖分割一致性 |
MobileSAMv2 | 圖像分割 | 基于檢測生成精確掩碼 |
SAM1 | 圖像分割輔助 | 提供初始幀的分割掩碼 |
YOLOv8 | 對象檢測 | 提供邊界框,啟動分割流程 |
Siglip | 特征提取 | 支持對象分組和文本查詢 |
技術細節與優勢
- 分割的準確性:結合 YOLOv8、SAM1 和 SAM2,確保對象在不同視圖中的一致分割。
- 特征的魯棒性:Siglip 的特征提取支持跨視圖的對象分組,SLERP 處理重疊掩碼增強一致性。
- 全局優化的復雜性:使用圖優化技術(如最小生成樹初始化)確保 3D 點的準確對齊。
demo.py 中的 3D 重建流程
demo.py 中的工作流程利用這些組件進行圖像分割、特征提取和全局對齊,從而生成高質量的 3D 模型。demo.py文件使用Models類從一組圖像進行3D重建
get_reconstructed_scene函數是核心過程。
流程包括使用YOLOv8檢測圖像中的對象,然后用SAM1和SAM2進行分割。MASt3R用于多視圖立體重建,獲取深度和姿態。
get_cog_feats函數使用SAM2初始化狀態,并通過視頻傳播分割掩碼。每個幀的掩碼被裁剪、調整大小后通過Siglip提取特征。3D重建的魔力在于結合對象檢測、分割、特征提取和多視圖立體技術。全局優化確保所有視圖一致對齊。
- 圖像加載與準備:
- 使用 Images 類加載輸入圖像列表,準備用于后續處理。
- 如果圖像少于 2 張,拋出錯誤,確保有足夠視圖進行重建。
- 對象檢測與分割:
- YOLOv8 檢測:對圖像運行 YOLOv8,獲取對象邊界框,設置置信度閾值為 0.25,IOU 閾值為 0.95。
- SAM1 分割:基于 YOLOv8 的邊界框,使用 MobileSAMv2 和自定義掩碼解碼器生成精確的分割掩碼。
- SAM2 傳播:初始化 SAM2 狀態,使用第一幀的 SAM1 掩碼,之后通過 propagate_in_video 在序列中傳播掩碼。
- NMS 過濾:使用非最大抑制(NMS)過濾重疊掩碼,確保分割結果的唯一性。
- 特征提取:
- 在 get_cog_feats 函數中,對每個幀的每個分割掩碼:
- 裁剪對應區域,填充為正方形,調整大小為 256x256。
- 使用 Siglip 處理這些圖像片段,提取特征向量。
- 對于重疊的掩碼,進行球面線性插值(SLERP)以合并特征,確保特征的一致性。
- 最終生成 multi_view_clip_feats,每個對象 ID 對應一個特征向量,跨視圖平均。
- 在 get_cog_feats 函數中,對每個幀的每個分割掩碼:
- 多視圖立體視覺:
- 使用 make_pairs 函數根據場景圖類型(complete、swin 或 oneref)生成圖像對。
- 運行 MASt3R 推理,估計每對圖像的深度和姿態,輸出匹配和深度圖。
- 全局對齊:
- 使用 global_aligner 優化相機姿態和 3D 點:
- 模式為 PointCloudOptimizer(多于 2 張圖像)或 PairViewer(2 張圖像)。
- 利用分割圖(cog_seg_maps 和 rev_cog_seg_maps)和特征(cog_feats)指導對齊。
- 優化過程包括多次迭代(默認 300 次),使用線性或余弦調度調整學習率。
- 使用 global_aligner 優化相機姿態和 3D 點:
- 3D 模型生成:
- 使用 get_3D_model_from_scene 將對齊后的點云或網格導出為 GLB 文件。
- 支持選項如點云顯示、天空掩碼、深度清理和相機透明度。
get_reconstructed_scene 方法分析
其中?get_reconstructed_scene 函數展示了 3D 重建的詳細步驟:
def get_reconstructed_scene(outdir, pe3r, device, silent, filelist, schedule, niter, min_conf_thr,as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,scenegraph_type, winsize, refid):"""from a list of images, run dust3r inference, global aligner.then run get_3D_model_from_scene"""if len(filelist) < 2:raise gradio.Error("Please input at least 2 images.")images = Images(filelist=filelist, device=device)# try:cog_seg_maps, rev_cog_seg_maps, cog_feats = get_cog_feats(images, pe3r)imgs = load_images(images, rev_cog_seg_maps, size=512, verbose=not silent)# except Exception as e:# rev_cog_seg_maps = []# for tmp_img in images.np_images:# rev_seg_map = -np.ones(tmp_img.shape[:2], dtype=np.int64)# rev_cog_seg_maps.append(rev_seg_map)# cog_seg_maps = rev_cog_seg_maps# cog_feats = torch.zeros((1, 1024))# imgs = load_images(images, rev_cog_seg_maps, size=512, verbose=not silent)if len(imgs) == 1:imgs = [imgs[0], copy.deepcopy(imgs[0])]imgs[1]['idx'] = 1if scenegraph_type == "swin":scenegraph_type = scenegraph_type + "-" + str(winsize)elif scenegraph_type == "oneref":scenegraph_type = scenegraph_type + "-" + str(refid)pairs = make_pairs(imgs, scene_graph=scenegraph_type, prefilter=None, symmetrize=True)output = inference(pairs, pe3r.mast3r, device, batch_size=1, verbose=not silent)mode = GlobalAlignerMode.PointCloudOptimizer if len(imgs) > 2 else GlobalAlignerMode.PairViewerscene_1 = global_aligner(output, cog_seg_maps, rev_cog_seg_maps, cog_feats, device=device, mode=mode, verbose=not silent)lr = 0.01# if mode == GlobalAlignerMode.PointCloudOptimizer:loss = scene_1.compute_global_alignment(tune_flg=True, init='mst', niter=niter, schedule=schedule, lr=lr)try:import torchvision.transforms as tvfImgNorm = tvf.Compose([tvf.ToTensor(), tvf.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])for i in range(len(imgs)):# print(imgs[i]['img'].shape, scene.imgs[i].shape, ImgNorm(scene.imgs[i])[None])imgs[i]['img'] = ImgNorm(scene_1.imgs[i])[None]pairs = make_pairs(imgs, scene_graph=scenegraph_type, prefilter=None, symmetrize=True)output = inference(pairs, pe3r.mast3r, device, batch_size=1, verbose=not silent)mode = GlobalAlignerMode.PointCloudOptimizer if len(imgs) > 2 else GlobalAlignerMode.PairViewerscene = global_aligner(output, cog_seg_maps, rev_cog_seg_maps, cog_feats, device=device, mode=mode, verbose=not silent)ori_imgs = scene.ori_imgslr = 0.01# if mode == GlobalAlignerMode.PointCloudOptimizer:loss = scene.compute_global_alignment(tune_flg=False, init='mst', niter=niter, schedule=schedule, lr=lr)except Exception as e:scene = scene_1scene.imgs = ori_imgsscene.ori_imgs = ori_imgsprint(e)outfile = get_3D_model_from_scene(outdir, silent, scene, min_conf_thr, as_pointcloud, mask_sky,clean_depth, transparent_cams, cam_size)# also return rgb, depth and confidence imgs# depth is normalized with the max value for all images# we apply the jet colormap on the confidence mapsrgbimg = scene.imgsdepths = to_numpy(scene.get_depthmaps())confs = to_numpy([c for c in scene.im_conf])# confs = to_numpy([c for c in scene.conf_2])cmap = pl.get_cmap('jet')depths_max = max([d.max() for d in depths])depths = [d / depths_max for d in depths]confs_max = max([d.max() for d in confs])confs = [cmap(d / confs_max) for d in confs]imgs = []for i in range(len(rgbimg)):imgs.append(rgbimg[i])imgs.append(rgb(depths[i]))imgs.append(rgb(confs[i]))return scene, outfile, imgs
函數輸入與初始檢查
get_reconstructed_scene 函數接受多個參數,包括輸出目錄、模型實例、設備、靜默模式、圖像文件列表、調度方式、迭代次數、最小置信度閾值等。至少包含 2 張圖像確保了有足夠的多視圖信息進行重建。
圖像加載與準備
函數創建 Images 對象,加載輸入圖像,并調用 get_cog_feats 獲取認知分割圖(cog_seg_maps)、反向認知分割圖(rev_cog_seg_maps)和認知特征(cog_feats)。這些步驟涉及對象檢測和分割:
- 對象檢測:使用 YOLOv8 檢測圖像中的對象,提供邊界框,置信度閾值為 0.25,IOU 閾值為 0.95。
- 圖像分割:通過 MobileSAMv2 和 SAM1 生成精確的分割掩碼,結合 SAM2 在序列中傳播這些掩碼,確保跨視圖的一致性。
- 特征提取:對每個分割區域裁剪、填充為正方形(256x256),使用 Siglip 提取特征向量。對于重疊掩碼,通過球面線性插值(SLERP)合并特征。
如果 get_cog_feats 失敗,函數會回退到默認分割圖(全為 -1)和零特征向量,確保流程繼續。
單圖像處理
如果輸入僅有一張圖像,函數會復制該圖像生成兩張,確保可以進行對齊和重建。這是為了處理邊緣情況,維持多視圖立體視覺的必要性。
場景圖配置
根據 scenegraph_type(complete、swin 或 oneref),函數調整場景圖參數:
- 如果為 “swin”,追加窗口大小;如果為 “oneref”,追加參考 ID。這些參數影響后續圖像對的生成。
圖像對生成與 MASt3R 推理
使用 make_pairs 函數根據場景圖生成圖像對,參數包括 scene_graph 類型、對稱化(symmetrize=True)等。然后調用 inference 函數,使用 pe3r.mast3r(MASt3R 模型)估計每對圖像的深度和姿態:
- MASt3R 是一個多視圖立體視覺模型,核心功能是生成 3D 點云和相機姿態。
- 推理過程批次大小為 1,是否顯示詳細信息由 silent 控制。
全局對齊優化
函數使用 global_aligner 進行全局對齊,模式根據圖像數量選擇:
- 如果圖像多于 2 張,使用 PointCloudOptimizer 模式;否則使用 PairViewer 模式。
- 對齊過程利用 cog_seg_maps、rev_cog_seg_maps 和 cog_feats,這些分割和特征信息指導 3D 點的分組和優化。
- 調用 compute_global_alignment 進行優化,初始化為最小生成樹(init='mst'),迭代次數為 niter(默認 300),學習率 lr=0.01,調度方式為 schedule(線性或余弦)。
二次推理與對齊
函數嘗試二次優化:
- 導入 torchvision 進行圖像歸一化(均值為 0.5,標準差為 0.5)。
- 更新 imgs 中的圖像數據,重復圖像對生成、MASt3R 推理和全局對齊步驟。
- 如果失敗,回退到第一次對齊結果,保持 scene_1 并更新圖像數據。
3D 模型生成
調用 get_3D_model_from_scene 生成最終 3D 模型:
- 提取場景中的 RGB 圖像、3D 點、掩碼、焦距和相機姿態。
- 支持后處理選項,如清理點云(clean_depth)、掩蓋天空(mask_sky)。
- 將結果導出為 GLB 文件,支持點云(as_pointcloud=True)或網格顯示,相機大小由 cam_size 控制。
返回結果
函數返回場景對象、輸出 GLB 文件路徑,以及一組圖像數組:
- 包括 RGB 圖像、歸一化深度圖(以最大值歸一化)和置信度圖(使用 jet 顏色映射)。
- 這些圖像用于可視化,深度和置信度圖幫助用戶評估重建質量。
技術細節與優勢
- 分割的準確性:YOLOv8 提供初始檢測,MobileSAMv2 和 SAM2 確保精確且一致的分割,減少背景噪聲。
- 特征的魯棒性:Siglip 提取的特征支持跨視圖對象分組,SLERP 處理重疊掩碼增強一致性。
- 全局優化的復雜性:使用圖優化技術(如最小生成樹初始化)確保 3D 點的準確對齊,迭代優化提升精度。
下面分析就不一一展示代碼實現了,有興趣的可以直接下載代碼對照 demo.py列舉的分析查看
除了 get_reconstructed_scene,其他方法如 _convert_scene_output_to_glb、get_3D_model_from_scene 等在后處理、分割、特征提取和用戶交互中扮演關鍵角色。
其他方法分析
mask_to_box
- 功能:從掩碼生成邊界框(左、上、右、下)。
- 作用:
- 將分割掩碼轉換為邊界框格式,便于后續裁剪和特征提取。
- 關鍵步驟:
- 計算掩碼中非零值的邊界,生成 [left, top, right, bottom]。
- 如果掩碼為空,返回零邊界框。
- 為什么重要:為圖像裁剪提供定位信息。
pad_img
- 功能:將圖像填充為正方形,保持寬高比。
- 作用:
- 標準化圖像尺寸,適配 Siglip 的輸入要求(256x256)。
- 關鍵步驟:
- 創建最大邊長的零矩陣,將圖像居中填充。
- 為什么重要:確保特征提取輸入一致。
get_cog_feats
- 功能:提取圖像序列的分割圖和特征。
- 作用:
- 生成認知分割圖(cog_seg_maps)、反向分割圖(rev_cog_seg_maps)和多視圖特征(cog_feats)。
- 關鍵步驟:
- 使用 SAM2 傳播掩碼,結合 SAM1 添加新掩碼。
- 對每幀分割區域裁剪、填充,提取 Siglip 特征。
- 使用 SLERP 合并重疊特征,生成多視圖特征。
- 為什么重要:為全局對齊提供對象級信息。
set_scenegraph_options
- 功能:根據場景圖類型調整 UI 參數。
- 作用:
- 配置滑動窗口(swin)或參考幀(oneref)的參數。
- 關鍵步驟:
- 根據圖像數量動態設置窗口大小和參考 ID。
- 為什么重要:優化圖像對生成策略。
get_mask_from_img_sam1
- 功能:使用 YOLOv8 和 MobileSAMv2 從圖像生成分割掩碼。
- 作用:
- 提供初始幀的精確分割結果。
- 關鍵步驟:
- YOLOv8 檢測邊界框,MobileSAMv2 生成掩碼。
- 分批處理(每批 320 個),過濾小面積掩碼(<0.2%),應用 NMS。
- 為什么重要:為 SAM2 提供初始掩碼,支持跨幀傳播。
get_seg_img
- 功能:根據掩碼和邊界框從圖像中裁剪分割區域。
- 作用:
- 提取特定對象的圖像片段,用于特征提取。
- 關鍵步驟:
- 根據掩碼面積與邊界框面積的比率,決定背景填充方式(黑色或隨機噪聲)。
- 裁剪圖像,返回分割區域。
- 為什么重要:為 Siglip 特征提取準備輸入。
mask_nms
- 功能:對一組分割掩碼執行非最大抑制(NMS),去除重疊掩碼。
- 作用:
- 確保每個對象只保留一個主要掩碼,避免重復分割。
- 關鍵步驟:
- 計算掩碼之間的交集占比(IOU),如果超過閾值(默認 0.8),抑制較小的掩碼。
- 返回保留的掩碼索引列表。
- 為什么重要:在對象分割中防止冗余,提高后續特征提取和對齊的準確性。
- 示例:從椅子照片中檢測到多個重疊掩碼,mask_nms 保留最大的一個。
?_convert_scene_output_to_glb
- 功能:將重建的場景數據(RGB 圖像、3D 點、掩碼、焦距和相機姿態)轉換為 GLB 文件格式,用于 3D 模型的導出和可視化。
- 作用:
- 將 3D 點云或網格與顏色信息結合,生成可視化的 3D 模型。
- 添加相機位置和方向,便于理解拍攝視角。
- 支持點云(as_pointcloud=True)或網格(as_pointcloud=False)兩種表示方式。
- 關鍵步驟:
- 如果選擇點云模式,合并所有點的坐標和顏色,創建 trimesh.PointCloud。
- 如果選擇網格模式,逐幀生成網格并合并為單一 trimesh.Trimesh。
- 使用 trimesh.Scene 添加相機,應用坐標變換(繞 Y 軸旋轉 180 度),導出為 GLB 文件。
這是 3D 重建的最終輸出步驟,將內部數據結構轉換為用戶可交互的格式。
綜上這些方法共同支持 PE3R 的 3D 重建流程:
- 分割與特征提取:get_mask_from_img_sam1、get_cog_feats 等提供對象級信息。
- 模型生成:_convert_scene_output_to_glb、get_3D_model_from_scene 完成 3D 輸出。
上面代碼代碼種涉及的兩個模型,作者只提供了pt文件并沒有提供訓練方式
下載地址在?https://github.com/hujiecpp/PE3R/releases/tag/checkpoints.
再多聊幾句這個 PromptModelPredictor(DetectionPredictor)
PromptModelPredictor 類是基于 Ultralytics YOLO 框架實現的自定義檢測預測器,定義在 PE3R 項目的某個模塊中(可能與 ObjectAwareModel 相關)。它通過繼承 DetectionPredictor,并重寫 __init__、adjust_bboxes_to_image_border 和 postprocess 方法,實現了特定的對象檢測和邊界框處理功能。
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):super().__init__(cfg, overrides, _callbacks)self.args.task = 'segment'
初始化預測器,設置任務類型為“分割”。后續檢測任務奠定基礎,確保與 YOLO 框架兼容。
def adjust_bboxes_to_image_border(self, boxes, image_shape, threshold=20): h, w = image_shapeboxes[:, 0] = torch.where(boxes[:, 0] < threshold, torch.tensor(0, dtype=torch.float, device=boxes.device), boxes[:, 0]) # x1boxes[:, 1] = torch.where(boxes[:, 1] < threshold, torch.tensor(0, dtype=torch.float, device=boxes.device), boxes[:, 1]) # y1boxes[:, 2] = torch.where(boxes[:, 2] > w - threshold, torch.tensor(w, dtype=torch.float, device=boxes.device), boxes[:, 2]) # x2boxes[:, 3] = torch.where(boxes[:, 3] > h - threshold, torch.tensor(h, dtype=torch.float, device=boxes.device), boxes[:, 3]) # y2return boxes
- 調整邊界框坐標,確保其不超出圖像邊界并避免過于靠近邊緣。
- 作用:將靠近圖像邊緣(小于 threshold=20 像素)的坐標設置為邊界值(0 或圖像寬高)。
-
- boxes[:, 0](x1):如果小于 20,設為 0。
- boxes[:, 1](y1):如果小于 20,設為 0。
- boxes[:, 2](x2):如果大于寬度-20,設為寬度。
- boxes[:, 3](y2):如果大于高度-20,設為高度。
- 為什么重要:防止邊界框超出圖像范圍或過于貼近邊緣,確保后續分割或特征提取的有效性。
def postprocess(self, preds, img, orig_imgs):p = ops.non_max_suppression(preds[0], self.args.conf, self.args.iou, agnostic=self.args.agnostic_nms, max_det=self.args.max_det, nc=len(self.model.names), classes=self.args.classes)results = []if len(p) == 0 or len(p[0]) == 0:print("No object detected.")return resultsfull_box = torch.zeros_like(p[0][0])full_box[2], full_box[3], full_box[4], full_box[6:] = img.shape[3], img.shape[2], 1.0, 1.0full_box = full_box.view(1, -1)self.adjust_bboxes_to_image_border(p[0][:, :4], img.shape[2:]) for i, pred in enumerate(p):orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgspath = self.batch[0]img_path = path[i] if isinstance(path, list) else pathif not len(pred): results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6]))continueif self.args.retina_masks:if not isinstance(orig_imgs, torch.Tensor):pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)else:if not isinstance(orig_imgs, torch.Tensor):pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=torch.zeros_like(img)))return results
- 對 YOLO 模型的預測結果進行后處理,返回檢測結果。
- 作用:
- 應用非最大抑制(NMS)過濾重疊邊界框。
- 調整邊界框坐標,適配原始圖像尺寸。
- 返回 Results 對象,包含邊界框信息(未生成實際掩碼)。
- 關鍵步驟:
- NMS:使用 ops.non_max_suppression 過濾預測結果,基于置信度(conf)、IOU(iou)和最大檢測數(max_det)。
- 邊界檢查:如果沒有檢測到對象,返回空結果并打印提示。
- 邊界框調整:調用 adjust_bboxes_to_image_border 修正坐標。
- 坐標縮放:使用 ops.scale_boxes 將邊界框從輸入圖像尺寸縮放到原始圖像尺寸。
- 結果封裝:創建 Results 對象,包含圖像、路徑、類別名稱和邊界框(boxes),掩碼默認為零。
- 將原始檢測結果轉換為標準格式,為后續 SAM 分割提供輸入。
PE3R 的上下文中,PromptModelPredictor 可能被實例化為 Models 類中的 yolov8 組件(ObjectAwareModel)。其作用包括:
- 提供初始檢測:為 get_mask_from_img_sam1 提供邊界框,啟動精確分割流程。
- 支持多視圖一致性:通過檢測對象位置,幫助 SAM2 在圖像序列中傳播掩碼。
- 集成到 3D 重建:邊界框信息間接支持特征提取和全局對齊。