【深度學習】多目標融合算法(五):定制門控網絡CGC(Customized Gate Control)

目錄

一、引言

二、CGC(Customized Gate Control,定制門控網絡)

2.1 技術原理

2.2?技術優缺點

2.3 業務代碼實踐

2.3.1 業務場景與建模

2.3.2 模型代碼實現

2.3.3 模型訓練與推理測試

2.3.4 打印模型結構?

三、總結


一、引言

上一篇我們講了MMoE多任務網絡,通過對每一個任務塔建立Gate門控,對專家網絡進行加權平均,Gate門控起到了對多個共享專家重要度篩選的作用。在每輪反向傳播時,每個任務tower分別更新對應Gate的參數,以及共享專家的參數。模型主要起到了多目標任務平衡的作用。

今天我們重點將CGC(Customized Gate Control)定制門控網絡,核心思想是在MMoE基礎上,為每一個任務tower定制獨享專家,實用任務獨享專家與共享專家共同決定任務Tower的輸入,相比于MMoE僅用Gate門控表征任務Tower的方法,CGC引入獨享專家,對任務表征更加全面,又通過共享專家保證關聯性。

二、CGC(Customized Gate Control,定制門控網絡)

2.1 技術原理

CGC(Customized Gate Control)全稱為定制門控網絡,主要由多個任務塔、對應多組獨享專家網絡,對應多個門控網絡以及一組共享專家網絡,專家網絡組內可以包含多個專家MLP。核心原理:樣本input分別輸入共享專家MLP、獨立專家MLP、獨立專家對應門控網絡,門控網絡輸出為經過softmax的權重分布,維度對應共享專家數num_shared_experts和獨立專家數num_task_experts的和,通過對獨立專家輸出和共享專家輸出采用Gate門控加權平均后, 輸入到對應的任務Tower。每個任務Tower輸入自己對應的獨享專家、共享專家、門控加權平均的輸入。反向傳播時,每個任務更新自己獨享專家、獨享門控以及共享專家的參數。

  • 共享專家網絡:樣本數據分別輸入num_shared_experts個專家網絡進行推理,每個共享專家網絡實際上是一個多層感知機(MLP),輸入維度為x,輸出維度為output_experts_dim。
  • 獨享專家網絡:樣本數據分別輸入num_task_experts個專家網絡進行推理,每個共享專家網絡實際上是一個多層感知機(MLP),輸入維度為x,輸出維度為output_experts_dim。
  • 門控網絡:樣本數據輸出各自任務對應的門控網絡,每個門控網絡可以是一個多層感知機,也可以是一個雙層的交叉,主要是為了輸出專家網絡的加權平均權重。
  • 任務網絡:對于每一個Task,將各自對應num_shared_experts個共享專家和num_task_experts個獨立專家,基于對應gate門控網絡的softmax加權平均,作為各自Task的輸入,所有Task的輸入統一維度均為output_experts_dim。

2.2?技術優缺點

相較于MMoE網絡,CGC為每一個任務tower定制獨享專家,實用任務獨享專家與共享專家共同決定任務Tower的輸入,相比于MMoE僅用Gate門控表征任務Tower的方法,CGC引入獨享專家,對任務表征更加全面,又通過共享專家保證關聯性。

優點:

  • 切斷任務tower與其他任務獨享專家的聯系,使得獨享專家能夠更專注的學習本任務內的知識與信息。比如切斷互動塔與點擊專家的聯系,只和互動專家同時迭代,讓互動目標的學習更加純粹。
  • 獨享專家只受對應任務梯度的影響,不受其他任務梯度的影響,而共享專家可以被多個任務梯度同時更新。
  • 本質上,CGC就是在MMoE上新增了獨享專家,MMoE僅有共享專家。

缺點:?

  • 相較于PLE、SNR等,沒有學習到專家與專家之間的相互關系,層級堆疊不夠。
  • 相較于DeepSeekMoE的路由方法,CGC還是過于定制化與單一話,專家組合不足。

2.3 業務代碼實踐

2.3.1 業務場景與建模

我們還是以小紅書推薦場景為例,針對一個視頻,用戶可以點紅心(互動),也可以點擊視頻進行播放(點擊),針對互動和點擊兩個目標進行多目標建模

