2025 騰訊廣告算法大賽 Baseline 項目解析

項目概述

2025 騰訊廣告算法大賽 Baseline,一個簡單的序列推薦系統,主要用于建模用戶和物品的交互序列,并利用多模態特征(文本、圖像等 embedding)來提升推薦效果。

核心文件功能

1. main.py - 主訓練腳本

  • 負責模型訓練的整體流程
  • 包含參數解析、數據加載、模型初始化、訓練循環等
  • 支持斷點續訓和僅推理模式
  • 使用 TensorBoard 記錄訓練日志
main.py 代碼
import argparse
import json
import os
import time
from pathlib import Pathimport numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdmfrom dataset import MyDataset
from model import BaselineModeldef get_args():parser = argparse.ArgumentParser()# Train paramsparser.add_argument('--batch_size', default=128, type=int)parser.add_argument('--lr', default=0.001, type=float)parser.add_argument('--maxlen', default=101, type=int)# Baseline Model constructionparser.add_argument('--hidden_units', default=32, type=int)parser.add_argument('--num_blocks', default=1, type=int)parser.add_argument('--num_epochs', default=3, type=int)parser.add_argument('--num_heads', default=1, type=int)parser.add_argument('--dropout_rate', default=0.2, type=float)parser.add_argument('--l2_emb', default=0.0, type=float)parser.add_argument('--device', default='cuda', type=str)parser.add_argument('--inference_only', action='store_true')parser.add_argument('--state_dict_path', default=None, type=str)parser.add_argument('--norm_first', action='store_true')# MMemb Feature IDparser.add_argument('--mm_emb_id', nargs='+', default=['81'], type=str, choices=[str(s) for s in range(81, 87)])args = parser.parse_args()return argsif __name__ == '__main__':Path(os.environ.get('TRAIN_LOG_PATH')).mkdir(parents=True, exist_ok=True)Path(os.environ.get('TRAIN_TF_EVENTS_PATH')).mkdir(parents=True, exist_ok=True)log_file = open(Path(os.environ.get('TRAIN_LOG_PATH'), 'train.log'), 'w')writer = SummaryWriter(os.environ.get('TRAIN_TF_EVENTS_PATH'))# global datasetdata_path = os.environ.get('TRAIN_DATA_PATH')args = get_args()dataset = MyDataset(data_path, args)train_dataset, valid_dataset = torch.utils.data.random_split(dataset, [0.9, 0.1])train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0, collate_fn=dataset.collate_fn)valid_loader = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0, collate_fn=dataset.collate_fn)usernum, itemnum = dataset.usernum, dataset.itemnumfeat_statistics, feat_types = dataset.feat_statistics, dataset.feature_typesmodel = BaselineModel(usernum, itemnum, feat_statistics, feat_types, args).to(args.device)for name, param in model.named_parameters():try:torch.nn.init.xavier_normal_(param.data)except Exception:passmodel.pos_emb.weight.data[0, :] = 0model.item_emb.weight.data[0, :] = 0model.user_emb.weight.data[0, :] = 0for k in model.sparse_emb:model.sparse_emb[k].weight.data[0, :] = 0epoch_start_idx = 1if args.state_dict_path is not None:try:model.load_state_dict(torch.load(args.state_dict_path, map_location=torch.device(args.device)))tail = args.state_dict_path[args.state_dict_path.find('epoch=') + 6 :]epoch_start_idx = int(tail[: tail.find('.')]) + 1except:print('failed loading state_dicts, pls check file path: ', end="")print(args.state_dict_path)raise RuntimeError('failed loading state_dicts, pls check file path!')bce_criterion = torch.nn.BCEWithLogitsLoss(reduction='mean')optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.98))best_val_ndcg, best_val_hr = 0.0, 0.0best_test_ndcg, best_test_hr = 0.0, 0.0T = 0.0t0 = time.time()global_step = 0print("Start training")for epoch in range(epoch_start_idx, args.num_epochs + 1):model.train()if args.inference_only:breakfor step, batch in tqdm(enumerate(train_loader), total=len(train_loader)):seq, pos, neg, token_type, next_token_type, next_action_type, seq_feat, pos_feat, neg_feat = batchseq = seq.to(args.device)pos = pos.to(args.device)neg = neg.to(args.device)pos_logits, neg_logits = model(seq, pos, neg, token_type, next_token_type, next_action_type, seq_feat, pos_feat, neg_feat)pos_labels, neg_labels = torch.ones(pos_logits.shape, device=args.device), torch.zeros(neg_logits.shape, device=args.device)optimizer.zero_grad()indices = np.where(next_token_type == 1)loss = bce_criterion(pos_logits[indices], pos_labels[indices])loss += bce_criterion(neg_logits[indices], neg_labels[indices])log_json = json.dumps({'global_step': global_step, 'loss': loss.item(), 'epoch': epoch, 'time': time.time()})log_file.write(log_json + '\n')log_file.flush()print(log_json)writer.add_scalar('Loss/train', loss.item(), global_step)global_step += 1for param in model.item_emb.parameters():loss += args.l2_emb * torch.norm(param)loss.backward()optimizer.step()model.eval()valid_loss_sum = 0for step, batch in tqdm(enumerate(valid_loader), total=len(valid_loader)):seq, pos, neg, token_type, next_token_type, next_action_type, seq_feat, pos_feat, neg_feat = batchseq = seq.to(args.device)pos = pos.to(args.device)neg = neg.to(args.device)pos_logits, neg_logits = model(seq, pos, neg, token_type, next_token_type, next_action_type, seq_feat, pos_feat, neg_feat)pos_labels, neg_labels = torch.ones(pos_logits.shape, device=args.device), torch.zeros(neg_logits.shape, device=args.device)indices = np.where(next_token_type == 1)loss = bce_criterion(pos_logits[indices], pos_labels[indices])loss += bce_criterion(neg_logits[indices], neg_labels[indices])valid_loss_sum += loss.item()valid_loss_sum /= len(valid_loader)writer.add_scalar('Loss/valid', valid_loss_sum, global_step)save_dir = Path(os.environ.get('TRAIN_CKPT_PATH'), f"global_step{global_step}.valid_loss={valid_loss_sum:.4f}")save_dir.mkdir(parents=True, exist_ok=True)torch.save(model.state_dict(), save_dir / "model.pt")print("Done")writer.close()log_file.close()

2. model.py - 核心模型實現

BaselineModel - 主推薦模型

基于 Transformer 的序列推薦模型,具有以下特點:

模型架構

  • 使用 FlashMultiHeadAttention 實現高效的多頭注意力機制
  • 采用 PointWiseFeedForward 作為前饋網絡
  • 支持多種特征類型:稀疏特征、數組特征、連續特征、多模態 embedding 特征

特征處理

  • 用戶特征:稀疏特征 (103,104,105,109)、數組特征 (106,107,108,110)
  • 物品特征:稀疏特征 (100,117,111 等)、多模態 embedding 特征 (81-86)
  • 通過 feat2emb 方法將不同類型特征轉換為統一的 embedding 表示

核心方法

  • log2feats:將用戶序列轉換為特征表示
  • forward:訓練時計算正負樣本的 logits
  • predict:推理時生成用戶表征
  • save_item_emb:保存物品 embedding 用于檢索
