20250226-代碼筆記05-class CVRP_Decoder

文章目錄

  • 前言
  • 一、class CVRP_Decoder(nn.Module):__init__(self, **model_params)
    • 函數功能
    • 函數代碼
  • 二、class CVRP_Decoder(nn.Module):set_kv(self, encoded_nodes)
    • 函數功能
    • 函數代碼
  • 三、class CVRP_Decoder(nn.Module):set_q1(self, encoded_q1)
    • 函數功能
    • 函數代碼
  • 四、class CVRP_Decoder(nn.Module):set_q2(self, encoded_q2)
    • 函數功能
    • 函數代碼
  • 五、class CVRP_Decoder(nn.Module):forward(self, encoded_last_node, load, ninf_mask)
    • 函數功能
    • 函數代碼
  • 附錄
    • class CVRP_Decoder代碼(全)


前言

class CVRP_DecoderCVRP_Model.py里的類。

/home/tang/RL_exa/NCO_code-main/single_objective/LCH-Regret/Regret-POMO/CVRP/POMO/CVRPModel.py


一、class CVRP_Decoder(nn.Module):init(self, **model_params)

函數功能

init 方法是 CVRP_Decoder 類中的構造函數,主要功能是初始化該類所需的所有網絡層、權重矩陣和參數。
該方法設置了用于多頭注意力機制的權重、一個用于表示"遺憾"的參數、以及其他必要的操作用于計算注意力權重。

執行流程圖鏈接
在這里插入圖片描述

函數代碼

    def __init__(self, **model_params):super().__init__()self.model_params = model_paramsembedding_dim = self.model_params['embedding_dim']head_num = self.model_params['head_num']qkv_dim = self.model_params['qkv_dim']# self.Wq_1 = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)self.Wq_2 = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)self.Wq_last = nn.Linear(embedding_dim+1, head_num * qkv_dim, bias=False)self.Wk = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)self.Wv = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)self.regret_embedding = nn.Parameter(torch.Tensor(embedding_dim))self.regret_embedding.data.uniform_(-1, 1)self.multi_head_combine = nn.Linear(head_num * qkv_dim, embedding_dim)self.k = None  # saved key, for multi-head attentionself.v = None  # saved value, for multi-head_attentionself.single_head_key = None  # saved, for single-head attention# self.q1 = None  # saved q1, for multi-head attentionself.q2 = None  # saved q2, for multi-head attention

二、class CVRP_Decoder(nn.Module):set_kv(self, encoded_nodes)

函數功能

set_kv 方法的功能是將 encoded_nodes 中的節點嵌入轉換為多頭注意力機制所需的 鍵(K)值(V),并將它們分別保存為類的屬性。
這個方法將輸入的節點嵌入通過權重矩陣進行線性變換,得到鍵和值的表示,并為后續的多頭注意力計算做好準備。
執行流程圖鏈接
在這里插入圖片描述

函數代碼

    def set_kv(self, encoded_nodes):# encoded_nodes.shape: (batch, problem+1, embedding)head_num = self.model_params['head_num']self.k = reshape_by_heads(self.Wk(encoded_nodes), head_num=head_num)self.v = reshape_by_heads(self.Wv(encoded_nodes), head_num=head_num)# shape: (batch, head_num, problem+1, qkv_dim)self.single_head_key = encoded_nodes.transpose(1, 2)# shape: (batch, embedding, problem+1)

三、class CVRP_Decoder(nn.Module):set_q1(self, encoded_q1)

函數功能

set_q1 方法的主要功能是 計算查詢(Q) 并將其轉換為適用于多頭注意力機制的形狀。
該方法接受輸入的查詢張量 encoded_q1,通過線性層 self.Wq_1 映射到一個新的維度,并使用 reshape_by_heads 函數將其調整為適合多頭注意力機制計算的形狀。計算出的查詢會被保存為類的屬性 q1,供后續使用。

執行流程圖鏈接
在這里插入圖片描述

函數代碼

    def set_q1(self, encoded_q1):# encoded_q.shape: (batch, n, embedding)  # n can be 1 or pomohead_num = self.model_params['head_num']self.q1 = reshape_by_heads(self.Wq_1(encoded_q1), head_num=head_num)# shape: (batch, head_num, n, qkv_dim)

四、class CVRP_Decoder(nn.Module):set_q2(self, encoded_q2)

函數功能

set_q2 方法的主要功能是 計算查詢(Q) 并將其轉換為適用于多頭注意力機制的形狀。
該方法接收輸入的查詢張量 encoded_q2,通過線性層 self.Wq_2 映射到一個新的維度,并使用 reshape_by_heads 函數將其調整為適合多頭注意力計算的形狀。
執行流程圖鏈接
在這里插入圖片描述

