機器學習中qa測試_如何對機器學習做單元測試

作者:Chase Roberts

編譯:ronghuaiyang

導讀

養成良好的單元測試的習慣,真的是受益終身的,特別是機器學習代碼,有些bug真不是看看就能看出來的。

9ad1a64e1939284f73524ee5c01a7131.png

在過去的一年里,我把大部分的工作時間都花在了深度學習研究和實習上。那一年,我犯了很多大錯誤,這些錯誤不僅幫助我了解了ML,還幫助我了解了如何正確而穩健地設計這些系統。我在谷歌Brain學到的一個主要原則是,單元測試可以決定算法的成敗,可以為你節省數周的調試和訓練時間。

然而,在如何為神經網絡代碼編寫單元測試方面,似乎沒有一個可靠的在線教程。即使是像OpenAI這樣的地方,也只是通過盯著他們代碼的每一行,并試著思考為什么它會導致bug來發現bug的。顯然,我們大多數人都沒有這樣的時間,所以希望本教程能夠幫助你開始理智地測試你的系統!

讓我們從一個簡單的例子開始。試著找出這段代碼中的錯誤。

def make_convnet(input_image):    net = slim.conv2d(input_image, 32, [11, 11], scope="conv1_11x11")    net = slim.conv2d(input_image, 64, [5, 5], scope="conv2_5x5")    net = slim.max_pool2d(net, [4, 4], stride=4, scope='pool1')    net = slim.conv2d(input_image, 64, [5, 5], scope="conv3_5x5")    net = slim.conv2d(input_image, 128, [3, 3], scope="conv4_3x3")    net = slim.max_pool2d(net, [2, 2], scope='pool2')    net = slim.conv2d(input_image, 128, [3, 3], scope="conv5_3x3")    net = slim.max_pool2d(net, [2, 2], scope='pool3')    net = slim.conv2d(input_image, 32, [1, 1], scope="conv6_1x1")    return net

你看到了嗎?網絡實際上并沒有堆積起來。在編寫這段代碼時,我復制并粘貼了slim.conv2d(…)行,并且只修改了內核大小,而沒有修改實際的輸入。

我很不好意思地說,這件事在一周前就發生在我身上了……但這是很重要的一課!由于一些原因,這些bug很難捕獲。

  1. 這段代碼不會崩潰,不會產生錯誤,甚至不會變慢。
  2. 這個網絡仍在運行,損失仍將下降。
  3. 幾個小時后,這些值就會收斂,但結果卻非常糟糕,讓你摸不著頭腦,不知道需要修復什么。

當你唯一的反饋是最終的驗證錯誤時,你惟一需要搜索的地方就是你的整個網絡體系結構。不用說,你需要一個更好的系統。

那么,在我們進行完整的多日訓練之前,我們如何真正抓住這個機會呢?關于這個最容易注意到的是層的值實際上不會到達函數外的任何其他張量。假設我們有某種類型的損失和一個優化器,這些張量永遠不會得到優化,所以它們總是有它們的默認值。

我們可以通過簡單的訓練步驟和前后對比來檢測它。