model.py 代碼
from pathlib import Pathimport numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdmfrom dataset import save_embclass FlashMultiHeadAttention(torch.nn.Module):def __init__(self, hidden_units, num_heads, dropout_rate):super(FlashMultiHeadAttention, self).__init__()self.hidden_units = hidden_unitsself.num_heads = num_headsself.head_dim = hidden_units // num_headsself.dropout_rate = dropout_rateassert hidden_units % num_heads == 0, "hidden_units must be divisible by num_heads"self.q_linear = torch.nn.Linear(hidden_units, hidden_units)self.k_linear = torch.nn.Linear(hidden_units, hidden_units)self.v_linear = torch.nn.Linear(hidden_units, hidden_units)self.out_linear = torch.nn.Linear(hidden_units, hidden_units)def forward(self, query, key, value, attn_mask=None):batch_size, seq_len, _ = query.size()# 計算Q, K, VQ = self.q_linear(query)K = self.k_linear(key)V = self.v_linear(value)# reshape為multi-head格式Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)if hasattr(F, 'scaled_dot_product_attention'):# PyTorch 2.0+ 使用內置的Flash Attentionattn_output = F.scaled_dot_product_attention(Q, K, V, dropout_p=self.dropout_rate if self.training else 0.0, attn_mask=attn_mask.unsqueeze(1))else:# 降級到標準注意力機制scale = (self.head_dim) ** -0.5scores = torch.matmul(Q, K.transpose(-2, -1)) * scaleif attn_mask is not None:scores.masked_fill_(attn_mask.unsqueeze(1).logical_not(), float('-inf'))attn_weights = F.softmax(scores, dim=-1)attn_weights = F.dropout(attn_weights, p=self.dropout_rate, training=self.training)attn_output = torch.matmul(attn_weights, V)# reshape回原來的格式attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.hidden_units)# 最終的線性變換output = self.out_linear(attn_output)return output, Noneclass PointWiseFeedForward(torch.nn.Module):def __init__(self, hidden_units, dropout_rate):super(PointWiseFeedForward, self).__init__()self.conv1 = torch.nn.Conv1d(hidden_units, hidden_units, kernel_size=1)self.dropout1 = torch.nn.Dropout(p=dropout_rate)self.relu = torch.nn.ReLU()self.conv2 = torch.nn.Conv1d(hidden_units, hidden_units, kernel_size=1)self.dropout2 = torch.nn.Dropout(p=dropout_rate)def forward(self, inputs):outputs = self.dropout2(self.conv2(self.relu(self.dropout1(self.conv1(inputs.transpose(-1, -2))))))outputs = outputs.transpose(-1, -2)  # as Conv1D requires (N, C, Length)return outputsclass BaselineModel(torch.nn.Module):"""Args:user_num: 用戶數量item_num: 物品數量feat_statistics: 特征統計信息,key為特征ID,value為特征數量feat_types: 各個特征的特征類型,key為特征類型名稱,value為包含的特征ID列表,包括user和item的sparse, array, emb, continual類型args: 全局參數Attributes:user_num: 用戶數量item_num: 物品數量dev: 設備norm_first: 是否先歸一化maxlen: 序列最大長度item_emb: Item Embedding Tableuser_emb: User Embedding Tablesparse_emb: 稀疏特征Embedding Tableemb_transform: 多模態特征的線性變換userdnn: 用戶特征拼接后經過的全連接層itemdnn: 物品特征拼接后經過的全連接層"""def __init__(self, user_num, item_num, feat_statistics, feat_types, args):  #super(BaselineModel, self).__init__()self.user_num = user_numself.item_num = item_numself.dev = args.deviceself.norm_first = args.norm_firstself.maxlen = args.maxlen# TODO: loss += args.l2_emb for regularizing embedding vectors during training# https://stackoverflow.com/questions/42704283/adding-l1-l2-regularization-in-pytorchself.item_emb = torch.nn.Embedding(self.item_num + 1, args.hidden_units, padding_idx=0)self.user_emb = torch.nn.Embedding(self.user_num + 1, args.hidden_units, padding_idx=0)self.pos_emb = torch.nn.Embedding(2 * args.maxlen + 1, args.hidden_units, padding_idx=0)self.emb_dropout = torch.nn.Dropout(p=args.dropout_rate)self.sparse_emb = torch.nn.ModuleDict()self.emb_transform = torch.nn.ModuleDict()self.attention_layernorms = torch.nn.ModuleList()  # to be Q for self-attentionself.attention_layers = torch.nn.ModuleList()self.forward_layernorms = torch.nn.ModuleList()self.forward_layers = torch.nn.ModuleList()self._init_feat_info(feat_statistics, feat_types)userdim = args.hidden_units * (len(self.USER_SPARSE_FEAT) + 1 + len(self.USER_ARRAY_FEAT)) + len(self.USER_CONTINUAL_FEAT)itemdim = (args.hidden_units * (len(self.ITEM_SPARSE_FEAT) + 1 + len(self.ITEM_ARRAY_FEAT))+ len(self.ITEM_CONTINUAL_FEAT)+ args.hidden_units * len(self.ITEM_EMB_FEAT))self.userdnn = torch.nn.Linear(userdim, args.hidden_units)self.itemdnn = torch.nn.Linear(itemdim, args.hidden_units)self.last_layernorm = torch.nn.LayerNorm(args.hidden_units, eps=1e-8)for _ in range(args.num_blocks):new_attn_layernorm = torch.nn.LayerNorm(args.hidden_units, eps=1e-8)self.attention_layernorms.append(new_attn_layernorm)new_attn_layer = FlashMultiHeadAttention(args.hidden_units, args.num_heads, args.dropout_rate)  # 優化:用FlashAttention替代標準Attentionself.attention_layers.append(new_attn_layer)new_fwd_layernorm = torch.nn.LayerNorm(args.hidden_units, eps=1e-8)self.forward_layernorms.append(new_fwd_layernorm)new_fwd_layer = PointWiseFeedForward(args.hidden_units, args.dropout_rate)self.forward_layers.append(new_fwd_layer)for k in self.USER_SPARSE_FEAT:self.sparse_emb[k] = torch.nn.Embedding(self.USER_SPARSE_FEAT[k] + 1, args.hidden_units, padding_idx=0)for k in self.ITEM_SPARSE_FEAT:self.sparse_emb[k] = torch.nn.Embedding(self.ITEM_SPARSE_FEAT[k] + 1, args.hidden_units, padding_idx=0)for k in self.ITEM_ARRAY_FEAT:self.sparse_emb[k] = torch.nn.Embedding(self.ITEM_ARRAY_FEAT[k] + 1, args.hidden_units, padding_idx=0)for k in self.USER_ARRAY_FEAT:self.sparse_emb[k] = torch.nn.Embedding(self.USER_ARRAY_FEAT[k] + 1, args.hidden_units, padding_idx=0)for k in self.ITEM_EMB_FEAT:self.emb_transform[k] = torch.nn.Linear(self.ITEM_EMB_FEAT[k], args.hidden_units)def _init_feat_info(self, feat_statistics, feat_types):"""將特征統計信息(特征數量)按特征類型分組產生不同的字典,方便聲明稀疏特征的Embedding TableArgs:feat_statistics: 特征統計信息,key為特征ID,value為特征數量feat_types: 各個特征的特征類型,key為特征類型名稱,value為包含的特征ID列表,包括user和item的sparse, array, emb, continual類型"""self.USER_SPARSE_FEAT = {k: feat_statistics[k] for k in feat_types['user_sparse']}self.USER_CONTINUAL_FEAT = feat_types['user_continual']self.ITEM_SPARSE_FEAT = {k: feat_statistics[k] for k in feat_types['item_sparse']}self.ITEM_CONTINUAL_FEAT = feat_types['item_continual']self.USER_ARRAY_FEAT = {k: feat_statistics[k] for k in feat_types['user_array']}self.ITEM_ARRAY_FEAT = {k: feat_statistics[k] for k in feat_types['item_array']}EMB_SHAPE_DICT = {"81": 32, "82": 1024, "83": 3584, "84": 4096, "85": 3584, "86": 3584}self.ITEM_EMB_FEAT = {k: EMB_SHAPE_DICT[k] for k in feat_types['item_emb']}  # 記錄的是不同多模態特征的維度def feat2tensor(self, seq_feature, k):"""Args:seq_feature: 序列特征list,每個元素為當前時刻的特征字典,形狀為 [batch_size, maxlen]k: 特征IDReturns:batch_data: 特征值的tensor,形狀為 [batch_size, maxlen, max_array_len(if array)]"""batch_size = len(seq_feature)if k in self.ITEM_ARRAY_FEAT or k in self.USER_ARRAY_FEAT:# 如果特征是Array類型,需要先對array進行padding,然后轉換為tensormax_array_len = 0max_seq_len = 0for i in range(batch_size):seq_data = [item[k] for item in seq_feature[i]]max_seq_len = max(max_seq_len, len(seq_data))max_array_len = max(max_array_len, max(len(item_data) for item_data in seq_data))batch_data = np.zeros((batch_size, max_seq_len, max_array_len), dtype=np.int64)for i in range(batch_size):seq_data = [item[k] for item in seq_feature[i]]for j, item_data in enumerate(seq_data):actual_len = min(len(item_data), max_array_len)batch_data[i, j, :actual_len] = item_data[:actual_len]return torch.from_numpy(batch_data).to(self.dev)else:# 如果特征是Sparse類型,直接轉換為tensormax_seq_len = max(len(seq_feature[i]) for i in range(batch_size))batch_data = np.zeros((batch_size, max_seq_len), dtype=np.int64)for i in range(batch_size):seq_data = [item[k] for item in seq_feature[i]]batch_data[i] = seq_datareturn torch.from_numpy(batch_data).to(self.dev)def feat2emb(self, seq, feature_array, mask=None, include_user=False):"""Args:seq: 序列IDfeature_array: 特征list,每個元素為當前時刻的特征字典mask: 掩碼,1表示item,2表示userinclude_user: 是否處理用戶特征,在兩種情況下不打開:1) 訓練時在轉換正負樣本的特征時(因為正負樣本都是item);2) 生成候選庫item embedding時。Returns:seqs_emb: 序列特征的Embedding"""seq = seq.to(self.dev)# pre-compute embeddingif include_user:user_mask = (mask == 2).to(self.dev)item_mask = (mask == 1).to(self.dev)user_embedding = self.user_emb(user_mask * seq)item_embedding = self.item_emb(item_mask * seq)item_feat_list = [item_embedding]user_feat_list = [user_embedding]else:item_embedding = self.item_emb(seq)item_feat_list = [item_embedding]# batch-process all feature typesall_feat_types = [(self.ITEM_SPARSE_FEAT, 'item_sparse', item_feat_list),(self.ITEM_ARRAY_FEAT, 'item_array', item_feat_list),(self.ITEM_CONTINUAL_FEAT, 'item_continual', item_feat_list),]if include_user:all_feat_types.extend([(self.USER_SPARSE_FEAT, 'user_sparse', user_feat_list),(self.USER_ARRAY_FEAT, 'user_array', user_feat_list),(self.USER_CONTINUAL_FEAT, 'user_continual', user_feat_list),])# batch-process each feature typefor feat_dict, feat_type, feat_list in all_feat_types:if not feat_dict:continuefor k in feat_dict:tensor_feature = self.feat2tensor(feature_array, k)if feat_type.endswith('sparse'):feat_list.append(self.sparse_emb[k](tensor_feature))elif feat_type.endswith('array'):feat_list.append(self.sparse_emb[k](tensor_feature).sum(2))elif feat_type.endswith('continual'):feat_list.append(tensor_feature.unsqueeze(2))for k in self.ITEM_EMB_FEAT:# collect all data to numpy, then batch-convertbatch_size = len(feature_array)emb_dim = self.ITEM_EMB_FEAT[k]seq_len = len(feature_array[0])# pre-allocate tensorbatch_emb_data = np.zeros((batch_size, seq_len, emb_dim), dtype=np.float32)for i, seq in enumerate(feature_array):for j, item in enumerate(seq):if k in item:batch_emb_data[i, j] = item[k]# batch-convert and transfer to GPUtensor_feature = torch.from_numpy(batch_emb_data).to(self.dev)item_feat_list.append(self.emb_transform[k](tensor_feature))# merge featuresall_item_emb = torch.cat(item_feat_list, dim=2)all_item_emb = torch.relu(self.itemdnn(all_item_emb))if include_user:all_user_emb = torch.cat(user_feat_list, dim=2)all_user_emb = torch.relu(self.userdnn(all_user_emb))seqs_emb = all_item_emb + all_user_embelse:seqs_emb = all_item_embreturn seqs_embdef log2feats(self, log_seqs, mask, seq_feature):"""Args:log_seqs: 序列IDmask: token類型掩碼,1表示item token,2表示user tokenseq_feature: 序列特征list,每個元素為當前時刻的特征字典Returns:seqs_emb: 序列的Embedding,形狀為 [batch_size, maxlen, hidden_units]"""batch_size = log_seqs.shape[0]maxlen = log_seqs.shape[1]seqs = self.feat2emb(log_seqs, seq_feature, mask=mask, include_user=True)seqs *= self.item_emb.embedding_dim**0.5poss = torch.arange(1, maxlen + 1, device=self.dev).unsqueeze(0).expand(batch_size, -1).clone()poss *= log_seqs != 0seqs += self.pos_emb(poss)seqs = self.emb_dropout(seqs)maxlen = seqs.shape[1]ones_matrix = torch.ones((maxlen, maxlen), dtype=torch.bool, device=self.dev)attention_mask_tril = torch.tril(ones_matrix)attention_mask_pad = (mask != 0).to(self.dev)attention_mask = attention_mask_tril.unsqueeze(0) & attention_mask_pad.unsqueeze(1)for i in range(len(self.attention_layers)):if self.norm_first:x = self.attention_layernorms[i](seqs)mha_outputs, _ = self.attention_layers[i](x, x, x, attn_mask=attention_mask)seqs = seqs + mha_outputsseqs = seqs + self.forward_layers[i](self.forward_layernorms[i](seqs))else:mha_outputs, _ = self.attention_layers[i](seqs, seqs, seqs, attn_mask=attention_mask)seqs = self.attention_layernorms[i](seqs + mha_outputs)seqs = self.forward_layernorms[i](seqs + self.forward_layers[i](seqs))log_feats = self.last_layernorm(seqs)return log_featsdef forward(self, user_item, pos_seqs, neg_seqs, mask, next_mask, next_action_type, seq_feature, pos_feature, neg_feature):"""訓練時調用,計算正負樣本的logitsArgs:user_item: 用戶序列IDpos_seqs: 正樣本序列IDneg_seqs: 負樣本序列IDmask: token類型掩碼,1表示item token,2表示user tokennext_mask: 下一個token類型掩碼,1表示item token,2表示user tokennext_action_type: 下一個token動作類型,0表示曝光,1表示點擊seq_feature: 序列特征list,每個元素為當前時刻的特征字典pos_feature: 正樣本特征list,每個元素為當前時刻的特征字典neg_feature: 負樣本特征list,每個元素為當前時刻的特征字典Returns:pos_logits: 正樣本logits,形狀為 [batch_size, maxlen]neg_logits: 負樣本logits,形狀為 [batch_size, maxlen]"""log_feats = self.log2feats(user_item, mask, seq_feature)loss_mask = (next_mask == 1).to(self.dev)pos_embs = self.feat2emb(pos_seqs, pos_feature, include_user=False)neg_embs = self.feat2emb(neg_seqs, neg_feature, include_user=False)pos_logits = (log_feats * pos_embs).sum(dim=-1)neg_logits = (log_feats * neg_embs).sum(dim=-1)pos_logits = pos_logits * loss_maskneg_logits = neg_logits * loss_maskreturn pos_logits, neg_logitsdef predict(self, log_seqs, seq_feature, mask):"""計算用戶序列的表征Args:log_seqs: 用戶序列IDseq_feature: 序列特征list,每個元素為當前時刻的特征字典mask: token類型掩碼,1表示item token,2表示user tokenReturns:final_feat: 用戶序列的表征,形狀為 [batch_size, hidden_units]"""log_feats = self.log2feats(log_seqs, mask, seq_feature)final_feat = log_feats[:, -1, :]return final_featdef save_item_emb(self, item_ids, retrieval_ids, feat_dict, save_path, batch_size=1024):"""生成候選庫item embedding,用于檢索Args:item_ids: 候選item ID(re-id形式)retrieval_ids: 候選item ID(檢索ID,從0開始編號,檢索腳本使用)feat_dict: 訓練集所有item特征字典,key為特征ID,value為特征值save_path: 保存路徑batch_size: 批次大小"""all_embs = []for start_idx in tqdm(range(0, len(item_ids), batch_size), desc="Saving item embeddings"):end_idx = min(start_idx + batch_size, len(item_ids))item_seq = torch.tensor(item_ids[start_idx:end_idx], device=self.dev).unsqueeze(0)batch_feat = []for i in range(start_idx, end_idx):batch_feat.append(feat_dict[i])batch_feat = np.array(batch_feat, dtype=object)batch_emb = self.feat2emb(item_seq, [batch_feat], include_user=False).squeeze(0)all_embs.append(batch_emb.detach().cpu().numpy().astype(np.float32))# 合并所有批次的結果并保存final_ids = np.array(retrieval_ids, dtype=np.uint64).reshape(-1, 1)final_embs = np.concatenate(all_embs, axis=0)save_emb(final_embs, Path(save_path, 'embedding.fbin'))save_emb(final_ids, Path(save_path, 'id.u64bin'))

