特征交叉系列:DCN-Mix 混合低秩交叉網絡理論和實踐

DCN-Mix和DCN-V2的關系

DCN-Mix(a mixture of low-rank DCN)是基于DCN-V2的改進版,它提出使用矩陣分解降低DCN-V2的時間空間復雜度,又引入多次矩陣分解來達到類似混合專家網絡MOE的效果從而提升交叉層的表征能力,若讀者對DCN-V2不甚了解可以參考上一節[特征交叉系列:Deep&Cross(DCN-V2)理論和實踐]做知識鋪墊。


DCN-V2權重矩陣的低秩性和矩陣分解

在DCN-V2中核心的參數是交叉層的權重矩陣W,該參數是M×M的方陣,其中M是所有輸入embedding拼接后的向量總長度,每一層交叉之間W不共享,W矩陣需要學習的參數數量能占到所有參數量的70%以上,而進一步作者發現隨著網絡的訓練,W矩陣的奇異值出現快速下降呈現出低秩特性,代表該矩陣存在信息冗余,因此可以考慮通過矩陣分解來進行特征提取和信息壓縮。
在PyTorch中可以通過torch.linalg.svd計算出矩陣的奇異值,例如

>>> a = torch.tensor([[1, 1], [1, 1.1]])
>>> u, s, v = torch.linalg.svd(a)
>>> print(s)
tensor([2.0512, 0.0488])

其中s是對角陣,斜對角線上的值就是奇異值,a矩陣的第二行幾乎可以從第一行線性變換而來,因此s各位置上的奇異值差距極大,第一個奇異值基本攜帶了全部的矩陣信息。
在DCN-V2的訓練代碼里面,打印出第一個交叉層初始化的W矩陣和訓練早停后W矩陣的奇異值,奇異值的長度和輸入長度M一致,代碼如下

# 初始化時
model = DCN(field_num=10, feat_dim=72, emb_num=16, order_num=2, dropout=0.1, method='parallel').to(DEVICE)
init_s = torch.linalg.svd(model.cross_net.cell_list[0].w)[1].cpu().detach().numpy().tolist()
# 早停時
if early_stop_flag:train_s = torch.linalg.svd(model.cross_net.cell_list[0].w)[1].cpu().detach().numpy().tolist()break

奇異值列表中元素大小逐個遞減,對init_s和train_s分別做最大最小歸一化,要求第一個奇異值歸因化為1,

init_s = [(x - min(init_s)) / (max(init_s) - min(init_s)) for x in init_s]
train_s = [(x - min(train_s)) / (max(train_s) - min(train_s)) for x in train_s]

然后做圖看一下初始矩陣的奇異值和收斂后的奇異值的各個位置元素的大小情況

import matplotlib.pylab as plt
plt.scatter(list(range(len(init_s))), init_s, label='init', s=3)
plt.scatter(list(range(len(train_s))), train_s, label='learned', s=3)
plt.legend(loc=0)
plt.show()

init和learned奇異值下降對比

相比于初始化階段(藍線),模型收斂后(橙線)的W矩陣奇異值急速下降,說明頭部的奇異值已經攜帶了大部分矩陣信息,W矩陣可以考慮做壓縮。
在論文中作者將W分解為U,V兩個矩陣的相乘,其中U,V都是維度為[M, R]的二維矩陣,M和輸入等長,R<=M/2,公式如下

矩陣分解

此時一個交叉權重的參數數量由M平方降低為2×MR。


DCN-Mix的混合專家網絡

DCN-Mix使用矩陣UV分解來逼近原始的交叉矩陣W,受到MOE(Mixture of Experts)混合專家網絡的啟發,作者對W進行多次矩陣分解,單個矩陣分解相當于單個專家網絡(Expert)在子空間學習特征交叉,再引入門控機制(Gate)對多個子空間的交叉結果進行自適應地融合,從而提高交叉層的表達能力,DCN結合MOE的示意圖如下

MOE示意圖

其中該層的輸入Input x分別進入n個Expert專家網絡,專家網絡中包含UV矩陣相乘,同時Input x輸入給一個門控網絡Gate+Softmax輸出n個權重標量,最后Input x會和加權求和的專家網絡結果做殘差連接。
將矩陣分解和MOE結合起來形成最終的交叉層公式如下

結合MOE的矩陣分解交叉層

相比于DCN-V2,等號左側的哈達瑪積部分改為了一個Σ加權求和的UV矩陣逼近,而右側的殘差連接放到最后和MOE的結果一起做殘差連接。