函數代碼

    def set_q2(self, encoded_q2):# encoded_q.shape: (batch, n, embedding)  # n can be 1 or pomohead_num = self.model_params['head_num']self.q2 = reshape_by_heads(self.Wq_2(encoded_q2), head_num=head_num)# shape: (batch, head_num, n, qkv_dim)

五、class CVRP_Decoder(nn.Module):forward(self, encoded_last_node, load, ninf_mask)

函數功能

forward 方法是 CVRP_Decoder 類中的前向傳播函數,主要功能是執行 多頭自注意力機制 和 單頭注意力計算,并最終輸出每個可能節點的選擇概率(probs)。
該方法通過多頭注意力計算、前饋神經網絡處理,以及概率計算來進行節點選擇。

執行流程圖鏈接
在這里插入圖片描述

函數代碼

    def forward(self, encoded_last_node, load, ninf_mask):# encoded_last_node.shape: (batch, pomo, embedding)# load.shape: (batch, pomo)# ninf_mask.shape: (batch, pomo, problem)head_num = self.model_params['head_num']#  Multi-Head Attention#######################################################input_cat = torch.cat((encoded_last_node, load[:, :, None]), dim=2)# shape = (batch, group, EMBEDDING_DIM+1)q_last = reshape_by_heads(self.Wq_last(input_cat), head_num=head_num)# shape: (batch, head_num, pomo, qkv_dim)# q = self.q1 + self.q2 + q_last# # shape: (batch, head_num, pomo, qkv_dim)# q = q_last# shape: (batch, head_num, pomo, qkv_dim)q = self.q2 + q_last# # shape: (batch, head_num, pomo, qkv_dim)out_concat = multi_head_attention(q, self.k, self.v, rank3_ninf_mask=ninf_mask)# shape: (batch, pomo, head_num*qkv_dim)mh_atten_out = self.multi_head_combine(out_concat)# shape: (batch, pomo, embedding)#  Single-Head Attention, for probability calculation#######################################################score = torch.matmul(mh_atten_out, self.single_head_key)# shape: (batch, pomo, problem)sqrt_embedding_dim = self.model_params['sqrt_embedding_dim']logit_clipping = self.model_params['logit_clipping']score_scaled = score / sqrt_embedding_dim# shape: (batch, pomo, problem)score_clipped = logit_clipping * torch.tanh(score_scaled)score_masked = score_clipped + ninf_maskprobs = F.softmax(score_masked, dim=2)# shape: (batch, pomo, problem)return probs

附錄

class CVRP_Decoder代碼(全)

class CVRP_Decoder(nn.Module):def __init__(self, **model_params):super().__init__()self.model_params = model_paramsembedding_dim = self.model_params['embedding_dim']head_num = self.model_params['head_num']qkv_dim = self.model_params['qkv_dim']# self.Wq_1 = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)self.Wq_2 = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)self.Wq_last = nn.Linear(embedding_dim+1, head_num * qkv_dim, bias=False)self.Wk = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)self.Wv = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)self.regret_embedding = nn.Parameter(torch.Tensor(embedding_dim))self.regret_embedding.data.uniform_(-1, 1)self.multi_head_combine = nn.Linear(head_num * qkv_dim, embedding_dim)self.k = None  # saved key, for multi-head attentionself.v = None  # saved value, for multi-head_attentionself.single_head_key = None  # saved, for single-head attention# self.q1 = None  # saved q1, for multi-head attentionself.q2 = None  # saved q2, for multi-head attentiondef set_kv(self, encoded_nodes):# encoded_nodes.shape: (batch, problem+1, embedding)head_num = self.model_params['head_num']self.k = reshape_by_heads(self.Wk(encoded_nodes), head_num=head_num)self.v = reshape_by_heads(self.Wv(encoded_nodes), head_num=head_num)# shape: (batch, head_num, problem+1, qkv_dim)self.single_head_key = encoded_nodes.transpose(1, 2)# shape: (batch, embedding, problem+1)def set_q1(self, encoded_q1):# encoded_q.shape: (batch, n, embedding)  # n can be 1 or pomohead_num = self.model_params['head_num']self.q1 = reshape_by_heads(self.Wq_1(encoded_q1), head_num=head_num)# shape: (batch, head_num, n, qkv_dim)def set_q2(self, encoded_q2):# encoded_q.shape: (batch, n, embedding)  # n can be 1 or pomohead_num = self.model_params['head_num']self.q2 = reshape_by_heads(self.Wq_2(encoded_q2), head_num=head_num)# shape: (batch, head_num, n, qkv_dim)def forward(self, encoded_last_node, load, ninf_mask):# encoded_last_node.shape: (batch, pomo, embedding)# load.shape: (batch, pomo)# ninf_mask.shape: (batch, pomo, problem)head_num = self.model_params['head_num']#  Multi-Head Attention#######################################################input_cat = torch.cat((encoded_last_node, load[:, :, None]), dim=2)# shape = (batch, group, EMBEDDING_DIM+1)q_last = reshape_by_heads(self.Wq_last(input_cat), head_num=head_num)# shape: (batch, head_num, pomo, qkv_dim)# q = self.q1 + self.q2 + q_last# # shape: (batch, head_num, pomo, qkv_dim)# q = q_last# shape: (batch, head_num, pomo, qkv_dim)q = self.q2 + q_last# # shape: (batch, head_num, pomo, qkv_dim)out_concat = multi_head_attention(q, self.k, self.v, rank3_ninf_mask=ninf_mask)# shape: (batch, pomo, head_num*qkv_dim)mh_atten_out = self.multi_head_combine(out_concat)# shape: (batch, pomo, embedding)#  Single-Head Attention, for probability calculation#######################################################score = torch.matmul(mh_atten_out, self.single_head_key)# shape: (batch, pomo, problem)sqrt_embedding_dim = self.model_params['sqrt_embedding_dim']logit_clipping = self.model_params['logit_clipping']score_scaled = score / sqrt_embedding_dim# shape: (batch, pomo, problem)score_clipped = logit_clipping * torch.tanh(score_scaled)score_masked = score_clipped + ninf_maskprobs = F.softmax(score_masked, dim=2)# shape: (batch, pomo, problem)return probs

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/70722.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/70722.shtml
英文地址,請注明出處:http://en.pswp.cn/web/70722.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

