文章目錄
- 一、nn.Linear
一、nn.Linear
??nn.Linear 是 PyTorch 中的一個類,用于定義線性變換(全連接層)。它是神經網絡中常用的一種層類型,作為輸入張量與權重矩陣之間的線性變換。
nn.Linear(in_features, out_features, bias=True)
參數說明:
- in_features:輸入特征的大小,即輸入張量的最后一維大小。
- out_features:輸出特征的大小,即輸出張量的最后一維大小。
- bias:是否使用偏置項,默認為 True,表示使用偏置項。
import torch
import torch.nn as nn# 創建一個線性層,輸入特征大小為 3,輸出特征大小為 2
linear_layer = nn.Linear(3, 2)# 輸入張量
x = torch.tensor([[1, 2, 3],[4, 5, 6]])# 進行線性變換
output = linear_layer(x)print(output)
tensor([[ 1.5323, -0.2660],[ 4.5969, -1.0649]], grad_fn=<AddmmBackward>)