本文通過示例代碼全面講解PyTorch中張量的基本操作,包含創建、運算、廣播機制、索引切片等核心功能,并提供完整的代碼和輸出結果。
1. 張量創建與基本屬性
import torch# 創建連續數值張量
x = torch.arange(12, dtype=torch.float32)
print("原始張量:\n", x)
print("形狀:", x.shape)
print("元素總數:", x.numel())# 創建全零/全一張量
zero = torch.zeros(2, 3, 4)
print("\n三維零張量:\n", zero)one = torch.ones(3, 4)
print("\n全一張量:\n", one)# 手動創建張量
a = torch.tensor([[1,2,3,4], [5,6,7,8], [9,10,11,12]])
print("\n自定義張量:\n", a)
輸出結果:
原始張量:tensor([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.])
形狀: torch.Size([12])
元素總數: 12三維零張量:tensor([[[0., 0., 0., 0.],[0., 0., 0., 0.],[0., 0., 0., 0.]],[[0., 0., 0., 0.],[0., 0., 0., 0.],[0., 0., 0., 0.]]])全一張量:tensor([[1., 1., 1., 1.],[1., 1., 1., 1.],[1., 1., 1., 1.]])自定義張量:tensor([[ 1, 2, 3, 4],[ 5, 6, 7, 8],[ 9, 10, 11, 12]])
2. 張量重塑與轉置
x = x.reshape(3, 4)
print("重塑后的3x4張量:\n", x)
print("轉置張量:\n", x.T)
輸出結果:
重塑后的3x4張量:tensor([[ 0., 1., 2., 3.],[ 4., 5., 6., 7.],[ 8., 9., 10., 11.]])
轉置張量:tensor([[ 0., 4., 8.],[ 1., 5., 9.],[ 2., 6., 10.],[ 3., 7., 11.]])
3. 數學運算
# 矩陣減法
print("x - one:\n", x - one)# 指數運算
b = torch.exp(a)
print("\n指數運算結果:\n", b)
輸出結果:
x - one:tensor([[-1., 0., 1., 2.],[ 3., 4., 5., 6.],[ 7., 8., 9., 10.]])指數運算結果:tensor([[2.7183e+00, 7.3891e+00, 2.0086e+01, 5.4598e+01],[1.4841e+02, 4.0343e+02, 1.0966e+03, 2.9810e+03],[8.1031e+03, 2.2026e+04, 5.9874e+04, 1.6275e+05]])
4. 張量拼接與比較
# 行拼接
c = torch.cat((x, one), dim=0)
print("行拼接結果:\n", c)# 列拼接
d = torch.cat((x, one), dim=1)
print("\n列拼接結果:\n", d)# 張量比較
print("\n張量比較:\n", x == a)
輸出結果:
行拼接結果:tensor([[ 0., 1., 2., 3.],[ 4., 5., 6., 7.],[ 8., 9., 10., 11.],[ 1., 1., 1., 1.],[ 1., 1., 1., 1.],[ 1., 1., 1., 1.]])列拼接結果:tensor([[ 0., 1., 2., 3., 1., 1., 1., 1.],[ 4., 5., 6., 7., 1., 1., 1., 1.],[ 8., 9., 10., 11., 1., 1., 1., 1.]])張量比較:tensor([[False, False, False, False],[False, False, False, False],[False, False, False, False]])
5. 廣播機制
e = torch.arange(3).reshape(3, 1)
print("廣播加法:\n", x + e)
輸出結果:
廣播加法:tensor([[ 0., 1., 2., 3.],[ 5., 6., 7., 8.],[10., 11., 12., 13.]])
6. 索引與切片
print("最后一行:", x[-1])
print("第二到第三行:\n", x[1:3])x[1, 2] = 100 # 修改單個元素
x[0:2, 1:3] = 0 # 修改子區域
print("\n修改后的張量:\n", x)
輸出結果:
最后一行: tensor([ 8., 9., 10., 11.])
第二到第三行:tensor([[ 4., 5., 6., 7.],[ 8., 9., 10., 11.]])修改后的張量:tensor([[ 0., 0., 0., 3.],[ 4., 0., 0., 7.],[ 8., 9., 10., 11.]])
7. 內存地址管理
before = id(x)
x = x + a # 新內存分配
# x += a # 原地操作
print("內存地址是否變化:", before == id(x))D = x.clone()
print("克隆張量地址對比:", before == id(D))
輸出結果:
內存地址是否變化: False
克隆張量地址對比: False
8. PyTorch與NumPy轉換
A = x.numpy()
B = torch.tensor(A)
print("類型轉換:", type(A), type(B))
輸出結果:
類型轉換: <class 'numpy.ndarray'> <class 'torch.Tensor'>
9. 統計操作
sum_a = a.sum(axis=1, keepdims=True)
print("按行求和:\n", sum_a)
print("歸一化結果:\n", a / sum_a)
print("按列累加:\n", a.cumsum(axis=0))
輸出結果:
按行求和:tensor([[10],[26],[42]])
歸一化結果:tensor([[0.1000, 0.2000, 0.3000, 0.4000],[0.1923, 0.2308, 0.2692, 0.3077],[0.2143, 0.2381, 0.2619, 0.2857]])
按列累加:tensor([[ 1, 2, 3, 4],[ 6, 8, 10, 12],[15, 18, 21, 24]])
通過本文的示例代碼,您可以快速掌握PyTorch張量操作的核心功能。建議讀者在實際項目中多加練習以鞏固知識!