門控MLP(Qwen3MLP)與稀疏混合專家(Qwen3MoeSparseMoeBlock)模塊解析

Qwen3MLP

Qwen3MLP是基于門控機制的MLP模塊,采用了類似門控線性單元(GLU)的結構。它通過三個線性變換層(gate_proj、up_proj和down_proj)和SiLU激活函數,先將輸入從隱藏維度擴展到中間維度,經過門控計算后再投影回原始維度。該模塊保持了輸入輸出形狀的一致性,演示了如何逐步執行前向傳播并驗證計算正確性,展示了Transformer模型中常用的前饋神經網絡結構。
具體代碼與測試如下:

import torch
import torch.nn as nn
from transformers.activations import ACT2FNclass Qwen3MLP(nn.Module):def __init__(self, config):super().__init__()self.config = configself.hidden_size = config.hidden_sizeself.intermediate_size = config.intermediate_sizeself.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)self.act_fn = ACT2FN[config.hidden_act] # siludef forward(self, x):down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))return down_proj# 模擬配置類
class MockConfig:def __init__(self):self.hidden_size = 1024self.intermediate_size = 2048self.hidden_act = "silu"# 完整示例
if __name__ == "__main__":# 1. 創建配置對象config = MockConfig()# 2. 初始化Qwen3MLP模塊mlp = Qwen3MLP(config)# 3. 創建測試輸入數據batch_size = 2seq_length = 8hidden_size = config.hidden_size  # 1024# 輸入張量形狀: (batch_size, seq_length, hidden_size)input_tensor = torch.randn(batch_size, seq_length, hidden_size)print("=== Qwen3MLP 示例 ===")print(f"配置信息:")print(f"  - hidden_size: {config.hidden_size}")print(f"  - intermediate_size: {config.intermediate_size}")print(f"  - activation: {config.hidden_act}")print(f"\n輸入張量形狀: {input_tensor.shape}")# 4. 前向傳播with torch.no_grad():output_tensor = mlp(input_tensor)print(f"輸出張量形狀: {output_tensor.shape}")# 5. 驗證輸出形狀與輸入形狀一致assert output_tensor.shape == input_tensor.shape, \f"輸出形狀 {output_tensor.shape} 與輸入形狀 {input_tensor.shape} 不一致"print("\n=== MLP 層內部組件 ===")print(f"gate_proj 權重形狀: {mlp.gate_proj.weight.shape}")print(f"up_proj 權重形狀: {mlp.up_proj.weight.shape}")print(f"down_proj 權重形狀: {mlp.down_proj.weight.shape}")# 6. 逐步計算過程演示print("\n=== 前向傳播步驟 ===")with torch.no_grad():# 第一步: 門控投影gate_output = mlp.gate_proj(input_tensor)print(f"1. gate_proj 輸出形狀: {gate_output.shape}")# 第二步: 激活函數gate_activated = mlp.act_fn(gate_output)print(f"2. 激活函數后形狀: {gate_activated.shape}")# 第三步: 上投影up_output = mlp.up_proj(input_tensor)print(f"3. up_proj 輸出形狀: {up_output.shape}")# 第四步: 門控線性單元 (GLU)glu_output = gate_activated * up_outputprint(f"4. GLU 輸出形狀: {glu_output.shape}")# 第五步: 下投影final_output = mlp.down_proj(glu_output)print(f"5. down_proj 輸出形狀: {final_output.shape}")# 驗證與直接調用forward的結果一致direct_output = mlp(input_tensor)assert torch.allclose(final_output, direct_output, atol=1e-6), "逐步計算結果與直接調用不一致"print("? 逐步計算結果與直接調用結果一致")print("\n=== 示例完成 ===")print(f"MLP 成功處理了形狀為 {input_tensor.shape} 的輸入,輸出形狀為 {output_tensor.shape}")
=== Qwen3MLP 示例 ===
配置信息:- hidden_size: 1024- intermediate_size: 2048- activation: silu輸入張量形狀: torch.Size([2, 8, 1024])
輸出張量形狀: torch.Size([2, 8, 1024])=== MLP 層內部組件 ===
gate_proj 權重形狀: torch.Size([2048, 1024])
up_proj 權重形狀: torch.Size([2048, 1024])
down_proj 權重形狀: torch.Size([1024, 2048])=== 前向傳播步驟 ===
1. gate_proj 輸出形狀: torch.Size([2, 8, 2048])
2. 激活函數后形狀: torch.Size([2, 8, 2048])
3. up_proj 輸出形狀: torch.Size([2, 8, 2048])
4. GLU 輸出形狀: torch.Size([2, 8, 2048])
5. down_proj 輸出形狀: torch.Size([2, 8, 1024])
? 逐步計算結果與直接調用結果一致=== 示例完成 ===
MLP 成功處理了形狀為 torch.Size([2, 8, 1024]) 的輸入,輸出形狀為 torch.Size([2, 8, 1024])

