嵌入式學習-PyTorch(7)-day23

損失函數的調用

import torch
from torch import nn
from torch.nn import L1Lossinputs = torch.tensor([1.0,2.0,3.0])
target = torch.tensor([1.0,2.0,5.0])inputs = torch.reshape(inputs, (1, 1, 1, 3))
target = torch.reshape(target, (1, 1, 1, 3))
#損失函數
loss = L1Loss(reduction='sum')
#MSELoss均值方差
loss_mse = nn.MSELoss()
result1 = loss(inputs, target)
result2 = loss_mse(inputs, target)
print(result1, result2)

?實際應用

import torch
import torchvision.datasets
from torch import nn
from torch.nn import Conv2ddataset = torchvision.datasets.CIFAR10(root='./data_CIF', train=False, download=True, transform=torchvision.transforms.ToTensor())
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1)class Tudui(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, padding=2)self.maxpool1 = nn.MaxPool2d(kernel_size=2)self.conv2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, padding=2)self.maxpool2 = nn.MaxPool2d(kernel_size=2)self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=2)self.maxpool3 = nn.MaxPool2d(kernel_size=2)self.flatten = nn.Flatten()self.linear1 = nn.Linear(in_features=1024, out_features=64)self.linear2 = nn.Linear(in_features=64, out_features=10)self.model1 = nn.Sequential(Conv2d(in_channels=3, out_channels=32, kernel_size=5, padding=2),nn.MaxPool2d(kernel_size=2),nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, padding=2),nn.MaxPool2d(kernel_size=2),nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=2),nn.MaxPool2d(kernel_size=2),nn.Flatten(),nn.Linear(in_features=1024, out_features=64),nn.Linear(in_features=64, out_features=10))def forward(self, x):x = self.model1(x)return x
loss = nn.CrossEntropyLoss()
tudui = Tudui()
for data in dataloader:imgs,targets = dataoutputs = tudui(imgs)result1 = loss(outputs, targets)print(result1)#反向傳播result1.backward()#梯度grad會改變,從而通過grad來降低loss

torch.nn.CrossEntropyLoss?

🧩 CrossEntropyLoss 是什么?

本質上是:

Softmax + NLLLoss(負對數似然) 的組合。

公式:

\text{Loss} = - \sum_{i} y_i \log(\hat{p}_i)

  • \hat{p}_i?:模型預測的概率(通過 softmax 得到)

  • y_i?:真實類別的 one-hot 標簽

PyTorch 不需要你手動做 softmax,它會直接從 logits(未經過 softmax 的原始輸出)算起,防止數值不穩定。


🏷? 常用參數

torch.nn.CrossEntropyLoss(weight=None, ignore_index=-100, reduction='mean')

參數含義
weight給不同類別加權(處理類別不均衡)
ignore_index忽略某個類別(常見于 NLP 的 padding)
reductionmean(默認平均)、sum(求和)、none(逐個樣本返回 loss)


🎨 最小使用例子

import torch
import torch.nn as nncriterion = nn.CrossEntropyLoss()# 假設 batch_size=3, num_classes=5
outputs = torch.tensor([[1.0, 2.0, 0.5, -1.0, 0.0],[0.1, -0.2, 2.3, 0.7, 1.8],[2.0, 0.1, 0.0, 1.0, 0.5]])  # logits
labels = torch.tensor([1, 2, 0])  # 真實類別索引loss = criterion(outputs, labels)
print(loss.item())
  • outputs:模型輸出 logits,不需要 softmax;

  • labels:真實類別(索引型),如 0, 1, 2,...

  • loss.item():輸出標量值。


💡 你需要注意:

?? 重點📌 說明
logits 直接輸入不要提前做 softmax
label 是類別索引不是 one-hot,而是整數(如 [1, 3, 0]
自動求 batch 平均默認 reduction='mean'
多分類用它最合適二分類也能用,但 BCEWithLogitsLoss 更常見


🎁 總結

優點缺點
? 簡單強大,適合分類? 不適合回歸任務
? 內置 softmax + log? label 不能是 one-hot
? 數值穩定性強? 類別極度不均衡需額外加 weight


🎯 一句話總結

CrossEntropyLoss 是深度學習中分類問題的“首選痛點衡量尺”,幫你用“正確標簽”去教訓“錯誤預測”,模型越聰明 loss 越小。

?公式:

?

1?? 第一部分:

- \log \left( \frac{\exp(x[\text{class}])}{\sum_j \exp(x[j])} \right)

這是經典 負對數似然(Negative Log-Likelihood):

  • 分子:你模型對正確類別 class 輸出的得分(logits),取 exp;

  • 分母:所有類別的 logits 做 softmax 歸一化;

  • 再取負 log —— 意思是“你對正確答案預測得越自信,loss 越小”。


2?? 推導為:

= - x[\text{class}] + \log \left( \sum_j \exp(x[j]) \right)

log(a/b) = log(a) - log(b) 的變形:

  • - x[\text{class}]:你對正確類輸出的分值直接扣掉;

  • +\log(\sum_j \exp(x[j])):對所有類別的總分值做歸一化。

這是交叉熵公式最常用的“log-sum-exp”形式。


📌 為什么這么寫?

  • 避免直接用 softmax(softmax+log 合并后可以避免數值不穩定 🚀)

  • 計算量更高效(框架底層可以優化)


🌟 直觀理解:

場景解釋
正確類分數高x[\text{class}]越大,loss 越小
錯誤類分數高\sum \exp(x[j])越大,loss 越大
目標壓低 log-sum-exp,拉高正確類別 logits

🎯 一句話總結:

交叉熵 = “扣掉正確答案得分” + “對所有類別歸一化”,越接近正確答案,loss 越小。
這就是你訓練神經網絡時 模型越來越聰明的數學依據 😎

舉例:

logits = torch.tensor([1.0, 2.0, 0.1])  # 模型輸出 (C=3)
label = torch.tensor([1])  # 真實類別索引 = 1

其中:

  • N=1(batch size)

  • C=3(類別數)

  • 正確類別是索引1,對應第二個值:2.0

🎁 完整公式回顧

\text{loss}(x, y) = -x[y] + \log \sum_{j} \exp(x[j])


🟣 第一步:Softmax + log 邏輯

softmax 本質上是:

p = \frac{\exp(x[\text{class}])}{\sum_j \exp(x[j])}

但是 PyTorch 的 CrossEntropyLoss 內部直接用:

\text{loss} = - \log p


🧮 你這個例子手動算:

logits = [1.0, 2.0, 0.1],class = 1,對應 logit = 2.0

第一部分:

- x[\text{class}] = -2.0

第二部分:

\log \sum_{j=1}^{3} \exp(x[j]) = \log (\exp(1.0) + \exp(2.0) + \exp(0.1))

先算:

  • exp(1.0)≈2.718

  • exp(2.0)≈7.389

  • exp(0.1)≈1.105

加起來:

∑=2.718+7.389+1.105=11.212

取對數:

log?(11.212)≈2.418

最終 loss:

loss=?2.0+2.418=0.418

🌟 你可以這樣理解

部分含義
?x[class]- x[\text{class}]?x[class]懲罰正確答案打分太低
log?∑exp?(x)\log \sum \exp(x)log∑exp(x)考慮所有類別的對比,如果錯誤類別打分高也被懲罰
最終目標“提升正確答案打分、降低錯誤答案打分”

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/914873.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/914873.shtml
英文地址,請注明出處:http://en.pswp.cn/news/914873.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

用 Ray 跨節點調用 GPU 部署 DeepSeek 大模型,實現分布式高效推理

在大模型時代,單節點 GPU 資源往往難以滿足大模型(如 7B/13B 參數模型)的部署需求。借助 Ray 分布式框架,我們可以輕松實現跨節點 GPU 資源調度,讓大模型在多節點間高效運行。本文將以 DeepSeek-llm-7B-Chat 模型為例&…

快速了解 HTTPS

1. 引入 在 HTTP 協議 章節的 reference 段,曾提到過 HTTPS。這里對HTTPS進行詳細介紹。 HTTPS 是在 HTTP 的基礎上,引入了一個加密層 (SSL)。HTTP 是明文傳輸的 (不安全)。當下所見到的大部分網站都是 HTTPS 的。 起初是拜運營商劫持所賜(…

mysql備份與視圖

要求:1.將mydb9_stusys數據庫下的student、sc 和course表,備份到本地主機保存為st_msg_bak.sql文件,然后將數據表恢復到自建的db_test數據庫中;2.在db_test數據庫創建一視圖 stu_info,查詢全體學生的姓名,性別,課程名&…

【數據結構】 鏈表 + 手動實現單鏈表和雙鏈表的接口(圖文并茂附完整源碼)

文章目錄 一、 鏈表的概念及結構 二、鏈表的分類 ?編輯 三、手動實現單鏈表 1、定義單鏈表的一個節點 2、打印單鏈表 3、創建新節點 4、單鏈表的尾插 5、單鏈表的頭插 6、單鏈表的尾刪 7、單鏈表的頭刪 8、單鏈表的查找 9、在指定位置之前插入一個新節點 10、在指…

Go語言時間控制:定時器技術詳細指南

1. 定時器基礎:從 time.Sleep 到 time.Timer 的進化為什么 time.Sleep 不夠好?在 Go 編程中,很多人初學時會用 time.Sleep 來實現時間控制。比如,想讓程序暫停 2 秒,代碼可能是這樣:package mainimport (&q…

C# 轉換(顯式轉換和強制轉換)

顯式轉換和強制轉換 如果要把短類型轉換為長類型,讓長類型保存短類型的所有位很簡單。然而,在其他情況下, 目標類型也許無法在不損失數據的情況下容納源值。 例如,假設我們希望把ushort值轉化為byte。 ushort可以保存任何0~65535的…

淺談自動化設計最常用的三款軟件catia,eplan,autocad

筆者從上半年開始接觸這三款軟件,掌握了基礎用法,但是過了一段時間不用,發現再次用,遇到的問題短時間解決不了,忘記的有點多,這里記錄一下,防止下次忘記Elpan:問題1QF01是柜安裝板上的一個部件&…

網絡編程7.17

練習&#xff1a;服務器&#xff1a;#include <stdio.h> #include <string.h> #include <unistd.h> #include <stdlib.h> #include <sys/types.h> #include <sys/stat.h> #include <fcntl.h> #include <pthread.h> #include &…

c++ 模板元編程

聽說模板元編程能在編譯時計算出常量&#xff0c;簡單測試下看看&#xff1a;template<int N> struct Summation {static constexpr int value N Summation<N - 1>::value; // 計算 1 2 ... N 的值 };template<> struct Summation<1> { // 遞歸終…

【深度學習】神經網絡過擬合與欠擬合-part5

八、過擬合與欠擬合訓練深層神經網絡時&#xff0c;由于模型參數較多&#xff0c;數據不足的時候容易過擬合&#xff0c;正則化技術就是防止過擬合&#xff0c;提升模型的泛化能力和魯棒性 &#xff08;對新數據表現良好 對異常數據表現良好&#xff09;1、概念1.1過擬合在訓練…

JavaScript的“硬件窺探術”:瀏覽器如何讀取你的設備信息?

JavaScript的“硬件窺探術”&#xff1a;瀏覽器如何讀取你的設備信息&#xff1f; 在Web開發的世界里&#xff0c;JavaScript一直扮演著“幕后魔術師”的角色。從簡單的頁面跳轉到復雜的實時數據處理&#xff0c;它似乎總能用最輕巧的方式解決最棘手的問題。但你是否想過&#…

論安全架構設計(層次)

安全架構設計&#xff08;層次&#xff09; 摘要 2021年4月&#xff0c;我有幸參與了某保險公司的“優車險”項目的建設開發工作&#xff0c;該系統以車險報價、車險投保和報案理賠為核心功能&#xff0c;同時實現了年檢代辦、道路救援、一鍵挪車等增值服務功能。在本項目中&a…

滾珠導軌常見的故障有哪些?

在自動化生產設備、精密機床等領域&#xff0c;滾珠導軌就像是設備平穩運行的 “軌道”&#xff0c;為機械部件的直線運動提供穩準導向。但導軌使用時間長了&#xff0c;難免會出現這樣那樣的故障。滾珠脫落&#xff1a;可能由安裝不當、導軌損壞、超負荷運行、維護不當或惡劣環…

機器視覺的包裝盒絲印應用

在包裝盒絲網印刷領域&#xff0c;隨著消費市場對產品外觀精細化要求的持續提升&#xff0c;傳統印刷工藝面臨多重挑戰&#xff1a;多色套印偏差、曲面基材定位困難、異形結構印刷失真等問題。雙翌光電科技研發的WiseAlign視覺系統&#xff0c;通過高精度視覺對位技術與智能化操…

Redis學習-03重要文件及作用、Redis 命令行客戶端

Redis 重要文件及作用 啟動/停止命令或腳本 /usr/bin/redis-check-aof -> /usr/bin/redis-server /usr/bin/redis-check-rdb -> /usr/bin/redis-server /usr/bin/redis-cli /usr/bin/redis-sentinel -> /usr/bin/redis-server /usr/bin/redis-server /usr/libexec/red…

SVN客戶端(TortoiseSVN)和SVN-VS2022插件(visualsvn)官網下載

SVN服務端官網下載地址&#xff1a;https://sourceforge.net/projects/win32svn/ SVN客戶端工具(TortoiseSVN):https://plan.io/tortoise-svn/ SVN-VS2022插件(visualsvn)官網下載地址&#xff1a;https://www.visualsvn.com/downloads/

990. 等式方程的可滿足性

題目&#xff1a;第一次思考&#xff1a; 經典并查集 實現&#xff1a;class UnionSet{public:vector<int> parent;public:UnionSet(int n) {parent.resize(n);}void init(int n) {for (int i 0; i < n; i) {parent[i] i;}}int find(int x) {if (parent[x] ! x) {pa…

HTML--教程

<!DOCTYPE html> <html> <head> <meta charset"utf-8"> <title>菜鳥教程(runoob.com)</title> </head> <body><h1>我的第一個標題</h1><p>我的第一個段落。</p> </body> </html&g…

Leetcode刷題營第二十七題:二叉樹的最大深度

104. 二叉樹的最大深度 給定一個二叉樹 root &#xff0c;返回其最大深度。 二叉樹的 最大深度 是指從根節點到最遠葉子節點的最長路徑上的節點數。 示例 1&#xff1a; 輸入&#xff1a;root [3,9,20,null,null,15,7] 輸出&#xff1a;3示例 2&#xff1a; 輸入&#xff…