昇思25天學習打卡營第23天|基于MobileNetv2的垃圾分類

基于MobileNetv2的垃圾分類

1、實驗目的

  • 了解熟悉垃圾分類應用代碼的編寫(Python語言);
  • 了解Linux操作系統的基本使用;
  • 掌握atc命令進行模型轉換的基本操作。

2、MobileNetv2模型原理介紹

MobileNet網絡是由Google團隊于2017年提出的專注于移動端、嵌入式或IoT設備的輕量級CNN網絡,相比于傳統的卷積神經網絡,MobileNet網絡使用深度可分離卷積(Depthwise Separable Convolution)的思想在準確率小幅度降低的前提下,大大減小了模型參數與運算量。并引入寬度系數 α和分辨率系數 β使模型滿足不同應用場景的需求。

由于MobileNet網絡中Relu激活函數處理低維特征信息時會存在大量的丟失,所以MobileNetV2網絡提出使用倒殘差結構(Inverted residual block)和Linear Bottlenecks來設計網絡,以提高模型的準確率,且優化后的模型更小。
在這里插入圖片描述

圖中Inverted residual block結構是先使用1x1卷積進行升維,然后使用3x3的DepthWise卷積,最后使用1x1的卷積進行降維,與Residual block結構相反。Residual block是先使用1x1的卷積進行降維,然后使用3x3的卷積,最后使用1x1的卷積進行升維。

  • 說明:
    詳細內容可參見MobileNetV2論文

3、實驗環境

本案例支持win_x86和Linux系統,CPU/GPU/Ascend均可運行。

在動手進行實踐之前,確保您已經正確安裝了MindSpore。不同平臺下的環境準備請參考《MindSpore環境搭建實驗手冊》。

4、數據處理

4.1數據準備

MobileNetV2的代碼默認使用ImageFolder格式管理數據集,每一類圖片整理成單獨的一個文件夾, 數據集結構如下:

└─ImageFolder

├─train
│   class1Folder
│   ......
└─evalclass1Folder......
from download import download# 下載data_en數據集
url = "https://ascend-professional-construction-dataset.obs.cn-north-4.myhuaweicloud.com:443/MindStudio-pc/data_en.zip" 
path = download(url, "./", kind="zip", replace=True)
from download import download# 下載預訓練權重文件
url = "https://ascend-professional-construction-dataset.obs.cn-north-4.myhuaweicloud.com:443/ComputerVision/mobilenetV2-200_1067.zip" 
path = download(url, "./", kind="zip", replace=True)

4.2數據加載

