動手實現一個帶自動微分的深度學習框架

動手實現一個帶自動微分的深度學習框架

轉自:Automatic Differentiation Tutorial

參考代碼:https://github.com/borgwang/tinynn-autograd (主要看 core/tensor.py 和 core/ops.py)

在這里插入圖片描述

目錄

  • 簡介
  • 自動求導設計
  • 自動求導實現
  • 一個例子
  • 總結
  • 參考資料

簡介

梯度下降(Gradient Descent)及其衍生算法是神經網絡訓練的基礎,梯度下降本質上就是求解損失關于網絡參數的梯度,不斷計算這個梯度對網絡參數進行更新。現代的神經網絡框架都實現了自動求導的功能,只需要要定義好網絡前向計算的邏輯,在運算時自動求導模塊就會自動把梯度算好,不用自己手寫求導梯度。

筆者在之前的 一篇文章 中講解和實現了一個迷你的神經網絡框架 tinynn,在 tinynn 中我們定義了網絡層 layer 的概念,整個網絡是由一層層的 layer 疊起來的(全連接層、卷積層、激活函數層、Pooling 層等等),如下圖所示

在這里插入圖片描述

在實現的時候需要顯示為每層定義好前向 forward 和反向 backward(梯度計算)的計算邏輯。從本質上看 這些 layer 其實是一組基礎算子的組合,而這些基礎算子(加減乘除、矩陣變換等等)的導函數本身都比較簡單,如果能夠將這些基礎算子的導函數寫好,同時把不同算子之間連接邏輯記錄(計算依賴圖)下來,那么這個時候就不再需要自己寫反向了,只需要計算損失,然后從損失函數開始,讓梯度自己用預先定義好的導函數,沿著計算圖反向流動即可以得到參數的梯度,這個就是自動求導的核心思想。tinynn 中之所有 layer 這個概念,一方面是符合我們直覺上的理解,另一方面是為了在沒有自動求導的情況下方便實現。有了自動求導,我們可以拋開 layer 這個概念,神經網絡的訓練可以抽象為定義好一個網絡的計算圖,然后讓數據前向流動,讓梯度自動反向流動( TensorFlow 這個名字起得相當有水準)。

我們可以看看 PyTorch 的一小段核心的訓練代碼(來源官方文檔 MNIST 例子)

for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()  # 初始化梯度output = model(data)  # 從 data 到 output 的計算圖loss = F.nll_loss(output, target) # 從 output 到 loss 的計算圖loss.backward()  # 梯度從 loss 開始反向流動optimizer.step()  # 使用梯度對參數更新

可以看到 PyTorch 的基本思路和我們上面描述的是一致的,定義好計算圖 -> forward 得到損失 -> 梯度反向流動。

自動求導設計

知道了自動求導的基本流程之后,我們考慮如何來實現。先考慮沒有自動求導,為每個運算手動寫 backward 的情況,在這種情況下我們實際上定義了兩個計算圖,一個前向一個反向,考慮最簡單的線性回歸的運算 WX+BWX+BWX+B,其計如下所示。

在這里插入圖片描述

可以看到這兩個計算圖的結構實際上是一樣的,只是在前向流動的是計算的中間結果,反向流動的是梯度,以及中間的運算反向的時候是導數運算。實際上我們可以把兩者結合到一起,只定義一次前向計算圖,讓反向計算圖自動生成

在這里插入圖片描述

從實現的角度看,如果我們不需要自動求導,那么網絡框架中的 Tensor 類只需要對 Tensor 運算符有定義,能夠進行數值運算(tinynn 中就簡單的使用 ndarray 作為 Tensor 的實現)。但如果要實現自動求導,那么 Tensor 類需要額外做幾件事:

  1. 增加一個梯度的變量保存當前 tensor 的梯度
  2. 保存當前 tensor 依賴的 tensor(如上圖中 O1O1 依賴于 X,WX,W)
  3. 保存下對各個依賴 tensor 的導函數(這個導函數的作用是將當前 tensor 的梯度傳到依賴的 tensor 上)

