???????
個人主頁:Icomi
在深度學習蓬勃發展的當下,PyTorch 是不可或缺的工具。它作為強大的深度學習框架,為構建和訓練神經網絡提供了高效且靈活的平臺。神經網絡作為人工智能的核心技術,能夠處理復雜的數據模式。通過 PyTorch,我們可以輕松搭建各類神經網絡模型,實現從基礎到高級的人工智能應用。接下來,就讓我們一同走進 PyTorch 的世界,探索神經網絡與人工智能的奧秘。本系列為PyTorch入門文章,若各位大佬想持續跟進,歡迎與我交流互關。
?????????前面我們學習了張量和 numpy 數組的相互轉換,這是我們在深度學習數據處理中非常實用的技能。
????????今天,咱們要講講張量的拼接操作,這可是在神經網絡搭建過程中極為常用的方法,就好比搭建一座宏偉建筑時不可或缺的連接工藝。想象一下,我們構建神經網絡就像搭建一座復雜的大廈,張量就是構成大廈的各種預制構件,而拼接操作就像是把這些構件精準連接在一起的關鍵技術。
????????比如說,在后面將要學習到的殘差網絡里,張量的拼接起到了至關重要的作用。殘差網絡能夠有效解決深度神經網絡訓練過程中的梯度消失和梯度爆炸問題,讓網絡可以更深層次地學習數據特征。這里面,通過巧妙地拼接不同層的張量,就像是把不同功能的建筑模塊合理組合,從而構建出了強大的深層網絡結構。
????????還有注意力機制,這也是深度學習領域的一個重要概念。在注意力機制中,張量的拼接幫助我們對不同的信息進行整合與聚焦,就像在紛繁復雜的信息海洋中,通過拼接操作找到最關鍵的信息片段并組合起來,讓模型能夠更加 “聰明” 地處理數據。
????????所以,掌握張量的拼接操作,對于我們理解和構建先進的神經網絡模型至關重要。接下來,咱們就深入學習一下張量的拼接到底是怎么實現的,以及在不同場景下該如何靈活運用它。
1. torch.cat 函數的使用?
torch.cat 函數可以將兩個張量根據指定的維度拼接起來.
?
import torchdef tensor_concatenation():# 創建第一個三維隨機整數張量tensor_1 = torch.randint(0, 10, [3, 5, 4])# 創建第二個三維隨機整數張量tensor_2 = torch.randint(0, 10, [3, 5, 4])print(tensor_1)print(tensor_2)print('-' * 50)# 1. 按 0 維度拼接concatenated_tensor_dim0 = torch.cat([tensor_1, tensor_2], dim=0)print(concatenated_tensor_dim0.shape)print('-' * 50)# 2. 按 1 維度拼接concatenated_tensor_dim1 = torch.cat([tensor_1, tensor_2], dim=1)print(concatenated_tensor_dim1.shape)print('-' * 50)# 3. 按 2 維度拼接concatenated_tensor_dim2 = torch.cat([tensor_1, tensor_2], dim=2)print(concatenated_tensor_dim2)if __name__ == '__main__':tensor_concatenation()
程序輸出結果:
tensor([[[6, 8, 3, 5],[1, 1, 3, 8],[9, 0, 4, 4],[1, 4, 7, 0],[5, 1, 4, 8]],[[0, 1, 4, 4],[4, 1, 8, 7],[5, 2, 6, 6],[2, 6, 1, 6],[0, 7, 8, 9]],[[0, 6, 8, 8],[5, 4, 5, 8],[3, 5, 5, 9],[3, 5, 2, 4],[3, 8, 1, 1]]])
tensor([[[4, 6, 8, 1],[0, 1, 8, 2],[4, 9, 9, 8],[5, 1, 5, 9],[9, 4, 3, 0]],[[7, 6, 3, 3],[4, 3, 3, 2],[2, 1, 1, 1],[3, 0, 8, 2],[8, 6, 6, 5]],[[0, 7, 2, 4],[4, 3, 8, 3],[4, 2, 1, 9],[4, 2, 8, 9],[3, 7, 0, 8]]])
--------------------------------------------------
torch.Size([6, 5, 4])
--------------------------------------------------
torch.Size([3, 10, 4])
tensor([[[6, 8, 3, 5, 4, 6, 8, 1],[1, 1, 3, 8, 0, 1, 8, 2],[9, 0, 4, 4, 4, 9, 9, 8],[1, 4, 7, 0, 5, 1, 5, 9],[5, 1, 4, 8, 9, 4, 3, 0]],[[0, 1, 4, 4, 7, 6, 3, 3],[4, 1, 8, 7, 4, 3, 3, 2],[5, 2, 6, 6, 2, 1, 1, 1],[2, 6, 1, 6, 3, 0, 8, 2],[0, 7, 8, 9, 8, 6, 6, 5]],[[0, 6, 8, 8, 0, 7, 2, 4],[5, 4, 5, 8, 4, 3, 8, 3],[3, 5, 5, 9, 4, 2, 1, 9],[3, 5, 2, 4, 4, 2, 8, 9],[3, 8, 1, 1, 3, 7, 0, 8]]])
2. torch.stack 函數的使用?
torch.stack 函數可以將兩個張量根據指定的維度疊加起來.
import torchdef test():data1= torch.randint(0, 10, [2, 3])data2= torch.randint(0, 10, [2, 3])print(data1)print(data2)new_data = torch.stack([data1, data2], dim=0)print(new_data.shape)new_data = torch.stack([data1, data2], dim=1)print(new_data.shape)new_data = torch.stack([data1, data2], dim=2)print(new_data)if __name__ == '__main__':test()
3. 總結?
張量的拼接操作也是在后面我們經常使用一種操作。cat 函數可以將張量按照指定的維度拼接起來,stack 函數可以將張量按照指定的維度疊加起來。