洛谷 P3628/SPOJ 15648 APIO2010 特別行動隊 Commando

題意 你有一支由 n n n 名預備役士兵組成的部隊,士兵從 1 1 1 到 n n n 編號,你要將他們拆分成若干特別行動隊調入戰場。出于默契的考慮,同一支特別行動隊中隊員的編號應該連續,即為形如 i , i 1 , ? , i k i, i 1, \cdo…

PCL源碼分析:曲面法向量采樣

文章目錄 一、簡介二、源碼分析三、實現效果參考資料一、簡介 曲面法向量點云采樣,整個過程如下所述: 1、空間劃分:使用遞歸方法將點云劃分為更小的區域, 每次劃分選擇一個維度(X、Y 或 Z),將點云分為兩部分,直到劃分區域內的點少于我們指定的數量,開始進行區域隨機采…

Go語言--語法基礎2--下載安裝

2、下載安裝 1、下載源碼包: go1.18.4.linux-amd64.tar.gz。 官方地址:https://golang.google.cn/dl/ 云盤地址:鏈接: https://pan.baidu.com/s/1N2jrRHaPibvmmNFep3VYag 提 取碼: zkc3 2、將下載的源碼包解壓…

lowagie(itext)老版本手繪PDF,包含頁碼、水印、圖片、復選框、復雜行列合并等。

入口類:exportPdf ? package xcsy.qms.webapi.service;import com.alibaba.fastjson.JSONArray; import com.alibaba.fastjson.JSONObject; import com.alibaba.nacos.common.utils.StringUtils; import com.ibm.icu.text.RuleBasedNumberFormat; import com.lowa…

3-2 WPS JS宏 工作簿的打開與保存(模板批量另存為工作)學習筆記

************************************************************************************************************** 點擊進入 -我要自學網-國內領先的專業視頻教程學習網站 *******************************************************************************************…

Ubuntu20.04之VNC的安裝使用與常見問題

Ubuntu20.04之VNC的安裝與使用 安裝圖形桌面選擇安裝gnome桌面選擇安裝xface桌面 VNC-Server安裝配置開機自啟 VNC Clientroot用戶無法登入問題臨時方案永久方案 安裝圖形桌面 Ubuntu20.04主流的圖形桌面有gnome和xface兩種,兩種桌面的安裝方式我都會寫&#xff0c…

Day46 反轉字符串

