文章目錄
- 版本介紹
- 隱私計算介紹
- 前言
- FATE架構
- 總體架構
- FateBoard架構
- 前端架構
- 后端架構
- FateClient架構
- 創建DAG方式
- DAG生成
- 任務管理
- python SDK方式
- FateFlow架構
- Eggroll架構
- FATE算法架構
- Cpn層
- FATE ML層
- 組件新增流程
- 新增組件流程
- 新增算法流程
版本介紹
WeBank的FATE開源版本 2.2.0
隱私計算介紹
(對隱私計算已經有了解的朋友可以跳過這節)
隱私計算,顧名思義,在保護隱私的前提下實現計算。
計算分為集中式計算和分布式計算,對于集中式計算的隱私包含,就是對集中式數據的保護。
1、集中式計算需要從各方收集數據,然后中心式進行計算。
目前主流的方法是采用可信執行環境,采用硬件加密的技術,建立虛擬機并對內存進行加密,并在生產環境部署后關閉虛擬機登錄入口。(保證了數據對中心計算節點不可見)
2、分布式計算則相反,不需要從各方收集數據到中心。
2.1 對于數據挖掘、模型訓練等需求往往采用聯邦學習(FL)的方式。
2.2 對于隱私求交、匿蹤查詢等需求,往往采用多方安全計算等方式,基于密碼協議實現分布式計算(多方安全計算往往需要針對一類需求設計一個密碼協議,目前使用仍然不廣泛)。
這里的中間結果,指的是對原始數據的提取信息,比如深度學習模型參數,模型梯度等信息。這些信息很難去還原原始數據信息。
前言
自從近年來隱私計算逐漸有熱度以后,國內目前主流的隱私計算框架也層出不窮。
在工業界,目前開源生態較為完善的主要有微眾銀行的FATE和螞蟻集團的SecretFlow等等。其中FATE主要專注于機器學習等軟件層面的功能,而SecretFlow主要專注于軟件、硬件一體化融合的功能。
而在學術界,目前主要是基于聯邦學習(FL)pytorch進行擴展的pysyft框架(筆者讀大三時發現有內存泄漏問題,不知道現在修復了沒有),以及tensorflow federated(TFF)框架。還有一種簡單粗暴的方式,就是直接采用本地深度學習框架進行模擬,使用串行方式來模擬并行方式,這對學術界的快速想法驗證具有較好的效果。
隨著國家將數據要素視為生產要素以后,以及各大數據交易所的成立,如深圳數據交易所,上海數據交易所等,這些交易所可以極大的促進數據流通,釋放數據潛力。但是隨著而來的就是,數據確權(兩方數據計算后的中間數據屬于誰?)、數據泄漏定責(數據泄漏后如何定位到具體是哪一個使用方進行的泄漏)、數據安全(使用方被攻擊導致數據丟失)等一系列問題。
重點:對于數據如何安全共享的需求,隱私計算提供了一套可行且靠譜的解決方案。SecretFlow適合于二次開發需求弱,后期計劃部署TEE環境的公司。FATE適合于定制化開發需求較高,且計劃長期迭代開發的公司。
FATE架構
所有的分布式系統架構都可以分為四個部分
1、調度系統
2、計算系統
3、存儲系統
4、監控系統
總體架構
這四個模塊的關系如下:
fate也不例外,FATE的基本架構如下:
在介紹各系統核心功能之前,先介紹一下任務從創建到執行的完成流程圖。
1、首先,用戶通過API、命令行、python SDK等方式,生成一個DAG(有向無環圖)。
2、將存儲DAG的yaml文件發送到FateFlow,由FateFlow進行解析調度DAG。
3、DAG由FateFlow發送到FateBoard,由FateBoard反向解析DAG,生成一張可視化的執行圖。
4、Eggroll接收到FateFlow的執行命令,調用本地的Fate算法庫進行執行具體計算代碼。
5、任務的元信息、模型的元信息等存儲到MySQL數據庫中。而具體的模型參數存儲到文件或Eggroll中(配置文件中指定)。
其中,各系統的核心功能如下:
FateBoard:可視化展示任務狀態,展示DAG的可視化,重試、取消任務等操作的可視化。(僅僅只是可視化發送指令到Flow)
FateClient:生成需要執行的DAG的yaml文件。
FateFlow:由于是分布式計算系統,所以需要進行任務調度。其負責任務的拆解調度以及任務的具體管控等行為。
OSX:統一網關,負責鑒權,消息控制等行為。在Fate1.x中,由RabbitMq+Nginx進行解決。
Eggroll:集群存儲和計算系統,用于存儲訓練的模型參數與MapReduce訓練。在Fate1.x時,采用Spark作為分布式計算框架。
FateBoard架構
總的來說,FateBoard主要提供了一個可視化功能,將前端的請求封裝為具體的FateFlow的請求進行轉發。轉發類為FlowFeign相關的類。由于不是核心算法相關內容,所以WeBank團隊這部分做的很潦草。
前端架構
筆者不是很了解前端開發,如有錯誤,歡迎批評指正。
開發語言:Typescript
開發框架:Vue
功能:根據組件描述的yaml文件、以及組件間的依賴關系(JSON文件)生成可視化的執行流程。
后端架構
開發語言:java
開發框架:Spring
功能:主要是作為前端和Flow的橋梁,將task狀態查詢請求、DAG依賴關系查詢請求等包裝后發送給flow(通過REST接口的方式)。同時也具有一些簡單的用戶管理功能。此外,需要實時更新的請求,如Log日志、Job的執行狀態則需要通過websocket協議向前端進行實時推送。
注: 采用javax的websocket協議進行ws服務管理,對于服務類中使用到的Bean是通過InitializingBean, ApplicationContextAware進行了手動注入。(因為javax產生的servlet不由spring進行管理,無法進行Bean的依賴注入)
FateClient架構
FateClient主要是用來提交任務,即提交DAG。這里可以對比一下SecretFlow,SecretFlow將FateClient和FateBoard綁在了一起,可以前端通過拖拉拽的方式構建流程圖,而Fate則是FateClient構建DAG并提交,由FateBoard查看訓練的狀態以及結果。
創建DAG方式
目前僅支持SDK的一種方式創建DAG,但是進行任務管理時,可以有三種方式進行任務管理,分別為API、命令行、python SDK。
DAG生成
python sdk提供了類似于Java Netty框架的方式的pipeline,構建流水線類,最終生成對應的DAG文件,并提交給flow進行調度。
這里舉個例子更好說明:
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.import argparse
import json
from dataclasses import asdictfrom fate_client.pipeline import FateFlowPipeline
from fate_client.pipeline.components.fate import PSI, Reader
from fate_client.pipeline.utils import test_utils# 實際上是封裝了一個post請求,向flow發送消息
def main(config="../config.yaml", namespace=""):# 解析config文件,獲取到各配置參數if isinstance(config, str):config = test_utils.load_job_config(config)parties = config.partiesguest = parties.guest[0]host = parties.host[0]# 創建流水線,初始化FateFlowExecutor()作為執行器,可以指定回調函數pipeline = FateFlowPipeline().set_parties(guest=guest, host=host)# 創建Reader組件,獲取各方數據集reader_0 = Reader("reader_0")reader_0.guest.task_parameters(namespace=f"experiment{namespace}",name="breast_hetero_guest")reader_0.hosts[0].task_parameters(namespace=f"experiment{namespace}",name="breast_hetero_host")# 創建PSI組件,定義PSI操作# PSI進行求交的列,是在數據上傳時指定的psi_0 = PSI("psi_0",hashType="sha512", input_data=reader_0.outputs["output_data"])# 將組件加入流水線中,組成一個DAG# 部署階段可以傳pipeline,訓練階段只可以傳component。將component傳入到一個tasks字典中,映射為[name:component]pipeline.add_tasks([reader_0, psi_0])# 依據pipeline當前的屬性,內部創建DAGSpecdag = pipeline.compile().get_dag();print(dag)# 將DAG發送給flow執行,這里還需要指定當前參與方的id,對DAG進行選擇性執行。下面是FateFlowExecutor()的執行邏輯。# def fit(self, dag_schema: DAGSchema, component_specs: Dict[str, ComponentSpec],# local_role: str, local_party_id: str, callback_handler: CallbackHandler) -> FateFlowModelInfo:# flow_job_invoker = FATEFlowJobInvoker()# local_party_id = self.get_site_party_id(flow_job_invoker, dag_schema, local_role, local_party_id)## return self._run(# dag_schema,# local_role,# local_party_id,# flow_job_invoker,# callback_handler,# event="fit"# )# 將DAG進行submit后,輪詢監視,1s輪詢一次#pipeline.fit()if __name__ == "__main__":parser = argparse.ArgumentParser("PIPELINE DEMO")parser.add_argument("--config", type=str, default="../config.yaml",help="config file")parser.add_argument("--namespace", type=str, default="",help="namespace for data stored in FATE")args = parser.parse_args()main(config=args.config, namespace=args.namespace)
這段代碼介紹了生成一個讀取數據,并對兩方數據進行psi的任務流程圖。具體的介紹已經在注釋中說明。生成的DAG如下:
dag:parties:- party_id: ['9999']role: guest- party_id: ['10000']role: hostparty_tasks:guest_9999:parties:- party_id: ['9999']role: guesttasks:reader_0:parameters: {name: breast_hetero_guest, namespace: experiment}host_10000:parties:- party_id: ['10000']role: hosttasks:reader_0:parameters: {name: breast_hetero_host, namespace: experiment}stage: traintasks:psi_0:component_ref: psidependent_tasks: [reader_0]inputs:data:input_data:task_output_artifact:output_artifact_key: output_dataparties:- party_id: ['9999']role: guest- party_id: ['10000']role: hostproducer_task: reader_0parameters: {hashType: sha512}stage: defaultreader_0:component_ref: readerparameters: {}stage: default
schema_version: 2.2.0
任務管理
1、API方式進行管理
官方文檔詳細聲明了API方式如何進行任務管理,也詳細列出了參數,這里不做過多介紹。
2、命令行方式管理
這個方式我沒有找到詳細的官方文檔,但是在fate的開源代碼中的Fate_client包下的flow_cli中有詳細的代碼實現,這個也不是我關心的重點,所以我只介紹兩個常用的命令。
1.數據上傳命令(所有在訓練過程中使用到的數據都需要將元信息進行上傳,包括文件名、文件內容描述等信息)
flow data upload -c upload.json
其中,upload.json的參考格式如下:
{"file": "examples/data/breast_hetero_guest.csv", #需要上傳文件的絕對路徑"head": true, # 是否存在表頭"partitions": 16, # 用于分布式并行計算的參數,這里默認16就好,也可以根據cpu的核心數進行配置"extend_sid": true, # 用于PSI操作,如果是縱向數據,就需要加上"meta": { # 對數據表結構的定義"delimiter": ",", # csv文件的分隔符"label_name": "y", # 聲明標簽列"match_id_name": "id" # 聲明主鍵列,唯一標識數據"dtype": "str" #標識數據類型為字符串,不指定的話默認為float32},"namespace": "experiment", # 數據的命名空間"name": "breast_hetero_guest" # 通過命令空間與name可以唯一索引到數據,用于后期的Reader組件進行數據讀取
}
上面只是標識了一部分的常用參數,如果需要具體的定制化操作,我沒有找到官方的具體文檔,但是解析數據上傳代碼類的路徑為:fate_flow/python/fate_flow/components/components/upload.py,這個py文件詳細定義了數據上傳時的默認參數設置。
2.任務提交命令
flow job submit -c train_lr.yaml
這個命令的作用是將存儲dag的yaml文件交給flow進行解析執行。
執行的具體結果有兩種方式可以查看:
一、直接訪問flow的日志,路徑為:fate_flow安裝路徑/logs/{task id}/xxxx(安裝環境的時候可以用,一般不采用這種方式)
二、開啟fate_board,并配置application.properties文件,將flow端口指向部署主機的9380端口。可以直接在前端通過可視化的方式查看任務的執行結果。
python SDK方式
python的SDK方式可以通過寫代碼的方式生成dag的yaml文件進行手動上傳提交任務,也可以直接調用pipeline的fit方式,直接自動進行解析并上傳到flow中。前文已經介紹過了,這里不做贅述。
FateFlow架構
關于FateFlow的架構,官方給了一張很詳細的圖進行說明。
但是這張圖給的細節太多了,如果第一次接觸FateFlow的朋友,估計一下子很難抓住重點,所以我簡化了一下,從開源代碼的文件夾的角度進行介紹。
這里的幾個架構我解釋一下:
app: 用于對接fate client,作為fate flow的入口。
scheduler: 提供了一些關于job和task的接口,如創建job,停止job等操作。
manager: Data、Component、Log組件相關的接口會進行調用,主要是對Data、Component、Log進行管理。
controller: 其余事項的一些服務,比如DAG中組件的依賴關系的查詢等。
其余的還有Eggroll和Spark存儲架構,這里不是flow的重心,我這里不做介紹。
OXS目前我還不太了解,只知道時用來做網關路由的,所以先不做介紹。
Eggroll架構
Eggroll負責存儲和分布式計算,但是Eggroll沒有設計自己的存儲引擎,Eggroll可以依托于MySQL等數據庫來存儲數據。Eggroll配套了對應了DashBoard可以進行監控,需要在conf/eggroll.properties文件中進行修改。
eggroll.resourcemanager.clustermanager.jdbc.driver.class.name=com.mysql.cj.jdbc.Driver
eggroll.resourcemanager.clustermanager.jdbc.url=jdbc:mysql://數據庫服務器ip:端口/數據庫名稱?useSSL=false&serverTimezone=UTC&characterEncoding=utf8&allowPublicKeyRetrieval=true
開發語言: python(引擎部分)+java(dashboard部分)
架構概覽:
Eggroll的計算模式采用的MapReduce架構,由各分區進行并行計算,并進行結果收集后處理。
其中,MySQL中存儲的主要是數據的原始信息,如store_locator中存儲的是數據的分區數量,分區id等信息,而store_partition存儲的是分區id以及每個分區的每個partition的id。
FATE算法架構
fate的算法架構官方的圖片講的較為全面,我這里就使用官方的圖片進行介紹。
這里可以看到,從左往右總共分了Cpn層,FateML層,FateArch層。
Cpn層
其中Cpn是Component的簡寫,在fate框架中主要由components/core下的py文件來進行實現。目的是為了實現組件的注冊發現于管理。在components/components包下的py文件中進行調用。這里不做具體算法的細節邏輯,更關注組件的交互流程。
cpn使用的方式較多,但是對于機器學習算法設計而言,主要使用到他的@cpn.component、@xxx.train()、@xxx.predict()三個裝飾器。我這里就僅僅介紹這三個裝飾器的作用。
1、@Cpn.component
這個裝飾器是為了根據函數名來生成一個Component類,作為組件管理。需要會自動裝配參數如下:
cpn = Component(name=cpn_name, # 組件名字roles=roles, # 當前運行方的角色provider=provider, # 組件庫,默認為fateversion=version, # 組件版本description=desc, # 組件描述callback=f, # 回調函數,就是聲明解釋器的函數parameters=parameters,artifacts=artifacts,is_subcomponent=is_subcomponent,)
具體的使用方式如下:
@cpn.component(roles=[GUEST, HOST, ARBITER], provider="fate")
def coordinated_lr(ctx, role):...
這里需要注意,如果是不區分訓練階段于測試階段的組件,則需要在這個函數中完成具體的邏輯設計。而對于大部分機器學習模型,往往需要針對訓練階段或測試結果進行不同的代碼運行。所以,這時@xxx.train()、@xxx.predict()裝飾器就派上了用場。
2、@xxx.train()與@xxx.predict()
在完成@cpn.component的組件注冊時,會產生coordinated_lr這個函數,這個函數也作為裝飾器進行修飾其他函數。這里針對具體的組件,對于訓練的train函數,在train函數上方添加xxx.train()裝飾器,而對于測試階段的predict函數,則在predict函數上方添加xxx.predict()裝飾器。
3、其余Cpn參數
Fate對所有的Component操作都進行了封裝,包括了參數。所以,所有的cpn相關的參數都需要以cpn封裝的形式進行提供。
舉個例子:
@coordinated_lr.train()
def train(ctx: Context, # 組件上下文,往往用于流水線中傳遞非規格化的信息,如host向guest傳遞數據,向前端寫入log信息等role: Role, # 當前執行組件的角色,隱私計算中guest和host執行的流程往往不同train_data: cpn.dataframe_input(roles=[GUEST, HOST]), # 訓練數據,由上游組件輸入validate_data: cpn.dataframe_input(roles=[GUEST, HOST], optional=True), # 驗證數據,由上游組件輸入...early_stop: cpn.parameter(type=params.string_choice(["weight_diff", "diff", "abs"]),default="diff",desc="early stopping criterion, choose from {weight_diff, diff, abs, val_metrics}",), # 非規格化參數,需要指定默認值,給定參數類型,參數的描述等信息。train_output_data: cpn.dataframe_output(roles=[GUEST, HOST]), # 組件的數據輸出,下一個組件可以從中獲取數據輸入output_model: cpn.json_model_output(roles=[GUEST, HOST, ARBITER]), # 組件的模型輸出warm_start_model: cpn.json_model_input(roles=[GUEST, HOST, ARBITER], optional=True), # 組件支持在預訓練模型的基礎上進行訓練,這里可以傳入之前訓練好的模型
):
FATE ML層
ml層為隱私計算算法的具體邏輯實現,由Cpn層中的Components/Components下的組件進行調用。這里的大部分文件夾都可以根據名字猜出組件名字,我這里介紹一下比較重要的三個文件夾,分別為abc,aggregator,nn。
1、abc
abc是Abstract Base Classes的簡寫。因為python沒有接口的概念,但是Fate的作者想要為所有的機器學習模型提供一種統一的抽象,以便上次cpn進行調用,所以定義了一個抽象基類module。
class Module:mode: str@typing.overloaddef fit(self, ctx: Context, input_data):...def fit(self,ctx: Context,*args,**kwargs,) -> None:...def transform(self, ctx: Context, transform_data: DataFrame) -> DataFrame: #計算一些中間結果,比如平均數等...def predict(self, ctx: Context, predict_data: DataFrame) -> DataFrame: #執行預測階段的任務...def from_model(cls, model: Union[dict, Model]): # 將json格式或二進制格式的模型進行反序列化...def get_model(self) -> Union[dict, Model]: # 將當前模型進行序列化...
由于python支持多繼承機制,所以所有繼承了module類的隱私計算算法類都需要重寫這些方法。
2、aggregator
aggregator是很核心的一個文件夾,提供了明文聚合和安全聚合兩種聚合方式,分別由兩個類實現。幾乎所有需要設計到聚合操作的算法,比如橫向聯邦學習等都需要用到這個類。聚合操作目前僅在橫向聚合時才需要使用到,所以接下來分析的都是橫向操作。(70%的相關調用類在nn/trainer/trainer_base.py文件夾下)
先展示一下aggregator下的__init__.py的代碼:
class AggregatorType(enum.Enum):PLAINTEXT = "plaintext"SECURE_AGGREGATE = "secure_aggregate"aggregator_map = {AggregatorType.PLAINTEXT.value: (PlainTextAggregatorClient, PlainTextAggregatorServer),AggregatorType.SECURE_AGGREGATE.value: (SecureAggregatorClient, SecureAggregatorServer),
}from fate.ml.aggregator.aggregator_wrapper import AggregatorClientWrapper, AggregatorServerWrapper__all__ = ["PlainTextAggregatorClient","PlainTextAggregatorServer","SecureAggregatorClient","SecureAggregatorServer","AggregatorServerWrapper","AggregatorClientWrapper",
]
這里面涉及到了多個調用關系,我畫了一張類的調用關系圖方便進行梳理:
文件的調用關系為:
nn/homo/fedavg.py —> nn/trainer/trainer_base.py
圖片為:
這里標綠的標識為外部框架transformers。Fate在封裝算法時,為了避免大量重寫已有算法,采用了大量的類繼承自torch。
這里aggregator中的聚合器在FedAVGServer類進行初始化時需要傳入,具體的方式采用字符串與類的綁定,降低了代碼的耦合性。
3、nn
nn作為神經網絡相關的最重要的一個類,下面分了好幾個文件夾:
(1)datatest的作用時從本地加載數據,fate上傳數據時,只是將數據的元信息存儲在了MySQL中,并沒有將數據納入管控,只是得到了數據存儲的絕對路徑。可以根據自己的需要在這里實現dataloader,我基于自己的需求,在這里實現了時間序列數據的讀取、圖片數據的讀取,不過fate原始只支持CSV數據的讀取 。
(2)hetero的作用為封裝了縱向訓練的相關操作,將參與方分為了guest與host。
(3)homo文件夾下主要實現了FedAVG算法,前面已經介紹過了,這里不贅述。
(4)model_zoo中存放了所有的機器學習模型,這里需要注意,對于縱向模型需要由top模型和bottom模型,這是由fedpass算法決定的,也可以選擇sshe算法進行聚合,我目前還沒有試過。橫向模型的化,只需要一個通用模型即可,fate會自動對模型參數進行聚合。
(5)trainer中存放的是所有和橫向或者縱向訓練流程相關的類。主要涉及聚合操作的類。
組件新增流程
fate官方給了一個組件新增流程,但是講的比較粗略,我這里詳細介紹一下。分別從兩個方面進行介紹,新增組件流程和新增模型流程。
新增組件流程
Fate中所謂的新增組件,就是新增一個類,并且這個類可以綁定在pipeline中進行處理。
①進入fate項目,在python/fate/components/components/下新建組件(以psi為例)新建一個my_dsj.py 其內容如下:
一般開發一個組件包含以下幾個部分(參考feature_scale):
#1.先定義組件
@cpn.component(roles=[GUEST, HOST], provider="fate")def 組件名稱(ctx, role):...
#2. 組件實現(一般包含
@組件名稱.train() 模型訓練
@組件名稱.predict() 模型預測
@組件名稱.cross_validation()) 交叉驗證 這三個裝飾器實現不同的階段,如果只有一個也可以不用任何裝飾器,例如:PSI。每個階段必須有ctx(上下文),role(角色范圍),每個階段可定義輸入(cpn.dataframe_input(數據輸入)| cpn.json_model_input(模型輸入))與輸出(cpn.dataframe_output(數據輸出)| cpn.json_model_output(模型輸出))。
②查看服務器conda生成fate對應的python對應的虛擬環境例如下
(venv=/data/projects/fate/common/miniconda3)
${venv}/lib/python3.10/site-packages/ fate/components/components/
然后將該文件 psi_my.py復制到該目錄下.
③在python/fate/components/components/目錄下在__init__.py注冊組件其內容如下:
@_lazy_cpn def psi_my(self):from .psi_my import psi_myreturn psi_my
然后進入${venv}/lib/python3.10/site-packages/fate_client/pipeline/component_define目錄執行以下命令生成組件描述
python -m fate.components component desc --name psi_my --save psi_my.yaml
④進入${venv}/lib/python3.10/site-packages/fate_client/pipeline/components/fate目錄,新建psi_my.py文件
⑤然后在__init__.py中引入新加入的組件from .psi_my import PSI_MY完成后查詢新注冊的組件。
python -m fate.components component list#返回: {'buildin': ['feature_scale', 'reader', 'coordinated_lr', 'coordinated_linr', 'homo_nn', 'hetero_nn', 'homo_lr', 'hetero_secureboost', 'dataframe_transformer', 'psi', 'psi_my', 'evaluation', 'artifact_test', 'statistics', 'hetero_feature_binning', 'hetero_feature_selection', 'feature_correlation', 'union', 'sample', 'data_split', 'sshe_lr', 'sshe_linr', 'toy_example', 'dataframe_io_test', 'multi_model_test', 'cv_test2'], 'thirdparty': []}
⑥以fate自帶的隱私求交為例進入examples/pipeline/psi/目錄,新建test_psi_my.py
import argparsefrom fate_client.pipeline import FateFlowPipeline from fate_client.pipeline.components.fate import PSI_MY, Reader from fate_client.pipeline.utils import test_utilsdef main(config="../config.yaml", namespace=""):if isinstance(config, str):config = test_utils.load_job_config(config)parties = config.parties guest = parties.guest[0]host = parties.host[0]pipeline = FateFlowPipeline().set_parties(guest=guest, host=host) # 初始化pipelinereader_0 = Reader("reader_0")reader_0.guest.task_parameters(namespace=f"experiment{namespace}",name="breast_hetero_guest" )reader_0.hosts[0].task_parameters(namespace=f"experiment{namespace}",name="breast_hetero_host" )psi_0 = PSI_MY("psi_0", input_data=reader_0.outputs["output_data"])pipeline.add_tasks([reader_0, psi_0]) # 往pipeline中添加任務pipeline.compile() # 編譯生成dag# print(pipeline.get_dag())pipeline.fit() # 上傳dag到fate flow進行執行,并定時查詢任務狀態if __name__ == "__main__":parser = argparse.ArgumentParser("PIPELINE DEMO")parser.add_argument("--config", type=str, default="../config.yaml", # yaml文件主要是指定guest和host的idhelp="config file")parser.add_argument("--namespace", type=str, default="",help="namespace for data stored in FATE")args = parser.parse_args()main(config=args.config, namespace=args.namespace)
然后執行示例(隱私求交): python test_psi_my.py
⑦在guest端處查詢到最終的結果
新增算法流程
在model_zoo中完成模型定義,如:
import torch
import torch.nn as nn
from fate_client.pipeline.components.fate.nn.torch.base import TorchModule
import logging# 定義 CNN 模型
class SimpleCNN(nn.Module):def __init__(self, in_features, out_features, height, width):super(SimpleCNN, self).__init__()self.in_features = in_featuresself.out_features = out_featuresself.height = heightself.width = width# 第一次卷積與池化self.height = (self.height - 2) // 2self.width = (self.width - 2) // 2# 第二次卷積與池化self.height = (self.height - 2) // 2self.width = (self.width - 2) // 2self.linner_in = self.height * self.width * 64self.defNetwork()def defNetwork(self):self.conv1 = nn.Conv2d(in_channels=self.in_features, out_channels=32, kernel_size=3)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(32, 64, kernel_size=3)self.fc1 = nn.Linear(self.linner_in, 64)self.fc2 = nn.Linear(64, out_features=self.out_features)def forward(self, x):x = self.pool(torch.relu(self.conv1(x)))x = self.pool(torch.relu(self.conv2(x)))x = x.view(-1, self.linner_in)x = torch.relu(self.fc1(x))x = self.fc2(x)if self.training:return xelse:softmax_out = nn.Softmax(dim=-1)(x)return softmax_outclass CNN(SimpleCNN, TorchModule):def __init__(self, in_features, out_features,height,width, **kwargs):TorchModule.__init__(self)self.param_dict["in_features"] = in_featuresself.param_dict["out_features"] = out_featuresself.param_dict["height"] = heightself.param_dict["width"] = widthself.param_dict.update(kwargs)SimpleCNN.__init__(self, **self.param_dict)
這里唯一需要注意的就是,最終對外暴露的CNN,需要繼承自from fate_client.pipeline.components.fate.nn.torch.base import TorchModule這個類,這個類可以理解為只重寫了to_string方法(),在pipeline中組件中傳遞模型時,具有非常大的作用。