數據蒸餾:Dataset Distillation by Matching Training Trajectories 論文翻譯和理解

一、TL;DR

  1. 數據集蒸餾的任務是合成一個較小的數據集,使得在該合成數據集上訓練的模型能夠達到在完整數據集上訓練的模型相同的測試準確率,號稱優于coreset的選擇方法
  2. 本文中,對于給定的網絡,我們在蒸餾數據上對其進行幾次迭代訓練,預先計算并存儲在真實數據集上訓練的專家網絡的訓練軌跡,并根據合成訓練參數與在真實數據上訓練的參數之間的距離來優化蒸餾數據。
  3. 有一個問題哈,這種蒸餾方法強依賴GT,如果新增數據優化模型,沒有GT可能還是只能使用coreset的方法來做

二、方法介紹

數據蒸餾的目標是從大型訓練數據集中提取知識,將其濃縮到一個非常小的合成訓練圖像集合中(每個類別低至一張圖像),以便在蒸餾數據上訓練模型能夠獲得與在原始數據集上訓練相似的測試性能,如下圖所示:

與經典的數據壓縮不同,數據集蒸餾旨在保留足夠的任務相關的信息,以便在小的合成數據集上訓練的模型能夠泛化到未見過的測試數據,如圖2所示。因此,蒸餾算法必須在大量壓縮信息的同時,保留區分性特征

之前的方法的問題:

  1. 大多數先前的方法都集中在小型數據集(如MNIST和CIFAR)上,而在真實、更高分辨率的圖像上卻難以取得進展
  2. 一些方法考慮了端到端的訓練,但往往需要巨大的計算和內存資源,并且存在近似松弛或訓練不穩定性的問題
  3. 另外一些方法專注于短期行為,強制在蒸餾數據上進行單次訓練步驟以匹配真實數據上的訓練步驟,在評估中可能會累積誤差。

在本工作中:

  1. 提出了一種新的數據集蒸餾方法,不僅在性能上超越了以前的工作,而且適用于大規模數據集,如圖1所示。
  2. 本文方法試圖直接模仿在真實數據集上訓練的網絡的長期訓練動態;
  3. 我們將合成數據上訓練的參數軌跡段與從在真實數據上訓練的模型記錄的專家軌跡段進行匹配,從而避免了短視(即,專注于單個步驟)或難以優化(即,建模完整軌跡)的問題
  4. 將真實數據集視為引導網絡訓練動態的黃金標準,我們可以認為誘導的網絡參數序列是一個專家軌跡。如果我們的蒸餾數據集能夠誘導網絡的訓練動態遵循這些專家軌跡那么合成訓練的網絡將在參數空間中接近于在真實數據上訓練的模型,并實現類似的測試性能

在我們的方法中,我們的損失函數直接鼓勵蒸餾數據集引導網絡優化沿著類似的軌跡進行(圖3)。

訓練流程:

  1. 從頭開始在真實數據集上訓練一組模型,并記錄它們的專家訓練軌跡。
  2. 從隨機選擇的專家軌跡中隨機選擇一個時間步來初始化一個新模型,并在合成數據集上進行幾次迭代訓練。
  3. 我們根據這個合成訓練的網絡與專家軌跡的偏離程度來懲罰蒸餾數據,并通過訓練迭代進行反向傳播。本質上,我們將許多專家訓練軌跡的知識轉移到了蒸餾圖像上。

實驗結果:

  1. 輕松超越了現有的數據集蒸餾方法以及核心集選擇方法,在標準數據集上表現優異,包括CIFAR-10、CIFAR-100和Tiny ImageNet。
  2. CIFAR-10上,我們使用每個類別一張圖像時達到了46.3%,每個類別50張圖像時達到了71.5%的準確率
  3. 首次能夠從ImageNet中蒸餾出高128×128分辨率的圖像

三、近期工作(直接翻譯)

3.1 數據集蒸餾

數據集蒸餾最早由Wang等人[44]提出,他們提出將模型權重表示為蒸餾圖像的函數,并使用基于梯度的超參數優化方法對其進行優化[23],這種方法在元學習研究中也得到了廣泛應用[8, 27]。隨后,通過學習軟標簽[2, 38]、通過梯度匹配放大學習信號[47]、采用數據增強[45]以及針對無限寬度核極限進行優化[25, 26],一些工作顯著提高了結果。數據集蒸餾已經實現了多種應用,包括持續學習[44, 45, 47]、高效的神經架構搜索[45, 47]、聯邦學習[11, 37, 50]以及針對圖像、文本和醫學影像數據的隱私保護機器學習[22, 37]。正如引言中提到的,我們的方法不依賴于單步行為匹配[45, 47]、成本高昂的完整優化軌跡展開[38, 44]或大規模神經切線核計算[25, 26]。相反,我們的方法通過從預訓練的專家中轉移知識來實現長期軌跡匹配。

