經典語義分割(一)利用pytorch復現全卷積神經網絡FCN

經典語義分割(一)利用pytorch復現全卷積神經網絡FCN

這里選擇B站up主[霹靂吧啦Wz]根據pytorch官方torchvision模塊中實現的FCN源碼。

Github連接:FCN源碼

1 FCN模型搭建

1.1 FCN網絡圖

  • pytorch官方實現的FCN網絡圖,如下所示。
  • 在這里插入圖片描述

1.2 backbone

  • FCN原文中的backbone是VGG,這里pytorch官方采用了resnet作為FCN的backbone。
    • ResNet的前兩層跟GoogLeNet中的?樣:
      • 在輸出通道數為64、步幅為2的7 × 7卷積層后,接步幅為2的3 × 3的最大匯聚層。
      • 不同之處在于ResNet每個卷積層后增加了批量規范化層。
    • GoogLeNet在后面接了4個由Inception塊組成的模塊。ResNet后接4個由殘差塊。
      • ResNet則使用4個由殘差塊組成的模塊,每個模塊使用若干個同樣輸出通道數的殘差塊。
      • 第1個模塊(layer1)由于之前已經使用了步幅為2的最大匯聚層,所以無須減小高和寬。
      • 原生的ResNet在之后的每個模塊(layer2、layer3、layer4)在第?個殘差塊里將上一個模塊的通道數翻倍,并將高和寬減半。
      • 不過,在這里和原生的ResNet不同的是,layer3和layer4使用了空洞卷積,并且高寬不減半。
# /fcn/src/backbone.py
import torch
import torch.nn as nn
from torchinfo import summarydef conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):"""3x3 convolution with padding"""return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,padding=dilation, groups=groups, bias=False, dilation=dilation)def conv1x1(in_planes, out_planes, stride=1):"""1x1 convolution"""return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)class Bottleneck(nn.Module):# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)# while original implementation places the stride at the first 1x1 convolution(self.conv1)# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.# This variant is also known as ResNet V1.5 and improves accuracy according to# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.expansion = 4def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,base_width=64, dilation=1, norm_layer=None):super(Bottleneck, self).__init__()if norm_layer is None:norm_layer = nn.BatchNorm2dwidth = int(planes * (base_width / 64.)) * groups# Both self.conv2 and self.downsample layers downsample the input when stride != 1self.conv1 = conv1x1(inplanes, width)self.bn1 = norm_layer(width)self.conv2 = conv3x3(width, width, stride, groups, dilation)self.bn2 = norm_layer(width)self.conv3 = conv1x1(width, planes * self.expansion)self.bn3 = norm_layer(planes * self.expansion)self.relu = nn.ReLU(inplace=True)self.downsample = downsampleself.stride = stridedef forward(self, x):identity = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)if self.downsample is not None:identity = self.downsample(x)out += identityout = self.relu(out)return outclass ResNet(nn.Module):def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,groups=1, width_per_group=64, replace_stride_with_dilation=None,norm_layer=None):super(ResNet, self).__init__()if norm_layer is None:norm_layer = nn.BatchNorm2dself._norm_layer = norm_layerself.inplanes = 64self.dilation = 1if replace_stride_with_dilation is None:# each element in the tuple indicates if we should replace# the 2x2 stride with a dilated convolution insteadreplace_stride_with_dilation = [False, False, False]if len(replace_stride_with_dilation) != 3:raise ValueError("replace_stride_with_dilation should be None ""or a 3-element tuple, got {}".format(replace_stride_with_dilation))self.groups = groupsself.base_width = width_per_group'''1、ResNet的前兩層ResNet的前兩層跟GoogLeNet中的?樣:在輸出通道數為64、步幅為2的7 × 7卷積層后,接步幅為2的3 × 3的最?匯聚層。不同之處在于ResNet每個卷積層后增加了批量規范化層。'''self.conv1 = nn.Conv2d(in_channels=3, out_channels=self.inplanes, kernel_size=7, stride=2, padding=3,bias=False)self.bn1 = norm_layer(self.inplanes)self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)'''2、ResNet后接4個由殘差塊GoogLeNet在后?接了4個由Inception塊組成的模塊。ResNet則使?4個由殘差塊組成的模塊,每個模塊使?若?個同樣輸出通道數的殘差塊。第?個模塊(layer1)由于之前已經使?了步幅為2的最?匯聚層,所以?須減??和寬。之后的每個模塊(layer2、layer3、layer4)在第?個殘差塊?將上?個模塊的通道數翻倍,并將?和寬減半。不過,在這里和原生的ResNet不同的是,layer3和layer4使用了空洞卷積,并且高寬不減半。'''self.layer1 = self._make_layer(block, 64,  layers[0])self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])self.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(512 * block.expansion, num_classes)for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')elif isinstance(m, nn.BatchNorm2d):nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)# Zero-initialize the last BN in each residual branch,# so that the residual branch starts with zeros, and each residual block behaves like an identity.# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677if zero_init_residual:for m in self.modules():if isinstance(m, Bottleneck):nn.init.constant_(m.bn3.weight, 0)def _make_layer(self, block, planes, blocks, stride=1, dilate=False):norm_layer = self._norm_layerdownsample = Noneprevious_dilation = self.dilationif dilate:# layer3和layer4使用了空洞卷積,高寬不減半,因此設置stride = 1self.dilation *= stridestride = 1# layer2、layer3和layer4的stride=2,滿足# layer1的stride=1,但是inplanes(64) != planes * block.expansion(64×4),因此也滿足if stride != 1 or self.inplanes != planes * block.expansion:downsample = nn.Sequential(conv1x1(self.inplanes, planes * block.expansion, stride),norm_layer(planes * block.expansion),)# 對于每個layer,只有第1個Bottleneck需要downsamplelayers = []layers.append(block(self.inplanes, planes, stride, downsample, self.groups,self.base_width, previous_dilation, norm_layer))self.inplanes = planes * block.expansion# 對于每個layer,從第2個Bottleneck開始,就不需要downsamplefor _ in range(1, blocks):layers.append(block(self.inplanes, planes, groups=self.groups,base_width=self.base_width, dilation=self.dilation,norm_layer=norm_layer))return nn.Sequential(*layers)def _forward_impl(self, x):# See note [TorchScript super()]x = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.maxpool(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)x = self.avgpool(x)x = torch.flatten(x, 1)x = self.fc(x)return xdef forward(self, x):return self._forward_impl(x)def _resnet(block, layers, **kwargs):model = ResNet(block, layers, **kwargs)return modeldef resnet50(**kwargs):r"""ResNet-50 model from`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_Args:pretrained (bool): If True, returns a model pre-trained on ImageNetprogress (bool): If True, displays a progress bar of the download to stderr"""return _resnet(Bottleneck, [3, 4, 6, 3], **kwargs)def resnet101(**kwargs):r"""ResNet-101 model from`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_Args:pretrained (bool): If True, returns a model pre-trained on ImageNetprogress (bool): If True, displays a progress bar of the download to stderr"""return _resnet(Bottleneck, [3, 4, 23, 3], **kwargs)if __name__ == '__main__':net = resnet50(replace_stride_with_dilation=[False, True, True])print(net)# pip install torchinfo# 可以看到網絡每一層的輸出shape以及網絡參數信息summary(net, input_size=(1, 3, 480, 480))