自動求導實現

我們按照上面的分析開始實現 Tensor 類如下,初始化方法中首先把 tensor 的值保存下來,然后有一個 requires_grad 的 bool 變量表明這個 tensor 是不是需要求梯度,還有一個 dependency 的列表用于保存該 tensor 依賴的 tensor 以及對于他們的導函數。

zero_grad() 方法比較簡單,將當前 tensor 的梯度設置為 0,防止梯度的累加。自動求導從調用計算圖的最后一個節點 tensor 的 backward() 方法開始(在神經網絡中這個節點一般是 loss)。backward() 方法主要流程為

  • 確保改 tensor 確實需要求導 self.requires_grad == True
  • 將從上個 tensor 傳進來的梯度加到自身梯度上,如果沒有(反向求導的起點 tensor),則將梯度初始化為 1.0
  • 對每一個依賴的 tensor 運行保存下來的導函數,計算傳播到依賴 tensor 的梯度,然后調用依賴 tensor 的 backward() 方法。可以看到這其實就是 Depth-First Search 計算圖的節點
def as_tensor(obj):if not isinstance(obj, Tensor):obj = Tensor(obj)return objclass Tensor:def __init__(self, values, requires_grad=False, dependency=None):self._values = np.array(values)self.shape = self.values.shapeself.grad = Noneif requires_grad:self.zero_grad()self.requires_grad = requires_gradif dependency is None:dependency = []self.dependency = dependency@propertydef values(self):return self._values@values.setterdef values(self, new_values):self._values = np.array(new_values)self.grad = Nonedef zero_grad(self):self.grad = np.zeros(self.shape)def backward(self, grad=None):assert self.requires_grad, "Call backward() on a non-requires-grad tensor."grad = 1.0 if grad is None else gradgrad = np.array(grad)# accumulate gradientself.grad += grad# propagate the gradient to its dependenciesfor dep in self.dependency:grad_for_dep = dep["grad_fn"](grad)dep["tensor"].backward(grad_for_dep)

可能看到這里讀者可能會疑問,一個 tensor 依賴的 tensor 和對他們的導函數(也就是 dependency 里面的東西)從哪里來?似乎沒有哪一個方法在做保存依賴這件事。

假設我們可能會這樣使用我們的 Tensor 類

W = Tensor([[1], [3]], requires_grad=True)  # 2x1 tensor
X = Tensor([[1, 2], [3, 4], [5, 6], [7, 8]], requires_grad=True)  # 4x2 tensor
O = X @ W  # suppose to be a 4x1 tensor

如何讓 XW 完成矩陣乘法輸出正確的 O 的同時,讓 O 能記下他依賴于 WX 呢?答案是重載運算符

class Tensor:# ...def __matmul__(self, other):# 1. calculate forward valuesvalues = self.values @ other.values# 2. if output tensor requires_gradrequires_grad = ts1.requires_grad or ts2.requires_grad# 3. build dependency listdependency = []if self.requires_grad:# O = X @ W# D_O / D_X = grad @ W.Tdef grad_fn1(grad):return grad @ other.values.Tdependency.append(dict(tensor=self, grad_fn=grad_fn1))if other.requires_grad:# O = X @ W# D_O / D_W = X.T @ graddef grad_fn2(grad):return self.values.T @ graddependency.append(dict(tensor=other, grad_fn=grad_fn2))return Tensor(values, requires_grad, dependency)# ...

關于 Python 中如何重載運算符這里不展開,讀者有興趣可以參考官方文檔或者這篇文章。基本上在 Tensor 類內定義了 __matmul__ 這個方法后,實際上是重載了矩陣乘法運算符 @ (Python 3.5 以上支持) 。當運行 X @ W 時會自動調用 X__matmul__ 方法。

