以下是完整的操作流程:用 PyTorch 訓練模型 → 導出為 ONNX 格式 → 用 Java 加載并推理,兼顧開發效率(PyTorch 快速訓練)和生產部署(Java 穩定運行)。
一、PyTorch 訓練模型并導出為 ONNX
1. 安裝依賴
bash
pip install torch onnx # PyTorch 和 ONNX 庫
2. 訓練一個簡單模型(以線性回歸為例)
python
運行
import torch
import torch.nn as nn
import torch.optim as optim# 1. 定義模型(線性回歸:y = 2x + 3 + 噪聲)
class LinearModel(nn.Module):def __init__(self):super().__init__()self.linear = nn.Linear(in_features=1, out_features=1) # 輸入1維,輸出1維def forward(self, x):return self.linear(x)# 2. 生成訓練數據
x_train = torch.tensor([[1.0], [2.0], [3.0], [4.0]], dtype=torch.float32)
y_train = torch.tensor([[5.1], [7.2], [8.9], [11.3]], dtype=torch.float32) # 近似 y=2x+3# 3. 訓練模型
model = LinearModel()
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)for epoch in range(1000):optimizer.zero_grad()y_pred = model(x_train)l