python打卡day40

知識點回顧:

  1. 彩色和灰度圖片測試和訓練的規范寫法:封裝在函數中
  2. 展平操作:除第一個維度batchsize外全部展平
  3. dropout操作:訓練階段隨機丟棄神經元,測試階段eval模式關閉dropout

導入包

# 先繼續之前的代碼
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader , Dataset # DataLoader 是 PyTorch 中用于加載數據的工具
from torchvision import datasets, transforms # torchvision 是一個用于計算機視覺的庫,datasets 和 transforms 是其中的模塊
import matplotlib.pyplot as plt
import warnings
# 忽略警告信息
warnings.filterwarnings("ignore")
# 設置隨機種子,確保結果可復現
torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用設備: {device}")

數據預處理和模型定義

# 1. 數據預處理
transform = transforms.Compose([transforms.ToTensor(),  # 轉換為張量并歸一化到[0,1]transforms.Normalize((0.1307,), (0.3081,))  # MNIST數據集的均值和標準差
])# 2. 加載MNIST數據集
train_dataset = datasets.MNIST(root='./data',train=True,download=True,transform=transform
)test_dataset = datasets.MNIST(root='./data',train=False,transform=transform
)# 3. 創建數據加載器
batch_size = 64  # 每批處理64個樣本
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)# 4. 定義模型、損失函數和優化器
class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.flatten = nn.Flatten()  # 將28x28的圖像展平為784維向量self.layer1 = nn.Linear(784, 128)  # 第一層:784個輸入,128個神經元self.relu = nn.ReLU()  # 激活函數self.layer2 = nn.Linear(128, 10)  # 第二層:128個輸入,10個輸出(對應10個數字類別)def forward(self, x):x = self.flatten(x)  # 展平圖像x = self.layer1(x)   # 第一層線性變換x = self.relu(x)     # 應用ReLU激活函數x = self.layer2(x)   # 第二層線性變換,輸出logitsreturn x# 初始化模型
model = MLP()
model = model.to(device)  # 將模型移至GPU(如果可用)# from torchsummary import summary  # 導入torchsummary庫
# print("\n模型結構信息:")
# summary(model, input_size=(1, 28, 28))  # 輸入尺寸為MNIST圖像尺寸criterion = nn.CrossEntropyLoss()  # 交叉熵損失函數,適用于多分類問題
optimizer = optim.Adam(model.parameters(), lr=0.001)  # Adam優化器

訓練定義

# 5. 訓練模型(記錄每個 iteration 的損失)
def train(model, train_loader, test_loader, criterion, optimizer, device, epochs):model.train()  # 設置為訓練模式# 新增:記錄每個 iteration 的損失all_iter_losses = []  # 存儲所有 batch 的損失iter_indices = []     # 存儲 iteration 序號(從1開始)for epoch in range(epochs):running_loss = 0.0correct = 0total = 0for batch_idx, (data, target) in enumerate(train_loader):# enumerate() 是 Python 內置函數,用于遍歷可迭代對象(如列表、元組)并同時獲取索引和值。# batch_idx:當前批次的索引(從 0 開始)# (data, target):當前批次的樣本數據和對應的標簽,是一個元組,這是因為dataloader內置的getitem方法返回的是一個元組,包含數據和標簽。# 只需要記住這種固定寫法即可data, target = data.to(device), target.to(device)  # 移至GPU(如果可用)optimizer.zero_grad()  # 梯度清零output = model(data)  # 前向傳播loss = criterion(output, target)  # 計算損失loss.backward()  # 計算optimizer.step()  # 更新參數# 記錄當前 iteration 的損失(注意:這里直接使用單 batch 損失,而非累加平均)iter_loss = loss.item()all_iter_losses.append(iter_loss)iter_indices.append(epoch * len(train_loader) + batch_idx + 1)  # iteration 序號從1開始# 統計準確率和損失running_loss += loss.item() #將loss轉化為標量值并且累加到running_loss中,計算總損失_, predicted = output.max(1) # output:是模型的輸出(logits),形狀為 [batch_size, 10](MNIST 有 10 個類別)# 獲取預測結果,max(1) 返回每行(即每個樣本)的最大值和對應的索引,這里我們只需要索引total += target.size(0) # target.size(0) 返回當前批次的樣本數量,即 batch_size,累加所有批次的樣本數,最終等于訓練集的總樣本數correct += predicted.eq(target).sum().item() # 返回一個布爾張量,表示預測是否正確,sum() 計算正確預測的數量,item() 將結果轉換為 Python 數字# 每100個批次打印一次訓練信息(可選:同時打印單 batch 損失)if (batch_idx + 1) % 100 == 0:print(f'Epoch: {epoch+1}/{epochs} | Batch: {batch_idx+1}/{len(train_loader)} 'f'| 單Batch損失: {iter_loss:.4f} | 累計平均損失: {running_loss/(batch_idx+1):.4f}')# 測試、打印 epoch 結果epoch_train_loss = running_loss / len(train_loader)epoch_train_acc = 100. * correct / totalepoch_test_loss, epoch_test_acc = test(model, test_loader, criterion, device)print(f'Epoch {epoch+1}/{epochs} 完成 | 訓練準確率: {epoch_train_acc:.2f}% | 測試準確率: {epoch_test_acc:.2f}%')# 繪制所有 iteration 的損失曲線plot_iter_losses(all_iter_losses, iter_indices)# 保留原 epoch 級曲線(可選)# plot_metrics(train_losses, test_losses, train_accuracies, test_accuracies, epochs)return epoch_test_acc  # 返回最終測試準確率

