論文代碼閱讀:TGN模型訓練階段代碼理解

文章目錄

    • @[toc]
  • TGN模型訓練階段代碼理解
    • 論文信息
    • 代碼過程手繪
    • 代碼訓練過程
          • compute_temporal_embeddings
          • update_memory
          • get_raw_messages
          • get_updated_memory
          • self.message_aggregator.aggregate
          • self.memory_updater.get_updated_memory
          • Memory
          • get_embedding_module
          • GraphAttentionEmbedding
          • TimeEncode
          • NeighborFinder
          • MergeLayer

TGN模型訓練階段代碼理解

論文信息

論文鏈接:https://arxiv.org/abs/2006.10637

GitHub: https://github.com/twitter-research/tgn?tab=readme-ov-file

年份:2020

代碼過程手繪

微信圖片_20231210165320

微信圖片_20231210165409

代碼訓練過程

pos_prob, neg_prob = tgn.compute_edge_probabilities(sources_batch, destinations_batch, negatives_batch,timestamps_batch, edge_idxs_batch, NUM_NEIGHBORS)

函數compute_edge_probabilities

def compute_edge_probabilities(self, source_nodes, destination_nodes, negative_nodes, edge_times,edge_idxs, n_neighbors=20):"""Compute probabilities for edges between sources and destination and between sources andnegatives by first computing temporal embeddings using the TGN encoder and then feeding theminto the MLP decoder.:param destination_nodes [batch_size]: destination ids:param negative_nodes [batch_size]: ids of negative sampled destination:param edge_times [batch_size]: timestamp of interaction:param edge_idxs [batch_size]: index of interaction:param n_neighbors [scalar]: number of temporal neighbor to consider in each convolutionallayer:return: Probabilities for both the positive and negative edgessource_nodes 源節點id列表destination_nodes 目標節點id列表negative_nodes 負采樣節點id列表edge_times 源節點列表中的節點與目標節點列表中的節點發生關系時的時間edge_idxs 邊的編號"""n_samples = len(source_nodes)# compute_temporal_embeddingssource_node_embedding, destination_node_embedding, negative_node_embedding = self.compute_temporal_embeddings(source_nodes, destination_nodes, negative_nodes, edge_times, edge_idxs, n_neighbors)score = self.affinity_score(torch.cat([source_node_embedding, source_node_embedding], dim=0),torch.cat([destination_node_embedding,negative_node_embedding])).squeeze(dim=0)pos_score = score[:n_samples]neg_score = score[n_samples:]return pos_score.sigmoid(), neg_score.sigmoid()
compute_temporal_embeddings

這個方法的目的是計算時間嵌入

