使用 PyTorch 數據讀取,JAX 框架來訓練一個簡單的神經網絡

使用 PyTorch 數據讀取,JAX 框架來訓練一個簡單的神經網絡

本文例程部分主要參考官方文檔。

JAX簡介

JAX 的前身是 Autograd ,也就是說 JAX 是 Autograd 升級版本,JAX 可以對 Python 和 NumPy 程序進行自動微分。可以通過 Python的大量特征子集進行區分,包括循環、分支、遞歸和閉包語句進行自動求導,也可以求三階導數(三階導數是由原函數導數的導數的導數。 所謂三階導數,即原函數導數的導數的導數,將原函數進行三次求導)。通過 grad ,JAX 支持反向模式和正向模式的求導,而且這兩種模式可以任意組合成任何順序,具有一定靈活性。

另一個特點是基于 XLA 的 JIT 即時編譯,大大提高速度。

需要注意的是,JAX 僅提供計算時的優化,相當于是一個支持自動微分和 JIT 編譯的 NumPy。也就是說,數據處理 Dataloader 等其他框架都會提供的 utils 功能這里是沒有的。所幸 JAX 可以比較好的支持 PyTorch、 TensorFlow 等主流框架的數據讀取。本文就將基于 PyTorch 的數據讀取工具和 JAX 框架來訓練一個簡單的神經網絡

以下是國內優秀的機器學習框架 OneFlow 同名公司的創始人袁進輝老師在知乎上的一個評價:

如果說tensorflow 是主打lazy, 偏functional 的思想,但實現的臃腫面目可憎;pytorch 則主打eager, 偏imperative 編程,但內核簡單,可視為支持gpu的numpy, 加上一個autograd。JAX 像是這倆框架的混合體,取了tensorflow的functional和PyTorch的精簡,即支持gpu的 numpy, 具有autograd功能,非常追求函數式編程的思想,強調無狀態,immutable,加上JIT修飾符后就是lazy,可以使用xla對計算流程進行靜態分析和優化。當然JAX不帶jit也可像pytorch那種命令式編程和eager執行。

JAX有可能和PyTorch競爭。

安裝

安裝可以通過源碼編譯,也可以直接 pip。源碼編譯詳見[官方文檔: Building from source][2],對于官方沒有提供預編譯包的 cuda-cudnn 版本組合,只能通過自己源碼構建。pip的方式比較簡單,在 github 倉庫的 README 文檔中就有介紹。要注意,不同于 PyTorch 等框架,JAX 不會再 pip 安裝中綁定 CUDA 或 cuDNN 進行安裝,若未安裝,需要自己先手動安裝。僅使用 CPU 的版本也有支持。

筆者是 CUDA11.1,CUDNN 8.2,安裝如下:

pip install --upgrade pip
pip install jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_releases.html

前面已經提到過,本文會借用 PyTorch 的數據處理工具,因此 torch 和 torchvision 也是必不可少的(已經安裝的可跳過):

pip install torch torchvision

構建簡單的神經網絡訓練

框架安裝完畢,我們正式開始。接下來我們使用 JAX 在 MNIST 上指定和訓練一個簡單的 MLP 進行計算,用 PyTorch 的數據加載 API 來加載圖像和標簽。

import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

超參數

# 本函數用來隨機初始化網絡權重
def random_layer_params(m, n, key, scale=1e-2):w_key, b_key = random.split(key)return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n, ))# 初始化各個全連接層
def init_network_params(sizes, key):keys = random.split(key, len(sizes))return [random_layer_params(m, n, k) for m, n, k in zip(sizes[: -1], sizes[1: ], keys)]layer_sizes = [784, 512, 512, 10]
step_size = 0.01
num_epochs = 8
batch_size = 128
n_targets = 10
params = init_network_params(layer_sizes, random.PRNGKey(0))

自動分批次預測

對于小批量,我們稍后將使用 JAX 的 vmap 函數來自動處理,而不會降低性能。我們現在先準備一個單張圖像推理預測函數:

from jax.scipy.special import logsumexpdef relu(x):return jnp.maximum(0, x)# 對單張圖像進行推理的函數
def predict(params, image):activations = imagefor w, b in params[: -1]:outputs = jnp.dot(w, activations) + bactivations = relu(outputs)final_w, final_b = params[-1]logits = jnp.dot(final_w, activations) + final_breturn logits - logsumexp(logits)