這個方法里面做了三件事:

  1. 計算矩陣乘法結果(這個是必須的)

  2. 確定是否需要新生成的 tensor 是否需要梯度,這個由兩個操作數決定。比如在這個例子中,如果 W 或者 X 需要梯度,那么生成的 O也是需要計算梯度的(這樣才能夠計算 W 或者 X 的梯度)

  3. 建立 tensor 的依賴列表

    自動求導中最關鍵的部分就是在這里了,還是以 O = X @ W 為例子,這里我們會先檢查是否 X需要計算梯度,如果需要,我們需要把導函數 D_O / D_X 定義好,保存下來;同樣的如果 W 需要梯度,我們將 D_O / D_W 定義好保存下來。最后生成一個 dependency 列表保存著在新生成的 tensor O 中。

然后我們再回顧前面講的 backward()方法,backward() 方法會遍歷 tensor 的 dependency ,將用保存的 grad_fn 計算要傳給依賴 tensor 的梯度,然后調用依賴 tensor 的 backward() 方法將梯度傳遞下去,從而實現了梯度在整個計算圖的流動。

grad_for_dep = dep["grad_fn"](grad)
dep["tensor"].backward(grad_for_dep)

自動求導講到這里其實已經基本沒有什么新東西,剩下的工作就是以類似的方法大量地重載各種各樣的運算符,使其能夠 cover 住大部分所需要的操作(基本上照著 NumPy 的接口都給重載一次就差不多了 🤨)。無論你定義了多復雜的運算,只要重載了相關的運算符,就都能夠自動求導了,再也不用自己寫梯度了。

在這里插入圖片描述

一個例子

大量的重載運算符的工作在文章里就不貼上來了(過程不怎么有趣),我寫在了一個 notebook 上,大家有興趣可以去看看 borgwang/toys/ml-autograd。在這個 notebook 里面重載了實現一個簡單的線性回歸需要的幾種運算符,以及一個線性回歸的例子。這里把例子和結果貼上來

# training data
x = Tensor(np.random.normal(0, 1.0, (100, 3)))
coef = Tensor(np.random.randint(0, 10, (3,)))
y = x * coef - 3params = {"w": Tensor(np.random.normal(0, 1.0, (3, 3)), requires_grad=True),"b": Tensor(np.random.normal(0, 1.0, 3), requires_grad=True)
}learng_rate = 3e-4
loss_list = []
for e in range(101):# set gradient to zerofor param in params.values():param.zero_grad()# forwardpredicted = x @ params["w"] + params["b"]err = predicted - yloss = (err * err).sum()# backward automaticallyloss.backward()# updata parameters (gradient descent)for param in params.values():param -= learng_rate * param.gradloss_list.append(loss.values)if e % 10 == 0:print("epoch-%i \tloss: %.4f" % (e, loss.values))
epoch-0 	loss: 8976.9821
epoch-10 	loss: 2747.4262
epoch-20 	loss: 871.4415
epoch-30 	loss: 284.9750
epoch-40 	loss: 95.7080
epoch-50 	loss: 32.9175
epoch-60 	loss: 11.5687
epoch-70 	loss: 4.1467
epoch-80 	loss: 1.5132
epoch-90 	loss: 0.5611
epoch-100 	loss: 0.2111

在這里插入圖片描述

接口和 PyTorch 相似,在每個循環里面首先將參數梯度設為 0 ,然后定義計算圖,然后從 loss 開始反向傳播,最后更新參數。從結果可以看到 loss 隨著訓練進行非常漂亮地下降,說明我們的自動求導按照我們的設想 work 了。

總結

本文實現了討論了自動求導的設計思路和整個過程是怎么運作的。總結起來:自動求導就是在定義了一個有狀態的計算圖,該計算圖上的節點不僅保存了節點的前向運算,還保存了反向計算所需的上下文信息。利用上下文信息,通過圖遍歷讓梯度在圖中流動,實現自動求節點梯度。

我們通過重載運算符實現了一個支持自動求導的 Tensor 類,用一個簡單的線性回歸 demo 測試了自動求導。當然這只是最基本的能實現自動求導功能的 demo,從實現的角度上看還有很多需要優化的地方(內存開銷、運算速度等),筆者有空會繼續深入研究,讀者如果有興趣也可以自行查閱相關資料。Peace out. 🤘

