【連續學習之VCL算法】2017年論文:Variational continual learning

1 介紹

年份:2017

期刊: arXiv preprint

Nguyen C V, Li Y, Bui T D, et al. Variational continual learning[J]. arXiv preprint arXiv:1710.10628, 2017.

本文提出的算法是變分連續學習(Variational Continual Learning, VCL),它是一種基于變分推斷的在線學習方法,結合了在線變分推斷(VI)和蒙特卡洛VI的最新進展,用于訓練深度判別模型和生成模型,以實現在連續學習設置中避免災難性遺忘并適應新任務的能力。關鍵步驟包括使用變分推斷來近似后驗分布,并通過核心集(coreset)數據摘要方法增強模型的記憶能力。本文算法屬于基于變分推斷的算法,它通過在線更新模型參數的后驗分布來實現連續學習,這可以歸類為基于正則化的算法,因為它利用KL散度最小化來正則化模型參數,以平衡對新數據的適應性和對舊數據的保留。

2 創新點

  1. 變分連續學習框架(VCL)
    • 提出了一種新的連續學習框架,即變分連續學習(VCL),它結合了在線變分推斷(VI)和蒙特卡洛VI,適用于復雜的連續學習環境。
  2. 深度模型的連續學習
    • 將VCL框架應用于深度判別模型和深度生成模型,展示了該框架在這些復雜神經網絡模型中的有效性。
  3. 核心集(coreset)數據摘要
    • 引入了核心集的概念,這是一種小型的代表性數據集,用于保留先前任務的關鍵信息,幫助算法在新任務學習中避免遺忘舊任務。
  4. 自動和無參數的連續學習
    • VCL框架避免了傳統方法中需要手動調整的超參數,實現了完全自動化的學習過程,且無需額外的驗證集來調整參數。
  5. 實驗結果的優越性
    • 在多個任務上的實驗結果顯示,VCL在避免災難性遺忘方面優于現有的連續學習方法,且不需要調整任何超參數。
  6. 理論基礎和擴展性
    • 基于貝葉斯推斷的理論基礎,VCL提供了一種原則性強、可擴展的解決方案,可以應用于多種不同的模型和學習場景。
  7. 適用于復雜任務演化
    • VCL能夠處理任務隨時間演變以及全新任務出現的情況,這對于現實世界中任務不斷變化的場景具有重要意義。

3 算法

3.1 算法原理

  1. 貝葉斯推斷框架
    • 貝葉斯推斷提供了一個自然框架來處理連續學習問題。它通過保留模型參數的分布來表示參數的不確定性,這有助于在新數據到來時更新知識,同時保留舊知識。
  2. 在線變分推斷(Online VI)
    • 在線VI是一種近似貝葉斯推斷的方法,它通過迭代更新近似后驗分布來處理新數據。VCL利用在線VI來遞歸地更新模型參數的后驗分布。
  3. 變分連續學習(VCL)
    • VCL通過最小化KL散度(Kullback-Leibler divergence)來找到最佳近似后驗分布。具體來說,對于每一步新數據的到來,VCL通過結合之前的后驗分布和新數據的似然函數,然后通過變分推斷找到新的近似后驗分布。
  4. 核心集(Coreset)
    • 為了緩解連續學習中累積的近似誤差,VCL引入了核心集的概念。核心集是從先前任務中提取的代表性數據點集合,用于在訓練過程中刷新模型對舊任務的記憶。
  5. 遞歸更新
    • VCL遞歸地更新模型參數的近似后驗分布。給定前一步的后驗分布和新數據,VCL通過乘以似然函數并重新歸一化來獲得新的后驗分布。
  6. 預測和參數更新
    • 在測試時,VCL使用最終的變分分布來進行預測。在訓練時,VCL通過最大化變分下界(variational lower bound)來更新變分參數,這涉及到計算期望對數似然和KL散度。
  7. 蒙特卡洛方法
    • 為了處理期望對數似然的計算,VCL采用蒙特卡洛方法來近似這些期望值,這通常涉及到使用重參數化技巧(reparameterization trick)來計算梯度。

