Torch -- 卷積學習day4 -- 完整項目流程

完整項目流程總結

1. 環境準備與依賴導入

import time
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.models import resnet18, ResNet18_Weights
import wandb
from torch.utils.tensorboard import SummaryWriter
from sklearn.metrics import *
import matplotlib.pyplot as plt

2. 數據準備與增強

# 數據增強變換
transform = transforms.Compose([transforms.RandomRotation(45),transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(p=0.5),transforms.RandomVerticalFlip(p=0.5),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2471, 0.2435, 0.2616)),
])
?
# 測試集變換
transformtest = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2471, 0.2435, 0.2616)),
])
?
# 數據集加載
train_dataset = CIFAR10(root=datapath,train=True,download=True,transform=transform,
)
?
train_loader = DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True,num_workers=2,
)

3. 模型構建與初始化

# 獲取ResNet18模型并調整全連接層
model = resnet18(weights=None)
in_features = model.fc.in_features
model.fc = nn.Linear(in_features=in_features, out_features=10)
?
# 加載預訓練權重(如果有)
if os.path.exists(weightpath):weights_default = torch.load(weightpath)weights_default.pop("fc.weight", None)weights_default.pop("fc.bias", None)new_state_dict = model.state_dict()weights_default_process = {k: v for k, v in weights_default.items() if k in new_state_dict}new_state_dict.update(weights_default_process)model.load_state_dict(new_state_dict)
?
model.to(device)

4. 訓練過程

# 初始化訓練工具
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
?
# 可視化工具初始化
wandb.init(project="my-qianyi-project", config={...})
write1 = SummaryWriter(log_dir=log_dir)
write1.add_graph(model, input_to_model=torch.randn(1, 3, 32, 32).to(device))
?
# 訓練循環
for epoch in range(epochs):model.train()# 訓練代碼...torch.save(model.state_dict(), weightpath)

5. 驗證與評估

# 加載最佳模型進行驗證
model.load_state_dict(torch.load(weightpath))
model.eval()
?
# 驗證過程
# 保存預測結果到CSV
# 生成分類報告和混淆矩陣

6. 模型應

# 加載模型進行推理
def predict_image(image_path):# 圖像預處理# 模型預測# 返回結果

7. 模型移植與部署

7.1 模型轉換(PyTorch → ONNX/)

python

# 轉換為ONNX格式
def convert_to_onnx(model, input_size, onnx_path):model.eval()dummy_input = torch.randn(1, *input_size).to(device)torch.onnx.export(model,dummy_input,onnx_path,export_params=True,opset_version=11,do_constant_folding=True,input_names=['input'],output_names=['output'],dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}})print(f"Model converted to ONNX and saved to {onnx_path}")
?
# 使用示例
convert_to_onnx(model, (3, 32, 32), "model.onnx")
7.2 模型量化(減小模型大小,加速推理)

python

# 動態量化
def quantize_model(model):quantized_model = torch.quantization.quantize_dynamic(model, {nn.Linear}, dtype=torch.qint8)return quantized_model
?
# 使用示例
quantized_model = quantize_model(model)
torch.save(quantized_model.state_dict(), "quantized_model.pth")
7.3 減少參數數量
# 簡單的權重剪枝
def prune_model(model, pruning_percentage=0.2):parameters_to_prune = []for name, module in model.named_modules():if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):parameters_to_prune.append((module, 'weight'))torch.nn.utils.prune.global_unstructured(parameters_to_prune,pruning_method=torch.nn.utils.prune.L1Unstructured,amount=pruning_percentage,)return model
?
# 使用示例
pruned_model = prune_model(model)
7.4 移動端部署(使用ONNX Runtime)
# 保存為LibTorch格式(C++可用)
example = torch.rand(1, 3, 32, 32).to(device)
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("model.pt")
7.5 Web部署(使用ONNX.js)
# 首先轉換為ONNX,然后使用ONNX.js在瀏覽器中運行
# 或者使用第三方工具如https://github.com/onnx/tensorflow-onnx
7.6 邊緣設備部署(使用TensorRT、OpenVINO等)
# 使用NVIDIA TensorRT優化(需要先轉換為ONNX)
# 或使用Intel OpenVINO工具包

8. 性能監控與優化

