手寫系列——MoE網絡

參考:

MOE原理解釋及從零實現一個MOE(專家混合模型)_moe代碼-CSDN博客

MoE環游記:1、從幾何意義出發 - 科學空間|Scientific Spaces?

深度學習之圖像分類(二十八)-- Sparse-MLP(MoE)網絡詳解_sparse moe-CSDN博客

深度學習之圖像分類(二十九)-- Sparse-MLP網絡詳解_sparse mlp-CSDN博客?

?

代碼如下:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader# 超參數設置
num_experts = 4      # 專家數量
top_k = 2            # 激活專家數
# input_dim = 3072     # CIFAR-10圖像展平后維度(32x32x3)
input_dim = 64 * 8 * 8
hidden_dim = 512     # 專家網絡隱藏層維度
num_classes = 10     # 分類類別數# MoE層實現(文獻[5][7])
class SparseMoE(nn.Module):def __init__(self):super().__init__()self.experts = nn.ModuleList([nn.Sequential(nn.Linear(input_dim, hidden_dim),nn.ReLU(),nn.Linear(hidden_dim, hidden_dim)) for _ in range(num_experts)])self.gate = nn.Sequential(nn.Linear(input_dim, num_experts),nn.Softmax(dim=1))# 負載均衡參數(文獻[4][7])self.balance_loss_weight = 0.01self.register_buffer('expert_counts', torch.zeros(num_experts))def forward(self, x):# 門控計算gate_scores = self.gate(x)  # [B, num_experts]# Top-k選擇(文獻[5])topk_scores, topk_indices = torch.topk(gate_scores, top_k, dim=1)mask = F.one_hot(topk_indices, num_experts).float().sum(dim=1)# 專家輸出聚合expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=1)selected_experts = expert_outputs.gather(1, topk_indices.unsqueeze(-1).expand(-1, -1, hidden_dim))  # [B, 2, H]# print(f"專家輸出維度: {expert_outputs.shape}")# print(f"選擇索引維度: {topk_indices.shape}")# print(f"選擇專家維度: {selected_experts.shape}")weighted_outputs = (selected_experts  * topk_scores.unsqueeze(-1)).sum(dim=1)# 更新專家使用統計self.expert_counts += mask.sum(dim=0)return weighted_outputsdef balance_loss(self):# 計算負載均衡損失(文獻[4][7])expert_probs = self.expert_counts / self.expert_counts.sum()balance_loss = torch.std(expert_probs) * self.balance_loss_weightself.expert_counts.zero_()  # 重置計數器return balance_loss# 完整模型架構(文獻[2][6])
class MoEImageClassifier(nn.Module):def __init__(self):super().__init__()self.feature_extractor = nn.Sequential(nn.Conv2d(3, 32, 3, padding=1),nn.ReLU(),nn.MaxPool2d(2),nn.Conv2d(32, 64, 3, padding=1),nn.ReLU(),nn.MaxPool2d(2))self.moe_layer = SparseMoE()self.classifier = nn.Linear(hidden_dim, num_classes)def forward(self, x):x = self.feature_extractor(x)x = x.view(x.size(0), -1)  # 展平特征x = self.moe_layer(x)return self.classifier(x)# 數據預處理(文獻[2])
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])train_set = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_set, batch_size=64, shuffle=True)# 訓練流程
model = MoEImageClassifier()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)for epoch in range(10):for images, labels in train_loader:optimizer.zero_grad()outputs = model(images)main_loss = criterion(outputs, labels)balance_loss = model.moe_layer.balance_loss()total_loss = main_loss + balance_losstotal_loss.backward()optimizer.step()print(f'Epoch [{epoch+1}/10], Loss: {total_loss.item():.4f}')

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/71837.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/71837.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/71837.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Linux的基礎指令和環境部署,項目部署實戰(下)

目錄 上一篇:Linxu的基礎指令和環境部署,項目部署實戰(上)-CSDN博客 1. 搭建Java部署環境 1.1 apt apt常用命令 列出所有的軟件包 更新軟件包數據庫 安裝軟件包 移除軟件包 1.2 JDK 1.2.1. 更新 1.2.2. 安裝openjdk&am…

【藍橋杯】第十五屆省賽大學真題組真題解析

