昇思MindSpore學習筆記6-06計算機視覺--Vision Transormer圖像分類

摘要:

????????記錄MindSpore?AI框架使用ViT模型在ImageNet圖像數據分類上進行訓練、驗證、推理的過程和方法。包括環境準備、下載數據集、數據集加載、模型解析與構建、模型訓練與推理等。

一、

1. ViT模型

Vision Transformer

自注意結構模型

Self-Attention

????????Transformer模型

????????????????能夠訓練具有超過100B規模的參數模型

領域

????????自然語言處理

????????計算機視覺

不依賴卷積操作

2.模型結構

ViT模型主體結構

從下往上

最下面主輸入數據集

????????原圖像劃分為多個patch(圖像塊)

????????????????二維patch(不考慮channel)轉換為一維向量

中間backbone基于Transformer模型Encoder部分

????????Multi-head Attention結構

????????部分結構順序有調整

????????????????Normalization位置不同

上面Blocks堆疊后接全連接層Head

附加輸入類別向量

輸出識別分類結果

二、環境準備

確保安裝了Python環境和MindSpore

%%capture captured_output
# 實驗環境已經預裝了mindspore==2.2.14,如需更換mindspore版本,可更改下面mindspore的版本號
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
# 查看當前 mindspore 版本
!pip show mindspore

輸出:

Name: mindspore
Version: 2.2.14
Summary: MindSpore is a new open source deep learning training/inference framework that could be used for mobile, edge and cloud scenarios.
Home-page: https://www.mindspore.cn
Author: The MindSpore Authors
Author-email: contact@mindspore.cn
License: Apache 2.0
Location: /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages
Requires: asttokens, astunparse, numpy, packaging, pillow, protobuf, psutil, scipy
Required-by: 

三、數據準備

1.下載、解壓數據集

下載源

http://image-net.org

ImageNet數據集

本案例應用數據集是從ImageNet篩選的子集。

from download import download
?
dataset_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/vit_imagenet_dataset.zip"
path = "./"
?
path = download(dataset_url, path, kind="zip", replace=True)

輸出:

Downloading data from https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/vit_imagenet_dataset.zip (489.1 MB)file_sizes: 100%|█████████████████████████████| 513M/513M [00:02<00:00, 228MB/s]
Extracting zip file...
Successfully downloaded / unzipped to ./

2.數據集路徑結構

.dataset/├── ILSVRC2012_devkit_t12.tar.gz├── train/├── infer/└── val/

3.加載數據集

import os
?
import mindspore as ms
from mindspore.dataset import ImageFolderDataset
import mindspore.dataset.vision as transforms
?
?
data_path = './dataset/'
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
?
dataset_train = ImageFolderDataset(os.path.join(data_path, "train"), shuffle=True)
?
trans_train = [transforms.RandomCropDecodeResize(size=224,scale=(0.08, 1.0),ratio=(0.75, 1.333)),transforms.RandomHorizontalFlip(prob=0.5),transforms.Normalize(mean=mean, std=std),transforms.HWC2CHW()
]
?
dataset_train = dataset_train.map(operations=trans_train, input_columns=["image"])
dataset_train = dataset_train.batch(batch_size=16, drop_remainder=True)

四、模型解析

1.Transformer基本原理

Transformer模型

基于Attention機制的編碼器-解碼器型結構

模型結構圖:

多個Encoder和Decoder模塊所組成

Encoder和Decoder詳細結構圖:

Encoder與Decoder結構組成

多頭注意力Multi-Head Attention層

????基于自注意力Self-Attention機制

????多個Self-Attention并行組成

Feed Forward層

Normaliztion層

殘差連接(Residual Connection),圖中的“Add”

2.Attention模塊

Self-Attention核心內容

為輸入向量的每個單詞學習一個權重

????????給定查詢向量Query

????????計算Query和各個Key的相似性或者相關性

????????????????得到注意力分布

????????????????得到每個Key對應Value的權重系數

????????對Value進行加權求和得到最終的Attention數值。

Self-Attention機制:

(1) 最初的輸入向量

經過Embedding層

????????映射成dim x 3

????????分割成三個向量

????????????????Q(Query)

????????????????K(Key)

????????????????V(Value)

輸入向量為一個一維向量序列(x1,x2,x3)

