用deepseek學大模型04-模型可視化與數據可視化

deepseek.com: pytorch可視化工具 生成神經網絡圖

在 PyTorch 中,可視化神經網絡結構的常用工具和方法有以下幾種,以下將詳細介紹它們的用法:


1. TensorBoard (PyTorch 官方集成)

PyTorch 通過 torch.utils.tensorboard 支持 TensorBoard,可以直接可視化模型結構。

步驟:
  1. 安裝 TensorBoard

    pip install tensorboard
    
  2. 代碼示例

    import torch
    from torch.utils.tensorboard import SummaryWriter
    from torchvision.models import resnet18# 定義模型
    model = resnet18()
    # 創建 SummaryWriter 對象
    writer = SummaryWriter("runs/experiment1")
    # 生成隨機輸入數據(需與模型輸入形狀匹配)
    dummy_input = torch.randn(1, 3, 224, 224)
    # 將模型結構寫入 TensorBoard
    writer.add_graph(model, dummy_input)
    writer.close()
    
  3. 啟動 TensorBoard

    tensorboard --logdir=runs
    

    在瀏覽器中打開 http://localhost:6006,在 GRAPHS 標簽頁查看模型結構。


2. torchviz (生成計算圖)

torchviz 基于 Graphviz,生成模型的計算圖(前向傳播路徑)。

步驟:
  1. 安裝依賴

    pip install torchviz graphviz
    
  2. 代碼示例

    import torch
    from torchviz import make_dot
    from torchvision.models import resnet18model = resnet18()
    dummy_input = torch.randn(1, 3, 224, 224)
    output = model(dummy_input)
    # 生成計算圖
    dot = make_dot(output, params=dict(model.named_parameters()))
    # 保存為圖片(格式可為 PDF 或 PNG)
    dot.render("resnet18", format="png")
    

3. HiddenLayer (輕量級可視化)

hiddenlayer 是一個輕量級工具,適合快速生成模型結構圖。

步驟:
  1. 安裝

    pip install hiddenlayer
    
  2. 代碼示例

    import torch
    import hiddenlayer as hl
    from torchvision.models import resnet18model = resnet18()
    dummy_input = torch.randn(1, 3, 224, 224)
    # 構建可視化圖表
    hl.build_graph(model, dummy_input).save("resnet18.png")
    

4. Netron (模型結構查看器)

Netron 是一個獨立的可視化工具,支持 PyTorch 模型文件(.pt.pth)或 ONNX 格式。

步驟:
  1. 導出模型為 ONNX

    import torch
    from torchvision.models import resnet18model = resnet18()
    dummy_input = torch.randn(1, 3, 224, 224)
    torch.onnx.export(model, dummy_input, "resnet18.onnx")
    
  2. 使用 Netron 打開
    訪問 Netron Web 版 或下載桌面端,上傳 .onnx 文件即可查看交互式模型結構。


5. PyTorch 內置打印

通過 print(model) 快速查看模型層級(適合簡單結構):

import torch.nn as nnclass MyModel(nn.Module):def __init__(self):super().__init__()self.layers = nn.Sequential(nn.Linear(10, 20),nn.ReLU(),nn.Linear(20, 1))def forward(self, x):return self.layers(x)model = MyModel()
print(model)

工具對比

工具優點缺點
TensorBoard集成訓練監控,交互式可視化需要啟動服務,略復雜
torchviz顯示計算細節,適合調試依賴 Graphviz,圖可能較復雜
HiddenLayer輕量級,快速生成功能相對簡單
Netron無需代碼,支持多格式依賴導出模型文件

根據需求選擇工具:調試計算流程用 torchviz,訓練集成用 TensorBoard,快速分享用 Netron

在 PyTorch 中,可視化訓練集、測試集、預測結果和訓練過程可以通過多種工具實現。以下是常用的工具和方法,按場景分類:


1. 可視化數據集(訓練集/測試集)

工具:Matplotlib/Seaborn

直接使用 Python 繪圖庫可視化輸入數據或特征分布。