def compute_temporal_embeddings(self, source_nodes, destination_nodes, negative_nodes, edge_times,edge_idxs, n_neighbors=20):"""Compute temporal embeddings for sources, destinations, and negatively sampled destinations.這個方法的目的是計算時間嵌入source_nodes [batch_size]: source ids.:param destination_nodes [batch_size]: destination ids:param negative_nodes [batch_size]: ids of negative sampled destination:param edge_times [batch_size]: timestamp of interaction:param edge_idxs [batch_size]: index of interaction:param n_neighbors [scalar]: number of temporal neighbor to consider in each convolutionallayer:return: Temporal embeddings for sources, destinations and negatives"""# n_samples 表示源節點有多少個n_samples = len(source_nodes)# nodes是所有的節點這個batch_size中所有的節點id, size=200*3=600nodes = np.concatenate([source_nodes, destination_nodes, negative_nodes])# positives 是將源節點和目標節點和在一切,前200個是源節點的node_id, 后200個是目標節點的node_idpositives = np.concatenate([source_nodes, destination_nodes])# timestamps shape=200*3 edge_times 是發生交互的時間timestamps = np.concatenate([edge_times, edge_times, edge_times])# edge_times shape = batch_size 是源節點和目的節點發生的時間memory = Nonetime_diffs = Noneif self.use_memory:if self.memory_update_at_start: # 是不是剛開始使用記憶# n_nodes 表示的是圖中一共有多少個節點 9228# 記憶列表 self.memory.messages 當前狀態一定為空# 在這個地方出來的memory是最新的memory,是根據節點的messages信息進行更新的,在代碼中會取該節點messages列表中最新的那一個memory, last_update = self.get_updated_memory(list(range(self.n_nodes)),self.memory.messages)  # memory shape = [n_nodes(9228), memory_dimension(172)] last_update shape [n_nodes(9228)]else:memory = self.memory.get_memory(list(range(self.n_nodes)))last_update = self.memory.last_update# ===================================== 下面這些都是處理單個節點的信息 ==============================# 計算節點內存最后一次更新的時間與我們希望計算該節點嵌入的時間之間的差異。# source_time_diffs shape [batch_size]source_time_diffs = torch.LongTensor(edge_times).to(self.device) - last_update[source_nodes].long() # 這是標準化操作source_time_diffs = (source_time_diffs - self.mean_time_shift_src) / self.std_time_shift_src # destination_time_diffs shape [batch_size]destination_time_diffs = torch.LongTensor(edge_times).to(self.device) - last_update[destination_nodes].long()# 這是標準化操作destination_time_diffs = (destination_time_diffs - self.mean_time_shift_dst) / self.std_time_shift_dstnegative_time_diffs = torch.LongTensor(edge_times).to(self.device) - last_update[negative_nodes].long()negative_time_diffs = (negative_time_diffs - self.mean_time_shift_dst) / self.std_time_shift_dst# 時間差time_diffs = torch.cat([source_time_diffs, destination_time_diffs, negative_time_diffs],dim=0)# Compute the embeddings using the embedding module# self.embedding_module 在下面所示# 1. 先是 self.embedding_module = get_embedding_module()"""memory 記憶對象nodes 是一個結合了源節點目的節點和負采樣節點的node_id列表timestamps 200*3的時間列表self.n_layers 遞歸的層數 這里為2n_neighbors 選取多少個鄰居節點 這里是10time_diffs 標準化過后的時間差"""# node_embedding shape [600, 172] 融合了節點的特征和鄰居其余邊的特征node_embedding = self.embedding_module.compute_embedding(memory=memory,source_nodes=nodes,timestamps=timestamps,n_layers=self.n_layers,n_neighbors=n_neighbors,time_diffs=time_diffs)# 然后去獲取不同列表的節點特征 source_node_embedding = node_embedding[:n_samples]destination_node_embedding = node_embedding[n_samples: 2 * n_samples]negative_node_embedding = node_embedding[2 * n_samples:]if self.use_memory:# 進行記憶力更新if self.memory_update_at_start:# Persist the updates to the memory only for sources and destinations (since now we have# new messages for them)self.update_memory(positives, self.memory.messages)assert torch.allclose(memory[positives], self.memory.get_memory(positives), atol=1e-5), \"Something wrong in how the memory was updated"# Remove messages for the positives since we have already updated the memory using them# 記憶已經更新了,那么對于每個信息就即為空self.memory.clear_messages(positives)unique_sources, source_id_to_messages = self.get_raw_messages(source_nodes,source_node_embedding,destination_nodes,destination_node_embedding,edge_times, edge_idxs)unique_destinations, destination_id_to_messages = self.get_raw_messages(destination_nodes,destination_node_embedding,source_nodes,source_node_embedding,edge_times, edge_idxs)if self.memory_update_at_start:# 存儲信息self.memory.store_raw_messages(unique_sources, source_id_to_messages)self.memory.store_raw_messages(unique_destinations, destination_id_to_messages)else:self.update_memory(unique_sources, source_id_to_messages)self.update_memory(unique_destinations, destination_id_to_messages)if self.dyrep:source_node_embedding = memory[source_nodes]destination_node_embedding = memory[destination_nodes]negative_node_embedding = memory[negative_nodes]return source_node_embedding, destination_node_embedding, negative_node_embedding
update_memory
def update_memory(self, nodes, messages):# Aggregate messages for the same nodes# self.message_aggregator -> LastMessageAggregatorunique_nodes, unique_messages, unique_timestamps = \self.message_aggregator.aggregate(nodes,messages)if len(unique_nodes) > 0:unique_messages = self.message_function.compute_message(unique_messages)# Update the memory with the aggregated messages# 聚合完了就去更新self.memory_updater.update_memory(unique_nodes, unique_messages,timestamps=unique_timestamps)
get_raw_messages
def get_raw_messages(self, source_nodes, source_node_embedding, destination_nodes,destination_node_embedding, edge_times, edge_idxs):# edge_times shape is [200, ]edge_times = torch.from_numpy(edge_times).float().to(self.device)# edge_features shape is [200, 172]edge_features = self.edge_raw_features[edge_idxs]source_memory = self.memory.get_memory(source_nodes) if not \self.use_source_embedding_in_message else source_node_embeddingdestination_memory = self.memory.get_memory(destination_nodes) if \not self.use_destination_embedding_in_message else destination_node_embeddingsource_time_delta = edge_times - self.memory.last_update[source_nodes]# source_time_delta_encoding [200, 172]source_time_delta_encoding = self.time_encoder(source_time_delta.unsqueeze(dim=1)).view(len(source_nodes), -1)# source_message shape [200, 688]source_message = torch.cat([source_memory, destination_memory, edge_features,source_time_delta_encoding],dim=1)messages = defaultdict(list)unique_sources = np.unique(source_nodes)for i in range(len(source_nodes)):messages[source_nodes[i]].append((source_message[i], edge_times[i]))return unique_sources, messages
get_updated_memory
def get_updated_memory(self, nodes, messages):# Aggregate messages for the same nodes# nodes 是一個列表 range(n_nodes)# messages是消息列表# 先是聚合消息,然后更新記憶# 在第一次進來這個函數的時候,返回的全是[]unique_nodes, unique_messages, unique_timestamps = \self.message_aggregator.aggregate(nodes, # 是一個列表 range(n_nodes)messages # 是消息列表)if len(unique_nodes) > 0:# 有兩個選擇"""class MLPMessageFunction(MessageFunction):def __init__(self, raw_message_dimension, message_dimension):super(MLPMessageFunction, self).__init__()self.mlp = self.layers = nn.Sequential(nn.Linear(raw_message_dimension, raw_message_dimension // 2),nn.ReLU(),nn.Linear(raw_message_dimension // 2, message_dimension),)def compute_message(self, raw_messages):messages = self.mlp(raw_messages)return messagesclass IdentityMessageFunction(MessageFunction): def compute_message(self, raw_messages):# 作者使用的是這個,啥也沒有邊,直接返回return raw_messages"""unique_messages = self.message_function.compute_message(unique_messages)# 在頭一次訓練的過程中進來這個地方, 返回的全是0的矩陣# 形狀為,[n_nodes, memory_dimension] [n_nodes]updated_memory, updated_last_update = self.memory_updater.get_updated_memory(unique_nodes,unique_messages,timestamps=unique_timestamps)return updated_memory, updated_last_update
self.message_aggregator.aggregate

