推薦系統(十七):在TensorFlow中用戶特征和商品特征是如何Embedding的?

在前面幾篇關于推薦模型的文章中,筆者均給出了示例代碼,有讀者反饋——想知道在 TensorFlow 中用戶特征和商品特征是如何 Embedding 的?因此,筆者特意寫作此文加以解答。

1. 何為 Embedding ?

關于 Embedding,筆者很久之前寫過一篇文章《推薦系統(十一):推薦系統中的 Embedding》,現在看來,差強人意,不過,對 Embedding 的概念解讀還是不錯的,只是缺乏代碼案例解讀。在本文中,筆者將基于 TensorFlow 來做解讀,讓讀者加深理解。

如下圖所示,為一個極簡(CTR 和 CVR 共享了交互層) “雙塔模型”(詳見文章《推薦系統(十五):基于雙塔模型的多目標商品召回/推薦系統》),簡單解讀一下:

  1. User Feature 和 Item Feature 先經過 Embedding Layer 處理,得到特征的 Embedding;
  2. User Feature Embedding 和 Item Feature Embedding 經過 Concat Layer 連接后輸入到 DNN 網絡;這樣直接 Concat 得到的 Embedding 結果被稱為 User 和 Item 的 “表示(Representation)”,顯然,這種 “表示” 比較粗糙;
  3. 經過 MLP 處理,得到 User Vector 和 Item Vector,相較于上一步的 “表示形式”,User Vector 和 Item Vector 要 “精細” 得多,是真正意義上的 User Embedding 和 Item Embedding。
  4. User Embedding 和 Item Embedding 計算內積后經過 Sigmoid 函數處理(即圖中的 Prediction),即可得到一個 0~1 之間的數值,即概率。
  5. 對于商品點擊(1-點擊,0-未點擊)和商品轉化(1-轉化,0-未轉化)這種二分類問題,結合模型預測的概率和樣本 Label,很容易計算出損失(二分類問題一般采用交叉墑損失)。
  6. 對于 CTR 和 CVR 這種多任務場景,需要將 CTR Loss 和 CVR Loss 加權融合作為最終的損失,進而指導訓練模型。
    在這里插入圖片描述

2.特征工程中的 Embedding

2.1 ID 類特征

在 User Feature 和 Item Feature 中,User ID 和 Item ID 是最為重點的特征之一,是典型的 “高維稀疏” 特征。直接以原始數據形式輸入模型是不行的,必須經過 Embedding Layer 的處理。在此,以 Item ID 為例,Embedding 處理的代碼如下:

# 模擬生成商品特征,其中 item_id 取值[1, 10000]
num_items = 10000
item_data = {'item_id': np.arange(1, num_items + 1),'item_category': np.random.choice(['electronics', 'books', 'clothing'], size=num_items),'item_brand': np.random.choice(['brandA', 'brandB', 'brandC'], size=num_items),'item_price': np.random.randint(1, 199, size=num_items)
}# 基于 TensorFlow 對原始的 item_id 進行 Embedding 處理,分為兩步
item_id = feature_column.categorical_column_with_identity('item_id', num_buckets=num_items)
item_id_emb = feature_column.embedding_column(item_id, dimension=8)

1. 分類列的創建:categorical_column_with_identity

item_id = feature_column.categorical_column_with_identity('item_id', num_buckets=num_items)
  • 功能:將輸入的整數 item_id 直接映射為分類標識。例如,若 num_items=1000,則輸入的 item_id 必須是 [0,
    1, 2, …, 999] 范圍內的整數。
  • 本質:這類似于對 item_id 做 One-Hot 編碼(但底層實現更高效,不顯式生成稀疏矩陣)。

2.嵌入列的創建:embedding_column

