本文是對DISTS圖像質量評價指標的代碼解讀,原文解讀請看DISTS文章講解。
本文的代碼來源于IQA-Pytorch工程。
1、原文概要
以前的一些IQA方法對于捕捉紋理上的感知一致性有所欠缺,魯棒性不足。基于此,作者開發了一個能夠在圖像結構和圖像紋理上都具有與人類相同感知判斷的指標,在此之上,還希望紋理能夠resample(不需要像素級對齊)之后也是一樣的,另外區分開退化(JPEG,JPEG會損失紋理)。實現該指標可以分為4個步驟:
- 對圖像進行一個初始的變換,從像素空間變換到特征空間。
- 對特征提取所謂紋理的表示,對特征提取所謂結構的表示。
- 利用紋理和結構的表示,加入一些可學習的權重綜合計算一個評價指標。
- 利用這個評價指標,進一步優化權重得到紋理區域resample不敏感的指標,且能夠有結構和紋理上做感知相似度的模型。
實現后的指標作為優化指標對比其他IQA指標有明顯優勢,如下圖所示。
2、代碼結構
代碼實現位于pyiqa/archs/dists_arch.py中:
3 、核心代碼模塊
L2pooling
類
這個類實現了我們前面提到的預處理部分替換max-pool的操作。
class L2pooling(nn.Module):def __init__(self, filter_size=5, stride=2, channels=None, pad_off=0):super(L2pooling, self).__init__()self.padding = (filter_size - 2) // 2self.stride = strideself.channels = channelsa = np.hanning(filter_size)[1:-1]g = torch.Tensor(a[:, None] * a[None, :])g = g / torch.sum(g)self.register_buffer('filter', g[None, None, :, :].repeat((self.channels, 1, 1, 1)))def forward(self, input):input = input**2out = F.conv2d(input,self.filter,stride=self.stride,padding=self.padding,groups=input.shape[1],)return (out + 1e-12).sqrt()
這里可以看到前向的過程中作者先是進行了一個平方,然后使用了一個self.filter的濾波器,kernel_size為3的hanning窗,stride=2,且是一個深度可分離的卷積,groups與輸入通道一致,這代替max-pool完成了一次抗混疊的下采樣,最后進行一個sqrt,這與講解中展示的公式一致,如下所示:
P(x)=g?(x?x)P(x)=\sqrt{g*(x*x)}P(x)=g?(x?x)?這個ggg在初始化時被復制了self.channels次,實際它一個通道的數值,讀者可以打印如下所示:
[0.06250.1250.06250.1250.250.1250.06250.1250.0625]\begin{bmatrix} 0.0625 & 0.125 & 0.0625 \\ 0.125 & 0.25 & 0.125 \\ 0.0625 & 0.125 & 0.0625 \end{bmatrix} ?0.06250.1250.0625?0.1250.250.125?0.06250.1250.0625??一個典型的低通濾波器,做了一個空間上根據距離的平均。
DISTS
類
存放著跟實際計算指標相關的代碼。
@ARCH_REGISTRY.register()
class DISTS(torch.nn.Module):r"""DISTS model.Args:pretrained_model_path (String): Pretrained model path."""def __init__(self, pretrained=True, pretrained_model_path=None, **kwargs):"""Refer to official code https://github.com/dingkeyan93/DISTS"""super(DISTS, self).__init__()vgg_pretrained_features = models.vgg16(weights='IMAGENET1K_V1').featuresself.stage1 = torch.nn.Sequential()self.stage2 = torch.nn.Sequential()self.stage3 = torch.nn.Sequential()self.stage4 = torch.nn.Sequential()self.stage5 = torch.nn.Sequential()for x in range(0, 4):self.stage1.add_module(str(x), vgg_pretrained_features[x])self.stage2.add_module(str(4), L2pooling(channels=64))for x in range(5, 9):self.stage2.add_module(str(x), vgg_pretrained_features[x])self.stage3.add_module(str(9), L2pooling(channels=128))for x in range(10, 16):self.stage3.add_module(str(x), vgg_pretrained_features[x])self.stage4.add_module(str(16), L2pooling(channels=256))for x in range(17, 23):self.stage4.add_module(str(x), vgg_pretrained_features[x])self.stage5.add_module(str(23), L2pooling(channels=512))for x in range(24, 30):self.stage5.add_module(str(x), vgg_pretrained_features[x])for param in self.parameters():param.requires_grad = Falseself.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, -1, 1, 1))self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, -1, 1, 1))self.chns = [3, 64, 128, 256, 512, 512]self.register_parameter('alpha', nn.Parameter(torch.randn(1, sum(self.chns), 1, 1)))self.register_parameter('beta', nn.Parameter(torch.randn(1, sum(self.chns), 1, 1)))self.alpha.data.normal_(0.1, 0.01)self.beta.data.normal_(0.1, 0.01)if pretrained_model_path is not None:load_pretrained_network(self, pretrained_model_path, False)elif pretrained:load_pretrained_network(self, default_model_urls['url'], False)def forward_once(self, x):h = (x - self.mean) / self.stdh = self.stage1(h)h_relu1_2 = hh = self.stage2(h)h_relu2_2 = hh = self.stage3(h)h_relu3_3 = hh = self.stage4(h)h_relu4_3 = hh = self.stage5(h)h_relu5_3 = hreturn [x, h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3]def forward(self, x, y):r"""Compute IQA using DISTS model.Args:- x: An input tensor with (N, C, H, W) shape. RGB channel order for colour images.- y: An reference tensor with (N, C, H, W) shape. RGB channel order for colour images.Returns:Value of DISTS model."""feats0 = self.forward_once(x)feats1 = self.forward_once(y)dist1 = 0dist2 = 0c1 = 1e-6c2 = 1e-6w_sum = self.alpha.sum() + self.beta.sum()alpha = torch.split(self.alpha / w_sum, self.chns, dim=1)beta = torch.split(self.beta / w_sum, self.chns, dim=1)for k in range(len(self.chns)):x_mean = feats0[k].mean([2, 3], keepdim=True)y_mean = feats1[k].mean([2, 3], keepdim=True)S1 = (2 * x_mean * y_mean + c1) / (x_mean**2 + y_mean**2 + c1)dist1 = dist1 + (alpha[k] * S1).sum(1, keepdim=True)x_var = ((feats0[k] - x_mean) ** 2).mean([2, 3], keepdim=True)y_var = ((feats1[k] - y_mean) ** 2).mean([2, 3], keepdim=True)xy_cov = (feats0[k] * feats1[k]).mean([2, 3], keepdim=True) - x_mean * y_meanS2 = (2 * xy_cov + c2) / (x_var + y_var + c2)dist2 = dist2 + (beta[k] * S2).sum(1, keepdim=True)score = 1 - (dist1 + dist2)return score.squeeze(-1).squeeze(-1)
3個重點如下:
- 初始化中首先會插入前面講到的L2_Pooling,來替換原始的max-pool,其他的就是初始化必要的標準化變量和用于各層結構和紋理的加權系數α\alphaα和β\betaβ,最后導入預訓練的網絡即可。
- 前向中調用的forward_once,可以看到總共有6個輸出,第一個輸出是輸入x,即我們講解中提到的identity的變換,其他5層是事先定義好的輸出位置。
- dists的計算:首先根據權重的大小對alpha和beta進行歸一化,隨后分層計算我們前面定義好的紋理特征和結構特征的相關性公式,針對于紋理的部分代碼中是S1,可以看到S1是利用了特征的在空間上的均值計算的參考圖像和待評估圖像的相關系數,然后利用alpha對計算好的S1進行加權,得到紋理上相似度dist1;針對于結構的部分代碼中是S2,S2是利用了參考圖像和待評估圖像兩個特征的協方差和方差,由于是全局的窗口所以在計算后會求取空間上的一個均值,這樣得到了結構上的相似度dist2。最后結合dist1和dist2得到最終的score。dists計算的公式如下,可以對照著公式來查看:
l(x~j(i),y~j(i))=2μx~j(i)μy~j(i)+c1(μx~j(i))2+(μy~j(i))2+c1l(\tilde{x}_j^{(i)}, \tilde{y}_j^{(i)}) = \frac{2\mu_{\tilde{x}_j}^{(i)}\mu_{\tilde{y}_j}^{(i)} + c_1}{(\mu_{\tilde{x}_j}^{(i)})^2 + (\mu_{\tilde{y}_j}^{(i)})^2 + c_1}l(x~j(i)?,y~?j(i)?)=(μx~j?(i)?)2+(μy~?j?(i)?)2+c1?2μx~j?(i)?μy~?j?(i)?+c1?? s(x~j(i),y~j(i))=2σx~jy~j(i)+c2(σx~j(i))2+(σy~j(i))2+c2,s(\tilde{x}_j^{(i)}, \tilde{y}_j^{(i)}) = \frac{2\sigma_{\tilde{x}_j\tilde{y}_j}^{(i)} + c_2}{(\sigma_{\tilde{x}_j}^{(i)})^2 + (\sigma_{\tilde{y}_j}^{(i)})^2 + c_2},s(x~j(i)?,y~?j(i)?)=(σx~j?(i)?)2+(σy~?j?(i)?)2+c2?2σx~j?y~?j?(i)?+c2??, D(x,y;α,β)=1?∑i=0m∑j=1ni(αijl(x~j(i),y~j(i))+βijs(x~j(i),y~j(i)))D(x, y; \alpha, \beta) = 1 - \sum_{i = 0}^{m} \sum_{j = 1}^{n_i} \left( \alpha_{ij} l(\tilde{x}_j^{(i)}, \tilde{y}_j^{(i)}) + \beta_{ij} s(\tilde{x}_j^{(i)}, \tilde{y}_j^{(i)}) \right)D(x,y;α,β)=1?i=0∑m?j=1∑ni??(αij?l(x~j(i)?,y~?j(i)?)+βij?s(x~j(i)?,y~?j(i)?))其中,lll和sss分別代表紋理和結構。
3、總結
代碼實現核心的部分講解完畢,DISTS作為一個可以同時捕獲結構和紋理相似度的全參考IQA指標,在很多比賽和論文的引用中都可以見到它的身影,實用性是毋庸置疑的。
大家有涉及到數據集篩選、紋理分類、紋理搜索類的任務可以嘗試使用DISTS指標,或者是在算法評估中利用它來做一個方面的對比評估。
感謝閱讀,歡迎留言或私信,一起探討和交流。
如果對你有幫助的話,也希望可以給博主點一個關注,感謝。