我們構建一個100維特征輸入,1組共享專家網絡(含2個共享專家),2組獨享專家網絡(各含2個獨享專家),2個門控,2個任務塔的CGC網絡,用于建模多目標學習問題,模型架構圖如下:

??????????????

如架構圖所示,其中有幾個注意的點:

  • num_shared_experts+num_task_expertsGate的維度等于共享專家的維度加上任務獨享專家的維度。
  • output_experts_dim:共享專家、獨享專家網絡的輸出維度和task網絡的輸入維度相同,task網絡承接的是專家網絡各維度的加權平均值,experts網絡與task網絡是直接對應關系。
  • Softmax:Gate門控網絡對共享專家和獨享專家的偏好權重采用Softmax歸一化,保證專家網絡加權平均后值域相同

2.3.2 模型代碼實現

基于pytorch,實現上述CGC網絡架構,如下:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDatasetclass CGCModel(nn.Module):def __init__(self, input_dim, experts_hidden1_dim, experts_hidden2_dim, output_experts_dim, task_hidden1_dim, task_hidden2_dim, output_task1_dim, output_task2_dim, gate_hidden1_dim, gate_hidden2_dim, num_shared_experts, num_task_experts):super(CGCModel, self).__init__()# 初始化函數外使用初始化變量需要賦值,否則默認使用全局變量# 初始化函數內使用初始化變量不需要賦值 self.num_shared_experts = num_shared_expertsself.num_task_experts = num_task_expertsself.output_experts_dim = output_experts_dim# 初始化共享專家self.shared_experts_2 = nn.ModuleList([nn.Sequential(nn.Linear(input_dim, experts_hidden1_dim),nn.ReLU(),nn.Linear(experts_hidden1_dim, experts_hidden2_dim),nn.ReLU(),nn.Linear(experts_hidden2_dim, output_experts_dim),nn.ReLU()) for _ in range(num_shared_experts)])# 初始化任務1專家self.task1_experts_2 = nn.ModuleList([nn.Sequential(nn.Linear(input_dim, experts_hidden1_dim),nn.ReLU(),nn.Linear(experts_hidden1_dim, experts_hidden2_dim),nn.ReLU(),nn.Linear(experts_hidden2_dim, output_experts_dim),nn.ReLU()) for _ in range(num_task_experts)])# 初始化任務2專家self.task2_experts_2 = nn.ModuleList([nn.Sequential(nn.Linear(input_dim, experts_hidden1_dim),nn.ReLU(),nn.Linear(experts_hidden1_dim, experts_hidden2_dim),nn.ReLU(),nn.Linear(experts_hidden2_dim, output_experts_dim),nn.ReLU()) for _ in range(num_task_experts)])# 初始化門控網絡任務1self.gating1_network_2 = nn.Sequential(nn.Linear(input_dim, gate_hidden1_dim),nn.ReLU(),nn.Linear(gate_hidden1_dim, gate_hidden2_dim),nn.ReLU(),nn.Linear(gate_hidden2_dim, num_shared_experts+num_task_experts),nn.Softmax(dim=1))# 初始化門控網絡任務2self.gating2_network_2 = nn.Sequential(nn.Linear(input_dim, gate_hidden1_dim),nn.ReLU(),nn.Linear(gate_hidden1_dim, gate_hidden2_dim),nn.ReLU(),nn.Linear(gate_hidden2_dim, num_shared_experts+num_task_experts),nn.Softmax(dim=1))# 定義任務1的輸出層self.task1_head = nn.Sequential(nn.Linear(output_experts_dim, task_hidden1_dim),nn.ReLU(),nn.Linear(task_hidden1_dim, task_hidden2_dim),nn.ReLU(),nn.Linear(task_hidden2_dim, output_task1_dim),nn.Sigmoid()) # 定義任務2的輸出層self.task2_head = nn.Sequential(nn.Linear(output_experts_dim, task_hidden1_dim),nn.ReLU(),nn.Linear(task_hidden1_dim, task_hidden2_dim),nn.ReLU(),nn.Linear(task_hidden2_dim, output_task2_dim),nn.Sigmoid()) def forward(self, x):gates1 = self.gating1_network_2(x)gates2 = self.gating2_network_2(x)#定義專家網絡輸出作為任務塔輸入batch_size, _ = x.shapetask1_inputs = torch.zeros(batch_size, self.output_experts_dim)task2_inputs = torch.zeros(batch_size, self.output_experts_dim)for i in range(self.num_shared_experts):task1_inputs += self.shared_experts_2[i](x) * gates1[:, i].unsqueeze(1) + self.task1_experts_2[i](x) * gates1[:, i+self.num_shared_experts].unsqueeze(1)task2_inputs += self.shared_experts_2[i](x) * gates2[:, i].unsqueeze(1) + self.task2_experts_2[i](x) * gates2[:, i+self.num_shared_experts].unsqueeze(1)task1_outputs = self.task1_head(task1_inputs)task2_outputs = self.task2_head(task2_inputs)return task1_outputs, task2_outputs# 實例化模型對象
experts_hidden1_dim = 64
experts_hidden2_dim = 32
output_experts_dim = 16
gate_hidden1_dim = 16
gate_hidden2_dim = 8
task_hidden1_dim = 32
task_hidden2_dim = 16
output_task1_dim = 1
output_task2_dim = 1
num_shared_experts = 2
num_task_experts = 2# 構造虛擬樣本數據
torch.manual_seed(42)  # 設置隨機種子以保證結果可重復
input_dim = 100
num_samples = 1024
X_train = torch.randint(0, 2, (num_samples, input_dim)).float()
y_train_task1 = torch.rand(num_samples, output_task1_dim)  # 假設任務1的輸出維度為1
y_train_task2 = torch.rand(num_samples, output_task2_dim)  # 假設任務2的輸出維度為1# 創建數據加載器
train_dataset = TensorDataset(X_train, y_train_task1, y_train_task2)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)model = CGCModel(input_dim, experts_hidden1_dim, experts_hidden2_dim, output_experts_dim, task_hidden1_dim, task_hidden2_dim, output_task1_dim, output_task2_dim, gate_hidden1_dim, gate_hidden2_dim, num_shared_experts, num_task_experts)# 定義損失函數和優化器
criterion_task1 = nn.MSELoss()
criterion_task2 = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 訓練循環
num_epochs = 100
for epoch in range(num_epochs):model.train()running_loss = 0.0for batch_idx, (X_batch, y_task1_batch, y_task2_batch) in enumerate(train_loader):# 前向傳播: 獲取預測值#print(batch_idx, X_batch )#print(f'Epoch [{epoch+1}/{num_epochs}-{batch_idx}], Loss: {running_loss/len(train_loader):.4f}')outputs_task1, outputs_task2 = model(X_batch)# 計算每個任務的損失loss_task1 = criterion_task1(outputs_task1, y_task1_batch)loss_task2 = criterion_task2(outputs_task2, y_task2_batch)total_loss = loss_task1 + loss_task2# 反向傳播和優化optimizer.zero_grad()total_loss.backward()optimizer.step()running_loss += total_loss.item()if epoch % 10 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}')print(model)
#for param_tensor in model.state_dict():
#    print(param_tensor, "\t", model.state_dict()[param_tensor].size())
# 模型預測
model.eval()
with torch.no_grad():test_input = torch.randint(0, 2, (1, input_dim)).float()  # 構造一個測試樣本pred_task1, pred_task2 = model(test_input)print(f'互動目標預測結果: {pred_task1}')print(f'點擊目標預測結果: {pred_task2}')

