基于梯度相似性的聯邦學習客戶端選擇算法
- Abstract 摘要
- introduction
- **背景**
- **目的**
- **結論**
- **結果**
- **討論**
- **思路**
鏈接:https://link.springer.com/article/10.1007/s10586-024-04846-0
三區
Abstract 摘要
聯邦學習(FL)是一種創新的機器學習方法,終端設備在中央服務器協調下共同訓練全局模型,解決了數據隱私和數據孤島問題,而無需將數據傳輸到中央服務器。然而,在聯邦學習中,客戶端數據的異質性顯著影響了 FL 的性能。為了解決模型精度低和收斂速度慢的問題,提出了一種基于梯度相似性的客戶端選擇算法(FedGSCS)。該算法通過比較客戶端梯度與平均梯度之間的相似性來選擇客戶端,優先選擇能夠加速模型聚合以促進模型收斂的客戶端。在 MNIST、FEMNIST 和莎士比亞數據集上的實驗表明,與 Federated Averaging(FedAvg)算法、選擇損失最高的客戶端的 Power-of-Choice 算法以及基于多樣客戶端選擇的聯邦平均(DivFL)算法相比,FedGSCS 將通信輪數減少了最多 80%,并提高了最多 16.38%的準確性。
introduction
隨著人工智能的快速發展,每天從各種移動設備和物聯網(IoT)設備中生成數以萬計的數據點[1]。利用這些數據可以幫助我們構建更大、更復雜的神經網絡模型,從而提高模型的準確性。然而,在傳統的分布式機器學習中,設備上的數據需要上傳到中央服務器進行復雜的模型訓練,涉及數百萬個參數。但在實際場景中,由于數據隱私方面的擔憂[2],各方通常不會與其他方共享私有數據,這可能導致數據孤島問題。在這種情況下,傳統的機器學習技術變得不再適用。
為了解決這個問題,Google 提出了聯邦學習(FL)[3]。在保護用戶隱私的前提下,該方法通過中央服務器協調參與者,利用各自的本地數據集進行聯合訓練,從而形成增強的全局模型。聯邦學習技術通過繞過直接將數據上傳到中央服務器的方式,有效保護了數據隱私。因此,聯邦學習在醫療行業[4]、智慧城市[5]、智能交通[6]等多個領域得到了越來越廣泛的應用。
盡管在聯邦學習方面取得了顯著進展以解決數據隱私和數據孤島問題,但仍存在若干技術挑戰。移動設備面臨設備狀態、帶寬和網絡連接等方面的限制,使得通信成本成為聯邦學習(FL)環境中的瓶頸。此外,中央服務器必須通過學習本地數據樣本來更新全局模型。FL 客戶端通常從各種來源收集數據、使用不同的工具并在不同的環境中運行,導致非獨立同分布(non-IID)數據。這種數據異質性會對 FL 系統的性能和收斂性產生嚴重影響 [7, 8]。為了緩解這些問題,已經提出了各種優化策略。一些方法側重于通過引入約束修改本地模型更新過程 [9, 10],而另一些方法則通過引入公共數據集 [11, 12] 或使用控制變量 [13] 來平衡客戶端差異。然而,這些方法往往忽視了數據異質性對全局模型性能的影響。 通常,在模型聚合過程中,所有客戶端都會被同等對待,而不考慮它們的具體數據特性。這種做法不僅增加了通信成本,還可能影響全局模型的性能和泛化能力。
為了解決 Federated Averaging(FedAvg)算法在非 IID 數據場景下的性能退化問題,Sattler 等人[14]引入了 CFL 框架。該框架表明,客戶端梯度之間的余弦相似度可以有效地指示兩個客戶端是否具有相同的數據生成分布。實驗結果表明,當客戶端數據表現出聚類結構時,CFL 框架顯著提高了分類準確率,并在困惑度方面優于傳統的 FL 方法。Palihawadana 等人[15]還提出了 FedSim 算法,該算法利用余弦相似度將 FL 模型聚合分為兩個階段:先進行局部聚合,再對具有相似梯度的客戶端進行全局聚合。這種策略減少了方差,增強了全局模型的穩定性和覆蓋率。進一步的實驗結果證實,使用余弦相似度衡量潛在客戶端的相似性可以顯著提高模型性能和穩定性。此外,Marnissi 等人[16]引入了一種基于梯度范數重要性的設備選擇策略。 **基于這一理論和實驗基礎,本文提出了一種基于梯度相似度的聯邦學習客戶端選擇方法。該方法旨在通過優先選擇其梯度與平均梯度具有更高余弦相似度的客戶端來提升聯邦學習性能。**具體而言,在每次模型聚合之前,會從每個客戶端上傳的梯度中計算出平均梯度。這一步驟加快了收斂速度,并減少了非 IID 數據對全局模型的影響。隨后,使用余弦相似度來衡量平均梯度與每個客戶端梯度之間的相似度。基于這種相似度度量,選擇能夠更有效地促進模型收斂的客戶端。通過智能選擇客戶端,該算法不僅在圖像分類和文本預測任務中提高了模型質量,還加快了模型的收斂速度。本文的主要貢獻如下:
本文提出了一種新的客戶端選擇算法,名為 FedGSCS。該算法旨在通過基于客戶端梯度與均值梯度之間的余弦相似度選擇客戶端,從而提高 FL 全局模型的準確性和收斂速度。該算法的主要目標是在聚合全局模型之前,利用余弦相似度篩選出低質量的客戶端,以提高整體性能。由于客戶端之間數據異質性存在差異,FedGSCS 算法巧妙地利用客戶端梯度與平均梯度之間的余弦相似度,戰略性地選擇客戶端,從而豐富所選客戶端數據的多樣性。在 MNIST、FEMNIST 和莎士比亞數據集上的實驗表明,與 Federated Averaging(FedAvg)、基于最高損失選擇客戶端的 Power-of-Choice 算法以及基于多樣客戶端選擇的 Federated Averaging(DivFL)相比,采用 FedGSCS 可以將 FL 訓練通信輪數最多減少 80%,并提高準確率高達 16.38%。
背景
聯邦學習(FL)通過終端設備協作訓練全局模型,解決了數據隱私和孤島問題,但客戶端數據的異構性顯著影響模型性能。傳統方法如聯邦平均算法(FedAvg)在非獨立同分布(non-IID)數據下存在收斂慢、準確率低的問題。現有客戶端選擇策略(如Power-of-Choice、DivFL)未充分利用梯度相似性優化選擇過程,導致通信效率和模型性能受限。
目的
提出一種基于梯度相似性的客戶端選擇算法(FedGSCS),通過篩選與全局梯度相似性高的客戶端參與聚合,提升模型在非IID數據下的收斂速度和準確率,同時減少通信開銷。
結論
- 性能優勢:FedGSCS在MNIST、FEMNIST和Shakespeare數據集上,較基線算法(FedAvg、Power-of-Choice、DivFL)準確率提升最高達16.38%,通信輪次減少最多80%。
- 有效性驗證:梯度相似性選擇策略有效篩選高貢獻客戶端,在非IID場景中顯著優于隨機選擇和基于損失的選擇方法。
- 魯棒性:在不同數據分布下均表現穩定,適用于圖像分類和文本預測任務。
結果
- 準確率對比
- MNIST:FedGSCS在Case 1非IID場景中準確率達90.52%,優于FedAvg(88.08%)和DivFL(89.16%)。
- FEMNIST:Case 1場景下準確率提升14.28%(從80.71%到94.99%)。
- Shakespeare:Case 1場景中準確率達40.44%,較Power-of-Choice(24.06%)提升顯著。
- 通信效率
- MNIST:ToA@0.9-Case 2場景中通信輪次減少41.03%(從82輪降至69輪)。
- Shakespeare:ToA@0.4-Case 1場景中僅需33輪,而Power-of-Choice未達標。
討論
- 梯度相似性的作用:通過余弦相似性量化客戶端與全局梯度的一致性,有效過濾低貢獻客戶端,加速模型收斂。
- 非IID數據適應性:在數據分布差異較大的場景(如β=0.5的Shakespeare數據集)中,FedGSCS仍保持較高準確率,驗證了其魯棒性。
- 局限性:未考慮設備資源異構性(如計算能力、能耗),未來需結合資源感知策略優化選擇過程。
思路
- 梯度平均計算:服務器計算所有客戶端梯度的平均值 g t = 1 K ∑ i = 1 K g i t g^t = \frac{1}{K}\sum_{i=1}^K g_i^t gt=K1?∑i=1K?git?,作為全局優化方向。
- 余弦相似性度量:通過公式 c o s k t = ? g t , g k t ? ∥ g t ∥ ∥ g k t ∥ cos_k^t = \frac{\langle g^t, g_k^t \rangle}{\|g^t\| \|g_k^t\|} coskt?=∥gt∥∥gkt?∥?gt,gkt??? 量化客戶端梯度與全局梯度的相似性。
- Top-P客戶端選擇:按相似性排序,選擇前P個客戶端參與聚合,提升模型更新質量。