Day 40 訓練和測試的規范寫法

知識點回顧:

  1. 彩色和灰度圖片測試和訓練的規范寫法:封裝在函數中
  2. 展平操作:除第一個維度batchsize外全部展平
  3. dropout操作:訓練階段隨機丟棄神經元,測試階段eval模式關閉dropout

作業:仔細學習下測試和訓練代碼的邏輯,這是基礎,這個代碼框架后續會一直沿用,后續的重點慢慢就是轉向模型定義階段了。

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np# 設置中文字體支持
plt.rcParams["font.family"] = ["SimHei"]
plt.rcParams['axes.unicode_minus'] = False  # 解決負號顯示問題# 1. 數據預處理
transform = transforms.Compose([transforms.ToTensor(),  # 轉換為張量并歸一化到[0,1]transforms.Normalize((0.1307,), (0.3081,))  # MNIST數據集的均值和標準差
])# 2. 加載MNIST數據集
train_dataset = datasets.MNIST(root='./data',train=True,download=True,transform=transform
)test_dataset = datasets.MNIST(root='./data',train=False,transform=transform
)# 3. 創建數據加載器
batch_size = 64  # 每批處理64個樣本
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)# 4. 定義模型、損失函數和優化器
class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.flatten = nn.Flatten()  # 將28x28的圖像展平為784維向量self.layer1 = nn.Linear(784, 128)  # 第一層:784個輸入,128個神經元self.relu = nn.ReLU()  # 激活函數self.layer2 = nn.Linear(128, 10)  # 第二層:128個輸入,10個輸出(對應10個數字類別)def forward(self, x):x = self.flatten(x)  # 展平圖像x = self.layer1(x)   # 第一層線性變換x = self.relu(x)     # 應用ReLU激活函數x = self.layer2(x)   # 第二層線性變換,輸出logitsreturn x# 檢查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 初始化模型
model = MLP()
model = model.to(device)  # 將模型移至GPU(如果可用)criterion = nn.CrossEntropyLoss()  # 交叉熵損失函數,適用于多分類問題
optimizer = optim.Adam(model.parameters(), lr=0.001)  # Adam優化器# 5. 訓練模型(記錄每個 iteration 的損失)
def train(model, train_loader, test_loader, criterion, optimizer, device, epochs):model.train()  # 設置為訓練模式# 新增:記錄每個 iteration 的損失all_iter_losses = []  # 存儲所有 batch 的損失iter_indices = []     # 存儲 iteration 序號(從1開始)for epoch in range(epochs):running_loss = 0.0correct = 0total = 0for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)  # 移至GPU(如果可用)optimizer.zero_grad()  # 梯度清零output = model(data)  # 前向傳播loss = criterion(output, target)  # 計算損失loss.backward()  # 反向傳播optimizer.step()  # 更新參數# 記錄當前 iteration 的損失(注意:這里直接使用單 batch 損失,而非累加平均)iter_loss = loss.item()all_iter_losses.append(iter_loss)iter_indices.append(epoch * len(train_loader) + batch_idx + 1)  # iteration 序號從1開始# 統計準確率和損失(原邏輯保留,用于 epoch 級統計)running_loss += iter_loss_, predicted = output.max(1)total += target.size(0)correct += predicted.eq(target).sum().item()# 每100個批次打印一次訓練信息(可選:同時打印單 batch 損失)if (batch_idx + 1) % 100 == 0:print(f'Epoch: {epoch+1}/{epochs} | Batch: {batch_idx+1}/{len(train_loader)} 'f'| 單Batch損失: {iter_loss:.4f} | 累計平均損失: {running_loss/(batch_idx+1):.4f}')# 原 epoch 級邏輯(測試、打印 epoch 結果)不變epoch_train_loss = running_loss / len(train_loader)epoch_train_acc = 100. * correct / totalepoch_test_loss, epoch_test_acc = test(model, test_loader, criterion, device)print(f'Epoch {epoch+1}/{epochs} 完成 | 訓練準確率: {epoch_train_acc:.2f}% | 測試準確率: {epoch_test_acc:.2f}%')# 繪制所有 iteration 的損失曲線plot_iter_losses(all_iter_losses, iter_indices)# 保留原 epoch 級曲線(可選)# plot_metrics(train_losses, test_losses, train_accuracies, test_accuracies, epochs)return epoch_test_acc  # 返回最終測試準確率# 6. 測試模型
def test(model, test_loader, criterion, device):model.eval()  # 設置為評估模式test_loss = 0correct = 0total = 0with torch.no_grad():  # 不計算梯度,節省內存和計算資源for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)test_loss += criterion(output, target).item()_, predicted = output.max(1)total += target.size(0)correct += predicted.eq(target).sum().item()avg_loss = test_loss / len(test_loader)accuracy = 100. * correct / totalreturn avg_loss, accuracy  # 返回損失和準確率# 7.繪制每個 iteration 的損失曲線
def plot_iter_losses(losses, indices):plt.figure(figsize=(10, 4))plt.plot(indices, losses, 'b-', alpha=0.7, label='Iteration Loss')plt.xlabel('Iteration(Batch序號)')plt.ylabel('損失值')plt.title('每個 Iteration 的訓練損失')plt.legend()plt.grid(True)plt.tight_layout()plt.show()# 8. 執行訓練和測試(設置 epochs=2 驗證效果)
epochs = 2  
print("開始訓練模型...")
final_accuracy = train(model, train_loader, test_loader, criterion, optimizer, device, epochs)
print(f"訓練完成!最終測試準確率: {final_accuracy:.2f}%")

