技術背景
在PyTorch中,沒有直接實現cbrt這一算子。這個算子是用于計算一個數的開立方,例如,最簡單的-8開立方就是-2。但這里有個問題是,在PyTorch中,因為沒有cbrt算子,如果直接用冪次計算去操作數字,就有可能出現問題。
代碼示例
首先看一下numpy做開立方的代碼示例:
In [1]: import numpy as npIn [2]: a = np.array(-8, np.float32)In [3]: a**(1/3)
<ipython-input-3-f6e83d4e282e>:1: RuntimeWarning: invalid value encountered in powera**(1/3)
Out[3]: np.float32(nan)In [4]: np.cbrt(a)
Out[4]: np.float32(-2.0)
在這個示例中,如果直接開立方,結果會是一個nan,很明顯不是我們想要的一個結果。而cbrt是一個單獨實現的開立方算子,可以支持負數的輸入,計算結果也是正確的。在PyTorch的場景下,只能用冪次運算:
In [1]: import torch as tcIn [2]: a=tc.tensor(-8,dtype=tc.float32)In [3]: a**(1/3)
Out[3]: tensor(nan)
這樣得到的結果是錯誤的。因此需要我們自己實現一個cbrt函數:
In [1]: import torch as tcIn [2]: cbrt=lambda x: tc.sign(x)*tc.abs(x)**(1/3)In [3]: a=tc.tensor(-8,dtype=tc.float32)In [4]: cbrt(a)
Out[4]: tensor(-2.)
其實邏輯也比較簡單,就是先把符號提取出來,然后再轉化為正數正常計算就好了。
總結概要
本文介紹了在PyTorch中直接使用冪次函數計算有可能導致的計算結果異常的問題。由于PyTorch中并未像Numpy和MindSpore一樣直接支持cbrt開立方函數,因此這里也提供了一個在PyTorch中計算開立方的函數。
版權聲明
本文首發鏈接為:https://www.cnblogs.com/dechinphy/p/cbrt.html
作者ID:DechinPhy
更多原著文章:https://www.cnblogs.com/dechinphy/
請博主喝咖啡:https://www.cnblogs.com/dechinphy/gallery/image/379634.html