神經網絡入門—自定義網絡

網絡模型

定義一個兩層網絡

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F# 定義神經網絡模型
class Net(nn.Module):def __init__(self, init_x=0.0):super().__init__()self.fc1 = nn.Linear(1, 10)self.fc2 = nn.Linear(10, 1)def forward(self, x):x = self.fc1(x)x = F.relu(x)x = self.fc2(x)return x# 初始化模型
model = Net()# 定義損失函數和優化器
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# 生成一些示例數據
x_train = torch.tensor([[1.0], [2.0], [3.0], [4.0]], dtype=torch.float32)
y_train = torch.tensor([[2.0], [4.0], [6.0], [8.0]], dtype=torch.float32)# 訓練模型
num_epochs = 1000
for epoch in range(num_epochs):# 清零梯度optimizer.zero_grad()# 前向計算outputs = model(x_train)loss = criterion(outputs, y_train)# 反向傳播loss.backward()# 更新參數optimizer.step()if (epoch + 1) % 100 == 0:print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')# 保存模型
torch.save(model.state_dict(), 'model.pth')# 加載模型
loaded_model = Net()
loaded_model.load_state_dict(torch.load('model.pth'))
loaded_model.eval()  # 將模型設置為評估模式# 輸入新數據進行預測
new_input = torch.tensor([[5.0]], dtype=torch.float32)
with torch.no_grad():prediction = loaded_model(new_input)print(f"輸入 {new_input.item()} 的預測結果: {prediction.item()}")

運行結果

訓練好的參數值:
參數名: fc1.weight, 參數值: tensor([[ 0.5051],
? ? ? ? [ 0.2675],
? ? ? ? [ 0.4080],
? ? ? ? [ 0.3069],
? ? ? ? [ 0.9132],
? ? ? ? [ 0.2250],
? ? ? ? [-0.2428],
? ? ? ? [ 0.4821],
? ? ? ? [ 0.0998],
? ? ? ? [ 0.6737]])
參數名: fc1.bias, 參數值: tensor([ 0.5201, -0.0252, ?0.0504, ?0.6593, -0.4250, ?0.6001, ?0.9645, -0.2310,
? ? ? ? -0.2038, ?0.2116])
參數名: fc2.weight, 參數值: tensor([[ 0.5492, ?0.2550, ?0.3046, ?0.3183, ?0.8147, ?0.3062, -0.4165, ?0.2969,
? ? ? ? ? 0.0482, ?0.5535]])
參數名: fc2.bias, 參數值: tensor([0.0147])

  • fc1?層

    • fc1.weight:這是輸入層到隱藏層的權重矩陣,其形狀為?(10, 1),意味著輸入層有 1 個神經元,隱藏層有 10 個神經元。矩陣中的每個元素代表從輸入神經元到對應隱藏層神經元的連接權重。
    • fc1.bias:這是隱藏層每個神經元的偏置項,形狀為?(10,),也就是每個隱藏層神經元都有一個對應的偏置值。
  • fc2?層

    • fc2.weight:這是隱藏層到輸出層的權重矩陣,形狀為?(1, 10),表明隱藏層有 10 個神經元,輸出層有 1 個神經元。矩陣中的每個元素代表從隱藏層神經元到輸出層神經元的連接權重。
    • fc2.bias:這是輸出層神經元的偏置項,形狀為?(1,),即輸出層只有一個神經元,所以只有一個偏置值。

不同的優化器

神經網絡入門—計算函數值-CSDN博客

激活函數解析

激活函數的作用

激活函數賦予神經網絡非線性映射能力,使其能夠更好地處理復雜的現實世界數據2。常見的激活函數包括ReLU、PReLU等。激活函數通常用于卷積層和全連接層,以增加模型的表達能力。

常見的激活函數

