張量的拼接操作在神經網絡搭建過程中是非常常用的方法,例如: 在后面將要學習的注意力機制中都使用到了張量拼接。
torch.cat 函數可以將兩個張量根據指定的維度拼接起來,不改變數據維度。
前提:除了拼接的維度,其他維度一定要相同。
def cat_tensor():data1=torch.randint(0,10,(1,2,3))data2=torch.randint(0,10,(1,2,3))# 0軸拼接 dim=0# data3=torch.cat([data1,data2]) # 默認dim=0,torch.Size([2, 2, 3])# data3=torch.cat([data1,data2],dim=0)# 1軸拼接 dim=1# data3=torch.cat([data1,data2],dim=1) # torch.Size([1, 4, 3])# 2軸拼接 dim=2data3=torch.cat([data1,data2],dim=2) # torch.Size([1, 2, 6])# data3=torch.concat([data1,data2],dim=2) # torch.Size([1, 2, 6]) # 第二種寫法 同種結果# data3=torch.concatenate([data1,data2],dim=2) # torch.Size([1, 2, 6]) # 第三種寫法 同種結果print(data3.shape)if __name__ == '__main__':cat_tensor()