20250225-代碼筆記03-class CVRPModel AND other class

文章目錄

  • 前言
  • 一、class CVRPModel(nn.Module):__init__(self, **model_params)
    • 函數功能
    • 函數代碼
  • 二、class CVRPModel(nn.Module):pre_forward(self, reset_state)
    • 函數功能
    • 函數代碼
  • 三、class CVRPModel(nn.Module):forward(self, state)
    • 函數功能
    • 函數代碼
  • 四、def _get_encoding(encoded_nodes, node_index_to_pick)
    • 函數功能
    • 函數代碼
  • 五、class CVRP_Encoder(nn.Module)
  • 六、class EncoderLayer(nn.Module)
  • 七、CVRP_Decoder(nn.Module)
  • 八、def reshape_by_heads(qkv, head_num)
    • 函數功能
    • 函數代碼
  • 九、def multi_head_attention(q, k, v, rank2_ninf_mask=None, rank3_ninf_mask=None)
    • 函數功能
    • 函數代碼
  • 十、class AddAndInstanceNormalization(nn.Module):__init__(self, **model_params)
    • 函數功能
    • Batch Normalization (BN) 是什么?
      • Batch Normalization 的具體操作
        • 1. **計算均值和方差**
        • 2. **標準化**
        • 3. **縮放和平移**
      • Batch Normalization 的優勢
    • 函數代碼
  • 十一、class AddAndInstanceNormalization(nn.Module):forward(self, input1, input2)
    • 函數功能
    • 函數代碼
  • 十二、class FeedForward(nn.Module):__init__(self, **model_params)
    • 函數功能
    • 函數代碼
  • 十三、class FeedForward(nn.Module):forward(self, input1)
    • 函數功能
    • 函數代碼
  • 附錄
    • 代碼(全)


前言

學習代碼:
class CVRPModel(nn.Module):
class CVRP_Encoder(nn.Module):
class EncoderLayer(nn.Module):
class CVRP_Decoder(nn.Module):
class AddAndInstanceNormalization(nn.Module):
class AddAndBatchNormalization(nn.Module):
class FeedForward(nn.Module):

/home/tang/RL_exa/NCO_code-main/single_objective/LCH-Regret/Regret-POMO/CVRP/POMO/CVRPModel.py


一、class CVRPModel(nn.Module):init(self, **model_params)

函數功能

init 是 CVRPModel 類的構造函數,負責初始化模型的各個組件。
主要任務包括:

  • 接收和存儲模型的參數(model_params)。
  • 初始化編碼器(encoder)和解碼器(decoder)子模塊。
  • 初始化 encoded_nodes 變量,用于存儲經過編碼的節點數據。

執行流程圖鏈接
在這里插入圖片描述

函數代碼

    def __init__(self, **model_params):super().__init__()self.model_params = model_paramsself.encoder = CVRP_Encoder(**model_params)self.decoder = CVRP_Decoder(**model_params)self.encoded_nodes = None# shape: (batch, problem+1, EMBEDDING_DIM)

二、class CVRPModel(nn.Module):pre_forward(self, reset_state)

函數功能

pre_forward 是 CVRPModel 類的一個前向傳播前的準備函數。它的主要任務是根據給定的初始狀態(reset_state)準備和編碼數據,為模型的后續前向傳播(forward)過程做準備
具體來說,函數的作用是:

  • 提取并處理初始狀態的數據。
  • 使用編碼器對節點進行編碼,得到編碼后的節點表示。
  • 為解碼器設置額外的嵌入信息,并將編碼后的節點與額外的嵌入信息拼接。
  • 設置解碼器中的 kv(key-value)信息,為解碼過程做準備。

執行流程圖鏈接
在這里插入圖片描述

函數代碼

    def pre_forward(self, reset_state):depot_xy = reset_state.depot_xy# shape: (batch, 1, 2)node_xy = reset_state.node_xy# shape: (batch, problem, 2)node_demand = reset_state.node_demand# shape: (batch, problem)node_xy_demand = torch.cat((node_xy, node_demand[:, :, None]), dim=2)# shape: (batch, problem, 3)encoded_nodes = self.encoder(depot_xy, node_xy_demand)# shape: (batch, problem+1, embedding)_ = self.decoder.regret_embedding[None, None, :].expand(encoded_nodes.size(0), 1,self.decoder.regret_embedding.size(-1))# _ 的shape:(batch,1,embedding)self.encoded_nodes = torch.cat((encoded_nodes, _), dim=1)# self.encoded_nodes的shape:(batch,problem+2,embedding)self.decoder.set_kv(self.encoded_nodes)

三、class CVRPModel(nn.Module):forward(self, state)

函數功能

forward 是 CVRPModel 類的核心前向傳播函數,用于根據當前狀態(state)生成模型的輸出,包括選擇的節點(selected)和相關的概率(prob)。
它的主要功能是基于當前的狀態和歷史選擇來決定接下來應該選擇哪個節點,并輸出相應的概率。

執行流程圖鏈接
在這里插入圖片描述