item_id_emb = feature_column.embedding_column(item_id, dimension=8)
  • 功能:將高維稀疏的分類 ID(如 num_items=1000 維的 One-Hot 向量)映射到低維稠密的連續向量空間(維度為 8)。
  • 關鍵點:嵌入矩陣的維度是 [num_items, 8],即每個 item_id 對應一個 8 維向量。這個嵌入矩陣是一個可訓練參數,初始值隨機(如 Glorot 初始化),通過神經網絡的反向傳播逐步優化。

3.嵌入向量的訓練過程

  • 何時生成:嵌入矩陣的值并非預先計算,而是在模型訓練時動態學習。
  • 如何學習
    1-輸入數據中的 item_id 會觸發嵌入層查找對應的 8 維向量。
    2-在反向傳播時,優化器(如Adam)根據損失函數的梯度調整嵌入矩陣的值。
    3-模型通過最小化損失函數,迫使相似的 item_id 在嵌入空間中靠近,從而捕捉潛在語義關系(如用戶行為中的物品相似性)。

4. 嵌入層的底層實現

當你在模型中調用 item_id_emb 時,TensorFlow 會隱式完成以下操作:

# 偽代碼解釋
embedding_matrix = tf.Variable(  # 可訓練參數initial_value=tf.random.uniform([num_items, 8]), name="item_id_embedding"
)
# 根據輸入的item_id查找嵌入向量
item_id_emb = tf.nn.embedding_lookup(embedding_matrix, input_item_ids)

5. 嵌入的優勢

  • 降維:將高維稀疏特征壓縮為低維稠密向量(例如從1000 維的 One-Hot 降到 8 維)。
  • 語義學習:模型自動學習嵌入空間中的幾何關系(如相似物品的向量距離更近)。
  • 泛化性:即使某些 item_id 在訓練數據中出現次數少,其嵌入向量仍可通過相似物品的梯度更新得到合理表示。

6. 完整流程示例

假設你的模型是一個推薦系統,處理流程如下:

  • 輸入層:接收原始特征(如 {‘item_id’: 5})。
  • 特征轉換:通過 item_id_emb 將 item_id=5 轉換為一個 8 維向量。
  • 神經網絡:將嵌入向量輸入全連接層(如 DNN)、激活函數等后續結構。
  • 訓練:通過損失函數(如點擊率預測的交叉熵)反向傳播,更新嵌入矩陣和其他權重。

2.2 類別特征

以用戶性別為例:

# 模擬生成用戶特征,其中用戶性別是可以枚舉的類別特征:male,female
user_data = {'user_id': np.arange(1, num_users + 1),'user_age': np.random.randint(18, 65, size=num_users),'user_gender': np.random.choice(['male', 'female'], size=num_users),'user_occupation': np.random.choice(['student', 'worker', 'teacher'], size=num_users),'city_code': np.random.randint(1, 2856, size=num_users),  # 城市編碼,中國有 2856 個城市'device_type': np.random.randint(0, 5, size=num_users)  # 設備類型(0=Android,1=iOS等)
}
# 對性別特征進行 Embedding 處理
user_gender = feature_column.categorical_column_with_vocabulary_list('user_gender', ['male', 'female'])
user_gender_emb = feature_column.embedding_column(user_gender, dimension=2)

1. 定義分類特征列

代碼如下:

user_gender = feature_column.categorical_column_with_vocabulary_list('user_gender', ['male', 'female'])
  • 作用:將字符串類型的性別特征(如 ‘male’ 或 ‘female’)映射為整數索引。
  • 細節: 輸入特征名為 ‘user_gender’,詞匯表為 [‘male’, ‘female’]。 模型會根據詞匯表將 ‘male’ 編碼為
    0,‘female’ 編碼為 1。 如果輸入的值不在詞匯表中(如 ‘unknown’),默認會被映射為 -1(可通過
    num_oov_buckets 參數調整)。

2. 創建嵌入列(Embedding Column)

如下代碼:

