1 介紹
年份:2017
期刊: arXiv preprint
Nguyen C V, Li Y, Bui T D, et al. Variational continual learning[J]. arXiv preprint arXiv:1710.10628, 2017.
本文提出的算法是變分連續學習(Variational Continual Learning, VCL),它是一種基于變分推斷的在線學習方法,結合了在線變分推斷(VI)和蒙特卡洛VI的最新進展,用于訓練深度判別模型和生成模型,以實現在連續學習設置中避免災難性遺忘并適應新任務的能力。關鍵步驟包括使用變分推斷來近似后驗分布,并通過核心集(coreset)數據摘要方法增強模型的記憶能力。本文算法屬于基于變分推斷的算法,它通過在線更新模型參數的后驗分布來實現連續學習,這可以歸類為基于正則化的算法,因為它利用KL散度最小化來正則化模型參數,以平衡對新數據的適應性和對舊數據的保留。
2 創新點
- 變分連續學習框架(VCL):
- 提出了一種新的連續學習框架,即變分連續學習(VCL),它結合了在線變分推斷(VI)和蒙特卡洛VI,適用于復雜的連續學習環境。
- 深度模型的連續學習:
- 將VCL框架應用于深度判別模型和深度生成模型,展示了該框架在這些復雜神經網絡模型中的有效性。
- 核心集(coreset)數據摘要:
- 引入了核心集的概念,這是一種小型的代表性數據集,用于保留先前任務的關鍵信息,幫助算法在新任務學習中避免遺忘舊任務。
- 自動和無參數的連續學習:
- VCL框架避免了傳統方法中需要手動調整的超參數,實現了完全自動化的學習過程,且無需額外的驗證集來調整參數。
- 實驗結果的優越性:
- 在多個任務上的實驗結果顯示,VCL在避免災難性遺忘方面優于現有的連續學習方法,且不需要調整任何超參數。
- 理論基礎和擴展性:
- 基于貝葉斯推斷的理論基礎,VCL提供了一種原則性強、可擴展的解決方案,可以應用于多種不同的模型和學習場景。
- 適用于復雜任務演化:
- VCL能夠處理任務隨時間演變以及全新任務出現的情況,這對于現實世界中任務不斷變化的場景具有重要意義。
3 算法
3.1 算法原理
- 貝葉斯推斷框架:
- 貝葉斯推斷提供了一個自然框架來處理連續學習問題。它通過保留模型參數的分布來表示參數的不確定性,這有助于在新數據到來時更新知識,同時保留舊知識。
- 在線變分推斷(Online VI):
- 在線VI是一種近似貝葉斯推斷的方法,它通過迭代更新近似后驗分布來處理新數據。VCL利用在線VI來遞歸地更新模型參數的后驗分布。
- 變分連續學習(VCL):
- VCL通過最小化KL散度(Kullback-Leibler divergence)來找到最佳近似后驗分布。具體來說,對于每一步新數據的到來,VCL通過結合之前的后驗分布和新數據的似然函數,然后通過變分推斷找到新的近似后驗分布。
- 核心集(Coreset):
- 為了緩解連續學習中累積的近似誤差,VCL引入了核心集的概念。核心集是從先前任務中提取的代表性數據點集合,用于在訓練過程中刷新模型對舊任務的記憶。
- 遞歸更新:
- VCL遞歸地更新模型參數的近似后驗分布。給定前一步的后驗分布和新數據,VCL通過乘以似然函數并重新歸一化來獲得新的后驗分布。
- 預測和參數更新:
- 在測試時,VCL使用最終的變分分布來進行預測。在訓練時,VCL通過最大化變分下界(variational lower bound)來更新變分參數,這涉及到計算期望對數似然和KL散度。
- 蒙特卡洛方法:
- 為了處理期望對數似然的計算,VCL采用蒙特卡洛方法來近似這些期望值,這通常涉及到使用重參數化技巧(reparameterization trick)來計算梯度。
3.2 算法步驟
- 初始化:選擇一個先驗分布 p ( θ ) p(\theta) p(θ)并初始化變分近似 q 0 ( θ ) = p ( θ ) q_0(\theta) = p(\theta) q0?(θ)=p(θ)。
- 核心集初始化:初始化核心集 C 0 = ? C_0 = \emptyset C0?=?。
- 對于每一個新任務 t = 1 , 2 , … , T t = 1, 2, \ldots, T t=1,2,…,T執行以下步驟:a. 觀察新數據集 D t D_t Dt?。b. 更新核心集 C t C_t Ct?,使用 C t ? 1 C_{t-1} Ct?1?和 D t D_t Dt?來選擇新的代表性數據點。c. 更新非核心集數據點的變分分布:
q ~ t ( θ ) = arg ? min ? q ∈ Q K L ( q ( θ ) ∥ q ~ t ? 1 ( θ ) p ( D t ∪ C t ? 1 ? C t ∣ θ ) Z ) \tilde{q}_t(\theta) = \arg\min_{q \in Q} KL \left( q(\theta) \parallel \frac{\tilde{q}_{t-1}(\theta) p(D_t \cup C_{t-1} \setminus C_t | \theta)}{Z} \right) q~?t?(θ)=argq∈Qmin?KL(q(θ)∥Zq~?t?1?(θ)p(Dt?∪Ct?1??Ct?∣θ)?)
其中, Z Z Z是歸一化常數。
d. 計算最終的變分分布(僅用于預測):
q t ( θ ) = arg ? min ? q ∈ Q K L ( q ( θ ) ∥ q ~ t ( θ ) p ( C t ∣ θ ) Z ) q_t(\theta) = \arg\min_{q \in Q} KL \left( q(\theta) \parallel \frac{\tilde{q}_t(\theta) p(C_t | \theta)}{Z} \right) qt?(θ)=argq∈Qmin?KL(q(θ)∥Zq~?t?(θ)p(Ct?∣θ)?)
e. 進行預測:在測試輸入 x ? x^* x?上,使用 q t ( θ ) q_t(\theta) qt?(θ)來計算預測分布:
p ( y ? ∣ x ? , D 1 : t ) = ∫ q t ( θ ) p ( y ? ∣ θ , x ? ) d θ p(y^* | x^*, D_{1:t}) = \int q_t(\theta) p(y^* | \theta, x^*) d\theta p(y?∣x?,D1:t?)=∫qt?(θ)p(y?∣θ,x?)dθ
4 實驗分析
圖1展示了論文中測試的多頭網絡架構,包括判別模型(a)和生成模型(b),其中判別模型中低層網絡參數θS在多個任務中共享,每個任務t有自己的“頭部網絡”θtH,映射到共同隱藏層的輸出;生成模型中頭部網絡生成來自潛在變量z的中間層表示。
圖6展示了在訓練后各個任務生成器生成的圖像,其中每列代表特定任務生成器的輸出,每行顯示所有訓練任務生成器的結果,明顯地,簡單直接的在線學習方法遭受了災難性遺忘,而其他方法(如VCL)成功地記住了之前的任務。實驗結論是,與簡單在線學習相比,VCL等方法在連續學習環境中能更好地保留對先前任務的記憶,避免了災難性遺忘,展現出更好的長期記憶性能。
5 思考
(1)代碼舉例理解本文算法
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Normal
from torch.nn.functional import softmax# 假設我們有一個簡單的神經網絡模型
class SimpleNN(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(SimpleNN, self).__init__()self.fc1 = nn.Linear(input_size, hidden_size)self.fc2 = nn.Linear(hidden_size, output_size)def forward(self, x):x = torch.relu(self.fc1(x))x = self.fc2(x)return x# 變分連續學習算法的實現
def variational_continual_learning(model, prior_mu, prior_sigma, tasks_num, lr=0.001):optimizer = optim.Adam(model.parameters(), lr=lr)for t in range(tasks_num):# 加載當前任務的數據datasets, labels = data_loader(t)# 遍歷當前任務的數據進行訓練for data, label in zip(datasets, labels):# 前向傳播output = model(data)log_likelihood = softmax(output, dim=1).gather(1, label.unsqueeze(1)).squeeze(1).log()# 計算損失函數,包括負對數似然和KL散度loss = -log_likelihood + kl_divergence(model.fc2.weight, model.fc2.bias, prior_mu, prior_sigma)# 反向傳播和優化optimizer.zero_grad()loss.backward()optimizer.step()return modeldef kl_divergence(weights, biases, prior_mu, prior_sigma):# 計算權重和偏置的KL散度posterior_mu = weightsposterior_sigma = torch.nn.functional.softplus(biases) + 1e-6 # 防止sigma為0# KL散度計算公式kl_w = 0.5 * (torch.log(prior_sigma) - torch.log(posterior_sigma) + posterior_sigma**2 + (posterior_mu - prior_mu)**2 / posterior_sigma**2 - 1)kl_b = 0.5 * (torch.log(prior_sigma) - torch.log(posterior_sigma) + posterior_sigma**2 - 1)return kl_w.sum() + kl_b.sum()# 假設我們有一個數據加載器,用于加載連續的任務
def data_loader(task_id):# 這里只是一個示例,實際中需要根據task_id加載不同的數據# 返回當前任務的數據和標簽pass# 初始化模型
input_size = 784 # 例如MNIST數據集
hidden_size = 100
output_size = 10 # 假設有10個類別
model = SimpleNN(input_size, hidden_size, output_size)# 設置先驗分布的均值和標準差
prior_mu = torch.zeros(output_size)
prior_sigma = torch.ones(output_size)# 執行變分連續學習算法
tasks_num = 5 # 假設有5個連續的任務
trained_model = variational_continual_learning(model, prior_mu, prior_sigma, tasks_num)