PyTorch 在深度學習中提供了多種 IEEE 754 二進制浮點格式的支持,包括半精度(float16
)、Brain?float(bfloat16
)、單精度(float32
)和雙精度(float64
),并通過統一的 torch.dtype
接口進行管理citeturn0search0turn0search3。用戶可利用 torch.finfo
查詢各類型的數值極限(如最大值、最小值、機器 ε 等),通過 torch.set_default_dtype
/torch.get_default_dtype
設置或獲取全局默認浮點精度,并使用 torch.promote_types
控制運算中的類型提升規則citeturn0search2turn0search4。在現代 GPU 上,PyTorch 提供了 torch.amp.autocast
和 torch.amp.GradScaler
等自動混合精度(AMP)工具,能夠在保證數值穩定性的前提下,大幅提升訓練速度和降低顯存占用citeturn0search6turn0search11。
PyTorch 浮點類型對比
類型 (torch.dtype ) | 別名 | 位寬 | 符號位 | 指數位 | 尾數位 (顯式) | 有效精度 (含隱含位) | 典型用途 |
---|---|---|---|---|---|---|---|
torch.float16 | torch.half | 16 | 1 | 5 | 10 | 11 位 (~3.3 十進制位) | 推理加速,對精度要求不高的場景 |
torch.bfloat16 | — | 16 | 1 | 8 | 7 | 8 位 (~2.4 十進制位) | 大規模訓練(TPU、支持 BF16 的 GPU) |
torch.float32 | torch.float | 32 | 1 | 8 | 23 | 24 位 (~7.2 十進制位) | 深度學習訓練/推理的標準精度 |
torch.float64 | torch.double | 64 | 1 | 11 | 52 | 53 位 (~15.9 十進制位) | 科學計算、高精度數值分析 |
上表位寬、指數位、尾數位數據遵循 IEEE 754 標準:二進制16(binary16)格式指數 5 位、尾數 10 位citeturn1search0;二進制32(binary32)格式指數 8 位、尾數 23 位citeturn1search8;二進制64(binary64)格式指數 11 位、尾數 52 位citeturn1search8。
數值屬性查詢
torch.finfo(dtype)
:返回指定浮點類型的數值極限信息,包括:bits
:總位寬eps
:機器 ε,即最小增量min
/max
:可表示的最小/最大值tiny
/smallest_normal
:最小非規范/規范化值 citeturn0search2。
import torch
print(torch.finfo(torch.float32))
# finfo(resolution=1e-06, min=-3.40282e+38, max=3.40282e+38, eps=1.19209e-07, smallest_normal=1.17549e-38, tiny=1.17549e-38, dtype=float32)
默認精度與類型提升
-
全局默認浮點精度
torch.get_default_dtype()
:獲取當前默認浮點類型,初始值為torch.float32
citeturn0search9。torch.set_default_dtype(d)
:設置默認浮點類型,僅支持浮點類型輸入;后續通過 Pythonfloat
構造的張量將采用該類型citeturn0search4。
-
類型提升 (Type Promotion)
torch.promote_types(type1, type2)
:返回在保證不降低精度與范圍的前提下,最小的可兼容浮點類型,用于混合類型運算時的結果類型推斷citeturn0search5。
自動混合精度(AMP)
PyTorch 的 AMP 機制在 前向/反向傳播 中自動選擇低精度(float16
或 bfloat16
)計算,而在 權重更新 等關鍵環節保留 float32
,以兼顧性能與數值穩定性。
torch.amp.autocast
:上下文管理器,針對支持的設備(如 CUDA GPU 或 CPU)自動切換運算精度;在 CUDA 上默認使用float16
,在 CPU 上可指定dtype=torch.bfloat16
citeturn0search6。torch.amp.GradScaler
:動態縮放梯度,避免低精度下的梯度下溢,實現穩定訓練;與autocast
搭配使用可獲顯著加速(1.5–2×)和顯存節省citeturn0search11。
示例(CUDA 上的混合精度訓練):
from torch.amp import autocast, GradScaler
scaler = GradScaler()
for data, target in loader:optimizer.zero_grad()with autocast():output = model(data)loss = loss_fn(output, target)scaler.scale(loss).backward()scaler.step(optimizer)scaler.update()
實踐建議
- 開發與調試階段:優先使用
float32
,確保數值穩定。 - 大規模訓練:若硬件支持 BF16,可嘗試
bfloat16
訓練;否則在 GPU 上結合 AMP 使用float16
。 3. 部署推理:在對精度容忍度高的場景下采用float16
,監控精度變化。 - 默認設置優化:根據項目需求使用
torch.set_default_dtype
控制全局默認精度,并結合torch.promote_types
處理跨類型運算。