聯邦學習框架
本文主要期望介紹一個設計良好的聯邦學習框架 Flower,在開始介紹 Flower 框架的細節前,先了解下聯邦學習框架的基礎知識。
作為一個聯邦學習框架,必然會包含對橫向聯邦學習的支持。橫向聯邦是指擁有類似數據的多方可以在不泄露數據的情況下聯合訓練出一個模型,這個模型可以充分利用各方的數據,接近將全部數據集中在一起進行訓練的效果。橫向聯邦學習的一般流程如下:
橫向聯邦學習的過程簡單理解如下:
- 各個參與方基于自身的數據訓練出本地模型,將模型參數發送給公共的服務端;
- 服務端將收到的多個模型參數聚合生成全局的模型參數,然后下發給各個參與方;
- 參與方使用全局的模型參數更新本地模型,重復這個步驟直到模型訓練達到要求;
從上面的過程可以看到,作為一個聯邦學習框架,需要關注下面要點:
- 參與方本地模型訓練;
- 模型參數的傳輸;
- 模型的聚合策略;
Flower 框架上手
Flower 是一個輕量的聯邦學習框架,提出于 2020 年。一直以來,因為設計良好,方便擴展受到了比較多的關注。團隊通過論文 FLOWER: A FRIENDLY FEDERATED LEARNING FRAMEWORK 介紹了框架的設計思想。通過論文可以看到框架設計主要追求下面目標:
- 可拓展,支持大量的客戶端同時進行模型訓練;
- 使用靈活,支持異構的客戶端,通信協議,隱私策略,支持新功能的開銷小;
首先基于 Flower 框架實際進行了一個機器學習的模型訓練,通過實際動手可以感受下基于 Flower 框架可以用相當簡單的方式實現一個聯邦學習模型訓練流程。這個流程是參考 Flower Quickstart 實現的。
常規機器學習部分實現
首先實現機器學習訓練所需的基礎方法,主要是數據集的準備,定義所需的模型,封裝訓練與測試流程,這部分與聯邦學習無關。熟悉 pytorch 的應該很容易理解這部分代碼:
from collections import OrderedDictimport flwr as fl
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, Normalize, ToTensor
from tqdm import tqdmDEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# 定義神經網絡模型class Net(nn.Module):def __init__(self) -> None:super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x: torch.Tensor) -> torch.Tensor:x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 16 * 5 * 5)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))return self.fc3(x)# 定義模型訓練流程def train(net, trainloader, epochs):criterion = torch.nn.CrossEntropyLoss()optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)for _ in range(epochs):for images, labels in tqdm(trainloader):optimizer.zero_grad()criterion(net(images.to(DEVICE)), labels.to(DEVICE)).backward()optimizer.step()# 定義模型推理流程def test(net, testloader):criterion = torch.nn.CrossEntropyLoss()correct, loss = 0, 0.0with torch.no_grad():for images, labels in tqdm(testloader):outputs = net(images.to(DEVICE))labels = labels.to(DEVICE)loss += criterion(outputs, labels).item()correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()accuracy = correct / len(testloader.dataset)return loss, accuracy# 定義數據集的獲取def load_data():"""Load CIFAR-10 (training and test set)."""trf = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])trainset = CIFAR10("./data", train=True, download=True, transform=trf)testset = CIFAR10("./data", train=False, download=True, transform=trf)return DataLoader(trainset, batch_size=32, shuffle=True), DataLoader(testset)# 生成模型對象,實際獲取訓練與測試數據集net = Net().to(DEVICE)
trainloader, testloader = load_data()
如果是常規的機器學習模型訓練,直接調用上面的 train(net, trainloader, epochs=1)
即可完成模型的訓練,但是此時各方只能使用本地有限的數據訓練出一個機器學習模型。如果希望能充分利用各方的數據訓練出一個更好的機器學習模型,就需要基于 Flower 補充聯邦學習的能力了。
Flower 客戶端實現
Flower 的客戶端表示的是各個參與聯合訓練的一方,Flower 客戶端會完成本地的模型訓練,并將本地訓練的模型發送給服務端,然后接收服務端下發的聚合模型。需要實現如下:
class FlowerClient(fl.client.NumPyClient):# 獲取本地模型對應的參數def get_parameters(self, config):return [val.cpu().numpy() for _, val in net.state_dict().items()]# 接收模型參數,并更新本地模型def set_parameters(self, parameters):params_dict = zip(net.state_dict().keys(), parameters)state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})net.load_state_dict(state_dict, strict=True)# 本地模型訓練,會先調用 set_parameters() 基于收到的全局模型參數更新本地模型def fit(self, parameters, config):self.set_parameters(parameters)train(net, trainloader, epochs=1)return self.get_parameters(config={}), len(trainloader.dataset), {}# 基于測試數據集進行測試def evaluate(self, parameters, config):self.set_parameters(parameters)loss, accuracy = test(net, testloader)return loss, len(testloader.dataset), {"accuracy": accuracy}# 啟動 Flower 客戶端fl.client.start_numpy_client(server_address="127.0.0.1:8080",client=FlowerClient(),
)
可以看到 Flower 客戶端的訓練流程中,除了實現 fit()
方法進行本地模型的訓練之外,只需要額外實現兩個方法,使用 get_parameters()
獲取了本地模型的參數,方便發送本地模型參數給服務端用于模型聚合,使用 set_parameters()
方便根據服務端發來的聚合模型參數更新本地模型。
Flower 服務端實現
Flower 服務端主要用于接收客戶端發來的模型參數,聚合參數后下發給客戶端,具體實現如下:
from typing import List, Tupleimport flwr as fl
from flwr.common import Metrics# 定義指標聚合方法def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]examples = [num_examples for num_examples, _ in metrics]return {"accuracy": sum(accuracies) / sum(examples)}# 定義模型聚合策略strategy = fl.server.strategy.FedAvg(evaluate_metrics_aggregation_fn=weighted_average)# 啟動 Flower 服務端fl.server.start_server(server_address="0.0.0.0:8080",config=fl.server.ServerConfig(num_rounds=3),strategy=strategy,
)
可以看到 Flower 服務端只需要定義模型聚合的策略,由于機器學習的主要流程是固定的,因此不需要手工實現。
通過上面的實現即可完成聯邦學習模型訓練,實際需要先啟動 Flower 服務端,然后可以根據需要啟動多個 Flower 客戶端,即可進行聯邦學習模型訓練。
Flower 框架設計
Flower 框架是如何設計從而實現上面的便利使用的呢?可以先看看官方的架構圖:
上面的框架中 RPC Server 就是服務端,作為中心節點協調大量的客戶端進行聯合的模型訓練。服務端中包含三大核心組件:
- ClientManager,用于管理現有的客戶端,提供 ClientProxy 作為客戶端的抽象,方便統一處理,利用 ClientProxy 進行實際數據的發送,基于 ClientProxy 消除了客戶端的異構性;
- Strategy,用于確定每個階段如何執行。比如在訓練階段,通過 Strategy 確定采取什么樣的模型聚合策略。Strategy 是可以由用戶自定義的;
- FL loop,定義整個執行機器學習流程,因為機器學習的訓練與預測流程都是固定的,因此根據確定的任務即可確定對應的 FL loop 流程;
Flower 源碼實現
通過上面的框架我們可以理解 Flower 是如何實現靈活性,接下來我們可以關注具體的源碼實現,具體了解 Flower 框架如何將我們實現的客戶端與服務端串聯起來,并了解 Flower 框架的支撐組件的具體信息。
Flower 客戶端
Flower 的客戶端的基礎類 NumPyClient
主要定義了一些必要的方法,這些方法需要在子類中實現,這個在上面的動手階段已經體現出來,沒有太多需要深入了解的。我們主要關注下客戶端的啟動方法 start_numpy_client()
,這部分的實現主要在 src/py/flwr/client/app.py
中的 start_client()
方法中完成,具體的實現如下:
def start_client(*,server_address: str,client_fn: Optional[ClientFn] = None,client: Optional[Client] = None,grpc_max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,root_certificates: Optional[Union[bytes, str]] = None,transport: Optional[str] = None,
) -> None:# 初始化連接上下文connection, address = _init_connection(transport, server_address)while True:sleep_duration: int = 0# 建立 RPC 連接with connection(address,grpc_max_message_length,root_certificates,) as conn:receive, send, create_node, delete_node = conn# 注冊當前客戶端if create_node is not None:create_node()while True:# 接收服務端發來的消息task_ins = receive()if task_ins is None:time.sleep(3)continue# 處理系統控制類消息,處理完控制類消息就退出task_res, sleep_duration = handle_control_message(task_ins=task_ins)if task_res:send(task_res)break# 處理普通任務中的消息task_res = handle(client_fn, task_ins)# 將執行結果發送給服務端send(task_res)# 注銷當前客戶端if delete_node is not None:delete_node()if sleep_duration == 0:breaktime.sleep(sleep_duration)
可以看到整體的實現流程還是比較簡單的,客戶端扮演一個被動的角色,通過 receive()
持續接收服務端發來的消息,系統消息通過 handle_control_message()
方法處理, 常規消息通過 handle()
方法進行處理,然后將處理的結果通過 send()
方法發送給服務端。
系統消息只包含一種 reconnect_ins
消息,主要用于客戶端重連或斷開,目前主要常規消息的處理,對應的處理方法 handle()
實現如下所示:
def handle(client_fn: ClientFn, task_ins: TaskIns) -> TaskRes:server_msg = get_server_message_from_task_ins(task_ins, exclude_reconnect_ins=False)# 臨時支持安全聚合的分支if server_msg is None:client = client_fn("-1")if task_ins.task.HasField("sa") and isinstance(client, SecureAggregationHandler):named_values = serde.named_values_from_proto(task_ins.task.sa.named_values)res = client.handle_secure_aggregation(named_values)task_res = TaskRes(task_id="",group_id="",workload_id=0,task=Task(ancestry=[],sa=SecureAggregation(named_values=serde.named_values_to_proto(res)),),)return task_resraise NotImplementedError()# 處理消息client_msg = handle_legacy_message(client_fn, server_msg)# 處理結果的封裝task_res = wrap_client_message_in_task_res(client_msg)return task_resdef handle_legacy_message(client_fn: ClientFn, server_msg: ServerMessage
) -> ClientMessage:field = server_msg.WhichOneof("msg")client = client_fn("-1")if field == "get_properties_ins":return _get_properties(client, server_msg.get_properties_ins)# 獲取客戶端的模型參數if field == "get_parameters_ins":return _get_parameters(client, server_msg.get_parameters_ins)# 客戶端模型訓練if field == "fit_ins":return _fit(client, server_msg.fit_ins)# 客戶端基于測試集進行測試if field == "evaluate_ins":return _evaluate(client, server_msg.evaluate_ins)raise UnknownServerMessage()
上面的消息處理最終調用的就是客戶端需要實現的 fit()
與 get_parameters()
等方法。可以看到客戶端的實現還是比較簡單的,只是根據幾種消息,執行預先支持的對應的方法即可。考慮到客戶端完全是被動響應服務端的消息,因此主要的聯邦學習的流程的支持都是由服務端定義好的。
Flower 服務端
Flower 服務端需要自定義的代碼較少,主流程是通過直接調用 src/py/flwr/server/app.py
中的 start_server()
方法完成的,我們可以了解下對應的實現:
def start_server(*,server_address: str = ADDRESS_FLEET_API_GRPC_BIDI,server: Optional[Server] = None,config: Optional[ServerConfig] = None,strategy: Optional[Strategy] = None,client_manager: Optional[ClientManager] = None,grpc_max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,certificates: Optional[Tuple[bytes, bytes, bytes]] = None,
) -> History:# 構造啟動地址parsed_address = parse_address(server_address)if not parsed_address:sys.exit(f"Server IP address ({server_address}) cannot be parsed.")host, port, is_v6 = parsed_addressaddress = f"[{host}]:{port}" if is_v6 else f"{host}:{port}"# 初始化 Server 對象,此對象中包含實際的模型訓練流程的支持initialized_server, initialized_config = init_defaults(server=server,config=config,strategy=strategy,client_manager=client_manager,)# 啟動 grpc 服務端,用于與客戶端進行通信grpc_server = start_grpc_server(client_manager=initialized_server.client_manager(),server_address=address,max_message_length=grpc_max_message_length,certificates=certificates,)# 執行訓練流程hist = run_fl(server=initialized_server,config=initialized_config,)# 停止 grpc 服務端grpc_server.stop(grace=1)return hist
執行訓練流程的方法 run_fl(initialized_server)
事實上就是調用 initialized_server.fit()
方法,主要的訓練流程都是在 src/py/flwr/server/server.py
中的 Server.fit()
實現的,下面就重點關注這邊的實現:
def fit(self, num_rounds: int, timeout: Optional[float]) -> History:# 初始化全局模型參數self.parameters = self._get_initial_parameters(timeout=timeout)# 執行 num_rounds 輪模型訓練for current_round in range(1, num_rounds + 1):# 執行單輪機器學習模型訓練res_fit = self.fit_round(server_round=current_round,timeout=timeout,)if res_fit is not None:parameters_prime, fit_metrics, _ = res_fit# 根據聚合生成的模型參數更新全局模型參數if parameters_prime:self.parameters = parameters_prime
可以看到單輪的聯邦學習模型訓練就是通過 fit_round()
實現,具體的代碼如下所示:
def fit_round(self,server_round: int,timeout: Optional[float],
) -> Optional[Tuple[Optional[Parameters], Dict[str, Scalar], FitResultsAndFailures]
]:client_instructions = self.strategy.configure_fit(server_round=server_round,parameters=self.parameters,client_manager=self._client_manager,)# 客戶端基于本地數據進行模型訓練results, failures = fit_clients(client_instructions=client_instructions,max_workers=self.max_workers,timeout=timeout,)# 聚合客戶端發來的模型參數aggregated_result: Tuple[Optional[Parameters],Dict[str, Scalar],] = self.strategy.aggregate_fit(server_round, results, failures)parameters_aggregated, metrics_aggregated = aggregated_resultreturn parameters_aggregated, metrics_aggregated, (results, failures)
下面具體 fit_clients()
是如何發起模型訓練的:
def fit_clients(client_instructions: List[Tuple[ClientProxy, FitIns]],max_workers: Optional[int],timeout: Optional[float],
) -> FitResultsAndFailures:# 多線程并發調用 fit_client 方法實現客戶端模型訓練with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:submitted_fs = {executor.submit(fit_client, client_proxy, ins, timeout)for client_proxy, ins in client_instructions}finished_fs, _ = concurrent.futures.wait(fs=submitted_fs,timeout=None,)results: List[Tuple[ClientProxy, FitRes]] = []failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]] = []for future in finished_fs:_handle_finished_future_after_fit(future=future, results=results, failures=failures)return results, failures# 通過 ClientProxy 發起客戶端的 fit() 模型訓練,ins 中包含全局模型參數def fit_client(client: ClientProxy, ins: FitIns, timeout: Optional[float]
) -> Tuple[ClientProxy, FitRes]:fit_res = client.fit(ins, timeout=timeout)return client, fit_res
總結而言:服務端就是通過多次調用 fit_round()
方法實現多輪的聯邦學習模型訓練,在單輪的聯邦學習模型訓練中,客戶端會根據全局模型參數更新本地模型參數,然后進行本地的模型訓練,并將訓練后的模型的參數發回給服務端,服務端會根據 Strategy 策略聚合客戶端發來的模型參數,然后更新服務端的全局模型參數。
總結
通過上面的內容,將 Flower 框架的動手實踐以及對應的實現細節都介紹到了,主要涉及到 Flower 三大核心組件中的 Strategy 與 FL loop,而 ClientManager 目前沒有過多展開,這部分主要用于管理客戶端的連接,有興趣的可以自行去探索下。
從目前來看,Flower 基本上是一個最精簡的橫向聯邦學習的實現方案了,通過必要的抽象簡化,Flower 將橫向聯邦用簡單易用的方式進行了封裝,對于了解橫向聯邦學習有很大的幫助。而且值得一提的是,Flower 本身具備比較好的靈活性,可以比較方便地支持不同聯邦學習策略,因為 Proxy 機制的存在也能靈活支持異構的客戶端,對于設計一個高效的聯邦學習框架有不少的借鑒意義。