點積運算要求第一個矩陣 shape:(n, m),第二個矩陣 shape: (m, p), 兩個矩陣點積運算shape為:(n,p)
- 運算符 @ 用于進行兩個矩陣的點乘運算
- torch.mm 用于進行兩個矩陣點乘運算,要求輸入的矩陣為3維 (mm 代表 mat, mul)
- torch.bmm 用于批量進行矩陣點乘運算,要求輸入的矩陣為3維 (b 代表 batch)
- torch.matmul 對進行點乘運算的兩矩陣形狀沒有限定。
a. 對于輸入都是二維的張量相當于 mm 運算
b. 對于輸入都是三維的張量相當于 bmm 運算
c. 對數輸入的shape不同的張量,對應的最后幾個維度必須符合矩陣運算規則
代碼
import torch
import numpy as np # 使用@運算符
def test01():# 形狀為:3行2列 data1 = torch.tensor([[1,2], [3,4], [5,6]])# 形狀為:2行2列data2 = torch.tensor([[5,6], [7,8]])data = data1 @ data2print(data) # 使用 mm 函數
def test02():# 要求輸入的張量形狀都是二維的# 形狀為:3行2列 data1 = torch.tensor([[1,2], [3,4], [5,6]])# 形狀為:2行2列data2 = torch.tensor([[5,6], [7,8]])data = torch.mm(data1, data2) print(data)print(data.shape)# 使用 bmm 函數
def test03():# 第一個維度:表示批次# 第二個維度:多少行# 第三個維度:多少列data1 = torch.randn(3, 4, 5)data2 = torch.randn(3, 5, 8)data = torch.bmm(data1, data2) print(data.shape)# 使用 matmul 函數
def test04():# 對二維進行計算data1 = torch.randn(4,5)data2 = torch.randn(5,8)print(torch.matmul(data1, data2).shape)# 對三維進行計算data1 = torch.randn(3, 4, 5)data2 = torch.randn(3, 5, 8)print(torch.matmul(data1, data2).shape)data1 = torch.randn(3, 4, 5)data2 = torch.randn(5, 8)print(torch.matmul(data1, data2).shape) if __name__ == "__main__":test04()