【MindSpore學習打卡】應用實踐-計算機視覺-深入解析 Vision Transformer(ViT):從原理到實踐

在近年來的深度學習領域,Transformer模型憑借其在自然語言處理(NLP)中的卓越表現,迅速成為研究熱點。尤其是基于自注意力(Self-Attention)機制的模型,更是推動了NLP的飛速發展。然而,隨著研究的深入,Transformer模型不僅在NLP領域大放異彩,還被引入到計算機視覺領域,形成了Vision Transformer(ViT)。ViT模型在不依賴傳統卷積神經網絡(CNN)的情況下,依然能夠在圖像分類任務中取得優異的效果。本文將深入解析ViT模型的結構、特點,并通過代碼示例展示如何使用MindSpore框架實現ViT模型的訓練、驗證和推理。

ViT模型結構

ViT模型的主體結構基于Transformer模型的編碼器(Encoder)部分,其整體結構如下圖所示:

vit-architecture

模型特點

為什么要使用Patch Embedding?

在傳統的Transformer模型中,輸入通常是一維的詞向量序列,而圖像數據是二維的像素矩陣。為了將圖像數據轉換為Transformer可以處理的形式,我們需要將圖像劃分為多個小塊(patch),并將每個patch轉換為一維向量。這一過程稱為Patch Embedding。通過這種方式,我們可以將圖像數據轉換為類似于詞向量的形式,從而利用Transformer模型處理圖像數據。
為什么要使用位置編碼(Position Embedding)?

由于Transformer模型在處理輸入序列時不考慮順序信息,因此在圖像數據中,patch之間的空間關系可能會丟失。為了解決這個問題,我們引入了位置編碼(Position Embedding),它為每個patch增加了位置信息,使得模型能夠識別不同patch之間的空間關系。這對于保留圖像的空間結構信息非常重要。

  1. Patch Embedding:輸入圖像被劃分為多個patch(圖像塊),然后將每個二維patch轉換為一維向量,并加上類別向量和位置向量作為模型輸入。
  2. Transformer Encoder:模型主體的Block結構基于Transformer的Encoder部分,主要結構是多頭注意力(Multi-Head Attention)和前饋神經網絡(Feed Forward)。
  3. 分類頭(Head):在Transformer Encoder堆疊后接一個全連接層,用于分類。

環境準備與數據讀取

開始實驗之前,請確保本地已經安裝了Python環境和MindSpore。

首先下載本案例的數據集,該數據集是從ImageNet中篩選出來的子集。數據集路徑結構如下:

.dataset/├── ILSVRC2012_devkit_t12.tar.gz├── train/├── infer/└── val/
from download import download
import os
import mindspore as ms
from mindspore.dataset import ImageFolderDataset
import mindspore.dataset.vision as transforms# 下載數據集
dataset_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/vit_imagenet_dataset.zip"
path = "./"
path = download(dataset_url, path, kind="zip", replace=True)data_path = './dataset/'
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]dataset_train = ImageFolderDataset(os.path.join(data_path, "train"), shuffle=True)trans_train = [transforms.RandomCropDecodeResize(size=224, scale=(0.08, 1.0), ratio=(0.75, 1.333)),transforms.RandomHorizontalFlip(prob=0.5),transforms.Normalize(mean=mean, std=std),transforms.HWC2CHW()
]dataset_train = dataset_train.map(operations=trans_train, input_columns=["image"])
dataset_train = dataset_train.batch(batch_size=16, drop_remainder=True)

Transformer基本原理

Transformer模型源于2017年的一篇文章,其主要結構為多個編碼器和解碼器模塊。編碼器和解碼器由多頭注意力(Multi-Head Attention)、前饋神經網絡(Feed Forward)、歸一化層(Normalization)和殘差連接(Residual Connection)組成。

Self-Attention機制

Self-Attention機制是Transformer的核心,其主要步驟如下:

  1. 輸入向量映射:將輸入向量映射成Query(Q)、Key(K)、Value(V)三個向量。
  2. 計算注意力權重:通過點乘計算Query和Key的相似性,并通過Softmax函數歸一化。
  3. 加權求和:使用注意力權重對Value進行加權求和,得到最終的Attention輸出。

