利用“Flower”實現聯邦機器學習的實戰指南

一個很尷尬的現狀就是我們用于訓練 AI 模型的數據快要用完了。所以我們在大量的使用合成數據!

據估計,目前公開可用的高質量訓練標記大約有 40 萬億到 90 萬億個,其中流行的 FineWeb 數據集包含 15 萬億個標記,僅限于英語。

作為參考,最近發布的 Llama 4 在文本、圖像和視頻數據集上進行了預訓練,使用的標記數量超過 30 萬億個,是 Llama 3 的兩倍多。

這讓我們意識到,我們距離訓練數據達到極限可能只有幾年的時間了。

但那真的是極限嗎?私人數據集呢?

這些數據集的規模可能是公開數據集的 10 到 20 倍(甚至更多),所有存儲的消息中大約有 650 萬億個標記,電子郵件中大約有 1200 萬億個標記。

令人驚訝的是,許多公司收集的大量數據從未被分析過,因此被稱為暗數據(Dark data)。

再想想政府機構、醫院、律師事務所、金融機構、用戶設備等存儲的數據。

我同意這些數據是敏感的,而且有嚴格的數據保護法規來規范其處理方式。

其中大部分數據可能確實不適合用于訓練機器學習模型,但肯定有一部分數據可以為人類和組織帶來巨大價值。

如果有一種方法可以在不共享數據本身的情況下,使用多個組織的敏感合規數據來訓練機器學習模型,那該多好啊!

這就是聯邦機器學習(Federated Machine Learning)的用武之地!

接下來,我們將深入探討聯邦機器學習是什么,它是如何工作的,然后編寫一個聯邦學習流程,使用多個醫療機構的數據安全地訓練一個可以檢測眼部疾病的機器學習模型。

讓我們開始吧!

但是,聯邦機器學習到底是什么?

為了理解聯邦機器學習是什么,我們先來看看傳統的機器學習模型訓練方法。

舉個例子,我們想訓練一個可以檢測 CT 掃描圖像中癌癥的機器學習模型。

第一步是收集來自不同地理位置的多家醫院的正常和癌癥患者的 CT 掃描圖像。

None

選擇多樣化數據源的原因是:

  • 增加樣本量;
  • 減少由于不同因素(包括人口統計、專家和機構因素)導致的偏差。

這使得我們的模型即使對于訓練數據集中未充分代表的群體也能具有泛化能力。

一旦這些數據被收集到一個中央的強大服務器上,我們就可以使用這些數據來訓練模型并對其進行評估。

None

你能發現這種方法有哪些問題,使得執行起來幾乎不可能嗎?

首先,敏感的醫療數據受到法律(如GDPRHIPAA)的嚴格監管,這使得將這些數據傳輸到中央服務器變得非常困難。

其次,中央服務器必須有足夠的計算和存儲資源來處理這些數據和訓練,這使得這種方法非常昂貴。

如果我們反過來,不是把數據移動到訓練中,而是把訓練移動到數據那里呢?

這就是聯邦機器學習做的事情。

聯邦機器學習是一種機器學習技術,多個組織可以在去中心化的方式下協作訓練機器學習模型,而無需共享他們的數據集。

以下是使用這種方法的步驟:

  1. 在中央服務器上初始化一個基礎/全局模型。

None

  1. 將該模型的參數發送到參與組織的服務器(稱為客戶端/節點),這些服務器包含本地數據。

None

  1. 每個客戶端在其本地數據上訓練模型一段時間(不是直到模型收斂,而是進行幾步/一到幾個周期)。

None

  1. 在本地訓練完成后,每個客戶端將其模型參數或累積的梯度發送回中央服務器。

None

  1. 由于每個客戶端的參數因在不同的本地數據集上訓練而與其他客戶端不同,因此需要通過一個稱為聚合的過程將它們結合起來。聚合的結果用于更新基礎/全局模型的參數。

可以使用多種技術進行聚合,其中一種流行的方法是聯邦平均(Federated Averaging)。

更新后的全局模型參數 = ∑ i = 1 N 客戶端? i 的更新參數 × 客戶端? i 的數據量 ∑ i = 1 N 客戶端? i 的數據量 \text{更新后的全局模型參數} = \frac{\sum_{i=1}^{N} \text{客戶端 } i \text{ 的更新參數} \times \text{客戶端 } i \text{ 的數據量}}{\sum_{i=1}^{N} \text{客戶端 } i \text{ 的數據量}} 更新后的全局模型參數=i=1N?客戶端?i?的數據量i=1N?客戶端?i?的更新參數×客戶端?i?的數據量?