DCN-Mix在PyTorch下的實踐

本次實踐的數據集和上一篇特征交叉系列:完全理解FM因子分解機原理和代碼實戰一致,采用用戶的購買記錄流水作為訓練數據,用戶側特征是年齡,性別,會員年限等離散特征,商品側特征采用商品的二級類目,產地,品牌三個離散特征,隨機構造負樣本,一共有10個特征域,全部是離散特征,對于枚舉值過多的特征采用hash分箱,得到一共72個特征。
DCN-Mix的PyTorch代碼實現如下

class Embedding(nn.Module):def __init__(self, feat_num, emb_num):super(Embedding, self).__init__()self.embedding = nn.Embedding(feat_num, emb_num)nn.init.xavier_normal_(self.embedding.weight.data)def forward(self, x):# [None, filed_num] => [None, filed_num, emb_num] => [None, filed_num * emb_num]return self.embedding(x).flatten(1)class DNN(nn.Module):def __init__(self, input_num, hidden_nums, dropout=0.1):super(DNN, self).__init__()layers = []input_num = input_numfor hidden_num in hidden_nums:layers.append(nn.Linear(input_num, hidden_num))layers.append(nn.BatchNorm1d(hidden_num))layers.append(nn.ReLU())layers.append(nn.Dropout(p=dropout))input_num = hidden_numself.mlp = nn.Sequential(*layers)for layer in self.mlp:if isinstance(layer, nn.Linear):nn.init.xavier_normal_(layer.weight.data)def forward(self, x):return self.mlp(x)class CrossCell(nn.Module):"""一個交叉單元"""def __init__(self, input_num, r):super(CrossCell, self).__init__()self.v = nn.Parameter(torch.randn(input_num, r))self.u = nn.Parameter(torch.randn(input_num, r))self.b = nn.Parameter(torch.randn(input_num, 1))nn.init.xavier_normal_(self.v.data)nn.init.xavier_normal_(self.u.data)def forward(self, x0, xi):# [None, emb_num] => [None, emb_num, 1]xi = xi.unsqueeze(2)x0 = x0.unsqueeze(2)# [r, input_num] * [None, emb_num, 1] => [None, r, 1]# [input_num, r] * [None, r, 1] => [None, emb_num, 1]xii = (torch.matmul(self.u, torch.matmul(self.v.t(), xi)) + self.b) * x0return xii  # [None, emb_num, 1]class MOECrossCell(nn.Module):def __init__(self, input_num, r, k):super(MOECrossCell, self).__init__()self.k = kself.cross_cell = nn.ModuleList([CrossCell(input_num, r) for i in range(self.k)])self.gate = nn.Linear(input_num, self.k)nn.init.xavier_normal_(self.gate.weight.data)def forward(self, x0, xi):# [None, emb_num] => [None, emb_num, 1]xii = xi.unsqueeze(2)export_out = []for i in range(self.k):cross_out = self.cross_cell[i](x0, xi)# [[None, emb_num, 1], [None, emb_num, 1], [None, emb_num, 1], [None, emb_num, 1]]export_out.append(cross_out)export_out = torch.concat(export_out, dim=2)  # [None, emb_num, 4]# [None, k] => [None, 1, k]gate_out = self.gate(xi).softmax(dim=1).unsqueeze(dim=1)# [None, emb_num, 4] * [None, 1, k] = [None, emb_num, k] => [None, emb_num, 1]out = torch.sum(export_out * gate_out, dim=2, keepdim=True)out = out + xii  # [None, emb_num, 1]return out.squeeze(2)class CrossNet(nn.Module):def __init__(self, order_num, input_num, r, k):super(CrossNet, self).__init__()self.order = order_numself.cell_list = nn.ModuleList([MOECrossCell(input_num, r, k) for i in range(order_num)])def forward(self, x0):xi = x0for i in range(self.order):xi = self.cell_list[i](x0=x0, xi=xi)return xiclass DCN(nn.Module):def __init__(self, field_num, feat_dim, emb_num, order_num, r=16, k=4, dropout=0.1, method='parallel',hidden_nums=(128, 64, 32)):super(DCN, self).__init__()input_num = field_num * emb_numself.embedding = Embedding(feat_num=feat_dim, emb_num=emb_num)self.dnn = DNN(input_num=input_num, hidden_nums=hidden_nums, dropout=dropout)self.cross_net = CrossNet(order_num=order_num, input_num=input_num, r=r, k=k)if method not in ('parallel', 'stacked'):raise ValueError('unknown combine type: ' + method)self.method = methodlinear_dim = hidden_nums[-1]if self.method == 'parallel':linear_dim = linear_dim + input_numself.linear = nn.Linear(linear_dim, 1)nn.init.xavier_normal_(self.linear.weight.data)def forward(self, x):emb = self.embedding(x)  # [None, field * emb_num]cross_out = self.cross_net(emb)  # [None, input_num]if self.method == 'parallel':dnn_out = self.dnn(emb)  # [None, input_num]out = torch.concat([cross_out, dnn_out], dim=1)else:out = self.dnn(cross_out)  # [None, input_num]out = self.linear(out)return torch.sigmoid(out).squeeze(dim=1)

