一.前言
本章節來介紹一下張量拼接的操作,掌握torch.cat torch.stack使?,張量的拼接操作在神經?絡搭建過程中是?常常?的?法,例如: 在后?將要學習到的殘差?絡、注意?機 制中都使?到了張量拼接。
二.torch.cat 函數的使用
torch.cat 函數可以將兩個張量根據指定的維度拼接起來.
import torchdef test():data1 = torch.randint(0, 10, [3, 5, 4])data2 = torch.randint(0, 10, [3, 5, 4])print(data1)print(data2)print('-' * 50)# 1. 按0維度拼接new_data = torch.cat([data1, data2], dim=0)print(new_data.shape)print('-' * 50)# 2. 按1維度拼接new_data = torch.cat([data1, data2], dim=1)print(new_data.shape)print('-' * 50)# 3. 按2維度拼接new_data = torch.cat([data1, data2], dim=2)print(new_data.shape)if __name__ == '__main__':test()
結果展示:
tensor([[[6, 7, 2, 6],
[4, 6, 4, 3],
[5, 3, 4, 9],
[8, 8, 6, 7],
[0, 3, 3, 0]],? ? ? ? [[6, 1, 2, 0],
[5, 6, 7, 0],
[6, 4, 8, 0],
[2, 2, 8, 3],
[0, 1, 6, 8]],? ? ? ? [[3, 5, 0, 8],
[6, 2, 1, 7],
[8, 9, 9, 8],
[3, 8, 8, 0],
[5, 8, 4, 4]]])
tensor([[[7, 2, 2, 1],
[8, 0, 6, 6],
[9, 0, 6, 5],
[1, 3, 7, 7],
[7, 0, 5, 1]],? ? ? ? [[0, 7, 3, 1],
[9, 2, 9, 0],
[9, 6, 2, 1],
[9, 3, 5, 0],
[8, 8, 6, 2]],? ? ? ? [[1, 8, 9, 9],
[4, 3, 0, 9],
[7, 3, 3, 8],
[2, 4, 6, 9],
[2, 1, 0, 5]]])
--------------------------------------------------
torch.Size([6, 5, 4])
--------------------------------------------------
torch.Size([3, 10, 4])
--------------------------------------------------
torch.Size([3, 5, 8])?
三.torch.stack 函數的使用
torch.stack 函數可以將兩個張量根據指定的維度疊加起來.
import torchdef test():data1 = torch.randint(0, 10, [2, 3])data2 = torch.randint(0, 10, [2, 3])print(data1)print(data2)print("="*50)new_data = torch.stack([data1, data2], dim=0)print(new_data.shape)print(new_data)print("=" * 50)new_data = torch.stack([data1, data2], dim=1)print(new_data.shape)print(new_data)print("=" * 50)new_data = torch.stack([data1, data2], dim=2)print(new_data.shape)print(new_data)if __name__ == '__main__':test()
?結果展示:
tensor([[6, 9, 6],
[3, 2, 7]])
tensor([[3, 3, 4],
[9, 1, 4]])
==================================================
torch.Size([2, 2, 3])
tensor([[[6, 9, 6],
[3, 2, 7]],? ? ? ? [[3, 3, 4],
[9, 1, 4]]])
==================================================
torch.Size([2, 2, 3])
tensor([[[6, 9, 6],
[3, 3, 4]],? ? ? ? [[3, 2, 7],
[9, 1, 4]]])
==================================================
torch.Size([2, 3, 2])
tensor([[[6, 3],
[9, 3],
[6, 4]],? ? ? ? [[3, 9],
[2, 1],
[7, 4]]])
這里十分的不好理解,大家拷貝完代碼自己執行理解一下。
四.總結?
張量的拼接操作也是在后?我們經常使??種操作。cat 函數可以將張量按照指定的維度拼接起來,stack 函數可以將張量按照指定的維度疊加起來。?
?
?