在這種方法中,不同客戶端的更新會被平均,并根據每個客戶端用于訓練的數據點數量進行加權

None

  1. 更新后的基礎模型參數被發送回客戶端,然后重復上述訓練過程,直到獲得一個完全訓練好的模型。

None

你有沒有注意到聯邦機器學習帶來的優勢?

首先,數據保留在其生成的地方,從未被傳輸到一個中央位置,這使得這種方法是去中心化的。

其次,減少了對單一強大基礎設施的需求,因為計算是在所有參與服務器之間共享的。

最后,我們有一種稱為差分隱私(Differential Privacy)的技術,可以保護客戶端數據的隱私。

這是一種技術,通過它,無法從聯邦學習過程中共享的模型更新中識別出關于單個數據點的敏感信息。

為了實現差分隱私,使用了兩種過程:

  • 裁剪客戶端模型更新,以限制單個數據點的影響;
  • 加噪,即在裁剪后的更新中添加校準后的噪聲。

根據這些過程發生的位置,我們有:

  • 中央差分隱私:中央服務器在全局參數上進行加噪,這些全局參數是通過接收客戶端的裁剪更新進行聚合的,或者是由中央服務器進行裁剪的。

None

  • 本地差分隱私:每個客戶端在將模型更新發送到中央服務器之前,本地應用裁剪和加噪。

None

訓練你的第一個聯邦學習機器學習模型

現在你已經了解了聯邦機器學習的基礎知識,是時候動手實踐并編寫一些代碼了。

視網膜疾病影響著全球數億人,是導致視力喪失和失明的主要原因之一。

**光學相干斷層掃描(OCT)**可以為我們提供視網膜及其他眼層的詳細橫截面圖像。

利用這些圖像,我們的目標是訓練一個機器學習模型,能夠區分健康視網膜和受疾病影響的視網膜。

本教程中的所有代碼都是使用 PyTorch 框架在 Jupyter 筆記本中編寫的。

下載并探索數據集

我們將使用的數據集名為 OCTMNIST。

它是 MedMNIST 數據集的一個子集,包含 109,309 張大小為 28 × 28 像素的灰度、居中裁剪的視網膜 OCT 圖像。

OCTMNIST 是一個多分類數據集,包含以下類別/標簽:

  1. 脈絡膜新生血管(CNV)
  2. 糖尿病性黃斑水腫(DME)
  3. 玻璃膜疣(Drusen)
  4. 正常

我們先在 Jupyter 筆記本中安裝 medmnist 包,并獲取有關 OCTMNIST 數據集的一些信息。

!uv pip install medmnist
from medmnist import INFO# 獲取 OCTMNIST 數據集信息
info = INFO["octmnist"]print("數據集類型: ", info["task"])
print("數據集標簽: ", info["label"])
print("圖像通道數: ", info["n_channels"])
print("訓練樣本數量: ", info["n_samples"]["train"])
print("驗證樣本數量: ", info["n_samples"]["val"])
print("測試樣本數量: ", info["n_samples"]["test"])

輸出結果如下:

數據集類型:  多分類
數據集標簽:  {'0': '脈絡膜新生血管','1': '糖尿病性黃斑水腫', '2': '玻璃膜疣', '3': '正常'}
圖像通道數:  1
訓練樣本數量:  97477
驗證樣本數量:  10832
測試樣本數量:  1000

接下來,我們下載這個數據集,對其應用轉換,并繪制其中的一些圖像。