以下是Self-Attention的代碼實現:

from mindspore import nn, opsclass Attention(nn.Cell):def __init__(self, dim: int, num_heads: int = 8, keep_prob: float = 1.0, attention_keep_prob: float = 1.0):super(Attention, self).__init__()self.num_heads = num_headshead_dim = dim // num_headsself.scale = ms.Tensor(head_dim ** -0.5)self.qkv = nn.Dense(dim, dim * 3)self.attn_drop = nn.Dropout(p=1.0-attention_keep_prob)self.out = nn.Dense(dim, dim)self.out_drop = nn.Dropout(p=1.0-keep_prob)self.attn_matmul_v = ops.BatchMatMul()self.q_matmul_k = ops.BatchMatMul(transpose_b=True)self.softmax = nn.Softmax(axis=-1)def construct(self, x):b, n, c = x.shapeqkv = self.qkv(x)qkv = ops.reshape(qkv, (b, n, 3, self.num_heads, c // self.num_heads))qkv = ops.transpose(qkv, (2, 0, 3, 1, 4))q, k, v = ops.unstack(qkv, axis=0)attn = self.q_matmul_k(q, k)attn = ops.mul(attn, self.scale)attn = self.softmax(attn)attn = self.attn_drop(attn)out = self.attn_matmul_v(attn, v)out = ops.transpose(out, (0, 2, 1, 3))out = ops.reshape(out, (b, n, c))out = self.out(out)out = self.out_drop(out)return out

Transformer Encoder

為什么要使用殘差連接(Residual Connection)和歸一化層(Normalization Layer)?

在深層神經網絡中,隨著層數的增加,梯度消失和梯度爆炸的問題變得越來越嚴重。殘差連接通過在每一層加上輸入的跳躍連接,可以有效緩解這些問題,確保信息能夠順利傳遞。此外,歸一化層(如LayerNorm)可以加速模型的訓練,并提高模型的穩定性和泛化能力。這些技術的結合,使得Transformer模型能夠在更深的層次上進行有效的訓練。

Transformer Encoder由多層Self-Attention和前饋神經網絡(Feed Forward)組成,通過殘差連接和歸一化層增強模型的訓練效果和泛化能力。

class FeedForward(nn.Cell):def __init__(self, in_features: int, hidden_features: Optional[int] = None, out_features: Optional[int] = None, activation: nn.Cell = nn.GELU, keep_prob: float = 1.0):super(FeedForward, self).__init__()out_features = out_features or in_featureshidden_features = hidden_features or in_featuresself.dense1 = nn.Dense(in_features, hidden_features)self.activation = activation()self.dense2 = nn.Dense(hidden_features, out_features)self.dropout = nn.Dropout(p=1.0-keep_prob)def construct(self, x):x = self.dense1(x)x = self.activation(x)x = self.dropout(x)x = self.dense2(x)x = self.dropout(x)return xclass ResidualCell(nn.Cell):def __init__(self, cell):super(ResidualCell, self).__init__()self.cell = celldef construct(self, x):return self.cell(x) + xclass TransformerEncoder(nn.Cell):def __init__(self, dim: int, num_layers: int, num_heads: int, mlp_dim: int, keep_prob: float = 1., attention_keep_prob: float = 1.0, drop_path_keep_prob: float = 1.0, activation: nn.Cell = nn.GELU, norm: nn.Cell = nn.LayerNorm):super(TransformerEncoder, self).__init__()layers = []for _ in range(num_layers):normalization1 = norm((dim,))normalization2 = norm((dim,))attention = Attention(dim=dim, num_heads=num_heads, keep_prob=keep_prob, attention_keep_prob=attention_keep_prob)feedforward = FeedForward(in_features=dim, hidden_features=mlp_dim, activation=activation, keep_prob=keep_prob)layers.append(nn.SequentialCell([ResidualCell(nn.SequentialCell([normalization1, attention])), ResidualCell(nn.SequentialCell([normalization2, feedforward]))]))self.layers = nn.SequentialCell(layers)def construct(self, x):return self.layers(x)

