Google的MLP-MIXer的復現(pytorch實現)

Google的MLP-MIXer的復現(pytorch實現)

該模型原論文實現用的jax框架實現,先貼出原論文的代碼實現:

# Copyright 2024 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.from typing import Any, Optionalimport einops
import flax.linen as nn
import jax
import jax.numpy as jnpclass MlpBlock(nn.Module):mlp_dim: int@nn.compactdef __call__(self, x):y = nn.Dense(self.mlp_dim)(x)y = nn.gelu(y)return nn.Dense(x.shape[-1])(y)class MixerBlock(nn.Module):"""Mixer block layer."""tokens_mlp_dim: intchannels_mlp_dim: int@nn.compactdef __call__(self, x):y = nn.LayerNorm()(x)y = jnp.swapaxes(y, 1, 2)y = MlpBlock(self.tokens_mlp_dim, name='token_mixing')(y) #  (32, 512, 196)y = jnp.swapaxes(y, 1, 2)x = x + yy = nn.LayerNorm()(x)return x + MlpBlock(self.channels_mlp_dim, name='channel_mixing')(y)class MlpMixer(nn.Module):"""Mixer architecture."""patches: Anynum_classes: intnum_blocks: inthidden_dim: inttokens_mlp_dim: intchannels_mlp_dim: intmodel_name: Optional[str] = None@nn.compactdef __call__(self, inputs, *, train):del trainx = nn.Conv(self.hidden_dim, self.patches.size,strides=self.patches.size, name='stem')(inputs)x = einops.rearrange(x, 'n h w c -> n (h w) c')  # 從(32,512,14,14)變成了(32,196,512)for _ in range(self.num_blocks):x = MixerBlock(self.tokens_mlp_dim, self.channels_mlp_dim)(x)x = nn.LayerNorm(name='pre_head_layer_norm')(x)x = jnp.mean(x, axis=1)if self.num_classes:x = nn.Dense(self.num_classes, kernel_init=nn.initializers.zeros,name='head')(x)return xmodel_params = {'patches': {'size': (16, 16), 'stride': (16, 16)}, # 這里需要一個描述patch大小和步長的對象,例如Flax的stem模塊初始化參數'num_classes': 10,  # 分類任務的類別數'num_blocks': 8,  # Mixer Block的重復次數'hidden_dim': 512,  # 隱藏層維度'tokens_mlp_dim': 256,  # token mixing的MLP維度'channels_mlp_dim': 2048,  # channel mixing的MLP維度
}# 準備輸入數據,例如一批32張圖片,每張圖片尺寸為512x14x14(假設已經按要求預處理)# 初始化模型
seed=0
key = jax.random.PRNGKey(seed)
model = MlpMixer.apply(key, **model_params)input_data = jnp.ones((4096, 224, 224, 3))  # 示例輸入數據
# 調用模型進行前向傳播
output = model(input_data)print("Output shape:", output)  # 打印輸出形狀,預期是(32, 10)如果num_classes=10

該模型的總體框架圖如下所示:

在這里插入圖片描述

對該框架的講解,網上已經很多了,就不在此贅述。

實現的pytorch代碼如下所示:

class MlpBlock(nn.Module):def __init__(self, in_mlp_dim=196, out_mlp_dim=256):super(MlpBlock, self).__init__()self.mlp_dim = out_mlp_dimself.dense1 = nn.Linear(in_mlp_dim, out_mlp_dim)  # 若輸入的向量為[32,196, 512]則輸入的也應該是512,輸出可以自己定self.gelu = nn.GELU()self.dense2 = nn.Linear(out_mlp_dim, in_mlp_dim)def forward(self, x):y = self.dense1(x)y = self.gelu(y)y = self.dense2(y)return yclass MixerBlock(nn.Module):def __init__(self, tokens_mlp_dim=256, channels_mlp_dim=2048, batch_size=32):super(MixerBlock, self).__init__()self.batch_size = batch_sizeself.norm1 = nn.LayerNorm(512)  # 對512維的做歸一化,默認給最后一個維度做歸一化self.token_Mixing = MlpBlock(out_mlp_dim=tokens_mlp_dim)self.norm2 = nn.LayerNorm(512)      # 對512維的做歸一化self.channel_mixing = MlpBlock(in_mlp_dim=512, out_mlp_dim=channels_mlp_dim)def forward(self, x):y = self.norm1(x)y = y.permute(0, 2, 1)y = self.token_Mixing(y)y = y.permute(0, 2, 1)x = x + yy = self.norm2(x)return x + self.channel_mixing(y)class MlpMixer(nn.Module):def __init__(self, patches, num_classes, num_blocks, hidden_dim, tokens_mlp_dim, channels_mlp_dim):super(MlpMixer, self).__init__()self.stem = nn.Conv2d(3, hidden_dim, kernel_size=patches, stride=patches)self.mixer_block_1 = MixerBlock()self.mixer_blocks = nn.ModuleList([MixerBlock(tokens_mlp_dim, channels_mlp_dim) for _ in range(num_blocks)])self.pre_head_norm = nn.LayerNorm(hidden_dim)self.head = nn.Linear(hidden_dim, num_classes) if num_classes > 0 else nn.Identity()def forward(self, x):x = self.stem(x)b, c, h, w = x.shapex = x.view(b, c, -1).permute(0, 2, 1)for mixer_block in self.mixer_blocks:x = mixer_block(x)x = self.pre_head_norm(x)x = x.mean(dim=1)x = self.head(x)return x# model = MlpMixer(16, 10, 6, 512, 256, 2048)
# input_tensor = torch.randn(32, 3, 224, 224)  # (batch size, num_patches, input_dim)
# output = model(input_tensor)
# print(output)