!uv pip install torch torchvision
import torch# 如果可用,使用 GPU
if torch.backends.mps.is_available():device = torch.device("mps")
elif torch.cuda.is_available():device = torch.device("cuda")
else:device = torch.device("cpu")print(f"使用設備: {device}")
from torchvision import transforms
from medmnist import OCTMNIST# 定義轉換
transform = transforms.ToTensor()# 下載數據集,大小為 64 x 64
train_dataset = OCTMNIST(split='train', transform=transform, download=True, size=64)
val_dataset = OCTMNIST(split='val', transform=transform, download=True, size=64)
test_dataset = OCTMNIST(split='test', transform=transform, download=True, size=64)
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt# 定義標簽映射
label_map = {0: '脈絡膜新生血管',1: '糖尿病性黃斑水腫',2: '玻璃膜疣',3: '正常'
}# 從數據加載器中獲取一批數據
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
images, labels = next(iter(train_loader))# 在 3 x 3 網格中繪制
rows, cols = 3, 3
fig, axes = plt.subplots(rows, cols, figsize=(5, 5))for i in range(rows * cols):ax = axes[i // cols, i % cols]ax.imshow(images[i][0], cmap='gray')ax.set_title(label_map[int(labels[i].item())], fontsize=6)ax.axis('off')plt.tight_layout()
plt.show()

None

將數據集拆分為子集

現實世界中的醫療數據通常存在類別不平衡和偏差。

為了模擬這種情況,我們將 OCTMNIST 數據集拆分為三個子集。

可以將這些子集視為屬于三家不同醫院的數據集,每個子集都排除了一種眼部疾病標簽。

from torch.utils.data import Subset# 創建子數據集
def create_sub_datasets(full_dataset):targets = torch.tensor([label.item() for _, label in full_dataset])mask_A = (targets == 0) | (targets == 2) | (targets == 3)mask_B = (targets == 0) | (targets == 1) | (targets == 3)mask_C = (targets == 1) | (targets == 2) | (targets == 3)indices_A = mask_A.nonzero(as_tuple=True)[0]indices_B = mask_B.nonzero(as_tuple=True)[0]indices_C = mask_C.nonzero(as_tuple=True)[0]dataset_A = Subset(train_dataset, indices_A)  # 包含:CNV、DRUSEN、NORMAL(排除 DME)dataset_B = Subset(train_dataset, indices_B)  # 包含:CNV、DME、NORMAL(排除 DRUSEN)dataset_C = Subset(train_dataset, indices_C)  # 包含:DME、DRUSEN、NORMAL(排除 CNV)return [dataset_A, dataset_B, dataset_C]dataset_A, dataset_B, dataset_C = create_sub_datasets(train_dataset)

接下來,我們定義一個 ResNet-18 模型,用于將圖像分類到相應的類別中。

from torchvision.models import resnet18, ResNet18_Weights
import torch.nn as nn# ResNet-18
def get_resnet_model(num_classes=4):model = resnet18(weights=ResNet18_Weights.DEFAULT)# 修改第一層卷積層以接受 1 通道輸入model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)# 替換最終的全連接層model.fc = nn.Linear(model.fc.in_features, num_classes)return model.to(device)
訓練與評估

為了模擬在每個醫院使用本地數據進行訓練,我們在之前定義的子數據集上分別訓練三個 ResNet 模型。

以下是訓練和評估的函數。

!uv pip install tqdm
import torch.optim as optim
from tqdm import tqdm# 訓練函數
def train_model(model, criterion, optimizer, train_loader, val_loader, epochs=10):for epoch in range(epochs):model.train()running_correct, running_total = 0, 0loop = tqdm(train_loader, desc=f"Epoch [{epoch+1}/{epochs}]", leave=False)for images, labels in loop:images = images.to(device)labels = labels.squeeze().long().to(device)optimizer.zero_grad()outputs = model(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()preds = torch.argmax(outputs, dim=1)running_correct += (preds == labels).sum().item()running_total += labels.size(0)loop.set_postfix(loss=loss.item(), acc=running_correct / running_total)train_acc = running_correct / running_totalval_acc = evaluate_model(model, val_loader)print(f"Epoch [{epoch+1}/{epochs}]  Train Acc: {train_acc:.4f}  Val Acc: {val_acc:.4f}")# 評估函數
def evaluate_model(model, test_loader):model.eval()correct, total = 0, 0with torch.no_grad():for images, labels in test_loader:images = images.to(device)labels = labels.squeeze().to(device)outputs = model(images)preds = torch.argmax(outputs, dim=1)correct += (preds == labels).sum().item()total += labels.size(0)return correct / total
# 在子數據集上訓練的函數
def train_on_subset(subset_dataset, val_loader, epochs=10):loader = DataLoader(subset_dataset, batch_size=64, shuffle=True)model = get_resnet_model()criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=0.001)train_model(model, criterion, optimizer, loader, val_loader, epochs)return model
# 在子數據集上評估的函數
def evaluate_on_test(model, test_loader):model.eval()all_preds = []all_labels = []with torch.no_grad():for images, labels in test_loader:images = images.to(device)labels = labels.squeeze().to(device)outputs = model(images)preds = torch.argmax(outputs, dim=1)all_preds.extend(preds.cpu().numpy())all_labels.extend(labels.cpu().numpy())acc = sum([p == t for p, t in zip(all_preds, all_labels)]) / len(all_labels)return acc, all_preds, all_labels

是時候訓練這些模型了!

# 在子數據集上訓練模型
val_loader = DataLoader(val_dataset, batch_size=64)model_A = train_on_subset(dataset_A, val_loader)
model_B = train_on_subset(dataset_B, val_loader)
model_C = train_on_subset(dataset_C, val_loader)

訓練過程的輸出如下:

Epoch [1/10]  Train Acc: 0.9162  Val Acc: 0.8073
Epoch [2/10]  Train Acc: 0.9454  Val Acc: 0.8477
Epoch [3/10]  Train Acc: 0.9526  Val Acc: 0.8588
Epoch [4/10]  Train Acc: 0.9587  Val Acc: 0.8509
Epoch [5/10]  Train Acc: 0.9597  Val Acc: 0.8574
Epoch [6/10]  Train Acc: 0.9671  Val Acc: 0.8619
Epoch [7/10]  Train Acc: 0.9700  Val Acc: 0.8629
Epoch [8/10]  Train Acc: 0.9747  Val Acc: 0.8623
Epoch [9/10]  Train Acc: 0.9774  Val Acc: 0.8541
Epoch [10/10]  Train Acc: 0.9787  Val Acc: 0.8647
Epoch [1/10]  Train Acc: 0.9466  Val Acc: 0.8498
Epoch [2/10]  Train Acc: 0.9725  Val Acc: 0.8988
Epoch [3/10]  Train Acc: 0.9780  Val Acc: 0.8967
Epoch [4/10]  Train Acc: 0.9816  Val Acc: 0.9027
Epoch [5/10]  Train Acc: 0.9841  Val Acc: 0.9031
Epoch [6/10]  Train Acc: 0.9854  Val Acc: 0.8917
Epoch [7/10]  Train Acc: 0.9881  Val Acc: 0.9060
Epoch [8/10]  Train Acc: 0.9899  Val Acc: 0.9060
Epoch [9/10]  Train Acc: 0.9911  Val Acc: 0.9053
Epoch [10/10]  Train Acc: 0.9930  Val Acc: 0.9005
Epoch [1/10]  Train Acc: 0.9001  Val Acc: 0.6071
Epoch [2/10]  Train Acc: 0.9429  Val Acc: 0.6188
Epoch [3/10]  Train Acc: 0.9531  Val Acc: 0.6117
Epoch [4/10]  Train Acc: 0.9509  Val Acc: 0.6280
Epoch [5/10]  Train Acc: 0.9610  Val Acc: 0.6289
Epoch [6/10]  Train Acc: 0.9649  Val Acc: 0.6283
Epoch [7/10]  Train Acc: 0.9675  Val Acc: 0.6265
Epoch [8/10]  Train Acc: 0.9699  Val Acc: 0.6321
Epoch [9/10]  Train Acc: 0.9750  Val Acc: 0.6330
Epoch [10/10]  Train Acc: 0.9768  Val Acc: 0.6363

然后,我們在完整的測試集上測試這些模型的性能(這將是我們在實際應用中運行模型的情況)。

# 在完整測試集上評估
test_loader = DataLoader(test_dataset, batch_size=64)acc_A, preds_A, labels_A = evaluate_on_test(model_A, test_loader)
acc_B, preds_B, labels_B = evaluate_on_test(model_B, test_loader)
acc_C, preds_C, labels_C = evaluate_on_test(model_C, test_loader)# 報告準確率
print(f"測試準確率 | 在排除 DME 的數據集上訓練的模型: {acc_A:.4f}")
print(f"測試準確率 | 在排除 DRUSEN 的數據集上訓練的模型: {acc_B:.4f}")
print(f"測試準確率 | 在排除 CNV 的數據集上訓練的模型: {acc_C:.4f}")

輸出結果如下:

測試準確率 | 在排除 DME 的數據集上訓練的模型: 0.6420
測試準確率 | 在排除 DRUSEN 的數據集上訓練的模型: 0.7080
測試準確率 | 在排除 CNV 的數據集上訓練的模型: 0.7030

我們可以看到,這些模型在測試數據集中未見過的類別上表現不佳。

當我們繪制混淆矩陣并可視化結果時,這一點更加明顯。

import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplaydef plot_confusion_matrix(y_true, y_pred, title):cm = confusion_matrix(y_true, y_pred, labels=[0,1,2,3])disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["CNV", "DME", "DRUSEN", "NORMAL"])disp.plot(cmap=plt.cm.Blues)plt.title(title)plt.show()# 繪制混淆矩陣
plot_confusion_matrix(labels_A, preds_A, "混淆矩陣 - 排除 DME 的模型")
plot_confusion_matrix(labels_B, preds_B, "混淆矩陣 - 排除 DRUSEN 的模型")
plot_confusion_matrix(labels_C, preds_C, "混淆矩陣 - 排除 CNV 的模型")

