昇思25天學習打卡營第16天 | Vision Transformer圖像分類

昇思25天學習打卡營第16天 | Vision Transformer圖像分類

文章目錄

  • 昇思25天學習打卡營第16天 | Vision Transformer圖像分類
    • Vision Transform(ViT)模型
      • Transformer
        • Attention模塊
        • Encoder模塊
      • ViT模型輸入
    • 模型構建
      • Multi-Head Attention模塊
      • Encoder模塊
      • Patch Embedding模塊
      • ViT網絡
    • 總結
    • 打卡

Vision Transform(ViT)模型

ViT是NLP和CV領域的融合,可以在不依賴于卷積操作的情況下在圖像分類任務上達到很好的效果。

ViT模型的主體結構是基于Transformer的Encoder部分。

Transformer

Transformer由很多Encoder和Decoder模塊構成,包括多頭注意力(Multi-Head Attention)層,Feed Forward層,Normalization層和殘差連接(Residual Connection)。
encoder-decoder
多頭注意力結構基于自注意力機制(Self-Attention),是多個Self-Attention的并行組成。

Attention模塊

Attention的核心在于為輸入向量的每個單詞學習一個權重。

  1. 最初的輸入向量首先經過Embedding層映射為Q(Query),K(Key),V(Value)三個向量。
  2. 通過將Q和所有K進行點乘初一維度平方根,得到向量間的相似度,通過softmax獲取每詞向量之間的關系權重。
  3. 利用關系權重對詞向量的V加權求和,得到自注意力值。
    self-attention
    多頭注意力機制只是對self-attention的并行化:
    multi-head-attention
Encoder模塊

ViT中的Encoder相對于標準Transformer,主要在于將Normolization放在self-attention和Feed Forward之前,其他結構與標準Transformer相同。
vit-encoder

ViT模型輸入

傳統Transformer主要應用于自然語言處理的一維詞向量,而圖像時二維矩陣的堆疊。
在ViT中:

  1. 通過卷積將輸入圖像在每個channel上劃分為 16 × 16 16\times 16 16×16個patch。如果輸入 224 × 224 224\times224 224×224的圖像,則每一個patch的大小為 14 × 14 14\times 14 14×14
  2. 將每一個patch拉伸為一個一維向量,得到近似詞向量堆疊的效果。如將 14 × 14 14\times14 14×14展開為 196 196 196的向量。
    這一部分Patch Embedding用來替換Transformer中Word Embedding,用作網絡中的圖像輸入。

模型構建

Multi-Head Attention模塊

from mindspore import nn, opsclass Attention(nn.Cell):def __init__(self,dim: int,num_heads: int = 8,keep_prob: float = 1.0,attention_keep_prob: float = 1.0):super(Attention, self).__init__()self.num_heads = num_headshead_dim = dim // num_headsself.scale = ms.Tensor(head_dim ** -0.5)self.qkv = nn.Dense(dim, dim * 3)self.attn_drop = nn.Dropout(p=1.0-attention_keep_prob)self.out = nn.Dense(dim, dim)self.out_drop = nn.Dropout(p=1.0-keep_prob)self.attn_matmul_v = ops.BatchMatMul()self.q_matmul_k = ops.BatchMatMul(transpose_b=True)self.softmax = nn.Softmax(axis=-1)def construct(self, x):"""Attention construct."""b, n, c = x.shapeqkv = self.qkv(x)qkv = ops.reshape(qkv, (b, n, 3, self.num_heads, c // self.num_heads))qkv = ops.transpose(qkv, (2, 0, 3, 1, 4))q, k, v = ops.unstack(qkv, axis=0)attn = self.q_matmul_k(q, k)attn = ops.mul(attn, self.scale)attn = self.softmax(attn)attn = self.attn_drop(attn)out = self.attn_matmul_v(attn, v)out = ops.transpose(out, (0, 2, 1, 3))out = ops.reshape(out, (b, n, c))out = self.out(out)out = self.out_drop(out)return out

Encoder模塊