在CrossCell模塊中完成了一個給予UV逼近的交叉操作,在MOECrossCell模塊中完成了MOE和殘差連接,其中export_out和gate_out分別為專家網絡的輸出和門控機制的權重。
本例全部是離散分箱變量,所有有值的特征都是1,因此只要輸入有值位置的索引即可,一條輸入例如

>>> train_data[0]
Out[120]: (tensor([ 2, 10, 14, 18, 34, 39, 47, 51, 58, 64]), tensor(0))

x的長度為10代表10個特征域,每個域的值是特征的全局位置索引,從0到71,一共72個特征。


DCN-Mix調參和效果對比

對階數(order_num)和融合策略(method)這兩個參數進行調參,分別嘗試1~4層交叉層,stacked和parallel兩種策略,采用10次驗證集AUC不上升作為早停條件,驗證集的平均AUC如下

DCN調參AUC并行parallel串行stacked
1層交叉(2階)0.63450.6321
2層交叉(3階)0.63280.6323
3層交叉(4階)0.63310.6333
4層交叉(5階)0.63400.6331

結論依舊是parallel效果好于stacked,其中一層交叉的并行parallel達到驗證集最優AUC為0.6345。
再對比一下之前文章中實踐的FM,FFM,PNN,DCN-V2等一系列算法,驗證集AUC和參數規模如下

算法AUC參數量
FM0.6274361
FFM0.63172953
PNN*0.634229953
DeepFM0.632212746
NFM0.632910186
DCN-parallel-30.6348110017
DCN-stacked-30.6344109857
DCN-Mix-parallel-10.634554501
DCN-Mix-stacked-30.633397869

使用矩陣分解逼近策略的DCN-Mix略低于原生的DCN-V2,但是還是超越一眾FM系列的算法,其中以同樣是三層交叉的stacked DCN為例,DCN-Mix的參數量相比于DCN-V2有所降低,也印證了論文中提到的“在模型效果和部署延遲之間找到一個平衡”。

最后的最后

感謝你們的閱讀和喜歡,我收藏了很多技術干貨,可以共享給喜歡我文章的朋友們,如果你肯花時間沉下心去學習,它們一定能幫到你。

因為這個行業不同于其他行業,知識體系實在是過于龐大,知識更新也非常快。作為一個普通人,無法全部學完,所以我們在提升技術的時候,首先需要明確一個目標,然后制定好完整的計劃,同時找到好的學習方法,這樣才能更快的提升自己。

這份完整版的大模型 AI 學習資料已經上傳CSDN,朋友們如果需要可以微信掃描下方CSDN官方認證二維碼免費領取【保證100%免費

一、全套AGI大模型學習路線

AI大模型時代的學習之旅:從基礎到前沿,掌握人工智能的核心技能!

img

二、640套AI大模型報告合集

這套包含640份報告的合集,涵蓋了AI大模型的理論研究、技術實現、行業應用等多個方面。無論您是科研人員、工程師,還是對AI大模型感興趣的愛好者,這套報告合集都將為您提供寶貴的信息和啟示。

img

三、AI大模型經典PDF籍

隨著人工智能技術的飛速發展,AI大模型已經成為了當今科技領域的一大熱點。這些大型預訓練模型,如GPT-3、BERT、XLNet等,以其強大的語言理解和生成能力,正在改變我們對人工智能的認識。 那以下這些PDF籍就是非常不錯的學習資源。

img

四、AI大模型商業化落地方案

img

五、面試資料

我們學習AI大模型必然是想找到高薪的工作,下面這些面試題都是總結當前最新、最熱、最高頻的面試題,并且每道題都有詳細的答案,面試前刷完這套面試題資料,小小offer,不在話下。
在這里插入圖片描述

