線性探針是什么
線性探針是一種在機器學習和相關領域廣泛應用的技術,用于評估預訓練模型特征、檢測數據中的特定序列等。在不同的應用場景下,線性探針有著不同的實現方式和作用:
-
評估預訓練模型特征:在機器學習中,線性探針是一種評估預訓練模型“特征遷移能力”的標準化方法。其核心是在凍結預訓練模型所有參數的情況下,僅用極少的標注數據(每個類別幾個樣本)訓練一個簡單的線性分類器(如邏輯回歸) 。通過這種方式來測試預訓練模型提取的特征是否足夠通用。
-
例如,有一個預訓練好的視覺模型CLIP,目標任務是“識別10種罕見鳥類”,且每個鳥類只有4張標注照片(4-shot)。此時,凍結CLIP的圖像編碼器,只訓練一個“512維→10類”的線性分類器(僅一層全連接層),用這4張/類的數據訓練。如果分類器準確率高,說明CLIP的圖像特征已經“理解”了“鳥的種類”的語義,特征遷移能力強 。
-
理解神經網絡中間層特征:可以用于監控模型每一層的特征,并衡量它們是否適合分類,以此來更好地理解中間層的角色和特點。例如在對流行的Inceptionv3和RESNET-50的研究中?