這個函數應該只能用來處理單張圖像推理預測,而不能批量處理,我們簡單測試一下,對于單張:

random_flattened_images = random.normal(random.PRNGKey(1), (28 * 28,))
preds = predict(params, random_flattened_images)
print(preds.shape)

輸出:

(10,)

對于批次:

random_flattened_images = random.normal(random.PRNGKey(1), (10, 28 * 28))
try:preds = predict(params, random_flattened_images)
except TypeError:print('Invalid shapes!')

輸出:

Invalid shapes!

現在我們使用 vmap 來使它能夠處理批量數據:

# 用 vmap 來實現一個批量版本
batched_predict = vmap(predict, in_axes=(None, 0))# batched_predict 的調用與 predict 相同
batched_preds = batched_predict(params, random_flattened_images)
print(batched_preds.shape)

輸出:

(10, 10)

現在,我們已經做好了準備工作,接下來就是要定義一個神經網絡并且進行訓練了,我們已經構建了的自動批處理版本的 predict 函數,并且將在損失函數中也使用它。我們將使用 grad 來得到損失關于神經網絡參數的導數。而且,這一切都可以用 jit 進行加速。

實用工具函數和損失函數

def one_hot(x, k, dtype=jnp.float32):"""構建一個 x 的 k 維 one-hot 編碼."""return jnp.array(x[:, None] == jnp.arange(k), dtype)def accuracy(params, images, targets):target_class = jnp.argmax(targets, axis=1)predicted_class =  jnp.argmax(batched_predict(params, images), axis=1)return jnp.mean(predicted_class == target_class)def loss(params, images, targets):preds = batched_predict(params, images)return -jnp.mean(preds * targets)@jit
def update(params, x, y):grads = grad(loss)(params, x, y)return [(w - step_size * dw, b - step_size * db) for (w, b), (dw, db) in zip(params, grads)]

使用 PyTorch 進行數據讀取

JAX 是一個專注于程序轉換和支持加速的 NumPy,對于數據的讀取,已經有很多優秀的工具了,這里我們就直接用 PyTorch 的 API。我們會做一個小的 shim 來使得它能夠支持 NumPy 數組。

import numpy as np
from torch.utils import data
from torchvision.datasets import MNISTdef numpy_collate(batch):if isinstance(batch[0], np.ndarray):return np.stack(batch)elif isinstance(batch[0], (tuple, list)):transposed = zip(*batch)return [numpy_collate(samples) for samples in transposed]else:return np.array(batch)class NumpyLoader(data.DataLoader):def __init__(self, dataset, batch_size=1,shuffle=False, sampler=None,batch_sampler=None, num_workers=0,pin_memory=False, drop_last=False,timeout=0, worker_init_fn=None):super(self.__class__, self).__init__(dataset,batch_size=batch_size,shuffle=shuffle,sampler=sampler,batch_sampler=batch_sampler,collate_fn=numpy_collate,num_workers=num_workers,pin_memory=pin_memory,drop_last=drop_last,timeout=timeout,worker_init_fn=worker_init_fn)class FlattenAndCast(object):def __call__(self, pic):return np.ravel(np.array(pic, dtype=jnp.float32))

接下來借助 PyTorch 的 datasets,定義我們自己的 dataset:

mnist_dataset = MNIST('/tmp/mnist/', download=True, transform=FlattenAndCast())
training_generator = NumpyLoader(mnist_dataset, batch_size=batch_size, num_workers=0)

此處應該輸出一堆下載 MNIST 數據集的信息,就不貼了。

接下來分別拿到整個訓練集和整個測試集,下面會用于測準確率:

train_images = np.array(mnist_dataset.train_data).reshape(len(mnist_dataset.train_data), -1)
train_labels = one_hot(np.array(mnist_dataset.train_labels), n_targets)mnist_dataset_test = MNIST('/tmp/mnist/', download=True, train=False)
test_images = jnp.array(mnist_dataset_test.test_data.numpy().reshape(len(mnist_dataset_test.test_data), -1), dtype=jnp.float32)
test_labels = one_hot(np.array(mnist_dataset_test.test_labels), n_targets)

開始訓練