在將flax框架的代碼改為pytorch實現的時候,還是踩了不少的坑,在此講一下,希望后面做的人,可以避免。

1.在flax框架的nn.linear層中沒有輸入維度,只有一個輸出維度。

2.在處理兩個差異的時候,如輸入維度[32,196,512],其中代表的意思分別為batch_size為32,196為圖片在經過patch之后的224*224輸入之后經過patch=16,變為14 * 14即196,512會在二維卷積處理之后輸出的channel類似。

1.在flax框架的nn.linear層中沒有輸入維度,只有一個輸出維度。

2.在處理兩個差異的時候,如輸入維度[32,196,512],其中代表的意思分別為batch_size為32,196為圖片在經過patch之后的224*224輸入之后經過patch=16,變為14 * 14即196,512會在二維卷積處理之后輸出的channel類似。

在nn.linear那兒的in_channel與第三個維度保持一致,就可以不必將其三維的轉換為二維的。同時在對layernorm那兒轉換的時候,默認也是對最后一個維度進行正則化。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/16474.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/16474.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/16474.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

GEC210編譯環境搭建

一、下載編譯工具鏈 下載:點擊跳轉 二、解壓到 /usr/local/arm 目錄 sudo mv gec210.zip /usr/local/arm cd /usr/local/arm sudo unzip gec210.zip 三、添加到環境變量 PATH/usr/local/arm/arm-cortex_a8-linux-gnueabi-4.7.3/bin:$PATH 四、測試驗證 在終端…

python數據分析-基于數據挖掘對APP評分的預測

前言 當我們談論關于APP用戶分析與電子商務之間的聯系時,機器學習在這兩個領域的應用變得至關重要。App用戶分析和電子商務之間存在著密切的關聯,因為用戶行為和偏好的深入理解對于提高用戶體驗、增加銷售以及優化產品功能至關重要。故本文基于K-近鄰模…

OFDM 802.11a的FPGA實現(二十)使用AXI-Stream FIFO進行跨時鐘(含代碼)

目錄 1.前言 2.AXI-Stream FIFO時序 3.AXI-Stream FIFO配置信息 4.時鐘控制模塊MMCM 5.ModelSim仿真 6.總結 1.前言 至此,通過前面的文章講解,對于OFDM 802.11a的發射基帶的一個完整的PPDU幀的所有處理已經全部完成,其結構如下圖所示&…

opencv-C++ VS2019配置安裝

最新opencv-c安裝及配置教程(VS2019 C & opencv4.4.0)_c opencv配置-CSDN博客

夜雨觸花感懷

夜雨觸花感懷 雨落有軌跡,業成無坦途。 ?雞毛飛虛空,尋德問心路。 ?恰如求耕耘,大話量寸土。 ?好吃品五味,難得評真俗。

CAN總線簡介

1. CAN總線概述 1.1 CAN定義與歷史背景 CAN,全稱為Controller Area Network,是一種基于消息廣播的串行通信協議。它最初由德國Bosch公司在1983年為汽車行業開發,目的是實現汽車內部電子控制單元(ECUs)之間的可靠通信。…

用Vuex存儲可配置下載的ip地址(用XML進行ajax請求配置文件)

1.在public文件夾下創建一個名為Configuration的文件在創建一個Configuration.txt里面就放IP地址(這里的名字可以隨便命名一定性的被人解讀文件含義) 例如: http://172.171.208.1:80032.在store文件夾中創建一個名為 ajaxModule.js 的 Vuex …

2. CSS選擇器與偽類

2.1 基本選擇器回顧 在開始介紹CSS3選擇器之前&#xff0c;我們先回顧一下CSS的基本選擇器。這些選擇器是所有CSS開發的基礎。 2.1.1 元素選擇器 元素選擇器用于選中指定類型的HTML元素。 /* 選中所有的<p>元素 */ p {color: blue; }2.1.2 類選擇器 類選擇器用于選中…

