上一章中,我們介紹了神經網絡的學習,并通過數值微分計算了神經網絡的權重參數的梯度(嚴格來說,是損失函數關于權重參數的梯度)。數值微分雖然簡單,也容易實現,但缺點是計算上比較費時間。本章我們將學習一個能夠高效計算權重參數的梯度的方法——誤差反向傳播法。
正確理解誤差反向傳播法,我個人認為有兩種方法:一種是基于數學式;另一種是基于計算圖(computational graph)。
本章希望大家通過計算圖,直觀地理解誤差反向傳播法。
5.1 計算圖
5.1.1 用計算圖求解
問題1:太郎在超市買了2個100日元一個的蘋果,消費稅是10%,請計算支付金額。
可以將“2”和“1.1”分別作為變量“蘋果的個數”和“消費稅”標在○外面。
問題2:太郎在超市買了2個蘋果、3個橘子。其中,蘋果每個100日元,橘子每個150日元。消費稅是10%,請計算支付金額。
綜上,用計算圖解題的情況下,需要按如下流程進行。
1.構建計算圖。
2.在計算圖上,從左向右進行計算。
這里的第2歩“從左向右進行計算”是一種正方向上的傳播,簡稱為正向傳播(forward propagation)。正向傳播是從計算圖出發點到結束點的傳播。 既然有正向傳播這個名稱,當然也可以考慮反向(從圖上看的話,就是從右向左)的傳播。實際上,這種傳播稱為反向傳播(backward propagation)。反向傳播將在接下來的導數計算中發揮重要作用。
5.1.2 局部計算
局部計算是指,無論全局發生了什么,都能只根據與自己相關的信息輸出接下來的結果。
我們用一個具體的例子來說明局部計算。比如,在超市買了2個蘋果和其他很多東西
這里的重點是,各個節點處的計算都是局部計算。這意味著,例如蘋果和其他很多東西的求和運算(4000 + 200 → 4200)并不關心4000這個數字是如何計算而來的,只要把兩個數字相加就可以了。換言之,各個節點處只需進行與自己有關的計算(在這個例子中是對輸入的兩個數字進行加法運算),不用考慮全局
5.1.3 為何用計算圖解題
那么計算圖到底有什么優點呢?
一個優點就在于前面所說的局部計算。無論全局是多么復雜的計算,都可以通過局部計算使各個節點致力于簡單的計算,從而簡化問題。另一個優點是,利用計算圖可以將中間的計算結果全部保存起來(比如,計算進行到2個蘋果時的金額是200日元、加上消費稅之前的金額650日元等)。但是只有這些理由可能還無法令人信服。實際上,使用計算圖最大的原因是,可以通過反向傳播高效計算導數。
這里,假設我們想知道蘋果價格的上漲會在多大程度上影響最終的支付金額,即求“支付金額關于蘋果的價格的導數”。設蘋果的價格為x,支付金額為L,則相當于求
。這個導數的值表示當蘋果的價格稍微上漲時,支付金額會增加多少。
反向傳播使用與正方向相反的箭頭(粗線)表示。反向傳播傳遞“局部導數”,將導數的值寫在箭頭的下方。在這個例子中,反向傳播從右向左傳遞導數的值(1 → 1.1 → 2.2)。從這個結果中可知,“支付金額關于蘋果的價格的導數”的值是2.2。這意味著,如果蘋果的價格上漲1日元,最終的支付金額會增加2.2日元(嚴格地講,如果蘋果的價格增加某個微小值,則最終的支付金額將增加那個微小值的2.2倍)。
5.2 鏈式法則
5.2.1 計算圖的反向傳播
讓我們先來看一個使用計算圖的反向傳播的例子。假設存在y = f(x)的計算,這個計算的反向傳播如圖5-6所示。
如圖所示,反向傳播的計算順序是,將信號E乘以節點的局部導數
,然后將結果傳遞給下一個節點。這里所說的局部導數是指正向傳播中y = f(x)的導數,也就是y關于x的導數
。比如,假設y = f(x) =
, 則局部導數為
= 2x。把這個局部導數乘以上游傳過來的值(本例中為E),然后傳遞給前面的節點。
5.2.2 什么是鏈式法則
介紹鏈式法則時,我們需要先從復合函數說起。復合函數是由多個函數構成的函數。比如,z = (x + y) 2 是由式(5.1)所示的兩個式子構成的。
鏈式法則是關于復合函數的導數的性質,定義如下。
如果某個函數由復合函數表示,則該復合函數的導數可以用構成復合函數的各個函數的導數的乘積表示。
這就是鏈式法則的原理,乍一看可能比較難理解,但實際上它是一個非常簡單的性質。以式(5.1)為例,
(z關于x的導數)可以用
(z關于t 的導數)和
(t關于x的導數)的乘積表示。用數學式表示的話,可以寫成式(5.2)。
式(5.2)中的? t正好可以像下面這樣“互相抵消”,所以記起來很簡單
所以最后要計算的結果
5.2.3 鏈式法則和計算圖
現在我們嘗試將式(5.4)的鏈式法則的計算用計算圖表示出來。如果用“**2”節點表示平方運算的話,則計算圖如圖5-7所示。
根據鏈式法則,
成立,對應“z關于x的導數”。也就是說,反向傳播是基于鏈式法則的。
5.3 反向傳播
5.3.1 加法節點的反向傳播
這里以z = x + y為對象,觀察它的反向傳播。z = x + y的導數可由下式(解析性地)計算出來。
在圖5-9中,反向傳播將從上游傳過來的導數(本例中是
)乘以1,然后傳向下游。也就是說,因為加法節點的反向傳播只乘以1,所以輸入的值會原封不動地流向下一個節點。
另外,本例中把從上游傳過來的導數的值設為
。這是因為,如圖5-10 所示,我們假定了一個最終輸出值為L的大型計算圖。
的計算位于這個大型計算圖的某個地方,從上游會傳來
,并向下游傳遞
和
現在來看一個加法的反向傳播的具體例子。假設有“10 + 5=15”這一計算,反向傳播時,從上游會傳來值1.3。用計算圖表示的話,如圖5-11所示。
5.3.2 乘法節點的反向傳播
這里我們考慮z = xy。這個式子的導數用式(5.6)表示。
乘法的反向傳播會將上游的值乘以正向傳播時的輸入信號的“翻轉值”后傳遞給下游。翻轉值表示一種翻轉關系,如圖5-12所示,正向傳播時信號是x的話,反向傳播時則是y;正向傳播時信號是y的話,反向傳播時則是x。
現在我們來看一個具體的例子。比如,假設有“10 × 5 = 50”這一計算,反向傳播時,從上游會傳來值1.3。用計算圖表示的話,如圖5-13所示。
因為乘法的反向傳播會乘以輸入信號的翻轉值,所以各自可按1.3 × 5 = 6.5、1.3 × 10 = 13計算。另外,加法的反向傳播只是將上游的值傳給下游,并不需要正向傳播的輸入信號。但是,乘法的反向傳播需要正向傳播時的輸入信號值。因此,實現乘法節點的反向傳播時,要保存正向傳播的輸入信號。
5.3.3 蘋果的例子
蘋果的例子(2個蘋果和消費稅)。這里要解的問題是蘋果的價格、蘋果的個數、消費稅這3個變量各自如何影響最終支付的金額。這個問題相當于求“支付金額關于蘋果的價格的導數”“支付金額關于蘋果的個數的導數”“支付金額關于消費稅的導數”。用計算圖的反向傳播來解的話,求解過程如圖5-14所示。
結果可知,蘋果的價格的導數是2.2,蘋果的個數的導數是110,消費稅的導數是200。這可以解釋為,如果消費稅和蘋果的價格增加相同的值,則消費稅將對最終價格產生200倍大小的影響,蘋果的價格將產生2.2倍大小的影響。