Pytorch系列教程:可視化Pytorch模型訓練過程

深度學習和理解訓練過程中的學習和進步機制對于優化性能、診斷欠擬合或過擬合等問題至關重要。將訓練過程可視化的過程為學習的動態提供了有價值的見解,使我們能夠做出合理的決策。訓練進度必須可視化的兩種方法是:使用Matplotlib和Tensor Board。在本文中,我們將學習如何在Pytorch中可視化模型訓練進度。

使用Matplotlib在PyTorch中可視化訓練進度

Matplotlib是Python中廣泛使用的繪圖庫,它為在Python中創建靜態,動畫和交互式可視化提供了靈活而強大的工具。它特別適合于創建出版質量的圖表。
在這里插入圖片描述

**步驟1:**導入必要的庫并生成樣本數據集

在這一步中,我們將導入必要的庫并生成樣本數據集。

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
# Sample data
X = torch.randn(100, 1)  # Sample features
y = 3 * X + 2 + torch.randn(100, 1)  # Sample labels with noise

**步驟2:**定義模型

  1. PyTorch中的LinearRegression類定義了一個簡單的線性回歸模型。它繼承自nn。模塊的類,使其成為一個神經網絡模型。
  2. 構造函數(__init__方法)初始化模型的結構,創建具有一個輸入特征和一個輸出特征的單一線性層(‘nn.Linear’)。
  3. 這個線性層被存儲為名為 self.linear的屬性。“forward”方法定義了如何通過這個線性層處理輸入數據“x”以產生模型的輸出。
  4. 具體來說,輸入x是通過 self.linear,并返回結果輸出。該方法封裝了神經網絡的前向傳遞計算,決定了模型如何將輸入轉換為輸出。
# Define a simple linear regression model
class LinearRegression(nn.Module):def __init__(self):super(LinearRegression, self).__init__()self.linear = nn.Linear(1, 1)  # One input feature, one outputdef forward(self, x):return self.linear(x)model = LinearRegression()

**步驟3:**定義損失函數、優化器和訓練循環

在下面的代碼中,我們將均方誤差定義為損失函數,將隨機梯度下降(SGD)優化器定義為優化器,該優化器通過使用學習率為0.01的計算梯度來修改模型的參數。

# Define loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

這段代碼運行了一個神經網絡模型在多個時代的訓練循環,使用梯度下降計算和優化損失。損失值被存儲以進行繪圖,進度每10次打印一次。

# Training loop
num_epochs = 100
losses = []
for epoch in range(num_epochs):# Forward passoutputs = model(X)loss = criterion(outputs, y)# Backward pass and optimizationoptimizer.zero_grad()loss.backward()optimizer.step()# Print progressif (epoch+1) % 10 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')# Store loss for plottinglosses.append(loss.item())

**步驟4:**使用Matplotlib在PyTorch中可視化訓練進度

使用下面的代碼,我們可以使用matplotlib可視化訓練損失曲線。

  • plot(損失)線根據epoch號繪制存儲在損失列表中的損失值。
  • x軸表示歷元數,y軸表示相應的損失值。
  • plt.xlabel(‘Epoch’), plt.ylabel(‘Loss‘)和plt.xlabel(’Epoch’).title()‘Training Loss’)行設置情節的標簽和標題。
  • 最后,plot .show()顯示該圖,允許您可視化地分析損失如何在訓練期間減少(或收斂)。
# Plot the loss curve
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.show()

通常,您會期望在損失曲線中看到下降的趨勢,這表明模型正在隨著時間的推移而學習和改進。

完整的代碼:

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt# Sample data
X = torch.randn(100, 1)  # Sample features
y = 3 * X + 2 + torch.randn(100, 1)  # Sample labels with noise# Define a simple linear regression model
class LinearRegression(nn.Module):def __init__(self):super(LinearRegression, self).__init__()self.linear = nn.Linear(1, 1)  # One input feature, one outputdef forward(self, x):return self.linear(x)model = LinearRegression()# Define loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# Training loop
num_epochs = 100
losses = []
for epoch in range(num_epochs):# Forward passoutputs = model(X)loss = criterion(outputs, y)# Backward pass and optimizationoptimizer.zero_grad()loss.backward()optimizer.step()# Print progressif (epoch+1) % 10 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')# Store loss for plottinglosses.append(loss.item())# Plot the loss curve
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.show()

在這里插入圖片描述

