在開發一個PyTorch模塊時,遇到了一個詭異的現象,將他描述出來就是下面這樣:
f[..., :p_index - 1] = f[..., 1:p_index]
這個操作將f張量的部分數值進行左移,我在模型訓練的時候還能正常跑,但是當我將模型部署到項目中時,這行代碼報錯了!
Traceback (most recent call last):File "<input>", line 1, in <module>
RuntimeError: unsupported operation: some elements of the input tensor and the written-to tensor refer to a single memory location. Please clone() the tensor before performing the operation.
這個PyTorch報錯是因為在執行操作時,輸入張量和目標張量共享了同一塊內存地址(存在內存重疊),導致PyTorch無法安全地完成原地(in-place)操作。
既然這樣的話為什么在模型訓練的時候不會這樣呢?后面我仔細研究了一下午,發現了下面的原因:
當我們模型在訓練階段中,f的形狀通常是(B,F)的形式存在的,而在部署的時候,作推理時數據通常是(1,F)的形式,所以會出現下面的情況:
# 創建高維張量(3維)
f_3d = torch.randn(16, 1, 25)
slice_3d = f_3d[..., 1:24] # 源切片print("高維張量切片是否連續:")
print(slice_3d.is_contiguous()) # 輸出 False# 創建一維張量對比
f_1d = torch.randn(1, 1, 25)
slice_1d = f_1d[..., 1:24]print("\n一維張量切片是否連續:")
print(slice_1d.is_contiguous()) # 輸出 True
可以看到,當張量是維度大于1時,其在內存中是非連續存儲的,而張量維度為1時,其在內存中是連續存儲的。對于非連續張量,PyTorch會在賦值時隱式創建臨時副本,避免內存覆蓋。因此在進行原地賦值時不會報錯。
最后,為了加強代碼的魯棒性,我在所有涉及這部分操作的代碼后面加上了clone()函數。
f[..., :p_index - 1] = f[..., 1:p_index].clone()