None

None

None

在子數據集上進行聯邦學習

現在輪到聯邦學習大顯身手了。

我們使用 Flower 框架,它允許我們使用任何機器學習框架和任何編程語言進行聯邦學習、分析和評估。

我仍然使用 PyTorch 框架進行本教程,以使其對大多數人更易于理解。

我使用的是 MacBook M4 Max 來運行以下代碼,但如果你在 Google Colab 上使用 T4 GPU,這將報錯。

這是因為 Colab 只將一個 GPU 暴露給主筆記本進程。當 Flower(使用 Ray 作為其默認后端運行模擬)為每個模擬客戶端啟動額外的 Python 工作進程時,這些工作進程無法訪問 GPU,程序就會崩潰。

如果你想在 Google Colab 上運行代碼,建議使用 CPU 作為設備。不過,這會使訓練變得非常緩慢。

安裝 Flower 包
!uv pip install "flwr[simulation]"

回顧一下我們之前學到的內容,在聯邦學習過程中,客戶端和中央服務器之間會交換模型參數/權重。

當客戶端從中央服務器接收到模型參數時,它將使用這些參數/權重更新其本地模型。

訓練完成后,它將這些本地模型參數/權重發送回中央服務器。

定義客戶端函數

兩個函數可以幫助我們執行這些操作:

  • get_weights:此函數用于訓練完成后獲取客戶端模型的更新權重,并將其發送回中央服務器。

