多數度量學習的代碼都需要進行挖掘,樣本挖掘過程就是把一個Batch中的所有樣本,根據標簽來劃分成正樣本和負樣本
這里我們只討論多標簽分類問題,標簽是onehot編碼,如果是單標簽分類任務可以去看pytorch_metric_learning這個庫有實現好的挖掘方法
比如輸入樣本為[Batch,Embedding],對應的標簽是[Batch,Class]
對這些樣本進行挖掘后得到以下三部分:
- Anchor :錨點樣本,其實就是和輸入的Batch一模一樣,
- Positive Sample : 挖掘的正正樣本
- Negtive Sample : 挖掘的負樣本
import torch
import torch.nn as nn
import torchvision# 損失函數
class HibCriterion(nn.Module):def __init__(self):super().__init__()def forward(self, z_samples, alpha, beta, indices_tuple):n_samples = z_samples.shape[1]if len(indices_tuple) == 3:a, p, n = indices_tupleap = an = aelif len(indices_tuple) == 4:ap, p, an, n = indices_tuplealpha = torch.nn.functional.softplus(alpha)loss = 0for i in range(n_samples):z_i = z_samples[:, i, :]for j in range(n_samples):z_j = z_samples[:, j, :]prob_pos = torch.sigmoid(- alpha * torch.sum((z_i[ap] - z_j[p])**2, dim=1) + beta) + 1e-6prob_neg = torch.sigmoid(- alpha * torch.sum((z_i[an] - z_j[n])**2, dim=1) + beta) + 1e-6# maximize the probability of positive pairs and minimize the probability of negative pairsloss += -torch.log(prob_pos) - torch.log(1 - prob_neg)loss = loss / (n_samples ** 2)return loss.mean()def get_matches_and_diffs(labels):matches = (labels.float() @ labels.float().T).byte()diffs = matches ^ 1 # 異或運算得到負標簽的矩陣return matches, diffsdef get_all_triplets_indices_vectorized_method(all_matches, all_diffs):"""Args:all_matches (torch.Tensor): 相同標簽all_diffs (torch.Tensor): 不相同標簽Processing : all_matches.unsqueeze(2) -> [Batch,Batch,1]all_diffs.unsqeeeze(1) -> [Batch,1,Batch] Returns:torch.Tensor: _description_"""triplets = all_matches.unsqueeze(2) * all_diffs.unsqueeze(1)return torch.where(triplets)class TripletMinner(nn.Module):def __init__(self, *args, **kwargs) -> None:super().__init__(*args, **kwargs)self.sim_mat = get_matches_and_diffsself.selctor = get_all_triplets_indices_vectorized_methoddef forward(self,labels):a , b = self.sim_mat(labels)c = self.selctor(a,b)return c