ViT模型的輸入

ViT模型通過將輸入圖像劃分為多個patch,將每個patch轉換為一維向量,并加上類別向量和位置向量作為模型輸入。以下是Patch Embedding的代碼實現:

class PatchEmbedding(nn.Cell):MIN_NUM_PATCHES = 4def __init__(self, image_size: int = 224, patch_size: int = 16, embed_dim: int = 768, input_channels: int = 3):super(PatchEmbedding, self).__init__()self.image_size = image_sizeself.patch_size = patch_sizeself.num_patches = (image_size // patch_size) ** 2self.conv = nn.Conv2d(input_channels, embed_dim, kernel_size=patch_size, stride=patch_size, has_bias=True)def construct(self, x):x = self.conv(x)b, c, h, w = x.shapex = ops.reshape(x, (b, c, h * w))x = ops.transpose(x, (0, 2, 1))return x

整體構建ViT

以下代碼構建了一個完整的ViT模型:

from mindspore.common.initializer import Normal
from mindspore.common.initializer import initializer
from mindspore import Parameterdef init(init_type, shape, dtype, name, requires_grad):initial = initializer(init_type, shape, dtype).init_data()return Parameter(initial, name=name, requires_grad=requires_grad)class ViT(nn.Cell):def __init__(self, image_size: int = 224, input_channels: int = 3, patch_size: int = 16, embed_dim: int = 768, num_layers: int = 12, num_heads: int = 12, mlp_dim: int = 3072, keep_prob: float = 1.0, attention_keep_prob: float = 1.0, drop_path_keep_prob: float = 1.0, activation: nn.Cell = nn.GELU, norm: Optional[nn.Cell] = nn.LayerNorm, pool: str = 'cls') -> None:super(ViT, self).__init__()self.patch_embedding = PatchEmbedding(image_size=image_size, patch_size=patch_size, embed_dim=embed_dim, input_channels=input_channels)num_patches = self.patch_embedding.num_patchesself.cls_token = init(init_type=Normal(sigma=1.0), shape=(1, 1, embed_dim), dtype=ms.float32, name='cls', requires_grad=True)self.pos_embedding = init(init_type=Normal(sigma=1.0), shape=(1, num_patches + 1, embed_dim), dtype=ms.float32, name='pos_embedding', requires_grad=True)self.pool = poolself.pos_dropout = nn.Dropout(p=1.0-keep_prob)self.norm = norm((embed_dim,))self.transformer = TransformerEncoder(dim=embed_dim, num_layers=num_layers, num_heads=num_heads, mlp_dim=mlp_dim, keep_prob=keep_prob, attention_keep_prob=attention_keep_prob, drop_path_keep_prob=drop_path_keep_prob, activation=activation, norm=norm)self.dropout = nn.Dropout(p=1.0-keep_prob)self.dense = nn.Dense(embed_dim, num_classes)def construct(self, x):x = self.patch_embedding(x)cls_tokens = ops.tile(self.cls_token.astype(x.dtype), (x.shape[0], 1, 1))x = ops.concat((cls_tokens, x), axis=1)x += self.pos_embeddingx = self.pos_dropout(x)x = self.transformer(x)x = self.norm(x)x = x[:, 0]if self.training:x = self.dropout(x)x = self.dense(x)return x

模型訓練與推理

模型訓練

模型訓練前,需要設定損失函數、優化器和回調函數。以下是訓練ViT模型的代碼:

from mindspore.nn import LossBase
from mindspore.train import LossMonitor, TimeMonitor, CheckpointConfig, ModelCheckpoint
from mindspore import train# 定義超參數
epoch_size = 10
momentum = 0.9
num_classes = 1000
resize = 224
step_size = dataset_train.get_dataset_size()# 構建模型
network = ViT()# 加載預訓練模型參數
vit_url = "https://download.mindspore.cn/vision/classification/vit_b_16_224.ckpt"
path = "./ckpt/vit_b_16_224.ckpt"
vit_path = download(vit_url, path, replace=True)
param_dict = ms.load_checkpoint(vit_path)
ms.load_param_into_net(network, param_dict)# 定義學習率
lr = nn.cosine_decay_lr(min_lr=float(0), max_lr=0.00005, total_step=epoch_size * step_size, step_per_epoch=step_size, decay_epoch=10)# 定義優化器
network_opt = nn.Adam(network.trainable_params(), lr, momentum)# 定義損失函數
class CrossEntropySmooth(LossBase):def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):super(CrossEntropySmooth, self).__init__()self.onehot = ops.OneHot()self.sparse = sparseself.on_value = ms.Tensor(1.0 - smooth_factor, ms.float32)self.off_value = ms.Tensor(1.0 * smooth_factor / (num_classes - 1), ms.float32)self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)def construct(self, logit, label):if self.sparse:label = self.onehot(label, ops.shape(logit)[1], self.on_value, self.off_value)loss = self.ce(logit, label)return lossnetwork_loss = CrossEntropySmooth(sparse=True, reduction="mean", smooth_factor=0.1, num_classes=num_classes)# 設置檢查點
ckpt_config = CheckpointConfig(save_checkpoint_steps=step_size, keep_checkpoint_max=100)
ckpt_callback = ModelCheckpoint(prefix='vit_b_16', directory='./ViT', config=ckpt_config)# 初始化模型
ascend_target = (ms.get_context("device_target") == "Ascend")
if ascend_target:model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics={"acc"}, amp_level="O2")
else:model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics={"acc"}, amp_level="O0")# 訓練模型
model.train(epoch_size, dataset_train, callbacks=[ckpt_callback, LossMonitor(125), TimeMonitor(125)], dataset_sink_mode=False)

