本文演示了PyTorch中張量(Tensor)和模型參數的保存與加載方法,并提供完整的代碼示例及輸出結果,幫助讀者快速掌握數據持久化的核心操作。
1. 保存和加載單個張量
通過torch.save
和torch.load
可以直接保存和讀取張量。
import torch# 創建并保存張量
x = torch.arange(4)
torch.save(x, 'x-file')# 加載張量
x2 = torch.load('x-file')
print(x2) # 輸出:tensor([0, 1, 2, 3])
輸出結果:
tensor([0, 1, 2, 3])
2. 保存和加載張量列表
可以將多個張量存儲為列表,并一次性加載。
# 創建兩個張量并保存為列表
y = torch.zeros(4)
torch.save([x, y], 'x-files')# 加載列表
x2, y2 = torch.load('x-files')
print((x2, y2))
輸出結果:
(tensor([0, 1, 2, 3]), tensor([0., 0., 0., 0.]))
3. 保存和加載字典
通過字典可以更靈活地管理多個張量。
# 創建字典并保存
mydict = {'x': x, 'y': y}
torch.save(mydict, 'mydict')# 加載字典
mydict2 = torch.load('mydict')
print(mydict2)
輸出結果:
{'x': tensor([0, 1, 2, 3]), 'y': tensor([0., 0., 0., 0.])}
4. 定義神經網絡模型
以下是一個簡單的全連接神經網絡示例:
from torch import nn
from torch.nn import functional as Fclass Model(nn.Module):def __init__(self):super().__init__()self.hidden = nn.Linear(20, 256) # 隱藏層self.output = nn.Linear(256, 10) # 輸出層def forward(self, x):return self.output(F.relu(self.hidden(x)))# 實例化模型并進行前向傳播
net = Model()
x = torch.rand(size=(2, 20))
y = net(x)
print(y)
輸出結果(因隨機初始化可能不同):
tensor([[-0.0711, 0.1161, -0.1113, ..., 0.0787],[-0.0151, 0.0275, -0.1652, ..., 0.0109]], grad_fn=<AddmmBackward0>)
5. 保存模型參數
使用state_dict
保存模型參數:
torch.save(net.state_dict(), 'net.params')
6. 加載模型參數并驗證
加載參數到新模型實例,并驗證一致性:
# 創建新模型并加載參數
clone = Model()
clone.load_state_dict(torch.load('net.params'))
clone.eval() # 設置為評估模式(關閉Dropout/BatchNorm等)# 比較輸出結果
Y_clone = clone(x)
print(Y_clone == y)
輸出結果:
tensor([[True, True, ..., True],[True, True, ..., True]])
總結
-
張量讀寫:直接使用
torch.save
和torch.load
,支持列表和字典。 -
模型參數保存:通過
state_dict
保存模型狀態,加載時需重新實例化模型。 -
驗證一致性:加載參數后,輸出與原模型一致表明操作成功。
通過本文的代碼示例,讀者可以快速掌握PyTorch中數據和模型參數的持久化方法,為模型訓練和部署提供便利。