keras訓練完以后怎么預測_還在使用“龜速”的單顯卡訓練模型?動動手,讓TPU節省你的時間...

點擊上方關注,All in AI中國

本文將介紹如何使用Keras和Google CoLaboratory與TPU一起訓練LSTM模型,與本地計算機上的GPU相比,這樣訓練能大大縮短訓練時間。

f4dda0965b13cd3fdc03ae60f4d8b13b.png

很長一段時間以來,我都在單張GTX 1070顯卡上訓練我的模型,它的單精度大約為8.18 TFlops。后來Google的Colab開放了免費的Tesla K80顯卡,配備12GB RAM,8.73TFlops。直到最近,Colab的運行時類型選擇器中還會彈出帶有180 TFlops的Cloud TPU選項。這篇教程將簡要介紹如何將現有的Keras模型轉換為TPU模型,然后在Colab上訓練。與在GTX1070上訓練相比,TPU能夠加速20倍。

我們將構建一個易于理解,但訓練起來非常復雜的Keras模型,這樣我們就可以稍微"預熱"一下Cloud TPU。在IMDB情感分類任務上訓練LSTM模型可能是一個很好的例子,因為相比密集層和卷積層來說,訓練LSTM對算力要求更高。

工作流程概述:

  • 使用靜態輸入batch_size構建用于功能API訓練的Keras模型
  • 將Keras模型轉換為TPU模型
  • 使用靜態batch_size * 8訓練TPU模型,并將權重保存到文件
  • 創建一個結構相同,但輸入批大小可變的Keras模型,用于推理
  • 加載模型權重
  • 基于推理模型進行預測