from typing import Optional, Dictclass FeedForward(nn.Cell):def __init__(self,in_features: int,hidden_features: Optional[int] = None,out_features: Optional[int] = None,activation: nn.Cell = nn.GELU,keep_prob: float = 1.0):super(FeedForward, self).__init__()out_features = out_features or in_featureshidden_features = hidden_features or in_featuresself.dense1 = nn.Dense(in_features, hidden_features)self.activation = activation()self.dense2 = nn.Dense(hidden_features, out_features)self.dropout = nn.Dropout(p=1.0-keep_prob)def construct(self, x):"""Feed Forward construct."""x = self.dense1(x)x = self.activation(x)x = self.dropout(x)x = self.dense2(x)x = self.dropout(x)return xclass ResidualCell(nn.Cell):def __init__(self, cell):super(ResidualCell, self).__init__()self.cell = celldef construct(self, x):"""ResidualCell construct."""return self.cell(x) + xclass TransformerEncoder(nn.Cell):def __init__(self,dim: int,num_layers: int,num_heads: int,mlp_dim: int,keep_prob: float = 1.,attention_keep_prob: float = 1.0,drop_path_keep_prob: float = 1.0,activation: nn.Cell = nn.GELU,norm: nn.Cell = nn.LayerNorm):super(TransformerEncoder, self).__init__()layers = []for _ in range(num_layers):normalization1 = norm((dim,))normalization2 = norm((dim,))attention = Attention(dim=dim,num_heads=num_heads,keep_prob=keep_prob,attention_keep_prob=attention_keep_prob)feedforward = FeedForward(in_features=dim,hidden_features=mlp_dim,activation=activation,keep_prob=keep_prob)layers.append(nn.SequentialCell([ResidualCell(nn.SequentialCell([normalization1, attention])),ResidualCell(nn.SequentialCell([normalization2, feedforward]))]))self.layers = nn.SequentialCell(layers)def construct(self, x):"""Transformer construct."""return self.layers(x)

Patch Embedding模塊