每個一維向量經過Embedding層映射出Q、K、V三個向量

????????只是Embedding矩陣不同

????????矩陣參數通過學習得到

向量之間關聯

通過Q、K、V三個矩陣可計算

其中兩個向量點乘獲得權重

另一個向量承載權重向加的結果

(2) 自注意力機制的自注意主要體現

Q、K、V來源于其自身

自注意過程

????????提取輸入的不同順序的向量的聯系與特征

????????通過不同順序向量之間的聯系緊密性表現

????????????????Q與K乘積經過Softmax的結果

獲取Q,K,V向量間權重

????????Q、K點乘

????????除以維度的平方根

????????Softmax處理所有向量的結果

(3) 全局自注意

向量V與Q、K經過Softmax結果

????????weight sum

每一組Q、K、V最后都有一個V輸出

當前向量結合其他向量關聯權重得到結果

Self-Attention全部過程:

多頭注意力機制

分割self-Attention處理的向量為多個Head部分處理

????????并行加速

????????保持參數總量不變

同樣的query, key和value映射為高維空間(Q,K,V)

????????不同子空間(Q_0,K_0,V_0)

????????分開計算自注意力

????????最后再合并不同子空間中的注意力信息。

同一個輸入向量

多個注意力機制可以并行加速處理

處理時更充分的分析和利用了向量特征

下圖中ai和aj是同一個向量分割而得

以下是Multi-Head Attention代碼:

