論文閱讀筆記:Dataset Condensation with Gradient Matching
- 1. 解決了什么問題?(Motivation)
- 2. 關鍵方法與創新點 (Key Method & Innovation)
- 2.1 核心思路的演進:從參數匹配到梯度匹配
- 2.2 算法實現細節 (Implementation Details)
- 3. 實驗結果與貢獻 (Experiments & Contributions)
- 4.個人思考與啟發
- 主要代碼
- 算法邏輯總結
ICLR2021 github
核心思想一句話總結:
本文提出了一種創新的數據集壓縮方法——數據集凝縮(Dataset Condensation,DC),其核心思想是通過梯度匹配(Gradient Matching),將一個大型數據集
T
濃縮成一個極小的、信息量豐富的合成數據集S
。在S
上從頭訓練的模型,其性能可以逼近在T
上訓練的模型,從而極大地節省了存儲和訓練成本。
1. 解決了什么問題?(Motivation)
- 問題:現代深度學習依賴于大規模數據集,導致存儲成本、數據傳輸寬帶和模型訓練時間急劇增加。
- 目標:創建一個微型合成數據集
S
,它能作為原始大型數據集T
的高效替代品,用于從零開始訓練神經網絡
2. 關鍵方法與創新點 (Key Method & Innovation)
2.1 核心思路的演進:從參數匹配到梯度匹配
- 參數匹配 (Parameter Matching) - 一個被否定的思路
- 想法:直接讓
S
訓練收斂后的模型參數θS\theta_SθS?與用T
訓練收斂后的θT\theta_TθT?盡可能接近。 - 缺陷:
- 優化路徑復雜:深度網絡的參數空間非凸,直接走向目標θT\theta_TθT?極易陷入局部最優。
- 計算成本高:需要嵌套的雙層優化,內循環必須將模型訓練至收斂,計算上不可行。
- 想法:直接讓
- 梯度匹配 (Gradient Matching) - 本文的核心創新
- 想法:放棄匹配靜態的”終點“,轉而匹配動態的”過程“。即,確保在訓練每一步,模型在合成數據
S
上產生的梯度?Ls?L_s?Ls?在真實數據T
上產生的梯度?LT?L_T?LT?方向一致。 - 優勢:
- 計算高效:通過一個巧妙的近似,極大提高了效率和可擴展性。
- 優化路徑清晰:每一步都有明確的監督信號(梯度差異),引導
S
的優化,避免了在復雜空間中盲目搜索。 - 對齊學習動態:保證了模型在
S
上的學習方式與T
上一致,結果更魯棒。
- 想法:放棄匹配靜態的”終點“,轉而匹配動態的”過程“。即,確保在訓練每一步,模型在合成數據
2.2 算法實現細節 (Implementation Details)
- 課程學習 (Curriculum Learning)
- 為了讓合成數據
S
具有泛化性,算法采用了一個”課程學習“的框架。在整個凝縮過程中,會周期性地重新隨機初始化網絡參數θ\thetaθ。 - 這確保了
S
不會過擬合到某一個特定的網絡初始化,而是對多種隨機起點都有效。
- 為了讓合成數據
- 梯度匹配損失函數(Gradient Matching Loss)
- 使用**余弦距離(1-Cosine Similarity)**來衡量兩個梯度的差異。這更關注梯度的方向而非大小,與梯度下降的本質契合。
- 按輸出節點分組計算:并非所有層的梯度粗暴地展平,而是按輸出神經元分組計算余弦距離,更好地保留了網絡結構信息。
- 重要的工程技巧(Practical Tricks)
- BatchNorm層預熱與凍結:由于合成數據批次極小,為了避免BN層統計量不穩定,每次迭代前都先用一個較大的真實數據批次來計算并”凍結“BN層的均值和方差。
- 按類別獨立匹配:在計算梯度時,按類別獨立進行,即用”貓“的合成數據區匹配”貓“的真實數據梯度。這降低了學習難度和內存消耗。
3. 實驗結果與貢獻 (Experiments & Contributions)
- 性能優越:在CIFAR-10, CIFAR-100, SVHN等數據集上,僅用極少量合成樣本(如IPC=1或10),就能訓練出性能遠超當時其他數據壓縮方法的模型。
- 開創性貢獻:
- 首次提出了梯度匹配這一高效且可擴展的數據集凝縮范式,為后續大量的研究(如DSA, MTT, FTD等)奠定了基礎。
- 成功將數據集凝縮技術應用到了大型網絡上,證明了其可行性。
- 展示了其在持續學習和神經架構搜索 (NAS) 等資源受限場景下的巨大潛力。
4.個人思考與啟發
- ”過程“比”結果”更重要:這篇論文最精妙的哲學在于,它揭示了在復雜優化問題中,對齊“過程”(梯度)比直接追求“結果”(參數)更有效、更可行。這一思想在很多其他領域也具有啟發性。
- 理論與實踐的結合:論文不僅提出了一個優雅的理論框架,還通過BN層處理等工程技巧解決了實際應用中的痛點。
主要代碼
''' training '''# 為合成圖像image_syn創建一個優化器# 我們只優化image_syn這個張量,所有優化器只傳入它。# 這里的優化器是SGD,意味著我們會用梯度下降法來更新圖像的像素值。optimizer_img = torch.optim.SGD([image_syn, ], lr=args.lr_img, momentum=0.5) # optimizer_img for synthetic data# 清空優化器的梯度緩存optimizer_img.zero_grad()# 定義用于計算分類損失的損失函數,這里是標準的交叉熵損失。criterion = nn.CrossEntropyLoss().to(args.device)print('%s training begins'%get_time())# 主迭代循環開始# 這個循環是整個數據集凝縮過程的核心,總共進行Iteration+1次。for it in range(args.Iteration+1):# 評估合成數據(在特定迭代點觸發)''' Evaluate synthetic data '''if it in eval_it_pool:for model_eval in model_eval_pool:# 遍歷model_eval_pool中的每一個模型架構,用于評估。# 這運行我們測試合成數據集在不同模型上的泛化能力。print('-------------------------\nEvaluation\nmodel_train = %s, model_eval = %s, iteration = %d'%(args.model, model_eval, it))# 設置評估時的數據增強策略if args.dsa:# 如果是DSA方法,使用其特定的增強策略。args.epoch_eval_train = 1000args.dc_aug_param = Noneprint('DSA augmentation strategy: \n', args.dsa_strategy)print('DSA augmentation parameters: \n', args.dsa_param.__dict__)else:# 如果是DC方法,調用 get_daparam 獲取專為DC設計的增強參數。# 注意:這些增強只在評估時使用,在生成合成數據時不用。args.dc_aug_param = get_daparam(args.dataset, args.model, model_eval, args.ipc) # This augmentation parameter set is only for DC method. It will be muted when args.dsa is True.print('DC augmentation parameters: \n', args.dc_aug_param)# 如果在評估時使用了任何數據增強,就需要更多的訓練輪數來讓模型充分學習。if args.dsa or args.dc_aug_param['strategy'] != 'none':args.epoch_eval_train = 1000 # Training with data augmentation needs more epochs.else:args.epoch_eval_train = 300# --- 3.2 執行評估 ---# 創建一個空列表,用于存儲多次評估的準確率accs = []# 為了結果的穩定性,我們會用當前的合成數據訓練num_eval個獨立,隨機初始化的模型。for it_eval in range(args.num_eval):# 每一次都創建一個全新的、隨機初始化的評估網絡。net_eval = get_network(model_eval, channel, num_classes, im_size).to(args.device) # get a random model# 深拷貝當前的合成數據和標簽,以防止在評估函數中被意外修改。# detach()是為了確保我們只復制數據,不帶計算圖。image_syn_eval, label_syn_eval = copy.deepcopy(image_syn.detach()), copy.deepcopy(label_syn.detach()) # avoid any unaware modification# 調用核心評估函數 evaluate_synset。# 這個函數會:# 1. 拿 image_syn_eval 從頭開始訓練 net_eval。# 2. 在訓練結束后,用訓練好的 net_eval 在真實的測試集 testloader 上進行測試。# 3. 返回在測試集上的準確率 acc_test。_, acc_train, acc_test = evaluate_synset(it_eval, net_eval, image_syn_eval, label_syn_eval, testloader, args)accs.append(acc_test)# 打印這次評估的平均準確率和標準差。print('Evaluate %d random %s, mean = %.4f std = %.4f\n-------------------------'%(len(accs), model_eval, np.mean(accs), np.std(accs)))# 如果這是最后一次迭代,將這次評估的所有準確率結果記錄到總的實驗結果字典中。if it == args.Iteration: # record the final resultsaccs_all_exps[model_eval] += accs# 可視化并保存合成圖像''' visualize and save '''save_name = os.path.join(args.save_path, 'vis_%s_%s_%s_%dipc_exp%d_iter%d.png'%(args.method, args.dataset, args.model, args.ipc, exp, it))# 深拷貝合成圖像,并移到CPU上進行處理。image_syn_vis = copy.deepcopy(image_syn.detach().cpu())# 對圖像進行反歸一化,以便人眼觀察# 訓練時圖像通常是歸一化的。# 反歸一化公式:pixel = pixel * std + meanfor ch in range(channel):image_syn_vis[:, ch] = image_syn_vis[:, ch] * std[ch] + mean[ch]# 將像素值裁剪到[0,1]范圍內,防止因浮點數誤差導致顯示異常。image_syn_vis[image_syn_vis<0] = 0.0image_syn_vis[image_syn_vis>1] = 1.0# 使用torchvision.utils.save_image將合成圖像保存為一張網格圖。# nrow=args.ipc表示每行顯示ipc張圖像。save_image(image_syn_vis, save_name, nrow=args.ipc) # Trying normalize = True/False may get better visual effects.# --- 初始化課程學習環境 ---''' Train synthetic data '''# 每次主迭代(it)開始,都創建一個全新的、隨機初始的網絡。# 這是”課程學習“的關鍵:確保合成數據對不同的網絡初始化方法都有效,而不是過擬合到某一個。net = get_network(args.model, channel, num_classes, im_size).to(args.device) # get a random modelnet.train() # 將網絡設置為訓練模式# 獲取網絡的所有可學習參數net_parameters = list(net.parameters())# 為這個新網絡創建一個優化器,用于在內循環中更新網絡參數optimizer_net = torch.optim.SGD(net.parameters(), lr=args.lr_net) # optimizer_img for synthetic dataoptimizer_net.zero_grad()# 初始化平均損失,用于記錄和打印loss_avg = 0# 在生成合成數據時,不使用任何數據增強,以與DC論文的設置保持一致args.dc_aug_param = None # Mute the DC augmentation when learning synthetic data (in inner-loop epoch function) in oder to be consistent with DC paper.# --- 課程學習外循環(Outer Loop) ---# 這個循環對應論文算法中的外循環,用于實現課程學習。for ol in range(args.outer_loop):# -- BatchNorm層預熱與凍結(一個非常重要的工程技巧) --''' freeze the running mu and sigma for BatchNorm layers '''# Synthetic data batch, e.g. only 1 image/batch, is too small to obtain stable mu and sigma.# So, we calculate and freeze mu and sigma for BatchNorm layer with real data batch ahead.# This would make the training with BatchNorm layers easier.# 動機:合成數據的批次非常小(例如ipc=1),如果讓BN層在這么小的批次上計算均值和方差,結果會極其不穩定,導致訓練困難。# 解決方案:先用一個包含多個真實樣本的”大“批次來預熱BN層,計算出穩定的統計量,然后將其凍結。BN_flag = FalseBNSizePC = 16 # for batch normalization 每個類別用于BN預熱的樣本數# 檢查網絡中是否存在BN層for module in net.modules():if 'BatchNorm' in module._get_name(): #BatchNormBN_flag = Trueif BN_flag:# 從每個類別中抽取BNSizePC個真實圖像,拼接成一個大批次。img_real = torch.cat([get_images(c, BNSizePC) for c in range(num_classes)], dim=0)# 確保網絡在訓練模式,以便BN層可以更新其 running_mean 和 running_var。net.train() # for updating the mu, sigma of BatchNorm# 進行一次前向傳播,這個操作會自動更新BN層的統計量。output_real = net(img_real) # get running mu, sigma# 將所有BN層切換到評估模式。# 在評估模式下,BN層會使用已經計算好的 running_mean 和 running_var,而不會再根據新的輸入來更新它們。# 這就實現了“凍結”的效果。for module in net.modules():if 'BatchNorm' in module._get_name(): #BatchNormmodule.eval() # fix mu and sigma of every BatchNorm layer# --- 核心:通過梯度匹配更新合成數據 ---''' update synthetic data '''# 初始化當前外循環的總損失loss = torch.tensor(0.0).to(args.device)# 按照類別獨立進行梯度匹配,這個是論文提出的另外一個技巧。for c in range(num_classes):# 準備真實數據和合成數據img_real = get_images(c, args.batch_real)lab_real = torch.ones((img_real.shape[0],), device=args.device, dtype=torch.long) * cimg_syn = image_syn[c*args.ipc:(c+1)*args.ipc].reshape((args.ipc, channel, im_size[0], im_size[1]))lab_syn = torch.ones((args.ipc,), device=args.device, dtype=torch.long) * c# 如果使用DSA方法,對真實和合成圖像應用相同的可微數據增強if args.dsa:seed = int(time.time() * 1000) % 100000img_real = DiffAugment(img_real, args.dsa_strategy, seed=seed, param=args.dsa_param)img_syn = DiffAugment(img_syn, args.dsa_strategy, seed=seed, param=args.dsa_param)# --- 計算真實梯度 gw_real ---output_real = net(img_real)loss_real = criterion(output_real, lab_real)# 計算損失對網絡參數的梯度gw_real = torch.autograd.grad(loss_real, net_parameters)# clone()和detach()是為了將梯度值復制下來,并切斷其與計算圖的聯系,# 因為我們只需要它的數值作為匹配目標,不希望梯度回流真實數據。gw_real = list((_.detach().clone() for _ in gw_real))# -- 計算合成梯度gw_syn --output_syn = net(img_syn)loss_syn = criterion(output_syn, lab_syn)# 關鍵所在:create_graph=True# 這個參數告訴pytorch,在計算gw_syn時,要保留其計算圖。# 這意味著gw_syn本身也成為了一個計算圖中的節點,它依賴于iamge_syn.# 因此,后續對gw_syn的損失進行反向傳播時,梯度可以一直流回image_syn。gw_syn = torch.autograd.grad(loss_syn, net_parameters, create_graph=True)# 計算真實梯度和合成梯度之間的匹配損失,余弦相似度loss += match_loss(gw_syn, gw_real, args)# 更新合成圖像optimizer_img.zero_grad() # 清空image_syn的梯度緩存loss.backward() # 反向傳播,計算匹配損失對image_syn對image_syn的梯度optimizer_img.step() # 根據梯度更新image_syn的像素值loss_avg += loss.item() # 累加損失用于打印# 如果是最后一個外循環,就不需要再更新網絡了,直接跳出。if ol == args.outer_loop - 1:break# --- 2.3 內循環:用更新后的合成數據訓練網絡 ---''' update network '''# 第二步:現在輪到網絡來適應更新后的合成數據了。image_syn_train, label_syn_train = copy.deepcopy(image_syn.detach()), copy.deepcopy(label_syn.detach()) # avoid any unaware modificationdst_syn_train = TensorDataset(image_syn_train, label_syn_train)trainloader = torch.utils.data.DataLoader(dst_syn_train, batch_size=args.batch_train, shuffle=True, num_workers=0)# 對網絡進行inner_loop次的訓練更新。for il in range(args.inner_loop):epoch('train', trainloader, net, optimizer_net, criterion, args, aug = True if args.dsa else False)# 記錄和保存# 計算并打印平均損失loss_avg /= (num_classes*args.outer_loop)if it%10 == 0:print('%s iter = %04d, loss = %.4f' % (get_time(), it, loss_avg))# 如果是最后一次主迭代,保存所有結果if it == args.Iteration: # only record the final resultsdata_save.append([copy.deepcopy(image_syn.detach().cpu()), copy.deepcopy(label_syn.detach().cpu())])torch.save({'data': data_save, 'accs_all_exps': accs_all_exps, }, os.path.join(args.save_path, 'res_%s_%s_%s_%dipc.pt'%(args.method, args.dataset, args.model, args.ipc)))
算法邏輯總結
“你追我趕”的雙重優化過程:
- 課程學習 (Outer Loop):
- 每一次外循環,都像是新學期開學,我們找來一個“新生”(一個隨機初始化的
net
)。 - 這個“新生”的存在,是為了確保我們的“教材”(合成數據
image_syn
)是普適的,對任何基礎的學生都有效。
- 每一次外循環,都像是新學期開學,我們找來一個“新生”(一個隨機初始化的
- 教材編寫 (Update Synthetic Data):
- 這是核心步驟。我們讓“新生”
net
分別看“官方教材”(真實數據img_real
)和我們正在編寫的“濃縮筆記”(合成數據img_syn
)。 - 我們記錄下“新生”看完兩種材料后的“學習心得”(梯度
gw_real
和gw_syn
)。 - 我們的目標是修改“濃縮筆記”
img_syn
,使得“新生”看完它之后產生的“學習心得”gw_syn
和看完“官方教材”產生的gw_real
一模一樣。 create_graph=True
是實現這一點的技術關鍵,它允許我們對“學習心得”本身求導,從而知道該如何修改“濃y縮筆記”的每一個字(像素)。
- 這是核心步驟。我們讓“新生”
- 學生自習 (Update Network):
- “濃縮筆記”
image_syn
更新完畢后,我們讓“新生”net
對著這本新版的筆記自習幾遍(inner_loop
次)。 - 這會讓“新生”對當前的“濃縮筆記”有更深的理解,為下一輪的“教材編寫”做好準備。
- “濃縮筆記”
這個“編寫教材 -> 學生自習 -> 換個新生再來一遍”的過程不斷重復,最終使得“濃縮筆記” image_syn
變得越來越精華,能夠高效地替代“官方教材” T
。