user_gender_emb = feature_column.embedding_column(user_gender, dimension=2)
  • 作用:將稀疏的整數索引轉換為密集的低維向量(嵌入向量)。
  • 細節
    1.嵌入矩陣的維度:嵌入矩陣的形狀為 (vocab_size, embedding_dimension),即 (2, 2)。 行數 2:對應詞匯表中的兩個類別(male 和 female);列數 2:指定的嵌入維度 dimension=2。
    2.嵌入初始化:嵌入向量的初始值默認通過隨機均勻分布生成(可通過 initializer 參數自定義)。
    3.訓練過程:嵌入向量會在模型訓練時通過反向傳播自動優化,學習與任務相關的語義表示。

2.3 數值特征

數值特征是一種簡單的特征,按照常理,可以直接用原始數據進行模型訓練和預測,然而,由于不同類型的數值特征存在 “量綱差異”,從而使得不同類型的數值特征 “不可比較”(如年齡數值區間(0~150),價格區間(0~10000000)),因此,數值特征也需要處理,比如標準化/歸一化,好處如下:

  • 統一特征尺度,避免梯度下降因不同特征量綱而震蕩。
  • 所有特征在相同尺度下,模型權重更新更均衡。
  • L1/L2正則化對所有特征施加相似強度的懲罰。

以用戶年齡為例:

scaler_age = StandardScaler()
df['user_age'] = scaler_age.fit_transform(df[['user_age']])
user_age = feature_column.numeric_column('user_age')

1. 數據標準化處理

代碼如下:

scaler_age = StandardScaler()
  • 作用:創建一個標準化處理器,用于對數值型特征(如年齡)進行均值方差標準化(Z-Score標準化)。
  • 細節:StandardScaler 是 scikit-learn 庫中的標準化工具,核心操作為:標準化值 =(原始值?均值)/ 標準差 標準化后,數據分布均值為 0,標準差為 1,消除量綱差異。適用于數值范圍大、分布不均衡的特征(如年齡范圍可能從 0 到 100)。

2.應用標準化到年齡列

代碼如下:

df['user_age'] = scaler_age.fit_transform(df[['user_age']])
  • 作用:對 DataFrame 中的 user_age 列進行擬合和轉換,實現標準化。
  • 細節
    1. fit_transform 兩步合并。fit:計算 user_age 列的均值(μ)和標準差(σ)。transform:使用公式:(X?μ)/ σ,對所有樣本進行標準化。
    2. 示例:假設原始年齡數據為 [20, 30, 40],均值為 30,標準差為 8.16,標準化后為 [-1.22, 0, 1.22]。
    3. 存儲參數:scaler_age 對象會保存計算出的 μ 和 σ,便于后續對新數據(如測試集)使用 transform 而非重新擬合。

3.創建數值特征列

代碼如下:

user_age = feature_column.numeric_column('user_age')
  • 作用:定義 TensorFlow 模型可接收的數值型特征列,將標準化后的年齡值直接輸入模型。
  • 細節
    1. 輸入數據類型:該列接收的是連續數值(如標準化后的 -1.22、0、1.22)。
    2. 模型中的處理:在訓練時,每個樣本的 user_age 值會以浮點數形式直接傳遞給神經網絡,無需進一步編碼。
    3. 參數擴展性: 可結合其他參數增強特征(例如 normalizer_fn 可添加自定義歸一化,但此處已提前標準化,通常不再需要)。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/73831.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/73831.shtml
英文地址,請注明出處:http://en.pswp.cn/web/73831.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

c++第三課(基礎c)

1.前文 2.break 3.continue 4.return 0 1.前文 上次寫文章到現在&#xff0c;有足足這么多天&#xff08;我也不知道&#xff0c;自己去數吧&#xff09; 開始吧 2.break break是結束循環的意思 舉個栗子 #include<bits/stdc.h> using namespace std; int main(…

關于ArcGIS中加載影像數據,符號系統中渲染參數的解析

