文章目錄
- 簡介
- 向量乘法
- 二維矩陣乘法
- 三維矩陣乘法
- 廣播
- 高維矩陣乘法
- 開源
簡介
一提到矩陣乘法,大家對于二維矩陣乘法都很了解,即 A 矩陣的行乘以 B 矩陣的列。
但對于高維矩陣乘法可能就不太清楚,不知道高維矩陣乘法是怎么在計算。
建議使用torch.matmul
做矩陣乘法,其支持向量乘法 和 二維、乃至多維的矩陣乘法。
向量乘法
a1 = torch.tensor([1, 2])
res1 = torch.matmul(a1, a1)
print(res1)
print(res1.shape)
輸出:
tensor(5)
torch.Size([])
torch 也支持使用 @
完成乘法操作
二維矩陣乘法
a2 = torch.tensor([[1, 2]])
res2 = torch.matmul(a2, a2.transpose(-2, -1))
print(res2)
print(res2.shape)
輸出:
tensor([[5]])
torch.Size([1, 1])
torch.mm
與 @
也可以做二維矩陣乘法:
a2 @ a2.transpose(-2, -1)
torch.mm(a2, a2.transpose(-2, -1))
三維矩陣乘法
torch.bmm 支持三維矩陣乘法,不支持更高維度的矩陣乘法
a3 = torch.randn(2, 3, 2)
res3 = torch.bmm(a3,a3.transpose(-1, -2)
)
print(res3)
print(res3.shape)
輸出:
tensor([[[ 4.5979, 0.6648, 2.9231],[ 0.6648, 0.1155, 0.4713],[ 2.9231, 0.4713, 1.9805]],[[ 1.0323, 1.8212, -0.3546],[ 1.8212, 3.5445, -0.3834],[-0.3546, -0.3834, 0.2988]]])
torch.Size([2, 3, 3])
a3 的 shape是(2, 3, 2),a3 底層的兩個維度做轉置之后變成(2, 2, 3),才可以做矩陣乘法。
可以發現第一位的數字都是2。高維矩陣做乘法的時候,除了最后兩個維度,高維矩陣前面的維度兩個矩陣要保持一致。
torch.randn(2, 3, 2) @ torch.randn(3, 2, 3)
雖然上述兩個矩陣,在最后兩個維度滿足矩陣運算的條件,但是第一個維度兩個矩陣的值不一樣,所以不能做矩陣乘法。
廣播
但是發現:
t1 = torch.randn(1, 3, 2)
t2 = torch.randn(3, 2, 3)
t1 @ t2
輸出:
tensor([[[-0.6557, 1.0518, 0.3055],[-0.2876, -2.5104, -1.4417],[ 1.4447, -0.1799, 0.4602]],[[ 0.2971, 0.0060, -0.2612],[-0.9089, 1.0824, 0.7131],[ 0.0929, -0.7898, -0.0199]],[[ 0.0027, 1.2031, 0.1543],[-0.5603, -1.8567, -0.1302],[ 0.3978, -0.9356, -0.1977]]])
理論上兩個矩陣的高維度的shape不一樣,就不可以做矩陣乘法。但上述 t1
與 t2
可以做矩陣乘法。這是因為 t1 的第一個維度是1,就會自動做廣播。
廣播的效果類似于,把 t1 在第一個維度復制成與t2一樣,第一個維度都變成3。
在下述使用 concat完成復制工作,再做矩陣乘法,發現可以得到上述一樣的結果。
torch.concat((t1, t1, t1)) @ t2
輸出:
tensor([[[-0.6557, 1.0518, 0.3055],[-0.2876, -2.5104, -1.4417],[ 1.4447, -0.1799, 0.4602]],[[ 0.2971, 0.0060, -0.2612],[-0.9089, 1.0824, 0.7131],[ 0.0929, -0.7898, -0.0199]],[[ 0.0027, 1.2031, 0.1543],[-0.5603, -1.8567, -0.1302],[ 0.3978, -0.9356, -0.1977]]])
高維矩陣乘法
矩陣乘法只會在最后兩個維度,用A矩陣的行乘以B矩陣的列。
其他的維度都是對應位置的數據,互相做乘法(類似向量乘法)。
high_matrix1 = torch.randn(2, 3, 4, 5)
high_matrix2 = torch.randn(2, 3, 5, 4)
high_result = high_matrix1 @ high_matrix2
把最后兩個維度看成一個點。更高的維度的矩陣乘法,可想象為兩個矩陣對應位置的點相乘。
比如,shape(2, 3, 4, 5)與shape(2, 3, 5, 4)的矩陣相乘,若把最后兩個維度看成一個點。就可以類比為 (2, 3) 與 (2, 3)的兩個矩陣做向量乘法,就是對應位置的點做乘法。
如下面的運行結果所示。針對兩個矩陣,在高維空間中,選取(1,2)對應的小矩陣數據做矩陣乘法得到的結果。與兩個矩陣乘法的結果對應(1,2)的值是一樣的。
(high_matrix1[1][2] @ high_matrix2[1][2]) == high_result[1][2]
輸出:
tensor([[True, True, True, True],[True, True, True, True],[True, True, True, True],[True, True, True, True]])
開源
https://github.com/JieShenAI/csdn/blob/main/25/06/torch_matmul/run.ipynb