它接受一個機器學習模型的引用,迭代其 state_dict 中的項,將每個項轉換為 Numpy ndarray,并返回這些 ndarray 的列表。

# 獲取客戶端模型的更新權重
def get_weights(net):return [val.cpu().numpy() for _, val in net.state_dict().items()]
  • set_weights:此函數用于在訓練開始之前,使用從中央服務器收到的新權重更新客戶端模型的權重。

它接受一個機器學習模型的引用和一個 ndarray 列表。使用這個列表,它更新模型 state_dict 中的所有項。

from collections import OrderedDict# 更新客戶端模型的權重
def set_weights(net, parameters):params_dict = zip(net.state_dict().keys(), parameters)state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})net.load_state_dict(state_dict, strict=True)

接下來,我們定義一個 FlowerClient 類,它將幫助我們在客戶端上訓練和評估模型。

from flwr.client import NumPyClient
from typing import Dict
from flwr.common import NDArrays, Scalar# Flower 客戶端
class FlowerClient(NumPyClient):def __init__(self, net, trainset, valset, testset):self.net = netself.trainset = trainsetself.valset = valsetself.testset = testset# 本地訓練def fit(self, parameters, config):set_weights(self.net, parameters)# 數據加載器train_loader = DataLoader(self.trainset, batch_size=64, shuffle=True)val_loader = DataLoader(self.valset, batch_size=64)# 損失函數和優化器criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(self.net.parameters(), lr=0.001)train_model(self.net,criterion,optimizer,train_loader,val_loader,epochs= 1,)return get_weights(self.net), len(self.trainset), {}# 本地評估def evaluate(self, parameters, config):set_weights(self.net, parameters)loss, acc = evaluate_model(self.net, DataLoader(self.testset, batch_size=64))return loss, len(self.testset), {"accuracy": acc}

client_fn 函數幫助我們根據需要創建此類的實例。

ClientApp 作為客戶端邏輯的入口點,當 Flower 客戶端從中央服務器接收任務時運行。