與我們的工作同時進行的,Zhao和Bilen[46]的方法完全忽略了優化步驟,而是專注于合成數據和真實數據之間的分布匹配。盡管這種方法由于降低了內存需求而適用于更高分辨率的數據集(例如Tiny ImageNet),但在大多數情況下,其性能表現不如以往的工作(例如,與之前的作品[45, 47]相比)。相比之下,我們的方法在標準基準測試和更高分辨率數據集上同時降低了內存成本,同時超越了現有作品[45, 47]和同時進行的方法[46]。

還有一條相關的研究路線是學習一個生成模型來合成訓練數據[24, 36]。然而,這些方法并沒有生成一個小尺寸的數據集,因此不能直接與數據集蒸餾方法進行比較。

3.2 模仿學習

模仿學習試圖通過觀察一系列專家演示來學習一個良好的策略[29, 30, 31]。行為克隆訓練學習策略以與專家演示相同的方式行動。一些更復雜的形式涉及使用專家的標記進行在線策略學習[33],而其他方法則完全避免任何標記,例如通過分布匹配[16]。這些方法(特別是行為克隆)已被證明在離線環境中效果良好[9, 12]。我們的方法可以被視為模仿通過在真實數據集上訓練獲得的一系列專家網絡訓練軌跡。因此,它可以被視為在優化軌跡上進行模仿學習。

3.3 核心集和實例選擇

與數據集蒸餾類似,核心集[1, 4, 13, 34, 41]和實例選擇[28]旨在選擇整個訓練數據集的一個子集,其中在這個小子集上進行訓練能夠獲得良好的性能。這些方法中的大多數并不適用于現代深度學習,但基于雙層優化的新公式在持續學習等應用中已經顯示出有希望的結果[3]。與核心集相關,其他研究路線旨在了解哪些訓練樣本對現代機器學習是“有價值的”,包括測量單個樣本的準確性[20]和計算誤分類率[39]。事實上,數據集蒸餾是這類想法的推廣,因為蒸餾數據不需要是真實的,也不需要來自訓練集。

四、方法詳細介紹

數據集蒸餾指的是策劃一個小的、合成的訓練集 Dsyn?,使得在該合成數據上訓練的模型在真實測試集上的表現與在大型真實訓練集 Dreal? 上訓練的模型相似。本文方法直接模仿真實數據訓練的長期行為,將蒸餾數據上的多個訓練步驟與真實數據上的更多步驟進行匹配。

3.1 專家軌跡

如何獲取在真實數據集上訓練的網絡的專家軌跡?

方法的核心:

  1. 利用專家軌跡 τ? 來指導我們合成數據集的蒸餾。專家軌跡是指在完整的真實數據集上訓練神經網絡時獲得的參數時間序列 {θt??}0T?。

如何生成這些專家軌跡?

  1. 我們簡單地在真實數據集上訓練大量網絡,每個模型不同epoch組成一條expert trajectory。作者稱這些參數序列為“expert trajectory”,因為它們代表了數據集蒸餾任務的理論上限:在完整的真實數據集上訓練的網絡的性能。
  2. 同樣,我們定義學生參數 θ^t? 為在訓練步驟 t 時在合成圖像上訓練的網絡參數。我們的目標是蒸餾一個數據集,使其誘導出與真實訓練集誘導的軌跡(給定相同的起始點)相似的軌跡,從而使我們最終得到一個類似的模型

由于這些專家軌跡僅使用真實數據計算,因此我們可以在蒸餾之前預先計算它們。對于給定數據集的所有實驗,我們都使用相同的預先計算的專家軌跡集合,這使得蒸餾和實驗能夠快速進行。

3.2 長期參數匹配

本文方法通過鼓勵蒸餾數據集誘導與真實數據集相似的長期網絡參數軌跡,從而使得在合成數據上訓練的網絡表現類似于在真實數據上訓練的網絡。

我們的蒸餾過程從構成我們expert trajectories中的參數序列 {θt??}0T? 中學習。與以往工作不同,我們的方法直接鼓勵我們合成數據集誘導的長期訓練動態與在真實數據上訓練的網絡的動態相匹配。

