在操作張量時,經常要去獲取某些元素進行處理或者修改操作,在這里需要了解torch中的索引操作。
準備數據:
data = torch.randint(0,10,[4,5])
print('data--->',data)
輸出結果:
data---> tensor([[3, 9, 4, 0, 5],[7, 5, 9, 9, 7],[5, 9, 8, 9, 7],[9, 2, 6, 7, 7]])
-
簡單行、列索引
print('第一行:',data[0]) print('第一列:',data[:,0])
輸出結果:
第一行: tensor([3, 9, 4, 0, 5]) 第一列: tensor([3, 7, 5, 9])
-
列表索引
print('-----------------返回(0,1)、(1,2) 2個位置的元素------------------') print(data[[0,1],[1,2]]) print('-----------------返回0、1 行的1、2 列共4個元素------------------') print(data[[[0],[1]],[1,2]])
輸出結果:
-----------------返回(0,1)、(1,2) 2個位置的元素------------------ tensor([9, 9]) -----------------返回0、1 行的1、2 列共4個元素------------------ tensor([[9, 4],[5, 9]])
-
范圍索引
print('-----------------前3行、前2列的數據------------------') print(data[:3,:2]) print('-----------------第2行到最后的前2列數據------------------') print(data[2:,:2])
輸出結果:
-----------------前3行、前2列的數據------------------ tensor([[3, 9],[7, 5],[5, 9]]) -----------------第2行到最后的前2列數據------------------ tensor([[5, 9],[9, 2]])
-
布爾索引
print('-----------------第三列大于5的行數據------------------') print(data[data[:,2] > 5]) print('-----------------第二行大于5的行數據------------------') print(data[:,data[1] > 5])
輸出結果:
-----------------第三列大于5的行數據------------------ tensor([[7, 5, 9, 9, 7],[5, 9, 8, 9, 7],[9, 2, 6, 7, 7]]) -----------------第二行大于5的行數據------------------ tensor([[3, 4, 0, 5],[7, 9, 9, 7],[5, 8, 9, 7],[9, 6, 7, 7]])
-
多維索引
data = torch.randint(0,10,[3,4,5]) print(data) # 獲取0軸上的第一個數據 print(data[0,:,:]) # 獲取1軸上的第一個數據 print(data[:,0,:]) # 獲取2軸上的第一個數據 print(data[:,:,0])
輸出結果:
tensor([[[8, 3, 6, 1, 5],[5, 0, 4, 3, 8],[8, 3, 3, 5, 0],[6, 4, 0, 8, 4]],[[7, 2, 3, 8, 5],[6, 2, 9, 5, 0],[4, 2, 7, 1, 1],[5, 4, 4, 1, 1]],[[2, 4, 7, 2, 5],[6, 1, 4, 5, 6],[9, 2, 3, 1, 0],[2, 1, 2, 7, 9]]]) tensor([[8, 3, 6, 1, 5],[5, 0, 4, 3, 8],[8, 3, 3, 5, 0],[6, 4, 0, 8, 4]]) tensor([[8, 3, 6, 1, 5],[7, 2, 3, 8, 5],[2, 4, 7, 2, 5]]) tensor([[8, 5, 8, 6],[7, 6, 4, 5],[2, 6, 9, 2]])