文章目錄
- 計算
- 參考文獻
計算
數組切片如下
import numpy as np
import dask.array as dadata = np.arange(1000).reshape(10, 100)
a = da.from_array(data, chunks=(5, 20))
print(a[:,0:3])
切片結果是前3列
dask.array<getitem, shape=(10, 3), dtype=int64, chunksize=(5, 3), chunktype=numpy.ndarray>
Dask是懶惰計算,就是說,當你要求結果時,它才會計算。
調用這方法設置任務圖,然后調用compute方法得到結果。
import numpy as np
import dask.array as dadata = np.arange(1000).reshape(10, 100)
a = da.from_array(data, chunks=(5, 20))
print(a[:,0:3].compute())
print(sum(a[0,0:5].compute()))#0+1+2+3+4=10
[[ 0 1 2][100 101 102][200 201 202][300 301 302][400 401 402][500 501 502][600 601 602][700 701 702][800 801 802][900 901 902]]
10
按列求和
import numpy as np
import dask.array as dadata = np.arange(1000).reshape(10, 100)
a = da.from_array(data, chunks=(5, 20))
print(a[1:3,0:3].compute())
print(sum(a[1:3,0:3].compute()))
[[100 101 102][200 201 202]]
[300 302 304]
調用numpy的函數
import numpy as np
import dask.array as dadata = np.arange(1000).reshape(10, 100)
a = da.from_array(data, chunks=(5, 20))
print(a[1:3,0:3].compute())
print(a[1:3,0:3].mean().compute())
print(a[1:3,0:3].sum().compute())
print(np.cos(a[1:3,0:3]).compute())
print(a[1:3,0:3].T.compute())
PS E:\learn\learnpy> & D:/Python312/python.exe e:/learn/learnpy/l2.py
[[100 101 102][200 201 202]]
151.0
906
[[0.86231887 0.89200487 0.1015857 ][0.48718768 0.99808296 0.59134538]]
[[100 200][101 201][102 202]]
可以調用JAX的函數試下
import dask.array as da
import jax.numpy as jnp
from dask import delayed# 創建 Dask 數組
x = da.random.random((1000, 1000), chunks=(100, 100))# 定義一個使用 JAX 的函數
@delayed
def jax_computation(arr):jax_arr = jnp.array(arr) # 轉換為 JAX 數組return jnp.sum(jax_arr * 2).block_until_ready() # 使用 JAX 計算# 應用計算
result = jax_computation(x.compute()) # 先計算 Dask 數組,再傳給 JAX
from dask import compute
import jax# 在多個設備上并行運行 JAX 函數
@delayed
def jax_operation(data):device = jax.devices()[0] # 可以使用不同設備with jax.default_device(device):return jnp.sum(data * 2)# 創建多個延遲任務
tasks = [jax_operation(jnp.ones(100)) for _ in range(10)]
results = compute(*tasks) # 并行計算
另外,分布式 JAX 計算,可以考慮使用 JAX 的 pmap 進行多設備并行
import jax
import jax.numpy as jnp
from jax import pmap# 檢查可用設備
print(jax.devices()) # 例如: [GpuDevice(id=0), GpuDevice(id=1)]# 定義一個簡單的函數
def f(x):return x * 2 + 1# 創建并行化版本
parallel_f = pmap(f)# 準備輸入數據 (注意: 第一維對應設備數量)
x = jnp.array([[1., 2.], [3., 4.]]) # 形狀 (2, 2)# 并行執行
result = parallel_f(x) # 在2個設備上并行計算
print(result)
TensorFlow Probability (TFP) 可以與 TensorFlow 的分布式策略結合使用,實現大規模的統計計算和概率建模。
基本概率分布計算
def run_in_distributed_environment():# 在策略范圍內創建變量和計算with strategy.scope():# 創建TFP正態分布normal = tfp.distributions.Normal(loc=0., scale=1.)# 分布式計算samples = normal.sample(1000)mean = tf.reduce_mean(samples)stddev = tf.math.reduce_std(samples)return mean, stddevmean, stddev = run_in_distributed_environment()
print(f"均值: {mean.numpy()}, 標準差: {stddev.numpy()}")
參考文獻
- https://docs.dask.org/en/stable
- deepseek