隱私計算實訓營第二期第十課:基于SPU機器學習建模實踐

隱私計算實訓營第二期-第十課

  • 第十課:基于SPU機器學習建模實踐
    • 1 隱私保護機器學習背景
      • 1.1 機器學習中隱私保護的需求
      • 1.2 PPML提供的技術解決方案
    • 2 SPU架構
      • 2.1 SPU前端
      • 2.2 SPU編譯器
      • 2.3 SPU運行時
      • 2.4 SPU目標
    • 3 密態訓練與推理
      • 3.1 四個基本問題
      • 3.2 解決數據來源問題
      • 3.3 解決數據安全問題
      • 3.4 解決模型計算問題
      • 3.5 解決密態計算問題
      • 3.6 如何應對更復雜的模型
      • 3.7 已有模型的復用
    • 4 作業實踐
      • 4.1 基礎NN模型作業
      • 4.2 進階Transformer模型作業

第十課:基于SPU機器學習建模實踐

首先必須感謝螞蟻集團及隱語社區帶來的隱私計算實訓第二期的學習機會!
本節課由螞蟻隱私計算部算法工程師吳豪奇老師講解。

在這里插入圖片描述
本節課主要內容為:

  • 隱私保護機器學習背景
  • SPU架構簡介
  • NN密態訓練/推理示例

1 隱私保護機器學習背景

1.1 機器學習中隱私保護的需求

本節課前兩個小節的內容,我們這之前的課程中已有一些了解,
本節課可以回顧一下。
數據和模型的隱私保護需求是產生隱私保護機器學習的根因。

在這里插入圖片描述

1.2 PPML提供的技術解決方案

MPC提供了隱私保護的技術解決方案。

在這里插入圖片描述

使用MPC結合機器學習,為模型訓練和推理提供隱私保護。
問題
我們是否可以直接以 MPC 的方式高效地運行已有的機器學習程序?

在這里插入圖片描述

2 SPU架構

SPU架構我們在之前已經學習過,宏觀上主要分為三部分:

  1. 前端部分
  2. 編譯器
  3. 運行時

2.1 SPU前端

SPU前端盡量支持原生的AI編程方式,支持JAX、TensorFlow,Pytorch
等典型的AI編程框架。

在這里插入圖片描述

2.2 SPU編譯器

SPU的編譯器以優化方式生成SPU的密態中間語言。

SPu

2.3 SPU運行時

SPU的運行時支持多種并行模式(數據并行+指令并行),多種MPC協議
以及多種部署模式。

在這里插入圖片描述

2.4 SPU目標

SPU的最終目標是實現易用、可擴展和高性能的密態計算虛擬設備。

在這里插入圖片描述

3 密態訓練與推理

3.1 四個基本問題

密態的訓練和推理需要解決的四個問題:

  • 數據從哪來?
  • 如何加密保護數據?
  • 如何定義模型計算?
  • 如何執行密態模型計算?

在這里插入圖片描述

3.2 解決數據來源問題

數據由數據個參與方以密態的形式提供。

在這里插入圖片描述

3.3 解決數據安全問題

數據安全通過MPC協議或者同態加密等外部模式解決。

在這里插入圖片描述

3.4 解決模型計算問題

NN模型的計算問題通過JAX實現前向和反向傳播。

在這里插入圖片描述

3.5 解決密態計算問題

NN模型的密態計算SPU的編譯器轉換為密態算子,然后按照MPC協議
進行計算。

在這里插入圖片描述

密態的計算過程與明文類似,通過SPU密態計算配置實現密態訓練。
在這里插入圖片描述

3.6 如何應對更復雜的模型

對于復雜模型,使用stax和flax來進行實現。

在這里插入圖片描述

3.7 已有模型的復用

已有模型的復用問題,根據明文實現來進行密態計算的遷移。
比如,明文實現的GPT2模型。
在這里插入圖片描述

