淺析Estimator、model_fn與EstimatorSpec

參考閱讀:https://zhuanlan.zhihu.com/p/74857888

文章目錄

  • 綜合對比
      • Estimator
      • model_fn
      • EstimatorSpec
      • 關系
      • 總結
  • Estimator
      • 主要功能
      • 構造函數參數
      • 示例用法
      • 小結
  • model_fn
  • EstimatorSpec
      • 字段解釋
      • 解釋代碼
      • 用途

綜合對比

Estimatormodel_fnEstimatorSpec 是 TensorFlow 中用于構建、訓練和評估模型的三個核心組件。它們之間的關系可以總結如下:

Estimator

  • 定義: Estimator 是 TensorFlow 提供的高層 API,用于簡化和標準化模型的訓練、評估和預測。
  • 功能:
    • 封裝訓練、評估和預測的邏輯。
    • 管理檢查點、日志記錄和模型保存。
    • 提供一致的接口來處理不同類型的模型。
  • 參數:
    • model_fn: 定義模型的函數。
    • model_dir: 模型保存目錄。
    • config: 執行環境的配置信息。
    • params: 超參數字典。
    • warm_start_from: 熱啟動配置。

model_fn

  • 定義: model_fn 是一個函數,定義了模型的結構和行為。它由 Estimator 在訓練、評估和預測時調用。
  • 功能:
    • 構建模型的計算圖。
    • 根據運行模式(TRAIN、EVAL、PREDICT)返回不同的操作。
    • 接受特征、標簽、模式、超參數和配置信息作為輸入。
  • 返回值:
    • 返回一個 EstimatorSpec 對象,定義了模型在不同模式下的行為。

EstimatorSpec

  • 定義: EstimatorSpec 是一個對象,包含了模型在訓練、評估和預測模式下的所有必要信息。
  • 功能:
    • 定義模型的預測、損失、訓練操作和評估指標。
    • 提供一致的接口,使 Estimator 能夠在不同模式下正確運行模型。
  • 字段:
    • mode: 運行模式(TRAIN、EVAL、PREDICT)。
    • predictions: 預測結果。
    • loss: 損失值。
    • train_op: 訓練操作。
    • eval_metric_ops: 評估指標操作。
    • export_outputs: 導出輸出。
    • training_chief_hooks, training_hooks, scaffold, evaluation_hooks, prediction_hooks: 各種鉤子和腳手架對象,用于在不同階段執行自定義操作。

關系

  1. Estimator 使用 model_fn:

    • Estimator 調用 model_fn 來構建模型的計算圖并定義其行為。
    • model_fn 接受特征、標簽、模式、超參數和配置信息,并返回一個 EstimatorSpec 對象。
  2. model_fn 返回 EstimatorSpec:

    • model_fn 根據當前的運行模式(TRAIN、EVAL、PREDICT)創建并返回一個 EstimatorSpec 對象。
    • EstimatorSpec 對象包含了模型在當前模式下所需的所有操作和輸出。
  3. Estimator 使用 EstimatorSpec:

    • Estimator 使用 EstimatorSpec 中定義的操作來執行訓練、評估和預測。
    • 根據 EstimatorSpec 中的信息,Estimator 知道如何處理模型的預測、損失計算和訓練步驟。

總結

  • Estimator 是高層接口,用于管理和運行模型。
  • model_fn 是用戶定義的函數,用于構建模型的計算圖并返回 EstimatorSpec
  • EstimatorSpec 定義了模型在不同模式下的行為,由 model_fn 返回,并由 Estimator 使用。

Estimator

Estimator 是 TensorFlow 提供的一個高層 API,用于簡化模型的訓練和評估。它封裝了一個模型,模型通過 model_fn 指定。Estimator 負責處理訓練、評估和預測所需的所有操作,并將結果輸出到指定的目錄。

主要功能

  1. 模型訓練、評估和預測: Estimator 封裝了這些操作,簡化了模型的開發和部署過程。
  2. 模型保存和恢復: 所有輸出(如檢查點、事件文件等)都寫入 model_dir,或其子目錄。這樣可以方便地保存和恢復模型。
  3. 運行配置: 通過 config 參數,Estimator 可以獲取有關執行環境的信息,并將其傳遞給 model_fn
  4. 超參數傳遞: 通過 params 參數,Estimator 可以將超參數傳遞給 model_fn 和輸入函數。

