【項目實戰】——深度學習.全連接神經網絡

目錄

1.使用全連接網絡訓練和驗證MNIST數據集

2.使用全連接網絡訓練和驗證CIFAR10數據集


1.使用全連接網絡訓練和驗證MNIST數據集

import torch
from torch import nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch import optim
from PIL import Image
import os# 數據預處理
transform = transforms.Compose([transforms.ToTensor()])# 數據準備
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
eval_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
eval_loader = DataLoader(dataset=eval_dataset, batch_size=512, shuffle=True)# 定義網絡結構
class MyNet(nn.Module):def __init__(self):super(MyNet, self).__init__()self.fc1 = nn.Linear(784, 256)self.bn1 = nn.BatchNorm1d(256)self.relu = nn.ReLU()self.fc2 = nn.Linear(256, 128)self.bn2 = nn.BatchNorm1d(128)self.fc3 = nn.Linear(128, 10)def forward(self, x):x = x.view(-1, 28 * 28)x = self.bn1(self.fc1(x))x = self.relu(x)x = self.bn2(self.fc2(x))x = self.relu(x)x = self.fc3(x)return xmodel = MyNet()
# 定義損失函數
criterion = nn.CrossEntropyLoss()
# 定義優化器
optimizer = optim.Adam(model.parameters(), lr=0.001)# 訓練
def train(model, train_loader, epochs):model.train()for epoch in range(epochs):correct = 0for data, target in train_loader:output = model(data)loss = criterion(output, target)optimizer.zero_grad()loss.backward()optimizer.step()_, predicted = torch.max(output.data, 1)correct += (predicted.eq(target)).sum().item()correct /= len(train_loader.dataset)print(f'Train Epoch:  {epoch} , loss: {loss.item():.4f},acc:{correct:.4f}')# 驗證
def eval(model, eval_loader):model.eval()eval_loss = 0correct = 0with torch.no_grad():for data, target in eval_loader:output = model(data)eval_loss += criterion(output, target).item()_, predicted = torch.max(output.data, 1)correct += (predicted.eq(target)).sum().item()eval_loss /= len(eval_loader.dataset)acc = 100.0 * correct / len(eval_loader.dataset)print(f'loss: {eval_loss:.4f}, acc: {acc:.4f}')# 保存模型
def save_model():torch.save(model.state_dict(), 'mnist_fc_model.pt')# 預測
def predict(img_path):model = MyNet()model.load_state_dict(torch.load('mnist_fc_model.pt'))model.eval()img = Image.open(img_path).convert('L')transform = transforms.Compose([transforms.Resize((28, 28)),transforms.ToTensor()])t_img = transform(img).unsqueeze(0)print(t_img.shape)with torch.no_grad():output = model(t_img)_, predicted = torch.max(output.data, 1)print(predicted.item())epochs = 5train(model, train_loader, epochs)
eval(model, eval_loader)save_model()img_path = './img/7.png'
predict(img_path)

2.使用全連接網絡訓練和驗證CIFAR10數據集

import torch
from torch import nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch import optim# 數據預處理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])# 數據準備
train_dataset = datasets.CIFAR10(root='./cifar10', train=True, transform=transform, download=True)
eval_dataset = datasets.CIFAR10(root='./cifar10', train=False, transform=transform, download=True)train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
eval_loader = DataLoader(dataset=eval_dataset, batch_size=512, shuffle=True)# 定義網絡結構
class MyNet(nn.Module):def __init__(self):super(MyNet, self).__init__()self.fc1 = nn.Linear(32 * 32 * 3, 1024)self.bn1 = nn.BatchNorm1d(1024)self.dropout1 = nn.Dropout(0.3)self.fc2 = nn.Linear(1024, 512)self.bn2 = nn.BatchNorm1d(512)self.dropout2 = nn.Dropout(0.3)self.fc3 = nn.Linear(512, 256)  # 增加第三層self.bn3 = nn.BatchNorm1d(256)self.fc4 = nn.Linear(256, 10)self.relu = nn.ReLU()def forward(self, x):x = x.view(-1, 32 * 32 * 3)x = self.dropout1(self.bn1(self.fc1(x)))x = self.relu(x)x = self.dropout2(self.bn2(self.fc2(x)))x = self.relu(x)x = self.bn3(self.fc3(x))x = self.relu(x)x = self.fc4(x)return xmodel = MyNet()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)def train(model, train_loader, epochs):model.train()for epoch in range(epochs):correct = 0for data, target in train_loader:data, target = data.to(device), target.to(device)output = model(data)loss = criterion(output, target)optimizer.zero_grad()loss.backward()optimizer.step()_, predicted = torch.max(output.data, 1)correct += (predicted.eq(target)).sum().item()correct /= len(train_loader.dataset)print(f'Train Epoch:  {epoch} , loss: {loss.item():.4f},acc:{correct:.4f}')def eval(model, eval_loader):model.eval()eval_loss = 0correct = 0with torch.no_grad():for data, target in eval_loader:data, target = data.to(device), target.to(device)output = model(data)eval_loss += criterion(output, target).item()_, predicted = torch.max(output.data, 1)correct += (predicted.eq(target)).sum().item()eval_loss /= len(eval_loader.dataset)acc = 100.0 * correct / len(eval_loader.dataset)print(f'loss: {eval_loss:.4f}, acc: {acc:.4f}')epochs = 25train(model, train_loader, epochs)
eval(model, eval_loader)