函數代碼

    def forward(self, state):batch_size = state.BATCH_IDX.size(0)pomo_size = state.BATCH_IDX.size(1)if state.selected_count == 0:  # First Move, depotselected = torch.zeros(size=(batch_size, pomo_size), dtype=torch.long)prob = torch.ones(size=(batch_size, pomo_size))# # Use Averaged encoded nodes for decoder input_1# encoded_nodes_mean = self.encoded_nodes.mean(dim=1, keepdim=True)# # shape: (batch, 1, embedding)# self.decoder.set_q1(encoded_nodes_mean)# Use encoded_depot for decoder input_2encoded_first_node = self.encoded_nodes[:, [0], :]# shape: (batch, 1, embedding)self.decoder.set_q2(encoded_first_node)elif state.selected_count == 1:  # Second Move, POMOselected = torch.arange(start=1, end=pomo_size+1)[None, :].expand(batch_size, pomo_size)prob = torch.ones(size=(batch_size, pomo_size))else:encoded_last_node = _get_encoding(self.encoded_nodes, state.current_node)# shape: (batch, pomo, embedding)probs = self.decoder(encoded_last_node, state.load, ninf_mask=state.ninf_mask)# shape: (batch, pomo, problem+1)if self.training or self.model_params['eval_type'] == 'softmax':while True:  # to fix pytorch.multinomial bug on selecting 0 probability elementswith torch.no_grad():selected = probs.reshape(batch_size * pomo_size, -1).multinomial(1) \.squeeze(dim=1).reshape(batch_size, pomo_size)# shape: (batch, pomo)prob = probs[state.BATCH_IDX, state.POMO_IDX, selected].reshape(batch_size, pomo_size)# shape: (batch, pomo)if (prob != 0).all():breakelse:probs=probs[:,:,:-1]selected = probs.argmax(dim=2)# shape: (batch, pomo)prob = None  # value not needed. Can be anything.return selected, prob

四、def _get_encoding(encoded_nodes, node_index_to_pick)

函數功能

_get_encoding 的作用是從 encoded_nodes 中按照 node_index_to_pick 選擇相應的編碼,并返回選中的編碼信息。

函數執行流程圖鏈接
在這里插入圖片描述

函數代碼

def _get_encoding(encoded_nodes, node_index_to_pick):# encoded_nodes.shape: (batch, problem, embedding)# node_index_to_pick.shape: (batch, pomo)batch_size = node_index_to_pick.size(0)pomo_size = node_index_to_pick.size(1)embedding_dim = encoded_nodes.size(2)gathering_index = node_index_to_pick[:, :, None].expand(batch_size, pomo_size, embedding_dim)# shape: (batch, pomo, embedding)picked_nodes = encoded_nodes.gather(dim=1, index=gathering_index)# shape: (batch, pomo, embedding)return picked_nodes

五、class CVRP_Encoder(nn.Module)

筆記:20250226-代碼筆記04-class CVRP_Encoder AND class EncoderLayer


六、class EncoderLayer(nn.Module)

筆記:20250226-代碼筆記04-class CVRP_Encoder AND class EncoderLayer


七、CVRP_Decoder(nn.Module)

筆記:20250226-代碼筆記05-class CVRP_Decoder


八、def reshape_by_heads(qkv, head_num)

函數功能

reshape_by_heads 函數的功能是將輸入的張量(如查詢 q, 鍵 k, 或值 v)從一個緊湊的多頭結構 (batch, n, head_num * key_dim) 轉換為適合多頭注意力機制計算的結構 (batch, head_num, n, key_dim)
此操作將多個注意力頭的維度進行拆分,并將其調整為每個頭獨立計算的格式。
執行流程圖鏈接
在這里插入圖片描述

函數代碼

def reshape_by_heads(qkv, head_num):# q.shape: (batch, n, head_num*key_dim)   : n can be either 1 or PROBLEM_SIZEbatch_s = qkv.size(0)n = qkv.size(1)q_reshaped = qkv.reshape(batch_s, n, head_num, -1)# shape: (batch, n, head_num, key_dim)q_transposed = q_reshaped.transpose(1, 2)# shape: (batch, head_num, n, key_dim)return q_transposed

九、def multi_head_attention(q, k, v, rank2_ninf_mask=None, rank3_ninf_mask=None)

函數功能

multi_head_attention 函數的主要功能是實現 多頭注意力機制。該函數接收查詢(Q)、鍵(K)和值(V),并計算多頭注意力輸出。它通過計算查詢與鍵之間的相似度,生成加權值的結果,并結合所有頭的輸出生成最終的注意力表示。
執行流程圖鏈接
在這里插入圖片描述

函數代碼