Sigmoid 函數

  • 公式σ(x)= ?\frac{1}{1+e^{-x}}
  • 特點:輸出范圍在?(0, 1)?之間,能夠把輸入映射為概率值,常用于二分類問題。不過它存在梯度消失問題,當輸入值非常大或者非常小時,梯度會趨近于 0。
import torch
import torch.nn.functional as Fx = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0])
sigmoid_output = torch.sigmoid(x)
print("Sigmoid 輸出:", sigmoid_output)

Tanh 函數

  • 公式:\(\tanh(x)=\frac{e^{x}-e^{-x}}{e^{x}+e^{-x}}\)
  • 特點:輸出范圍在?(-1, 1)?之間,零中心化,相較于 Sigmoid 函數,梯度消失問題有所緩解,但仍然存在。
import torch
import torch.nn.functional as Fx = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0])
tanh_output = torch.tanh(x)
print("Tanh 輸出:", tanh_output)

ReLU 函數

  • 公式:\(ReLU(x)=\max(0, x)\)
  • 特點:計算簡單,能夠有效緩解梯度消失問題,在深度學習中被廣泛使用。不過它存在死亡 ReLU 問題,即某些神經元可能永遠不會被激活。
import torch
import torch.nn.functional as Fx = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0])
relu_output = F.relu(x)
print("ReLU 輸出:", relu_output)

?Leaky ReLU 函數

  • 公式:\(LeakyReLU(x)=\begin{cases}x, & x\geq0 \\ \alpha x, & x < 0\end{cases}\),其中?\(\alpha\)?是一個小的常數,例如 0.01。
  • 特點:解決了死亡 ReLU 問題,當輸入為負數時,也會有一個小的梯度。
import torch
import torch.nn.functional as Fx = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0])
leaky_relu_output = F.leaky_relu(x, negative_slope=0.01)
print("Leaky ReLU 輸出:", leaky_relu_output)

損失函數解析

# 定義損失函數和優化器
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

該程序使用了MSELoss損失函數和SGD優化器

全部損失函數總類有

__all__ = ["L1Loss","NLLLoss","NLLLoss2d","PoissonNLLLoss","GaussianNLLLoss","KLDivLoss","MSELoss","BCELoss","BCEWithLogitsLoss","HingeEmbeddingLoss","MultiLabelMarginLoss","SmoothL1Loss","HuberLoss","SoftMarginLoss","CrossEntropyLoss","MultiLabelSoftMarginLoss","CosineEmbeddingLoss","MarginRankingLoss","MultiMarginLoss","TripletMarginLoss","TripletMarginWithDistanceLoss","CTCLoss",
]

  1. L1Loss:計算輸入和目標之間的平均絕對誤差(MAE),即?loss = 1/n * sum(|input - target|)
  2. NLLLoss:負對數似然損失,常用于分類任務,通常在模型輸出經過?log_softmax?變換后使用。
  3. NLLLoss2d:二維的負對數似然損失,適用于圖像等二維數據的分類任務。
  4. PoissonNLLLoss:泊松負對數似然損失,適用于泊松分布的數據,常用于計數數據的回歸。
  5. GaussianNLLLoss:高斯負對數似然損失,假設數據服從高斯分布,用于回歸任務。
  6. KLDivLoss:Kullback-Leibler 散度損失,用于衡量兩個概率分布之間的差異。
  7. MSELoss:均方誤差損失,計算輸入和目標之間的平均平方誤差,即?loss = 1/n * sum((input - target) ** 2),常用于回歸任務。
  8. BCELoss:二元交叉熵損失,用于二分類任務,輸入和目標都應該是概率值(在 0 到 1 之間)。
  9. BCEWithLogitsLoss:將?Sigmoid?函數和?BCELoss?結合在一起,適用于輸入是未經過激活函數的原始輸出(logits)的情況。
  10. HingeEmbeddingLoss:用于度量兩個輸入樣本之間的相似性,常用于度量學習任務。
  11. MultiLabelMarginLoss:多標簽分類的邊緣損失,適用于一個樣本可能屬于多個類別的情況。
  12. SmoothL1Loss:平滑的 L1 損失,在 L1 損失的基礎上進行了平滑處理,在某些情況下比 L1 和 L2 損失表現更好。
  13. HuberLoss:也稱為平滑 L1 損失,結合了 L1 和 L2 損失的優點,對離群點更魯棒。
  14. SoftMarginLoss:用于二分類的軟邊緣損失,允許一些樣本在邊緣內。
  15. CrossEntropyLoss:交叉熵損失,通常是?log_softmax?和?NLLLoss?的組合,常用于多分類任務。
  16. MultiLabelSoftMarginLoss:多標簽軟邊緣損失,適用于多標簽分類問題,每個標簽都有一個獨立的分類器。
  17. CosineEmbeddingLoss:基于余弦相似度的嵌入損失,用于度量兩個輸入樣本之間的余弦相似度,常用于度量學習。
  18. MarginRankingLoss:邊緣排序損失,用于比較兩個輸入樣本的得分,常用于排序任務。
  19. MultiMarginLoss:多邊緣損失,用于多分類任務,基于每個類別的邊緣來計算損失。
  20. TripletMarginLoss:三元組邊緣損失,常用于度量學習,通過比較三元組(錨點、正樣本、負樣本)之間的距離來學習嵌入。
  21. TripletMarginWithDistanceLoss:結合了距離度量的三元組邊緣損失,在?TripletMarginLoss?的基礎上增加了距離度量的計算。
  22. CTCLoss:連接主義時間分類損失,常用于處理序列到序列的問題,如語音識別和手寫文字識別等,不需要對齊輸入和輸出序列。

