《昇思25天學習打卡營第6天 | 函數式自動微分》

《昇思25天學習打卡營第6天 | 函數式自動微分》

目錄

  • 《昇思25天學習打卡營第6天 | 函數式自動微分》
    • 函數式自動微分
    • 簡單的單層線性變換模型
    • 函數與計算圖
    • 微分函數與梯度計算
    • Stop Gradient

函數式自動微分

神經網絡的訓練主要使用反向傳播算法,模型預測值(logits)與正確標簽(label)送入損失函數(loss function)獲得loss,然后進行反向傳播計算,求得梯度(gradients),最終更新至模型參數(parameters)。自動微分能夠計算可導函數在某點處的導數值,是反向傳播算法的一般化。自動微分主要解決的問題是將一個復雜的數學運算分解為一系列簡單的基本運算,該功能對用戶屏蔽了大量的求導細節和過程,大大降低了框架的使用門檻。

MindSpore使用函數式自動微分的設計理念,提供更接近于數學語義的自動微分接口grad和value_and_grad。

簡單的單層線性變換模型

我們通過學習使用一個簡單的單層線性變換模型來了解

import numpy as np
import mindspore
from mindspore import nn
from mindspore import ops
from mindspore import Tensor, Parameter

函數與計算圖

計算圖是用圖論語言表示數學函數的一種方式,也是深度學習框架表達神經網絡模型的統一方法。我們將根據下面的計算圖構造計算函數和神經網絡。

compute-graph
在這個模型中, 𝑥
為輸入, 𝑦
為正確值, 𝑤
和 𝑏
是我們需要優化的參數。

x = ops.ones(5, mindspore.float32)  # input tensor
y = ops.zeros(3, mindspore.float32)  # expected output
w = Parameter(Tensor(np.random.randn(5, 3), mindspore.float32), name='w') # weight
b = Parameter(Tensor(np.random.randn(3,), mindspore.float32), name='b') # bias

我們根據計算圖描述的計算過程,構造計算函數。 其中,binary_cross_entropy_with_logits 是一個損失函數,計算預測值和目標值之間的二值交叉熵損失。

def function(x, y, w, b):z = ops.matmul(x, w) + bloss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))return loss

執行計算函數,可以獲得計算的loss值。

loss = function(x, y, w, b)
print(loss)

Tensor(shape=[], dtype=Float32, value= 0.914285)

微分函數與梯度計算

為了優化模型參數,需要求參數對loss的導數: ? loss ? ? w \frac{\partial \operatorname{loss}}{\partial w} ?w?loss? ? loss ? ? b \frac{\partial \operatorname{loss}}{\partial b} ?b?loss?,此時我們調用mindspore.grad函數,來獲得function的微分函數。

這里使用了grad函數的兩個入參,分別為:

  • fn:待求導的函數。
  • grad_position:指定求導輸入位置的索引。

由于我們對 w w w b b b求導,因此配置其在function入參對應的位置(2, 3)

使用grad獲得微分函數是一種函數變換,即輸入為函數,輸出也為函數。

grad_fn = mindspore.grad(function, (2, 3))

執行微分函數,即可獲得 𝑤
、 𝑏
對應的梯度。

grads = grad_fn(x, y, w, b)
print(grads)

(Tensor(shape=[5, 3], dtype=Float32, value=
[[ 6.56869709e-02, 5.37334494e-02, 3.01467031e-01],
[ 6.56869709e-02, 5.37334494e-02, 3.01467031e-01],
[ 6.56869709e-02, 5.37334494e-02, 3.01467031e-01],
[ 6.56869709e-02, 5.37334494e-02, 3.01467031e-01],
[ 6.56869709e-02, 5.37334494e-02, 3.01467031e-01]]),
Tensor(shape=[3], dtype=Float32, value= [ 6.56869709e-02, 5.37334494e-02, 3.01467031e-01]))

Stop Gradient

通常情況下,求導時會求loss對參數的導數,因此函數的輸出只有loss一項。當我們希望函數輸出多項時,微分函數會求所有輸出項對參數的導數。此時如果想實現對某個輸出項的梯度截斷,或消除某個Tensor對梯度的影響,需要用到Stop Gradient操作。

這里我們將function改為同時輸出loss和z的function_with_logits,獲得微分函數并執行。