3. dataset.py - 數據處理

MyDataset - 訓練數據集
  • 處理用戶行為序列數據,支持用戶和物品交替出現的序列格式
  • 實現高效的數據加載,使用文件偏移量進行隨機訪問
  • 支持多種特征類型的 padding 和缺失值填充
  • 實現負采樣機制用于訓練
MyTestDataset - 測試數據集
  • 繼承自訓練數據集,專門用于推理階段
  • 處理冷啟動問題(訓練時未見過的特征值)
dataset.py 代碼
import json
import pickle
import struct
from pathlib import Pathimport numpy as np
import torch
from tqdm import tqdmclass MyDataset(torch.utils.data.Dataset):"""用戶序列數據集Args:data_dir: 數據文件目錄args: 全局參數Attributes:data_dir: 數據文件目錄maxlen: 最大長度item_feat_dict: 物品特征字典mm_emb_ids: 激活的mm_emb特征IDmm_emb_dict: 多模態特征字典itemnum: 物品數量usernum: 用戶數量indexer_i_rev: 物品索引字典 (reid -> item_id)indexer_u_rev: 用戶索引字典 (reid -> user_id)indexer: 索引字典feature_default_value: 特征缺省值feature_types: 特征類型,分為user和item的sparse, array, emb, continual類型feat_statistics: 特征統計信息,包括user和item的特征數量"""def __init__(self, data_dir, args):"""初始化數據集"""super().__init__()self.data_dir = Path(data_dir)self._load_data_and_offsets()self.maxlen = args.maxlenself.mm_emb_ids = args.mm_emb_idself.item_feat_dict = json.load(open(Path(data_dir, "item_feat_dict.json"), 'r'))self.mm_emb_dict = load_mm_emb(Path(data_dir, "creative_emb"), self.mm_emb_ids)with open(self.data_dir / 'indexer.pkl', 'rb') as ff:indexer = pickle.load(ff)self.itemnum = len(indexer['i'])self.usernum = len(indexer['u'])self.indexer_i_rev = {v: k for k, v in indexer['i'].items()}self.indexer_u_rev = {v: k for k, v in indexer['u'].items()}self.indexer = indexerself.feature_default_value, self.feature_types, self.feat_statistics = self._init_feat_info()def _load_data_and_offsets(self):"""加載用戶序列數據和每一行的文件偏移量(預處理好的), 用于快速隨機訪問數據并I/O"""self.data_file = open(self.data_dir / "seq.jsonl", 'rb')with open(Path(self.data_dir, 'seq_offsets.pkl'), 'rb') as f:self.seq_offsets = pickle.load(f)def _load_user_data(self, uid):"""從數據文件中加載單個用戶的數據Args:uid: 用戶ID(reid)Returns:data: 用戶序列數據,格式為[(user_id, item_id, user_feat, item_feat, action_type, timestamp)]"""self.data_file.seek(self.seq_offsets[uid])line = self.data_file.readline()data = json.loads(line)return datadef _random_neq(self, l, r, s):"""生成一個不在序列s中的隨機整數, 用于訓練時的負采樣Args:l: 隨機整數的最小值r: 隨機整數的最大值s: 序列Returns:t: 不在序列s中的隨機整數"""t = np.random.randint(l, r)while t in s or str(t) not in self.item_feat_dict:t = np.random.randint(l, r)return tdef __getitem__(self, uid):"""獲取單個用戶的數據,并進行padding處理,生成模型需要的數據格式Args:uid: 用戶ID(reid)Returns:seq: 用戶序列IDpos: 正樣本ID(即下一個真實訪問的item)neg: 負樣本IDtoken_type: 用戶序列類型,1表示item,2表示usernext_token_type: 下一個token類型,1表示item,2表示userseq_feat: 用戶序列特征,每個元素為字典,key為特征ID,value為特征值pos_feat: 正樣本特征,每個元素為字典,key為特征ID,value為特征值neg_feat: 負樣本特征,每個元素為字典,key為特征ID,value為特征值"""user_sequence = self._load_user_data(uid)  # 動態加載用戶數據ext_user_sequence = []for record_tuple in user_sequence:u, i, user_feat, item_feat, action_type, _ = record_tupleif u and user_feat:ext_user_sequence.insert(0, (u, user_feat, 2, action_type))if i and item_feat:ext_user_sequence.append((i, item_feat, 1, action_type))seq = np.zeros([self.maxlen + 1], dtype=np.int32)pos = np.zeros([self.maxlen + 1], dtype=np.int32)neg = np.zeros([self.maxlen + 1], dtype=np.int32)token_type = np.zeros([self.maxlen + 1], dtype=np.int32)next_token_type = np.zeros([self.maxlen + 1], dtype=np.int32)next_action_type = np.zeros([self.maxlen + 1], dtype=np.int32)seq_feat = np.empty([self.maxlen + 1], dtype=object)pos_feat = np.empty([self.maxlen + 1], dtype=object)neg_feat = np.empty([self.maxlen + 1], dtype=object)nxt = ext_user_sequence[-1]idx = self.maxlents = set()for record_tuple in ext_user_sequence:if record_tuple[2] == 1 and record_tuple[0]:ts.add(record_tuple[0])# left-padding, 從后往前遍歷,將用戶序列填充到maxlen+1的長度for record_tuple in reversed(ext_user_sequence[:-1]):i, feat, type_, act_type = record_tuplenext_i, next_feat, next_type, next_act_type = nxtfeat = self.fill_missing_feat(feat, i)next_feat = self.fill_missing_feat(next_feat, next_i)seq[idx] = itoken_type[idx] = type_next_token_type[idx] = next_typeif next_act_type is not None:next_action_type[idx] = next_act_typeseq_feat[idx] = featif next_type == 1 and next_i != 0:pos[idx] = next_ipos_feat[idx] = next_featneg_id = self._random_neq(1, self.itemnum + 1, ts)neg[idx] = neg_idneg_feat[idx] = self.fill_missing_feat(self.item_feat_dict[str(neg_id)], neg_id)nxt = record_tupleidx -= 1if idx == -1:breakseq_feat = np.where(seq_feat == None, self.feature_default_value, seq_feat)pos_feat = np.where(pos_feat == None, self.feature_default_value, pos_feat)neg_feat = np.where(neg_feat == None, self.feature_default_value, neg_feat)return seq, pos, neg, token_type, next_token_type, next_action_type, seq_feat, pos_feat, neg_featdef __len__(self):"""返回數據集長度,即用戶數量Returns:usernum: 用戶數量"""return len(self.seq_offsets)def _init_feat_info(self):"""初始化特征信息, 包括特征缺省值和特征類型Returns:feat_default_value: 特征缺省值,每個元素為字典,key為特征ID,value為特征缺省值feat_types: 特征類型,key為特征類型名稱,value為包含的特征ID列表"""feat_default_value = {}feat_statistics = {}feat_types = {}feat_types['user_sparse'] = ['103', '104', '105', '109']feat_types['item_sparse'] = ['100','117','111','118','101','102','119','120','114','112','121','115','122','116',]feat_types['item_array'] = []feat_types['user_array'] = ['106', '107', '108', '110']feat_types['item_emb'] = self.mm_emb_idsfeat_types['user_continual'] = []feat_types['item_continual'] = []for feat_id in feat_types['user_sparse']:feat_default_value[feat_id] = 0feat_statistics[feat_id] = len(self.indexer['f'][feat_id])for feat_id in feat_types['item_sparse']:feat_default_value[feat_id] = 0feat_statistics[feat_id] = len(self.indexer['f'][feat_id])for feat_id in feat_types['item_array']:feat_default_value[feat_id] = [0]feat_statistics[feat_id] = len(self.indexer['f'][feat_id])for feat_id in feat_types['user_array']:feat_default_value[feat_id] = [0]feat_statistics[feat_id] = len(self.indexer['f'][feat_id])for feat_id in feat_types['user_continual']:feat_default_value[feat_id] = 0for feat_id in feat_types['item_continual']:feat_default_value[feat_id] = 0for feat_id in feat_types['item_emb']:feat_default_value[feat_id] = np.zeros(list(self.mm_emb_dict[feat_id].values())[0].shape[0], dtype=np.float32)return feat_default_value, feat_types, feat_statisticsdef fill_missing_feat(self, feat, item_id):"""對于原始數據中缺失的特征進行填充缺省值Args:feat: 特征字典item_id: 物品IDReturns:filled_feat: 填充后的特征字典"""if feat == None:feat = {}filled_feat = {}for k in feat.keys():filled_feat[k] = feat[k]all_feat_ids = []for feat_type in self.feature_types.values():all_feat_ids.extend(feat_type)missing_fields = set(all_feat_ids) - set(feat.keys())for feat_id in missing_fields:filled_feat[feat_id] = self.feature_default_value[feat_id]for feat_id in self.feature_types['item_emb']:if item_id != 0 and self.indexer_i_rev[item_id] in self.mm_emb_dict[feat_id]:if type(self.mm_emb_dict[feat_id][self.indexer_i_rev[item_id]]) == np.ndarray:filled_feat[feat_id] = self.mm_emb_dict[feat_id][self.indexer_i_rev[item_id]]return filled_feat@staticmethoddef collate_fn(batch):"""Args:batch: 多個__getitem__返回的數據Returns:seq: 用戶序列ID, torch.Tensor形式pos: 正樣本ID, torch.Tensor形式neg: 負樣本ID, torch.Tensor形式token_type: 用戶序列類型, torch.Tensor形式next_token_type: 下一個token類型, torch.Tensor形式seq_feat: 用戶序列特征, list形式pos_feat: 正樣本特征, list形式neg_feat: 負樣本特征, list形式"""seq, pos, neg, token_type, next_token_type, next_action_type, seq_feat, pos_feat, neg_feat = zip(*batch)seq = torch.from_numpy(np.array(seq))pos = torch.from_numpy(np.array(pos))neg = torch.from_numpy(np.array(neg))token_type = torch.from_numpy(np.array(token_type))next_token_type = torch.from_numpy(np.array(next_token_type))next_action_type = torch.from_numpy(np.array(next_action_type))seq_feat = list(seq_feat)pos_feat = list(pos_feat)neg_feat = list(neg_feat)return seq, pos, neg, token_type, next_token_type, next_action_type, seq_feat, pos_feat, neg_featclass MyTestDataset(MyDataset):"""測試數據集"""def __init__(self, data_dir, args):super().__init__(data_dir, args)def _load_data_and_offsets(self):self.data_file = open(self.data_dir / "predict_seq.jsonl", 'rb')with open(Path(self.data_dir, 'predict_seq_offsets.pkl'), 'rb') as f:self.seq_offsets = pickle.load(f)def _process_cold_start_feat(self, feat):"""處理冷啟動特征。訓練集未出現過的特征value為字符串,默認轉換為0.可設計替換為更好的方法。"""processed_feat = {}for feat_id, feat_value in feat.items():if type(feat_value) == list:value_list = []for v in feat_value:if type(v) == str:value_list.append(0)else:value_list.append(v)processed_feat[feat_id] = value_listelif type(feat_value) == str:processed_feat[feat_id] = 0else:processed_feat[feat_id] = feat_valuereturn processed_featdef __getitem__(self, uid):"""獲取單個用戶的數據,并進行padding處理,生成模型需要的數據格式Args:uid: 用戶在self.data_file中儲存的行號Returns:seq: 用戶序列IDtoken_type: 用戶序列類型,1表示item,2表示userseq_feat: 用戶序列特征,每個元素為字典,key為特征ID,value為特征值user_id: user_id eg. user_xxxxxx ,便于后面對照答案"""user_sequence = self._load_user_data(uid)  # 動態加載用戶數據ext_user_sequence = []for record_tuple in user_sequence:u, i, user_feat, item_feat, _, _ = record_tupleif u:if type(u) == str:  # 如果是字符串,說明是user_iduser_id = uelse:  # 如果是int,說明是re_iduser_id = self.indexer_u_rev[u]if u and user_feat:if type(u) == str:u = 0if user_feat:user_feat = self._process_cold_start_feat(user_feat)ext_user_sequence.insert(0, (u, user_feat, 2))if i and item_feat:# 序列對于訓練時沒見過的item,不會直接賦0,而是保留creative_id,creative_id遠大于訓練時的itemnumif i > self.itemnum:i = 0if item_feat:item_feat = self._process_cold_start_feat(item_feat)ext_user_sequence.append((i, item_feat, 1))seq = np.zeros([self.maxlen + 1], dtype=np.int32)token_type = np.zeros([self.maxlen + 1], dtype=np.int32)seq_feat = np.empty([self.maxlen + 1], dtype=object)idx = self.maxlents = set()for record_tuple in ext_user_sequence:if record_tuple[2] == 1 and record_tuple[0]:ts.add(record_tuple[0])for record_tuple in reversed(ext_user_sequence[:-1]):i, feat, type_ = record_tuplefeat = self.fill_missing_feat(feat, i)seq[idx] = itoken_type[idx] = type_seq_feat[idx] = featidx -= 1if idx == -1:breakseq_feat = np.where(seq_feat == None, self.feature_default_value, seq_feat)return seq, token_type, seq_feat, user_iddef __len__(self):"""Returns:len(self.seq_offsets): 用戶數量"""with open(Path(self.data_dir, 'predict_seq_offsets.pkl'), 'rb') as f:temp = pickle.load(f)return len(temp)@staticmethoddef collate_fn(batch):"""將多個__getitem__返回的數據拼接成一個batchArgs:batch: 多個__getitem__返回的數據Returns:seq: 用戶序列ID, torch.Tensor形式token_type: 用戶序列類型, torch.Tensor形式seq_feat: 用戶序列特征, list形式user_id: user_id, str"""seq, token_type, seq_feat, user_id = zip(*batch)seq = torch.from_numpy(np.array(seq))token_type = torch.from_numpy(np.array(token_type))seq_feat = list(seq_feat)return seq, token_type, seq_feat, user_iddef save_emb(emb, save_path):"""將Embedding保存為二進制文件Args:emb: 要保存的Embedding,形狀為 [num_points, num_dimensions]save_path: 保存路徑"""num_points = emb.shape[0]  # 數據點數量num_dimensions = emb.shape[1]  # 向量的維度print(f'saving {save_path}')with open(Path(save_path), 'wb') as f:f.write(struct.pack('II', num_points, num_dimensions))emb.tofile(f)def load_mm_emb(mm_path, feat_ids):"""加載多模態特征EmbeddingArgs:mm_path: 多模態特征Embedding路徑feat_ids: 要加載的多模態特征ID列表Returns:mm_emb_dict: 多模態特征Embedding字典,key為特征ID,value為特征Embedding字典(key為item ID,value為Embedding)"""SHAPE_DICT = {"81": 32, "82": 1024, "83": 3584, "84": 4096, "85": 3584, "86": 3584}mm_emb_dict = {}for feat_id in tqdm(feat_ids, desc='Loading mm_emb'):shape = SHAPE_DICT[feat_id]emb_dict = {}if feat_id != '81':try:base_path = Path(mm_path, f'emb_{feat_id}_{shape}')for json_file in base_path.glob('*.json'):with open(json_file, 'r', encoding='utf-8') as file:for line in file:data_dict_origin = json.loads(line.strip())insert_emb = data_dict_origin['emb']if isinstance(insert_emb, list):insert_emb = np.array(insert_emb, dtype=np.float32)data_dict = {data_dict_origin['anonymous_cid']: insert_emb}emb_dict.update(data_dict)except Exception as e:print(f"transfer error: {e}")if feat_id == '81':with open(Path(mm_path, f'emb_{feat_id}_{shape}.pkl'), 'rb') as f:emb_dict = pickle.load(f)mm_emb_dict[feat_id] = emb_dictprint(f'Loaded #{feat_id} mm_emb')return mm_emb_dict