3.2 算法步驟

  1. 初始化:選擇一個先驗分布 p ( θ ) p(\theta) p(θ)并初始化變分近似 q 0 ( θ ) = p ( θ ) q_0(\theta) = p(\theta) q0?(θ)=p(θ)
  2. 核心集初始化:初始化核心集 C 0 = ? C_0 = \emptyset C0?=?
  3. 對于每一個新任務 t = 1 , 2 , … , T t = 1, 2, \ldots, T t=1,2,,T執行以下步驟:a. 觀察新數據集 D t D_t Dt?。b. 更新核心集 C t C_t Ct?,使用 C t ? 1 C_{t-1} Ct?1? D t D_t Dt?來選擇新的代表性數據點。c. 更新非核心集數據點的變分分布:

q ~ t ( θ ) = arg ? min ? q ∈ Q K L ( q ( θ ) ∥ q ~ t ? 1 ( θ ) p ( D t ∪ C t ? 1 ? C t ∣ θ ) Z ) \tilde{q}_t(\theta) = \arg\min_{q \in Q} KL \left( q(\theta) \parallel \frac{\tilde{q}_{t-1}(\theta) p(D_t \cup C_{t-1} \setminus C_t | \theta)}{Z} \right) q~?t?(θ)=argqQmin?KL(q(θ)Zq~?t?1?(θ)p(Dt?Ct?1??Ct?θ)?)

其中, Z Z Z是歸一化常數。

d. 計算最終的變分分布(僅用于預測):

q t ( θ ) = arg ? min ? q ∈ Q K L ( q ( θ ) ∥ q ~ t ( θ ) p ( C t ∣ θ ) Z ) q_t(\theta) = \arg\min_{q \in Q} KL \left( q(\theta) \parallel \frac{\tilde{q}_t(\theta) p(C_t | \theta)}{Z} \right) qt?(θ)=argqQmin?KL(q(θ)Zq~?t?(θ)p(Ct?θ)?)

e. 進行預測:在測試輸入 x ? x^* x?上,使用 q t ( θ ) q_t(\theta) qt?(θ)來計算預測分布:

p ( y ? ∣ x ? , D 1 : t ) = ∫ q t ( θ ) p ( y ? ∣ θ , x ? ) d θ p(y^* | x^*, D_{1:t}) = \int q_t(\theta) p(y^* | \theta, x^*) d\theta p(y?x?,D1:t?)=qt?(θ)p(y?θ,x?)dθ

4 實驗分析

圖1展示了論文中測試的多頭網絡架構,包括判別模型(a)和生成模型(b),其中判別模型中低層網絡參數θS在多個任務中共享,每個任務t有自己的“頭部網絡”θtH,映射到共同隱藏層的輸出;生成模型中頭部網絡生成來自潛在變量z的中間層表示。

圖6展示了在訓練后各個任務生成器生成的圖像,其中每列代表特定任務生成器的輸出,每行顯示所有訓練任務生成器的結果,明顯地,簡單直接的在線學習方法遭受了災難性遺忘,而其他方法(如VCL)成功地記住了之前的任務。實驗結論是,與簡單在線學習相比,VCL等方法在連續學習環境中能更好地保留對先前任務的記憶,避免了災難性遺忘,展現出更好的長期記憶性能。

5 思考

(1)代碼舉例理解本文算法