def multi_head_attention(q, k, v, rank2_ninf_mask=None, rank3_ninf_mask=None):# q shape: (batch, head_num, n, key_dim)   : n can be either 1 or PROBLEM_SIZE# k,v shape: (batch, head_num, problem, key_dim)# rank2_ninf_mask.shape: (batch, problem)# rank3_ninf_mask.shape: (batch, group, problem)batch_s = q.size(0)head_num = q.size(1)n = q.size(2)key_dim = q.size(3)input_s = k.size(2)score = torch.matmul(q, k.transpose(2, 3))# shape: (batch, head_num, n, problem)score_scaled = score / torch.sqrt(torch.tensor(key_dim, dtype=torch.float))if rank2_ninf_mask is not None:score_scaled = score_scaled + rank2_ninf_mask[:, None, None, :].expand(batch_s, head_num, n, input_s)if rank3_ninf_mask is not None:score_scaled = score_scaled + rank3_ninf_mask[:, None, :, :].expand(batch_s, head_num, n, input_s)weights = nn.Softmax(dim=3)(score_scaled)# shape: (batch, head_num, n, problem)out = torch.matmul(weights, v)# shape: (batch, head_num, n, key_dim)out_transposed = out.transpose(1, 2)# shape: (batch, n, head_num, key_dim)out_concat = out_transposed.reshape(batch_s, n, head_num * key_dim)# shape: (batch, n, head_num*key_dim)return out_concat

十、class AddAndInstanceNormalization(nn.Module):init(self, **model_params)

函數功能

對輸入數據進行基于嵌入維度的批量標準化操作,從而使得模型在訓練過程中能夠更好地收斂和提高穩定性。

Batch Normalization (BN) 是什么?

Batch Normalization (BN) 是一種在訓練深度神經網絡時常用的技術,它的目的是提高網絡的訓練速度、穩定性,并幫助避免梯度消失或爆炸問題。
Batch Normalization 操作的核心思想是對每一層的輸入數據進行標準化,使得輸入數據的均值接近 0,方差接近 1。這樣可以避免激活函數輸出過大或過小的問題,幫助優化過程更加穩定。

Batch Normalization 的具體操作

1. 計算均值和方差

對于一批輸入樣本(batch),在每個特征維度上計算均值和方差:

  • 均值
    μ B = 1 m ∑ i = 1 m x i \mu_B = \frac{1}{m} \sum_{i=1}^{m} x_i μB?=m1?i=1m?xi?

  • 方差
    σ B 2 = 1 m ∑ i = 1 m ( x i ? μ B ) 2 \sigma_B^2 = \frac{1}{m} \sum_{i=1}^{m} (x_i - \mu_B)^2 σB2?=m1?i=1m?(xi??μB?)2

其中, m m m 是一個批次中的樣本數, x i x_i xi?是每個樣本的輸入值。

2. 標準化

使用計算出的均值和方差將輸入數據標準化,使得每個特征的均值為 0,方差為 1:

x ^ i = x i ? μ B σ B 2 + ? \hat{x}_i = \frac{x_i - \mu_B}{\sqrt{\sigma_B^2 + \epsilon}} x^i?=σB2?+? ?xi??μB??

這里 ? \epsilon ?是一個非常小的數值,用來防止除以零的情況。

3. 縮放和平移

由于標準化可能會影響到模型的表達能力,Batch Normalization 還會引入兩個可學習的參數 γ \gamma γ(縮放參數)和 β \beta β(平移參數),它們允許模型重新調整標準化后的數據:

y i = γ x ^ i + β y_i = \gamma \hat{x}_i + \beta yi?=γx^i?+β

其中, γ \gamma γ β \beta β是學習的參數,通常會通過反向傳播進行優化。

Batch Normalization 的優勢

  • 加速訓練:Batch Normalization 通過減少輸入數據的偏移(internal covariate shift),使得每一層的輸入分布更加穩定,從而加速了網絡的訓練過程。
  • 提高穩定性:由于它通過標準化輸入避免了梯度爆炸或梯度消失問題,使得訓練更加穩定。
  • 緩解過擬合:在一些情況下,Batch Normalization 也可以起到正則化的作用,減少了模型對訓練數據的過擬合。
  • 減少對初始化的依賴:Batch Normalization 可以在一定程度上緩解對權重初始化的敏感性。

函數代碼

    def __init__(self, **model_params):super().__init__()embedding_dim = model_params['embedding_dim']self.norm = nn.InstanceNorm1d(embedding_dim, affine=True, track_running_stats=False)

十一、class AddAndInstanceNormalization(nn.Module):forward(self, input1, input2)

函數功能

forward 方法,它執行了加法和批量歸一化操作。
forward 方法的主要功能是:

  • 加法操作:將兩個輸入張量 input1input2 相加。
  • 批量歸一化:將加法結果進行批量歸一化(Batch Normalization),標準化其特征維度。
  • 形狀恢復:批量歸一化后,將張量的形狀恢復到原來的維度。

執行流程:

函數代碼

  1. 獲取輸入張量的維度:
batch_s = input1.size(0)
problem_s = input1.size(1)
embedding_dim = input1.size(2)
  • batch_s 表示批次大小,problem_s 表示問題的大小(特征的數量),embedding_dim 表示嵌入的維度。
  • 這些維度來自輸入張量input1,并且假設 input2 具有相同的形狀。
  1. 加法操作:
added = input1 + input2
  • input1 nput2 進行逐元素加法。此時,added 張量的形狀與 input1input2 相同,仍為 (batch_s, problem_s, embedding_dim)
  1. 批量歸一化:
normalized = self.norm_by_EMB(added.reshape(batch_s * problem_s, embedding_dim))
  • added 張量的形狀重塑為 (batch_s * problem_s, embedding_dim),將批次維度和問題維度合并,以便進行批量歸一化操作。這樣就對每個特征維度(embedding_dim)做了批量標準化。
  • self.norm_by_EMB 是一個 BatchNorm1d 層,它會對每個特征維度執行標準化,使得每個特征的均值接近 0,方差接近 1。
  1. 恢復形狀:
back_trans = normalized.reshape(batch_s, problem_s, embedding_dim)
  • 批量歸一化后,將 normalized 張量的形狀恢復回 (batch_s, problem_s, embedding_dim),即恢復原本的輸入形狀。
  1. 返回結果:
return back_trans
  • 返回經過批量歸一化的張量 back_trans,它的形狀與輸入相同,并且每個特征維度已經經過標準化。
    def forward(self, input1, input2):# input.shape: (batch, problem, embedding)added = input1 + input2# shape: (batch, problem, embedding)transposed = added.transpose(1, 2)# shape: (batch, embedding, problem)normalized = self.norm(transposed)# shape: (batch, embedding, problem)back_trans = normalized.transpose(1, 2)# shape: (batch, problem, embedding)return back_trans

十二、class FeedForward(nn.Module):init(self, **model_params)

函數功能

FeedForward 的類,它是一個典型的前饋神經網絡(Feedforward Neural Network)模塊,實現了一個簡單的兩層神經網絡。

  • __init__ 方法是類的構造函數,用來初始化網絡的層和超參數。
  • embedding_dim ff_hidden_dim 是通過 model_params 傳遞的超參數,分別表示嵌入維度和前饋神經網絡隱藏層的維度。
    • embedding_dim 是輸入和輸出的維度。
    • ff_hidden_dim 是隱藏層的維度,即在網絡的中間層。
  • self.W1 self.W2是兩個全連接層nn.Linear):
    • self.W1 將輸入的 embedding_dim 維度的向量轉換為 ff_hidden_dim 維度的向量。
    • self.W2 ff_hidden_dim 維度的向量轉換回 embedding_dim 維度的向量。

函數代碼

    def __init__(self, **model_params):super().__init__()embedding_dim = model_params['embedding_dim']ff_hidden_dim = model_params['ff_hidden_dim']self.W1 = nn.Linear(embedding_dim, ff_hidden_dim)self.W2 = nn.Linear(ff_hidden_dim, embedding_dim)

十三、class FeedForward(nn.Module):forward(self, input1)

