TorchRec中的輸入和輸出格式
文章目錄
- TorchRec中的輸入和輸出格式
- 前言
- 一、JaggedTensor
- 1.1 核心概念
- 1.2 核心屬性,也就是參數
- 1.3 關鍵操作與方法
- 二、KeyedJaggedTensor
- 2.1 核心概念
- 2.2 核心屬性,也就是參數
- 3、KeyedTensor
- 總結
前言
- TorchRec具有其特定的輸入輸出格式,跟Torch中的Tonsor有些不同,下邊就讓我們來了解他們。
一、JaggedTensor
- JaggedTensor 通過長度、值和偏移量來表示稀疏特征。它之所以被稱為“jagged”,是因為它可以有效地表示可變長度序列的數據。
- 相比之下,規范的 torch.Tensor 假設每個序列具有相同的長度,但這在真實世界數據中通常不是這種情況。
- JaggedTensor 有助于表示此類數據而無需填充,從而使其非常高效。
1.1 核心概念
- JaggedTensor表示一個包含多個不等長序列的2D張量,例如:
- 用戶歷史點擊序列:[[item1, item2], [item3], [item4, item5, item6], …]
- 文本分詞后的句子:[[tokenA, tokenB], [tokenC], …]
1.2 核心屬性,也就是參數
- values (Tensor): 一個 1D 張量,包含每個實體的實際值,連續存儲,形狀為[total_values]
- lengths (Optional[Tensor]): 一個整數列表,表示每個實體的元素數量,形狀為[B]
- offsets (Optional[Tensor]): 一個整數列表,表示扁平化值張量中每個序列的起始索引。這些提供了長度的替代方案。形狀為[B+1]
演示代碼如下:
import torch
from torchrec.sparse.jagged_tensor import JaggedTensor# 方式1 - 使用lengths
values = torch.tensor([1, 2, 3, 4, 5])
lengths = torch.tensor([2, 1, 2]) # 三個序列的長度分別為2,1,2
jt = JaggedTensor(lengths=lengths, values=values)
# 意思就是有三個張量[[1, 2],[3],[4, 5]]# 方式2 - 使用offsets
values = torch.tensor([10, 20, 30, 40])
offsets = torch.tensor([1, 3]) # 兩個序列的偏移量
jt = JaggedTensor(offsets=offsets, values=values)
# 意思就是有兩個張量[[10],[20, 30, 40]]
1.3 關鍵操作與方法
- 操作:
- 拼接
torchrec.sparse.jagged_ops.concat(jt1, jt2)
- 分塊
jt.split(split_size)
- 聚合
jt.sum(dim=1) 或 jt.mean(dim=1)
- 方法
- 與稠密張量互轉()
# 轉Padded Tensor padded, mask = jt.to_padded_tensor(padding_value=0, max_length=5) # 轉Packed Tensor (類似PyTorch的PackedSequence) packed = jt.to_packed_tensor()
- 嵌入表查詢
# 轉Padded Tensor embedding_bag = torch.nn.EmbeddingBag(num_embeddings=100, embedding_dim=16) embeddings = embedding_bag(jt.values, offsets=jt.offsets())
二、KeyedJaggedTensor
- KeyedJaggedTensor 通過引入鍵(通常是特征名稱)來標記不同的特征組(例如,用戶特征和項目特征),從而擴展了 JaggedTensor 的功能。
- 這是 EmbeddingBagCollection 和 EmbeddingCollection 的 forward 中使用的數據類型,因為它們用于表示表中的多個特征。
2.1 核心概念
-
KeyedJaggedTensor 是用于管理多個變長特征序列的高效數據結構,核心場景包括:
- 多特征推薦系統:同時處理用戶歷史點擊(click_ids)、搜索詞(search_ids)、收藏商品(favor_ids)等不同特征
2.2 核心屬性,也就是參數
- keys (List[str]): 特征名稱列表,如 [“click”, “search”]
- values (Tensor): 一個 1D 張量,所有特征值的展平拼接,連續存儲,形狀為[total_values]
- lengths (Optional[Tensor]): 一個整數列表,表示每個特征在樣本中的長度,按 keys 順序排列
- offsets (Optional[Tensor]): 一個整數列表,表示扁平化值張量中每個特征列的起始索引。形狀為[B+1]
演示代碼如下:
import torch
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, JaggedTensor# 方式一:從多個JaggedTensor構建
# 創建兩個特征的JaggedTensor
click_jt = JaggedTensor(lengths=torch.tensor([2, 1]),values=torch.tensor([101, 202, 303])
)
search_jt = JaggedTensor(lengths=torch.tensor([1, 3]),values=torch.tensor([401, 402, 403, 404])
)# 合并為KeyedJaggedTensor
kjt = KeyedJaggedTensor.from_jagged_tensors(keys=["click", "search"],tensors=[click_jt, search_jt]
)print(kjt)
# KeyedJaggedTensor({
# "click": JaggedTensor([[101, 202], [303]]),
# "search": JaggedTensor([[401], [402, 403, 404]])
# })# 方式二:從原始數據直接構建
kjt = KeyedJaggedTensor(keys=["click", "search"],values=torch.tensor([101, 202, 303, 401, 402, 403, 404]),lengths=torch.tensor([2, 1, 1, 3]), # click_lengths + search_lengthsoffsets=None # 自動生成
)
# 跟上邊是一樣的效果
3、KeyedTensor
- torch.Tensor 的包裝器,允許通過鍵訪問張量值。
總結
- 本節我們學習TorchRec中的數據類型,了解他的輸入輸出格式。