在每個蒸餾步驟中,我們首先從我們的專家軌跡之一中隨機時間步采樣參數 θt??,并用這些參數初始化我們的學生參數 θ^t?:=θt??。在初始化我們的學生網絡后,我們接著對合成數據的分類損失進行 N 次梯度下降更新,更新學生參數:

其中A是可微分增強操作,α是個可學習的學習率。然后計算更新后的學生參數和expert trajectory的模型參數的匹配損失,根據權重匹配損失更新我們的蒸餾圖像,即更新后學生參數 θ^t+N? 與已知未來的專家參數 θt+M?? 之間的歸一化平方 L2 誤差:

通過將反向傳播通過學生網絡的所有 N 次更新來最小化這個目標,更新我們蒸餾數據集的像素,以及我們的可訓練學習率 α。可訓練學習率 α 的優化起到了自動調整學生和專家更新次數(超參數 M 和 N)的作用。我們使用帶有動量的隨機梯度下降(SGD)來優化 Dsyn? 和 α,以達到上述目標。整體如下所示:

3.3 內存限制

本文如何減少內存消耗?

原式是這樣進行梯度更新的,由于Dataset太大,因此可以將一式轉化為三式

我們可以為學生網絡的每次更新(即算法 1 第 10 行的內循環)采樣一個新的小批量 b,這樣在計算最終權重匹配損失(方程 2)時,所有的蒸餾圖像都將被看到。小批量 b 仍然包含來自不同類別的圖像,但每個類別的圖像數量要少得多。在這種情況下,我們的學生網絡更新變為

這種分批方法允許我們在確保同一類別蒸餾圖像之間存在一定程度的異質性的同時,蒸餾出一個更大的合成數據集。

五、實驗

對于 CIFAR-10,這些蒸餾圖像可以在圖 4 中看到。CIFAR-100 的圖像在補充材料中進行了可視化。

如表 1 所示,我們的方法在每種設置中都顯著優于所有基線。事實上,在每個類別一張圖像的設置中,我們在兩個數據集上都將次優方法(DSA [45])的測試準確率幾乎提高了一倍。

在表 2 中,我們還與最近的方法 KIP [25, 26] 進行了比較:

正如之前的方法 [44] 所指出的,我們還發現在合成數據集中允許更多圖像時,收益會顯著減少

  1. 例如,在 CIFAR-10 上,當我們將每個類別的圖像數量從 1 增加到 10 時,分類準確率從 46.3% 提高到 65.3%,
  2. 但當我們將每個類別的蒸餾圖像數量從 10 增加到 50 時,僅從 65.3% 提高到 71.5%。

如果我們查看圖 4(頂部)中每個類別一張圖像的可視化,我們會看到每個類別的非常抽象但仍然可以識別的表示。當我們只允許每個類別有一張合成圖像時,優化被迫將盡可能多的類別區分信息壓縮到一個樣本中。當我們允許更多圖像來分散類別的信息時,優化有自由度將類別的區分特征分散到多個樣本中,從而產生我們在圖 4(底部)中看到的多樣化的一組結構化圖像(例如,不同類型的汽車和馬,具有不同的姿勢)。

跨架構泛化。我們還在 CIFAR-10、每個類別一張圖像的任務上評估了我們的合成數據在與用于蒸餾它的架構不同的架構上的表現。在表 3 中,我們展示了我們的基線 ConvNet 性能,并在 ResNet 、VGG 和 AlexNet 上進行了評估。

表明我們的方法對架構的變化具有魯棒性

4.2 短期匹配與長期匹配

非常短期的匹配(N=1 且 M 較小)通常比長期匹配表現更差,當 N 和 M 都相對較大時,達到最佳性能

對于這兩種方法,我們測試它們使用蒸餾數據從相同的初始參數訓練網絡到目標參數的接近程度。DSA 僅針對短期行為進行優化,因此在更長時間的訓練過程中可能會累積誤差。實際上,隨著 Δt 變得更大,DSA 在更長距離上無法模仿真實數據的行為。相比之下,我們的方法針對長期匹配進行了優化,因此表現更好。

六、總結

在這項工作中,我們介紹了一種數據集蒸餾算法,通過直接優化合成數據來誘導與真實數據相似的網絡訓練動態。我們的方法與以往方法的主要區別在于,我們既不受限于短期單步匹配,也不受優化整個訓練過程的不穩定性以及計算強度的影響。我們的方法在這兩個方面都取得了平衡,并且在這兩方面都顯示出改進。與以往方法不同,我們的方法首次擴展到128×128的ImageNet圖像