相比于上一篇MMoE中的代碼,CGC復雜了很多,新增了2組獨享專家,且在門控與獨享、共享專家加權平均計算的時候需要進行處理,很容易出問題。

2.3.3 模型訓練與推理測試

運行上述代碼,模型啟動訓練,Loss逐漸收斂,測試結果如下:

2.3.4 打印模型結構????????

三、總結

本文詳細介紹了CGC多任務模型的算法原理、算法優勢,他是下一篇PLE多層多任務模型的基礎,并以小紅書業務場景為例,構建CGC網絡結構并使用pytorch代碼實現對應的網絡結構、訓練流程。相比于MMoE,CGC新增獨享專家網絡,通過gate門控的串聯,切斷任務Tower與其他任務獨享專家的聯系,使得獨享專家能夠更專注的學習本任務內的知識與信息。

如果您還有時間,歡迎閱讀本專欄的其他文章:

【深度學習】多目標融合算法(一):樣本Loss加權(Sample Loss Reweight)

【深度學習】多目標融合算法(二):底部共享多任務模型(Shared-Bottom Multi-task Model)????????

【深度學習】多目標融合算法(三):混合專家網絡MOE(Mixture-of-Experts)?

?【深度學習】多目標融合算法(四):多門混合專家網絡MMOE(Multi-gate Mixture-of-Experts)???????

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/73852.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/73852.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/73852.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