這份完整版的大模型 AI 學習資料已經上傳CSDN,朋友們如果需要可以微信掃描下方CSDN官方認證二維碼免費領取【保證100%免費

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/23846.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/23846.shtml
英文地址,請注明出處:http://en.pswp.cn/web/23846.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

linux shell腳本啟動springboot服務

1.腳本代碼 xx.sh&#xff0c;自己隨意命名 #!/bin/bash# 設置變量 JAR_NAME"xssq-1.0.0.jar" JAR_PATH"./$JAR_NAME" PID0#檢查程序是否在運行 is_exist(){PIDps -ef|grep $JAR_NAME|grep -v grep|awk {print $2} #如果不存在返回1&#xff0c;存在返回0…

評價GPT-4的方案

評價GPT-4的方案 引言: 隨著人工智能技術的不斷發展,自然語言處理領域取得了顯著的突破。其中,GPT-4作為最新的大型語言模型之一,備受關注。本方案旨在對GPT-4進行全面評價,包括其技術特點、性能表現、應用場景以及潛在的影響等方面。 一、技術特點 1. 模型規模和參數數…

微信小程序使用自定義tabbar被組件遮擋調試層級沒有用

在我自定義使用tabbar的時候&#xff0c;發現使用vant weapp環形進度條的時候把tabbar給遮擋了&#xff0c;查看了文章說沒什么好的解決辦法&#xff0c;但是也有&#xff0c;鏈接在此 我是直接修改的自定義組件的標簽view標簽和image標簽都使用cover- image和cover-view代替就…

部署kubesphere報錯

安裝kubesphere報錯命名空間terminted [rootk8smaster ~]# kubectl apply -f kubesphere-installer.yaml Warning: apiextensions.k8s.io/v1beta1 CustomResourceDefinition is deprecated in v1.16, unavailable in v1.22; use apiextensions.k8s.io/v1 CustomResourceDefini…

618科技好物清單:物超所值的產品推薦,總有一款適合你!

隨著科技的不斷發展&#xff0c;我們生活中涌現出了越來越多的科技創新產品。這些產品不僅讓我們的生活變得更加便捷&#xff0c;還提升了我們的生活品質。而在即將到來的618購物節&#xff0c;正是我們購買這些物超所值科技好物的絕佳時機。 本文將為您推薦一些在618期間值得關…

軟光敏的程序實現

軟光敏的程序實現通常涉及到使用攝像頭或其他圖像捕捉設備的內部sensor來感應環境光線&#xff0c;并結合軟件算法來控制補光燈或其他相關設備的開關。以下是一個簡化的軟光敏程序實現的示例流程&#xff0c;使用偽代碼來描述&#xff1a; pseudo 初始化攝像頭 while 攝像頭開…

每天一個數據分析題(三百五十五)-業務分析報告

業務分析報告的主要作用是將業務分析報表中發現的業務問題進行匯總說明&#xff0c;并進一步提出解決問題的建議&#xff0c;以幫助閱讀者做出正確的決策判斷。業務分析報告撰寫的注意事項中正確的是&#xff1f; A. 條理清晰、結構完整 B. 論點明確 C. 圖、表、文字相結合 …

英偉達的數字孿生地球是什么

1 英偉達的數字孿生地球 Earth-2是一個全棧式開放平臺&#xff0c;包含&#xff1a;ICON 和 IFS 等數值模型的物理模擬&#xff1b;多種機器學習模型&#xff0c;例如 FourCastNet、GraphCast 和通過 NVIDIA Modulus 實現的深度學習天氣預測 (DLWP)&#xff1b;以及通過 NVIDI…

Go理論-面試題

面向對象&#xff1f; 面向對象是一種方法論。一種非常實用的系統化軟件開發方法。 三大特點&#xff1a;封裝、繼承、多態 Go和Java的區別 Go不允許重載&#xff0c;Java允許Java允許多態&#xff0c;Go沒有&#xff08;但可以通過接口實現&#xff09;Go語言的繼承通過匿…

手撕設計模式——克隆對象之原型模式

1.業務需求 ? 大家好&#xff0c;我是菠菜啊&#xff0c;前倆天有點忙&#xff0c;今天繼續更新了。今天給大家介紹克隆對象——原型模式。老規矩&#xff0c;在介紹這期之前&#xff0c;我們先來看看這樣的需求&#xff1a;《西游記》中每次孫悟空拔出一撮猴毛吹一下&#x…

