參考:
MOE原理解釋及從零實現一個MOE(專家混合模型)_moe代碼-CSDN博客
MoE環游記:1、從幾何意義出發 - 科學空間|Scientific Spaces?
深度學習之圖像分類(二十八)-- Sparse-MLP(MoE)網絡詳解_sparse moe-CSDN博客
深度學習之圖像分類(二十九)-- Sparse-MLP網絡詳解_sparse mlp-CSDN博客?
?
代碼如下:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader# 超參數設置
num_experts = 4 # 專家數量
top_k = 2 # 激活專家數
# input_dim = 3072 # CIFAR-10圖像展平后維度(32x32x3)
input_dim = 64 * 8 * 8
hidden_dim = 512 # 專家網絡隱藏層維度
num_classes = 10 # 分類類別數# MoE層實現(文獻[5][7])
class SparseMoE(nn.Module):def __init__(self):super().__init__()self.experts = nn.ModuleList([nn.Sequential(nn.Linear(input_dim, hidden_dim),nn.ReLU(),nn.Linear(hidden_dim, hidden_dim)) for _ in range(num_experts)])self.gate = nn.Sequential(nn.Linear(input_dim, num_experts),nn.Softmax(dim=1))# 負載均衡參數(文獻[4][7])self.balance_loss_weight = 0.01self.register_buffer('expert_counts', torch.zeros(num_experts))def forward(self, x):# 門控計算gate_scores = self.gate(x) # [B, num_experts]# Top-k選擇(文獻[5])topk_scores, topk_indices = torch.topk(gate_scores, top_k, dim=1)mask = F.one_hot(topk_indices, num_experts).float().sum(dim=1)# 專家輸出聚合expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=1)selected_experts = expert_outputs.gather(1, topk_indices.unsqueeze(-1).expand(-1, -1, hidden_dim)) # [B, 2, H]# print(f"專家輸出維度: {expert_outputs.shape}")# print(f"選擇索引維度: {topk_indices.shape}")# print(f"選擇專家維度: {selected_experts.shape}")weighted_outputs = (selected_experts * topk_scores.unsqueeze(-1)).sum(dim=1)# 更新專家使用統計self.expert_counts += mask.sum(dim=0)return weighted_outputsdef balance_loss(self):# 計算負載均衡損失(文獻[4][7])expert_probs = self.expert_counts / self.expert_counts.sum()balance_loss = torch.std(expert_probs) * self.balance_loss_weightself.expert_counts.zero_() # 重置計數器return balance_loss# 完整模型架構(文獻[2][6])
class MoEImageClassifier(nn.Module):def __init__(self):super().__init__()self.feature_extractor = nn.Sequential(nn.Conv2d(3, 32, 3, padding=1),nn.ReLU(),nn.MaxPool2d(2),nn.Conv2d(32, 64, 3, padding=1),nn.ReLU(),nn.MaxPool2d(2))self.moe_layer = SparseMoE()self.classifier = nn.Linear(hidden_dim, num_classes)def forward(self, x):x = self.feature_extractor(x)x = x.view(x.size(0), -1) # 展平特征x = self.moe_layer(x)return self.classifier(x)# 數據預處理(文獻[2])
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])train_set = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_set, batch_size=64, shuffle=True)# 訓練流程
model = MoEImageClassifier()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)for epoch in range(10):for images, labels in train_loader:optimizer.zero_grad()outputs = model(images)main_loss = criterion(outputs, labels)balance_loss = model.moe_layer.balance_loss()total_loss = main_loss + balance_losstotal_loss.backward()optimizer.step()print(f'Epoch [{epoch+1}/10], Loss: {total_loss.item():.4f}')