Python訓練打卡Day35

模型可視化與推理

知識點回顧:

  1. 三種不同的模型可視化方法:推薦torchinfo打印summary+權重分布可視化
  2. 進度條功能:手動和自動寫法,讓打印結果更加美觀
  3. 推理的寫法:評估模式

模型結構可視化

理解一個深度學習網絡最重要的2點:

1. 了解損失如何定義的,知道損失從何而來----把抽象的任務通過損失函數量化出來

2. 了解參數總量,即知道每一層的設計---層設計決定參數總量

為了了解參數總量,我們需要知道層設計,以及每一層參數的數量。下面介紹1幾個層可視化工具:

1. nn.model自帶的方法
#  nn.Module 的內置功能,直接輸出模型結構
print(model)

這是最基礎、最簡單的方法,會直接打印模型對象,它會輸出模型的結構,顯示模型中各個層的名稱和參數信息

# nn.Module 的內置功能,返回模型的可訓練參數迭代器
for name, param in model.named_parameters():print(f"Parameter name: {name}, Shape: {param.shape}")

可以將模型中帶有weight的參數(即權重)提取出來,并轉為 numpy 數組形式,對其計算統計分布,并且繪制可視化圖表

# 提取權重數據
import numpy as np
weight_data = {}
for name, param in model.named_parameters():if 'weight' in name:weight_data[name] = param.detach().cpu().numpy()# 可視化權重分布
fig, axes = plt.subplots(1, len(weight_data), figsize=(15, 5))
fig.suptitle('Weight Distribution of Layers')for i, (name, weights) in enumerate(weight_data.items()):# 展平權重張量為一維數組weights_flat = weights.flatten()# 繪制直方圖axes[i].hist(weights_flat, bins=50, alpha=0.7)axes[i].set_title(name)axes[i].set_xlabel('Weight Value')axes[i].set_ylabel('Frequency')axes[i].grid(True, linestyle='--', alpha=0.7)plt.tight_layout()
plt.subplots_adjust(top=0.85)
plt.show()# 計算并打印每層權重的統計信息
print("\n=== 權重統計信息 ===")
for name, weights in weight_data.items():mean = np.mean(weights)std = np.std(weights)min_val = np.min(weights)max_val = np.max(weights)print(f"{name}:")print(f"  均值: {mean:.6f}")print(f"  標準差: {std:.6f}")print(f"  最小值: {min_val:.6f}")print(f"  最大值: {max_val:.6f}")print("-" * 30)

對比 fc1.weight 和 fc2.weight 的統計信息 ,可以發現它們的均值、標準差、最值等存在差異。這反映了不同層在模型中的作用不同。權重統計信息可以為超參數調整提供參考。

2.torchsummary庫的summary方法
# pip install torchsummary -i https://pypi.tuna.tsinghua.edu.cn/simple
from torchsummary import summary
# 打印模型摘要,可以放置在模型定義后面
summary(model, input_size=(4,))

????????該方法不顯示輸入層的尺寸,因為輸入的神經網是自己設置的,所以不需要顯示輸入層的尺寸。但是在使用該方法時,input_size=(4,) 參數是必需的,因為 PyTorch 需要知道輸入數據的形狀才能推斷模型各層的輸出形狀和參數數量。

????????這是因為PyTorch 的模型在定義時是動態的,它不會預先知道輸入數據的具體形狀。nn.Linear(4, 10) 只定義了 “輸入維度是 4,輸出維度是 10”,但不知道輸入的批量大小和其他維度,比如卷積層需要知道輸入的通道數、高度、寬度等信息。----并非所有輸入數據都是結構化數據

????????因此,要生成模型摘要(如每層的輸出形狀、參數數量),必須提供一個示例輸入形狀,讓 PyTorch “運行” 一次模型,從而推斷出各層的信息。

summary 函數的核心邏輯是:

1. 創建一個與 input_size 形狀匹配的虛擬輸入張量(通常填充零)

2. 將虛擬輸入傳遞給模型,執行一次前向傳播(但不計算梯度)

3. 記錄每一層的輸入和輸出形狀,以及參數數量

4. 生成可讀的摘要報告

構建神經網絡的時候

1. 輸入層不需要寫:x多少個特征 輸入層就有多少神經元

