論文筆記
資料
1.代碼地址
https://github.com/qinenergy/cotta
2.論文地址
https://arxiv.org/abs/2203.13591
3.數據集地址
論文摘要的翻譯
TTA的目的是在不使用任何源數據的情況下,將源預先訓練的模型適應到目標域。現有的工作主要考慮目標域是靜態的情況。然而,現實世界的機器感知系統運行在非靜態和不斷變化的環境中,其中目標域分布可能會隨著時間的推移而變化。現有的方法大多基于自訓練和熵正則化,可能會受到這些非平穩環境的影響。由于目標域中的分布隨時間移動,偽標簽變得不可靠。嘈雜的偽標簽會進一步導致錯誤累積和災難性的遺忘。為了解決這些問題,我們提出了一種連續測試時間適應方法(continual test-time adaptation,COTTA),該方法包括兩個部分。首先,我們建議通過使用通常更準確的加權平均和增廣平均預測來減少誤差積累。另一方面,為了避免災難性的遺忘,我們建議在每次迭代期間隨機將一小部分神經元恢復到源預先訓練的權重,以幫助長期保存源知識。該方法能夠對網絡中的所有參數進行長期自適應。CONTTA易于實施,并且可以很容易地整合到現成的預訓練的模型中。我們在四個分類任務和一個連續測試時間自適應的分割任務上證明了我們的方法的有效性,我們在這方面的表現優于現有的方法。
1 介紹
TTA旨在通過在推理時從未標記的測試(目標)數據中學習來適配源預先訓練的模型。由于源訓練數據和目標測試數據之間的域分布差異,需要進行自適應以獲得良好的性能。源數據通常被認為在推理時間內不可用,這使得它比無監督的域自適應更具挑戰性但更現實。
現有的測試時間自適應工作通常通過使用偽標記法或熵正則化來更新模型參數來處理源域和固定目標域之間的分布差異
然而,當目標測試數據來自一個不斷變化的環境時,它們可能是不穩定的。這有兩個方面的原因:
首先,在不斷變化的環境下,由于分布偏移,偽標簽變得更噪聲和錯誤校準。因此,早期預測錯誤更有可能導致誤差累積。
其次,由于模型長時間地不斷適應新的分布,來自源域的知識更難保存,導致災難性的遺忘。
這里主要介紹online continual test-time adaptation的實際問題。
如圖1所示,目標是從現成的源代碼預先訓練的模型開始,并不斷地使其適應當前的測試數據。
為了有效地使預先訓練源模型適應不斷變化的測試數據,我們提出了一種連續測試時間自適應方法(COTTA),解決了現有方法的兩個主要局限性。該方法的第一個組成部分旨在減少誤差累積。我們建議在自訓練框架下通過兩種不同的方法來提高偽標簽的質量。一方面,由于教師平均預測往往比標準模型[55]具有更高的質量,我們使用加權平均教師模型來提供更準確的預測。另一方面,對于領域差距較大的測試數據,我們使用了增廣平均預測來進一步提高偽標簽的質量。提出的方法的第二個組成部分旨在幫助保存源知識和避免遺忘。我們建議隨機地將網絡中的一小部分神經元恢復到預先訓練的源模型。通過減少錯誤積累和保存知識,CONTTA能夠在不斷變化的環境中進行長期適應,并使訓練網絡的所有參數成為可能。相比之下,以前的方法只能訓練BN的參數。
2論文的創新點
- 提出了一種連續的測試時間自適應方法COTTA,該方法能夠有效地使現成的源預訓練模型適應不斷變化的目標數據。
- 具體地說,通過使用更準確的加權平均和增廣平均偽標簽來減少誤差累積。
- 通過顯式地保存來自源模型的知識來緩解長期遺忘效應。
- 該方法顯著提高了分類基準和分割基準的continual test-time adaptation的性能。
3 Continual Test-Time Domain Adaptation方法的概述
3.1 問題定義
給定具有對源數據 ( X S , Y S ) (\mathcal{X^S},\mathcal{Y^S}) (XS,YS)訓練的參數 θ θ θ的現有預訓練模型 f θ 0 ( x ) f_{θ_0}(x) fθ0??(x),在不訪問任何源數據的情況下以在線方式不斷變化的目標域。順序地提供未標記的目標域數據 X T \mathcal{X^T} XT,并且該模型只能訪問當前時間步長的數據。在時間步長 t t t處,目標數據 X t T X^T_t XtT?被提供作為輸入,并且模型 f θ t f_{θ_t} fθt??需要做出預測 f θ t ( X t T ) f_{θ_t}(X^T_t) fθt??(XtT?),并相應地適應未來輸入 θ t → θ t + 1 θ_t→θ_{t+1} θt?→θt+1?。 X t T X^T_t XtT?的數據分布不斷變化。根據在線預測對該模型進行了評估。這種設置在很大程度上是由于機器感知應用在不斷變化的環境中的需求。我們在表1中列出了我們的在線連續測試時間適應設置與現有適應設置之間的主要區別。與以前專注于固定目標域的設置相比,我們考慮的是對不斷變化的目標環境的長期適應
3.2 方法
提出了一種用于在線連續測試時間自適應設置的自適應方法。該方法采用現成的源預訓練模型,并在線自適應不斷變化的目標數據。基于錯誤累積是自訓練框架中的關鍵瓶頸之一這一事實,我們提出使用加權和增強平均偽標簽來減少錯誤累積。此外,為了幫助減少連續適應中的遺忘,我們建議顯式保留來自源模型的信息。圖2顯示了所建議方法的概述。
3.2.1 Source Model
現有的測試時間自適應工作往往需要在源模型的訓練過程中進行特殊處理,以提高領域泛化能力,促進自適應。本方法不需要修改體系結構或額外的源訓練過程。因此,任何現有的預先訓練的模型都可以使用,而不需要對源進行重新培訓。
3.2.2 Weight-Averaged Pseudo-Labels
在給定目標數據 x t T x^T_t xtT?和模型 f θ t f_{θ_t} fθt??的情況下,自訓練框架下的共同測試時間目標是最小化預測 y ^ t T = f θ t ( x t T ) a \hat{y}_{t}^{T} = f_{\theta_{t}}(x_{t}^{T}) \mathrm{a} y^?tT?=fθt??(xtT?)a和偽標簽之間的交叉熵一致性。例如,直接使用模型預測本身作為偽標簽導致TENT[61]的訓練目標(即熵最小化)。雖然這對固定的目標域有效,但由于分布偏移,對于不斷變化的目標數據,偽標簽的質量可能會顯著下降。
由于觀察到訓練步驟中的加權平均模型通常比最終模型提供更準確的模型,我們使用加權平均教師模型 f θ ′ f_{\theta^{\prime}} fθ′?來生成偽標簽。在時間步長 t = 0 t=0 t=0時,教師網絡被初始化為與源預訓練網絡相同。在時間處于 t t t時,首先由教師 y ′ ^ t T = f θ t ′ ( x t T ) . \hat{y^{\prime}}_{t}^{T}=f_{\theta_{t}^{\prime}}(x_{t}^{T}). y′^?tT?=fθt′??(xtT?).生成偽標簽。
然后通過學生和教師預測之間的交叉點損失來更新學生 f θ t f_{θ_t} fθt??
L θ t ( x t T ) = ? ∑ c y ′ ^ t c T log ? y ^ t c T , ( 1 ) \mathcal{L}_{\theta_{t}}(x_{t}^{T})=-\sum_{c}\hat{y'}_{tc}^{T}\log\hat{y}_{tc}^{T},\quad(1) Lθt??(xtT?)=?c∑?y′^?tcT?logy^?tcT?,(1)
在使用公式1更新學生模型 θ t → θ t + 1 θ_t→θ_{t+1} θt?→θt+1?之后,我們使用學生權重通過指數移動平均來更新教師模型的權重 θ t + 1 ′ = α θ t ′ + ( 1 ? α ) θ t + 1 , ( 2 ) \theta'_{t+1}=\alpha\theta'_t+(1-\alpha)\theta_{t+1},\quad(2) θt+1′?=αθt′?+(1?α)θt+1?,(2)
其中,α是一個平滑因子。我們對輸入數據 x t T x^T_t xtT?的最終預測是 y ′ ^ t T \hat{y^{\prime}}_{t}^{T} y′^?tT?中具有最高概率的類。
重量平均一致性的好處有兩個。一方面,通過使用通常更準確的加權平均預測作為偽標簽目標,我們的模型在連續自適應過程中遭受的誤差累積較少。
另一方面,平均教師預測 y ′ ^ t T \hat{y^{\prime}}_{t}^{T} y′^?tT?編碼了過去迭代中來自模型的信息,因此在長期的連續適應中不太可能遭受災難性遺忘,并提高了對新的未知領域的泛化能力。
3.2.3 Augmentation-Averaged Pseudo-Labels
訓練時間內的數據擴充已被廣泛應用于提高模型的性能。對于不同的數據集,通常手動設計或搜索不同的擴充策略。雖然測試時間擴充也已被證明能夠提高穩健性,但擴充策略通常是針對特定數據集確定和固定的,而不考慮推理時間期間的分布變化。在不斷變化的環境下,測試分發可能會發生巨大變化,這可能會使增強策略無效。在這里,我們考慮了測試時間域的分布差異,并用預測置信度來逼近域差異。僅當域差異較大時才應用增強,以減少誤差累積。
y ′ ~ t T = 1 N ∑ i = 0 N ? 1 f θ t ′ ( arg ? i ( x t T ) ) , (3) y ′ t T = { y ′ ^ t T , if?conf ( f θ 0 ( x t T ) ) ≥ p t h y ′ ~ t T , otherwise , (4) \begin{aligned}\tilde{y'}_{t}^{T}&=\frac{1}{N}\sum_{i=0}^{N-1}f_{\theta_{t}^{'}}(\arg_{i}(x_{t}^{T})),&\text{(3)}\\{y'}_{t}^{T}&=\begin{cases}\hat{y'}_{t}^{T},&\text{if conf}(f_{\theta_{0}}(x_{t}^{T}))\geq p_{th}\\\tilde{y'}_{t}^{T},&\text{otherwise},\end{cases}&\text{(4)}\end{aligned} y′~?tT?y′tT??=N1?i=0∑N?1?fθt′??(argi?(xtT?)),={y′^?tT?,y′~?tT?,?if?conf(fθ0??(xtT?))≥pth?otherwise,??(3)(4)?
其中 y ′ ~ t T \widetilde{y^{\prime}}_{t}^{T} y′ ?tT?是來自教師模型的增廣平均預測, y ′ ^ t T \hat{y^{\prime}}_{t}^{T} y′^?tT?是來自教師模型的直接預測, c o n f ( f θ 0 ( X t T ) ) conf(f_{θ_0}(X_t^T)) conf(fθ0??(XtT?)) 是源預訓練模型對當前輸入 x t T x^T_t xtT?的預測置信度,以及 P t h P_{th} Pth?是置信度閾值。通過使用公式4中的預訓練模型 f θ 0 來 f_{θ_0}來 fθ0??來計算對當前輸入Xtt的預測置信度,我們試圖逼近源和當前域之間的域差異。我們假設較低的置信度表示較大的域間隙,而相對較高的置信度表示較小的域間隙。因此,當置信度高且大于閾值時,我們直接使用 y ′ ^ t T \hat{y^{\prime}}_{t}^{T} y′^?tT?作為我們的偽標簽,而不使用任何增廣。當置信度較低時,我們采用額外的N個隨機增強來進一步提高偽標簽的質量。過濾是至關重要的,過濾是至關重要的,因為我們觀察到隨機增強,因為我們觀察到,在具有小域間隙的自信樣本上的隨機增加有時會降低模型的性能。我們在補充材料中對這一觀察結果進行了詳細討論。總而言之,我們使用置信度來逼近域差異,并確定何時應用擴展。學生通過改進的偽標簽進行更新:
3.2.4 Stochastic Restoration
雖然更準確的偽標簽可以減少錯誤積累,但長期自我訓練的持續適應不可避免地會引入錯誤并導致遺忘。如果我們在數據序列中遇到強烈的域移,這個問題可能特別相關,因為強烈的分布移位會導致錯誤校準甚至錯誤的預測。在這種情況下,自我訓練可能只會強化錯誤的預測。更糟糕的是,在遇到硬性例子后,即使新數據沒有嚴重漂移,模型也可能因為不斷的適應而無法恢復。為了進一步解決災難性遺忘問題,我們提出了一種隨機恢復方法,該方法顯式地恢復源預先訓練模型中的知識。考慮基于時間步 t t t處的公式1的梯度更新之后的學生模型 f θ f_θ fθ?內的卷積層: x l + 1 = W t + 1 ? x l , ( 6 ) x_{l+1}=W_{t+1}*x_{l},\quad(6) xl+1?=Wt+1??xl?,(6)其中,?表示卷積運算, x l 和 x l + 1 x_l和x_{l + 1} xl?和xl+1?表示到該層的輸入和輸出, W t + 1 W_{t + 1} Wt+1?表示可訓練的卷積濾波器。建議的隨機恢復方法還通過以下方式更新權重 W W W: M ~ Bernoulli ( p ) , ( 7 ) W t + 1 = M ⊙ W 0 + ( 1 ? M ) ⊙ W t + 1 , ( 8 ) \begin{aligned}M&\sim\text{Bernoulli}(p),\quad&(7)\\W_{t+1}&=M\odot W_0+(1-M)\odot W_{t+1},\quad&(8)\end{aligned} MWt+1??~Bernoulli(p),=M⊙W0?+(1?M)⊙Wt+1?,?(7)(8)?其中同 ⊙ \odot ⊙表示逐個元素的乘法。 p p p是一個小的恢復概率, M 是與 W t + 1 M是與W_{t+1} M是與Wt+1?形狀相同的掩模張量。隨機恢復也可以看作是丟棄的一種特殊形式。通過隨機地將可訓練權值中的少量張量元素恢復到初始權值,網絡避免了距離初始源模型太遠的漂移,從而避免了災難性遺忘。此外,通過保存來自源模型的信息,我們能夠訓練所有可訓練的參數,而不會遭受模型崩潰的痛苦。這為自適應帶來了更多的容量,并且與僅訓練用于測試時間自適應的BN參數的熵最小化方法如算法1所示,將改進的偽標記法與隨機恢復相結合,得到了在線連續測試時間自適應(COTTA)方法。
4 論文實驗
五個連續測試時間自適應基準任務:CIFAR10-to-CIFAR10C(標準和漸進式)、CIFAR100-to-CIFAR100C、ImageNet-to-ImageNet-C以及用于語義分割的Cityscapses-to-ACDC上對我們的方法進行了評估。
4.1 Experiments on CIFAR10-to-CIFAR10C
我們首先評估了所提出的模型在CIFAR10到CIFAR10C任務上的有效性。我們將我們的方法與純源代碼基線和四種流行的方法進行了比較。
如表2所示,直接使用沒有自適應的預訓練模型產生了43.5%的高平均錯誤率,表明自適應是必要的。BN統計自適應方法保持網絡權重,并使用來自當前迭代的輸入數據的批量歸一化統計用于預測。該方法簡單且完全在線,在僅限源代碼的基線上顯著提高了性能。使用硬偽標簽來更新BN可訓練參數可以將錯誤率降低到19.8%。如果帳篷在線方法能夠訪問附加域信息,并在遇到新域時將其自身重置為初始的預訓練模型,則性能可以進一步提高到18.6%。然而,這樣的信息在實際應用中通常是不可用的。如果不能訪問這些附加信息,帳篷連續方法不會比BNStats Adapt方法產生任何改進。值得一提的是,在適應的早期階段,帳篷持續的表現優于國陣統計適應。然而,在觀察到三種類型的腐敗后,該模型很快就惡化了。這表明,由于誤差累積,基于帳篷的方法在長期持續適應下可能不穩定。通過使用加權平均一致性,我們提出的方法可以持續地優于上述所有方法。誤碼率顯著降低到16.2%。此外,由于我們的隨機恢復方法,它在長期內不會受到性能下降的影響。
這一部分的消融實驗
表2的下部分
4.2 Experiments on CIFAR100-to-CIFAR100C
為了進一步證明所提方法的有效性,我們在難度更大的CIFAR100to-CIFAR100C任務上進行了評估。表4總結了實驗結果。
4.3 Experiments on ImageNet-to-ImageNet-C
為了對所提出的方法進行更全面的評估,在嚴重性級別為5的10個不同的腐敗類型序列上進行了ImageNet到ImageNet-C的實驗。如表6所示,CONTA能夠持續地優于帳篷和其他競爭方法。±之后的數字是10種不同損壞類型序列的標準偏差。
4.4 Experiments on Cityscapes-to-ACDC
此外,我們還在更復雜的連續測試時間語義分割Cityscapesto-ACDC任務上對我們的方法進行了評估。實驗結果如表5所示。實驗結果表明,我們的方法對于語義分割任務也是有效的,并且對不同的體系結構選擇具有較強的魯棒性。我們提出的方法在基準的基礎上產生了1.9%的絕對改進,并且達到了58.6%的MIU.值得一提的是,BN統計適應和帳篷在這項任務中表現不佳,隨著時間的推移,性能會顯著下降。這在一定程度上是因為兩者都是專門為具有批歸一化層的網絡設計的,而Segformer中只有一個批歸一化層,而transform模型中的大多數歸一化層都基于LayerNorm。然而,我們的方法不依賴于特定的層,并且仍然可以在非常不同的體系結構上有效地完成這項更復雜的任務。改進的性能在經過相對較長的時間不斷調整后也基本保持不變。
5 總結
在這項工作中,關注的是在非靜態環境中的連續測試時間適應,其中目標域分布可以隨著時間的推移而不斷變化。為了解決這種方法中的誤差累積和災難性遺忘問題,我們提出了一種新的方法COTTA,該方法包括兩部分。==首先,我們通過使用加權平均和增廣平均預測來減少誤差積累,這兩種預測往往更準確。==其次,為了保存來自源模型的知識,我們隨機地將一小部分權重恢復到源預先訓練的權重。所提出的方法可以結合到現成的預訓練模型中,而不需要對源數據的任何訪問。在4個分類和1個分割任務上驗證了COTTA的有效性。