昇思MindSpore學習總結九——FCN語義分割

1、語義分割

????????圖像語義分割(semantic segmentation)是圖像處理和機器視覺技術中關于圖像理解的重要一環,AI領域中一個重要分支,常被應用于人臉識別、物體檢測、醫學影像、衛星圖像分析、自動駕駛感知等領域。

????????語義分割的目的是對圖像中每個像素點進行分類。?要識別出整張圖片的每個部分,就意味著要精確到像素點,所以語義分割實際上是對圖像中每一個像素點進行分類,確定每個點的類別(如屬于背景、人、汽車、馬等),從而進行區域劃分。

????????與普通的分類任務只輸出某個類別不同,語義分割任務輸出與輸入大小相同的圖像,輸出圖像的每個像素對應了輸入圖像每個像素的類別。語義在圖像領域指的是圖像的內容,對圖片意思的理解,下圖是一些語義分割的實例:

2、全卷積網絡

????????全卷積網絡(Fully Convolutional Networks,FCN)是UC Berkeley的Jonathan Long等人于2015年在Fully Convolutional Networks for Semantic Segmentation[1]一文中提出的用于圖像語義分割的一種框架。

核心思想:

1.不含全連接層(fc)的全卷積(fully conv)網絡。可適應任意尺寸輸入。?

2.增大數據尺寸的反卷積(deconv)層。能夠輸出精細的結果。?

3.結合不同深度層結果的跳級(skip)結構。同時確保魯棒性和精確性。

FCN是首個端到端(end to end)進行像素級(pixel level)預測的全卷積網絡。

?3、模型簡介

FCN主要用于圖像分割領域,是一種端到端的分割方法,是深度學習應用在圖像語義分割的開山之作。通過進行像素級的預測直接得出與原圖大小相等的label map。因FCN丟棄全連接層替換為全卷積層,網絡所有層均為卷積層,故稱為全卷積網絡。

全卷積神經網絡主要使用以下三種技術:

3.1?卷積化(Convolutional)

使用VGG-16作為FCN的backbone。VGG-16的輸入為224*224的RGB圖像,輸出為1000個預測值。VGG-16只能接受固定大小的輸入,丟棄了空間坐標,產生非空間輸出。VGG-16中共有三個全連接層,全連接層也可視為帶有覆蓋整個區域的卷積。將全連接層轉換為卷積層能使網絡輸出由一維非空間輸出變為二維矩陣,利用輸出能生成輸入圖片映射的heatmap。

?

?3*3 conv, 64:使用64個size是3*3,stride步長為1,padding填充為1的卷積核。

池化:最大池化,使用size是2*2,stride步長為2,padding填充為0進行池化。

第6層和第7層分別是一個長度為4096的一維向量,第8層是長度為1000的一維向量,分別對應1000個類別的概率。FCN將這3層表示為卷積層,卷積核的大小(通道數,寬,高)分別為(4096,7,7)、(4096,1,1)、(1000,1,1)。所有的層都是卷積層,故稱為全卷積網絡。

3.2 上采樣(Upsample)

在卷積過程的卷積操作和池化操作會使得特征圖的尺寸變小,為得到原圖的大小的稠密圖像預測,需要對得到的特征圖進行上采樣操作。使用雙線性插值的參數來初始化上采樣逆卷積的參數,后通過反向傳播來學習非線性上采樣。在網絡中執行上采樣,以通過像素損失的反向傳播進行端到端的學習。

3.3 跳躍結構(Skip Layer)

利用上采樣技巧對最后一層的特征圖進行上采樣得到原圖大小的分割是步長為32像素的預測,稱之為FCN-32s。由于最后一層的特征圖太小,損失過多細節,采用skips結構將更具有全局信息的最后一層預測和更淺層的預測結合,使預測結果獲取更多的局部細節。將底層(stride 32)的預測(FCN-32s)進行2倍的上采樣得到原尺寸的圖像,并與從pool4層(stride 16)進行的預測融合起來(相加),這一部分的網絡被稱為FCN-16s。隨后將這一部分的預測再進行一次2倍的上采樣并與從pool3層得到的預測融合起來,這一部分的網絡被稱為FCN-8s。 Skips結構將深層的全局信息與淺層的局部信息相結合。

?4、數據處理

????????由于PASCAL VOC 2012數據集中圖像的分辨率大多不一致,無法放在一個tensor中,故輸入前需做標準化處理。

4.2 加載數據集

from download import downloadurl = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/dataset_fcn8s.tar"download(url, "./dataset", kind="tar", replace=True)

4.2 數據預處理?