在線pdf處理網站合集

1、PDF24 Tools:https://tools.pdf24.org/zh/ 2、PDF派:https://www.pdfpai.com/ 3、ALL TO ALL:https://www.alltoall.net/ 4、CleverPDF:https://www.cleverpdf.com/cn 5、Doc Small:https://docsmall.com/ 6、Aconv…

網絡編程-實現客戶端通信

#include <stdio.h> #include <stdlib.h> #include <string.h> #include <unistd.h> #include <sys/socket.h> #include <netinet/in.h> #include <sys/select.h>#define MAX_CLIENTS 2 // 最大客戶端連接數 #define BUFFER_SI…

力扣100二刷——圖論、回溯

第二次刷題不在idea寫代碼&#xff0c;而是直接在leetcode網站上寫&#xff0c;“逼”自己掌握常用的函數。 標志掌握程度解釋辦法?Fully 完全掌握看到題目就有思路&#xff0c;編程也很流利??Basically 基本掌握需要稍作思考&#xff0c;或者看到提示方法后能解答???Sl…

【大模型實戰篇】多模態推理模型Skywork-R1V

1. 背景介紹 近期昆侖萬維開源的Skywork R1V模型&#xff0c;是基于InternViT-6B-448px-V2_5以及deepseek-ai/DeepSeek-R1-Distill-Qwen-32B 通過強化學習得到。當然語言模型也可以切換成QwQ-32B。因此該模型最終的參數量大小為38B。 該模型具備多模態推理能力&#xf…

識別并脫敏上傳到deepseek/chatgpt的文本文件中的護照信息

本文將介紹一種簡單高效的方法解決用戶在上傳文件到DeepSeek、ChatGPT&#xff0c;文心一言&#xff0c;AI等大語言模型平臺過程中的護照號識別和脫敏問題。 DeepSeek、ChatGPT&#xff0c;Qwen&#xff0c;Claude等AI平臺工具快速的被接受和使用&#xff0c;用戶每天上傳的文…

數據驅動進化:AI Agent如何重構手機交互范式?

如果說AIGC拉開了內容生成的序幕&#xff0c;那么AI Agent則標志著AI從“工具”向“助手”的跨越式進化。它不再是簡單的問答機器&#xff0c;而是一個能夠感知環境、規劃任務并自主執行的智能體&#xff0c;更像是虛擬世界中的“全能員工”。 正如行業所熱議的&#xff1a;“大…

【AI News | 20250319】每日AI進展

AI Repos 1、XianyuAutoAgent 實現了 24 小時自動化值守的 AI 智能客服系統&#xff0c;支持多專家協同決策、智能議價和上下文感知對話&#xff0c;讓我們店鋪管理更輕松。主要功能&#xff1a; 智能對話引擎&#xff0c;支持上下文感知和專家路由階梯降價策略&#xff0c;自…

nginx中間件部署

中間件部署流程 ~高級權限賬戶安裝必要的插件 -> 普通權限賬戶安裝所需要的服務 -> 高級權限賬戶開啟并設置開機自啟所安裝的服務 -> iptables放行所需要的服務 普通權限賬戶安裝NGINX中間件 1、擁有高級權限的賬戶安裝必要的插件 sudo yum install -y gcc-c make…

C語言自定義類型【結構體】詳解,【結構體內存怎么計算】 詳解 【熱門考點】:結構體內存對齊

引言 詳細講解什么是結構體&#xff0c;結構體的運用&#xff0c; 詳細介紹了結構體在內存中占幾個字節的計算。 【熱門考點】&#xff1a;結構體內存對齊 介紹了&#xff1a;結構體傳參 一、什么是結構體&#xff1f; 結構是?些值的集合&#xff0c;這些值稱為成員變量。結構…

前端應用更新通知機制全解析:構建智能化版本更新策略

