PETR和位置編碼

PETR和位置編碼

petr檢測網絡中有2種類型的位置編碼。
正弦編碼和petr論文提出的3D Position Embedding。transformer模塊輸入除了qkv,還有query_pos和key_pos。這里重點記錄下query_pos和key_pos的生成

  • query pos的生成
    先定義reference_points, shape為(n_query, 3),編碼部分有兩部分構成,經過pos2posemb3d編碼(sin編碼)后,再用FFN(query_embed)編碼一次后用作transformer的query_pos. 至于為什么多了個一次FFN編碼,GPT這么解釋的:

這種兩步編碼的設計實際上是將固定的位置編碼(pos2posemb3d)和可學習的位置編碼(query_embedding)相結合,既保留了位置的幾何信息,又允許模型學習任務相關的位置表示。這種設計在3D視覺任務中特別有效,因為它既考慮了空間的周期性特征,又保持了位置編碼的可學習性。

  1. pos2posemb3d(sin編碼)
    標準的正弦編碼
def pos2posemb3d(pos, num_pos_feats=128, temperature=10000):scale = 2 * math.pipos = pos * scale # map pos from [-1, 1] to [-2pi, 2pi]dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos.device)dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats)pos_x = pos[..., 0, None] / dim_tpos_y = pos[..., 1, None] / dim_tpos_z = pos[..., 2, None] / dim_tpos_x = torch.stack((pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()), dim=-1).flatten(-2)pos_y = torch.stack((pos_y[..., 0::2].sin(), pos_y[..., 1::2].cos()), dim=-1).flatten(-2)pos_z = torch.stack((pos_z[..., 0::2].sin(), pos_z[..., 1::2].cos()), dim=-1).flatten(-2)posemb = torch.cat((pos_y, pos_x, pos_z), dim=-1)return posemb
  1. query_embed(FFN)
    self.query_embedding = nn.Sequential(nn.Linear(self.embed_dims*3//2, self.embed_dims),nn.ReLU(),nn.Linear(self.embed_dims, self.embed_dims),)
    
  • key_pos的生成
    對于二維目標檢測來說,對像素位置做編碼就行了(sin_embed), 如下圖的backbone這個分支,對于三維目標檢測,petr對每個像素還做了三維位置編碼(coords_position_embeding), 下圖最下面一個分支。
    最終給transformer的key_pos = 3d位置編碼+ 2d像素位置編碼
    在這里插入圖片描述

    1. 像素的3d位置編碼
      根據圖像尺寸定義一個視錐空間(coords),每個點用(u, v, d)表示,結合相機內參,可以將其轉為世界坐標系下的點(coords3d),在用position_encoder(卷積)處理得到位置編碼。
      在這里插入圖片描述
    def position_embeding(self, img_feats, img_metas, masks=None):eps = 1e-5pad_h, pad_w, _ = img_metas[0]['pad_shape'][0]B, N, C, H, W = img_feats[self.position_level].shapecoords_h = torch.arange(H, device=img_feats[0].device).float() * pad_h / Hcoords_w = torch.arange(W, device=img_feats[0].device).float() * pad_w / Wif self.LID:index  = torch.arange(start=0, end=self.depth_num, step=1, device=img_feats[0].device).float()index_1 = index + 1bin_size = (self.position_range[3] - self.depth_start) / (self.depth_num * (1 + self.depth_num))coords_d = self.depth_start + bin_size * index * index_1else:index  = torch.arange(start=0, end=self.depth_num, step=1, device=img_feats[0].device).float()bin_size = (self.position_range[3] - self.depth_start) / self.depth_numcoords_d = self.depth_start + bin_size * indexD = coords_d.shape[0]coords = torch.stack(torch.meshgrid([coords_w, coords_h, coords_d])).permute(1, 2, 3, 0) # W, H, D, 3coords = torch.cat((coords, torch.ones_like(coords[..., :1])), -1)coords[..., :2] = coords[..., :2] * torch.maximum(coords[..., 2:3], torch.ones_like(coords[..., 2:3])*eps)img2lidars = []for img_meta in img_metas:img2lidar = []for i in range(len(img_meta['lidar2img'])):img2lidar.append(np.linalg.inv(img_meta['lidar2img'][i]))img2lidars.append(np.asarray(img2lidar))img2lidars = np.asarray(img2lidars)img2lidars = coords.new_tensor(img2lidars) # (B, N, 4, 4)coords = coords.view(1, 1, W, H, D, 4, 1).repeat(B, N, 1, 1, 1, 1, 1)img2lidars = img2lidars.view(B, N, 1, 1, 1, 4, 4).repeat(1, 1, W, H, D, 1, 1)coords3d = torch.matmul(img2lidars, coords).squeeze(-1)[..., :3]coords3d[..., 0:1] = (coords3d[..., 0:1] - self.position_range[0]) / (self.position_range[3] - self.position_range[0])coords3d[..., 1:2] = (coords3d[..., 1:2] - self.position_range[1]) / (self.position_range[4] - self.position_range[1])coords3d[..., 2:3] = (coords3d[..., 2:3] - self.position_range[2]) / (self.position_range[5] - self.position_range[2])coords_mask = (coords3d > 1.0) | (coords3d < 0.0)coords_mask = coords_mask.flatten(-2).sum(-1) > (D * 0.5)coords_mask = masks | coords_mask.permute(0, 1, 3, 2)coords3d = coords3d.permute(0, 1, 4, 5, 3, 2).contiguous().view(B*N, -1, H, W)coords3d = inverse_sigmoid(coords3d)coords_position_embeding = self.position_encoder(coords3d) # position_encoder:conv+relu+convreturn coords_position_embeding.view(B, N, self.embed_dims, H, W), coords_mask
    
    1. 像素的2d正弦編碼
      通過圖像的寬高,可以對每個像素坐標生成位置編碼
    #SinePositionalEncoding3D
    def forward(self, mask):        """Forward function for `SinePositionalEncoding`.Args:mask (Tensor): ByteTensor mask. Non-zero values representingignored positions, while zero values means valid positionsfor this image. Shape [bs, h, w].Returns:pos (Tensor): Returned position embedding with shape[bs, num_feats*2, h, w]."""# For convenience of exporting to ONNX, it's required to convert# `masks` from bool to int.mask = mask.to(torch.int)not_mask = 1 - mask  # logical_notn_embed = not_mask.cumsum(1, dtype=torch.float32)y_embed = not_mask.cumsum(2, dtype=torch.float32)x_embed = not_mask.cumsum(3, dtype=torch.float32)if self.normalize:n_embed = (n_embed + self.offset) / \(n_embed[:, -1:, :, :] + self.eps) * self.scaley_embed = (y_embed + self.offset) / \(y_embed[:, :, -1:, :] + self.eps) * self.scalex_embed = (x_embed + self.offset) / \(x_embed[:, :, :, -1:] + self.eps) * self.scaledim_t = torch.arange(self.num_feats, dtype=torch.float32, device=mask.device)dim_t = self.temperature**(2 * (dim_t // 2) / self.num_feats)pos_n = n_embed[:, :, :, :, None] / dim_tpos_x = x_embed[:, :, :, :, None] / dim_tpos_y = y_embed[:, :, :, :, None] / dim_t# use `view` instead of `flatten` for dynamically exporting to ONNXB, N, H, W = mask.size()pos_n = torch.stack((pos_n[:, :, :, :, 0::2].sin(), pos_n[:, :, :, :, 1::2].cos()),dim=4).view(B, N, H, W, -1)pos_x = torch.stack((pos_x[:, :, :, :, 0::2].sin(), pos_x[:, :, :, :, 1::2].cos()),dim=4).view(B, N, H, W, -1)pos_y = torch.stack((pos_y[:, :, :, :, 0::2].sin(), pos_y[:, :, :, :, 1::2].cos()),dim=4).view(B, N, H, W, -1)pos = torch.cat((pos_n, pos_y, pos_x), dim=4).permute(0, 1, 4, 2, 3)return posdef __repr__(self):"""str: a string that describes the module"""repr_str = self.__class__.__name__repr_str += f'(num_feats={self.num_feats}, 'repr_str += f'temperature={self.temperature}, 'repr_str += f'normalize={self.normalize}, 'repr_str += f'scale={self.scale}, 'repr_str += f'eps={self.eps})'return repr_str

順便記錄下未使用的可學習編碼

    @POSITIONAL_ENCODING.register_module()class LearnedPositionalEncoding3D(BaseModule):"""Position embedding with learnable embedding weights.Args:num_feats (int): The feature dimension for each positionalong x-axis or y-axis. The final returned dimension foreach position is 2 times of this value.row_num_embed (int, optional): The dictionary size of row embeddings.Default 50.col_num_embed (int, optional): The dictionary size of col embeddings.Default 50.init_cfg (dict or list[dict], optional): Initialization config dict."""def __init__(self,num_feats,row_num_embed=50,col_num_embed=50,init_cfg=dict(type='Uniform', layer='Embedding')):super(LearnedPositionalEncoding3D, self).__init__(init_cfg)self.row_embed = nn.Embedding(row_num_embed, num_feats)self.col_embed = nn.Embedding(col_num_embed, num_feats)self.num_feats = num_featsself.row_num_embed = row_num_embedself.col_num_embed = col_num_embeddef forward(self, mask):"""Forward function for `LearnedPositionalEncoding`.Args:mask (Tensor): ByteTensor mask. Non-zero values representingignored positions, while zero values means valid positionsfor this image. Shape [bs, h, w].Returns:pos (Tensor): Returned position embedding with shape[bs, num_feats*2, h, w]."""h, w = mask.shape[-2:]x = torch.arange(w, device=mask.device)y = torch.arange(h, device=mask.device)x_embed = self.col_embed(x)y_embed = self.row_embed(y)pos = torch.cat((x_embed.unsqueeze(0).repeat(h, 1, 1), y_embed.unsqueeze(1).repeat(1, w, 1)),dim=-1).permute(2, 0,1).unsqueeze(0).repeat(mask.shape[0], 1, 1, 1)return pos

參考鏈接:
https://blog.csdn.net/qq_16137569/article/details/123576866

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/79243.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/79243.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/79243.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Ubuntu搭建 Nginx以及Keepalived 實現 主備

目錄 前言1. 基本知識2. Keepalived3. 腳本配置4. Nginx前言 ?? 找工作,來萬碼優才:?? #小程序://萬碼優才/r6rqmzDaXpYkJZF 爬蟲神器,無代碼爬取,就來:bright.cn Java基本知識: java框架 零基礎從入門到精通的學習路線 附開源項目面經等(超全)【Java項目】實戰CRU…

文章記單詞 | 第56篇(六級)

一&#xff0c;單詞釋義 interview /??nt?vju?/&#xff1a; 名詞&#xff1a;面試&#xff1b;采訪&#xff1b;面談動詞&#xff1a;對… 進行面試&#xff1b;采訪&#xff1b;接見 radioactive /?re?di???kt?v/&#xff1a;形容詞&#xff1a;放射性的&#xff…

MATLAB函數調用全解析:從入門到精通

在MATLAB編程中&#xff0c;函數是代碼復用的核心單元。本文將全面解析MATLAB中各類函數的調用方法&#xff0c;包括內置函數、自定義函數、匿名函數等&#xff0c;幫助提升代碼效率&#xff01; 一、MATLAB函數概述 MATLAB函數分為以下類型&#xff1a; 內置函數&#xff1a…

哈希表筆記(二)redis

Redis哈希表實現分析 這份代碼是Redis核心數據結構之一的字典(dict)實現&#xff0c;本質上是一個哈希表的實現。Redis的字典結構被廣泛用于各種內部數據結構&#xff0c;包括Redis數據庫本身和哈希鍵類型。 核心特點 雙表設計&#xff1a;每個字典包含兩個哈希表&#xff0…

PDF嵌入隱藏的文字

所需依賴 <dependency><groupId>com.itextpdf</groupId><artifactId>itext-core</artifactId><version>9.0.0</version><type>pom</type> </dependency>源碼 /*** PDF工具*/ public class PdfUtils {/*** 在 PD…

RAG工程-基于LangChain 實現 Advanced RAG(預檢索-查詢優化)(下)

Multi-Query 多路召回 多路召回流程圖 多路召回策略利用大語言模型&#xff08;LLM&#xff09;對原始查詢進行拓展&#xff0c;生成多個與原始查詢相關的問題&#xff0c;再將原始查詢和生成的所有相關問題一同發送給檢索系統進行檢索。它適用于用戶查詢比較寬泛、模糊或者需要…

【業務領域】PCIE協議理解

PCIE協議理解 提示&#xff1a;這里可以添加系列文章的所有文章的目錄&#xff0c;目錄需要自己手動添加 PCIE學習理解。 文章目錄 PCIE協議理解[TOC](文章目錄) 前言零、PCIE掌握點&#xff1f;一、PCIE是什么&#xff1f;二、PCIE協議總結物理層切速 鏈路層事務層6.2 TLP的路…

Jupyter notebook快捷鍵

文章目錄 Jupyter notebook鍵盤模式快捷鍵&#xff08;常用的已加粗&#xff09; Jupyter notebook鍵盤模式 命令模式&#xff1a;鍵盤輸入運行程序命令&#xff1b;這時單元格框線為藍色 編輯模式&#xff1a;允許你往單元格中鍵入代碼或文本&#xff1b;這時單元格框線是綠色…

Unity圖片導入設置

&#x1f3c6; 個人愚見&#xff0c;沒事寫寫筆記 &#x1f3c6;《博客內容》&#xff1a;Unity3D開發內容 &#x1f3c6;&#x1f389;歡迎 &#x1f44d;點贊?評論?收藏 &#x1f50e;Unity支持的圖片格式 ??BMP:是Windows操作系統的標準圖像文件格式&#xff0c;特點是…

Spark-小練試刀

任務1&#xff1a;HDFS上有三份文件&#xff0c;分別為student.txt&#xff08;學生信息表&#xff09;result_bigdata.txt&#xff08;大數據基礎成績表&#xff09;&#xff0c; result_math.txt&#xff08;數學成績表&#xff09;。 加載student.txt為名稱為student的RDD…

內存安全的攻防戰:工具鏈與語言特性的協同突圍

一、內存安全&#xff1a;C 開發者永恒的達摩克利斯之劍 在操作系統內核、游戲引擎、金融交易系統等對穩定性要求苛刻的領域&#xff0c;內存安全問題始終是 C 開發者的核心挑戰。緩沖區溢出、懸空指針、雙重釋放等經典漏洞&#xff0c;每年在全球范圍內造成數千億美元的損失。…

OceanBase數據庫-學習筆記1-概論

多租戶概念 集群和分布式 隨著互聯網、物聯網和大數據技術的發展&#xff0c;數據量呈指數級增長&#xff0c;單機數據庫難以存儲和處理如此龐大的數據。現代應用通常需要支持大量用戶同時訪問&#xff0c;單機數據庫在高并發場景下容易成為性能瓶頸。單點故障是單機數據庫的…

計算機網絡——鍵入網址到網頁顯示,期間發生了什么?

瀏覽器做的第一步工作是解析 URL&#xff0c;分清協議是http還是https&#xff0c;主機名&#xff0c;路徑名&#xff0c;然后生成http消息&#xff0c;之后委托操作系統將消息發送給 Web 服務器。在發送之前&#xff0c;還需要先去查詢dns&#xff0c;首先是查詢緩存瀏覽器緩存…

Qwen3本地化部署,準備工作:SGLang

文章目錄 SGLang安裝deepseek運行Qwen3-30B-A3B官網:https://github.com/sgl-project/sglang SGLang SGLang 是一個面向大語言模型和視覺語言模型的高效服務框架。它通過協同設計后端運行時和前端編程語言,使模型交互更快速且具備更高可控性。核心特性包括: 1. 快速后端運…

全面接入!Qwen3現已上線千帆

百度智能云千帆正式上線通義千問團隊開源的最新一代Qwen3系列模型&#xff0c;包括旗艦級MoE模型Qwen3-235B-A22B、輕量級MoE模型Qwen3-30B-A3B。千帆大模型平臺開源模型進一步擴充&#xff0c;以多維開放的模型服務、全棧模型開發、應用開發工具鏈、多模態數據治理及安全的能力…

藍橋杯Python(B)省賽回憶

Q&#xff1a;為什么我要寫這篇博客&#xff1f; A&#xff1a;在藍橋杯軟件類競賽&#xff08;Python B組&#xff09;的備賽過程中我在網上搜索關于藍橋杯的資料&#xff0c;感謝你們提供的參賽經歷&#xff0c;對我的備賽起到了整體調整的幫助&#xff0c;讓我知道如何以更…

數據轉儲(go)

? 隨著時間推移&#xff0c;數據庫中的數據量不斷累積&#xff0c;可能導致查詢性能下降、存儲壓力增加等問題。數據轉儲作為一種有效的數據管理策略&#xff0c;能夠將歷史數據從生產數據庫中轉移到其他存儲介質&#xff0c;從而減輕數據庫負擔&#xff0c;提高系統性能&…

Git Stash 詳解

Git Stash 詳解 在使用 Git 進行版本控制時&#xff0c;經常會遇到需要臨時保存當前工作狀態的情況。git stash 命令就是為此設計的&#xff0c;它允許你將未提交的更改暫存起來&#xff0c;在處理其他任務或分支后&#xff0c;再恢復這些更改。 目錄 基本概念常用命令示例和…

Windows下Dify安裝及使用

Dify安裝及使用 Dify 是開源的 LLM 應用開發平臺。提供從 Agent 構建到 AI workflow 編排、RAG 檢索、模型管理等能力&#xff0c;輕松構建和運營生成式 AI 原生應用。比 LangChain 更易用。 前置條件 windows下安裝了docker環境-Windows11安裝Docker-CSDN博客 下載 Git下載…

Clang-Tidy協助C++編譯期檢查

文章目錄 在Visual Studio中啟用clang-tidyClang-tidy 常用的檢查項readability-inconsistent-declaration-parameter-namemisc-static-assert 例子 C/C語言是一門編譯型語言&#xff0c;比起python,javascript 這些&#xff0c;有很多BUG可以在編譯期被排除掉&#xff0c;當然…