Flag 驗證器使用教程
Flag 驗證器
是一種常用工具,用來驗證命令行參數或配置文件中的標志(flag)是否符合預期規則。這些工具可以幫助開發者確保傳入的參數滿足一定的條件,避免因參數錯誤而導致程序運行失敗。以下是對各個驗證器功能的中文說明以及使用示例。
功能解釋
1. register_validator
用于注冊一個驗證函數,該函數用來驗證某個特定 flag 的值是否有效。
- 用法:
register_validator("learning_rate", lambda lr: lr > 0, message="學習率必須為正數。")
- 第一個參數是 flag 的名稱,例如
"learning_rate"
。 - 第二個參數是一個驗證函數,接收 flag 的值作為輸入,返回
True
表示合法,拋出異常或返回False
表示非法。 message
參數是可選的,用于在驗證失敗時輸出提示信息。
- 第一個參數是 flag 的名稱,例如
2. validator
這是一個裝飾器,用來定義并注冊驗證器函數。它和 register_validator
類似,但更簡潔。
- 用法:
@validator def validate_positive_learning_rate(value):return value > 0 # 學習率必須為正數
3. register_multi_flags_validator
用于驗證多個 flags 之間的關系。適用于當多個 flag 需要滿足某種依賴關系或約束時。
- 用法:
register_multi_flags_validator(["learning_rate", "batch_size"],lambda lr, bs: lr < 1 and bs > 0,message="學習率必須小于 1 且批量大小必須大于 0。" )
- 第一個參數是 flag 名稱的列表。
- 第二個參數是驗證函數,接收多個 flag 的值作為輸入。
message
參數用于驗證失敗時的提示。
4. multi_flags_validator
這是 register_multi_flags_validator
的裝飾器版本,用來簡化驗證器的定義。
- 用法:
@multi_flags_validator(["flag_a", "flag_b"]) def validate_flags(flag_a, flag_b):return flag_a != flag_b # 確保 flag_a 和 flag_b 的值不同
5. mark_flag_as_required
標記某個 flag 為必需。如果運行程序時未提供該 flag,則會報錯。
- 用法:
mark_flag_as_required("model_path") # 模型路徑是必需的
6. mark_flags_as_required
標記多個 flag 為必需。如果這些 flag 中的任意一個未提供,則會報錯。
- 用法:
mark_flags_as_required(["input_path", "output_path"]) # 輸入路徑和輸出路徑都是必需的
7. mark_flags_as_mutual_exclusive
確保多個 flag 是互斥的,即只能設置其中一個。如果多個 flag 同時被設置,則會報錯。
- 用法:
mark_flags_as_mutual_exclusive(["use_gpu", "use_tpu"]) # GPU 和 TPU 不能同時使用
8. mark_bool_flags_as_mutual_exclusive
這是 mark_flags_as_mutual_exclusive
的專門版本,用于布爾類型的 flag。確保多個布爾 flag 中最多只有一個為 True
。
- 用法:
mark_bool_flags_as_mutual_exclusive(["debug", "production"]) # debug 和 production 模式不能同時開啟
這些工具如何協同使用
這些驗證器通常用于框架(如 TensorFlow、PyTorch)或自定義的命令行工具中,用來確保傳入的參數符合要求。以下是一個示例,展示如何結合使用這些驗證器。
示例代碼
以下代碼展示了如何使用這些驗證器來定義和驗證命令行 flag。
from _validators import (register_validator,register_multi_flags_validator,mark_flag_as_required,mark_flags_as_mutual_exclusive,mark_bool_flags_as_mutual_exclusive,
)# 定義 flags
flags.DEFINE_float("learning_rate", 0.01, "優化器的學習率。")
flags.DEFINE_integer("batch_size", 32, "訓練的批量大小。")
flags.DEFINE_boolean("use_gpu", False, "是否使用 GPU 進行訓練。")
flags.DEFINE_boolean("use_tpu", False, "是否使用 TPU 進行訓練。")
flags.DEFINE_string("output_dir", None, "保存訓練結果的目錄。")# 注冊驗證器
# 確保學習率為正數
register_validator("learning_rate", lambda lr: lr > 0, message="學習率必須為正數!")# 確保批量大小大于 0
register_validator("batch_size", lambda bs: bs > 0, message="批量大小必須大于 0!")# 確保輸出目錄是必需的
mark_flag_as_required("output_dir")# 確保 GPU 和 TPU 是互斥的
mark_bool_flags_as_mutual_exclusive(["use_gpu", "use_tpu"])# 確保學習率和批量大小滿足一定的關系
register_multi_flags_validator(["learning_rate", "batch_size"],lambda lr, bs: lr * bs < 1,message="學習率和批量大小的乘積必須小于 1!"
)
運行結果
-
如果未提供
output_dir
:錯誤:output_dir 是必需的,請指定保存路徑。
-
如果同時啟用了
use_gpu
和use_tpu
:錯誤:use_gpu 和 use_tpu 是互斥的,請選擇其中之一。
-
如果
learning_rate
為負數:錯誤:學習率必須為正數!
-
如果
learning_rate * batch_size >= 1
:錯誤:學習率和批量大小的乘積必須小于 1!
總結
通過以上的工具和方法,可以輕松實現以下功能:
- 驗證單個 flag 的合法性,如檢查參數范圍。
- 驗證多個 flag 的依賴關系,如互斥性或相關性。
- 確保必需的 flag 被提供,避免缺少關鍵參數導致程序失敗。
因此在jaxpi的代碼里:
import os# Deterministic
# os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_reductions --xla_gpu_autotune_level=0"
os.environ["TF_CUDNN_DETERMINISTIC"] = "1" # DETERMINISTICfrom absl import app
from absl import flags
from absl import loggingfrom ml_collections import config_flagsimport jax
jax.config.update("jax_default_matmul_precision", "highest")import train
import evalFLAGS = flags.FLAGSflags.DEFINE_string("workdir", ".", "Directory to store model data.")config_flags.DEFINE_config_file("config","./configs/default.py","File path to the training hyperparameter configuration.",lock_config=True,
)def main(argv):if FLAGS.config.mode == "train":train.train_and_evaluate(FLAGS.config, FLAGS.workdir)elif FLAGS.config.mode == "eval":eval.evaluate(FLAGS.config, FLAGS.workdir)if __name__ == "__main__":flags.mark_flags_as_required(["config", "workdir"])app.run(main)
將 config 和 workdir 標記為必需的命令行參數。
如果運行程序時未提供這兩個參數,會報錯。
作用:config:配置文件的路徑,程序需要通過它加載配置。workdir:工作目錄,用于保存訓練結果、模型檢查點等。