2. 隱藏層需要寫,從第一個隱藏層可以看出特征的個數

3. 輸出層的神經元和任務有關,比如分類任務,輸出層有3個神經元,一個對應每個類別

可學習參數計算

1. Linear-1對應self.fc1 = nn.Linear(4, 10),表明前一層有4個神經元,這一層有10個神經元,每2個神經元之間靠著線相連,所有有4*10個權重參數+10個偏置參數=50個參數

2. relu層不涉及可學習參數,可以把它和前一個線性層看成一層,圖上也是這個含義

3. Linear-3層對應代碼 self.fc2 = nn.Linear(10,3),10*3個權重參數+3個偏置=33個參數

總參數83個,占用內存幾乎為0

1.3 torchinfo庫的summary方法

?torchinfo 是提供比 torchsummary 更詳細的模型摘要信息,包括每層的輸入輸出形狀、參數數量、計算量等。

# pip install torchinfo -i https://pypi.tuna.tsinghua.edu.cn/simple
from torchinfo import summary
summary(model, input_size=(4, ))

進度條功能

tqdm這個庫非常適合用在循環中觀察進度。尤其在深度學習這種訓練是循環的場景中。他最核心的邏輯如下

1. 創建一個進度條對象,并傳入總迭代次數。一般用with語句創建對象,這樣對象會在with語句結束后自動銷毀,保證資源釋放。with是常見的上下文管理器,這樣的使用方式還有用with打開文件,結束后會自動關閉文件。

2. 更新進度條,通過pbar.update(n)指定每次前進的步數n(適用于非固定步長的循環)。

1.手動更新
from tqdm import tqdm  # 先導入tqdm庫
import time  # 用于模擬耗時操作# 創建一個總步數為10的進度條
with tqdm(total=10) as pbar:  # pbar是進度條對象的變量名# pbar 是 progress bar(進度條)的縮寫,約定俗成的命名習慣。for i in range(10):  # 循環10次(對應進度條的10步)time.sleep(0.5)  # 模擬每次循環耗時0.5秒pbar.update(1)  # 每次循環后,進度條前進1步
from tqdm import tqdm
import time# 創建進度條時添加描述(desc)和單位(unit)
with tqdm(total=5, desc="下載文件", unit="個") as pbar:# 進度條這個對象,可以設置描述和單位# desc是描述,在左側顯示# unit是單位,在進度條右側顯示for i in range(5):time.sleep(1)pbar.update(1)  # 每次循環進度+1

unit 參數的核心作用是明確進度條中每個進度單位的含義,使可視化信息更具可讀性。在深度學習訓練中,常用的單位包括:

  • epoch:訓練輪次(遍歷整個數據集一次)。
  • batch:批次(每次梯度更新處理的樣本組)。
  • sample:樣本(單個數據點)
2.自動更新
from tqdm import tqdm
import time# 直接將range(3)傳給tqdm,自動生成進度條
# 這個寫法我覺得是有點神奇的,直接可以給這個對象內部傳入一個可迭代對象,然后自動生成進度條
for i in tqdm(range(3), desc="處理任務", unit="epoch"):time.sleep(1)

for i in tqdm(range(3), desc="處理任務", unit="個")這個寫法則不需要在循環中調用update()方法,更加簡潔。實際上這2種寫法都隨意選取,這里都介紹下

?

# 用tqdm的set_postfix方法在進度條右側顯示實時數據(如當前循環的數值、計算結果等):
from tqdm import tqdm
import timetotal = 0  # 初始化總和
with tqdm(total=10, desc="累加進度") as pbar:for i in range(1, 11):time.sleep(0.3)total += i  # 累加1+2+3+...+10pbar.update(1)  # 進度+1pbar.set_postfix({"當前總和": total})  # 顯示實時總和

完整代碼:

import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
import time
import matplotlib.pyplot as plt
from tqdm import tqdm  # 導入tqdm庫用于進度條顯示# 設置GPU設備
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"使用設備: {device}")# 加載鳶尾花數據集
iris = load_iris()
X = iris.data  # 特征數據
y = iris.target  # 標簽數據# 劃分訓練集和測試集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 歸一化數據
scaler = MinMaxScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)# 將數據轉換為PyTorch張量并移至GPU
X_train = torch.FloatTensor(X_train).to(device)
y_train = torch.LongTensor(y_train).to(device)
X_test = torch.FloatTensor(X_test).to(device)
y_test = torch.LongTensor(y_test).to(device)class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.fc1 = nn.Linear(4, 10)  # 輸入層到隱藏層self.relu = nn.ReLU()self.fc2 = nn.Linear(10, 3)  # 隱藏層到輸出層def forward(self, x):out = self.fc1(x)out = self.relu(out)out = self.fc2(out)return out# 實例化模型并移至GPU
model = MLP().to(device)# 分類問題使用交叉熵損失函數
criterion = nn.CrossEntropyLoss()# 使用隨機梯度下降優化器
optimizer = optim.SGD(model.parameters(), lr=0.01)# 訓練模型
num_epochs = 20000  # 訓練的輪數# 用于存儲每100個epoch的損失值和對應的epoch數
losses = []
epochs = []start_time = time.time()  # 記錄開始時間# 創建tqdm進度條
with tqdm(total=num_epochs, desc="訓練進度", unit="epoch") as pbar:# 訓練模型for epoch in range(num_epochs):# 前向傳播outputs = model(X_train)  # 隱式調用forward函數loss = criterion(outputs, y_train)# 反向傳播和優化optimizer.zero_grad()loss.backward()optimizer.step()# 記錄損失值并更新進度條if (epoch + 1) % 200 == 0:losses.append(loss.item())epochs.append(epoch + 1)# 更新進度條的描述信息pbar.set_postfix({'Loss': f'{loss.item():.4f}'})# 每1000個epoch更新一次進度條if (epoch + 1) % 1000 == 0:pbar.update(1000)  # 更新進度條# 確保進度條達到100%if pbar.n < num_epochs:pbar.update(num_epochs - pbar.n)  # 計算剩余的進度并更新time_all = time.time() - start_time  # 計算訓練時間
print(f'Training time: {time_all:.2f} seconds')# # 可視化損失曲線
# plt.figure(figsize=(10, 6))
# plt.plot(epochs, losses)
# plt.xlabel('Epoch')
# plt.ylabel('Loss')
# plt.title('Training Loss over Epochs')
# plt.grid(True)
# plt.show()

模型的推理

測試這個詞在大模型領域叫做推理(inference),意味著把數據輸入到訓練好的模型的過程。

?

# 在測試集上評估模型,此時model內部已經是訓練好的參數了
# 評估模型
model.eval() # 設置模型為評估模式
with torch.no_grad(): # torch.no_grad()的作用是禁用梯度計算,可以提高模型推理速度outputs = model(X_test)  # 對測試數據進行前向傳播,獲得預測結果_, predicted = torch.max(outputs, 1) # torch.max(outputs, 1)返回每行的最大值和對應的索引#這個函數返回2個值,分別是最大值和對應索引,參數1是在第1維度(行)上找最大值,_ 是Python的約定,表示忽略這個返回值,所以這個寫法是找到每一行最大值的下標# 此時outputs是一個tensor,p每一行是一個樣本,每一行有3個值,分別是屬于3個類別的概率,取最大值的下標就是預測的類別# predicted == y_test判斷預測值和真實值是否相等,返回一個tensor,1表示相等,0表示不等,然后求和,再除以y_test.size(0)得到準確率# 因為這個時候數據是tensor,所以需要用item()方法將tensor轉化為Python的標量# 之所以不用sklearn的accuracy_score函數,是因為這個函數是在CPU上運行的,需要將數據轉移到CPU上,這樣會慢一些# size(0)獲取第0維的長度,即樣本數量correct = (predicted == y_test).sum().item() # 計算預測正確的樣本數accuracy = correct / y_test.size(0)print(f'測試集準確率: {accuracy * 100:.2f}%')

模型的評估模式簡單來說就是評估階段會關閉一些訓練相關的操作和策略 ,比如更新參數 正則化等操作,確保模型輸出結果的穩定性和一致性。

@浙大疏錦行

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/907199.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/907199.shtml
英文地址,請注明出處:http://en.pswp.cn/news/907199.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

四、生活常識