4. model_rqvae.py - 多模態特征壓縮

實現了 RQ-VAE(Residual Quantized Variational AutoEncoder)框架,用于將高維多模態 embedding 轉換為離散的語義 ID:

核心組件

  • RQEncoder/RQDecoder:編碼器和解碼器
  • VQEmbedding:向量量化模塊,支持 K-means 初始化
  • RQ:殘差量化器,實現多級量化
  • RQVAE:完整的 RQ-VAE 模型

量化方法

  • 支持標準 K-means 和平衡 K-means 聚類
  • 使用余弦距離或 L2 距離進行向量量化
  • 通過殘差量化實現更精確的特征表示
model_rqvae.py 代碼
"""
選手可參考以下流程,使用提供的 RQ-VAE 框架代碼將多模態emb數據轉換為Semantic Id:
1. 使用 MmEmbDataset 讀取不同特征 ID 的多模態emb數據.
2. 訓練 RQ-VAE 模型, 訓練完成后將數據轉換為Semantic Id.
3. 參照 Item Sparse 特征格式處理Semantic Id,作為新特征加入Baseline模型訓練.
"""import torch
import torch.nn.functional as F
from sklearn.cluster import KMeans# class MmEmbDataset(torch.utils.data.Dataset):
#     """
#     Build Dataset for RQ-VAE Training#     Args:
#         data_dir = os.environ.get('TRAIN_DATA_PATH')
#         feature_id = MM emb ID
#     """#     def __init__(self, data_dir, feature_id):
#         super().__init__()
#         self.data_dir = Path(data_dir)
#         self.mm_emb_id = [feature_id]
#         self.mm_emb_dict = load_mm_emb(Path(data_dir, "creative_emb"), self.mm_emb_id)#         self.mm_emb = self.mm_emb_dict[self.mm_emb_id[0]]
#         self.tid_list, self.emb_list = list(self.mm_emb.keys()), list(self.mm_emb.values())
#         self.emb_list = [torch.tensor(emb, dtype=torch.float32) for emb in self.emb_list]#         assert len(self.tid_list) == len(self.emb_list)
#         self.item_cnt = len(self.tid_list)#     def __getitem__(self, index):
#         tid = torch.tensor(self.tid_list[index], dtype=torch.long)
#         emb = self.emb_list[index]
#         return tid, emb#     def __len__(self):
#         return self.item_cnt#     @staticmethod
#     def collate_fn(batch):
#         tid, emb = zip(*batch)#         tid_batch, emb_batch = torch.stack(tid, dim=0), torch.stack(emb, dim=0)
#         return tid_batch, emb_batch## Kmeans
def kmeans(data, n_clusters, kmeans_iters):"""auto init: n_init = 10 if n_clusters <= 10 else 1"""km = KMeans(n_clusters=n_clusters, max_iter=kmeans_iters, n_init="auto")# sklearn only support cpudata_cpu = data.detach().cpu()np_data = data_cpu.numpy()km.fit(np_data)return torch.tensor(km.cluster_centers_), torch.tensor(km.labels_)## Balanced Kmeans
class BalancedKmeans(torch.nn.Module):def __init__(self, num_clusters: int, kmeans_iters: int, tolerance: float, device: str):super().__init__()self.num_clusters = num_clustersself.kmeans_iters = kmeans_itersself.tolerance = toleranceself.device = deviceself._codebook = Nonedef _compute_distances(self, data):return torch.cdist(data, self._codebook)def _assign_clusters(self, dist):samples_cnt = dist.shape[0]samples_labels = torch.zeros(samples_cnt, dtype=torch.long, device=self.device)clusters_cnt = torch.zeros(self.num_clusters, dtype=torch.long, device=self.device)sorted_indices = torch.argsort(dist, dim=-1)for i in range(samples_cnt):for j in range(self.num_clusters):cluster_idx = sorted_indices[i, j]if clusters_cnt[cluster_idx] < samples_cnt // self.num_clusters:samples_labels[i] = cluster_idxclusters_cnt[cluster_idx] += 1breakreturn samples_labelsdef _update_codebook(self, data, samples_labels):_new_codebook = []for i in range(self.num_clusters):cluster_data = data[samples_labels == i]if len(cluster_data) > 0:_new_codebook.append(cluster_data.mean(dim=0))else:_new_codebook.append(self._codebook[i])return torch.stack(_new_codebook)def fit(self, data):num_emb, codebook_emb_dim = data.shapedata = data.to(self.device)# initialize codebookindices = torch.randperm(num_emb)[: self.num_clusters]self._codebook = data[indices].clone()for _ in range(self.kmeans_iters):dist = self._compute_distances(data)samples_labels = self._assign_clusters(dist)_new_codebook = self._update_codebook(data, samples_labels)if torch.norm(_new_codebook - self._codebook) < self.tolerance:breakself._codebook = _new_codebookreturn self._codebook, samples_labelsdef predict(self, data):data = data.to(self.device)dist = self._compute_distances(data)samples_labels = self._assign_clusters(dist)return samples_labels## Base RQVAE
class RQEncoder(torch.nn.Module):def __init__(self, input_dim: int, hidden_channels: list, latent_dim: int):super().__init__()self.stages = torch.nn.ModuleList()in_dim = input_dimfor out_dim in hidden_channels:stage = torch.nn.Sequential(torch.nn.Linear(in_dim, out_dim), torch.nn.ReLU())self.stages.append(stage)in_dim = out_dimself.stages.append(torch.nn.Sequential(torch.nn.Linear(in_dim, latent_dim), torch.nn.ReLU()))def forward(self, x):for stage in self.stages:x = stage(x)return xclass RQDecoder(torch.nn.Module):def __init__(self, latent_dim: int, hidden_channels: list, output_dim: int):super().__init__()self.stages = torch.nn.ModuleList()in_dim = latent_dimfor out_dim in hidden_channels:stage = torch.nn.Sequential(torch.nn.Linear(in_dim, out_dim), torch.nn.ReLU())self.stages.append(stage)in_dim = out_dimself.stages.append(torch.nn.Sequential(torch.nn.Linear(in_dim, output_dim), torch.nn.ReLU()))def forward(self, x):for stage in self.stages:x = stage(x)return x## Generate semantic id
class VQEmbedding(torch.nn.Embedding):def __init__(self,num_clusters,codebook_emb_dim: int,kmeans_method: str,kmeans_iters: int,distances_method: str,device: str,):super(VQEmbedding, self).__init__(num_clusters, codebook_emb_dim)self.num_clusters = num_clustersself.codebook_emb_dim = codebook_emb_dimself.kmeans_method = kmeans_methodself.kmeans_iters = kmeans_itersself.distances_method = distances_methodself.device = devicedef _create_codebook(self, data):if self.kmeans_method == 'kmeans':_codebook, _ = kmeans(data, self.num_clusters, self.kmeans_iters)elif self.kmeans_method == 'bkmeans':BKmeans = BalancedKmeans(num_clusters=self.num_clusters, kmeans_iters=self.kmeans_iters, tolerance=1e-4, device=self.device)_codebook, _ = BKmeans.fit(data)else:_codebook = torch.randn(self.num_clusters, self.codebook_emb_dim)_codebook = _codebook.to(self.device)assert _codebook.shape == (self.num_clusters, self.codebook_emb_dim)self.codebook = torch.nn.Parameter(_codebook)@torch.no_grad()def _compute_distances(self, data):_codebook_t = self.codebook.t()assert _codebook_t.shape == (self.codebook_emb_dim, self.num_clusters)assert data.shape[-1] == self.codebook_emb_dimif self.distances_method == 'cosine':data_norm = F.normalize(data, p=2, dim=-1)_codebook_t_norm = F.normalize(_codebook_t, p=2, dim=0)distances = 1 - torch.mm(data_norm, _codebook_t_norm)# l2else:data_norm_sq = data.pow(2).sum(dim=-1, keepdim=True)_codebook_t_norm_sq = _codebook_t.pow(2).sum(dim=0, keepdim=True)distances = torch.addmm(data_norm_sq + _codebook_t_norm_sq, data, _codebook_t, beta=1.0, alpha=-2.0)return distances@torch.no_grad()def _create_semantic_id(self, data):distances = self._compute_distances(data)_semantic_id = torch.argmin(distances, dim=-1)return _semantic_iddef _update_emb(self, _semantic_id):update_emb = super().forward(_semantic_id)return update_embdef forward(self, data):self._create_codebook(data)_semantic_id = self._create_semantic_id(data)update_emb = self._update_emb(_semantic_id)return update_emb, _semantic_id## Residual Quantizer
class RQ(torch.nn.Module):"""Args:num_codebooks, codebook_size, codebook_emb_dim -> Build codebookif_shared_codebook -> If use same codebookkmeans_method, kmeans_iters -> Initialize codebookdistances_method -> Generate semantic_idloss_beta -> Calculate RQ-VAE loss"""def __init__(self,num_codebooks: int,codebook_size: list,codebook_emb_dim,shared_codebook: bool,kmeans_method,kmeans_iters,distances_method,loss_beta: float,device: str,):super().__init__()self.num_codebooks = num_codebooksself.codebook_size = codebook_sizeassert len(self.codebook_size) == self.num_codebooksself.codebook_emb_dim = codebook_emb_dimself.shared_codebook = shared_codebookself.kmeans_method = kmeans_methodself.kmeans_iters = kmeans_itersself.distances_method = distances_methodself.loss_beta = loss_betaself.device = deviceif self.shared_codebook:self.vqmodules = torch.nn.ModuleList([VQEmbedding(self.codebook_size[0],self.codebook_emb_dim,self.kmeans_method,self.kmeans_iters,self.distances_method,self.device,)for _ in range(self.num_codebooks)])else:self.vqmodules = torch.nn.ModuleList([VQEmbedding(self.codebook_size[idx],self.codebook_emb_dim,self.kmeans_method,self.kmeans_iters,self.distances_method,self.device,)for idx in range(self.num_codebooks)])def quantize(self, data):"""Exa:i-th quantize: input[i]( i.e. res[i-1] ) = VQ[i] + res[i]vq_emb_list: [vq1, vq1+vq2, ...]res_emb_list: [res1, res2, ...]semantic_id_list: [vq1_sid, vq2_sid, ...]Returns:vq_emb_list[0] -> [batch_size, codebook_emb_dim]semantic_id_list -> [batch_size, num_codebooks]"""res_emb = data.detach().clone()vq_emb_list, res_emb_list = [], []semantic_id_list = []vq_emb_aggre = torch.zeros_like(data)for i in range(self.num_codebooks):vq_emb, _semantic_id = self.vqmodules[i](res_emb)res_emb -= vq_embvq_emb_aggre += vq_embres_emb_list.append(res_emb)vq_emb_list.append(vq_emb_aggre)semantic_id_list.append(_semantic_id.unsqueeze(dim=-1))semantic_id_list = torch.cat(semantic_id_list, dim=-1)return vq_emb_list, res_emb_list, semantic_id_listdef _rqvae_loss(self, vq_emb_list, res_emb_list):rqvae_loss_list = []for idx, quant in enumerate(vq_emb_list):# stop gradientloss1 = (res_emb_list[idx].detach() - quant).pow(2.0).mean()loss2 = (res_emb_list[idx] - quant.detach()).pow(2.0).mean()partial_loss = loss1 + self.loss_beta * loss2rqvae_loss_list.append(partial_loss)rqvae_loss = torch.sum(torch.stack(rqvae_loss_list))return rqvae_lossdef forward(self, data):vq_emb_list, res_emb_list, semantic_id_list = self.quantize(data)rqvae_loss = self._rqvae_loss(vq_emb_list, res_emb_list)return vq_emb_list, semantic_id_list, rqvae_lossclass RQVAE(torch.nn.Module):def __init__(self,input_dim: int,hidden_channels: list,latent_dim: int,num_codebooks: int,codebook_size: list,shared_codebook: bool,kmeans_method,kmeans_iters,distances_method,loss_beta: float,device: str,):super().__init__()self.encoder = RQEncoder(input_dim, hidden_channels, latent_dim).to(device)self.decoder = RQDecoder(latent_dim, hidden_channels[::-1], input_dim).to(device)self.rq = RQ(num_codebooks,codebook_size,latent_dim,shared_codebook,kmeans_method,kmeans_iters,distances_method,loss_beta,device,).to(device)def encode(self, x):return self.encoder(x)def decode(self, z_vq):if isinstance(z_vq, list):z_vq = z_vq[-1]return self.decoder(z_vq)def compute_loss(self, x_hat, x_gt, rqvae_loss):recon_loss = F.mse_loss(x_hat, x_gt, reduction="mean")total_loss = recon_loss + rqvae_lossreturn recon_loss, rqvae_loss, total_lossdef _get_codebook(self, x_gt):z_e = self.encode(x_gt)vq_emb_list, semantic_id_list, rqvae_loss = self.rq(z_e)return semantic_id_listdef forward(self, x_gt):z_e = self.encode(x_gt)vq_emb_list, semantic_id_list, rqvae_loss = self.rq(z_e)x_hat = self.decode(vq_emb_list)recon_loss, rqvae_loss, total_loss = self.compute_loss(x_hat, x_gt, rqvae_loss)return x_hat, semantic_id_list, recon_loss, rqvae_loss, total_loss

