文章目錄
- 前言
- 一、class CVRP_Decoder(nn.Module):__init__(self, **model_params)
- 函數功能
- 函數代碼
- 二、class CVRP_Decoder(nn.Module):set_kv(self, encoded_nodes)
- 函數功能
- 函數代碼
- 三、class CVRP_Decoder(nn.Module):set_q1(self, encoded_q1)
- 函數功能
- 函數代碼
- 四、class CVRP_Decoder(nn.Module):set_q2(self, encoded_q2)
- 函數功能
- 函數代碼
- 五、class CVRP_Decoder(nn.Module):forward(self, encoded_last_node, load, ninf_mask)
- 函數功能
- 函數代碼
- 附錄
- class CVRP_Decoder代碼(全)
前言
class CVRP_Decoder
是CVRP_Model.py
里的類。
/home/tang/RL_exa/NCO_code-main/single_objective/LCH-Regret/Regret-POMO/CVRP/POMO/CVRPModel.py
一、class CVRP_Decoder(nn.Module):init(self, **model_params)
函數功能
init 方法是 CVRP_Decoder 類中的構造函數,主要功能是初始化該類所需的所有網絡層、權重矩陣和參數。
該方法設置了用于多頭注意力機制的權重、一個用于表示"遺憾"的參數、以及其他必要的操作用于計算注意力權重。
執行流程圖鏈接
函數代碼
def __init__(self, **model_params):super().__init__()self.model_params = model_paramsembedding_dim = self.model_params['embedding_dim']head_num = self.model_params['head_num']qkv_dim = self.model_params['qkv_dim']# self.Wq_1 = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)self.Wq_2 = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)self.Wq_last = nn.Linear(embedding_dim+1, head_num * qkv_dim, bias=False)self.Wk = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)self.Wv = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)self.regret_embedding = nn.Parameter(torch.Tensor(embedding_dim))self.regret_embedding.data.uniform_(-1, 1)self.multi_head_combine = nn.Linear(head_num * qkv_dim, embedding_dim)self.k = None # saved key, for multi-head attentionself.v = None # saved value, for multi-head_attentionself.single_head_key = None # saved, for single-head attention# self.q1 = None # saved q1, for multi-head attentionself.q2 = None # saved q2, for multi-head attention
二、class CVRP_Decoder(nn.Module):set_kv(self, encoded_nodes)
函數功能
set_kv
方法的功能是將 encoded_nodes
中的節點嵌入轉換為多頭注意力機制所需的 鍵(K) 和 值(V),并將它們分別保存為類的屬性。
這個方法將輸入的節點嵌入通過權重矩陣進行線性變換,得到鍵和值的表示,并為后續的多頭注意力計算做好準備。
執行流程圖鏈接
函數代碼
def set_kv(self, encoded_nodes):# encoded_nodes.shape: (batch, problem+1, embedding)head_num = self.model_params['head_num']self.k = reshape_by_heads(self.Wk(encoded_nodes), head_num=head_num)self.v = reshape_by_heads(self.Wv(encoded_nodes), head_num=head_num)# shape: (batch, head_num, problem+1, qkv_dim)self.single_head_key = encoded_nodes.transpose(1, 2)# shape: (batch, embedding, problem+1)
三、class CVRP_Decoder(nn.Module):set_q1(self, encoded_q1)
函數功能
set_q1 方法的主要功能是 計算查詢(Q) 并將其轉換為適用于多頭注意力機制的形狀。
該方法接受輸入的查詢張量 encoded_q1,通過線性層 self.Wq_1 映射到一個新的維度,并使用 reshape_by_heads 函數將其調整為適合多頭注意力機制計算的形狀。計算出的查詢會被保存為類的屬性 q1,供后續使用。
執行流程圖鏈接
函數代碼
def set_q1(self, encoded_q1):# encoded_q.shape: (batch, n, embedding) # n can be 1 or pomohead_num = self.model_params['head_num']self.q1 = reshape_by_heads(self.Wq_1(encoded_q1), head_num=head_num)# shape: (batch, head_num, n, qkv_dim)
四、class CVRP_Decoder(nn.Module):set_q2(self, encoded_q2)
函數功能
set_q2 方法的主要功能是 計算查詢(Q) 并將其轉換為適用于多頭注意力機制的形狀。
該方法接收輸入的查詢張量 encoded_q2,通過線性層 self.Wq_2 映射到一個新的維度,并使用 reshape_by_heads 函數將其調整為適合多頭注意力計算的形狀。
執行流程圖鏈接
函數代碼
def set_q2(self, encoded_q2):# encoded_q.shape: (batch, n, embedding) # n can be 1 or pomohead_num = self.model_params['head_num']self.q2 = reshape_by_heads(self.Wq_2(encoded_q2), head_num=head_num)# shape: (batch, head_num, n, qkv_dim)
五、class CVRP_Decoder(nn.Module):forward(self, encoded_last_node, load, ninf_mask)
函數功能
forward 方法是 CVRP_Decoder 類中的前向傳播函數,主要功能是執行 多頭自注意力機制 和 單頭注意力計算,并最終輸出每個可能節點的選擇概率(probs)。
該方法通過多頭注意力計算、前饋神經網絡處理,以及概率計算來進行節點選擇。
執行流程圖鏈接
函數代碼
def forward(self, encoded_last_node, load, ninf_mask):# encoded_last_node.shape: (batch, pomo, embedding)# load.shape: (batch, pomo)# ninf_mask.shape: (batch, pomo, problem)head_num = self.model_params['head_num']# Multi-Head Attention#######################################################input_cat = torch.cat((encoded_last_node, load[:, :, None]), dim=2)# shape = (batch, group, EMBEDDING_DIM+1)q_last = reshape_by_heads(self.Wq_last(input_cat), head_num=head_num)# shape: (batch, head_num, pomo, qkv_dim)# q = self.q1 + self.q2 + q_last# # shape: (batch, head_num, pomo, qkv_dim)# q = q_last# shape: (batch, head_num, pomo, qkv_dim)q = self.q2 + q_last# # shape: (batch, head_num, pomo, qkv_dim)out_concat = multi_head_attention(q, self.k, self.v, rank3_ninf_mask=ninf_mask)# shape: (batch, pomo, head_num*qkv_dim)mh_atten_out = self.multi_head_combine(out_concat)# shape: (batch, pomo, embedding)# Single-Head Attention, for probability calculation#######################################################score = torch.matmul(mh_atten_out, self.single_head_key)# shape: (batch, pomo, problem)sqrt_embedding_dim = self.model_params['sqrt_embedding_dim']logit_clipping = self.model_params['logit_clipping']score_scaled = score / sqrt_embedding_dim# shape: (batch, pomo, problem)score_clipped = logit_clipping * torch.tanh(score_scaled)score_masked = score_clipped + ninf_maskprobs = F.softmax(score_masked, dim=2)# shape: (batch, pomo, problem)return probs
附錄
class CVRP_Decoder代碼(全)
class CVRP_Decoder(nn.Module):def __init__(self, **model_params):super().__init__()self.model_params = model_paramsembedding_dim = self.model_params['embedding_dim']head_num = self.model_params['head_num']qkv_dim = self.model_params['qkv_dim']# self.Wq_1 = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)self.Wq_2 = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)self.Wq_last = nn.Linear(embedding_dim+1, head_num * qkv_dim, bias=False)self.Wk = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)self.Wv = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)self.regret_embedding = nn.Parameter(torch.Tensor(embedding_dim))self.regret_embedding.data.uniform_(-1, 1)self.multi_head_combine = nn.Linear(head_num * qkv_dim, embedding_dim)self.k = None # saved key, for multi-head attentionself.v = None # saved value, for multi-head_attentionself.single_head_key = None # saved, for single-head attention# self.q1 = None # saved q1, for multi-head attentionself.q2 = None # saved q2, for multi-head attentiondef set_kv(self, encoded_nodes):# encoded_nodes.shape: (batch, problem+1, embedding)head_num = self.model_params['head_num']self.k = reshape_by_heads(self.Wk(encoded_nodes), head_num=head_num)self.v = reshape_by_heads(self.Wv(encoded_nodes), head_num=head_num)# shape: (batch, head_num, problem+1, qkv_dim)self.single_head_key = encoded_nodes.transpose(1, 2)# shape: (batch, embedding, problem+1)def set_q1(self, encoded_q1):# encoded_q.shape: (batch, n, embedding) # n can be 1 or pomohead_num = self.model_params['head_num']self.q1 = reshape_by_heads(self.Wq_1(encoded_q1), head_num=head_num)# shape: (batch, head_num, n, qkv_dim)def set_q2(self, encoded_q2):# encoded_q.shape: (batch, n, embedding) # n can be 1 or pomohead_num = self.model_params['head_num']self.q2 = reshape_by_heads(self.Wq_2(encoded_q2), head_num=head_num)# shape: (batch, head_num, n, qkv_dim)def forward(self, encoded_last_node, load, ninf_mask):# encoded_last_node.shape: (batch, pomo, embedding)# load.shape: (batch, pomo)# ninf_mask.shape: (batch, pomo, problem)head_num = self.model_params['head_num']# Multi-Head Attention#######################################################input_cat = torch.cat((encoded_last_node, load[:, :, None]), dim=2)# shape = (batch, group, EMBEDDING_DIM+1)q_last = reshape_by_heads(self.Wq_last(input_cat), head_num=head_num)# shape: (batch, head_num, pomo, qkv_dim)# q = self.q1 + self.q2 + q_last# # shape: (batch, head_num, pomo, qkv_dim)# q = q_last# shape: (batch, head_num, pomo, qkv_dim)q = self.q2 + q_last# # shape: (batch, head_num, pomo, qkv_dim)out_concat = multi_head_attention(q, self.k, self.v, rank3_ninf_mask=ninf_mask)# shape: (batch, pomo, head_num*qkv_dim)mh_atten_out = self.multi_head_combine(out_concat)# shape: (batch, pomo, embedding)# Single-Head Attention, for probability calculation#######################################################score = torch.matmul(mh_atten_out, self.single_head_key)# shape: (batch, pomo, problem)sqrt_embedding_dim = self.model_params['sqrt_embedding_dim']logit_clipping = self.model_params['logit_clipping']score_scaled = score / sqrt_embedding_dim# shape: (batch, pomo, problem)score_clipped = logit_clipping * torch.tanh(score_scaled)score_masked = score_clipped + ninf_maskprobs = F.softmax(score_masked, dim=2)# shape: (batch, pomo, problem)return probs