一、效應定律 效應 1、沉沒成本效應 投入的越多&#xff0c;退出的難度就越大&#xff0c;因為不甘心自己之前的所有付出都付之東流。 2、破窗效應 干凈的環境下&#xff0c;沒有人會第一個丟垃圾&#xff0c;但是當環境變得糟糕&#xff0c;人們就開始無所妒忌的丟垃圾。…

機器學習圣經PRML作者Bishop20年后新作中文版出版!

機器學習圣經PRML作者Bishop20年后新書《深度學習&#xff1a;基礎與概念》出版。作者克里斯托弗M. 畢曉普&#xff08;Christopher M. Bishop&#xff09;微軟公司技術研究員、微軟研究 院 科學智 能 中 心&#xff08;Microsoft Research AI4Science&#xff09;負責人。劍橋…

Python應用嵌套猜數字小游戲

大家好!今天向大家分享的是有關“嵌套”的猜數字小游戲。希望能夠幫助大家理解嵌套。 代碼呈現: # 1. 構建一個隨機的數字變量 import random num random.randint(1, 10)guess_num int(input("輸入你要猜測的數字&#xff1a; "))# 2. 通過if判斷語句進行數字的猜…

黑馬k8s(十四)

1.Service-概述 service&#xff1a;用于四層路由的負載&#xff0c;Ingress七層路由的負載&#xff1b;&#xff0c;先學習service 開啟ipvs 2.Service-資源清單文件介紹 修改每個顯示的內容 ClusterIP類型的Service Endpoints&#xff1a;建立service與pod關聯 親和性測試…

Kotlin 中 Lambda 表達式的語法結構及簡化推導

在 Kotlin 編程中&#xff0c;Lambda 表達式是一項非常實用且強大的功能。今天&#xff0c;我們就來深入探討一下 Lambda 表達式的語法結構&#xff0c;以及它那些令人 “又愛又恨” 的簡化寫法。 一、Lambda 表達式完整語法結構 Lambda 表達式最完整的語法結構定義為{參數名…

Kafka Streams 和 Apache Flink 的無狀態流處理與有狀態流處理

Kafka Streams 和 Apache Flink 與數據庫和數據湖相比的無狀態和有狀態流處理的概念和優勢。 在數據驅動的應用中&#xff0c;流處理的興起改變了我們處理和操作數據的方式。雖然傳統數據庫、數據湖和數據倉庫對于許多基于批處理的用例來說非常有效&#xff0c;但在要求低延遲…

【后端高階面經:緩存篇】34、高并發下緩存穿透、擊穿、雪崩怎么解決

一、緩存三大核心問題:穿透、擊穿、雪崩的本質區別 (一)概念對比表 問題類型核心特征典型場景危害等級緩存穿透數據在緩存和數據庫中均不存在,請求直接穿透到數據庫惡意攻擊(偽造不存在的ID)、業務邏輯漏洞★★★★★緩存擊穿熱點數據在緩存中過期,大量并發請求同時擊穿…

使用Rancher在CentOS 環境上部署和管理多Kubernetes集群

引言 隨著容器技術的迅猛發展&#xff0c;Kubernetes已成為容器編排領域的事實標準。然而&#xff0c;隨著企業應用規模的擴大&#xff0c;多集群管理逐漸成為企業IT架構中的重要需求。 Rancher作為一個開源的企業級多集群Kubernetes管理平臺&#xff0c;以其友好的用戶界面和…

【Mini-F5265-OB開發板試用測評】按鍵控制測試

本文介紹了如何使用按鍵控制 MCU 引腳的輸出電平。 原理 由原理圖可知 板載用戶按鍵 K1 和 K2 分別與主控的 PB0 和 PB1 相連。 代碼 #define _MAIN_C_#include "platform.h" #include "gpio_key_input.h" #include "main.h"int main(void) …

用C#最小二乘法擬合圓形,計算圓心和半徑

用C#最小二乘法擬合圓形&#xff0c;計算圓心和半徑 using System; using System.Collections.Generic;namespace ConsoleApp2 {internal class Program{static void Main(string[] args){List<Tuple<double, double>> points new List<Tuple<double, doubl…

四、web安全-行業術語