構造函數參數

  • model_fn: 模型函數,定義了如何構建模型。它接受以下參數:

    • features: 從 input_fn 返回的特征,通常是 TensorTensor 字典。
    • labels: 從 input_fn 返回的標簽,通常是 TensorTensor 字典。在預測模式下,labelsNone
    • mode: 運行模式,可以是 TRAINEVALPREDICT
    • params: 超參數字典,包含傳遞給 Estimator 的超參數。
    • config: RunConfig 對象,包含執行環境的配置信息。
  • model_dir: 模型參數、圖等的保存目錄,也可以用于從目錄加載檢查點以繼續訓練之前保存的模型。

  • config: RunConfig 配置對象,包含執行環境的配置信息。如果model_fn函數也定義config這個變量,則會將config傳給model_fn。

  • params: 超參數字典,包含傳遞給 model_fn 的超參數。

  • warm_start_from: 檢查點或 SavedModel 的文件路徑,用于熱啟動,或一個 WarmStartSettings 對象以完全配置熱啟動。

示例用法

  1. 創建一個 Estimator 實例

    estimator = tf.estimator.DNNClassifier(feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb],hidden_units=[1024, 512, 256],warm_start_from="/path/to/checkpoint/dir"
    )
    
  2. 定義 model_fn

    def my_model_fn(features, labels, mode, params):# 構建模型logits = build_model(features, mode, params)predictions = {'classes': tf.argmax(input=logits, axis=1),'probabilities': tf.nn.softmax(logits)}# PREDICT 模式if mode == tf.estimator.ModeKeys.PREDICT:return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)# 計算損失loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)# 訓練操作if mode == tf.estimator.ModeKeys.TRAIN:optimizer = tf.train.AdamOptimizer(learning_rate=params['learning_rate'])train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)# 評估指標eval_metric_ops = {'accuracy': tf.metrics.accuracy(labels=labels, predictions=predictions['classes'])}return tf.estimator.EstimatorSpec(mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)
    
  3. 使用 Estimator 進行訓練、評估和預測

    # 訓練
    estimator.train(input_fn=train_input_fn, steps=1000)# 評估
    eval_result = estimator.evaluate(input_fn=eval_input_fn)
    print(eval_result)# 預測
    predictions = estimator.predict(input_fn=predict_input_fn)
    for pred in predictions:print(pred)
    

小結

Estimator 提供了一種結構化的方法來定義和管理 TensorFlow 模型,使得模型的訓練、評估和預測更加方便和標準化。它通過 model_fn 將模型的構建與訓練、評估和預測邏輯分離,并且通過配置和參數化提供了靈活性。

model_fn

輸入:

  • features: 從 input_fn 返回的特征,通常是 TensorTensor 字典。
  • labels: 從 input_fn 返回的標簽,通常是 TensorTensor 字典。在預測模式下,labelsNone
  • mode: 運行模式,可以是 TRAINEVALPREDICT
  • params: 超參數字典,包含傳遞給 Estimator 的超參數。
  • config: RunConfig 對象,包含執行環境的配置信息。

返回值:
一個EstimatorSpec

前兩個參數是從輸入函數中返回的特征和標簽批次;也就是說,features 和 labels 是模型將使用的數據。

params 是一個字典,它可以傳入許多參數用來構建網絡或者定義訓練方式等。例如通過設置params[‘n_classes’]來定義最終輸出節點的個數等。
config 通常用來控制checkpoint或者分布式什么,這里不深入研究。
mode 參數表示調用程序是請求訓練、評估還是預測,分別通過tf.estimator.ModeKeys.TRAIN / EVAL / PREDICT 來定義。另外通過觀察DNNClassifier的源代碼可以看到,mode這個參數并不用手動傳入,因為Estimator會自動調整。例如當你調用estimator.train(…)的時候,mode則會被賦值tf.estimator.ModeKeys.TRAIN。

模型有訓練,驗證和測試三種階段,而且對于不同模式,對數據有不同的處理方式。例如在訓練階段,我們需要將數據喂給模型,模型基于輸入數據給出預測值,然后我們在通過預測值和真實值計算出loss,最后用loss更新網絡參數,而在評估階段,我們則不需要反向傳播更新網絡參數,換句話說,model_fn需要對三種模式設置三套代碼

EstimatorSpec

collections.namedtuple 是 Python 標準庫中的一個函數,用于創建不可變的、具名的元組(named tuple)。這些具名元組可以像類一樣使用,有字段名稱,使代碼更具可讀性和可維護性。

在這段代碼中,collections.namedtuple 被用來創建一個名為 EstimatorSpec 的具名元組,它包含了一組用于定義模型在不同模式下行為的字段。以下是每個字段的解釋:

字段解釋

  1. mode: 模式,表示當前的運行模式,可以是訓練(TRAIN)、評估(EVAL)或預測(PREDICT)模式。
  2. predictions: 預測值,可以是一個 TensorTensor 字典,用于預測模式下輸出結果。
  3. loss: 損失值,一個標量 Tensor,表示模型的損失,用于訓練和評估模式。
  4. train_op: 訓練操作,表示在訓練模式下執行的操作(通常是優化步驟)。
  5. eval_metric_ops: 評估指標操作,是一個字典,包含評估模式下的度量結果。
  6. export_outputs: 導出輸出,是一個字典,定義了模型在導出為 SavedModel 時的輸出簽名。
  7. training_chief_hooks: 主訓練鉤子,是一個迭代器,包含在主 worker 上運行的 SessionRunHook 對象。
  8. training_hooks: 訓練鉤子,是一個迭代器,包含在所有 worker 上運行的 SessionRunHook 對象。
  9. scaffold: 腳手架,是一個 tf.train.Scaffold 對象,用于設置初始化、保存和恢復操作。
  10. evaluation_hooks: 評估鉤子,是一個迭代器,包含在評估過程中運行的 SessionRunHook 對象。
  11. prediction_hooks: 預測鉤子,是一個迭代器,包含在預測過程中運行的 SessionRunHook 對象。

解釋代碼

collections.namedtuple('EstimatorSpec', ['mode', 'predictions', 'loss', 'train_op', 'eval_metric_ops','export_outputs', 'training_chief_hooks', 'training_hooks', 'scaffold','evaluation_hooks', 'prediction_hooks'
])

這行代碼創建了一個名為 EstimatorSpec 的具名元組類,它包含了上述的這些字段。EstimatorSpec 類可以用于存儲和傳遞這些字段的值,使得在模型函數(model_fn)中可以方便地定義和返回這些值。

用途

EstimatorSpec 主要用于 TensorFlow 的 Estimator API 中,以統一的方式定義模型的各個組成部分。通過使用 EstimatorSpec,可以確保模型在不同模式下的行為是一致且正確的。例如:

  • 在訓練模式下,必須提供 losstrain_op
  • 在評估模式下,必須提供 loss
  • 在預測模式下,必須提供 predictions

使用 EstimatorSpec,可以更簡潔和清晰地定義模型的各個部分,并且通過具名元組的方式,使代碼更加可讀和易于維護。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/39321.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/39321.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/39321.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

西電811考研、140分專業課及811/821經驗

被擬錄取了,說一說自己考研經驗,本人跟的研夢考研全程班,胖覃學長很負責任,貌似已經直博西電了,但也很負責。 1、通信工程學院分為學碩與專碩,學碩包含信息與通信工程、交通運輸工程、軍隊指揮學&#xff…

Perl語言中的排序藝術:深入探討內置排序函數

Perl是一種功能強大的腳本語言,以其靈活的文本處理能力而聞名。在Perl中,排序是一項常見的任務,無論是對數組元素進行排序,還是對復雜數據結構進行排序,Perl都提供了多種內置的排序函數,以滿足不同的需求。…

深入掌握Symfony與Composer:PHP依賴管理的藝術

引言 Composer是PHP的依賴管理工具,廣泛用于Symfony等現代PHP應用程序中。它允許開發者聲明依賴項,自動處理依賴的安裝和更新,確保應用程序的依賴項得到有效管理。本文將詳細介紹Composer的使用方法,包括基本命令、依賴管理、自動…

Linux環境安裝配置nginx服務流程

Linux環境的Centos、麒麟、統信操作系統安裝配置nginx服務流程操作: 1、官網下載 下載地址 或者通過命令下載 wget http://nginx.org/download/nginx-1.20.2.tar.gz 2、上傳到指定的服務器并解壓 tar -zxvf nginx-1.20.1.tar.gzcd nginx-1.20.1 3、編譯并安裝到…

條件過濾檢索

背景介紹 在大多數業務場景中,單純使用向量進行相似性檢索并無法滿足業務需求,通常需要在滿足特定過濾條件、或者特定的“標簽”的前提下,再進行相似性檢索。 向量檢索服務DashVector支持條件過濾和向量相似性檢索相結合,在精確滿…

數字化供應鏈:背景特點

?背景 1、外部環境 近年來,供應鏈脆弱性凸顯,企業供應鏈壓力難以緩解。 美國媒體針對美國零售聯合會、美國服裝和鞋類協會、美國供應鏈管理專業委員會等主體進行的一項供應鏈調查顯示: 61%的供應鏈經理預計,供應鏈紊亂問題至少…

C++(第一天-----命名空間和引用)

一、C/C的區別 1、與C相比   c語言面向過程,c面向對象。   c能夠對函數進行重載,可使同名的函數功能變得更加強大。   c引入了名字空間,可以使定義的變量名更多。   c可以使用引用傳參,引用傳參比起指針傳參更加快&#…

企業化運維(5)_mysql數據庫