@浙大疏錦行

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/92763.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/92763.shtml
英文地址,請注明出處:http://en.pswp.cn/web/92763.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

分析代碼并回答問題

代碼 <template><div>Counter: {{ counter }}</div><div>Double Counter: {{ doubleCounter }}</div> </template><script setup lang"ts"> import { ref, computed } from "vue";const counter ref(0);const …

在macOS上掃描192.168.1.0/24子網的所有IP地址

在macOS上掃描192.168.1.0/24子網的所有IP地址&#xff0c;可以通過終端命令實現。以下是幾種常用方法&#xff1a; 使用ping命令循環掃描 打開終端執行以下腳本&#xff0c;會逐個ping測試192.168.1.1到192.168.1.254的地址&#xff0c;并過濾出有響應的IP&#xff1a; for i …

Java基礎05——類型轉換(本文為個人學習筆記,內容整理自嗶哩嗶哩UP主【遇見狂神說】的公開課程。 > 所有知識點歸屬原作者,僅作非商業用途分享)

Java基礎05——類型轉換 類型轉換 由于Java是強類型語言&#xff0c;所以要進行有些運算的時候&#xff0c;需要用到類型轉換。 如&#xff1a;byte(占1個字節)&#xff0c;short(占2個字節)&#xff0c;char(占2個字節)→int(4個字節)→long(占8個字節)→float(占4個字節)→do…

mysql基礎(二)五分鐘掌握全量與增量備份

全量備份 Linux環境 數據備份 數據庫的備份與恢復有多中方法&#xff0c;通過mysql自帶的mysqldump工具可對數據庫進行備份。語法&#xff1a; mysqldump -u username -p password --databases db_name > file_name .sql說明&#xff1a; -u參數指定用戶名&#xff0c;usern…

使用Windbg分析多線程死鎖項目實戰問題分享

目錄 1、問題描述 2、使用.effmach x86命令切換到32位上下文 3、切換到UI線程&#xff0c;發現UI線程死鎖了 4、使用!locks命令查看臨界區鎖的詳細信息&#xff0c;遇到了問題 5、使用dt命令查看臨界區對象信息&#xff0c;找到發生死鎖的多個線程 6、用戶態鎖與內核態鎖…

防火墻組網方式總結

一、部署模式&#xff1a;靈活適配多樣網絡環境下一代防火墻&#xff08;NGAF&#xff09;具備極強的網絡適應能力&#xff0c;支持五種核心部署模式&#xff0c;可根據不同網絡需求靈活選擇。路由模式&#xff1a;防火墻相當于路由器&#xff0c;位于內外網之間負責路由尋址&a…

AI大模型:(二)5.1 文生視頻(Text-to-Video)模型發展史

目錄 1.介紹 2.發展歷史 2.1.早期探索階段(2015-2019) 2.1.1.技術萌芽期 2.1.2.RNN/LSTM時代 2.2.技術突破期(2020-2021) 2.2.1 Transformer引入視頻生成 2.2.2 擴散模型的興起 2.3.商業化突破期(2022-2023) 2.3.1 產品化里程碑 2.3.2 競爭格局形成 2.4.革命…

14mm尋北儀能否塞進液壓支架生死縫隙?

在煤礦井下世界的方寸之間&#xff0c;液壓支架的每個關鍵節點都承載著千鈞重壓。頂梁鉸接點、立柱頂端、掩護梁角落&#xff0c;恰恰是空間最為局促的“禁區”。ER-MNS-10A MEMS尋北儀應運而生&#xff01;它采用了先進的MEMS陀螺技術&#xff0c;以14mm至薄高度、40g極致輕盈…

python之淺拷貝深拷貝

文章目錄潛拷貝(shallow copy)深拷貝(deep copy)總結一下python的淺拷貝和深拷貝.潛拷貝(shallow copy) python中潛拷貝指的是:構造一個新的復合對象&#xff0c;然后將原對象中的對象引用插入其中 平常開發過程中潛拷貝是比深拷貝更常見的場景. 比如編程中使用到的一些基本的…

