神經網絡中 標量求導和向量求導

0. 引出問題

在神經網絡反向傳播過程中 loss = [loss?,loss?, loss?],為什么 ?loss/?w

?loss?/?w 
?loss?/?w
?loss?/?w 

?loss?/?w 和 loss 維度一樣都是三位向量 ,[?loss?/?w, ?loss?/?w, ?loss?/?w] 就變成3*3的矩陣
如下所示:

import torchw = torch.tensor([1.0, 2.0,3.0], requires_grad=True)
loss = w * 3  
print("loss: \n", loss)loss_m = []for i, val in enumerate(loss):w.grad = None  # 清零val.backward(retain_graph=True)print(f"?loss{i+1}/?w = {w.grad}")loss_m.append(w.grad.clone())print("loss_m: \n", torch.stack(loss_m))

輸出結果:

loss: tensor([3., 6., 9.], grad_fn=<MulBackward0>)?loss1/?w = tensor([3., 0., 0.])
?loss2/?w = tensor([0., 3., 0.])
?loss3/?w = tensor([0., 0., 3.])loss_m: tensor([[3., 0., 0.],[0., 3., 0.],[0., 0., 3.]])

loss: tensor([3., 6., 9.]) 為向量,對w求導時為矩陣
但是 w.grad 必須 是標量或張量,不能是向量矩陣

1. 標量求導

import torchw = torch.tensor([1.0, 2.0,3.0], requires_grad=True)
loss = w * 3  
print("loss: \n", loss)loss_m = []
# 方法1:分別計算
for i, val in enumerate(loss):w.grad = None  # 清零val.backward(retain_graph=True)print(f"?loss{i+1}/?w = {w.grad}")loss_m.append(w.grad.clone())print("loss_m: \n", torch.stack(loss_m))grads = torch.autograd.grad(loss.sum(), w,retain_graph=True)
print("grads: \n", grads)  grads1 = torch.autograd.grad(loss.mean(), w)[0]
print("grads1: \n", grads1) 

輸出;

loss: tensor([3., 6., 9.], grad_fn=<MulBackward0>)
?loss1/?w = tensor([3., 0., 0.])
?loss2/?w = tensor([0., 3., 0.])
?loss3/?w = tensor([0., 0., 3.])
loss_m: tensor([[3., 0., 0.],[0., 3., 0.],[0., 0., 3.]])
grads: (tensor([3., 3., 3.]),)
grads1: tensor([1., 1., 1.])

同樣的例子:

import torch# 3個樣本的真實數據
x = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], requires_grad=True)
y_true = torch.tensor([1.0, 2.0, 3.0])# 線性模型:y = w?x? + w?x?
w = torch.tensor([0.5, 0.5], requires_grad=True)
predictions = (x @ w)  # [1.5, 3.5, 5.5]
print("預測值:", predictions)
# 計算每個樣本的梯度
individual_grads = []
for i in range(3):loss = (predictions[i] - y_true[i])**2loss.backward(retain_graph=True)individual_grads.append(w.grad.clone())w.grad.zero_()print("樣本1梯度:", individual_grads[0]) 
print("樣本2梯度:", individual_grads[1])  
print("樣本3梯度:", individual_grads[2])  # 標量梯度:自動綜合
total_loss = ((predictions - y_true)**2).mean()
total_loss.backward()# 驗證:標量梯度 = 向量梯度的平均
manual_average = (individual_grads[0] + individual_grads[1] + individual_grads[2]) / 3print("手動平均:", manual_average)  
print("標量結果:", w.grad)  

輸出結果:

預測值: tensor([1.5000, 3.5000, 5.5000], grad_fn=<MvBackward0>)
樣本1梯度: tensor([1., 2.])
樣本2梯度: tensor([ 9., 12.])
樣本3梯度: tensor([25., 30.])
手動平均: tensor([11.6667, 14.6667])
標量結果: tensor([11.6667, 14.6667])

訓練神經網絡是為了最小化整體損失,不是單獨優化每個樣本

# 實際訓練:最小化平均損失
batch_loss = individual_losses.mean()  # 標量
batch_loss.backward()  # 得到平均梯度
optimizer.step()       # 朝平均最優方向更新

2. 什么時候需要向量梯度?

僅用于研究:分析樣本敏感性

