- reshape()?
- squeeze()
- unsqueeze()
- transpose()
- permute()
- view()? ==?reshape()?
- contiguous() ==?reshape()?
一、reshape() 函數
保證張量數據不變的前提下改變數據的維度,將其轉換成指定的形狀。
def reshape_tensor():data = torch.tensor([[1, 2, 3], [4, 5, 6]])print(data, data.shape) # pytorch中shape=size() 都可以獲得張量的形狀data1=data.reshape(1, 6) #reshape行列相乘要等于數據data的總個數# data1=data.reshape(6) # 等于上一行print(data1, data1.shape)data2=data.reshape(3,-1) #-1 自動推斷,已知三行自動推斷列數# data2=data.reshape(-1) #-1 自動推斷print(data2, data2.shape)if __name__ == '__main__':reshape_tensor()
擴展
pytorch中 shape=size() ?都可以獲得張量的形狀
# 擴展:size,和shape是等價的,都是看數據的維度 # print("data1-->", data1.shape, data1.size()) # shape[0] <=> size()[0] <=> size(0) # shape[1] <=> size()[1] <=> size(1)print("data1-->", data, data.shape[0], data.size(1))
二、squeeze() 和 unsqueeze()
squeeze 函數刪除 形狀為 1 的維度(升維),unsqueeze 函數添加形狀為1的維度(降維)。
# 生維與降維
def unsqueeze_squeeze_tensor():# 準備數據data = torch.tensor([1, 2, 3, 4, 5])print('data-->', data, data.shape)# 升維: unsqueeze(), 增加一個維度,這個維度的長度為1# data1 = data.unsqueeze(dim=0) # [1, 5]# data1 = data.unsqueeze(dim=1) # [5, 1]data1 = data.unsqueeze(dim=-1).unsqueeze(dim=0) # [1, 5, 1]# data1 = data.unsqueeze(dim=2) # [5, 1] 會報錯,越界print("data1-->", data1, data1.shape)# 降維: squeeze(), 能夠減少維度為1的維度# 所有長度為1的維度都會降低。data2 = data1.squeeze()# print("data2-->", data2, data2.shape)
if __name__ == '__main__':unsqueeze_squeeze_tensor()
三、transpose() 和 permute()
transpose 函數可以實現交換張量形狀的指定維度, 例如: 一個張量的形狀為 (2, 3, 4) 可以通過 transpose 函數把 3 和 4 進行交換, 將張量的形狀變為 (2, 4, 3)
permute 函數可以一次交換更多的維度。
def transpose_permute_tensor():# 生成隨機張量,并設置隨機種子,保持隨機張量是固定值torch.manual_seed(0)data = torch.randint(0, 10, (3, 4, 5))print(data, data.shape)# .transpose指定交換的兩個維度data1=data.transpose(1, 2) # torch.Size([3, 5, 4])# data1=data.transpose(0, 2)print(data1, data1.shape) # torch.Size([5, 4, 3])# .permute指定交換的多個維度data2 = data.permute(2,0,1)print(data2, data2.shape) # torch.Size([5, 3, 4])
if __name__ == '__main__':transpose_permute_tensor()
五、view() 和 contiguous()
????????view 函數也可以用于修改張量的形狀,但是其用法比較局限,只能用于存儲在整塊內存中的張量。在 PyTorch 中,有些張量是由不同的數據塊組成的,它們并沒有存儲在整塊的內存中,view 函數無法對這樣的張量進行變形處理,例如: 一個張量經過了 transpose 函數的處理之后,就無法使用 view 函數進行形狀操作。
view 函數也可以用于修改張量的形狀, 但是它要求被轉換的張量內存必須連續,所以一般配合 contiguous 函數使用。
def view_contiguous_tensor():torch.manual_seed(0)data = torch.randint(0, 10, (3, 4, 5))print(data, data.shape)data = data.transpose(1, 2) # 不連續 torch.Size([3, 5, 4])# data = data.permute(1, 2, 0)# data = data.view(1, 2, -1) # data不連續后,調用view函數會報錯print(data, data.shape)print(data.is_contiguous()) # 判斷是否連續# print(data.contiguous().is_contiguous()) # 通過contiguous把不連續的內存空間變成連續print(data.contiguous().view(3,4,5)) # 再view()
if __name__ == '__main__':view_contiguous_tensor()
六、小結
- reshape 函數可以在保證張量數據不變的前提下改變數據的維度
- squeeze 和 unsqueeze 函數可以用來增加或者減少維度
- transpose 函數可以實現交換張量形狀的指定維度, permute 可以一次交換更多的維度
- view 函數也可以用于修改張量的形狀, 但是它要求被轉換的張量內存必須連續, 所以一般配合 contiguous 函數使用