1.3 FCN Head

  • 經過backbone后,再通過FCN Head模塊。
    • 通過3×3卷積層縮小通道為原來的1/4【2048-512】,再通過一個dropout和一個1×1卷積層
    • 這里1×1卷積層調整特征層的channel為分割類別中的類別個數。
    • layer3中引出的一條FCN Head輔助分類器,是為了防止誤差梯度無法傳遞到網絡淺層。
      • 訓練的時候是可以使用輔助分類器件的。
      • 最后去預測或者部署到正式環境的時候只用主干的output,不用aux output。
  • 最后經過雙線性插值還原特征圖大小到原圖。
# /fcn/src/fcn_model.py
from collections import OrderedDictfrom typing import Dictimport torch
from torch import nn, Tensor
from torch.nn import functional as F
try:from .backbone import resnet50, resnet101
except:from backbone import resnet50, resnet101class IntermediateLayerGetter(nn.ModuleDict):_version = 2__annotations__ = {"return_layers": Dict[str, str],}def __init__(self, model: nn.Module, return_layers: Dict[str, str]) -> None:if not set(return_layers).issubset([name for name, _ in model.named_children()]):raise ValueError("return_layers are not present in model")orig_return_layers = return_layersreturn_layers = {str(k): str(v) for k, v in return_layers.items()}# 重新構建backbone,將沒有使用到的模塊全部刪掉layers = OrderedDict()for name, module in model.named_children():layers[name] = moduleif name in return_layers:del return_layers[name]if not return_layers:breaksuper(IntermediateLayerGetter, self).__init__(layers)self.return_layers = orig_return_layersdef forward(self, x: Tensor) -> Dict[str, Tensor]:out = OrderedDict()for name, module in self.items():x = module(x)# self.return_layers = {'layer4': 'out', 'layer3': 'aux'}if name in self.return_layers:out_name = self.return_layers[name]out[out_name] = xreturn outclass FCN(nn.Module):__constants__ = ['aux_classifier']def __init__(self, backbone, classifier, aux_classifier=None):super(FCN, self).__init__()self.backbone = backboneself.classifier = classifierself.aux_classifier = aux_classifierdef forward(self, x: Tensor) -> Dict[str, Tensor]:input_shape = x.shape[-2:]# contract: features is a dict of tensorsfeatures = self.backbone(x)result = OrderedDict()x = features["out"]x = self.classifier(x)# 原論文中雖然使用的是ConvTranspose2d,但權重是凍結的,所以就是一個bilinear插值x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)result["out"] = x# FCN Head輔助分類器,是為了防止誤差梯度無法傳遞到網絡淺層if self.aux_classifier is not None:x = features["aux"]x = self.aux_classifier(x)# 原論文中雖然使用的是ConvTranspose2d,但權重是凍結的,所以就是一個bilinear插值x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)result["aux"] = xreturn resultclass FCNHead(nn.Sequential):def __init__(self, in_channels, channels):# 通過3×3卷積層縮小通道為原來的1/4【2048-512】,再通過一個dropout和一個1×1卷積層inter_channels = in_channels // 4layers = [nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),nn.BatchNorm2d(inter_channels),nn.ReLU(),nn.Dropout(0.1),nn.Conv2d(inter_channels, channels, 1) # 這里1×1卷積層調整特征層的channel為分割類別中的類別個數]super(FCNHead, self).__init__(*layers)def fcn_resnet50(aux, num_classes=21, pretrain_backbone=False):# 'resnet50_imagenet': 'https://download.pytorch.org/models/resnet50-0676ba61.pth'# 'fcn_resnet50_coco': 'https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth'backbone = resnet50(replace_stride_with_dilation=[False, True, True])if pretrain_backbone:# 載入resnet50 backbone預訓練權重backbone.load_state_dict(torch.load("resnet50.pth", map_location='cpu'))out_inplanes = 2048aux_inplanes = 1024return_layers = {'layer4': 'out'}if aux:return_layers['layer3'] = 'aux'# backbone經過前向傳播的結果為OrderedDict()backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)aux_classifier = None# why using aux: https://github.com/pytorch/vision/issues/4292if aux:aux_classifier = FCNHead(aux_inplanes, num_classes)classifier = FCNHead(out_inplanes, num_classes)model = FCN(backbone, classifier, aux_classifier)return modeldef fcn_resnet101(aux, num_classes=21, pretrain_backbone=False):# 'resnet101_imagenet': 'https://download.pytorch.org/models/resnet101-63fe2227.pth'# 'fcn_resnet101_coco': 'https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth'backbone = resnet101(replace_stride_with_dilation=[False, True, True])if pretrain_backbone:# 載入resnet101 backbone預訓練權重backbone.load_state_dict(torch.load("resnet101.pth", map_location='cpu'))out_inplanes = 2048aux_inplanes = 1024return_layers = {'layer4': 'out'}if aux:return_layers['layer3'] = 'aux'backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)aux_classifier = None# why using aux: https://github.com/pytorch/vision/issues/4292if aux:aux_classifier = FCNHead(aux_inplanes, num_classes)classifier = FCNHead(out_inplanes, num_classes)model = FCN(backbone, classifier, aux_classifier)return modelif __name__ == '__main__':model = fcn_resnet50(aux=True, num_classes=21)print(model)x = torch.randn(size=(1, 3, 480, 480))print(model(x)['out'].shape)print(model(x)['aux'].shape)