在這里插入圖片描述

模型驗證

模型驗證過程主要應用了ImageFolderDataset,CrossEntropySmooth和Model等接口。以下是驗證ViT模型的代碼:

dataset_val = ImageFolderDataset(os.path.join(data_path, "val"), shuffle=True)trans_val = [transforms.Decode(),transforms.Resize(224 + 32),transforms.CenterCrop(224),transforms.Normalize(mean=mean, std=std),transforms.HWC2CHW()
]dataset_val = dataset_val.map(operations=trans_val, input_columns=["image"])
dataset_val = dataset_val.batch(batch_size=16, drop_remainder=True)# 構建模型
network = ViT()# 加載預訓練模型參數
param_dict = ms.load_checkpoint(vit_path)
ms.load_param_into_net(network, param_dict)network_loss = CrossEntropySmooth(sparse=True, reduction="mean", smooth_factor=0.1, num_classes=num_classes)# 定義評價指標
eval_metrics = {'Top_1_Accuracy': train.Top1CategoricalAccuracy(), 'Top_5_Accuracy': train.Top5CategoricalAccuracy()}if ascend_target:model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics=eval_metrics, amp_level="O2")
else:model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics=eval_metrics, amp_level="O0")# 驗證模型
result = model.eval(dataset_val)
print(result)

模型推理

在進行模型推理之前,首先要定義一個對推理圖片進行數據預處理的方法。以下是推理ViT模型的代碼:

dataset_infer = ImageFolderDataset(os.path.join(data_path, "infer"), shuffle=True)trans_infer = [transforms.Decode(),transforms.Resize([224, 224]),transforms.Normalize(mean=mean, std=std),transforms.HWC2CHW()
]dataset_infer = dataset_infer.map(operations=trans_infer, input_columns=["image"], num_parallel_workers=1)
dataset_infer = dataset_infer.batch(1)# 讀取推理數據
for i, image in enumerate(dataset_infer.create_dict_iterator(output_numpy=True)):image = image["image"]image = ms.Tensor(image)prob = model.predict(image)label = np.argmax(prob.asnumpy(), axis=1)mapping = index2label()output = {int(label): mapping[int(label)]}print(output)show_result(img="./dataset/infer/n01440764/ILSVRC2012_test_00000279.JPEG", result=output, out_file="./dataset/infer/ILSVRC2012_test_00000279.JPEG")

在這里插入圖片描述
在這里插入圖片描述

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/40316.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/40316.shtml
英文地址,請注明出處:http://en.pswp.cn/web/40316.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