然后進行密態遷移:

在這里插入圖片描述

在支持不同的模型方面,SPU還需要更新和優化自己的實現以滿足不同
模型的需求。

在這里插入圖片描述

4 作業實踐

4.1 基礎NN模型作業

本次課程有兩個作業,一個是基礎的NN模型。另一個是進階的Transformer
模型。

在這里插入圖片描述

完成步驟如下:

1、加載數據集

import numpy as np
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import Normalizerdef breast_cancer(party_id=None, train: bool = True) -> (np.ndarray, np.ndarray):x, y = load_breast_cancer(return_X_y=True)x = (x - np.min(x)) / (np.max(x) - np.min(x))x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=42)if train:if party_id:if party_id == 1:return x_train[:, :15], _else:return x_train[:, 15:], y_trainelse:return x_train, y_trainelse:return x_test, y_test

2、定義模型

from typing import Sequence
import flax.linen as nnFEATURES = [30, 15, 8, 1]class MLP(nn.Module):features: Sequence[int]@nn.compactdef __call__(self, x):for feat in self.features[:-1]:x = nn.relu(nn.Dense(feat)(x))x = nn.Dense(self.features[-1])(x)return x

3、定義訓練參數

import jax.numpy as jnpdef predict(params, x):# TODO(junfeng): investigate why need to have a duplicated definition in notebook,# which is not the case in a normal python program.from typing import Sequenceimport flax.linen as nnFEATURES = [30, 15, 8, 1]class MLP(nn.Module):features: Sequence[int]@nn.compactdef __call__(self, x):for feat in self.features[:-1]:x = nn.relu(nn.Dense(feat)(x))x = nn.Dense(self.features[-1])(x)return xreturn MLP(FEATURES).apply(params, x)def loss_func(params, x, y):pred = predict(params, x)def mse(y, pred):def squared_error(y, y_pred):return jnp.multiply(y - y_pred, y - y_pred) / 2.0return jnp.mean(squared_error(y, pred))return mse(y, pred)def train_auto_grad(x1, x2, y, params, n_batch=10, n_epochs=10, step_size=0.01):x = jnp.concatenate((x1, x2), axis=1)xs = jnp.array_split(x, len(x) / n_batch, axis=0)ys = jnp.array_split(y, len(y) / n_batch, axis=0)def body_fun(_, loop_carry):params = loop_carryfor x, y in zip(xs, ys):_, grads = jax.value_and_grad(loss_func)(params, x, y)params = jax.tree_util.tree_map(lambda p, g: p - step_size * g, params, grads)return paramsparams = jax.lax.fori_loop(0, n_epochs, body_fun, params)return paramsdef model_init(n_batch=10):model = MLP(FEATURES)return model.init(jax.random.PRNGKey(1), jnp.ones((n_batch, FEATURES[0])))

4、驗證參數

from sklearn.metrics import roc_auc_score
def validate_model(params, X_test, y_test):y_pred = predict(params, X_test)return roc_auc_score(y_test, y_pred)

5、開始明文訓練

import jax# Load the data
x1, _ = breast_cancer(party_id=1, train=True)
x2, y = breast_cancer(party_id=2, train=True)# Hyperparameter
n_batch = 10
n_epochs = 10
step_size = 0.01# Train the model
init_params = model_init(n_batch)
params = train_auto_grad(x1, x2, y, init_params, n_batch, n_epochs, step_size)# Test the model
X_test, y_test = breast_cancer(train=False)
auc = validate_model(params, X_test, y_test)
print(f'auc={auc}')

這里輸出的明文訓練結果為:

在這里插入圖片描述

6、開始密文訓練