引言&#xff1a;數字時代的更新挑戰 在持續交付的現代軟件開發模式下&#xff0c;前端應用平均每周產生2-3次版本迭代。但據Google研究報告顯示&#xff0c;38%的用戶在遇到功能異常時仍在使用過期版本的應用。如何優雅地實現版本更新通知&#xff0c;已成為提升用戶體驗的關…

Apache DolphinScheduler:一個可視化大數據工作流調度平臺

Apache DolphinScheduler&#xff08;海豚調度&#xff09;是一個分布式易擴展的可視化工作流任務調度開源系統&#xff0c;適用于企業級場景&#xff0c;提供了一個可視化操作任務、工作流和全生命周期數據處理過程的解決方案。 Apache DolphinScheduler 旨在解決復雜的大數據…

[藍橋杯 2023 省 B] 飛機降落

[藍橋杯 2023 省 B] 飛機降落 題目描述 N N N 架飛機準備降落到某個只有一條跑道的機場。其中第 i i i 架飛機在 T i T_{i} Ti? 時刻到達機場上空&#xff0c;到達時它的剩余油料還可以繼續盤旋 D i D_{i} Di? 個單位時間&#xff0c;即它最早可以于 T i T_{i} Ti? 時刻…

使用Trae 生成的React版的貪吃蛇

使用Trae 生成的React版的貪吃蛇 首先你想用這個貪吃蛇&#xff0c;你需要先安裝Trae Trae 官方地址 他有兩種模式 chat builder 我使用的是builder模式,雖然是Alpha.還是可以用。 接下來就是按著需求傻瓜式的操作生成代碼 他生成的代碼不完全正確&#xff0c;比如沒有引入…

AI大模型:(一)1.大模型的發展與局限

說起AI大模型不得不說下機器學習的發展史&#xff0c;機器學習包括傳統機器學習、深度學習&#xff0c;而大模型&#xff08;Large Models&#xff09;屬于機器學習中的深度學習&#xff08;Deep Learning&#xff09;領域&#xff0c;具體來說&#xff0c;它們通常基于神經網絡…

rust學習筆記17-異常處理

今天聊聊rust中異常錯誤處理 1. 基礎類型&#xff1a;Result 和 Option&#xff0c;之前判斷空指針就用到過 Option<T> 用途&#xff1a;表示值可能存在&#xff08;Some(T)&#xff09;或不存在&#xff08;None&#xff09;&#xff0c;適用于無需錯誤信息的場景。 f…

Python:單繼承方法的重寫

繼承&#xff1a;讓類和類之間轉變為父子關系&#xff0c;子類默認繼承父類的屬性和方法 單繼承&#xff1a; class Person:def eat(self):print("eat")def sing(self):print("sing") class Girl(Person):pass#占位符&#xff0c;代碼里面類下面不寫任何東…

記錄一下aes加密與解密

該文章只做拓展后續會更新&#xff1b;如有出錯請指出 首先需要先引入相關依賴 crypto-js 然后直接開始存儲 export function aesEncrypt(message: string, key: string) {return aes.encrypt(message, key).toString(); } 之后是解密方式 function decrypt(content: any, key…

[免費]直接整篇翻譯pdf工具-支持多種語言

<閑來沒事寫篇博客填補中文知識庫漏洞> 如題&#xff0c;[免費][本地]工具基于開源倉庫&#xff1a; 工具 是python&#xff01;太好了&#xff0c;所以各個平臺都可以&#xff0c;我這里基于windows. 1. 先把github代碼下載下來&#xff1a; git clone https://githu…

UI設計中的用戶反饋機制:提升交互體驗的關鍵

hello寶子們...我們是艾斯視覺擅長ui設計和前端數字孿生、大數據、三維建模、三維動畫10年經驗!希望我的分享能幫助到您!如需幫助可以評論關注私信我們一起探討!致敬感謝感恩! 在數字化產品泛濫的今天&#xff0c;用戶與界面的每一次交互都在無形中塑造著他們對產品的認知。一個…

Hessian 矩陣是什么

Hessian 矩陣是什么 目錄 Hessian 矩陣是什么Hessian 矩陣的性質及舉例說明**1. 對稱性****2. 正定性決定極值類型****特征值為 2(正),因此原點 ( 0 , 0 ) (0, 0) (0,0) 是極小值點。****3. 牛頓法中的應用****4. 特征值與曲率方向****5. 機器學習中的實際意義**一、定義與…