基于PyTorch的深度學習5——神經網絡工具箱

可以學習如下內容:

? 介紹神經網絡核心組件。

? 如何構建一個神經網絡。

? 詳細介紹如何構建一個神經網絡。

? 如何使用nn模塊中Module及functional。

? 如何選擇優化器。

? 動態修改學習率參數。

5.1 核心組件

神經網絡核心組件不多,把這些組件確定后,這個神經網絡基本就確定了。這些核心組件包括:

1)層:神經網絡的基本結構,將輸入張量轉換為輸出張量。

2)模型:層構成的網絡。

3)損失函數:參數學習的目標函數,通過最小化損失函數來學習各種參數。

4)優化器:如何使損失函數最小,這就涉及優化器。

多個層鏈接在一起構成一個模型或網絡,輸入數據通過這個模型轉換為預測值,然后損失函數把預測值與真實值進行比較,得到損失值(損失值可以是距離、概率值等)?,該損失值用于衡量預測值與目標結果的匹配或相似程度,優化器利用損失值更新權重參數,從而使損失值越來越小。這是一個循環過程,當損失值達到一個閥值或循環次數到達指定次數,循環結束。接下來利用PyTorch的nn工具箱,構建一個神經網絡實例。nn中對這些組件都有現成包或類,可以直接使用,非常方便。

——————————實現神經網絡實例

構建網絡層可以基于Module類,或函數(nn.functional)。

nn.Module中的大多數層(Layer)在functional中都有與之對應的函數。

nn.functional中函數與nn.Module中的Layer的主要區別是后者繼承Module類,會自動提取可學習的參數。

而nn.functional更像是純函數。

兩者功能相同,且性能也沒有很大區別,那么如何選擇呢?像卷積層、全連接層、Dropout層等因含有可學習參數,一般使用nn.Module,而激活函數、池化層不含可學習參數,可以使用nn.functional中對應的函數。

下面通過實例來說明如何使用nn構建一個網絡模型。

這節將利用神經網絡完成對手寫數字進行識別的實例,來說明如何借助nn工具箱來實現一個神經網絡,并對神經網絡有個直觀了解。在這個基礎上,后續我們將對nn的各模塊進行詳細介紹。實例環境使用PyTorch1.0+,GPU或CPU,源數據集為MNIST。

主要步驟:

1)利用PyTorch內置函數mnist下載數據。

2)利用torchvision對數據進行預處理,調用torch.utils建立一個數據迭代器。

3)可視化源數據。

4)利用nn工具箱構建神經網絡模型。

5)實例化模型,并定義損失函數及優化器。

6)訓練模型。

7)可視化結果。

import numpy as np
import torchfrom torchvision.datasets import mnist#導入預處理模塊
import torchvision.transforms as transforms
from torch.utils.data import DataLoader#導入nn及優化器
import torch.nn.functional as F
import torch.optim as optim
from torch import nn

接下來,定義一些超參數

train_batch_size=64
test_batch_size=128
learning_rate=0.01
num_epoches=20
lr=0.01
momentum=0.5

下載數據并對數據進行預處理

#定義預處理函數,這些預處理依次放在Compose函數中。
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5], [0.5])])
#下載數據,并對數據進行預處理
train_dataset = mnist.MNIST('./data', train=True, transform=transform, download=True)
test_dataset = mnist.MNIST('./data', train=False, transform=transform)#dataloader是一個可迭代對象,可以使用迭代器一樣使用。
train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)
  • transforms.Compose:

    • Compose?是 PyTorch 中的一個工具,用于將多個預處理操作組合在一起,依次執行。
  • transforms.ToTensor():

    • 將 PIL 圖像或 NumPy 數組轉換為 PyTorch 張量(Tensor)。
    • 同時會將圖像的像素值從?[0, 255]?轉換到?[0, 1]?的范圍。
  • transforms.Normalize([0.5], [0.5]):

    • 對張量進行歸一化處理。
    • 歸一化的公式是:output = (input - mean) / std
    • 這里的參數?[0.5]?表示均值(mean),[0.5]?表示標準差(std)。
    • 因此,歸一化后數據的范圍會從?[0, 1]?變為?[-1, 1]
    • DataLoader?是 PyTorch 提供的一個工具,用于將數據集分批加載,并支持多線程、打亂順序等功能。
  • 參數解釋:
    • train_dataset?和?test_dataset: 分別指定訓練集和測試集。
    • batch_size: 每次加載的數據批量大小。train_batch_size?和?test_batch_size?應該在代碼其他地方定義。
    • shuffle=True: 是否在每個 epoch 開始時打亂數據順序。通常在訓練集上設置為?True,以增加模型的泛化能力。
    • shuffle=False: 測試集一般不需要打亂順序。

