一個很尷尬的現狀就是我們用于訓練 AI 模型的數據快要用完了。所以我們在大量的使用合成數據!
據估計,目前公開可用的高質量訓練標記大約有 40 萬億到 90 萬億個,其中流行的 FineWeb 數據集包含 15 萬億個標記,僅限于英語。
作為參考,最近發布的 Llama 4 在文本、圖像和視頻數據集上進行了預訓練,使用的標記數量超過 30 萬億個,是 Llama 3 的兩倍多。
這讓我們意識到,我們距離訓練數據達到極限可能只有幾年的時間了。
但那真的是極限嗎?私人數據集呢?
這些數據集的規模可能是公開數據集的 10 到 20 倍(甚至更多),所有存儲的消息中大約有 650 萬億個標記,電子郵件中大約有 1200 萬億個標記。
令人驚訝的是,許多公司收集的大量數據從未被分析過,因此被稱為暗數據(Dark data)。
再想想政府機構、醫院、律師事務所、金融機構、用戶設備等存儲的數據。
我同意這些數據是敏感的,而且有嚴格的數據保護法規來規范其處理方式。
其中大部分數據可能確實不適合用于訓練機器學習模型,但肯定有一部分數據可以為人類和組織帶來巨大價值。
如果有一種方法可以在不共享數據本身的情況下,使用多個組織的敏感合規數據來訓練機器學習模型,那該多好啊!
這就是聯邦機器學習(Federated Machine Learning)的用武之地!
接下來,我們將深入探討聯邦機器學習是什么,它是如何工作的,然后編寫一個聯邦學習流程,使用多個醫療機構的數據安全地訓練一個可以檢測眼部疾病的機器學習模型。
讓我們開始吧!
但是,聯邦機器學習到底是什么?
為了理解聯邦機器學習是什么,我們先來看看傳統的機器學習模型訓練方法。
舉個例子,我們想訓練一個可以檢測 CT 掃描圖像中癌癥的機器學習模型。
第一步是收集來自不同地理位置的多家醫院的正常和癌癥患者的 CT 掃描圖像。
選擇多樣化數據源的原因是:
- 增加樣本量;
- 減少由于不同因素(包括人口統計、專家和機構因素)導致的偏差。
這使得我們的模型即使對于訓練數據集中未充分代表的群體也能具有泛化能力。
一旦這些數據被收集到一個中央的強大服務器上,我們就可以使用這些數據來訓練模型并對其進行評估。
你能發現這種方法有哪些問題,使得執行起來幾乎不可能嗎?
首先,敏感的醫療數據受到法律(如GDPR和HIPAA)的嚴格監管,這使得將這些數據傳輸到中央服務器變得非常困難。
其次,中央服務器必須有足夠的計算和存儲資源來處理這些數據和訓練,這使得這種方法非常昂貴。
如果我們反過來,不是把數據移動到訓練中,而是把訓練移動到數據那里呢?
這就是聯邦機器學習做的事情。
聯邦機器學習是一種機器學習技術,多個組織可以在去中心化的方式下協作訓練機器學習模型,而無需共享他們的數據集。
以下是使用這種方法的步驟:
- 在中央服務器上初始化一個基礎/全局模型。
- 將該模型的參數發送到參與組織的服務器(稱為客戶端/節點),這些服務器包含本地數據。
- 每個客戶端在其本地數據上訓練模型一段時間(不是直到模型收斂,而是進行幾步/一到幾個周期)。
- 在本地訓練完成后,每個客戶端將其模型參數或累積的梯度發送回中央服務器。
- 由于每個客戶端的參數因在不同的本地數據集上訓練而與其他客戶端不同,因此需要通過一個稱為聚合的過程將它們結合起來。聚合的結果用于更新基礎/全局模型的參數。
可以使用多種技術進行聚合,其中一種流行的方法是聯邦平均(Federated Averaging)。
更新后的全局模型參數 = ∑ i = 1 N 客戶端? i 的更新參數 × 客戶端? i 的數據量 ∑ i = 1 N 客戶端? i 的數據量 \text{更新后的全局模型參數} = \frac{\sum_{i=1}^{N} \text{客戶端 } i \text{ 的更新參數} \times \text{客戶端 } i \text{ 的數據量}}{\sum_{i=1}^{N} \text{客戶端 } i \text{ 的數據量}} 更新后的全局模型參數=∑i=1N?客戶端?i?的數據量∑i=1N?客戶端?i?的更新參數×客戶端?i?的數據量?
在這種方法中,不同客戶端的更新會被平均,并根據每個客戶端用于訓練的數據點數量進行加權。
- 更新后的基礎模型參數被發送回客戶端,然后重復上述訓練過程,直到獲得一個完全訓練好的模型。
你有沒有注意到聯邦機器學習帶來的優勢?
首先,數據保留在其生成的地方,從未被傳輸到一個中央位置,這使得這種方法是去中心化的。
其次,減少了對單一強大基礎設施的需求,因為計算是在所有參與服務器之間共享的。
最后,我們有一種稱為差分隱私(Differential Privacy)的技術,可以保護客戶端數據的隱私。
這是一種技術,通過它,無法從聯邦學習過程中共享的模型更新中識別出關于單個數據點的敏感信息。
為了實現差分隱私,使用了兩種過程:
- 裁剪客戶端模型更新,以限制單個數據點的影響;
- 加噪,即在裁剪后的更新中添加校準后的噪聲。
根據這些過程發生的位置,我們有:
- 中央差分隱私:中央服務器在全局參數上進行加噪,這些全局參數是通過接收客戶端的裁剪更新進行聚合的,或者是由中央服務器進行裁剪的。
- 本地差分隱私:每個客戶端在將模型更新發送到中央服務器之前,本地應用裁剪和加噪。
訓練你的第一個聯邦學習機器學習模型
現在你已經了解了聯邦機器學習的基礎知識,是時候動手實踐并編寫一些代碼了。
視網膜疾病影響著全球數億人,是導致視力喪失和失明的主要原因之一。
**光學相干斷層掃描(OCT)**可以為我們提供視網膜及其他眼層的詳細橫截面圖像。
利用這些圖像,我們的目標是訓練一個機器學習模型,能夠區分健康視網膜和受疾病影響的視網膜。
本教程中的所有代碼都是使用 PyTorch 框架在 Jupyter 筆記本中編寫的。
下載并探索數據集
我們將使用的數據集名為 OCTMNIST。
它是 MedMNIST 數據集的一個子集,包含 109,309
張大小為 28 × 28
像素的灰度、居中裁剪的視網膜 OCT 圖像。
OCTMNIST 是一個多分類數據集,包含以下類別/標簽:
- 脈絡膜新生血管(CNV)
- 糖尿病性黃斑水腫(DME)
- 玻璃膜疣(Drusen)
- 正常
我們先在 Jupyter 筆記本中安裝 medmnist
包,并獲取有關 OCTMNIST 數據集的一些信息。
!uv pip install medmnist
from medmnist import INFO# 獲取 OCTMNIST 數據集信息
info = INFO["octmnist"]print("數據集類型: ", info["task"])
print("數據集標簽: ", info["label"])
print("圖像通道數: ", info["n_channels"])
print("訓練樣本數量: ", info["n_samples"]["train"])
print("驗證樣本數量: ", info["n_samples"]["val"])
print("測試樣本數量: ", info["n_samples"]["test"])
輸出結果如下:
數據集類型: 多分類
數據集標簽: {'0': '脈絡膜新生血管','1': '糖尿病性黃斑水腫', '2': '玻璃膜疣', '3': '正常'}
圖像通道數: 1
訓練樣本數量: 97477
驗證樣本數量: 10832
測試樣本數量: 1000
接下來,我們下載這個數據集,對其應用轉換,并繪制其中的一些圖像。
!uv pip install torch torchvision
import torch# 如果可用,使用 GPU
if torch.backends.mps.is_available():device = torch.device("mps")
elif torch.cuda.is_available():device = torch.device("cuda")
else:device = torch.device("cpu")print(f"使用設備: {device}")
from torchvision import transforms
from medmnist import OCTMNIST# 定義轉換
transform = transforms.ToTensor()# 下載數據集,大小為 64 x 64
train_dataset = OCTMNIST(split='train', transform=transform, download=True, size=64)
val_dataset = OCTMNIST(split='val', transform=transform, download=True, size=64)
test_dataset = OCTMNIST(split='test', transform=transform, download=True, size=64)
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt# 定義標簽映射
label_map = {0: '脈絡膜新生血管',1: '糖尿病性黃斑水腫',2: '玻璃膜疣',3: '正常'
}# 從數據加載器中獲取一批數據
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
images, labels = next(iter(train_loader))# 在 3 x 3 網格中繪制
rows, cols = 3, 3
fig, axes = plt.subplots(rows, cols, figsize=(5, 5))for i in range(rows * cols):ax = axes[i // cols, i % cols]ax.imshow(images[i][0], cmap='gray')ax.set_title(label_map[int(labels[i].item())], fontsize=6)ax.axis('off')plt.tight_layout()
plt.show()
將數據集拆分為子集
現實世界中的醫療數據通常存在類別不平衡和偏差。
為了模擬這種情況,我們將 OCTMNIST 數據集拆分為三個子集。
可以將這些子集視為屬于三家不同醫院的數據集,每個子集都排除了一種眼部疾病標簽。
from torch.utils.data import Subset# 創建子數據集
def create_sub_datasets(full_dataset):targets = torch.tensor([label.item() for _, label in full_dataset])mask_A = (targets == 0) | (targets == 2) | (targets == 3)mask_B = (targets == 0) | (targets == 1) | (targets == 3)mask_C = (targets == 1) | (targets == 2) | (targets == 3)indices_A = mask_A.nonzero(as_tuple=True)[0]indices_B = mask_B.nonzero(as_tuple=True)[0]indices_C = mask_C.nonzero(as_tuple=True)[0]dataset_A = Subset(train_dataset, indices_A) # 包含:CNV、DRUSEN、NORMAL(排除 DME)dataset_B = Subset(train_dataset, indices_B) # 包含:CNV、DME、NORMAL(排除 DRUSEN)dataset_C = Subset(train_dataset, indices_C) # 包含:DME、DRUSEN、NORMAL(排除 CNV)return [dataset_A, dataset_B, dataset_C]dataset_A, dataset_B, dataset_C = create_sub_datasets(train_dataset)
接下來,我們定義一個 ResNet-18 模型,用于將圖像分類到相應的類別中。
from torchvision.models import resnet18, ResNet18_Weights
import torch.nn as nn# ResNet-18
def get_resnet_model(num_classes=4):model = resnet18(weights=ResNet18_Weights.DEFAULT)# 修改第一層卷積層以接受 1 通道輸入model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)# 替換最終的全連接層model.fc = nn.Linear(model.fc.in_features, num_classes)return model.to(device)
訓練與評估
為了模擬在每個醫院使用本地數據進行訓練,我們在之前定義的子數據集上分別訓練三個 ResNet 模型。
以下是訓練和評估的函數。
!uv pip install tqdm
import torch.optim as optim
from tqdm import tqdm# 訓練函數
def train_model(model, criterion, optimizer, train_loader, val_loader, epochs=10):for epoch in range(epochs):model.train()running_correct, running_total = 0, 0loop = tqdm(train_loader, desc=f"Epoch [{epoch+1}/{epochs}]", leave=False)for images, labels in loop:images = images.to(device)labels = labels.squeeze().long().to(device)optimizer.zero_grad()outputs = model(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()preds = torch.argmax(outputs, dim=1)running_correct += (preds == labels).sum().item()running_total += labels.size(0)loop.set_postfix(loss=loss.item(), acc=running_correct / running_total)train_acc = running_correct / running_totalval_acc = evaluate_model(model, val_loader)print(f"Epoch [{epoch+1}/{epochs}] Train Acc: {train_acc:.4f} Val Acc: {val_acc:.4f}")# 評估函數
def evaluate_model(model, test_loader):model.eval()correct, total = 0, 0with torch.no_grad():for images, labels in test_loader:images = images.to(device)labels = labels.squeeze().to(device)outputs = model(images)preds = torch.argmax(outputs, dim=1)correct += (preds == labels).sum().item()total += labels.size(0)return correct / total
# 在子數據集上訓練的函數
def train_on_subset(subset_dataset, val_loader, epochs=10):loader = DataLoader(subset_dataset, batch_size=64, shuffle=True)model = get_resnet_model()criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=0.001)train_model(model, criterion, optimizer, loader, val_loader, epochs)return model
# 在子數據集上評估的函數
def evaluate_on_test(model, test_loader):model.eval()all_preds = []all_labels = []with torch.no_grad():for images, labels in test_loader:images = images.to(device)labels = labels.squeeze().to(device)outputs = model(images)preds = torch.argmax(outputs, dim=1)all_preds.extend(preds.cpu().numpy())all_labels.extend(labels.cpu().numpy())acc = sum([p == t for p, t in zip(all_preds, all_labels)]) / len(all_labels)return acc, all_preds, all_labels
是時候訓練這些模型了!
# 在子數據集上訓練模型
val_loader = DataLoader(val_dataset, batch_size=64)model_A = train_on_subset(dataset_A, val_loader)
model_B = train_on_subset(dataset_B, val_loader)
model_C = train_on_subset(dataset_C, val_loader)
訓練過程的輸出如下:
Epoch [1/10] Train Acc: 0.9162 Val Acc: 0.8073
Epoch [2/10] Train Acc: 0.9454 Val Acc: 0.8477
Epoch [3/10] Train Acc: 0.9526 Val Acc: 0.8588
Epoch [4/10] Train Acc: 0.9587 Val Acc: 0.8509
Epoch [5/10] Train Acc: 0.9597 Val Acc: 0.8574
Epoch [6/10] Train Acc: 0.9671 Val Acc: 0.8619
Epoch [7/10] Train Acc: 0.9700 Val Acc: 0.8629
Epoch [8/10] Train Acc: 0.9747 Val Acc: 0.8623
Epoch [9/10] Train Acc: 0.9774 Val Acc: 0.8541
Epoch [10/10] Train Acc: 0.9787 Val Acc: 0.8647
Epoch [1/10] Train Acc: 0.9466 Val Acc: 0.8498
Epoch [2/10] Train Acc: 0.9725 Val Acc: 0.8988
Epoch [3/10] Train Acc: 0.9780 Val Acc: 0.8967
Epoch [4/10] Train Acc: 0.9816 Val Acc: 0.9027
Epoch [5/10] Train Acc: 0.9841 Val Acc: 0.9031
Epoch [6/10] Train Acc: 0.9854 Val Acc: 0.8917
Epoch [7/10] Train Acc: 0.9881 Val Acc: 0.9060
Epoch [8/10] Train Acc: 0.9899 Val Acc: 0.9060
Epoch [9/10] Train Acc: 0.9911 Val Acc: 0.9053
Epoch [10/10] Train Acc: 0.9930 Val Acc: 0.9005
Epoch [1/10] Train Acc: 0.9001 Val Acc: 0.6071
Epoch [2/10] Train Acc: 0.9429 Val Acc: 0.6188
Epoch [3/10] Train Acc: 0.9531 Val Acc: 0.6117
Epoch [4/10] Train Acc: 0.9509 Val Acc: 0.6280
Epoch [5/10] Train Acc: 0.9610 Val Acc: 0.6289
Epoch [6/10] Train Acc: 0.9649 Val Acc: 0.6283
Epoch [7/10] Train Acc: 0.9675 Val Acc: 0.6265
Epoch [8/10] Train Acc: 0.9699 Val Acc: 0.6321
Epoch [9/10] Train Acc: 0.9750 Val Acc: 0.6330
Epoch [10/10] Train Acc: 0.9768 Val Acc: 0.6363
然后,我們在完整的測試集上測試這些模型的性能(這將是我們在實際應用中運行模型的情況)。
# 在完整測試集上評估
test_loader = DataLoader(test_dataset, batch_size=64)acc_A, preds_A, labels_A = evaluate_on_test(model_A, test_loader)
acc_B, preds_B, labels_B = evaluate_on_test(model_B, test_loader)
acc_C, preds_C, labels_C = evaluate_on_test(model_C, test_loader)# 報告準確率
print(f"測試準確率 | 在排除 DME 的數據集上訓練的模型: {acc_A:.4f}")
print(f"測試準確率 | 在排除 DRUSEN 的數據集上訓練的模型: {acc_B:.4f}")
print(f"測試準確率 | 在排除 CNV 的數據集上訓練的模型: {acc_C:.4f}")
輸出結果如下:
測試準確率 | 在排除 DME 的數據集上訓練的模型: 0.6420
測試準確率 | 在排除 DRUSEN 的數據集上訓練的模型: 0.7080
測試準確率 | 在排除 CNV 的數據集上訓練的模型: 0.7030
我們可以看到,這些模型在測試數據集中未見過的類別上表現不佳。
當我們繪制混淆矩陣并可視化結果時,這一點更加明顯。
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplaydef plot_confusion_matrix(y_true, y_pred, title):cm = confusion_matrix(y_true, y_pred, labels=[0,1,2,3])disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["CNV", "DME", "DRUSEN", "NORMAL"])disp.plot(cmap=plt.cm.Blues)plt.title(title)plt.show()# 繪制混淆矩陣
plot_confusion_matrix(labels_A, preds_A, "混淆矩陣 - 排除 DME 的模型")
plot_confusion_matrix(labels_B, preds_B, "混淆矩陣 - 排除 DRUSEN 的模型")
plot_confusion_matrix(labels_C, preds_C, "混淆矩陣 - 排除 CNV 的模型")
在子數據集上進行聯邦學習
現在輪到聯邦學習大顯身手了。
我們使用 Flower 框架,它允許我們使用任何機器學習框架和任何編程語言進行聯邦學習、分析和評估。
我仍然使用 PyTorch 框架進行本教程,以使其對大多數人更易于理解。
我使用的是 MacBook M4 Max 來運行以下代碼,但如果你在 Google Colab 上使用 T4 GPU,這將報錯。
這是因為 Colab 只將一個 GPU 暴露給主筆記本進程。當 Flower(使用 Ray 作為其默認后端運行模擬)為每個模擬客戶端啟動額外的 Python 工作進程時,這些工作進程無法訪問 GPU,程序就會崩潰。
如果你想在 Google Colab 上運行代碼,建議使用 CPU 作為設備。不過,這會使訓練變得非常緩慢。
安裝 Flower 包
!uv pip install "flwr[simulation]"
回顧一下我們之前學到的內容,在聯邦學習過程中,客戶端和中央服務器之間會交換模型參數/權重。
當客戶端從中央服務器接收到模型參數時,它將使用這些參數/權重更新其本地模型。
訓練完成后,它將這些本地模型參數/權重發送回中央服務器。
定義客戶端函數
兩個函數可以幫助我們執行這些操作:
get_weights
:此函數用于訓練完成后獲取客戶端模型的更新權重,并將其發送回中央服務器。
它接受一個機器學習模型的引用,迭代其 state_dict
中的項,將每個項轉換為 Numpy ndarray
,并返回這些 ndarray
的列表。
# 獲取客戶端模型的更新權重
def get_weights(net):return [val.cpu().numpy() for _, val in net.state_dict().items()]
set_weights
:此函數用于在訓練開始之前,使用從中央服務器收到的新權重更新客戶端模型的權重。
它接受一個機器學習模型的引用和一個 ndarray
列表。使用這個列表,它更新模型 state_dict
中的所有項。
from collections import OrderedDict# 更新客戶端模型的權重
def set_weights(net, parameters):params_dict = zip(net.state_dict().keys(), parameters)state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})net.load_state_dict(state_dict, strict=True)
接下來,我們定義一個 FlowerClient
類,它將幫助我們在客戶端上訓練和評估模型。
from flwr.client import NumPyClient
from typing import Dict
from flwr.common import NDArrays, Scalar# Flower 客戶端
class FlowerClient(NumPyClient):def __init__(self, net, trainset, valset, testset):self.net = netself.trainset = trainsetself.valset = valsetself.testset = testset# 本地訓練def fit(self, parameters, config):set_weights(self.net, parameters)# 數據加載器train_loader = DataLoader(self.trainset, batch_size=64, shuffle=True)val_loader = DataLoader(self.valset, batch_size=64)# 損失函數和優化器criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(self.net.parameters(), lr=0.001)train_model(self.net,criterion,optimizer,train_loader,val_loader,epochs= 1,)return get_weights(self.net), len(self.trainset), {}# 本地評估def evaluate(self, parameters, config):set_weights(self.net, parameters)loss, acc = evaluate_model(self.net, DataLoader(self.testset, batch_size=64))return loss, len(self.testset), {"accuracy": acc}
client_fn
函數幫助我們根據需要創建此類的實例。
ClientApp
作為客戶端邏輯的入口點,當 Flower 客戶端從中央服務器接收任務時運行。
from flwr.client import Client, ClientApp
from flwr.common import Contexttrain_sets = [dataset_A, dataset_B, dataset_C]# 創建客戶端的函數
def client_fn(context: Context) -> Client:cid = int(context.node_config["partition-id"])trainset = train_sets[cid]return FlowerClient(get_resnet_model(),trainset,val_dataset,test_dataset,).to_client()client = ClientApp(client_fn)
這就是客戶端需要的所有內容。
定義服務器函數
接下來,我們定義一個 evaluate
函數,中央服務器在每輪聯邦學習之后使用它來評估全局模型。
我們還定義了一個名為 filter_by_classes
的函數,它返回一個測試集的子集,其中只包含指定類別列表中的樣本。
這有助于我們在每個客戶端可用的類別子集上測試模型。
def filter_by_classes(dataset, class_list):indices = [i for i, (_, label) in enumerate(dataset) if label.item() in class_list]return Subset(dataset, indices)# 包含:CNV、DRUSEN、NORMAL - 排除:DME
testset_no_dme = filter_by_classes(test_dataset, [0, 2, 3])# 包含:CNV、DME、NORMAL - 排除:DRUSEN
testset_no_drusen = filter_by_classes(test_dataset, [0, 1, 3])# 包含:DME、DRUSEN、NORMAL - 排除:CNV
testset_no_cnv = filter_by_classes(test_dataset, [1, 2, 3])
# 評估全局模型
def evaluate(server_round, parameters, config, num_rounds = 20):net = get_resnet_model()set_weights(net, parameters)batch_size = 64acc_tot = evaluate_model(net, DataLoader(test_dataset, batch_size=batch_size))acc_A = evaluate_model(net, DataLoader(testset_no_dme, batch_size=batch_size))acc_B = evaluate_model(net, DataLoader(testset_no_drusen, batch_size=batch_size))acc_C = evaluate_model(net, DataLoader(testset_no_cnv, batch_size=batch_size))print(f"[Round {server_round}] 全局準確率: {acc_tot:.4f}")print(f"[Round {server_round}] (CNV,DRUSEN,NORMAL) 準確率: {acc_A:.4f}")print(f"[Round {server_round}] (CNV,DME,NORMAL) 準確率: {acc_B:.4f}")print(f"[Round {server_round}] (DME,DRUSEN,NORMAL) 準確率: {acc_C:.4f}")# 在最后一輪繪制混淆矩陣if server_round == num_rounds:acc_final, preds_final, labels_final = evaluate_on_test(net, DataLoader(test_dataset, batch_size=64))plot_confusion_matrix(labels_final, preds_final, "最終全局混淆矩陣")
接下來,使用 server_fn
函數,我們設置中央服務器,它使用聯邦平均聚合策略。
from flwr.common import ndarrays_to_parameters
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
from flwr.server.strategy import FedAvgnet = get_resnet_model()
params = ndarrays_to_parameters(get_weights(net))# 設置全局服務器的函數
def server_fn(context: Context, num_rounds = 5):# 聯邦平均策略strategy = FedAvg(fraction_fit=1.0,fraction_evaluate=0.0,initial_parameters=params,evaluate_fn=evaluate,)config=ServerConfig(num_rounds)return ServerAppComponents(strategy=strategy,config=config,)server = ServerApp(server_fn=server_fn)
現在我們已經準備好訓練我們的機器學習模型了。
訓練與評估
為了模擬在三個客戶端上的訓練,我們使用 run_simulation
函數如下:
from flwr.simulation import run_simulation
from logging import ERROR# 為了保持日志輸出簡潔
backend_setup = {"init_args": {"logging_level": ERROR, "log_to_driver": False}}# 運行訓練模擬
run_simulation(server_app=server,client_app=client,num_supernodes=3,backend_config=backend_setup,
)
以下是經過 20 輪聯邦學習后的結果。
INFO : aggregate_fit: received 3 results and 0 failures
[Round 20] 全局準確率: 0.7710
[Round 20] (CNV,DRUSEN,NORMAL) 準確率: 0.7933
[Round 20] (CNV,DME,NORMAL) 準確率: 0.8947
[Round 20] (DME,DRUSEN,NORMAL) 準確率: 0.6960
模型在每個客戶端的標簽分布過濾后的測試數據集上表現良好(如三個測試子集的準確率所示)。
最棒的是,盡管每個客戶端在其本地數據集中缺少一個疾病標簽,全局模型仍然能夠很好地識別所有標簽(如完整測試集上的全局準確率所示)。
請注意,訓練并不完美,還需要進一步優化和調整超參數以獲得更好的結果。
數據集本身也存在類別不平衡問題,正常 OCT 圖像的樣本最多,而玻璃膜疣(Drusen)的樣本最少,這可能解釋了在最終全局混淆矩陣中對這一標簽的誤分類。
我們可以通過繪制 OCTMNIST 訓練集中的類別分布來觀察類別不平衡。
# 檢查類別不平衡import matplotlib.pyplot as plt
from collections import Counter# 統計類別出現次數
labels = [label.item() for _, label in train_dataset]
class_counts = Counter(labels)
print(class_counts)# 準備繪圖數據
class_names = ['CNV', 'DME', 'DRUSEN', 'NORMAL']
counts = [class_counts[i] for i in range(4)]# 繪圖
plt.figure(figsize=(8, 5))
plt.bar(class_names, counts)
plt.title("OCTMNIST 訓練集中的類別分布")
plt.xlabel("類別")
plt.ylabel("樣本數量")
plt.grid(axis='y', linestyle='--', alpha=0.6)
plt.tight_layout()
plt.show()
閱讀參考
- Flower 框架文檔
- DeepLearning.ai 上的“聯邦學習入門”課程
- 具有差分隱私的 Gboard 語言模型的聯邦學習