普通大學本科生如何入門強化學習?

問題:你平時是如何緊跟大型語言模型和智能體技術前沿的&#xff1f;有哪些具體的學習和跟蹤方式&#xff1f;回答:我會通過“輸入-內化-實踐”結合的方式跟蹤前沿。首先&#xff0c;學術動態方面&#xff0c;每天花10分鐘瀏覽arXiv的http://cs.CL和http://cs.AI板塊&#xff0c…

新手向:Python實現數據可視化圖表生成

Python數據可視化入門&#xff1a;從零開始生成圖表數據可視化是數據分析過程中不可或缺的關鍵環節&#xff0c;它通過將抽象的數字信息轉化為直觀的圖形展示&#xff0c;幫助分析師和決策者更快速、更準確地發現數據中隱藏的模式、規律和發展趨勢。在當今大數據時代&#xff0…

VBA即用型代碼手冊:計算選擇的單詞數Count Words in Selection

我給VBA下的定義&#xff1a;VBA是個人小型自動化處理的有效工具。可以大大提高自己的勞動效率&#xff0c;而且可以提高數據的準確性。我這里專注VBA,將我多年的經驗匯集在VBA系列九套教程中。作為我的學員要利用我的積木編程思想&#xff0c;積木編程最重要的是積木如何搭建及…

DNS(域名系統)

分層結構根域名&#xff08;ipv4&#xff0c;13臺&#xff09;&#xff0c;二級域名&#xff0c;三級域名……相關記錄A將域名解析為ipv4地址AAAA將域名解析為ipv6地址MX指名該區域為郵件服務區PTR反向查詢將主機名解析為域名NS記錄服務器的名字CNAME別名查詢方式遞歸查詢迭代查…

【大模型】強化學習算法總結

角色和術語定義 State&#xff1a;狀態Action&#xff1a;動作Policy/actor model&#xff1a;策略模型&#xff0c;用于決策行動的主要模型Critic/value model&#xff1a;價值模型&#xff0c;用于評判某個行動的價值大小Reward model&#xff1a;獎勵模型&#xff0c;用于給…

基于梅特卡夫定律的開源鏈動2+1模式AI智能名片S2B2C商城小程序價值重構研究

摘要&#xff1a;梅特卡夫定律揭示了網絡價值與用戶數量的平方關系&#xff0c;在互聯網經濟中&#xff0c;連接的深度與形式正因人的參與發生質變。本文以開源鏈動21模式、AI智能名片與S2B2C商城小程序的協同應用為研究對象&#xff0c;通過實證分析其在社群團購、下沉市場等場…

Ubuntu22.04安裝CH340驅動及串口

一、CH340驅動安裝 1.1 查看USB設備能否被識別 CtrlAltT打開終端&#xff1a; lsusb 插入設備前&#xff1a; 插入設備后&#xff1a; 輸出中包含ID 1a86:7523 QinHeng Electronics CH340 serial converter的信息&#xff0c;這表明CH340設備已經被系統識別。 1.2 查看USB轉串…

CPU緩存(CPU Cache)和TLB(Translation Lookaside Buffer)緩存現代計算機體系結構中用于提高性能的關鍵技術

CPU緩存&#xff08;CPU Cache&#xff09;和TLB&#xff08;Translation Lookaside Buffer&#xff09;緩存是現代計算機體系結構中用于提高性能的關鍵技術。它們通過減少CPU訪問數據和指令的延遲來提高系統的整體效率。以下是對這兩者的詳細解釋&#xff1a; 1. CPU 緩存 CPU…

唐揚·高并發系統設計40問

課程下載&#xff1a;https://download.csdn.net/download/m0_66047725/91644703 00開篇詞 _ 為什么你要學習高并發系統設計&#xff1f;.pdf 00開篇詞丨為什么你要學習高并發系統設計&#xff1f;.mp3 01 _ 高并發系統&#xff1a;它的通用設計方法是什么&#xff1f;.pdf …

基于Spring Data Elasticsearch的分布式全文檢索與集群性能優化實踐指南

基于Spring Data Elasticsearch的分布式全文檢索與集群性能優化實踐指南 技術背景與應用場景 隨著大數據時代的到來&#xff0c;海量信息的存儲與檢索成為各類應用的核心需求。Elasticsearch 作為一款分布式搜索引擎&#xff0c;憑借其高可擴展、高可用和實時檢索的優勢&#x…

Linux系統編程——基礎IO

一些前置知識&#xff1a;文件 屬性 內容文件 分為 打開的文件、未打開的文件打開的文件&#xff1a;由進程打開&#xff0c;本質是 進程與文件 的關系&#xff1b;維護的文件對象先加載文件屬性&#xff0c;文件內容一般按需加載未打開的文件&#xff1a;在永久性存儲介質 —…