from flwr.client import Client, ClientApp
from flwr.common import Contexttrain_sets = [dataset_A, dataset_B, dataset_C]# 創建客戶端的函數
def client_fn(context: Context) -> Client:cid = int(context.node_config["partition-id"])trainset = train_sets[cid]return FlowerClient(get_resnet_model(),trainset,val_dataset,test_dataset,).to_client()client = ClientApp(client_fn)

這就是客戶端需要的所有內容。

定義服務器函數

接下來,我們定義一個 evaluate 函數,中央服務器在每輪聯邦學習之后使用它來評估全局模型。

我們還定義了一個名為 filter_by_classes 的函數,它返回一個測試集的子集,其中只包含指定類別列表中的樣本。

這有助于我們在每個客戶端可用的類別子集上測試模型。

def filter_by_classes(dataset, class_list):indices = [i for i, (_, label) in enumerate(dataset) if label.item() in class_list]return Subset(dataset, indices)# 包含:CNV、DRUSEN、NORMAL - 排除:DME
testset_no_dme = filter_by_classes(test_dataset, [0, 2, 3])# 包含:CNV、DME、NORMAL - 排除:DRUSEN
testset_no_drusen = filter_by_classes(test_dataset, [0, 1, 3])# 包含:DME、DRUSEN、NORMAL - 排除:CNV
testset_no_cnv = filter_by_classes(test_dataset, [1, 2, 3])
# 評估全局模型
def evaluate(server_round, parameters, config, num_rounds = 20):net = get_resnet_model()set_weights(net, parameters)batch_size = 64acc_tot = evaluate_model(net, DataLoader(test_dataset, batch_size=batch_size))acc_A = evaluate_model(net, DataLoader(testset_no_dme, batch_size=batch_size))acc_B = evaluate_model(net, DataLoader(testset_no_drusen, batch_size=batch_size))acc_C = evaluate_model(net, DataLoader(testset_no_cnv, batch_size=batch_size))print(f"[Round {server_round}] 全局準確率: {acc_tot:.4f}")print(f"[Round {server_round}] (CNV,DRUSEN,NORMAL) 準確率: {acc_A:.4f}")print(f"[Round {server_round}] (CNV,DME,NORMAL)    準確率: {acc_B:.4f}")print(f"[Round {server_round}] (DME,DRUSEN,NORMAL) 準確率: {acc_C:.4f}")# 在最后一輪繪制混淆矩陣if server_round == num_rounds:acc_final, preds_final, labels_final = evaluate_on_test(net, DataLoader(test_dataset, batch_size=64))plot_confusion_matrix(labels_final, preds_final, "最終全局混淆矩陣")

接下來,使用 server_fn 函數,我們設置中央服務器,它使用聯邦平均聚合策略。

from flwr.common import ndarrays_to_parameters
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
from flwr.server.strategy import FedAvgnet = get_resnet_model()
params = ndarrays_to_parameters(get_weights(net))# 設置全局服務器的函數
def server_fn(context: Context, num_rounds = 5):# 聯邦平均策略strategy = FedAvg(fraction_fit=1.0,fraction_evaluate=0.0,initial_parameters=params,evaluate_fn=evaluate,)config=ServerConfig(num_rounds)return ServerAppComponents(strategy=strategy,config=config,)server = ServerApp(server_fn=server_fn)

現在我們已經準備好訓練我們的機器學習模型了。

訓練與評估

為了模擬在三個客戶端上的訓練,我們使用 run_simulation 函數如下:

from flwr.simulation import run_simulation
from logging import ERROR# 為了保持日志輸出簡潔
backend_setup = {"init_args": {"logging_level": ERROR, "log_to_driver": False}}# 運行訓練模擬
run_simulation(server_app=server,client_app=client,num_supernodes=3,backend_config=backend_setup,
)

以下是經過 20 輪聯邦學習后的結果。

INFO : aggregate_fit: received 3 results and 0 failures
[Round 20] 全局準確率: 0.7710
[Round 20] (CNV,DRUSEN,NORMAL) 準確率: 0.7933
[Round 20] (CNV,DME,NORMAL)    準確率: 0.8947
[Round 20] (DME,DRUSEN,NORMAL) 準確率: 0.6960

在這里插入圖片描述

模型在每個客戶端的標簽分布過濾后的測試數據集上表現良好(如三個測試子集的準確率所示)。