思考:為什么CIFAR10數據集的準確率很低?

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/916043.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/916043.shtml
英文地址,請注明出處:http://en.pswp.cn/news/916043.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

嵌入式學習的第三十四天-進程間通信-TCP

一、TCPTCP : 傳輸控制協議 傳輸層1. TCP特點(1).面向連接,避免部分數據丟失 (2).安全、可靠 (3).面向字節流 (4).占用資源開銷大2.TCP安全可靠機制三次握手:指建立tcp連接時,需要客戶端和服務端總共發送三次報文確認連接。確保雙方均已做好 收發…

【爬蟲】06 - 自動化爬蟲selenium

自動化爬蟲selenium 文章目錄自動化爬蟲selenium一:Selenium簡介1:什么是selenium2:安裝準備二:元素定位1:id 定位2:name 定位3:class 定位4:tag 定位5:xpath 定位(最常用…

2025年中國移動鴻鵠大數據實訓營(大數據方向)kafka講解及實踐-第2次作業指導

書接上回,第二次作業比較容易解決,我問了ai,讓他對我進行指導,按照它提供的步驟,我完成了本次實驗,接下來我會標注出需要注意的細節,指導大家完成此次任務。 🎯 一、作業目標 ??…

三十七、【高級特性篇】定時任務:基于 APScheduler 實現測試計劃的靈活調度

三十七、【高級特性篇】定時任務:基于 APScheduler 實現測試計劃的靈活調度 前言 準備工作 第一部分:后端實現 - `APScheduler` 集成與任務調度 1. 安裝 `django-apscheduler` 2. 配置 `django-apscheduler` 3. 數據庫遷移 4. 創建調度觸發函數 5. 啟動 APScheduler 調度器 6…

RabbitMQ--消息順序性

看本章之前強烈建議先去看博主的這篇博客 RabbitMQ--消費端單線程與多線程-CSDN博客 一、消息順序性概念 消息順序性是指消息在生產者發送的順序和消費者接收處理的順序保持一致。 二、RabbitMQ 順序性保證機制 情況順序保證情況備注單隊列,單消費者消息嚴格按發送順…

.net core接收對方傳遞的body體里的json并反序列化

1、首先我在通用程序里有一個可以接收對象型和數組型json串的反序列化方法public static async Task<Dictionary<string, string>> AllParameters(this HttpRequest request){Dictionary<string, string> parameters QueryParameters(request);request.Enab…

(10)機器學習小白入門 YOLOv:YOLOv8-cls 模型評估實操

YOLOv8-cls 模型評估實操 (1)機器學習小白入門YOLOv &#xff1a;從概念到實踐 (2)機器學習小白入門 YOLOv&#xff1a;從模塊優化到工程部署 (3)機器學習小白入門 YOLOv&#xff1a; 解鎖圖片分類新技能 (4)機器學習小白入門YOLOv &#xff1a;圖片標注實操手冊 (5)機器學習小…

Vue 腳手架基礎特性

一、ref屬性1.被用來給元素或子組件注冊引用信息&#xff08;id的替代者&#xff09;2.應用在html標簽上獲取的是真實DOM元素&#xff0c;用在組件標簽上是組件實例對象3.使用方式&#xff1a;(1).打標識&#xff1a;<h1 ref"xxx">...</h1> 或 <Schoo…

Ubuntu安裝k8s集群入門實踐-v1.31

準備3臺虛擬機 在自己電腦上使用virtualbox 開了3臺1核2G的Ubuntu虛擬機&#xff0c;你可以先安裝好一臺&#xff0c;安裝第一臺的時候配置臨時調高到2核4G&#xff0c;安裝速度會快很多&#xff0c;安裝完通過如下命令關閉桌面&#xff0c;能夠省內存占用&#xff0c;后面我們…

Word Press富文本控件的保存

新建富文本編輯器&#xff0c;并編寫save方法如下&#xff1a; edit方法&#xff1a; export default function Edit({ attributes, setAttributes }) {return (<><div { ...useBlockProps() }><RichTexttagNameponChange{ (value) > setAttributes({ noteCo…

【編程趣味游戲】:基于分支循環語句的猜數字、關機程序

&#x1f31f;菜鳥主頁&#xff1a;晨非辰的主頁 &#x1f440;學習專欄&#xff1a;《C語言學習》 &#x1f4aa;學習階段&#xff1a;C語言方向初學者 ?名言欣賞&#xff1a;"編程的核心是實踐&#xff0c;而非空談" 目錄 1. 游戲1--猜數字 1.1 rand函數 1.2 sr…

UE5 UI 控件切換器

文章目錄分類作用屬性分類 面板 作用 可以根據索引切換要顯示哪個子UI&#xff0c;可以擁有多個子物體&#xff0c;但是任何時間只能顯示一個 屬性 在這里指定要顯示的UI的索引

scikit-learn 包

文章目錄scikit-learn 包核心功能模塊案例其他用法**常用功能詳解****(1) 分類任務示例&#xff08;SVM&#xff09;****(2) 回歸任務示例&#xff08;線性回歸&#xff09;****(3) 聚類任務示例&#xff08;K-Means&#xff09;****(4) 特征工程&#xff08;PCA降維&#xff0…

Excel 將數據導入到SQLServer數據庫

一般系統上線前期都會導入期初數據&#xff0c;業務人員一般要求你提供一個Excel模板&#xff0c;業務人員根據要求整理數據。SQLServer管理工具是支持批量導入數據的&#xff0c;所以我們可以使用該工具導入期初。Excel格式 第一行為字段1、連接登入的數據庫并且選中你需要導入…

剪枝和N皇后在后端項目中的應用

剪枝算法&#xff08;Pruning Algorithm&#xff09; 生活比喻&#xff1a;就像修剪樹枝一樣&#xff0c;把那些明顯不會結果的枝條提前剪掉&#xff0c;節省養分。 在后端項目中的應用場景&#xff1a; 搜索優化&#xff1a;在商品搜索中&#xff0c;如果某個分類下沒有符合條…

cocos 2d游戲中多邊形碰撞器會觸發多次,怎么解決

子彈打到敵機 一發子彈擊中&#xff0c;碰撞回調多次執行 我碰撞組件原本是多邊形碰撞組件 PolygonCollider2D&#xff0c;我改成盒碰撞組件BoxCollider2D 就好了 用前端的節流方式。或者loading處理邏輯。我測試過了&#xff0c;是可以 本來就是多次啊,設計上貌似就是這樣的…

Kubernetes環境中GPU分配異常問題深度分析與解決方案

Kubernetes環境中GPU分配異常問題深度分析與解決方案 一、問題背景與核心矛盾 在基于Kubernetes的DeepStream應用部署中&#xff0c;GPU資源的獨占性分配是保障應用性能的關鍵。本文將圍繞一個典型的GPU分配異常問題展開分析&#xff1a;多個請求GPU的容器本應獨占各自的GPU&…

Django與模板

我叫補三補四&#xff0c;很高興見到大家&#xff0c;歡迎一起學習交流和進步今天來講一講視圖Django與模板文件工作流程模板引擎&#xff1a;主要參與模板渲染的系統。內容源&#xff1a;輸入的數據流。比較常見的有數據庫、XML文件和用戶請求這樣的網絡數據。模板&#xff1a…

日本上市IT企業|8月25日將在大連舉辦赴日it招聘會

株式會社GSD的核心戰略伙伴貝斯株式會社&#xff0c;將于2025年8月25日在大連香格里拉大酒店商務會議室隆重舉辦赴日技術人才專場招聘會。本次招聘會面向全國范圍內的優秀IT人才&#xff0c;旨在為貝斯株式會社東京本社長期發展招募優質的系統開發與管理人才。招聘計劃&#xf…

低功耗設計雙目協同畫面實現光學變焦內帶AI模型

低功耗設計延長續航&#xff0c;集成儲能模塊保障陰雨天氣下的鐵塔路線的安全一、智能感知與識別技術 多光譜融合監控結合可見光、紅外熱成像、激光補光等技術&#xff0c;實現全天候監測。例如&#xff0c;紅外熱成像可穿透雨霧監測山火隱患&#xff0c;激光補光技術則解決夜間…