import math
import numpy as np
import os
import randomfrom matplotlib import pyplot as plt
from easydict import EasyDict
from PIL import Image
import numpy as np
import mindspore.nn as nn
from mindspore import ops as P
from mindspore.ops import add
from mindspore import Tensor
import mindspore.common.dtype as mstype
import mindspore.dataset as de
import mindspore.dataset.vision as C
import mindspore.dataset.transforms as C2
import mindspore as ms
from mindspore import set_context, nn, Tensor, load_checkpoint, save_checkpoint, export
from mindspore.train import Model
from mindspore.train import Callback, LossMonitor, ModelCheckpoint, CheckpointConfigos.environ['GLOG_v'] = '3' # Log level includes 3(ERROR), 2(WARNING), 1(INFO), 0(DEBUG).
os.environ['GLOG_logtostderr'] = '0' # 0:輸出到文件,1:輸出到屏幕
os.environ['GLOG_log_dir'] = '../../log' # 日志目錄
os.environ['GLOG_stderrthreshold'] = '2' # 輸出到目錄也輸出到屏幕:3(ERROR), 2(WARNING), 1(INFO), 0(DEBUG).
set_context(mode=ms.GRAPH_MODE, device_target="CPU", device_id=0) # 設置采用圖模式執行,設備為Ascend#
配置后續訓練、驗證、推理用到的參數:
# 垃圾分類數據集標簽,以及用于標簽映射的字典。
garbage_classes = {'干垃圾': ['貝殼', '打火機', '舊鏡子', '掃把', '陶瓷碗', '牙刷', '一次性筷子', '臟污衣服'],'可回收物': ['報紙', '玻璃制品', '籃球', '塑料瓶', '硬紙板', '玻璃瓶', '金屬制品', '帽子', '易拉罐', '紙張'],'濕垃圾': ['菜葉', '橙皮', '蛋殼', '香蕉皮'],'有害垃圾': ['電池', '藥片膠囊', '熒光燈', '油漆桶']
}class_cn = ['貝殼', '打火機', '舊鏡子', '掃把', '陶瓷碗', '牙刷', '一次性筷子', '臟污衣服','報紙', '玻璃制品', '籃球', '塑料瓶', '硬紙板', '玻璃瓶', '金屬制品', '帽子', '易拉罐', '紙張','菜葉', '橙皮', '蛋殼', '香蕉皮','電池', '藥片膠囊', '熒光燈', '油漆桶']
class_en = ['Seashell', 'Lighter','Old Mirror', 'Broom','Ceramic Bowl', 'Toothbrush','Disposable Chopsticks','Dirty Cloth','Newspaper', 'Glassware', 'Basketball', 'Plastic Bottle', 'Cardboard','Glass Bottle', 'Metalware', 'Hats', 'Cans', 'Paper','Vegetable Leaf','Orange Peel', 'Eggshell','Banana Peel','Battery', 'Tablet capsules','Fluorescent lamp', 'Paint bucket']index_en = {'Seashell': 0, 'Lighter': 1, 'Old Mirror': 2, 'Broom': 3, 'Ceramic Bowl': 4, 'Toothbrush': 5, 'Disposable Chopsticks': 6, 'Dirty Cloth': 7,'Newspaper': 8, 'Glassware': 9, 'Basketball': 10, 'Plastic Bottle': 11, 'Cardboard': 12, 'Glass Bottle': 13, 'Metalware': 14, 'Hats': 15, 'Cans': 16, 'Paper': 17,'Vegetable Leaf': 18, 'Orange Peel': 19, 'Eggshell': 20, 'Banana Peel': 21,'Battery': 22, 'Tablet capsules': 23, 'Fluorescent lamp': 24, 'Paint bucket': 25}# 訓練超參
config = EasyDict({"num_classes": 26,"image_height": 224,"image_width": 224,#"data_split": [0.9, 0.1],"backbone_out_channels":1280,"batch_size": 16,"eval_batch_size": 8,"epochs": 10,"lr_max": 0.05,"momentum": 0.9,"weight_decay": 1e-4,"save_ckpt_epochs": 1,"dataset_path": "./data_en","class_index": index_en,"pretrained_ckpt": "./mobilenetV2-200_1067.ckpt" # mobilenetV2-200_1067.ckpt 
})
數據預處理操作

利用ImageFolderDataset方法讀取垃圾分類數據集,并整體對數據集進行處理。

讀取數據集時指定訓練集和測試集,首先對整個數據集進行歸一化,修改圖像頻道等預處理操作。然后對訓練集的數據依次進行RandomCropDecodeResize、RandomHorizontalFlip、RandomColorAdjust、shuffle操作,以增加訓練數據的豐富度;對測試集進行Decode、Resize、CenterCrop等預處理操作;最后返回處理后的數據集。

def create_dataset(dataset_path, config, training=True, buffer_size=1000):"""create a train or eval datasetArgs:dataset_path(string): the path of dataset.config(struct): the config of train and eval in diffirent platform.Returns:train_dataset, val_dataset"""data_path = os.path.join(dataset_path, 'train' if training else 'test')ds = de.ImageFolderDataset(data_path, num_parallel_workers=4, class_indexing=config.class_index)resize_height = config.image_heightresize_width = config.image_widthnormalize_op = C.Normalize(mean=[0.485*255, 0.456*255, 0.406*255], std=[0.229*255, 0.224*255, 0.225*255])change_swap_op = C.HWC2CHW()type_cast_op = C2.TypeCast(mstype.int32)if training:crop_decode_resize = C.RandomCropDecodeResize(resize_height, scale=(0.08, 1.0), ratio=(0.75, 1.333))horizontal_flip_op = C.RandomHorizontalFlip(prob=0.5)color_adjust = C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4)train_trans = [crop_decode_resize, horizontal_flip_op, color_adjust, normalize_op, change_swap_op]train_ds = ds.map(input_columns="image", operations=train_trans, num_parallel_workers=4)train_ds = train_ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=4)train_ds = train_ds.shuffle(buffer_size=buffer_size)ds = train_ds.batch(config.batch_size, drop_remainder=True)else:decode_op = C.Decode()resize_op = C.Resize((int(resize_width/0.875), int(resize_width/0.875)))center_crop = C.CenterCrop(resize_width)eval_trans = [decode_op, resize_op, center_crop, normalize_op, change_swap_op]eval_ds = ds.map(input_columns="image", operations=eval_trans, num_parallel_workers=4)eval_ds = eval_ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=4)ds = eval_ds.batch(config.eval_batch_size, drop_remainder=True)return ds
展示部分處理后的數據:
ds = create_dataset(dataset_path=config.dataset_path, config=config, training=False)
print(ds.get_dataset_size())
data = ds.create_dict_iterator(output_numpy=True)._get_next()
images = data['image']
labels = data['label']for i in range(1, 5):plt.subplot(2, 2, i)plt.imshow(np.transpose(images[i], (1,2,0)))plt.title('label: %s' % class_en[labels[i]])plt.xticks([])
plt.show()