在閱讀本文的同時,你可以上手試驗相應的Colab Jupyter notebook:Keras_LSTM_TPU.ipynb。(https://colab.research.google.com/drive/1QZf1WeX3EQqBLeFeT4utFKBqq-ogG1FN)

首先,按照下圖中的說明來激活在Colab運行中的TPU。

80c751cfce9118d7a0f3aea2a65ce646.png

激活TPU

固定輸入批尺寸

大多數情況下,CPU和GPU上對輸入形狀沒有限制,但XLA/TPU環境下會強制使用固定的形狀和批尺寸。

Can TPU包含8個TPU核心,作為獨立的處理單元運行。如果沒有使用所有八個核心,那TPU就不會得到充分利用。為了充分提高訓練的矢量化速度,相比在單一GPU上訓練的同樣的模型,我們可以選擇較大的批尺寸。總批尺寸大小為1024(每個核心128個)通常是一個很好的起點。

如果你要訓練批尺寸較大的型號,請嘗試慢慢減小批尺寸,以保證TPU內存放得下,只需確保總批尺寸為64的倍數(每核心批尺寸應該是8的倍數)。

值得一提,在批尺寸較大時,通常可以提高優化器的學習速率,以實現更快的收斂。你可以在本文中找到參考——"Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour"。(https://arxiv.org/pdf/1706.02677.pdf)

在Keras中,要定義靜態批處理尺寸,我們使用函數API,然后為輸入層指定batch_size參數。請注意,模型構建在一個帶有batch_size參數的函數中,因此我們之后可以很方便地創建在CPU或GPU上運行的模型,這些模型接受可變批尺寸的輸入。

74a8d2e19c8fde1fcafdb4acfac198b3.png

此外,我們在這里使用了tf.train.Optimizer而不是標準的Keras優化器,因為TPU對Keras優化器的支持還處于實驗階段。

將Keras模型轉換為TPU模型

tf.contrib.tpu.keras_to_tpu_model函數將tf.keras模型轉換為等價的TPU版本。

86c584e222b04e31ba37735e86b9ee24.png

然后,我們使用標準的Keras方法來訓練,保存權重并評估模型。請注意,batch_size設置為模型輸入batch_size的八倍,因為輸入樣本在8個TPU核心上均勻分布。

a7f7262a5f5a2903955bccbeb8828d24.png

我做了一個實驗,用來比較在Windows PC上運行單個GTX1070和在Colab上運行的TPU之間的訓練速度,結果如下:

  • GPU和TPU都將輸入批尺寸設為128。
  • GPU:每個歷元179秒。20個歷元后的驗證準確率達到了76.9%,總計3600秒。
  • TPU:每個歷元5秒(第一個歷元需要49秒)。20個歷元后的驗證準確率達到了95.2%,總計150秒。
  • 在20個歷元之后TPU的驗證準確度高于在GPU上的表現,那是因為TPU上同時訓練8個批的樣本(每個批的大小為128)。

在CPU上進行推理

一旦我們獲得了模型權重,我們就可以像往常一樣加載它,然后在CPU或GPU等其他設備上進行預測。我們想要推理模型接受可變的輸入批大小,這可以使用之前的make_model()函數來實現。

bc4039b346df25e1dd24da3f71df66b2.png

你可以看到推理模型現在可以接受可變輸入樣本數目,

6bad7d1498641d83b1045fb652c683d4.png

然后,你可以使用標準的fit()、evaluate()函數與推理模型。

結論以及進一步閱讀

這篇快速教程向你簡要介紹了如何利用Google Colab上的免費Cloud TPU資源更快地訓練Keras模型。

云TPU文檔:https://cloud.google.com/tpu/docs/

云TPU性能指南:https://cloud.google.com/tpu/docs/performance-guide

云TPU故障排除指南:https://cloud.google.com/tpu/docs/troubleshooting

XLA概述:https://www.tensorflow.org/performance/xla/

5f6982a3dcefc2e09bd707df85b060f5.png

編譯出品

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/276445.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/276445.shtml
英文地址,請注明出處:http://en.pswp.cn/news/276445.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

PHP5加載|安裝外部C動態庫

[1] cd php-5.3.9/ext[2] ./ext_skel --extnamencdocxml[3] cd ncdocxml[4] nano -w config.m4############刪除 3 個 dnldnl PHP_ARG_WITH(my_module, for my_module support,dnl Make sure that the comment is aligned:dnl [ --with-my_module Include my_module support])或…

手把手教你寫個小程序定時器管理庫

背景凹凸曼是個小程序開發者,他要在小程序實現秒殺倒計時。于是他不假思索,寫了以下代碼:Page({init: function () {clearInterval(this.timer)this.timer setInterval(() > {// 倒計時計算邏輯console.log(setInterval)})}, })可是&…

[New Portal]Windows Azure Virtual Machine (14) 在本地制作數據文件VHD并上傳至Azure(1)

《Windows Azure Platform 系列文章目錄》 之前的內容里,我介紹了如何將本地的Server 2012中文版 VHD上傳至Windows Azure,并創建基于該Server 2012 VHD的虛擬機。 我們知道,VHD不僅僅可以保存操作系統,而且可以保存數據文件。 如…

python 退出程序_Python:用Ctrl+C解決終止多線程程序的問題!(建議收藏)

前言:今天為大家帶來的內容是Python:用CtrlC解決終止多線程程序的問題!文章中的代碼具有不錯的參考意義,希望在此能夠幫助到各位!(多數代碼用圖片的方式呈現出來,方便各位觀看與收藏)出發點:前段時間&#…

Mysql InnoDB Plugin安裝 install

轉載鏈接:http://www.orczhou.com/index.php/2010/03/innodb-plugin-setup/ InnoDB Plugin較之Built-in版本新增了很多特性:包括快速DDL、壓縮存儲等,而且引入了全新的文件格式Barracuda。眾多測試也表明,Plugin在很多方面優于Bu…

Hibernate的數據過濾查詢

數據過濾并不是一種常規的數據查詢方法,而是一種整體的篩選方法。數據過濾也可對數據進行篩選,因此,將其放在Hibernate的數據查詢框架中介紹。 如果一旦啟用了數據過濾器,則不管數據查詢,還是數據加載,該過…

若川知乎高贊:有哪些必看的 JS 庫?

歡迎星標我的公眾號,回復加群,長期交流學習我的知乎回答目前2w閱讀量,270贊,現在發到公眾號聲明原創。必看的js庫?只有當前階段值不值看。我從去年7月起看一些前端庫的源碼,歷時一年才寫了八篇《學習源碼整…

python用for循環求10的因數_python for循環練習(初級)

for循環練習1for i in range(4):print(i)D:\尚硅谷Python\venv\Scripts\python.exe D:/尚硅谷Python/test.py0123for循環練習2for x in range(1,40,5): #間隔5print(x)D:\尚硅谷Python\venv\Scripts\python.exe D:/尚硅谷Python/test.py16111621263136打印99乘法表for i in ran…

基于EasyUI的Web應用程序及過去一年的總結

前言 一個多月之前已經提交了離職申請,好在領導都已經批準了,過幾天就辦理手續了,在此感謝領導的栽培與挽留,感謝各位同事在工作中的給我的幫助,離開這個團隊確實有一些不舍,不為別的,只因為這個…

MySQL外鍵創建失敗1005原因總結

1、安裝mysql有InnoDB的插件擴展 ./configure --prefix/usr/local/mysql --with-pluginscsv,innobase,myisam,heap,innodb_plugin 2、找不到主表中 引用的列 3、主鍵和外鍵的字符編碼不一致 4、外鍵字段與要做外鍵校驗的字段類型不匹配 5、MySQL支持外鍵約束,并…

Hibernate的事件機制

4.8 事 件 機 制 通常,Hibernate執行持久化過程中,應用程序無法參與其中。所有的數據持久化操作,對用戶都是透明的,用戶無法插入自己的動作。 通過事件框架,Hibernate允許應用程序能響應特定的內部事件,從而…

快速使用Vue3最新的15個常用API

之前我寫了一篇博客介紹了Vue3的新特性,簡單了解了一下Vue3都有哪些特色,并且在文末帶大家稍微體驗了一下Vue3中 Compsition API 的簡單使用上一篇文章地址:緊跟尤大的腳步提前體驗Vue3新特性,你不會還沒了解過Vue3吧因為這個月的…

超級馬里奧代碼_任天堂的源碼泄露,揭示超級馬里奧的前世之生

被黑客盯上的任天堂任天堂遭到了史上最大規模的黑客攻擊,Wii 完整源碼、設計以及《寶可夢》多部作品的信息遭到泄露,而此次泄露事件的后續影響似乎也爆發了出來。《馬里奧賽車》和《超級馬里奧世界2》(耀西島)的早期原型視頻,以及《超級馬里奧…

行者寂寞

公元2009年7月20日。閏五月廿八。炎日,汗如雨。晨行。病臥于京西客站。是夜,不能寐。 公元2009年7月21日。閏五月廿九。戲于清華,游于星巴克。汗如雨。是夜,困于京國際機場5小時。 公元2009年7月22日。六月初一。晨時抵寧。大雨。…

Azure PowerShell (1) PowerShell整理

《Windows Azure Platform 系列文章目錄》 把之前Azure ASM的PowerShell都整理好了。 https://github.com/leizhang1984/AzureChinaPowerShell

漫畫 | 前端發展史的江湖恩怨情仇

時間總是過得很快, 似乎快得讓人忘記了昨天,前端WEB領域的發展更是如此,轉眼間已是近30年,時光荏苒,初心不變,在一代又一代前端人的努力下,前端已經是互聯網不可或缺的一部分。然而很多前端打工…

jquery1.9 下檢測瀏覽器類型和版本

原文鏈接:http://blog.csdn.net/lyc_2011_acm/article/details/8749177 Jquery1.9版本中$.browser已被剔除: 判斷瀏覽器類型: $.browser.mozilla /firefox/.test(navigator.userAgent.toLowerCase()); $.browser.webkit /webkit/.test(nav…

python可迭代對象 迭代器生成器_Python可迭代對象、迭代器和生成器

8.1 可迭代對象(Iterable)大部分對象都是可迭代,只要實現了__iter__方法的對象就是可迭代的。__iter__方法會返回迭代器(iterator)本身,例如:>>> lst [1,2,3]>>> lst.__iter__()Python提供一些語句和關鍵字用于訪問可迭代…

Windows Mobile下使用CppUnitLite輸出測試結果

背景 TDD測試驅動開發是當前流行的開發方法及模式。遵循TDD的方法對開發程序庫(Library)特別有用,因為Library就是為第三方提供一定功能接口的實現,使用TDD的方法可以預先為定義的接口提供測試案例,保證實現代碼能通過測試,保證Li…

sql注意事項2點

①對Null的判斷,如果要用<>與null判斷,則都會得到否定結果②insert into時,要把字段顯示指出,不然會因字段位置變化出錯③-一個數時,如果有可能存在Null,則結果會被置為Null解決方法,select出來的結果,最好加isnull判斷轉載于:https://www.cnblogs.com/lishenglyx/archiv…