import numpy as np
import cv2
import mindspore.dataset as dsclass SegDataset:def __init__(self,image_mean,  #圖像的均值,用于圖像標準化。通常是一個列表或數組,包含每個通道的均值值。image_std,#圖像的標準差,也用于圖像標準化。通常是一個列表或數組,包含每個通道的標準差值。data_file='',#數據文件的路徑。這個參數默認是一個空字符串,表示可能的默認值或未提供文件路徑。batch_size=32,#批處理大小,表示一次訓練中使用的樣本數量。32是一個常見的默認值。crop_size=512,#裁剪大小,表示圖像在訓練或測試時裁剪的尺寸。512通常用于高分辨率圖像。max_scale=2.0,#最大縮放比例,用于數據增強,通過隨機縮放圖像來增加模型的魯棒性。min_scale=0.5,#最小縮放比例,同樣用于數據增強。ignore_label=255,#忽略標簽,通常用于語義分割任務,表示某些像素點不參與訓練的標簽值。num_classes=21,#類別數量,表示數據集中不同類別的數量。21個類別可能用于某個特定的數據集,比如Pascal VOC。num_readers=2,#讀取器數量,表示用于讀取數據文件的并行讀取器數量,可以加快數據加載速度。num_parallel_calls=4):#并行調用數量,用于數據預處理的并行調用次數,可以加快數據預處理過程。self.data_file = data_fileself.batch_size = batch_sizeself.crop_size = crop_sizeself.image_mean = np.array(image_mean, dtype=np.float32)self.image_std = np.array(image_std, dtype=np.float32)self.max_scale = max_scaleself.min_scale = min_scaleself.ignore_label = ignore_labelself.num_classes = num_classesself.num_readers = num_readersself.num_parallel_calls = num_parallel_callsmax_scale > min_scaledef preprocess_dataset(self, image, label):# np.frombuffer(image, dtype=np.uint8):將原始字節數據轉換為NumPy數組,以便OpenCV可以處理。# cv2.imdecode(..., cv2.IMREAD_COLOR):使用OpenCV解碼這個NumPy數組,并將其轉換為一個圖像矩陣,以便后續的圖像處理操作。image_out = cv2.imdecode(np.frombuffer(image, dtype=np.uint8), cv2.IMREAD_COLOR)label_out = cv2.imdecode(np.frombuffer(label, dtype=np.uint8), cv2.IMREAD_GRAYSCALE)#生成一個在self.min_scale和self.max_scale之間的隨機浮點數。sc = np.random.uniform(self.min_scale, self.max_scale)new_h, new_w = int(sc * image_out.shape[0]), int(sc * image_out.shape[1])#將圖像 image_out 的大小調整為 new_w 寬和 new_h 高,并使用(nterpolation=cv2.INTER_CUBIC)雙三次插值方法進行插值。image_out = cv2.resize(image_out, (new_w, new_h), interpolation=cv2.INTER_CUBIC)label_out = cv2.resize(label_out, (new_w, new_h), interpolation=cv2.INTER_NEAREST)image_out = (image_out - self.image_mean) / self.image_stdout_h, out_w = max(new_h, self.crop_size), max(new_w, self.crop_size)pad_h, pad_w = out_h - new_h, out_w - new_wif pad_h > 0 or pad_w > 0:# cv2.copyMakeBorder:這是OpenCV中的一個函數,用于為圖像添加邊框。image_out:這是要添加邊框的輸入圖像矩陣。# 0, pad_h, 0, pad_w:這些參數指定了邊框的大小:image_out = cv2.copyMakeBorder(image_out, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=0)label_out = cv2.copyMakeBorder(label_out, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=self.ignore_label)offset_h = np.random.randint(0, out_h - self.crop_size + 1)offset_w = np.random.randint(0, out_w - self.crop_size + 1)image_out = image_out[offset_h: offset_h + self.crop_size, offset_w: offset_w + self.crop_size, :]label_out = label_out[offset_h: offset_h + self.crop_size, offset_w: offset_w+self.crop_size]if np.random.uniform(0.0, 1.0) > 0.5:image_out = image_out[:, ::-1, :]label_out = label_out[:, ::-1]image_out = image_out.transpose((2, 0, 1))image_out = image_out.copy()label_out = label_out.copy()label_out = label_out.astype("int32")return image_out, label_outdef get_dataset(self):ds.config.set_numa_enable(True)dataset = ds.MindDataset(self.data_file, columns_list=["data", "label"],shuffle=True, num_parallel_workers=self.num_readers)transforms_list = self.preprocess_datasetdataset = dataset.map(operations=transforms_list, input_columns=["data", "label"],output_columns=["data", "label"],num_parallel_workers=self.num_parallel_calls)dataset = dataset.shuffle(buffer_size=self.batch_size * 10)dataset = dataset.batch(self.batch_size, drop_remainder=True)return dataset# 定義創建數據集的參數
IMAGE_MEAN = [103.53, 116.28, 123.675]
IMAGE_STD = [57.375, 57.120, 58.395]
DATA_FILE = "dataset/dataset_fcn8s/mindname.mindrecord"# 定義模型訓練參數
train_batch_size = 4
crop_size = 512
min_scale = 0.5
max_scale = 2.0
ignore_label = 255
num_classes = 21# 實例化Dataset
dataset = SegDataset(image_mean=IMAGE_MEAN,image_std=IMAGE_STD,data_file=DATA_FILE,batch_size=train_batch_size,crop_size=crop_size,max_scale=max_scale,min_scale=min_scale,ignore_label=ignore_label,num_classes=num_classes,num_readers=2,num_parallel_calls=4)dataset = dataset.get_dataset()

4.3 訓練集可視化

import numpy as np
import matplotlib.pyplot as pltplt.figure(figsize=(16, 8))# 對訓練集中的數據進行展示
for i in range(1, 9):plt.subplot(2, 4, i)show_data = next(dataset.create_dict_iterator())show_images = show_data["data"].asnumpy()show_images = np.clip(show_images, 0, 1)
# 將圖片轉換HWC格式后進行展示plt.imshow(show_images[0].transpose(1, 2, 0))plt.axis("off")plt.subplots_adjust(wspace=0.05, hspace=0)
plt.show()

【運行結果】?

?5、網絡構建

FCN網絡的流程如下圖所示:

  1. 輸入圖像image,經過pool1池化后,尺寸變為原始尺寸的1/2。
  2. 經過pool2池化,尺寸變為原始尺寸的1/4。
  3. 接著經過pool3、pool4、pool5池化,大小分別變為原始尺寸的1/8、1/16、1/32。
  4. 經過conv6-7卷積,輸出的尺寸依然是原圖的1/32。
  5. FCN-32s是最后使用反卷積,使得輸出圖像大小與輸入圖像相同。
  6. FCN-16s是將conv7的輸出進行反卷積,使其尺寸擴大兩倍至原圖的1/16,并將其與pool4輸出的特征圖進行融合,后通過反卷積擴大到原始尺寸。
  7. FCN-8s是將conv7的輸出進行反卷積擴大4倍,將pool4輸出的特征圖反卷積擴大2倍,并將pool3輸出特征圖拿出,三者融合后通反卷積擴大到原始尺寸。

?5.1 構建代碼

import mindspore.nn as nnclass FCN8s(nn.Cell):def __init__(self, n_class):super().__init__()self.n_class = n_classself.conv1 = nn.SequentialCell(nn.Conv2d(in_channels=3, out_channels=64,kernel_size=3, weight_init='xavier_uniform'),nn.BatchNorm2d(64),nn.ReLU(),nn.Conv2d(in_channels=64, out_channels=64,kernel_size=3, weight_init='xavier_uniform'),nn.BatchNorm2d(64),nn.ReLU())self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)self.conv2 = nn.SequentialCell(nn.Conv2d(in_channels=64, out_channels=128,kernel_size=3, weight_init='xavier_uniform'),nn.BatchNorm2d(128),nn.ReLU(),nn.Conv2d(in_channels=128, out_channels=128,kernel_size=3, weight_init='xavier_uniform'),nn.BatchNorm2d(128),nn.ReLU())self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)self.conv3 = nn.SequentialCell(nn.Conv2d(in_channels=128, out_channels=256,kernel_size=3, weight_init='xavier_uniform'),nn.BatchNorm2d(256),nn.ReLU(),nn.Conv2d(in_channels=256, out_channels=256,kernel_size=3, weight_init='xavier_uniform'),nn.BatchNorm2d(256),nn.ReLU(),nn.Conv2d(in_channels=256, out_channels=256,kernel_size=3, weight_init='xavier_uniform'),nn.BatchNorm2d(256),nn.ReLU())self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)self.conv4 = nn.SequentialCell(nn.Conv2d(in_channels=256, out_channels=512,kernel_size=3, weight_init='xavier_uniform'),nn.BatchNorm2d(512),nn.ReLU(),nn.Conv2d(in_channels=512, out_channels=512,kernel_size=3, weight_init='xavier_uniform'),nn.BatchNorm2d(512),nn.ReLU(),nn.Conv2d(in_channels=512, out_channels=512,kernel_size=3, weight_init='xavier_uniform'),nn.BatchNorm2d(512),nn.ReLU())self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)self.conv5 = nn.SequentialCell(nn.Conv2d(in_channels=512, out_channels=512,kernel_size=3, weight_init='xavier_uniform'),nn.BatchNorm2d(512),nn.ReLU(),nn.Conv2d(in_channels=512, out_channels=512,kernel_size=3, weight_init='xavier_uniform'),nn.BatchNorm2d(512),nn.ReLU(),nn.Conv2d(in_channels=512, out_channels=512,kernel_size=3, weight_init='xavier_uniform'),nn.BatchNorm2d(512),nn.ReLU())self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2)self.conv6 = nn.SequentialCell(nn.Conv2d(in_channels=512, out_channels=4096,kernel_size=7, weight_init='xavier_uniform'),nn.BatchNorm2d(4096),nn.ReLU(),)self.conv7 = nn.SequentialCell(nn.Conv2d(in_channels=4096, out_channels=4096,kernel_size=1, weight_init='xavier_uniform'),nn.BatchNorm2d(4096),nn.ReLU(),)self.score_fr = nn.Conv2d(in_channels=4096, out_channels=self.n_class,kernel_size=1, weight_init='xavier_uniform')self.upscore2 = nn.Conv2dTranspose(in_channels=self.n_class, out_channels=self.n_class,kernel_size=4, stride=2, weight_init='xavier_uniform')self.score_pool4 = nn.Conv2d(in_channels=512, out_channels=self.n_class,kernel_size=1, weight_init='xavier_uniform')self.upscore_pool4 = nn.Conv2dTranspose(in_channels=self.n_class, out_channels=self.n_class,kernel_size=4, stride=2, weight_init='xavier_uniform')self.score_pool3 = nn.Conv2d(in_channels=256, out_channels=self.n_class,kernel_size=1, weight_init='xavier_uniform')self.upscore8 = nn.Conv2dTranspose(in_channels=self.n_class, out_channels=self.n_class,kernel_size=16, stride=8, weight_init='xavier_uniform')def construct(self, x):x1 = self.conv1(x)p1 = self.pool1(x1)x2 = self.conv2(p1)p2 = self.pool2(x2)x3 = self.conv3(p2)p3 = self.pool3(x3)x4 = self.conv4(p3)p4 = self.pool4(x4)x5 = self.conv5(p4)p5 = self.pool5(x5)x6 = self.conv6(p5)x7 = self.conv7(x6)sf = self.score_fr(x7)u2 = self.upscore2(sf)s4 = self.score_pool4(p4)f4 = s4 + u2u4 = self.upscore_pool4(f4)s3 = self.score_pool3(p3)f3 = s3 + u4out = self.upscore8(f3)return out