import time
for epoch in range(num_epochs):start_time = time.time()for x, y in training_generator:y = one_hot(y, n_targets)params = update(params, x, y)epoch_time = time.time() - start_timetrain_acc = accuracy(params, train_images, train_labels)test_acc = accuracy(params, test_images, test_labels)print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))print("Training set accuracy {}".format(train_acc))print("Test set accuracy {}".format(test_acc))

輸出:

Epoch 0 in 3.29 sec
Training set accuracy 0.9156666994094849
Test set accuracy 0.9196999669075012
...
Epoch 7 in 1.78 sec
Training set accuracy 0.9736666679382324
Test set accuracy 0.9670999646186829

在本文的過程中,我們已經使用了整個 JAX API:grad 用于自動微分、jit 用于加速、vmap 用于自動矢量化。我們使用 NumPy 來進行我們所有的計算,并從 PyTorch 借用了出色的數據加載器,并在 GPU 上運行了整個過程。

Ref:

https://juejin.cn/post/6994695537316331556

https://jax.readthedocs.io/en/latest/notebooks/Neural_Network_and_Data_Loading.html

https://jax.readthedocs.io/en/latest/developer.html#building-from-source

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/532574.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/532574.shtml
英文地址,請注明出處:http://en.pswp.cn/news/532574.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Yapi Mock 遠程代碼執行漏洞

跟風一波復現Yapi 漏洞描述: YApi接口管理平臺遠程代碼執行0day漏洞,攻擊者可通過平臺注冊用戶添加接口,設置mock腳本從而執行任意代碼。鑒于該漏洞目前處于0day漏洞利用狀態,強烈建議客戶盡快采取緩解措施以避免受此漏洞影響 …

C++ ACM模式輸入輸出

