搭建模型時,數據都是基于張量形式的表示,網絡層與層之間很多都是以不同的shape的方式進行表現和運算。
對張量形狀的操作,以便能夠更好處理網絡各層之間的數據連接。
reshape 函數的用法
reshape 函數可以再保證張量數據不變的前提下改變數據的維度,將其轉換成指定的形狀,在神經網絡中經常使用該函數來調節數據的形狀,以適配不同網絡層之間的數據傳遞。
import torch
import numpy as np def test01():torch.manual_seed(0)data = torch.randint(0, 10, [4, 5])# 查看張量的形狀print(data.shape, data.shape[0], data.shape[1]) # shape屬性可以查看張量的形狀print(data.size(), data.size(0), data.size(1)) # size()方法可以查看張量的形狀# 修改張量的形狀new_data = data.reshape(2, 10) # 兩行十列print(new_data)# 注意:轉換之后的形狀元素個數得等于原來張量的元素個數,不然就報錯。上面創建data就是4*5=20個元素# 使用 -1 代替省略的形狀new_data = data.reshape(-1, 10) # -1表示自動計算行數print(new_data.shape) # torch.Size([2, 10])print(new_data)new_data = data.reshape(2, -1) # -1表示自動計算列數print(new_data)if __name__ == "__main__":test01()
transpose 和 permute 函數的使用
transpose 函數可以實現交換張量形狀的指定維度。
例如:一個張量的形狀為 (2, 3, 4) 可以通過 transpose 函數把 3 和 4 進行交換,將張量的形狀變為 (2, 4, 3)。
permute 函數可以一次交換更多的維度。
本質上都是在修改數據的維度。
import torch
import numpy as np # transpose 函數
def test01():torch.manual_seed(0)data = torch.randint(0, 10, [3, 4, 5])# new_data = data.reshape(4, 3, 5) # 重新計算維度# print(new_data.shape)# 直接交換兩個維度的值new_data = torch.transpose(data, 0, 1) # 只是將這兩個位置進行交換。0表示第0個維度,1表示第1個維度print(new_data.shape)# 缺點:transpose 一次只能交換兩個維度# 把數據的形狀變成 (4, 5, 3)# 進行第一次交換:(4, 3, 5)# 進行第二次交換:(4, 5, 3)new_data = torch.transpose(data, 0, 1)new_data = torch.transpose(new_data, 1, 2)print(new_data.shape)# permute 函數
def test02():torch.manual_seed(0)data = torch.randint(0, 10, [3, 4, 5])# permute 函數可以一次性交換多個維度new_data = torch.permute(data, [1, 2, 0])print(new_data.shape)if __name__ == "__main__":test02()
view 和 contigous 函數的用法
view 函數可以用于修改張量的形狀,但是其用法比較局限,只能用于存儲在整塊內存中的張量。
在 PyTorch 中,有些張量是由不同的數據塊組成的,它們并沒有存儲在整塊的內存中,view 函數無法對這樣的張量進行變形處理。
例如:一個張量經過了 transpose 或者 permute 函數的處理之后,就無法使用 view 函數進行形狀操作。
import torch
import numpy as np # view 函數的使用
def test01():data = torch.tensor([[10, 20, 30], [40, 50, 60]])data = data.view(3, 2)print(data.shape)# 通過 is_contigous 函數來判斷張量是否是連續內存空間 (整塊的內存)print(data.is_contiguous())# view 函數使用注意
def test02():# 當張量經過 transpose 或者 permute 函數之后,內存空間基本不連續# 此時,必須先把空間連續,才能使用 view 函數進行張量形狀操作data = torch.tensor([[10, 20, 30], [40, 50, 60]])data = torch.transpose(data, 0, 1)print(data.is_contiguous())# data = data.view(2, 3) # 這是報錯的data = data.contiguous().view(2, 3)print(data)if __name__ == "__main__":test02()
squeeze 和 unsqueeze 函數的用法
squeeze 函數用刪除 shape 為 1 的維度。
unsqueeze 在每個維度添加1,以增加數據的形狀。
import torch
import numpy as np # squeeze 函數使用
def test01():data = torch.randint(0, 10, [1, 3, 1, 5])print(data.shape)# 維度壓縮,默認去掉所有的1的維度new_data = data.squeeze()print(new_data.shape)# 指定去掉某個1的維度new_data = data.squeeze(2)print(new_data.shape)# unsqueeze 函數使用
def test02():data = torch.randint(0, 10, [3, 5])print(data.shape)new_data = data.unsqueeze(0)print(new_data)if __name__ == "__main__":test01()
總結
- reshape 函數可以在保證張量數據不變的前提下改變數據的維度
- transpose 函數可以實現交換張量形狀的指定維度,permute 可以一次交換更多的維度
- view 函數也可以用于修改張量的形狀,但是它要求被轉換的張量內存必須連續,所以一般配合 contiguous 函數使用。
- squeeze 和 unsqueeze 函數可以用來增加或者減少維度。