1.torch.nonzero(input, *, as_tuple=False)
作用:在PyTorch中用于返回輸入張量中非零元素的位置索引。
返回值:返回一個張量,每行代表一個非零元素的索引。
參數含義:
(1)input:輸入的PyTorch 張量。
(2)as_tuple:一個布爾值,指定返回結果的格式。默認為 False,返回一個張量。如果設置為 True,則返回一個元組,其中每個元素代表一個維度上的索引。
應用場景:
(1)高級索引:
使用as_tuple=True返回的元組可以用于對原始張量進行高級索引,例如提取所有非零元素;
(2)掩碼操作:
結合torch.nonzero()和其他函數來創建掩碼,例如選擇特定條件下的元素;
示例:
index_list = torch.nonzero(scores > det_thr, as_tuple=True)[0]
其中scores是torch.Tensor ,維度ndim=1; det_thr類型為<float>
代碼含義:
其他函數scores>det_thr: 是逐元素比較,返回一個布爾型張量,得到result = tensor([True,True, False,False,…])
然后torch.nonzero(result)? 得到所有值為True的元素的位置索引。例如tensor([[0],[2]])
參數as_tuple用于控制輸出的格式:
- 默認 (as_tuple=False): 返回形如 二維張量 的結果,比如 tensor([[0], [2]]);
- 加上 as_tuple=True: 返回一個元組,其中每一維是一個 1D 張量,表示每個軸的索引
最后的[0]用于從元組中提取索引值張量;
最后整句代碼含義:
取出所有 scores 中 大于閾值 det_thr 的元素的索引,并保存為 index_list
(3)稀疏張量處理:
在處理稀疏張量時,torch.nonzero()可以用于獲取非零元素的索引。