Python 整理3種查看神經網絡結構的方法

1. 網絡結構代碼

import torch
import torch.nn as nn# 定義Actor-Critic模型
class ActorCritic(nn.Module):def __init__(self, state_dim, action_dim):super(ActorCritic, self).__init__()self.actor = nn.Sequential(# 全連接層,輸入維度為 state_dim,輸出維度為 256nn.Linear(state_dim, 64),nn.ReLU(),nn.Linear(64, action_dim),# Softmax 函數,將輸出轉換為概率分布,dim=-1 表示在最后一個維度上應用 Softmaxnn.Softmax(dim=-1))self.critic = nn.Sequential(nn.Linear(state_dim, 64),nn.ReLU(),nn.Linear(64, 1))def forward(self, state):policy = self.actor(state)value = self.critic(state)return policy, value# 參數設置
state_dim = 1
action_dim = 2model = ActorCritic(state_dim, action_dim)

2. 查看結構

2.1 直接打印模型

print(model)

輸出:

ActorCritic((actor): Sequential((0): Linear(in_features=1, out_features=64, bias=True)(1): ReLU()(2): Linear(in_features=64, out_features=2, bias=True)(3): Softmax(dim=-1))(critic): Sequential((0): Linear(in_features=1, out_features=64, bias=True)(1): ReLU()(2): Linear(in_features=64, out_features=1, bias=True))
)

2.2 可視化網絡結構(需要安裝 torchviz 包)

安裝 torchsummary 包:

$ pip install torchsummary

python 代碼:

from torchviz import make_dot# 創建一個虛擬輸入
x = torch.randn(1, state_dim)
# 生成計算圖
dot = make_dot(model(x), params=dict(model.named_parameters()))
dot.render("actor_critic_model", format="png")  # 保存為PNG圖片

輸出 actor_critic_model

digraph {graph [size="12,12"]node [align=left fontname=monospace fontsize=10 height=0.2 ranksep=0.1 shape=box style=filled]140281544075344 [label="(1, 2)" fillcolor=darkolivegreen1]140281544213744 [label=SoftmaxBackward0]140281544213840 -> 140281544213744140281544213840 [label=AddmmBackward0]140281544213600 -> 140281544213840140285722327344 [label="actor.2.bias(2)" fillcolor=lightblue]140285722327344 -> 140281544213600140281544213600 [label=AccumulateGrad]140281544214032 -> 140281544213840140281544214032 [label=ReluBackward0]140281544213984 -> 140281544214032140281544213984 [label=AddmmBackward0]140281544214176 -> 140281544213984140285722327024 [label="actor.0.bias(64)" fillcolor=lightblue]140285722327024 -> 140281544214176140281544214176 [label=AccumulateGrad]140281544214224 -> 140281544213984140281544214224 [label=TBackward0]140281543934832 -> 140281544214224140285722327264 [label="actor.0.weight(64, 1)" fillcolor=lightblue]140285722327264 -> 140281543934832140281543934832 [label=AccumulateGrad]140281544213648 -> 140281544213840140281544213648 [label=TBackward0]140281544214080 -> 140281544213648140285722327184 [label="actor.2.weight(2, 64)" fillcolor=lightblue]140285722327184 -> 140281544214080140281544214080 [label=AccumulateGrad]140281544213744 -> 140281544075344140285722328704 [label="(1, 1)" fillcolor=darkolivegreen1]140281544213888 [label=AddmmBackward0]140281544214368 -> 140281544213888140285722328064 [label="critic.2.bias(1)" fillcolor=lightblue]140285722328064 -> 140281544214368140281544214368 [label=AccumulateGrad]140281544214128 -> 140281544213888140281544214128 [label=ReluBackward0]140281544214464 -> 140281544214128140281544214464 [label=AddmmBackward0]140281544214512 -> 140281544214464140285722327424 [label="critic.0.bias(64)" fillcolor=lightblue]140285722327424 -> 140281544214512140281544214512 [label=AccumulateGrad]140281544214560 -> 140281544214464140281544214560 [label=TBackward0]140281544214704 -> 140281544214560140285722327504 [label="critic.0.weight(64, 1)" fillcolor=lightblue]140285722327504 -> 140281544214704140281544214704 [label=AccumulateGrad]140281544213696 -> 140281544213888140281544213696 [label=TBackward0]140281544214272 -> 140281544213696140285722327584 [label="critic.2.weight(1, 64)" fillcolor=lightblue]140285722327584 -> 140281544214272140281544214272 [label=AccumulateGrad]140281544213888 -> 140285722328704
}

輸出模型圖片:
在這里插入圖片描述

2.3 使用 summary 方法(需要安裝 torchsummary 包)

安裝 torchsummary 包:

pip install torchsummary

代碼:

from torchsummary import summarydevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
model = model.to(device)
summary(model, input_size=(state_dim,))#查看模型參數
print("查看模型參數:")
for name, param in model.named_parameters():print(f"Layer: {name} | Size: {param.size()} | Values: {param[:2]}...")

輸出:

cuda:0
----------------------------------------------------------------Layer (type)               Output Shape         Param #
================================================================Linear-1                   [-1, 64]             128ReLU-2                   [-1, 64]               0Linear-3                    [-1, 2]             130Softmax-4                    [-1, 2]               0Linear-5                   [-1, 64]             128ReLU-6                   [-1, 64]               0Linear-7                    [-1, 1]              65
================================================================
Total params: 451
Trainable params: 451
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.00
----------------------------------------------------------------
查看模型參數:
Layer: actor.0.weight | Size: torch.Size([64, 1]) | Values: tensor([[ 0.7747],[-0.0440]], device='cuda:0', grad_fn=<SliceBackward0>)...
Layer: actor.0.bias | Size: torch.Size([64]) | Values: tensor([ 0.5995, -0.2155], device='cuda:0', grad_fn=<SliceBackward0>)...
Layer: actor.2.weight | Size: torch.Size([2, 64]) | Values: tensor([[ 0.0373,  0.0851,  0.1000,  0.1060,  0.0387,  0.0479,  0.0127,  0.0696,0.0388,  0.0033,  0.1173, -0.1195, -0.0830,  0.0186,  0.0063, -0.0863,-0.0353,  0.0782, -0.0558,  0.0011, -0.0533,  0.1241,  0.0120, -0.0906,-0.0551, -0.0673, -0.1070,  0.0402, -0.0662,  0.0596, -0.0811,  0.0457,0.0349,  0.0564, -0.0155, -0.0404,  0.0843, -0.0978,  0.0459,  0.1097,-0.0858,  0.0736, -0.0067, -0.0756, -0.0363, -0.0525, -0.0426, -0.1087,-0.0611,  0.0420, -0.1038,  0.0402,  0.0065, -0.1217, -0.0467,  0.0383,-0.0217,  0.0283,  0.0800,  0.0228,  0.0415, -0.0473, -0.0199, -0.0436],[-0.1118, -0.0806, -0.0700, -0.0224,  0.0335, -0.0087,  0.0265, -0.1196,-0.0907, -0.0360,  0.0621, -0.0471, -0.0939, -0.0912, -0.1061,  0.1051,-0.0592, -0.0757,  0.0758, -0.1082, -0.0317,  0.1208, -0.0279, -0.0693,0.0920, -0.0318, -0.0476,  0.0236, -0.0761,  0.0591,  0.0862, -0.0712,0.0156, -0.1073,  0.1133,  0.0039, -0.0191,  0.0605, -0.0686, -0.1202,0.0962,  0.0581,  0.1145,  0.0741, -0.0993, -0.0987,  0.0939,  0.1006,0.0773, -0.0756, -0.1096,  0.0156, -0.0599,  0.0857,  0.1005, -0.0618,0.0474,  0.0066, -0.0531, -0.0479,  0.1136,  0.0356,  0.1169, -0.0023]],device='cuda:0', grad_fn=<SliceBackward0>)...
Layer: actor.2.bias | Size: torch.Size([2]) | Values: tensor([-0.0039,  0.0937], device='cuda:0', grad_fn=<SliceBackward0>)...
Layer: critic.0.weight | Size: torch.Size([64, 1]) | Values: tensor([[0.5799],[0.0473]], device='cuda:0', grad_fn=<SliceBackward0>)...
Layer: critic.0.bias | Size: torch.Size([64]) | Values: tensor([ 0.6507, -0.6974], device='cuda:0', grad_fn=<SliceBackward0>)...
Layer: critic.2.weight | Size: torch.Size([1, 64]) | Values: tensor([[ 0.0738, -0.0370, -0.1010, -0.0333, -0.0595, -0.0172,  0.0928,  0.0815,0.1221, -0.0842,  0.0511,  0.0452, -0.0386, -0.0503, -0.0964,  0.0370,-0.0341, -0.0693, -0.0845,  0.0424, -0.0491, -0.0439, -0.0443,  0.0203,0.0960, -0.1178, -0.0836, -0.0144, -0.0576, -0.0851,  0.0461,  0.1160,0.0120,  0.1180,  0.0255,  0.1047, -0.0398,  0.0786,  0.1143,  0.0806,0.1125,  0.0267,  0.0534, -0.0318,  0.1125, -0.0727,  0.1169,  0.0120,-0.0178, -0.0845,  0.0069,  0.0194,  0.1188,  0.0481,  0.1077, -0.0840,0.1013,  0.0586, -0.0857, -0.0974, -0.0630,  0.0359, -0.0080, -0.0926]],device='cuda:0', grad_fn=<SliceBackward0>)...
Layer: critic.2.bias | Size: torch.Size([1]) | Values: tensor([0.0621], device='cuda:0', grad_fn=<SliceBackward0>)...

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/78564.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/78564.shtml
英文地址,請注明出處:http://en.pswp.cn/web/78564.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Linux 查詢CPU飆高的原因

獲取進程ID ps -efgrep xxxx查詢占用最高的線程ID top -Hp 線程ID線程ID 轉 16進制數 printf 0x%x\n 線程ID基于jstack工具 跟蹤堆棧定位代碼位置 jstack 進程ID | grep 16禁止線程ID -A 20

Oracle OCP認證考試考點詳解083系列09

題記&#xff1a; 本系列主要講解Oracle OCP認證考試考點&#xff08;題目&#xff09;&#xff0c;適用于19C/21C,跟著學OCP考試必過。 41. 第41題&#xff1a; 題目 解析及答案&#xff1a; 關于應用程序容器&#xff0c;以下哪三項是正確的&#xff1f; A) 它可以包含單個…

GESP2024年3月認證C++八級( 第二部分判斷題(1-5))

孫子定理參考程序&#xff1a; #include <iostream> #include <vector> using namespace std;// 擴展歐幾里得算法&#xff1a;用于求逆元 int extendedGCD(int a, int b, int &x, int &y) {if (b 0) {x 1; y 0;return a;}int x1, y1;int gcd extende…

C 語言比較運算符:程序如何做出“判斷”?

各類資料學習下載合集 ??https://pan.quark.cn/s/8c91ccb5a474?? 在編寫程序時,我們經常需要根據不同的條件來執行不同的代碼。比如,如果一個分數大于 60 分,就判斷為及格;如果用戶的年齡小于 18 歲,就禁止訪問某個內容等等。這些“判斷”的核心,就依賴于程序能夠比…

WITH在MYSQL中的用法

WITH 子句&#xff08;也稱為公共表表達式&#xff0c;Common Table Expression&#xff0c;簡稱 CTE&#xff09;是 SQL 中一種強大的查詢構建工具&#xff0c;它可以顯著提高復雜查詢的可讀性和可維護性。 一、基本語法結構 WITH cte_name AS (SELECT ... -- 定義CTE的查詢…

多序列比對軟件MAFFT介紹

MAFFT(Multiple Alignment using Fast Fourier Transform)是一款廣泛使用且高效的多序列比對軟件,由日本京都大學的Katoh Kazutaka等人開發,最早發布于2002年,并持續迭代優化至今。 它支持從幾十條到上萬條核酸或蛋白質序列的快速比對,同時在準確率和計算效率之間提供靈…

APP 設計中的色彩心理學:如何用色彩提升用戶體驗

在數字化時代&#xff0c;APP 已成為人們日常生活中不可或缺的一部分。用戶在打開一個 APP 的瞬間&#xff0c;首先映入眼簾的便是其色彩搭配&#xff0c;而這些色彩并非只是視覺上的裝飾&#xff0c;它們蘊含著強大的心理暗示力量&#xff0c;能夠潛移默化地影響用戶的情緒、行…

Compose 中使用 WebView

在 Jetpack Compose 中&#xff0c;我們可以使用 AndroidView 組件來集成傳統的 Android WebView。以下是幾種實現方式&#xff1a; 基礎 WebView 實現 Composable fun WebViewScreen(url: String) {AndroidView(factory { context ->WebView(context).apply {// 設置布局…

2025年01月03日美蜥(杭州普瑞兼職)二面

目錄 為何 nginx 可以實現跨域請求&#xff0c;原理是什么為何 nodejs 可以實現跨域請求&#xff0c;原理是什么瀏覽器的請求頭有哪些瀏覽器的響應頭有哪些瀏覽器輸入網址后發生什么http 協議和 https 有什么區別你的核心優勢是什么瀏覽器緩存機制https 的加密機制tcp 的三次握…

如何選擇合適的光源?

目錄 工業相機光源類型全面指南 1. 環形光源及其變體 高角度環形光源 優點 缺點 典型應用場景 低角度環形光源&#xff08;暗場照明&#xff09; 優點 缺點 典型應用場景 2. 條形光源與組合照明系統 技術特點 組合條形光源 優點 缺點 典型應用場景 3. 同軸光源…

「OC」源碼學習——對象的底層探索

「OC」源碼學習——對象的底層探索 前言 上次我們說到了源碼里面的調用順序&#xff0c;現在我們繼續了解我們上一篇文章沒有講完的關于對象的內容函數&#xff0c;完整了解對象的產生對于isa賦值以及內存申請的內容 函數內容 先把_objc_rootAllocWithZone函數的內容先貼上…

【C++指南】STL list容器完全解讀(一):從入門到掌握基礎操作

. &#x1f493; 博客主頁&#xff1a;倔強的石頭的CSDN主頁 &#x1f4dd;Gitee主頁&#xff1a;倔強的石頭的gitee主頁 ? 文章專欄&#xff1a;《C指南》 期待您的關注 文章目錄 一、初識list容器1.1 什么是list&#xff1f;1.2 核心特性1.3 典型應用場景 二、核心成員函數…

labelimg快捷鍵

一、核心標注快捷鍵 ?W?&#xff1a;調出標注十字架&#xff0c;開始繪制矩形框&#xff08;最常用功能&#xff09;?A/D?&#xff1a;切換上一張(A)或下一張(D)圖片&#xff0c;實現快速導航?Del?&#xff1a;刪除當前選中的標注框 二、文件操作快捷鍵 ?CtrlS?&…

linux-文件操作

在 Linux 系統中&#xff0c;文件操作與管理是日常使用和系統管理的重要組成部分。下面將詳細介紹文件的復制、移動、鏈接創建&#xff0c;以及文件查找、文本處理、排序、權限管理等相關知識。 一、文件的復制 在 Linux 里&#xff0c;cp 命令可用于復制文件或目錄&#xff…

C++ 復習

VS 修改 C 語言標準 右鍵項目-屬性 輸入輸出 //引用頭文件&#xff0c;用<>包裹起來的一般是系統提供的寫好的代碼 編譯器會在專門的系統路徑中去進行查找 #include <iostream> //自己寫的代碼文件一般都用""包裹起來 編譯器會在當前文件所在的目錄中査…

openGauss新特性 | HTAP新特性介紹

一、行列融合功能簡介 HTAP 行列融合特性在單機、主備場景下&#xff0c;通過節點的行列雙格式內存模式&#xff0c;實現openGauss HTAP一體化數據庫架構。 通過高效的行列轉換技術方案&#xff0c;節點讀取磁盤行存數據&#xff0c;生成列存儲單元&#xff08;Column Unit&am…

雙目測量中的將視差圖重投影成三維坐標圖

雙目測距主要步驟如下&#xff1a; 左右兩張圖片 → 匹配 → 得到視差圖 disp&#xff1b; 使用 cv2.reprojectImageTo3D(disp, Q) 將視差圖 重投影 成三維坐標圖 → 得到 points_3d 什么是 points_3d&#xff1f; points_3d cv2.reprojectImageTo3D(disp, Q)points_3d.shap…

《深度剖析:SOAP與REST,API集成的兩極選擇》

API作為不同系統之間交互的橋梁&#xff0c;其設計與實現的優劣直接影響著整個軟件生態的運轉效率。而在API的設計領域&#xff0c;SOAP和REST猶如兩座巍峨的山峰&#xff0c;各自代表著截然不同的設計理念與應用方向&#xff0c;成為開發者在構建API時必須慎重權衡的關鍵選項。…

非對稱加密算法(RSA、ECC、SM2)——密碼學基礎

對稱加密算法&#xff08;AES、ChaCha20和SM4&#xff09;Python實現——密碼學基礎(Python出現No module named “Crypto” 解決方案) 這篇的續篇&#xff0c;因此實踐部分少些&#xff1b; 文章目錄 一、非對稱加密算法基礎二、RSA算法2.1 RSA原理與數學基礎2.2 RSA密鑰長度…

Pillow 玩圖術:輕松獲取圖片尺寸和顏色模式

前言 在這個“圖像為王”的時代,誰還敢說自己沒被一張圖折磨過?一張圖片不講武德,說崩就崩,說卡就卡,仿佛像素里藏著程序員的眼淚。不管你是網頁設計師、AI煉丹師,還是只是想把貓片修得像藝術品,圖片的尺寸和顏色模式都是你必須掌握的第一手情報。如果你不知道它有多寬…