Day 35

模型可視化與推理
知識點回顧:

三種不同的模型可視化方法:推薦torchinfo打印summary+權重分布可視化
進度條功能:手動和自動寫法,讓打印結果更加美觀
推理的寫法:評估模式
模型結構可視化
理解一個深度學習網絡最重要的2點:

1. 了解損失如何定義的,知道損失從何而來----把抽象的任務通過損失函數量化出來

2. 了解參數總量,即知道每一層的設計---層設計決定參數總量

為了了解參數總量,我們需要知道層設計,以及每一層參數的數量。下面介紹1幾個層可視化工具:

1. nn.model自帶的方法

#  nn.Module 的內置功能,直接輸出模型結構
print(model)

這是最基礎、最簡單的方法,會直接打印模型對象,它會輸出模型的結構,顯示模型中各個層的名稱和參數信息

# nn.Module 的內置功能,返回模型的可訓練參數迭代器
for name, param in model.named_parameters():print(f"Parameter name: {name}, Shape: {param.shape}")

可以將模型中帶有weight的參數(即權重)提取出來,并轉為 numpy 數組形式,對其計算統計分布,并且繪制可視化圖表

# 提取權重數據
import numpy as np
weight_data = {}
for name, param in model.named_parameters():if 'weight' in name:weight_data[name] = param.detach().cpu().numpy()# 可視化權重分布
fig, axes = plt.subplots(1, len(weight_data), figsize=(15, 5))
fig.suptitle('Weight Distribution of Layers')for i, (name, weights) in enumerate(weight_data.items()):# 展平權重張量為一維數組weights_flat = weights.flatten()# 繪制直方圖axes[i].hist(weights_flat, bins=50, alpha=0.7)axes[i].set_title(name)axes[i].set_xlabel('Weight Value')axes[i].set_ylabel('Frequency')axes[i].grid(True, linestyle='--', alpha=0.7)plt.tight_layout()
plt.subplots_adjust(top=0.85)
plt.show()# 計算并打印每層權重的統計信息
print("\n=== 權重統計信息 ===")
for name, weights in weight_data.items():mean = np.mean(weights)std = np.std(weights)min_val = np.min(weights)max_val = np.max(weights)print(f"{name}:")print(f"  均值: {mean:.6f}")print(f"  標準差: {std:.6f}")print(f"  最小值: {min_val:.6f}")print(f"  最大值: {max_val:.6f}")print("-" * 30)

對比 fc1.weight 和 fc2.weight 的統計信息 ,可以發現它們的均值、標準差、最值等存在差異。這反映了不同層在模型中的作用不同。權重統計信息可以為超參數調整提供參考。

2.torchsummary庫的summary方法
# pip install torchsummary -i https://pypi.tuna.tsinghua.edu.cn/simple
from torchsummary import summary
# 打印模型摘要,可以放置在模型定義后面
summary(model, input_size=(4,))

?該方法不顯示輸入層的尺寸,因為輸入的神經網是自己設置的,所以不需要顯示輸入層的尺寸。但是在使用該方法時,input_size=(4,) 參數是必需的,因為 PyTorch 需要知道輸入數據的形狀才能推斷模型各層的輸出形狀和參數數量。

? ? ? ? 這是因為PyTorch 的模型在定義時是動態的,它不會預先知道輸入數據的具體形狀。nn.Linear(4, 10) 只定義了 “輸入維度是 4,輸出維度是 10”,但不知道輸入的批量大小和其他維度,比如卷積層需要知道輸入的通道數、高度、寬度等信息。----并非所有輸入數據都是結構化數據

? ? ? ? 因此,要生成模型摘要(如每層的輸出形狀、參數數量),必須提供一個示例輸入形狀,讓 PyTorch “運行” 一次模型,從而推斷出各層的信息。

summary 函數的核心邏輯是:

1. 創建一個與 input_size 形狀匹配的虛擬輸入張量(通常填充零)

2. 將虛擬輸入傳遞給模型,執行一次前向傳播(但不計算梯度)

3. 記錄每一層的輸入和輸出形狀,以及參數數量

4. 生成可讀的摘要報告

構建神經網絡的時候

1. 輸入層不需要寫:x多少個特征 輸入層就有多少神經元

2. 隱藏層需要寫,從第一個隱藏層可以看出特征的個數

3. 輸出層的神經元和任務有關,比如分類任務,輸出層有3個神經元,一個對應每個類別

可學習參數計算

1. Linear-1對應self.fc1 = nn.Linear(4, 10),表明前一層有4個神經元,這一層有10個神經元,每2個神經元之間靠著線相連,所有有4*10個權重參數+10個偏置參數=50個參數

2. relu層不涉及可學習參數,可以把它和前一個線性層看成一層,圖上也是這個含義

3. Linear-3層對應代碼 self.fc2 = nn.Linear(10,3),10*3個權重參數+3個偏置=33個參數

總參數83個,占用內存幾乎為0