局限性。我們使用預先計算的軌跡,雖然節省了大量內存,但以增加磁盤存儲和專家模型訓練的計算成本為代價。訓練和存儲專家軌跡的計算開銷相當高。例如,CIFAR專家每個epoch大約需要3秒(所有200個CIFAR專家總共需要8個GPU小時),而每個ImageNet(子集)專家每個epoch大約需要11秒(所有100個ImageNet專家總共需要15個GPU小時)。在存儲方面,每個CIFAR專家大約占用60MB的存儲空間,而每個ImageNet專家大約占用120MB。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/74650.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/74650.shtml
英文地址,請注明出處:http://en.pswp.cn/web/74650.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

【spring cloud Netflix】Ribbon組件

1.基本概念 SpringCloud Ribbon是基于Netflix Ribbon 實現的一套客戶端負載均衡的工具。簡單的說,Ribbon 是 Netflix 發布的開源項目,主要功能是提供客戶端的軟件負載均衡算法,將 Netflix 的中間層服務連接在一 起。Ribbon 的客戶端組件提供…

P1036 [NOIP 2002 普及組] 選數(DFS)

題目描述 已知 n 個整數 x1?,x2?,?,xn?&#xff0c;以及 1 個整數 k&#xff08;k<n&#xff09;。從 n 個整數中任選 k 個整數相加&#xff0c;可分別得到一系列的和。例如當 n4&#xff0c;k3&#xff0c;4 個整數分別為 3,7,12,19 時&#xff0c;可得全部的組合與它…

在響應式網頁的開發中使用固定布局、流式布局、彈性布局哪種更好

一、首先看下固定布局與流體布局的區別 &#xff08;一&#xff09;固定布局 固定布局的網頁有一個固定寬度的容器&#xff0c;內部組件寬度可以是固定像素值或百分比。其容器元素不會移動&#xff0c;無論訪客屏幕分辨率如何&#xff0c;看到的網頁寬度都相同。現代網頁設計…

二分查找與二叉樹中序遍歷——面試算法

目錄 二分查找與分治 循環方式 遞歸方式 元素中有重復的二分查找 基于二分查找的拓展問題 山脈數組的頂峰索引——局部有序 旋轉數字中的最小數字 找缺失數字 優化平方根 中序與搜索樹 二叉搜索樹中搜索特定值 驗證二叉搜索樹 有序數組轉化為二叉搜索樹 尋找兩個…

字符串——面試考察高頻算法題

目錄 轉換成小寫字母 字符串轉化為整數 反轉相關的問題 反轉字符串 k個一組反轉 僅僅反轉字母 反轉字符串里的單詞 驗證回文串 判斷是否互為字符重排 最長公共前綴 字符串壓縮問題 轉換成小寫字母 給你一個字符串 s &#xff0c;將該字符串中的大寫字母轉換成相同的…

現代復古電影海報品牌徽標設計襯線英文字體安裝包 Thick – Retro Vintage Cinematic Font

Thick 是一種大膽的復古字體&#xff0c;專為有影響力的標題和懷舊的視覺效果而設計。其厚實的字體、復古魅力和電影風格使其成為電影海報、產品標簽、活動品牌和編輯設計的理想選擇。無論您是在引導電影的黃金時代&#xff0c;還是在現代布局中注入復古活力&#xff0c;Thick …

[C++面試] new、delete相關面試點

一、入門 1、說說new與malloc的基本用途 int* p1 (int*)malloc(sizeof(int)); // C風格 int* p2 new int(10); // C風格&#xff0c;初始化為10 new 是 C 中的運算符&#xff0c;用于在堆上動態分配內存并調用對象的構造函數&#xff0c;會自動計算所需內存…

Unity URP管線與HDRP管線對比

1. 渲染架構與底層技術 URP 渲染路徑&#xff1a; 前向渲染&#xff08;Forward&#xff09;&#xff1a;默認單Pass前向&#xff0c;支持少量實時光源&#xff08;通常4-8個逐物體&#xff09;。 延遲渲染&#xff08;Deferred&#xff09;&#xff1a;可選但功能簡化&#…

JDK8卸載與安裝教程(超詳細)

JDK8卸載與安裝教程&#xff08;超詳細&#xff09; 最近學習一個項目&#xff0c;需要使用更高級的JDK&#xff0c;這里記錄一下卸載舊版本與安裝新版本JDK的過程。 JDK8卸載 以windows10操作系統為例&#xff0c;使用快捷鍵winR輸入cmd&#xff0c;打開控制臺窗口&#xf…

python爬蟲:DrissionPage實戰教程