歐拉openEuler 22.03 LTS-部署k8sv1.03.1

1.設置ip # vi /etc/sysconfig/network-scripts/ifcfg-ens32 TYPEEthernet PROXY_METHODnone BROWSER_ONLYno BOOTPROTOstatic DEFROUTEyes IPV4_FAILURE_FATALno #IPV6INITyes #IPV6_AUTOCONFyes #IPV6_DEFROUTEyes #IPV6_FAILURE_FATALno #IPV6_ADDR_GEN_MODEeui64 NAMEens1…

物聯網數據解析實戰:掌握CJSON庫核心函數,精準處理JSON數據

物聯網數據解析實戰:掌握CJSON庫核心函數,精準處理JSON數據 CJSON庫是一個輕量級的JSON解析庫,專為C語言設計,適用于嵌入式系統和物聯網應用。它提供了簡單易用的API,使得開發者能夠輕松地解析和生成JSON數據。在本教…

部署Gunicorn + Flask應用到Docker

部署Gunicorn Flask應用到Docker中涉及幾個步驟,下面是一個基本的指南: 1. 創建Flask應用 首先,確保你有一個可用的Flask應用。這里有一個簡單的示例: from flask import Flask app Flask(__name__)app.route(/) def hello_w…

pandas,dataframe使用筆記

目錄 新建一個dataframe不帶列名帶列名 dataframe添加一行內容查看dataframe某列的數據類型新建dataframe時設置了列名,則數據類型為object dataframe的保存保存為csv文件保存為excel文件 dataframe屬于pandas 新建一個dataframe 不帶列名 df pd.DataFrame() 帶…

GuLi商城-商品服務-API-品牌管理-效果優化與快速顯示開關

<template><div class"mod-config"><el-form :inline"true" :model"dataForm" keyup.enter.native"getDataList()"><el-form-item><el-input v-model"dataForm.key" placeholder"參數名&qu…

華為交換機 LACP協議

華為交換機支持的LACP協議&#xff0c;即鏈路聚合控制協議&#xff0c;是一種基于IEEE 802.3ad標準的動態鏈路聚合與解聚合的協議。它允許設備根據自身配置自動形成聚合鏈路并啟動聚合鏈路收發數據。 在LACP模式下&#xff0c;鏈路聚合組能夠自動調整鏈路聚合&#xff0c;維護…

java集合(1)

目錄 一.集合概述 二. 集合體系概述 1. Collection接口 1.1 List接口 1.2 Set接口 2. Map接口 三. ArrayList 1.ArrayList常用方法 2.ArrayList遍歷 2.1 for循環 2.2 增強for循環 2.3 迭代器遍歷 一.集合概述 我們經常需要存儲一些數據類型相同的元素,之前我們學過…

Java 基礎語法

Java 是一種面向對象的編程語言&#xff0c;具有簡單、健壯、安全、跨平臺等特點。下面是Java基礎語法的詳細介紹&#xff0c;并附帶一些示例說明&#xff1a; ### 1. 變量和數據類型 Java 中的變量用于存儲數據&#xff0c;必須先聲明后使用。Java 的數據類型分為基本數據類…

C++ 仿QT信號槽二

// 實現原理 // 每個signal映射到bitset位&#xff0c;全集 // 每個slot做為signal的bitset子集 // signal全集觸發&#xff0c;標志位有效 // flip將觸發事件隊列前置 // slot檢測智能指針全集觸發的標志位&#xff0c;主動運行子集綁定的函數 // 下一幀對bitset全集進行觸發清…

【C++】 解決 C++ 語言報錯:Segmentation Fault

文章目錄 引言 段錯誤&#xff08;Segmentation Fault&#xff09;是 C 編程中常見且令人頭疼的錯誤之一。段錯誤通常發生在程序試圖訪問未被允許的內存區域時&#xff0c;導致程序崩潰。本文將深入探討段錯誤的產生原因、檢測方法及其預防和解決方案&#xff0c;幫助開發者在…

Lex Fridman Podcast with Andrej Karpathy