from mindspore import nn, ops
?
class Attention(nn.Cell):def __init__(self,dim: int,num_heads: int = 8,keep_prob: float = 1.0,attention_keep_prob: float = 1.0):super(Attention, self).__init__()
?self.num_heads = num_headshead_dim = dim // num_headsself.scale = ms.Tensor(head_dim ** -0.5)
?self.qkv = nn.Dense(dim, dim * 3)self.attn_drop = nn.Dropout(p=1.0-attention_keep_prob)self.out = nn.Dense(dim, dim)self.out_drop = nn.Dropout(p=1.0-keep_prob)self.attn_matmul_v = ops.BatchMatMul()self.q_matmul_k = ops.BatchMatMul(transpose_b=True)self.softmax = nn.Softmax(axis=-1)
?def construct(self, x):"""Attention construct."""b, n, c = x.shapeqkv = self.qkv(x)qkv = ops.reshape(qkv, (b, n, 3, self.num_heads, c // self.num_heads))qkv = ops.transpose(qkv, (2, 0, 3, 1, 4))q, k, v = ops.unstack(qkv, axis=0)attn = self.q_matmul_k(q, k)attn = ops.mul(attn, self.scale)attn = self.softmax(attn)attn = self.attn_drop(attn)out = self.attn_matmul_v(attn, v)out = ops.transpose(out, (0, 2, 1, 3))out = ops.reshape(out, (b, n, c))out = self.out(out)out = self.out_drop(out)
?return out

Transformer Encoder

多結構拼接形成Transformer基礎結構

Self-Attention

Feed Forward

Residual Connection

Feed Forward,Residual Connection結構代碼:

from typing import Optional, Dict
?
class FeedForward(nn.Cell):def __init__(self,in_features: int,hidden_features: Optional[int] = None,out_features: Optional[int] = None,activation: nn.Cell = nn.GELU,keep_prob: float = 1.0):super(FeedForward, self).__init__()out_features = out_features or in_featureshidden_features = hidden_features or in_featuresself.dense1 = nn.Dense(in_features, hidden_features)self.activation = activation()self.dense2 = nn.Dense(hidden_features, out_features)self.dropout = nn.Dropout(p=1.0-keep_prob)
?def construct(self, x):"""Feed Forward construct."""x = self.dense1(x)x = self.activation(x)x = self.dropout(x)x = self.dense2(x)x = self.dropout(x)
?return x
?
class ResidualCell(nn.Cell):def __init__(self, cell):super(ResidualCell, self).__init__()self.cell = cell
?def construct(self, x):"""ResidualCell construct."""return self.cell(x) + x

Self-Attention構建ViT模型中的TransformerEncoder部分:

ViT模型Transformer不同

Normalization放在Self-Attention和Feed Forward之前

其他結構不變

Transformer結構圖

多個子encoder堆疊構建模型編碼器

ViT模型配置超參數num_layers

????????確定堆疊層數

Residual Connection,Normalization的結構

保證信息經過深層處理不退化

增強模型泛化能力

TransformerEncoder結構和多層感知器(MLP)結合

構成了ViT模型的backbone部分

class TransformerEncoder(nn.Cell):def __init__(self,dim: int,num_layers: int,num_heads: int,mlp_dim: int,keep_prob: float = 1.,attention_keep_prob: float = 1.0,drop_path_keep_prob: float = 1.0,activation: nn.Cell = nn.GELU,norm: nn.Cell = nn.LayerNorm):super(TransformerEncoder, self).__init__()layers = []
?for _ in range(num_layers):normalization1 = norm((dim,))normalization2 = norm((dim,))attention = Attention(dim=dim,num_heads=num_heads,keep_prob=keep_prob,attention_keep_prob=attention_keep_prob)
?feedforward = FeedForward(in_features=dim,hidden_features=mlp_dim,activation=activation,keep_prob=keep_prob)
?layers.append(nn.SequentialCell([ResidualCell(nn.SequentialCell([normalization1, attention])),ResidualCell(nn.SequentialCell([normalization2, feedforward]))]))self.layers = nn.SequentialCell(layers)
?def construct(self, x):"""Transformer construct."""return self.layers(x)

ViT模型的輸入

傳統的Transformer結構

處理自然語言領域的詞向量

(Word Embedding or Word Vector),

詞向量是一維向量堆疊

圖片是二維矩陣堆疊,

多頭注意力機制處理一維詞向量堆疊時會提取詞向量之間的聯系也就是上下文語義

ViT模型中:

輸入圖像每個channel卷積操作劃分1616個patch

????????一幅輸入224 x 224的圖像卷積處理

????????????????得到16 x 16個patch

????????????????每一個patch的大小就是14?x 14

每個patch矩陣拉伸成為一維向量

獲得近似詞向量堆疊的效果

????????14 x 14patch轉換為長度196的向量

圖像輸入網絡經過的第一步處理。

Patch Embedding代碼:

class PatchEmbedding(nn.Cell):MIN_NUM_PATCHES = 4
?def __init__(self,image_size: int = 224,patch_size: int = 16,embed_dim: int = 768,input_channels: int = 3):super(PatchEmbedding, self).__init__()
?self.image_size = image_sizeself.patch_size = patch_sizeself.num_patches = (image_size // patch_size) ** 2self.conv = nn.Conv2d(input_channels, embed_dim, kernel_size=patch_size, stride=patch_size, has_bias=True)
?def construct(self, x):"""Path Embedding construct."""x = self.conv(x)b, c, h, w = x.shapex = ops.reshape(x, (b, c, h * w))x = ops.transpose(x, (0, 2, 1))
?return x

輸入圖像劃分patch后

????????經過pos_embedding

????????????????class_embedding兩個過程。

class_embedding借鑒BERT模型用于文本分類

每一個word vector之前增加一個類別值

196維向量加上class_embedding變為197維

class_embedding是一個可以學習的參數

經過網絡的不斷訓練,輸出向量的第一個維度的輸出來決定最后的輸出類別;

輸入16 x 16patch

輸出16x16個class_embedding進行分類。

pos_embedding也是一組可以學習的參數

????????加入patch矩陣

pos_embedding有4種方案

????????采用一維pos_embedding

????????由于class_embedding是加在pos_embedding之前

????????所以pos_embedding維度會比patch拉伸后的維度加1。

五、整體構建ViT

構建ViT模型代碼

from mindspore.common.initializer import Normal
from mindspore.common.initializer import initializer
from mindspore import Parameter
?
?
def init(init_type, shape, dtype, name, requires_grad):"""Init."""initial = initializer(init_type, shape, dtype).init_data()return Parameter(initial, name=name, requires_grad=requires_grad)
?
?
class ViT(nn.Cell):def __init__(self,image_size: int = 224,input_channels: int = 3,patch_size: int = 16,embed_dim: int = 768,num_layers: int = 12,num_heads: int = 12,mlp_dim: int = 3072,keep_prob: float = 1.0,attention_keep_prob: float = 1.0,drop_path_keep_prob: float = 1.0,activation: nn.Cell = nn.GELU,norm: Optional[nn.Cell] = nn.LayerNorm,pool: str = 'cls') -> None:super(ViT, self).__init__()
?self.patch_embedding = PatchEmbedding(image_size=image_size,patch_size=patch_size,embed_dim=embed_dim,input_channels=input_channels)num_patches = self.patch_embedding.num_patches
?self.cls_token = init(init_type=Normal(sigma=1.0),shape=(1, 1, embed_dim),dtype=ms.float32,name='cls',requires_grad=True)
?self.pos_embedding = init(init_type=Normal(sigma=1.0),shape=(1, num_patches + 1, embed_dim),dtype=ms.float32,name='pos_embedding',requires_grad=True)
?self.pool = poolself.pos_dropout = nn.Dropout(p=1.0-keep_prob)self.norm = norm((embed_dim,))self.transformer = TransformerEncoder(dim=embed_dim,num_layers=num_layers,num_heads=num_heads,mlp_dim=mlp_dim,keep_prob=keep_prob,attention_keep_prob=attention_keep_prob,drop_path_keep_prob=drop_path_keep_prob,activation=activation,norm=norm)self.dropout = nn.Dropout(p=1.0-keep_prob)self.dense = nn.Dense(embed_dim, num_classes)
?def construct(self, x):"""ViT construct."""x = self.patch_embedding(x)cls_tokens = ops.tile(self.cls_token.astype(x.dtype), (x.shape[0], 1, 1))x = ops.concat((cls_tokens, x), axis=1)x += self.pos_embedding
?x = self.pos_dropout(x)x = self.transformer(x)x = self.norm(x)x = x[:, 0]if self.training:x = self.dropout(x)x = self.dense(x)
?return x

整體流程圖如下所示:

六、模型訓練與推理

1.模型訓練

模型開始訓練

設定損失函數

????????優化器

????????回調函數

調整epoch_size

from mindspore.nn import LossBase
from mindspore.train import LossMonitor, TimeMonitor, CheckpointConfig, ModelCheckpoint
from mindspore import train
?
# define super parameter
epoch_size = 10
momentum = 0.9
num_classes = 1000
resize = 224
step_size = dataset_train.get_dataset_size()
?
# construct model
network = ViT()
?
# load ckpt
vit_url = "https://download.mindspore.cn/vision/classification/vit_b_16_224.ckpt"
path = "./ckpt/vit_b_16_224.ckpt"
?
vit_path = download(vit_url, path, replace=True)
param_dict = ms.load_checkpoint(vit_path)
ms.load_param_into_net(network, param_dict)
?
# define learning rate
lr = nn.cosine_decay_lr(min_lr=float(0),max_lr=0.00005,total_step=epoch_size * step_size,step_per_epoch=step_size,decay_epoch=10)
?
# define optimizer
network_opt = nn.Adam(network.trainable_params(), lr, momentum)
?
?
# define loss function
class CrossEntropySmooth(LossBase):"""CrossEntropy."""
?def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):super(CrossEntropySmooth, self).__init__()self.onehot = ops.OneHot()self.sparse = sparseself.on_value = ms.Tensor(1.0 - smooth_factor, ms.float32)self.off_value = ms.Tensor(1.0 * smooth_factor / (num_classes - 1), ms.float32)self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)
?def construct(self, logit, label):if self.sparse:label = self.onehot(label, ops.shape(logit)[1], self.on_value, self.off_value)loss = self.ce(logit, label)return loss
?
?
network_loss = CrossEntropySmooth(sparse=True,reduction="mean",smooth_factor=0.1,num_classes=num_classes)
?
# set checkpoint
ckpt_config = CheckpointConfig(save_checkpoint_steps=step_size, keep_checkpoint_max=100)
ckpt_callback = ModelCheckpoint(prefix='vit_b_16', directory='./ViT', config=ckpt_config)
?
# initialize model
# "Ascend + mixed precision" can improve performance
ascend_target = (ms.get_context("device_target") == "Ascend")
if ascend_target:model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics={"acc"}, amp_level="O2")
else:model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics={"acc"}, amp_level="O0")
?
# train model
model.train(epoch_size,dataset_train,callbacks=[ckpt_callback, LossMonitor(125), TimeMonitor(125)],dataset_sink_mode=False,)