【藍橋杯】第十五屆省賽大學真題組真題解析 一、智能停車系統 1、知識點 (1)flex-wrap 控制子元素的換行方式 屬性值有: no-wrap不換行wrap伸縮容器不夠則自動往下換行wrap-reverse伸縮容器不夠則自動往上換行 (2&#xff0…

flink operator v1.10對接華為云對象存儲OBS

1 概述 flink operator及其flink集群,默認不直接支持華為云OBS,需要在這些java程序的插件目錄放一個jar包,以及修改flink配置后,才能支持集成華為云OBS。 相關鏈接參考: https://support.huaweicloud.com/bestpracti…

免費PDF工具

Smallpdf.com - A Free Solution to all your PDF Problems Smallpdf - the platform that makes it super easy to convert and edit all your PDF files. Solving all your PDF problems in one place - and yes, free. https://smallpdf.com/#rappSmallpdf.com-解決您所有PD…

去中心化技術P2P框架

中心化網絡與去中心化網絡 1. 中心化網絡 在傳統的中心化網絡中,所有客戶端都通過一個中心服務器進行通信。這種網絡拓撲結構通常是一個星型結構,其中服務器作為中心節點,每個客戶端只能與服務器通信。如果客戶端之間需要通信,必須…

muduo源碼閱讀:linux timefd定時器

?timerfd timerfd 是Linux一個定時器接口,它基于文件描述符工作,并通過該文件描述符的可讀事件進行超時通知。可以方便地與select、poll和epoll等I/O多路復用機制集成,從而在沒有處理事件時阻塞程序執行,實現高效的零輪詢編程模…

Pinia 3.0 正式發布:全面擁抱 Vue 3 生態,升級指南與實戰教程

一、重大版本更新解析 2024年2月11日,Vue 官方推薦的狀態管理庫 Pinia 迎來 3.0 正式版發布,本次更新標志著其全面轉向 Vue 3 技術生態。以下是開發者需要重點關注的升級要點: 1.1 核心變更說明 特性3.0 版本要求兼容性說明Vue 支持Vue 3.…

【圖像處理 --- Sobel 邊緣檢測的詳解】

Sobel 邊緣檢測的詳解 目錄 Sobel 邊緣檢測的詳解1. 梯度計算2. 梯度大小3. 梯度方向4. 非極大值抑制5. 雙閾值處理6. 在 MATLAB 中實現 Sobel 邊緣檢測7.運行結果展示8.關鍵參數解釋9.實驗與驗證 Sobel 邊緣檢測是一種經典的圖像處理算法,用于檢測圖像中的邊緣。它…

LeetCode 熱題100 15. 三數之和

LeetCode 熱題100 | 15. 三數之和 大家好,今天我們來解決一道經典的算法題——三數之和。這道題在 LeetCode 上被標記為中等難度,要求我們從一個整數數組中找到所有不重復的三元組,使得三元組的和為 0。下面我將詳細講解解題思路&#xff0c…

基因組組裝中的術語1——from HGP

Initial sequencing and analysis of the human genome | Nature 1,分層鳥槍法測序hierarchical shotgun sequencing

安全開發-環境選擇

文章目錄 個人心得虛擬機選擇ubuntu 22.04python環境選擇conda下載使用: 個人心得 在做開發時配置一個專門的環境可以使我們在開發中的效率顯著提升,可以避免掉很多環境沖突的報錯。尤其是python各種版本沖突,還有做滲透工具不要選擇windows…

數字體驗驅動用戶參與增效路徑

內容概要 在數字化轉型深化的當下,數字內容體驗已成為企業與用戶建立深度連接的核心切入點。通過個性化推薦引擎與智能數據分析系統的協同運作,企業能夠實時捕捉用戶行為軌跡,構建精準的用戶行為深度洞察模型。這一模型不僅支撐內容分發的動…

Python 字符串(str)全方位剖析:從基礎入門、方法詳解到跨語言對比與知識拓展

Python 字符串(str)全方位剖析:從基礎入門、方法詳解到跨語言對比與知識拓展 本文將深入探討 Python 中字符串(str)的相關知識,涵蓋字符串的定義、創建、基本操作、格式化等內容。同時,會將 Py…

使用C++實現簡單的TCP服務器和客戶端

使用C實現簡單的TCP服務器和客戶端 介紹準備工作1. TCP服務器實現代碼結構解釋 2. TCP客戶端實現代碼結構解釋 3. 測試1.編譯:2.運行 結語 介紹 本文將通過一個簡單的例子,介紹如何使用C實現一個基本的TCP服務器和客戶端。這個例子展示了如何創建服務器…

Java Web開發實戰與項目——Spring Boot與Spring Cloud微服務項目實戰

企業級應用中,微服務架構已經成為一種常見的開發模式。Spring Boot與Spring Cloud提供了豐富的工具和組件,幫助開發者快速構建、管理和擴展微服務應用。本文將通過一個實際的微服務項目,展示如何使用Spring Boot與Spring Cloud構建微服務架構…

VMware建立linux虛擬機

本文適用于初學者,幫助初學者學習如何創建虛擬機,了解在創建過程中各個選項的含義。 環境如下: CentOS版本: CentOS 7.9(2009) 軟件: VMware Workstation 17 Pro 17.5.0 build-22583795 1.配…

Linux8-互斥鎖、信號量

一、前情回顧 void perror(const char *s);功能:參數: 二、資源競爭 1.多線程訪問臨界資源時存在資源競爭(存在資源競爭、造成數據錯亂) 臨界資源:多個線程可以同時操作的資源空間(全局變量、共享內存&a…

LD_PRELOAD 繞過 disable_function 學習

借助這位師傅的文章來學習通過LD_PRELOAD來繞過disable_function的原理 【PHP繞過】LD_PRELOAD bypass disable_functions_phpid繞過-CSDN博客 感謝這位師傅的貢獻 介紹 靜態鏈接: (1)舉個情景來幫助理解: 假設你要搬家&#x…

【無人集群系列---無人機集群編隊算法】

【無人集群系列---無人機集群編隊算法】 一、核心目標二、主流編隊控制方法1. 領航-跟隨法(Leader-Follower)2. 虛擬結構法(Virtual Structure)3. 行為法(Behavior-Based)4. 人工勢場法(Artific…

Oracle Fusion Middleware更改weblogic密碼

前言 當用戶忘記weblogic密碼時,且無法登錄到web界面中,需要使用服務器命令更改密碼 更改方式 1、備份 首先進入 weblogic 安裝目錄,備份三個文件:boot.properties,DefaultAuthenticatorInit.ldift,Def…