import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Normal
from torch.nn.functional import softmax# 假設我們有一個簡單的神經網絡模型
class SimpleNN(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(SimpleNN, self).__init__()self.fc1 = nn.Linear(input_size, hidden_size)self.fc2 = nn.Linear(hidden_size, output_size)def forward(self, x):x = torch.relu(self.fc1(x))x = self.fc2(x)return x# 變分連續學習算法的實現
def variational_continual_learning(model, prior_mu, prior_sigma, tasks_num, lr=0.001):optimizer = optim.Adam(model.parameters(), lr=lr)for t in range(tasks_num):# 加載當前任務的數據datasets, labels = data_loader(t)# 遍歷當前任務的數據進行訓練for data, label in zip(datasets, labels):# 前向傳播output = model(data)log_likelihood = softmax(output, dim=1).gather(1, label.unsqueeze(1)).squeeze(1).log()# 計算損失函數,包括負對數似然和KL散度loss = -log_likelihood + kl_divergence(model.fc2.weight, model.fc2.bias, prior_mu, prior_sigma)# 反向傳播和優化optimizer.zero_grad()loss.backward()optimizer.step()return modeldef kl_divergence(weights, biases, prior_mu, prior_sigma):# 計算權重和偏置的KL散度posterior_mu = weightsposterior_sigma = torch.nn.functional.softplus(biases) + 1e-6  # 防止sigma為0# KL散度計算公式kl_w = 0.5 * (torch.log(prior_sigma) - torch.log(posterior_sigma) + posterior_sigma**2 + (posterior_mu - prior_mu)**2 / posterior_sigma**2 - 1)kl_b = 0.5 * (torch.log(prior_sigma) - torch.log(posterior_sigma) + posterior_sigma**2 - 1)return kl_w.sum() + kl_b.sum()# 假設我們有一個數據加載器,用于加載連續的任務
def data_loader(task_id):# 這里只是一個示例,實際中需要根據task_id加載不同的數據# 返回當前任務的數據和標簽pass# 初始化模型
input_size = 784  # 例如MNIST數據集
hidden_size = 100
output_size = 10  # 假設有10個類別
model = SimpleNN(input_size, hidden_size, output_size)# 設置先驗分布的均值和標準差
prior_mu = torch.zeros(output_size)
prior_sigma = torch.ones(output_size)# 執行變分連續學習算法
tasks_num = 5  # 假設有5個連續的任務
trained_model = variational_continual_learning(model, prior_mu, prior_sigma, tasks_num)

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/64204.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/64204.shtml
英文地址,請注明出處:http://en.pswp.cn/web/64204.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

多視圖 (Multi-view) 與多模態 (Multi-modal)

多視圖 (Multi-view) 與多模態 (Multi-modal) 是兩種不同的數據處理方式,它們在機器學習和數據分析中有著重要的應用。盡管這兩者有一些相似之處,但它們關注的角度和處理方法有所不同。 多視圖 (Multi-view) 定義:多視圖指的是同一數據對象…

MySQL 性能瓶頸,為什么 MySQL 表的數據量不能太大?

MySQL的性能瓶頸(為什么MySQL有幾萬的qps,怎么來的?性能分析 為什么 MySQL 表不能太大網上大部分人的說法:問題的關鍵: B樹層數對查詢性能的影響到底有多大? 是什么導致的 MySQL 查詢緩慢?如何解決: MySQL的性能瓶頸(為什么MySQL有幾萬的qps,怎么來的? 一個全表掃描的查詢…

Linux 實用命令 grep、wc

grep 命令詳解 grep [選項] ‘模式’ 文件名 grep [參數] [選項] [操作對象]grep ‘error’ -c 5 --color info.log [模式]:是要搜索的字符串或正則表達式。 [選項]:是可選的,用于定制grep的行為。 [操作對象]:是要搜索的文件…

【Transformer】深入淺出自注意力機制

寫在前面:博主本人也是剛接觸計算機視覺領域不久,本篇文章是為了記錄自己的學習,大家一起學習,有問題歡迎大家指出。(博主本人的習慣是看文章看到不懂的有立馬去看不懂的那塊,所以博文可能內容比較雜&#…

HarmonyOS NEXT 實戰之元服務:靜態案例效果---教育培訓服務

背景: 前幾篇學習了元服務,后面幾期就讓我們開發簡單的元服務吧,里面豐富的內容大家自己加,本期案例 僅供參考 先上本期效果圖 ,里面圖片自行替換 效果圖1完整代碼案例如下: import { authentication } …

3.阿里云flinkselectdb-py作業

1.概述 Python API中文文檔 本文介紹在阿里云實時計算flink中使用python作業,把oss中的數據同步數據到阿里云selectdb的過程。python簡單的語法特性更適合flink作業的開發; 先說結論: 在實際開發中遇到了很多問題,導致python作業基本基本無法…

互聯網視頻云平臺EasyDSS無人機推流直播技術如何助力野生動植物保護工作?

在當今社會,隨著科技的飛速發展,無人機技術已經廣泛應用于各個領域,為我們的生活帶來了諸多便利。而在動植物保護工作中,無人機的應用更是為這一領域注入了新的活力。EasyDSS,作為一款集視頻處理、分發、存儲于一體的綜…

51c視覺~YOLO~合集8

我自己的原文哦~ https://blog.51cto.com/whaosoft/12897680 1、Yolo9 1.1、YOLOv9SAM實現動態目標檢測和分割 主要介紹基于YOLOv9SAM實現動態目標檢測和分割 背景介紹 在本文中,我們使用YOLOv9SAM在RF100 Construction-Safety-2 數據集上實現自定義對象檢測模…

Docker Container 可觀測性最佳實踐

Docker Container 介紹 Docker Container( Docker 容器)是一種輕量級、可移植的、自給自足的軟件運行環境,它在 Docker 引擎的宿主機上運行。容器在許多方面類似于虛擬機,但它們更輕量,因為它們不需要模擬整個操作系統…

氣相色譜-質譜聯用分析方法中的常用部件,分流平板更換

分流平板,是氣相色譜-質譜聯用分析方法中的一個常用部件,它可以實現氣相色譜柱流與MS檢測器流的分離和分流。常見的氣質聯用儀分流平板有很多種,如單層T型分流平板、雙層T型分流平板、螺旋分流平板等等。 操作視頻http://www.spcctech.com/v…

易基因: BS+ChIP-seq揭示DNA甲基化調控非編碼RNA(VIM-AS1)抑制腫瘤侵襲性|Exp Mol Med

大家好,這里是專注表觀組學十余年,領跑多組學科研服務的易基因。 肝細胞癌(hepatocellular carcinoma,HCC)早期復發仍然是一個具有挑戰性的領域,其中涉及的機制尚未完全被理解。盡管微血管侵犯&#xff08…

鴻蒙系統文件管理基礎服務的設計背景和設計目標

有一定經驗的開發者通常對文件管理相關的api應用或者底層邏輯都比較熟悉,但是關于文件管理服務的設計背景和設計目標可能了解得不那么清楚,本文旨在分享文件管理服務的設計背景及目標,方便廣大開發者更好地理解鴻蒙系統文件管理服務。 1 鴻蒙…

如何配置 Java 環境變量:設置 JAVA_HOME 和 PATH

目錄 一、什么是 Java 環境變量? 二、配置 Java 環境變量 1. 下載并安裝 JDK 2. 配置 JAVA_HOME Windows 系統 Linux / macOS 系統 3. 配置 PATH Windows 系統 Linux / macOS 系統 4. 驗證配置 三、常見問題與解決方案 1. 無法識別 java 或 javac 命令 …

Doris 數據庫外部表-JDBC 外表,Oracle to Doris

簡介 提供了 Doris 通過數據庫訪問的標準接口 (JDBC) 來訪問外部表,外部表省去了繁瑣的數據導入工作,讓 Doris 可以具有了訪問各式數據庫的能力,并借助 Doris 本身的 OLAP 的能力來解決外部表的數據分析問題: 支持各種數據源接入…

分布式 IO 模塊助力沖壓機械臂產線實現智能控制

在當今制造業蓬勃發展的浪潮中,沖壓機械臂產線的智能化控制已然成為提升生產效率、保障產品質量以及增強企業競爭力的關鍵所在。而分布式 IO 模塊的應用,正如同為這條產線注入了一股強大的智能動力,開啟了全新的高效生產篇章。 傳統挑戰 沖壓…

CSS系列(37)-- Overscroll Behavior詳解

前端技術探索系列:CSS Overscroll Behavior詳解 📱 致讀者:探索滾動交互的藝術 👋 前端開發者們, 今天我們將深入探討 CSS Overscroll Behavior,這個強大的滾動行為控制特性。 基礎概念 🚀 …

深度學習中的并行策略概述:4 Tensor Parallelism

深度學習中的并行策略概述:4 Tensor Parallelism 使用 PyTorch 實現 Tensor Parallelism 。首先定義了一個簡單的模型 SimpleModel,它包含兩個全連接層。然后,本文使用 torch.distributed.device_mesh 初始化了一個設備網格,這代…

企業銷售人員培訓系統|Java|SSM|VUE| 前后端分離

【技術棧】 1??:架構: B/S、MVC 2??:系統環境:Windowsh/Mac 3??:開發環境:IDEA、JDK1.8、Maven、Mysql5.7 4??:技術棧:Java、Mysql、SSM、Mybatis-Plus、VUE、jquery,html 5??數據庫…

vue 本地自測iframe通訊

使用 postMessage API 來實現跨窗口(跨域)的消息傳遞。postMessage 允許你安全地發送消息到其他窗口,包括嵌套的 iframe,而不需要擔心同源策略的問題。 發送消息(父應用) 1. 父應用:發送消息給…

Linux:code:network:devinet_sysctl_forward;IN_DEV_FORWARD

文章目錄 簡介sysctl 設置使用,arp_process間接使用IN_DEV_RX_REDIRECTSdev_disable_lro簡介 最近在看Linux里的forwarding的功能。順便在這里總結一下。有些詳細代碼邏輯,如果可以記錄一下,會好一點。 sysctl 設置 這個函數在查看的時候需要注意的問題:變量名起的有點簡…