輸出:

Downloading data from https://download-mindspore.osinfra.cn/vision/classification/vit_b_16_224.ckpt (330.2 MB)file_sizes: 100%|████████████████████████████| 346M/346M [00:26<00:00, 13.2MB/s]
Successfully downloaded file to ./ckpt/vit_b_16_224.ckpt
epoch: 1 step: 125, loss is 1.4842896
Train epoch time: 275011.631 ms, per step time: 2200.093 ms
epoch: 2 step: 125, loss is 1.3481578
Train epoch time: 23961.255 ms, per step time: 191.690 ms
epoch: 3 step: 125, loss is 1.3990085
Train epoch time: 24217.701 ms, per step time: 193.742 ms
epoch: 4 step: 125, loss is 1.1687485
Train epoch time: 23769.989 ms, per step time: 190.160 ms
epoch: 5 step: 125, loss is 1.209775
Train epoch time: 23603.390 ms, per step time: 188.827 ms
epoch: 6 step: 125, loss is 1.3151006
Train epoch time: 23977.132 ms, per step time: 191.817 ms
epoch: 7 step: 125, loss is 1.4682239
Train epoch time: 23898.189 ms, per step time: 191.186 ms
epoch: 8 step: 125, loss is 1.2927357
Train epoch time: 23681.583 ms, per step time: 189.453 ms
epoch: 9 step: 125, loss is 1.5348746
Train epoch time: 23521.045 ms, per step time: 188.168 ms
epoch: 10 step: 125, loss is 1.3726548
Train epoch time: 23719.398 ms, per step time: 189.755 ms