代碼中默認使用last

def get_message_aggregator(aggregator_type, device):if aggregator_type == "last":return LastMessageAggregator(device=device)elif aggregator_type == "mean":return MeanMessageAggregator(device=device)else:raise ValueError("Message aggregator {} not implemented".format(aggregator_type))

LastMessageAggregator代碼:

class LastMessageAggregator(MessageAggregator):def __init__(self, device):super(LastMessageAggregator, self).__init__(device)def aggregate(self, node_ids, messages):"""Only keep the last message for each node"""unique_node_ids = np.unique(node_ids) # 去重節點,不知道啥作用,因為本來就沒有重復unique_messages = []unique_timestamps = []to_update_node_ids = []for node_id in unique_node_ids: # 循環range(n_nodes)=9228if len(messages[node_id]) > 0:"""上一步結束每個節點存儲的信息以及對應的(時間?)source_message = torch.cat([source_memory, destination_memory, edge_features,source_time_delta_encoding], dim=1)source_message, edge_times"""to_update_node_ids.append(node_id)unique_messages.append(messages[node_id][-1][0])unique_timestamps.append(messages[node_id][-1][1])unique_messages = torch.stack(unique_messages) if len(to_update_node_ids) > 0 else []unique_timestamps = torch.stack(unique_timestamps) if len(to_update_node_ids) > 0 else []return to_update_node_ids, unique_messages, unique_timestamps
self.memory_updater.get_updated_memory

代碼中默認采用使用gru的方式去更新記憶

class SequenceMemoryUpdater(MemoryUpdater):def __init__(self, memory, message_dimension, memory_dimension, device):super(SequenceMemoryUpdater, self).__init__()self.memory = memoryself.layer_norm = torch.nn.LayerNorm(memory_dimension)self.message_dimension = message_dimensionself.device = devicedef update_memory(self, unique_node_ids, unique_messages, timestamps):if len(unique_node_ids) <= 0:returnassert (self.memory.get_last_update(unique_node_ids) <= timestamps).all().item(), "Trying to " \"update memory to time in the past"memory = self.memory.get_memory(unique_node_ids)self.memory.last_update[unique_node_ids] = timestampsupdated_memory = self.memory_updater(unique_messages, memory)self.memory.set_memory(unique_node_ids, updated_memory)def get_updated_memory(self, unique_node_ids, unique_messages, timestamps):if len(unique_node_ids) <= 0:# 這里的self.memory在下面進行定義# self.memory.memory 在初始化的時候是一個全為0,shape=[n_nodes, memory_dimension], 沒有梯度的矩陣# self.memory.last_update 在初始化的時候是一個全為0,shape=[n_nodes], 沒有梯度的舉證# 這里的clone是深拷貝,并不會影響原來的值是多少# 第二次就不是走這里咯return self.memory.memory.data.clone(), self.memory.last_update.data.clone()assert (self.memory.get_last_update(unique_node_ids) <= timestamps).all().item(), "Trying to " \"update memory to time in the past"updated_memory = self.memory.memory.data.clone()updated_memory[unique_node_ids] = self.memory_updater(unique_messages, updated_memory[unique_node_ids])updated_last_update = self.memory.last_update.data.clone()updated_last_update[unique_node_ids] = timestampsreturn updated_memory, updated_last_updateclass GRUMemoryUpdater(SequenceMemoryUpdater):def __init__(self, memory, message_dimension, memory_dimension, device):super(GRUMemoryUpdater, self).__init__(memory, message_dimension, memory_dimension, device)self.memory_updater = nn.GRUCell(input_size=message_dimension,hidden_size=memory_dimension)
Memory
class Memory(nn.Module):def __init__(self, n_nodes, memory_dimension, input_dimension, message_dimension=None,device="cpu", combination_method='sum'):super(Memory, self).__init__()self.n_nodes = n_nodesself.memory_dimension = memory_dimensionself.input_dimension = input_dimensionself.message_dimension = message_dimensionself.device = deviceself.combination_method = combination_methodself.__init_memory__()# 這是初是化def __init_memory__(self):"""Initializes the memory to all zeros. It should be called at the start of each epoch."""# Treat memory as parameter so that it is saved and loaded together with the model# self.memory_dimension = 172# self.n_nodes = 9228# self.memory shape is [9228, 172]的一個記憶,每一個節點都有對應的記憶,并且每一個記憶向量是172# self.memory = 一個全為0的矩陣self.memory = nn.Parameter(torch.zeros((self.n_nodes, self.memory_dimension)).to(self.device),requires_grad=False)# last_update shape = [9228]self.last_update = nn.Parameter(torch.zeros(self.n_nodes).to(self.device),requires_grad=False)self.messages = defaultdict(list)def store_raw_messages(self, nodes, node_id_to_messages):for node in nodes:self.messages[node].extend(node_id_to_messages[node])def get_memory(self, node_idxs):return self.memory[node_idxs, :]def set_memory(self, node_idxs, values):self.memory[node_idxs, :] = valuesdef get_last_update(self, node_idxs):return self.last_update[node_idxs]def backup_memory(self):messages_clone = {}for k, v in self.messages.items():messages_clone[k] = [(x[0].clone(), x[1].clone()) for x in v]return self.memory.data.clone(), self.last_update.data.clone(), messages_clonedef restore_memory(self, memory_backup):self.memory.data, self.last_update.data = memory_backup[0].clone(), memory_backup[1].clone()self.messages = defaultdict(list)for k, v in memory_backup[2].items():self.messages[k] = [(x[0].clone(), x[1].clone()) for x in v]def detach_memory(self):self.memory.detach_()# Detach all stored messagesfor k, v in self.messages.items():new_node_messages = []for message in v:new_node_messages.append((message[0].detach(), message[1]))self.messages[k] = new_node_messagesdef clear_messages(self, nodes):for node in nodes:self.messages[node] = []
get_embedding_module

