PyTorch 是什么?
它是一個基于 Python 的科學計算包, 其主要是為了解決兩類場景:
- NumPy 的替代品, 以使用 GPU 的強大加速功能
- 一個深度學習研究平臺, 提供最大的靈活性和速度
Tensors(張量)
Tensors 與 NumPy 的 ndarrays 非常相似, 除此之外還可以在 GPU 上使用張量來加速計算.
from __future__ import print_function
import torch
構建一個 5x3 的矩陣, 未初始化的:
x = torch.Tensor(5, 3)
print(x)
構建一個隨機初始化的矩陣:
x = torch.rand(5, 3)
print(x)
獲得 size:
print(x.size())
操作
針對操作有許多語法. 在下面的例子中, 我們來看看加法運算.
加法: 語法 1
y = torch.rand(5, 3)
print(x + y)
加法: 語法 2
print(torch.add(x, y))
NumPy Bridge
將一個 Torch Tensor 轉換為 NumPy 數組, 反之亦然.
Torch Tensor 和 NumPy 數組將會共享它們的實際的內存位置, 改變一個另一個也會跟著改變.