1.Autograd
- grad和grad_fn
- grad:該tensor的梯度值,每次在計算backward時都需要將前一時刻的梯度歸零,否則梯度值會一直累加
- grad_fn:葉子結點通常為None,只有結果節點的grad_fn才有效,用于只是梯度函數時哪種類型
- torch.autograd.backward(tensors,grad_tensors,retain_graph,create_graph)
- 參數tensors:用于計算梯度的tensor
- 參數grad_tensors:在計算矩陣的梯度時會用到
- 參數retain_graph:通常在調用一次backward后,pytorch會自動把計算圖銷毀,所以想要對某個變量重復調用backward,則需要將該參數值設為True,默認值為False
- 參數create_graph:如果為True,則會創建一個專門計算微分的圖
- torch.autograd.grad(output,input,grad_output,retain_graph,create_graph,only_input,allow_unused)
- 計算和返回output關于input的梯度的和
- 參數output:函數的因變量,即需要求導的函數
- 參數input:函數的自變量
- 參數grad_output、retain_graph、create_graph:同backward
- 參數only_input:值為True時只計算input的梯度
- 參數allow_unused:值為False時,當計算輸出出錯時,指明不使用inpiu
- torch.autograd.Function
- 每一個原始的自動求導運算實際上是兩個在tensor上運行的函數
- forward函數:計算從輸入tensor獲得的輸出tensor
- backward函數:接收輸出tensor對于某個標量值的梯度,并且計算輸入tensor相對于該相同標量值的梯度
- 利用apply方法執行相應的運算
- 每一個原始的自動求導運算實際上是兩個在tensor上運行的函數
import torch
class line(torch.autograd.Function):@staticmethoddef forward(ctx, w, x, b):ctx.save_for_backward(w,x,b)return w * x + b@staticmethoddef backward(ctx,grad_out):w,x,b = ctx.saved_tensorsgrad_w = grad_out * xgrad_x = grad_out * wgrad_b = grad_outreturn grad_w, grad_x, grad_b
w = torch.rand(2,2,requires_grad=True)
x = torch.rand(2,2,requires_grad=True)
b = torch.rand(2,2,requires_grad=True)
out = line.apply(w,x,b)
out.backward(torch.ones(2,2))
print(w,x,b)
print(w.grad)
print(x.grad)
print(b.grad)
- 其他torch.autograd包中的函數
- torch.autograd.enable_grad:啟動梯度計算的上下文管理器
- torch.autograd.no_grad:禁止梯度計算的上下文管理器
- torch.autograd.set_grad_enabled(mode):設置是否進行梯度計算的上下文管理器
2.nn庫
torch.nn庫是專門為神經網絡設計的模塊化接口,自動計算前向傳播和反向傳播,可以用來定義和運行神經網絡。
- nn.Parameter & nn.ParameterList & nn.ParameterDict
- 定義可訓練參數
- nn.Linear & nn.conv2d & nn.ReLU & nn.MaxPool2d & nn.MSELoss
- 各種神經網絡層的定義,繼承于nn.Module的子類
- nn.functional
- 包含了torch.nn庫中所有的函數,包含大量loss和activation function
- nn.functional.xxx是函數接口
- nn.functional.xxx無法與nn.Sequential結合使用
- nn.Sequential
- 通過一個序列的方法完成對一個網絡的定義
- nn.ModuleList
- 用于搭建一個網絡模型
- nn.MouduleDict
- 通過字典的方式搭建一個網絡模型
具體案例使用,后期在神經網絡的學習中
?知識點為聽課總結筆記,課程為B站“2025最新整合!公認B站講解最強【PyTorch】入門到進階教程,從環境配置到算法原理再到代碼實戰逐一解讀,比自學效果強得多!”:2025最新整合!公認B站講解最強【PyTorch】入門到進階教程,從環境配置到算法原理再到代碼實戰逐一解讀,比自學效果強得多!_嗶哩嗶哩_bilibili