def function_with_logits(x, y, w, b):z = ops.matmul(x, w) + bloss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))return loss, z
grad_fn = mindspore.grad(function_with_logits, (2, 3))
grads = grad_fn(x, y, w, b)
print(grads)

(Tensor(shape=[5, 3], dtype=Float32, value=
[[ 1.06568694e+00, 1.05373347e+00, 1.30146706e+00],
[ 1.06568694e+00, 1.05373347e+00, 1.30146706e+00],
[ 1.06568694e+00, 1.05373347e+00, 1.30146706e+00],
[ 1.06568694e+00, 1.05373347e+00, 1.30146706e+00],
[ 1.06568694e+00, 1.05373347e+00, 1.30146706e+00]]),
Tensor(shape=[3], dtype=Float32, value= [ 1.06568694e+00, 1.05373347e+00, 1.30146706e+00]))

可以看到求得 w w w b b b對應的梯度值發生了變化。此時如果想要屏蔽掉z對梯度的影響,即仍只求參數對loss的導數,可以使用ops.stop_gradient接口,將梯度在此處截斷。我們將function實現加入stop_gradient,并執行。

def function_stop_gradient(x, y, w, b):z = ops.matmul(x, w) + bloss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))return loss, ops.stop_gradient(z)
grad_fn = mindspore.grad(function_stop_gradient, (2, 3))
grads = grad_fn(x, y, w, b)
print(grads)

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/37395.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/37395.shtml
英文地址,請注明出處:http://en.pswp.cn/web/37395.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

建站小記:遷移域名DNS到CloudFlare

CloudFlare一直有賽博菩薩之稱,據說用它做DNS解析服務又快又好又免費,還能防DDOS攻擊,并且可以提供頁面訪問統計功能。 正好我博客網頁打開略卡頓,所以決定將自己的DNS解析遷移到CloudFlare。 1.登錄CF控制臺,添加自己…

LeetCode-刷題記錄-二分法合集(本篇blog會持續更新哦~)