import secretflow as sf# Check the version of your SecretFlow
print('The version of SecretFlow: {}'.format(sf.__version__))# In case you have a running secretflow runtime already.
sf.shutdown()sf.init(['alice', 'bob'], address='local')alice, bob = sf.PYU('alice'), sf.PYU('bob')
spu = sf.SPU(sf.utils.testing.cluster_def(['alice', 'bob']))x1, _ = alice(breast_cancer)(party_id=1, train=True)
x2, y = bob(breast_cancer)(party_id=2, train=True)
init_params = model_init(n_batch)device = spu
x1_, x2_, y_ = x1.to(device), x2.to(device), y.to(device)
init_params_ = sf.to(alice, init_params).to(device)params_spu = spu(train_auto_grad, static_argnames=['n_batch', 'n_epochs', 'step_size'])(x1_, x2_, y_, init_params_, n_batch=n_batch, n_epochs=n_epochs, step_size=step_size
)

7、檢查參數

params_spu = spu(train_auto_grad)(x1_, x2_, y_, init_params)
params = sf.reveal(params_spu)
print(params)

8、輸出訓練結果

X_test, y_test = breast_cancer(train=False)
auc = validate_model(params, X_test, y_test)
print(f'auc={auc}')

密文訓練輸出結果為:

在這里插入圖片描述
可以看出,密文訓練和明文訓練的效果相同,本作業結束,

4.2 進階Transformer模型作業

完成步驟如下:
1、安裝Transformer模型

import sys
!{sys.executable} -m pip install transformers[flax] -i https://pypi.tuna.tsinghua.edu.cn/simple

2、設置鏡像huggingface

import os
import sys
!{sys.executable} -m pip install huggingface_hub
os.environ['HF_ENDPOINT']='https://hf-mirror.com'

3、加載模型

from transformers import AutoTokenizer, FlaxGPT2LMHeadModel, GPT2Config
tokenizer = AutoTokenizer.from_pretrained("gpt2")
pretrained_model = FlaxGPT2LMHeadModel.from_pretrained("gpt2")

4、定義文本生成函數

def text_generation(input_ids, params):config = GPT2Config()model = FlaxGPT2LMHeadModel(config=config)for _ in range(10):outputs = model(input_ids=input_ids, params=params)next_token_logits = outputs[0][0, -1, :]next_token = jnp.argmax(next_token_logits)input_ids = jnp.concatenate([input_ids, jnp.array([[next_token]])], axis=1)return input_ids

5、進行明文的文本生成

import jax.numpy as jnpinputs_ids = tokenizer.encode('I enjoy walking with my cute dog', return_tensors='jax')
outputs_ids = text_generation(inputs_ids, pretrained_model.params)print('-' * 65 + '\nRun on CPU:\n' + '-' * 65)
print(tokenizer.decode(outputs_ids[0], skip_special_tokens=True))
print('-' * 65)

生成的輸出結果為:

在這里插入圖片描述
6、進行密文訓練

import secretflow as sf# In case you have a running secretflow runtime already.
sf.shutdown()sf.init(['alice', 'bob', 'carol'], address='local')alice, bob = sf.PYU('alice'), sf.PYU('bob')
conf = sf.utils.testing.cluster_def(['alice', 'bob', 'carol'])
conf['runtime_config']['fxp_exp_mode'] = 1
conf['runtime_config']['experimental_disable_mmul_split'] = True
spu = sf.SPU(conf)def get_model_params():pretrained_model = FlaxGPT2LMHeadModel.from_pretrained("gpt2")return pretrained_model.paramsdef get_token_ids():tokenizer = AutoTokenizer.from_pretrained("gpt2")return tokenizer.encode('I enjoy walking with my cute dog', return_tensors='jax')model_params = alice(get_model_params)()
input_token_ids = bob(get_token_ids)()device = spu
model_params_, input_token_ids_ = model_params.to(device), input_token_ids.to(device)output_token_ids = spu(text_generation)(input_token_ids_, model_params_)

這里由于機器配置不夠,內存不足,被系統kill進程,導致無法完成訓練。小伙伴們機器好的應該可以跑完。