5. run.sh - 運行腳本

簡單的 bash 腳本,用于啟動訓練程序。

run.sh 代碼
#!/bin/bash# show ${RUNTIME_SCRIPT_DIR}
echo ${RUNTIME_SCRIPT_DIR}
# enter train workspace
cd ${RUNTIME_SCRIPT_DIR}# write your code below
python -u main.py

技術特點

  1. 高效注意力機制:使用 Flash Attention 優化計算效率
  2. 多模態融合:支持文本、圖像等多種模態的 embedding 特征
  3. 特征工程:支持稀疏、密集、數組等多種特征類型
  4. 序列建模:同時建模用戶和物品的交互序列
  5. 可擴展性:支持大規模物品庫的 embedding 保存和檢索

數據流程

  1. 訓練階段:讀取用戶序列 → 特征 embedding → Transformer 編碼 → 計算正負樣本 loss
  2. 推理階段:生成用戶表征 → 保存物品 embedding → 進行向量檢索推薦
  3. 多模態處理:原始 embedding → RQ-VAE 壓縮 → 語義 ID → 作為新特征加入模型

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/91573.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/91573.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/91573.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

數據結構(11)棧和隊列算法題 OVA

一、概念與結構 循環隊列是一種特殊的隊列&#xff0c;首尾相連成環&#xff0c;也叫環形隊列。環形隊列具有以下三個特點&#xff1a; &#xff08;1&#xff09;隊頭刪除數據&#xff0c;隊尾插入數據。 &#xff08;2&#xff09;給定固定的空間&#xff0c;使用過程中不…

