目錄
- 創建 Tensor
- 常用操作
- unsqueeze
- squeeze
- Softmax
- 代碼1
- 代碼2
- 代碼3
- argmax
- item
創建 Tensor
使用 Torch 接口創建 Tensor
import torch
參考:https://pytorch.org/tutorials/beginner/basics/tensorqs_tutorial.html
常用操作
unsqueeze
將多維數組解套,并嵌入新的一層維度。
data = [[1, 2],[3, 4]]x_data = torch.tensor(data)print("x_data")print(x_data)x2_data = x_data.unsqueeze(-1)print("x_data>> unsqueeze -1")print(x2_data)x2_data = x_data.unsqueeze(0)print("x_data>> unsqueeze 0")print(x2_data)x2_data = x_data.unsqueeze(1)print("x_data>> unsqueeze 1")print(x2_data)x2_data = x_data.unsqueeze(2)print("x_data>> unsqueeze 2")print(x2_data)
結果:
x_data
tensor([[1, 2],[3, 4]])
x_data>> unsqueeze -1 # -1 代表最內層,將最內層的數用一個新的維度包起來
tensor([[[1],[2]],[[3],[4]]])
x_data>> unsqueeze 0 # 0 代表最外層,將原來的多維數組整個多套一層
tensor([[[1, 2],[3, 4]]])
x_data>> unsqueeze 1 # 代表原來第一維里的每個元素,套一層
tensor([[[1, 2]],[[3, 4]]])
x_data>> unsqueeze 2 # 代表原來第二維里的每個元素,套一層
tensor([[[1], # 當前一共兩維,所以效果和 -1 一樣[2]],[[3],[4]]])
squeeze
去掉指定或全部的維度中只有一個元素的多維數組。
比如輸入為 Ax1xBxCx1xD 維的數組,輸出變成了 AxBxCxD 維的數組。
https://pytorch.org/docs/stable/generated/torch.squeeze.html
data = [[1], [2],[3], [4]]x_data = torch.tensor(data)print("x_data")print(x_data)x2_data = x_data.squeeze()print("x_data>> squeeze")print(x2_data)x2_data = x_data.squeeze(1)print("x_data>> squeeze 1")print(x2_data)
結果:
x_data
tensor([[1],[2],[3],[4]])
x_data>> squeeze
tensor([1, 2, 3, 4])
x_data>> squeeze 1
tensor([1, 2, 3, 4])
Softmax
https://pytorch.org/docs/stable/generated/torch.softmax.html
歸一化操作。
代碼1
data = torch.tensor([1,2,3], dtype=torch.float) # 維度 3; 注意,此處 dtype 是 int 或 long 接口報錯x_data = torch.softmax(data, 0)print("x_data")print(x_data)
結果:
x_data
tensor([0.0900, 0.2447, 0.6652]) # 維度 3
代碼2
data = torch.tensor([[1],[2],[3]], dtype=torch.float) # 維度 3x1x_data2 = torch.softmax(data, 0)print("x_data2")print(x_data2)
結果:
x_data2 # 維度 3x1
tensor([[0.0900],[0.2447],[0.6652]])
代碼3
data = torch.tensor([[1],[2],[3]], dtype=torch.float) # 維度 3x1x_data2 = torch.softmax(data, 1) # 沿著第一維求print("x_data2")print(x_data2)
結果:
x_data2
tensor([[1.],[1.],[1.]])
此時,每維都是 1 個元素,針對自身求 softmax,所以,結果是 1.
argmax
https://pytorch.org/docs/stable/generated/torch.argmax.html
返回一個多維數組的最大值的索引,如果是多維數組,則返回第一維的索引。
item
https://pytorch.org/docs/stable/generated/torch.Tensor.item.html
返回一個 Tensor 中攜帶的 Python Number 對象。該接口只對 Tensor 是一維的有效。
x = torch.tensor([1.0])
x.item()