——————————————可視化源數據

import matplotlib.pyplot as pltexamples=enumerate(test_loader)
batch_idx,(example_data,example_targets)=next(examples)fig=plt.figure()
for i in range(6):plt.subplot(2,3,i+1)plt.tight_layout()plt.imshow(example_data[i][0],cmap="gray",interpolation='none')plt.title("Ground Truth:{}".format(example_targets[i]))plt.xticks([])plt.yticks([])

_______________________構建模型

數據預處理結束后,開始構建模型

1)構建網絡的類Net

class Net(nn.Module):"""使用sequential構建網絡,Sequential()函數的功能是將網絡的層組合到一起"""def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):super(Net, self).__init__()self.layer1 = nn.Sequential(nn.Linear(in_dim,
n_hidden_1),nn.BatchNorm1d(n_hidden_1))self.layer2 = nn.Sequential(nn.Linear(n_hidden_1, n_hidden_2),nn.BatchNorm1d (n_hidden_2))self.layer3 = nn.Sequential(nn.Linear(n_hidden_2, out_dim))def forward(self, x):x = F.relu(self.layer1(x))x = F.relu(self.layer2(x))x = self.layer3(x)return x

2)創造實例

#檢測是否有可用的GPU,有則使用,否則使用CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#實例化網絡
model = Net(28 * 28, 300, 100, 10)
model.to(device)# 定義損失函數criterion和優化器optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)

以下內容為大語言模型生成:

這段代碼實現了一個使用nn.Sequential構建的三層全連接神經網絡,包含批歸一化和ReLU激活函數。以下是對代碼的分析及優化建議:


代碼結構分析

  1. 網絡定義

    • Layer1 :線性層(in_dimn_hidden_1) + 批歸一化。
    • Layer2 :線性層(n_hidden_1n_hidden_2) + 批歸一化。
    • Layer3 :線性層(n_hidden_2out_dim),無批歸一化和激活函數。
  2. 前向傳播

    • 前兩層使用ReLU激活,最后一層輸出原始 logits(適用于分類任務結合CrossEntropyLoss或回歸任務)。

————————訓練模型

這里使用for循環進行迭代。其中包括對訓練數據的訓練模型,然后用測試數據的驗證模型。

1.訓練模型

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm# 假設已定義模型、數據集、損失函數和優化器
# model = Net(...)
# train_dataset = ...
# test_dataset = ...# 數據預處理(展平)
class FlattenTransform:def __call__(self, img):return img.view(-1)transform = transforms.Compose([transforms.ToTensor(),FlattenTransform()
])train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64)# 初始化
losses = []
acces = []
eval_losses = []
eval_acces = []for epoch in range(num_epoches):# 動態調整學習率if epoch > 0 and epoch % 5 == 0:for param_group in optimizer.param_groups:new_lr = max(param_group['lr'] * 0.1, 1e-6)param_group['lr'] = new_lrprint(f"Learning rate updated to {new_lr}")# 訓練階段model.train()train_loss = 0train_correct = 0train_total = 0for img, label in tqdm(train_loader, desc=f"Epoch {epoch+1} Training"):img, label = img.to(device), label.to(device)# 前向傳播output = model(img)loss = criterion(output, label)# 反向傳播optimizer.zero_grad()loss.backward()optimizer.step()# 記錄指標train_loss += loss.item()_, pred = output.max(1)train_correct += (pred == label).sum().item()train_total += label.size(0)# 計算平均指標avg_train_loss = train_loss / len(train_loader)avg_train_acc = train_correct / train_totallosses.append(avg_train_loss)acces.append(avg_train_acc)# 測試階段model.eval()eval_loss, eval_acc = calculate_metrics(test_loader, model, criterion, device)eval_losses.append(eval_loss)eval_acces.append(eval_acc)# 打印日志print(f'Epoch {epoch+1}/{num_epoches}: 'f'Train Loss: {avg_train_loss:.4f}, Train Acc: {avg_train_acc:.4f}, 'f'Test Loss: {eval_loss:.4f}, Test Acc: {eval_acc:.4f}')

