torch.optim 優化器
torch.optim
是 PyTorch 中用于優化神經網絡參數的模塊,里面實現了一系列常用的優化算法,比如 SGD、Adam、RMSprop 等,主要負責根據梯度更新模型的參數。
🏗? 核心組成
1. 常用優化器
優化器 | 作用 | 典型參數 |
---|---|---|
torch.optim.SGD | 標準隨機梯度下降,支持 momentum | lr , momentum , weight_decay |
torch.optim.Adam | 自適應學習率,效果穩定 | lr , betas , weight_decay |
torch.optim.RMSprop | 平滑梯度,常用于RNN | lr , alpha , momentum |
torch.optim.AdamW | 改進版Adam,解耦正則化 | lr , weight_decay |
torch.optim.Adagrad | 稀疏特征場景,自動調整每個參數的學習率 | lr , lr_decay , weight_decay |
?演示代碼
import torch
import torch.nn as nn
import torch.optim as optimmodel = nn.Linear(10, 1) # 一個簡單的線性層
optimizer = optim.Adam(model.parameters(), lr=0.001)for epoch in range(100):output = model(torch.randn(4, 10)) # 模擬一個輸入loss = (output - torch.randn(4, 1)).pow(2).mean() # 假設是 MSE 損失optimizer.zero_grad() # 梯度清零loss.backward() # 反向傳播optimizer.step() # 更新參數
import torch
import torchvision.datasets
from torch import nn
from torch.nn import Conv2ddataset = torchvision.datasets.CIFAR10(root='./data_CIF', train=False, download=True, transform=torchvision.transforms.ToTensor())
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1)class Tudui(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, padding=2)self.maxpool1 = nn.MaxPool2d(kernel_size=2)self.conv2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, padding=2)self.maxpool2 = nn.MaxPool2d(kernel_size=2)self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=2)self.maxpool3 = nn.MaxPool2d(kernel_size=2)self.flatten = nn.Flatten()self.linear1 = nn.Linear(in_features=1024, out_features=64)self.linear2 = nn.Linear(in_features=64, out_features=10)self.model1 = nn.Sequential(Conv2d(in_channels=3, out_channels=32, kernel_size=5, padding=2),nn.MaxPool2d(kernel_size=2),nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, padding=2),nn.MaxPool2d(kernel_size=2),nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=2),nn.MaxPool2d(kernel_size=2),nn.Flatten(),nn.Linear(in_features=1024, out_features=64),nn.Linear(in_features=64, out_features=10))def forward(self, x):x = self.model1(x)return x
loss = nn.CrossEntropyLoss()
tudui = Tudui()
optim = torch.optim.SGD(tudui.parameters(), lr=0.01 )
for epoch in range(100):running_loss = 0.0for data in dataloader:imgs,targets = dataoutputs = tudui(imgs)result_loss = loss(outputs, targets)#梯度置零optim.zero_grad()#反向傳播result_loss.backward()#更新參數optim.step()running_loss += result_lossprint(running_loss)
?
?對網絡模型的修改
import torchvision
from torch import nn# train_data = torchvision.datasets.ImageNet(root='./data_IMG',split="train", transform=torchvision.transforms.ToTensor())
#學習如何改變現有的網絡結構
vgg16_false = torchvision.models.vgg16(pretrained=False)vgg16_true = torchvision.models.vgg16(pretrained=True)train_data = torchvision.datasets.CIFAR10(root='./data_CIF',train=True,transform=torchvision.transforms.ToTensor(),download=True)
#加一個線性層
vgg16_true.add_module('add_linear',nn.Linear(in_features=1000,out_features=10))
vgg16_true.classifier.add_module('add_linear',nn.Linear(in_features=1000,out_features=10))
#修改一個線性層
vgg16_false.classifier[6] = nn.Linear(in_features=4096,out_features=10)
print(vgg16_false)
網絡模型的保存與讀取
#model_save.pyimport torch
import torchvision
from torch import nnvgg16 = torchvision.models.vgg16(pretrained=False)
#保存方式一:模型結構+模型參數
torch.save(vgg16,"vgg16.pth")#保存方式二:模型參數(官方推薦)
torch.save(vgg16.state_dict(),"vgg16_state_dict.pth")#陷阱
class Tudui(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(3,64,kernel_size=3)def forward(self, x):x = self.conv1(x)return xtudui = Tudui()
torch.save(tudui,"tudui_method1.pth")
#model_load.pyimport torch
import torchvisionfrom torch import nn#保存方式一,加載模型
# model = torch.load("vgg16.pth",weights_only=False)
# print(model)#方式二,加載模型
vgg16 = torchvision.models.vgg16(pretrained=False)
# model = torch.load("vgg16_state_dict.pth")
vgg16.load_state_dict(torch.load("vgg16_state_dict.pth"))
# print(vgg16)#陷阱
#陷阱
class Tudui(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(3,64,kernel_size=3)def forward(self, x):x = self.conv1(x)return x#如果直接這么調用的話,機器會找不到類在哪里
# 當你 torch.save(model) 保存整個模型時,它會把整個類的信息序列化。如果加載時當前文件找不到 Tudui 類,自然就炸了。
#可以將定義寫到這個類來,也可以在開頭寫from model_save import *
#!!!更推薦一下模式:
"""
# 保存
torch.save(model.state_dict(), "tudui_method2.pth")# 加載
model = Tudui()
model.load_state_dict(torch.load("tudui_method2.pth"))優點:不管類在哪個文件,只要 Tudui() 存在就能加載;避免因為 class 變動導致報錯;更加靈活,適合后期修改網絡結構。
"""
model = torch.load("tudui_method1.pth",weights_only=False)
print(model)