1.3 torchinfo庫的summary方法
?torchinfo 是提供比 torchsummary 更詳細的模型摘要信息,包括每層的輸入輸出形狀、參數數量、計算量等。

# pip install torchinfo -i https://pypi.tuna.tsinghua.edu.cn/simple
from torchinfo import summary
summary(model, input_size=(4, ))

進度條功能
tqdm這個庫非常適合用在循環中觀察進度。尤其在深度學習這種訓練是循環的場景中。他最核心的邏輯如下

1. 創建一個進度條對象,并傳入總迭代次數。一般用with語句創建對象,這樣對象會在with語句結束后自動銷毀,保證資源釋放。with是常見的上下文管理器,這樣的使用方式還有用with打開文件,結束后會自動關閉文件。

2. 更新進度條,通過pbar.update(n)指定每次前進的步數n(適用于非固定步長的循環)。

1.手動更新

from tqdm import tqdm  # 先導入tqdm庫
import time  # 用于模擬耗時操作# 創建一個總步數為10的進度條
with tqdm(total=10) as pbar:  # pbar是進度條對象的變量名# pbar 是 progress bar(進度條)的縮寫,約定俗成的命名習慣。for i in range(10):  # 循環10次(對應進度條的10步)time.sleep(0.5)  # 模擬每次循環耗時0.5秒pbar.update(1)  # 每次循環后,進度條前進1步
from tqdm import tqdm
import time# 創建進度條時添加描述(desc)和單位(unit)
with tqdm(total=5, desc="下載文件", unit="個") as pbar:# 進度條這個對象,可以設置描述和單位# desc是描述,在左側顯示# unit是單位,在進度條右側顯示for i in range(5):time.sleep(1)pbar.update(1)  # 每次循環進度+1

unit 參數的核心作用是明確進度條中每個進度單位的含義,使可視化信息更具可讀性。在深度學習訓練中,常用的單位包括:

  • epoch:訓練輪次(遍歷整個數據集一次)。
  • batch:批次(每次梯度更新處理的樣本組)。
  • sample:樣本(單個數據點)
  • @浙大疏錦行

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/82713.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/82713.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/82713.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

[yolov11改進系列]基于yolov11引入自注意力與卷積混合模塊ACmix提高FPS+檢測效率python源碼+訓練源碼

[ACmix的框架原理] 1.1 ACMix的基本原理 ACmix是一種混合模型,結合了自注意力機制和卷積運算的優勢。它的核心思想是,傳統卷積操作和自注意力模塊的大部分計算都可以通過1x1的卷積來實現。ACmix首先使用1x1卷積對輸入特征圖進行投影,生成一組…

[DS]使用 Python 庫中自帶的數據集來實現上述 50 個數據分析和數據可視化程序的示例代碼

使用 Python 庫中自帶的數據集來實現上述 50 個數據分析和數據可視化程序的示例代碼 摘要:由于 sample_data.csv 是一個占位符文件,用于代表任意數據集,我將使用 Python 庫中自帶的數據集來實現上述 50 個數據分析和數據可視化程序的示例代碼…

【Python 中 lambda、map、filter 和 reduce】詳細功能介紹及用法總結

以下是 Python 中 lambda、map、filter 和 reduce 的詳細功能介紹及用法總結,涵蓋基礎語法、高頻場景和示例代碼。 一、lambda 匿名函數 功能 用于快速定義一次性使用的匿名函數。不需要顯式命名,適合簡化小規模邏輯。 語法 lambda 參數1, 參數2, ..…

貪心算法——分數背包問題

一、背景介紹 給定𝑛個物品,第𝑖個物品的重量為𝑤𝑔𝑡[𝑖?1]、價值為𝑣𝑎𝑙[𝑖?1],和一個容量為𝑐𝑎&#…

《軟件工程》第 5 章 - 需求分析模型的表示

目錄 5.1需求分析與驗證 5.1.1 順序圖 5.1.2 通信圖 5.1.3 狀態圖 5.1.4 擴充機制 5.2 需求分析的過程模型 5.3 需求優先級分析 5.3.1 確定需求項優先級 5.3.2 排定用例分析的優先順序 5.4 用例分析 5.4.1 精化領域概念模型 5.4.2 設置分析類 5.4.3 構思分析類之間…

基于MATLAB的大規模MIMO信道仿真

