和鯨社區深度學習基礎訓練營2025年關卡4

使用 pytorch 構建一個簡單的卷積神經網絡(CNN)模型,完成對 CIFAR-10 數據集的圖像分類任務。 直接使用 CNN 進行分類的模型性能。 提示: 數據集:CIFAR-10 網絡結構:可以使用 2-3 層卷積層,ReLU 激活,MaxPooling 層,最后連接全連接層。

#1. 數據預處理與加載
import torch
import torchvision
import torchvision.transforms as transforms# 數據增強與歸一化(使用CIFAR-10官方均值和標準差)
transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4),       # 隨機裁剪增強泛化性transforms.RandomHorizontalFlip(),          # 隨機水平翻轉transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
])transform_test = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
])# 加載數據集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)# 數據加載器
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)#2. CNN模型架構
import torch.nn as nn
import torch.nn.functional as Fclass SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()self.conv1 = nn.Conv2d(3, 32, 3, padding=1)  # 輸入通道3(RGB),輸出32通道self.bn1 = nn.BatchNorm2d(32)                 # 批量歸一化self.conv2 = nn.Conv2d(32, 64, 3, padding=1)self.bn2 = nn.BatchNorm2d(64)self.conv3 = nn.Conv2d(64, 128, 3, padding=1)self.bn3 = nn.BatchNorm2d(128)self.pool = nn.MaxPool2d(2, 2)                # 池化層(尺寸減半)self.fc1 = nn.Linear(128 * 4 * 4, 256)       # 全連接層(輸入尺寸計算:32x32 → 16x16 → 8x8 → 4x4)self.fc2 = nn.Linear(256, 10)                 # 輸出10類def forward(self, x):x = self.pool(F.relu(self.bn1(self.conv1(x))))  # 32x32 → 16x16x = self.pool(F.relu(self.bn2(self.conv2(x))))  # 16x16 → 8x8x = self.pool(F.relu(self.bn3(self.conv3(x))))  # 8x8 → 4x4x = x.view(-1, 128 * 4 * 4)                    # 展平x = F.relu(self.fc1(x))x = self.fc2(x)return x# 實例化模型并移至GPU(若可用)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net = SimpleCNN().to(device)#3. 訓練與優化
import torch.optim as optimcriterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)  # 每5輪學習率×0.1# 訓練循環(10個epoch)
for epoch in range(10):net.train()running_loss = 0.0for i, (inputs, labels) in enumerate(trainloader):inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()outputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()if i % 100 == 99:  # 每100批次打印一次print(f'Epoch [{epoch+1}/10], Step [{i+1}/{len(trainloader)}], Loss: {running_loss/100:.3f}')running_loss = 0.0scheduler.step()  # 更新學習率print(f"Epoch {epoch+1} completed, learning rate: {scheduler.get_last_lr()[0]:.6f}")#4. 模型評估與可視化
net.eval()
correct, total = 0, 0
with torch.no_grad():for (images, labels) in testloader:images, labels = images.to(device), labels.to(device)outputs = net(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = 100 * correct / total
print(f'Test Accuracy: {accuracy:.2f}%')

運行結果:

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/914192.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/914192.shtml
英文地址,請注明出處:http://en.pswp.cn/news/914192.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

前端性能優化全攻略:從加載到渲染

目錄 前言網絡請求優化資源加載優化JavaScript執行優化渲染優化用戶體驗優化性能監控與分析總結 前言 隨著Web應用復雜度不斷提升,前端性能優化變得尤為重要。本文將系統性地介紹從資源加載到頁面渲染的全鏈路性能優化策略,幫助開發者構建高效、流暢的…

hiredis: 一個輕量級、高性能的 C 語言 Redis 客戶端庫

目錄 1.簡介 2.安裝和配置 2.1.源碼編譯安裝(通用方法) 2.2.包管理器安裝(特定系統) 2.3.Windows 安裝 3.常用的函數及功能 3.1.連接管理函數 3.2.命令執行函數 3.3.異步操作函數 3.4.回復處理函數 3.5.錯誤處理 3.6.…

TCP套接字

1.概念套接字是專門進行網絡間數據通信的一種文件類型,可以實現不同主機之間雙向通信,包含了需要交換的數據和通信雙方的IP地址和port端口號。2.套接字文件的創建int socket(int domain, int type, int protocol); 功能:該函數用來創建各種各…

Go語言高并發聊天室(一):架構設計與核心概念

Go語言高并發聊天室(一):架構設計與核心概念 🚀 引言 在當今互聯網時代,實時通信已成為各類應用的核心功能。從微信、QQ到各種在線協作工具,高并發聊天系統的需求無處不在。本系列文章將手把手教你使用Go語…

Java基礎:泛型

什么是泛型? 簡單來說,Java泛型是JDK 5引入的一種特性,它允許你在定義類、接口和方法時使用類型參數(Type Parameters)。這些類型參數可以在編譯時被具體的類型(如 String, Integer, MyCustomClass 等&…

RMSNorm實現

當前Qwen、Llama等系列RMSNorm實現源碼均一致。具體現實如下: class RMSNorm(nn.Module):def __init__(self, hidden_size, eps1e-6):super().__init__()self.weight nn.Parameter(torch.ones(hidden_size))self.variance_epsilon epsdef forward(self, hidden_s…

智能Agent場景實戰指南 Day 11:財務分析Agent系統開發

【智能Agent場景實戰指南 Day 11】財務分析Agent系統開發 文章標簽 AI Agent,財務分析,LLM應用,智能財務,Python開發 文章簡述 本文是"智能Agent場景實戰指南"系列第11篇,聚焦財務分析Agent系統的開發。文章深入解析如何構建一個能夠自動處理財務報表…

人工智能安全基礎復習用:可解釋性

一、可解釋性的核心作用1. 錯誤檢測與模型改進發現模型的異常行為(如過擬合、偏見),優化性能。例:醫療模型中,可解釋性幫助識別誤診原因。2. 安全與可信性關鍵領域(醫療、軍事)需透明決策&#…

Qt:QCustomPlot類介紹

QCustomPlot的核心類就是QCustomPlot類。這個類繼承自QWidget,因此可以像其他QWidget一樣使用,比如放入布局中。QCustomPlot類基本結構一個QCustomPlot對象可以包含多個圖層(通過QCPLayer表示),通常使用默認圖層。它包…

Visual Studio 2022 上使用ffmpeg

目錄 1. 添加包含目錄 2. 添加庫目錄 3. 添加依賴項 4. 添加動態庫目錄 5. 測試 在解決方案中右擊項目名稱,彈出的窗口中選擇 "屬性"。 1. 添加包含目錄 "C/C" -> "常規" -> "附加包含目錄"中添加 ffmpeg中的…

Elasticsearch 線程池

Elasticsearch 線程池「每個線程池到底采用哪種實現策略」:Elasticsearch 線程池(ThreadPool)中 **所有內置線程池名稱的常量定義**。 每個字符串常量對應一個 **線程池的名字(name)**,也就是你在 Thread…

深入理解 Next.js API 路由:構建全棧應用的終極指南

Next.js 是一個強大的 React 框架,不僅支持服務端渲染(SSR)和靜態站點生成(SSG),還提供了內置的 API 路由功能,使開發者能夠輕松構建全棧應用。傳統的全棧開發通常需要單獨搭建后端服務&#xf…

【6.1.2 漫畫分布式事務技術選型】

漫畫分布式事務技術選型 🎯 學習目標:掌握架構師核心技能——分布式事務技術選型與一致性解決方案,構建高可靠的分布式系統 🎭 第一章:分布式事務模式對比 🤔 2PC vs 3PC vs TCC vs Saga 想象分布式事務就…

液冷智算數據中心崛起,AI算力聯動PC Farm與云智算開拓新藍海(二)

從算法革新到基礎設施升級,從行業滲透到地域布局,人工智能算力正以 “規模擴張 效率提升”雙輪驅動中國數字經濟轉型。中國智能算力規模將在 2025 年突破 1000 EFLOPS,2028 年達到 2781.9 EFLOPS,五年復合增長率 46.2%&#xff0…

《QtPy:Python與Qt的完美橋梁》

QtPy 是什么 在 Python 的廣袤編程宇宙中,當涉及到圖形用戶界面(GUI)開發,Qt 框架宛如一顆璀璨的明星,散發著獨特的魅力。而 QtPy,作為 Python 與 Qt 生態系統交互中的關鍵角色,更是為開發者們開…

ubuntu環境下調試 RT-Thread

調試 RT-Thread 下載源碼 github 搜索 RT-Thread 下載源碼 安裝 python scons 環境 你已經安裝了 kconfiglib,但 scons --menuconfig 仍然提示找不到它。這種情況通常是由于 Python 環境不一致 導致的:你在一個 Python 環境中安裝了 kconfiglib&#xff…

【數據結構初階】--順序表(二)

🔥個人主頁:草莓熊Lotso 🎬作者簡介:C研發方向學習者 📖個人專欄: 《C語言》 《數據結構與算法》《C語言刷題集》《Leetcode刷題指南》 ??人生格言:生活是默默的堅持,毅力是永久的…

Java中的方法傳參機制

1. 概述Java中的方法傳參機制分為兩種:值傳遞(Pass by Value) 和 引用傳遞(Pass by Reference)。然而,Java中所有的參數傳遞都是值傳遞,只不過對于對象來說,傳遞的是對象的引用地址的…

C++——this關鍵字和new關鍵字

一、this 關鍵字1. 什么是 this?this 是 C 中的一個隱式指針,它指向當前對象(即調用成員函數的對象),在成員函數內部使用,用于引用調用該函數的對象。每個類的非靜態成員函數內部都可以使用 this。使用 thi…

Python中類靜態方法:@classmethod/@staticmethod詳解和實戰示例

在 Python 中,類方法 (classmethod) 和靜態方法 (staticmethod) 是類作用域下的兩種特殊方法。它們使用裝飾器定義,并且與實例方法 (def func(self)) 的行為有所不同。1. 三種方法的對比概覽方法類型是否訪問實例 (self)是否訪問類 (cls)典型用途實例方法…