class PatchEmbedding(nn.Cell):MIN_NUM_PATCHES = 4def __init__(self,image_size: int = 224,patch_size: int = 16,embed_dim: int = 768,input_channels: int = 3):super(PatchEmbedding, self).__init__()self.image_size = image_sizeself.patch_size = patch_sizeself.num_patches = (image_size // patch_size) ** 2self.conv = nn.Conv2d(input_channels, embed_dim, kernel_size=patch_size, stride=patch_size, has_bias=True)def construct(self, x):"""Path Embedding construct."""x = self.conv(x)b, c, h, w = x.shapex = ops.reshape(x, (b, c, h * w))x = ops.transpose(x, (0, 2, 1))return x

ViT網絡

from mindspore.common.initializer import Normal
from mindspore.common.initializer import initializer
from mindspore import Parameterdef init(init_type, shape, dtype, name, requires_grad):"""Init."""initial = initializer(init_type, shape, dtype).init_data()return Parameter(initial, name=name, requires_grad=requires_grad)class ViT(nn.Cell):def __init__(self,image_size: int = 224,input_channels: int = 3,patch_size: int = 16,embed_dim: int = 768,num_layers: int = 12,num_heads: int = 12,mlp_dim: int = 3072,keep_prob: float = 1.0,attention_keep_prob: float = 1.0,drop_path_keep_prob: float = 1.0,activation: nn.Cell = nn.GELU,norm: Optional[nn.Cell] = nn.LayerNorm,pool: str = 'cls') -> None:super(ViT, self).__init__()self.patch_embedding = PatchEmbedding(image_size=image_size,patch_size=patch_size,embed_dim=embed_dim,input_channels=input_channels)num_patches = self.patch_embedding.num_patchesself.cls_token = init(init_type=Normal(sigma=1.0),shape=(1, 1, embed_dim),dtype=ms.float32,name='cls',requires_grad=True)self.pos_embedding = init(init_type=Normal(sigma=1.0),shape=(1, num_patches + 1, embed_dim),dtype=ms.float32,name='pos_embedding',requires_grad=True)self.pool = poolself.pos_dropout = nn.Dropout(p=1.0-keep_prob)self.norm = norm((embed_dim,))self.transformer = TransformerEncoder(dim=embed_dim,num_layers=num_layers,num_heads=num_heads,mlp_dim=mlp_dim,keep_prob=keep_prob,attention_keep_prob=attention_keep_prob,drop_path_keep_prob=drop_path_keep_prob,activation=activation,norm=norm)self.dropout = nn.Dropout(p=1.0-keep_prob)self.dense = nn.Dense(embed_dim, num_classes)def construct(self, x):"""ViT construct."""x = self.patch_embedding(x)cls_tokens = ops.tile(self.cls_token.astype(x.dtype), (x.shape[0], 1, 1))x = ops.concat((cls_tokens, x), axis=1)x += self.pos_embeddingx = self.pos_dropout(x)x = self.transformer(x)x = self.norm(x)x = x[:, 0]if self.training:x = self.dropout(x)x = self.dense(x)return x

總結

這一節對Transformer進行介紹,包括Attention機制、并行化的Attention以及Encoder模塊。由于傳統Transformer主要作用于一維的詞向量,因此二維圖像需要被轉換為類似的一維詞向量堆疊,在ViT中通過將Patch Embedding解決這一問題,并用來代替傳統Transformer中的Word Embedding作為網絡的輸入。

打卡

在這里插入圖片描述

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/46145.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/46145.shtml
英文地址,請注明出處:http://en.pswp.cn/web/46145.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

工業三防平板助力工廠生產數據實時管理

在當今高度數字化和智能化的工業生產環境中,工業三防平板正逐漸成為工廠實現生產數據實時管理的得力助手。這種創新的技術設備不僅能夠在惡劣的工業環境中穩定運行,還為工廠的生產流程優化、效率提升和質量控制帶來了前所未有的機遇。 工業生產場景通常充…

機器學習——數據預處理和特征工程(sklearn)

目錄 一、數據挖掘流程 1. 獲取數據 2. 數據預處理 3. 特征工程 4. 建模,測試模型并預測出結果 5. 驗證模型效果 二、sklearn中的相關包 1.sklearn.preprocessing 2.sklearn.Impute 3.sklearn.feature_selection 4.sklearn.decomposition 三、數據預處理…

【網絡安全】PostMessage:分析JS實現XSS

未經許可,不得轉載。 文章目錄 前言示例正文 前言 PostMessage是一個用于在網頁間安全地發送消息的瀏覽器 API。它允許不同的窗口(例如,來自同一域名下的不同頁面或者不同域名下的跨域頁面)進行通信,而無需通過服務器…

【Arduino IDE】安裝及開發環境、ESP32庫

一、Arduino IDE下載 二、Arduino IDE安裝 三、ESP32庫 四、Arduino-ESP32庫配置 五、新建ESP32-S3N15R8工程文件 樂鑫官網 Arduino官方下載地址 Arduino官方社區 Arduino中文社區 一、Arduino IDE下載 ESP-IDF、MicroPython和Arduino是三種不同的開發框架,各自適…

定制開發AI智能名片商城微信小程序在私域流量池構建中的應用與策略

摘要 在數字經濟蓬勃發展的今天,私域流量已成為企業競爭的新戰場。定制開發AI智能名片商城微信小程序,作為私域流量池構建的創新工具,正以其獨特的優勢助力企業實現用戶資源的深度挖掘與高效轉化。本文深入探討了定制開發AI智能名片商城微信…

.NET Framework、.NET Core 、 .NET 5、.NET 6和.NET 7 和.NET8 簡介及區別

簡述 在軟件開發的宇宙中,.NET是一個不斷擴展的星系,每個版本都像是一顆獨特的星球,擁有自己的特性和環境。作為技術經理,站在選擇的十字路口,您需要一張詳盡的星圖來導航。本文將作為您的向導,帶您穿越從.…

AIoTedge智能物聯網邊緣計算平臺:引領未來智能邊緣技術

引言 隨著物聯網技術的飛速發展,我們正步入一個萬物互聯的時代。AIoTedge智能物聯網邊緣計算平臺,以其創新的邊云協同架構,為智能設備和系統提供了強大的數據處理和智能決策能力,開啟了智能物聯網的新篇章。 智能邊緣計算平臺的核…

LLaMA-Factory

文章目錄 一、關于 LLaMA-Factory項目特色性能指標 二、如何使用1、安裝 LLaMA Factory2、數據準備3、快速開始4、LLaMA Board 可視化微調5、構建 DockerCUDA 用戶:昇騰 NPU 用戶:不使用 Docker Compose 構建CUDA 用戶:昇騰 NPU 用戶&#xf…

【Java項目筆記】01項目介紹

一、技術框架 1.后端服務 Spring Boot為主體框架 Spring MVC為Web框架 MyBatis、MyBatis Plus為持久層框架,負責數據庫的讀寫 阿里云短信服務 2.存儲服務 MySql redis緩存數據 MinIO為對象存儲,存儲非結構化數據(圖片、視頻、音頻&a…

推薦一款處理TCP數據的架構--EasyTcp4Net

EasyTcp4Net是一個基于c# Pipe,ReadonlySequence的高性能Tcp通信庫,旨在提供穩定,高效,可靠的tcp通訊服務。 基礎的消息通訊 重試機制 超時機制 SSL加密通信支持 KeepAlive 流量背壓控制 粘包和斷包處理 (支持固定頭處理,固定長度處理,固定字符處理) 日志支持Pipe &…

Spring MVC 的常用注解

RequestMapping 和 RestController注解 上面兩個注解,是Spring MCV最常用的注解。 RequestMapping , 他是用來注冊接口的路由映射。 路由映射:當一個用戶訪問url時,將用戶的請求對應到某個方法或類的過程叫做路由映射。 Reques…

定制QCustomPlot 帶有ListView的QCustomPlot 全網唯一份

定制QCustomPlot 帶有ListView的QCustomPlot 文章目錄 定制QCustomPlot 帶有ListView的QCustomPlot摘要需求描述實現關鍵字: Qt、 QCustomPlot、 魔改、 定制、 控件 摘要 先上效果,是你想要的,再看下面的分解,順便點贊搜藏一下;不是直接右上角。 QCustomPlot是一款…

基于springboot+vue+uniapp的駕校預約平臺小程序

開發語言:Java框架:springbootuniappJDK版本:JDK1.8服務器:tomcat7數據庫:mysql 5.7(一定要5.7版本)數據庫工具:Navicat11開發軟件:eclipse/myeclipse/ideaMaven包&#…

認識AOP--小白可看

AOP(Aspect-Oriented Programming,面向切面編程)是一種軟件開發范式,旨在通過橫切關注點(cross-cutting concerns)的方式來解耦系統中的各個模塊。橫切關注點指的是那些不屬于業務邏輯本身,但是…

Apache Sqoop

Apache Sqoop是一個開源工具,用于在Apache Hadoop和關系型數據庫(如MySQL、Oracle、PostgreSQL等)之間進行數據的批量傳輸。其主要功能包括: 1. 數據導入:從關系型數據庫(如MySQL、Oracle等)中將…

WPF設置歡迎屏幕,程序啟動過度動畫

當主窗體加載時間過長,這時候基本都會想添加一個等待操作來響應用戶點擊,提高用戶體驗。下面我記錄兩個方法,一點拙見,僅供參考。 方法1:在App類中使用SplashScreen類。 protected override void OnStartup(StartupEventArgs e)…

35.UART(通用異步收發傳輸器)-RS232(2)

(1)RS232接收模塊visio框圖: (2)接收模塊Verilog代碼編寫: /* 常見波特率: 4800、9600、14400、115200 在系統時鐘為50MHz時,對應計數為: (1/4800) * 10^9 /20 -1 10416 …

【作業】 貪心算法1

Tips:三題尚未完成。 #include <iostream> #include <algorithm> using namespace std; int a[110]; int main(){int n,r,sum0;cin>>n>>r;for(int i0;i<n;i){cin>>a[i];}sort(a0,an);for(int i0;i<n;i){if(i>r){a[i]a[i-r]a[i];}suma[…

[USACO18JAN] Cow at Large P

題解都說了&#xff0c;當統計 u u u為根節點的時候&#xff0c;答案就是滿足以下條件的 i i i的數量&#xff1a; d i ≥ g i d_i≥g_i di?≥gi?且 d f a i < g f a i d_{fa_i}<g_{fa_i} dfai??<gfai??&#xff0c;設這個數量為 a n s ans ans。以下嚴格證明 …

Solana開發資源都有哪些

Solana是一個高性能的區塊鏈平臺&#xff0c;吸引了大量開發者構建去中心化應用&#xff08;dApps&#xff09;。以下是一些有用的Solana開發教程和資源&#xff1a; 官方資源 Solana 官方文檔&#xff1a; Solana Documentation: 這是最全面的資源&#xff0c;包括快速入門、…