參考資料

  • PyTorch Doc
  • PyTorch Autograd Explained - In-depth Tutorial
  • joelgrus/autograd
  • Automatic Differentiation in Machine Learning: a Survey

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/532393.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/532393.shtml
英文地址,請注明出處:http://en.pswp.cn/news/532393.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

git安裝后找不見版本_結果發現git版本為1.7.4,(git --version)而官方提示必須是1.7.10及以后版本...

結果發現git版本為1.7.4,(git --version)而官方提示必須是1.7.10及以后版本升級增加ppasudo apt-add-repository ppa:git-core/ppasudo apt-get updatesudo apt-get install git如果本地已經安裝過Git,可以使用升級命令:sudo apt-get dist-upgradeapt命令…

隨機數生成算法:K進制逐位生成+拒絕采樣

隨機數生成算法:K進制逐位生成拒絕采樣 轉自:【宮水三葉】k 進制諸位生成 拒絕采樣 基本分析 給定一個隨機生成 1 ~ 7 的函數,要求實現等概率返回 1 ~ 10 的函數。 首先需要知道,在輸出域上進行定量整體偏移,仍然滿…

深入理解NLP Subword算法:BPE、WordPiece、ULM

深入理解NLP Subword算法:BPE、WordPiece、ULM 本文首發于微信公眾號【AI充電站】,感謝大家的贊同、收藏和轉發(▽) 轉自:深入理解NLP Subword算法:BPE、WordPiece、ULM 前言 Subword算法如今已經成為了一個重要的NLP模型性能提升…

http 錯誤 404.0 - not found_電腦Regsvr32 用法和錯誤消息的說明

? 對于那些可以自行注冊的對象鏈接和嵌入 (OLE) 控件,例如動態鏈接庫 (DLL) 文件或 ActiveX 控件 (OCX) 文件,您可以使用 Regsvr32 工具 (Regsvr32.exe) 來將它們注冊和取消注冊。Regsvr32.exe 的用法RegSvr32.exe 具有以下命令行選項: Regs…

mysql error 1449_MySql錯誤:ERROR 1449 (HY000)

筆者系統為 mac ,不知怎的,Mysql 竟然報如下錯誤:ERROR 1449 (HY000): The user specified as a definer (mysql.infoschemalocalhost) does not exist一時沒有找到是什么操作導致的這個錯誤。然后經過查詢,參考文章解決了問題。登…

MobileNet 系列:從V1到V3

MobileNet 系列:從V1到V3 轉自:輕量級神經網絡“巡禮”(二)—— MobileNet,從V1到V3 自從2017年由谷歌公司提出,MobileNet可謂是輕量級網絡中的Inception,經歷了一代又一代的更新。成為了學習輕…

mysql 查詢表的key_mysql查詢表和字段的注釋

1,新建表以及添加表和字段的注釋.create table auth_user(ID INT(19) primary key auto_increment comment 主鍵,NAME VARCHAR(300) comment 姓名,CREATE_TIME date comment 創建時間)comment 用戶信息表;2,修改表/字段的注釋.alter table auth_user comment 修改后的表注…

mysql 高級知識點_這是我見過最全的《MySQL筆記》,涵蓋MySQL所有高級知識點!...

作為運維和編程人員,對MySQL一定不會陌生,尤其是互聯網行業,對MySQL的使用是比較多的。MySQL 作為主流的數據庫,是各大廠面試官百問不厭的知識點,但是需要了解到什么程度呢?僅僅停留在 建庫、創表、增刪查改…

teechart mysql_TeeChart 的應用

TeeChart 是一個很棒的繪圖控件,不過由于里面沒有注釋,網上相關的資料也很少,所以在應用的時候只能是一點點的試。為了防止以后用到的時候忘記,我就把自己用到的東西都記錄下來,以便以后使用的時候查詢。1、進制縮放圖…

NLP新寵——淺談Prompt的前世今生