九聯UNT403HS_海思MV320處理器_安卓9-優盤強刷刷機包

九聯UNT403HS_海思MV320處理器_安卓9-優盤強刷刷機包前言&#xff1a;九聯UNT403HS&#xff0c;海思MV320芯片&#xff0c;已知有2種內存型號&#xff0c;分別是28G和216G。已知河南融合版本是28G&#xff0c;廣東版好像既有28G又有216G。理論上固件沒有本質區分&#xff0c;能…

Xilinx高性能低延時PCIe-DMA控制器IP,SGDMA,QDMA,RDMA,CDMA,V4L2驅動,視頻采集、AD采集

Multi-Channel High Performance PCIe QDMA&RDMA IP介紹基于PCI Express Integrated Block&#xff0c;Multi-Channel PCIe QDMA Subsystem實現了使用DMA地址隊列的獨立多通道、高性能Continous&#xff08;CDMA&#xff09;或Scather Gather DMA&#xff08;SGDMA&#xf…

10、Docker Compose 安裝 MySQL

&#x1f433; 使用 Docker Compose 安裝 MySQL&#xff08;含配置詳解與常見問題&#xff09;標簽&#xff1a;#DockerCompose #MySQL #數據庫部署 #后端開發 #運維入門 #配置詳解 適合讀者&#xff1a;開發者、DevOps、新手運維人員&#x1f4cc; 一、前言 在日常開發與部署中…

