提示:文章寫完后,目錄可以自動生成,如何生成可參考右邊的幫助文檔
文章目錄
- 摘要
- Abstract
- 一、方法介紹
- 2.Rainbow Memory(RM)
- 2.1多樣性感知內存更新
- 2.2通過數據增強增強樣本多樣性(DA)
- 二、使用步驟
- 1.實驗概況
- 2.RM核心代碼
- 總結
摘要
本博客概述了文章《Rainbow Memory: Continual Learning with a Memory of Diverse Samples》聚焦于任務邊界模糊的持續學習場景,提出基于樣本分類不確定性和數據增強的Rainbow Memory (RM)記憶管理策略。多數研究在任務不共享類別的較人為的設置下評估相關方法,但在現實世界應用場景中,任務之間的類分布是不斷變化的,更現實和實用的是任務共享類別的模糊CIL設置。在這種設置下,之前存儲少量舊數據的方法雖在緩解災難性遺忘方面有成果,但也引出了如何管理記憶(memory)的最優策略問題。基于該問題,研究者在新定義的模糊CIL設置下更好地持續學習的兩個因素:記憶的采樣和記憶中的數據增強,進而提出Rainbow Memory(RM)方法。通過在MNIST、CIFAR10、CIFAR100和ImageNet數據集上的實證驗證,RM在模糊持續學習設置中顯著提高了準確性,大幅超越現有技術。
文章鏈接
Abstract
This blog summarizes the article “Rainbow Memory: Continual Learning with a Memory of Diverse Samples”, which focuses on the continuous learning scenario with fuzzy task boundaries, and proposes a Rainbow Memory (RM) memory management strategy based on sample classification uncertainty and data augmentation. Most studies evaluate the relevant methods in a more artificial setting where tasks do not share categories, but in real-world application scenarios, the class distribution between tasks is constantly changing, and it is more realistic and practical to see the fuzzy CIL settings of task sharing categories. In this setting, the previous method of storing a small amount of old data has been successful in mitigating catastrophic forgetting, but it also raises the question of the optimal strategy for managing memory. Based on this problem, the researchers proposed a rainbow memory (RM) method for better continuous learning under the newly defined fuzzy CIL setting: memory sampling and data enhancement in memory. Through empirical verification on MNIST, CIFAR10, CIFAR100, and ImageNet datasets, RM significantly improves accuracy in fuzzy continuous learning settings, significantly outperforming existing technologies.
一、方法介紹
模糊類增量學習的設置要求如下:1)每個任務作為流順序地給出,(2)大多數(分配的)任務類別彼此不同,以及(3)模型只能利用先前任務的非常小的一部分數據。 如下圖所示,在模糊CIL中,任務共享類,與傳統的不相交CIL相反。建議的記憶管理策略更新的情景記憶與當前任務的樣本,以保持不同的樣本在內存中。數據擴充(DA)進一步增強了內存中樣本的多樣性。
2.Rainbow Memory(RM)
在模糊類增量學習的場景中,現有方法因樣本多樣性不足導致模型易過擬合或遺忘嚴重。為了解決該問題,研究者提出了Rainbow Memory(RM),RM提出通過多樣性記憶管理和數據增強解決 Blurry-CIL 問題。
2.1多樣性感知內存更新
研究者認為,被選擇存儲在內存中的樣本應該不僅是代表其相應的類,還要識別其他類。為了選擇這樣的樣本,研究者認為,在分類邊界附近的樣本是最具鑒別力的,靠近分布中心的樣本是最具代表性的。為了滿足這兩個特點,研究者建議抽樣的樣本是不同的特征空間。
由于計算樣本與樣本之間的距離O(N2)較為復雜和昂貴,研究者通過分類模型估計的樣本的不確定性來估計相對位置,即假設模型的更確定的樣本將位于更靠近類分布的中心,通過測量擾動樣本的模型輸出方差來計算樣本的不確定性,擾動樣本通過各種數據增強轉換方法進行:包括顏色抖動、剪切和剪切,如下圖所示:
通過蒙特-卡羅(MC)法近似計算分布p(y = c)的不確定度|x),當給定擾動樣本x的先驗時,即p(x| x)的情況下,推導過程可以寫成:
其中,x、x^~、y和A分別表示樣本、擾動樣本、樣本的標簽和擾動方法的數量。分布D * 表示由擾動樣本λ x定義的數據分布。特別地,擾動樣本λ x由隨機函數fr(·)繪制,如下:
其中θr是表示第r次擾動的隨機因子的超參數。
測量樣品相對于擾動的不確定性為:
其中u(x)表示樣本x的不確定性,Sc是類別c是預測的前1類別的次數。1c表示二進制類索引向量。較低的u(x)值對應于擾動上更一致的top-1類,表明x位于模型強置信的區域.
2.2通過數據增強增強樣本多樣性(DA)
為了進一步增強記憶中的示例的多樣性,研究者采用了數據增強(DA)。 DA的通過圖像級或特征擾動使給定的樣本多樣化,這對應于通過確保多樣性來更新內存的理念。
隨著任務迭代的進行,新任務中的樣本可能會遵循與情節內存中的樣本(即,從以前的任務中)遵循不同的分布。 研究者在新任務的類別和內存中舊類的示例中采用混合標記的DA來“混合”圖像。 這種混合標簽DA減輕了由類分布在任務上的變化引起的副作用,并改善了表現。
混合標記的DA方法之一,CutMix 生成了混合樣品和平滑標簽,鑒于一組監督樣品(X1,Y1)和(X2,Y2),其公式如下:
二、使用步驟
1.實驗概況
研究者通過將RM與各種實驗設置中的藝術狀態進行比較,從經驗上驗證了RM的功效。 基準測試的CIL任務設置,情節內存的內存大小和性能指標。在MNIST、CIFAR10、CIFAR100和ImageNet數據集上進行實驗。采用多種CIL任務設置、不同的記憶大小和性能指標評估RM方法。將RM與EWC、Rwalk、iCaRL等標準CIL方法對比 ,比較不同方法在各種設置下的Last Accuracy(A5)、Last Forgetting(F5)和Intransigence(I5)等指標。分析RM在不同模糊水平(如Blurry0、Blurry10、Blurry30)下的性能,還探究了不確定性測量方法、記憶更新算法、數據增強方法等對性能的影響。
2.RM核心代碼
RM部分的完整核心代碼如下:
import logging
import randomimport numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterfrom methods.finetune import Finetune
from utils.data_loader import cutmix_data, ImageDatasetlogger = logging.getLogger()
writer = SummaryWriter("tensorboard")def cycle(iterable):# iterate with shufflingwhile True:for i in iterable:yield iclass RM(Finetune):def __init__(self, criterion, device, train_transform, test_transform, n_classes, **kwargs):super().__init__(criterion, device, train_transform, test_transform, n_classes, **kwargs)self.batch_size = kwargs["batchsize"]self.n_worker = kwargs["n_worker"]self.exp_env = kwargs["stream_env"]if kwargs["mem_manage"] == "default":self.mem_manage = "uncertainty"def train(self, cur_iter, n_epoch, batch_size, n_worker, n_passes=0):if len(self.memory_list) > 0:mem_dataset = ImageDataset(pd.DataFrame(self.memory_list),dataset=self.dataset,transform=self.train_transform,)memory_loader = DataLoader(mem_dataset,shuffle=True,batch_size=(batch_size // 2),num_workers=n_worker,)stream_batch_size = batch_size - batch_size // 2else:memory_loader = Nonestream_batch_size = batch_size# train_list == streamed_list in RMtrain_list = self.streamed_listtest_list = self.test_listrandom.shuffle(train_list)# Configuring a batch with streamed and memory data equally.train_loader, test_loader = self.get_dataloader(stream_batch_size, n_worker, train_list, test_list)logger.info(f"Streamed samples: {len(self.streamed_list)}")logger.info(f"In-memory samples: {len(self.memory_list)}")logger.info(f"Train samples: {len(train_list)+len(self.memory_list)}")logger.info(f"Test samples: {len(test_list)}")# TRAINbest_acc = 0.0eval_dict = dict()self.model = self.model.to(self.device)for epoch in range(n_epoch):# initialize for each taskif epoch <= 0: # Warm start of 1 epochfor param_group in self.optimizer.param_groups:param_group["lr"] = self.lr * 0.1elif epoch == 1: # Then set to maxlrfor param_group in self.optimizer.param_groups:param_group["lr"] = self.lrelse: # Aand go!self.scheduler.step()train_loss, train_acc = self._train(train_loader=train_loader, memory_loader=memory_loader,optimizer=self.optimizer, criterion=self.criterion)eval_dict = self.evaluation(test_loader=test_loader, criterion=self.criterion)writer.add_scalar(f"task{cur_iter}/train/loss", train_loss, epoch)writer.add_scalar(f"task{cur_iter}/train/acc", train_acc, epoch)writer.add_scalar(f"task{cur_iter}/test/loss", eval_dict["avg_loss"], epoch)writer.add_scalar(f"task{cur_iter}/test/acc", eval_dict["avg_acc"], epoch)writer.add_scalar(f"task{cur_iter}/train/lr", self.optimizer.param_groups[0]["lr"], epoch)logger.info(f"Task {cur_iter} | Epoch {epoch+1}/{n_epoch} | train_loss {train_loss:.4f} | train_acc {train_acc:.4f} | "f"test_loss {eval_dict['avg_loss']:.4f} | test_acc {eval_dict['avg_acc']:.4f} | "f"lr {self.optimizer.param_groups[0]['lr']:.4f}")best_acc = max(best_acc, eval_dict["avg_acc"])return best_acc, eval_dictdef update_model(self, x, y, criterion, optimizer):optimizer.zero_grad()do_cutmix = self.cutmix and np.random.rand(1) < 0.5if do_cutmix:x, labels_a, labels_b, lam = cutmix_data(x=x, y=y, alpha=1.0)logit = self.model(x)loss = lam * criterion(logit, labels_a) + (1 - lam) * criterion(logit, labels_b)else:logit = self.model(x)loss = criterion(logit, y)_, preds = logit.topk(self.topk, 1, True, True)loss.backward()optimizer.step()return loss.item(), torch.sum(preds == y.unsqueeze(1)).item(), y.size(0)def _train(self, train_loader, memory_loader, optimizer, criterion):total_loss, correct, num_data = 0.0, 0.0, 0.0self.model.train()if memory_loader is not None and train_loader is not None:data_iterator = zip(train_loader, cycle(memory_loader))elif memory_loader is not None:data_iterator = memory_loaderelif train_loader is not None:data_iterator = train_loaderelse:raise NotImplementedError("None of dataloder is valid")for data in data_iterator:if len(data) == 2:stream_data, mem_data = datax = torch.cat([stream_data["image"], mem_data["image"]])y = torch.cat([stream_data["label"], mem_data["label"]])else:x = data["image"]y = data["label"]x = x.to(self.device)y = y.to(self.device)l, c, d = self.update_model(x, y, criterion, optimizer)total_loss += lcorrect += cnum_data += dif train_loader is not None:n_batches = len(train_loader)else:n_batches = len(memory_loader)return total_loss / n_batches, correct / num_datadef allocate_batch_size(self, n_old_class, n_new_class):new_batch_size = int(self.batch_size * n_new_class / (n_old_class + n_new_class))old_batch_size = self.batch_size - new_batch_sizereturn new_batch_size, old_batch_size
1.內存管理與數據混合(對應論文 Section 4.1)
將內存中的舊任務樣本(memory_loader)與當前任務的流數據(train_loader)按比例混合(默認各占50%)。
使用cycle(memory_loader)循環讀取內存數據,避免內存樣本因容量限制被忽略。
實現多樣性記憶回放,通過混合新舊任務樣本緩解災難性遺忘,確保模型同時學習新任務和鞏固舊任務知識。
def train(self, cur_iter, n_epoch, batch_size, n_worker, n_passes=0):# 加載內存數據(舊任務樣本)和流數據(新任務樣本)if len(self.memory_list) > 0:mem_dataset = ImageDataset(self.memory_list, transform=self.train_transform)memory_loader = DataLoader(mem_dataset, batch_size=(batch_size // 2), ...)stream_batch_size = batch_size - batch_size // 2else:memory_loader = Nonestream_batch_size = batch_size# 混合流數據和內存數據data_iterator = zip(train_loader, cycle(memory_loader)) # 循環迭代內存數據x = torch.cat([stream_data["image"], mem_data["image"]])y = torch.cat([stream_data["label"], mem_data["label"]])
數據增強:CutMix
以50%概率應用CutMix,將兩張圖像局部區域混合,并生成對應的混合標簽(labels_a和labels_b)。
計算混合損失(lam * loss_a + (1-lam) * loss_b),鼓勵模型學習更魯棒的特征,實現標簽混合增強(Section 4.2),通過生成邊界復雜的樣本提升記憶庫多樣性,增強模型泛化能力。
def update_model(self, x, y, criterion, optimizer):# CutMix增強:混合圖像和標簽do_cutmix = self.cutmix and np.random.rand(1) < 0.5if do_cutmix:x, labels_a, labels_b, lam = cutmix_data(x=x, y=y, alpha=1.0)logit = self.model(x)loss = lam * criterion(logit, labels_a) + (1 - lam) * criterion(logit, labels_b)else:logit = self.model(x)loss = criterion(logit, y)
動態學習率與批量調整
# Warm start學習率調整
if epoch <= 0:for param_group in self.optimizer.param_groups:param_group["lr"] = self.lr * 0.1 # 初始低學習率
elif epoch == 1:param_group["lr"] = self.lr # 恢復基準學習率
else:self.scheduler.step() # 后續按計劃調整# 動態調整新舊任務批量大小
def allocate_batch_size(self, n_old_class, n_new_class):new_batch_size = int(self.batch_size * n_new_class / (n_old_class + n_new_class))old_batch_size = self.batch_size - new_batch_sizereturn new_batch_size, old_batch_size
初始階段使用低學習率(10%基準值)進行預熱(Warm-up),避免訓練初期不穩定。
根據新舊類別比例動態分配批量大小,平衡新舊任務的學習強度,防止新任務數據主導學習過程。
4. 訓練流程與評估
# 訓練與評估循環
for epoch in range(n_epoch):train_loss, train_acc = self._train(...) # 訓練eval_dict = self.evaluation(...) # 評估logger.info(f"Task {cur_iter} | Epoch {epoch+1} | train_acc {train_acc:.4f} | test_acc {eval_dict['avg_acc']:.4f}")
3.實驗結果
研究者將提出的RM與各種數據集的“ Blurry10-Online”設置中的其他方法進行了比較,并總結了如下表的結果,如表所示,RM始終優于所有其他方法,并且當類(| C |)增加時,增益會更大。但是,在MNIST上,沒有DA的RM表現最好。 研究者認為,DA會干擾模型培訓,因為示例足以避免忘記。
下表列出了三個情節記憶大小(K)的CIFAR10-Blurry10Online的比較; 200、500和1,000。結果表明,這些方法在最終任務中保留了有效的示例,足以恢復以前任務中發生的遺忘。 ICARL,GDUMB和BIC對于不固定(i5)的有效性較小,并且與EWC和RWALK相比,它們在忘記方面的表現較大,作為權衡。
研究者進一步比較了任務流的準確性軌跡; 由隨機分配的函數ψ(c)生成的三個流,具有不同的隨機種子,用于Imagenet和單個流,用Imagenet,并總結了下圖中的結果:
RM在整個任務流中始終優于其他基線。
總結
研究結論:研究者提出一種名為彩虹記憶(RM)的方法,用于處理任務共享類別(模糊 - CIL)的現實持續學習場景。通過基于樣本分類不確定性的新的多樣性增強采樣方法和多種數據增強技術,在CIFAR10、CIFAR100和ImageNet的模糊 - CIL場景中,RM大幅優于現有方法,在不連續和離線CIL設置中也有可比性能。
研究的創新性:一是提出基于樣本擾動不確定性的多樣性增強采樣方法管理有限容量記憶;二是采用多種數據增強技術提高樣本多樣性,增強記憶中樣本的代表性和判別性。
研究展望:可研究基于不確定性的記憶更新和數據增強在訓練時的關系,及其對不同CIL任務的影響。還可探索RM在更多類型數據集或其他領域持續學習場景中的應用效果。