?6、訓練準備

導入VGG-6部分預訓練權重,使用下面代碼導入VGG-16預訓練模型的部分預訓練權重。

from download import download
from mindspore import load_checkpoint, load_param_into_neturl = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/fcn8s_vgg16_pretrain.ckpt"
download(url, "fcn8s_vgg16_pretrain.ckpt", replace=True)
def load_vgg16():ckpt_vgg16 = "fcn8s_vgg16_pretrain.ckpt"param_vgg = load_checkpoint(ckpt_vgg16)load_param_into_net(net, param_vgg)

?7、損失函數

????????語義分割是對圖像中每個像素點進行分類,仍是分類問題,故損失函數選擇交叉熵損失函數來計算FCN網絡輸出與mask之間的交叉熵損失。這里我們使用的是mindspore.nn.CrossEntropyLoss()作為損失函數。

7.1 自定義評價指標Metrics

????????這一部分主要對訓練出來的模型效果進行評估,為了便于解釋,假設如下:共有?𝑘+1 個類(從?𝐿0?到?𝐿𝑘, 其中包含一個空類或背景),?𝑝𝑖𝑗?表示本屬于𝑖類但被預測為𝑗類的像素數量。即,?𝑝𝑖𝑖?表示真正的數量, 而?𝑝𝑖𝑗𝑝𝑗𝑖則分別被解釋為假正和假負, 盡管兩者都是假正與假負之和。

  • Pixel Accuracy(PA, 像素精度):這是最簡單的度量,為標記正確的像素占總像素的比例。

  • Mean Pixel Accuracy(MPA, 均像素精度):是PA的一種簡單提升,計算每個類內被正確分類像素數的比例,之后求所有類的平均。

  • Mean Intersection over Union(MloU, 均交并比):為語義分割的標準度量。其計算兩個集合的交集和并集之,在語義分割的問題中,這兩個集合為真實值(ground truth) 和預測值(predicted segmentation)。這個比例可以變形為正真數 (intersection) 比上真正、假負、假正(并集)之和。在每個類上計算loU,之后平均。

  • Frequency Weighted Intersection over Union(FWIoU, 頻權交井比):為MloU的一種提升,這種方法根據每個類出現的頻率為其設置權重。