import matplotlib.pyplot as plt
import torchvision# 示例:可視化 CIFAR10 訓練集
dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')# 顯示前 9 張圖片
fig, axes = plt.subplots(3, 3, figsize=(10, 10))
for i, ax in enumerate(axes.flat):img, label = dataset[i]ax.imshow(img)ax.set_title(f"Label: {classes[label]}")ax.axis('off')
plt.show()

2. 可視化訓練過程

工具 1:TensorBoard(PyTorch 集成)

監控訓練損失、準確率等指標,支持動態更新。

from torch.utils.tensorboard import SummaryWriterwriter = SummaryWriter("runs/experiment1")for epoch in range(num_epochs):# 訓練代碼...train_loss = ...val_accuracy = ...# 記錄標量數據writer.add_scalar('Loss/train', train_loss, epoch)writer.add_scalar('Accuracy/val', val_accuracy, epoch)# 記錄模型權重分布for name, param in model.named_parameters():writer.add_histogram(name, param, epoch)# 啟動 TensorBoard
# tensorboard --logdir=runs
工具 2:Weights & Biases(第三方協作工具)

云端記錄實驗,支持超參數跟蹤和團隊協作。

import wandb# 初始化
wandb.init(project="my-project")# 記錄指標
wandb.log({"train_loss": train_loss, "val_acc": val_accuracy})# 記錄預測結果(圖像示例)
wandb.log({"predictions": [wandb.Image(img, caption=f"Pred:{pred}, True:{true}")]})

3. 可視化預測結果

方法 1:Matplotlib 直接繪制
# 示例:分類結果可視化
import numpy as npmodel.eval()
with torch.no_grad():inputs, labels = next(iter(test_loader))outputs = model(inputs)preds = torch.argmax(outputs, dim=1)# 顯示預測結果
fig, axes = plt.subplots(4, 4, figsize=(12, 12))
for i, ax in enumerate(axes.flat):ax.imshow(inputs[i].permute(1, 2, 0))  # 調整通道順序ax.set_title(f"Pred: {classes[preds[i]]}\nTrue: {classes[labels[i]]}")ax.axis('off')
plt.tight_layout()
plt.show()
方法 2:混淆矩陣(分類任務)
from sklearn.metrics import confusion_matrix
import seaborn as sns# 計算混淆矩陣
cm = confusion_matrix(true_labels, pred_labels)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=classes, yticklabels=classes)
plt.xlabel("Predicted")
plt.ylabel("True")
plt.show()

4. 高級可視化工具

工具 1:Plotly(交互式可視化)

繪制動態訓練曲線:

import plotly.express as px# 假設 logs 是包含訓練歷史的字典
fig = px.line(logs, x='epoch', y=['train_loss', 'val_loss'], title="Training and Validation Loss")
fig.show()
工具 2:Gradio(快速構建交互式 Demo)

部署模型預測交互界面:

import gradio as grdef classify_image(img):img = preprocess(img)  # 預處理pred = model(img)      # 推理return classes[torch.argmax(pred)]gr.Interface(fn=classify_image, inputs="image", outputs="label").launch()

5. 訓練過程可視化(高級)

工具:PyTorch Lightning

自動集成 TensorBoard 和 WandB,簡化日志記錄:

import pytorch_lightning as plclass MyModel(pl.LightningModule):def training_step(self, batch, batch_idx):x, y = batchy_hat = self(x)loss = F.cross_entropy(y_hat, y)self.log("train_loss", loss)  # 自動記錄到日志return loss# 訓練時指定 logger
trainer = pl.Trainer(logger=pl.loggers.TensorBoardLogger("logs/"),# 或使用 WandB# logger=pl.loggers.WandbLogger(project="my-project")
)
trainer.fit(model)

工具對比

工具/方法適用場景優點缺點
TensorBoard訓練指標跟蹤、模型結構可視化官方集成,功能全面需本地啟動服務
WandB團隊協作、云端實驗管理實時同步、超參數跟蹤需要注冊賬號
Matplotlib靜態數據可視化靈活、無需額外依賴交互性弱
Plotly交互式動態圖表支持網頁嵌入、動態更新學習曲線稍陡峭
Gradio快速部署預測 Demo零代碼交互界面功能相對簡單