可視化模型

Graphviz

Download | Graphviz

安裝時候選擇添加path到環境變量

輸入

dot -version

顯示下面說明安裝成功

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchviz import make_dot# 定義神經網絡模型
class Net(nn.Module):def __init__(self, init_x=0.0):super().__init__()self.fc1 = nn.Linear(1, 10)self.fc2 = nn.Linear(10, 1)def forward(self, x):x = self.fc1(x)x = F.relu(x)x = self.fc2(x)return x# 初始化模型
model = Net()# 生成一個示例輸入
x = torch.randn(1, 1)# 前向傳播
y = model(x)# 繪制計算圖
dot = make_dot(y, params=dict(model.named_parameters()))
dot.render('net_model_structure', format='png', cleanup=True)

Tensorboard

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter# 定義神經網絡模型
class Net(nn.Module):def __init__(self, init_x=0.0):super().__init__()self.fc1 = nn.Linear(1, 10)self.fc2 = nn.Linear(10, 1)def forward(self, x):x = self.fc1(x)x = F.relu(x)x = self.fc2(x)return x# 初始化模型
model = Net()# 初始化 SummaryWriter
writer = SummaryWriter('file/net_model')# 生成一個示例輸入
x = torch.randn(1, 1)# 將模型結構寫入 TensorBoard
writer.add_graph(model, x)# 關閉 writer
writer.close()

進入file文件夾

?tensorboard --logdir="./net_model"

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/76138.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/76138.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/76138.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

無人機裝調與測試

文章目錄 前言一、無人機基本常識/預備知識&#xff08;一&#xff09;無人機飛行原理無人機硬件組成/各組件作用1.飛控2.GPS3.接收機4.電流計5.電調6.電機7.電池8.螺旋槳9.UBEC&#xff08;穩壓模塊&#xff09; &#xff08;二&#xff09;飛控硬件簡介&#xff08;三&#x…

2024年-全國大學生數學建模競賽(CUMCM)試題速瀏、分類及淺析

2024年-全國大學生數學建模競賽(CUMCM)試題速瀏、分類及淺析 全國大學生數學建模競賽&#xff08;China Undergraduate Mathematical Contest in Modeling&#xff09;是國家教委高教司和中國工業與應用數學學會共同主辦的面向全國大學生的群眾性科技活動&#xff0c;目的在于激…