I. 編寫一個函數,其作用是將輸入的字符串反轉過來。輸入字符串以字符數組 s 的形式給出。 不要給另外的數組分配額外的空間,你必須原地修改輸入數組、使用 O(1) 的額外空間解決這一問題。 class Solution {public void reverseString(char[] s) {int i …

用FileZilla Server 1.9.4給Windows Server 2025搭建FTP服務端

FileZilla Server 是一款免費的開源 FTP 和 FTPS 服務器軟件,分為服務器版和客戶端版。服務器版原本只支持Windows操作系統,比如筆者曾長期使用過0.9.60版,那時候就只支持Windows操作系統。當時我們生產環境對FTP穩定性要求較高,比…

【愚公系列】《Python網絡爬蟲從入門到精通》033-DataFrame的數據排序

標題詳情作者簡介愚公搬代碼頭銜華為云特約編輯,華為云云享專家,華為開發者專家,華為產品云測專家,CSDN博客專家,CSDN商業化專家,阿里云專家博主,阿里云簽約作者,騰訊云優秀博主,騰訊云內容共創官,掘金優秀博主,亞馬遜技領云博主,51CTO博客專家等。近期榮譽2022年度…

營銷過程烏龜圖模版

營銷過程烏龜圖模版 輸入 公司現狀產品服務客戶問詢客戶期望電話、電腦系統品牌軟件硬件材料 售前 - 溝通 - 確定需求 - 滿足需求 - 售后 機料環 電話、電腦等設備軟件硬件、系統品牌等工具材料 人 責任人協助者生產者客戶 法 訂單由誰評審控制程序營銷過程控制程序顧客滿意度…

Kubernetes (K8S) 高效使用技巧與實踐指南

Kubernetes(K8S)作為容器編排領域的核心工具,其靈活性和復雜性并存。本文結合實戰經驗,從運維效率提升、生產環境避坑、核心功能應用等維度,總結高頻使用技巧與最佳實踐,分享如何快速掌握 K8S。 一、kubect…

Idea java項目結構介紹

一般來說,一個典型的 IntelliJ IDEA Java 項目具有特定的結構,以下是對其主要部分的介紹: 項目根目錄 項目的最頂層目錄,包含了整個項目的所有文件和文件夾,通常以項目名稱命名。在這個目錄下可以找到.idea文件夾、.g…

C++大整數類的設計與實現

1. 簡介 我們知道現代的計算機大多數都是64位的,因此能處理最大整數為 2 64 ? 1 2^{64}-1 264?1。那如果是超過了這個數怎么辦呢,那就需要我們自己手動模擬數的加減乘除了。 2. 思路 我們可以用一個數組來存儲大數,數組中的每一個位置表…

2024年第十五屆藍橋杯大賽軟件賽省賽Python大學A組真題解析

文章目錄 試題A: 拼正方形(本題總分:5 分)解析答案試題B: 召喚數學精靈(本題總分:5 分)解析答案試題C: 數字詩意解析答案試題A: 拼正方形(本題總分:5 分) 【問題描述】 小藍正在玩拼圖游戲,他有7385137888721 個2 2 的方塊和10470245 個1 1 的方塊,他需要從中挑出一些…

開源RAG主流框架有哪些?如何選型?

開源RAG主流框架有哪些?如何選型? 一、開源RAG框架全景圖 (一)核心框架類型對比 類型典型工具技術特征適用場景傳統RAGLangChain, Haystack線性流程(檢索→生成)通用問答、知識庫檢索增強型RAGRAGFlow, AutoRAG支持重排序、多路召回優化高精度問答、復雜文檔處理輕量級…

Java SE與Java EE

Java SE(Java 平臺標準版) Java SE 是 Java 平臺的核心,提供了 Java 語言的基礎功能。它包含了 Java 開發工具包(JDK),其中有 Java 編譯器(javac)、Java 虛擬機(JVM&…

【Java企業生態系統的演進】從單體J2EE到云原生微服務

Java企業生態系統的演進:從單體J2EE到云原生微服務 目錄標題 Java企業生態系統的演進:從單體J2EE到云原生微服務摘要1. 引言2. 整體框架演進:從原始Java到Spring Cloud2.1 原始Java階段(1995-1999)2.2 J2EE階段&#x…

kicad中R樹的使用

在 KiCad 中,使用 R樹(R-tree)進行空間索引和加速查詢通常不在用戶層面直接操作,而是作為工具的一部分用于優化電路板設計的性能,尤其在布局、碰撞檢測、設計規則檢查(DRC)以及元件搜索等方面。…

org.springframework.boot不存在的其中一個解決辦法

最近做項目的時候發現問題,改了幾次pom.xml文件之后突然發現項目中的注解全部爆紅。 可以嘗試點擊左上角的循環小圖標,同步所有maven項目。 建議順便檢查一下Project Structure中的SDK和Language Level是否對應,否則可能報類似:“…

C語言實現通訊錄項目

一、通訊錄功能 實現一個可以存放100個人的信息的通訊錄(這里采用靜態版本),每個人的信息有姓名、性別、年齡、電話、地址等。 通訊錄可以執行的操作有添加聯系人信息、刪除指定聯系人、查找指定聯系人信息、修改指定聯系人信息、顯示聯系人信…