Dynamic A(D)算法深度剖析:動態環境下的路徑規劃革新

Dynamic A*(D*)算法深度剖析:動態環境下的路徑規劃革新 文章目錄 Dynamic A*(D*)算法深度剖析:動態環境下的路徑規劃革新 1. 引言:動態路徑規劃的核心挑戰與解決方案 1.1 動態環境的本質特征 1.2 D * 算法的誕生與核心價值 2. D * 算法核心原理深度解析 2.1 反向搜索機制…

前端框架Vue3(四)——組件通信及其他API

組件通信組件關系傳遞方式父傳子1. props2. v-model3. $refs4. 默認插槽、具名插槽子傳父1.props2.自定義事件3.v-model4.parent5.作用域插槽祖傳孫、孫傳祖1.$attrs2.provide、inject兄弟間、任意組件間1.mitt2.pinia【props】 概述&#xff1a;props是使用頻率最高的一種通信…

07【C++ 初階】類和對象(中篇) --- 類的默認成員函數

文章目錄前言類的6個默認成員函數1.構造函數1.1 構造函數特性1.1.1 函數名與類名相同1.1.2 無返回值1.1.3 對象實例化時編譯器自動調用對應的構造函數1.1.4 構造函數可以重載1.1.5 默認構造只能有一個1.1.6 默認構造的必要性1.2 構造函數的初始化列表2.析構函數2.1 析構函數特性…

