LLM:MoE原理與實現探索

文章目錄

  • 前言
  • 一、Deepseek Moe
  • 二. Moe架構
    • 1. Expert
    • 2. Gate
    • 3. MoE Module
  • 三、Auxiliary Loss
  • 總結


前言

MoE(Mixture of Experts) 已經逐漸在LLM中廣泛應用,其工程部署相關目前也有了越來越多的支持,本文主要記錄一下MoE的基本模塊構造與原理。以Deepseek中的MoE構造為例。


MoE本質上可以看作一個MLP層,不過是對每個token都不一樣的MLP層,假設一個MoE模塊存在64個expert,相當于有64個并行MLP層,當有一個token送入進行處理時,會自動選擇當前最合適的幾個expert進行運算,而不是一個固定個MLP。

一、Deepseek Moe

模型整體架構如下

Transformer((embed): ParallelEmbedding()(layers): ModuleList((0): Block((attn): MLA((wq): ColumnParallelLinear()(wkv_a): Linear()(kv_norm): RMSNorm()(wkv_b): ColumnParallelLinear()(wo): RowParallelLinear())(ffn): MLP((w1): ColumnParallelLinear()(w2): RowParallelLinear()(w3): ColumnParallelLinear())(attn_norm): RMSNorm()(ffn_norm): RMSNorm())(1-26): 26 x Block((attn): MLA((wq): ColumnParallelLinear()(wkv_a): Linear()(kv_norm): RMSNorm()(wkv_b): ColumnParallelLinear()(wo): RowParallelLinear())(ffn): MoE((gate): Gate()(experts): ModuleList((0-63): 64 x Expert((w1): Linear()(w2): Linear()(w3): Linear()))(shared_experts): MLP((w1): ColumnParallelLinear()(w2): RowParallelLinear()(w3): ColumnParallelLinear()))(attn_norm): RMSNorm()(ffn_norm): RMSNorm()))(norm): RMSNorm()(head): ColumnParallelLinear()
)

二. Moe架構

從上面的架構可以發現, moe作為一個模塊替換了原始Transformer架構mlp模塊,主要包含一個Gate和Expert。

1. Expert

Expert代碼如下:


class Expert(nn.Module):"""Expert layer for Mixture-of-Experts (MoE) models.Attributes:w1 (nn.Module): Linear layer for input-to-hidden transformation.w2 (nn.Module): Linear layer for hidden-to-output transformation.w3 (nn.Module): Additional linear layer for feature transformation."""def __init__(self, dim: int, inter_dim: int):"""Initializes the Expert layer.Args:dim (int): Input and output dimensionality.inter_dim (int): Hidden layer dimensionality."""super().__init__()self.w1 = Linear(dim, inter_dim)self.w2 = Linear(inter_dim, dim)self.w3 = Linear(dim, inter_dim)def forward(self, x: torch.Tensor) -> torch.Tensor:"""Forward pass for the Expert layer.Args:x (torch.Tensor): Input tensor.Returns:torch.Tensor: Output tensor after expert computation."""return self.w2(F.silu(self.w1(x)) * self.w3(x))

可以看到Expert本質上就是一個簡單的線性層

2. Gate

如何在當前MoE模塊中選擇合適的Expert則是通過Gate操作來完成的

Gate代碼如下:


class Gate(nn.Module):"""Gating mechanism for routing inputs in a mixture-of-experts (MoE) model.Attributes:dim (int): Dimensionality of input features.topk (int): Number of top experts activated for each input.n_groups (int): Number of groups for routing.topk_groups (int): Number of groups to route inputs to.score_func (str): Scoring function ('softmax' or 'sigmoid').route_scale (float): Scaling factor for routing weights.weight (torch.nn.Parameter): Learnable weights for the gate.bias (Optional[torch.nn.Parameter]): Optional bias term for the gate."""def __init__(self, args: ModelArgs):"""Initializes the Gate module.Args:args (ModelArgs): Model arguments containing gating parameters."""super().__init__()self.dim = args.dimself.topk = args.n_activated_expertsself.n_groups = args.n_expert_groupsself.topk_groups = args.n_limited_groupsself.score_func = args.score_funcself.route_scale = args.route_scaleself.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim))self.bias = nn.Parameter(torch.empty(args.n_routed_experts)) if self.dim == 7168 else Nonedef forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:"""Forward pass for the gating mechanism.Args:x (torch.Tensor): Input tensor.Returns:Tuple[torch.Tensor, torch.Tensor]: Routing weights and selected expert indices."""scores = linear(x, self.weight)if self.score_func == "softmax":scores = scores.softmax(dim=-1, dtype=torch.float32)else:scores = scores.sigmoid()original_scores = scoresif self.bias is not None:scores = scores + self.biasif self.n_groups > 1:scores = scores.view(x.size(0), self.n_groups, -1)if self.bias is None:group_scores = scores.amax(dim=-1)else:group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1)indices = group_scores.topk(self.topk_groups, dim=-1)[1]mask = scores.new_ones(x.size(0), self.n_groups, dtype=bool).scatter_(1, indices, False)scores = scores.masked_fill_(mask.unsqueeze(-1), float("-inf")).flatten(1)indices = torch.topk(scores, self.topk, dim=-1)[1]weights = original_scores.gather(1, indices)if self.score_func == "sigmoid":weights /= weights.sum(dim=-1, keepdim=True)weights *= self.route_scalereturn weights.type_as(x), indices