輸出圖顯示了訓練損失如何隨時間變化,并根據迭代次數繪制。這種可視化使人們能夠看到模型在訓練時是如何減少損失的。此外,Matplotlib圖還有其他東西,如軸標簽、標題,可能還有標記或線條,表示特定事件,如最小實現損失或損失急劇下降。

使用TensorBoard可視化訓練進度

為了在深度學習模型中可視化訓練過程,我們可以使用torch.utils.tensorboard模塊中的SummaryWriter類,該模塊與TensorFlow開發的可視化工具TensorBoard無縫集成。
在這里插入圖片描述

  • 集成:PyTorch在torch.utils.tensorboard模塊中提供了一個SummaryWriter類,它與TensorBoard無縫集成以實現可視化。
  • 日志記錄:在訓練循環中,您可以使用SummaryWriter記錄各種指標,如損失,準確性等,以實現可視化。
  • 可視化:TensorBoard提供了記錄指標的交互式和實時可視化,允許您動態監控訓練進度。
  • 監控:TensorBoard使您能夠監控訓練的多個方面,例如學習曲線,模型圖和權重直方圖,為優化您的模型提供見解。

使用以下命令安裝TensorBoard庫:

pip install tensorboard

步驟1:導入庫

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

步驟2:定義簡單的神經網絡

讓我們定義SimpleNN一個簡單神經網絡的類聲明,它包含兩個完全連接的層,以及定義網絡前向傳遞的forward函數。

