聯邦學習的收斂性分析
在聯邦學習中,我們的目標是分析全局模型的收斂性,考慮設備異構性(不同用戶的本地訓練輪次不同)和數據異質性(用戶數據分布不均勻)。以下推導從全局模型更新開始,逐步引入假設并推導期望損失的遞減關系,最終給出收斂性結論。
1. 全局模型更新與泰勒展開
全局模型更新
在聯邦學習中,設全局模型在第 t t t 輪為 g t g_t gt?,共有 U U U 個用戶參與訓練。每個用戶 k k k 從全局模型 g t g_t gt? 開始(即 w t k , 0 = g t w_t^{k, 0} = g_t wtk,0?=gt?),進行 l k t l_k^t lkt? 輪本地梯度下降更新:
w t k , i + 1 = w t k , i ? η ? G t k , i , w_t^{k, i+1} = w_t^{k, i} - \eta \nabla \mathcal{G}_t^{k, i}, wtk,i+1?=wtk,i??η?Gtk,i?,
其中 η \eta η 是學習率, ? G t k , i \nabla \mathcal{G}_t^{k, i} ?Gtk,i? 是用戶 k k k 在第 i i i 輪本地訓練時的梯度。經過 l k t l_k^t lkt? 輪訓練后,用戶 k k k 的本地模型為:
w t k , l k t = w t k , 0 ? η ∑ i = 0 l k t ? 1 ? G t k , i = g t ? η ∑ i = 0 l k t ? 1 ? G t k , i . w_t^{k, l_k^t} = w_t^{k, 0} - \eta \sum_{i=0}^{l_k^t - 1} \nabla \mathcal{G}_t^{k, i} = g_t - \eta \sum_{i=0}^{l_k^t - 1} \nabla \mathcal{G}_t^{k, i}. wtk,lkt??=wtk,0??ηi=0∑lkt??1??Gtk,i?=gt??ηi=0∑lkt??1??Gtk,i?.
全局模型通過聚合所有用戶的本地模型得到:
g t + 1 = 1 U ∑ k = 1 U w t k , l k t = g t ? η U ∑ k = 1 U ∑ i = 0 l k t ? 1 ? G t k , i . g_{t+1} = \frac{1}{U} \sum_{k=1}^U w_t^{k, l_k^t} = g_t - \frac{\eta}{U} \sum_{k=1}^U \sum_{i=0}^{l_k^t - 1} \nabla \mathcal{G}_t^{k, i}. gt+1?=U1?k=1∑U?wtk,lkt??=gt??Uη?k=1∑U?i=0∑lkt??1??Gtk,i?.
泰勒展開
為了分析全局損失 F ( g t + 1 ) F(g_{t+1}) F(gt+1?) 的變化,我們對 F ( g t + 1 ) F(g_{t+1}) F(gt+1?) 在 g t g_t gt? 處進行二階泰勒展開:
F ( g t + 1 ) ≈ F ( g t ) + ? F ( g t ) T ( g t + 1 ? g t ) + 1 2 ( g t + 1 ? g t ) T ? 2 F ( g t ) ( g t + 1 ? g t ) . F(g_{t+1}) \approx F(g_t) + \nabla F(g_t)^T (g_{t+1} - g_t) + \frac{1}{2} (g_{t+1} - g_t)^T \nabla^2 F(g_t) (g_{t+1} - g_t). F(gt+1?)≈F(gt?)+?F(gt?)T(gt+1??gt?)+21?(gt+1??gt?)T?2F(gt?)(gt+1??gt?).
代入 g t + 1 ? g t = ? η U ∑ k = 1 U ∑ i = 0 l k t ? 1 ? G t k , i g_{t+1} - g_t = -\frac{\eta}{U} \sum_{k=1}^U \sum_{i=0}^{l_k^t - 1} \nabla \mathcal{G}_t^{k, i} gt+1??gt?=?Uη?∑k=1U?∑i=0lkt??1??Gtk,i?:
F ( g t + 1 ) ≈ F ( g t ) ? η U ? F ( g t ) T ( ∑ k = 1 U ∑ i = 0 l k t ? 1 ? G t k , i ) + η 2 2 ( 1 U ∑ k = 1 U ∑ i = 0 l k t ? 1 ? G t k , i ) T ? 2 F ( g t ) ( 1 U ∑ k = 1 U ∑ i = 0 l k t ? 1 ? G t k , i ) . F(g_{t+1}) \approx F(g_t) - \frac{\eta}{U} \nabla F(g_t)^T \left( \sum_{k=1}^U \sum_{i=0}^{l_k^t - 1} \nabla \mathcal{G}_t^{k, i} \right) + \frac{\eta^2}{2} \left( \frac{1}{U} \sum_{k=1}^U \sum_{i=0}^{l_k^t - 1} \nabla \mathcal{G}_t^{k, i} \right)^T \nabla^2 F(g_t) \left( \frac{1}{U} \sum_{k=1}^U \sum_{i=0}^{l_k^t - 1} \nabla \mathcal{G}_t^{k, i} \right). F(gt+1?)≈F(gt?)?Uη??F(gt?)T