Gate函數通過輸入特征向量,輸出分數

請添加圖片描述

基于當前分數根據情況確定是否采用分組路由, 若采用分組路由,則會將多個expert分組,然后取每個組的最高分

請添加圖片描述


3. MoE Module

有了Gate計算的分數,就可以進行MoE的計算了,MoE模塊如下

class MoE(nn.Module):"""Mixture-of-Experts (MoE) module.Attributes:dim (int): Dimensionality of input features.n_routed_experts (int): Total number of experts in the model.n_local_experts (int): Number of experts handled locally in distributed systems.n_activated_experts (int): Number of experts activated for each input.gate (nn.Module): Gating mechanism to route inputs to experts.experts (nn.ModuleList): List of expert modules.shared_experts (nn.Module): Shared experts applied to all inputs."""def __init__(self, args: ModelArgs):"""Initializes the MoE module.Args:args (ModelArgs): Model arguments containing MoE parameters."""super().__init__()self.dim = args.dimassert args.n_routed_experts % world_size == 0, f"Number of experts must be divisible by world size (world_size={world_size})"self.n_routed_experts = args.n_routed_expertsself.n_local_experts = args.n_routed_experts // world_sizeself.n_activated_experts = args.n_activated_expertsself.experts_start_idx = rank * self.n_local_expertsself.experts_end_idx = self.experts_start_idx + self.n_local_expertsself.gate = Gate(args)self.experts = nn.ModuleList([Expert(args.dim, args.moe_inter_dim) if self.experts_start_idx <= i < self.experts_end_idx else Nonefor i in range(self.n_routed_experts)])self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim)def forward(self, x: torch.Tensor) -> torch.Tensor:"""Forward pass for the MoE module.Args:x (torch.Tensor): Input tensor.Returns:torch.Tensor: Output tensor after expert routing and computation."""shape = x.size()x = x.view(-1, self.dim)weights, indices = self.gate(x)y = torch.zeros_like(x)counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist()for i in range(self.experts_start_idx, self.experts_end_idx):if counts[i] == 0:continueexpert = self.experts[i]idx, top = torch.where(indices == i)y[idx] += expert(x[idx]) * weights[idx, top, None]z = self.shared_experts(x)if world_size > 1:dist.all_reduce(y)return (y + z).view(shape)

為了方便理解,繪制了個草圖

請添加圖片描述

通過gate函數可以獲得每個token所選的專家索引與對應的score
如上圖中畫了一個簡單的當topk=3時的索引矩陣,針對每個token,選擇三個expert進行處理,每個expert對應了一個分數用于加權。

循環所有expert,基于indices矩陣找出需要第i個expert處理的所有token ——> x[idx] ,經過expert處理后加權賦予y 作為新的特征。

獲得y后,在將x送入shared_expert 獲得z, 最后兩者相加獲得最終MoE的特征。
請添加圖片描述

三、Auxiliary Loss

這里還需要補充一點,在使用MoE的同時,需要使用Auxiliary Loss。在混合專家 (MoE, Mixture of Experts) 模型中,Auxiliary Loss(輔助損失) 的主要作用是 負載均衡,即:

平衡專家的負載: 避免所有的 token 都被分配到少數幾個專家上,導致這些專家過度繁忙,而其他專家閑置。

穩定訓練過程: 防止路由器(Router,負責將 token 分配給專家)學到某種不健康的分布,例如所有 token 都集中到一個專家,或 token 分布極其稀疏。

每個 token 的路由器分數(router_logits)表示該 token 應該路由到哪些專家的偏好。
為了負載均衡,需要統計:
每個專家獲得分配的 token 的比例(fraction_tokens_per_expert)。
路由概率(softmax(router_logits))中每個專家的總概率比例(fraction_prob_per_expert)。