Qwen3MoeSparseMoeBlock

Qwen3 模型的稀疏混合專家(Sparse MoE)模塊,核心是通過“路由機制+多專家并行計算”提升模型在大參數量下的效率與能力。

Qwen3MoeSparseMoeBlock 處理輸入的流程可分為 路由計算→專家選擇→并行計算→結果聚合 四步:

1. 路由計算:為每個 token 選專家
  • 輸入 hidden_states(形狀 [batch_size, seq_length, hidden_size])先展平為 [batch*seq, hidden_size]
  • self.gate(線性層)生成 router_logits(每個 token 對 8 個專家的“匹配分數”);
  • 通過 softmax+topk,為每個 token 選 num_experts_per_tok=2 個“最匹配專家”,并得到歸一化的路由權重(決定每個專家對 token 的貢獻占比)。
2. 專家選擇:標記活躍專家

通過 one_hot 編碼生成 expert_mask,標記“哪些專家被哪些 token 選中”;再通過 expert_hit 篩選出至少被一個 token 選中的活躍專家(示例中 8 個專家都有 token 命中)。

3. 并行計算:專家各自處理 token

對每個活躍專家,執行:

  • 篩選出“屬于當前專家”的 token(通過 expert_mask 定位);
  • 調用該專家的 Qwen3MoeMLP 層(結構同普通 MLP,但參數量僅服務部分 token),完成“門控投影→激活→上投影→下投影”的計算;
  • 用路由權重對專家輸出加權(確保不同專家的貢獻按匹配度分配)。
4. 結果聚合:合并所有專家輸出

通過 index_add_ 將每個專家處理后的 token 結果,按原始位置合并,最終還原為 [batch_size, seq_length, hidden_size] 的輸出。


具體代碼與測試如下:

