PyTorch 中mm和bmm函數的使用詳解

torch.mm 是 PyTorch 中用于 二維矩陣乘法(matrix-matrix multiplication) 的函數,等價于數學中的 A × B 矩陣乘積。


一、函數定義

torch.mm(input, mat2) → Tensor

執行的是兩個 2D Tensor(矩陣)的標準矩陣乘法。

  • input: 第一個二維張量,形狀為 (n × m)
  • mat2: 第二個二維張量,形狀為 (m × p)
  • 返回:形狀為 (n × p) 的張量

二、使用條件和注意事項

條件說明
僅支持 2D 張量一維或三維以上使用 torch.matmul@ 操作符
維度要匹配input.shape[1] == mat2.shape[0]
不支持廣播兩個矩陣維度不匹配會直接報錯
結果是普通矩陣乘積不是逐元素乘法(Hadamard),即不是 *torch.mul()

三、示例代碼

示例 1:基本矩陣乘法

import torchA = torch.tensor([[1., 2.], [3., 4.]])   # 2x2
B = torch.tensor([[5., 6.], [7., 8.]])   # 2x2C = torch.mm(A, B)
print(C)

輸出:

tensor([[19., 22.],[43., 50.]])

計算步驟:

C[0][0] = 1*5 + 2*7 = 19
C[0][1] = 1*6 + 2*8 = 22
...

示例 2:不匹配維度導致報錯

A = torch.rand(2, 3)
B = torch.rand(4, 2)
C = torch.mm(A, B)  # ? 會報錯

報錯:

RuntimeError: mat1 and mat2 shapes cannot be multiplied (2x3 and 4x2)