# 模型推理速度測試
def benchmark_model(model, input_size, num_runs=100):model.eval()input_tensor = torch.randn(1, *input_size).to(device)# GPU預熱for _ in range(10):_ = model(input_tensor)# 計時start_time = time.time()for _ in range(num_runs):_ = model(input_tensor)end_time = time.time()avg_time = (end_time - start_time) / num_runsfps = 1 / avg_timeprint(f"Average inference time: {avg_time*1000:.2f} ms, FPS: {fps:.2f}")return avg_time, fps
?
# 使用示例
benchmark_model(model, (3, 32, 32))

這個完整的流程涵蓋了從數據準備到模型部署的全過程,特別是新增的模型移植部分,提供了將訓練好的模型部署到不同平臺和設備的方法,這對于實際應用非常重要。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/93926.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/93926.shtml
英文地址,請注明出處:http://en.pswp.cn/web/93926.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

MTK Linux DRM分析(七)- KMS drm_plane.c

一、簡介在 Linux DRM(Direct Rendering Manager)子系統中,Plane(平面)代表了一個圖像源,可以在掃描輸出過程中與 CRTC 混合或疊加顯示。每個 Plane 從 drm_framebuffer 中獲取輸入數據,并負責圖…

OpenHarmony之 藍牙子系統全棧剖析:從協議棧到芯片適配的端到端實踐(大合集)

1. 系統架構概述 OpenHarmony藍牙系統采用分層架構設計,基于HDF(Hardware Driver Foundation)驅動框架和系統能力管理(System Ability)機制實現。 1.1 架構層次 ┌─────────────────────────…

探索 Ultralytics YOLOv8標記圖片

1、下載YOLOv8模型文件 下載地址:https://docs.ultralytics.com/zh/models/yolov8/#performance-metrics 2、編寫python腳本 aaa.py import cv2 import numpy as np from ultralytics import YOLO import matplotlib.pyplot as pltdef plot_detection(image, box…

Matplotlib數據可視化實戰:Matplotlib子圖布局與管理入門

Matplotlib多子圖布局實戰 學習目標 通過本課程的學習,學員將掌握如何在Matplotlib中創建和管理多個子圖,了解子圖布局的基本原理和調整方法,能夠有效地展示多個數據集,提升數據可視化的效果。 相關知識點 Matplotlib子圖 學習內容…

【python實用小腳本-194】Python一鍵給PDF加水印:輸入文字秒出防偽文件——再也不用開Photoshop

Python一鍵給PDF加水印:輸入文字秒出防偽文件——再也不用開Photoshop PDF加水印, 本地腳本, 零會員費, 防偽標記, 瑞士軍刀 故事開場:一把瑞士軍刀救了投標的你 周五下午,你把 100 頁標書 PDF 發給客戶,卻擔心被同行盜用。 想加水…

開源 C++ QT Widget 開發(四)文件--二進制文件查看編輯

文章的目的為了記錄使用C 進行QT Widget 開發學習的經歷。臨時學習,完成app的開發。開發流程和要點有些記憶模糊,趕緊記錄,防止忘記。 相關鏈接: 開源 C QT Widget 開發(一)工程文件結構-CSDN博客 開源 C…

【密碼學實戰】X86、ARM、RISC-V 全量指令集與密碼加速技術全景解析

前言 CPU 指令集是硬件與軟件交互的核心橋梁,其設計直接決定計算系統的性能邊界與應用場景。在數字化時代,信息安全依賴密碼算法的高效實現,而指令集擴展則成為密碼加速的 “隱形引擎”—— 從服務器端的高吞吐量加密,到移動端的…

2025-08-21 Python進階2——數據結構

文章目錄1 列表(List)1.1 列表常用方法1.2 列表的特殊用途1.2.1 實現堆棧(后進先出)1.2.2 實現隊列(先進先出)1.3 列表推導式1.4 嵌套列表推導式2 del 語句3 元組(Tuple)4 集合&…

告別手工編寫測試腳本!Claude+Playwright MCP快速生成自動化測試腳本

在進行自動化測試時,前端頁面因為頻繁迭代UI 結構常有變動,這往往使得自動化測試的腳本往往“寫得快、廢得也快”,維護成本極高。在大模型之前大家往往都會使用錄制類工具,但錄制類工具生成的代碼靈活性較差、定位方式不太合理只能…

一款更適合 SpringBoot 的API文檔新選擇(Spring Boot 應用 API 文檔)