###1.源碼編譯mysql### 對壓縮包進行解壓,并對mysql進行源碼編譯,其中需要下載依賴才能編譯成功。 官網: www.mysql.com解壓并進入目錄 [rootserver1 ~]# tar xf mysql-boost-5.7.40.tar.gz [rootserver1 ~]# cd mysql-5.7.40/安裝依賴性…

初識Java(復習版)

一. 什么是Java Java是一種面向對象的編程語言,和C語言有所不同,C語言是一門面向過程的語言。偏底層實現,比較注重底層的邏輯實現。不能一味的說某一種語言特別好,每一種語言都是在特定的情況下有自己的優勢。 二.Java語言發展史…

昇思25天學習打卡營第2天|yulang

今天主要了解快速入門,主要包含了處理數據集、網絡構建、模型訓練、保存模型和加載模型,這些對于不是算法工程師理解起來可能稍微有一點的難度,學習起來有點枯燥,期待后續實戰部分能完成一些獨立的比較有意思的項目。

鴻蒙項目實戰-月木學途:2.自定義底部導航

效果預覽 Tabs組件簡介 Tabs組件的頁面組成包含兩個部分,分別是TabContent和TabBar。TabContent是內容頁,TabBar是導航頁簽欄,頁面結構如下圖所示,根據不同的導航類型,布局會有區別,可以分為底部導航、頂部…

使用ECharts實現動態數據可視化的最佳實踐

使用ECharts實現動態數據可視化的最佳實踐 大家好,我是免費搭建查券返利機器人省錢賺傭金就用微賺淘客系統3.0的小編,也是冬天不穿秋褲,天冷也要風度的程序猿! 引言 隨著數據驅動決策的重要性日益增強,動態數據可視…

第二十站:Java未來光譜——量子計算與新興技術的展望

Java作為一門成熟且廣泛使用的編程語言,其在傳統計算領域已經取得了巨大的成功。然而,隨著量子計算等新興技術的出現,Java也在探索其在這些領域的應用潛力。IBM Qiskit是一個開源的量子計算軟件框架,它允許開發者使用多種編程語言…

登錄驗證碼高擴展性設計方案

登錄驗證碼高擴展性建設方案 本文分享了一種登錄驗證碼高擴展性的建設方案,通過工廠模式策略模式,增強了驗證碼服務中驗證碼生成器、驗證碼存儲器、驗證碼圖片生成器的擴展性,實現了服務組件的多樣化,降低了維護成本 登錄驗證碼高…

8617 階乘數字和

這是一個關于計算階乘結果所有位上的數字之和的問題。我們可以通過以下步驟來解決這個問題: 1. 首先,我們需要一個函數來計算階乘。由于n的范圍可以達到50,階乘的結果可能非常大,所以我們需要使用一個可以處理大整數的數據類型&a…

adb shell logcat -b all|grep如何可以grep兩個子串?

在adb shell logcat命令中結合grep來過濾日志時,如果你想要同時匹配兩個子串,你可以使用管道(|)將兩個grep命令連接起來,或者使用grep的-E(或egrep,它等同于-E)選項來支持擴展的正則…

[課程][原創]opencv圖像在C#與C++之間交互傳遞

opencv圖像在C#與C之間交互傳遞 課程地址:https://edu.csdn.net/course/detail/39689 無限期視頻有效期 課程介紹課程目錄討論留言 你將收獲 學會如何封裝C的DLL 學會如何用C#調用C的DLL 掌握opencv在C#和C傳遞思路 學會如何配置C的opencv 適用人群 擁有C#…

報錯:pathspec ‘xxx‘ did not match any file(s) known to git

在 escode 中進行分支切換時報如下錯誤 PS > git checkout xxx error: pathspec xxx did not match any file(s) known to git遠程分支已經在 gitlab 客戶端手動創建,在 escode 中也使用了拉取之類的操作,但是切換分支時依然報錯。 解決方案 查看分…

怎么找到DNS服務器的地址?

所有域都注冊到域名名稱服務器(DNS)點,以解析域名應指向的IP地址。此查找類似于在查找個人名稱并查找其電話號碼時的電話簿如何運行。如果DNS服務器設置錯誤或指向錯誤的名稱服務器,則域可能無法加載相應的網頁。 如何查找當前的…

【深度學習】C++ onnx Yolov8 目標檢測推理

【深度學習】C onnx Yolov8 目標檢測推理 導出onnx模型代碼onnx_detect_infer.honnx_detect_infer.cppmain.cppCMAKELIST 導出onnx模型 python 中導出 from ultralytics import YOLO# Load the YOLOv8 model model YOLO("best.pt")# # Export the model to ONNX f…