測試定義

# 6. 測試模型(不變)
def test(model, test_loader, criterion, device):model.eval()  # 設置為評估模式test_loss = 0correct = 0total = 0with torch.no_grad():  # 不計算梯度,節省內存和計算資源for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)test_loss += criterion(output, target).item()_, predicted = output.max(1)total += target.size(0)correct += predicted.eq(target).sum().item()avg_loss = test_loss / len(test_loader)accuracy = 100. * correct / totalreturn avg_loss, accuracy  # 返回損失和準確率

分析圖像定義

# 7. 繪制每個 iteration 的損失曲線
def plot_iter_losses(losses, indices):plt.figure(figsize=(10, 4))plt.plot(indices, losses, 'b-', alpha=0.7, label='Iteration Loss')plt.xlabel('Iteration(Batch序號)')plt.ylabel('損失值')plt.title('每個 Iteration 的訓練損失')plt.legend()plt.grid(True)plt.tight_layout()plt.show()

流程

# 8. 執行訓練和測試(設置 epochs=2 驗證效果)
epochs = 2  
print("開始訓練模型...")
final_accuracy = train(model, train_loader, test_loader, criterion, optimizer, device, epochs)
print(f"訓練完成!最終測試準確率: {final_accuracy:.2f}%")

@浙大疏錦行

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/82010.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/82010.shtml
英文地址,請注明出處:http://en.pswp.cn/web/82010.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

系統性學習C語言-第十二講-深入理解指針(2)

系統性學習C語言-第十二講-深入理解指針(2) 1. const 修飾指針1.1 const 修飾變量1.2 const 修飾指針變量 2. 野指針2.1 野指針成因2.2 如何規避野指針2.2.1 指針初始化2.2.2 小心指針越界2.2.3 指針變量不再使用時,及時置 NULL &…

《高等數學》(同濟大學·第7版) 第一節《映射與函數》超詳細解析

集合(Set)—— 最基礎的數學容器 定義: 集合是由確定的、互不相同的對象(稱為元素)組成的整體。 表示方法: 列舉法:A {1, 2, 3} 描述法:B {x | x > 0}(表示所有大于…

Spring Boot整活指南:從Helo World到“真香”定律

📌 一、Spring Boot的"真香"本質(不是996的福報) 你以為Spring Boot只是個簡化配置的工具?Too young!它其實是程序員的??摸魚加速器??。 ??經典場景還原??: 產品經理:“這個…

打字練習:平臺推薦

1.打字練習 . 1)平臺推薦 下面推薦兩個打字練習平臺 Keybr:https://www.keybr.com/ TypingClub:https://www.edclub.com/sportal/ . 2)平臺對比 特性KeybrTypingClub核心優勢AI智能弱項訓練結構化課程體系適合人群開發者/…

ASP.NET Core 中JWT的基本使用