這里的module_type=graph_attention

def get_embedding_module(module_type, node_features, edge_features, memory, neighbor_finder,time_encoder, n_layers, n_node_features, n_edge_features, n_time_features,embedding_dimension, device,n_heads=2, dropout=0.1, n_neighbors=None,use_memory=True):# embedding_module采用的是這個if module_type == "graph_attention":return GraphAttentionEmbedding(node_features=node_features,edge_features=edge_features,memory=memory,neighbor_finder=neighbor_finder,time_encoder=time_encoder,n_layers=n_layers,n_node_features=n_node_features,n_edge_features=n_edge_features,n_time_features=n_time_features,embedding_dimension=embedding_dimension,device=device,n_heads=n_heads, dropout=dropout, use_memory=use_memory)elif module_type == "graph_sum":return GraphSumEmbedding(node_features=node_features,edge_features=edge_features,memory=memory,neighbor_finder=neighbor_finder,time_encoder=time_encoder,n_layers=n_layers,n_node_features=n_node_features,n_edge_features=n_edge_features,n_time_features=n_time_features,embedding_dimension=embedding_dimension,device=device,n_heads=n_heads, dropout=dropout, use_memory=use_memory)elif module_type == "identity":return IdentityEmbedding(node_features=node_features,edge_features=edge_features,memory=memory,neighbor_finder=neighbor_finder,time_encoder=time_encoder,n_layers=n_layers,n_node_features=n_node_features,n_edge_features=n_edge_features,n_time_features=n_time_features,embedding_dimension=embedding_dimension,device=device,dropout=dropout)elif module_type == "time":return TimeEmbedding(node_features=node_features,edge_features=edge_features,memory=memory,neighbor_finder=neighbor_finder,time_encoder=time_encoder,n_layers=n_layers,n_node_features=n_node_features,n_edge_features=n_edge_features,n_time_features=n_time_features,embedding_dimension=embedding_dimension,device=device,dropout=dropout,n_neighbors=n_neighbors)else:raise ValueError("Embedding Module {} not supported".format(module_type))
GraphAttentionEmbedding
class GraphEmbedding(EmbeddingModule):def __init__(self, node_features, edge_features, memory, neighbor_finder, time_encoder, n_layers,n_node_features, n_edge_features, n_time_features, embedding_dimension, device,n_heads=2, dropout=0.1, use_memory=True):super(GraphEmbedding, self).__init__(node_features, edge_features, memory,neighbor_finder, time_encoder, n_layers,n_node_features, n_edge_features, n_time_features,embedding_dimension, device, dropout)self.use_memory = use_memoryself.device = devicedef compute_embedding(self, memory, source_nodes, timestamps, n_layers, n_neighbors=20, time_diffs=None,use_time_proj=True):"""Recursive implementation of curr_layers temporal graph attention layers.使用遞歸的方式來實現一系列時間圖注意力src_idx_l [batch_size]: users / items input ids.cut_time_l [batch_size]: scalar representing the instant of the time where we want to extract the user / item representation.curr_layers [scalar]: number of temporal convolutional layers to stack.num_neighbors [scalar]: number of temporal neighbor to consider in each convolutional layer.""""""memory 記憶對象source_nodes 是一個結合了源節點目的節點和負采樣節點的node_id列表(一開始是,后面不是)timestamps 200*3的時間列表self.n_layers 遞歸的層數 這里為2n_neighbors 選取多少個鄰居節點 這里是10time_diffs 標準化過后的時間差"""assert (n_layers >= 0)# source_nodes_torch shape = [n_nodes]source_nodes_torch = torch.from_numpy(source_nodes).long().to(self.device)# timestamps_torch shape = [3*200, 1]timestamps_torch = torch.unsqueeze(torch.from_numpy(timestamps).float().to(self.device), dim=1)# query node always has the start time -> time span == 0# 這里的time_encoder是一個模型,經過的是一個cos(linear(x)),在下面有對應的代碼# torch.zeros_like(timestamps_torch) 是一個全為0 shape = [3*200, 1]# source_nodes_time_embedding shape = [3*200, 1, 172]source_nodes_time_embedding = self.time_encoder(torch.zeros_like(timestamps_torch))# self.node_features是一個全為0的矩陣# self.node_features shape is [n_nodes, node_dim] = [9228, 172]# source_node_features 是所有節點的特征 shape is [600, 172]source_node_features = self.node_features[source_nodes_torch, :]if self.use_memory:# 將節點當前的特征 再加上記憶中節點的特征source_node_features = memory[source_nodes, :] + source_node_features# ====================================== 這下面執行了一個遞歸的操作 ==================================# n_layers = 1if n_layers == 0:return source_node_featureselse:# 再一次調用自己,返回的是節點的特征shape is [600, 172]source_node_conv_embeddings = self.compute_embedding(memory,source_nodes,timestamps,n_layers=n_layers - 1,n_neighbors=n_neighbors)# 獲得是source_nodes這3*200個節點,在3*200的時間列表中,選取前十個鄰居"""neighbors shape is [3*200, n_neighbors]edge_idxs shape is [3*200, n_neighbors]edge_times shape is [3*200, n_neighbors] """neighbors, edge_idxs, edge_times = self.neighbor_finder.get_temporal_neighbor(source_nodes,timestamps,n_neighbors=n_neighbors)# 這里的鄰居節點node_id是source_nodes中的每一個鄰居節點,變成torch形式neighbors_torch = torch.from_numpy(neighbors).long().to(self.device)edge_idxs = torch.from_numpy(edge_idxs).long().to(self.device)# 時間差,600個節點的edge_deltas = timestamps[:, np.newaxis] - edge_timesedge_deltas_torch = torch.from_numpy(edge_deltas).float().to(self.device)# 展平,變成6000neighbors = neighbors.flatten()# 這是neighbor_embeddings shape = [600*10, 172]neighbor_embeddings = self.compute_embedding(memory,neighbors, # 這里有6000個np.repeat(timestamps, n_neighbors), # 也是6000n_layers=n_layers - 1,n_neighbors=n_neighbors)effective_n_neighbors = n_neighbors if n_neighbors > 0 else 1# 這是neighbor_embeddings shape = [600, 10, 172]neighbor_embeddings = neighbor_embeddings.view(len(source_nodes), effective_n_neighbors, -1)# edge_time_embeddings shape is [600, 10, 172]edge_time_embeddings = self.time_encoder(edge_deltas_torch)# self.edge_features shape [157475, 172]# edge_idxs shape [600, 10]# edge_features shape [600, 10, 172]edge_features = self.edge_features[edge_idxs, :]mask = neighbors_torch == 0# 這個聚合在下面"""n_layers: 1source_node_conv_embeddings: 一開始那600個節點的編碼source_nodes_time_embedding: 數據是和timestamps_torch一樣的0矩陣[3*200, 1, 172]neighbor_embeddings: 之前那600個節點的發生過操作的鄰居edge_time_embeddings: 時間差編碼edge_features: 一開始那600個節點,對應的十個鄰居,分別邊的特征是多少mask = [600*10]"""source_embedding = self.aggregate(n_layers, source_node_conv_embeddings,source_nodes_time_embedding,neighbor_embeddings,edge_time_embeddings,edge_features,mask)return source_embeddingdef aggregate(self, n_layers, source_node_features, source_nodes_time_embedding,neighbor_embeddings,edge_time_embeddings, edge_features, mask):return NotImplementedclass GraphAttentionEmbedding(GraphEmbedding):def __init__(self, node_features, edge_features, memory, neighbor_finder, time_encoder, n_layers,n_node_features, n_edge_features, n_time_features, embedding_dimension, device,n_heads=2, dropout=0.1, use_memory=True):super(GraphAttentionEmbedding, self).__init__(node_features, edge_features, memory,neighbor_finder, time_encoder, n_layers,n_node_features, n_edge_features,n_time_features,embedding_dimension, device,n_heads, dropout,use_memory)self.attention_models = torch.nn.ModuleList([TemporalAttentionLayer(n_node_features=n_node_features,n_neighbors_features=n_node_features,n_edge_features=n_edge_features,time_dim=n_time_features,n_head=n_heads,dropout=dropout,output_dimension=n_node_features)for _ in range(n_layers)])def aggregate(self, n_layer, source_node_features, source_nodes_time_embedding,neighbor_embeddings,edge_time_embeddings, edge_features, mask):attention_model = self.attention_models[n_layer - 1]source_embedding, _ = attention_model(source_node_features,source_nodes_time_embedding,neighbor_embeddings,edge_time_embeddings,edge_features,mask)return source_embedding
TimeEncode
class TimeEncode(torch.nn.Module):# Time Encoding proposed by TGATdef __init__(self, dimension):super(TimeEncode, self).__init__()self.dimension = dimension # 172self.w = torch.nn.Linear(1, dimension)# todoself.w.weight = torch.nn.Parameter((torch.from_numpy(1 / 10 ** np.linspace(0, 9, dimension))).float().reshape(dimension, -1))self.w.bias = torch.nn.Parameter(torch.zeros(dimension).float())def forward(self, t): # -> [batch_size, seq_len, dimension]# t has shape [batch_size, seq_len]# Add dimension at the end to apply linear layer --> [batch_size, seq_len, 1]t = t.unsqueeze(dim=2)# output has shape [batch_size, seq_len, dimension]output = torch.cos(self.w(t))return output
NeighborFinder
class NeighborFinder:def __init__(self, adj_list, uniform=False, seed=None):self.node_to_neighbors = []self.node_to_edge_idxs = []self.node_to_edge_timestamps = []for neighbors in adj_list:# Neighbors is a list of tuples (neighbor, edge_idx, timestamp)# We sort the list based on timestampsorted_neighhbors = sorted(neighbors, key=lambda x: x[2])self.node_to_neighbors.append(np.array([x[0] for x in sorted_neighhbors]))  # 是一個二維數組,第一個維度表示的是某一個節點,第二個維度表示的是這個節點和那些節點發生的聯系self.node_to_edge_idxs.append(np.array([x[1] for x in sorted_neighhbors]))self.node_to_edge_timestamps.append(np.array([x[2] for x in sorted_neighhbors]))self.uniform = uniformif seed is not None:self.seed = seedself.random_state = np.random.RandomState(self.seed)def find_before(self, src_idx, cut_time):"""Extracts all the interactions happening before cut_time for user src_idx in the overall interaction graph. The returned interactions are sorted by time.Returns 3 lists: neighbors, edge_idxs, timestamps"""i = np.searchsorted(self.node_to_edge_timestamps[src_idx], cut_time)return self.node_to_neighbors[src_idx][:i], self.node_to_edge_idxs[src_idx][:i], self.node_to_edge_timestamps[src_idx][:i]def get_temporal_neighbor(self, source_nodes, timestamps, n_neighbors=20):"""Given a list of users ids and relative cut times, extracts a sampled temporal neighborhood of each user in the list.Params------src_idx_l: List[int]cut_time_l: List[float],num_neighbors: int"""assert (len(source_nodes) == len(timestamps))tmp_n_neighbors = n_neighbors if n_neighbors > 0 else 1# NB! All interactions described in these matrices are sorted in each row by timeneighbors = np.zeros((len(source_nodes), tmp_n_neighbors)).astype( # shape [600, 10]np.int32)  # each entry in position (i,j) represent the id of the item targeted by user src_idx_l[i] with an interaction happening before cut_time_l[i]edge_times = np.zeros((len(source_nodes), tmp_n_neighbors)).astype(np.float32)  # each entry in position (i,j) represent the timestamp of an interaction between user src_idx_l[i] and item neighbors[i,j] happening before cut_time_l[i]edge_idxs = np.zeros((len(source_nodes), tmp_n_neighbors)).astype(np.int32)  # each entry in position (i,j) represent the interaction index of an interaction between user src_idx_l[i] and item neighbors[i,j] happening before cut_time_l[i]for i, (source_node, timestamp) in enumerate(zip(source_nodes, timestamps)):source_neighbors, source_edge_idxs, source_edge_times = self.find_before(source_node,timestamp)  # extracts all neighbors, interactions indexes and timestamps of all interactions of user source_node happening before cut_timeif len(source_neighbors) > 0 and n_neighbors > 0:if self.uniform:  # if we are applying uniform sampling, shuffles the data above before samplingsampled_idx = np.random.randint(0, len(source_neighbors), n_neighbors)neighbors[i, :] = source_neighbors[sampled_idx]edge_times[i, :] = source_edge_times[sampled_idx]edge_idxs[i, :] = source_edge_idxs[sampled_idx]# re-sort based on timepos = edge_times[i, :].argsort()neighbors[i, :] = neighbors[i, :][pos]edge_times[i, :] = edge_times[i, :][pos]edge_idxs[i, :] = edge_idxs[i, :][pos]else:# Take most recent interactionssource_edge_times = source_edge_times[-n_neighbors:]source_neighbors = source_neighbors[-n_neighbors:]source_edge_idxs = source_edge_idxs[-n_neighbors:]assert (len(source_neighbors) <= n_neighbors)assert (len(source_edge_times) <= n_neighbors)assert (len(source_edge_idxs) <= n_neighbors)neighbors[i, n_neighbors - len(source_neighbors):] = source_neighborsedge_times[i, n_neighbors - len(source_edge_times):] = source_edge_timesedge_idxs[i, n_neighbors - len(source_edge_idxs):] = source_edge_idxsreturn neighbors, edge_idxs, edge_times
class TemporalAttentionLayer(torch.nn.Module):"""Temporal attention layer. Return the temporal embedding of a node given the node itself,its neighbors and the edge timestamps."""def __init__(self, n_node_features, n_neighbors_features, n_edge_features, time_dim,output_dimension, n_head=2,dropout=0.1):super(TemporalAttentionLayer, self).__init__()self.n_head = n_headself.feat_dim = n_node_featuresself.time_dim = time_dimself.query_dim = n_node_features + time_dimself.key_dim = n_neighbors_features + time_dim + n_edge_featuresself.merger = MergeLayer(self.query_dim, n_node_features, n_node_features, output_dimension)self.multi_head_target = nn.MultiheadAttention(embed_dim=self.query_dim,kdim=self.key_dim,vdim=self.key_dim,num_heads=n_head,dropout=dropout)def forward(self, src_node_features, src_time_features, neighbors_features,neighbors_time_features, edge_features, neighbors_padding_mask):""""Temporal attention model:param src_node_features: float Tensor of shape [batch_size, n_node_features]:param src_time_features: float Tensor of shape [batch_size, 1, time_dim]:param neighbors_features: float Tensor of shape [batch_size, n_neighbors, n_node_features]:param neighbors_time_features: float Tensor of shape [batch_size, n_neighbors,time_dim]:param edge_features: float Tensor of shape [batch_size, n_neighbors, n_edge_features]:param neighbors_padding_mask: float Tensor of shape [batch_size, n_neighbors]:return:attn_output: float Tensor of shape [1, batch_size, n_node_features]attn_output_weights: [batch_size, 1, n_neighbors]"""# src_node_features_unrolled shape is [600, 1, 172]src_node_features_unrolled = torch.unsqueeze(src_node_features, dim=1)# 將節點特征和時間特征結合在一起# query shape is [600, 1, 172*2]query = torch.cat([src_node_features_unrolled, src_time_features], dim=2)# 鄰居的特征、邊的特征和時間差特征組合在一起 key shape = [600, 10, 516]key = torch.cat([neighbors_features, edge_features, neighbors_time_features], dim=2)# query shape is [1, 600, 344]query = query.permute([1, 0, 2])  # [1, batch_size, num_of_features]# key shape is [10, 600, 516]key = key.permute([1, 0, 2])  # [n_neighbors, batch_size, num_of_features]# 在dim=1的維度下,要是全為True,那么就代表這一行是沒有用的,反之為Falseinvalid_neighborhood_mask = neighbors_padding_mask.all(dim=1, keepdim=True)# neighbors_padding_mask[invalid_neighborhood_mask.squeeze(), 0] = False# print(query.shape, key.shape)attn_output, attn_output_weights = self.multi_head_target(query=query, key=key, value=key,key_padding_mask=neighbors_padding_mask)# mask = torch.unsqueeze(neighbors_padding_mask, dim=2)  # mask [B, N, 1]# mask = mask.permute([0, 2, 1])# attn_output, attn_output_weights = self.multi_head_target(q=query, k=key, v=key,#                                                           mask=mask)# attn_output shape = [600, 344]# attn_output_weights = [600, 10]attn_output = attn_output.squeeze()attn_output_weights = attn_output_weights.squeeze()# Source nodes with no neighbors have an all zero attention output. The attention output is# then added or concatenated to the original source node features and then fed into an MLP.# This means that an all zero vector is not used.attn_output = attn_output.masked_fill(invalid_neighborhood_mask, 0)attn_output_weights = attn_output_weights.masked_fill(invalid_neighborhood_mask, 0)# Skip connection with temporal attention over neighborhood and the features of the node itself# attn_output = [600, 172]attn_output = self.merger(attn_output, src_node_features)return attn_output, attn_output_weights
MergeLayer
class MergeLayer(torch.nn.Module):def __init__(self, dim1, dim2, dim3, dim4):super().__init__()self.fc1 = torch.nn.Linear(dim1 + dim2, dim3)self.fc2 = torch.nn.Linear(dim3, dim4)self.act = torch.nn.ReLU()torch.nn.init.xavier_normal_(self.fc1.weight)torch.nn.init.xavier_normal_(self.fc2.weight)def forward(self, x1, x2):x = torch.cat([x1, x2], dim=1)h = self.act(self.fc1(x))return self.fc2(h)

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/212390.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/212390.shtml
英文地址,請注明出處:http://en.pswp.cn/news/212390.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