import numpy as np
import mindspore as ms
import mindspore.nn as nn
import mindspore.train as trainclass PixelAccuracy(train.Metric):def __init__(self, num_class=21):# 初始化方法,設置類別數,并調用父類的初始化方法super(PixelAccuracy, self).__init__()self.num_class = num_classdef _generate_matrix(self, gt_image, pre_image):# 掩碼,僅保留有效類別范圍內的像素mask = (gt_image >= 0) & (gt_image < self.num_class)label = self.num_class * gt_image[mask].astype('int') + pre_image[mask]# 計算標簽count = np.bincount(label, minlength=self.num_class**2)# 統計每個類別組合的頻數confusion_matrix = count.reshape(self.num_class, self.num_class)# 重塑為混淆矩陣return confusion_matrixdef clear(self):# 將混淆矩陣重置為全零矩陣self.confusion_matrix = np.zeros((self.num_class,) * 2)def update(self, *inputs):# 更新混淆矩陣y_pred = inputs[0].asnumpy().argmax(axis=1)# 獲取預測類別y = inputs[1].asnumpy().reshape(4, 512, 512) # 獲取真實類別并重塑self.confusion_matrix += self._generate_matrix(y, y_pred) # 更新混淆矩陣def eval(self):# 計算并返回像素準確率pixel_accuracy = np.diag(self.confusion_matrix).sum() / self.confusion_matrix.sum()return pixel_accuracyclass PixelAccuracyClass(train.Metric):def __init__(self, num_class=21):super(PixelAccuracyClass, self).__init__()self.num_class = num_classdef _generate_matrix(self, gt_image, pre_image):mask = (gt_image >= 0) & (gt_image < self.num_class)label = self.num_class * gt_image[mask].astype('int') + pre_image[mask]count = np.bincount(label, minlength=self.num_class**2)confusion_matrix = count.reshape(self.num_class, self.num_class)return confusion_matrixdef update(self, *inputs):y_pred = inputs[0].asnumpy().argmax(axis=1)y = inputs[1].asnumpy().reshape(4, 512, 512)self.confusion_matrix += self._generate_matrix(y, y_pred)def clear(self):self.confusion_matrix = np.zeros((self.num_class,) * 2)def eval(self):mean_pixel_accuracy = np.diag(self.confusion_matrix) / self.confusion_matrix.sum(axis=1)mean_pixel_accuracy = np.nanmean(mean_pixel_accuracy)return mean_pixel_accuracyclass MeanIntersectionOverUnion(train.Metric):def __init__(self, num_class=21):super(MeanIntersectionOverUnion, self).__init__()self.num_class = num_classdef _generate_matrix(self, gt_image, pre_image):mask = (gt_image >= 0) & (gt_image < self.num_class)label = self.num_class * gt_image[mask].astype('int') + pre_image[mask]count = np.bincount(label, minlength=self.num_class**2)confusion_matrix = count.reshape(self.num_class, self.num_class)return confusion_matrixdef update(self, *inputs):y_pred = inputs[0].asnumpy().argmax(axis=1)y = inputs[1].asnumpy().reshape(4, 512, 512)self.confusion_matrix += self._generate_matrix(y, y_pred)def clear(self):self.confusion_matrix = np.zeros((self.num_class,) * 2)def eval(self):mean_iou = np.diag(self.confusion_matrix) / (np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) -np.diag(self.confusion_matrix))mean_iou = np.nanmean(mean_iou)return mean_iouclass FrequencyWeightedIntersectionOverUnion(train.Metric):def __init__(self, num_class=21):super(FrequencyWeightedIntersectionOverUnion, self).__init__()self.num_class = num_classdef _generate_matrix(self, gt_image, pre_image):mask = (gt_image >= 0) & (gt_image < self.num_class)label = self.num_class * gt_image[mask].astype('int') + pre_image[mask]count = np.bincount(label, minlength=self.num_class**2)confusion_matrix = count.reshape(self.num_class, self.num_class)return confusion_matrixdef update(self, *inputs):y_pred = inputs[0].asnumpy().argmax(axis=1)y = inputs[1].asnumpy().reshape(4, 512, 512)self.confusion_matrix += self._generate_matrix(y, y_pred)def clear(self):self.confusion_matrix = np.zeros((self.num_class,) * 2)def eval(self):freq = np.sum(self.confusion_matrix, axis=1) / np.sum(self.confusion_matrix)iu = np.diag(self.confusion_matrix) / (np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) -np.diag(self.confusion_matrix))frequency_weighted_iou = (freq[freq > 0] * iu[freq > 0]).sum()return frequency_weighted_iou