最棒的是,盡管每個客戶端在其本地數據集中缺少一個疾病標簽,全局模型仍然能夠很好地識別所有標簽(如完整測試集上的全局準確率所示)。

請注意,訓練并不完美,還需要進一步優化和調整超參數以獲得更好的結果。

數據集本身也存在類別不平衡問題,正常 OCT 圖像的樣本最多,而玻璃膜疣(Drusen)的樣本最少,這可能解釋了在最終全局混淆矩陣中對這一標簽的誤分類。

我們可以通過繪制 OCTMNIST 訓練集中的類別分布來觀察類別不平衡。

# 檢查類別不平衡import matplotlib.pyplot as plt
from collections import Counter# 統計類別出現次數
labels = [label.item() for _, label in train_dataset]
class_counts = Counter(labels)
print(class_counts)# 準備繪圖數據
class_names = ['CNV', 'DME', 'DRUSEN', 'NORMAL']
counts = [class_counts[i] for i in range(4)]# 繪圖
plt.figure(figsize=(8, 5))
plt.bar(class_names, counts)
plt.title("OCTMNIST 訓練集中的類別分布")
plt.xlabel("類別")
plt.ylabel("樣本數量")
plt.grid(axis='y', linestyle='--', alpha=0.6)
plt.tight_layout()
plt.show()

None

閱讀參考

  • Flower 框架文檔
  • DeepLearning.ai 上的“聯邦學習入門”課程
  • 具有差分隱私的 Gboard 語言模型的聯邦學習

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/905256.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/905256.shtml
英文地址,請注明出處:http://en.pswp.cn/news/905256.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

自動化測試與功能測試詳解

🍅 點擊文末小卡片,免費獲取軟件測試全套資料,資料在手,漲薪更快 什么是自動化測試? 自動化測試是指利用軟件測試工具自動實現全部或部分測試,它是軟件測試的一個重要組成 部分,能完成許多手工測試無…

MySQL全量,增量備份與恢復

目錄 一.MySQL數據庫備份概述 1.數據備份的重要性 2.數據庫備份類型 3.常見的備份方法 二:數據庫完全備份操作 1.物理冷備份與恢復 2.mysqldump 備份與恢復 3.MySQL增量備份與恢復 3.1MySQL增量恢復 3.2MySQL備份案例 三:定制企業備份策略思路…

Ubuntu 安裝 Nginx

Nginx 是一個高性能的 Web 服務器和反向代理服務器,同時也可以用作負載均衡器和 HTTP 緩存。 Nginx 的主要用途 用途說明Web服務器提供網頁服務,處理用戶的 HTTP 請求,返回 HTML、CSS、JS、圖片等靜態資源。反向代理服務器將用戶請求轉發到…

人工智能 機器學習期末考試題

自測試卷2 一、選擇題 1.下面哪個屬性不是NumPy中數組的屬性( )。 A.ndim B.size C.shape D.add 2.一個簡單的Series是由( )的數據組成的。 A.兩…

使用阿里云CLI調用OpenAPI

介紹使用阿里云CLI調用OpenAPI的具體操作流程,包括安裝、配置憑證、生成并調用命令等步驟。 方案概覽 使用阿里云CLI調用OpenAPI,大致分為四個步驟: 安裝阿里云CLI:根據您使用設備的操作系統,選擇并安裝相應的版本。…

K8S Svc Port-forward 訪問方式

在 Kubernetes 中,kubectl port-forward 是一種 本地與集群內資源(Pod/Service)建立臨時網絡隧道 的訪問方式,無需暴露服務到公網,適合開發調試、臨時訪問等場景。以下是詳細使用方法及注意事項: 1. 基礎用…

23、DeepSeek-V2論文筆記