C ACM模式輸入輸出 以下我們都以求和作為題目要求,來看一下各種輸入輸出應該怎么寫。 1 只有一個或幾個輸入 輸入樣例: 3 5 7輸入輸出模板: int main() {int a, b, c;// 接收有限個輸入cin >> a >> b >> c;// 輸出結果…

CVE-2017-10271 WebLogic XMLDecoder反序列化漏洞

漏洞產生原因: CVE-2017-10271漏洞產生的原因大致是Weblogic的WLS Security組件對外提供webservice服務,其中使用了XMLDecoder來解析用戶傳入的XML數據,在解析的過程中出現反序列化漏洞,導致可執行任意命令。攻擊者發送精心構造的…

樹莓派攝像頭 C++ OpenCV YoloV3 實現實時目標檢測

樹莓派攝像頭 C OpenCV YoloV3 實現實時目標檢測 本文將實現樹莓派攝像頭 C OpenCV YoloV3 實現實時目標檢測,我們會先實現樹莓派對視頻文件的逐幀檢測來驗證算法流程,成功后,再接入攝像頭進行實時目標檢測。 先聲明一下筆者的主要軟硬件配…

【實戰】記錄一次服務器挖礦病毒處理

信息收集及kill: 查看監控顯示長期CPU利用率超高,懷疑中了病毒 top 命令查看進程資源占用: netstat -lntupa 命令查看有無ip進行發包 netstat -antp 然而并沒有找到對應的進程名 查看java進程和solr進程 ps aux :查看所有進程…

ag 搜索工具參數詳解

ag 搜索工具參數詳解 Ag 是類似ack, grep的工具,它來在文件中搜索相應關鍵字。 官方列出了幾點選擇它的理由: 它比ack還要快 (和grep不在一個數量級上)它會忽略.gitignore和.hgignore中的匹配文件如果有你想忽略的文…

CVE-2013-4547 文件名邏輯漏洞

搭建環境,訪問 8080 端口 漏洞說明: Nginx: Nginx是一款輕量級的Web 服務器/反向代理服務器及電子郵件(IMAP/POP3)代理服務器,在BSD-like 協議下發行。其特點是占有內存少,并發能力強&#xf…

CMake指令入門 ——以構建OpenCV項目為例

CMake指令入門 ——以構建OpenCV項目為例 轉自:https://blog.csdn.net/sandalphon4869/article/details/100589747 一、安裝 sudo apt-get install cmake安裝好后,輸入 cmake -version如果出現了cmake的版本顯示,那么說明安裝成功 二、c…

CVE-2017-7529Nginx越界讀取緩存漏洞POC

漏洞影響 低危,造成信息泄露,暴露真實ip等 實驗內容 漏洞原理 通過查看patch確定問題是由于對http header中range域處理不當造成,焦點在ngx_http_range_parse 函數中的循環: HTTP頭部range域的內容大約為Range: bytes4096-81…

Linux命令行性能監控工具大全

Linux命令行性能監控工具大全 作者:Arnold Lu 原文:https://www.cnblogs.com/arnoldlu/p/9462221.html 關鍵詞:top、perf、sar、ksar、mpstat、uptime、vmstat、pidstat、time、cpustat、munin、htop、glances、atop、nmon、pcp-gui、collect…

Weblogic12c T3 協議安全漏洞分析【CVE-2020-14645 CVE-2020-2883 CVE-2020-14645】

給個關注?寶兒! 給個關注?寶兒! 給個關注?寶兒! 關注公眾號:b1gpig信息安全,文章推送不錯過 ## 前言 WebLogic是美國Oracle公司出品的一個application server,確切的說是一個基于JAV…

Getshell總結

按方式分類: 0x01注入getshell: 0x02 上傳 getwebshell 0x03 RCE getshell 0x04 包含getwebshell 0x05 漏洞組合拳getshell 0x06 系統層getcmdshell 0x07 釣魚 getcmdshell 0x08 cms后臺getshell 0x09 紅隊shell競爭分析 0x01注入getshell:…

編寫可靠bash腳本的一些技巧

編寫可靠bash腳本的一些技巧 原作者:騰訊技術工程 原文鏈接:https://zhuanlan.zhihu.com/p/123989641 寫過很多 bash 腳本的人都知道,bash 的坑不是一般的多。 其實 bash 本身并不是一個很嚴謹的語言,但是很多時候也不得不用。以下…

python 到 poc

0x01 特殊函數 0x02 模塊 0x03 小工具開發記錄 特殊函數 # -*- coding:utf-8 -*- #內容見POC.demo; POC.demo2 ;def add(x,y):axyprint(a)add(3,5) print(------------引入lambad版本:) add lambda x,y : xy print(add(3,5)) #lambda函數,在lambda函數后面直接…

protobuf版本常見問題

protobuf版本常見問題 許多軟件都依賴 google 的 protobuf,我們很有可能在安裝多個軟件時重復安裝了多個版本的 protobuf,它們之間很可能出現沖突并導致在后續的工作中出現版本不匹配之類的錯誤。本文將討論筆者在使用 protobuf 中遇到的一些問題&#…

CMake常用命令整理

CMake常用命令整理 轉自:https://zhuanlan.zhihu.com/p/315768216 CMake 是什么我就不用再多說什么了,相信大家都有接觸才會看一篇文章。對于不太熟悉的開發人員可以把這篇文章當個查找手冊。 1.CMake語法 1.1 指定cmake的最小版本 cmake_minimum_r…

CVE-2021-41773 CVE-2021-42013 Apache HTTPd最新RCE漏洞復現 目錄穿越漏洞

給個關注?寶兒! 給個關注?寶兒! 給個關注?寶兒! CVE-2021-41773漏洞描述: Apache HTTPd是Apache基金會開源的一款流行的HTTP服務器。2021年10月8日Apache HTTPd官方發布安全更新,披…

SSRF,以weblogic為案例

給個關注?寶兒! 給個關注?寶兒! 給個關注?寶兒! 復習一下ssrf的原理及危害,并且以weblog的ssrf漏洞為案例 漏洞原理 SSRF(Server-side Request Forge, 服務端請求偽造) 通常用于控制web進而…

C++11 右值引用、移動語義、完美轉發、萬能引用

C11 右值引用、移動語義、完美轉發、引用折疊、萬能引用 轉自:http://c.biancheng.net/ C中的左值和右值 右值引用可以從字面意思上理解,指的是以引用傳遞(而非值傳遞)的方式使用 C 右值。關于 C 引用,已經在《C引用…

C++11 std::function, std::bind, std::ref, std::cref

C11 std::function, std::bind, std::ref, std::cref 轉自&#xff1a;http://www.jellythink.com/ std::function 看看這段代碼 先來看看下面這兩行代碼&#xff1a; std::function<void(EventKeyboard::KeyCode, Event*)> onKeyPressed; std::function<void(Ev…