Linux入門指南:從零開始探索開源世界

&#x1f680; 前言 大家好&#xff01;今天我們來聊一聊Linux這個神奇的操作系統~ &#x1f916; 很多小伙伴可能覺得Linux是程序員專屬&#xff0c;其實它早已滲透到我們生活的各個角落&#xff01;本文將帶你了解Linux的誕生故事、發行版選擇攻略、應用領域&#xff0c;還有…

記錄vscode連接不上wsl子系統下ubuntu18.04問題解決方法

記錄vscode連接不上wsl子系統下ubuntu18.04問題解決方法 報錯內容嘗試第一次解決方法嘗試第二次解決方法注意事項參考連接 報錯內容 Unable to download server on client side: Error: Request downloadRequest failed unexpectedly without providing any details… Will tr…

Cursor+MCP學習記錄

參考視頻 Cursor MCP 王炸&#xff01;徹底顛覆我的Cursor工作流&#xff0c;效率直接起飛_嗶哩嗶哩_bilibili 感覺這個博主講的還不錯 所使用到的網址 Smithery - Model Context Protocol Registry Introduction - Model Context Protocol 學習過程 Smithery - Model …

testflight上架ipa包-只有ipa包的情況下如何修改簽名信息為蘋果開發者賬戶對應的信息-ipa蘋果包如何手動改簽或者第三方工具改簽-優雅草卓伊凡

testflight上架ipa包-只有ipa包的情況下如何修改簽名信息為蘋果開發者賬戶對應的信息-ipa蘋果包如何手動改簽或者第三方工具改簽-優雅草卓伊凡 直接修改蘋果IPA包的簽名和打包信息并不是一個推薦的常規做法&#xff0c;因為這可能違反蘋果的開發者條款&#xff0c;并且可能導致…

深入解析Java內存與緩存:從原理到實踐優化

一、Java內存管理&#xff1a;JVM的核心機制 1. JVM內存模型全景圖 ┌───────────────────────────────┐ │ JVM Memory │ ├─────────────┬─────────────────┤ │ Thread │ 共享…

紫光展銳5G SoC T8300:影像升級,「定格」美好世界

影像能力已成為當今衡量智能手機性能的重要標尺之一。隨著消費者對手機攝影需求日益提升&#xff0c;手機廠商紛紛在影像硬件和算法上展開激烈競爭&#xff0c;力求為用戶帶來更加出色的拍攝體驗。 紫光展銳專為全球主流用戶打造的暢享影音和游戲體驗的5G SoC——T8300&#x…

【Java設計模式】第6章 抽象工廠模式講解

6. 抽象工廠模式 6.1 抽象工廠講解 定義:提供一個接口創建一系列相關或依賴對象,無需指定具體類。核心概念: 產品等級結構:同一類型的不同產品(如Java視頻、Python視頻)。產品族:同一工廠生產的多個產品(如Java視頻 + Java手記)。適用場景: 需要創建多個相關聯的產品…

Dify教程01-Dify是什么、應用場景、如何安裝

Dify教程01-Dify是什么、應用場景、如何安裝 大家好&#xff0c;我是星哥&#xff0c;上篇文章講了Coze、Dify、FastGPT、MaxKB 對比&#xff0c;今天就來學習如何搭建Dify。 Dify是什么 **Dify 是一款開源的大語言模型(LLM) 應用開發平臺。**它融合了后端即服務&#xff08…

Java后端開發-面試總結(集結版)

第一個問題&#xff0c;在 Java 集合框架中&#xff0c;ArrayList和LinkedList有什么區別&#xff1f;在實際應用場景中&#xff0c;應該如何選擇使用它們&#xff1f; ArrayList 基于數組&#xff0c;LinkedList 基于雙向鏈表。 在查詢方面 ArrayList 效率高&#xff0c;添加…

