張量的拼接操作在神經網絡搭建過程中是非常常用的方法,例如:殘差網絡,注意力機制中都使用張量拼接。
torch.cat 函數的使用
可以將兩個張量根據指定的維度拼接起來。
import torch
import numpy as np def test01():data1 = torch.randint(0, 10, [3, 4, 5])data2 = torch.randint(0, 10, [3, 4, 5])print(data1.shape)print(data2.shape)# dim 對應的值可以是負數,可以通過list來思考# 按照第 0 維度進行拼接new_data = torch.cat([data1, data2], dim = 0) # 是列表print(new_data.shape)# 按照第 1 維度進行拼接new_data = torch.cat([data1, data2], dim = 1)print(new_data.shape)# 按照第 2 維度進行拼接new_data = torch.cat([data1, data2], dim = 2)print(new_data.shape)if __name__ == "__main__":test01()
torch.stack 函數的使用
torch.stack 函數可以將兩個張量根據指定的維度疊加起來,或者組合成新的元素。疊加
的意思:當兩個元素疊在一起,我們就將這兩個元素當作一個元素。
import torch
import numpy as np def test01():data1 = torch.randint(0, 10, [2, 3])data2 = torch.randint(0, 10, [2, 3])print(data1)print(data2)# 將兩個張量 stack 疊加起來,像 cat 一樣指定維度# 1. 按照第0維度進行疊加new_data = torch.stack([data1, data2], dim=0)print(new_data.shape)# 2. 按照第1維度進行疊加new_data = torch.stack([data1, data2], dim=1)print(new_data)# 3. 按照第2維度進行疊加new_data = torch.stack([data1, data2], dim=2)print(new_data)if __name__ == "__main__":test01()