經典語義分割(一)利用pytorch復現全卷積神經網絡FCN
這里選擇B站up主[霹靂吧啦Wz]根據pytorch官方torchvision模塊中實現的FCN源碼。
Github連接:FCN源碼
1 FCN模型搭建
1.1 FCN網絡圖
- pytorch官方實現的FCN網絡圖,如下所示。
1.2 backbone
- FCN原文中的backbone是VGG,這里pytorch官方采用了resnet作為FCN的backbone。
- ResNet的前兩層跟GoogLeNet中的?樣:
- 在輸出通道數為64、步幅為2的7 × 7卷積層后,接步幅為2的3 × 3的最大匯聚層。
- 不同之處在于ResNet每個卷積層后增加了批量規范化層。
- GoogLeNet在后面接了4個由Inception塊組成的模塊。ResNet后接4個由殘差塊。
- ResNet則使用4個由殘差塊組成的模塊,每個模塊使用若干個同樣輸出通道數的殘差塊。
- 第1個模塊(layer1)由于之前已經使用了步幅為2的最大匯聚層,所以無須減小高和寬。
- 原生的ResNet在之后的每個模塊(layer2、layer3、layer4)在第?個殘差塊里將上一個模塊的通道數翻倍,并將高和寬減半。
- 不過,在這里和原生的ResNet不同的是,layer3和layer4使用了空洞卷積,并且高寬不減半。
- ResNet的前兩層跟GoogLeNet中的?樣:
# /fcn/src/backbone.py
import torch
import torch.nn as nn
from torchinfo import summarydef conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):"""3x3 convolution with padding"""return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,padding=dilation, groups=groups, bias=False, dilation=dilation)def conv1x1(in_planes, out_planes, stride=1):"""1x1 convolution"""return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)class Bottleneck(nn.Module):# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)# while original implementation places the stride at the first 1x1 convolution(self.conv1)# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.# This variant is also known as ResNet V1.5 and improves accuracy according to# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.expansion = 4def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,base_width=64, dilation=1, norm_layer=None):super(Bottleneck, self).__init__()if norm_layer is None:norm_layer = nn.BatchNorm2dwidth = int(planes * (base_width / 64.)) * groups# Both self.conv2 and self.downsample layers downsample the input when stride != 1self.conv1 = conv1x1(inplanes, width)self.bn1 = norm_layer(width)self.conv2 = conv3x3(width, width, stride, groups, dilation)self.bn2 = norm_layer(width)self.conv3 = conv1x1(width, planes * self.expansion)self.bn3 = norm_layer(planes * self.expansion)self.relu = nn.ReLU(inplace=True)self.downsample = downsampleself.stride = stridedef forward(self, x):identity = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)if self.downsample is not None:identity = self.downsample(x)out += identityout = self.relu(out)return outclass ResNet(nn.Module):def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,groups=1, width_per_group=64, replace_stride_with_dilation=None,norm_layer=None):super(ResNet, self).__init__()if norm_layer is None:norm_layer = nn.BatchNorm2dself._norm_layer = norm_layerself.inplanes = 64self.dilation = 1if replace_stride_with_dilation is None:# each element in the tuple indicates if we should replace# the 2x2 stride with a dilated convolution insteadreplace_stride_with_dilation = [False, False, False]if len(replace_stride_with_dilation) != 3:raise ValueError("replace_stride_with_dilation should be None ""or a 3-element tuple, got {}".format(replace_stride_with_dilation))self.groups = groupsself.base_width = width_per_group'''1、ResNet的前兩層ResNet的前兩層跟GoogLeNet中的?樣:在輸出通道數為64、步幅為2的7 × 7卷積層后,接步幅為2的3 × 3的最?匯聚層。不同之處在于ResNet每個卷積層后增加了批量規范化層。'''self.conv1 = nn.Conv2d(in_channels=3, out_channels=self.inplanes, kernel_size=7, stride=2, padding=3,bias=False)self.bn1 = norm_layer(self.inplanes)self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)'''2、ResNet后接4個由殘差塊GoogLeNet在后?接了4個由Inception塊組成的模塊。ResNet則使?4個由殘差塊組成的模塊,每個模塊使?若?個同樣輸出通道數的殘差塊。第?個模塊(layer1)由于之前已經使?了步幅為2的最?匯聚層,所以?須減??和寬。之后的每個模塊(layer2、layer3、layer4)在第?個殘差塊?將上?個模塊的通道數翻倍,并將?和寬減半。不過,在這里和原生的ResNet不同的是,layer3和layer4使用了空洞卷積,并且高寬不減半。'''self.layer1 = self._make_layer(block, 64, layers[0])self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])self.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(512 * block.expansion, num_classes)for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')elif isinstance(m, nn.BatchNorm2d):nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)# Zero-initialize the last BN in each residual branch,# so that the residual branch starts with zeros, and each residual block behaves like an identity.# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677if zero_init_residual:for m in self.modules():if isinstance(m, Bottleneck):nn.init.constant_(m.bn3.weight, 0)def _make_layer(self, block, planes, blocks, stride=1, dilate=False):norm_layer = self._norm_layerdownsample = Noneprevious_dilation = self.dilationif dilate:# layer3和layer4使用了空洞卷積,高寬不減半,因此設置stride = 1self.dilation *= stridestride = 1# layer2、layer3和layer4的stride=2,滿足# layer1的stride=1,但是inplanes(64) != planes * block.expansion(64×4),因此也滿足if stride != 1 or self.inplanes != planes * block.expansion:downsample = nn.Sequential(conv1x1(self.inplanes, planes * block.expansion, stride),norm_layer(planes * block.expansion),)# 對于每個layer,只有第1個Bottleneck需要downsamplelayers = []layers.append(block(self.inplanes, planes, stride, downsample, self.groups,self.base_width, previous_dilation, norm_layer))self.inplanes = planes * block.expansion# 對于每個layer,從第2個Bottleneck開始,就不需要downsamplefor _ in range(1, blocks):layers.append(block(self.inplanes, planes, groups=self.groups,base_width=self.base_width, dilation=self.dilation,norm_layer=norm_layer))return nn.Sequential(*layers)def _forward_impl(self, x):# See note [TorchScript super()]x = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.maxpool(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)x = self.avgpool(x)x = torch.flatten(x, 1)x = self.fc(x)return xdef forward(self, x):return self._forward_impl(x)def _resnet(block, layers, **kwargs):model = ResNet(block, layers, **kwargs)return modeldef resnet50(**kwargs):r"""ResNet-50 model from`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_Args:pretrained (bool): If True, returns a model pre-trained on ImageNetprogress (bool): If True, displays a progress bar of the download to stderr"""return _resnet(Bottleneck, [3, 4, 6, 3], **kwargs)def resnet101(**kwargs):r"""ResNet-101 model from`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_Args:pretrained (bool): If True, returns a model pre-trained on ImageNetprogress (bool): If True, displays a progress bar of the download to stderr"""return _resnet(Bottleneck, [3, 4, 23, 3], **kwargs)if __name__ == '__main__':net = resnet50(replace_stride_with_dilation=[False, True, True])print(net)# pip install torchinfo# 可以看到網絡每一層的輸出shape以及網絡參數信息summary(net, input_size=(1, 3, 480, 480))
1.3 FCN Head
- 經過backbone后,再通過FCN Head模塊。
- 通過3×3卷積層縮小通道為原來的1/4【2048-512】,再通過一個dropout和一個1×1卷積層
- 這里1×1卷積層調整特征層的channel為分割類別中的類別個數。
- layer3中引出的一條FCN Head輔助分類器,是為了防止誤差梯度無法傳遞到網絡淺層。
- 訓練的時候是可以使用輔助分類器件的。
- 最后去預測或者部署到正式環境的時候只用主干的output,不用aux output。
- 最后經過雙線性插值還原特征圖大小到原圖。
# /fcn/src/fcn_model.py
from collections import OrderedDictfrom typing import Dictimport torch
from torch import nn, Tensor
from torch.nn import functional as F
try:from .backbone import resnet50, resnet101
except:from backbone import resnet50, resnet101class IntermediateLayerGetter(nn.ModuleDict):_version = 2__annotations__ = {"return_layers": Dict[str, str],}def __init__(self, model: nn.Module, return_layers: Dict[str, str]) -> None:if not set(return_layers).issubset([name for name, _ in model.named_children()]):raise ValueError("return_layers are not present in model")orig_return_layers = return_layersreturn_layers = {str(k): str(v) for k, v in return_layers.items()}# 重新構建backbone,將沒有使用到的模塊全部刪掉layers = OrderedDict()for name, module in model.named_children():layers[name] = moduleif name in return_layers:del return_layers[name]if not return_layers:breaksuper(IntermediateLayerGetter, self).__init__(layers)self.return_layers = orig_return_layersdef forward(self, x: Tensor) -> Dict[str, Tensor]:out = OrderedDict()for name, module in self.items():x = module(x)# self.return_layers = {'layer4': 'out', 'layer3': 'aux'}if name in self.return_layers:out_name = self.return_layers[name]out[out_name] = xreturn outclass FCN(nn.Module):__constants__ = ['aux_classifier']def __init__(self, backbone, classifier, aux_classifier=None):super(FCN, self).__init__()self.backbone = backboneself.classifier = classifierself.aux_classifier = aux_classifierdef forward(self, x: Tensor) -> Dict[str, Tensor]:input_shape = x.shape[-2:]# contract: features is a dict of tensorsfeatures = self.backbone(x)result = OrderedDict()x = features["out"]x = self.classifier(x)# 原論文中雖然使用的是ConvTranspose2d,但權重是凍結的,所以就是一個bilinear插值x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)result["out"] = x# FCN Head輔助分類器,是為了防止誤差梯度無法傳遞到網絡淺層if self.aux_classifier is not None:x = features["aux"]x = self.aux_classifier(x)# 原論文中雖然使用的是ConvTranspose2d,但權重是凍結的,所以就是一個bilinear插值x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)result["aux"] = xreturn resultclass FCNHead(nn.Sequential):def __init__(self, in_channels, channels):# 通過3×3卷積層縮小通道為原來的1/4【2048-512】,再通過一個dropout和一個1×1卷積層inter_channels = in_channels // 4layers = [nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),nn.BatchNorm2d(inter_channels),nn.ReLU(),nn.Dropout(0.1),nn.Conv2d(inter_channels, channels, 1) # 這里1×1卷積層調整特征層的channel為分割類別中的類別個數]super(FCNHead, self).__init__(*layers)def fcn_resnet50(aux, num_classes=21, pretrain_backbone=False):# 'resnet50_imagenet': 'https://download.pytorch.org/models/resnet50-0676ba61.pth'# 'fcn_resnet50_coco': 'https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth'backbone = resnet50(replace_stride_with_dilation=[False, True, True])if pretrain_backbone:# 載入resnet50 backbone預訓練權重backbone.load_state_dict(torch.load("resnet50.pth", map_location='cpu'))out_inplanes = 2048aux_inplanes = 1024return_layers = {'layer4': 'out'}if aux:return_layers['layer3'] = 'aux'# backbone經過前向傳播的結果為OrderedDict()backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)aux_classifier = None# why using aux: https://github.com/pytorch/vision/issues/4292if aux:aux_classifier = FCNHead(aux_inplanes, num_classes)classifier = FCNHead(out_inplanes, num_classes)model = FCN(backbone, classifier, aux_classifier)return modeldef fcn_resnet101(aux, num_classes=21, pretrain_backbone=False):# 'resnet101_imagenet': 'https://download.pytorch.org/models/resnet101-63fe2227.pth'# 'fcn_resnet101_coco': 'https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth'backbone = resnet101(replace_stride_with_dilation=[False, True, True])if pretrain_backbone:# 載入resnet101 backbone預訓練權重backbone.load_state_dict(torch.load("resnet101.pth", map_location='cpu'))out_inplanes = 2048aux_inplanes = 1024return_layers = {'layer4': 'out'}if aux:return_layers['layer3'] = 'aux'backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)aux_classifier = None# why using aux: https://github.com/pytorch/vision/issues/4292if aux:aux_classifier = FCNHead(aux_inplanes, num_classes)classifier = FCNHead(out_inplanes, num_classes)model = FCN(backbone, classifier, aux_classifier)return modelif __name__ == '__main__':model = fcn_resnet50(aux=True, num_classes=21)print(model)x = torch.randn(size=(1, 3, 480, 480))print(model(x)['out'].shape)print(model(x)['aux'].shape)
2 損失函數的計算
2.1 VOC的標注詳解
-
這張圖片大致可以分為四部分,一部分是黑色背景,一部分是粉紅色的人,一部分是大紅色的飛機,還有一部分是白色的神秘物體。
-
圖片的背景,它是黑色的,背景類別為0,因此在調色板中0所對應的RGB值為[0,0,0],為黑色。
-
pascal_voc_classes.json中
"person": 15
,可知人用數字15表示,而在palette.json中,"15": [192, 128, 128]
可知15對應的RGB為粉紅色,因此粉紅色的是人。 -
同理,可知飛機
"aeroplane": 1
在調色板中對應的顏色為大紅色。 -
這個白色的神秘物體其實也是一個小飛機,但很難分辨,故標注時用白色像素給隱藏起來了,最后白色對應的像素也不會參與損失計算。如果你足夠細心的話,你會發現在人和飛機的邊緣其實都是存在一圈白色的像素的,這是為了更好的區分不同類別對應的像素。同樣,這里的白色也不會參與損失計算。
-
-
我們可以用程序來看看標注圖像中是否有白色像素。
from PIL import Image import numpy as np img = Image.open('D:\\VOCdevkit\\VOC2007\\SegmentationClass\\2007_000032.png') img_np = np.array(img)
- 可以看到地下的像素是1,表示飛機(大紅色),上面的像素為0,表示背景(黑色),中間的像素為255,這就對應著飛機周圍的白色像素。
- 我們可以看一下255對應的RGB值, [224,224,192]表示的RGB顏色為白色。
- 這里的255需要注意,后面計算損失時白色部分不計算正是通過忽略這個值實現的。
2.2 交叉熵損失cross_entropy
l o s s ( x , c l a s s ) = ? l o g ( e x [ c l a s s ] ∑ j e x [ j ] ) = ? x [ c l a s s ] + l o g ( ∑ j e x [ j ] ) 舉個例子:假設輸入 x = [ 0.1 , 0.2 , 0.3 ] ,標簽 c l a s s = 1 l o s s ( x , c l a s s ) = ? x [ c l a s s ] + l o g ( ∑ j e x [ j ] ) = ? 0.2 + l o g ( e x [ 0 ] + e x [ 1 ] + e x [ 2 ] ) = ? 0.2 + l o g ( e 0.1 + e 0.2 + e 0.3 ) loss(x,class)=-log(\frac{e^{x[class]}}{\sum\limits_{j} e^{x[j]}})=-x[class]+log(\sum\limits_{j} e^{x[j]})\\ 舉個例子:假設輸入x=[0.1,0.2,0.3],標簽class=1 \\ loss(x,class)=-x[class]+log(\sum\limits_{j} e^{x[j]})=-0.2 +log( e^{x[0]} + e^{x[1]} + e^{x[2]}) \\ = -0.2 +log( e^{0.1} + e^{0.2} + e^{0.3}) loss(x,class)=?log(j∑?ex[j]ex[class]?)=?x[class]+log(j∑?ex[j])舉個例子:假設輸入x=[0.1,0.2,0.3],標簽class=1loss(x,class)=?x[class]+log(j∑?ex[j])=?0.2+log(ex[0]+ex[1]+ex[2])=?0.2+log(e0.1+e0.2+e0.3)
我們可以用程序進行驗證:
import torch
import numpy as np
import math# 官方實現
input = torch.tensor([[0.1, 0.2, 0.3],[0.1, 0.2, 0.3],[0.1, 0.2, 0.3]])
target = torch.tensor([0, 1, 2])
loss = torch.nn.functional.cross_entropy(input, target)
print('官方計算 loss = ', loss.numpy())# 自己計算
res0 = -0.1 + np.log(math.exp(0.1) + math.exp(0.2) + math.exp(0.3))
res1 = -0.2 + np.log(math.exp(0.1) + math.exp(0.2) + math.exp(0.3))
res2 = -0.3 + np.log(math.exp(0.1) + math.exp(0.2) + math.exp(0.3))
res = (res0 + res1 + res2) / 3
print('自己計算 loss = %.7f ' % res)
# 僅精度有差別,所以這證明了我們的計算方式是沒有錯的。
官方計算 loss = 1.1019429
自己計算 loss = 1.1019428
FCN在計算損失是會忽略白色的像素,其就對應著標簽中的255。
忽略白色像素的損失其實很簡單,只要在函數調用時傳入ignore_index并指定對應的值即可。
如對本例來說,現我打算忽略target中標簽為2的數據,即不讓其參與損失計算,我們來看看如何使用cross_entropy函數來實現。
import torch
import numpy as np
import math# 官方實現
input = torch.tensor([[0.1, 0.2, 0.3],[0.1, 0.2, 0.3],[0.1, 0.2, 0.3]])
target = torch.tensor([0, 1, 2])
loss = torch.nn.functional.cross_entropy(input, target, ignore_index=2)
print('官方計算 loss = ', loss.numpy())# 自己計算
res0 = -0.1 + np.log(math.exp(0.1) + math.exp(0.2) + math.exp(0.3))
res1 = -0.2 + np.log(math.exp(0.1) + math.exp(0.2) + math.exp(0.3))
res = (res0 + res1 ) / 2
print('自己計算 loss = %.6f ' % res)
官方計算 loss = 1.151943
自己計算 loss = 1.151943
2.3 FCN中損失計算過程
-
程序中輸入cross_entropy函數中的x通常是4維的tensor,即[N,C,H,W],這時候訓練損失是怎么計算的呢?我們以x的維度為[1,2,2,2]為例講解
-
我們手動計算時候,會將數據按通道方向展開,然后分別計算cross_entropy,最后求平均(如下圖所示)
import torch
import numpy as np
import math# 1、官方計算
input = torch.tensor([[[[0.1, 0.2],[0.3, 0.4]],[[0.5, 0.6],[0.7, 0.8]]]]) #shape(1 2 2 2 )target = torch.tensor([[[0, 1],[0, 1]]])loss = torch.nn.functional.cross_entropy(input, target)
print('官方計算 loss = ', loss.numpy())# 2、自己計算
res0 = -0.1 + np.log(math.exp(0.1) + math.exp(0.5))
res1 = -0.6 + np.log(math.exp(0.2) + math.exp(0.6))
res2 = -0.3 + np.log(math.exp(0.3) + math.exp(0.7))
res3 = -0.8 + np.log(math.exp(0.4) + math.exp(0.8))
res = (res0 + res1 + res2 + res3)/4
print('自己計算 loss = %.8f ' % res)
官方計算 loss = 0.71301526
自己計算 loss = 0.71301525
- 如果,我們此時忽略target=0
import torch
import numpy as np
import math# 1、官方計算
input = torch.tensor([[[[0.1, 0.2],[0.3, 0.4]],[[0.5, 0.6],[0.7, 0.8]]]]) #shape(1 2 2 2 )target = torch.tensor([[[0, 1],[0, 1]]])loss = torch.nn.functional.cross_entropy(input, target , ignore_index=0)
print('官方計算 loss = ', loss.numpy())# 2、自己計算
res1 = -0.6 + np.log(math.exp(0.2) + math.exp(0.6))
res3 = -0.8 + np.log(math.exp(0.4) + math.exp(0.8))
res = ( res1 + res3)/2
print('自己計算 loss = %.7f ' % res)
官方計算 loss = 0.5130153
自己計算 loss = 0.5130153
2.4 FCN中損失代碼
- 通過上面講解,我們就很容易理解FCN的損失計算了。這里忽略了255像素,不讓其參與到損失的計算中。
- 如果輔助分類器存在,給予較小的損失權重。
# fcn/train_utils/train_and_eval.py
def criterion(inputs, target):losses = {}for name, x in inputs.items():# 忽略target中值為255的像素,255的像素是目標邊緣或者padding填充losses[name] = nn.functional.cross_entropy(x, target, ignore_index=255)if len(losses) == 1:return losses['out']return losses['out'] + 0.5 * losses['aux']
3 VOC數據集的讀取及數據預處理
我們自定義VOCSegmentation類,繼承pytorch提供的torch.utils.data.Dataset類,主要實現__getitem__
函數。再利用pytorch提供的Dataloader,就可以通過調用__getitem__
函數來批量讀取VOC數據集圖片和標簽了。
VOCSegmentation類的初始化部分,如下方的代碼所示:
# fcn/my_dataset.py
class VOCSegmentation(data.Dataset):def __init__(self, voc_root, year="2007", transforms=None, txt_name: str = "train.txt"):super(VOCSegmentation, self).__init__()assert year in ["2007", "2012"], "year must be in ['2007', '2012']"root = os.path.join(voc_root, "VOCdevkit", f"VOC{year}")assert os.path.exists(root), "path '{}' does not exist.".format(root)image_dir = os.path.join(root, 'JPEGImages')mask_dir = os.path.join(root, 'SegmentationClass')txt_path = os.path.join(root, "ImageSets", "Segmentation", txt_name)assert os.path.exists(txt_path), "file '{}' does not exist.".format(txt_path)with open(os.path.join(txt_path), "r") as f:file_names = [x.strip() for x in f.readlines() if len(x.strip()) > 0]self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names]assert (len(self.images) == len(self.masks))self.transforms = transforms
-
首先我們需要獲取輸入(image)和標簽(target)的路徑。
-
voc_root是我們應該傳入VOCdevkit所在的文件夾。
-
最終self.image和self.masks里存儲的就是我們輸入和標簽的路徑了。
-
-
接著我們對輸入圖片和標簽進行transformer預處理(代碼如下)
- 訓練集采用了隨機縮放、水平翻轉、隨機裁剪、toTensor和Normalize。
- 驗證集僅使用了隨機縮放、toTensor和Normalize。
- crop_size設置為480,即訓練圖片都會裁剪到480*480大小,而驗證時沒有使用隨機裁剪方法,因此
驗證集的圖片尺寸是不一致的, 需要進行進一步的處理
。
# fcn/train.py
class SegmentationPresetTrain:def __init__(self, base_size, crop_size, hflip_prob=0.5, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):min_size = int(0.5 * base_size)max_size = int(2.0 * base_size)trans = [T.RandomResize(min_size, max_size)]if hflip_prob > 0:trans.append(T.RandomHorizontalFlip(hflip_prob))trans.extend([T.RandomCrop(crop_size),T.ToTensor(),T.Normalize(mean=mean, std=std),])self.transforms = T.Compose(trans)def __call__(self, img, target):return self.transforms(img, target)class SegmentationPresetEval:def __init__(self, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):self.transforms = T.Compose([T.RandomResize(base_size, base_size),T.ToTensor(),T.Normalize(mean=mean, std=std),])def __call__(self, img, target):return self.transforms(img, target)def get_transform(train):base_size = 520crop_size = 480return SegmentationPresetTrain(base_size, crop_size) if train else SegmentationPresetEval(base_size)
-
預處理代碼完成后,就可以實現
__getitem__
以及__len__
方法。# fcn/my_dataset.pydef __getitem__(self, index):"""Args:index (int): IndexReturns:tuple: (image, target) where target is the image segmentation."""img = Image.open(self.images[index]).convert('RGB')target = Image.open(self.masks[index])if self.transforms is not None:img, target = self.transforms(img, target)return img, targetdef __len__(self):return len(self.images)@staticmethoddef collate_fn(batch):images, targets = list(zip(*batch))batched_imgs = cat_list(images, fill_value=0)batched_targets = cat_list(targets, fill_value=255)return batched_imgs, batched_targets
-
在VOCSegmentation類中,還實現了DataLoader中需要的collate_fn。
- 在collate_fn中,接受一個List類型數據,其中每個元素是一個Tuple2類型,包括了image和target。
- 在collate_fn中調用cat_list方法,對驗證集圖片尺寸是不一致進行處理。
# fcn/my_dataset.py def cat_list(images, fill_value=0):# 計算該batch數據中,channel, h, w的最大值max_size = tuple(max(s) for s in zip(*[img.shape for img in images]))batch_shape = (len(images),) + max_sizebatched_imgs = images[0].new(*batch_shape).fill_(fill_value)for img, pad_img in zip(images, batched_imgs):pad_img[..., :img.shape[-2], :img.shape[-1]].copy_(img)return batched_imgs
-
最后就可以調用Dataloader批量獲取數據了。
# fcn/train.py # VOCdevkit -> VOC2007 -> ImageSets -> Segmentation -> train.txttrain_dataset = VOCSegmentation(args.data_path,year="2007",transforms=get_transform(train=True),txt_name="train.txt")# VOCdevkit -> VOC2007 -> ImageSets -> Segmentation -> val.txtval_dataset = VOCSegmentation(args.data_path,year="2007",transforms=get_transform(train=False),txt_name="val.txt")num_workers = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,num_workers=num_workers,shuffle=True,pin_memory=True,collate_fn=train_dataset.collate_fn)val_loader = torch.utils.data.DataLoader(val_dataset,batch_size=1,num_workers=num_workers,pin_memory=True,collate_fn=val_dataset.collate_fn)
4 模型訓練及測試
4.1 模型訓練
-
代碼在 fcn/train.py 中。
-
先利用Dataset和DataLoader批量獲取數據。
-
然后創建FCN網絡模型,可以加載在COCO數據集上的預訓練權重。
def create_model(aux, num_classes, pretrain=True):model = fcn_resnet50(aux=aux, num_classes=num_classes)if pretrain:weights_dict = torch.load("./fcn_resnet50_coco.pth", map_location='cpu')if num_classes != 21:# 官方提供的預訓練權重是21類(包括背景)# 如果訓練自己的數據集,將和類別相關的權重刪除,防止權重shape不一致報錯for k in list(weights_dict.keys()):if "classifier.4" in k:del weights_dict[k]missing_keys, unexpected_keys = model.load_state_dict(weights_dict, strict=False)if len(missing_keys) != 0 or len(unexpected_keys) != 0:print("missing_keys: ", missing_keys)print("unexpected_keys: ", unexpected_keys)return model
-
設置SGD優化器
# 設置優化器 optimizer = torch.optim.SGD(params_to_optimize,lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
-
設置學習率更新策略。
# 創建學習率更新策略,這里是每個step更新一次(不是每個epoch)lr_scheduler = create_lr_scheduler(optimizer, len(train_loader), args.epochs, warmup=True)
# fcn/train_utils/train_and_eval.py def create_lr_scheduler(optimizer,num_step: int,epochs: int,warmup=True,warmup_epochs=1,warmup_factor=1e-3):assert num_step > 0 and epochs > 0if warmup is False:warmup_epochs = 0def f(x):"""根據step數返回一個學習率倍率因子,注意在訓練開始之前,pytorch會提前調用一次lr_scheduler.step()方法"""if warmup is True and x <= (warmup_epochs * num_step):alpha = float(x) / (warmup_epochs * num_step)# warmup過程中lr倍率因子從warmup_factor -> 1return warmup_factor * (1 - alpha) + alphaelse:# warmup后lr倍率因子從1 -> 0# 參考deeplab_v2: Learning rate policyreturn (1 - (x - warmup_epochs * num_step) / ((epochs - warmup_epochs) * num_step)) ** 0.9return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=f)
-
訓練代碼如下,可以代碼調試。
for epoch in range(args.start_epoch, args.epochs):mean_loss, lr = train_one_epoch(model, optimizer, train_loader, device, epoch,lr_scheduler=lr_scheduler, print_freq=args.print_freq, scaler=scaler)# 測試confmat = evaluate(model, val_loader, device=device, num_classes=num_classes)val_info = str(confmat)print(val_info)# write into txtwith open(results_file, "a") as f:# 記錄每個epoch對應的train_loss、lr以及驗證集各指標train_info = f"[epoch: {epoch}]\n" \f"train_loss: {mean_loss:.4f}\n" \f"lr: {lr:.6f}\n"f.write(train_info + val_info + "\n\n")save_file = {"model": model.state_dict(),"optimizer": optimizer.state_dict(),"lr_scheduler": lr_scheduler.state_dict(),"epoch": epoch,"args": args}if args.amp:save_file["scaler"] = scaler.state_dict()torch.save(save_file, "save_weights/model_{}.pth".format(epoch))
4.2 模型測試
在 train_and_val.py 文件中的 evaluate 函數代碼如下:
- 創建 ConfusionMatrix 混淆矩陣
- 使用 for 循環遍歷 data_loader 得到 image 和 target 信息,并將其指給對應的設備當中
- 再將 image 圖像輸入到 model 模型中進行預測,得到 output 輸出(只使用主分支上的輸出)
- 調用 update 方法時,在計算每一批數據預測結果與真實結果對比的過程中,將 target 和 output.argmax(1) 進行 flatten 處理
- output.argmax(1) 中的 1 是指在 channel 維度,而 argmax 方法用于 將每個像素預測值最大的類別作為其預測類別(如下圖所示) 。
# fcn/train_utils/train_and_eval.py
def evaluate(model, data_loader, device, num_classes):model.eval()confmat = utils.ConfusionMatrix(num_classes)metric_logger = utils.MetricLogger(delimiter=" ")header = 'Test:'with torch.no_grad():for image, target in metric_logger.log_every(data_loader, 100, header):image, target = image.to(device), target.to(device)output = model(image)output = output['out']confmat.update(target.flatten(), output.argmax(1).flatten())confmat.reduce_from_all_processes()return confmat
ConfusionMatrix 類代碼如下:
-
ConfusionMatrix 類中的
update 函數
傳入了真實標簽 a 和預測標簽 b 等參數,代碼的具體解析:- 這里的 num_classes 是指包含了背景的類別個數。
- 如果 self.mat 是 None ,就使用 torch.zeros 創建一個全零矩陣作為混淆矩陣,大小為 n x n ,用于記錄真實標簽和預測標簽之間的關系。
- 通過檢查真實標簽 a 中的元素是否屬于有效類別范圍 [ 0 , N ) 來尋找屬于目標類別的像素索引。
- 根據像素的真實類別 a [ k ] 和預測類別 b [ k ] 計算類別索引 inds ,用于統計真實類別為 a [ k ] 被預測成 b [ k ] 的像素個數。
- 使用 torch.bincount 統計類別索引 inds 在 [ 0 , n**2 ) 內的出現次數,并將結果重塑成 ( n , n ) 的矩陣形狀,統計數據累加到混淆矩陣中。
-
ConfusionMatrix 類中的
compute 函數
計算常見的語義分割評價指標。- 語義分割評價指標主要包括 Pixel Accuracy ( Global Accuracy )、mean Accuracy、mean IoU 等:
- Pixel Accuracy = 類別預測正確的像素個數總和 ÷ 圖片的總像素個數
- mean Accuracy = 對每個類別的 Accuracy 求平均值
- mean IoU = 對每個類別的 IoU 求平均值
- 語義分割評價指標主要包括 Pixel Accuracy ( Global Accuracy )、mean Accuracy、mean IoU 等:
class ConfusionMatrix(object):def __init__(self, num_classes):self.num_classes = num_classesself.mat = Nonedef update(self, a, b):n = self.num_classesif self.mat is None:# 創建混淆矩陣self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device)with torch.no_grad():# 尋找GT中為目標的像素索引(例如:255就不是目標的像素索引)k = (a >= 0) & (a < n)# 統計像素真實類別a[k]被預測成類別b[k]的個數(這里的做法很巧妙)inds = n * a[k].to(torch.int64) + b[k]self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n)def reset(self):if self.mat is not None:self.mat.zero_()def compute(self):h = self.mat.float()# 計算全局預測準確率(混淆矩陣的對角線為預測正確的個數)acc_global = torch.diag(h).sum() / h.sum()# 計算每個類別的準確率acc = torch.diag(h) / h.sum(1)# 計算每個類別預測與真實目標的iouiu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h))return acc_global, acc, iudef reduce_from_all_processes(self):if not torch.distributed.is_available():returnif not torch.distributed.is_initialized():returntorch.distributed.barrier()torch.distributed.all_reduce(self.mat)def __str__(self):acc_global, acc, iu = self.compute()return ('global correct: {:.1f}\n''average row correct: {}\n''IoU: {}\n''mean IoU: {:.1f}').format(acc_global.item() * 100,['{:.1f}'.format(i) for i in (acc * 100).tolist()],['{:.1f}'.format(i) for i in (iu * 100).tolist()],iu.mean().item() * 100)
4.3 模型預測
- 模型輸出為1×c×h×w,因為這是預測,故batch=1,這里使用的是VOC數據,故這里的c=num_class=21。【包含一個背景類】
- 首先我們會取輸出中每個像素在21個通道中的最大值,如第一個像素在21個通道的最大值在通道0上取得。這個通道對應的索引是0,在VOC中是背景類,故這個像素所屬類別為背景。其它像素同理。
# fcn/predict.pymodel.eval() # 進入驗證模式with torch.no_grad():# init modelimg_height, img_width = img.shape[-2:]init_img = torch.zeros((1, 3, img_height, img_width), device=device)model(init_img)t_start = time_synchronized()output = model(img.to(device))t_end = time_synchronized()print("inference time: {}".format(t_end - t_start))# 在輸出中的chanel維度求最大值對應的類別索引prediction = output['out'].argmax(1).squeeze(0)prediction = prediction.to("cpu").numpy().astype(np.uint8)mask = Image.fromarray(prediction)mask.putpalette(pallette)mask.save("test_result.png")