def compute_sample_gradients(model, x, y):"""僅用于分析,不用于訓練"""grads = []for xi, yi in zip(x, y):model.zero_grad()pred = model(xi.unsqueeze(0))loss = ((pred - yi) ** 2)loss.backward()grads.append(model.weight.grad.clone())return grads  # 每個樣本的單獨梯度

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/920008.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/920008.shtml
英文地址,請注明出處:http://en.pswp.cn/news/920008.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

tcpdump命令打印抓包信息

tcpdump命令打印抓包信息 下面是在服務器抓取打印服務端7701端口打印 rootgb:/home/gb# ifconfig -a eth0: flags4163<UP,BROADCAST,RUNNING,MULTICAST> mtu 1500inet 10.250.251.197 netmask 255.255.255.0 broadcast 10.250.251.255inet6 fe80::76fe:48ff:fe94:5a5 …

Mysql-經典實戰案例(13):如何通過Federated實現跨實例訪問表

實現原理&#xff1a;使用Federated引擎本創建一個鏈接表實現&#xff0c;但是Federated 引擎只是一個按列的順序和類型解析遠程返回的數據流準備工作&#xff1a; 1. 本地庫啟用 Federated 引擎查看是否已啟用&#xff1a; SHOW ENGINES;如果Federated 引擎的 Support 是 YES …

Linux -- 動靜態庫

一、什么是庫1、動靜態庫概念# 庫是寫好的現有的&#xff0c;成熟的&#xff0c;可以復?的代碼。現實中每個程序都要依賴很多基礎的底層庫&#xff0c;不可能每個?的代碼都從零開始&#xff0c;因此庫的存在意義?同尋常。# 本質上來說庫是?種可執?代碼的?進制形式&#x…

Linux筆記---單例模式與線程池

1. 單例模式單例模式是一種常用的設計模式&#xff0c;它確保一個類只有一個實例&#xff0c;并提供一個全局訪問點來獲取這個實例。這種模式在需要控制資源訪問、管理共享狀態或協調系統行為時非常有用。單例模式的核心特點&#xff1a;私有構造函數&#xff1a;防止外部通過n…

Linux中的指令

1.adduseradduser的作用是創立一個新的用戶。當我們在命令行中輸入1中的指令后&#xff0c;就會彈出2中的命令行&#xff0c;讓我們設立新的密碼&#xff0c;緊接著就會讓我們再次輸入新的密碼&#xff0c;對于密碼的輸入它是不會顯示出來的&#xff0c;如果輸入錯誤就會讓我們…

【n8n】Docker容器中安裝ffmpeg

容器化部署 n8n 時&#xff0c;常常會遇到一些環境依賴問題。缺少 docker 命令或無法安裝 ffmpeg 是較為常見的場景&#xff0c;如果處理不當&#xff0c;會導致流程執行受限。 本文介紹如何在 n8n 容器中解決 docker 命令不可用和 ffmpeg 安裝受限的問題&#xff0c;并給出多…

【基礎算法】初識搜索:遞歸型枚舉與回溯剪枝

文章目錄一、搜索1. 什么是搜索&#xff1f;2. 遍歷 vs 搜索3. 回溯與剪枝二、OJ 練習1. 枚舉子集 ?(1) 解題思路(2) 代碼實現2. 組合型枚舉 ?(1) 解題思路請添加圖片描述(2) 代碼實現3. 枚舉排列 ?(1) 解題思路(2) 代碼實現4. 全排列問題 ?(1) 解題思路(2) 代碼實現一、搜…

Node.js異步編程——async/await實現

一、async/await基礎語法 在Node.Js編程中,async關鍵字用于定義異步函數,這個異步函數執行完會返回一個Promise對象,異步函數的內部可以使用await關鍵字來暫停當前代碼的繼續執行,直到Promise操作完成。 在用法上,async關鍵字主要用于聲明一個異步函數,await關鍵字主要…

搭建一個簡單的Agent

準備本案例使用deepseek&#xff0c;登錄deepseek官網&#xff0c;登錄賬號&#xff0c;充值幾塊錢&#xff0c;然后創建Api key可以創建虛擬環境&#xff0c;python版本最好是3.12&#xff0c;以下是文件目錄。test文件夾中&#xff0c;放一些txt文件做測試&#xff0c;main.p…

uv,下一代Python包管理工具

什么是uv uv&#xff08;Universal Virtual&#xff09;是由Astral團隊&#xff08;知名Python工具Ruff的開發者&#xff09;推出的下一代Python包管理工具&#xff0c;使用Rust編寫。它集成了包管理、虛擬環境、依賴解析、Python版本控制等功能&#xff0c;它聚焦于三個關鍵點…