今天遇到一個很有意思的問題&#xff0c;故記錄下來&#xff0c;以作參考和后續的研究。歡迎隨時溝通交流。如果表達錯誤或誤導&#xff0c;請各位指正。 正文 當我們拿到一幅成果影像數據的時候&#xff0c;在不同的GIS軟件中會有不同效果呈現&#xff0c;但這其實是影像是…

北森測評的經驗

測評經驗記錄 首先聲明&#xff0c;北森測評就是垃圾&#xff0c;把行測拿過來就能評測能力了&#xff1f;直接去參加公務員考試更好。網上2024年的題庫 評測分為 閱讀理解數學計算圖形題性格測試 圖形題 總結的經驗如下 圖形推理題 一組圖形&#xff0c;推測另一組圖形最…

Java/Scala是什么

Java 和 Scala 是兩種運行在 ?JVM&#xff08;Java 虛擬機&#xff09;? 上的編程語言&#xff0c;雖然共享相同的運行時環境&#xff0c;但它們在設計哲學、語法特性和適用場景上有顯著差異。以下是兩者的詳細解析&#xff1a; ?1. Java ?核心特性 ?面向對象&#xff1…

SQL Server 備份相關信息查看

目錄標題 一、統計每個數據庫在不同備份目錄和備份類型下的備份次數&#xff0c;以及最后一次備份的時間整體功能詳細解釋 二、查詢所有完整數據庫備份的信息&#xff0c;包括備份集 ID、數據庫名稱、備份開始時間和備份文件的物理設備名稱&#xff0c;并按備份開始時間降序排列…

CANoe入門——CANoe的診斷模塊,調用CAPL進行uds診斷

目錄 一、診斷窗口介紹 二、診斷數據庫文件管理 三、添加基礎診斷描述文件&#xff08;若沒有CDD/ODX/PDX文件&#xff09;并使用對應的診斷功能進行UDS診斷 3.1、添加基礎診斷描述文件 3.2、基于基礎診斷&#xff0c;使用診斷控制臺進行UDS診斷 3.2.1、生成基礎診斷 3.…

【數據結構】二叉樹的遞歸

數據結構系列三&#xff1a;二叉樹(二) 一、遞歸的原理 1.全訪問 2.主角 3.返回值 4.執等 二、遞歸的化關系思路 三、遞歸的方法設計 一、遞歸的原理 1.全訪問 方法里調用方法自己&#xff0c;就會形成調用方法本身的一層一層全新相同的調用&#xff0c;方法的形參設置…

Imgui處理glfw的鼠標鍵盤的方法

在Imgui初始化時&#xff0c;會重新接手glfw的鍵盤鼠標事件。也就是遇到glfw的鍵盤鼠標事件時&#xff0c;imgui先會運行自己的處理過程&#xff0c;然后再去處理用戶自己注冊的glfw的鍵盤鼠標事件。 看imgui_impl_glfw.cpp源碼的安裝回調函數部分代碼 void ImGui_ImplGlfw_In…

【LVS】負載均衡群集部署(DR模式)

部署前IP分配 DR服務器&#xff1a;192.168.166.101 vip&#xff1a;192.168.166.100 Web服務器1&#xff1a;192.168.166.104 vip&#xff1a;192.168.166.100 Web服務器2&#xff1a;192.168.166.107 vip&#xff1a;192.168.166.100 NFS服務器&#xff1a;192.168.166.108 …

C++Primer學習(14.1 基本概念)

當運算符作用于類類型的運算對象時&#xff0c;可以通過運算符重載重新定義該運算符的含義。明智地使用運算符重載能令我們的程序更易于編寫和閱讀。舉個例子&#xff0c;因為在Sales_item類中定義了輸入、輸出和加法運算符&#xff0c;所以可以通過下述形式輸出兩個Sales_item…

計算機視覺準備八股中

一邊記錄一邊看&#xff0c;這段實習跑路之前運行完3DGAN&#xff0c;弄完潤了&#xff0c;現在開始記憶八股 1.CLIP模型的主要創新點&#xff1a; 圖像和文本兩種不同模態數據之間的深度融合、對比學習、自監督學習 2.等效步長是每一步操作步長的乘積 3.卷積層計算輸入輸出…