2 損失函數的計算

2.1 VOC的標注詳解

在這里插入圖片描述

  • 這張圖片大致可以分為四部分,一部分是黑色背景,一部分是粉紅色的人,一部分是大紅色的飛機,還有一部分是白色的神秘物體。

    • 圖片的背景,它是黑色的,背景類別為0,因此在調色板中0所對應的RGB值為[0,0,0],為黑色。

    • pascal_voc_classes.json中"person": 15,可知人用數字15表示,而在palette.json中,"15": [192, 128, 128]可知15對應的RGB為粉紅色,因此粉紅色的是人。

    • 同理,可知飛機"aeroplane": 1在調色板中對應的顏色為大紅色。

    • 這個白色的神秘物體其實也是一個小飛機,但很難分辨,故標注時用白色像素給隱藏起來了,最后白色對應的像素也不會參與損失計算。如果你足夠細心的話,你會發現在人和飛機的邊緣其實都是存在一圈白色的像素的,這是為了更好的區分不同類別對應的像素。同樣,這里的白色也不會參與損失計算。

  • 我們可以用程序來看看標注圖像中是否有白色像素。

    from PIL import Image
    import numpy as np
    img = Image.open('D:\\VOCdevkit\\VOC2007\\SegmentationClass\\2007_000032.png')
    img_np = np.array(img)
    

    在這里插入圖片描述

    • 可以看到地下的像素是1,表示飛機(大紅色),上面的像素為0,表示背景(黑色),中間的像素為255,這就對應著飛機周圍的白色像素。
    • 我們可以看一下255對應的RGB值, [224,224,192]表示的RGB顏色為白色。
    • 這里的255需要注意,后面計算損失時白色部分不計算正是通過忽略這個值實現的。

2.2 交叉熵損失cross_entropy

l o s s ( x , c l a s s ) = ? l o g ( e x [ c l a s s ] ∑ j e x [ j ] ) = ? x [ c l a s s ] + l o g ( ∑ j e x [ j ] ) 舉個例子:假設輸入 x = [ 0.1 , 0.2 , 0.3 ] ,標簽 c l a s s = 1 l o s s ( x , c l a s s ) = ? x [ c l a s s ] + l o g ( ∑ j e x [ j ] ) = ? 0.2 + l o g ( e x [ 0 ] + e x [ 1 ] + e x [ 2 ] ) = ? 0.2 + l o g ( e 0.1 + e 0.2 + e 0.3 ) loss(x,class)=-log(\frac{e^{x[class]}}{\sum\limits_{j} e^{x[j]}})=-x[class]+log(\sum\limits_{j} e^{x[j]})\\ 舉個例子:假設輸入x=[0.1,0.2,0.3],標簽class=1 \\ loss(x,class)=-x[class]+log(\sum\limits_{j} e^{x[j]})=-0.2 +log( e^{x[0]} + e^{x[1]} + e^{x[2]}) \\ = -0.2 +log( e^{0.1} + e^{0.2} + e^{0.3}) loss(x,class)=?log(j?ex[j]ex[class]?)=?x[class]+log(j?ex[j])舉個例子:假設輸入x=[0.1,0.2,0.3],標簽class=1loss(x,class)=?x[class]+log(j?ex[j])=?0.2+log(ex[0]+ex[1]+ex[2])=?0.2+log(e0.1+e0.2+e0.3)

我們可以用程序進行驗證:

import torch
import numpy as np
import math# 官方實現
input = torch.tensor([[0.1, 0.2, 0.3],[0.1, 0.2, 0.3],[0.1, 0.2, 0.3]])
target = torch.tensor([0, 1, 2])
loss = torch.nn.functional.cross_entropy(input, target)
print('官方計算 loss = ', loss.numpy())# 自己計算
res0 = -0.1 + np.log(math.exp(0.1) + math.exp(0.2) + math.exp(0.3))
res1 = -0.2 + np.log(math.exp(0.1) + math.exp(0.2) + math.exp(0.3))
res2 = -0.3 + np.log(math.exp(0.1) + math.exp(0.2) + math.exp(0.3))
res = (res0 + res1 + res2) / 3
print('自己計算 loss = %.7f ' % res)
# 僅精度有差別,所以這證明了我們的計算方式是沒有錯的。
官方計算 loss = 1.1019429
自己計算 loss = 1.1019428 

FCN在計算損失是會忽略白色的像素,其就對應著標簽中的255。

忽略白色像素的損失其實很簡單,只要在函數調用時傳入ignore_index并指定對應的值即可。

如對本例來說,現我打算忽略target中標簽為2的數據,即不讓其參與損失計算,我們來看看如何使用cross_entropy函數來實現。

import torch
import numpy as np
import math# 官方實現
input = torch.tensor([[0.1, 0.2, 0.3],[0.1, 0.2, 0.3],[0.1, 0.2, 0.3]])
target = torch.tensor([0, 1, 2])
loss = torch.nn.functional.cross_entropy(input, target,  ignore_index=2)
print('官方計算 loss = ', loss.numpy())# 自己計算
res0 = -0.1 + np.log(math.exp(0.1) + math.exp(0.2) + math.exp(0.3))
res1 = -0.2 + np.log(math.exp(0.1) + math.exp(0.2) + math.exp(0.3))
res = (res0 + res1 ) / 2
print('自己計算 loss = %.6f ' % res)
官方計算 loss = 1.151943
自己計算 loss = 1.151943 

2.3 FCN中損失計算過程

  • 程序中輸入cross_entropy函數中的x通常是4維的tensor,即[N,C,H,W],這時候訓練損失是怎么計算的呢?我們以x的維度為[1,2,2,2]為例講解

  • 我們手動計算時候,會將數據按通道方向展開,然后分別計算cross_entropy,最后求平均(如下圖所示)

在這里插入圖片描述

import torch
import numpy as np
import math# 1、官方計算
input = torch.tensor([[[[0.1, 0.2],[0.3, 0.4]],[[0.5, 0.6],[0.7, 0.8]]]])    #shape(1 2 2 2 )target = torch.tensor([[[0, 1],[0, 1]]])loss = torch.nn.functional.cross_entropy(input, target)
print('官方計算 loss = ', loss.numpy())# 2、自己計算
res0 = -0.1 + np.log(math.exp(0.1) + math.exp(0.5))
res1 = -0.6 + np.log(math.exp(0.2) + math.exp(0.6))
res2 = -0.3 + np.log(math.exp(0.3) + math.exp(0.7))
res3 = -0.8 + np.log(math.exp(0.4) + math.exp(0.8))
res = (res0 + res1 + res2 + res3)/4
print('自己計算 loss = %.8f ' % res)
官方計算 loss = 0.71301526
自己計算 loss = 0.71301525 
  • 如果,我們此時忽略target=0

在這里插入圖片描述

import torch
import numpy as np
import math# 1、官方計算
input = torch.tensor([[[[0.1, 0.2],[0.3, 0.4]],[[0.5, 0.6],[0.7, 0.8]]]])    #shape(1 2 2 2 )target = torch.tensor([[[0, 1],[0, 1]]])loss = torch.nn.functional.cross_entropy(input, target , ignore_index=0)
print('官方計算 loss = ', loss.numpy())# 2、自己計算
res1 = -0.6 + np.log(math.exp(0.2) + math.exp(0.6))
res3 = -0.8 + np.log(math.exp(0.4) + math.exp(0.8))
res = ( res1  + res3)/2
print('自己計算 loss = %.7f ' % res)
官方計算 loss =  0.5130153
自己計算 loss =  0.5130153 

2.4 FCN中損失代碼

  • 通過上面講解,我們就很容易理解FCN的損失計算了。這里忽略了255像素,不讓其參與到損失的計算中。
  • 如果輔助分類器存在,給予較小的損失權重。
# fcn/train_utils/train_and_eval.py
def criterion(inputs, target):losses = {}for name, x in inputs.items():# 忽略target中值為255的像素,255的像素是目標邊緣或者padding填充losses[name] = nn.functional.cross_entropy(x, target, ignore_index=255)if len(losses) == 1:return losses['out']return losses['out'] + 0.5 * losses['aux']

3 VOC數據集的讀取及數據預處理

我們自定義VOCSegmentation類,繼承pytorch提供的torch.utils.data.Dataset類,主要實現__getitem__函數。再利用pytorch提供的Dataloader,就可以通過調用__getitem__函數來批量讀取VOC數據集圖片和標簽了。

VOCSegmentation類的初始化部分,如下方的代碼所示:

# fcn/my_dataset.py
class VOCSegmentation(data.Dataset):def __init__(self, voc_root, year="2007", transforms=None, txt_name: str = "train.txt"):super(VOCSegmentation, self).__init__()assert year in ["2007", "2012"], "year must be in ['2007', '2012']"root = os.path.join(voc_root, "VOCdevkit", f"VOC{year}")assert os.path.exists(root), "path '{}' does not exist.".format(root)image_dir = os.path.join(root, 'JPEGImages')mask_dir = os.path.join(root, 'SegmentationClass')txt_path = os.path.join(root, "ImageSets", "Segmentation", txt_name)assert os.path.exists(txt_path), "file '{}' does not exist.".format(txt_path)with open(os.path.join(txt_path), "r") as f:file_names = [x.strip() for x in f.readlines() if len(x.strip()) > 0]self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names]assert (len(self.images) == len(self.masks))self.transforms = transforms
  • 首先我們需要獲取輸入(image)和標簽(target)的路徑。

    • voc_root是我們應該傳入VOCdevkit所在的文件夾。

    • 最終self.image和self.masks里存儲的就是我們輸入和標簽的路徑了。

  • 接著我們對輸入圖片和標簽進行transformer預處理(代碼如下)

    • 訓練集采用了隨機縮放、水平翻轉、隨機裁剪、toTensor和Normalize。
    • 驗證集僅使用了隨機縮放、toTensor和Normalize。
    • crop_size設置為480,即訓練圖片都會裁剪到480*480大小,而驗證時沒有使用隨機裁剪方法,因此驗證集的圖片尺寸是不一致的, 需要進行進一步的處理
# fcn/train.py
class SegmentationPresetTrain:def __init__(self, base_size, crop_size, hflip_prob=0.5, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):min_size = int(0.5 * base_size)max_size = int(2.0 * base_size)trans = [T.RandomResize(min_size, max_size)]if hflip_prob > 0:trans.append(T.RandomHorizontalFlip(hflip_prob))trans.extend([T.RandomCrop(crop_size),T.ToTensor(),T.Normalize(mean=mean, std=std),])self.transforms = T.Compose(trans)def __call__(self, img, target):return self.transforms(img, target)class SegmentationPresetEval:def __init__(self, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):self.transforms = T.Compose([T.RandomResize(base_size, base_size),T.ToTensor(),T.Normalize(mean=mean, std=std),])def __call__(self, img, target):return self.transforms(img, target)def get_transform(train):base_size = 520crop_size = 480return SegmentationPresetTrain(base_size, crop_size) if train else SegmentationPresetEval(base_size)
  • 預處理代碼完成后,就可以實現__getitem__以及__len__方法。

     # fcn/my_dataset.pydef __getitem__(self, index):"""Args:index (int): IndexReturns:tuple: (image, target) where target is the image segmentation."""img = Image.open(self.images[index]).convert('RGB')target = Image.open(self.masks[index])if self.transforms is not None:img, target = self.transforms(img, target)return img, targetdef __len__(self):return len(self.images)@staticmethoddef collate_fn(batch):images, targets = list(zip(*batch))batched_imgs = cat_list(images, fill_value=0)batched_targets = cat_list(targets, fill_value=255)return batched_imgs, batched_targets
    
  • 在VOCSegmentation類中,還實現了DataLoader中需要的collate_fn。

    • 在collate_fn中,接受一個List類型數據,其中每個元素是一個Tuple2類型,包括了image和target。
    • 在collate_fn中調用cat_list方法,對驗證集圖片尺寸是不一致進行處理。
    # fcn/my_dataset.py
    def cat_list(images, fill_value=0):# 計算該batch數據中,channel, h, w的最大值max_size = tuple(max(s) for s in zip(*[img.shape for img in images]))batch_shape = (len(images),) + max_sizebatched_imgs = images[0].new(*batch_shape).fill_(fill_value)for img, pad_img in zip(images, batched_imgs):pad_img[..., :img.shape[-2], :img.shape[-1]].copy_(img)return batched_imgs
    
  • 最后就可以調用Dataloader批量獲取數據了。

# fcn/train.py   # VOCdevkit -> VOC2007 -> ImageSets -> Segmentation -> train.txttrain_dataset = VOCSegmentation(args.data_path,year="2007",transforms=get_transform(train=True),txt_name="train.txt")# VOCdevkit -> VOC2007 -> ImageSets -> Segmentation -> val.txtval_dataset = VOCSegmentation(args.data_path,year="2007",transforms=get_transform(train=False),txt_name="val.txt")num_workers = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,num_workers=num_workers,shuffle=True,pin_memory=True,collate_fn=train_dataset.collate_fn)val_loader = torch.utils.data.DataLoader(val_dataset,batch_size=1,num_workers=num_workers,pin_memory=True,collate_fn=val_dataset.collate_fn)

4 模型訓練及測試