關鍵場景總結

  1. 訓練過程監控:優先選擇 TensorBoard 或 WandB。
  2. 數據集預覽:Matplotlib/Seaborn 快速繪制。
  3. 預測結果分析:混淆矩陣(分類)、BBox 標注(檢測)、Matplotlib 對比圖(回歸)。
  4. 協作與報告:WandB 或 TensorBoard.dev(云端共享)。

可根據需求組合使用工具,例如:TensorBoard + Matplotlib(本地開發)或 WandB + Gradio(團隊協作 + 演示)。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/70049.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/70049.shtml
英文地址,請注明出處:http://en.pswp.cn/web/70049.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

JavaScript設計模式 -- 外觀模式

在實際開發中,往往會遇到多個子系統協同工作時,直接操作各個子系統不僅接口繁瑣,還容易導致客戶端與內部實現緊密耦合。**外觀模式(Facade Pattern)**通過為多個子系統提供一個統一的高層接口,將復雜性隱藏…

【性能測試】如何理解“10個線程且10次循環“的請求和“100線程且1次循環“的請求

在性能測試中,我們常常會見到不同的并發配置:比如“10個線程且10次循環”與“100線程且1次循環”。乍一看,這兩個設置的總請求數都是100次,但它們對系統的壓力和測試場景卻截然不同。了解其中的區別,能幫助你更精準地模…

Spring Boot 實戰:輕松實現文件上傳與下載功能

目錄 一、引言 二、Spring Boot 文件上傳基礎 (一)依賴引入 (二)配置文件設置 (三)文件上傳接口編寫 (一)文件類型限制 (二)文件大小驗證 &#xff0…

【Golang】GC探秘/寫屏障是什么?

之前寫了 一篇【Golang】內存管理 ,有了很多的閱讀量,那么我就接著分享一下Golang的GC相關的學習。 由于Golang的GC機制一直在持續迭代,本文敘述的主要是Go1.9版本及以后的GC機制,該版本中Golang引入了 混合寫屏障大幅度地優化了S…

DeepSeek教unity------MessagePack-03

數據契約兼容性 你可以使用 [DataContract] 注解代替 [MessagePackObject]。如果類型用 DataContract 進行注解,可以使用 [DataMember] 注解代替 [Key],并使用 [IgnoreDataMember] 代替 [IgnoreMember]。 然后,[DataMember(Order int)] 的…

【對比】Pandas 和 Polars 的區別

Pandas vs Polars 對比表 特性PandasPolars開發語言Python(Cython 實現核心部分)Rust(高性能系統編程語言)性能較慢,尤其在大數據集上(內存占用高,計算效率低)極快,利用…

百度千帆平臺對接DeepSeek官方文檔

目錄 第一步:注冊賬號,開通千帆服務 第二步:創建應用,獲取調用秘鑰 第三步:調用模型,開啟AI對話 方式一:通過API直接調用 方式二:使用SDK快速調用 方式三:在千帆大模…

49. c++計時器

為了測試某段特定代碼的執行時間&#xff0c;體現代碼的性能&#xff0c;可以使用計時器對代碼段計時。下面使用std::chrono中的api編寫簡單案例&#xff1a; // // main.cpp // HelloWorld // // Created by on 2024/11/28. //#include <iostream> #include <vec…

Natural Language Processing NLP

NLP 清晰版本查看 Sentence segmentation (split)Tokenisation (split)Named entity recognition (combine) 概念主要內容典型方法Distributional Semantics&#xff08;分佈式語義&#xff09;&#xff08;分銷語義&#xff08;分佈式語義&#xff09;單詞的語義來自於它的…

Linux中線程創建,線程退出,線程接合

線程的簡單了解 之前我們了解過 task_struct 是用于描述進程的核心數據結構。它包含了一個進程的所有重要信息&#xff0c;并且在進程的生命周期內保持更新。我們想要獲取進程相關信息往往從這里得到。 在Linux中&#xff0c;線程的實現方式與進程類似&#xff0c;每個線程都…