以下內容由大語言模型生成

為什么需要 model.eval()

在神經網絡中,某些層(如 BatchNormDropout)在訓練(training)和推理(evaluation)階段的行為不同:

  1. BatchNorm 層

    • 訓練階段 :使用當前 batch 的均值和方差進行標準化,并更新運行統計量(running mean/variance)。
    • 推理階段 :使用訓練階段累積的運行統計量,而非當前 batch 的統計量。
    • 若不調用 model.eval(),推理時會繼續更新統計量,導致結果不穩定。
  2. Dropout 層

    • 訓練階段 :隨機丟棄部分神經元,防止過擬合。
    • 推理階段 :關閉 Dropout,使用所有神經元的輸出(權重按概率縮放)。
    • 若不調用 model.eval(),推理時會繼續隨機丟棄神經元,導致結果隨機。

功能詳解

調用 model.eval() 后:

  1. 關閉訓練相關行為
    • BatchNorm 層停止計算均值/方差,使用累積的統計量。
    • Dropout 層停止隨機丟棄神經元。
  2. 不影響梯度計算
    • model.eval() 僅改變層的行為,不涉及梯度計算。如果需要禁用梯度,需配合 torch.no_grad()

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/71742.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/71742.shtml
英文地址,請注明出處:http://en.pswp.cn/web/71742.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

模擬調制技術詳解