請添加圖片描述

關于aux loss的直觀理解如下:
請添加圖片描述


總結

MoE作為目前大模型在結構推理上的一大創新還是很厲害的,比較現在的LLM結構同質化蠻嚴重的。MoE確實給LLM未來的發展帶來更多可能性。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/92894.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/92894.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/92894.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

基于領域事件驅動的微服務架構設計與實踐

引言&#xff1a;為什么你的微服務總是"牽一發而動全身"&#xff1f; 在復雜的業務系統中&#xff0c;你是否遇到過這樣的困境&#xff1a;修改一個訂單服務&#xff0c;卻導致支付服務異常&#xff1b;調整庫存邏輯&#xff0c;用戶服務開始報錯。這種"蝴蝶效應…

如何使用curl編程來下載文件

libcurl 是一個功能強大的跨平臺網絡傳輸庫&#xff0c;支持多種協議。 本篇來介紹libcul的C語言編程&#xff0c;實現一個文件下載的功能。 1 curl基礎介紹 1.1 核心數據結構 1.1.1 CURL句柄 CURL是libcurl 的核心句柄&#xff0c;每個請求對應一個 CURL 實例&#xff0c;…

大語言模型提示工程與應用:ChatGPT提示工程技術指南

ChatGPT提示工程 學習目標 在本課程中&#xff0c;我們將學習更多關于ChatGPT的最新提示工程技術。 相關知識點 ChatGPT提示工程 學習內容 1 ChatGPT提示工程 ChatGPT是OpenAI研發的新型對話模型&#xff0c;具備多輪對話能力。該模型通過人類反饋強化學習(RLHF)訓練&am…

能力評估:如何系統評估你的技能和經驗

能力評估&#xff1a;如何系統評估你的技能和經驗 作為一名38歲的互聯網研發老兵&#xff0c;你已經積累了豐富的經驗&#xff0c;包括技術深度、項目管理、團隊協作等。但能力評估不是一次性事件&#xff0c;而是持續過程&#xff0c;幫助你識別優勢、短板&#xff0c;并為職業…

鴻蒙開發中所有自定義裝飾器的完整案例解析--涵蓋 16 個核心裝飾器的詳細用法和實戰場景

以下是鴻蒙開發中 所有自定義裝飾器的完整案例解析 和 終極總結指南&#xff0c;涵蓋 16 個核心裝飾器的詳細用法和實戰場景&#xff1a; 一、終極總結表&#xff1a;16大裝飾器全景圖 裝飾器類別V1V2核心作用典型場景Component組件定義??創建標準組件業務UI組件ComponentV2…

【C++】哈希表的實現(unordered_map和unordered_set的底層)

文章目錄 目錄 文章目錄 前言 一、unordered_set和unordered_map介紹 二、哈希表的介紹 三、哈希沖突的解決方法 1.開放定址法 2.鏈地址法 四、兩種哈希表代碼實現 總結 前言 前面我們學習了紅黑樹&#xff0c;紅黑樹就是map和set的底層&#xff0c;本篇文章帶來的是unordered…

歐拉公式的意義

歐拉公式的意義 歐拉公式&#xff08;Euler’s Formula&#xff09;是數學中最重要的公式之一&#xff0c;它將復數、指數函數和三角函數緊密聯系在一起。其基本形式為&#xff1a; eiθcos?θisin?θ e^{i\theta} \cos \theta i \sin \theta eiθcosθisinθ 當 θπ\thet…

Linux Docker 運行SQL Server

在Linux操作系統&#xff0c;已安裝docker&#xff0c;現在以docker compose方式&#xff0c;安裝一個最新版SQL Server 2022的數據庫。 # 建個目錄&#xff08;請不要照抄&#xff0c;我的數據盤在/data&#xff0c;你可以改為/opt&#xff09; mkdir /data/sqlserver# 進入目…

C++:stack_queue(2)實現底層

文章目錄一.容器適配器1. 本質&#xff1a;2. 接口&#xff1a;3. 迭代器&#xff1a;4. 功能&#xff1a;二.deque的簡單介紹1.概念與特性2.結構與底層邏輯2.1 雙端隊列&#xff08;deque&#xff09;結構&#xff1a;2.2 deque的內部結構2.3 deque的插入與刪除操作&#xff1…

Lightroom 安卓版 + Windows 版 + Mac 版全適配,編輯管理一站式,專業攝影后期教程