SpringDoc:Spring Boot 應用 API 文檔生成的現代化解決方案 概述 SpringDoc 是一個專為 Spring Boot 應用設計的開源庫,能夠自動生成符合 OpenAPI 3 規范的 API 文檔。它通過掃描項目中的控制器、方法注解及相關配置,動態生成 JSON/YAML/HTML…

文獻閱讀 250821-When and where soil dryness matters to ecosystem photosynthesis

When and where soil dryness matters to ecosystem photosynthesis 來自 <When and where soil dryness matters to ecosystem photosynthesis | Nature Plants> ## Abstract: Background: Projected increases in the intensity and frequency of droughts in the twen…

React學習(九)

目錄&#xff1a;1.react-進階-antd-新增2.react-進階-antd-刪除選中1.react-進階-antd-新增新增代碼&#xff0c;跟需改的代碼類似&#xff0c;直接copy修改組件代碼進行修改userEffect可以先帶著&#xff0c;沒啥用A6組件用到的函數跟修改的也類似&#xff1a;這個useEffect函…

零基礎從頭教學Linux(Day 17)

三層交換機一、三層交換機的配置1.關于如何配置三層交換機&#xff0c;首先我們應該先創建VLANSwitch>en Switch#vlan database % Warning: It is recommended to configure VLAN from config mode,as VLAN database mode is being deprecated. Please consult userdocument…

任務十四 推薦頁面接口開發

一、接口準備 在對接qq音樂接口之前,首先要將之前的項目,一定要記得備份一份; 備份完成之后,首先要在vscode終端安裝axios,這個是請求后端的工具,和之前的ajax一樣,都是請求后端的工具。只不過axios更專業化,跟強大 至于qq音樂接口怎么獲取,一般有兩個途徑,第一個是…

醫療AI與醫院數據倉庫的智能化升級:異構采集、精準評估與高效交互的融合方向(下)

核心功能創新詳解: 統一門戶與角色化工作臺: 統一入口: 用戶通過單一URL登錄,系統根據其角色和權限自動呈現專屬工作臺。 角色化工作臺: 臨床醫生工作臺: 首屏展示常用患者查詢入口、快速統計(如“我的患者檢驗異常趨勢”)、相關臨床文獻推薦、待處理任務(如報告審核)…

數據庫面試常見問題

數據庫 Delete Truncate Drop 區別 答:這三個操作都是針對數據庫的表進行操作,都有刪除表的功能,其中的區別在于: Delete:只將表中的數據進行刪除,不刪除定義不釋放空間,是dml語句,需要提交事務,如果不想刪除可以回滾。delete每次刪除一行,并在事務日志中為所刪除…

用nohup setsid繞過超時斷連,穩定反彈Shell

在We滲透過程中&#xff0c;我們常常會利用目標系統的遠程代碼執行&#xff08;RCE&#xff09;漏洞進行反彈Shell。然而&#xff0c;由于Web服務器&#xff08;如PHP、Python后端&#xff09;的執行環境通常存在超時限制&#xff08;如max_execution_time或進程管理策略&#…

Java設計模式-模板方法模式

Java設計模式-模板方法模式 模式概述 模板方法模式簡介 核心思想&#xff1a;定義一個操作中的算法骨架&#xff08;模板方法&#xff09;&#xff0c;將算法中某些步驟的具體實現延遲到子類中完成。子類可以在不改變算法整體結構的前提下&#xff0c;重定義這些步驟的行為&…

Centos7物理安裝 Redis8.2.0

Centos7物理安裝 Redis8.2.0一、準備依賴環境首先安裝編譯 Redis 所需的依賴&#xff1a;# CentOS/RHEL系統 yum install -y gcc gcc-c make wget 二、下載并編譯 Redis 8.2.0# 1. 下載Redis 8.2.0源碼包 wget https://download.redis.io/releases/redis-8.2.0.tar.gz# 2. 解壓…

牛津大學xDeepMind 自然語言處理(3)

條件語言模型無條件語言模型 概率計算&#xff1a;通過鏈式法則分解為預測下一詞概率&#xff08;將語言建模問題簡化為建模給定前面詞語歷史的下一個詞的概率&#xff09;基于循環神經網絡的無條件語言模型&#xff1a;根據歷史詞語預測下一個詞的概率條件語言模型 定義&#…