示例 3:推薦寫法(推薦使用 @matmul

A = torch.rand(3, 4)
B = torch.rand(4, 5)C1 = torch.mm(A, B)
C2 = A @ B                # 推薦用法
C3 = torch.matmul(A, B)   # 推薦用法

四、與其他乘法函數的比較

函數名支持維度運算類型支持廣播
torch.mm僅限二維矩陣乘法? 不支持
torch.matmul1D, 2D, ND自動判斷點乘 / 矩陣乘? 支持
torch.bmm批量二維乘法3D Tensor batch × batch? 不支持
torch.mul任意維度元素乘(Hadamard)? 支持
* 運算符任意維度元素乘? 支持
@ 運算符ND(推薦用)矩陣乘法(和 matmul 一樣)?

五、典型應用場景

  • 神經網絡權重乘法:output = torch.mm(W, x)
  • 點云 / 圖像變換:x' = torch.mm(R, x) + t
  • 多層感知機中的矩陣計算
  • 注意力機制中 QK^T 乘積

六、總結:什么時候用 mm

使用場景用什么
僅二維矩陣乘法torch.mm
高維或支持廣播乘法torch.matmul / @
批量矩陣乘法 (如 batch_size×3×3)torch.bmm
元素乘torch.mul or *

在 PyTorch 中,torch.bmm批量矩陣乘法(batch matrix multiplication) 的操作,專用于處理三維張量(batch of matrices)。它的主要作用是對一組矩陣成對進行乘法,效率遠高于手動循環計算。


一、torch.bmm 語法

torch.bmm(input, mat2, *, out=None) → Tensor
  • input: Tensor,形狀為 (B, N, M)
  • mat2: Tensor,形狀為 (B, M, P)
  • 返回結果形狀為 (B, N, P)

這表示對 BN×MM×P 的矩陣進行成對相乘。


二、示例演示

示例 1:基礎用法

import torch# 定義兩個 batch 矩陣
A = torch.randn(4, 2, 3)  # shape: (B=4, N=2, M=3)
B = torch.randn(4, 3, 5)  # shape: (B=4, M=3, P=5)# 批量矩陣乘法
C = torch.bmm(A, B)       # shape: (4, 2, 5)print(C.shape)  # 輸出: torch.Size([4, 2, 5])

示例 2:手動循環 vs bmm 效率對比

# 慢速手動方式
C_manual = torch.stack([A[i] @ B[i] for i in range(A.size(0))])# 等效于 bmm
C_bmm = torch.bmm(A, B)print(torch.allclose(C_manual, C_bmm))  # True

三、注意事項

1. 維度必須是三維張量

  • 否則會報錯:
RuntimeError: batch1 must be a 3D tensor

你可以通過 .unsqueeze() 手動調整維度:

a = torch.randn(2, 3)
b = torch.randn(3, 4)# 升維
a_batch = a.unsqueeze(0)  # (1, 2, 3)
b_batch = b.unsqueeze(0)  # (1, 3, 4)c = torch.bmm(a_batch, b_batch)  # (1, 2, 4)

2. 維度必須滿足矩陣乘法規則

  • (B, N, M) × (B, M, P)(B, N, P)
  • M 不一致會報錯:
RuntimeError: Expected size for the second dimension of batch2 tensor to match the first dimension of batch1 tensor

3. bmm 不支持廣播(broadcasting)

  • 必須顯式提供相同的 batch size。
  • 如果只有一個矩陣固定,可以使用 .expand()
A = torch.randn(1, 2, 3)  # 單個矩陣
B = torch.randn(4, 3, 5)  # 4 個矩陣# 擴展 A 以進行 batch 乘法
A_expand = A.expand(4, -1, -1)
C = torch.bmm(A_expand, B)  # (4, 2, 5)

四、在實際應用中的例子

在點云變換中:批量乘旋轉矩陣

# 假設有 B 個旋轉矩陣和點坐標
R = torch.randn(B, 3, 3)       # 旋轉矩陣
points = torch.randn(B, 3, N)  # 點云# 先轉置點坐標為 (B, N, 3)
points_T = points.transpose(1, 2)  # (B, N, 3)# 用 bmm 做點變換:每組點乘旋轉
transformed = torch.bmm(points_T, R.transpose(1, 2))  # (B, N, 3)

五、總結

特性torch.bmm
操作對象三維張量(batch of matrices)
核心規則(B, N, M) x (B, M, P) = (B, N, P)
是否支持廣播? 不支持,需要手動 .expand()
matmul 區別matmul 支持更多廣播,bmm 更高效用于純批量矩陣乘法
應用場景批量線性變換、點云配準、神經網絡前向傳播等

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/910304.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/910304.shtml
英文地址,請注明出處:http://en.pswp.cn/news/910304.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Qt 解析復雜對象構成

Qt 解析復雜對象構成 dumpStructure 如 QComboBox / QCalendarWidget / QSpinBox … void Widget::Widget(QWidget* parent){auto c new QCalendarWidget(this);dumpStructure(c,4); }void Widget::dumpStructure(const QObject *obj, int spaces) {qDebug() << QString…

山姆·奧特曼:從YC到OpenAI,硅谷創新之星的崛起

名人說&#xff1a;路漫漫其修遠兮&#xff0c;吾將上下而求索。—— 屈原《離騷》 創作者&#xff1a;Code_流蘇(CSDN)&#xff08;一個喜歡古詩詞和編程的Coder&#x1f60a;&#xff09; 山姆奧特曼&#xff1a;從YC到OpenAI&#xff0c;硅谷創新之星的崛起 在人工智能革命…

PHP語法基礎篇(五):流程控制

任何 PHP 腳本都是由一系列語句構成的。一條語句可以是一個賦值語句&#xff0c;一個函數調用&#xff0c;一個循環&#xff0c;一個條件語句或者甚至是一個什么也不做的語句&#xff08;空語句&#xff09;。語句通常以分號結束。此外&#xff0c;還可以用花括號將一組語句封裝…

怎么隱藏關閉或恢復顯示輸入法的懸浮窗

以搜狗輸入法為例&#xff0c;隱藏輸入法懸浮窗 懸浮窗在輸入法里的官方叫法為【狀態欄】。 假設目前大家的輸入法相關顯示呈現如下狀態&#xff1a; 那我們只需在輸入法懸浮窗&#xff08;狀態欄&#xff09;的任意位置鼠標右鍵單擊&#xff0c;調出輸入法菜單&#xff0c;就…

Electron (02)集成 SpringBoot:服務與桌面程序協同啟動方案

本篇是關于把springboot生成的jar打到electron里&#xff0c;在生成的桌面程序啟動時springboot服務就會自動啟動。 雖然之后并不需要這種方案&#xff0c;更好的是部署[一套服務端&#xff0c;多個客戶端]...但是既然搭建成功了&#xff0c;也記錄一下。 前端文件 1、main.js…

2025年計算機應用與神經網絡國際會議(CANN 2025)

2025 International Conference on Computer Applications and Neural Networks &#xff08;一&#xff09;會議信息 會議簡稱&#xff1a;CANN 2025 大會地點&#xff1a;中國重慶 收錄檢索&#xff1a;提交Ei Compendex,CPCI,CNKI,Google Scholar等 &#xff08;二&#x…

振動分析中的低頻噪聲問題:從理論到實踐的完整解決方案

前言 在振動監測和結構健康監測領域&#xff0c;我們經常需要從加速度信號計算速度和位移。然而&#xff0c;許多工程師在實際應用中都會遇到一個令人困擾的問題&#xff1a;通過積分計算得到的速度和位移頻譜中低頻噪聲異常放大。 本文將深入分析這個問題的根本原因&#xf…

ncu學習筆記01——合并訪存

全局內存通過緩存實現加載和存儲過程。其中&#xff0c;L1為一級緩存&#xff0c;每個SM都有自己的L1&#xff1b;L2為二級緩存&#xff0c;L2則被所有SM共有。 數據從全局內存到SM的傳輸過程中&#xff0c;會去L1和L2中查詢是否有緩存。對全局內存的訪問將經過L1&#xff1b;…

2012 - 正方形矩陣

????題目描述 晶晶同學非常喜歡方形&#xff0c;她希望打印出來的字符串也是方形的。老師給了晶晶同學一個字符串"ACM"&#xff0c;晶晶同學突發奇想&#xff0c;如果任意給定義一個整數n&#xff0c;能不能打印出由這個字符串組成的正方形字符串呢&#xff1f;…

C++中set的常見用法

在 C 里&#xff0c;std::set屬于標準庫容器的一種&#xff0c;其特性是按照特定順序存儲唯一的元素。下面為你詳細介紹它的常見使用方法&#xff1a; 1. 頭文件引入 要使用std::set&#xff0c;需要在代碼中包含相應的頭文件&#xff1a; #include <set> 2. 集合的定…

stm32移植freemodbus

1、設置串口 開啟串口中斷 2、設置定時器 已知在freemodbus中默認定義&#xff1a;當波特率大于19200時&#xff0c;判斷一幀數據超時時間固定為1750us&#xff0c;當波特率小于19200時&#xff0c;超時時間為3.5個字符時間。這里移植的是115200&#xff0c;所以一幀數據超時…

鴻蒙next 使用canvas實現ecg動態波形繪制

該代碼可在Arkts 與 前端使用&#xff0c;基于canvas 倉庫地址&#xff1a;https://gitee.com/harmony_os_example/harmony-os-ecg-waveform.git 代碼中的list數組為波形數據&#xff0c;該示例需要根據自己業務替換繪制頻率&#xff0c;波形數據&#xff0c;ecg原始數據生成…

基于原生能力的鍵盤控制

基于原生能力的鍵盤控制 前言一、進入頁面TextInput獲焦1、方案2、核心代碼 二、點擊按鈕或其他事件觸發TextInput獲焦1、方案2、核心代碼 三、鍵盤彈出后只上抬特定的輸入組件1、方案2、核心代碼 四、監聽鍵盤高度1、方案2、核心代碼 五、設置窗口在鍵盤抬起時的頁面避讓模式為…

大數據治理域——數據存儲與成本管理

摘要 本文主要探討了數據存儲與成本管理的多種策略。介紹了數據壓縮技術&#xff0c;如MaxCompute的archive壓縮方法&#xff0c;通過RAID file形式存儲數據&#xff0c;可有效節省空間&#xff0c;但恢復時間較長&#xff0c;適用于冷備與日志數據。還詳細闡述了數據生命周期…

國產Linux銀河麒麟操作系統上使用自帶openssh遠程工具SSH方式登陸華為交換機或服務器

在Windows和Linux Debian系統上我一直使用electerm遠程工具訪問服務器或交換機&#xff0c; 一、 electerm簡介 簡介&#xff1a;electerm是一款開源免費的SSH工具&#xff0c;具有良好的跨平臺兼容性&#xff0c;適用于Windows、macOS、Linux以及麒麟操作系統。特點&#xf…

Logback 在java中的使用

Logback 是 Java 應用中廣泛使用的日志框架&#xff0c;以下是其核心使用方法及最佳實踐&#xff1a; 1. 引入依賴 在 Maven 或 Gradle 項目中添加 Logback 及 SLF4J 依賴&#xff1a; <!-- Maven --> <dependency><groupId>ch.qos.logback</groupId>…

Axure應用交互設計:中繼器—整行、條件行、當前行賦值

親愛的小伙伴,如有幫助請訂閱專欄!跟著老師每課一練,系統學習Axure交互設計課程! Axure產品經理精品視頻課https://edu.csdn.net/course/detail/40420 課程主題:對中繼器中:整行、符合某種條件的任意行、當前行的賦值操作 課程視頻:

ToolsSet之:TTS及Morse編解碼

ToolsSet是微軟商店中的一款包含數十種實用工具數百種細分功能的工具集合應用&#xff0c;應用基本功能介紹可以查看以下文章&#xff1a; Windows應用ToolsSet介紹https://blog.csdn.net/BinField/article/details/145898264其中Text菜單中的TTS & Morse可用于將文本轉換…

【C++】編碼傳輸:創建零拷貝幀對象4:shared_ptr轉unique_ptr給到rtp打包

【C++】編碼傳輸:創建零拷貝幀對象3: dll api轉換內部的共享內存根本原因 你想要的是基于 packet 指向的那個已有對象,拷貝(或移動)出一個新的 VideoDataPacket3 實例,因此需要把那個對象本身傳進去——也就是 *packet。copilot的原因分析與gpt一致 The issue is with t…

基于UDP的套接字通信

udp是一個面向無連接的&#xff0c;不安全的&#xff0c;報式傳輸層協議&#xff0c;udp的通信過程默認也是阻塞的。使用UDP進行通信&#xff0c;服務器和客戶端的處理步驟比TCP要簡單很多&#xff0c;并且兩端是對等的 &#xff08;通信的處理流程幾乎是一樣的&#xff09;&am…