完整項目流程總結
1. 環境準備與依賴導入
import time import os import numpy as np import pandas as pd import torch import torch.nn as nn import torch.optim as optim import torchvision import torchvision.transforms as transforms from torch.utils.data import DataLoader from torchvision.datasets import CIFAR10 from torchvision.models import resnet18, ResNet18_Weights import wandb from torch.utils.tensorboard import SummaryWriter from sklearn.metrics import * import matplotlib.pyplot as plt
2. 數據準備與增強
# 數據增強變換 transform = transforms.Compose([transforms.RandomRotation(45),transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(p=0.5),transforms.RandomVerticalFlip(p=0.5),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2471, 0.2435, 0.2616)), ]) ? # 測試集變換 transformtest = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2471, 0.2435, 0.2616)), ]) ? # 數據集加載 train_dataset = CIFAR10(root=datapath,train=True,download=True,transform=transform, ) ? train_loader = DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True,num_workers=2, )
3. 模型構建與初始化
# 獲取ResNet18模型并調整全連接層 model = resnet18(weights=None) in_features = model.fc.in_features model.fc = nn.Linear(in_features=in_features, out_features=10) ? # 加載預訓練權重(如果有) if os.path.exists(weightpath):weights_default = torch.load(weightpath)weights_default.pop("fc.weight", None)weights_default.pop("fc.bias", None)new_state_dict = model.state_dict()weights_default_process = {k: v for k, v in weights_default.items() if k in new_state_dict}new_state_dict.update(weights_default_process)model.load_state_dict(new_state_dict) ? model.to(device)
4. 訓練過程
# 初始化訓練工具 loss_fn = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=lr) ? # 可視化工具初始化 wandb.init(project="my-qianyi-project", config={...}) write1 = SummaryWriter(log_dir=log_dir) write1.add_graph(model, input_to_model=torch.randn(1, 3, 32, 32).to(device)) ? # 訓練循環 for epoch in range(epochs):model.train()# 訓練代碼...torch.save(model.state_dict(), weightpath)
5. 驗證與評估
# 加載最佳模型進行驗證 model.load_state_dict(torch.load(weightpath)) model.eval() ? # 驗證過程 # 保存預測結果到CSV # 生成分類報告和混淆矩陣
6. 模型應
# 加載模型進行推理 def predict_image(image_path):# 圖像預處理# 模型預測# 返回結果
7. 模型移植與部署
7.1 模型轉換(PyTorch → ONNX/)
python
# 轉換為ONNX格式 def convert_to_onnx(model, input_size, onnx_path):model.eval()dummy_input = torch.randn(1, *input_size).to(device)torch.onnx.export(model,dummy_input,onnx_path,export_params=True,opset_version=11,do_constant_folding=True,input_names=['input'],output_names=['output'],dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}})print(f"Model converted to ONNX and saved to {onnx_path}") ? # 使用示例 convert_to_onnx(model, (3, 32, 32), "model.onnx")
7.2 模型量化(減小模型大小,加速推理)
python
# 動態量化 def quantize_model(model):quantized_model = torch.quantization.quantize_dynamic(model, {nn.Linear}, dtype=torch.qint8)return quantized_model ? # 使用示例 quantized_model = quantize_model(model) torch.save(quantized_model.state_dict(), "quantized_model.pth")
7.3 減少參數數量
# 簡單的權重剪枝 def prune_model(model, pruning_percentage=0.2):parameters_to_prune = []for name, module in model.named_modules():if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):parameters_to_prune.append((module, 'weight'))torch.nn.utils.prune.global_unstructured(parameters_to_prune,pruning_method=torch.nn.utils.prune.L1Unstructured,amount=pruning_percentage,)return model ? # 使用示例 pruned_model = prune_model(model)
7.4 移動端部署(使用ONNX Runtime)
# 保存為LibTorch格式(C++可用) example = torch.rand(1, 3, 32, 32).to(device) traced_script_module = torch.jit.trace(model, example) traced_script_module.save("model.pt")
7.5 Web部署(使用ONNX.js)
# 首先轉換為ONNX,然后使用ONNX.js在瀏覽器中運行 # 或者使用第三方工具如https://github.com/onnx/tensorflow-onnx
7.6 邊緣設備部署(使用TensorRT、OpenVINO等)
# 使用NVIDIA TensorRT優化(需要先轉換為ONNX) # 或使用Intel OpenVINO工具包
8. 性能監控與優化
# 模型推理速度測試 def benchmark_model(model, input_size, num_runs=100):model.eval()input_tensor = torch.randn(1, *input_size).to(device)# GPU預熱for _ in range(10):_ = model(input_tensor)# 計時start_time = time.time()for _ in range(num_runs):_ = model(input_tensor)end_time = time.time()avg_time = (end_time - start_time) / num_runsfps = 1 / avg_timeprint(f"Average inference time: {avg_time*1000:.2f} ms, FPS: {fps:.2f}")return avg_time, fps ? # 使用示例 benchmark_model(model, (3, 32, 32))
這個完整的流程涵蓋了從數據準備到模型部署的全過程,特別是新增的模型移植部分,提供了將訓練好的模型部署到不同平臺和設備的方法,這對于實際應用非常重要。