# Define a simple neural network
class SimpleNN(nn.Module):def __init__(self):super(SimpleNN, self).__init__()self.fc1 = nn.Linear(784, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = torch.flatten(x, 1)x = torch.relu(self.fc1(x))x = self.fc2(x)return x

步驟3:加載MNIST數據集

讓我們加載用于訓練的MINST數據,將其分成批次并使用一些預處理技術進行轉換。

# Load a smaller subset of MNIST dataset
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
small_train_dataset = torch.utils.data.Subset(train_dataset, range(1000))  # Subset of first 1000 samples
train_loader = DataLoader(small_train_dataset, batch_size=64, shuffle=True)

步驟4:初始化模型、損失函數和優化器

現在,初始化模型。與此同時,我們將使用交叉熵損失函數和adam優化器來更新模型參數。

# Initialize model, loss function, and optimizer
model = SimpleNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

步驟5:初始化用于日志記錄的SummaryWriter

SummaryWriter是導入模塊的對象,用于編寫要在TensorBoard中可視化的日志。

# Initialize SummaryWriter for logging
writer = SummaryWriter('logs_small')

第六步:循環訓練

  • 訓練循環:通過時代和批次,執行向前傳遞,計算損失,向后傳遞和更新模型參數。
  • 日志損失和準確性:記錄劃時代的訓練損失和準確性。
# Training loop
epochs = 5
for epoch in range(epochs):running_loss = 0.0correct = 0total = 0for i, (inputs, labels) in enumerate(train_loader):optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()# Calculate accuracy_, predicted = torch.max(outputs, 1)total += labels.size(0)correct += (predicted == labels).sum().item()# Log losswriter.add_scalar('Loss/train', loss.item(), epoch * len(train_loader) + i)# Log accuracyaccuracy = 100 * correct / totalwriter.add_scalar('Accuracy/train', accuracy, epoch)print(f'Epoch [{epoch+1}/{epochs}], Loss: {running_loss / len(train_loader)}, Accuracy: {accuracy}%')print('Finished Training')
writer.close()

完整代碼:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter# Define a simple neural network
class SimpleNN(nn.Module):def __init__(self):super(SimpleNN, self).__init__()self.fc1 = nn.Linear(784, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = torch.flatten(x, 1)x = torch.relu(self.fc1(x))x = self.fc2(x)return x# Load a smaller subset of MNIST dataset
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
small_train_dataset = torch.utils.data.Subset(train_dataset, range(1000))  # Subset of first 1000 samples
train_loader = DataLoader(small_train_dataset, batch_size=64, shuffle=True)# Initialize model, loss function, and optimizer
model = SimpleNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# Initialize SummaryWriter for logging
writer = SummaryWriter('logs_small')# Training loop
epochs = 5
for epoch in range(epochs):running_loss = 0.0correct = 0total = 0for i, (inputs, labels) in enumerate(train_loader):optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()# Calculate accuracy_, predicted = torch.max(outputs, 1)total += labels.size(0)correct += (predicted == labels).sum().item()# Log losswriter.add_scalar('Loss/train', loss.item(), epoch * len(train_loader) + i)# Log accuracyaccuracy = 100 * correct / totalwriter.add_scalar('Accuracy/train', accuracy, epoch)print(f'Epoch [{epoch+1}/{epochs}], Loss: {running_loss / len(train_loader)}, Accuracy: {accuracy}%')print('Finished Training')
writer.close()

運行示例,輸出如下:

Epoch [1/5], Loss: 1.8145772516727448, Accuracy: 47.1%
Epoch [2/5], Loss: 1.0121613591909409, Accuracy: 78.8%
Epoch [3/5], Loss: 0.6829517856240273, Accuracy: 84.1%
Epoch [4/5], Loss: 0.5442189555615187, Accuracy: 85.4%
Epoch [5/5], Loss: 0.46599634923040867, Accuracy: 87.0%
Finished Training

TensorBoard提供了一個基于web的儀表板,其中包含代表各種培訓方面的選項卡和可視化。標量度量將損失或準確度等值可視化,為訓練動態提供了不同的視角。此外,TensorBoard可以顯示直方圖、嵌入和基于日志信息的專門可視化。

在PyTorch中可視化訓練進度

為了運行TensorBoard,你應該打開終端,然后運行tensorboard use命令:

tensorboard --logdir=./logs_small

注意,這里logdir指定上節示例的路徑,采用相對路徑表示。訪問TensorBoard需要:打開瀏覽器,輸入TensorBoard提供的網址(通常為http://localhost:6006/)。

a
b

TensorBoard提供了一個基于web的儀表板,其中包含代表各種培訓方面的選項卡和可視化。標量度量將損失或準確度等值可視化,為訓練動態提供了不同的視角。此外,TensorBoard可以顯示直方圖、嵌入和基于日志信息的專門可視化。

在這篇博客中,我們介紹了如何使用matplotlib和tensorboard來可視化深度學習框架的訓練過程。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/897594.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/897594.shtml
英文地址,請注明出處:http://en.pswp.cn/news/897594.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

18 | 實現簡潔架構的 Handler 層

提示: 所有體系課見專欄:Go 項目開發極速入門實戰課;歡迎加入我的訓練營:云原生AI實戰營,一個助力 Go 開發者在 AI 時代建立技術競爭力的實戰營;本節課最終源碼位于 fastgo 項目的 feature/s14 分支&#x…

藍隊第三次

1.了解什么是盲注 盲注(Blind SQL Injection)是SQL注入的一種形式,攻擊者無法直接通過頁面回顯或錯誤信息獲取數據,而是通過觀察頁面的布爾狀態(真/假)或時間延遲來間接推斷數據庫信息。例如,通…

sql server 2016 版本補丁說明

包信息和發布類型 Microsoft為創建和分發的 SQL Server 的所有軟件更新包采用了標準化命名架構。 軟件更新包是一個可執行文件(.exe 或 .msi)文件,其中包含一個或多個文件,這些文件可能應用于 SQL Server 安裝以更正特定問題。 …

STM32之I2C硬件外設

注意:硬件I2C的引腳是固定的 SDA和SCL都是復用到外部引腳。 SDA發送時數據寄存器的數據在數據移位寄存器空閑的狀態下進入數據移位寄存器,此時會置狀態寄存器的TXE為1,表示發送寄存器為空,然后往數據控制寄存器中一位一位的移送數…

從青銅到王者:六大排序算法實戰解析

前言 在編程的世界里,排序算法如同一顆璀璨的明珠,閃耀著智慧的光芒。它不僅是計算機科學的基礎知識點,更是每一位程序員必備的技能。今天,就讓我們一同走進排序算法的世界,深入探究冒泡排序、選擇排序、插入排序、快速排序、歸并排序、堆排序這六大經典算法的精髓所在,…

小程序配置webview

1.在微信公眾平臺配置業務域名 1)包括把校驗文件放在服務器根目錄 2)配置域名 2.在小程序中 新建文件 小程序新建頁面:web-view json配置:{ "pageOrientation": "landscape", "renderer":&qu…

不用 Tomcat?SpringBoot 項目用啥代替?

在SpringBoot框架中,我們使用最多的是Tomcat,這是SpringBoot默認的容器技術,而且是內嵌式的Tomcat。 同時,SpringBoot也支持Undertow容器,我們可以很方便的用Undertow替換Tomcat,而Undertow的性能和內存使…

線索二叉樹構造及遍歷算法

線索二叉樹構造以及遍歷算法 線索二叉樹(中序遍歷版)構造線索二叉樹構造雙向線索鏈表遍歷中序線索二叉樹 線索二叉樹(中序遍歷版) 中序遍歷找到對應結點的前驅(土方法) #mermaid-svg-eunGO5d2GhjLxCn5 {fo…

基于SpringBoot的“體育購物商城”的設計與實現(源碼+數據庫+文檔+PPT)

基于SpringBoot的“體育購物商城”的設計與實現(源碼數據庫文檔PPT) 開發語言:Java 數據庫:MySQL 技術:SpringBoot 工具:IDEA/Ecilpse、Navicat、Maven 系統展示 系統總體模塊設計 前臺用戶登錄界面 系統首頁界面…

數據篇| App爬蟲入門(一)

App 的爬取相比 Web 端爬取更加容易,反爬蟲能力沒有那么強,而且數據大多是以 JSON 形式傳輸的,解析更加簡單。在 Web 端,我們可以通過瀏覽器的開發者工具監聽到各個網絡請求和響應過程,在 App 端如果想要查看這些內容就需要借助抓包軟件。常見抓包軟件有: ?工具名稱??…

go context學習

1.Context接口2.emptyCtx3.Deadline()方法4.Done()方法5.Err方法6.Value方法()7.contex應用場景8.其他context方法 1.Context接口 Context接口只有四個方法,以下是context源碼。 type Context interface {Deadline() (deadline time.Time, …

在VMware Workstation Pro上輕松部署CentOS7 Linux虛擬機

首先我們需要下載VM虛擬機和Centos7的鏡像 下載并安裝VMware Workstation Pro 訪問VMware Workstation Pro官網下載 https://www.vmware.com/ 第二步:下載centos7鏡像 訪問centos官網下載 https://www.centos.org/ 開始部署Centos7 點擊創建新的虛擬機 這里是Cen…

Jsoup 解析商品信息時需要注意哪些細節?

在使用Jsoup解析商品信息時,需要注意以下細節和最佳實踐,以確保爬蟲的穩定性和數據的準確性: 1. 檢查HTML文檔的合法性 在解析之前,需要確認所解析的文檔是否是一份合法正確的HTML文檔。如果HTML結構不完整或存在錯誤&#xff0…

Android AudioFlinger(五)—— 揭開AudioMixer面紗

前言: 在 Android 音頻系統中,AudioMixer 是音頻框架中一個關鍵的組件,用于處理多路音頻流的混音操作。它主要存在于音頻回放路徑中,是 AudioFlinger 服務的一部分。 上一節我們講threadloop的時候,提到了一個函數pr…

go的”ambiguous import in multiple modules”

執行“go mod tidy”報如下錯誤: go mod tidy -compat1.17 go: finding module for package github.com/gomooon/goredis go: found github.com/gomooon/goredis in github.com/gomooon/goredis v0.3.5 go: github.com/gomooon/core importsgithub.com/gomooon/gor…

從0開始的操作系統手搓教程27:下一步,實現我們的用戶進程

目錄 第一步:添加用戶進程虛擬空間 準備沖向我們的特權級3(用戶特權級) 討論下我們創建用戶線程的基本步驟 更加詳細的分析代碼 用戶進程的視圖 說一說BSS段 繼續看process.c中的函數 添加用戶線程激活 現在,我們做好了TSS…

Java線程池深度解析,從源碼到面試熱點

Java線程池深度解析,從源碼到面試熱點 一、線程池的核心價值與設計哲學 在開始討論多線程編程之前,可以先思考一個問題?多線程編程的原理是什么? 我們知道,現在的CUP是多核CPU,假設你的機器是4核的&#x…

大數據技術在土地利用規劃中的應用分析

大數據技術在土地利用規劃中的應用分析 一、引言 土地利用規劃是對一定區域內的土地開發、利用、整治和保護所作出的統籌安排與戰略部署,對于實現土地資源的優化配置、保障社會經濟的可持續發展具有關鍵意義。在當今數字化時代,大數據技術憑借其海量數據處理、高效信息挖掘等…

Node 使用 SSE 結合redis 推送數據(echarts 圖表實時更新)

1、實時通信有哪些實現方式? 特性輪詢(Polling)WebSocketSSE (Server-Sent Events)通信方向單向(客戶端 → 服務端)雙向(客戶端 ? 服務端)單向(服務端 → 客戶端)連接方…

Android Native 之 文件系統掛載

一、文件系統掛載流程概述 二、文件系統掛載流程細節 1、Init啟動階段 眾所周知,init進程為android系統的第一個進程,也是native世界的開端,要想讓整個android世界能夠穩定的運行,文件系統的創建和初始化是必不可少的&#xff…