8、模型訓練

?????????導入VGG-16預訓練參數后,實例化損失函數、優化器,使用Model接口編譯網絡,訓練FCN-8s網絡。

import mindspore
from mindspore import Tensor
import mindspore.nn as nn
from mindspore.train import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor, Modeldevice_target = "Ascend"
mindspore.set_context(mode=mindspore.PYNATIVE_MODE, device_target=device_target)train_batch_size = 4
num_classes = 21
# 初始化模型結構
net = FCN8s(n_class=21)
# 導入vgg16預訓練參數
load_vgg16()
# 計算學習率
min_lr = 0.0005
base_lr = 0.05
train_epochs = 1
iters_per_epoch = dataset.get_dataset_size()
total_step = iters_per_epoch * train_epochslr_scheduler = mindspore.nn.cosine_decay_lr(min_lr,           # 最低學習率base_lr,          # 基礎學習率total_step,       # 總步數iters_per_epoch,  # 每個 epoch 的迭代次數decay_epoch=2)    # 開始衰減的 epoch
# 從學習率調度器中取出最后一個學習率值,并將其轉換為 Tensor
lr = Tensor(lr_scheduler[-1])# 定義損失函數
loss = nn.CrossEntropyLoss(ignore_index=255)
# 定義優化器
optimizer = nn.Momentum(params=net.trainable_params(), learning_rate=lr, momentum=0.9, weight_decay=0.0001)
# 定義loss_scale
scale_factor = 4
scale_window = 3000
loss_scale_manager = ms.amp.DynamicLossScaleManager(scale_factor, scale_window)
# 初始化模型
if device_target == "Ascend":model = Model(net, loss_fn=loss, optimizer=optimizer, loss_scale_manager=loss_scale_manager, metrics={"pixel accuracy": PixelAccuracy(), "mean pixel accuracy": PixelAccuracyClass(), "mean IoU": MeanIntersectionOverUnion(), "frequency weighted IoU": FrequencyWeightedIntersectionOverUnion()})
else:model = Model(net, loss_fn=loss, optimizer=optimizer, metrics={"pixel accuracy": PixelAccuracy(), "mean pixel accuracy": PixelAccuracyClass(), "mean IoU": MeanIntersectionOverUnion(), "frequency weighted IoU": FrequencyWeightedIntersectionOverUnion()})# 設置ckpt文件保存的參數
time_callback = TimeMonitor(data_size=iters_per_epoch)
loss_callback = LossMonitor()
callbacks = [time_callback, loss_callback]
save_steps = 330
keep_checkpoint_max = 5
config_ckpt = CheckpointConfig(save_checkpoint_steps=10,keep_checkpoint_max=keep_checkpoint_max)
ckpt_callback = ModelCheckpoint(prefix="FCN8s",directory="./ckpt",config=config_ckpt)
callbacks.append(ckpt_callback)
model.train(train_epochs, dataset, callbacks=callbacks)

【運行結果】

?????????FCN網絡在訓練的過程中需要大量的訓練數據和訓練輪數,這里只提供了小數據單個epoch的訓練來演示loss收斂的過程,下文中使用已訓練好的權重文件進行模型評估和推理效果的展示。

9、模型評估

IMAGE_MEAN = [103.53, 116.28, 123.675]
IMAGE_STD = [57.375, 57.120, 58.395]
DATA_FILE = "dataset/dataset_fcn8s/mindname.mindrecord"# 下載已訓練好的權重文件
url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/FCN8s.ckpt"
download(url, "FCN8s.ckpt", replace=True)
net = FCN8s(n_class=num_classes)ckpt_file = "FCN8s.ckpt"
param_dict = load_checkpoint(ckpt_file)
load_param_into_net(net, param_dict)if device_target == "Ascend":model = Model(net, loss_fn=loss, optimizer=optimizer, loss_scale_manager=loss_scale_manager, metrics={"pixel accuracy": PixelAccuracy(), "mean pixel accuracy": PixelAccuracyClass(), "mean IoU": MeanIntersectionOverUnion(), "frequency weighted IoU": FrequencyWeightedIntersectionOverUnion()})
else:model = Model(net, loss_fn=loss, optimizer=optimizer, metrics={"pixel accuracy": PixelAccuracy(), "mean pixel accuracy": PixelAccuracyClass(), "mean IoU": MeanIntersectionOverUnion(), "frequency weighted IoU": FrequencyWeightedIntersectionOverUnion()})# 實例化Dataset
dataset = SegDataset(image_mean=IMAGE_MEAN,image_std=IMAGE_STD,data_file=DATA_FILE,batch_size=train_batch_size,crop_size=crop_size,max_scale=max_scale,min_scale=min_scale,ignore_label=ignore_label,num_classes=num_classes,num_readers=2,num_parallel_calls=4)
dataset_eval = dataset.get_dataset()
model.eval(dataset_eval)

?10、模型推理

使用訓練的網絡對模型推理結果進行展示。

import cv2
import matplotlib.pyplot as pltnet = FCN8s(n_class=num_classes)
# 設置超參
ckpt_file = "FCN8s.ckpt"
param_dict = load_checkpoint(ckpt_file)
load_param_into_net(net, param_dict)
eval_batch_size = 4
img_lst = []
mask_lst = []
res_lst = []
# 推理效果展示(上方為輸入圖片,下方為推理效果圖片)
plt.figure(figsize=(8, 5))
show_data = next(dataset_eval.create_dict_iterator())
show_images = show_data["data"].asnumpy()
mask_images = show_data["label"].reshape([4, 512, 512])
show_images = np.clip(show_images, 0, 1)
for i in range(eval_batch_size):img_lst.append(show_images[i])mask_lst.append(mask_images[i])
res = net(show_data["data"]).asnumpy().argmax(axis=1)
for i in range(eval_batch_size):plt.subplot(2, 4, i + 1)plt.imshow(img_lst[i].transpose(1, 2, 0))plt.axis("off")plt.subplots_adjust(wspace=0.05, hspace=0.02)plt.subplot(2, 4, i + 5)plt.imshow(res[i])plt.axis("off")plt.subplots_adjust(wspace=0.05, hspace=0.02)
plt.show()

?【運行結果】自己電腦沒運行出來

打卡

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/40099.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/40099.shtml
英文地址,請注明出處:http://en.pswp.cn/web/40099.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

【楚怡杯】職業院校技能大賽 “Python程序開發”賽項樣題三

Python程序開發實訓 &#xff08;時量&#xff1a;240分鐘&#xff09; 中國XX 實訓說明 注意事項 1. 請根據提供的實訓環境&#xff0c;檢查所列的硬件設備、軟件清單、材料清單是否齊全&#xff0c;計算機設備是否能正常使用。 2. 實訓結束前&#xff0c;在實訓平臺提供的…

從數據到智能,英智私有大模型助力企業實現數智化發展

在數字化時代&#xff0c;數據已經成為企業最重要的資源。如何將這些數據轉化為實際的業務價值&#xff0c;是每個企業面臨的重要課題。英智利用業界領先的清洗、訓練和微調技術&#xff0c;對企業數據進行深度挖掘和分析&#xff0c;定制符合企業業務場景的私有大模型&#xf…

篩選有合并單元格的數據

我們經常會使用合并單元格&#xff0c;比如下面表格&#xff0c;因為一個部門中會有不同的員工&#xff0c;就會出現如下表格&#xff1a; 但是當按部門去篩選的時候&#xff0c;會發現并不是我們預期的結果&#xff0c;部門列有空值&#xff0c;每個部門只有第一行數據可以被…

虛幻引擎 快速的色度摳圖 Chroma Key 算法

快就完了 ColorTolerance_PxRange為容差&#xff0c;這里是0-255的輸入&#xff0c;也就是px單位&#xff0c;直接用0-1可以更快 Key為目標顏色

PySide6 實現資源的加載:深入解析與實戰案例

目錄 1. 引言 2. 加載內置資源 3. 使用自定義資源文件&#xff08;.qrc&#xff09; 創建.qrc文件 編譯.qrc文件 加載資源 4. 動態加載UI文件 使用Qt Designer設計UI 加載UI文件 5. 注意事項與最佳實踐 6. 結論 在開發基于PySide6的桌面應用程序時&…

什么是 DDoS 攻擊及如何防護DDOS攻擊

自進入互聯網時代&#xff0c;網絡安全問題就一直困擾著用戶&#xff0c;尤其是DDOS攻擊&#xff0c;一直威脅著用戶的業務安全。而高防IP被廣泛用于增強網絡防護能力。今天我們就來了解下關于DDOS攻擊&#xff0c;以及可以防護DDOS攻擊的高防IP該如何正確選擇使用。 一、什么是…

個人引導頁+音樂炫酷播放器(附加源碼)

個人引導頁音樂炫酷播放器 效果圖部分源碼完整源碼領取下期更新內容 效果圖 部分源碼 //網站動態標題開始 var OriginTitile document.title, titleTime; document.addEventListener("visibilitychange", function() {if (document.hidden) {document.title "…

極客時間 - 《Linux 性能優化實戰》

極客時間 - 《Linux 性能優化實戰》原文鏈接&#xff1a;https://time.geekbang.org/column/intro/100020901 02 | 基礎篇&#xff1a;到底應該怎么理解“平均負載”&#xff1f;在Linux系統中&#xff0c;當一個進程啟動時&#xff0c;操作系統會為該進程申請哪些資源&#x…

Python學習從0開始——Kaggle實踐可視化001

Python學習從0開始——Kaggle實踐可視化001 一、創建和加載數據集二、數據預處理1.按name檢查&#xff0c;處理重復值&#xff08;查重&#xff09;2.查看存在缺失值的列并處理&#xff08;缺失值處理&#xff09;2.1按行或列查看2.2無法推測的數據2.3可由其它列推測的數據 3.拆…

QT實現GIF動圖顯示(小白版,可直接copy使用)

需要你自己提前設置好動圖的位置&#xff0c;本例中存放于"/Users/PLA/PLA/PLA.gif widget.h #ifndef WIDGET_H #define WIDGET_H#include <QWidget> #include <QMovie> #include <QLabel>class Widget : public QWidget {Q_OBJECTpublic:explicit Wid…

mysql數據表時間字段自動存時間

時間字段自動存時間&#xff0c;不用通過插入語句存當前操作時間&#xff1a; created_at timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT 創建時間,

深入分析 Android BroadcastReceiver (九)

文章目錄 深入分析 Android BroadcastReceiver (九)1. Android 廣播機制的擴展應用與高級優化1.1 廣播機制的擴展應用1.1.1 示例&#xff1a;有序廣播1.1.2 示例&#xff1a;粘性廣播1.1.3 示例&#xff1a;局部廣播 1.2 廣播機制的高級優化1.2.1 示例&#xff1a;使用 Pending…

空調計費系統是什么,你知道嗎

空調計費系統是一種通過對使用空調的時間和能源消耗進行監測和計量來進行費用計算的系統。它廣泛應用于各種場所&#xff0c;如家庭、辦公室、商場等&#xff0c;為用戶提供了方便、準確的能源使用管理和費用控制。 可實現功能 智能計費&#xff1a;中央空調分戶計費系統通過智…

SOLIDWORKS分期許可(訂閱形式),降低前期的投入成本!

SOLIDWORKS 分期許可使您能夠降低前期軟件成本&#xff0c;同時提供對 SOLIDWORKS 新版本和升級程序的即時訪問&#xff0c;以及在每個期限結束時調整產品的靈活性&#xff0c;幫助您跟上市場需求和競爭壓力的步伐。 目 錄&#xff1a; ★ 1 什么是SOLIDWORKS分期許可 ★ 2 …

gen_region_line 生成直線

gen_region_line (Operator) Name 名稱 gen_region_line — Store input lines as regions.將輸入行存儲為region。 生成直線&#xff0c;直線區域 Signature 簽名 gen_region_line( : RegionLines : BeginRow, BeginCol, EndRow, EndCol : ) Description 描述 運算符ge…

【LLM大模型】程序員為什么要學習大模型應用開發?

0 prompt engineer 就是prompt工程師它的底層透視。 1 學習大模型的重要性 底層邏輯 人工智能大潮已來&#xff0c;不加入就可能被淘汰。就好像現在職場里誰不會用PPT和excel一樣&#xff0c;基本上你見不到。你問任何一個人問他會不會用PPT&#xff0c;他都會說會用&#…

請查收!模擬電路精選書單一份(可下載)

在電子工程的廣闊天地中&#xff0c;模擬電路設計是一門藝術&#xff0c;也是一種科學。它要求設計師不僅要有深厚的理論知識&#xff0c;還要有精湛的實踐技能。隨著技術的發展&#xff0c;模擬電路設計領域不斷涌現新的理論、技術和工具&#xff0c;這使得學習和掌握模擬設計…

css使用偽元素after或者before的時候想要給after設置z-index無效

css使用偽元素after或者before的時候想要給after或者before設置一個層級關系&#xff0c;使該偽類寫入的樣式在box的下面&#xff0c;發現給box設置z-index無效&#xff0c; 需要找到父級元素&#xff0c;在父級元素上設置z-index值并且將偽類設置z-index:-1

開放式耳機哪個牌子好?五款優質產品推薦,老司機帶飛!

后臺有粉絲滴滴我說&#xff0c;還想再多分享一些耳機的測評或者選購指南&#xff0c;開放式耳機確實越來越火了&#xff0c;市面上的品牌從十幾塊到幾千塊的開放式耳機也比比皆是&#xff0c;但是要選擇適合自己的一款開放式耳機確實還挺難的&#xff0c;所以作為耳機測評師這…

深入解析大型語言模型:從訓練到部署大模型

簡介 隨著數據科學領域的深入發展&#xff0c;大型語言模型——這種能夠處理和生成復雜自然語言的精密人工智能系統—逐漸引發了更大的關注。 LLMs是自然語言處理&#xff08;NLP&#xff09;中最令人矚目的突破之一。這些模型有潛力徹底改變從客服到科學研究等各種行業&…