nslookup、dig、traceroute、ping 這些工具在解析域名時是否查詢 DNS 服務器 或 本地 hosts 文件 的詳細對比

host配置解析 127.0.0.1 example.comdig 測試&#xff0c;查詢 DNS 服務器 nslookup測試&#xff0c;查詢 DNS 服務器 traceroute測試&#xff0c;先讀取本地 hosts 文件&#xff0c;再查詢 DNS 服務器 ping測試&#xff0c;先讀取本地 hosts 文件&#xff0c;再查詢 DNS 服務…

文件上傳、讀取與包含漏洞解析及防御實戰

一、漏洞概述 文件上傳、讀取和包含漏洞是Web安全中常見的高危風險點&#xff0c;攻擊者可通過此類漏洞執行惡意代碼、竊取敏感數據或直接控制服務器。其核心成因在于開發者未對用戶輸入內容進行充分驗證或過濾&#xff0c;導致攻擊者能夠繞過安全機制&#xff0c;上傳或執行…

STM32 的編程方式總結

&#x1f9f1; 按照“是否可獨立工作”來分&#xff1a; 庫/方式是否可獨立使用是否依賴其他庫說明寄存器裸寫? 是? 無完全自主控制&#xff0c;無庫依賴標準庫&#xff08;StdPeriph&#xff09;? 是? 只依賴 CMSIS自成體系&#xff08;F1專屬&#xff09;&#xff0c;只…

Flutter命令行打包打不出ipa報錯

Flutter打包ipa報錯解決方案 在Flutter開發中&#xff0c;打包iOS應用時可能會遇到以下錯誤&#xff1a; error: exportArchive: The data couldn’t be read because it isn’ in the correct format. 或者 Encountered error while creating the IPA: error: exportArchive…

SQL Server常見問題的分類解析(一)

以下是SQL Server常見問題的分類解析,涵蓋安裝配置、性能優化、備份恢復、高可用性等核心場景,結合微軟官方文檔和社區實踐整理而成(編號對應搜索結果來源): 一、安裝與配置問題 安裝失敗:.NET Framework缺失解決方案:手動安裝所需版本.NET Framework,以管理員身份運行…

Spring Boot 3.x 下 Spring Security 的執行流程、核心類和原理詳解,結合用戶描述的關鍵點展開說明,并以表格總結

以下是 Spring Boot 3.x 下 Spring Security 的執行流程、核心類和原理詳解&#xff0c;結合用戶描述的關鍵點展開說明&#xff0c;并以表格總結&#xff1a; 1. Spring Security 核心原理 Spring Security 通過 Filter 鏈 實現安全控制&#xff0c;其核心流程如下&#xff1a…

Vue:路由切換表格塌陷

目錄 一、 出現場景二、 解決方案 一、 出現場景 當路由切換時&#xff0c;表格操作欄會出現行錯亂、塌陷的問題 二、 解決方案 在組件重新被激活的時候刷新表格 <el-table ref"table"></el-table>activated(){this.$nextTick(() > {this.$refs[t…

文件上傳漏洞原理學習

什么是文件上傳漏洞 文件上傳漏洞是指用戶上傳了一個可執行的腳本文件&#xff0c;并通過此腳本文件獲得了執行服務器端命令的能力。“文件上傳” 本身沒有問題&#xff0c;有問題的是文件上傳后&#xff0c;服務器怎么處理、解釋文件。如果服務器的處理邏輯做的不夠安全&#…

leetcode_數組 189. 輪轉數組

189. 輪轉數組 給定一個整數數組 nums&#xff0c;將數組中的元素向右輪轉 k 個位置&#xff0c;其中 k 是非負數 示例 1: 輸入: nums [1,2,3,4,5,6,7], k 3輸出: [5,6,7,1,2,3,4] 示例 2: 輸入&#xff1a;nums [-1,-100,3,99], k 2輸出&#xff1a;[3,99,-1,-100] 思…