什么是W3C標準? 什么要遵循?

Hi i,m JinXiang ? 前言 ? 本篇文章主要介紹HTML5中W3C的標準&#xff0c;需要遵循的規則以及部分理論知識 &#x1f349;歡迎點贊 &#x1f44d; 收藏 ?留言評論 &#x1f4dd;私信必回喲&#x1f601; &#x1f349;博主收將持續更新學習記錄獲&#xff0c;友友們有任何問…

【AIGC】Midjourney高級進階版

Midjourney 真是越玩越上頭&#xff0c;真是給它的想象力跪了~ 研究了官方API&#xff0c;出一個進階版教程 命令 旨在介紹Midjourney在Discord頻道中的文本框中支持的指令。 1&#xff09;shorten 簡化Prompt 該指令可以將輸入的Prompt為模型可以理解的語言。模型理解語言…

Git初學入門指令

git基本指令 初始化&#xff1a; git init查看狀態&#xff1a; git status新建文件&#xff1a; touch <filename>加入暫存區&#xff1a; git add . 或者 git add -A 表示全部加入暫存區 git add <filename>單個文件加入暫存區加入倉庫&#xff1a; …

PCIe中斷總結-各個中斷的區別

1.簡介&#xff1a; PCIe中斷支持三種傳輸方式&#xff1a; ? Legacy: 也稱傳統中斷&#xff0c;Legacy PCI中斷機制,每個PCI設備最多支持四個中斷信號(邊帶信號:INTA# INTB# INTC# INTD#) ? MSI(Message Signaled Interrupt): Post Memory Write,只支持32個中斷向量 ? MSI…