基于大語言模型的智能音樂創作系統——從推薦到生成

一、引言&#xff1a;當AI成為音樂創作伙伴 2023年&#xff0c;一款由大語言模型&#xff08;LLM&#xff09;生成的鋼琴曲《量子交響曲》在Spotify沖上熱搜&#xff0c;引發音樂界震動。傳統音樂創作需要數年專業訓練&#xff0c;而現代AI技術正在打破這一壁壘。本文提出一種…

Mysql---鎖篇

1&#xff1a;MySQL 有哪些鎖&#xff1f; 全局鎖 flush tables with read lock 整個數據庫就處于只讀狀態了 unlock tables 釋放全局鎖 全局鎖主要應用于做全庫邏輯備份&#xff0c;這樣在備份數據庫期間&#xff0c;不會因為數據或表結構的更新&#xff0c;而出現備份文件的數…

VLAN綜合實驗二

一.實驗拓撲&#xff1a; 二.實驗需求&#xff1a; 1.內網Ip地址使用172.16.0.0/分配 2.sw1和SW2之間互為備份 3.VRRP/STP/VLAN/Eth-trunk均使用 4.所有Pc均通過DHCP獲取IP地址 5.ISP只能配置IP地址 6.所有…

GEO(生成引擎優化)實施策略全解析:從用戶意圖到效果追蹤

——基于行業實證的AI信源占位方法論 ?一、理解用戶查詢&#xff1a;構建AI語料的核心起點 生成式AI的內容推薦邏輯以用戶意圖為核心&#xff0c;?精準捕捉高頻問題是GEO優化的第一步。企業需通過以下方法挖掘用戶真實需求&#xff1a; ?AI對話日志分析&#xff1a; 分析用…

HTML基礎及進階

目錄 一、HTML基礎 1.什么是HTML 2.常用標簽 &#xff08;1&#xff09;標題標簽&#xff1a;h1-h6數字越小文字會越大&#xff0c;這個標簽會占一整行 &#xff08;2&#xff09;加粗標簽&#xff1a; &#xff08;3&#xff09;換行標簽&#xff1a; &#xff08;4&am…

MSTP與鏈路聚合技術

MSTP&#xff08;多生成樹協議&#xff09; 簡介 MSTP&#xff08;多生成樹協議&#xff09;是Spanning Tree Protocol&#xff08;STP&#xff09;的改進版&#xff0c;支持網絡中使用多條生成樹&#xff0c;并根據用戶需求限制生成樹間的路徑。MSTP將多個VLAN映射到一棵生成…

ModuleNotFoundError: No module named ‘ml_logger.logbook‘

問題 (legion) zhouy24RL-DSlab:~/zhouy24Files/legion/LEGION$ python main.py ML_LOGGER_USER is not set. This is required for online usage. Traceback (most recent call last): File “main.py”, line 7, in from mtrl.app.run import run File “/data/zhouy24File…

c# ftp上傳下載 幫助類

工作中FTP的上傳和下載還是很常用的。如下載打標數據,上傳打標結果等。 這個類常用方法都有了:上傳,下載,判斷文件夾是否存在,創建文件夾,獲取當前目錄下文件列表(不包括文件夾) ,獲取當前目錄下文件列表(不包括文件夾) ,獲取FTP文件列表(包括文件夾), 獲取當前目…

PyTorch 分布式訓練(Distributed Data Parallel, DDP)簡介

PyTorch 分布式訓練&#xff08;Distributed Data Parallel, DDP&#xff09; 一、DDP 核心概念 torch.nn.parallel.DistributedDataParallel 1. DDP 是什么&#xff1f; Distributed Data Parallel (DDP) 是 PyTorch 提供的分布式訓練接口&#xff0c;DistributedDataPara…