2.模型驗證

模型驗證

ImageFolderDataset接口用于讀取數據集

CrossEntropySmooth接口用于損失函數實例化

Model等接口用于編譯模型

步驟:

數據增強

定義ViT網絡結構

加載預訓練模型參數

設置損失函數

設置評價指標

????????Top_1_Accuracy輸出最大值為預測結果

????????Top_5_Accuracy輸出前5的值為預測結果

????????兩個指標的值越大,代表模型準確率越高

編譯模型

驗證

dataset_val = ImageFolderDataset(os.path.join(data_path, "val"), shuffle=True)
?
trans_val = [transforms.Decode(),transforms.Resize(224 + 32),transforms.CenterCrop(224),transforms.Normalize(mean=mean, std=std),transforms.HWC2CHW()
]
?
dataset_val = dataset_val.map(operations=trans_val, input_columns=["image"])
dataset_val = dataset_val.batch(batch_size=16, drop_remainder=True)
?
# construct model
network = ViT()
?
# load ckpt
param_dict = ms.load_checkpoint(vit_path)
ms.load_param_into_net(network, param_dict)
?
network_loss = CrossEntropySmooth(sparse=True,reduction="mean",smooth_factor=0.1,num_classes=num_classes)
?
# define metric
eval_metrics = {'Top_1_Accuracy': train.Top1CategoricalAccuracy(),'Top_5_Accuracy': train.Top5CategoricalAccuracy()}
?
if ascend_target:model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics=eval_metrics, amp_level="O2")
else:model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics=eval_metrics, amp_level="O0")
?
# evaluate model
result = model.eval(dataset_val)
print(result)

輸出:

{'Top_1_Accuracy': 0.7495, 'Top_5_Accuracy': 0.928}

3.模型推理

推理圖片數據預處理

resize

normalize

匹配訓練輸入數據

dataset_infer = ImageFolderDataset(os.path.join(data_path, "infer"), shuffle=True)
?
trans_infer = [transforms.Decode(),transforms.Resize([224, 224]),transforms.Normalize(mean=mean, std=std),transforms.HWC2CHW()
]
?
dataset_infer = dataset_infer.map(operations=trans_infer,input_columns=["image"],num_parallel_workers=1)
dataset_infer = dataset_infer.batch(1)

模型推理

調用模型predict方法

index2label獲取對應標簽

自定義show_result接口在對應圖片上寫結果