1. 肉雞 所謂“肉雞”是一種很形象的比喻&#xff0c;比喻那些可以隨意被我們控制的電腦&#xff0c;對方可以是WINDOWS系統&#xff0c;也可以是UNIX/LINUX系統&#xff0c;可以是普通的個人電腦&#xff0c;也可以是大型的服務器&#xff0c;我們可以象操作自己的電腦那樣來…

MYSQL丟失pid處理方式

1、停止服務器 systemctl stop mysqld 2、修改 /data/mysql/etc/my.cnf pid-file /tmp/mysql/mysql.pid 改為 pid-file /data/mysql/mysql.pid 3、創建 touch /data/mysql/mysql.pid ch…

《計算機組成原理》第 2 章 - 計算機的發展及應用?

計算機從誕生至今&#xff0c;經歷了翻天覆地的變化&#xff0c;應用領域也在不斷拓展。本文將結合 Java 代碼實例&#xff0c;帶你深入了解計算機的發展歷程、應用場景及未來展望&#xff0c;讓你在學習理論的同時&#xff0c;還能通過實踐加深理解。? 2.1 計算機的發展史? …

Github 2025-05-26 開源項目周報Top15

根據Github Trendings的統計,本周(2025-05-26統計)共有15個項目上榜。根據開發語言中項目的數量,匯總情況如下: 開發語言項目數量Python項目5TypeScript項目3JavaScript項目3C++項目2Roff項目1Go項目1C#項目1Jupyter Notebook項目1Rust項目1CSS項目1Shell項目1Dockerfile項目…

詳解MYSQL索引失效問題排查

目錄 一、快速定位索引失效的步驟 1. 使用 EXPLAIN 分析執行計劃詳解Mysql的Explain語句 2. 確認索引是否存在 3. 檢查查詢條件是否符合索引規則 二、常見索引失效場景及解決方法 1. 索引列參與計算或函數 2. 隱式類型轉換 3. 使用 LIKE 以通配符開頭 4. 使用 OR 連接…

在 springboot3.x 使用 knife4j 以及常見報錯匯總

目錄 引言&#xff1a; 引入依賴&#xff1a; 配置文件&#xff1a; 過濾靜態資源&#xff1a; 增強模式&#xff1a; 便捷地址訪問&#xff1a; 常見問題&#xff1a; 注解使用實例&#xff1a; &#x1f4c4; ?文檔參考地址?&#xff1a; SpringBoot 3.x 結合 …

【C/C++】環形緩沖區:高效數據流轉核心

文章目錄 1 核心結構與原理1.1 組成1.2 內存布局1.3 關鍵操作 2 實現細節與優化2.1 滿/空狀態的判斷2.2 多線程安全&#xff08;無鎖實現&#xff09;2.3 性能優化 3 典型應用場景4 代碼示例5 優缺點6 對比7 進階 環形緩沖區&#xff08;Ring Buffer&#xff09;&#xff0c;又…

功耗僅4W!迷你服務器黑豹X2(Panther X2)卡刷、線刷刷入Armbian(ubuntu)系統教程

功耗僅4W&#xff01;迷你服務器黑豹X2&#xff08;Panther X2&#xff09;卡刷、線刷刷入Armbian&#xff08;ubuntu&#xff09;系統教程 前言 前段時間逛海鮮市場的時候留意到一個礦渣盒子&#xff0c;黑豹x2&#xff0c;又是一個類似迅雷賺錢寶這樣的挖礦項目已經gg的定制…

【Elasticsearch】更新操作原理

Elasticsearch 的更新操作&#xff08;如 _update 和 _update_by_query&#xff09;在底層實現上有一些復雜的原理&#xff0c;這些原理涉及到 Elasticsearch 的數據存儲機制、索引機制以及事務日志&#xff08;Translog&#xff09;的使用。以下是 Elasticsearch 更新操作的主…

【C++】紅黑樹的實現

目錄 前言 一、紅黑樹的概念 二、紅黑樹的實現 三、紅黑樹的查找 四、紅黑樹的驗證 五、紅黑樹的刪除 總結 前言 本文講解紅黑樹&#xff0c;主要講解插入部分的實現&#xff0c;建議在理解了AVL樹的旋轉后再來學習紅黑樹&#xff0c;因為紅黑樹也涉及旋轉&#xff0c;并…