實現多層感知機

目錄

多層感知機:

介紹:

代碼實現:

運行結果:

問題答疑:

線性變換與非線性變換

參數含義

為什么清除梯度?

反向傳播的作用

為什么更新權重?


多層感知機:

介紹:

縮寫:MLP,這是一種人工神經網絡,由一個輸入層、一個或多個隱藏層以及一個輸出層組成,每一層都由多個節點(神經元)構成。在MLP中,節點之間只有前向連接,沒有循環連接,這使得它屬于前饋神經網絡的一種。每個節點都應用一個激活函數,如sigmoid、ReLU等,以引入非線性,從而使網絡能夠擬合復雜的函數和數據分布。

代碼實現:

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader# Step 1: Define the MLP model
class SimpleMLP(nn.Module):def __init__(self):super(SimpleMLP, self).__init__()self.fc1 = nn.Linear(784, 128)  # Input layer to hidden layerself.fc2 = nn.Linear(128, 64)   # Hidden layer to another hidden layerself.fc3 = nn.Linear(64, 10)    # Hidden layer to output layerself.relu = nn.ReLU()def forward(self, x):x = x.view(-1, 784)             # Flatten the input from 28x28 to 784x = self.relu(self.fc1(x))x = self.relu(self.fc2(x))x = self.fc3(x)return x# Step 2: Load MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)# Step 3: Define loss function and optimizer
model = SimpleMLP()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# Step 4: Train the model
num_epochs = 5
for epoch in range(num_epochs):for batch_idx, (data, target) in enumerate(train_loader):optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()if batch_idx % 100 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))# Step 5: Evaluate the model on the test set (optional)
with torch.no_grad():correct = 0total = 0for images, labels in test_loader:outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Accuracy of the network on the 10000 test images: {} %'.format(100 * correct / total))

運行結果:

問題答疑:

線性變換與非線性變換

在神經網絡中

線性變換通常指的是權重矩陣和輸入數據的矩陣乘法,再加上偏置向量。數學上,對于一個輸入向量𝑥x和權重矩陣𝑊W,加上偏置向量𝑏b,線性變換可以表示為: 𝑧=𝑊𝑥+𝑏z=Wx+b

非線性變換是指在神經網絡的每一層之后應用的激活函數,如ReLU、sigmoid或tanh等。這些函數引入了非線性,使神經網絡能夠學習和表達復雜的函數關系。沒有非線性變換,無論多少層的神經網絡最終都將簡化為一個線性模型。

參數含義

在上述模型中,參數如784, 128, 64, 10并不是字節,而是神經網絡層的尺寸,具體來說是神經元的數量:

  • 784: 這是輸入層的神經元數量,對應于MNIST數據集中每個圖片的像素數量。MNIST的圖片是28x28像素,因此總共有784個像素點。
  • 128?和?64: 這是兩個隱藏層的神經元數量。它們代表了第一層和第二層的寬度,即這一層有多少個神經元。
  • 10: 這是輸出層的神經元數量,對應于MNIST數據集中的10個數字類別(0到9)。

為什么清除梯度?

在每一次前向傳播和反向傳播過程中,梯度會被累積在張量的.grad屬性中。如果不手動清零,這些梯度將會被累加,導致不正確的梯度值。因此,在每次迭代開始之前,都需要調用optimizer.zero_grad()來清空梯度。

反向傳播的作用

反向傳播(Backpropagation)是一種算法,用于計算損失函數相對于神經網絡中所有權重的梯度。它的目的是為了讓神經網絡知道,當損失函數值較高時,哪些權重需要調整,以及調整的方向和幅度。這些梯度隨后被用于權重更新,以最小化損失函數。

為什么更新權重?

權重更新是基于梯度下降算法進行的。在反向傳播計算出梯度后,權重通過optimizer.step()函數更新,以朝著減小損失函數的方向移動。

這是訓練神經網絡的核心,即通過不斷調整權重和偏置,使模型能夠更好地擬合訓練數據,從而提高預測準確性。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/44853.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/44853.shtml
英文地址,請注明出處:http://en.pswp.cn/web/44853.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