def test_convnet():  image = tf.placeholder(tf.float32, (None, 100, 100, 3)  model = Model(image)  sess = tf.Session()  sess.run(tf.global_variables_initializer())  before = sess.run(tf.trainable_variables())  _ = sess.run(model.train, feed_dict={               image: np.ones((1, 100, 100, 3)),               })  after = sess.run(tf.trainable_variables())  for b, a, n in zip(before, after):      # Make sure something changed.      assert (b != a).any()

在不到15行代碼中,我們現在驗證了至少我們創建的所有變量都得到了訓練。

這個測試超級簡單,超級有用。假設我們修復了前面的問題,現在我們要開始添加一些批歸一化。看看你能否發現這個bug。

  def make_convnet(image_input):        # Try to normalize the input before convoluting        net = slim.batch_norm(image_input)        net = slim.conv2d(net, 32, [11, 11], scope="conv1_11x11")        net = slim.conv2d(net, 64, [5, 5], scope="conv2_5x5")        net = slim.max_pool2d(net, [4, 4], stride=4, scope='pool1')        net = slim.conv2d(net, 64, [5, 5], scope="conv3_5x5")        net = slim.conv2d(net, 128, [3, 3], scope="conv4_3x3")        net = slim.max_pool2d(net, [2, 2], scope='pool2')        net = slim.conv2d(net, 128, [3, 3], scope="conv5_3x3")        net = slim.max_pool2d(net, [2, 2], scope='pool3')        net = slim.conv2d(net, 32, [1, 1], scope="conv6_1x1")        return net

你看到了嗎?這個非常微妙。您可以看到,在tensorflow batch_norm中,is_training的默認值是False,所以添加這行代碼并不能使你在訓練期間的輸入正常化!值得慶幸的是,我們編寫的最后一個單元測試將立即發現這個問題!(我知道,因為這是三天前發生在我身上的事。)

再看一個例子。這實際上來自我一天看到的一篇文章(https://www.reddit.com/r/MachineLearning/comments/6qyvvg/p_tensorflow_response_is_making_no_sense/)。我不會講太多細節,但是基本上這個人想要創建一個輸出范圍為(0,1)的分類器。

class Model:  def __init__(self, input, labels):    """Classifier model    Args:      input: Input tensor of size (None, input_dims)      label: Label tensor of size (None, 1).         Should be of type tf.int32.    """    prediction = self.make_network(input)    # Prediction size is (None, 1).    self.loss = tf.nn.softmax_cross_entropy_with_logits(        logits=prediction, labels=labels)    self.train_op = tf.train.AdamOptimizer().minimize(self.loss)

注意到這個錯誤嗎?這是真的很難提前發現,并可能導致超級混亂的結果。基本上,這里發生的是預測只有一個輸出,當你將softmax交叉熵應用到它上時,它的損失總是0。

一個簡單的測試方法是確保損失不為0。

def test_loss():  in_tensor = tf.placeholder(tf.float32, (None, 3))  labels = tf.placeholder(tf.int32, None, 1))  model = Model(in_tensor, labels)  sess = tf.Session()  loss = sess.run(model.loss, feed_dict={    in_tensor:np.ones(1, 3),    labels:[[1]]  })  assert loss != 0

另一個很好的測試與我們的第一個測試類似,但是是反向的。你可以確保只有你想訓練的變量得到了訓練。以GAN為例。出現的一個常見錯誤是在進行優化時不小心忘記設置要訓練的變量。這樣的代碼經常發生。

class GAN:  def __init__(self, z_vector, true_images):    # Pretend these are implemented.    with tf.variable_scope("gen"):      self.make_geneator(z_vector)    with tf.variable_scope("des"):      self.make_descriminator(true_images)    opt = tf.AdamOptimizer()    train_descrim = opt.minimize(self.descrim_loss)    train_gen = opt.minimize(self.gen_loss)

這里最大的問題是優化器有一個默認設置來優化所有變量。在像GANs這樣的高級架構中,這是對你所有訓練時間的死刑判決。但是,你可以通過編寫這樣的測試來輕松地發現這些錯誤:

def test_gen_training():  model = Model  sess = tf.Session()  gen_vars = tf.get_collection(tf.GraphKeys.VARIABLES, scope='gen')  des_vars = tf.get_collection(tf.GraphKeys.VARIABLES, scope='des')  before_gen = sess.run(gen_vars)  before_des = sess.run(des_vars)  # Train the generator.  sess.run(model.train_gen)  after_gen = sess.run(gen_vars)  after_des = sess.run(des_vars)  # Make sure the generator variables changed.  for b,a in zip(before_gen, after_gen):    assert (a != b).any()  # Make sure descriminator did NOT change.  for b,a in zip(before_des, after_des):    assert (a == b).all()

可以為鑒別器編寫一個非常類似的測試。同樣的測試也可以用于許多強化學習算法。許多行為-批評模型有單獨的網絡,需要根據不同的損失進行優化。

下面是一些我推薦你進行測試的模式。

  1. 讓測試具有確定性。如果一個測試以一種奇怪的方式失敗,卻永遠無法重現這個錯誤,那就太糟糕了。如果你真的想要隨機輸入,確保使用種子隨機數,這樣你就可以輕松地重新運行測試。
  2. 保持測試簡短。不要使用單元測試來訓練收斂性并檢查驗證集。這樣做是在浪費自己的時間。
  3. 確保你在每個測試之間重置了計算圖。

總之,這些黑箱算法仍然有很多方法需要測試!花一個小時寫一個測試可以節省你幾天的重新運行訓練模型,并可以大大提高你的研究效率。因為我們的實現有缺陷而不得不放棄完美的想法,這不是很糟糕嗎?

這個列表顯然不全面,但它是一個堅實的開始!

英文原文:https://medium.com/@keeper6928/how-to-unit-test-machine-learning-code-57cf6fd81765

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/454365.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/454365.shtml
英文地址,請注明出處:http://en.pswp.cn/news/454365.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

項目寶提供的服務器,開源WebSocket服務器項目寶貝魚CshBBrain V4.0.1 和 V2.0.2發布

開源WebSocket服務器項目寶貝魚CshBBrain V4.0.1 和 V2.0.2發布更新的功能列表如下:1.解決開啟廣播消息開關時,不能同時接入2個客戶端的重大缺陷。2.對廣播消息做了重大優化,從以前一個線程發送廣播消息進化到使用工作線程池中的線程并行的發…

c# 無損高質量壓縮圖片代碼

/// 無損壓縮圖片 /// <param name"sFile">原圖片</param> /// <param name"dFile">壓縮后保存位置</param> /// <param name"dHeight">高度</param> /// <param name"dWidth"…

一個從文本文件里“查找并替換”的功能

12345678910111213141516171819202122232425# -*- coding: UTF-8 -*-file input("請輸入文件路徑:") word1 input("請輸入要替換的詞:") word2 input("請輸入新的詞&#xff1a;") fopen(file,"r") AAAf.read() count 0 def BBB()…

機器學習算法之 KNN

K近鄰法(k-nearst neighbors,KNN)是一種很基本的機器學習方法了&#xff0c;在我們平常的生活中也會不自主的應用。比如&#xff0c;我們判斷一個人的人品&#xff0c;只需要觀察他來往最密切的幾個人的人品好壞就可以得出了。這里就運用了KNN的思想。KNN方法既可以做分類&…

安裝云端服務器操作系統,安裝云端服務器操作系統

安裝云端服務器操作系統 內容精選換一換SAP云服務器規格在申請SAP ECS之前&#xff0c;請參考SAP標準Sizing方法進行SAPS值評估&#xff0c;并根據Sizing結果申請云端ECS服務器資源&#xff0c;詳細信息請參考SAP Quick Sizer。SAP 各組件最低硬盤空間、RAM&#xff0c;以及軟件…

python 進度條_六種酷炫Python運行進度條

轉自&#xff1a;一行數據閱讀文本大概需要 3 分鐘你的代碼進度還剩多少&#xff1f;今天給大家介紹下目前6種比較常用的進度條&#xff0c;讓大家都能直觀地看到腳本運行最新的進展情況。1.普通進度條2.帶時間進度條3.tpdm進度條4.progress進度條5.alive_progress進度條6.可視…

js 獲取多少天前

getBeforeDate: function(day, str) { var now new Date().getTime(); //獲取毫秒數 var before new Date(now - ((day > 0 && day ? day : 0) * 86400 * 1000)); var year before.getFullYear(); var month before.getMonth()1; var date before.getDate(); …

程序員的基本素質

給所有立志成為程序員的朋友 以及 自勉之&#xff01; 程序員基本素質&#xff1a; 作一個真正合格的程序員&#xff0c;或者說就是可以真正合格完成一些代碼工作的程序員&#xff0c;應該具有的素質。 1&#xff1a;團隊精神和協作能力 把它作為基本素質&#xff0c;并…

權限之淺理解

白馬過隙&#xff0c;在感嘆時光流逝的同時不得不承認在學習中隨著知識面的不斷擴展所接受的東西也越來越多&#xff0c;尤其是那些外形比較容易混淆的命令&#xff0c;著實讓作為新手的吃了很多苦頭&#xff0c;趁著學習緊張之時偷個懶整理這周易混淆的命令&#xff1a; chgrp…

機器學習算法之生成樹

一、什么是決策樹&#xff1f; 決策樹&#xff08;Decision Tree&#xff09;是一種基本的分類和回歸的方法。 分類決策樹模型是一種描述對實例進行分類的樹形結構。決策樹由結點&#xff08;node&#xff09;和有向邊&#xff08;directed edge&#xff09;組成。結點有兩種…

強烈推薦給從事IT業的同行們 (轉載)

作者&#xff1a;李學凌 文章來源&#xff1a;bbs.ustc.edu.cn 中國有很多小朋友&#xff0c;他們18,9歲或21,2歲&#xff0c;通過自學也寫了不少代碼&#xff0c;他們有的代碼寫的很漂亮&#xff0c;一些技術細節相當出眾&#xff0c;也很有鉆研精神&#xff0c;但是他…

微機原理控制轉移類指令

1.無條件跳轉指令 指令格式;JMP 目標地址 功能&#xff1a;JMP可以使程序無條件地跳轉到程序存儲器中某目標地址 注意點&#xff1a; 1&#xff09;指令目標地址若在JMP指令所在的代碼段內&#xff0c;屬段內跳轉&#xff0c;指令只修改IP內容。指令目標地址若在JMP指令所在的代…

OPENNMS的后臺并行管理任務

Concurrent management tasks: 1. . Action daemon - automated action (work flow)2. .數據采集Collection daemon - collects data3. .能力檢查Capability daemon - capability check on nodes4. .動態主機配置協議DHCP daemon - DHCP clien…

機器學習算法之集成學習

集成學習的思想是將若干個學習器(分類器&回歸器)組合之后產生一個新學習器。弱分類器(weak learner)指那些分類準確率只稍微好于隨機猜測的分類器(errorrate < 0.5)。 集成算法的成功在于保證弱分類器的多樣性(Diversity)。而且集成不穩定的算法也能夠得到一個比較明顯…

常用的方法論-NPS

轉載于:https://www.cnblogs.com/qjm201000/p/7687510.html

controller調用controller的方法_SpringBoot 優雅停止服務的幾種方法

轉自&#xff1a;博客園&#xff0c;作者&#xff1a;黃青石www.cnblogs.com/huangqingshi/p/11370291.html 在使用 SpringBoot 的時候&#xff0c;都要涉及到服務的停止和啟動&#xff0c;當我們停止服務的時候&#xff0c;很多時候大家都是kill -9 直接把程序進程殺掉&#x…

linux下安裝Oracle10g時,安裝rpm文件的技巧 (rpm -Uvh package名)

rpm -q package名 &#xff1a; 查詢該package是否已經被安裝了rpm -qa | grep package名 或是package 的關鍵字 &#xff1a; 查詢該package是否已經被安裝了rpm -Uvh package名 &#xff1a; 意思是update packagerpm -Uvh package名 --force &#xff1a; 意思是如果該…

機器學習之聚類概述

什么是聚類 聚類就是對大量未知標注的數據集&#xff0c;按照數據 內部存在的數據特征 將數據集劃分為 多個不同的類別 &#xff0c;使 類別內的數據比較相似&#xff0c;類別之間的數據相似度比較小&#xff1b;屬于 無監督學習。 聚類算法的重點是計算樣本項之間的 相似度&…

程序員-建立你的商業意識 閆輝 著

1 程序員為什么需要商業意識 幾 年前&#xff0c;當我剛剛認識Fishman的時候&#xff0c;聽到他神奇的創業經歷&#xff0c;覺得非常不可思議。甚至還專門寫了一篇報道發到《電腦報》上&#xff0c;題目是《從程序員到 CEO》。不久&#xff0c;Fishman將創建的又一個新公司…

qt release打包發布_幾種解決Qt程序打包后無法連接數據庫問題的方法

Qt是一個跨平臺C圖形用戶界面應用程序開發框架&#xff0c;使用它不僅可以方便地開發GUI程序&#xff0c;也可以開發非GUI程序&#xff0c;可以一次編寫&#xff0c;處處編譯。今天遇到的問題比較怪異&#xff0c;我開發的是一個桌面版訂單管理系統&#xff0c;整體架構就是一個…