第二次CISSP考試通過!

今天我終于臨時通過了 CISSP 考試&#xff01;這第二次的精神壓力一點也不比第一次小。我在第 101 道題 時通過&#xff0c;還剩大約 30 分鐘。我當時真的以為自己又要像上次那樣時間不夠了。第一次考試的失敗經歷&#xff1a;第一次考試是我剛參加完為期 5 天的強化 Boot Camp…

USRP捕獲手機/路由器數據傳輸信號波形(上)

目錄&#xff1a; USRP捕獲手機/路由器數據傳輸信號波形&#xff08;上&#xff09; USRP捕獲手機/路由器數據傳輸信號波形&#xff08;中&#xff09; USRP捕獲手機/路由器數據傳輸信號波形&#xff08;下&#xff09; 一、前期準備 1.1 場景與系統 手機、路由器與天線的…

基于STM32F103的FM1702驅動程序

基于STM32F103微控制器與復旦微電子FM1702SL射頻讀卡芯片的驅動開發方案&#xff0c;整合了硬件配置、寄存器操作和通信協議實現&#xff1a;一、硬件連接設計 1. 管腳映射表FM1702SL引腳STM32F103引腳功能說明VDD3.3V電源輸入GNDGND地線SCKPA5(SPI1_SCK)SPI時鐘MISOPA6(SPI1_M…

京東商品評論API指南

一、引言京東商品評論API(JD.item_review)是京東開放平臺提供的重要接口&#xff0c;允許開發者獲取商品的詳細評論數據。通過該接口可以獲取包括評論內容、評分、評論時間、用戶昵稱等信息&#xff0c;為商品分析、用戶行為研究等提供數據支持?。二、接口概述1. 接口基本信息…

網絡編程概述與UDP編程

一、 網絡編程概述 1.1 概述 在現代軟件開發與系統交互場景里&#xff0c;基于 Socket 的網絡多進程通信占據核心地位&#xff0c;其適用場景廣泛且深入到各類數字化交互中&#xff1a; 直播場景&#xff1a;主播端通過 Socket 建立的網絡連接&#xff0c;將音視頻流以數據包…

新手教程:用外部 PostgreSQL 和 Zookeeper 啟動 Dolphinscheduler

本文將帶你一步步通過外部PostgreSQL和Zookeeper來啟動Apache DolphinScheduler。無論你是新手還是有經驗的開發者&#xff0c;都能輕松跟著這些步驟在Linux/Unix環境中完成安裝和配置。除了常見的安裝步驟&#xff0c;我們還會分享一些集群部署的技巧&#xff0c;讓你輕松擴展…

安寶特案例丨AR+AI賦能軌道交通制造:破解人工裝配難題的創新實踐

在軌道交通裝備制造領域&#xff0c;小批量、多品種的生產特性與高度依賴人工經驗的作業模式長期并存&#xff0c;導致效率瓶頸與質量隱患并存。安寶特通過AR&#xff08;增強現實&#xff09;AI&#xff08;人工智能&#xff09;技術融合&#xff0c;在螺栓緊固、內飾裝配、制…

基于LSTM-GRU混合網絡的動態解析:美聯儲維穩政策與黃金單日跌1.5%的非線性關聯

摘要&#xff1a;本文通過構建多因子量化模型&#xff0c;結合自然語言處理&#xff08;NLP&#xff09;技術對美聯儲政策文本進行情緒分析&#xff0c;解析經濟數據、市場情緒及宏觀環境對黃金價格的復合影響機制。研究基于LSTM時間序列預測框架&#xff0c;驗證關鍵事件對金價…

RabbitMQ消息確認機制有幾個confirm?

RabbitMQ 的消息確認機制中&#xff0c;“confirm” 這個詞主要出現在兩個關鍵環節&#xff0c;對應兩種確認&#xff1a;? 兩種 confirm&#xff08;確認&#xff09;機制確認類型觸發方說明Publisher Confirm&#xff08;生產者確認&#xff09;生產者 → Broker消息是否成功…

vue項目啟動時因內存不足啟動失敗

可以使用increase-memory-limit跟npm install cross-env插件npm install increase-memory-limit npm install cross-env安裝后需要在package.json文件中加入如下代碼"scripts": {"fix-memory-limit": "cross-env LIMIT3072 increase-memory-limit&quo…

WEditor:高效的移動端UI自動化腳本可視化編輯器

WEditor&#xff1a;高效的移動端UI自動化腳本可視化編輯器前言一、核心特性與優勢1. 可視化操作&#xff0c;降低門檻2. 跨平臺支持3. 豐富的控件層級展示4. 快捷鍵高效操作5. 開源可擴展二、安裝與環境配置1. 環境準備Android 設備用戶需額外準備ADB 安裝與配置步驟2. 安裝依…

面試高頻題 力扣 283.移動零 雙指針技巧 原地修改 順序保持 C++解題思路 每日一題

目錄零、題目描述一、為什么這道題值得你花幾分鐘看懂&#xff1f;二、題目拆解&#xff1a;提取其中的關鍵點三、明確思路&#xff1a;雙指針的巧妙配合四、算法實現&#xff1a;雙指針的代碼演繹五、C代碼實現&#xff1a;一步步拆解代碼拆解時間復雜度和空間復雜度六、實現過…

arrch64架構下調用pyvista報錯

arrch64架構下調用pyvista報錯 問題 python編程使用到了pyvista&#xff0c;使用conda新建了環境&#xff0c;但是使用的時候報錯 Traceback (most recent call last):File "/home/ztl/MGGBSAR/src/trans_las_3D.py", line 16, in <module>import pyvista as p…