Qwen3MLP
Qwen3MLP是基于門控機制的MLP模塊,采用了類似門控線性單元(GLU)的結構。它通過三個線性變換層(gate_proj、up_proj和down_proj)和SiLU激活函數,先將輸入從隱藏維度擴展到中間維度,經過門控計算后再投影回原始維度。該模塊保持了輸入輸出形狀的一致性,演示了如何逐步執行前向傳播并驗證計算正確性,展示了Transformer模型中常用的前饋神經網絡結構。
具體代碼與測試如下:
import torch
import torch.nn as nn
from transformers.activations import ACT2FNclass Qwen3MLP(nn.Module):def __init__(self, config):super().__init__()self.config = configself.hidden_size = config.hidden_sizeself.intermediate_size = config.intermediate_sizeself.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)self.act_fn = ACT2FN[config.hidden_act] # siludef forward(self, x):down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))return down_proj# 模擬配置類
class MockConfig:def __init__(self):self.hidden_size = 1024self.intermediate_size = 2048self.hidden_act = "silu"# 完整示例
if __name__ == "__main__":# 1. 創建配置對象config = MockConfig()# 2. 初始化Qwen3MLP模塊mlp = Qwen3MLP(config)# 3. 創建測試輸入數據batch_size = 2seq_length = 8hidden_size = config.hidden_size # 1024# 輸入張量形狀: (batch_size, seq_length, hidden_size)input_tensor = torch.randn(batch_size, seq_length, hidden_size)print("=== Qwen3MLP 示例 ===")print(f"配置信息:")print(f" - hidden_size: {config.hidden_size}")print(f" - intermediate_size: {config.intermediate_size}")print(f" - activation: {config.hidden_act}")print(f"\n輸入張量形狀: {input_tensor.shape}")# 4. 前向傳播with torch.no_grad():output_tensor = mlp(input_tensor)print(f"輸出張量形狀: {output_tensor.shape}")# 5. 驗證輸出形狀與輸入形狀一致assert output_tensor.shape == input_tensor.shape, \f"輸出形狀 {output_tensor.shape} 與輸入形狀 {input_tensor.shape} 不一致"print("\n=== MLP 層內部組件 ===")print(f"gate_proj 權重形狀: {mlp.gate_proj.weight.shape}")print(f"up_proj 權重形狀: {mlp.up_proj.weight.shape}")print(f"down_proj 權重形狀: {mlp.down_proj.weight.shape}")# 6. 逐步計算過程演示print("\n=== 前向傳播步驟 ===")with torch.no_grad():# 第一步: 門控投影gate_output = mlp.gate_proj(input_tensor)print(f"1. gate_proj 輸出形狀: {gate_output.shape}")# 第二步: 激活函數gate_activated = mlp.act_fn(gate_output)print(f"2. 激活函數后形狀: {gate_activated.shape}")# 第三步: 上投影up_output = mlp.up_proj(input_tensor)print(f"3. up_proj 輸出形狀: {up_output.shape}")# 第四步: 門控線性單元 (GLU)glu_output = gate_activated * up_outputprint(f"4. GLU 輸出形狀: {glu_output.shape}")# 第五步: 下投影final_output = mlp.down_proj(glu_output)print(f"5. down_proj 輸出形狀: {final_output.shape}")# 驗證與直接調用forward的結果一致direct_output = mlp(input_tensor)assert torch.allclose(final_output, direct_output, atol=1e-6), "逐步計算結果與直接調用不一致"print("? 逐步計算結果與直接調用結果一致")print("\n=== 示例完成 ===")print(f"MLP 成功處理了形狀為 {input_tensor.shape} 的輸入,輸出形狀為 {output_tensor.shape}")
=== Qwen3MLP 示例 ===
配置信息:- hidden_size: 1024- intermediate_size: 2048- activation: silu輸入張量形狀: torch.Size([2, 8, 1024])
輸出張量形狀: torch.Size([2, 8, 1024])=== MLP 層內部組件 ===
gate_proj 權重形狀: torch.Size([2048, 1024])
up_proj 權重形狀: torch.Size([2048, 1024])
down_proj 權重形狀: torch.Size([1024, 2048])=== 前向傳播步驟 ===
1. gate_proj 輸出形狀: torch.Size([2, 8, 2048])
2. 激活函數后形狀: torch.Size([2, 8, 2048])
3. up_proj 輸出形狀: torch.Size([2, 8, 2048])
4. GLU 輸出形狀: torch.Size([2, 8, 2048])
5. down_proj 輸出形狀: torch.Size([2, 8, 1024])
? 逐步計算結果與直接調用結果一致=== 示例完成 ===
MLP 成功處理了形狀為 torch.Size([2, 8, 1024]) 的輸入,輸出形狀為 torch.Size([2, 8, 1024])
Qwen3MoeSparseMoeBlock
Qwen3 模型的稀疏混合專家(Sparse MoE)模塊,核心是通過“路由機制+多專家并行計算”提升模型在大參數量下的效率與能力。
Qwen3MoeSparseMoeBlock
處理輸入的流程可分為 路由計算→專家選擇→并行計算→結果聚合 四步:
1. 路由計算:為每個 token 選專家
- 輸入
hidden_states
(形狀[batch_size, seq_length, hidden_size]
)先展平為[batch*seq, hidden_size]
; - 用
self.gate
(線性層)生成router_logits
(每個 token 對 8 個專家的“匹配分數”); - 通過
softmax
+topk
,為每個 token 選num_experts_per_tok=2
個“最匹配專家”,并得到歸一化的路由權重(決定每個專家對 token 的貢獻占比)。
2. 專家選擇:標記活躍專家
通過 one_hot
編碼生成 expert_mask
,標記“哪些專家被哪些 token 選中”;再通過 expert_hit
篩選出至少被一個 token 選中的活躍專家(示例中 8 個專家都有 token 命中)。
3. 并行計算:專家各自處理 token
對每個活躍專家,執行:
- 篩選出“屬于當前專家”的 token(通過
expert_mask
定位); - 調用該專家的
Qwen3MoeMLP
層(結構同普通 MLP,但參數量僅服務部分 token),完成“門控投影→激活→上投影→下投影”的計算; - 用路由權重對專家輸出加權(確保不同專家的貢獻按匹配度分配)。
4. 結果聚合:合并所有專家輸出
通過 index_add_
將每個專家處理后的 token 結果,按原始位置合并,最終還原為 [batch_size, seq_length, hidden_size]
的輸出。
具體代碼與測試如下:
import torch.nn as nn
from transformers.activations import ACT2FN
import torch.nn.functional as Fclass Qwen3MoeMLP(nn.Module):def __init__(self, config, intermediate_size=None):super().__init__()self.config = configself.hidden_size = config.hidden_size # 512self.intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size# 256self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)# 512, 256self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) # 512, 256self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) # 256, 512self.act_fn = ACT2FN[config.hidden_act]def forward(self, x):down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))return down_projclass Qwen3MoeSparseMoeBlock(nn.Module):def __init__(self, config):super().__init__()self.num_experts = config.num_experts # 8self.top_k = config.num_experts_per_tok # 2self.norm_topk_prob = config.norm_topk_prob # Trueself.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) # 512 -> 8self.experts = nn.ModuleList([Qwen3MoeMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(self.num_experts)]) # 512 -> 256 -> 512def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, sequence_length, hidden_dim = hidden_states.shape # 2, 6, 512hidden_states = hidden_states.view(-1, hidden_dim) # 2, 6, 512 -> 12, 512router_logits = self.gate(hidden_states) # 12 8routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) # 12 8routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) # 12 2if self.norm_topk_prob: routing_weights /= routing_weights.sum(dim=-1, keepdim=True)routing_weights = routing_weights.to(hidden_states.dtype)final_hidden_states = torch.zeros((batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device) # 12 512expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)# 12 2 8 8 2 12 print("expert_mask: \n",expert_mask)expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() # 8print("expert hit: \n",expert_hit)for expert_idx in expert_hit:expert_layer = self.experts[expert_idx] # Qwen3MoeMLPidx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) # 4 4 if expert_idx == 0:print("expert_mask[expert_idx].squeeze(0):",expert_mask[expert_idx].squeeze(0))print("idx:",idx)print("top_x:",top_x)print("hidden_states[None, top_x]:",hidden_states[None, top_x].shape)current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) # 1, 4, 512 -> 4, 512current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]# 4, 512 * 4, 512 -> 4, 512final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) # 2, 6, 512return final_hidden_states, router_logits
class MockConfig:def __init__(self):self.hidden_size = 512self.moe_intermediate_size = 256self.hidden_act = "silu"self.num_experts = 8self.num_experts_per_tok = 2self.norm_topk_prob = Trueimport numpy as np
import random# 設置隨機種子以確保可重復性
def set_random_seed(seed=42):"""設置所有隨機種子以確保結果可重復"""torch.manual_seed(seed)torch.cuda.manual_seed(seed)torch.cuda.manual_seed_all(seed)np.random.seed(seed)random.seed(seed)torch.backends.cudnn.deterministic = Truetorch.backends.cudnn.benchmark = False
# 完整示例
if __name__ == "__main__":set_random_seed(42)config = MockConfig()moe_block = Qwen3MoeSparseMoeBlock(config)batch_size = 2seq_length = 6hidden_size = config.hidden_size # 512input_tensor = torch.randn(batch_size, seq_length, hidden_size)print("=== Qwen3MoeSparseMoeBlock 示例 ===")print(f"配置信息:")print(f" - hidden_size: {config.hidden_size}")print(f" - moe_intermediate_size: {config.moe_intermediate_size}")print(f" - activation: {config.hidden_act}")print(f" - num_experts: {config.num_experts}")print(f" - num_experts_per_tok: {config.num_experts_per_tok}")print(f" - norm_topk_prob: {config.norm_topk_prob}")print(f"\n輸入張量形狀: {input_tensor.shape}")with torch.no_grad():output_tensor, router_logits = moe_block(input_tensor)print(f"輸出張量形狀: {output_tensor.shape}")print(f"路由邏輯形狀: {router_logits.shape}")
=== Qwen3MoeSparseMoeBlock 示例 ===
配置信息:- hidden_size: 512- moe_intermediate_size: 256- activation: silu- num_experts: 8- num_experts_per_tok: 2- norm_topk_prob: True輸入張量形狀: torch.Size([2, 6, 512])
expert_mask: tensor([[[1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0],[0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]],[[0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],[0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0]],[[0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]],[[0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],[0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0]],[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],[[0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],[[0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1],[0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0]],[[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],[0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1]]])
expert hit: tensor([[0],[1],[2],[3],[4],[5],[6],[7]])
expert_mask[expert_idx].squeeze(0): tensor([[1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0],[0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]])
idx: tensor([0, 0, 0, 1])
top_x: tensor([ 0, 2, 10, 6])
hidden_states[None, top_x]: torch.Size([1, 4, 512])
輸出張量形狀: torch.Size([2, 6, 512])
路由邏輯形狀: torch.Size([12, 8])