在 PyTorch 中,tensor
是一種強大且靈活的數據結構,可以與多種 Python 常用數據結構(如 int
, list
, numpy array
等)互相轉換。下面是詳細解釋和代碼示例:
1. Tensor ? int / float
轉為 int / float(前提是 tensor 中只有一個元素)
import torcht = torch.tensor(3.14)
i = t.item() # 轉為 float
j = int(t.item()) # 強制轉為 intprint(i) # 3.14
print(j) # 3
.item()
只能用于單元素張量:tensor.numel() == 1
,否則會報錯。
2. Tensor ? list
Tensor 轉 list(Python 原生嵌套 list)
t = torch.tensor([[1, 2], [3, 4]])
lst = t.tolist()
print(lst) # [[1, 2], [3, 4]]
list 轉 Tensor
lst = [[1, 2], [3, 4]]
t = torch.tensor(lst)
print(t) # tensor([[1, 2], [3, 4]])
支持嵌套 list(矩陣)、一維 list(向量)。
3. Tensor ? numpy.ndarray
PyTorch Tensor 和 NumPy array 之間可以無縫轉換,共享內存(改變其中一個會影響另一個)。
Tensor → numpy array
import numpy as np
t = torch.tensor([[1, 2], [3, 4]])
a = t.numpy()
print(type(a)) # <class 'numpy.ndarray'>
numpy array → Tensor
a = np.array([[1, 2], [3, 4]])
t = torch.from_numpy(a)
print(type(t)) # <class 'torch.Tensor'>
numpy 數組必須是數值型(不能是對象數組等),否則會報錯。
4. Tensor ? Python scalar 類型(int, float)
如果你從計算結果中獲取單個數值,比如:
t = torch.tensor([5.5])
val = float(t) # 也可以使用 float(t.item())
print(val) # 5.5# 對于整型:
t2 = torch.tensor([3])
val2 = int(t2) # 等效于 int(t2.item())
print(val2) # 3
5. Tensor ? bytes(用于序列化,如保存到文件)
Tensor → bytes
t = torch.tensor([1, 2, 3])
b = t.numpy().tobytes()
bytes → Tensor
import numpy as np
b = b'\x01\x00\x00\x00\x02\x00\x00\x00\x03\x00\x00\x00'
a = np.frombuffer(b, dtype=np.int32)
t = torch.from_numpy(a)
print(t) # tensor([1, 2, 3], dtype=torch.int32)
6.實戰示例
下面我們從三個實際應用場景來講解 PyTorch 中 tensor 與其他類型(如 list
、numpy
、int
等)相互轉換的用途和技巧:
場景一:數據加載與預處理
讀取圖像數據(使用 PIL) → 轉為 tensor
from PIL import Image
from torchvision import transformsimg = Image.open('cat.jpg') # 打開圖片為 PIL.Image
to_tensor = transforms.ToTensor()
t = to_tensor(img) # 轉為 [C, H, W] 的 float32 Tensor
此時你獲得了一個 Tensor
,可以送入模型。但如果你想可視化或分析:
Tensor → numpy → 可視化或保存
import matplotlib.pyplot as pltimg_np = t.permute(1, 2, 0).numpy() # [H, W, C]
plt.imshow(img_np)
plt.show()
permute
是因為ToTensor
會變成[C,H,W]
,而 matplotlib 需要[H,W,C]
。
場景二:模型推理后的結果處理(轉為 Python 值)
假設你有一個分類網絡,輸出如下:
output = torch.tensor([[0.1, 0.7, 0.2]]) # 假設輸出為 batch_size=1 的 logits
pred_idx = output.argmax(dim=1) # tensor([1])
你要拿到預測類別的整數值:
pred_class = pred_idx.item() # 1
print(type(pred_class)) # <class 'int'>
.item()
在推理階段非常常用!
場景三:保存 Tensor 到磁盤 / 網絡傳輸
Tensor 保存和加載時經常需要轉為 numpy 或 byte 流:
保存為 bytes 再寫入文件
t = torch.tensor([1, 2, 3, 4], dtype=torch.int32)
with open("tensor.bin", "wb") as f:f.write(t.numpy().tobytes())
從文件讀回 tensor
with open("tensor.bin", "rb") as f:byte_data = f.read()import numpy as np
arr = np.frombuffer(byte_data, dtype=np.int32)
t2 = torch.from_numpy(arr)
print(t2) # tensor([1, 2, 3, 4], dtype=torch.int32)
你必須記住原始
dtype
和shape
才能正確還原!
場景四:構造 batch 時將 list 轉為 Tensor
在訓練時經常從數據集中拿到多個樣本組成 batch(Python list):
samples = [[1.0, 2.0], [3.0, 4.0]]
batch_tensor = torch.tensor(samples, dtype=torch.float32)
print(batch_tensor.shape) # torch.Size([2, 2])
或者更通用的方式(可以處理動態 shape):
batch_tensor = torch.stack([torch.tensor(s) for s in samples])
補充:在 with torch.no_grad()
中常用轉換
推理階段經常用 Tensor → numpy → list
:
with torch.no_grad():output = model(input_tensor)pred = output.softmax(dim=1)top1_class = pred.argmax(dim=1).item()
小結對照表
轉換類型 | 方法 | 注意事項 |
---|---|---|
Tensor → int/float | .item() | 只能單元素 |
Tensor → list | .tolist() | 支持嵌套 |
list → Tensor | torch.tensor(list) | 自動推斷類型 |
Tensor → ndarray | .numpy() | 共享內存 |
ndarray → Tensor | torch.from_numpy(ndarray) | 共享內存 |
Tensor → bytes | tensor.numpy().tobytes() | 用于存儲 |
bytes → Tensor | np.frombuffer + from_numpy | 需知道 dtype |