摘要:
????????記錄MindSpore?AI框架使用ViT模型在ImageNet圖像數據分類上進行訓練、驗證、推理的過程和方法。包括環境準備、下載數據集、數據集加載、模型解析與構建、模型訓練與推理等。
一、概念
1. ViT模型
Vision Transformer
自注意結構模型
Self-Attention
????????Transformer模型
????????????????能夠訓練具有超過100B規模的參數模型
領域
????????自然語言處理
????????計算機視覺
不依賴卷積操作
2.模型結構
ViT模型主體結構
從下往上
最下面主輸入數據集
????????原圖像劃分為多個patch(圖像塊)
????????????????二維patch(不考慮channel)轉換為一維向量
中間backbone基于Transformer模型Encoder部分
????????Multi-head Attention結構
????????部分結構順序有調整
????????????????Normalization位置不同
上面Blocks堆疊后接全連接層Head
附加輸入類別向量
輸出識別分類結果
二、環境準備
確保安裝了Python環境和MindSpore
%%capture captured_output
# 實驗環境已經預裝了mindspore==2.2.14,如需更換mindspore版本,可更改下面mindspore的版本號
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
# 查看當前 mindspore 版本
!pip show mindspore
輸出:
Name: mindspore
Version: 2.2.14
Summary: MindSpore is a new open source deep learning training/inference framework that could be used for mobile, edge and cloud scenarios.
Home-page: https://www.mindspore.cn
Author: The MindSpore Authors
Author-email: contact@mindspore.cn
License: Apache 2.0
Location: /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages
Requires: asttokens, astunparse, numpy, packaging, pillow, protobuf, psutil, scipy
Required-by:
三、數據準備
1.下載、解壓數據集
下載源
http://image-net.org
ImageNet數據集
本案例應用數據集是從ImageNet篩選的子集。
from download import download
?
dataset_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/vit_imagenet_dataset.zip"
path = "./"
?
path = download(dataset_url, path, kind="zip", replace=True)
輸出:
Downloading data from https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/vit_imagenet_dataset.zip (489.1 MB)file_sizes: 100%|█████████████████████████████| 513M/513M [00:02<00:00, 228MB/s]
Extracting zip file...
Successfully downloaded / unzipped to ./
2.數據集路徑結構:
.dataset/├── ILSVRC2012_devkit_t12.tar.gz├── train/├── infer/└── val/
3.加載數據集
import os
?
import mindspore as ms
from mindspore.dataset import ImageFolderDataset
import mindspore.dataset.vision as transforms
?
?
data_path = './dataset/'
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
?
dataset_train = ImageFolderDataset(os.path.join(data_path, "train"), shuffle=True)
?
trans_train = [transforms.RandomCropDecodeResize(size=224,scale=(0.08, 1.0),ratio=(0.75, 1.333)),transforms.RandomHorizontalFlip(prob=0.5),transforms.Normalize(mean=mean, std=std),transforms.HWC2CHW()
]
?
dataset_train = dataset_train.map(operations=trans_train, input_columns=["image"])
dataset_train = dataset_train.batch(batch_size=16, drop_remainder=True)
四、模型解析
1.Transformer基本原理
Transformer模型
基于Attention機制的編碼器-解碼器型結構
模型結構圖:
多個Encoder和Decoder模塊所組成
Encoder和Decoder詳細結構圖:
Encoder與Decoder結構組成
多頭注意力Multi-Head Attention層
????基于自注意力Self-Attention機制
????多個Self-Attention并行組成
Feed Forward層
Normaliztion層
殘差連接(Residual Connection),圖中的“Add”
2.Attention模塊
Self-Attention核心內容
為輸入向量的每個單詞學習一個權重
????????給定查詢向量Query
????????計算Query和各個Key的相似性或者相關性
????????????????得到注意力分布
????????????????得到每個Key對應Value的權重系數
????????對Value進行加權求和得到最終的Attention數值。
Self-Attention機制:
(1) 最初的輸入向量
經過Embedding層
????????映射成dim x 3
????????分割成三個向量
????????????????Q(Query)
????????????????K(Key)
????????????????V(Value)
輸入向量為一個一維向量序列(x1,x2,x3)
每個一維向量經過Embedding層映射出Q、K、V三個向量
????????只是Embedding矩陣不同
????????矩陣參數通過學習得到
向量之間關聯
通過Q、K、V三個矩陣可計算
其中兩個向量點乘獲得權重
另一個向量承載權重向加的結果
(2) 自注意力機制的自注意主要體現
Q、K、V來源于其自身
自注意過程
????????提取輸入的不同順序的向量的聯系與特征
????????通過不同順序向量之間的聯系緊密性表現
????????????????Q與K乘積經過Softmax的結果
獲取Q,K,V向量間權重
????????Q、K點乘
????????除以維度的平方根
????????Softmax處理所有向量的結果
(3) 全局自注意
向量V與Q、K經過Softmax結果
????????weight sum
每一組Q、K、V最后都有一個V輸出
當前向量結合其他向量關聯權重得到結果
Self-Attention全部過程:
多頭注意力機制
分割self-Attention處理的向量為多個Head部分處理
????????并行加速
????????保持參數總量不變
同樣的query, key和value映射為高維空間(Q,K,V)
????????不同子空間(Q_0,K_0,V_0)
????????分開計算自注意力
????????最后再合并不同子空間中的注意力信息。
同一個輸入向量
多個注意力機制可以并行加速處理
處理時更充分的分析和利用了向量特征
下圖中ai和aj是同一個向量分割而得
以下是Multi-Head Attention代碼:
from mindspore import nn, ops
?
class Attention(nn.Cell):def __init__(self,dim: int,num_heads: int = 8,keep_prob: float = 1.0,attention_keep_prob: float = 1.0):super(Attention, self).__init__()
?self.num_heads = num_headshead_dim = dim // num_headsself.scale = ms.Tensor(head_dim ** -0.5)
?self.qkv = nn.Dense(dim, dim * 3)self.attn_drop = nn.Dropout(p=1.0-attention_keep_prob)self.out = nn.Dense(dim, dim)self.out_drop = nn.Dropout(p=1.0-keep_prob)self.attn_matmul_v = ops.BatchMatMul()self.q_matmul_k = ops.BatchMatMul(transpose_b=True)self.softmax = nn.Softmax(axis=-1)
?def construct(self, x):"""Attention construct."""b, n, c = x.shapeqkv = self.qkv(x)qkv = ops.reshape(qkv, (b, n, 3, self.num_heads, c // self.num_heads))qkv = ops.transpose(qkv, (2, 0, 3, 1, 4))q, k, v = ops.unstack(qkv, axis=0)attn = self.q_matmul_k(q, k)attn = ops.mul(attn, self.scale)attn = self.softmax(attn)attn = self.attn_drop(attn)out = self.attn_matmul_v(attn, v)out = ops.transpose(out, (0, 2, 1, 3))out = ops.reshape(out, (b, n, c))out = self.out(out)out = self.out_drop(out)
?return out
Transformer Encoder
多結構拼接形成Transformer基礎結構
Self-Attention
Feed Forward
Residual Connection
Feed Forward,Residual Connection結構代碼:
from typing import Optional, Dict
?
class FeedForward(nn.Cell):def __init__(self,in_features: int,hidden_features: Optional[int] = None,out_features: Optional[int] = None,activation: nn.Cell = nn.GELU,keep_prob: float = 1.0):super(FeedForward, self).__init__()out_features = out_features or in_featureshidden_features = hidden_features or in_featuresself.dense1 = nn.Dense(in_features, hidden_features)self.activation = activation()self.dense2 = nn.Dense(hidden_features, out_features)self.dropout = nn.Dropout(p=1.0-keep_prob)
?def construct(self, x):"""Feed Forward construct."""x = self.dense1(x)x = self.activation(x)x = self.dropout(x)x = self.dense2(x)x = self.dropout(x)
?return x
?
class ResidualCell(nn.Cell):def __init__(self, cell):super(ResidualCell, self).__init__()self.cell = cell
?def construct(self, x):"""ResidualCell construct."""return self.cell(x) + x
Self-Attention構建ViT模型中的TransformerEncoder部分:
ViT模型Transformer不同
Normalization放在Self-Attention和Feed Forward之前
其他結構不變
Transformer結構圖
多個子encoder堆疊構建模型編碼器
ViT模型配置超參數num_layers
????????確定堆疊層數
Residual Connection,Normalization的結構
保證信息經過深層處理不退化
增強模型泛化能力
TransformerEncoder結構和多層感知器(MLP)結合
構成了ViT模型的backbone部分
class TransformerEncoder(nn.Cell):def __init__(self,dim: int,num_layers: int,num_heads: int,mlp_dim: int,keep_prob: float = 1.,attention_keep_prob: float = 1.0,drop_path_keep_prob: float = 1.0,activation: nn.Cell = nn.GELU,norm: nn.Cell = nn.LayerNorm):super(TransformerEncoder, self).__init__()layers = []
?for _ in range(num_layers):normalization1 = norm((dim,))normalization2 = norm((dim,))attention = Attention(dim=dim,num_heads=num_heads,keep_prob=keep_prob,attention_keep_prob=attention_keep_prob)
?feedforward = FeedForward(in_features=dim,hidden_features=mlp_dim,activation=activation,keep_prob=keep_prob)
?layers.append(nn.SequentialCell([ResidualCell(nn.SequentialCell([normalization1, attention])),ResidualCell(nn.SequentialCell([normalization2, feedforward]))]))self.layers = nn.SequentialCell(layers)
?def construct(self, x):"""Transformer construct."""return self.layers(x)
ViT模型的輸入
傳統的Transformer結構
處理自然語言領域的詞向量
(Word Embedding or Word Vector),
詞向量是一維向量堆疊
圖片是二維矩陣堆疊,
多頭注意力機制處理一維詞向量堆疊時會提取詞向量之間的聯系也就是上下文語義
ViT模型中:
輸入圖像每個channel卷積操作劃分1616個patch
????????一幅輸入224 x 224的圖像卷積處理
????????????????得到16 x 16個patch
????????????????每一個patch的大小就是14?x 14
每個patch矩陣拉伸成為一維向量
獲得近似詞向量堆疊的效果
????????14 x 14patch轉換為長度196的向量
圖像輸入網絡經過的第一步處理。
Patch Embedding代碼:
class PatchEmbedding(nn.Cell):MIN_NUM_PATCHES = 4
?def __init__(self,image_size: int = 224,patch_size: int = 16,embed_dim: int = 768,input_channels: int = 3):super(PatchEmbedding, self).__init__()
?self.image_size = image_sizeself.patch_size = patch_sizeself.num_patches = (image_size // patch_size) ** 2self.conv = nn.Conv2d(input_channels, embed_dim, kernel_size=patch_size, stride=patch_size, has_bias=True)
?def construct(self, x):"""Path Embedding construct."""x = self.conv(x)b, c, h, w = x.shapex = ops.reshape(x, (b, c, h * w))x = ops.transpose(x, (0, 2, 1))
?return x
輸入圖像劃分patch后
????????經過pos_embedding
????????????????class_embedding兩個過程。
class_embedding借鑒BERT模型用于文本分類
每一個word vector之前增加一個類別值
196維向量加上class_embedding變為197維
class_embedding是一個可以學習的參數
經過網絡的不斷訓練,輸出向量的第一個維度的輸出來決定最后的輸出類別;
輸入16 x 16patch
輸出16x16個class_embedding進行分類。
pos_embedding也是一組可以學習的參數
????????加入patch矩陣
pos_embedding有4種方案
????????采用一維pos_embedding
????????由于class_embedding是加在pos_embedding之前
????????所以pos_embedding維度會比patch拉伸后的維度加1。
五、整體構建ViT
構建ViT模型代碼
from mindspore.common.initializer import Normal
from mindspore.common.initializer import initializer
from mindspore import Parameter
?
?
def init(init_type, shape, dtype, name, requires_grad):"""Init."""initial = initializer(init_type, shape, dtype).init_data()return Parameter(initial, name=name, requires_grad=requires_grad)
?
?
class ViT(nn.Cell):def __init__(self,image_size: int = 224,input_channels: int = 3,patch_size: int = 16,embed_dim: int = 768,num_layers: int = 12,num_heads: int = 12,mlp_dim: int = 3072,keep_prob: float = 1.0,attention_keep_prob: float = 1.0,drop_path_keep_prob: float = 1.0,activation: nn.Cell = nn.GELU,norm: Optional[nn.Cell] = nn.LayerNorm,pool: str = 'cls') -> None:super(ViT, self).__init__()
?self.patch_embedding = PatchEmbedding(image_size=image_size,patch_size=patch_size,embed_dim=embed_dim,input_channels=input_channels)num_patches = self.patch_embedding.num_patches
?self.cls_token = init(init_type=Normal(sigma=1.0),shape=(1, 1, embed_dim),dtype=ms.float32,name='cls',requires_grad=True)
?self.pos_embedding = init(init_type=Normal(sigma=1.0),shape=(1, num_patches + 1, embed_dim),dtype=ms.float32,name='pos_embedding',requires_grad=True)
?self.pool = poolself.pos_dropout = nn.Dropout(p=1.0-keep_prob)self.norm = norm((embed_dim,))self.transformer = TransformerEncoder(dim=embed_dim,num_layers=num_layers,num_heads=num_heads,mlp_dim=mlp_dim,keep_prob=keep_prob,attention_keep_prob=attention_keep_prob,drop_path_keep_prob=drop_path_keep_prob,activation=activation,norm=norm)self.dropout = nn.Dropout(p=1.0-keep_prob)self.dense = nn.Dense(embed_dim, num_classes)
?def construct(self, x):"""ViT construct."""x = self.patch_embedding(x)cls_tokens = ops.tile(self.cls_token.astype(x.dtype), (x.shape[0], 1, 1))x = ops.concat((cls_tokens, x), axis=1)x += self.pos_embedding
?x = self.pos_dropout(x)x = self.transformer(x)x = self.norm(x)x = x[:, 0]if self.training:x = self.dropout(x)x = self.dense(x)
?return x
整體流程圖如下所示:
六、模型訓練與推理
1.模型訓練
模型開始訓練
設定損失函數
????????優化器
????????回調函數
調整epoch_size
from mindspore.nn import LossBase
from mindspore.train import LossMonitor, TimeMonitor, CheckpointConfig, ModelCheckpoint
from mindspore import train
?
# define super parameter
epoch_size = 10
momentum = 0.9
num_classes = 1000
resize = 224
step_size = dataset_train.get_dataset_size()
?
# construct model
network = ViT()
?
# load ckpt
vit_url = "https://download.mindspore.cn/vision/classification/vit_b_16_224.ckpt"
path = "./ckpt/vit_b_16_224.ckpt"
?
vit_path = download(vit_url, path, replace=True)
param_dict = ms.load_checkpoint(vit_path)
ms.load_param_into_net(network, param_dict)
?
# define learning rate
lr = nn.cosine_decay_lr(min_lr=float(0),max_lr=0.00005,total_step=epoch_size * step_size,step_per_epoch=step_size,decay_epoch=10)
?
# define optimizer
network_opt = nn.Adam(network.trainable_params(), lr, momentum)
?
?
# define loss function
class CrossEntropySmooth(LossBase):"""CrossEntropy."""
?def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):super(CrossEntropySmooth, self).__init__()self.onehot = ops.OneHot()self.sparse = sparseself.on_value = ms.Tensor(1.0 - smooth_factor, ms.float32)self.off_value = ms.Tensor(1.0 * smooth_factor / (num_classes - 1), ms.float32)self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)
?def construct(self, logit, label):if self.sparse:label = self.onehot(label, ops.shape(logit)[1], self.on_value, self.off_value)loss = self.ce(logit, label)return loss
?
?
network_loss = CrossEntropySmooth(sparse=True,reduction="mean",smooth_factor=0.1,num_classes=num_classes)
?
# set checkpoint
ckpt_config = CheckpointConfig(save_checkpoint_steps=step_size, keep_checkpoint_max=100)
ckpt_callback = ModelCheckpoint(prefix='vit_b_16', directory='./ViT', config=ckpt_config)
?
# initialize model
# "Ascend + mixed precision" can improve performance
ascend_target = (ms.get_context("device_target") == "Ascend")
if ascend_target:model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics={"acc"}, amp_level="O2")
else:model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics={"acc"}, amp_level="O0")
?
# train model
model.train(epoch_size,dataset_train,callbacks=[ckpt_callback, LossMonitor(125), TimeMonitor(125)],dataset_sink_mode=False,)
輸出:
Downloading data from https://download-mindspore.osinfra.cn/vision/classification/vit_b_16_224.ckpt (330.2 MB)file_sizes: 100%|████████████████████████████| 346M/346M [00:26<00:00, 13.2MB/s]
Successfully downloaded file to ./ckpt/vit_b_16_224.ckpt
epoch: 1 step: 125, loss is 1.4842896
Train epoch time: 275011.631 ms, per step time: 2200.093 ms
epoch: 2 step: 125, loss is 1.3481578
Train epoch time: 23961.255 ms, per step time: 191.690 ms
epoch: 3 step: 125, loss is 1.3990085
Train epoch time: 24217.701 ms, per step time: 193.742 ms
epoch: 4 step: 125, loss is 1.1687485
Train epoch time: 23769.989 ms, per step time: 190.160 ms
epoch: 5 step: 125, loss is 1.209775
Train epoch time: 23603.390 ms, per step time: 188.827 ms
epoch: 6 step: 125, loss is 1.3151006
Train epoch time: 23977.132 ms, per step time: 191.817 ms
epoch: 7 step: 125, loss is 1.4682239
Train epoch time: 23898.189 ms, per step time: 191.186 ms
epoch: 8 step: 125, loss is 1.2927357
Train epoch time: 23681.583 ms, per step time: 189.453 ms
epoch: 9 step: 125, loss is 1.5348746
Train epoch time: 23521.045 ms, per step time: 188.168 ms
epoch: 10 step: 125, loss is 1.3726548
Train epoch time: 23719.398 ms, per step time: 189.755 ms
2.模型驗證
模型驗證
ImageFolderDataset接口用于讀取數據集
CrossEntropySmooth接口用于損失函數實例化
Model等接口用于編譯模型
步驟:
數據增強
定義ViT網絡結構
加載預訓練模型參數
設置損失函數
設置評價指標
????????Top_1_Accuracy輸出最大值為預測結果
????????Top_5_Accuracy輸出前5的值為預測結果
????????兩個指標的值越大,代表模型準確率越高
編譯模型
驗證
dataset_val = ImageFolderDataset(os.path.join(data_path, "val"), shuffle=True)
?
trans_val = [transforms.Decode(),transforms.Resize(224 + 32),transforms.CenterCrop(224),transforms.Normalize(mean=mean, std=std),transforms.HWC2CHW()
]
?
dataset_val = dataset_val.map(operations=trans_val, input_columns=["image"])
dataset_val = dataset_val.batch(batch_size=16, drop_remainder=True)
?
# construct model
network = ViT()
?
# load ckpt
param_dict = ms.load_checkpoint(vit_path)
ms.load_param_into_net(network, param_dict)
?
network_loss = CrossEntropySmooth(sparse=True,reduction="mean",smooth_factor=0.1,num_classes=num_classes)
?
# define metric
eval_metrics = {'Top_1_Accuracy': train.Top1CategoricalAccuracy(),'Top_5_Accuracy': train.Top5CategoricalAccuracy()}
?
if ascend_target:model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics=eval_metrics, amp_level="O2")
else:model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics=eval_metrics, amp_level="O0")
?
# evaluate model
result = model.eval(dataset_val)
print(result)
輸出:
{'Top_1_Accuracy': 0.7495, 'Top_5_Accuracy': 0.928}
3.模型推理
推理圖片數據預處理
resize
normalize
匹配訓練輸入數據
dataset_infer = ImageFolderDataset(os.path.join(data_path, "infer"), shuffle=True)
?
trans_infer = [transforms.Decode(),transforms.Resize([224, 224]),transforms.Normalize(mean=mean, std=std),transforms.HWC2CHW()
]
?
dataset_infer = dataset_infer.map(operations=trans_infer,input_columns=["image"],num_parallel_workers=1)
dataset_infer = dataset_infer.batch(1)
模型推理
調用模型predict方法
index2label獲取對應標簽
自定義show_result接口在對應圖片上寫結果
import os
import pathlib
import cv2
import numpy as np
from PIL import Image
from enum import Enum
from scipy import io
?
?
class Color(Enum):"""dedine enum color."""red = (0, 0, 255)green = (0, 255, 0)blue = (255, 0, 0)cyan = (255, 255, 0)yellow = (0, 255, 255)magenta = (255, 0, 255)white = (255, 255, 255)black = (0, 0, 0)
?
?
def check_file_exist(file_name: str):"""check_file_exist."""if not os.path.isfile(file_name):raise FileNotFoundError(f"File `{file_name}` does not exist.")
?
?
def color_val(color):"""color_val."""if isinstance(color, str):return Color[color].valueif isinstance(color, Color):return color.valueif isinstance(color, tuple):assert len(color) == 3for channel in color:assert 0 <= channel <= 255return colorif isinstance(color, int):assert 0 <= color <= 255return color, color, colorif isinstance(color, np.ndarray):assert color.ndim == 1 and color.size == 3assert np.all((color >= 0) & (color <= 255))color = color.astype(np.uint8)return tuple(color)raise TypeError(f'Invalid type for color: {type(color)}')
?
?
def imread(image, mode=None):"""imread."""if isinstance(image, pathlib.Path):image = str(image)
?if isinstance(image, np.ndarray):passelif isinstance(image, str):check_file_exist(image)image = Image.open(image)if mode:image = np.array(image.convert(mode))else:raise TypeError("Image must be a `ndarray`, `str` or Path object.")
?return image
?
?
def imwrite(image, image_path, auto_mkdir=True):"""imwrite."""if auto_mkdir:dir_name = os.path.abspath(os.path.dirname(image_path))if dir_name != '':dir_name = os.path.expanduser(dir_name)os.makedirs(dir_name, mode=777, exist_ok=True)
?image = Image.fromarray(image)image.save(image_path)
?
?
def imshow(img, win_name='', wait_time=0):"""imshow"""cv2.imshow(win_name, imread(img))if wait_time == 0: # prevent from hanging if windows was closedwhile True:ret = cv2.waitKey(1)
?closed = cv2.getWindowProperty(win_name, cv2.WND_PROP_VISIBLE) < 1# if user closed window or if some key pressedif closed or ret != -1:breakelse:ret = cv2.waitKey(wait_time)
?
?
def show_result(img: str,result: Dict[int, float],text_color: str = 'green',font_scale: float = 0.5,row_width: int = 20,show: bool = False,win_name: str = '',wait_time: int = 0,out_file: Optional[str] = None) -> None:"""Mark the prediction results on the picture."""img = imread(img, mode="RGB")img = img.copy()x, y = 0, row_widthtext_color = color_val(text_color)for k, v in result.items():if isinstance(v, float):v = f'{v:.2f}'label_text = f'{k}: {v}'cv2.putText(img, label_text, (x, y), cv2.FONT_HERSHEY_COMPLEX,font_scale, text_color)y += row_widthif out_file:show = Falseimwrite(img, out_file)
?if show:imshow(img, win_name, wait_time)
?
?
def index2label():"""Dictionary output for image numbers and categories of the ImageNet dataset."""metafile = os.path.join(data_path, "ILSVRC2012_devkit_t12/data/meta.mat")meta = io.loadmat(metafile, squeeze_me=True)['synsets']
?nums_children = list(zip(*meta))[4]meta = [meta[idx] for idx, num_children in enumerate(nums_children) if num_children == 0]
?_, wnids, classes = list(zip(*meta))[:3]clssname = [tuple(clss.split(', ')) for clss in classes]wnid2class = {wnid: clss for wnid, clss in zip(wnids, clssname)}wind2class_name = sorted(wnid2class.items(), key=lambda x: x[0])
?mapping = {}for index, (_, class_name) in enumerate(wind2class_name):mapping[index] = class_name[0]return mapping
?
?
# Read data for inference
for i, image in enumerate(dataset_infer.create_dict_iterator(output_numpy=True)):image = image["image"]image = ms.Tensor(image)prob = model.predict(image)label = np.argmax(prob.asnumpy(), axis=1)mapping = index2label()output = {int(label): mapping[int(label)]}print(output)show_result(img="./dataset/infer/n01440764/ILSVRC2012_test_00000279.JPEG",result=output,out_file="./dataset/infer/ILSVRC2012_test_00000279.JPEG")
輸出:
{236: 'Doberman'}
推理過程完成后
推理文件夾下找圖片推理結果