前饋神經網絡多分類任務

pytorch深度學習的套路都差不多,多看多想多寫多測試,自然就會了。主要的技術還是在于背后的數學思想和數學邏輯。

廢話不多說,上代碼自己看。

import torch
import numpy as np
import torch.nn as nn
import torchvision
import torchvision.transforms as transformsclass Network(nn.Module):def __init__(self ,input_dim ,hidden_dim ,out_dim):super().__init__()self.layer1 = nn.Sequential(  # 全連接層     [1, 28, 28]nn.Linear(784, 400),       # 輸入維度,輸出維度nn.BatchNorm1d(400),  # 批標準化,加快收斂,可不需要nn.ReLU()  				 # 激活函數)self.layer2 = nn.Sequential(nn.Linear(400, 200),nn.BatchNorm1d(200),nn.ReLU())self.layer3 = nn.Sequential(   # 全連接層nn.Linear(200, 100),nn.BatchNorm1d(100),nn.ReLU())self.layer4 = nn.Sequential(   # 最后一層為實際輸出,不需要激活函數,因為有 10 個數字,所以輸出維度為 10,表示10 類nn.Linear(100, 10),)def forward(self ,x):x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)output = self.layer4(x)return output
def get_num_correct(preds, labels):return (preds.argmax(dim=1) == labels).sum().item()def dropout(x, keep_prob = 0.5):'''np.random.binomial 當輸入二維數組時,按行按列(每個維度)都是按照給定概率生成1的個數,
比如 輸入 10 * 6的矩陣,按照0.5的概率生成1 那么每列都大概會有5個1,每行大概會有3個1,
其實就不用考慮按行drop或者按列drop,相當于每行生成的mask都是不一樣的,那么矩陣中每行的元素(代表一層中的神經元)都是按照不同的mask失活的
當矩陣形狀改變行列代表的意義不一樣時,由于每行每列(各個維度)的1的個數都是按照prob留存的,因此對結果沒有影響。'''mask = torch.from_numpy(np.random.binomial(1,keep_prob,x.shape))return x * mask / keep_probif __name__ == "__main__":train_set = torchvision.datasets.MNIST(root='./data', train=True, download=False, transform=transforms.Compose([transforms.ToTensor()]))test_set = torchvision.datasets.MNIST(root='./data',train=False,download=False,transform=transforms.Compose([transforms.ToTensor()]))train_loader = torch.utils.data.DataLoader(train_set, batch_size=512, shuffle=True)test_loader = torch.utils.data.DataLoader(test_set, batch_size=512, shuffle=True)net = Network(28 * 28, 256, 10)optimizer = torch.optim.SGD(net.parameters(), lr=0.01)criterion = nn.CrossEntropyLoss()epoch = 10for i in range(epoch):train_accur = 0.0train_loss = 0.0for batch in train_loader:images, labels = batch#images, labels = images.to(device), labels.to(device)images = images.squeeze(1).reshape(images.shape[0], -1)preds = net(images)optimizer.zero_grad()loss = criterion(preds, labels)loss.backward()optimizer.step()train_loss += loss.item()train_accur += get_num_correct(preds, labels)print("loss :" + str(train_loss) + "train accur:" + str(train_accur * 1.0 / 60000))global correctwith torch.no_grad():correct = 0for batch in test_loader:images, labels = batch#images, labels = images.to(device), labels.to(device)images = images.squeeze(1).reshape(-1, 784)preds = net(images)preds = preds.argmax(dim=1)correct += (preds == labels).sum()print(correct)print(correct.item() * 1.0 / len(test_set))

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/41957.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/41957.shtml
英文地址,請注明出處:http://en.pswp.cn/news/41957.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

【騰訊云Cloud Studio實戰訓練營】使用Cloud Studio社區版快速構建React完成點餐H5頁面還原

陳老老老板🦸 👨?💻本文專欄:生活(主要講一下自己生活相關的內容) 👨?💻本文簡述:生活就像海洋,只有意志堅強的人,才能到達彼岸。 👨?💻上一篇…

成集云 | 用友U8采購請購單同步釘釘 | 解決方案

源系統成集云目標系統 方案介紹 用友U8是中國用友集團開發和推出的一款企業級管理軟件產品。具有豐富的功能模塊,包括財務管理、采購管理、銷售管理、庫存管理、生產管理、人力資源管理、客戶關系管理等,可根據企業的需求選擇相應的模塊進行集…

什么是原子交換?

安全地在各個區塊鏈網絡之間傳輸資產對于釋放被困流動性并吸引更多用戶進入這一領域至關重要,同時也保持 Web3 的信任最小化核心價值。原子交換是一種讓兩個人在不依賴于中介來促成交易的情況下,在不同的區塊鏈網絡之間交換通證資產的方式。這為 DeFi 用…

Linux硬鏈接和軟連接

1、硬鏈接 硬連接指通過索引節點來進行連接。在 Linux 的文件系統中,保存在磁盤分區中的文件不管是什么類型都給它分配一個編號,稱為索引節點號(Inode Index)。在 Linux 中,多個文件名指向同一索引節點是存在的。比如:A 是 B 的硬…

數據結構之隊列詳解(包含例題)

一、隊列的概念 隊列是一種特殊的線性表,特殊之處在于它只允許在表的前端(front)進行刪除操作,而在表的后端(rear)進行插入操作,和棧一樣,隊列是一種操作受限制的線性表。進行插入操…

【Windows 常用工具系列 5 -- Selenium IDE的使用方法 】

文章目錄 Selenium 介紹Selenium IDE 介紹 Selenium IDE安裝Chrome 瀏覽器安裝Selenium IDE使用 Selenium 介紹 Selenium是一個用于Web應用程序測試的工具。Selenium測試直接運行在瀏覽器中,就像真正的用戶在操作一樣。 Selenium家庭成員有三個,分別是S…

Ubuntu 20.04 與 ROS noetic安裝 gtsam 編譯 LIO-SAM 的適配版本

Ubuntu 20.04 基于 ROS noetic安裝 gtsam, 編譯 LIO-SAM 的適配版本 摘要安裝GTSAM(ros-noetic-gtsam版本)編譯LIO-SAM的適配版本 摘要 本文簡介在 Ubuntu 20.04 下以 ROS noetic 為基礎安裝 GTSAM 并成功編譯 LIO-SAM 的適配版本。 安裝GTSAM(ros-noetic-gtsam版…

騰訊云國際站代充-阿里云ECS怎么一鍵遷移到騰訊云cvm?

今天主要來介紹一下如何通過阿里云國際ECS控制臺一鍵遷移至騰訊云國際CVM。騰訊云國際站云服務器CVM提供全面廣泛的服務內容。無-需-綁-定PayPal,代-充-值騰訊云國際站、阿里云國際站、AWS亞馬遜云、GCP谷歌云,官方授權經銷商!靠譜&#xff0…

視頻匯聚集中存儲EasyCVR平臺調用iframe地址視頻無法播放,該如何解決?

安防監控視頻匯聚平臺EasyCVR基于云邊端一體化架構,具有強大的數據接入、處理及分發能力,可提供視頻監控直播、云端錄像、視頻云存儲、視頻集中存儲、視頻存儲磁盤陣列、錄像檢索與回看、智能告警、平臺級聯、云臺控制、語音對講、AI算法中臺智能分析無縫…

【SpringBoot】中的ApplicationRunner接口 和 CommandLineRunner接口

1. ApplicationRunner接口 用法: 類型: 接口 方法: 只定義了一個run方法 使用場景: springBoot項目啟動時,若想在啟動之后直接執行某一段代碼,就可以用 ApplicationRunner這個接口,并實現接口…

vue3+elementUI-plus實現select下拉框的虛擬滾動

網上查了幾個方案,要不就是不兼容,要不就是不支持vue3, 最終找到一個合適的,并且已上線使用,需要修改一下樣式: 代碼如下: main.js里引用 import vue3-virtual-scroller/dist/vue3-virtual-scroller.css; …

xollam勒索病毒數據恢復|金蝶、用友、管家婆、OA、速達、ERP等軟件數據庫恢復

引言: 數字時代的繁榮與便捷,也孕育著各種網絡安全威脅。其中,.xollam勒索病毒以其毒害性和隱蔽性引發了廣泛關注。本文91數據恢復將為您深入解析.xollam勒索病毒的威脅,探討解密方法,同時分享預防.xollam勒索病毒的關…

Python入門教程23:math模塊的用法

**math是Python 的一個內置模塊,它提供了許多數學函數和常量,用于進行數學計算。**以下是一些常用的math模塊中的函數和常量: math.pi:圓周率π的近似值,約等于3.14159。 math.e:自然對數的底數e的近似值…

【Tomcat】(Tomcat 下載Tomcat 啟動Tomcat 簡單部署 基于Tomcat進行網站后端開發)

文章目錄 Tomcat下載Tomcat啟動Tomcat簡單部署 基于Tomcat進行網站后端開發 Tomcat Tomcat 是一個 HTTP 服務器.HTTP 協議就是 HTTP 客戶端和 HTTP 服務器之間的交互數據的格式. HTTP 服務器我們可以通過 Java Socket 來實現. 而 Tomcat 就是基于 Java 實現的一個開源免費,也是…

Python爬蟲:如何使用Python爬取網站數據

更新:2023-08-13 15:30 想要獲取網站的數據?使用Python爬蟲是一個絕佳的選擇。Python爬蟲是通過自動化程序來提取互聯網上的信息。本文章將會詳細介紹Python爬蟲的相關技術。 一、網絡協議和請求 在使用Python爬蟲之前,我們需要理解網絡協…

Synopsys EDA數字設計與仿真

搭建EDA環境 參考如下博文安裝Synopsys EDA開發工具 https://blog.csdn.net/tugouxp/article/details/132255002?csdn_share_tail%7B%22type%22%3A%22blog%22%2C%22rType%22%3A%22article%22%2C%22rId%22%3A%22132255002%22%2C%22source%22%3A%22tugouxp%22%7D Synopsys ED…

【Git】本地搭建Gitee、Github環境

本地 (Local) 1、使用命令生成公鑰(pub文件) 1. $ ssh-keygen -t rsa -C "xxxxxxxemail.com" -f "github_id_rsa" 2. $ ssh-keygen -t rsa -C "xxxxxxxemail.com" -f "gitee_id_rsa" …

配置pyqt5開發環境

安裝庫 pip install pyqt5 -i https://mirrors.aliuyun.com/pypi/simple pip install pyqt5-tools -i https://mirrors.aliuyun.com/pypi/simple pip install PyQt5designer -i https://mirrors.aliuyun.com/pypi/simple配置External Tools Name:QtDesigner Program:C:\Anaco…

常見的 JavaScript 框架比較

以下是10種常見的JavaScript框架的比較: React:是由Facebook開發和維護的開源JavaScript庫,用于構建用戶界面。它允許你使用組件來構建復雜的UI,并專注于每個組件的內部邏輯,而不必擔心管理整個應用程序的狀態。WebBu…

使用路由器更改設備IP_跨網段連接PLC

在一些設備IP已經固定,但是需要采集此設備的數據,需要用到跨網段采集 1、將路由器WAN(外網撥號口)設置為靜態IP 2、設置DMZ主機,把DMZ主機地址設置成跨網段的PLC地址 DMZ主機 基本信息. DMZ (Demilitarized Zone)即俗稱的非軍事區&#xff0…