函數功能

  • forward 方法定義了數據流通過網絡的方式,也就是前向傳播過程。
  • 輸入 input1 的形狀為 (batch, problem, embedding),即批次大小 batch、問題數量 problem和每個問題的嵌入維度embedding
  • 執行的步驟如下:
    • 1.第一層線性變換(self.W1:輸入通過 self.W1 進行線性變換,將輸入的嵌入維度轉換為隱藏層的維度(ff_hidden_dim)。變換公式為:
      在這里插入圖片描述
      其中 x 是輸入,W1 是權重矩陣,b1 是偏置。

    • 2.激活函數(ReLU):對 self.W1 的輸出應用 ReLU 激活函數,ReLU 將負值歸零,保留正值。公式為:
      在這里插入圖片描述

    • 3.第二層線性變換(self.W2:通過 self.W2 進行線性變換,將隱藏層的輸出轉換回原始的嵌入維度(embedding_dim)。變換公式為:
      在這里插入圖片描述

  • 最終輸出是經過兩層線性變換和 ReLU 激活函數處理的結果,形狀仍然是 (batch, problem, embedding)。

函數代碼

    def forward(self, input1):# input.shape: (batch, problem, embedding)return self.W2(F.relu(self.W1(input1)))

附錄

代碼(全)


import torch
import torch.nn as nn
import torch.nn.functional as Fclass CVRPModel(nn.Module):def __init__(self, **model_params):super().__init__()self.model_params = model_paramsself.encoder = CVRP_Encoder(**model_params)self.decoder = CVRP_Decoder(**model_params)self.encoded_nodes = None# shape: (batch, problem+1, EMBEDDING_DIM)def pre_forward(self, reset_state):depot_xy = reset_state.depot_xy# shape: (batch, 1, 2)node_xy = reset_state.node_xy# shape: (batch, problem, 2)node_demand = reset_state.node_demand# shape: (batch, problem)node_xy_demand = torch.cat((node_xy, node_demand[:, :, None]), dim=2)# shape: (batch, problem, 3)encoded_nodes = self.encoder(depot_xy, node_xy_demand)# shape: (batch, problem+1, embedding)_ = self.decoder.regret_embedding[None, None, :].expand(encoded_nodes.size(0), 1,self.decoder.regret_embedding.size(-1))# _ 的shape:(batch,1,embedding)self.encoded_nodes = torch.cat((encoded_nodes, _), dim=1)# self.encoded_nodes的shape:(batch,problem+2,embedding)self.decoder.set_kv(self.encoded_nodes)def forward(self, state):batch_size = state.BATCH_IDX.size(0)pomo_size = state.BATCH_IDX.size(1)if state.selected_count == 0:  # First Move, depotselected = torch.zeros(size=(batch_size, pomo_size), dtype=torch.long)prob = torch.ones(size=(batch_size, pomo_size))# # Use Averaged encoded nodes for decoder input_1# encoded_nodes_mean = self.encoded_nodes.mean(dim=1, keepdim=True)# # shape: (batch, 1, embedding)# self.decoder.set_q1(encoded_nodes_mean)# Use encoded_depot for decoder input_2encoded_first_node = self.encoded_nodes[:, [0], :]# shape: (batch, 1, embedding)self.decoder.set_q2(encoded_first_node)elif state.selected_count == 1:  # Second Move, POMOselected = torch.arange(start=1, end=pomo_size+1)[None, :].expand(batch_size, pomo_size)prob = torch.ones(size=(batch_size, pomo_size))else:encoded_last_node = _get_encoding(self.encoded_nodes, state.current_node)# shape: (batch, pomo, embedding)probs = self.decoder(encoded_last_node, state.load, ninf_mask=state.ninf_mask)# shape: (batch, pomo, problem+1)if self.training or self.model_params['eval_type'] == 'softmax':while True:  # to fix pytorch.multinomial bug on selecting 0 probability elementswith torch.no_grad():selected = probs.reshape(batch_size * pomo_size, -1).multinomial(1) \.squeeze(dim=1).reshape(batch_size, pomo_size)# shape: (batch, pomo)prob = probs[state.BATCH_IDX, state.POMO_IDX, selected].reshape(batch_size, pomo_size)# shape: (batch, pomo)if (prob != 0).all():breakelse:probs=probs[:,:,:-1]selected = probs.argmax(dim=2)# shape: (batch, pomo)prob = None  # value not needed. Can be anything.return selected, probdef _get_encoding(encoded_nodes, node_index_to_pick):# encoded_nodes.shape: (batch, problem, embedding)# node_index_to_pick.shape: (batch, pomo)batch_size = node_index_to_pick.size(0)pomo_size = node_index_to_pick.size(1)embedding_dim = encoded_nodes.size(2)gathering_index = node_index_to_pick[:, :, None].expand(batch_size, pomo_size, embedding_dim)# shape: (batch, pomo, embedding)picked_nodes = encoded_nodes.gather(dim=1, index=gathering_index)# shape: (batch, pomo, embedding)return picked_nodes########################################
# ENCODER
########################################class CVRP_Encoder(nn.Module):def __init__(self, **model_params):super().__init__()self.model_params = model_paramsembedding_dim = self.model_params['embedding_dim']encoder_layer_num = self.model_params['encoder_layer_num']self.embedding_depot = nn.Linear(2, embedding_dim)self.embedding_node = nn.Linear(3, embedding_dim)self.layers = nn.ModuleList([EncoderLayer(**model_params) for _ in range(encoder_layer_num)])def forward(self, depot_xy, node_xy_demand):# depot_xy.shape: (batch, 1, 2)# node_xy_demand.shape: (batch, problem, 3)embedded_depot = self.embedding_depot(depot_xy)# shape: (batch, 1, embedding)embedded_node = self.embedding_node(node_xy_demand)# shape: (batch, problem, embedding)out = torch.cat((embedded_depot, embedded_node), dim=1)# shape: (batch, problem+1, embedding)for layer in self.layers:out = layer(out)return out# shape: (batch, problem+1, embedding)class EncoderLayer(nn.Module):def __init__(self, **model_params):super().__init__()self.model_params = model_paramsembedding_dim = self.model_params['embedding_dim']head_num = self.model_params['head_num']qkv_dim = self.model_params['qkv_dim']self.Wq = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)self.Wk = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)self.Wv = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)self.multi_head_combine = nn.Linear(head_num * qkv_dim, embedding_dim)self.add_n_normalization_1 = AddAndInstanceNormalization(**model_params)self.feed_forward = FeedForward(**model_params)self.add_n_normalization_2 = AddAndInstanceNormalization(**model_params)def forward(self, input1):# input1.shape: (batch, problem+1, embedding)head_num = self.model_params['head_num']q = reshape_by_heads(self.Wq(input1), head_num=head_num)k = reshape_by_heads(self.Wk(input1), head_num=head_num)v = reshape_by_heads(self.Wv(input1), head_num=head_num)# qkv shape: (batch, head_num, problem, qkv_dim)out_concat = multi_head_attention(q, k, v)# shape: (batch, problem, head_num*qkv_dim)multi_head_out = self.multi_head_combine(out_concat)# shape: (batch, problem, embedding)out1 = self.add_n_normalization_1(input1, multi_head_out)out2 = self.feed_forward(out1)out3 = self.add_n_normalization_2(out1, out2)return out3# shape: (batch, problem, embedding)########################################
# DECODER
########################################class CVRP_Decoder(nn.Module):def __init__(self, **model_params):super().__init__()self.model_params = model_paramsembedding_dim = self.model_params['embedding_dim']head_num = self.model_params['head_num']qkv_dim = self.model_params['qkv_dim']# self.Wq_1 = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)self.Wq_2 = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)self.Wq_last = nn.Linear(embedding_dim+1, head_num * qkv_dim, bias=False)self.Wk = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)self.Wv = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)self.regret_embedding = nn.Parameter(torch.Tensor(embedding_dim))self.regret_embedding.data.uniform_(-1, 1)self.multi_head_combine = nn.Linear(head_num * qkv_dim, embedding_dim)self.k = None  # saved key, for multi-head attentionself.v = None  # saved value, for multi-head_attentionself.single_head_key = None  # saved, for single-head attention# self.q1 = None  # saved q1, for multi-head attentionself.q2 = None  # saved q2, for multi-head attentiondef set_kv(self, encoded_nodes):# encoded_nodes.shape: (batch, problem+1, embedding)head_num = self.model_params['head_num']self.k = reshape_by_heads(self.Wk(encoded_nodes), head_num=head_num)self.v = reshape_by_heads(self.Wv(encoded_nodes), head_num=head_num)# shape: (batch, head_num, problem+1, qkv_dim)self.single_head_key = encoded_nodes.transpose(1, 2)# shape: (batch, embedding, problem+1)def set_q1(self, encoded_q1):# encoded_q.shape: (batch, n, embedding)  # n can be 1 or pomohead_num = self.model_params['head_num']self.q1 = reshape_by_heads(self.Wq_1(encoded_q1), head_num=head_num)# shape: (batch, head_num, n, qkv_dim)def set_q2(self, encoded_q2):# encoded_q.shape: (batch, n, embedding)  # n can be 1 or pomohead_num = self.model_params['head_num']self.q2 = reshape_by_heads(self.Wq_2(encoded_q2), head_num=head_num)# shape: (batch, head_num, n, qkv_dim)def forward(self, encoded_last_node, load, ninf_mask):# encoded_last_node.shape: (batch, pomo, embedding)# load.shape: (batch, pomo)# ninf_mask.shape: (batch, pomo, problem)head_num = self.model_params['head_num']#  Multi-Head Attention#######################################################input_cat = torch.cat((encoded_last_node, load[:, :, None]), dim=2)# shape = (batch, group, EMBEDDING_DIM+1)q_last = reshape_by_heads(self.Wq_last(input_cat), head_num=head_num)# shape: (batch, head_num, pomo, qkv_dim)# q = self.q1 + self.q2 + q_last# # shape: (batch, head_num, pomo, qkv_dim)# q = q_last# shape: (batch, head_num, pomo, qkv_dim)q = self.q2 + q_last# # shape: (batch, head_num, pomo, qkv_dim)out_concat = multi_head_attention(q, self.k, self.v, rank3_ninf_mask=ninf_mask)# shape: (batch, pomo, head_num*qkv_dim)mh_atten_out = self.multi_head_combine(out_concat)# shape: (batch, pomo, embedding)#  Single-Head Attention, for probability calculation#######################################################score = torch.matmul(mh_atten_out, self.single_head_key)# shape: (batch, pomo, problem)sqrt_embedding_dim = self.model_params['sqrt_embedding_dim']logit_clipping = self.model_params['logit_clipping']score_scaled = score / sqrt_embedding_dim# shape: (batch, pomo, problem)score_clipped = logit_clipping * torch.tanh(score_scaled)score_masked = score_clipped + ninf_maskprobs = F.softmax(score_masked, dim=2)# shape: (batch, pomo, problem)return probs########################################
# NN SUB CLASS / FUNCTIONS
########################################def reshape_by_heads(qkv, head_num):# q.shape: (batch, n, head_num*key_dim)   : n can be either 1 or PROBLEM_SIZEbatch_s = qkv.size(0)n = qkv.size(1)q_reshaped = qkv.reshape(batch_s, n, head_num, -1)# shape: (batch, n, head_num, key_dim)q_transposed = q_reshaped.transpose(1, 2)# shape: (batch, head_num, n, key_dim)return q_transposeddef multi_head_attention(q, k, v, rank2_ninf_mask=None, rank3_ninf_mask=None):# q shape: (batch, head_num, n, key_dim)   : n can be either 1 or PROBLEM_SIZE# k,v shape: (batch, head_num, problem, key_dim)# rank2_ninf_mask.shape: (batch, problem)# rank3_ninf_mask.shape: (batch, group, problem)batch_s = q.size(0)head_num = q.size(1)n = q.size(2)key_dim = q.size(3)input_s = k.size(2)score = torch.matmul(q, k.transpose(2, 3))# shape: (batch, head_num, n, problem)score_scaled = score / torch.sqrt(torch.tensor(key_dim, dtype=torch.float))if rank2_ninf_mask is not None:score_scaled = score_scaled + rank2_ninf_mask[:, None, None, :].expand(batch_s, head_num, n, input_s)if rank3_ninf_mask is not None:score_scaled = score_scaled + rank3_ninf_mask[:, None, :, :].expand(batch_s, head_num, n, input_s)weights = nn.Softmax(dim=3)(score_scaled)# shape: (batch, head_num, n, problem)out = torch.matmul(weights, v)# shape: (batch, head_num, n, key_dim)out_transposed = out.transpose(1, 2)# shape: (batch, n, head_num, key_dim)out_concat = out_transposed.reshape(batch_s, n, head_num * key_dim)# shape: (batch, n, head_num*key_dim)return out_concatclass AddAndInstanceNormalization(nn.Module):def __init__(self, **model_params):super().__init__()embedding_dim = model_params['embedding_dim']self.norm = nn.InstanceNorm1d(embedding_dim, affine=True, track_running_stats=False)def forward(self, input1, input2):# input.shape: (batch, problem, embedding)added = input1 + input2# shape: (batch, problem, embedding)transposed = added.transpose(1, 2)# shape: (batch, embedding, problem)normalized = self.norm(transposed)# shape: (batch, embedding, problem)back_trans = normalized.transpose(1, 2)# shape: (batch, problem, embedding)return back_transclass AddAndBatchNormalization(nn.Module):def __init__(self, **model_params):super().__init__()embedding_dim = model_params['embedding_dim']self.norm_by_EMB = nn.BatchNorm1d(embedding_dim, affine=True)# 'Funny' Batch_Norm, as it will normalized by EMB dimdef forward(self, input1, input2):# input.shape: (batch, problem, embedding)batch_s = input1.size(0)problem_s = input1.size(1)embedding_dim = input1.size(2)added = input1 + input2normalized = self.norm_by_EMB(added.reshape(batch_s * problem_s, embedding_dim))back_trans = normalized.reshape(batch_s, problem_s, embedding_dim)return back_transclass FeedForward(nn.Module):def __init__(self, **model_params):super().__init__()embedding_dim = model_params['embedding_dim']ff_hidden_dim = model_params['ff_hidden_dim']self.W1 = nn.Linear(embedding_dim, ff_hidden_dim)self.W2 = nn.Linear(ff_hidden_dim, embedding_dim)def forward(self, input1):# input.shape: (batch, problem, embedding)return self.W2(F.relu(self.W1(input1)))

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/72129.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/72129.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/72129.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

使用Hydra進行AI項目的動態配置管理

引言:機器學習中的超參數調優挑戰 在機器學習領域,超參數調優是決定模型性能的關鍵環節。不同的模型架構,如神經網絡中的層數、節點數,決策樹中的最大深度、最小樣本分割數等;以及各種訓練相關的超參數,像學習率、優化器類型、批量大小等,其取值的選擇對最終模型的效果…

preg_replace 與 str_replace 的比較與選擇

preg_replace 與 str_replace 的比較與選擇 ——PHP字符串處理的核心工具深度解析 一、核心功能定位 在PHP的字符串處理中,str_replace和preg_replace是兩種最常用的替換函數,但其設計目標和應用場景存在本質差異: str_replace 簡單字符串替…

嵌入式開發:傅里葉變換(4):在 STM32上面實現FFT(基于STM32L071KZT6 HAL庫+DSP庫)

目錄 步驟 1:準備工作 步驟 2:創建 Keil 項目,并配置工程 步驟 3:在MDK工程上添加 CMSIS-DSP 庫 步驟 5:編寫代碼 步驟 6:配置時鐘和優化 步驟 7:調試與驗證 步驟 8:優化和調…

【MySQL篇】數據類型

目錄 前言: 1,數據類型的分類 ?編輯 2 ,數值類型 2.1 tinyint類型 2.2 bit類型 2.3 小數類型 2.3.1 float類型 2.3.2 decimal類型 3,字符串類型 3.1 char 3.2 varchar 3.3 char與varchar的比較 3.4日期和時間類型 3.5 …

nuxt常用組件庫html-validator應用解析

html-validator 主要用于自動驗證nuxt服務器呈現的HTML(SSR和SSG),以檢測可能導致水合錯誤的HTML常見問題,有助于減少水合錯誤,檢測常見的可訪問性錯誤。 安裝 npx nuxilatest module add html-validator配置 若自動更新nuxt.config.ts配置文…

智能圖像處理平臺:圖片管理

接著我們講圖片管理,先實現圖片基礎的增刪改查,再去考慮圖像處理。 主要是,我們需要完成查詢時,查詢的圖片的上傳者的角色等級小于等于我們當前登陸賬號。 后端controller: package com.llpp.controller;import cn.…

大模型知識蒸餾技術(8)——知識蒸餾應用場景

版權聲明 本文原創作者:谷哥的小弟作者博客地址:http://blog.csdn.net/lfdfhl1. 知識蒸餾概述 知識蒸餾是一種將大型復雜模型(教師模型)的知識遷移到小型簡單模型(學生模型)的技術。其核心原理是通過教師模型的輸出(通常是softmax后的概率分布)來指導學生模型的訓練,…

LeetCode:131. 分割回文串(DP Java)

目錄 131. 分割回文串 題目描述: 實現代碼與解析: 動態規劃 原理思路: 131. 分割回文串 題目描述: 給你一個字符串 s,請你將 s 分割成一些子串,使每個子串都是 回文串 。返回 s 所有可能的分割方案。…

INT202 Complexity of Algroithms 算法的復雜度

文章目錄 1. 前言1.1 算法(Algorithms)和數據結構(Data Structure)1.2 什么是好的算法?1.3 算法分析1.3.1 實驗分析(Experimental Analysis)1.3.2 理論分析1.3.2.1 偽代碼(Pseudo-co…

BDF報告翻譯簡介后:關于A φ方法criterion引理1如何由范數導出內積

關于A φ方法criterion 引理1 如何由范數導出內積 在數學中,特別是在泛函分析中,給定一個范數,可以定義一個與之相關的內積。這個過程不是總是可能的,但當一個賦范向量空間是完備的且滿足平行四邊形恒等式時,可以導出…

初識uniApp

詳細思考一下uniApp這個跨平臺開發框架。首先,我對uniApp還不是很了解,所以需要從基本概念開始,逐步深入。 什么是uniApp? 我記得uniApp是基于Vue.js的,可能是一個用來開發多個平臺的應用的框架。用戶可能想了解它是什…

olmOCR:使用VLM解析PDF

在PDF解析中,目前主流的開源工具包括Minuer、GOT OCR等。主要都是通過飛槳等OCR套件組裝的一套pipeline,或者直接通過VLM解析圖像。 #一、 olmOCR是使用VLM進行的端到端的PDF文檔解析 二、document-anchoring 與上述的不同在于,olmOCR使用…

Nginx 代理配置導致瀏覽器應用網頁頁面加載失敗的分析與解決

Nginx 代理配置導致應用頁面加載失敗的分析與解決 前期部署信息: 部署DM數據庫DEM時,配置了nginx代理,conf配置內容如下: charset utf-8;client_max_body_size 128M;listen 4567;server_name 192.168.1.156;root /opt/h5/;index…

Windows 11【1001問】查看Windows 11 版本的18種方法

隨著技術的飛速發展,操作系統作為連接硬件與軟件的核心橋梁,其版本管理和更新變得尤為重要。對于用戶而言,了解自己設備上運行的具體Windows 11版本不僅有助于優化系統性能,還能確保安全性和兼容性。然而,不同場景和需…

企業jsapi_ticket,java舉例

在企業微信開發中,使用 Java 獲取 jsapi_ticket 并生成簽名的步驟如下。以下是完整的 Java 示例代碼。 1. 獲取 jsapi_ticket 的流程 獲取 access_token。 使用 access_token 獲取 jsapi_ticket。 使用 jsapi_ticket 生成簽名(signature)。…

【Godot4.3】自定義簡易菜單欄節點ETDMenuBar

概述 Godot中的菜單創建是一個復雜的災難性工作,往往無從下手,我也是不止一次嘗試簡化菜單的創建。 從自己去年的發明“簡易樹形數據”用于簡化Tree控件獲得靈感,于是嘗試編寫了用于表示菜單數據的EasyMenuData類,以及對應的純文…

大數據與金融科技:革新金融行業的動力引擎

大數據與金融科技:革新金融行業的動力引擎 在今天的金融行業,大數據與金融科技的結合正在以驚人的速度推動著金融服務的創新與變革。通過精準的數據分析與智能化決策,金融機構能夠更高效地進行風險管理、客戶服務、資產管理等一系列關鍵操作…

二、IDE集成DeepSeek保姆級教學(使用篇)

各位看官老爺好,如果還沒有安裝DeepSeek請查閱前一篇 一、IDE集成DeepSeek保姆級教學(安裝篇) 一、DeepSeek在CodeGPT中使用教學 1.1、Edit Code 編輯代碼 選中代碼片段 —> 右鍵 —> CodeGPT —> Edit Code, 輸入自然語言可編輯代碼,點擊S…

Rohm發布TOLL封裝650V GaN HEMT,引領汽車用GaN器件大規模生產新浪潮

Rohm震撼發布TOLL封裝650V GaN HEMT,引領汽車用GaN器件大規模生產新浪潮。在創新的TOLL(TO LeadLess)封裝技術的懷抱中,Rohm精心孕育出650V GaN HEMT這一瑰寶,此技術正如一股強勁東風,日益吹拂于高功率處理…

Spring Boot 3.x 基于 Redis 實現郵箱驗證碼認證

文章目錄 依賴配置開啟 QQ 郵箱 SMTP 服務配置文件代碼實現驗證碼服務郵件服務接口實現執行流程 依賴配置 <dependencies> <!-- Spring Boot Starter Web --> <dependency> <groupId>org.springframework.boot</groupId> <artifactId>spr…