簡介
簡介:這次學習的OpenGAN主要學習一個思路,跳出傳統GAN對于判斷真假的識別到判斷是已知種類還是未知種類。重點內容不在于代碼而是思路,會簡要給出一個設計的代碼。
論文題目:OpenGAN: Open-Set Recognition via Open Data Generation(基于開放數據生成的開放集識別)
期刊:IEEE TRANSACTIONS ON PATTERN ANALYSIS AND MACHINE INTELLIGENCE(if=20.8,超級Top)
摘要:現實世界的機器學習系統需要分析可能與訓練數據不同的測試數據。 在K-way分類中,這被清晰地表述為開集識別,其核心是區分K個閉集類之外的開集數據的能力。 開集判別的兩個概念是:1)通過利用一些離群數據作為開集判別學習開-閉二元判別器; 2)利用GAN的判別器作為開集似然函數,對閉集數據分布進行無監督學習。 然而,由于對訓練異常值的過度擬合,前者對各種開放測試數據的泛化效果較差,而訓練異常值不太可能詳盡地跨越開放世界。 后者不能很好地工作,可能是由于gan的訓練不穩定。 在上述的激勵下,我們提出了OpenGAN,它通過將每種方法與幾個技術見解相結合來解決每種方法的局限性。 首先,我們證明了在一些真實的離群數據上精心選擇的gan鑒別器已經達到了最先進的水平。 其次,我們用對抗合成的“假”數據增強真實開放訓練樣本的可用集。 第三,也是最重要的,我們在封閉世界K-way網絡計算的特征上建立了鑒別器。 這使得OpenGAN可以通過建立在現有K-way網絡之上的輕量級鑒別器頭來實現。 大量的實驗表明,OpenGAN顯著優于先前的開集方法。
問題背景的具體例子
想象你要訓練一個自動駕駛汽車的視覺系統:
- 訓練數據:汽車、行人、紅綠燈、建筑物等19個類別
- 現實問題:路上突然出現嬰兒車、街頭小攤等訓練時沒見過的物體
- 危險后果:系統可能把嬰兒車錯誤識別為"摩托車",導致不當的避讓策略
OpenGAN的具體操作流程
第一步:準備基礎設施
1. 已有一個訓練好的K類分類器(比如識別19種交通場景物體)
2. 從這個分類器的倒數第二層提取特征(不是直接用原始圖片)
3. 收集少量"其他"類別的數據作為已知的異常樣本
為什么用特征而不用像素?
- 原始圖片:1024×2048×3 = 600萬維度,太復雜
- 提取的特征:可能只有512維,包含了高級語義信息
- 就像人類識別物體時關注的是形狀、紋理,而不是每個像素點
第二步:構建OpenGAN架構
OpenGAN包含兩個核心組件:
判別器D(Discriminator):
- 輸入:特征向量
- 輸出:該特征屬于"已知類別"的概率
- 作用:區分"已知"vs"未知"
生成器G(Generator):
- 輸入:隨機噪聲
- 輸出:假的"未知類別"特征
- 作用:生成更多樣的"未知"樣本來訓練判別器
第三步:對抗訓練過程
這是核心!OpenGAN使用三種數據同時訓練:
# 偽代碼展示訓練過程
for epoch in training:# 1. 真實的已知類別數據real_closed_features = extract_features(known_class_images)# 2. 真實的未知類別數據(少量)real_open_features = extract_features(outlier_images)# 3. 生成器創造的假未知數據noise = random_noise()fake_open_features = Generator(noise)# 訓練判別器:讓它學會區分已知和未知d_loss = - log(D(real_closed_features)) # 已知類別應該得高分- λ_o * log(1-D(real_open_features)) # 真實未知應該得低分 - λ_G * log(1-D(fake_open_features)) # 生成未知應該得低分# 訓練生成器:讓它生成能騙過判別器的"未知"特征g_loss = - log(1-D(Generator(noise))) # 生成的特征要能騙過判別器
關鍵參數解釋:
λ_o
:控制真實異常數據的重要性λ_G
:控制生成假數據的重要性- 當
λ_o=0
時,就是OpenGAN-0(不使用真實異常數據)
第四步:模型選擇的巧妙方法
傳統GAN的問題:
- 訓練到最后,生成器太強了,判別器分不出真假
- 判別器失去了區分"已知"和"未知"的能力
OpenGAN的解決方案:
- 在訓練過程中保存多個判別器快照
- 用少量真實的異常數據作為驗證集
- 選擇在驗證集上表現最好的判別器
# 模型選擇過程
best_auroc = 0
best_discriminator = Nonefor checkpoint in training_checkpoints:discriminator = load_checkpoint(checkpoint)auroc = evaluate_on_validation_set(discriminator, validation_outliers)if auroc > best_auroc:best_auroc = aurocbest_discriminator = discriminator# 使用最佳判別器進行最終預測
第五步:實際應用
在測試時:
def predict_open_set(test_image):# 1. 提取特征features = pretrained_classifier.extract_features(test_image)# 2. 用OpenGAN判別器打分confidence = best_discriminator(features)# 3. 做決策if confidence > threshold:# 進行正常的K類分類class_prediction = pretrained_classifier(test_image)return class_predictionelse:return "UNKNOWN_OBJECT" # 未知物體