示例代碼:
plt.imshow(np.transpose(tensor_denorm.numpy(), (1, 2, 0)))
它的作用是:把一個 PyTorch 的圖像張量轉換成 NumPy 格式,并按照正確的維度順序顯示出來。
🚀 一步步解釋:
? tensor_denorm
這是一個形狀為 (3, H, W)
的 PyTorch Tensor,表示一個圖像:
- 3:表示三個顏色通道(RGB)
- H:圖像高度
- W:圖像寬度
PyTorch 中的圖像張量格式是 (C, H, W)
? .numpy()
這一步把 PyTorch Tensor 轉換成 NumPy 數組(前提是 Tensor 在 CPU 上):
tensor_denorm.numpy()
得到一個 NumPy 數組,形狀依然是 (3, H, W)
? np.transpose(..., (1, 2, 0))
NumPy 默認顯示圖像的格式是 (H, W, C)
,也就是:
- 高度(H)
- 寬度(W)
- 通道(C)
所以要把 (3, H, W)
轉換成 (H, W, 3)
,需要換維度順序:
np.transpose(tensor_denorm.numpy(), (1, 2, 0))
? plt.imshow(...)
這是 matplotlib.pyplot
的圖像顯示函數。它接收一個 (H, W, 3)
的數組并顯示出來:
plt.imshow(...)
📌 舉個例子:
假設我們有這個張量:
tensor = torch.rand(3, 150, 150) # 隨機圖像,3通道 150x150
執行這一步:
plt.imshow(np.transpose(tensor.numpy(), (1, 2, 0)))
就能把這個隨機圖像展示出來了。
? 總結一句話:
plt.imshow(np.transpose(tensor.numpy(), (1, 2, 0)))
等價于:
“把 PyTorch 中格式為
(C, H, W)
的圖像轉成(H, W, C)
并顯示出來”