軟件是啥樣的? Adobe Lightroom 這軟件&#xff0c;在安卓手機、Windows 電腦和 Mac 電腦上都能用。不管是喜歡拍照的人&#xff0c;還是專門搞攝影的&#xff0c;用它都挺方便&#xff0c;能一站式搞定照片編輯、整理和分享這些事兒。 ****下載地址 分享文件&#xff1a;【Li…

office卸載不干凈?Office356卸載不干凈,office強力卸載軟件下載

微軟官方認可的卸載工具&#xff0c;支持徹底清除Office組件及注冊表殘留。需要以管理員身份運行&#xff0c;選擇“移除Office”功能并確認操作。 Office Tool Plus安裝地址獲取 點擊這里獲取&#xff1a;Office Tool Plus 1、雙擊打開軟件 image 2、選擇左右的工具箱&…

互聯網企業慢性死亡的招聘視角分析:從崗位割裂看戰略短視

內容簡介&#xff1a; 一個獵頭和HR的簡單拒絕&#xff0c;揭示了中國互聯網企業人才觀念的深層問題。通過分析崗位過度細分現象&#xff0c;本文探討了戰略短視、內斗文化和核心競爭力缺失如何導致企業慢性死亡&#xff0c;并提出了系統性的解決方案。#互聯網企業 #人才招聘 #…

OpenBMC中phosphor-dbus-interfaces深度解析:架構、原理與應用實踐

引言 在OpenBMC生態系統中&#xff0c;phosphor-dbus-interfaces作為D-Bus接口定義的核心組件&#xff0c;扮演著系統各模塊間通信"契約"的關鍵角色。本文將基于OpenBMC源碼&#xff0c;從架構設計、實現原理到實際應用三個維度&#xff0c;全面剖析這一基礎組件的技…

駕駛場景玩手機識別準確率↑32%:陌訊動態特征融合算法實戰解析

原創聲明本文為原創技術解析文章&#xff0c;核心技術參數與架構設計參考自《陌訊技術白皮書》&#xff0c;轉載請注明出處。一、行業痛點&#xff1a;駕駛場景行為識別的現實挑戰根據交通運輸部道路運輸司發布的《駕駛員不安全行為研究報告》顯示&#xff0c;駕駛過程中使用手…

Mysql——單表最多數據量多少需要分表

目錄 一、MySql單表最多數據量多少需要分表 1.1、阿里開發公約 1.2、一個三層的B+樹,它最多可以存儲多少數據量 1.3、示例 1.3.1、示例表中一行的數據占多少字節數 1.3.2、示例表中一頁里面最多可以存多少條記錄 1.3.3、按示例表計算,一個三層的B+樹,可以放多少條100字節的數…

scikit-learn/sklearn學習|嶺回歸解讀

【1】引言 前序學習進程中&#xff0c;對用scikit-learn表達線性回歸進行了初步解讀。 線性回歸能夠將因變量yyy表達成由自變量xxx、線性系數矩陣www和截距bbb組成的線性函數式&#xff1a; y∑i1nwi?xibwTxby\sum_{i1}^{n}w_{i}\cdot x_{i}bw^T{x}byi1∑n?wi??xi?bwTxb實…

基于Django的圖書館管理系統的設計與實現

基于Django的圖書館管理系統的設計與實現、

ComfyUI版本更新---解決ComfyUI的節點不兼容問題

前言&#xff1a; 新版本的COMFYUI與節點容易出現不兼容的問題,會導致整個系統崩掉。 目錄 一、前期準備工作&#xff1a;虛擬環境配置 為什么需要虛擬環境&#xff1f; 具體操作步驟 二、常見問題解決方案 1、工作流輸入輸出圖像不顯示問題 2、工作流不能拖動&#xff0…

生產管理ERP系統|物聯及生產管理ERP系統|基于SprinBoot+vue的制造裝備物聯及生產管理ERP系統設計與實現(源碼+數據庫+文檔)

生產管理ERP系統 目錄 基于SprinBootvue的制造裝備物聯及生產管理ERP系統設計與實現 一、前言 二、系統設計 三、系統功能設計 四、數據庫設計 五、核心代碼 六、論文參考 七、最新計算機畢設選題推薦 八、源碼獲取&#xff1a; 博主介紹&#xff1a;??大廠碼農|畢…

Numpy科學計算與數據分析:Numpy數組創建與應用入門

Numpy數組創建實戰 學習目標 通過本課程的學習&#xff0c;學員將掌握使用Numpy庫創建不同類型的數組的方法&#xff0c;包括一維數組、多維數組、全零數組、全一陣列、空數組等。本課程將通過理論講解與實踐操作相結合的方式&#xff0c;幫助學員深入理解Numpy數組的創建過程…