在商品推薦系統中,粗排和精排環節的知識蒸餾方法主要通過復雜模型(Teacher)指導簡單模型(Student)的訓練,以提升粗排效果及與精排的一致性。本文將以淘寶的一篇論文《Privileged Features Distillation at Taobao Recommendations》中介紹的 PFD(Privileged Features Distillation)方法為例實現一個Demo,幫助讀者學習知識蒸餾。
1.知識蒸餾方法概述
知識蒸餾誕生至今,早已不局限于粗排,而是在粗排和精排均有應用。粗排和精排的知識蒸餾核心在于通過不同形式的知識遷移(logits、排序結果、特征)提升模型效果與一致性。粗排側重從精排獲取排序偏好,而精排側重模型壓縮。實際應用中需結合業務場景選擇蒸餾策略,并權衡性能與效果。本節簡要介紹一下知識蒸餾方法。
一、粗排環節的典型蒸餾方法
粗排需平衡性能和效果,通常以精排為Teacher進行知識遷移,主要方法包括:
(1)Logits蒸餾
- 原理:利用精排模型的輸出logits(未歸一化的預測值)作為軟標簽(soft label),指導粗排模型學習。通過引入溫度系數(Temperature Scaling)調整軟標簽的分布,增強非主導類別的信息傳遞。
- 損失函數:粗排模型的損失由兩部分組成: Hard Loss:基于真實標簽的交叉熵損失; Soft Loss:基于精排輸出logits的KL散度或MSE損失。
- 應用:美團、愛奇藝等采用兩階段訓練,先訓練精排Teacher,再固定其參數指導粗排Student56。
(2)排序結果蒸餾
-
原理:直接利用精排輸出的有序列表信息,構造粗排的訓練樣本。常見方法包括:
1. Point-wise:將精排Top-K結果作為正樣本,其余作為負樣本,并引入位置權重。2. Pair-wise:從精排列表中隨機抽取商品對,學習偏序關系(如BPR損失)。、3. List-wise:通過NDCG等指標對齊粗排與精排的整體排序。
-
優勢:緩解樣本選擇偏差,增強粗排對精排排序偏好的擬合。
(3)特征蒸餾
- 原理:遷移精排模型的中間層特征,要求粗排和精排的網絡結構部分對齊。例如: 隱層特征對齊:通過MSE損失約束粗排與精排的隱層輸出(如淘寶的 PFD(Privileged Features Distillation) 方法)。
- 優勢特征蒸餾:將精排使用的交叉特征等“特權特征”遷移到粗排(如用戶與商品的交互特征)。
- 應用:淘寶在 KDD 2020 提出的 PFD 方法中,精排 Teacher 使用交叉特征,粗排 Student 僅用基礎特征,通過蒸餾提升效果。
二、精排環節的典型蒸餾方法
精排蒸餾主要用于模型壓縮,將復雜模型(如集成模型)的能力遷移至輕量級模型:
(1)Logits蒸餾
- 原理:與粗排類似,使用復雜精排模型的 logits 指導輕量級 Student 模型訓練。例如: 阿里 Rocket Launching 框架:Teacher 和 Student 共享 Embedding 層,聯合訓練并通過 logits 對齊。
- 改進:愛奇藝雙 DNN 模型進一步約束 Student 隱層與 Teacher 隱層的激活值相似性。
(2)多目標蒸餾
- 原理:將精排的多任務輸出(如CTR、CVR)遷移至 Student。例如: 騰訊在 SIGIR 2021 提出通過 KL 散度對齊多任務 logits,提升粗排/召回模型的多目標一致性。
- 損失設計:結合多任務損失和蒸餾損失,如加權交叉熵或對比學習損失。
三、關鍵技術與實踐
(1)溫度系數(Temperature)
調節 softmax 輸出的平滑度,溫度值越大,分布越平滑,幫助 Student 學習 Teacher 的暗知識(Dark Knowledge)。
(2)兩階段訓練 vs 聯合訓練
- 兩階段:先獨立訓練 Teacher,再固定其參數指導 Student(穩定性高)。
- 聯合訓練:Teacher 和 Student 同步更新(減少耗時,但需設計梯度阻斷防止相互干擾)。
(3)實際應用案例
- 美團:通過對比學習強化粗排與精排的特征對齊,粗排CTR提升 0.15%。
- 淘寶:優勢特征蒸餾使粗排 CTR 提升 5%,精排CVR提升 2.3%。
- 騰訊音樂:多目標蒸餾在粗排階段實現閱讀時長與點擊率的聯合優化。
2. PFD(Privileged Features Distillation)方法介紹
PFD(Privileged Features Distillation)方法出自論文《Privileged Features Distillation at Taobao Recommendations》。論文中描述:在離線環境下同時訓練兩個模型:一個學生模型以及一個教師模型。其中學生模型和原始模型完全相同,而教師模型額外利用了優勢特征, 其準確率也因此更高。通過將教師模型蒸餾出的知識(Knowlege, 本文特指教師模型中最后一層的輸出)傳遞給學生模型,可以輔助其訓練以進一步提升準確率。在線上服務時,我們只抽取學生模型進行部署,因為輸入不依賴于優勢特征,離線、在線的一致性得以保證。在 PFD 中,所有的優勢特征都被統一到教師模型作為輸入,加入更多的優勢特征往往能帶來模型更高的準確度。
PFD 不同于常見的模型蒸餾(Model Disitillation, 簡稱 MD)。 在 MD 中,教師模型和學生模型處理同樣的輸入特征,其中教師模型會比學生模型更為復雜, 比如,教師模型會用更深的網絡結構來指導使用淺層網絡的學生模型進行學習。在 PFD 中,教師和學生模型會使用相同網絡結構,而處理不同的輸入特征。MD 和 PFD 兩者的差異如下圖所示。
如上圖所示:模型蒸餾(Model Distill, 簡稱 MD)與優勢特征蒸餾(PFD)對比; 在 MD 中,知識(Knowledge)是從更復雜的模型中蒸餾出來,而在 PFD 中,知識是從優勢特征中蒸餾出來。
由此可見,我們可以訓練一個使用了復雜特征(如交叉特征)的模型作為老師,指導訓練一個僅使用簡單特征的學生模型,從而實現提升模型效果,而又不增加線上耗時(線上使用交叉特征等復雜特征通常會導致耗時大幅增加,因此,在粗排環節幾乎不直接使用交叉特征)。
3.基于 Wide&Deep 指導訓練 TowTower 模型
基于 PFD 方法的原理,在本節我們將實現一個知識蒸餾的 Demo。其中,Teacher 模型基于 Wide&Deep 模型;Student 模型則采用簡單的“雙塔模型”。為了簡單起見,Wide&Deep 模型和 “雙塔模型” 均為單目標(CTR )模型。
3.1 模擬數據構造
"""
Part-1:模擬數據構造本部分模擬真實場景,人工構造用戶數據、商品數據、用戶-商品交互數據(點擊、轉化),并進行必要的預處
"""
# 設置隨機種子保證可復現性
np.random.seed(42)
tf.random.set_seed(42)# 生成用戶、商品和交互數據
num_users = 100
num_items = 200
num_interactions = 1000# 用戶特征
user_data = {'user_id': np.arange(1, num_users + 1),'user_age': np.random.randint(18, 65, size=num_users),'user_gender': np.random.choice(['male', 'female'], size=num_users),'user_occupation': np.random.choice(['student', 'worker', 'teacher'], size=num_users),'city_code': np.random.randint(1, 2856, size=num_users),'device_type': np.random.randint(0, 5, size=num_users)
}# 商品特征
item_data = {'item_id': np.arange(1, num_items + 1),'item_category': np.random.choice(['electronics', 'books', 'clothing'], size=num_items),'item_brand': np.random.choice(['brandA', 'brandB', 'brandC'], size=num_items),'item_price': np.random.randint(1, 199, size=num_items)
}# 交互數據
# 包括:點擊和轉化(購買)數據
interactions = []
for _ in range(num_interactions):user_id = np.random.randint(1, num_users + 1)item_id = np.random.randint(1, num_items + 1)# 點擊標簽。0: 未點擊, 1: 點擊。在真實場景中可通過客戶端埋點上報獲得用戶的點擊行為數據click_label = np.random.randint(0, 2)interactions.append([user_id, item_id, click_label])# 合并用戶特征、商品特征和交互數據
interaction_df = pd.DataFrame(interactions, columns=['user_id', 'item_id', 'click_label'])
user_df = pd.DataFrame(user_data)
item_df = pd.DataFrame(item_data)
df = interaction_df.merge(user_df, on='user_id').merge(item_df, on='item_id')# 劃分數據集
labels = df[['click_label']]
features = df.drop(['click_label'], axis=1)
train_features, test_features, train_labels, test_labels = train_test_split(features, labels, test_size=0.2,random_state=42)
3.2 特征工程
代碼如下,相較于 Student 模型,作為 Teacher 的 Wide&Deep 模型采用了更多的特征,特別是交叉特征。
"""
Part-2:特征工程本部分對原始用戶數據、商品數據、用戶-商品交互數據進行分類處理,加工為模型訓練需要的特征1.數值型特征:如用戶年齡、價格,少數場景下可直接使用,但最好進行標準化,從而消除量綱差異2.類別型特征:需要進行 Embedding 處理3.交叉特征:由于維度高,需要哈希技巧處理高維組合特征
"""
# 用戶特征處理
user_id = feature_column.categorical_column_with_identity('user_id', num_buckets=num_users + 1)
user_id_emb = feature_column.embedding_column(user_id, dimension=8)scaler_age = StandardScaler()
df['user_age'] = scaler_age.fit_transform(df[['user_age']])
user_age = feature_column.numeric_column('user_age')user_gender = feature_column.categorical_column_with_vocabulary_list('user_gender', ['male', 'female'])
user_gender_emb = feature_column.embedding_column(user_gender, dimension=2)user_occupation = feature_column.categorical_column_with_vocabulary_list('user_occupation',['student', 'worker', 'teacher'])
user_occupation_emb = feature_column.embedding_column(user_occupation, dimension=2)city_code_column = feature_column.categorical_column_with_identity(key='city_code', num_buckets=2856)
city_code_emb = feature_column.embedding_column(city_code_column, dimension=8)device_types_column = feature_column.categorical_column_with_identity(key='device_type', num_buckets=5)
device_types_emb = feature_column.embedding_column(device_types_column, dimension=8)# 商品特征處理
item_id = feature_column.categorical_column_with_identity('item_id', num_buckets=num_items + 1)
item_id_emb = feature_column.embedding_column(item_id, dimension=8)scaler_price = StandardScaler()
df['item_price'] = scaler_price.fit_transform(df[['item_price']])
item_price = feature_column.numeric_column('item_price')item_category = feature_column.categorical_column_with_vocabulary_list('item_category',['electronics', 'books', 'clothing'])
item_category_emb = feature_column.embedding_column(item_category, dimension=2)item_brand = feature_column.categorical_column_with_vocabulary_list('item_brand', ['brandA', 'brandB', 'brandC'])
item_brand_emb = feature_column.embedding_column(item_brand, dimension=2)"""
交叉特征預處理
"""
# 使用TensorFlow的交叉特征(crossed_column)定義了Wide部分的特征列,主要用于捕捉用戶與商品特征之間的組合效應
# 將用戶ID(user_id)和商品ID(item_id)組合成一個新特征,捕捉**“特定用戶對特定商品的偏好”**
# 用戶ID和商品ID的組合總數可能非常大(num_users * num_items),直接編碼會導致維度爆炸。
# hash_bucket_size=10000:使用哈希函數將組合映射到固定數量的桶(10,000個),控制內存和計算開銷,適用于稀疏高維特征(如用戶-商品對)
user_id_x_item_id = feature_column.crossed_column([user_id, item_id], hash_bucket_size=10000)
user_id_x_item_id = feature_column.indicator_column(user_id_x_item_id)
user_gender_x_item_category = feature_column.crossed_column([user_gender, item_category], hash_bucket_size=1000)
user_gender_x_item_category = feature_column.indicator_column(user_gender_x_item_category)
user_occupation_x_item_brand = feature_column.crossed_column([user_occupation, item_brand], hash_bucket_size=1000)
user_occupation_x_item_brand = feature_column.indicator_column(user_occupation_x_item_brand)"""
特征列定義
"""
# ESMM 模型相關特征列定義
user_tower_columns = [user_id_emb, user_age, user_gender_emb, user_occupation_emb, city_code_emb, device_types_emb]
item_tower_columns = [item_id_emb, item_category_emb, item_brand_emb, item_price]# Wide&Deep 模型相關特征列定義
deep_feature_columns = [user_id_emb,user_age,user_gender_emb,user_occupation_emb,item_id_emb,item_category_emb,item_brand_emb,item_price
]wide_feature_columns = [user_id_x_item_id,user_gender_x_item_category,user_occupation_x_item_brand
]
3.3 模型架構設計
Teacher 模型:采用 Wide&Deep 模型(模擬精排模型);Student 模型:采用普通 “雙塔模型”(模擬粗排模型)。
"""
Part-3:模型架構設計
"""
# 教師模型:采用 Wide&Deep 模型
class WideDeepModel(tf.keras.Model):"""Wide部分:線性模型,擅長記憶(Memorization),通過交叉特征捕捉明確的特征組合模式(如用戶A常點擊商品B)。Deep部分:深度神經網絡,擅長泛化(Generalization),通過嵌入向量學習特征的潛在關系(如女性用戶與服裝品類的關聯)。結合優勢:同時處理稀疏特征(如用戶ID、商品ID)和密集特征(如價格、年齡),平衡記憶與泛化能力"""def __init__(self, wide_feature_columns, deep_feature_columns):super(WideDeepModel, self).__init__()# Wide部分(線性模型)self.linear_features = tf.keras.layers.DenseFeatures(wide_feature_columns)self.wide_out = tf.keras.layers.Dense(1, activation='sigmoid')# Deep部分(深度神經網絡)self.dnn_features = tf.keras.layers.DenseFeatures(deep_feature_columns)self.dnn_layer = tf.keras.Sequential([tf.keras.layers.Dense(64, activation='relu'),tf.keras.layers.Dense(32, activation='relu')])self.deep_out = tf.keras.layers.Dense(1, activation='sigmoid')def call(self, inputs):# Wide部分:預測CTRlinear_features = self.linear_features(inputs)ctr_wide_logits = self.wide_out(linear_features)# Deep部分:預測CTRdnn_features = self.dnn_features(inputs)dnn_layer = self.dnn_layer(dnn_features)ctr_deep_logits = self.deep_out(dnn_layer)# 將Wide和Deep的logits相加,通過Sigmoid輸出點擊概率ctr_logits = tf.sigmoid(ctr_wide_logits + ctr_deep_logits)# 返回return {'ctr_logits': ctr_logits}# 學生模型:采用普通雙塔模型
class TowTowerStudent(tf.keras.Model):"""普通雙塔模型:User Tower + Item Tower"""def __init__(self, user_columns, item_columns):super(TowTowerStudent, self).__init__()# 共享特征處理層self.user_feature = tf.keras.layers.DenseFeatures(user_columns)self.item_feature = tf.keras.layers.DenseFeatures(item_columns)# User塔self.user_tower = tf.keras.Sequential([tf.keras.layers.Dense(64, activation='relu'),tf.keras.layers.Dense(32, activation='relu'),])# Item塔self.item_tower = tf.keras.Sequential([tf.keras.layers.Dense(64, activation='relu'),tf.keras.layers.Dense(32, activation='relu'),])self.tower_out = tf.keras.layers.Dense(1, activation='sigmoid')def call(self, inputs):# 雙塔結構user_feature = self.user_feature(inputs)item_feature = self.item_feature(inputs)user_emb = self.user_tower(user_feature)item_emb = self.item_tower(item_feature)# CTR預測# 點積交互(即用戶Embedding和商品Embedding求取余弦相似度)interaction = tf.keras.layers.Dot(axes=1)([user_emb, item_emb])ctr_logits = self.tower_out(interaction)return {'ctr_logits': ctr_logits}
3.4 知識蒸餾實現
本質上就是用 Teacher 模型指導 Student 模型訓練。使得 Student 模型的預測結果逼近 Teacher 模型。
"""
Part-4:知識蒸餾實現
"""
class DistillationModel(tf.keras.Model):def __init__(self, teacher, student):super(DistillationModel, self).__init__()self.teacher = teacherself.student = student# 溫度參數:典型取值2-5之間self.temperature = 2.0def compile(self, optimizer, metrics, student_loss_fn, distillation_loss_fn):super().compile(optimizer=optimizer, metrics=metrics)self.student_loss_fn = student_loss_fnself.distillation_loss_fn = distillation_loss_fndef call(self, inputs):# 推理時直接使用學生模型return self.student(inputs)def train_step(self, data):# 解包數據x, y = data# 教師模型前向傳播(僅推理)teacher_predictions = self.teacher(x, training=False) # 凍結教師模型teacher_ctr = teacher_predictions['ctr_logits']# 使用tf.GradientTape實現動態梯度計算with tf.GradientTape() as tape:# 學生模型前向傳播student_outputs = self.student(x, training=True)student_ctr = student_outputs['ctr_logits']# 計算學生損失# 學生損失(student_loss):直接擬合真實標簽# y['ctr_logits'] = labels['click_label'],在輸入數據時有定義student_loss_ctr = self.student_loss_fn(y['ctr_logits'], student_ctr)# 計算蒸餾損失distillation_loss_ctr = self.distillation_loss_fn(# 蒸餾損失(distillation_loss):學習教師模型的軟標簽分布teacher_ctr / self.temperature, # 教師輸出軟化student_ctr / self.temperature # 學生輸出對齊)# 總損失total_loss = 0.7 * student_loss_ctr + 0.3 * distillation_loss_ctr# 計算梯度并更新(僅更新學生參數)trainable_vars = self.student.trainable_variablesgradients = tape.gradient(total_loss, trainable_vars)self.optimizer.apply_gradients(zip(gradients, trainable_vars))# 更新指標self.compiled_metrics.update_state(y, {'ctr_logits': student_ctr,})return {m.name: m.result() for m in self.metrics}
3.5 模型訓練與評估
- 第一步:數據準備;
- 第二步:模型初始化;
- 第三步:編譯、訓練 Teacher 模型;
- 第四步:編譯、訓練 Student 模型;
- 第五步:評估、可視化效果
"""
Part-5:模型訓練與評估
"""
# 數據輸入管道
def df_to_dataset(features, labels, shuffle=True, batch_size=32):ds = tf.data.Dataset.from_tensor_slices((dict(features),{# 這里做了一個映射,主要為了對齊學生模型和教師模型的輸出,從而便于計算損失'ctr_logits': labels['click_label']}))if shuffle:ds = ds.shuffle(1000)ds = ds.batch(batch_size)return ds# 轉換數據集
train_ds = df_to_dataset(train_features, train_labels)
test_ds = df_to_dataset(test_features, test_labels, shuffle=False)# 初始化模型
teacher = WideDeepModel(wide_feature_columns, deep_feature_columns)
student = TowTowerStudent(user_tower_columns, item_tower_columns)
distiller = DistillationModel(teacher, student)# 編譯教師模型(先單獨訓練)
teacher.compile(optimizer='adam',loss={'ctr_logits': 'binary_crossentropy'},metrics=['accuracy'],loss_weights=[0.7, 0.3] # 可選:設置不同任務的損失權重
)# 訓練教師模型
print("訓練教師模型...")
teacher.fit(train_ds, epochs=5, validation_data=test_ds)# 編譯蒸餾模型
distiller.compile(optimizer='adam',metrics={'ctr_logits': ['accuracy']},student_loss_fn=tf.keras.losses.BinaryCrossentropy(),distillation_loss_fn=tf.keras.losses.KLDivergence()
)# 訓練學生模型(帶蒸餾)
print("訓練學生模型...")
history = distiller.fit(train_ds, epochs=10, validation_data=test_ds)
print(history.history)# 可視化訓練過程
plt.plot(history.history['accuracy'], label='CTR Accuracy')plt.title('Training Metrics')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.show()
3.6 模型服務化與測試
保存訓練好的學生模型,在另一個工程中可以加載這個模型,并執行預測。
"""
Part-6:模型服務化(示例)
"""
# 保存學生模型
student.save('esmm_student_model')# 加載模型進行推理
loaded_model = tf.keras.models.load_model('esmm_student_model')# 查看模型輸入層名稱
loaded_model.summary()# 示例預測:從 test_features 數據框中提取第一行數據
sample = test_features.iloc[0]sample_dict = {col: tf.expand_dims(value, -1)for col, value in dict(sample).items()
}predictions = loaded_model.predict(sample_dict)
print(f"預測結果:CTR={predictions['ctr_logits'][0][0]:.3f}")
3.7 知識蒸餾完整代碼
完整代碼如下:
import tensorflow as tftf.config.set_visible_devices([], 'GPU') # 禁用GPU設備
from tensorflow import feature_column
from sklearn.model_selection import train_test_split
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler"""
Part-1:模擬數據構造本部分模擬真實場景,人工構造用戶數據、商品數據、用戶-商品交互數據(點擊、轉化),并進行必要的預處
"""
# 設置隨機種子保證可復現性
np.random.seed(42)
tf.random.set_seed(42)# 生成用戶、商品和交互數據
num_users = 100
num_items = 200
num_interactions = 1000# 用戶特征
user_data = {'user_id': np.arange(1, num_users + 1),'user_age': np.random.randint(18, 65, size=num_users),'user_gender': np.random.choice(['male', 'female'], size=num_users),'user_occupation': np.random.choice(['student', 'worker', 'teacher'], size=num_users),'city_code': np.random.randint(1, 2856, size=num_users),'device_type': np.random.randint(0, 5, size=num_users)
}# 商品特征
item_data = {'item_id': np.arange(1, num_items + 1),'item_category': np.random.choice(['electronics', 'books', 'clothing'], size=num_items),'item_brand': np.random.choice(['brandA', 'brandB', 'brandC'], size=num_items),'item_price': np.random.randint(1, 199, size=num_items)
}# 交互數據
# 包括:點擊和轉化(購買)數據
interactions = []
for _ in range(num_interactions):user_id = np.random.randint(1, num_users + 1)item_id = np.random.randint(1, num_items + 1)# 點擊標簽。0: 未點擊, 1: 點擊。在真實場景中可通過客戶端埋點上報獲得用戶的點擊行為數據click_label = np.random.randint(0, 2)interactions.append([user_id, item_id, click_label])# 合并用戶特征、商品特征和交互數據
interaction_df = pd.DataFrame(interactions, columns=['user_id', 'item_id', 'click_label'])
user_df = pd.DataFrame(user_data)
item_df = pd.DataFrame(item_data)
df = interaction_df.merge(user_df, on='user_id').merge(item_df, on='item_id')# 劃分數據集
labels = df[['click_label']]
features = df.drop(['click_label'], axis=1)
train_features, test_features, train_labels, test_labels = train_test_split(features, labels, test_size=0.2,random_state=42)"""
Part-2:特征工程本部分對原始用戶數據、商品數據、用戶-商品交互數據進行分類處理,加工為模型訓練需要的特征1.數值型特征:如用戶年齡、價格,少數場景下可直接使用,但最好進行標準化,從而消除量綱差異2.類別型特征:需要進行 Embedding 處理3.交叉特征:由于維度高,需要哈希技巧處理高維組合特征
"""
# 用戶特征處理
user_id = feature_column.categorical_column_with_identity('user_id', num_buckets=num_users + 1)
user_id_emb = feature_column.embedding_column(user_id, dimension=8)scaler_age = StandardScaler()
df['user_age'] = scaler_age.fit_transform(df[['user_age']])
user_age = feature_column.numeric_column('user_age')user_gender = feature_column.categorical_column_with_vocabulary_list('user_gender', ['male', 'female'])
user_gender_emb = feature_column.embedding_column(user_gender, dimension=2)user_occupation = feature_column.categorical_column_with_vocabulary_list('user_occupation',['student', 'worker', 'teacher'])
user_occupation_emb = feature_column.embedding_column(user_occupation, dimension=2)city_code_column = feature_column.categorical_column_with_identity(key='city_code', num_buckets=2856)
city_code_emb = feature_column.embedding_column(city_code_column, dimension=8)device_types_column = feature_column.categorical_column_with_identity(key='device_type', num_buckets=5)
device_types_emb = feature_column.embedding_column(device_types_column, dimension=8)# 商品特征處理
item_id = feature_column.categorical_column_with_identity('item_id', num_buckets=num_items + 1)
item_id_emb = feature_column.embedding_column(item_id, dimension=8)scaler_price = StandardScaler()
df['item_price'] = scaler_price.fit_transform(df[['item_price']])
item_price = feature_column.numeric_column('item_price')item_category = feature_column.categorical_column_with_vocabulary_list('item_category',['electronics', 'books', 'clothing'])
item_category_emb = feature_column.embedding_column(item_category, dimension=2)item_brand = feature_column.categorical_column_with_vocabulary_list('item_brand', ['brandA', 'brandB', 'brandC'])
item_brand_emb = feature_column.embedding_column(item_brand, dimension=2)"""
交叉特征預處理
"""
# 使用TensorFlow的交叉特征(crossed_column)定義了Wide部分的特征列,主要用于捕捉用戶與商品特征之間的組合效應
# 將用戶ID(user_id)和商品ID(item_id)組合成一個新特征,捕捉**“特定用戶對特定商品的偏好”**
# 用戶ID和商品ID的組合總數可能非常大(num_users * num_items),直接編碼會導致維度爆炸。
# hash_bucket_size=10000:使用哈希函數將組合映射到固定數量的桶(10,000個),控制內存和計算開銷,適用于稀疏高維特征(如用戶-商品對)
user_id_x_item_id = feature_column.crossed_column([user_id, item_id], hash_bucket_size=10000)
user_id_x_item_id = feature_column.indicator_column(user_id_x_item_id)
user_gender_x_item_category = feature_column.crossed_column([user_gender, item_category], hash_bucket_size=1000)
user_gender_x_item_category = feature_column.indicator_column(user_gender_x_item_category)
user_occupation_x_item_brand = feature_column.crossed_column([user_occupation, item_brand], hash_bucket_size=1000)
user_occupation_x_item_brand = feature_column.indicator_column(user_occupation_x_item_brand)"""
特征列定義
"""
# ESMM 模型相關特征列定義
user_tower_columns = [user_id_emb, user_age, user_gender_emb, user_occupation_emb, city_code_emb, device_types_emb]
item_tower_columns = [item_id_emb, item_category_emb, item_brand_emb, item_price]# Wide&Deep 模型相關特征列定義
deep_feature_columns = [user_id_emb,user_age,user_gender_emb,user_occupation_emb,item_id_emb,item_category_emb,item_brand_emb,item_price
]wide_feature_columns = [user_id_x_item_id,user_gender_x_item_category,user_occupation_x_item_brand
]"""
Part-3:模型架構設計
"""
# 教師模型:采用 Wide&Deep 模型
class WideDeepModel(tf.keras.Model):"""Wide部分:線性模型,擅長記憶(Memorization),通過交叉特征捕捉明確的特征組合模式(如用戶A常點擊商品B)。Deep部分:深度神經網絡,擅長泛化(Generalization),通過嵌入向量學習特征的潛在關系(如女性用戶與服裝品類的關聯)。結合優勢:同時處理稀疏特征(如用戶ID、商品ID)和密集特征(如價格、年齡),平衡記憶與泛化能力"""def __init__(self, wide_feature_columns, deep_feature_columns):super(WideDeepModel, self).__init__()# Wide部分(線性模型)self.linear_features = tf.keras.layers.DenseFeatures(wide_feature_columns)self.wide_out = tf.keras.layers.Dense(1, activation='sigmoid')# Deep部分(深度神經網絡)self.dnn_features = tf.keras.layers.DenseFeatures(deep_feature_columns)self.dnn_layer = tf.keras.Sequential([tf.keras.layers.Dense(64, activation='relu'),tf.keras.layers.Dense(32, activation='relu')])self.deep_out = tf.keras.layers.Dense(1, activation='sigmoid')def call(self, inputs):# Wide部分:預測CTRlinear_features = self.linear_features(inputs)ctr_wide_logits = self.wide_out(linear_features)# Deep部分:預測CTRdnn_features = self.dnn_features(inputs)dnn_layer = self.dnn_layer(dnn_features)ctr_deep_logits = self.deep_out(dnn_layer)# 將Wide和Deep的logits相加,通過Sigmoid輸出點擊概率ctr_logits = tf.sigmoid(ctr_wide_logits + ctr_deep_logits)# 返回return {'ctr_logits': ctr_logits}# 學生模型:采用普通雙塔模型
class TowTowerStudent(tf.keras.Model):"""普通雙塔模型:User Tower + Item Tower"""def __init__(self, user_columns, item_columns):super(TowTowerStudent, self).__init__()# 共享特征處理層self.user_feature = tf.keras.layers.DenseFeatures(user_columns)self.item_feature = tf.keras.layers.DenseFeatures(item_columns)# User塔self.user_tower = tf.keras.Sequential([tf.keras.layers.Dense(64, activation='relu'),tf.keras.layers.Dense(32, activation='relu'),])# Item塔self.item_tower = tf.keras.Sequential([tf.keras.layers.Dense(64, activation='relu'),tf.keras.layers.Dense(32, activation='relu'),])self.tower_out = tf.keras.layers.Dense(1, activation='sigmoid')def call(self, inputs):# 雙塔結構user_feature = self.user_feature(inputs)item_feature = self.item_feature(inputs)user_emb = self.user_tower(user_feature)item_emb = self.item_tower(item_feature)# CTR預測# 點積交互(即用戶Embedding和商品Embedding求取余弦相似度)interaction = tf.keras.layers.Dot(axes=1)([user_emb, item_emb])ctr_logits = self.tower_out(interaction)return {'ctr_logits': ctr_logits}"""
Part-4:知識蒸餾實現
"""
class DistillationModel(tf.keras.Model):def __init__(self, teacher, student):super(DistillationModel, self).__init__()self.teacher = teacherself.student = student# 溫度參數:典型取值2-5之間self.temperature = 2.0def compile(self, optimizer, metrics, student_loss_fn, distillation_loss_fn):super().compile(optimizer=optimizer, metrics=metrics)self.student_loss_fn = student_loss_fnself.distillation_loss_fn = distillation_loss_fndef call(self, inputs):# 推理時直接使用學生模型return self.student(inputs)def train_step(self, data):# 解包數據x, y = data# 教師模型前向傳播(僅推理)teacher_predictions = self.teacher(x, training=False) # 凍結教師模型teacher_ctr = teacher_predictions['ctr_logits']# 使用tf.GradientTape實現動態梯度計算with tf.GradientTape() as tape:# 學生模型前向傳播student_outputs = self.student(x, training=True)student_ctr = student_outputs['ctr_logits']# 計算學生損失# 學生損失(student_loss):直接擬合真實標簽# y['ctr_logits'] = labels['click_label'],在輸入數據時有定義student_loss_ctr = self.student_loss_fn(y['ctr_logits'], student_ctr)# 計算蒸餾損失distillation_loss_ctr = self.distillation_loss_fn(# 蒸餾損失(distillation_loss):學習教師模型的軟標簽分布teacher_ctr / self.temperature, # 教師輸出軟化student_ctr / self.temperature # 學生輸出對齊)# 總損失total_loss = 0.7 * student_loss_ctr + 0.3 * distillation_loss_ctr# 計算梯度并更新(僅更新學生參數)trainable_vars = self.student.trainable_variablesgradients = tape.gradient(total_loss, trainable_vars)self.optimizer.apply_gradients(zip(gradients, trainable_vars))# 更新指標self.compiled_metrics.update_state(y, {'ctr_logits': student_ctr,})return {m.name: m.result() for m in self.metrics}"""
Part-5:模型訓練與評估
"""
# 數據輸入管道
def df_to_dataset(features, labels, shuffle=True, batch_size=32):ds = tf.data.Dataset.from_tensor_slices((dict(features),{# 這里做了一個映射,主要為了對齊學生模型和教師模型的輸出,從而便于計算損失'ctr_logits': labels['click_label']}))if shuffle:ds = ds.shuffle(1000)ds = ds.batch(batch_size)return ds# 轉換數據集
train_ds = df_to_dataset(train_features, train_labels)
test_ds = df_to_dataset(test_features, test_labels, shuffle=False)# 初始化模型
teacher = WideDeepModel(wide_feature_columns, deep_feature_columns)
student = TowTowerStudent(user_tower_columns, item_tower_columns)
distiller = DistillationModel(teacher, student)# 編譯教師模型(先單獨訓練)
teacher.compile(optimizer='adam',loss={'ctr_logits': 'binary_crossentropy'},metrics=['accuracy'],loss_weights=[0.7, 0.3] # 可選:設置不同任務的損失權重
)# 訓練教師模型
print("訓練教師模型...")
teacher.fit(train_ds, epochs=5, validation_data=test_ds)# 編譯蒸餾模型
distiller.compile(optimizer='adam',metrics={'ctr_logits': ['accuracy']},student_loss_fn=tf.keras.losses.BinaryCrossentropy(),distillation_loss_fn=tf.keras.losses.KLDivergence()
)# 訓練學生模型(帶蒸餾)
print("訓練學生模型...")
history = distiller.fit(train_ds, epochs=10, validation_data=test_ds)
print(history.history)# 可視化訓練過程
plt.plot(history.history['accuracy'], label='CTR Accuracy')plt.title('Training Metrics')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.show()"""
Part-6:模型服務化(示例)
"""
# 保存學生模型
student.save('esmm_student_model')# 加載模型進行推理
loaded_model = tf.keras.models.load_model('esmm_student_model')# 查看模型輸入層名稱
loaded_model.summary()# 示例預測:從 test_features 數據框中提取第一行數據
sample = test_features.iloc[0]sample_dict = {col: tf.expand_dims(value, -1)for col, value in dict(sample).items()
}predictions = loaded_model.predict(sample_dict)
print(f"預測結果:CTR={predictions['ctr_logits'][0][0]:.3f}")
3.8 運行效果
Teacher 模型訓練過程:
訓練教師模型...
Epoch 1/5
2025-03-30 21:41:55.398982: W tensorflow/tsl/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz
25/25 [==============================] - 2s 13ms/step - loss: 0.5838 - accuracy: 0.5013 - val_loss: 0.5115 - val_accuracy: 0.4850
Epoch 2/5
25/25 [==============================] - 0s 3ms/step - loss: 0.5049 - accuracy: 0.5013 - val_loss: 0.5101 - val_accuracy: 0.4850
Epoch 3/5
25/25 [==============================] - 0s 2ms/step - loss: 0.5037 - accuracy: 0.5013 - val_loss: 0.5093 - val_accuracy: 0.4850
Epoch 4/5
25/25 [==============================] - 0s 2ms/step - loss: 0.5026 - accuracy: 0.5013 - val_loss: 0.5085 - val_accuracy: 0.4850
Epoch 5/5
25/25 [==============================] - 0s 5ms/step - loss: 0.5014 - accuracy: 0.5013 - val_loss: 0.5077 - val_accuracy: 0.4850
Student 模型訓練過程:
訓練學生模型...
Epoch 1/10
25/25 [==============================] - 2s 11ms/step - accuracy: 0.4975 - val_loss: 0.0000e+00 - val_accuracy: 0.4850
Epoch 2/10
25/25 [==============================] - 0s 2ms/step - accuracy: 0.5038 - val_loss: 0.0000e+00 - val_accuracy: 0.4850
Epoch 3/10
25/25 [==============================] - 0s 2ms/step - accuracy: 0.5063 - val_loss: 0.0000e+00 - val_accuracy: 0.4850
Epoch 4/10
25/25 [==============================] - 0s 2ms/step - accuracy: 0.5050 - val_loss: 0.0000e+00 - val_accuracy: 0.4850
Epoch 5/10
25/25 [==============================] - 0s 2ms/step - accuracy: 0.5075 - val_loss: 0.0000e+00 - val_accuracy: 0.4850
Epoch 6/10
25/25 [==============================] - 0s 2ms/step - accuracy: 0.5075 - val_loss: 0.0000e+00 - val_accuracy: 0.4850
Epoch 7/10
25/25 [==============================] - 0s 2ms/step - accuracy: 0.5075 - val_loss: 0.0000e+00 - val_accuracy: 0.4850
Epoch 8/10
25/25 [==============================] - 0s 2ms/step - accuracy: 0.5075 - val_loss: 0.0000e+00 - val_accuracy: 0.4850
Epoch 9/10
25/25 [==============================] - 0s 2ms/step - accuracy: 0.5075 - val_loss: 0.0000e+00 - val_accuracy: 0.4850
Epoch 10/10
25/25 [==============================] - 0s 2ms/step - accuracy: 0.5088 - val_loss: 0.0000e+00 - val_accuracy: 0.4900
模型結構及預測示例:
Model: "tow_tower_student"
_________________________________________________________________Layer (type) Output Shape Param #
=================================================================dense_features_2 (DenseFeat multiple 23706 ures) dense_features_3 (DenseFeat multiple 1620 ures) sequential_1 (Sequential) (None, 32) 4000 sequential_2 (Sequential) (None, 32) 2976 dense_8 (Dense) multiple 2 =================================================================
Total params: 32,304
Trainable params: 32,304
Non-trainable params: 0
_________________________________________________________________
1/1 [==============================] - 0s 178ms/step
預測結果:CTR=0.526
可視化訓練過程:
3.9 模型預測
在另一個工程中加載通過蒸餾訓練好的 Student 模型,并執行預測,代碼示例如下:
# 導入必要的庫
import tensorflow as tf
import pandas as pd
import numpy as np# 人工構造數據
num_users = 100
num_items = 200# 重新生成新的樣本,模擬真實數據進行預測
def generate_new_samples(num_samples=5):new_samples = []for _ in range(num_samples):user_id = np.random.randint(1, num_users + 1)item_id = np.random.randint(1, num_items + 1)user_age = np.random.randint(18, 65)user_gender = np.random.choice(['male', 'female'])user_occupation = np.random.choice(['student', 'worker', 'teacher'])city_code = np.random.randint(1, 2856)device_type = np.random.randint(0, 5)item_category = np.random.choice(['electronics', 'books', 'clothing'])item_brand = np.random.choice(['brandA', 'brandB', 'brandC'])item_price = np.random.randint(1, 199)new_samples.append({'user_id': user_id,'user_age': user_age,'user_gender': user_gender,'user_occupation': user_occupation,'city_code': city_code,'device_type': device_type,'item_id': item_id,'item_category': item_category,'item_brand': item_brand,'item_price': item_price})return pd.DataFrame(new_samples)# 生成并打印預覽新的樣本數據
new_samples = generate_new_samples(num_samples=5)
# 設置display.max_columns為None,強制顯示全部列:
pd.set_option('display.max_columns', None)
print("\nGenerated New Samples:\n", new_samples)# 準備輸入數據
input_dict = {'user_id': tf.convert_to_tensor(new_samples['user_id'].values, dtype=tf.int64),'user_age': tf.convert_to_tensor(new_samples['user_age'].values, dtype=tf.int64),'user_gender': tf.convert_to_tensor(new_samples['user_gender'].values, dtype=tf.string),'user_occupation': tf.convert_to_tensor(new_samples['user_occupation'].values, dtype=tf.string),'city_code': tf.convert_to_tensor(new_samples['city_code'].values, dtype=tf.int64),'device_type': tf.convert_to_tensor(new_samples['device_type'].values, dtype=tf.int64),'item_id': tf.convert_to_tensor(new_samples['item_id'].values, dtype=tf.int64),'item_category': tf.convert_to_tensor(new_samples['item_category'].values, dtype=tf.string),'item_brand': tf.convert_to_tensor(new_samples['item_brand'].values, dtype=tf.string),'item_price': tf.convert_to_tensor(new_samples['item_price'].values, dtype=tf.int64)
}# 加載模型進行推理
loaded_model = tf.keras.models.load_model('esmm_student_model')
# 明確使用默認簽名
predict_fn = loaded_model.signatures['serving_default']
predictions = predict_fn(**input_dict)# 提取并打印預測結果
# 預測結果是一個 CTCVR 綜合分
predicted_ctr = predictions['ctr_logits'].numpy().flatten()
new_samples['ctr_prob'] = predicted_ctr
print("\nPrediction Results:")
for idx, row in new_samples.iterrows():print(f"Item ID: {row['item_id']} | CTR Final Score: {row['ctr_prob']:.4f}")
運行結果如下:
Generated New Samples:user_id user_age user_gender user_occupation city_code device_type \
0 34 49 female teacher 843 0
1 15 30 female student 564 3
2 26 37 male teacher 2229 0
3 31 35 male worker 2494 0
4 41 57 female student 1668 3 item_id item_category item_brand item_price
0 147 electronics brandA 127
1 196 clothing brandC 190
2 1 books brandA 1
3 150 clothing brandA 5
4 128 electronics brandA 156
Metal device set to: Apple M1 ProPrediction Results:
Item ID: 147 | CTR Final Score: 0.5263
Item ID: 196 | CTR Final Score: 0.5263
Item ID: 1 | CTR Final Score: 0.5263
Item ID: 150 | CTR Final Score: 0.4793
Item ID: 128 | CTR Final Score: 0.5263