在這里插入圖片描述

7、輸出密文訓練結果

outputs_ids = sf.reveal(output_token_ids)
print('-' * 65 + '\nRun on SPU:\n' + '-' * 65)
print(tokenizer.decode(outputs_ids[0], skip_special_tokens=True))
print('-' * 65)

至此,本次作業全部結束。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/39049.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/39049.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/39049.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

全新升級!中央集中式架構功能測試為新車型保駕護航

“軟件定義汽車”新時代下,整車電氣電氣架構向中央-區域集中式發展已成為行業共識,車型架構的變革帶來更復雜的整車功能定義、更多的新技術的應用(如SOA服務化、智能配電等)和更短的車型研發周期,對整車和新產品研發的…

OkHttp的源碼解讀1

介紹 OkHttp 是 Square 公司開源的一款高效的 HTTP 客戶端,用于與服務器進行 HTTP 請求和響應。它具有高效的連接池、透明的 GZIP 壓縮和響應緩存等功能,是 Android 開發中廣泛使用的網絡庫。 本文將詳細解讀 OkHttp 的源碼,包括其主要組件…

Qt實現手動切換多種布局

引言 之前寫了一個手動切換多個布局的程序,下面來記錄一下。 程序運行效果如下: 示例 需求 通過點擊程序界面上不同的布局按鈕,使主工作區呈現出不同的頁面布局,多個布局之間可以通過點擊不同布局按鈕切換。支持的最多的窗口…

如何使用 AppML

如何使用 AppML AppML(Application Markup Language)是一種輕量級的標記語言,旨在簡化Web應用的創建和部署過程。它允許開發者通過XML或JSON格式的配置文件來定義應用的結構和行為,從而實現快速開發和靈活擴展。AppML特別適用于構建數據驅動的企業級應用,它可以與各種后端…

pytorch跑手寫體實驗

目錄 1、環境條件 2、代碼實現 3、總結 1、環境條件 pycharm編譯器pytorch依賴matplotlib依賴numpy依賴等等 2、代碼實現 import torch import torch.nn as nn import torch.optim as optim import torchvision import torchvision.transforms as transforms import matpl…

burpsuite 設置監聽窗口 火狐利用插件快速切換代理狀態

一、修改burpsuite監聽端口 1、首先打開burpsuite,點擊Proxy下的Options選項: 2、可以看到默認的監聽端口為8080,首先選中我們想要修改的監聽,點擊Edit進行編輯 3、將端口改為9876,并保存 4、可以看到監聽端口修改成功…

typescript學習回顧(五)

