dask.dataframe.shuffle.set_index
中獲取 divisions 的步驟分析
主要流程概述
在 set_index
函數中,當 divisions=None
時,系統需要通過分析數據來動態計算分區邊界。這個過程分為以下幾個關鍵步驟:
1. 初始檢查和準備
if divisions is None:sizes = df.map_partitions(sizeof) if repartition else []divisions = index2._repartition_quantiles(npartitions, upsample=upsample)mins = index2.map_partitions(M.min)maxes = index2.map_partitions(M.max)divisions, sizes, mins, maxes = base.compute(divisions, sizes, mins, maxes)
步驟說明:
- 計算每個分區的大小(如果啟用重新分區)
- 調用
_repartition_quantiles
計算近似分位數 - 并行計算每個分區的最小值和最大值
- 使用
base.compute
觸發實際計算
2. 分位數計算過程 (_repartition_quantiles
)
_repartition_quantiles
方法調用 partition_quantiles
函數,該函數執行以下步驟:
2.1 生成采樣策略
def sample_percentiles(num_old, num_new, chunk_length, upsample=1.0, random_state=None):# 計算隨機百分位比例random_percentage = 1 / (1 + (4 * num_new / num_old) ** 0.5)# 生成等間距和隨機百分位
2.2 創建計算圖
# 1. 數據類型信息
dtype_dsk = {(name0, 0): (dtype_info, df_keys[0])}# 2. 每個分區的百分位摘要
val_dsk = {(name1, i): (percentiles_summary, key, df.npartitions, npartitions, upsample, state)for i, (state, key) in enumerate(zip(state_data, df_keys))
}# 3. 合并和壓縮摘要
merge_dsk = create_merge_tree(merge_and_compress_summaries, sorted(val_dsk), name2)# 4. 最終處理
last_dsk = {(name3, 0): (pd.Series, (process_val_weights, merged_key, npartitions, (name0, 0)), qs, None, df.name)
}
3. 數據后處理
divisions = methods.tolist(divisions)
if type(sizes) is not list:sizes = methods.tolist(sizes)
mins = methods.tolist(mins)
maxes = methods.tolist(maxes)
4. 空數據檢測和重新分區
empty_dataframe_detected = pd.isnull(divisions).all()
if repartition or empty_dataframe_detected:total = sum(sizes)npartitions = max(math.ceil(total / partition_size), 1)npartitions = min(npartitions, df.npartitions)# 插值生成新的分界點divisions = np.interp(x=np.linspace(0, n - 1, npartitions + 1),xp=np.linspace(0, n - 1, n),fp=divisions,).tolist()
5. 數據類型特殊處理
if pd.api.types.is_categorical_dtype(index2.dtype):dtype = index2.dtypemins = pd.Categorical(mins, dtype=dtype).codes.tolist()maxes = pd.Categorical(maxes, dtype=dtype).codes.tolist()
6. 排序優化檢查
if (mins == sorted(mins) and maxes == sorted(maxes) and all(mx < mn for mx, mn in zip(maxes[:-1], mins[1:]))):divisions = mins + [maxes[-1]]result = set_sorted_index(df, index, drop=drop, divisions=divisions)return result.map_partitions(M.sort_index)
這個檢查的作用:
- 如果數據已經按索引排序,可以直接使用最小值和最大值作為分界點
- 避免昂貴的shuffle操作
分位數計算詳細過程
核心算法:percentiles_summary
函數
def percentiles_summary(df, num_old, num_new, upsample, state):"""Summarize data using percentiles and derived weights."""# 1. 生成采樣百分位qs = sample_percentiles(num_old, num_new, len(df), upsample, state)# 2. 計算百分位值vals = df.quantile(qs)# 3. 轉換為權重return percentiles_to_weights(qs, vals, len(df))
權重計算:percentiles_to_weights
函數
def percentiles_to_weights(qs, vals, length):"""Weigh percentile values by length and the difference between percentiles"""if length == 0:return ()diff = np.ediff1d(qs, 0.0, 0.0)weights = 0.5 * length * (diff[1:] + diff[:-1])return vals.tolist(), weights.tolist()
權重計算原理:
- 每個百分位值的權重 = 0.5 × 分區長度 × (前一個百分位差 + 后一個百分位差)
- 這樣確保權重反映該值在數據分布中的重要性
合并和壓縮:merge_and_compress_summaries
函數
def merge_and_compress_summaries(vals_and_weights):"""Merge and sort percentile summaries that are already sorted."""vals, weights = zip(*vals_and_weights)vals = list(merge_sorted(*vals))weights = list(merge_sorted(*weights))# 壓縮重復值compressed_vals = []compressed_weights = []# ... 壓縮邏輯return compressed_vals, compressed_weights
最終處理:process_val_weights
函數
這個函數(我們之前分析過的)將合并后的值-權重對轉換為最終的分區邊界。
完整流程圖和示例
讓我創建一個詳細的流程圖來展示整個過程:
開始 set_index(df, index, divisions=None)↓
檢查 divisions 是否為 None↓
是 → 開始動態計算分界點↓
1. 并行計算:- 每個分區的大小 (sizes)- 每個分區的分位數 (divisions)- 每個分區的最小值 (mins)- 每個分區的最大值 (maxes)↓
2. 觸發計算:base.compute(divisions, sizes, mins, maxes)↓
3. 分位數計算過程:┌─────────────────────────────────────────┐│ 對每個分區執行 percentiles_summary: ││ 1. 生成采樣百分位 (等間距 + 隨機) ││ 2. 計算百分位值 ││ 3. 轉換為權重 │└─────────────────────────────────────────┘↓
4. 合并所有分區的摘要:┌─────────────────────────────────────────┐│ merge_and_compress_summaries: ││ 1. 合并排序的值和權重 ││ 2. 壓縮重復值 │└─────────────────────────────────────────┘↓
5. 最終處理:┌─────────────────────────────────────────┐│ process_val_weights: ││ 1. 處理巨型值 ││ 2. 計算最終分界點 ││ 3. 處理數據類型 │└─────────────────────────────────────────┘↓
6. 后處理:- 轉換為列表格式- 檢查空數據- 處理重新分區- 特殊數據類型處理↓
7. 排序優化檢查:- 如果數據已排序,使用 min/max 作為分界點- 否則繼續到 shuffle 階段↓
調用 set_partition 進行實際的數據重排↓
結束
關鍵優化策略
- 采樣策略:結合等間距和隨機百分位,平衡計算效率和準確性
- 排序檢測:如果數據已排序,避免昂貴的shuffle操作
- 數據類型感知:特別處理分類、時間等特殊數據類型
- 內存優化:通過壓縮和合并減少內存使用
- 分布式計算:利用Dask的并行計算能力
性能考慮
- 時間復雜度:O(n log n),主要由排序和分位數計算決定
- 空間復雜度:O(n),存儲采樣數據和權重
- 網絡開銷:需要收集所有分區的統計信息
- 計算開銷:需要兩次數據遍歷(統計 + shuffle)
總結
dask.dataframe.shuffle.set_index
中獲取 divisions 的過程是一個復雜的分布式算法,主要包含以下步驟:
核心步驟
- 并行統計:計算每個分區的分位數、大小、最小值、最大值
- 分位數計算:使用采樣策略生成代表性百分位
- 權重分配:根據數據分布為每個值分配權重
- 合并壓縮:合并所有分區的統計信息并壓縮重復值
- 分界點計算:使用
process_val_weights
計算最終分界點 - 優化檢查:檢測數據是否已排序,避免不必要的shuffle
關鍵特點
- 分布式設計:充分利用Dask的并行計算能力
- 智能采樣:結合等間距和隨機采樣策略
- 類型感知:特別處理不同數據類型
- 性能優化:檢測已排序數據,避免重復計算
- 內存高效:通過壓縮和合并減少內存使用
這個算法是Dask DataFrame實現高效分布式排序和分區的核心,通過巧妙的采樣和合并策略,在保證準確性的同時實現了良好的性能。
自己實現
import numpy as np
import pandas as pd# 1?? 采樣百分位
def sample_percentiles(num_old, num_new, chunk_length, upsample=1.0, random_state=None):"""簡單版本:等間距百分位"""return np.linspace(0, 1, num_new + 1)# 2?? 計算百分位摘要(值+權重)
def percentiles_summary(series, num_old, num_new):qs = sample_percentiles(num_old, num_new, len(series))vals = series.quantile(qs).to_numpy()diff = np.ediff1d(qs, 0.0, 0.0)weights = 0.5 * len(series) * (diff[1:] + diff[:-1])return vals.tolist(), weights.tolist()# 3?? 合并多個分區的摘要
def merge_and_compress_summaries(summaries):all_vals = []all_weights = []for vals, weights in summaries:all_vals.extend(vals)all_weights.extend(weights)# 按值排序order = np.argsort(all_vals)vals = np.array(all_vals)[order]weights = np.array(all_weights)[order]# 壓縮重復值compressed_vals = []compressed_weights = []last_val = Nonefor v, w in zip(vals, weights):if last_val is not None and v == last_val:compressed_weights[-1] += welse:compressed_vals.append(v)compressed_weights.append(w)last_val = vreturn np.array(compressed_vals), np.array(compressed_weights)# 4?? 最終處理:計算分界點
def process_val_weights(vals, weights, npartitions):if len(vals) == 0:return np.array([])if len(vals) == npartitions + 1:return valselif len(vals) < npartitions + 1:q_weights = np.cumsum(weights)q_target = np.linspace(q_weights[0], q_weights[-1], npartitions + 1)return np.interp(q_target, q_weights, vals)else:target_weight = weights.sum() / npartitionsjumbo_mask = weights >= target_weightjumbo_vals = vals[jumbo_mask]trimmed_vals = vals[~jumbo_mask]trimmed_weights = weights[~jumbo_mask]trimmed_npartitions = npartitions - len(jumbo_vals)q_weights = np.cumsum(trimmed_weights)q_target = np.linspace(0, q_weights[-1], trimmed_npartitions + 1)left = np.searchsorted(q_weights, q_target, side="left")right = np.searchsorted(q_weights, q_target, side="right") - 1lower = np.minimum(left, right)trimmed = trimmed_vals[lower]rv = np.concatenate([trimmed, jumbo_vals])rv.sort()return rv# 5?? 模擬 set_index 中 divisions 的獲取
def simulate_set_index(df, column, npartitions):num_old = len(df)# 假設原始有分區(這里手動切分成2塊模擬)partitions = np.array_split(df[column], 2)summaries = [percentiles_summary(p, num_old, npartitions) for p in partitions]vals, weights = merge_and_compress_summaries(summaries)divisions = process_val_weights(vals, weights, npartitions)return divisions# ========== DEMO 使用 ==========
df = pd.DataFrame({"x": np.random.randint(0, 100, size=50)})divs = simulate_set_index(df, "x", npartitions=4)print("原始數據示例:\n", df.head())
print("\n計算得到的 divisions:", divs)