力扣labuladong一刷day35天

力扣labuladong一刷day35天 文章目錄 力扣labuladong一刷day35天一、98. 驗證二叉搜索樹二、700. 二叉搜索樹中的搜索三、701. 二叉搜索樹中的插入操作四、450. 刪除二叉搜索樹中的節點 一、98. 驗證二叉搜索樹 題目鏈接&#xff1a;https://leetcode.cn/problems/validate-bi…

【Linux】如何對文本文件進行有條件地劃分?——cut命令

cut 命令可以根據一個指定的標記&#xff08;默認是 tab&#xff09;來為文本劃分列&#xff0c;然后將此列顯示。 例如想要顯示 passwd 文件的第一列可以使用以下命令&#xff1a;cut –f 1 –d : /etc/passwd cut&#xff1a;用于從文件的每一行中提取部分內容的命令。-f 1&…

Sql server數據庫數據查詢

請查詢學生信息表的所有記錄。 答&#xff1a;查詢所需的代碼如下&#xff1a; USE 學生管理數據庫 GO SELECT * FROM 學生信息表 執行結果如下&#xff1a; 查詢學生的學號、姓名和性別。 答&#xff1a;查詢所需的代碼如下&#xff1a; USE 學生管理數據庫 GO SELE…

為什么需要 Kubernetes,它能做什么?