import os
import pathlib
import cv2
import numpy as np
from PIL import Image
from enum import Enum
from scipy import io
?
?
class Color(Enum):"""dedine enum color."""red = (0, 0, 255)green = (0, 255, 0)blue = (255, 0, 0)cyan = (255, 255, 0)yellow = (0, 255, 255)magenta = (255, 0, 255)white = (255, 255, 255)black = (0, 0, 0)
?
?
def check_file_exist(file_name: str):"""check_file_exist."""if not os.path.isfile(file_name):raise FileNotFoundError(f"File `{file_name}` does not exist.")
?
?
def color_val(color):"""color_val."""if isinstance(color, str):return Color[color].valueif isinstance(color, Color):return color.valueif isinstance(color, tuple):assert len(color) == 3for channel in color:assert 0 <= channel <= 255return colorif isinstance(color, int):assert 0 <= color <= 255return color, color, colorif isinstance(color, np.ndarray):assert color.ndim == 1 and color.size == 3assert np.all((color >= 0) & (color <= 255))color = color.astype(np.uint8)return tuple(color)raise TypeError(f'Invalid type for color: {type(color)}')
?
?
def imread(image, mode=None):"""imread."""if isinstance(image, pathlib.Path):image = str(image)
?if isinstance(image, np.ndarray):passelif isinstance(image, str):check_file_exist(image)image = Image.open(image)if mode:image = np.array(image.convert(mode))else:raise TypeError("Image must be a `ndarray`, `str` or Path object.")
?return image
?
?
def imwrite(image, image_path, auto_mkdir=True):"""imwrite."""if auto_mkdir:dir_name = os.path.abspath(os.path.dirname(image_path))if dir_name != '':dir_name = os.path.expanduser(dir_name)os.makedirs(dir_name, mode=777, exist_ok=True)
?image = Image.fromarray(image)image.save(image_path)
?
?
def imshow(img, win_name='', wait_time=0):"""imshow"""cv2.imshow(win_name, imread(img))if wait_time == 0:  # prevent from hanging if windows was closedwhile True:ret = cv2.waitKey(1)
?closed = cv2.getWindowProperty(win_name, cv2.WND_PROP_VISIBLE) < 1# if user closed window or if some key pressedif closed or ret != -1:breakelse:ret = cv2.waitKey(wait_time)
?
?
def show_result(img: str,result: Dict[int, float],text_color: str = 'green',font_scale: float = 0.5,row_width: int = 20,show: bool = False,win_name: str = '',wait_time: int = 0,out_file: Optional[str] = None) -> None:"""Mark the prediction results on the picture."""img = imread(img, mode="RGB")img = img.copy()x, y = 0, row_widthtext_color = color_val(text_color)for k, v in result.items():if isinstance(v, float):v = f'{v:.2f}'label_text = f'{k}: {v}'cv2.putText(img, label_text, (x, y), cv2.FONT_HERSHEY_COMPLEX,font_scale, text_color)y += row_widthif out_file:show = Falseimwrite(img, out_file)
?if show:imshow(img, win_name, wait_time)
?
?
def index2label():"""Dictionary output for image numbers and categories of the ImageNet dataset."""metafile = os.path.join(data_path, "ILSVRC2012_devkit_t12/data/meta.mat")meta = io.loadmat(metafile, squeeze_me=True)['synsets']
?nums_children = list(zip(*meta))[4]meta = [meta[idx] for idx, num_children in enumerate(nums_children) if num_children == 0]
?_, wnids, classes = list(zip(*meta))[:3]clssname = [tuple(clss.split(', ')) for clss in classes]wnid2class = {wnid: clss for wnid, clss in zip(wnids, clssname)}wind2class_name = sorted(wnid2class.items(), key=lambda x: x[0])
?mapping = {}for index, (_, class_name) in enumerate(wind2class_name):mapping[index] = class_name[0]return mapping
?
?
# Read data for inference
for i, image in enumerate(dataset_infer.create_dict_iterator(output_numpy=True)):image = image["image"]image = ms.Tensor(image)prob = model.predict(image)label = np.argmax(prob.asnumpy(), axis=1)mapping = index2label()output = {int(label): mapping[int(label)]}print(output)show_result(img="./dataset/infer/n01440764/ILSVRC2012_test_00000279.JPEG",result=output,out_file="./dataset/infer/ILSVRC2012_test_00000279.JPEG")

輸出:

{236: 'Doberman'}

推理過程完成后

推理文件夾下找圖片推理結果

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/43985.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/43985.shtml
英文地址,請注明出處:http://en.pswp.cn/web/43985.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

MySQL(基礎篇)

DDL (Data Definition Language) 數據定義語言&#xff0c;用來定義數據庫對象(數據庫&#xff0c;表&#xff0c; 字段) DML (Data Manipulation Languag) 數據操作語言&#xff0c;用來對數據庫表中的數據進行增刪改 DQL (Data Query Language) 數據查詢語言&#xff0c;用…

前綴,中綴,后綴表達式

前綴表達式 前綴表達式&#xff08;也稱為波蘭式&#xff09;是一種將運算符放在操作數之前的表示數學表達式的方法。在前綴表達式中&#xff0c;操作符出現在它們所操作的操作數之前。 例如&#xff0c;將中綴表達式5 3轉換為前綴表達式&#xff0c;可以寫成 5 3。在這個例…

9 個讓 Python 性能更高的小技巧,你掌握了嗎?

我們經常聽到 “Python 太慢了”&#xff0c;“Python 性能不行”這樣的觀點。但是&#xff0c;只要掌握一些編程技巧&#xff0c;就能大幅提升 Python 的運行速度。 今天就讓我們一起來看下讓 Python 性能更高的 9 個小技巧 python學習資料分享&#xff08;無償&#xff09;…

數據(圖像)增廣

一、數據增強 1、增加一個已有數據集&#xff0c;使得有更多的多樣性&#xff0c;比如加入不同的背景噪音、改變圖片的顏色和形狀。 2、增強數據是在線生成的 3、增強類型&#xff1a; &#xff08;1&#xff09;翻轉 &#xff08;2&#xff09;切割 &#xff08;3&#xf…

金龍魚:只是躺槍?

中儲糧罐車運輸油罐混用事件持續發酵&#xff0c;食用油板塊集體躺槍。 消費者憤怒的火&#xff0c;怕是會讓食用油企們一點就著。 今天&#xff0c;我們聊聊“油”茅——金龍魚。 一邊是業內人士指出&#xff0c;油罐混用的現象普遍存在&#xff0c;另一邊是金龍魚回應稱&am…

2972.力扣每日一題7/11 Java(擊敗100%)

博客主頁&#xff1a;音符猶如代碼系列專欄&#xff1a;算法練習關注博主&#xff0c;后期持續更新系列文章如果有錯誤感謝請大家批評指出&#xff0c;及時修改感謝大家點贊&#x1f44d;收藏?評論? 目錄 解題思路 解題方法 時間復雜度 空間復雜度 Code 解題思路 該問…

RISC-V主要指令集介紹及規則

推薦資料 RISC-V Reader / RISC-V開放架構設計之道&#xff0c;適合新手閱讀。 概述 RISC-V的模塊化到底是如何實現的呢&#xff1f; 核心部分&#xff1a;RV32I&#xff0c;代表32位字長的整型指令集&#xff08;Integer&#xff09;&#xff0c;包含了許多整型指令如load…

在C++項目中添加錄像功能:從攝像頭捕獲到視頻文件的保存

在C項目中添加錄像功能&#xff1a;從攝像頭捕獲到視頻文件的保存 在這篇博客中&#xff0c;我們將介紹如何在一個現有的C項目中添加錄像功能&#xff0c;具體包括如何從攝像頭捕獲圖像并將其保存為視頻文件。我們將使用OpenCV庫來處理圖像捕獲和視頻寫入。 目錄 引言準備工…

Python學習筆記35:進階篇(二十四)pygame的使用之音頻文件播放

前言 基礎模塊的知識通過這么長時間的學習已經有所了解&#xff0c;更加深入的話需要通過完成各種項目&#xff0c;在這個過程中逐漸學習&#xff0c;成長。 我們的下一步目標是完成python crash course中的外星人入侵項目&#xff0c;這是一個2D游戲項目。在這之前&#xff…

元組列表之案例

1.列表推導式 基本語法&#xff1a; [表達式 for語句1 if 語句1 for語句2 if語句2 ........ ] 1.零到九的平方列表 a [i*i for i in range(10)] print(a) 2.for 循環前面加if else #如果是偶數乘以2&#xff0c;如果是奇數直接輸出 a [i*2 if i%2 0 else i for i in ran…

