目錄
- PyTorch Tensor生成方式及復制方法詳解
- 一、Tensor的生成方式
- (一)從Python列表/元組創建
- (二)從NumPy數組創建
- (三)特殊初始化方法
- (四)從現有Tensor創建
- (五)高級初始化方法
- 二、復制方法對比
- (一) `torch.tensor()` vs `torch.from_numpy()`
- (二) `.clone()` vs `.copy_()` vs `copy.deepcopy()`
- (三) 深度拷貝(Deep Copy)
- 三、核心區別總結
- 四、最佳實踐建議
PyTorch Tensor生成方式及復制方法詳解
在PyTorch中,Tensor的創建和復制是深度學習開發的基礎操作。本文將全面總結Tensor的各種生成方式,并深入分析不同復制方法的區別。
一、Tensor的生成方式
(一)從Python列表/元組創建
import torch# 直接創建Tensor
t1 = torch.tensor([1, 2, 3]) # 整型Tensor
t2 = torch.tensor([[1.0, 2], [3, 4]]) # 浮點型Tensor
(二)從NumPy數組創建
import numpy as nparr = np.array([1, 2, 3])
t = torch.from_numpy(arr) # 共享內存
(三)特殊初始化方法
zeros = torch.zeros(2, 3) # 全0矩陣
ones = torch.ones(2, 3) # 全1矩陣
rand = torch.rand(2, 3) # [0,1)均勻分布
randn = torch.randn(2, 3) # 標準正態分布
arange = torch.arange(0, 10, 2) # 0-10步長為2
(四)從現有Tensor創建
x = torch.tensor([1, 2, 3])
x1 = x.new_tensor([4, 5, 6]) # 新Tensor(復制數據)
x2 = torch.zeros_like(x) # 形狀相同,全0
x3 = torch.randn_like(x) # 形狀相同,隨機值
(五)高級初始化方法
eye = torch.eye(3) # 3x3單位矩陣
lin = torch.linspace(0, 1, 5) # 0-1等分5份
log = torch.logspace(0, 2, 3) # 10^0到10^2等分3份
二、復制方法對比
(一) torch.tensor()
vs torch.from_numpy()
方法 | 數據源 | 內存共享 | 梯度傳遞 | 數據類型 |
---|---|---|---|---|
torch.tensor() | Python數據 | 不共享 | 支持 | 自動推斷 |
torch.from_numpy() | NumPy數組 | 共享 | 不支持 | 保持一致 |
# 示例:內存共享驗證
arr = np.array([1, 2, 3])
t = torch.from_numpy(arr)
arr[0] = 99 # 修改NumPy數組
print(t) # tensor([99, 2, 3]),同步變化
(二) .clone()
vs .copy_()
vs copy.deepcopy()
方法 | 內存共享 | 梯度傳遞 | 計算圖保留 | 使用場景 |
---|---|---|---|---|
.clone() | 不共享 | 保留梯度 | 保留計算圖 | 需要梯度回傳 |
.copy_() | 目標共享 | 不保留 | 破壞計算圖 | 高效覆蓋數據 |
copy.deepcopy() | 不共享 | 不保留 | 不保留 | 完全獨立拷貝 |
# 示例:梯度傳遞對比
x = torch.tensor([1.], requires_grad=True)
y = x.clone()
z = torch.tensor([2.], requires_grad=True)
z.copy_(x) # 覆蓋z的值y.backward() # 正常回傳梯度到x
# z.backward() # 報錯!copy_()破壞計算圖
(三) 深度拷貝(Deep Copy)
import copyorig = torch.tensor([1, 2, 3])
deep_copied = copy.deepcopy(orig) # 完全獨立拷貝
三、核心區別總結
-
內存共享:
from_numpy()
與NumPy共享內存- 視圖操作(如
view()
/切片)共享內存 - 其他方法均創建獨立副本
-
梯度處理:
.clone()
唯一保留梯度計算圖copy_()
會破壞目標Tensor的計算圖torch.tensor()
創建新計算圖
-
使用場景:
- 需要梯度回傳:使用
.clone()
- 高效數據覆蓋:使用
.copy_()
- 完全獨立拷貝:使用
copy.deepcopy()
- 與NumPy交互:使用
from_numpy()
/numpy()
- 需要梯度回傳:使用
四、最佳實踐建議
- 優先使用
torch.tensor()
創建新Tensor - 需要從NumPy導入數據且避免復制時用
from_numpy()
- 在計算圖中復制數據時必須使用
.clone()
- 需要覆蓋現有Tensor數據時使用
.copy_()
- 調試時注意內存共享可能導致的意外修改
# 正確梯度傳遞示例
x = torch.tensor([1.], requires_grad=True)
y = x.clone() ** 2 # 保留計算圖
y.backward() # 梯度可回傳到x