傳統部署時代&#xff1a; 早期&#xff0c;各個組織是在物理服務器上運行應用程序。 由于無法限制在物理服務器中運行的應用程序資源使用&#xff0c;因此會導致資源分配問題。 例如&#xff0c;如果在同一臺物理服務器上運行多個應用程序&#xff0c; 則可能會出現一個應用程…

【QED】高昂的貓 Ⅰ

目錄 題目背景題目描述輸入格式輸出格式 測試樣例樣例說明數據范圍 思路核心代碼 題目背景 這是小橘。因為它總是看起來很高傲&#xff0c;所以人送外號“高昂的貓”。 題目描述 "錒狗"的房間里放著 n n n ( 1 ≤ n ≤ 1 0 9 ) (1 \leq n \leq 10^9) (1≤n≤109)個…

C# 使用CancellationTokenSource 取消Task執行

寫在前面 在Task創建并執行后&#xff0c;如果狀態發生了變化&#xff0c;需要取消正在執行中的Task&#xff0c;除了使用主線程上的共享變量來判斷之外&#xff0c;更優雅的方式就是就是用CancellationTokenSource來取消任務的執行。 代碼實現 public static void CancelTas…

主流MQ [Kafka、RabbitMQ、ZeroMQ、RocketMQ 和 ActiveMQ]