import torch.nn as nn
from transformers.activations import ACT2FN
import torch.nn.functional as Fclass Qwen3MoeMLP(nn.Module):def __init__(self, config, intermediate_size=None):super().__init__()self.config = configself.hidden_size = config.hidden_size  # 512self.intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size# 256self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)# 512, 256self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) # 512, 256self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) # 256, 512self.act_fn = ACT2FN[config.hidden_act]def forward(self, x):down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))return down_projclass Qwen3MoeSparseMoeBlock(nn.Module):def __init__(self, config):super().__init__()self.num_experts = config.num_experts # 8self.top_k = config.num_experts_per_tok # 2self.norm_topk_prob = config.norm_topk_prob # Trueself.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) # 512 -> 8self.experts = nn.ModuleList([Qwen3MoeMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(self.num_experts)]) #  512 -> 256 -> 512def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, sequence_length, hidden_dim = hidden_states.shape  # 2, 6, 512hidden_states = hidden_states.view(-1, hidden_dim) # 2, 6, 512 -> 12, 512router_logits = self.gate(hidden_states) # 12 8routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) # 12 8routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) # 12 2if self.norm_topk_prob:  routing_weights /= routing_weights.sum(dim=-1, keepdim=True)routing_weights = routing_weights.to(hidden_states.dtype)final_hidden_states = torch.zeros((batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device) # 12 512expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)# 12 2 8    8 2 12 print("expert_mask: \n",expert_mask)expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() # 8print("expert hit: \n",expert_hit)for expert_idx in expert_hit:expert_layer = self.experts[expert_idx]  # Qwen3MoeMLPidx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) # 4 4 if expert_idx == 0:print("expert_mask[expert_idx].squeeze(0):",expert_mask[expert_idx].squeeze(0))print("idx:",idx)print("top_x:",top_x)print("hidden_states[None, top_x]:",hidden_states[None, top_x].shape)current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)   # 1, 4, 512 -> 4, 512current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]# 4, 512 * 4, 512 -> 4, 512final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) # 2, 6, 512return final_hidden_states, router_logits
class MockConfig:def __init__(self):self.hidden_size = 512self.moe_intermediate_size = 256self.hidden_act = "silu"self.num_experts = 8self.num_experts_per_tok = 2self.norm_topk_prob = Trueimport numpy as np
import random# 設置隨機種子以確保可重復性
def set_random_seed(seed=42):"""設置所有隨機種子以確保結果可重復"""torch.manual_seed(seed)torch.cuda.manual_seed(seed)torch.cuda.manual_seed_all(seed)np.random.seed(seed)random.seed(seed)torch.backends.cudnn.deterministic = Truetorch.backends.cudnn.benchmark = False
# 完整示例
if __name__ == "__main__":set_random_seed(42)config = MockConfig()moe_block = Qwen3MoeSparseMoeBlock(config)batch_size = 2seq_length = 6hidden_size = config.hidden_size  # 512input_tensor = torch.randn(batch_size, seq_length, hidden_size)print("=== Qwen3MoeSparseMoeBlock 示例 ===")print(f"配置信息:")print(f"  - hidden_size: {config.hidden_size}")print(f"  - moe_intermediate_size: {config.moe_intermediate_size}")print(f"  - activation: {config.hidden_act}")print(f"  - num_experts: {config.num_experts}")print(f"  - num_experts_per_tok: {config.num_experts_per_tok}")print(f"  - norm_topk_prob: {config.norm_topk_prob}")print(f"\n輸入張量形狀: {input_tensor.shape}")with torch.no_grad():output_tensor, router_logits = moe_block(input_tensor)print(f"輸出張量形狀: {output_tensor.shape}")print(f"路由邏輯形狀: {router_logits.shape}")
=== Qwen3MoeSparseMoeBlock 示例 ===
配置信息:- hidden_size: 512- moe_intermediate_size: 256- activation: silu- num_experts: 8- num_experts_per_tok: 2- norm_topk_prob: True輸入張量形狀: torch.Size([2, 6, 512])
expert_mask: tensor([[[1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0],[0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]],[[0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],[0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0]],[[0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]],[[0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],[0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0]],[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],[[0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],[[0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1],[0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0]],[[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],[0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1]]])
expert hit: tensor([[0],[1],[2],[3],[4],[5],[6],[7]])
expert_mask[expert_idx].squeeze(0): tensor([[1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0],[0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]])
idx: tensor([0, 0, 0, 1])
top_x: tensor([ 0,  2, 10,  6])
hidden_states[None, top_x]: torch.Size([1, 4, 512])
輸出張量形狀: torch.Size([2, 6, 512])
路由邏輯形狀: torch.Size([12, 8])

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/96074.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/96074.shtml
英文地址,請注明出處:http://en.pswp.cn/web/96074.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

產線相機問題分析思路

現象:復現問題 原因:問題分析、溯源,定位根本原因; 方案:提出解決方案、規避措施 驗證:導入、驗證方案是否可行(先小批量、再大批量);一. 現象產線反饋4pcs預覽又臟污、劃…

【開關電源篇】EMI輸入電路-超簡單解讀

1. 輸入電路主要包含哪些元件?濾波設計需遵循什么原則? 輸入電路是電子設備(如開關電源)的“入口”,核心作用是抑制電磁干擾(EMI)、保護后級電路,其設計直接影響設備的穩定性和電磁…

勝券POS:打造智能移動終端,讓零售智慧運營觸手可及

零售企業運營中依然存在重重挑戰:收銀臺前的長隊消磨著顧客的耐心,倉庫里的庫存盤點不斷侵蝕著員工的精力,導購培訓的成本長期居高不下卻收效甚微……面對這些痛點,零售企業或許都在等待一個破局的答案。百勝軟件勝券POS&#xff…

(回溯/組合)Leetcode77組合+39組合總和+216組合總和III

為什么不能暴力,因為不知道要循環多少次,如果長度為n,難道要循環n次么,回溯的本質還是暴力,但是是可以知道多少層的暴力 之所以要pop是因為回溯相當于一個樹形結構,要pop進行第二個分支 剪枝:…

07 下載配置很完善的yum軟件源

文章目錄前言ping 測試網絡排查原因排查虛擬機的虛擬網絡是否開啟檢查net8虛擬網絡和Centos 7的ip地址是否在一個局域網點擊虛擬網絡編輯器點擊更改設置記錄net8的虛擬網絡地址ip a記錄Centos 7的ip地址比較net8和Centos 7的ip地址是否在一個網段解決問題問題解決辦法修改net8的…

SpringBoot中添加健康檢查服務

問題 今天需要給一個Spring工程添加健康檢查。 pom.xml <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-actuator</artifactId> </dependency>application.yml management:endpoints:web:e…

AI工具深度測評與選型指南 - AI工具測評框架及方法論

目錄引言&#xff1a;AI工具爆發期的機遇與挑戰一、從AI模型到AI工具&#xff1a;核心認知與生態解析1.1 DeepSeek&#xff1a;快速出圈的國產大模型代表1.2 大模型的核心能力與類型劃分1.2.1 大模型的三層能力與“雙系統”類比1.2.2 生成模型與推理模型的核心差異1.3 AI工具與…

Spring Cloud Alibaba快速入門02-Nacos(中)

文章目錄實現注冊中心-服務發現模擬掉線遠程調用1.訂單和商品模塊的接口商品服務訂單服務2.抽取實體類3.訂單服務拿到需要調用服務的ip和端口負載均衡步驟1步驟2步驟3步驟4面試題&#xff1a;注冊中心宕機&#xff0c;遠程調用還能成功嗎&#xff1f;1、調用過;遠程調用不在依賴…

【Python】數據可視化之熱力圖

熱力圖&#xff08;Heatmap&#xff09;是一種通過顏色深淺來展示數據分布、密度和強度等信息的可視化圖表。它通過對色塊著色來反映數據特征&#xff0c;使用戶能夠直觀地理解數據模式&#xff0c;發現規律&#xff0c;并作出決策。 目錄 基本原理 sns.heatmap 代碼實現 基…

如何 正確使用 nrm 工具 管理鏡像源

目錄 nrm 是啥&#xff1f; nrm 的安裝 查看你當前已有的鏡像源 怎么切換到目標鏡像源 添加鏡像源 刪除鏡像源 測試鏡像源速度 nrm 是啥&#xff1f; 鏡像源&#xff1a;可以理解為&#xff0c;你訪問或下載某jar包或依賴的倉庫。 nrm&#xff08;Node Registry Manag…

關于對逾期提醒的定時任務~改進完善

Spring Boot 中實現到期提醒任務的定時Job詳解在金融或借貸系統中&#xff0c;到期提醒是常見的功能需求。通過定時任務&#xff0c;可以定期掃描即將到期的借款記錄&#xff0c;并生成或更新提醒信息。本文基于提供的三個JobHandler類&#xff08;FarExpireRemindJob、MidExpi…

springboot配置請求日志

springboot配置請求日志 一般情況下&#xff0c;接口請求都需要日志記錄&#xff0c;Java springboot中的日志記錄相對復雜一點 經過實踐&#xff0c;以下方案可行&#xff0c;記錄一下完整過程 一、創建日志數據模型 創建實體類&#xff0c;也就是日志文件中要記錄的數據格式 …

Redis(50) Redis哨兵如何與客戶端進行交互?

Redis 哨兵&#xff08;Sentinel&#xff09;不僅負責監控和管理 Redis 主從復制集群的高可用性&#xff0c;還需要與客戶端進行有效的交互來實現故障轉移后的透明連接切換。下面詳細探討 Redis 哨兵如何與客戶端進行交互&#xff0c;并結合代碼示例加以說明。 哨兵與客戶端的交…

【.Net技術棧梳理】04-核心框架與運行時(線程處理)

文章目錄1. 線程管理1.1 線程的核心概念&#xff1a;System.Threading.Thread1.2 現代線程管理&#xff1a;System.Threading.Tasks.Task 和 Task Parallel Library (TPL)1.3 狀態管理和異常處理1.4 協調任務&#xff1a;async/await 模式2. 線程間通信2.1 共享內存與競態條件2…

(JVM)四種垃圾回收算法

在 JVM 中&#xff0c;垃圾回收&#xff08;GC&#xff09;是核心機制之一。為了提升性能與內存利用率&#xff0c;JVM 采用了多種垃圾回收算法。本文總結了 四種常見的 GC 算法&#xff0c;并結合其優缺點與應用場景進行說明。1. 標記-清除&#xff08;Mark-Sweep&#xff09;…

論文閱讀:VGGT Visual Geometry Grounded Transformer

論文閱讀&#xff1a;VGGT: Visual Geometry Grounded Transformer 今天介紹一篇 CVPR 2025 的 best paper&#xff0c;這篇文章是牛津大學的 VGG 團隊的工作&#xff0c;主要圍繞著 3D 視覺中的各種任務&#xff0c;這篇文章提出了一種多任務統一的架構&#xff0c;實現一次輸…

python編程:一文掌握pypiserver的詳細使用

更多內容請見: python3案例和總結-專欄介紹和目錄 文章目錄 一、 pypiserver 概述 1.1 pypiserver是什么? 1.2 核心特性 1.3 典型應用場景 1.4 pypiserver優缺點 二、 安裝與基本使用 2.1 安裝 pypiserver 2.2 快速啟動(最簡模式) 2.3 使用私有服務器安裝包 2.4 向私有服務…

Git reset 回退版本

- 第 121 篇 - Date: 2025 - 09 - 06 Author: 鄭龍浩&#xff08;仟墨&#xff09; 文章目錄Git reset 回退版本1 介紹三種命令區別3 驗證三種的區別3 如果不小心git reset --hard將「工作區」和「暫存區」中的內容刪除&#xff0c;剛才的記錄找不到了&#xff0c;怎么辦呢&…

ARM 基礎(2)

ARM內核工作模式及其切換條件用戶模式(User Mode, usr) 權限最低&#xff0c;運行普通應用程序。只能通過異常被動切換到其他模式。快速中斷模式(FIQ Mode, fiq) 處理高速外設中斷&#xff0c;專用寄存器減少上下文保存時間&#xff0c;響應周期約4個時鐘周期。觸發條件為FIQ中…

Flutter 性能優化

Flutter 性能優化是一個系統性的工程&#xff0c;涉及多個層面。 一、性能分析工具&#xff08;Profiling Tools&#xff09; 在開始優化前&#xff0c;必須使用工具定位瓶頸。切忌盲目優化。 1. DevTools 性能視圖 DevTools 性能視圖 (Performance View) 作用&#xff1a;…