DeepSeek-V2 1、背景2、KV緩存優化2.0 KV緩存(Cache)的核心原理2.1 KV緩存優化2.2 性能對比2.3 架構2.4多頭注意力 (MHA)2.5 多頭潛在注意力 (MLA)2.5.1 低秩鍵值聯合壓縮 (Low-Rank Key-Value …

MySQL OCP試題解析(2)

試題如下圖所示: 一、題目背景還原 假設存在以下MySQL用戶權限配置: -- 創建本地會計用戶CREATE USER accountinglocalhost IDENTIFIED BY acc_123;-- 創建匿名代理用戶(用戶名為空,允許任意主機)CREATE USER % IDENTI…

深度學習Y7周:YOLOv8訓練自己數據集

🍨 本文為🔗365天深度學習訓練營中的學習記錄博客🍖 原作者:K同學啊 一、配置環境 1.官網下載源碼 2.安裝需要環境 二、準備好自己的數據 目錄結構: 主目錄 data images(存放圖片) annotati…

英偉達Blackwell架構重構未來:AI算力革命背后的技術邏輯與產業變革

——從芯片暴力美學到分布式智能體網絡,解析英偉達如何定義AI基礎設施新范式 開篇:當算力成為“新石油”,英偉達的“煉油廠”如何升級? 2025年3月,英偉達GTC大會上,黃仁勛身披標志性皮衣,宣布了…

CurrentHashMap的整體系統介紹及Java內存模型(JVM)介紹

當我們提到ConurrentHashMap時,先想到的就是HashMap不是線程安全的: 在多個線程共同操作HashMap時,會出現一個數據不一致的問題。 ConcurrentHashMap是HashMap的線程安全版本。 它通過在相應的方法上加鎖,來保證多線程情況下的…

Android開發-設計規范

在Android應用開發中,遵循良好的設計規范不僅能夠提升用戶體驗,還能確保代碼的可維護性和擴展性。本文將從用戶界面(UI)、用戶體驗(UX)、性能優化以及代碼結構等多個維度探討Android開發中的設計規范&#…

泛型加持的策略模式:打造高擴展的通用策略工具類

一、傳統策略模式的痛點與突破 1.1 傳統策略實現回顧 // 傳統支付策略接口 public interface PaymentStrategy {void pay(BigDecimal amount); }// 具體策略實現 public class AlipayStrategy implements PaymentStrategy {public void pay(BigDecimal amount) { /* 支付寶支…

物聯網從HomeAssistant開始

文章目錄 一、什么是home-assistant?1.核心架構2.集成架構 二、在樹梅派5上安裝home-assistant三、接入米家1.對比下趨勢2.手動安裝插件3.配置方式 四、接入公牛1.手動安裝插件2.配置方式 五、接入海爾1.手動安裝插件2.配置方式 六、接入國家電網 一、什么是home-assistant? …

系統架構-嵌入式系統架構

原理與特征 嵌入式系統的典型架構可概括為兩種模式,即層次化模式架構和遞歸模式架構 層次化模式架構,位于高層的抽象概念與低層的更加具體的概念之間存在著依賴關系,封閉型層次架構指的是,高層的對象只能調用同一層或下一層對象…

計算機圖形學編程(使用OpenGL和C++)(第2版)學習筆記 09.天空和背景

天空和背景 對于 3D 場景,通常可以通過在遠處的地平線附近創造一些逼真的效果,來增強其真實感。我們可以采用天空盒、天空柱(Skydome)或天空穹(Skydome)等技術來模擬天空。 天空盒 天空盒(Sk…

【Leetcode 每日一題】1550. 存在連續三個奇數的數組

問題背景 給你一個整數數組 a r r arr arr,請你判斷數組中是否存在連續三個元素都是奇數的情況:如果存在,請返回 t r u e true true;否則,返回 f a l s e false false。 數據約束 1 ≤ a r r . l e n g t h ≤ 10…

面試題解析 | C++空類的默認成員函數(附生成條件與底層原理)

在C面試中,“空類默認生成哪些成員函數”是考察對象模型和編譯器行為的高頻題目。許多資料僅提及前4個函數,但完整的答案應包含6個核心函數,并結合C標準深入解析其生成規則與使用場景。 一、空類默認生成的6大成員函數 1. ?缺省構造函數? …

視頻編解碼學習7之視頻編碼簡介

視頻編碼技術發展歷程與主流編碼標準詳解 視頻編碼技術是現代數字媒體領域的核心技術之一,它通過高效的壓縮算法大幅減少了視頻數據的體積,使得視頻的存儲、傳輸和播放變得更加高效和經濟。從早期的H.261標準到最新的AV1和H.266/VVC,視頻編碼…

使用Stable Diffusion(SD)中,步數(Steps)指的是什么?該如何使用?

Ⅰ定義: 在Stable Diffusion(SD)中,步數(Steps) 指的是采樣過程中的迭代次數,也就是模型從純噪聲一步步“清晰化”圖像的次數。你可以理解為模型在畫這張圖時“潤色”的輪數。 Ⅱ步數的具體作…