4.1 模型訓練

  • 代碼在 fcn/train.py 中。

  • 先利用Dataset和DataLoader批量獲取數據。

  • 然后創建FCN網絡模型,可以加載在COCO數據集上的預訓練權重。

    def create_model(aux, num_classes, pretrain=True):model = fcn_resnet50(aux=aux, num_classes=num_classes)if pretrain:weights_dict = torch.load("./fcn_resnet50_coco.pth", map_location='cpu')if num_classes != 21:# 官方提供的預訓練權重是21類(包括背景)# 如果訓練自己的數據集,將和類別相關的權重刪除,防止權重shape不一致報錯for k in list(weights_dict.keys()):if "classifier.4" in k:del weights_dict[k]missing_keys, unexpected_keys = model.load_state_dict(weights_dict, strict=False)if len(missing_keys) != 0 or len(unexpected_keys) != 0:print("missing_keys: ", missing_keys)print("unexpected_keys: ", unexpected_keys)return model
    
  • 設置SGD優化器

    # 設置優化器
    optimizer = torch.optim.SGD(params_to_optimize,lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    
  • 設置學習率更新策略。

     # 創建學習率更新策略,這里是每個step更新一次(不是每個epoch)lr_scheduler = create_lr_scheduler(optimizer, len(train_loader), args.epochs, warmup=True)
    
    # fcn/train_utils/train_and_eval.py
    def create_lr_scheduler(optimizer,num_step: int,epochs: int,warmup=True,warmup_epochs=1,warmup_factor=1e-3):assert num_step > 0 and epochs > 0if warmup is False:warmup_epochs = 0def f(x):"""根據step數返回一個學習率倍率因子,注意在訓練開始之前,pytorch會提前調用一次lr_scheduler.step()方法"""if warmup is True and x <= (warmup_epochs * num_step):alpha = float(x) / (warmup_epochs * num_step)# warmup過程中lr倍率因子從warmup_factor -> 1return warmup_factor * (1 - alpha) + alphaelse:# warmup后lr倍率因子從1 -> 0# 參考deeplab_v2: Learning rate policyreturn (1 - (x - warmup_epochs * num_step) / ((epochs - warmup_epochs) * num_step)) ** 0.9return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=f)
    
  • 訓練代碼如下,可以代碼調試。

    for epoch in range(args.start_epoch, args.epochs):mean_loss, lr = train_one_epoch(model, optimizer, train_loader, device, epoch,lr_scheduler=lr_scheduler, print_freq=args.print_freq, scaler=scaler)# 測試confmat = evaluate(model, val_loader, device=device, num_classes=num_classes)val_info = str(confmat)print(val_info)# write into txtwith open(results_file, "a") as f:# 記錄每個epoch對應的train_loss、lr以及驗證集各指標train_info = f"[epoch: {epoch}]\n" \f"train_loss: {mean_loss:.4f}\n" \f"lr: {lr:.6f}\n"f.write(train_info + val_info + "\n\n")save_file = {"model": model.state_dict(),"optimizer": optimizer.state_dict(),"lr_scheduler": lr_scheduler.state_dict(),"epoch": epoch,"args": args}if args.amp:save_file["scaler"] = scaler.state_dict()torch.save(save_file, "save_weights/model_{}.pth".format(epoch))
    

4.2 模型測試

在 train_and_val.py 文件中的 evaluate 函數代碼如下:

  • 創建 ConfusionMatrix 混淆矩陣
  • 使用 for 循環遍歷 data_loader 得到 image 和 target 信息,并將其指給對應的設備當中
  • 再將 image 圖像輸入到 model 模型中進行預測,得到 output 輸出(只使用主分支上的輸出)
  • 調用 update 方法時,在計算每一批數據預測結果與真實結果對比的過程中,將 target 和 output.argmax(1) 進行 flatten 處理
    • output.argmax(1) 中的 1 是指在 channel 維度,而 argmax 方法用于 將每個像素預測值最大的類別作為其預測類別(如下圖所示) 。
    • 在這里插入圖片描述
# fcn/train_utils/train_and_eval.py
def evaluate(model, data_loader, device, num_classes):model.eval()confmat = utils.ConfusionMatrix(num_classes)metric_logger = utils.MetricLogger(delimiter="  ")header = 'Test:'with torch.no_grad():for image, target in metric_logger.log_every(data_loader, 100, header):image, target = image.to(device), target.to(device)output = model(image)output = output['out']confmat.update(target.flatten(), output.argmax(1).flatten())confmat.reduce_from_all_processes()return confmat

ConfusionMatrix 類代碼如下:

  • ConfusionMatrix 類中的 update 函數傳入了真實標簽 a 和預測標簽 b 等參數,代碼的具體解析:

    • 這里的 num_classes 是指包含了背景的類別個數。
    • 如果 self.mat 是 None ,就使用 torch.zeros 創建一個全零矩陣作為混淆矩陣,大小為 n x n ,用于記錄真實標簽和預測標簽之間的關系。
    • 通過檢查真實標簽 a 中的元素是否屬于有效類別范圍 [ 0 , N ) 來尋找屬于目標類別的像素索引。
    • 根據像素的真實類別 a [ k ] 和預測類別 b [ k ] 計算類別索引 inds ,用于統計真實類別為 a [ k ] 被預測成 b [ k ] 的像素個數。
    • 使用 torch.bincount 統計類別索引 inds 在 [ 0 , n**2 ) 內的出現次數,并將結果重塑成 ( n , n ) 的矩陣形狀,統計數據累加到混淆矩陣中。
      在這里插入圖片描述
  • ConfusionMatrix 類中的 compute 函數計算常見的語義分割評價指標。

    • 語義分割評價指標主要包括 Pixel Accuracy ( Global Accuracy )、mean Accuracy、mean IoU 等:
      • Pixel Accuracy = 類別預測正確的像素個數總和 ÷ 圖片的總像素個數
      • mean Accuracy = 對每個類別的 Accuracy 求平均值
      • mean IoU = 對每個類別的 IoU 求平均值
class ConfusionMatrix(object):def __init__(self, num_classes):self.num_classes = num_classesself.mat = Nonedef update(self, a, b):n = self.num_classesif self.mat is None:# 創建混淆矩陣self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device)with torch.no_grad():# 尋找GT中為目標的像素索引(例如:255就不是目標的像素索引)k = (a >= 0) & (a < n)# 統計像素真實類別a[k]被預測成類別b[k]的個數(這里的做法很巧妙)inds = n * a[k].to(torch.int64) + b[k]self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n)def reset(self):if self.mat is not None:self.mat.zero_()def compute(self):h = self.mat.float()# 計算全局預測準確率(混淆矩陣的對角線為預測正確的個數)acc_global = torch.diag(h).sum() / h.sum()# 計算每個類別的準確率acc = torch.diag(h) / h.sum(1)# 計算每個類別預測與真實目標的iouiu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h))return acc_global, acc, iudef reduce_from_all_processes(self):if not torch.distributed.is_available():returnif not torch.distributed.is_initialized():returntorch.distributed.barrier()torch.distributed.all_reduce(self.mat)def __str__(self):acc_global, acc, iu = self.compute()return ('global correct: {:.1f}\n''average row correct: {}\n''IoU: {}\n''mean IoU: {:.1f}').format(acc_global.item() * 100,['{:.1f}'.format(i) for i in (acc * 100).tolist()],['{:.1f}'.format(i) for i in (iu * 100).tolist()],iu.mean().item() * 100)

4.3 模型預測

  • 模型輸出為1×c×h×w,因為這是預測,故batch=1,這里使用的是VOC數據,故這里的c=num_class=21。【包含一個背景類】
  • 首先我們會取輸出中每個像素在21個通道中的最大值,如第一個像素在21個通道的最大值在通道0上取得。這個通道對應的索引是0,在VOC中是背景類,故這個像素所屬類別為背景。其它像素同理。
  • 在這里插入圖片描述
    # fcn/predict.pymodel.eval()  # 進入驗證模式with torch.no_grad():# init modelimg_height, img_width = img.shape[-2:]init_img = torch.zeros((1, 3, img_height, img_width), device=device)model(init_img)t_start = time_synchronized()output = model(img.to(device))t_end = time_synchronized()print("inference time: {}".format(t_end - t_start))# 在輸出中的chanel維度求最大值對應的類別索引prediction = output['out'].argmax(1).squeeze(0)prediction = prediction.to("cpu").numpy().astype(np.uint8)mask = Image.fromarray(prediction)mask.putpalette(pallette)mask.save("test_result.png")

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/717865.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/717865.shtml
英文地址,請注明出處:http://en.pswp.cn/news/717865.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

為raspberrypi編譯bpftrace調試工具

基于eBPF的嵌入式應用調試 筆者之前寫過幾篇有關于使用eBPF調試Linux內核和應用的博客&#xff0c;其中提到&#xff0c;在嵌入式設備上使用BCC或bpftrace是不可行的&#xff1b;主要原因在于嵌入式設備的資源有限&#xff0c;而這兩個調試工具依賴python/clang/llvm等庫&…

Scratch 第十六課-彈珠臺游戲

第十六課-彈珠臺游戲 大家好&#xff0c;今天我們一起做一款彈珠臺scratch游戲&#xff0c;我們也可以叫它彈球游戲&#xff01;這款游戲在剛出來的時候非常火爆。小朋友們要認真學習下&#xff01; 這節課的學習目標 物體碰撞如何處理轉向問題。復習鍵盤對角色的控制方式。…

STL-內存的配置與釋放

STL-內存的配置與釋放 STL有兩級空間配置器&#xff0c;默認是使用第二級。第二級空間配置器會在某些情況下去調用第一級空間配置器。空間配置器都是在allocate函數內分配內存&#xff0c;在deallocate函數內釋放內存。 第一級空間配置器 第一級配置器只是對malloc函數和fre…

【自然語言處理】BitNet b1.58:1bit LLM時代

論文地址&#xff1a;https://arxiv.org/pdf/2402.17764.pdf 相關博客 【自然語言處理】BitNet b1.58&#xff1a;1bit LLM時代 【自然語言處理】【長文本處理】RMT&#xff1a;能處理長度超過一百萬token的Transformer 【自然語言處理】【大模型】MPT模型結構源碼解析(單機版)…

如何在 Mac 上成功輕松地恢復 Excel 文件

Microsoft Excel 的 Mac 版本始終略落后于 Windows 版本&#xff0c;這也許可以解釋為什么如此多的用戶渴望學習如何在 Mac 上恢復 Excel 文件。 但導致重要電子表格不可用的不僅僅是 Mac 版 Excel 的不完全穩定性。用戶有時會失去注意力并刪除錯誤的文件&#xff0c;存儲設備…

2024-03-03 c++

&#x1f338; MFC進度條控件 | Progress Control 1。新建MFC項目&#xff08;基于對話框、靜態庫&#xff09; 2。添加控件&#xff0c;刪除初始的3個多余控件 加1個progress control&#xff0c;修改其marquee為true&#xff0c;添加變量&#xff1a;變量名為test_progress。…

Angular基礎---HelloWorld---Day1

文章目錄 1. 創建Angular 項目2.對Angular架構的最基本了解3.創建并引用新的組件&#xff08;component&#xff09;4.對Angular架構新的認識&#xff08;多組件&#xff09;5.組件中業務邏輯文件的編輯&#xff08;ts文件&#xff09;6.標簽中屬性的綁定(1) ID的綁定(2) class…

String和String Builder

String和StringBuilder的區別 String類 String類代表字符串。java程序中所有字符串文字&#xff08;例如“abc”&#xff09;都被實現為此類的實例。 String類源碼是用final修飾的&#xff0c;它們的值在創建后不能被更改。字符串緩沖區支持可變字符串。 String對象是不可變…

STM32 (2)

1.stm32編程模型 將C語言程序燒錄到芯片中會存儲在單片機的flsah存儲器中&#xff0c;給芯片上電后&#xff0c;Flash中的程序會逐條進入到CPU中去執行&#xff0c;進而CPU去控制各種模塊&#xff08;即外設&#xff09;去實現各種功能。 2.寄存器和寄存器編程 CPU通過控制其…

Apache POI的簡單介紹與應用

介紹 Apache POI 是一個處理Miscrosoft Office各種文件格式的開源項目。我們可以使用 POI 在 Java 程序中對Miscrosoft Office各種文件進行讀寫操作。PS&#xff1a; 一般情況下&#xff0c;POI 都是用于操作 Excel 文件&#xff0c;如圖&#xff1a; Apache POI 的應用場景&…

SQL無列名注入

SQL無列名注入 ? 前段時間&#xff0c;隊里某位大佬發了一個關于sql注入無列名的文章&#xff0c;感覺好像很有用&#xff0c;特地研究下。 關于 information_schema 數據庫&#xff1a; ? 對于這一個庫&#xff0c;我所知曉的內容并不多&#xff0c;并且之前總結SQL注入的…

設計模式-橋接模式實踐案例

橋接模式&#xff08;Bridge Pattern&#xff09;是一種結構型設計模式&#xff0c;用于將抽象與實現分離&#xff0c;使它們可以獨立地變化。這種模式通過提供一個橋接結構&#xff0c;可以將實現接口的實現部分和抽象層中可變化的部分分離開來。 以下是一個使用 Java 實現橋…

【數據結構】_包裝類與泛型

目錄 1. 包裝類 1.1 基本數據類型和對應的包裝類 1.2 &#xff08;自動&#xff09;裝箱和&#xff08;自動&#xff09;拆箱 1.2.1 裝箱與拆箱 1.2.2 自動&#xff08;顯式&#xff09;裝箱與自動&#xff08;顯式&#xff09;拆箱 1.3 valueOf()方法 2. 泛型類 2.1 泛…

【深度學習筆記】計算機視覺——目標檢測和邊界框

目標檢測和邊界框 前面的章節&#xff08;例如 sec_alexnet— sec_googlenet&#xff09;介紹了各種圖像分類模型。 在圖像分類任務中&#xff0c;我們假設圖像中只有一個主要物體對象&#xff0c;我們只關注如何識別其類別。 然而&#xff0c;很多時候圖像里有多個我們感興趣…

某大型制造企業數字化轉型規劃方案(附下載)

目錄 一、項目背景和目標 二、業務現狀 1. 總體應用現狀 2. 各模塊業務問題 2.1 設計 2.2 仿真 2.3 制造 2.4 服務 2.5 管理 三、業務需求及預期效果 1. 總體業務需求 2. 各模塊業務需求 2.1 設計 2.2 仿真 2.3 制造 2.4 服務 2.5 管理 四、…

在vue中對keep-alive的理解,它是如何實現的,具體緩存的是什么?

對keep-alive的理解&#xff0c;它是如何實現的&#xff0c;具體緩存的是什么&#xff1f; &#xff08;1&#xff09;keep-alive有以下三個屬性&#xff1a;注意&#xff1a;keep-alive 包裹動態組件時&#xff0c;會緩存不活動的組件實例。主要流程 &#xff08;2&#xff09…

數字化轉型導師堅鵬:證券公司數字化營銷

證券公司數字化營銷 ——借力數字化技術實現零售業務的批量化、精準化、場景化、智能化營銷 課程背景&#xff1a; 很多證券公司存在以下問題&#xff1a; 不知道如何提升證券公司數字化營銷能力&#xff1f; 不知道證券公司如何開展數字化營銷工作&#xff1f; 不知道…

胎神游戲集第二期

延續上一期 一、海島奇胎 #include<bits/stdc.h> #include<windows.h> #include<stdio.h> #include<conio.h> #include<time.h> using namespace std; typedef BOOL (WINAPI *PROCSETCONSOLEFONT)(HANDLE, DWORD); PROCSETCONSOLEFONT SetCons…

Linux 安裝pip和換源

一 配置文檔 Linux和macOS&#xff1a; 全局配置&#xff1a;/etc/pip.conf 用戶級配置&#xff1a;~/.pip/pip.conf 或 ~/.config/pip/pip.conf 二 下載 和 安裝 # pip 安裝 wget https://bootstrap.pypa.io/get-pip.py python get-pip.py 三 查看和升級 pip -Vpython -m…

GO語言學習筆記(與Java的比較學習)(十一)

協程與通道 什么是協程 一個應用程序是運行在機器上的一個進程&#xff1b;進程是一個運行在自己內存地址空間里的獨立執行體。一個進程由一個或多個操作系統線程組成&#xff0c;這些線程其實是共享同一個內存地址空間的一起工作的執行體。 并行是一種通過使用多處理器以提…