主流MQ [Kafka、RabbitMQ、ZeroMQ、RocketMQ 和 ActiveMQ] 一&#xff0c;MQ對比圖 下面是 Kafka、RabbitMQ、ZeroMQ、RocketMQ 和 ActiveMQ 的更詳細和專業的對比&#xff1a; 特性/功能KafkaRabbitMQZeroMQRocketMQActiveMQ語言JavaErlangCJavaJava協議自有協議AMQP自有協…

算法工程師-機器學習面試題總結(6)

目錄 1.Bagging的思想是什么&#xff1f;它是降低偏差還是方差&#xff0c;為什么&#xff1f; 2.可否將RF的基分類模型由決策樹改成線性模型或者knn&#xff1f;為什么&#xff1f; 3.GBDT梯度提升和梯度下降有什么區別和聯系&#xff1f; 4.如何理解Boosting和Bagging&am…

基于ssm高校實驗室管理系統的設計與實現論文

摘 要 互聯網發展至今&#xff0c;無論是其理論還是技術都已經成熟&#xff0c;而且它廣泛參與在社會中的方方面面。它讓信息都可以通過網絡傳播&#xff0c;搭配信息管理工具可以很好地為人們提供服務。針對高校實驗室信息管理混亂&#xff0c;出錯率高&#xff0c;信息安全性…

散列卡片懸停變為整齊列表

效果展示 CSS 知識點 transform 屬性運用 頁面整體布局 <ul><li><div class"box"><img src"./user1.jpg" /><div class"content"><h4>Hamidah</h4><p>commented on your photo.<br />…

Excel 數據處理記錄

20231203 excel中的字符串以符號間隔開了&#xff0c;如何將其中的字符串挑出&#xff0c;分別放到其他單元列&#xff1a; 在Excel中打開你的表格&#xff0c;選中包含以符號間隔的字符串的單元格。在頂部菜單中&#xff0c;找到“數據”選項&#xff0c;并選擇“分列”。在…

電腦主板支持的cpu型號匯總

一、如何選擇不同的主板和對應CPU 1、看針腳&#xff1a;網上有相應的參數&#xff0c;只要CPU能安裝到主板中&#xff0c;基本就兼容&#xff0c;這主要取決CPU插槽和主板插槽十分一致。 2、看型號&#xff1a;桌面處理器&#xff0c;只有Intel和AMD兩大平臺&#xff0c;他們對…

dlib是什么?

dlib C Libraryhttp://dlib.net/ dlib是什么&#xff1f; Dlib is a modern C toolkit containing machine learning algorithms and tools for creating complex software in C to solve real world problems. It is used in both industry and academia in a wide range of…

基于SSM的高校共享單車管理系統的設計與實現論文

摘 要 網絡技術和計算機技術發展至今&#xff0c;已經擁有了深厚的理論基礎&#xff0c;并在現實中進行了充分運用&#xff0c;尤其是基于計算機運行的軟件更是受到各界的關注。加上現在人們已經步入信息時代&#xff0c;所以對于信息的宣傳和管理就很關鍵。因此高校單車租賃信…

二百一十、Hive——Flume采集的JSON數據文件寫入Hive的ODS層表后字段的數據殘缺

一、目的 在用Flume把Kafka的數據采集寫入Hive的ODS層表的HDFS文件路徑后&#xff0c;發現HDFS文件中沒問題&#xff0c;但是ODS層表中字段的數據卻有問題&#xff0c;字段中的JSON數據不全 二、Hive處理JSON數據方式 &#xff08;一&#xff09;將Flume采集Kafka的JSON數據…

【華為OD題庫-075】拼接URL-Java

題目 題目描述: 給定一個url前綴和url后綴,通過,分割。需要將其連接為一個完整的url。 如果前綴結尾和后綴開頭都沒有/&#xff0c;需要自動補上/連接符 如果前綴結尾和后綴開頭都為/&#xff0c;需要自動去重 約束:不用考慮前后綴URL不合法情況 輸入描述: url前綴(一個長度小于…