pytorch-nn.Module

目錄 1. nn.Module2. nn.Sequential容器3. 網絡參數parameters4. Modules內部管理5. checkpoint6. train/test狀態切換6. 實現自己的網絡層6.1 實現打平操作6.2 實現自己的線性層 7. 代碼 1. nn.Module 是所有nn.類的父類&#xff0c;其中包括nn.Linear nn.BatchNorm2d nn.Con…

每日一練 - OSPF協議驗證機制

01 真題題目 OSPF 只有在 Hello 報文中有驗證信息,OSPF 支持 MD5 密文驗證. A.正確 B.錯誤 02 真題答案 B 03 答案解析 這個陳述是不完全正確的。首先&#xff0c;OSPF確實使用Hello報文來攜帶認證信息&#xff0c;但這不意味著只有Hello報文包含驗證信息。 OSPF的認證機制可…

政府績效考核第三方評估的含義

政府績效考核第三方評估是指由獨立于政府的外部機構&#xff08;如專業評估公司、研究機構或非政府組織&#xff09;對政府部門或其下屬單位的績效進行客觀、公正、系統的評估。其主要目的是通過引入獨立的第三方評估機構&#xff0c;對政府績效進行科學、全面的考核&#xff0…

【AIGC調研系列】Qwen2與llama3對比的優勢

Qwen2與Llama3的對比中&#xff0c;Qwen2展現出了多方面的優勢。首先&#xff0c;從性能角度來看&#xff0c;Qwen2在多個基準測試中表現出色&#xff0c;尤其是在代碼和數學能力上有顯著提升[1][9]。此外&#xff0c;Qwen2還在自然語言理解、知識、多語言等多項能力上均顯著超…

肺結節14問,查出肺結節怎么辦?哪些能用中醫調治消散?快來了解一下吧

近些年&#xff0c;隨著大眾防癌意識的加強&#xff0c;和胸部低劑量CT的普及&#xff0c;肺結節的檢出率也逐年升高&#xff0c;不少患者CT報告上&#xff0c;寫著“肺小結”“肺部磨玻璃結節”的字樣&#xff0c;當你看到這幾個字時&#xff0c;會不會瞬間緊張起來&#xff1…

編程規范-代碼檢測-格式化-規范化提交

適用于vue項目的編程規范 – 在多人開發時統一編程規范至關重要 1、代碼檢測 --Eslint Eslint&#xff1a;一個插件化的 javascript 代碼檢測工具 在 .eslintrc.js 文件中進行配置 // ESLint 配置文件遵循 commonJS 的導出規則&#xff0c;所導出的對象就是 ESLint 的配置對…

簡化電動汽車充電器和光伏逆變器的高壓電流檢測

在任何電氣系統中&#xff0c;電流都是一個至關重要的參數。電動汽車 (EV) 充電系統和太陽能系統都需要檢測電流的大小&#xff0c;以便控制和監測功率轉換、充電和放電。電流傳感器通過監測分流電阻器上的壓降或導體中電流產生的磁場來測量電流。 金屬氧化物半導體場效應晶體…

DBeaver連接MySQL提示“Public Key Retrieval is not allowed“問題的解決方式

問題描述 客戶端root用戶連接數據庫出現出現Public Key Retrieval is not allowed 原因分析&#xff1a; 加上allowPublicKeyRetrievalfalse&#xff1a; 解決方案&#xff1a; allowPublicKeyRetrievaltrue&#xff1a;

Java Web學習筆記14——BOM對象

BOM&#xff1a; 概念&#xff1a;瀏覽器對象模型&#xff08;Browser Object Model&#xff09;&#xff0c;允許JavaScript與瀏覽器對話&#xff0c;JavaScript將瀏覽器的各個組成部分封裝為對象。 組成&#xff1a; Window&#xff1a;瀏覽器窗口對象 介紹&#xff1a;瀏覽…

opencv銳化卷積核的定義和應用(圖像銳化)。

定義銳化卷積核 卷積核&#xff08;Kernel&#xff09;是一個小矩陣&#xff0c;它用于在圖像處理操作中&#xff0c;比如模糊、銳化、邊緣檢測等。卷積核通過卷積操作應用于圖像像素&#xff0c;產生新的圖像。 在銳化操作中&#xff0c;我們通常使用一個 3x3 的卷積核。以下…