我不太喜歡Lex Fridman的聲音&#xff0c;總覺得那讓人昏昏欲睡&#xff0c; 但無奈他采訪的人都太大牌了&#xff0c;只能去聽。但是聽著聽著&#xff0c;就會覺得有深度的采訪這些人&#xff0c;似乎也只有他這種由研究員背景的人能干&#xff0c; 另&#xff0c;他提的問題確…

4.2 投影

一、投影和投影矩陣 我們以下面兩個問題開始&#xff0c;問題一是為了展示投影是很容易視覺化的&#xff0c;問題二是關于 “投影矩陣”&#xff08;projection matrices&#xff09;—— 對稱矩陣且 P 2 P P^2P P2P。 b \boldsymbol b b 的投影是 P b P\boldsymbol b Pb。…

android的dump_processe中anon和swap字段的含義是什么?計算進程占用內存大小是否可以用這兩個字段相加?

在Android系統中&#xff0c;dump_processes 命令或類似機制&#xff08;如通過 adb shell dumpsys&#xff09;的輸出中&#xff0c;可能會包含與進程內存使用相關的信息&#xff0c;但通常不直接以 anon 和 swap 作為字段名。不過&#xff0c;基于您的提問&#xff0c;我可以…

嵌入式學習——硬件(Linux內核驅動編程LED、蜂鳴器、按鍵)——day59

1. 編寫LED驅動&#xff08;初始化所有子設備號&#xff09; #include <linux/init.h> #include <linux/module.h> #include <linux/kernel.h> #include <linux/fs.h> #include <asm/uaccess.h> #include <asm/io.h>#define GPBCON (0x5…

2024年7月5日 (周五) 葉子游戲新聞

老板鍵工具來喚去: 它可以為常用程序自定義快捷鍵&#xff0c;實現一鍵喚起、一鍵隱藏的 Windows 工具&#xff0c;并且支持窗口動態綁定快捷鍵&#xff08;無需設置自動實現&#xff09;。 卸載工具 HiBitUninstaller: Windows上的軟件卸載工具 《樂高地平線大冒險》為何不登陸…

江漢大學劉春萌同學整理的wifi模塊 上傳mqtt實驗步驟

一.固件燒錄 1.打開安信可官網 2.點擊wifi模組系列的ESP8266 3.點擊各類固件后選擇固件號1471下載 4.打開燒錄工具將下載的二進制文件導入并將后面的起始地址寫為0x00000,下面勾選40mhz QIO 8Mbit點擊start下載即可 二.本地部署mqtt服務器(windows) 1.下載mosquitto后有一個m…

Java并發編程知識整理筆記

目錄 ?1. 什么是線程和進程&#xff1f; 線程與進程有什么區別&#xff1f; 那什么是上下文切換&#xff1f; 進程間怎么通信&#xff1f; 什么是用戶線程和守護線程&#xff1f; 2. 并行和并發的區別&#xff1f; 3. 創建線程的幾種方式&#xff1f; Runnable接口和C…

微博視頻下載

video_urls 獲取xpath://video/src|//video/autoplay # !/usr/bin/python3 # -*- coding:utf-8 -*- """ author: JHC000abcgmail.com file: demo1.py time: 2024/6/3 18:00 desc:""" import os import re import requests from urllib.parse im…

Qt實現流動的管道效果代碼示例

在現代圖形用戶界面&#xff08;GUI&#xff09;應用程序中&#xff0c;動態效果可以顯著增強用戶體驗。本文將介紹如何使用Qt框架實現一個流動的管道效果。我們將通過自定義QWidget來繪制管道&#xff0c;并使用定時器來實現流動效果。 1. 準備工作 首先&#xff0c;確保你已…

LeetCode.68文本左右對齊

問題描述 給定一個單詞數組 words 和一個長度 maxWidth &#xff0c;重新排版單詞&#xff0c;使其成為每行恰好有 maxWidth 個字符&#xff0c;且左右兩端對齊的文本。 你應該使用 “貪心算法” 來放置給定的單詞&#xff1b;也就是說&#xff0c;盡可能多地往每行中放置單詞…