論文鏈接:https://arxiv.org/pdf/2304.03977.pdf
代碼:https://github.com/tsb0601/EMP-SSL
其他學習鏈接:突破自監督學習效率極限!馬毅、LeCun聯合發布EMP-SSL:無需花哨trick,30個epoch即可實現SOTA
主要思想
如圖,一張圖片裁剪成不同的 patch,對不同的 patch 做數據增強,分別輸入 encoder,得到多個 embedding,對它們求均值,得到??作為這張圖片的 embedding。最后,拉近每個 patch 的 embedding 和圖片的 embedding(
)之間的余弦距離;再用 Total Coding Rate(TCR) 防止坍塌(即 encoder 對所有輸入都輸出相同的 embedding)
Total Coding Rate(TCR)
公式如下:
其中,det 表示求矩陣的行列式,d 是 feature vector 的 dimension,b 是 batch size
查了查該公式的含義:expand all features of Z as large as possible,即盡可能拉遠矩陣中特征之間的距離。
源自 PPT 第 24 頁:
https://s3.amazonaws.com/sf-web-assets-prod/wp-content/uploads/2021/06/15175515/Deep_Networks_from_First_Principles.pdf
至于為什么最大化該公式的值就可以拉遠矩陣中特征之間的距離,這背后的數學原理真難啃啊 /(ㄒoㄒ)/~~
核心代碼解讀
數據處理
https://github.com/tsb0601/EMP-SSL/blob/main/dataset/aug.py#L116C1-L138C27
class ContrastiveLearningViewGenerator(object):def __init__(self, num_patch = 4):self.num_patch = num_patchdef __call__(self, x):normalize = transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])aug_transform = transforms.Compose([transforms.RandomResizedCrop(32,scale=(0.25, 0.25), ratio=(1,1)),transforms.RandomHorizontalFlip(p=0.5),transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.2)], p=0.8),transforms.RandomGrayscale(p=0.2),GBlur(p=0.1),transforms.RandomApply([Solarization()], p=0.1),transforms.ToTensor(), normalize])augmented_x = [aug_transform(x) for i in range(self.num_patch)]return augmented_x
由此看出返回的 數據 為:長度為 num_patches 個 tensor 的列表。其中,每個?tensor 的 shape 為 (B, C, H, W)。
主函數
https://github.com/tsb0601/EMP-SSL/blob/main/main.py#L148C9-L162C63
for step, (data, label) in tqdm(enumerate(dataloader)):net.zero_grad()opt.zero_grad()data = torch.cat(data, dim=0) data = data.cuda()z_proj = net(data)z_list = z_proj.chunk(num_patches, dim=0)z_avg = chunk_avg(z_proj, num_patches)# Contractive Lossloss_contract, _ = contractive_loss(z_list, z_avg)loss_TCR = cal_TCR(z_proj, criterion, num_patches)
這里要稍微注意一下幾個變量的 shape:
- data 被 cat 完后:(num_patches * B,C,H,W)
- z_proj:(num_patches * B,C)
- z_list:(num_patches,B,C)
- z_avg:(B,C)
其中,chunk_avg 就是對來自同一張圖片的不同?patch 的 embedding 求均值():
https://github.com/tsb0601/EMP-SSL/blob/main/main.py#L67
def chunk_avg(x,n_chunks=2,normalize=False):x_list = x.chunk(n_chunks,dim=0)x = torch.stack(x_list,dim=0)if not normalize:return x.mean(0)else:return F.normalize(x.mean(0),dim=1)
loss
contractive_loss 就是計算每個 patch 的 embedding 和均值()的余弦距離:
https://github.com/tsb0601/EMP-SSL/blob/main/main.py#L76
class Similarity_Loss(nn.Module):def __init__(self, ):super().__init__()passdef forward(self, z_list, z_avg):z_sim = 0num_patch = len(z_list)z_list = torch.stack(list(z_list), dim=0)z_avg = z_list.mean(dim=0)z_sim = 0for i in range(num_patch):z_sim += F.cosine_similarity(z_list[i], z_avg, dim=1).mean()z_sim = z_sim/num_patchz_sim_out = z_sim.clone().detach()return -z_sim, z_sim_out
TCR loss:最大化矩陣之間特征的距離,即拉遠負樣本(不是來自同一個樣本的 patches)之間的距離
https://github.com/tsb0601/EMP-SSL/blob/main/main.py#L96
def cal_TCR(z, criterion, num_patches):z_list = z.chunk(num_patches,dim=0)loss = 0for i in range(num_patches):loss += criterion(z_list[i])loss = loss/num_patchesreturn loss
需要注意:函數輸入的 z 是?z_proj,形狀為(num_patches * B,C)。
所以,函數內部?z_list?的形狀為(num_patches,B,C),即將數據分為了?num_patches 個組,每個組包含了來自不同圖片里 patch 的 embedding。再分別對每個組求 TCR loss,最大化組內(不同圖片的 patch)特征的距離。
所以,公式中的??指的是一組來自不同圖片里 patch 的 embedding,形狀為(B,C)。
每個組內求 TCR loss 的代碼按照公式計算,如下:?
https://github.com/tsb0601/EMP-SSL/blob/main/loss.py#L76
class TotalCodingRate(nn.Module):def __init__(self, eps=0.01):super(TotalCodingRate, self).__init__()self.eps = epsdef compute_discrimn_loss(self, W):"""Discriminative Loss."""p, m = W.shape #[d, B]I = torch.eye(p,device=W.device)scalar = p / (m * self.eps)logdet = torch.logdet(I + scalar * W.matmul(W.T))return logdet / 2.def forward(self,X):return - self.compute_discrimn_loss(X.T)