項目概述
2025 騰訊廣告算法大賽 Baseline,一個簡單的序列推薦系統,主要用于建模用戶和物品的交互序列,并利用多模態特征(文本、圖像等 embedding)來提升推薦效果。
核心文件功能
1. main.py - 主訓練腳本
- 負責模型訓練的整體流程
- 包含參數解析、數據加載、模型初始化、訓練循環等
- 支持斷點續訓和僅推理模式
- 使用 TensorBoard 記錄訓練日志
main.py 代碼
import argparse
import json
import os
import time
from pathlib import Pathimport numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdmfrom dataset import MyDataset
from model import BaselineModeldef get_args():parser = argparse.ArgumentParser()# Train paramsparser.add_argument('--batch_size', default=128, type=int)parser.add_argument('--lr', default=0.001, type=float)parser.add_argument('--maxlen', default=101, type=int)# Baseline Model constructionparser.add_argument('--hidden_units', default=32, type=int)parser.add_argument('--num_blocks', default=1, type=int)parser.add_argument('--num_epochs', default=3, type=int)parser.add_argument('--num_heads', default=1, type=int)parser.add_argument('--dropout_rate', default=0.2, type=float)parser.add_argument('--l2_emb', default=0.0, type=float)parser.add_argument('--device', default='cuda', type=str)parser.add_argument('--inference_only', action='store_true')parser.add_argument('--state_dict_path', default=None, type=str)parser.add_argument('--norm_first', action='store_true')# MMemb Feature IDparser.add_argument('--mm_emb_id', nargs='+', default=['81'], type=str, choices=[str(s) for s in range(81, 87)])args = parser.parse_args()return argsif __name__ == '__main__':Path(os.environ.get('TRAIN_LOG_PATH')).mkdir(parents=True, exist_ok=True)Path(os.environ.get('TRAIN_TF_EVENTS_PATH')).mkdir(parents=True, exist_ok=True)log_file = open(Path(os.environ.get('TRAIN_LOG_PATH'), 'train.log'), 'w')writer = SummaryWriter(os.environ.get('TRAIN_TF_EVENTS_PATH'))# global datasetdata_path = os.environ.get('TRAIN_DATA_PATH')args = get_args()dataset = MyDataset(data_path, args)train_dataset, valid_dataset = torch.utils.data.random_split(dataset, [0.9, 0.1])train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0, collate_fn=dataset.collate_fn)valid_loader = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0, collate_fn=dataset.collate_fn)usernum, itemnum = dataset.usernum, dataset.itemnumfeat_statistics, feat_types = dataset.feat_statistics, dataset.feature_typesmodel = BaselineModel(usernum, itemnum, feat_statistics, feat_types, args).to(args.device)for name, param in model.named_parameters():try:torch.nn.init.xavier_normal_(param.data)except Exception:passmodel.pos_emb.weight.data[0, :] = 0model.item_emb.weight.data[0, :] = 0model.user_emb.weight.data[0, :] = 0for k in model.sparse_emb:model.sparse_emb[k].weight.data[0, :] = 0epoch_start_idx = 1if args.state_dict_path is not None:try:model.load_state_dict(torch.load(args.state_dict_path, map_location=torch.device(args.device)))tail = args.state_dict_path[args.state_dict_path.find('epoch=') + 6 :]epoch_start_idx = int(tail[: tail.find('.')]) + 1except:print('failed loading state_dicts, pls check file path: ', end="")print(args.state_dict_path)raise RuntimeError('failed loading state_dicts, pls check file path!')bce_criterion = torch.nn.BCEWithLogitsLoss(reduction='mean')optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.98))best_val_ndcg, best_val_hr = 0.0, 0.0best_test_ndcg, best_test_hr = 0.0, 0.0T = 0.0t0 = time.time()global_step = 0print("Start training")for epoch in range(epoch_start_idx, args.num_epochs + 1):model.train()if args.inference_only:breakfor step, batch in tqdm(enumerate(train_loader), total=len(train_loader)):seq, pos, neg, token_type, next_token_type, next_action_type, seq_feat, pos_feat, neg_feat = batchseq = seq.to(args.device)pos = pos.to(args.device)neg = neg.to(args.device)pos_logits, neg_logits = model(seq, pos, neg, token_type, next_token_type, next_action_type, seq_feat, pos_feat, neg_feat)pos_labels, neg_labels = torch.ones(pos_logits.shape, device=args.device), torch.zeros(neg_logits.shape, device=args.device)optimizer.zero_grad()indices = np.where(next_token_type == 1)loss = bce_criterion(pos_logits[indices], pos_labels[indices])loss += bce_criterion(neg_logits[indices], neg_labels[indices])log_json = json.dumps({'global_step': global_step, 'loss': loss.item(), 'epoch': epoch, 'time': time.time()})log_file.write(log_json + '\n')log_file.flush()print(log_json)writer.add_scalar('Loss/train', loss.item(), global_step)global_step += 1for param in model.item_emb.parameters():loss += args.l2_emb * torch.norm(param)loss.backward()optimizer.step()model.eval()valid_loss_sum = 0for step, batch in tqdm(enumerate(valid_loader), total=len(valid_loader)):seq, pos, neg, token_type, next_token_type, next_action_type, seq_feat, pos_feat, neg_feat = batchseq = seq.to(args.device)pos = pos.to(args.device)neg = neg.to(args.device)pos_logits, neg_logits = model(seq, pos, neg, token_type, next_token_type, next_action_type, seq_feat, pos_feat, neg_feat)pos_labels, neg_labels = torch.ones(pos_logits.shape, device=args.device), torch.zeros(neg_logits.shape, device=args.device)indices = np.where(next_token_type == 1)loss = bce_criterion(pos_logits[indices], pos_labels[indices])loss += bce_criterion(neg_logits[indices], neg_labels[indices])valid_loss_sum += loss.item()valid_loss_sum /= len(valid_loader)writer.add_scalar('Loss/valid', valid_loss_sum, global_step)save_dir = Path(os.environ.get('TRAIN_CKPT_PATH'), f"global_step{global_step}.valid_loss={valid_loss_sum:.4f}")save_dir.mkdir(parents=True, exist_ok=True)torch.save(model.state_dict(), save_dir / "model.pt")print("Done")writer.close()log_file.close()
2. model.py - 核心模型實現
BaselineModel
- 主推薦模型
基于 Transformer 的序列推薦模型,具有以下特點:
模型架構:
- 使用
FlashMultiHeadAttention
實現高效的多頭注意力機制 - 采用
PointWiseFeedForward
作為前饋網絡 - 支持多種特征類型:稀疏特征、數組特征、連續特征、多模態 embedding 特征
特征處理:
- 用戶特征:稀疏特征 (103,104,105,109)、數組特征 (106,107,108,110)
- 物品特征:稀疏特征 (100,117,111 等)、多模態 embedding 特征 (81-86)
- 通過
feat2emb
方法將不同類型特征轉換為統一的 embedding 表示
核心方法:
log2feats
:將用戶序列轉換為特征表示forward
:訓練時計算正負樣本的 logitspredict
:推理時生成用戶表征save_item_emb
:保存物品 embedding 用于檢索
model.py 代碼
from pathlib import Pathimport numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdmfrom dataset import save_embclass FlashMultiHeadAttention(torch.nn.Module):def __init__(self, hidden_units, num_heads, dropout_rate):super(FlashMultiHeadAttention, self).__init__()self.hidden_units = hidden_unitsself.num_heads = num_headsself.head_dim = hidden_units // num_headsself.dropout_rate = dropout_rateassert hidden_units % num_heads == 0, "hidden_units must be divisible by num_heads"self.q_linear = torch.nn.Linear(hidden_units, hidden_units)self.k_linear = torch.nn.Linear(hidden_units, hidden_units)self.v_linear = torch.nn.Linear(hidden_units, hidden_units)self.out_linear = torch.nn.Linear(hidden_units, hidden_units)def forward(self, query, key, value, attn_mask=None):batch_size, seq_len, _ = query.size()# 計算Q, K, VQ = self.q_linear(query)K = self.k_linear(key)V = self.v_linear(value)# reshape為multi-head格式Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)if hasattr(F, 'scaled_dot_product_attention'):# PyTorch 2.0+ 使用內置的Flash Attentionattn_output = F.scaled_dot_product_attention(Q, K, V, dropout_p=self.dropout_rate if self.training else 0.0, attn_mask=attn_mask.unsqueeze(1))else:# 降級到標準注意力機制scale = (self.head_dim) ** -0.5scores = torch.matmul(Q, K.transpose(-2, -1)) * scaleif attn_mask is not None:scores.masked_fill_(attn_mask.unsqueeze(1).logical_not(), float('-inf'))attn_weights = F.softmax(scores, dim=-1)attn_weights = F.dropout(attn_weights, p=self.dropout_rate, training=self.training)attn_output = torch.matmul(attn_weights, V)# reshape回原來的格式attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.hidden_units)# 最終的線性變換output = self.out_linear(attn_output)return output, Noneclass PointWiseFeedForward(torch.nn.Module):def __init__(self, hidden_units, dropout_rate):super(PointWiseFeedForward, self).__init__()self.conv1 = torch.nn.Conv1d(hidden_units, hidden_units, kernel_size=1)self.dropout1 = torch.nn.Dropout(p=dropout_rate)self.relu = torch.nn.ReLU()self.conv2 = torch.nn.Conv1d(hidden_units, hidden_units, kernel_size=1)self.dropout2 = torch.nn.Dropout(p=dropout_rate)def forward(self, inputs):outputs = self.dropout2(self.conv2(self.relu(self.dropout1(self.conv1(inputs.transpose(-1, -2))))))outputs = outputs.transpose(-1, -2) # as Conv1D requires (N, C, Length)return outputsclass BaselineModel(torch.nn.Module):"""Args:user_num: 用戶數量item_num: 物品數量feat_statistics: 特征統計信息,key為特征ID,value為特征數量feat_types: 各個特征的特征類型,key為特征類型名稱,value為包含的特征ID列表,包括user和item的sparse, array, emb, continual類型args: 全局參數Attributes:user_num: 用戶數量item_num: 物品數量dev: 設備norm_first: 是否先歸一化maxlen: 序列最大長度item_emb: Item Embedding Tableuser_emb: User Embedding Tablesparse_emb: 稀疏特征Embedding Tableemb_transform: 多模態特征的線性變換userdnn: 用戶特征拼接后經過的全連接層itemdnn: 物品特征拼接后經過的全連接層"""def __init__(self, user_num, item_num, feat_statistics, feat_types, args): #super(BaselineModel, self).__init__()self.user_num = user_numself.item_num = item_numself.dev = args.deviceself.norm_first = args.norm_firstself.maxlen = args.maxlen# TODO: loss += args.l2_emb for regularizing embedding vectors during training# https://stackoverflow.com/questions/42704283/adding-l1-l2-regularization-in-pytorchself.item_emb = torch.nn.Embedding(self.item_num + 1, args.hidden_units, padding_idx=0)self.user_emb = torch.nn.Embedding(self.user_num + 1, args.hidden_units, padding_idx=0)self.pos_emb = torch.nn.Embedding(2 * args.maxlen + 1, args.hidden_units, padding_idx=0)self.emb_dropout = torch.nn.Dropout(p=args.dropout_rate)self.sparse_emb = torch.nn.ModuleDict()self.emb_transform = torch.nn.ModuleDict()self.attention_layernorms = torch.nn.ModuleList() # to be Q for self-attentionself.attention_layers = torch.nn.ModuleList()self.forward_layernorms = torch.nn.ModuleList()self.forward_layers = torch.nn.ModuleList()self._init_feat_info(feat_statistics, feat_types)userdim = args.hidden_units * (len(self.USER_SPARSE_FEAT) + 1 + len(self.USER_ARRAY_FEAT)) + len(self.USER_CONTINUAL_FEAT)itemdim = (args.hidden_units * (len(self.ITEM_SPARSE_FEAT) + 1 + len(self.ITEM_ARRAY_FEAT))+ len(self.ITEM_CONTINUAL_FEAT)+ args.hidden_units * len(self.ITEM_EMB_FEAT))self.userdnn = torch.nn.Linear(userdim, args.hidden_units)self.itemdnn = torch.nn.Linear(itemdim, args.hidden_units)self.last_layernorm = torch.nn.LayerNorm(args.hidden_units, eps=1e-8)for _ in range(args.num_blocks):new_attn_layernorm = torch.nn.LayerNorm(args.hidden_units, eps=1e-8)self.attention_layernorms.append(new_attn_layernorm)new_attn_layer = FlashMultiHeadAttention(args.hidden_units, args.num_heads, args.dropout_rate) # 優化:用FlashAttention替代標準Attentionself.attention_layers.append(new_attn_layer)new_fwd_layernorm = torch.nn.LayerNorm(args.hidden_units, eps=1e-8)self.forward_layernorms.append(new_fwd_layernorm)new_fwd_layer = PointWiseFeedForward(args.hidden_units, args.dropout_rate)self.forward_layers.append(new_fwd_layer)for k in self.USER_SPARSE_FEAT:self.sparse_emb[k] = torch.nn.Embedding(self.USER_SPARSE_FEAT[k] + 1, args.hidden_units, padding_idx=0)for k in self.ITEM_SPARSE_FEAT:self.sparse_emb[k] = torch.nn.Embedding(self.ITEM_SPARSE_FEAT[k] + 1, args.hidden_units, padding_idx=0)for k in self.ITEM_ARRAY_FEAT:self.sparse_emb[k] = torch.nn.Embedding(self.ITEM_ARRAY_FEAT[k] + 1, args.hidden_units, padding_idx=0)for k in self.USER_ARRAY_FEAT:self.sparse_emb[k] = torch.nn.Embedding(self.USER_ARRAY_FEAT[k] + 1, args.hidden_units, padding_idx=0)for k in self.ITEM_EMB_FEAT:self.emb_transform[k] = torch.nn.Linear(self.ITEM_EMB_FEAT[k], args.hidden_units)def _init_feat_info(self, feat_statistics, feat_types):"""將特征統計信息(特征數量)按特征類型分組產生不同的字典,方便聲明稀疏特征的Embedding TableArgs:feat_statistics: 特征統計信息,key為特征ID,value為特征數量feat_types: 各個特征的特征類型,key為特征類型名稱,value為包含的特征ID列表,包括user和item的sparse, array, emb, continual類型"""self.USER_SPARSE_FEAT = {k: feat_statistics[k] for k in feat_types['user_sparse']}self.USER_CONTINUAL_FEAT = feat_types['user_continual']self.ITEM_SPARSE_FEAT = {k: feat_statistics[k] for k in feat_types['item_sparse']}self.ITEM_CONTINUAL_FEAT = feat_types['item_continual']self.USER_ARRAY_FEAT = {k: feat_statistics[k] for k in feat_types['user_array']}self.ITEM_ARRAY_FEAT = {k: feat_statistics[k] for k in feat_types['item_array']}EMB_SHAPE_DICT = {"81": 32, "82": 1024, "83": 3584, "84": 4096, "85": 3584, "86": 3584}self.ITEM_EMB_FEAT = {k: EMB_SHAPE_DICT[k] for k in feat_types['item_emb']} # 記錄的是不同多模態特征的維度def feat2tensor(self, seq_feature, k):"""Args:seq_feature: 序列特征list,每個元素為當前時刻的特征字典,形狀為 [batch_size, maxlen]k: 特征IDReturns:batch_data: 特征值的tensor,形狀為 [batch_size, maxlen, max_array_len(if array)]"""batch_size = len(seq_feature)if k in self.ITEM_ARRAY_FEAT or k in self.USER_ARRAY_FEAT:# 如果特征是Array類型,需要先對array進行padding,然后轉換為tensormax_array_len = 0max_seq_len = 0for i in range(batch_size):seq_data = [item[k] for item in seq_feature[i]]max_seq_len = max(max_seq_len, len(seq_data))max_array_len = max(max_array_len, max(len(item_data) for item_data in seq_data))batch_data = np.zeros((batch_size, max_seq_len, max_array_len), dtype=np.int64)for i in range(batch_size):seq_data = [item[k] for item in seq_feature[i]]for j, item_data in enumerate(seq_data):actual_len = min(len(item_data), max_array_len)batch_data[i, j, :actual_len] = item_data[:actual_len]return torch.from_numpy(batch_data).to(self.dev)else:# 如果特征是Sparse類型,直接轉換為tensormax_seq_len = max(len(seq_feature[i]) for i in range(batch_size))batch_data = np.zeros((batch_size, max_seq_len), dtype=np.int64)for i in range(batch_size):seq_data = [item[k] for item in seq_feature[i]]batch_data[i] = seq_datareturn torch.from_numpy(batch_data).to(self.dev)def feat2emb(self, seq, feature_array, mask=None, include_user=False):"""Args:seq: 序列IDfeature_array: 特征list,每個元素為當前時刻的特征字典mask: 掩碼,1表示item,2表示userinclude_user: 是否處理用戶特征,在兩種情況下不打開:1) 訓練時在轉換正負樣本的特征時(因為正負樣本都是item);2) 生成候選庫item embedding時。Returns:seqs_emb: 序列特征的Embedding"""seq = seq.to(self.dev)# pre-compute embeddingif include_user:user_mask = (mask == 2).to(self.dev)item_mask = (mask == 1).to(self.dev)user_embedding = self.user_emb(user_mask * seq)item_embedding = self.item_emb(item_mask * seq)item_feat_list = [item_embedding]user_feat_list = [user_embedding]else:item_embedding = self.item_emb(seq)item_feat_list = [item_embedding]# batch-process all feature typesall_feat_types = [(self.ITEM_SPARSE_FEAT, 'item_sparse', item_feat_list),(self.ITEM_ARRAY_FEAT, 'item_array', item_feat_list),(self.ITEM_CONTINUAL_FEAT, 'item_continual', item_feat_list),]if include_user:all_feat_types.extend([(self.USER_SPARSE_FEAT, 'user_sparse', user_feat_list),(self.USER_ARRAY_FEAT, 'user_array', user_feat_list),(self.USER_CONTINUAL_FEAT, 'user_continual', user_feat_list),])# batch-process each feature typefor feat_dict, feat_type, feat_list in all_feat_types:if not feat_dict:continuefor k in feat_dict:tensor_feature = self.feat2tensor(feature_array, k)if feat_type.endswith('sparse'):feat_list.append(self.sparse_emb[k](tensor_feature))elif feat_type.endswith('array'):feat_list.append(self.sparse_emb[k](tensor_feature).sum(2))elif feat_type.endswith('continual'):feat_list.append(tensor_feature.unsqueeze(2))for k in self.ITEM_EMB_FEAT:# collect all data to numpy, then batch-convertbatch_size = len(feature_array)emb_dim = self.ITEM_EMB_FEAT[k]seq_len = len(feature_array[0])# pre-allocate tensorbatch_emb_data = np.zeros((batch_size, seq_len, emb_dim), dtype=np.float32)for i, seq in enumerate(feature_array):for j, item in enumerate(seq):if k in item:batch_emb_data[i, j] = item[k]# batch-convert and transfer to GPUtensor_feature = torch.from_numpy(batch_emb_data).to(self.dev)item_feat_list.append(self.emb_transform[k](tensor_feature))# merge featuresall_item_emb = torch.cat(item_feat_list, dim=2)all_item_emb = torch.relu(self.itemdnn(all_item_emb))if include_user:all_user_emb = torch.cat(user_feat_list, dim=2)all_user_emb = torch.relu(self.userdnn(all_user_emb))seqs_emb = all_item_emb + all_user_embelse:seqs_emb = all_item_embreturn seqs_embdef log2feats(self, log_seqs, mask, seq_feature):"""Args:log_seqs: 序列IDmask: token類型掩碼,1表示item token,2表示user tokenseq_feature: 序列特征list,每個元素為當前時刻的特征字典Returns:seqs_emb: 序列的Embedding,形狀為 [batch_size, maxlen, hidden_units]"""batch_size = log_seqs.shape[0]maxlen = log_seqs.shape[1]seqs = self.feat2emb(log_seqs, seq_feature, mask=mask, include_user=True)seqs *= self.item_emb.embedding_dim**0.5poss = torch.arange(1, maxlen + 1, device=self.dev).unsqueeze(0).expand(batch_size, -1).clone()poss *= log_seqs != 0seqs += self.pos_emb(poss)seqs = self.emb_dropout(seqs)maxlen = seqs.shape[1]ones_matrix = torch.ones((maxlen, maxlen), dtype=torch.bool, device=self.dev)attention_mask_tril = torch.tril(ones_matrix)attention_mask_pad = (mask != 0).to(self.dev)attention_mask = attention_mask_tril.unsqueeze(0) & attention_mask_pad.unsqueeze(1)for i in range(len(self.attention_layers)):if self.norm_first:x = self.attention_layernorms[i](seqs)mha_outputs, _ = self.attention_layers[i](x, x, x, attn_mask=attention_mask)seqs = seqs + mha_outputsseqs = seqs + self.forward_layers[i](self.forward_layernorms[i](seqs))else:mha_outputs, _ = self.attention_layers[i](seqs, seqs, seqs, attn_mask=attention_mask)seqs = self.attention_layernorms[i](seqs + mha_outputs)seqs = self.forward_layernorms[i](seqs + self.forward_layers[i](seqs))log_feats = self.last_layernorm(seqs)return log_featsdef forward(self, user_item, pos_seqs, neg_seqs, mask, next_mask, next_action_type, seq_feature, pos_feature, neg_feature):"""訓練時調用,計算正負樣本的logitsArgs:user_item: 用戶序列IDpos_seqs: 正樣本序列IDneg_seqs: 負樣本序列IDmask: token類型掩碼,1表示item token,2表示user tokennext_mask: 下一個token類型掩碼,1表示item token,2表示user tokennext_action_type: 下一個token動作類型,0表示曝光,1表示點擊seq_feature: 序列特征list,每個元素為當前時刻的特征字典pos_feature: 正樣本特征list,每個元素為當前時刻的特征字典neg_feature: 負樣本特征list,每個元素為當前時刻的特征字典Returns:pos_logits: 正樣本logits,形狀為 [batch_size, maxlen]neg_logits: 負樣本logits,形狀為 [batch_size, maxlen]"""log_feats = self.log2feats(user_item, mask, seq_feature)loss_mask = (next_mask == 1).to(self.dev)pos_embs = self.feat2emb(pos_seqs, pos_feature, include_user=False)neg_embs = self.feat2emb(neg_seqs, neg_feature, include_user=False)pos_logits = (log_feats * pos_embs).sum(dim=-1)neg_logits = (log_feats * neg_embs).sum(dim=-1)pos_logits = pos_logits * loss_maskneg_logits = neg_logits * loss_maskreturn pos_logits, neg_logitsdef predict(self, log_seqs, seq_feature, mask):"""計算用戶序列的表征Args:log_seqs: 用戶序列IDseq_feature: 序列特征list,每個元素為當前時刻的特征字典mask: token類型掩碼,1表示item token,2表示user tokenReturns:final_feat: 用戶序列的表征,形狀為 [batch_size, hidden_units]"""log_feats = self.log2feats(log_seqs, mask, seq_feature)final_feat = log_feats[:, -1, :]return final_featdef save_item_emb(self, item_ids, retrieval_ids, feat_dict, save_path, batch_size=1024):"""生成候選庫item embedding,用于檢索Args:item_ids: 候選item ID(re-id形式)retrieval_ids: 候選item ID(檢索ID,從0開始編號,檢索腳本使用)feat_dict: 訓練集所有item特征字典,key為特征ID,value為特征值save_path: 保存路徑batch_size: 批次大小"""all_embs = []for start_idx in tqdm(range(0, len(item_ids), batch_size), desc="Saving item embeddings"):end_idx = min(start_idx + batch_size, len(item_ids))item_seq = torch.tensor(item_ids[start_idx:end_idx], device=self.dev).unsqueeze(0)batch_feat = []for i in range(start_idx, end_idx):batch_feat.append(feat_dict[i])batch_feat = np.array(batch_feat, dtype=object)batch_emb = self.feat2emb(item_seq, [batch_feat], include_user=False).squeeze(0)all_embs.append(batch_emb.detach().cpu().numpy().astype(np.float32))# 合并所有批次的結果并保存final_ids = np.array(retrieval_ids, dtype=np.uint64).reshape(-1, 1)final_embs = np.concatenate(all_embs, axis=0)save_emb(final_embs, Path(save_path, 'embedding.fbin'))save_emb(final_ids, Path(save_path, 'id.u64bin'))
3. dataset.py - 數據處理
MyDataset
- 訓練數據集
- 處理用戶行為序列數據,支持用戶和物品交替出現的序列格式
- 實現高效的數據加載,使用文件偏移量進行隨機訪問
- 支持多種特征類型的 padding 和缺失值填充
- 實現負采樣機制用于訓練
MyTestDataset
- 測試數據集
- 繼承自訓練數據集,專門用于推理階段
- 處理冷啟動問題(訓練時未見過的特征值)
dataset.py 代碼
import json
import pickle
import struct
from pathlib import Pathimport numpy as np
import torch
from tqdm import tqdmclass MyDataset(torch.utils.data.Dataset):"""用戶序列數據集Args:data_dir: 數據文件目錄args: 全局參數Attributes:data_dir: 數據文件目錄maxlen: 最大長度item_feat_dict: 物品特征字典mm_emb_ids: 激活的mm_emb特征IDmm_emb_dict: 多模態特征字典itemnum: 物品數量usernum: 用戶數量indexer_i_rev: 物品索引字典 (reid -> item_id)indexer_u_rev: 用戶索引字典 (reid -> user_id)indexer: 索引字典feature_default_value: 特征缺省值feature_types: 特征類型,分為user和item的sparse, array, emb, continual類型feat_statistics: 特征統計信息,包括user和item的特征數量"""def __init__(self, data_dir, args):"""初始化數據集"""super().__init__()self.data_dir = Path(data_dir)self._load_data_and_offsets()self.maxlen = args.maxlenself.mm_emb_ids = args.mm_emb_idself.item_feat_dict = json.load(open(Path(data_dir, "item_feat_dict.json"), 'r'))self.mm_emb_dict = load_mm_emb(Path(data_dir, "creative_emb"), self.mm_emb_ids)with open(self.data_dir / 'indexer.pkl', 'rb') as ff:indexer = pickle.load(ff)self.itemnum = len(indexer['i'])self.usernum = len(indexer['u'])self.indexer_i_rev = {v: k for k, v in indexer['i'].items()}self.indexer_u_rev = {v: k for k, v in indexer['u'].items()}self.indexer = indexerself.feature_default_value, self.feature_types, self.feat_statistics = self._init_feat_info()def _load_data_and_offsets(self):"""加載用戶序列數據和每一行的文件偏移量(預處理好的), 用于快速隨機訪問數據并I/O"""self.data_file = open(self.data_dir / "seq.jsonl", 'rb')with open(Path(self.data_dir, 'seq_offsets.pkl'), 'rb') as f:self.seq_offsets = pickle.load(f)def _load_user_data(self, uid):"""從數據文件中加載單個用戶的數據Args:uid: 用戶ID(reid)Returns:data: 用戶序列數據,格式為[(user_id, item_id, user_feat, item_feat, action_type, timestamp)]"""self.data_file.seek(self.seq_offsets[uid])line = self.data_file.readline()data = json.loads(line)return datadef _random_neq(self, l, r, s):"""生成一個不在序列s中的隨機整數, 用于訓練時的負采樣Args:l: 隨機整數的最小值r: 隨機整數的最大值s: 序列Returns:t: 不在序列s中的隨機整數"""t = np.random.randint(l, r)while t in s or str(t) not in self.item_feat_dict:t = np.random.randint(l, r)return tdef __getitem__(self, uid):"""獲取單個用戶的數據,并進行padding處理,生成模型需要的數據格式Args:uid: 用戶ID(reid)Returns:seq: 用戶序列IDpos: 正樣本ID(即下一個真實訪問的item)neg: 負樣本IDtoken_type: 用戶序列類型,1表示item,2表示usernext_token_type: 下一個token類型,1表示item,2表示userseq_feat: 用戶序列特征,每個元素為字典,key為特征ID,value為特征值pos_feat: 正樣本特征,每個元素為字典,key為特征ID,value為特征值neg_feat: 負樣本特征,每個元素為字典,key為特征ID,value為特征值"""user_sequence = self._load_user_data(uid) # 動態加載用戶數據ext_user_sequence = []for record_tuple in user_sequence:u, i, user_feat, item_feat, action_type, _ = record_tupleif u and user_feat:ext_user_sequence.insert(0, (u, user_feat, 2, action_type))if i and item_feat:ext_user_sequence.append((i, item_feat, 1, action_type))seq = np.zeros([self.maxlen + 1], dtype=np.int32)pos = np.zeros([self.maxlen + 1], dtype=np.int32)neg = np.zeros([self.maxlen + 1], dtype=np.int32)token_type = np.zeros([self.maxlen + 1], dtype=np.int32)next_token_type = np.zeros([self.maxlen + 1], dtype=np.int32)next_action_type = np.zeros([self.maxlen + 1], dtype=np.int32)seq_feat = np.empty([self.maxlen + 1], dtype=object)pos_feat = np.empty([self.maxlen + 1], dtype=object)neg_feat = np.empty([self.maxlen + 1], dtype=object)nxt = ext_user_sequence[-1]idx = self.maxlents = set()for record_tuple in ext_user_sequence:if record_tuple[2] == 1 and record_tuple[0]:ts.add(record_tuple[0])# left-padding, 從后往前遍歷,將用戶序列填充到maxlen+1的長度for record_tuple in reversed(ext_user_sequence[:-1]):i, feat, type_, act_type = record_tuplenext_i, next_feat, next_type, next_act_type = nxtfeat = self.fill_missing_feat(feat, i)next_feat = self.fill_missing_feat(next_feat, next_i)seq[idx] = itoken_type[idx] = type_next_token_type[idx] = next_typeif next_act_type is not None:next_action_type[idx] = next_act_typeseq_feat[idx] = featif next_type == 1 and next_i != 0:pos[idx] = next_ipos_feat[idx] = next_featneg_id = self._random_neq(1, self.itemnum + 1, ts)neg[idx] = neg_idneg_feat[idx] = self.fill_missing_feat(self.item_feat_dict[str(neg_id)], neg_id)nxt = record_tupleidx -= 1if idx == -1:breakseq_feat = np.where(seq_feat == None, self.feature_default_value, seq_feat)pos_feat = np.where(pos_feat == None, self.feature_default_value, pos_feat)neg_feat = np.where(neg_feat == None, self.feature_default_value, neg_feat)return seq, pos, neg, token_type, next_token_type, next_action_type, seq_feat, pos_feat, neg_featdef __len__(self):"""返回數據集長度,即用戶數量Returns:usernum: 用戶數量"""return len(self.seq_offsets)def _init_feat_info(self):"""初始化特征信息, 包括特征缺省值和特征類型Returns:feat_default_value: 特征缺省值,每個元素為字典,key為特征ID,value為特征缺省值feat_types: 特征類型,key為特征類型名稱,value為包含的特征ID列表"""feat_default_value = {}feat_statistics = {}feat_types = {}feat_types['user_sparse'] = ['103', '104', '105', '109']feat_types['item_sparse'] = ['100','117','111','118','101','102','119','120','114','112','121','115','122','116',]feat_types['item_array'] = []feat_types['user_array'] = ['106', '107', '108', '110']feat_types['item_emb'] = self.mm_emb_idsfeat_types['user_continual'] = []feat_types['item_continual'] = []for feat_id in feat_types['user_sparse']:feat_default_value[feat_id] = 0feat_statistics[feat_id] = len(self.indexer['f'][feat_id])for feat_id in feat_types['item_sparse']:feat_default_value[feat_id] = 0feat_statistics[feat_id] = len(self.indexer['f'][feat_id])for feat_id in feat_types['item_array']:feat_default_value[feat_id] = [0]feat_statistics[feat_id] = len(self.indexer['f'][feat_id])for feat_id in feat_types['user_array']:feat_default_value[feat_id] = [0]feat_statistics[feat_id] = len(self.indexer['f'][feat_id])for feat_id in feat_types['user_continual']:feat_default_value[feat_id] = 0for feat_id in feat_types['item_continual']:feat_default_value[feat_id] = 0for feat_id in feat_types['item_emb']:feat_default_value[feat_id] = np.zeros(list(self.mm_emb_dict[feat_id].values())[0].shape[0], dtype=np.float32)return feat_default_value, feat_types, feat_statisticsdef fill_missing_feat(self, feat, item_id):"""對于原始數據中缺失的特征進行填充缺省值Args:feat: 特征字典item_id: 物品IDReturns:filled_feat: 填充后的特征字典"""if feat == None:feat = {}filled_feat = {}for k in feat.keys():filled_feat[k] = feat[k]all_feat_ids = []for feat_type in self.feature_types.values():all_feat_ids.extend(feat_type)missing_fields = set(all_feat_ids) - set(feat.keys())for feat_id in missing_fields:filled_feat[feat_id] = self.feature_default_value[feat_id]for feat_id in self.feature_types['item_emb']:if item_id != 0 and self.indexer_i_rev[item_id] in self.mm_emb_dict[feat_id]:if type(self.mm_emb_dict[feat_id][self.indexer_i_rev[item_id]]) == np.ndarray:filled_feat[feat_id] = self.mm_emb_dict[feat_id][self.indexer_i_rev[item_id]]return filled_feat@staticmethoddef collate_fn(batch):"""Args:batch: 多個__getitem__返回的數據Returns:seq: 用戶序列ID, torch.Tensor形式pos: 正樣本ID, torch.Tensor形式neg: 負樣本ID, torch.Tensor形式token_type: 用戶序列類型, torch.Tensor形式next_token_type: 下一個token類型, torch.Tensor形式seq_feat: 用戶序列特征, list形式pos_feat: 正樣本特征, list形式neg_feat: 負樣本特征, list形式"""seq, pos, neg, token_type, next_token_type, next_action_type, seq_feat, pos_feat, neg_feat = zip(*batch)seq = torch.from_numpy(np.array(seq))pos = torch.from_numpy(np.array(pos))neg = torch.from_numpy(np.array(neg))token_type = torch.from_numpy(np.array(token_type))next_token_type = torch.from_numpy(np.array(next_token_type))next_action_type = torch.from_numpy(np.array(next_action_type))seq_feat = list(seq_feat)pos_feat = list(pos_feat)neg_feat = list(neg_feat)return seq, pos, neg, token_type, next_token_type, next_action_type, seq_feat, pos_feat, neg_featclass MyTestDataset(MyDataset):"""測試數據集"""def __init__(self, data_dir, args):super().__init__(data_dir, args)def _load_data_and_offsets(self):self.data_file = open(self.data_dir / "predict_seq.jsonl", 'rb')with open(Path(self.data_dir, 'predict_seq_offsets.pkl'), 'rb') as f:self.seq_offsets = pickle.load(f)def _process_cold_start_feat(self, feat):"""處理冷啟動特征。訓練集未出現過的特征value為字符串,默認轉換為0.可設計替換為更好的方法。"""processed_feat = {}for feat_id, feat_value in feat.items():if type(feat_value) == list:value_list = []for v in feat_value:if type(v) == str:value_list.append(0)else:value_list.append(v)processed_feat[feat_id] = value_listelif type(feat_value) == str:processed_feat[feat_id] = 0else:processed_feat[feat_id] = feat_valuereturn processed_featdef __getitem__(self, uid):"""獲取單個用戶的數據,并進行padding處理,生成模型需要的數據格式Args:uid: 用戶在self.data_file中儲存的行號Returns:seq: 用戶序列IDtoken_type: 用戶序列類型,1表示item,2表示userseq_feat: 用戶序列特征,每個元素為字典,key為特征ID,value為特征值user_id: user_id eg. user_xxxxxx ,便于后面對照答案"""user_sequence = self._load_user_data(uid) # 動態加載用戶數據ext_user_sequence = []for record_tuple in user_sequence:u, i, user_feat, item_feat, _, _ = record_tupleif u:if type(u) == str: # 如果是字符串,說明是user_iduser_id = uelse: # 如果是int,說明是re_iduser_id = self.indexer_u_rev[u]if u and user_feat:if type(u) == str:u = 0if user_feat:user_feat = self._process_cold_start_feat(user_feat)ext_user_sequence.insert(0, (u, user_feat, 2))if i and item_feat:# 序列對于訓練時沒見過的item,不會直接賦0,而是保留creative_id,creative_id遠大于訓練時的itemnumif i > self.itemnum:i = 0if item_feat:item_feat = self._process_cold_start_feat(item_feat)ext_user_sequence.append((i, item_feat, 1))seq = np.zeros([self.maxlen + 1], dtype=np.int32)token_type = np.zeros([self.maxlen + 1], dtype=np.int32)seq_feat = np.empty([self.maxlen + 1], dtype=object)idx = self.maxlents = set()for record_tuple in ext_user_sequence:if record_tuple[2] == 1 and record_tuple[0]:ts.add(record_tuple[0])for record_tuple in reversed(ext_user_sequence[:-1]):i, feat, type_ = record_tuplefeat = self.fill_missing_feat(feat, i)seq[idx] = itoken_type[idx] = type_seq_feat[idx] = featidx -= 1if idx == -1:breakseq_feat = np.where(seq_feat == None, self.feature_default_value, seq_feat)return seq, token_type, seq_feat, user_iddef __len__(self):"""Returns:len(self.seq_offsets): 用戶數量"""with open(Path(self.data_dir, 'predict_seq_offsets.pkl'), 'rb') as f:temp = pickle.load(f)return len(temp)@staticmethoddef collate_fn(batch):"""將多個__getitem__返回的數據拼接成一個batchArgs:batch: 多個__getitem__返回的數據Returns:seq: 用戶序列ID, torch.Tensor形式token_type: 用戶序列類型, torch.Tensor形式seq_feat: 用戶序列特征, list形式user_id: user_id, str"""seq, token_type, seq_feat, user_id = zip(*batch)seq = torch.from_numpy(np.array(seq))token_type = torch.from_numpy(np.array(token_type))seq_feat = list(seq_feat)return seq, token_type, seq_feat, user_iddef save_emb(emb, save_path):"""將Embedding保存為二進制文件Args:emb: 要保存的Embedding,形狀為 [num_points, num_dimensions]save_path: 保存路徑"""num_points = emb.shape[0] # 數據點數量num_dimensions = emb.shape[1] # 向量的維度print(f'saving {save_path}')with open(Path(save_path), 'wb') as f:f.write(struct.pack('II', num_points, num_dimensions))emb.tofile(f)def load_mm_emb(mm_path, feat_ids):"""加載多模態特征EmbeddingArgs:mm_path: 多模態特征Embedding路徑feat_ids: 要加載的多模態特征ID列表Returns:mm_emb_dict: 多模態特征Embedding字典,key為特征ID,value為特征Embedding字典(key為item ID,value為Embedding)"""SHAPE_DICT = {"81": 32, "82": 1024, "83": 3584, "84": 4096, "85": 3584, "86": 3584}mm_emb_dict = {}for feat_id in tqdm(feat_ids, desc='Loading mm_emb'):shape = SHAPE_DICT[feat_id]emb_dict = {}if feat_id != '81':try:base_path = Path(mm_path, f'emb_{feat_id}_{shape}')for json_file in base_path.glob('*.json'):with open(json_file, 'r', encoding='utf-8') as file:for line in file:data_dict_origin = json.loads(line.strip())insert_emb = data_dict_origin['emb']if isinstance(insert_emb, list):insert_emb = np.array(insert_emb, dtype=np.float32)data_dict = {data_dict_origin['anonymous_cid']: insert_emb}emb_dict.update(data_dict)except Exception as e:print(f"transfer error: {e}")if feat_id == '81':with open(Path(mm_path, f'emb_{feat_id}_{shape}.pkl'), 'rb') as f:emb_dict = pickle.load(f)mm_emb_dict[feat_id] = emb_dictprint(f'Loaded #{feat_id} mm_emb')return mm_emb_dict
4. model_rqvae.py - 多模態特征壓縮
實現了 RQ-VAE(Residual Quantized Variational AutoEncoder)框架,用于將高維多模態 embedding 轉換為離散的語義 ID:
核心組件:
RQEncoder
/RQDecoder
:編碼器和解碼器VQEmbedding
:向量量化模塊,支持 K-means 初始化RQ
:殘差量化器,實現多級量化RQVAE
:完整的 RQ-VAE 模型
量化方法:
- 支持標準 K-means 和平衡 K-means 聚類
- 使用余弦距離或 L2 距離進行向量量化
- 通過殘差量化實現更精確的特征表示
model_rqvae.py 代碼
"""
選手可參考以下流程,使用提供的 RQ-VAE 框架代碼將多模態emb數據轉換為Semantic Id:
1. 使用 MmEmbDataset 讀取不同特征 ID 的多模態emb數據.
2. 訓練 RQ-VAE 模型, 訓練完成后將數據轉換為Semantic Id.
3. 參照 Item Sparse 特征格式處理Semantic Id,作為新特征加入Baseline模型訓練.
"""import torch
import torch.nn.functional as F
from sklearn.cluster import KMeans# class MmEmbDataset(torch.utils.data.Dataset):
# """
# Build Dataset for RQ-VAE Training# Args:
# data_dir = os.environ.get('TRAIN_DATA_PATH')
# feature_id = MM emb ID
# """# def __init__(self, data_dir, feature_id):
# super().__init__()
# self.data_dir = Path(data_dir)
# self.mm_emb_id = [feature_id]
# self.mm_emb_dict = load_mm_emb(Path(data_dir, "creative_emb"), self.mm_emb_id)# self.mm_emb = self.mm_emb_dict[self.mm_emb_id[0]]
# self.tid_list, self.emb_list = list(self.mm_emb.keys()), list(self.mm_emb.values())
# self.emb_list = [torch.tensor(emb, dtype=torch.float32) for emb in self.emb_list]# assert len(self.tid_list) == len(self.emb_list)
# self.item_cnt = len(self.tid_list)# def __getitem__(self, index):
# tid = torch.tensor(self.tid_list[index], dtype=torch.long)
# emb = self.emb_list[index]
# return tid, emb# def __len__(self):
# return self.item_cnt# @staticmethod
# def collate_fn(batch):
# tid, emb = zip(*batch)# tid_batch, emb_batch = torch.stack(tid, dim=0), torch.stack(emb, dim=0)
# return tid_batch, emb_batch## Kmeans
def kmeans(data, n_clusters, kmeans_iters):"""auto init: n_init = 10 if n_clusters <= 10 else 1"""km = KMeans(n_clusters=n_clusters, max_iter=kmeans_iters, n_init="auto")# sklearn only support cpudata_cpu = data.detach().cpu()np_data = data_cpu.numpy()km.fit(np_data)return torch.tensor(km.cluster_centers_), torch.tensor(km.labels_)## Balanced Kmeans
class BalancedKmeans(torch.nn.Module):def __init__(self, num_clusters: int, kmeans_iters: int, tolerance: float, device: str):super().__init__()self.num_clusters = num_clustersself.kmeans_iters = kmeans_itersself.tolerance = toleranceself.device = deviceself._codebook = Nonedef _compute_distances(self, data):return torch.cdist(data, self._codebook)def _assign_clusters(self, dist):samples_cnt = dist.shape[0]samples_labels = torch.zeros(samples_cnt, dtype=torch.long, device=self.device)clusters_cnt = torch.zeros(self.num_clusters, dtype=torch.long, device=self.device)sorted_indices = torch.argsort(dist, dim=-1)for i in range(samples_cnt):for j in range(self.num_clusters):cluster_idx = sorted_indices[i, j]if clusters_cnt[cluster_idx] < samples_cnt // self.num_clusters:samples_labels[i] = cluster_idxclusters_cnt[cluster_idx] += 1breakreturn samples_labelsdef _update_codebook(self, data, samples_labels):_new_codebook = []for i in range(self.num_clusters):cluster_data = data[samples_labels == i]if len(cluster_data) > 0:_new_codebook.append(cluster_data.mean(dim=0))else:_new_codebook.append(self._codebook[i])return torch.stack(_new_codebook)def fit(self, data):num_emb, codebook_emb_dim = data.shapedata = data.to(self.device)# initialize codebookindices = torch.randperm(num_emb)[: self.num_clusters]self._codebook = data[indices].clone()for _ in range(self.kmeans_iters):dist = self._compute_distances(data)samples_labels = self._assign_clusters(dist)_new_codebook = self._update_codebook(data, samples_labels)if torch.norm(_new_codebook - self._codebook) < self.tolerance:breakself._codebook = _new_codebookreturn self._codebook, samples_labelsdef predict(self, data):data = data.to(self.device)dist = self._compute_distances(data)samples_labels = self._assign_clusters(dist)return samples_labels## Base RQVAE
class RQEncoder(torch.nn.Module):def __init__(self, input_dim: int, hidden_channels: list, latent_dim: int):super().__init__()self.stages = torch.nn.ModuleList()in_dim = input_dimfor out_dim in hidden_channels:stage = torch.nn.Sequential(torch.nn.Linear(in_dim, out_dim), torch.nn.ReLU())self.stages.append(stage)in_dim = out_dimself.stages.append(torch.nn.Sequential(torch.nn.Linear(in_dim, latent_dim), torch.nn.ReLU()))def forward(self, x):for stage in self.stages:x = stage(x)return xclass RQDecoder(torch.nn.Module):def __init__(self, latent_dim: int, hidden_channels: list, output_dim: int):super().__init__()self.stages = torch.nn.ModuleList()in_dim = latent_dimfor out_dim in hidden_channels:stage = torch.nn.Sequential(torch.nn.Linear(in_dim, out_dim), torch.nn.ReLU())self.stages.append(stage)in_dim = out_dimself.stages.append(torch.nn.Sequential(torch.nn.Linear(in_dim, output_dim), torch.nn.ReLU()))def forward(self, x):for stage in self.stages:x = stage(x)return x## Generate semantic id
class VQEmbedding(torch.nn.Embedding):def __init__(self,num_clusters,codebook_emb_dim: int,kmeans_method: str,kmeans_iters: int,distances_method: str,device: str,):super(VQEmbedding, self).__init__(num_clusters, codebook_emb_dim)self.num_clusters = num_clustersself.codebook_emb_dim = codebook_emb_dimself.kmeans_method = kmeans_methodself.kmeans_iters = kmeans_itersself.distances_method = distances_methodself.device = devicedef _create_codebook(self, data):if self.kmeans_method == 'kmeans':_codebook, _ = kmeans(data, self.num_clusters, self.kmeans_iters)elif self.kmeans_method == 'bkmeans':BKmeans = BalancedKmeans(num_clusters=self.num_clusters, kmeans_iters=self.kmeans_iters, tolerance=1e-4, device=self.device)_codebook, _ = BKmeans.fit(data)else:_codebook = torch.randn(self.num_clusters, self.codebook_emb_dim)_codebook = _codebook.to(self.device)assert _codebook.shape == (self.num_clusters, self.codebook_emb_dim)self.codebook = torch.nn.Parameter(_codebook)@torch.no_grad()def _compute_distances(self, data):_codebook_t = self.codebook.t()assert _codebook_t.shape == (self.codebook_emb_dim, self.num_clusters)assert data.shape[-1] == self.codebook_emb_dimif self.distances_method == 'cosine':data_norm = F.normalize(data, p=2, dim=-1)_codebook_t_norm = F.normalize(_codebook_t, p=2, dim=0)distances = 1 - torch.mm(data_norm, _codebook_t_norm)# l2else:data_norm_sq = data.pow(2).sum(dim=-1, keepdim=True)_codebook_t_norm_sq = _codebook_t.pow(2).sum(dim=0, keepdim=True)distances = torch.addmm(data_norm_sq + _codebook_t_norm_sq, data, _codebook_t, beta=1.0, alpha=-2.0)return distances@torch.no_grad()def _create_semantic_id(self, data):distances = self._compute_distances(data)_semantic_id = torch.argmin(distances, dim=-1)return _semantic_iddef _update_emb(self, _semantic_id):update_emb = super().forward(_semantic_id)return update_embdef forward(self, data):self._create_codebook(data)_semantic_id = self._create_semantic_id(data)update_emb = self._update_emb(_semantic_id)return update_emb, _semantic_id## Residual Quantizer
class RQ(torch.nn.Module):"""Args:num_codebooks, codebook_size, codebook_emb_dim -> Build codebookif_shared_codebook -> If use same codebookkmeans_method, kmeans_iters -> Initialize codebookdistances_method -> Generate semantic_idloss_beta -> Calculate RQ-VAE loss"""def __init__(self,num_codebooks: int,codebook_size: list,codebook_emb_dim,shared_codebook: bool,kmeans_method,kmeans_iters,distances_method,loss_beta: float,device: str,):super().__init__()self.num_codebooks = num_codebooksself.codebook_size = codebook_sizeassert len(self.codebook_size) == self.num_codebooksself.codebook_emb_dim = codebook_emb_dimself.shared_codebook = shared_codebookself.kmeans_method = kmeans_methodself.kmeans_iters = kmeans_itersself.distances_method = distances_methodself.loss_beta = loss_betaself.device = deviceif self.shared_codebook:self.vqmodules = torch.nn.ModuleList([VQEmbedding(self.codebook_size[0],self.codebook_emb_dim,self.kmeans_method,self.kmeans_iters,self.distances_method,self.device,)for _ in range(self.num_codebooks)])else:self.vqmodules = torch.nn.ModuleList([VQEmbedding(self.codebook_size[idx],self.codebook_emb_dim,self.kmeans_method,self.kmeans_iters,self.distances_method,self.device,)for idx in range(self.num_codebooks)])def quantize(self, data):"""Exa:i-th quantize: input[i]( i.e. res[i-1] ) = VQ[i] + res[i]vq_emb_list: [vq1, vq1+vq2, ...]res_emb_list: [res1, res2, ...]semantic_id_list: [vq1_sid, vq2_sid, ...]Returns:vq_emb_list[0] -> [batch_size, codebook_emb_dim]semantic_id_list -> [batch_size, num_codebooks]"""res_emb = data.detach().clone()vq_emb_list, res_emb_list = [], []semantic_id_list = []vq_emb_aggre = torch.zeros_like(data)for i in range(self.num_codebooks):vq_emb, _semantic_id = self.vqmodules[i](res_emb)res_emb -= vq_embvq_emb_aggre += vq_embres_emb_list.append(res_emb)vq_emb_list.append(vq_emb_aggre)semantic_id_list.append(_semantic_id.unsqueeze(dim=-1))semantic_id_list = torch.cat(semantic_id_list, dim=-1)return vq_emb_list, res_emb_list, semantic_id_listdef _rqvae_loss(self, vq_emb_list, res_emb_list):rqvae_loss_list = []for idx, quant in enumerate(vq_emb_list):# stop gradientloss1 = (res_emb_list[idx].detach() - quant).pow(2.0).mean()loss2 = (res_emb_list[idx] - quant.detach()).pow(2.0).mean()partial_loss = loss1 + self.loss_beta * loss2rqvae_loss_list.append(partial_loss)rqvae_loss = torch.sum(torch.stack(rqvae_loss_list))return rqvae_lossdef forward(self, data):vq_emb_list, res_emb_list, semantic_id_list = self.quantize(data)rqvae_loss = self._rqvae_loss(vq_emb_list, res_emb_list)return vq_emb_list, semantic_id_list, rqvae_lossclass RQVAE(torch.nn.Module):def __init__(self,input_dim: int,hidden_channels: list,latent_dim: int,num_codebooks: int,codebook_size: list,shared_codebook: bool,kmeans_method,kmeans_iters,distances_method,loss_beta: float,device: str,):super().__init__()self.encoder = RQEncoder(input_dim, hidden_channels, latent_dim).to(device)self.decoder = RQDecoder(latent_dim, hidden_channels[::-1], input_dim).to(device)self.rq = RQ(num_codebooks,codebook_size,latent_dim,shared_codebook,kmeans_method,kmeans_iters,distances_method,loss_beta,device,).to(device)def encode(self, x):return self.encoder(x)def decode(self, z_vq):if isinstance(z_vq, list):z_vq = z_vq[-1]return self.decoder(z_vq)def compute_loss(self, x_hat, x_gt, rqvae_loss):recon_loss = F.mse_loss(x_hat, x_gt, reduction="mean")total_loss = recon_loss + rqvae_lossreturn recon_loss, rqvae_loss, total_lossdef _get_codebook(self, x_gt):z_e = self.encode(x_gt)vq_emb_list, semantic_id_list, rqvae_loss = self.rq(z_e)return semantic_id_listdef forward(self, x_gt):z_e = self.encode(x_gt)vq_emb_list, semantic_id_list, rqvae_loss = self.rq(z_e)x_hat = self.decode(vq_emb_list)recon_loss, rqvae_loss, total_loss = self.compute_loss(x_hat, x_gt, rqvae_loss)return x_hat, semantic_id_list, recon_loss, rqvae_loss, total_loss
5. run.sh - 運行腳本
簡單的 bash 腳本,用于啟動訓練程序。
run.sh 代碼
#!/bin/bash# show ${RUNTIME_SCRIPT_DIR}
echo ${RUNTIME_SCRIPT_DIR}
# enter train workspace
cd ${RUNTIME_SCRIPT_DIR}# write your code below
python -u main.py
技術特點
- 高效注意力機制:使用 Flash Attention 優化計算效率
- 多模態融合:支持文本、圖像等多種模態的 embedding 特征
- 特征工程:支持稀疏、密集、數組等多種特征類型
- 序列建模:同時建模用戶和物品的交互序列
- 可擴展性:支持大規模物品庫的 embedding 保存和檢索
數據流程
- 訓練階段:讀取用戶序列 → 特征 embedding → Transformer 編碼 → 計算正負樣本 loss
- 推理階段:生成用戶表征 → 保存物品 embedding → 進行向量檢索推薦
- 多模態處理:原始 embedding → RQ-VAE 壓縮 → 語義 ID → 作為新特征加入模型