1.permute
permute函數用于重新排列張量的維度。它接受一個元組作為參數,表示新的維度順序。例如,如果我們有一個形狀為(2, 3)的二維張量,我們可以使用permute函數將其維度重新排列為(3, 2),如下所示:
>>> import torch
>>> x = torch.randn(2,3)
>>> x
tensor([[-0.5945, 0.7441, 0.5515],[-1.3831, 0.4533, -0.6908]])
>>> y = x.permute(1,0)
>>> y
tensor([[-0.5945, -1.3831],[ 0.7441, 0.4533],[ 0.5515, -0.6908]])
>>>
首先創建了一個形狀為(2, 3)的二維張量x。然后,我們使用permute函數將其維度重新排列為(3, 2),并將結果存儲在變量y中。
2.transpose
transpose函數用于交換張量的兩個維度。它接受兩個整數作為參數,表示要交換的維度的索引。例如,如果我們有一個形狀為(2, 3)的二維張量,我們可以使用transpose函數交換第0維和第1維,如下所示:
>>> import torch
>>> x = torch.randn(2,3)
>>> x
tensor([[-0.5945, 0.7441, 0.5515],[-1.3831, 0.4533, -0.6908]])>>> y=x.transpose(0,1)
>>> y
tensor([[-0.5945, -1.3831],[ 0.7441, 0.4533],[ 0.5515, -0.6908]])
>>>
在上面的例子中,創建了一個形狀為(2, 3)的二維張量x。然后,我們使用transpose函數將第0維和第1維交換,并將結果存儲在變量y中。
需要注意的是,transpose函數與permute函數不同,它只交換兩個特定的維度,而permute函數可以重新排列所有維度。
3.view / reshape
view和reshape函數用于將張量重塑為不同的形狀。它們接受一個或兩個整數元組作為參數,表示新的形狀。例如,如果我們有一個形狀為(2, 3)的二維張量,我們可以使用view或reshape函數將其重塑為形狀為(6,)的一維張量,如下所示:
>>> import torch
>>> x = torch.randn(2,3)
>>> x
tensor([[-0.5945, 0.7441, 0.5515],[-1.3831, 0.4533, -0.6908]])>>> y = x.view(-1)
>>> y
tensor([-0.5945, 0.7441, 0.5515, -1.3831, 0.4533, -0.6908])
# 或者
>>> y=x.reshape(-1)
>>> y
tensor([-0.5945, 0.7441, 0.5515, -1.3831, 0.4533, -0.6908])
>>>
在上面的例子中,創建了一個形狀為(2, 3)的二維張量x。然后,我們使用view或reshape函數將x重塑為形狀為(6,)的一維張量,并將結果存儲在變量y中。
需要注意的是,view和reshape函數實際上不會改變張量中的數據,只是改變了數據的布局方式。因此,新的形狀必須與原始形狀兼容,否則會拋出錯誤。具體來說,新的形狀的元素總數必須與原始形狀的元素總數相同。
4.?flatten
flatten函數用于將多維張量展平為一維張量。它接受一個整數作為參數,表示展平后的一維張量的最大長度。例如,如果我們有一個形狀為(2, 3)的二維張量,我們可以使用flatten函數將其展平為一維張量,如下所示:
>>> import torch
>>> x = torch.randn(2,3)>>> y=x.flatten(1)
>>> y
tensor([[-0.5945, 0.7441, 0.5515],[-1.3831, 0.4533, -0.6908]])
>>> y=x.flatten(0)
>>> y
tensor([-0.5945, 0.7441, 0.5515, -1.3831, 0.4533, -0.6908])
>>> y=x.flatten(2)
Traceback (most recent call last):File "<stdin>", line 1, in <module>
IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)
在上面的例子中,創建了一個形狀為(2, 3)的二維張量x。然后,我們使用flatten函數將x展平為一維張量,并將結果存儲在變量y中。需要注意的是,flatten函數的參數指定了展平后的一維張量的最大長度。在本例中,我們將最大長度設置為1,因此展平后的張量將具有形狀(6,)。如果展平后的長度超過了指定的最大長度,將會拋出錯誤。
總結:在PyTorch中,permute、transpose、view、reshape和flatten函數都是用于改變張量形狀和維度的工具。它們具有不同的用途和特點,可以根據具體需求選擇合適的函數來操作張量。