深度學習-10-測試
本文是《深度學習入門2-自製框架》 的學習筆記,記錄自己學習心得,以及對重點知識的理解。如果內容對你有幫助,請支持正版,去購買正版書籍,支持正版書籍不僅是尊重作者的辛勤勞動,也是鼓勵更多優秀作品問世。
當前筆記內容主要為:步驟 10 ?測試 章節的相關理解。
書籍總共分為5個階段,每個階段分很多步驟,最終是一步一步實現一個深度學習框架。例如前兩個階段為:
第 1 階段共包括 10 個步驟 。 在這個階段,將創建自動微分的機制
第 2 階段,從步驟11-24,該階段的主要目標是擴展當前的 DeZero ,使它能夠執行更復雜的計算 ,使它能 夠處理接收多個輸入的函數和返回多個輸出的函數
1.Python 的單元測試
軟件開發中測試必不可少,有時候測試都會占用項目流程中很大一段時間。為了保證項目質量,更是要求測試進行相關自動化,以便加速。而且分為sit uat 測試不同階段,來保證投產質量。
不同的變成語言,有不同測試框架,例如java 里面有junit 框架支持。python 語言里面有 unittest 庫來支持。 這里我們以unittest 庫來說明。
編寫代碼
class SquareTest(unittest.TestCase):def test_forward(self):x = Variable(np.array(2.0))y = square(x)expected = np.array(4.0)self.assertEquals(y.data, expected)
執行命令運行測試:
python -m unittest step10.py
注意如果你用的是創建了虛擬venv ,則需要先激活此環境,然后再執行命令
查看輸出結果:
(venv) PS C:\pyworkspace\Dezero> python -m unittest step10.py ?
C:\pyworkspace\Dezero\step10.py:97: DeprecationWarning: Please use assertEqual instead. ? ? ? ?self.assertEquals(y.data, expected)
.
----------------------------------------------------------------------
Ran 1 test in 0.002sOK
我們可以看到測試通過了,并且有匯總信息。這個測試案例是測試-平方函數,我們知道 2的平方等于4 ,結果確實等于4。
2.square 函數反向傳播的測試
對square 函數進行反向傳播測試, 增加一下代碼:
class SquareTest(unittest.TestCase):def test_forward(self):x = Variable(np.array(2.0))y = square(x)expected = np.array(4.0)self.assertEquals(y.data, expected)def test_backward(self):x = Variable(np.array(3.0))y = square(x)y.backward()expected = np.array(6.0)self.assertEquals(x.grad, expected)
其中 test_backward 函數是本次新加的代碼?? ??? ?
查看測試結果:
(venv) PS C:\pyworkspace\Dezero>
(venv) PS C:\pyworkspace\Dezero> python -m unittest step10.py
C:\pyworkspace\Dezero\step10.py:104: DeprecationWarning: Please use assertEqual instead.self.assertEquals(x.grad, expected) ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ??
.. ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?
----------------------------------------------------------------------
Ran 2 tests in 0.002s ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ??OK ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?
(venv) PS C:\pyworkspace\Dezero>?
結果正向傳播,反向傳播的兩個測試案例都通過了。y=x^2 的導函數是 y=2x 在x = 3.0 的時候,導函數的值為 2*3.0 = 6.0 正確。
3.通過梯度檢驗來自動測試
在上面的第二步驟中,我們是手動定義x = 3.0 ?并且我們手動求導發現導函數是 2x ,并且求得值是 ?6.0 ,這一步驟是否可以自動化呢?
這里引入一個方法:梯度檢驗 ,代替手動計算的測試方法。達到高效測試的目的。
# 求導公式計算任意函數倒數
def numberical_diff(f, x, eps= 13-4) :x0= Variable(x.data -eps)x1 = Variable(x.data + eps)y0 = f(x0)y1 = f(x1)return (y1.data -y0.data) /(2*eps)class SquareTest(unittest.TestCase):def test_forward(self):x = Variable(np.array(2.0))y = square(x)expected = np.array(4.0)self.assertEquals(y.data, expected)def test_backward(self):x = Variable(np.array(3.0))y = square(x)y.backward()expected = np.array(6.0)self.assertEquals(x.grad, expected)def test_gradient(self):x = Variable(np.random.random(1)) # 隨機生成x 值y = square(x)y.backward()num_grad = numberical_diff(square, x)flg = np.allclose(x.grad, num_grad) #判斷 ndarray 實例的a,b 值是否接近#如果 a 和 b 的所有元素滿足以 下條件,則返回 Trueself.assertTrue(flg)
再次執行測試案例:
python -m unittest step10.py?
查看執行結果:
(venv) PS C:\pyworkspace\Dezero> python -m unittest step10.py
C:\pyworkspace\Dezero\step10.py:111: DeprecationWarning: Please use assertEqual instead.self.assertEquals(x.grad, expected) ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ??
... ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ??
----------------------------------------------------------------------
Ran 3 tests in 0.002s ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ??OK ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?
(venv) PS C:\pyworkspace\Dezero>?
4.本節所有代碼
'''
step10.py
測試,使用unittest 庫自動進行測試
'''import numpy as np
import unittestclass Variable:def __init__(self, data):if data is not None: # 新增if not isinstance(data, np.ndarray):raise TypeError('{} is not supported'.format(type(data)))self.data = dataself.grad = Noneself.creator = Nonedef set_creator(self, func):self.creator = funcdef backward(self):if self.grad is None:self.grad = np.ones_like(self.data)funcs = [self.creator]while funcs:f = funcs.pop()x, y = f.input, f.outputx.grad = f.backward(y.grad)if x.creator is not None:funcs.append(x.creator)class Function:def __call__(self, input):x = input.datay = self.forward(x) # 新增output = Variable(as_array(y)) # 轉成 ndarray 類型output.set_creator(self) # 輸出者保存創造者對象self.input = inputself.output = output # 保存輸出者。我是創造者的信息,這是動態建立 "連接"這 一 機制的核心return outputdef forward(self, x):raise NotImplementedError() # 使用Function 這個方法forward 方法的人 , 這個方法應該通過繼承采實現def backward(self, gy):raise NotImplementedError()class Square(Function):def forward(self, x):y = x ** 2return ydef backward(self, gy):x = self.input.datagx = 2 * x * gy # 方法的參數 gy 是 一個 ndarray 實例 , 它是從輸出傳播而來的導數 。return gxclass Exp(Function):def forward(self, x):y = np.exp(x)return ydef backward(self, gy):x = self.input.datagx = np.exp(x) * gyreturn gxdef square(x):f = Square()return f(x)def exp(x):f = Exp()return f(x)def as_array(x): # 新增if np.isscalar(x): # 使用 np.isscalar 函數來檢查 numpy.float64 等屬于標量return np.array(x)return x# 求導公式計算任意函數倒數
def numberical_diff(f, x, eps= 13-4) :x0= Variable(x.data -eps)x1 = Variable(x.data + eps)y0 = f(x0)y1 = f(x1)return (y1.data -y0.data) /(2*eps)class SquareTest(unittest.TestCase):def test_forward(self):x = Variable(np.array(2.0))y = square(x)expected = np.array(4.0)self.assertEquals(y.data, expected)def test_backward(self):x = Variable(np.array(3.0))y = square(x)y.backward()expected = np.array(6.0)self.assertEquals(x.grad, expected)def test_gradient(self):x = Variable(np.random.random(1))y = square(x)y.backward()num_grad = numberical_diff(square, x)flg = np.allclose(x.grad, num_grad) #判斷 ndarray 實例的a,b 值是否接近#如果 a 和 b 的所有元素滿足以 下條件,則返回 Trueself.assertTrue(flg)if __name__ == '__main__':x = Variable(np.array(0.5))a = square(x)b = exp(a)y = square(b)y.grad = np.array(1.0)y.backward()print(x.grad)# 優化ones_like 初始化后# 不需要定義 y.grad = np.array(1.0) 這個了x = Variable(np.array(0.5))y = square(exp(square(x)))y.backward()print(x.grad)# 錯誤使用x = Variable(np.array(1.0))x = Variable(None)# x = Variable(1.0) # 錯誤使用# Numpy 特性問題x = np.array([1.0])y = x ** 2print(type(x), x.ndim)print(type(y))x = np.array(1.0)y = x ** 2print(type(x), x.ndim)print(type(y))
5.測試小結
通過本節,可以學習如果使用 unittest 這個庫進行代碼測試。