一、二分查找概述 二分查找(Binary Search)是一種高效的查找算法,適用于有序數組或列表。(但其實只要滿足二段性,就可以使用二分法,本篇博客后面博主會持續更新一些題,來破除一下人們對“只有有…

(已解決)Adobe Flash Player已不再受支持

文章目錄 前言解決方案 前言 一般來說,很少遇到官方網站使用Adobe Flash Player來進行錄用名單公示了。但是,今天就偏偏遇到一次, 用谷歌瀏覽器打不開, 點了沒有反應,用其他的瀏覽器,例如windows自帶的那…

Golang | Leetcode Golang題解之第207題課程表

題目: 題解: func canFinish(numCourses int, prerequisites [][]int) bool {var (edges make([][]int, numCourses)indeg make([]int, numCourses)result []int)for _, info : range prerequisites {edges[info[1]] append(edges[info[1]], info[0]…

數據結構:期末考 第六次測試(總復習)

一、 單選題 (共50題,100分) 1、表長為n的順序存儲的線性表,當在任何位置上插入或刪除一個元素的概率相等時,插入一個元素所需移動元素的平均個數為( D ).(2.0) A、 &am…

在node環境使用MySQL

什么是Sequelize? Sequelize是一個基于Promise的NodeJS ORM模塊 什么是ORM? ORM(Object-Relational-Mapping)是對象關系映射 對象關系映射可以把JS中的類和對象,和數據庫中的表和數據進行關系映射。映射之后我們就可以直接通過類和對象來操作數據表和數據了, 就…

join()方法——連接字符串、元組、列表和字典

自學python如何成為大佬(目錄):https://blog.csdn.net/weixin_67859959/article/details/139049996?spm1001.2014.3001.5501 語法參考 join()方法用于連接字符串數組。將字符串、元組、列表中的元素以指定的字符(分隔符)連接生成一個新的字符串&#…

喜報 | 極限科技獲得北京市“創新型”中小企業資格認證

2024年6月20日,北京市經濟和信息化局正式發布《關于對2024年度4月份北京市創新型中小企業名單進行公告的通知》,極限數據(北京)科技有限公司憑借其出色的創新能力和卓越的企業實力,成功獲得“北京市創新型中小企業”的…

學會python——在excel中寫入數據(python實例十三)

目錄 1.認識Python 2.環境與工具 2.1 python環境 2.2 Visual Studio Code編譯 3 .想Excel中寫入數據 3.1 代碼構思 3.2 代碼實例 3.3 運行結果 4.總結 1.認識Python Python 是一個高層次的結合了解釋性、編譯性、互動性和面向對象的腳本語言。 Python 的設計具有很強的…

數據結構算法之B樹

一、緒論 1.1 數據結構的概念和作用 1.2 B樹的起源和應用領域 二、B樹的基本原理 2.1 B樹的定義和特點 2.2 B樹的結構和節點組成 2.3 B樹的插入 2.4 B樹的刪除操作 三、B樹的優勢和應用 3.1 B樹在數據庫系統中的應用 3.2 B樹在文件系統中的應用 3.3 B樹在內存管理中…

HTML5的多線程技術:Shared Worker的使用示例

Shared Worker 與普通的 Web Worker 類似,但不同之處在于它可以被多個瀏覽器窗口、標簽頁或者iframe共享,使得這些上下文之間能夠相互通信。下面是一個使用 Shared Worker 的完整示例。共享Worker腳本(sharedWorker.js) self.add…

isupper()方法——判斷字符串是否全由大寫字母組成

自學python如何成為大佬(目錄):https://blog.csdn.net/weixin_67859959/article/details/139049996?spm1001.2014.3001.5501 語法參考 isupper()方法用于判斷字符串中所有的字母是否都是大寫。isupper()方法的語法格式如下: str.isupper() 如果字符串中包含至少…

我是如何在bytemd中實現自定義目錄的

介紹 接著上文說完,實現了在markdown編輯器中插入視頻的能力,接下來還需要繼續優化 markdown文檔的閱讀體驗,比如 再加個目錄 熟悉markdown語法的朋友可能會說,直接在編輯時添加 toc 標簽,可以在文章頂部自動生成目錄…

實驗三 時序邏輯電路實驗

仿真 鏈接:https://pan.baidu.com/s/1z9KFQANyNF5PvUPPYFQ9Ow 提取碼:e3md 一、實驗目的 1、通過實驗,理解觸發的概念,理解JK、D等常見觸發器的功能; 2、通過實驗,加深集成計數器功能的理解,掌…

?Ollama的本地安裝?

先來逛一下咱們的主角Ollama的官網地址: Ollama 大概長這個樣子🤔 因為本地系統的原因,文章只提供Widows的安裝方式,使用Linux和Mac的大佬,可以自行摸索🧐 下載完成后就是安裝了🍕&#xff0c…

一、Redis簡介

一、Redis介紹與一般應用 1.1 基本了解 Redis全稱Remote Dictionary Server(遠程字典服務), 是一個開源的高性能鍵值存儲系統,通常用作數據庫、緩存和消息代理。使用ANSI C語言編寫遵守BSD協議,是一個高性能的Key-Value數據庫提供了豐富的數…

JVM性能監控與調優:生產環境的實踐指南

JVM性能監控與調優:生產環境的實踐指南 一、引言 在生產環境中,Java應用程序的性能監控和調優是確保系統穩定運行、提升用戶體驗的關鍵環節。JVM(Java Virtual Machine)作為Java應用程序的運行環境,其性能直接影響到…

Flink 本地任務添加配置參數

Flink 本地任務添加配置參數 配置一個Configuration,然后通過StreamExecutionEnvironment.getExecutionEnvironment(configuration)傳入。 例如: Configuration configuration new Configuration();configuration.set(RestartStrategyOptions.RESTART_…

蘋果筆記本能玩網頁游戲嗎 蘋果電腦玩steam游戲怎么樣 蘋果手機可以玩游戲嗎 mac電腦安裝windows

蘋果筆記本有著優雅的機身、強大的性能,每次更新迭代都備受用戶青睞。但是,當需要使用蘋果筆記本進行游戲時,很多人會有疑問:蘋果筆記本能玩網頁游戲嗎?蘋果筆記本適合打游戲嗎?本文將討論這兩個話題&#…

6-14題連接 - 高頻 SQL 50 題基礎版

目錄 1. 相關知識點2. 例子2.6. 使用唯一標識碼替換員工ID2.7- 產品銷售分析 I2.8 - 進店卻未進行過交易的顧客2.9 - 上升的溫度2.10 - 每臺機器的進程平均運行時間2.11- 員工獎金2.12-學生們參加各科測試的次數2.13-至少有5名直接下屬的經理2.14 - 確認率 1. 相關知識點 left …