03自動輔助導航駕駛NOP其實就是NOA

蔚來NOP是什么意思&#xff1f;蔚來NOP是啥 蔚來NOP的意思就是NavigateonPilot智能輔助導航駕駛&#xff0c;也就是大家俗稱的高階輔助駕駛&#xff0c;在車主設定好導航路線&#xff0c;并且符合開啟NOP條件的前提下&#xff0c;蔚來NOP可以代替駕駛員完成從A點到B點的智能輔助…

深入理解數倉開發(二)數據技術篇之數據同步

1、數據同步 數據同步我們之前在數倉當中使用了多種工具&#xff0c;比如使用 Flume 將日志文件從服務器采集到 Kafka&#xff0c;再通過 Flume 將 Kafka 中的數據采集到 HDFS。使用 MaxWell 實時監聽 MySQL 的 binlog 日志&#xff0c;并將采集到的變更日志&#xff08;json 格…

【二叉樹】:LeetCode:100.相同的數(分治)

&#x1f381;個人主頁&#xff1a;我們的五年 &#x1f50d;系列專欄&#xff1a;初階初階結構刷題 &#x1f389;歡迎大家點贊&#x1f44d;評論&#x1f4dd;收藏?文章 1.問題描述&#xff1a; 2.問題分析&#xff1a; 二叉樹是區分結構的&#xff0c;即左右子樹是不一…

[JDK工具-6] jmap java內存映射工具

文章目錄 1. 介紹2. 主要選項3. 生成java堆轉儲快照 jmap -dump4. 顯示堆詳細信息 jmap -heap pid5. 顯示堆中對象統計信息 jmap -histo pid jmap(Memory Map for Java) 1. 介紹 位置&#xff1a;jdk\bin 作用&#xff1a; jdk安裝后會自帶一些小工具&#xff0c;jmap命令(Mem…

PySide6升級導致的Fatal Python error: could not initialize part 2問題及其解決方法

問題出現 把PySide6從6.6.1升級到6.7.1&#xff0c;結果運行程序的時候就報如下錯誤&#xff1a; Traceback (most recent call last): File "signature_bootstrap.py", line 77, in bootstrap File "signature_bootstrap.py", line 93, in find_inc…

Kafka SASL_SSL集群認證

背景 公司需要對kafka環境進行安全驗證,目前考慮到的方案有Kerberos和SSL和SASL_SSL,最終考慮到安全和功能的豐富度,我們最終選擇了SASL_SSL方案。處于知識積累的角度,記錄一下kafka SASL_SSL安裝部署的步驟。 機器規劃 目前測試環境公搭建了三臺kafka主機服務,現在將詳…

H3CNE-7-TCP和UDP協議

TCP和UDP協議 TCP&#xff1a;可靠傳輸&#xff0c;面向連接 -------- 速度慢&#xff0c;準確性高 UDP&#xff1a;不可靠傳輸&#xff0c;非面向連接 -------- 速度快&#xff0c;但準確性差 面向連接&#xff1a;如果某應用層協議的四層使用TCP端口&#xff0c;那么正式的…

智能家居完結 -- 整體設計

系統框圖 前情提要: 智能家居1 -- 實現語音模塊-CSDN博客 智能家居2 -- 實現網絡控制模塊-CSDN博客 智能家居3 - 實現煙霧報警模塊-CSDN博客 智能家居4 -- 添加接收消息的初步處理-CSDN博客 智能家居5 - 實現處理線程-CSDN博客 智能家居6 -- 配置 ini文件優化設備添加-CS…

【MySQL】聊聊count的相關操作

在平時的操作中&#xff0c;經常使用count進行操作&#xff0c;計算統計的數據。那么具體的原理是如何的&#xff1f;為什么有時候執行count很慢。 count的實現方式 select count(*) from student;對于MyISAM引擎來說&#xff0c;會把一個表的總行數存儲在磁盤上&#xff0c;…

Linux下Vision Mamba環境配置+多CUDA版本切換

上篇文章大致講了下Vision Mamba的相關知識&#xff0c;網上關于Vision Mamba的配置博客太多&#xff0c;筆者主要用來整合下。 筆者在Win10和Linux下分別嘗試配置相關環境。 Win10下配置 失敗 \textcolor{red}{失敗} 失敗&#xff0c;最后出現的問題如下&#xff1a; https://…

基于物聯網架構的電子小票服務系統

1.電子小票物聯網架構 采用感知層、網絡層和應用層的3層物聯網體系架構模型&#xff0c;電子小票物聯網的架構見圖1。 圖1 電子小票物聯網架構 感知層的小票智能硬件能夠取代傳統的小票打印機&#xff0c;在不改變商家原有收銀系統的前提下&#xff0c;采集收音機待打印的購物…