NLP新寵——淺談Prompt的前世今生 轉自:NLP新寵——淺談Prompt的前世今生 作者:閔映乾,中國人民大學信息學院碩士,目前研究方向為自然語言處理。 《Pre-train, Prompt, and Predict: A Systematic Survey of Prompting Methods in…

mysql key_len_淺談mysql explain中key_len的計算方法

mysql的explain命令可以分析sql的性能,其中有一項是key_len(索引的長度)的統計。本文將分析mysql explain中key_len的計算方法。1、創建測試表及數據CREATE TABLE member (id int(10) unsigned NOT NULL AUTO_INCREMENT,name varchar(20) DEFAULT NULL,age tinyint(…

requestfacade 這個是什么類?_Java 的大 Class 到底是什么?

作者在之前工作中,面試過很多求職者,發現有很多面試者對Java的 Class 搞不明白,理解的不到位,一知半解,一到用的時候,就不太會用。想寫一篇關于Java Class 的文章,沒有那么多專業名詞&#xff0…

初學機器學習:直觀解讀KL散度的數學概念

初學機器學習:直觀解讀KL散度的數學概念 轉自:初學機器學習:直觀解讀KL散度的數學概念 譯自:https://towardsdatascience.com/light-on-math-machine-learning-intuitive-guide-to-understanding-kl-divergence-2b382ca2b2a8 解讀…

php mysql讀取數據查詢_PHP MySQL 讀取數據

PHP MySQL 讀取數據從 MySQL 數據庫讀取數據SELECT 語句用于從數據表中讀取數據:SELECT column_name(s) FROM table_name我們可以使用 * 號來讀取所有數據表中的字段:SELECT * FROM table_name如需學習更多關于 SQL 的知識,請訪問我們的 SQL 教程。使用 …

MySQL應用安裝_mysql安裝和應用

1.下載mysql安裝包2.安裝mysql,自定義->修改路徑3.配置mysql,選擇自定義->server模式->500訪問量->勾選控制臺->設置gbk->設置密碼和允許root用戶遠程登錄等等。以管理員權限,在控制臺輸入:net start MySQL, 啟…

mysql 商品規格表_商品規格分析

產品表每次更新商品都會變動的,ID不能用,可是購物車還是用了,這就導致每次保存商品,哪怕什么都沒有改動,也會導致用戶的購物車失效。~~~其實可以考慮不是每次更新商品就除所有的SKU,畢竟有時什么都沒修改呢…

mysql維表的代理鍵字段_mysql多維數據倉庫指南--第三篇第12章(2)

賓夕法尼亞州地區客戶維在本節我將用賓夕法尼亞州地區客戶的子集維度來解釋第二種維度子集的類型。我也將向你說明如何測試該子集維度。相對的,一個向上鉆取的維包含了它基礎維的所有更高級別的數據。而一個特定子集維度則選擇了它基礎維的某個特定的數據集合。列表…

huggingface NLP工具包教程1:Transformers模型

huggingface NLP工具包教程1:Transformers模型 原文:TRANSFORMER MODELS 本課程會通過 Hugging Face 生態系統中的一些工具包,包括 Transformers, Datasets, Tokenizers, Accelerate 和 Hugging Face Hub。…

mysql日期比較timestamp_Mysql中的Datetime和Timestamp比較(轉載)

mysql中用于表示時間的三種類型date, datetime, timestamp (如果算上int的話,四種) 比較容易混淆,下面就比較一下這三種類型的異同相同點都可以用于表示時間都呈字符串顯示不同點1.顧名思義,date只表示YYYY-MM-DD形式的日期,datet…

隱馬爾可夫模型HMM推導

隱馬爾可夫模型HMM推導 機器學習-白板推導系列(十四)-隱馬爾可夫模型HMM(Hidden Markov Model) 課程筆記 背景介紹 介紹一下頻率派和貝葉斯派兩大流派發展出的建模方式。 頻率派 頻率派逐漸發展成了統計機器學習,該流派通常將任務建模為一…