HarmonyOS:使用List實現分組列表(包含粘性標題)

一、支持分組列表 在列表中支持數據的分組展示&#xff0c;可以使列表顯示結構清晰&#xff0c;查找方便&#xff0c;從而提高使用效率。分組列表在實際應用中十分常見&#xff0c;如下圖所示聯系人列表。 聯系人分組列表 在List組件中使用ListItemGroup對項目進行分組&#…

django上傳文件

1、settings.py配置 # 靜態文件配置 STATIC_URL /static/ STATICFILES_DIRS [BASE_DIR /static, ]上傳文件 # 定義一個視圖函數&#xff0c;該函數接收一個 request 參數 from django.shortcuts import render # 必備引入 import json from django.views.decorators.http i…

【前端知識】瀏覽器兼容方案polyfill

瀏覽器兼容方案polyfill 什么是 Polyfill&#xff1f;Polyfill 的作用Polyfill 的工作原理1. **特性檢測**2. **加載 Polyfill**3. **模擬實現** Polyfill 的常見場景Polyfill 的使用方式Polyfill 的優缺點優點缺點 常見的 Polyfill 庫總結 什么是 Polyfill&#xff1f; Polyf…

C#學習之DateTime 類

目錄 一、DateTime 類的常用方法和屬性的匯總表格 二、常用方法程序示例 1. 獲取當前本地時間 2. 獲取當前 UTC 時間 3. 格式化日期和時間 4. 獲取特定部分的時間 5. 獲取時間戳 6. 獲取時區信息 三、總結 一、DateTime 類的常用方法和屬性的匯總表格 在 C# 中&#x…

dedecms 開放重定向漏洞(附腳本)(CVE-2024-57241)

免責申明: 本文所描述的漏洞及其復現步驟僅供網絡安全研究與教育目的使用。任何人不得將本文提供的信息用于非法目的或未經授權的系統測試。作者不對任何由于使用本文信息而導致的直接或間接損害承擔責任。如涉及侵權,請及時與我們聯系,我們將盡快處理并刪除相關內容。 0x0…

如何選擇合適的超參數來訓練Bert和TextCNN模型?

選擇合適的超參數來訓練Bert和TextCNN模型是一個復雜但關鍵的過程&#xff0c;它會顯著影響模型的性能。以下是一些常見的超參數以及選擇它們的方法&#xff1a; 1. 與數據處理相關的超參數 最大序列長度&#xff08;max_length&#xff09; 含義&#xff1a;指輸入到Bert模…

AWS 前端自動化部署流程指南

本文詳細介紹從前端代碼開發到 AWS 自動化部署的完整流程。 一、流程概覽 1.1 部署流程圖 #mermaid-svg-nYg7k6L5IKVBjDtr {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-nYg7k6L5IKVBjDtr .error-icon{fill:#552…

Office word打開加載比較慢處理方法

1.添加safe參數 ,找到word啟動項,右擊word,選擇屬性 , 添加/safe , 應用并確定 2.取消加載項,點擊文件,點擊選項 ,點擊加載項,點擊轉到,取消所有勾選,確定。

大數據SQL調優專題——Spark執行原理

引入 在深入MapReduce中有提到&#xff0c;MapReduce雖然通過“分而治之”的思想&#xff0c;解決了海量數據的計算處理問題&#xff0c;但性能還是不太理想&#xff0c;這體現在兩個方面&#xff1a; 每個任務都有比較大的overhead&#xff0c;都需要預先把程序復制到各個 w…

MYSQL下載安裝及使用

MYSQL官網下載地址&#xff1a;https://downloads.mysql.com/archives/community/ 也可以直接在服務器執行指令下載&#xff0c;但是下載速度比較慢。還是自己下載好拷貝過來比較快。 wget https://dev.mysql.com/get/Downloads/mysql-5.7.38-linux-glibc2.12-x86_64.tar.gz 1…