今天來分享一下ts的泛型,最后來做一個練習 泛型 有時候,我們在書寫某些函數的時候,會丟失一些類型信息,比如我下面有一個例子,我想提取一個數組的某個索引之前的所有數據 function getArraySomeData(newArr, n:numb…

JVM原理(十):JVM虛擬機調優分析與實戰

1. 大內存硬件上的程序部署策略 這是筆者很久之前處理過的一個案例,但今天仍然具有代表性。一個15萬PV/日左右的在線文檔類型網站最近更換了硬件系統,服務器的硬件為四路志強處理器、16GB物理內存,操作系統為64位CentOS5.4,Resin…

js數組方法歸納——concat、join、reverse

1、concat( ) 用途:可以連接兩個或多個數組,并將新的數組返回該方法不會對原數組產生影響 var arr ["孫悟空","豬八戒","沙和尚"];var arr2 ["白骨精","玉兔精","蜘蛛精"];var arr3 [&…

Vue Router的深度解析

引言 在現代Web應用開發中,客戶端路由已成為實現流暢用戶體驗的關鍵技術。與傳統的服務器端路由不同,客戶端路由通過JavaScript在瀏覽器中控制頁面內容的更新,避免了頁面的全量刷新。Vue Router作為Vue.js官方的路由解決方案,以其…

阿里云centos 取消硬盤掛載并重建數據盤信息再次掛載

一、取消掛載 umount [掛載點或設備] 如果要取消掛載/dev/sdb1分區,可以使用以下命令: umount /dev/sdb1 如果要取消掛載在/mnt/mydisk的掛載點,可以使用以下命令: umount /mnt/mydisk 如果設備正忙,無法立即取消…

【Spring Boot】簡單了解spring boot支持的三種服務器

Tomcat 概述:Tomcat 是 Apache 軟件基金會(Apache Software Foundation)的 Jakarta EE 項目中的一個核心項目,由 Apache、Sun 和其他一些公司及個人共同開發而成。它作為 Java Servlet、JSP、JavaServer Pages Expression Languag…

系統安全及應用(命令)

目錄 一、賬號安全控制 1.1 系統賬號清理 1.2 密碼安全控制 1.3 歷史記錄控制 1.4 終端自動注銷 二、系統引導和登陸控制 2.1 限制su命令用戶 2.2 PAM安全認證 示例一:通過pam 模塊來防止暴力破解ssh 2.3 sudo機制提升權限 2.3.1 sudo命令(ro…

Java的日期類常用方法

Java_Date 第一代日期類 獲取當前時間 Date date new Date(); System.out.printf("當前時間" date); 格式化時間信息 SimpleDateFormat simpleDateFormat new SimpleDateFormat("yyyy-mm-dd hh:mm:ss E); System.out.printf("格式化后時間" si…

【windows|012】光貓、路由器、交換機詳解

🍁博主簡介: 🏅云計算領域優質創作者 🏅2022年CSDN新星計劃python賽道第一名 🏅2022年CSDN原力計劃優質作者 ? 🏅阿里云ACE認證高級工程師 ? 🏅阿里云開發者社區專家博主 💊交流社…

windows USB 驅動開發-URB結構

通用串行總線 (USB) 客戶端驅動程序無法直接與其設備通信。 相反,客戶端驅動程序會創建請求并將其提交到 USB 驅動程序堆棧進行處理。 在每個請求中,客戶端驅動程序提供一個可變長度的數據結構,稱為 USB 請求塊 (URB) ,URB 結構描…

ctfshow-web入門-命令執行(web75-web77)

目錄 1、web75 2、web76 3、web77 1、web75 使用 glob 協議繞過 open_basedir&#xff0c;讀取根目錄下的文件&#xff0c;payload&#xff1a; c?><?php $anew DirectoryIterator("glob:///*"); foreach($a as $f) {echo($f->__toString(). ); } ex…

讀書筆記-Java并發編程的藝術-第3章(Java內存模型)-第9節(Java內存模型綜述)

3.9 Java內存模型綜述 前面對Java內存模型的基礎知識和內存模型的具體實現進行了說明。下面對Java內存模型的相關知識做一個總結。 3.9.1 處理器的內存模型 順序一致性內存模型是一個理論參考模型&#xff0c;JMM和處理器內存模型在設計時通常會以順序一致性內存模型為參照。…

ORB-SLAM2 安裝編譯運行(非 ROS)

安裝編譯 必備安裝工具 主要包括 cmake 、 git 、 gcc 、 g gcc 的全稱是 GNU Compiler Collection&#xff0c;它是由 GNU 推出的一款功能強大的、性能優越的 多平臺編譯器&#xff0c;是一個能夠編譯多種語言的編譯器。最開始 gcc 是作為 C 語言的編譯器&#xff08;GNU …

如何將等保2.0的要求融入日常安全運維實踐中?

等保2.0的基本要求 等保2.0是中國網絡安全領域的基本國策和基本制度&#xff0c;它要求網絡運營商按照網絡安全等級保護制度的要求&#xff0c;履行相關的安全保護義務。等保2.0的實施得到了《中華人民共和國網絡安全法》等法律法規的支持&#xff0c;要求相關行業和單位必須按…