taocms 3.0.1 本地文件泄露漏洞(CVE-2021-44983)

前言 CVE-2021-44983 是一個影響 taoCMS 3.0.1 的遠程代碼執行(RCE)漏洞。該漏洞允許攻擊者通過上傳惡意文件并在服務器上執行任意代碼來利用這一安全缺陷。 漏洞描述 taoCMS 是一個內容管理系統(CMS),用于創建和管…

持續集成的自動化之旅:Gradle在CI中的配置秘籍

持續集成的自動化之旅:Gradle在CI中的配置秘籍 引言 持續集成(Continuous Integration, CI)是現代軟件開發中的一項基礎實踐,它通過自動化的構建和測試流程來提高軟件質量和開發效率。Gradle作為一個靈活的構建工具,…

【眼疾病識別】圖像識別+深度學習技術+人工智能+卷積神經網絡算法+計算機課設+Python+TensorFlow

一、項目介紹 眼疾識別系統,使用Python作為主要編程語言進行開發,基于深度學習等技術使用TensorFlow搭建ResNet50卷積神經網絡算法,通過對眼疾圖片4種數據集進行訓練(‘白內障’, ‘糖尿病性視網膜病變’, ‘青光眼’, ‘正常’&…

jenkins系列-05-jenkins構建golang程序

下載go1.20.2.linux-arm64.tar.gz 并存放到jenkins home目錄: 寫一個golang demo程序:靜態文件服務器:https://gitee.com/jelex/jenkins_golang package mainimport ("encoding/base64""flag""fmt""lo…

window下安裝go環境

一、go官網下載安裝包 官網地址如下:https://golang.google.cn/dl/ 選擇對應系統的安裝包,這里是window系統,可以選擇zip包,下載完解壓就可以使用 二、配置環境變量 這里的截圖配置以win11為例 我的文件解壓目錄是 D:\Software…

力扣32.最長有效括號

力扣32.最長有效括號 class Solution {public:int longestValidParentheses(string s) {int n s.size();int res0;int start -1;vector<int> st;for(int i0;i<n;i){if(s[i] ()st.push_back(i);else{//前面沒有( , (開啟下一段)下一段的開始更新為當前下標if(st.emp…

機器學習和人工智能在農業的應用——案例分析

作者主頁: 知孤云出岫 目錄 引言機器學習和人工智能在農業的應用1. 精準農業作物健康監測土壤分析 2. 作物產量預測3. 農業機器人自動化播種和收割智能灌溉 4. 農業市場分析價格預測需求預測 機器學習和人工智能帶來的變革1. 提高生產效率2. 降低生產成本3. 提升作物產量和質量…

Elsaticsearch java基本操作

索引 基本操作 package com.orchids.elasticsearch.web.controller;import cn.hutool.core.collection.CollUtil; import cn.hutool.json.JSONUtil; import com.orchids.elasticsearch.web.po.User; import io.swagger.annotations.Api; import io.swagger.annotations.ApiOpe…

探索JT808協議在車輛遠程視頻監控系統中的應用

一、部標JT808協議概述 隨著物聯網技術的迅猛發展&#xff0c;智能交通系統&#xff08;ITS&#xff09;已成為現代交通領域的重要組成部分。其中&#xff0c;車輛遠程監控與管理技術作為ITS的核心技術之一&#xff0c;對于提升交通管理效率、保障道路安全具有重要意義。 JT8…

TensorBoard ,PIL 和 OpenCV 在深度學習中的應用

重要工具介紹 TensorBoard&#xff1a; 是一個TensorFlow提供的強大工具&#xff0c;用于可視化和理解深度學習模型的訓練過程和結果。下面我將介紹TensorBoard的相關知識和使用方法。 TensorBoard 簡介 TensorBoard是TensorFlow提供的一個可視化工具&#xff0c;用于&#x…

尚品匯-(十七)

目錄&#xff1a; &#xff08;1&#xff09;獲取價格信息 &#xff08;2&#xff09;獲取銷售信息 前面的表&#xff1a; &#xff08;1&#xff09;獲取價格信息 繼續編寫接口&#xff1a;ManagerService /*** 獲取sku價格* param skuId* return*/ BigDecimal getSkuPrice…

『 Linux 』匿名管道應用 - 簡易進程池

文章目錄 池化技術進程池框架及基本思路進程的描述組織管道通信建立的潛在問題 任務的描述與組織子進程讀取管道信息控制子進程進程退出及資源回收 池化技術 池化技術是一種編程技巧,一般用于優化資源的分配與復用; 當一種資源需要被使用時這意味著這個資源可能會被進行多次使…

mqtt.fx連接阿里云

本文主要是記述一下如何使用mqtt.fx連接在阿里云上創建好的MQTT服務。 1 根據MQTT填寫對應端口即可 找到設備信息&#xff0c;里面有MQTT連接參數 2 使用物模型通信Topic&#xff0c;注意這里的post說設備上報&#xff0c;那也就是意味著云端訂閱post&#xff1b;set則意味著設…

【輕松拿捏】Java-final關鍵字(面試)

目錄 1. 定義和基本用法 回答要點&#xff1a; 示例回答&#xff1a; 2. final 變量 回答要點&#xff1a; 示例回答&#xff1a; 3. final 方法 回答要點&#xff1a; 示例回答&#xff1a; 4. final 類 回答要點&#xff1a; 示例回答&#xff1a; 5. final 關鍵…

搭建hadoop+spark完全分布式集群環境

目錄 一、集群規劃 二、更改主機名 三、建立主機名和ip的映射 四、關閉防火墻(master,slave1,slave2) 五、配置ssh免密碼登錄 六、安裝JDK 七、hadoop之hdfs安裝與配置 1)解壓Hadoop 2)修改hadoop-env.sh 3)修改 core-site.xml 4)修改hdfs-site.xml 5) 修改s…

【進階篇-Day9:JAVA中單列集合Collection、List、ArrayList、LinkedList的介紹】

目錄 1、集合的介紹1.1 概念1.2 集合的分類 2、單列集合&#xff1a;Collection2.1 Collection的使用2.2 集合的通用遍歷方式2.2.1 迭代器遍歷&#xff1a;&#xff08;1&#xff09;例子&#xff1a;&#xff08;2&#xff09;迭代器遍歷的原理&#xff1a;&#xff08;3&…

排序——交換排序

在上篇文章我們詳細介紹了排序的概念與插入排序&#xff0c;大家可以通過下面這個鏈接去看&#xff1a; 排序的概念及插入排序 這篇文章就介紹一下一種排序方式&#xff1a;交換排序。 一&#xff0c;交換排序 基本思想&#xff1a;兩兩比較&#xff0c;如果發生逆序則交換…

jenkins系列-09.jpom構建java docker harbor

本地先啟動jpom server agent: /Users/jelex/Documents/work/jpom-2.10.40/server-2.10.40-release/bin jelexjelexxudeMacBook-Pro bin % sh Server.sh start/Users/jelex/Documents/work/jpom-2.10.40/agent-2.10.40-release/bin jelexjelexxudeMacBook-Pro bin % ./Agent.…

達夢數據庫的系統視圖v$sessions

達夢數據庫的系統視圖v$sessions 達夢數據庫&#xff08;DM Database&#xff09;是中國的一款國產數據庫管理系統&#xff0c;它提供了類似于Oracle的系統視圖來監控和管理數據庫。V$SESSIONS 是達夢數據庫中的一個系統視圖&#xff0c;用于顯示當前數據庫會話的信息。 以下…

全自主巡航無人機項目思路:STM32/PX4 + ROS + AI 實現從傳感融合到智能規劃的端到端解決方案

1. 項目概述 本項目旨在設計并實現一款高度自主的自動巡航無人機系統。該系統能夠按照預設路徑自主飛行&#xff0c;完成各種巡航任務&#xff0c;如電力巡線、森林防火、邊境巡邏和災害監測等。 1.1 系統特點 基于STM32F4和PX4的高性能嵌入式飛控系統多傳感器融合技術實現精…