單片機的輸出模式推挽和開漏如何選擇呢?

推挽和開漏是單片機的輸出模式&#xff0c;屬于I/O口配置的常見類型。開漏&#xff08;Open-Drain&#xff09;和推挽&#xff08;Push-Pull&#xff09;是兩種根本不同的輸出電路結構&#xff0c;理解它們的區別是正確使用任何單片機&#xff08;包括51和STM32&#xff09;GPI…

java18學習筆記-Simple Web Server

408:Simple Web Server Python、Ruby、PHP、Erlang 和許多其他平臺提供從命令行運行的開箱即用服務器。這種現有的替代方案表明了對此類工具的公認需求。 提供一個命令行工具來啟動僅提供靜態文件的最小web服務器。沒有CGI或類似servlet的功能可用。該工具將用于原型設計、即…

深度解析Atlassian 團隊協作套件(Jira、Confluence、Loom、Rovo)如何賦能全球分布式團隊協作

無窮無盡的聊天記錄、混亂不堪的文檔、反饋信息分散在各個不同時區……在全球分布式團隊中開展真正的高效協作&#xff0c;就像是一場不可能完成的任務。 為什么會這樣&#xff1f;因為即使是最聰明的團隊&#xff0c;也會遇到類似的障礙&#xff1a; 割裂的工作流&#xff1a…

理解AI 智能體:智能體架構

1. 引言 智能體架構&#xff08;agent architecture&#xff09;是一份藍圖&#xff0c;它定義了AI智能體各組件的組織方式和交互機制&#xff0c;使智能體能夠感知環境、進行推理并采取行動。本質上&#xff0c;它就像是智能體的數字大腦——整合了“眼睛”&#xff08;傳感器…

Spring Cloud系列—SkyWalking鏈路追蹤

上篇文章&#xff1a; Spring Cloud系列—Seata分布式事務解決方案TCC模式和Saga模式https://blog.csdn.net/sniper_fandc/article/details/149947829?fromshareblogdetail&sharetypeblogdetail&sharerId149947829&sharereferPC&sharesourcesniper_fandc&…

機器人領域的算法研發

研究生期間學習大模型&#xff0c;可投遞機器人領域的算法研發、技術支持等相關崗位&#xff0c;以下是具體推薦&#xff1a; AI算法工程師&#xff08;大模型方向-機器人應用&#xff09;&#xff1a;主要負責大模型開發與優化&#xff0c;如模型預訓練、調優及訓練效率提升等…

深度學習入門:神經網絡

文章目錄一、深度學習基礎認知二、神經網絡核心構造解析2.1 神經元的基本原理2.2 感知器&#xff1a;最簡單的神經網絡2.3 多層感知器&#xff1a;引入隱藏層解決非線性問題2.3.1 多層感知器的結構特點2.3.2 偏置節點的作用2.3.3 多層感知器的計算過程三、神經網絡訓練核心方法…

mysql的索引有哪些?

1. 主鍵索引&#xff08;PRIMARY KEY&#xff09;主鍵索引通常在創建表時定義&#xff0c;確保字段唯一且非空&#xff1a;-- 建表時直接定義主鍵 CREATE TABLE users (id INT NOT NULL,name VARCHAR(50),PRIMARY KEY (id) -- 單字段主鍵 );-- 復合主鍵&#xff08;多字段組合…

【計算機視覺與深度學習實戰】08基于DCT、DFT和DWT的圖像變換處理系統設計與實現(有完整代碼python3.13可直接粘貼使用)

1. 引言 數字圖像處理作為計算機視覺和信號處理領域的重要分支,在過去幾十年中得到了快速發展。圖像變換技術作為數字圖像處理的核心技術之一,為圖像壓縮、特征提取、去噪和增強等應用提供了強有力的數學工具。離散余弦變換(Discrete Cosine Transform, DCT)、離散傅里葉變…

使用Python實現DLT645-2007智能電表協議

文章目錄&#x1f334;通訊支持&#x1f334; 功能完成情況服務端架構設計一、核心模塊劃分二、數據層定義三、協議解析層四、通信業務層&#xff08;以DLT645服務端為例&#xff09;五、通信層&#xff08;以TCP為例&#xff09;使用例子&#x1f334;通訊支持 功能狀態TCP客…