什么是生成器函數?

生成器函數&#xff08;Generator Function&#xff09;是 JavaScript 中一種特殊的函數&#xff0c;它可以在執行過程中暫停并在之后恢復執行。生成器函數使用 function* 語法定義&#xff0c;并且內部使用 yield 表達式來暫停函數執行并返回一個值。每次調用生成器函數返回的…

rabbitmq集群創建admin用戶之后,提示can access virtual hosts是No access狀態

問題描述&#xff1a; 因業務需要使用的rabbitmq是3.7.8版本的&#xff0c;rabbitmq在3.3.0之后就允許使用guest賬號的權限了&#xff0c;所以需要創建一個administrator標簽的用戶。 如下操作創建的用戶&#xff1a; 創建完成之后就提示如下的報錯&#xff1a; 注&#xff1a…

php表單提交并自動發送郵件給某個郵箱(示例源碼下載)

只需要將以下代碼內容進行復制即可用到自己的程序/API接口中&#xff1a; <?php if(!empty($_POST[is_post]) && $_POST[is_post]1){$url "https://www.aoksend.com/index/api/send_email";$name $_POST[name];$email $_POST[email];$subject $_POS…

探索Mojo模型:解鎖機器學習模型的可解釋性之旅

探索Mojo模型&#xff1a;解鎖機器學習模型的可解釋性之旅 在人工智能和機器學習領域&#xff0c;模型的可解釋性是一個至關重要的議題。隨著模型變得越來越復雜&#xff0c;理解模型的決策過程成為了一個挑戰。Mojo模型作為一種模型序列化格式&#xff0c;提供了一種方法來部…

Python 給存入 Redis 的鍵值對設置過期時間

Redis 是一種內存中的數據存儲系統&#xff0c;與許多傳統數據庫相比&#xff0c;它具有一些優勢&#xff0c;其中之一就是可以設置數據的過期時間。通過 Redis 的過期時間設置&#xff0c;可以為存儲在 Redis 中的數據設置一個特定的生存時間。一旦數據到達過期時間&#xff0…

mybatis日志記錄方案

首先對指定表進行監控 對表進行監控,那么就要使用的是statementInterceptor 攔截器 使用攔截器那么就要寫intercepts寫攔截條件進行攔截 監控只對與增刪改 查詢不進行監控 對于字段的監控,是誰修改了字段,那么就進行報警,或者提醒 消息提醒使用釘釘機器人進行消息提醒 P…

軟鏈接node_modules

公司項目很多微應用的子項目公用同一套模板&#xff0c;也就會使用同一個node_modules 1.先創建3個同樣的項目,并安裝一個其中的一個node_modules給他丟到外邊 2.win r -------> cmd --------> ctrlshift enter(已管理員身份打開cmd) 3.在窗口分別執行以下代碼…

視頻減小技巧:十大頂級視頻壓縮軟件

視頻壓縮軟件會盡可能地壓縮視頻&#xff0c;以便上傳到各個網站。通常&#xff0c;4K 或更高質量的視頻體積更大。壓縮軟件有助于壓縮體積。在這里&#xff0c;我們來討論一下 10 款最佳視頻壓縮軟件。 十大頂級視頻壓縮軟件 1. 奇客壓縮寶 奇客壓縮寶是由Geekersoft公司開發…

基于SpringBoot+MySQL的租房項目+文檔

&#x1f497;博主介紹&#x1f497;&#xff1a;?在職Java研發工程師、專注于程序設計、源碼分享、技術交流、專注于Java技術領域和畢業設計? 溫馨提示&#xff1a;文末有 CSDN 平臺官方提供的老師 Wechat / QQ 名片 :) Java精品實戰案例《700套》 2025最新畢業設計選題推薦…

數據庫系統中的Undo和Redo

在數據庫管理系統&#xff08;DBMS&#xff09;中&#xff0c;undo 和 redo 是兩種用于事務管理和故障恢復的重要機制。它們主要涉及事務的提交、回滾以及系統故障后的數據恢復。 Undo&#xff08;撤銷&#xff09; 作用&#xff1a;undo 用于撤銷未提交事務所做的修改&#…