Google的MLP-MIXer的復現(pytorch實現)
該模型原論文實現用的jax框架實現,先貼出原論文的代碼實現:
# Copyright 2024 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.from typing import Any, Optionalimport einops
import flax.linen as nn
import jax
import jax.numpy as jnpclass MlpBlock(nn.Module):mlp_dim: int@nn.compactdef __call__(self, x):y = nn.Dense(self.mlp_dim)(x)y = nn.gelu(y)return nn.Dense(x.shape[-1])(y)class MixerBlock(nn.Module):"""Mixer block layer."""tokens_mlp_dim: intchannels_mlp_dim: int@nn.compactdef __call__(self, x):y = nn.LayerNorm()(x)y = jnp.swapaxes(y, 1, 2)y = MlpBlock(self.tokens_mlp_dim, name='token_mixing')(y) # (32, 512, 196)y = jnp.swapaxes(y, 1, 2)x = x + yy = nn.LayerNorm()(x)return x + MlpBlock(self.channels_mlp_dim, name='channel_mixing')(y)class MlpMixer(nn.Module):"""Mixer architecture."""patches: Anynum_classes: intnum_blocks: inthidden_dim: inttokens_mlp_dim: intchannels_mlp_dim: intmodel_name: Optional[str] = None@nn.compactdef __call__(self, inputs, *, train):del trainx = nn.Conv(self.hidden_dim, self.patches.size,strides=self.patches.size, name='stem')(inputs)x = einops.rearrange(x, 'n h w c -> n (h w) c') # 從(32,512,14,14)變成了(32,196,512)for _ in range(self.num_blocks):x = MixerBlock(self.tokens_mlp_dim, self.channels_mlp_dim)(x)x = nn.LayerNorm(name='pre_head_layer_norm')(x)x = jnp.mean(x, axis=1)if self.num_classes:x = nn.Dense(self.num_classes, kernel_init=nn.initializers.zeros,name='head')(x)return xmodel_params = {'patches': {'size': (16, 16), 'stride': (16, 16)}, # 這里需要一個描述patch大小和步長的對象,例如Flax的stem模塊初始化參數'num_classes': 10, # 分類任務的類別數'num_blocks': 8, # Mixer Block的重復次數'hidden_dim': 512, # 隱藏層維度'tokens_mlp_dim': 256, # token mixing的MLP維度'channels_mlp_dim': 2048, # channel mixing的MLP維度
}# 準備輸入數據,例如一批32張圖片,每張圖片尺寸為512x14x14(假設已經按要求預處理)# 初始化模型
seed=0
key = jax.random.PRNGKey(seed)
model = MlpMixer.apply(key, **model_params)input_data = jnp.ones((4096, 224, 224, 3)) # 示例輸入數據
# 調用模型進行前向傳播
output = model(input_data)print("Output shape:", output) # 打印輸出形狀,預期是(32, 10)如果num_classes=10
該模型的總體框架圖如下所示:
對該框架的講解,網上已經很多了,就不在此贅述。
實現的pytorch代碼如下所示:
class MlpBlock(nn.Module):def __init__(self, in_mlp_dim=196, out_mlp_dim=256):super(MlpBlock, self).__init__()self.mlp_dim = out_mlp_dimself.dense1 = nn.Linear(in_mlp_dim, out_mlp_dim) # 若輸入的向量為[32,196, 512]則輸入的也應該是512,輸出可以自己定self.gelu = nn.GELU()self.dense2 = nn.Linear(out_mlp_dim, in_mlp_dim)def forward(self, x):y = self.dense1(x)y = self.gelu(y)y = self.dense2(y)return yclass MixerBlock(nn.Module):def __init__(self, tokens_mlp_dim=256, channels_mlp_dim=2048, batch_size=32):super(MixerBlock, self).__init__()self.batch_size = batch_sizeself.norm1 = nn.LayerNorm(512) # 對512維的做歸一化,默認給最后一個維度做歸一化self.token_Mixing = MlpBlock(out_mlp_dim=tokens_mlp_dim)self.norm2 = nn.LayerNorm(512) # 對512維的做歸一化self.channel_mixing = MlpBlock(in_mlp_dim=512, out_mlp_dim=channels_mlp_dim)def forward(self, x):y = self.norm1(x)y = y.permute(0, 2, 1)y = self.token_Mixing(y)y = y.permute(0, 2, 1)x = x + yy = self.norm2(x)return x + self.channel_mixing(y)class MlpMixer(nn.Module):def __init__(self, patches, num_classes, num_blocks, hidden_dim, tokens_mlp_dim, channels_mlp_dim):super(MlpMixer, self).__init__()self.stem = nn.Conv2d(3, hidden_dim, kernel_size=patches, stride=patches)self.mixer_block_1 = MixerBlock()self.mixer_blocks = nn.ModuleList([MixerBlock(tokens_mlp_dim, channels_mlp_dim) for _ in range(num_blocks)])self.pre_head_norm = nn.LayerNorm(hidden_dim)self.head = nn.Linear(hidden_dim, num_classes) if num_classes > 0 else nn.Identity()def forward(self, x):x = self.stem(x)b, c, h, w = x.shapex = x.view(b, c, -1).permute(0, 2, 1)for mixer_block in self.mixer_blocks:x = mixer_block(x)x = self.pre_head_norm(x)x = x.mean(dim=1)x = self.head(x)return x# model = MlpMixer(16, 10, 6, 512, 256, 2048)
# input_tensor = torch.randn(32, 3, 224, 224) # (batch size, num_patches, input_dim)
# output = model(input_tensor)
# print(output)
在將flax框架的代碼改為pytorch實現的時候,還是踩了不少的坑,在此講一下,希望后面做的人,可以避免。
1.在flax框架的nn.linear層中沒有輸入維度,只有一個輸出維度。
2.在處理兩個差異的時候,如輸入維度[32,196,512],其中代表的意思分別為batch_size為32,196為圖片在經過patch之后的224*224輸入之后經過patch=16,變為14 * 14即196,512會在二維卷積處理之后輸出的channel類似。
1.在flax框架的nn.linear層中沒有輸入維度,只有一個輸出維度。
2.在處理兩個差異的時候,如輸入維度[32,196,512],其中代表的意思分別為batch_size為32,196為圖片在經過patch之后的224*224輸入之后經過patch=16,變為14 * 14即196,512會在二維卷積處理之后輸出的channel類似。
在nn.linear那兒的in_channel與第三個維度保持一致,就可以不必將其三維的轉換為二維的。同時在對layernorm那兒轉換的時候,默認也是對最后一個維度進行正則化。