pytorch小記(二十九):深入解析 PyTorch 中的 `torch.clip`(及其別名 `torch.clamp`)
- 深入解析 PyTorch 中的 `torch.clip`(及其別名 `torch.clamp`)
- 一、函數簽名
- 二、簡單示例
- 三、廣播支持
- 四、與 Autograd 的兼容性
- 五、典型應用場景
- 六、小結
深入解析 PyTorch 中的 torch.clip
(及其別名 torch.clamp
)
在深度學習任務中,我們經常需要對張量(Tensor)中的數值進行約束,以保證模型訓練的穩定性和數值的合理性。PyTorch 提供了 torch.clip
(以及早期版本中的別名 torch.clamp
)函數,能夠快速將張量中的元素裁剪到指定范圍。本文將帶你從函數簽名、參數說明,到實際示例和應用場景,一步步掌握 torch.clip
的用法。
一、函數簽名
torch.clip(input, min=None, max=None, *, out=None) → Tensor
# 等價于
torch.clamp(input, min=min, max=max, out=out)
- input (
Tensor
):待裁剪的輸入張量。 - min (
float
或Tensor
,可選):下界;所有元素小于此值的會被設置成該值。若為None
,則不進行下界裁剪。 - max (
float
或Tensor
,可選):上界;所有元素大于此值的會被設置成該值。若為None
,則不進行上界裁剪。 - out (
Tensor
,可選):可選的輸出張量,用于將結果寫入指定張量中,避免額外分配。
返回值:一個新的張量(或當指定了 out
時,原地寫入并返回該張量),其中的每個元素滿足:
output[i] =min if input[i] < min,max if input[i] > max,input[i] otherwise.
二、簡單示例
import torchx = torch.tensor([-5.0, -1.0, 0.0, 2.5, 10.0])# 裁剪到區間 [0, 5]
y = torch.clip(x, min=0.0, max=5.0)
print(y) # tensor([0.0, 0.0, 0.0, 2.5, 5.0])# 只有下界裁剪(所有 < 1 的值變成 1)
y_lower = torch.clip(x, min=1.0)
print(y_lower) # tensor([1.0, 1.0, 1.0, 2.5, 10.0])# 只有上界裁剪(所有 > 3 的值變成 3)
y_upper = torch.clip(x, max=3.0)
print(y_upper) # tensor([-5.0, -1.0, 0.0, 2.5, 3.0])
三、廣播支持
當 min
或 max
為張量時,torch.clip
會自動執行廣播對齊:
import torchx = torch.arange(6).reshape(2, 3).float()
# tensor([[0., 1., 2.],
# [3., 4., 5.]])min_vals = torch.tensor([[1., 2., 3.]])
max_vals = torch.tensor([[2., 3., 4.]])y = torch.clip(x, min=min_vals, max=max_vals)
print(y)
# tensor([[1., 2., 2.],
# [2., 3., 4.]])
四、與 Autograd 的兼容性
torch.clip
支持自動梯度(Autograd):
- 當輸入值位于
(min, max)
區間內時,梯度正常傳遞; - 當輸入值被裁剪到邊界時(小于
min
或大于max
),對應位置的梯度為 0,因為輸出對該輸入不敏感。
x = torch.tensor([-10.0, 0.5, 10.0], requires_grad=True)
y = torch.clip(x, min=-1.0, max=1.0)y.sum().backward()
print(x.grad) # tensor([0., 1., 0.])
五、典型應用場景
- 數值穩定性:避免激活值和梯度過大或過小導致溢出/下溢。
- 數據歸一化:將輸入特征裁剪到指定區間,例如將圖像像素限定在
[0, 1]
。 - 損失裁剪:限制損失值范圍,避免單次梯度過大影響整體訓練。
- 強化學習:裁剪策略梯度中的概率比率,防止策略更新過猛。
六、小結
torch.clip
(或 torch.clamp
)是 PyTorch 中一個高效且直觀的張量裁剪操作。通過簡單的參數設置,就能保證張量數值在合理范圍內,提升模型訓練的穩定性和魯棒性。掌握好它的用法,能讓你的深度學習工作流更加可靠。
希望本文能幫到你,如果有任何問題或討論,歡迎在評論區留言交流!