在這里插入圖片描述

5、MobileNetV2模型搭建

使用MindSpore定義MobileNetV2網絡的各模塊時需要繼承mindspore.nn.Cell。Cell是所有神經網絡(Conv2d等)的基類。

神經網絡的各層需要預先在__init__方法中定義,然后通過定義construct方法來完成神經網絡的前向構造。原始模型激活函數為ReLU6,池化模塊采用是全局平均池化層。

__all__ = ['MobileNetV2', 'MobileNetV2Backbone', 'MobileNetV2Head', 'mobilenet_v2']def _make_divisible(v, divisor, min_value=None):if min_value is None:min_value = divisornew_v = max(min_value, int(v + divisor / 2) // divisor * divisor)if new_v < 0.9 * v:new_v += divisorreturn new_vclass GlobalAvgPooling(nn.Cell):"""Global avg pooling definition.Args:Returns:Tensor, output tensor.Examples:>>> GlobalAvgPooling()"""def __init__(self):super(GlobalAvgPooling, self).__init__()def construct(self, x):x = P.mean(x, (2, 3))return xclass ConvBNReLU(nn.Cell):"""Convolution/Depthwise fused with Batchnorm and ReLU block definition.Args:in_planes (int): Input channel.out_planes (int): Output channel.kernel_size (int): Input kernel size.stride (int): Stride size for the first convolutional layer. Default: 1.groups (int): channel group. Convolution is 1 while Depthiwse is input channel. Default: 1.Returns:Tensor, output tensor.Examples:>>> ConvBNReLU(16, 256, kernel_size=1, stride=1, groups=1)"""def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):super(ConvBNReLU, self).__init__()padding = (kernel_size - 1) // 2in_channels = in_planesout_channels = out_planesif groups == 1:conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, pad_mode='pad', padding=padding)else:out_channels = in_planesconv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, pad_mode='pad',padding=padding, group=in_channels)layers = [conv, nn.BatchNorm2d(out_planes), nn.ReLU6()]self.features = nn.SequentialCell(layers)def construct(self, x):output = self.features(x)return outputclass InvertedResidual(nn.Cell):"""Mobilenetv2 residual block definition.Args:inp (int): Input channel.oup (int): Output channel.stride (int): Stride size for the first convolutional layer. Default: 1.expand_ratio (int): expand ration of input channelReturns:Tensor, output tensor.Examples:>>> ResidualBlock(3, 256, 1, 1)"""def __init__(self, inp, oup, stride, expand_ratio):super(InvertedResidual, self).__init__()assert stride in [1, 2]hidden_dim = int(round(inp * expand_ratio))self.use_res_connect = stride == 1 and inp == ouplayers = []if expand_ratio != 1:layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))layers.extend([ConvBNReLU(hidden_dim, hidden_dim,stride=stride, groups=hidden_dim),nn.Conv2d(hidden_dim, oup, kernel_size=1,stride=1, has_bias=False),nn.BatchNorm2d(oup),])self.conv = nn.SequentialCell(layers)self.cast = P.Cast()def construct(self, x):identity = xx = self.conv(x)if self.use_res_connect:return P.add(identity, x)return xclass MobileNetV2Backbone(nn.Cell):"""MobileNetV2 architecture.Args:class_num (int): number of classes.width_mult (int): Channels multiplier for round to 8/16 and others. Default is 1.has_dropout (bool): Is dropout used. Default is falseinverted_residual_setting (list): Inverted residual settings. Default is Noneround_nearest (list): Channel round to . Default is 8Returns:Tensor, output tensor.Examples:>>> MobileNetV2(num_classes=1000)"""def __init__(self, width_mult=1., inverted_residual_setting=None, round_nearest=8,input_channel=32, last_channel=1280):super(MobileNetV2Backbone, self).__init__()block = InvertedResidual# setting of inverted residual blocksself.cfgs = inverted_residual_settingif inverted_residual_setting is None:self.cfgs = [# t, c, n, s[1, 16, 1, 1],[6, 24, 2, 2],[6, 32, 3, 2],[6, 64, 4, 2],[6, 96, 3, 1],[6, 160, 3, 2],[6, 320, 1, 1],]# building first layerinput_channel = _make_divisible(input_channel * width_mult, round_nearest)self.out_channels = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)features = [ConvBNReLU(3, input_channel, stride=2)]# building inverted residual blocksfor t, c, n, s in self.cfgs:output_channel = _make_divisible(c * width_mult, round_nearest)for i in range(n):stride = s if i == 0 else 1features.append(block(input_channel, output_channel, stride, expand_ratio=t))input_channel = output_channelfeatures.append(ConvBNReLU(input_channel, self.out_channels, kernel_size=1))self.features = nn.SequentialCell(features)self._initialize_weights()def construct(self, x):x = self.features(x)return xdef _initialize_weights(self):"""Initialize weights.Args:Returns:None.Examples:>>> _initialize_weights()"""self.init_parameters_data()for _, m in self.cells_and_names():if isinstance(m, nn.Conv2d):n = m.kernel_size[0] * m.kernel_size[1] * m.out_channelsm.weight.set_data(Tensor(np.random.normal(0, np.sqrt(2. / n),m.weight.data.shape).astype("float32")))if m.bias is not None:m.bias.set_data(Tensor(np.zeros(m.bias.data.shape, dtype="float32")))elif isinstance(m, nn.BatchNorm2d):m.gamma.set_data(Tensor(np.ones(m.gamma.data.shape, dtype="float32")))m.beta.set_data(Tensor(np.zeros(m.beta.data.shape, dtype="float32")))@propertydef get_features(self):return self.featuresclass MobileNetV2Head(nn.Cell):"""MobileNetV2 architecture.Args:class_num (int): Number of classes. Default is 1000.has_dropout (bool): Is dropout used. Default is falseReturns:Tensor, output tensor.Examples:>>> MobileNetV2(num_classes=1000)"""def __init__(self, input_channel=1280, num_classes=1000, has_dropout=False, activation="None"):super(MobileNetV2Head, self).__init__()# mobilenet headhead = ([GlobalAvgPooling(), nn.Dense(input_channel, num_classes, has_bias=True)] if not has_dropout else[GlobalAvgPooling(), nn.Dropout(0.2), nn.Dense(input_channel, num_classes, has_bias=True)])self.head = nn.SequentialCell(head)self.need_activation = Trueif activation == "Sigmoid":self.activation = nn.Sigmoid()elif activation == "Softmax":self.activation = nn.Softmax()else:self.need_activation = Falseself._initialize_weights()def construct(self, x):x = self.head(x)if self.need_activation:x = self.activation(x)return xdef _initialize_weights(self):"""Initialize weights.Args:Returns:None.Examples:>>> _initialize_weights()"""self.init_parameters_data()for _, m in self.cells_and_names():if isinstance(m, nn.Dense):m.weight.set_data(Tensor(np.random.normal(0, 0.01, m.weight.data.shape).astype("float32")))if m.bias is not None:m.bias.set_data(Tensor(np.zeros(m.bias.data.shape, dtype="float32")))@propertydef get_head(self):return self.headclass MobileNetV2(nn.Cell):"""MobileNetV2 architecture.Args:class_num (int): number of classes.width_mult (int): Channels multiplier for round to 8/16 and others. Default is 1.has_dropout (bool): Is dropout used. Default is falseinverted_residual_setting (list): Inverted residual settings. Default is Noneround_nearest (list): Channel round to . Default is 8Returns:Tensor, output tensor.Examples:>>> MobileNetV2(backbone, head)"""def __init__(self, num_classes=1000, width_mult=1., has_dropout=False, inverted_residual_setting=None, \round_nearest=8, input_channel=32, last_channel=1280):super(MobileNetV2, self).__init__()self.backbone = MobileNetV2Backbone(width_mult=width_mult, \inverted_residual_setting=inverted_residual_setting, \round_nearest=round_nearest, input_channel=input_channel, last_channel=last_channel).get_featuresself.head = MobileNetV2Head(input_channel=self.backbone.out_channel, num_classes=num_classes, \has_dropout=has_dropout).get_headdef construct(self, x):x = self.backbone(x)x = self.head(x)return xclass MobileNetV2Combine(nn.Cell):"""MobileNetV2Combine architecture.Args:backbone (Cell): the features extract layers.head (Cell):  the fully connected layers.Returns:Tensor, output tensor.Examples:>>> MobileNetV2(num_classes=1000)"""def __init__(self, backbone, head):super(MobileNetV2Combine, self).__init__(auto_prefix=False)self.backbone = backboneself.head = headdef construct(self, x):x = self.backbone(x)x = self.head(x)return xdef mobilenet_v2(backbone, head):return MobileNetV2Combine(backbone, head)

6、MobileNetV2模型的訓練與測試

訓練策略

一般情況下,模型訓練時采用靜態學習率,如0.01。隨著訓練步數的增加,模型逐漸趨于收斂,對權重參數的更新幅度應該逐漸降低,以減小模型訓練后期的抖動。所以,模型訓練時可以采用動態下降的學習率,常見的學習率下降策略有:

  • polynomial decay/square decay;
  • cosine decay;
  • exponential decay;
  • stage decay.

這里使用cosine decay下降策略:

def cosine_decay(total_steps, lr_init=0.0, lr_end=0.0, lr_max=0.1, warmup_steps=0):"""Applies cosine decay to generate learning rate array.Args:total_steps(int): all steps in training.lr_init(float): init learning rate.lr_end(float): end learning ratelr_max(float): max learning rate.warmup_steps(int): all steps in warmup epochs.Returns:list, learning rate array."""lr_init, lr_end, lr_max = float(lr_init), float(lr_end), float(lr_max)decay_steps = total_steps - warmup_stepslr_all_steps = []inc_per_step = (lr_max - lr_init) / warmup_steps if warmup_steps else 0for i in range(total_steps):if i < warmup_steps:lr = lr_init + inc_per_step * (i + 1)else:cosine_decay = 0.5 * (1 + math.cos(math.pi * (i - warmup_steps) / decay_steps))lr = (lr_max - lr_end) * cosine_decay + lr_endlr_all_steps.append(lr)return lr_all_steps

在模型訓練過程中,可以添加檢查點(Checkpoint)用于保存模型的參數,以便進行推理及中斷后再訓練使用。使用場景如下:

  • 訓練后推理場景
  1. 模型訓練完畢后保存模型的參數,用于推理或預測操作。
  2. 訓練過程中,通過實時驗證精度,把精度最高的模型參數保存下來,用于預測操作。
  • 再訓練場景
  1. 進行長時間訓練任務時,保存訓練過程中的Checkpoint文件,防止任務異常退出后從初始狀態開始訓練。
  2. Fine-tuning(微調)場景,即訓練一個模型并保存參數,基于該模型,面向第二個類似任務進行模型訓練。

這里加載ImageNet數據上預訓練的MobileNetv2進行Fine-tuning,只訓練最后修改的FC層,并在訓練過程中保存Checkpoint。

def switch_precision(net, data_type):if ms.get_context('device_target') == "Ascend":net.to_float(data_type)for _, cell in net.cells_and_names():if isinstance(cell, nn.Dense):cell.to_float(ms.float32)
模型訓練與測試

在進行正式的訓練之前,定義訓練函數,讀取數據并對模型進行實例化,定義優化器和損失函數。

首先簡單介紹損失函數及優化器的概念:

  • 損失函數:又叫目標函數,用于衡量預測值與實際值差異的程度。深度學習通過不停地迭代來縮小損失函數的值。定義一個好的損失函數,可以有效提高模型的性能。

  • 優化器:用于最小化損失函數,從而在訓練過程中改進模型。

定義了損失函數后,可以得到損失函數關于權重的梯度。梯度用于指示優化器優化權重的方向,以提高模型性能。

在訓練MobileNetV2之前對MobileNetV2Backbone層的參數進行了固定,使其在訓練過程中對該模塊的權重參數不進行更新;只對MobileNetV2Head模塊的參數進行更新。

MindSpore支持的損失函數有SoftmaxCrossEntropyWithLogits、L1Loss、MSELoss等。這里使用SoftmaxCrossEntropyWithLogits損失函數。

訓練測試過程中會打印loss值,loss值會波動,但總體來說loss值會逐步減小,精度逐步提高。每個人運行的loss值有一定隨機性,不一定完全相同。

每打印一個epoch后模型都會在測試集上的計算測試精度,從打印的精度值分析MobileNetV2模型的預測能力在不斷提升。

from mindspore.amp import FixedLossScaleManager
import time
LOSS_SCALE = 1024train_dataset = create_dataset(dataset_path=config.dataset_path, config=config)
eval_dataset = create_dataset(dataset_path=config.dataset_path, config=config)
step_size = train_dataset.get_dataset_size()backbone = MobileNetV2Backbone() #last_channel=config.backbone_out_channels
# Freeze parameters of backbone. You can comment these two lines.
for param in backbone.get_parameters():param.requires_grad = False
# load parameters from pretrained model
load_checkpoint(config.pretrained_ckpt, backbone)head = MobileNetV2Head(input_channel=backbone.out_channels, num_classes=config.num_classes)
network = mobilenet_v2(backbone, head)# define loss, optimizer, and model
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
loss_scale = FixedLossScaleManager(LOSS_SCALE, drop_overflow_update=False)
lrs = cosine_decay(config.epochs * step_size, lr_max=config.lr_max)
opt = nn.Momentum(network.trainable_params(), lrs, config.momentum, config.weight_decay, loss_scale=LOSS_SCALE)# 定義用于訓練的train_loop函數。
def train_loop(model, dataset, loss_fn, optimizer):# 定義正向計算函數def forward_fn(data, label):logits = model(data)loss = loss_fn(logits, label)return loss# 定義微分函數,使用mindspore.value_and_grad獲得微分函數grad_fn,輸出loss和梯度。# 由于是對模型參數求導,grad_position 配置為None,傳入可訓練參數。grad_fn = ms.value_and_grad(forward_fn, None, optimizer.parameters)# 定義 one-step training函數def train_step(data, label):loss, grads = grad_fn(data, label)optimizer(grads)return losssize = dataset.get_dataset_size()model.set_train()for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):loss = train_step(data, label)if batch % 10 == 0:loss, current = loss.asnumpy(), batchprint(f"loss: {loss:>7f}  [{current:>3d}/{size:>3d}]")# 定義用于測試的test_loop函數。
def test_loop(model, dataset, loss_fn):num_batches = dataset.get_dataset_size()model.set_train(False)total, test_loss, correct = 0, 0, 0for data, label in dataset.create_tuple_iterator():pred = model(data)total += len(data)test_loss += loss_fn(pred, label).asnumpy()correct += (pred.argmax(1) == label).asnumpy().sum()test_loss /= num_batchescorrect /= totalprint(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")print("============== Starting Training ==============")
# 由于時間問題,訓練過程只進行了2個epoch ,可以根據需求調整。
epoch_begin_time = time.time()
epochs = 2
for t in range(epochs):begin_time = time.time()print(f"Epoch {t+1}\n-------------------------------")train_loop(network, train_dataset, loss, opt)ms.save_checkpoint(network, "save_mobilenetV2_model.ckpt")end_time = time.time()times = end_time - begin_timeprint(f"per epoch time: {times}s")test_loop(network, eval_dataset, loss)
epoch_end_time = time.time()
times = epoch_end_time - epoch_begin_time
print(f"total time:  {times}s")
print("============== Training Success ==============")

在這里插入圖片描述

7、模型推理

加載模型Checkpoint進行推理,使用load_checkpoint接口加載數據時,需要把數據傳入給原始網絡,而不能傳遞給帶有優化器和損失函數的訓練網絡。

CKPT="save_mobilenetV2_model.ckpt"
def image_process(image):"""Precess one image per time.Args:image: shape (H, W, C)"""mean=[0.485*255, 0.456*255, 0.406*255]std=[0.229*255, 0.224*255, 0.225*255]image = (np.array(image) - mean) / stdimage = image.transpose((2,0,1))img_tensor = Tensor(np.array([image], np.float32))return img_tensordef infer_one(network, image_path):image = Image.open(image_path).resize((config.image_height, config.image_width))logits = network(image_process(image))pred = np.argmax(logits.asnumpy(), axis=1)[0]print(image_path, class_en[pred])def infer():backbone = MobileNetV2Backbone(last_channel=config.backbone_out_channels)head = MobileNetV2Head(input_channel=backbone.out_channels, num_classes=config.num_classes)network = mobilenet_v2(backbone, head)load_checkpoint(CKPT, network)for i in range(91, 100):infer_one(network, f'data_en/test/Cardboard/000{i}.jpg')
infer()

在這里插入圖片描述

8、導出AIR/GEIR/ONNX模型文件

導出AIR模型文件,用于后續Atlas 200 DK上的模型轉換與推理。當前僅支持MindSpore+Ascend環境。

backbone = MobileNetV2Backbone(last_channel=config.backbone_out_channels)
head = MobileNetV2Head(input_channel=backbone.out_channels, num_classes=config.num_classes)
network = mobilenet_v2(backbone, head)
load_checkpoint(CKPT, network)input = np.random.uniform(0.0, 1.0, size=[1, 3, 224, 224]).astype(np.float32)
# export(network, Tensor(input), file_name='mobilenetv2.air', file_format='AIR')
# export(network, Tensor(input), file_name='mobilenetv2.pb', file_format='GEIR')
export(network, Tensor(input), file_name='mobilenetv2.onnx', file_format='ONNX')

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/45622.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/45622.shtml
英文地址,請注明出處:http://en.pswp.cn/web/45622.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

在 Debian 12 上安裝 budgie-extras-common 包

在 Debian 12 上安裝 budgie-extras-common 包&#xff1a; 安裝前的準備 更新 apt 數據庫&#xff1a; 使用 apt-get:sudo apt-get update或者使用 apt:sudo apt update如果使用 aptitude&#xff08;通常不在 Debian 默認安裝中&#xff09;&#xff0c;首先需要安裝它&…

效能工具:執行 npm start 可直接切換proxy代理UR后直接啟動項目

1) 背景: 我們項目是2個前端3個后端的配置。前端和每個后端都有需要調試的接口。 因此經常切換vite.congig.js中的proxy后端代理鏈接&#xff0c;是挺麻煩的。 于是我研究如何能快速切換后端URL&#xff0c;所幸懶人有懶福&#xff0c;我找到了Inquirer 和 fs&#xff0c; 實…

根據日志繪制障礙物輪廓點和中心點

繪制log中的障礙物凸包點&#xff0c;首先給出log日志中的障礙物的凸包點 [Info]-[PointCloudHandle:88]:[2024-07-14,09:55:41.052]-back obj size 6 [Info]-[PointCloudHandle:92]:[2024-07-14,09:55:41.052]-back obj size 6 cur idx 1 [Info]-[PointCloudHandle:93]:[2024…

極客筆記【收藏】

1. 鴻蒙調試命令&#xff08;adb&#xff09;&#xff1a; OH HDC命令使用指南|極客筆記 2. 添加selinux 權限 Android 根據AVC報錯添加Selinux 權限|極客筆記

【面試題】Golang 鎖的相關問題(第七篇)

目錄 1.Mutex 幾種狀態 1. 鎖定狀態&#xff08;Locked&#xff09; 2. 未鎖定狀態&#xff08;Unlocked&#xff09; 3. 喚醒狀態&#xff08;Woken&#xff09; 4. 饑餓狀態&#xff08;Starving&#xff09; 5. 等待者計數&#xff08;Waiters Count&#xff09; 總結…

STM32+TMC2209控制步進電機正反轉。

STM32F103ZET6TMC2209控制步進電機正反轉 1. 步進電機介紹2 驅動器TMC2209介紹2.1 引腳圖及其功能2.2 細分介紹2.3 TMC控制驅動器接法 3 控制器介紹3.1 確定控制引腳3.2 UBEMX配置3.2.1 GPIO配置3.2.2 NVIC配置3.2.3 RCC配置3.2.4 SYS配置3.2.5 USRAT2配置&#xff08;PS:沒用上…

單相電機或風扇接電容的具體接線方法示例

單相電機或風扇接電容的具體接線方法示例 如下圖所示&#xff0c;單相電機引出3根繞組線&#xff08;不同品牌或型號的電機&#xff0c;引出線的顏色可能會有差異&#xff09;&#xff0c; 那么如何進行接線呢&#xff1f; 首先&#xff0c;跳過萬用表測量主、副繞組的阻值…

Unable to obtain driver using Selenium Manager: Selenium Manager failed解決方案

大家好,我是愛編程的喵喵。雙985碩士畢業,現擔任全棧工程師一職,熱衷于將數據思維應用到工作與生活中。從事機器學習以及相關的前后端開發工作。曾在阿里云、科大訊飛、CCF等比賽獲得多次Top名次。現為CSDN博客專家、人工智能領域優質創作者。喜歡通過博客創作的方式對所學的…

聊聊自動駕駛中的路徑和軌跡

在移動機器人領域&#xff0c;路徑&#xff08;Path&#xff09;和軌跡&#xff08;Trajectory&#xff09;是兩個緊密相關但又有所區別的概念。 路徑 是機器人從起點到終點的一系列點的序列&#xff0c;它只考慮了位置信息&#xff0c;而不考慮時間信息。路徑描述了機器人將要…

Java中常見的語法糖

文章目錄 概覽泛型增強for循環自動裝箱與拆箱字符串拼接枚舉類型可變參數內部類try-with-resourcesLambda表達式 概覽 語法糖是指編程語言中的一種語法結構&#xff0c;它們并不提供新的功能&#xff0c;而是為了讓代碼更易讀、更易寫而設計的。語法糖使得某些常見的編程模式或…

【Linux】Ubuntu 漏洞掃描與修復的吃癟經歷

自從上次“劫持”事情后&#xff0c;項目經理將所有跟安全相關的都推給我了&#xff08;不算 KPI 又要被白嫖&#xff0c;煩死了&#xff09;。這次客戶又提了一個服務器安全掃描和漏洞修復的“活”&#xff0c;我這邊順手將過程記錄一下&#xff0c;就當經驗總結跟各位分享一下…

centos7安裝配置maven

一、配置安裝環境 #安裝wget yum install -y wget #安裝jdk17 #創建jdk存放目錄 mkdir -p /usr/local/java #切換目錄 cd /usr/local/java #下載jdk17 wget https://download.java.net/java/GA/jdk17.0.1/2a2082e5a09d4267845be086888add4f/12/GPL/openjdk-17.0.1_linux-x64_b…

【Linux】多線程_7

文章目錄 九、多線程8. POSIX信號量根據信號量環形隊列的生產者消費者模型代碼結果演示 未完待續 九、多線程 8. POSIX信號量 POSIX信號量和SystemV信號量作用相同&#xff0c;都是用于同步操作&#xff0c;達到無沖突的訪問共享資源目的。 但POSIX可以用于線程間同步。 創建…

什么ISP?什么是IAP?

做單片機開發的工程師經常會聽到兩個詞&#xff1a;ISP和IAP&#xff0c;但新手往往對這兩個概念不是很清楚&#xff0c;今天就來和大家聊聊什么是ISP&#xff0c;什么是IAP&#xff1f; 一、ISP ISP的全稱是&#xff1a;In System Programming&#xff0c;即在系統編程&…

如何申請抖音本地生活服務商?3種方式優劣勢分析!

隨著多家互聯網大廠在本地生活板塊的布局力度不斷加大&#xff0c;以抖音為代表的頭部互聯網平臺的本地生活服務商成為了創業賽道中的大熱門&#xff0c;與抖音本地生活服務商怎么申請等相關的帖子&#xff0c;更是多次登頂創業者社群的話題榜單。 就目前的市場情況來看&#x…

Go語言--廣播式并發聊天服務器

實現功能 每個客戶端上線&#xff0c;服務端可以向其他客戶端廣播上線信息&#xff1b;發送的消息可以廣播給其他在線的客戶支持改名支持客戶端主動退出支持通過who查找當前在線的用戶超時退出 流程 變量 用戶結構體 保存用戶的管道&#xff0c;用戶名以及網絡地址信息 typ…

ARM功耗管理之功耗數據與功耗收益評估

安全之安全(security)博客目錄導讀 思考&#xff1a;功耗數據如何測試&#xff1f;功耗曲線&#xff1f;功耗收益評估&#xff1f; UPF的全稱是Unified Power Format&#xff0c;其作用是把功耗設計意圖&#xff08;power intent&#xff09;傳遞給EDA工具&#xff0c; 從而幫…

vscode OpenCV環境搭建

cmake 官網https://cmake.org/files/ 環境變量D:\Program Files\CMake\bin w64devkit 官網https://github.com/skeeto/w64devkit/tags 環境變量D:\Program Files\w64devkit\bin minGW32-make 拷貝并重命名為make OpenCV 在執行完后&#xff0c;把關于python的都給取消勾選…

mybatis-plus映射mysql的json類型的字段

一、對json里面內容建立實體類 Data AllArgsConstructor NoArgsConstructor public class RouteMetaEntity {private String title;private Boolean affix;private Boolean isAlwaysShow; }二、主類做映射 TableField(typeHandler JacksonTypeHandler.class)private RouteMe…

Codeforces Round 958 (Div. 2)補題

文章目錄 A題 (拆分多集)B題(獲得多數票)C題&#xff08;固定 OR 的遞增序列&#xff09; A題 (拆分多集) 本題在賽時卡的時間比較久&#xff0c;把這題想復雜了&#xff0c;導致WA了兩次。后來看明白之后就是將n每次轉換成k-1個1&#xff0c;到最后分不出來k-1個1直接一次就能…