如果本文章看不懂可以看看上一篇文章&#xff0c;加強自己的基礎&#xff1a;爬蟲自動化工具&#xff1a;DrissionPage-CSDN博客 案例解析&#xff1a; 前提&#xff1a;我們以ChromiumPage為主&#xff0c;寫代碼工具使用Pycharm&#xff08;python環境3.9-3.10&#xff09; …

07-01-自考數據結構(20331)- 排序-內部排序知識點

內部排序算法是數據結構核心內容,主要包括插入類(直接插入、希爾)、交換類(冒泡、快速)、選擇類(簡單選擇、堆)、歸并和基數五大類排序方法。 知識拓撲 知識點介紹 直接插入排序 定義:將每個待排序元素插入到已排序序列的適當位置 算法步驟: 從第二個元素開始遍歷…

Go語言-初學者日記(八):構建、部署與 Docker 化

&#x1f9f1; 一、go build&#xff1a;最基礎的構建方式 Go 的構建工具鏈是出了名的輕量、簡潔&#xff0c;直接用 go build 就能把項目編譯成二進制文件。 ? 構建當前項目 go build -o myapp-o myapp 指定輸出文件名默認會構建當前目錄下的 main.go 或 package main &a…

教程:如何使用 JSON 合并腳本

目錄 1. 介紹 2. 使用方法 3. 注意事項 4. 示例 5.完整代碼 1. 介紹 該腳本用于將多個 COCO 格式的 JSON 標注文件合并為一個 JSON 文件。COCO 格式常用于目標檢測和圖像分割任務&#xff0c;包含以下三個主要部分&#xff1a; "images"&#xff1a;圖像信息&a…

Java學習總結-緩沖流性能分析

測試用例&#xff1a; 分別使用原始的字節流&#xff0c;以及字節緩沖流復制一個很大的視頻。 測試步驟&#xff1a; 在這個分析性能需要一個記錄時間的工具&#xff1a;這個是記錄1970-1-1 00&#xff1a;00&#xff1a;00到現在的總毫秒值。 long start System.currentT…

流影---開源網絡流量分析平臺(五)(成果展示)

目錄 前沿 攻擊過程 前沿 前四章我們已經成功安裝了流影的各個功能&#xff0c;那么接下來我們就看看這個開源工具的實力&#xff0c;本實驗將進行多個攻擊手段&#xff08;ip掃描&#xff0c;端口掃描&#xff0c;sql注入&#xff09;攻擊靶機&#xff0c;來看看流影的態感效…

vs環境中編譯osg以及osgQt

1、下載 OpenSceneGraph 獲取源代碼 您可以通過以下方式獲取 OSG 源代碼: 官網下載:https://github.com/openscenegraph/OpenSceneGraph/releases 使用 git 克隆: git clone https://github.com/openscenegraph/OpenSceneGraph.git 2、下載必要的第三方依賴庫 依賴庫 ht…

Unity:標簽(tags)

為什么需要Tags&#xff1f; 在游戲開發中&#xff0c;游戲對象&#xff08;GameObject&#xff09;數量可能非常多&#xff0c;比如玩家、敵人、子彈等。開發者需要一種簡單的方法來區分這些對象&#xff0c;并根據它們的類型執行不同的邏輯。 核心需求&#xff1a; 分類和管…

【C++11】lambda

lambda lambda表達式語法 lambda表達式本質是一個匿名函數對象&#xff0c;跟普通函數不同的是它可以定義在函數內部。lambda表達式語法使用層而言沒有類型&#xff0c;所以一般是用auto或者模板參數定義的對象去接收lambda對象。 lambda表達式的格式&#xff1a;[capture-l…

fpga:分秒計時器

任務目標 分秒計數器核心功能&#xff1a;實現從00:00到59:59的循環計數&#xff0c;通過四個七段數碼管顯示分鐘和秒。 復位功能&#xff1a;支持硬件復位&#xff0c;將計數器歸零并顯示00:00。 啟動/暫停控制&#xff1a;通過按鍵控制計時的啟動和暫停。 消抖處理&#…

《UNIX網絡編程卷1:套接字聯網API》第6章 IO復用:select和poll函數

《UNIX網絡編程卷1&#xff1a;套接字聯網API》第6章 I/O復用&#xff1a;select和poll函數 6.1 I/O復用的核心價值與適用場景 I/O復用是高并發網絡編程的基石&#xff0c;允許單個進程/線程同時監控多個文件描述符&#xff08;套接字&#xff09;的狀態變化&#xff0c;從而高…