【新手向】PyTorch常用Tensor shape變換方法
前言
B站UP主科研水神大隊長
的視頻中介紹了“縫合模塊”大法,其中專門強調了“深度學習 玩的就是shape”。受此啟發,專門整理能夠調整tensor形狀的幾個內置函數,方便以后更好地調整PyTorch代碼中的模型結構。
squeeze, unsqueeze
- torch.squeeze()
- torch.unsqueeze()
squeeze()
用于壓縮掉指定的維度,這個維度的取值必須是1,否則無效。unsqueeze()
用于在指定的位置增加一個維度。
代碼實例:
import torch# image
x1 = torch.ones([4, 3, 256, 256]) # batch_size, channels, height, width
print("x1.shape:", x1.shape) # x1.shape: torch.Size([4, 3, 256, 256])x2 = torch.ones([1, 1, 3])
print("x2.shape:", x2.shape) # x2.shape: torch.Size([1, 1, 3])y1 = x1.squeeze(0)
print("y1.shape:", y1.shape) # y1.shape: torch.Size([4, 3, 256, 256])y2 = x2.squeeze(0)
print("y2.shape:", y2.shape) # y2.shape: torch.Size([1, 3])y6 = x2.squeeze(1)
print("y6.shape:", y6.shape) # y6.shape: torch.Size([1, 3])y7 = x2.squeeze(-1)
print("y7.shape:", y7.shape) # y7.shape: torch.Size([1, 1, 3])y3 = x1.unsqueeze(0)
print("y3.shape:", y3.shape) # y3.shape: torch.Size([1, 4, 3, 256, 256])y4 = x1.unsqueeze(1)
print("y4.shape:", y4.shape) # y4.shape: torch.Size([4, 1, 3, 256, 256])y5 = x1.unsqueeze(-1)
print("y5.shape:", y5.shape) # y5.shape: torch.Size([4, 3, 256, 256, 1])
transpose
transpose()
用于調整tensor的維度順序,在計算機視覺的任務中經常需要調整通道順序,比如有的模型輸出的順序是(channel, height, width)
,而有的輸出順序是(height, width, channel)
,需要通過調換順序來匹配輸入輸出。
transpose()
有兩種用法:
torch.transpose()
x.transpose()
代碼實例:
import torchx1 = torch.ones([4, 3, 256, 256]) # batch_size, channels, height, width
print("x1.shape:", x1.shape) # x1.shape: torch.Size([4, 3, 256, 256])x2 = torch.ones([1, 1, 3])
print("x2.shape:", x2.shape) # x2.shape: torch.Size([1, 1, 3])trans1 = torch.transpose(x1, 0, 1)
print("trans1.shape:", trans1.shape) # trans1.shape: torch.Size([3, 4, 256, 256])trans2 = torch.transpose(x2, 1, 2)
print("trans2.shape:", trans2.shape) # trans2.shape: torch.Size([1, 3, 1])trans3 = x1.transpose(0, 1)
print("trans3.shape:", trans3.shape) # trans3.shape: torch.Size([3, 4, 256, 256])
reshape
reshape()
能夠在總元素數量不產生變化的前提下改變tensor的形狀。它也可以用于處理numpy array的形狀。
代碼實例:
import torch
import numpy as npx3 = torch.Tensor([1, 2, 3, 4, 5, 6])
print("x3.shape:", x3.shape) # x3.shape: torch.Size([6])reshape1 = x3.reshape(2, 3)
print("reshape1.shape:", reshape1.shape) # reshape1.shape: torch.Size([2, 3])x4 = torch.ones([4, 4, 3, 256, 256])
print("x4.shape:", x4.shape) # x4.shape: torch.Size([4, 4, 3, 256, 256])reshape2 = x4.reshape(4*4, 3, 256, 256)
print("reshape2.shape:", reshape2.shape) # reshape2.shape: torch.Size([16, 3, 256, 256])x5 = np.array([1, 2, 3, 4, 5, 6])
print("x5.shape:", x5.shape) # x5.shape: (6,)reshape3 = x5.reshape(2, 3)
print("reshape3.shape:", reshape3.shape) # reshape3.shape: (2, 3)
view
view()
的作用與reshape()
的作用相似,也是在總元素數量不產生變化的前提下改變形狀,但view()
只能對張量進行操作。
代碼實例:
import torchx3 = torch.Tensor([1, 2, 3, 4, 5, 6])
print("x3.shape:", x3.shape) # x3.shape: torch.Size([6])view1 = x3.view(2, 3)
print("view1.shape:", view1.shape) # view1.shape: torch.Size([2, 3])
print("view1:", view1)
# view1: tensor([[1., 2., 3.],
# [4., 5., 6.]])view2 = x3.view(3, 2)
print("view2.shape:", view2.shape) # view2.shape: torch.Size([3, 2])
print("view2:", view2)
# view2: tensor([[1., 2.],
# [3., 4.],
# [5., 6.]])
permute
permute()
用于調整維度的順序。與transpose()
一次僅能“對調”兩個維度的順序不同,permute()
可以一次調整多個維度的順序。
代碼實例:
import torchx4 = torch.ones([4, 4, 3, 256, 256])
print("x4.shape:", x4.shape) # x4.shape: torch.Size([4, 4, 3, 256, 256])permute1 = x4.permute(1, 3, 4, 0, 2)
print("permute1.shape:", permute1.shape) # permute1.shape: torch.Size([4, 256, 256, 4, 3])