使用 PyTorch 數據讀取,JAX 框架來訓練一個簡單的神經網絡
本文例程部分主要參考官方文檔。
JAX簡介
JAX 的前身是 Autograd ,也就是說 JAX 是 Autograd 升級版本,JAX 可以對 Python 和 NumPy 程序進行自動微分。可以通過 Python的大量特征子集進行區分,包括循環、分支、遞歸和閉包語句進行自動求導,也可以求三階導數(三階導數是由原函數導數的導數的導數。 所謂三階導數,即原函數導數的導數的導數,將原函數進行三次求導)。通過 grad ,JAX 支持反向模式和正向模式的求導,而且這兩種模式可以任意組合成任何順序,具有一定靈活性。
另一個特點是基于 XLA 的 JIT 即時編譯,大大提高速度。
需要注意的是,JAX 僅提供計算時的優化,相當于是一個支持自動微分和 JIT 編譯的 NumPy。也就是說,數據處理 Dataloader 等其他框架都會提供的 utils 功能這里是沒有的。所幸 JAX 可以比較好的支持 PyTorch、 TensorFlow 等主流框架的數據讀取。本文就將基于 PyTorch 的數據讀取工具和 JAX 框架來訓練一個簡單的神經網絡。
以下是國內優秀的機器學習框架 OneFlow 同名公司的創始人袁進輝老師在知乎上的一個評價:
如果說tensorflow 是主打lazy, 偏functional 的思想,但實現的臃腫面目可憎;pytorch 則主打eager, 偏imperative 編程,但內核簡單,可視為支持gpu的numpy, 加上一個autograd。JAX 像是這倆框架的混合體,取了tensorflow的functional和PyTorch的精簡,即支持gpu的 numpy, 具有autograd功能,非常追求函數式編程的思想,強調無狀態,immutable,加上JIT修飾符后就是lazy,可以使用xla對計算流程進行靜態分析和優化。當然JAX不帶jit也可像pytorch那種命令式編程和eager執行。
JAX有可能和PyTorch競爭。
安裝
安裝可以通過源碼編譯,也可以直接 pip。源碼編譯詳見[官方文檔: Building from source][2],對于官方沒有提供預編譯包的 cuda-cudnn 版本組合,只能通過自己源碼構建。pip的方式比較簡單,在 github 倉庫的 README 文檔中就有介紹。要注意,不同于 PyTorch 等框架,JAX 不會再 pip 安裝中綁定 CUDA 或 cuDNN 進行安裝,若未安裝,需要自己先手動安裝。僅使用 CPU 的版本也有支持。
筆者是 CUDA11.1,CUDNN 8.2,安裝如下:
pip install --upgrade pip
pip install jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_releases.html
前面已經提到過,本文會借用 PyTorch 的數據處理工具,因此 torch 和 torchvision 也是必不可少的(已經安裝的可跳過):
pip install torch torchvision
構建簡單的神經網絡訓練
框架安裝完畢,我們正式開始。接下來我們使用 JAX 在 MNIST 上指定和訓練一個簡單的 MLP 進行計算,用 PyTorch 的數據加載 API 來加載圖像和標簽。
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
超參數
# 本函數用來隨機初始化網絡權重
def random_layer_params(m, n, key, scale=1e-2):w_key, b_key = random.split(key)return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n, ))# 初始化各個全連接層
def init_network_params(sizes, key):keys = random.split(key, len(sizes))return [random_layer_params(m, n, k) for m, n, k in zip(sizes[: -1], sizes[1: ], keys)]layer_sizes = [784, 512, 512, 10]
step_size = 0.01
num_epochs = 8
batch_size = 128
n_targets = 10
params = init_network_params(layer_sizes, random.PRNGKey(0))
自動分批次預測
對于小批量,我們稍后將使用 JAX 的 vmap
函數來自動處理,而不會降低性能。我們現在先準備一個單張圖像推理預測函數:
from jax.scipy.special import logsumexpdef relu(x):return jnp.maximum(0, x)# 對單張圖像進行推理的函數
def predict(params, image):activations = imagefor w, b in params[: -1]:outputs = jnp.dot(w, activations) + bactivations = relu(outputs)final_w, final_b = params[-1]logits = jnp.dot(final_w, activations) + final_breturn logits - logsumexp(logits)
這個函數應該只能用來處理單張圖像推理預測,而不能批量處理,我們簡單測試一下,對于單張:
random_flattened_images = random.normal(random.PRNGKey(1), (28 * 28,))
preds = predict(params, random_flattened_images)
print(preds.shape)
輸出:
(10,)
對于批次:
random_flattened_images = random.normal(random.PRNGKey(1), (10, 28 * 28))
try:preds = predict(params, random_flattened_images)
except TypeError:print('Invalid shapes!')
輸出:
Invalid shapes!
現在我們使用 vmap
來使它能夠處理批量數據:
# 用 vmap 來實現一個批量版本
batched_predict = vmap(predict, in_axes=(None, 0))# batched_predict 的調用與 predict 相同
batched_preds = batched_predict(params, random_flattened_images)
print(batched_preds.shape)
輸出:
(10, 10)
現在,我們已經做好了準備工作,接下來就是要定義一個神經網絡并且進行訓練了,我們已經構建了的自動批處理版本的 predict
函數,并且將在損失函數中也使用它。我們將使用 grad
來得到損失關于神經網絡參數的導數。而且,這一切都可以用 jit
進行加速。
實用工具函數和損失函數
def one_hot(x, k, dtype=jnp.float32):"""構建一個 x 的 k 維 one-hot 編碼."""return jnp.array(x[:, None] == jnp.arange(k), dtype)def accuracy(params, images, targets):target_class = jnp.argmax(targets, axis=1)predicted_class = jnp.argmax(batched_predict(params, images), axis=1)return jnp.mean(predicted_class == target_class)def loss(params, images, targets):preds = batched_predict(params, images)return -jnp.mean(preds * targets)@jit
def update(params, x, y):grads = grad(loss)(params, x, y)return [(w - step_size * dw, b - step_size * db) for (w, b), (dw, db) in zip(params, grads)]
使用 PyTorch 進行數據讀取
JAX 是一個專注于程序轉換和支持加速的 NumPy,對于數據的讀取,已經有很多優秀的工具了,這里我們就直接用 PyTorch 的 API。我們會做一個小的 shim 來使得它能夠支持 NumPy 數組。
import numpy as np
from torch.utils import data
from torchvision.datasets import MNISTdef numpy_collate(batch):if isinstance(batch[0], np.ndarray):return np.stack(batch)elif isinstance(batch[0], (tuple, list)):transposed = zip(*batch)return [numpy_collate(samples) for samples in transposed]else:return np.array(batch)class NumpyLoader(data.DataLoader):def __init__(self, dataset, batch_size=1,shuffle=False, sampler=None,batch_sampler=None, num_workers=0,pin_memory=False, drop_last=False,timeout=0, worker_init_fn=None):super(self.__class__, self).__init__(dataset,batch_size=batch_size,shuffle=shuffle,sampler=sampler,batch_sampler=batch_sampler,collate_fn=numpy_collate,num_workers=num_workers,pin_memory=pin_memory,drop_last=drop_last,timeout=timeout,worker_init_fn=worker_init_fn)class FlattenAndCast(object):def __call__(self, pic):return np.ravel(np.array(pic, dtype=jnp.float32))
接下來借助 PyTorch 的 datasets,定義我們自己的 dataset:
mnist_dataset = MNIST('/tmp/mnist/', download=True, transform=FlattenAndCast())
training_generator = NumpyLoader(mnist_dataset, batch_size=batch_size, num_workers=0)
此處應該輸出一堆下載 MNIST 數據集的信息,就不貼了。
接下來分別拿到整個訓練集和整個測試集,下面會用于測準確率:
train_images = np.array(mnist_dataset.train_data).reshape(len(mnist_dataset.train_data), -1)
train_labels = one_hot(np.array(mnist_dataset.train_labels), n_targets)mnist_dataset_test = MNIST('/tmp/mnist/', download=True, train=False)
test_images = jnp.array(mnist_dataset_test.test_data.numpy().reshape(len(mnist_dataset_test.test_data), -1), dtype=jnp.float32)
test_labels = one_hot(np.array(mnist_dataset_test.test_labels), n_targets)
開始訓練
import time
for epoch in range(num_epochs):start_time = time.time()for x, y in training_generator:y = one_hot(y, n_targets)params = update(params, x, y)epoch_time = time.time() - start_timetrain_acc = accuracy(params, train_images, train_labels)test_acc = accuracy(params, test_images, test_labels)print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))print("Training set accuracy {}".format(train_acc))print("Test set accuracy {}".format(test_acc))
輸出:
Epoch 0 in 3.29 sec
Training set accuracy 0.9156666994094849
Test set accuracy 0.9196999669075012
...
Epoch 7 in 1.78 sec
Training set accuracy 0.9736666679382324
Test set accuracy 0.9670999646186829
在本文的過程中,我們已經使用了整個 JAX API:grad
用于自動微分、jit
用于加速、vmap
用于自動矢量化。我們使用 NumPy 來進行我們所有的計算,并從 PyTorch 借用了出色的數據加載器,并在 GPU 上運行了整個過程。
Ref:
https://juejin.cn/post/6994695537316331556
https://jax.readthedocs.io/en/latest/notebooks/Neural_Network_and_Data_Loading.html
https://jax.readthedocs.io/en/latest/developer.html#building-from-source