torch.autograd.Function
是 PyTorch 提供的一個接口,用于自定義前向傳播和反向傳播的操作。自定義操作需要繼承 torch.autograd.Function 并重載 forward 和 backward 方法。
下面是一個簡單的示例,展示如何自定義一個平方操作的前向傳播和反向傳播。
示例一:
import torch
from torch.autograd import Function
class SquareFunction(Function):@staticmethoddef forward(ctx, input):# ctx 是一個上下文對象,用于存儲反向傳播所需的信息ctx.save_for_backward(input)return input * input@staticmethoddef backward(ctx, grad_output):# 從上下文對象中取回前向傳播保存的信息input, = ctx.saved_tensorsgrad_input = grad_output * 2 * inputreturn grad_input
# 輸入張量
input = torch.tensor([2.0, 3.0], requires_grad=True)# 使用自定義的 SquareFunction
output = SquareFunction.apply(input)# 進行反向傳播
output.backward(torch.tensor([1.0, 1.0]))# 打印梯度
print(input.grad) # 輸出:tensor([4., 6.])
示例二:
import torchclass SignWithSigmoidGrad(torch.autograd.Function):@staticmethoddef forward(ctx, x):result = (x > 0).float()sigmoid_result = torch.sigmoid(x)ctx.save_for_backward(sigmoid_result)return result@staticmethoddef backward(ctx, grad_result):(sigmoid_result,) = ctx.saved_tensorsif ctx.needs_input_grad[0]:grad_input = grad_result * sigmoid_result * (1 - sigmoid_result)else:grad_input = Nonereturn grad_input
這段代碼定義了一個自定義的 PyTorch autograd 函數 SignWithSigmoidGrad,這個函數在前向傳播中計算輸入張量 x 的符號函數(sign function),在反向傳播中計算與 sigmoid 函數有關的梯度。
示例三:
import torch
from torch.autograd import Functionclass SquareFunction(Function):@staticmethoddef forward(ctx, input):# ctx 是一個上下文對象,用于存儲反向傳播所需的信息ctx.save_for_backward(input)return torch.sum(input)@staticmethoddef backward(ctx, grad_output):# 從上下文對象中取回前向傳播保存的信息input, = ctx.saved_tensorsgrad_input = grad_output * 2 * inputreturn grad_input# 輸入張量
input = torch.tensor([2.0, 3.0], requires_grad=True)# 使用自定義的 SquareFunction
output = SquareFunction.apply(input)# 進行反向傳播
output.backward(torch.tensor(2.0))# 打印梯度
print(input.grad) # 輸出:tensor([8., 12.])