1. 什么是免訓練指標(Zero-Cost Proxies,ZC proxies)?
免訓練指標是一類 無需完整訓練模型即可評估其性能的度量方法,主要用于提高 神經架構搜索(NAS) 的效率。
傳統 NAS 需要訓練候選架構來評估其性能,但訓練消耗巨大,因此免訓練指標提供了一種 基于模型本身特性(如梯度、參數分布)快速估計模型質量的方法。
核心思想:
只用一個小批量數據 計算某些統計量(如梯度、參數重要性、激活值分布),從而 近似衡量模型的好壞,而不需要完整訓練整個模型。
2. 免訓練指標的類別
免訓練指標可以大致分為兩類:
- 傳統結構分析指標(如 SNIP、Synflow、Fisher)
- 基于知識蒸餾的指標(如 DisWOT)
(1)傳統結構分析指標
這些方法通常通過計算 梯度、權重、Hessian 矩陣 等信息來評估模型的質量。
① SNIP(Single-shot Network Pruning)
- 計算梯度的重要性,衡量每個參數對損失函數的影響:
ρ s n i p = ∣ ? L ? W ⊙ W ∣ \rho_{snip} = \left| \frac{\partial \mathcal{L}}{\partial \mathcal{W}} \odot \mathcal{W} \right| ρsnip?= ??W?L?⊙W ? - 核心思想:如果去掉某個權重后損失變化較大,則該權重很重要。因此,可以用梯度信息估算整個網絡的質量。
② Synflow
- 通過梯度流分析,避免層塌陷(layer collapse):
ρ s y n f l o w = ? L ? W ⊙ W \rho_{synflow} = \frac{\partial \mathcal{L}}{\partial \mathcal{W}} \odot \mathcal{W} ρsynflow?=?W?L?⊙W - 核心思想:確保不同層的梯度能夠均勻流動,以保持架構的穩定性。
③ Fisher
- 計算激活梯度的平方和,用于通道剪枝:
ρ f i s h e r = ( ? L ? A A ) 2 \rho_{fisher} = \left( \frac{\partial \mathcal{L}}{\partial \mathcal{A}} \mathcal{A} \right)^2 ρfisher?=(?A?L?A)2 - 核心思想:通道(Channel)如果對梯度變化敏感,則在訓練時影響更大,可以用它來衡量模型質量。
(2)基于知識蒸餾的指標
DisWOT(Distillation Without Training)
-
這是一種 基于知識蒸餾的免訓練指標,通過計算 教師-學生模型的特征匹配誤差 來評估網絡質量:
ρ D i s W O T = D L 2 ( G ( [ A S , A T ] ) ) + D L 2 ( G ( [ F S , F T ] ) ) \rho_{DisWOT} = \mathcal{D}_{L2} (\mathcal{G}([AS,AT])) + \mathcal{D}_{L2} (\mathcal{G}([FS,FT])) ρDisWOT?=DL2?(G([AS,AT]))+DL2?(G([FS,FT])) -
其中:
- ( AS, AT ) 是教師-學生模型的 激活圖(Activation Maps)
- ( FS, FT ) 是教師-學生模型的 特征圖(Feature Maps)
- ( \mathcal{D}_{L2} ) 計算的是 L2 距離(歐幾里得距離),衡量特征匹配誤差
-
核心思想:如果一個模型可以很好地模仿教師模型的特征分布(即 L2 誤差小),則這個模型的質量更好。
3. 免訓練指標如何用于 NAS
在 NAS 中,免訓練指標可以用于:
- 快速評估候選架構
- 在搜索空間中 篩選掉性能較差的架構,減少訓練計算量。
- 結合搜索算法優化架構
- 可以將 梯度信息(SNIP, Synflow) 或 知識蒸餾誤差(DisWOT) 作為搜索目標,指導 NAS 選擇更優的架構。
- 設計高效的蒸餾感知 NAS(DAS)
- 結合 DAS(Distillation-aware Architecture Search),讓 NAS 選擇對知識蒸餾更友好的模型,提高輕量化模型的性能。