當訓練任務結束,常常需要評價函數(Metrics)來評估模型的好壞。不同的訓練任務往往需要不同的Metrics函數。例如,對于二分類問題,常用的評價指標有precision(準確率)、recall(召回率)等,而對于多分類任務,可使用宏平均(Macro)和微平均(Micro)來評估。
MindSpore提供了大部分常見任務的評價函數,如Accuracy、Precision、MAE和MSE等,由于MindSpore提供的評價函數無法滿足所有任務的需求,很多情況下用戶需要針對具體的任務自定義Metrics來評估訓練的模型。
本章主要介紹如何自定義Metrics以及如何在mindspore.train.Model中使用Metrics。
自定義Metrics
自定義Metrics函數需要繼承mindspore.train.Metric父類,并重新實現父類中的clear方法、update方法和eval方法。
- clear:初始化相關的內部參數。
- update:接收網絡預測輸出和標簽,計算誤差,每次step后并更新內部評估結果。
- eval:計算最終評估結果,在每次epoch結束后計算最終的評估結果。
平均絕對誤差(MAE)算法如式(1)所示:
下面以簡單的MAE算法為例,介紹clear、update和eval三個函數及其使用方法。
模型訓練中使用Metrics
mindspore.train.Model是用于訓練和評估的高層API,可以將自定義或MindSpore已有的Metrics作為參數傳入,Model能夠自動調用傳入的Metrics進行評估。
在網絡模型訓練后,需要使用評價指標,來評估網絡模型的訓練效果,因此在演示具體代碼之前首先簡單擬定數據集,對數據集進行加載和定義一個簡單的線性回歸網絡模型:
使用內置評價指標
使用MindSpore內置的Metrics作為參數傳入Model時,Metrics可以定義為一個字典類型,字典的key值為字符串類型,字典的value值為MindSpore內置的評價指標,如下示例使用train.Accuracy計算分類的準確率。
使用自定義評價指標
如下示例在Model中傳入上述自定義的評估指標MAE(),將驗證數據集傳入model.fit()接口邊訓練邊驗證。
驗證結果為一個字典類型,驗證結果的key值與metrics的key值相同,驗證結果的value值為預測值與實際值的平均絕對誤差。
?