1. 系統模型與參數設置 以下是一個單小區大規模MIMO系統的參數配置示例,適用于多發多收和單發單收場景。 % 參數配置 params.N_cell 1; % 小區數量(單小區仿真) params.cell_radius 500; % 小區半徑(米&#xff09…

想查看或修改 MinIO 桶的匿名訪問權限(public/private/custom)

在 Ubuntu 下,如果你想查看或修改 MinIO 桶的匿名訪問權限(public/private/custom),需要使用 mc anonymous 命令而不是 mc policy。以下是詳細操作指南: 1. 查看當前匿名訪問權限 mc anonymous get minio/test輸出示例…

HarmonyOS:相機選擇器

一、概述 相機選擇器提供相機拍照與錄制的能力。應用可選擇媒體類型實現拍照和錄制的功能。調用此類接口時,應用必須在界面UIAbility中調用,否則無法啟動cameraPicker應用。 說明 本模塊首批接口從API version 11開始支持。后續版本的新增接口&#xff0…

牛客AI簡歷篩選:提升招聘效率的智能解決方案

在競爭激烈的人才市場中,企業HR每天需處理海量簡歷,面臨篩選耗時長、標準不統一、誤判率高等痛點。牛客網推出的AI簡歷篩選工具,以“20分鐘處理1000份簡歷、準確率媲美真人HR”的高效表現,成為企業招聘的智能化利器。本文將深度解…

白楊SEO:做AI搜索優化的DeepSeek、豆包、Kimi、百度文心一言、騰訊元寶、通義、智譜、天工等AI生成內容信息采集主要來自哪?占比是多少?

大家好,我是白楊SEO,專注SEO十年以上,全網SEO流量實戰派,AI搜索優化研究者。 在開始寫之前,先說個抱歉。 上周在上海客戶以及線下聚會AI搜索優化分享說各大AI模型的聯網搜索是關閉的,最開始上來確實是的。…

QML與C++交互2

在QML與C的交互中,主要有兩種方式:在C中調用QML的方法和在QML中調用C的方法。以下是具體的實現方法。 在C中調用QML的方法 首先,我們需要在QML文件中定義一個函數,然后在C代碼中調用它。 示例 //QML main.qml文件 import QtQu…

OpenGL Chan視頻學習-8 How I Deal with Shaders in OpenGL

bilibili視頻鏈接: 【最好的OpenGL教程之一】https://www.bilibili.com/video/BV1MJ411u7Bc?p5&vd_source44b77bde056381262ee55e448b9b1973 函數網站: docs.gl 說明: 1.之后就不再整理具體函數了,網站直接翻譯會更直觀也…

動態防御新紀元:AI如何重構DDoS攻防成本格局

1. 傳統高防IP的靜態瓶頸與成本困境 傳統高防IP依賴預定義規則庫,面對SYN Flood、CC攻擊等威脅時,常因規則更新滯后導致誤封合法流量。例如,某電商平臺曾因靜態閾值過濾誤封20%的訂單接口流量,直接影響營收。以下代碼模擬傳統方案…

如何實現高性能超低延遲的RTSP或RTMP播放器

隨著直播行業的快速發展,RTSP和RTMP協議成為了廣泛使用的流媒體傳輸協議,尤其是在實時視頻直播領域,如何構建一個高性能超低延遲的直播播放器,已經成為了決定直播平臺成功與否的關鍵因素之一。作為音視頻直播SDK技術老兵&#xff…

UE5 編輯器工具藍圖

文章目錄 簡述使用方法樣例自動生成Actor,并根據模型的包圍盒設置Actor的大小批量修改場景中Actor的屬性,設置Actor的名字,設置Actor到指定的文件夾 簡述 使用編輯器工具好處是可以在非運行時可以對資源或場景做一些操作,例如自動…

解鎖5月游戲新體驗 高速電腦配置推薦

很多玩家用戶會發現一個規律,618大促前很多商家會提前解鎖各種福利,5月選購各種電腦配件有時候會更劃算!并且,STEAM在5月還有幾個年度主題促銷,“生物收集游戲節”、“僵尸大戰吸血鬼游戲節”等等,配件大促…

干貨|VR全景是什么?

VR全景技術解析:概念、特點與用途 VR全景,全稱為虛擬現實全景技術(Virtual Reality Panorama Technology),是基于虛擬現實(Virtual Reality,VR)技術的創新展示方式。VR全景技術利用專業的拍攝設…

Nacos適配GaussDB超詳細部署流程,通過二進制包、以及 Docker 打通用鏡像包部署保姆級教程

1部署openGauss 官方文檔下載 https://support.huaweicloud.com/download_gaussdb/index.html 社區地址 安裝包下載 本文主要是以部署輕量級為主要教程,系統為openEuler,ip: 192.168.1.15 1.1系統環境準備 操作系統選擇 系統AARCH64X86-64openEuler√√CentOS7√Docker…

MySQL 表內容的增刪查改 -- CRUD操作,聚合函數,group by 子句

目錄 1. Create 1.1 語法 1.2 單行數據 全列插入 1.3 多行數據 指定列插入 1.4 插入數據否則更新數據 1.5 替換 2. Retrieve 2.1 SELECT 列 2.1.1 全列查詢 2.1.2 指定列查詢 2.1.3 查詢字段為表達式 2.1.4 為查詢結果指定別名 2.1.5 結構去重 2.2 WHERE 條件 …

LabVIEW累加器標簽通道

主要展示了 Accumulator Tag 通道的使用,通過三個并行運行的循環模擬不同數值的多個隨機序列,分別以不同頻率向累加器寫入數值,右側循環每秒讀取累加器值,同時可切換查看每秒內每次事件的平均值,用于演示多線程數據交互…