文章目錄 前言一、JWT與RBAC二、JWT 的作用三、RBAC 的核心思想四、使用1、配置文件 (appsettings.json)2、JWT配置模型 (Entity/JwtSettings.cs)3、服務擴展類,JWT配置 (Extensions/ServiceExtensions.cs)4、用戶倉庫接口服務5、認證服務 (Interface/IAuthService.…

(19)java在區塊鏈中的應用

🔗 Java在區塊鏈中的應用:智能合約開發全攻略 TL;DR: Java在區塊鏈領域主要通過Hyperledger Fabric、Web3j和專用JVM實現智能合約開發,相比Solidity具有更強的企業級支持和開發效率,但在執行效率和Gas消耗方面存在差異&#xff0c…

深入理解設計模式之訪問者模式

深入理解設計模式之訪問者模式(Visitor Pattern) 一、什么是訪問者模式? 訪問者模式(Visitor Pattern)是一種行為型設計模式。它的主要作用是將數據結構與數據操作分離,使得在不改變數據結構的前提下&…

div或button一些好看實用的 CSS 樣式示例

1:現代漸變按鈕 .count {width: 800px;background: linear-gradient(135deg, #72EDF2 0%, #5151E5 100%);padding: 12px 24px;border-radius: 10px;box-shadow: 0 4px 15px rgba(81, 81, 229, 0.3);color: white;font-weight: bold;border: none;cursor: pointer;t…

【基于STM32的新能源汽車智能循跡系統開發全解析】

基于STM32的新能源汽車智能循跡系統開發全解析(附完整工程代碼) 作者聲明 作者: 某新能源車企資深嵌入式工程師(專家認證) 技術方向: 智能駕駛底層控制 | 車規級嵌入式開發 原創聲明: 本文已申…

HTML Day02

Day02 0. 引言1. 文本格式化1.1 HTML文本格式化標簽1.2 HTML"計算機輸出"標簽1.3 HTML 引文,引用及標簽定義 2. HTML鏈接2.1鏈接跳轉原理(有點亂可跳過)2.2 HTML超鏈接2.3 target屬性2.4 id屬性2.4.1 id屬性在頁面內和不同頁面的定…

MIT 6.S081 2020 Lab6 Copy-on-Write Fork for xv6 個人全流程

文章目錄 零、寫在前面一、Implement copy-on write1.1 說明1.2 實現1.2.1 延遲復制與釋放1.2.2 寫時復制 零、寫在前面 可以閱讀下 《xv6 book》 的第五章中斷和設備驅動。 問題 在 xv6 中,fork() 系統調用會將父進程的整個用戶空間內存復制到子進程中。**如果父…

xhr、fetch和axios

XMLHttpRequest (XHR) XMLHttpRequest 是最早用于在瀏覽器中進行異步網絡請求的 API。它允許網頁在不刷新整個頁面的情況下與服務器交換數據。 // 創建 XHR 對象 const xhr new XMLHttpRequest();// 初始化請求 xhr.open(GET, https://api.example.com/data, true);// 設置請…

電腦驅動程序更新工具, 3DP Chip 中文綠色版,一鍵更新驅動!

介紹 3DP Chip 是一款免費的驅動程序更新工具,可以幫助用戶快速、方便地識別和更新計算機硬件驅動程序。 驅動程序更新工具下載 https://pan.quark.cn/s/98895d47f57c 軟件截圖 軟件特點 簡單易用:用戶界面簡潔明了,操作方便,…

機器學習與深度學習06-決策樹02

目錄 前文回顧5.決策樹中的熵和信息增益6.什么是基尼不純度7.決策樹與回歸問題8.隨機森林是什么 前文回顧 上一篇文章地址:鏈接 5.決策樹中的熵和信息增益 熵和信息增益是在決策樹中用于特征選擇的重要概念,它們幫助選擇最佳特征進行劃分。 熵&#…

【Kotlin】數字字符串數組集合

【Kotlin】簡介&變量&類&接口 【Kotlin】數字&字符串&數組&集合 文章目錄 Kotlin_數字&字符串&數組&集合數字字面常量顯式轉換數值類型轉換背后發生了什么 運算字符串字符串模板字符串判等修飾符數組集合通過序列提高效率惰性求值序列的操…

oscp練習PG Monster靶機復現

端口掃描 nmap -A -p- -T4 -Pn 192.168.134.180 PORT STATE SERVICE VERSION 80/tcp open http Apache httpd 2.4.41 ((Win64) OpenSSL/1.1.1c PHP/7.3.10) |_http-server-header: Apache/2.4.41 (Win64) OpenSSL/1.1.1c PHP/7.3.10 | http-methods:…

近期知識庫開發過程中遇到的一些問題

我們正在使用Rust開發一個知識庫系統,遇到了一些問題,在此記錄備忘。 錯誤:Unable to make method calls because underlying connection is closed 場景:在docker中調用headless_chrome時出錯 原因:為減小鏡像大小&am…

Ubuntu 22.04 系統下 Docker 安裝與配置全指南

Ubuntu 22.04 系統下 Docker 安裝與配置全指南 一、前言 Docker 作為現代開發中不可或缺的容器化工具,能極大提升應用部署和環境管理的效率。本文將詳細介紹在 Ubuntu 22.04 系統上安裝與配置 Docker 的完整流程,包括環境準備、安裝步驟、權限配置及鏡…

C#獲取磁盤容量:代碼實現與應用場景解析

C#獲取磁盤容量:代碼實現與應用場景解析 在軟件開發過程中,尤其是涉及文件存儲、數據備份等功能時,獲取磁盤容量信息是常見的需求。通過獲取磁盤的可用空間和總大小,程序可以更好地進行資源管理、預警提示等操作。在 C# 語言中&a…

2025年- H56-Lc164--200.島嶼數量(圖論,深搜)--Java版

1.題目描述 2.思路 (1)主函數,存儲圖結構 (2)主函數,visit數組表示已訪問過的元素 (3)輔助函數,用遞歸(深搜),遍歷以已訪問過的元素&…