谷歌大腦出品
paper: https://arxiv.org/abs/1805.09501
這里是個論文的閱讀心得,筆記,不等同論文全部內容
文章目錄
- 一、摘要
- 1.1 翻譯
- 1.2 筆記
- 二、(第三部分)自動增強:直接在感興趣的數據集上搜索最佳增強策略
- 2.1 翻譯
- 2.2 筆記
- 三、(第四部分)實驗與結果
- 3.1 翻譯
- 3.2 筆記
- 四、跳出論文,轉入應用——timm包
- 3.1 timm包的自動增強搜索策略
- 3.2 隨機增強參數解釋
- 3.3 策略增強的imagenet官方給的參數注釋
- 3.4 數據增強效果實驗
一、摘要
1.1 翻譯
數據增強是提高現代圖像分類器準確率的一種有效技術。然而,當前的數據增強實現是手工設計的。在本文中,我們描述了一個稱為AutoAugment的簡單過程,用于自動搜索改進的數據增強策略。在我們的實現中,我們設計了一個搜索空間,其中一個策略由許多子策略組成,其中一個子策略是為每個mini-batch中的每個圖像隨機選擇的。子策略由兩個操作組成,每個操作都是一個圖像處理函數,如平移、旋轉或剪切,以及應用這些函數的概率和大小。
我們使用搜索算法來找到最佳策略,使神經網絡在目標數據集上產生最高的驗證精度。我們的方法在CIFAR-10、CIFAR-100、SVHN和ImageNet上達到了最先進的精度(不需要額外的數據)。在ImageNet上,我們獲得了83.5%的Top-1準確率,比之前83.1%的記錄提高了0.4%。在CIFAR-10上,我們實現了1.5%的錯誤率,比以前的核心狀態好0.6%。我們發現增強策略在數據集之間是可轉移的。在ImageNet上學習的策略可以很好地轉移到其他數據集上,例如Oxford Flowers、Caltech-101、Oxford- iit Pets、FGVC Aircraft和Stanford Cars。
1.2 筆記
主要陳述了自動數據增廣的概念,這里比較重要的,我比較感興趣的是搜索算法來找到最佳策略,結合第一段,也就是作者提出一個搜索空間,一個策略分解為多個子策略,子策略也是隨機生成,而且每個子策略有2個數據增強的方法,然后搜索出最佳的子策略,然后表現在各大數據集上效果不錯,轉移到其他數據集也可以。
這里我比較好奇,如何去搜索最佳的策略?所以往下直接看方法。
二、(第三部分)自動增強:直接在感興趣的數據集上搜索最佳增強策略
2.1 翻譯
我們將尋找最佳增強策略的問題表述為一個離散搜索問題(參見圖1)。我們的方法由兩個部分組成:搜索算法和搜索空間。在高層次上,搜索算法(作為控制器RNN實現)對數據增強策略S進行采樣,該策略包含要使用的圖像處理操作、在每個批處理中使用該操作的概率以及操作的大小等信息。我們方法的關鍵是策略S將用于訓練具有固定架構的神經網絡,其驗證精度R將被發送回更新控制器。由于R不可微,控制器將通過策略梯度方法進行更新。在下一節中,我們將詳細描述這兩個組件。
圖1:概述我們使用搜索方法(例如,強化學習)來搜索更好的數據增強策略的框架。控制器RNN從搜索空間預測增強策略。具有固定架構的子網絡被訓練到收斂,達到精度R。獎勵R將與策略梯度方法一起使用來更新控制器,以便它可以隨著時間的推移生成更好的策略。
搜索空間細節:在我們的搜索空間中,一個策略由5個子策略組成,每個子策略由順序應用的兩個圖像操作組成。此外,每個操作還與兩個超參數相關聯:1)應用該操作的概率,以及2)操作的幅度。圖2顯示了在我們的搜索空間中具有5個子策略的策略示例。第一個子策略指定ShearX的順序應用程序,然后是Invert。這個應用ShearX的概率為0.9,當應用時,其大小為7(滿分為10)。然后我們以0.8的概率應用Invert。反相操作不使用幅度信息。我們強調這些操作是按照指定的順序進行的。
圖2:在SVHN上發現的策略之一,以及如何使用它來生成增強數據給定用于訓練神經網絡的原始圖像。該策略有5個子策略。對于小批處理中的每個圖像,我們均勻隨機地選擇一個子策略來生成變換后的圖像來訓練神經網絡。每個子策略由2個操作組成,每個操作與兩個數值相關聯:調用操作的概率和操作的大小。有可能調用某個操作,因此該操作可能不會應用到該小批處理中。但是,如果施加,則以固定的幅度施加。我們通過展示如何在不同的小批量中對一個圖像進行不同的轉換來強調應用子策略的隨機性,即使使用相同的子策略。正如文中所解釋的,在SVHN上,AutoAugment更經常地選擇幾何變換。可以看出為什么在SVHN上通常選擇反轉操作,因為圖像中的數字對于該變換是不變的。
我們在實驗中使用的操作來自PIL,一個流行的Python圖像庫為了通用性,我們考慮PIL中所有函數接受圖像作為輸入和輸出一個圖像。我們還使用了另外兩種很有前景的增強技術:Cutout[12]和samplep播[24]。我們搜索的操作是ShearX/Y, TranslateX/Y, Rotate, AutoContrast, Invert, Equalize, solalize, Posterize, Contrast, Color, Brightness, sharpening, cutout [12], Sample Pairing [24].總的來說,我們在搜索空間中有16項操作。每個操作還附帶一個默認的幅度范圍,這將在第4節中更詳細地描述。我們將震級范圍離散為10個值(均勻間隔),這樣我們就可以使用離散搜索算法來找到它們。同樣,我們也將應用該操作的概率離散為11個值(均勻間隔)。在(16×10×11)2種可能性的空間中查找每個子策略成為一個搜索問題。然而,我們的目標是同時找到5個這樣的子政策,以增加多樣性。有5個子策略的搜索空間大約有(16×10×11)10≈2.9×1032種可能性。
我們使用的16個操作及其默認值范圍如附錄中的表1所示。注意,在我們的搜索空間中沒有顯式的“Identity”操作;這個操作是隱式的,可以通過調用一個概率設置為0的操作來實現。
控制器在搜索過程中可以選擇的所有圖像轉換的列表。此外,控制器在搜索每個操作期間可以預測的幅度值如第三列所示(對于圖像大小為331x331)。有些變換不使用幅度信息(例如逆變和均衡)。
搜索算法細節: 我們在實驗中使用的搜索算法使用了強化學習,靈感來自[71,4,72,5]。搜索算法由兩個部分組成:控制器(遞歸神經網絡)和訓練算法(鄰域策略優化算法)[53]。在每一步,控制器預測由softmax產生的決策;然后將預測作為嵌入饋送到下一步。為了預測5個子策略,控制器總共有30個softmax預測,每個子策略有2個操作,每個操作需要操作類型、大小和概率。
控制器RNN的訓練: 控制器使用獎勵信號進行訓練,這表明該策略在改善“子模型”(作為搜索過程一部分訓練的神經網絡)的泛化方面有多好。在我們的實驗中,我們設置了一個驗證集來度量子模型的泛化。通過在訓練集(不包含驗證集)上應用5個子策略生成的增強數據來訓練子模型。對于mini-batch中的每個示例,隨機選擇5個子策略中的一個來增強圖像。然后在驗證集上評估子模型以測量準確性,并將其用作訓練循環網絡控制器的獎勵信號。在每個數據集上,控制器對大約15,000個策略進行采樣。
控制器RNN的架構和訓練超參數: 我們遵循[72]中的訓練過程和超參數來訓練控制器。更具體地,控制器RNN是一個單層LSTM[21],每層有100個隱藏單元,對與每個架構決策相關的兩個卷積單元(其中B通常為5)進行2 × 5B softmax預測。控制器RNN的10B個預測中的每一個都與一個概率相關聯。子網絡的聯合概率是這10B軟最大值的所有概率的乘積。該聯合概率用于計算控制器RNN的梯度。根據子網絡的驗證精度縮放梯度,以更新控制器RNN,使控制器為壞的子網絡分配低概率,為好的子網絡分配高概率。與[72]類似,我們采用學習率為0.00035的近端策略優化(PPO)[53]。為了鼓勵探索,我們還使用了權重為0.00001的熵懲罰。在我們的實現中,基線函數是先前獎勵的指數移動平均值,權重為0.95。控制器的權重在-0.1到0.1之間均勻初始化。出于方便,我們選擇使用PPO來訓練控制器,盡管先前的工作表明,其他方法(例如增強隨機搜索和進化策略)可以表現得同樣好,甚至略好[30]。
在搜索結束時,我們將最佳5個策略中的子策略連接到單個策略中(包含25個子策略)。最后這個包含25個子策略的策略用于訓練每個數據集的模型。
上述搜索算法是我們可以用來尋找最佳策略的許多可能的搜索算法之一。也許可以使用不同的離散搜索算法,如遺傳規劃[48]甚至隨機搜索[6]來改進本文的結果。
【關于訓練迭代數在5部分Discuss有提到,這里放在一起】
訓練步驟與子策略數量之間的關系:我們工作的一個重要方面是子策略在訓練過程中的隨機應用。每個圖像僅由每個小批中可用的許多子策略中的一個增強,子策略本身具有進一步的隨機性,因為每個轉換都有與其關聯的應用程序的概率。我們發現這種隨機性要求每個子策略有一定數量的epoch才能使AutoAugment有效。由于子模型每次訓練都用5個子策略,因此在模型完全受益于所有子策略之前,它們需要訓練超過80-100個epoch。 這就是為什么我們選擇訓練我們的child模型為120個epochs。每個子策略需要應用一定的次數,模型才能從中受益。在策略被學習之后,完整的模型被訓練更長的時間(例如CIFAR-10上的Shake-Shake訓練1800個epoch, ImageNet上的ResNet-50訓練270個epoch),這允許我們使用更多的子策略。
2.2 筆記
這里講了自動搜索算法,類似訓練的概念,學習出一個最優的數據增強策略,討論部分也提到了要更多的epoch來搜索。后面的實驗也就是設置一個基準,然后跟沒有用autoaugment或者跟其他方法比較,最后討論和消融實驗。這里就不往下看了,感興趣可以直接進最上面的原文鏈接看原文。
這里再看下應用,研究autoAug也是因為需要提升訓練精度,然后在timm包里發現了這個,進而來研究下,下面再做一下timm里面的學習筆記。
三、(第四部分)實驗與結果
看一下具體的實驗是怎么做的,比如:如何設置基準,具體實驗在什么基礎上去選取最佳策略
3.1 翻譯
實驗總結。在本節中,我們在autoaugmentdirect和AutoAugmenttransfer兩個用例中實證地研究了AutoAugment的性能。首先,我們將對AutoAugment進行基準測試,在高度競爭的數據集上直接搜索最佳增強策略:CIFAR-10[28]、CIFAR-100[28]、SVHN42和ImageNet10數據集。我們的研究結果表明,AutoAugment的直接應用顯著改善了基線模型,并在這些具有挑戰性的數據集上產生了最先進的精度。接下來,我們將研究增強策略在數據集之間的可轉移性。更具體地說,我們將把在ImageNet上找到的最佳增強策略轉移到細粒度分類數據集,如Oxford 102 Flowers、Caltech-101、Oxford- iiit Pets、FGVC Aircraft、Stanford Cars(第4.3節)。我們的研究結果還表明,增強策略具有驚人的可轉移性,并且在這些數據集上的強基線模型上產生了顯著的改進。最后,在第5節中,我們將進行比較AutoAugment與其他自動數據增強方法,并表明AutoAugment明顯更好。
CIFAR-10, CIFAR-100, SVHN結果
雖然CIFAR-10有50,000個訓練樣本,但我們在一個較小的數據集上執行最佳策略搜索,我們稱之為“減少的CIFAR-10”,它由4,000個隨機選擇的樣本組成,以節省在增強搜索過程中訓練子模型的時間(我們發現結果策略似乎對這個數字并不敏感)。我們發現,對于固定的訓練時間,允許子模型訓練更多的epoch比使用更多的訓練數據訓練更少的epoch更有用。對于子模型架構,我們使用小型WideResNet-40-2(40層-拓寬因子為2)模型[67],并進行120次epoch的訓練。使用小型Wide-ResNet是為了提高計算效率,因為每個子模型都是從頭開始訓練以計算控制器的梯度更新。我們使用10?4的權重衰減,0.01的學習率和一個退火周期的余弦學習衰減[36]。
在縮減后的CIFAR10上搜索到的策略隨后用于訓練CIFAR-10、縮減后的CIFAR-10和CIFAR-100上的最終模型。如上所述,我們將最好的5個策略中的子策略連接起來,形成具有25個子策略的單個策略,用于CIFAR數據集上的所有AutoAugment實驗。
基線預處理遵循最先進的CIFAR-10模型的慣例:標準化數據,使用50%概率的水平翻轉,零填充和隨機作物,最后使用16 × 16像素的Cutout[17,65,48,72]。除了標準基線預處理之外,還應用了AutoAugment策略:在一張圖像上,我們首先應用現有基線方法提供的基線增強,然后應用AutoAugment策略,然后應用Cutout。我們沒有優化cut - out區域的大小,使用16像素的建議值[12]。注意,由于Cutout是搜索空間中的一個操作,因此可以在同一圖像上使用兩次Cutout:第一次使用學習的區域大小,第二次使用固定的區域大小。在實踐中,由于第一次應用程序中Cutout操作的概率很小,因此通常對給定圖像使用一次Cutout。
在CIFAR-10上,AutoAugment主要選擇基于顏色的轉換。例如,CIFAR-10上最常用的轉換是均衡、自動對比度、顏色和亮度(參見附錄中的表1了解它們的描述)。像ShearX和ShearY這樣的幾何變換很少出現在好的策略中。此外,在成功的策略中幾乎從未應用過轉換Invert。在CIFAR-10上找到的策略包含在附錄中。下面,我們描述結果在CIFAR數據集使用在簡化的CIFAR-10上找到的策略。所有報告的結果都是5次運行的平均值。
CIFAR-10結果。 在表2中,我們展示了不同神經網絡架構上的測試集精度。我們在TensorFlow[1]中實現了Wide-ResNet-28-10[67]、Shake-Shake[17]和ShakeDrop[65]模型,并找到了權重衰減和學習率超參數,它們為常規訓練提供了最佳的驗證集準確性。其他超參數與介紹模型的論文中報道的相同[67,17,65],除了對Wide-ResNet-28-10使用余弦學習衰減。然后,我們使用相同的模型和超參數來評估AutoAugment的測試集準確性。對于amoebanet,我們使用了與[48]中用于基線增強和自動增強相同的超參數。從表中可以看出,我們使用ShakeDrop[65]模型的錯誤率為1.5%,比目前的狀態[48]提高了0.6%。值得注意的是,這個增益要比之前的AmoebaNet-B對ShakeDrop的增益(+0.2%)和ShakeDrop對Shake-Shake的增益(+0.2%)大得多。參考文獻[68]報告了在CIFAR-10上訓練的WideResNet-28-10模型的改進幅度為1.1%。
我們還在最近提出的CIFAR-10測試集上評估了使用AutoAugment訓練的最佳模型[50]。Recht等人[50]報告說,Shake-Shake (26 2x64d) + Cutout在這個新數據集上表現最好,錯誤率為7.0%(相對于原始CIFAR10測試集的錯誤率高4.1%)。此外,PyramidNet+ShakeDrop在新數據集上的錯誤率為7.7%(相對于原始測試集高4.6%)。我們最好的模型,使用AutoAugment訓練的PyramidNet+ShakeDrop的錯誤率為4.4%(比原始集的錯誤率高2.9%)。與在這個新數據集上評估的其他模型相比,我們的模型在準確性上的下降幅度要小得多。
3.2 筆記
是在小的數據集上先搜索出策略,然后從最好的5個策略中選擇子策略組成25個子策略再訓練整個數據集。對于CIFAR-10數據集,ResNet-40-2是訓練120個epoch去搜索最佳策略。全程沒有提到預訓練模型,要不默認,要不就沒用預訓練模型。我估計還是默認預訓練模型,然后都用預訓練模型作為baseline去比較。
四、跳出論文,轉入應用——timm包
參考:https://timm.fast.ai/AutoAugment#auto_augment_policy
原文:
在本教程中,我們將了解如何利用 AutoAugment 作為一種數據增強技術來訓練神經網絡。
我們看:
- 我們如何使用 timm 訓練腳本來應用 AutoAugment 。
- 我們如何使用 AutoAugment 作為自定義訓練循環的獨立數據增強技術。
- 深入研究 AutoAugment 的源代碼。
理解:
發現這里只是用了論文的預設結論或者其他的結論生成的一些策略,以及一些增強算子隨機增強。以下是對自動增強策略的解讀,以及實驗看下每個隨機增強的效果。
3.1 timm包的自動增強搜索策略
其中timm包的自動增強搜索策略包含:
- AutoContrast: 自動對比度調整。
- Equalize: 直方圖均衡化。
- Invert: 反轉圖像顏色。
- Rotate: 隨機旋轉圖像。
- Posterize: 減少圖像的色階。
- Solarize: 部分地反轉圖像的像素值。
- SolarizeAdd: 在圖像上添加一些反轉效果。
- Color: 隨機調整圖像的顏色。
- Contrast: 隨機調整圖像的對比度。
- Brightness: 隨機調整圖像的亮度。
- Sharpness: 隨機調整圖像的銳度。
- ShearX: 沿著 X 軸隨機剪切圖像。
- ShearY: 沿著 Y 軸隨機剪切圖像。
- TranslateXRel: 沿著 X 軸相對隨機平移圖像。
- TranslateYRel: 沿著 Y 軸相對隨機平移圖像。
3.2 隨機增強參數解釋
rand_augment_transform函數的注釋
這段代碼是用于創建一個 RandAugment 變換的函數。RandAugment 是一種數據增強的方法,通過對圖像應用一系列隨機的數據變換來增加訓練數據的多樣性。
這個函數接受兩個參數:
- config_str:一個字符串,定義了隨機增強的配置。這個字符串包括多個部分,由破折號(‘-’)分隔。第一個部分定義了具體的 RandAugment 變體(目前只有 ‘rand’)。其余的部分用于確定具體的配置參數,包括:
- ‘m’:整數,表示 RandAugment 的幅度(magnitude)。
- ‘n’:整數,表示每個圖像選擇的變換操作的數量。
- ‘w’:整數,表示概率權重的索引(一組權重集合的索引,用于影響操作的選擇)。
- ‘mstd’:浮點數,表示幅度噪聲的標準差,或者如果是無窮大(或大于100),則進行均勻采樣。
- ‘mmax’:設置幅度的上限,而不是默認的 _LEVEL_DENOM(10)。
- ‘inc’:整數(布爾值),表示是否使用隨著幅度增加而增加的增強(默認為0)。
- hparams:其他的超參數(關鍵字參數),用于配置 RandAugmentation 方案。
最終,這個函數返回一個與 PyTorch 兼容的變換(Transform),可以用于數據增強。這個變換將在訓練過程中被應用于圖像數據。
3.3 策略增強的imagenet官方給的參數注釋
policy = [[('PosterizeOriginal', 0.4, 8), ('Rotate', 0.6, 9)],[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],[('PosterizeOriginal', 0.6, 7), ('PosterizeOriginal', 0.6, 6)],[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],[('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],[('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],[('PosterizeOriginal', 0.8, 5), ('Equalize', 1.0, 2)],[('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],[('Equalize', 0.6, 8), ('PosterizeOriginal', 0.4, 6)],[('Rotate', 0.8, 8), ('Color', 0.4, 0)],[('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],[('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],[('Invert', 0.6, 4), ('Equalize', 1.0, 8)],[('Color', 0.6, 4), ('Contrast', 1.0, 8)],[('Rotate', 0.8, 8), ('Color', 1.0, 2)],[('Color', 0.8, 8), ('Solarize', 0.8, 7)],[('Sharpness', 0.4, 7), ('Invert', 0.6, 8)],[('ShearX', 0.6, 5), ('Equalize', 1.0, 9)],[('Color', 0.4, 0), ('Equalize', 0.6, 3)],[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],[('Invert', 0.6, 4), ('Equalize', 1.0, 8)],[('Color', 0.6, 4), ('Contrast', 1.0, 8)],[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
]
分別是:變換名,變換概率,變換強度
3.4 數據增強效果實驗
from timm.data.auto_augment import AugmentOp
from PIL import Image
from matplotlib import pyplot as pltimg_path = r"/path/to/imagenet-mini/val/n01537544/ILSVRC2012_val_00023438.JPEG"
mean = (0.485, 0.456, 0.406)
X = Image.open(img_path)
img_size_min = min(X.size)
plt.imshow(X)
plt.show()all_policy_use_op = [['AutoContrast', 1, 10], ['Equalize', 1, 10], ['Invert', 1, 10], ['Rotate', 1, 10], ['Posterize', 1, 10],['PosterizeIncreasing', 1, 10], ['PosterizeOriginal', 1, 10], ['Solarize', 1, 10], ['SolarizeIncreasing', 1, 10],['SolarizeAdd', 1, 10], ['Color', 1, 10], ['ColorIncreasing', 1, 10], ['Contrast', 1, 10],['ContrastIncreasing', 1, 10], ['Brightness', 1, 10], ['BrightnessIncreasing', 1, 10], ['Sharpness', 1, 10],['SharpnessIncreasing', 1, 10], ['ShearX', 1, 10], ['ShearY', 1, 10], ['TranslateX', 1, 10], ['TranslateY', 1, 10],['TranslateXRel', 1, 10], ['TranslateYRel', 1, 10]
]for op_name, p, m in all_policy_use_op:aug_op = AugmentOp(name=op_name, prob=p, magnitude=m,hparams={'translate_const': int(img_size_min * 0.45),'img_mean': tuple([min(255, round(255 * x)) for x in mean])})plt.imshow(aug_op(X))plt.title(f'{op_name}_{str(p)}_{str(m)}')plt.show()
原圖
AutoContrast
Equalize
Invert
Rotate
Posterize
PosterizeIncreasing
PosterizeOriginal
Solarize
SolarizeIncreasing
SolarizeAdd
Color
ColorIncreasing
Contrast
ContrastIncreasing
Brightness
BrightnessIncreasing
Sharpness
SharpnessIncreasing
ShearX
ShearY
TranslateX
TranslateY
TranslateXRel
TranslateYRel