內容摘要 本文系統講解模擬調制技術原理及Matlab實現,涵蓋幅度調制的四種主要類型:雙邊帶抑制載波調幅(DSB-SC)、含離散大載波調幅(AM)、單邊帶調幅(SSB)和殘留邊帶調幅(…

aws(學習筆記第三十一課) aws cdk深入學習(batch-arm64-instance-type)

aws(學習筆記第三十一課) aws cdk深入學習 學習內容: 深入練習aws cdk下部署batch-arm64-instance-type 1. 深入練習aws cdk下部署batch-arm64-instance-type 代碼鏈接 代碼鏈接 代碼鏈接 -> batch-arm64-instance-type之前代碼學習 之前學習代碼鏈接 -> aw…

讀書報告」網絡安全防御實戰--藍軍武器庫

一眨眼,20天過去了,刷完了這本書「網絡安全防御實戰--藍軍武器庫」,回味無窮,整理概覽如下,可共同交流讀書心得。在閱讀本書的過程中,我深刻感受到網絡安全防御是一個綜合性、復雜性極高的領域。藍軍需要掌…

生成任務,大模型

一個生成項目 輸入:文字描述(但是給的數據集是一串數字,id,ct描述,醫生描述) 輸出:診斷報告 一、數據處理 import pandas as pd #處理表格數據pre_train_file "data/train.csv"tr…

Spring Boot API 項目中 HAProxy 與 Nginx 的選擇與實踐

在開發 Spring Boot 構建的 RESTful API 項目時,負載均衡和反向代理是提升性能與可用性的關鍵環節。HAProxy 和 Nginx 作為兩種流行的工具,經常被用于流量分發,但它們各有側重。究竟哪一個更適合你的 Spring Boot API 項目?本文將…

Java常用集合與映射的線程安全問題深度解析

Java常用集合與映射的線程安全問題深度解析 一、線程安全基礎認知 在并發編程環境下,當多個線程同時操作同一集合對象時,若未采取同步措施,可能導致以下典型問題: 數據競爭:多個線程同時修改數據導致結果不可預測狀…

DeepLabv3+改進6:在主干網絡中添加SegNext_Attention|助力漲點

??【DeepLabv3+改進專欄!探索語義分割新高度】 ?? 你是否在為圖像分割的精度與效率發愁? ?? 本專欄重磅推出: ? 獨家改進策略:融合注意力機制、輕量化設計與多尺度優化 ? 即插即用模塊:ASPP+升級、解碼器 PS:訂閱專欄提供完整代碼 目錄 論文簡介 步驟一 步驟二…

使用 Elastic-Agent 或 Beats 將 Journald 中的 syslog 和 auth 日志導入 Elastic Stack

作者:來自 Elastic TiagoQueiroz 我們在 Elastic 一直努力將更多 Linux 發行版添加到我們的支持矩陣中,現在 Elastic-Agent 和 Beats 已正式支持 Debian 12! 本文演示了我們正在開發的功能,以支持使用 Journald 存儲系統和身份驗…

3.9[A]csd

在傳統CPU中心架構中,中央處理器通過內存訪問外部存儲器,而數據必須經過網絡接口卡才能到達外部存儲器。這種架構存在集中式計算、DRAM帶寬和容量挑戰、大量數據移動(服務器內和網絡)以及固定計算導致工作負載容量增長等問題。 而…

ESP32S3讀取數字麥克風INMP441的音頻數據

ESP32S3 與 INMP441 麥克風模塊的集成通常涉及使用 I2S 接口進行數字音頻數據的傳輸。INMP441 是一款高性能的數字麥克風,它通過 I2S 接口輸出音頻數據。在 Arduino 環境中,ESP32S3 的開發通常使用 ESP-IDF(Espressif IoT Development Framew…

DeepSeek大模型 —— 全維度技術解析

DeepSeek大模型 —— 全維度技術解析 前些天發現了一個巨牛的人工智能學習網站,通俗易懂,風趣幽默,可以分享一下給大家。點擊跳轉到網站。 https://www.captainbed.cn/ccc 文章目錄 DeepSeek大模型 —— 全維度技術解析一、模型架構全景解析1…

[Kubernetes] 7控制平面組件

1. 調度 kube- scheduler what 負責分配調度pod到集群節點監聽kube-apiserver,查詢未分配node的pod根據調度策略分配這些pod(更新pod的nodename)需要考慮的因素: 公平調度,資源有效利用,QoS,affinity, an…

PyTorch系列教程:編寫高效模型訓練流程

當使用PyTorch開發機器學習模型時,建立一個有效的訓練循環是至關重要的。這個過程包括組織和執行對數據、參數和計算資源的操作序列。讓我們深入了解關鍵組件,并演示如何構建一個精細的訓練循環流程,有效地處理數據處理,向前和向后…

LeetCode Hot100刷題——反轉鏈表(迭代+遞歸)

206.反轉鏈表 給你單鏈表的頭節點 head ,請你反轉鏈表,并返回反轉后的鏈表。 示例 1: 輸入:head [1,2,3,4,5] 輸出:[5,4,3,2,1]示例 2: 輸入:head [1,2] 輸出:[2,1]示例 3&#…

機器學習的發展史

機器學習(Machine Learning, ML)作為人工智能(AI)的一個分支,其發展經歷了多個階段。以下是機器學習的發展史概述: 1. 早期探索(20世紀50年代 - 70年代) 1950年:艾倫圖…

Springboot redis bitMap實現用戶簽到以及統計,保姆級教程

項目架構,這是作為demo展示使用: Redis config: package com.zy.config;import com.fasterxml.jackson.annotation.JsonAutoDetect; import com.fasterxml.jackson.annotation.PropertyAccessor; import com.fasterxml.jackson.databind.Ob…

Ardupilot開源無人機之Geek SDK進展2025Q1

Ardupilot開源無人機之Geek SDK進展2025Q1 1. 源由2. 內容匯總2.1 【jetson-fpv】YOLO INT8 coco8 dataset 精度降級2.2 【OpenIPC-Configurator】OpenIPC Configurator 固件升級失敗2.3 【OpenIPC-Adaptive-link】OpenIPC RF信號質量相關顯示2.4 【OpenIPC-msposd】.srt/.osd…

《云原生監控體系構建實錄:從Prometheus到Grafana的觀測革命》

PrometheusGrafana部署配置 Prometheus安裝 下載Prometheus服務端 Download | PrometheusAn open-source monitoring system with a dimensional data model, flexible query language, efficient time series database and modern alerting approach.https://prometheus.io/…

SpringMvc與Struts2

一、Spring MVC 1.1 概述 Spring MVC 是 Spring 框架的一部分,是一個基于 MVC 設計模式的輕量級 Web 框架。它提供了靈活的配置和強大的擴展能力,適合構建復雜的 Web 應用程序。 1.2 特點 輕量級:與 Spring 框架無縫集成,依賴…

數據類設計_圖片類設計之1_矩陣類設計(前端架構基礎)

前言 學的東西多了,要想辦法用出來.C和C是偏向底層的語言,直接與數據打交道.嘗試做一些和數據方面相關的內容 引入 圖形在底層是怎么表示的,用C來表示 認識圖片 圖片是個風景,動物,還是其他內容,人是可以看出來的.那么計算機是怎么看懂的呢?在有自主意識的人工智能被設計出來…