AI技術實戰:從零搭建圖像分類系統全流程詳解

AI技術實戰:從零搭建圖像分類系統全流程詳解


在這里插入圖片描述

人工智能學習 https://www.captainbed.cn/ccc

前言

本文將以圖像分類任務為切入點,手把手教你完成AI模型從數據準備到工業部署的全鏈路開發。通過一個完整的Kaggle貓狗分類項目(代碼兼容PyTorch/TensorFlow),覆蓋以下核心技能:

  • 數據清洗與增強的工程化實現
  • 模型構建與訓練技巧
  • 模型壓縮與TensorRT部署優化
  • 可視化監控與性能調優

所有代碼均提供可運行的Colab鏈接,建議邊閱讀邊實踐。


目錄

  1. 環境搭建與數據準備

    • 1.1 本地/云端開發環境配置
    • 1.2 數據爬取與清洗腳本開發
    • 1.3 自動化標注工具實戰
  2. 圖像分類模型實戰

    • 2.1 手寫CNN模型構建(帶可運行代碼)
    • 2.2 遷移學習Fine-tuning技巧
    • 2.3 訓練過程可視化監控
  3. 模型優化與部署

    • 3.1 模型剪枝與量化壓縮
    • 3.2 ONNX格式轉換與TensorRT加速
    • 3.3 RESTful API服務封裝
  4. 工業級增強技巧

    • 4.1 解決類別不平衡問題
    • 4.2 應對小樣本學習的策略
    • 4.3 模型熱更新方案

1. 環境搭建與數據準備

1.1 開發環境配置(PyTorch示例)

# 創建虛擬環境
conda create -n ai_tutorial python=3.8
conda activate ai_tutorial# 安裝核心依賴
pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113
pip install opencv-python albumentations pandas

1.2 數據爬取實戰

# 使用Bing圖片下載API批量獲取數據
import requestsdef download_images(keyword, count=100):headers = {'Ocp-Apim-Subscription-Key': 'YOUR_API_KEY'}params = {'q': keyword, 'count': count}response = requests.get('https://api.bing.microsoft.com/v7.0/images/search', headers=headers, params=params)for idx, img in enumerate(response.json()['value']):img_data = requests.get(img['contentUrl']).contentwith open(f'dataset/{keyword}_{idx}.jpg', 'wb') as f:f.write(img_data)# 執行下載
download_images('cat')
download_images('dog')

1.3 自動化數據清洗

# 使用OpenCV過濾損壞圖片
import cv2
import osdef clean_dataset(folder):valid_extensions = ['.jpg', '.jpeg', '.png']for filename in os.listdir(folder):filepath = os.path.join(folder, filename)try:img = cv2.imread(filepath)if img is None or img.size == 0:os.remove(filepath)elif os.path.splitext(filename)[1].lower() not in valid_extensions:os.remove(filepath)except Exception as e:print(f"刪除損壞文件: {filename}")os.remove(filepath)clean_dataset('dataset/train')

2. 圖像分類模型實戰

2.1 自定義CNN模型(PyTorch實現)

import torch.nn as nnclass SimpleCNN(nn.Module):def __init__(self, num_classes=2):super().__init__()self.features = nn.Sequential(nn.Conv2d(3, 32, kernel_size=3, padding=1), # 輸入3通道nn.ReLU(),nn.MaxPool2d(2),nn.Conv2d(32, 64, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(2),nn.Conv2d(64, 128, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(2))self.classifier = nn.Sequential(nn.Linear(128 * 28 * 28, 512), # 根據輸入尺寸調整nn.ReLU(),nn.Dropout(0.5),nn.Linear(512, num_classes))def forward(self, x):x = self.features(x)x = x.view(x.size(0), -1)return self.classifier(x)

2.2 遷移學習實戰(ResNet50微調)

from torchvision import models# 加載預訓練模型
model = models.resnet50(pretrained=True)# 替換最后一層全連接
num_ftrs = model.fc.in_features
model.fc = nn.Sequential(nn.Linear(num_ftrs, 256),nn.ReLU(),nn.Dropout(0.4),nn.Linear(256, 2)
)# 凍結早期層參數
for param in model.parameters():param.requires_grad = False
for param in model.layer4.parameters():param.requires_grad = True

2.3 訓練過程可視化(TensorBoard集成)

from torch.utils.tensorboard import SummaryWriterwriter = SummaryWriter()for epoch in range(epochs):# 訓練代碼...writer.add_scalar('Loss/train', loss.item(), epoch)writer.add_scalar('Accuracy/train', acc, epoch)# 可視化特征圖if epoch % 10 == 0:writer.add_images('Feature Maps', model.features[0](images[:4]), epoch)

3. 模型優化與部署

3.1 模型剪枝實戰

import torch.nn.utils.prune as prune# 對卷積層進行L1非結構化剪枝
parameters_to_prune = ((model.features[0], 'weight'),(model.features[3], 'weight'),
)prune.global_unstructured(parameters_to_prune,pruning_method=prune.L1Unstructured,amount=0.2, # 剪枝20%的權重
)

3.2 TensorRT加速部署

# 導出ONNX模型
torch.onnx.export(model, dummy_input, "model.onnx",opset_version=11)# 使用TensorRT轉換
trt_cmd = f"""
trtexec --onnx=model.onnx \--saveEngine=model.trt \--fp16 \--workspace=2048
"""
os.system(trt_cmd)

3.3 封裝Flask API服務

from flask import Flask, request
import trt_inference  # 自定義TRT推理模塊app = Flask(__name__)@app.route('/predict', methods=['POST'])
def predict():file = request.files['image']img = preprocess(file.read())output = trt_inference.run(img)return {'class_id': int(output.argmax())}if __name__ == '__main__':app.run(host='0.0.0.0', port=5000)

4. 工業級增強技巧

4.1 類別不平衡解決方案

# 使用加權采樣器
from torch.utils.data import WeightedRandomSamplerclass_counts = [num_cat, num_dog] 
weights = 1. / torch.tensor(class_counts, dtype=torch.float)
samples_weights = weights[labels]sampler = WeightedRandomSampler(weights=samples_weights,num_samples=len(samples_weights),replacement=True
)

4.2 小樣本學習方案

# 使用MixUp數據增強
def mixup_data(x, y, alpha=0.2):lam = np.random.beta(alpha, alpha)batch_size = x.size()[0]index = torch.randperm(batch_size)mixed_x = lam * x + (1 - lam) * x[index]y_a, y_b = y, y[index]return mixed_x, y_a, y_b, lam# 修改損失函數
criterion = nn.CrossEntropyLoss()
loss = lam * criterion(output, y_a) + (1 - lam) * criterion(output, y_b)

在這里插入圖片描述

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/901280.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/901280.shtml
英文地址,請注明出處:http://en.pswp.cn/news/901280.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

NIPS2024論文 End-to-End Ontology Learning with Large Language Models

文章所謂的端到端本體學習,指的是從輸入到目標本體這個完整過程。在很多其他文章中,是把本體學習這個任務肢解了來做的,同樣也是肢解了之后評估。 文章號稱的貢獻,不但對通用本體學習提供所謂的baseline,而且還給出了驗…

【NLP】18. Encoder 和 Decoder

1. Encoder 和 Decoder 概述 在序列到序列(sequence-to-sequence,簡稱 seq2seq)的模型中,整個系統通常分為兩大部分:Encoder(編碼器)和 Decoder(解碼器)。 Encoder&…

Deepseek Bart模型相比Bert的優勢

BART(Bidirectional and Auto-Regressive Transformers)與BERT(Bidirectional Encoder Representations from Transformers)雖然均基于Transformer架構,但在模型設計、任務適配性和應用場景上存在顯著差異。以下是BART…

在人工智能與計算機技術融合的框架下探索高中教育數字化教學模式的創新路徑

一、引言 1.1 研究背景 在數字中國戰略與《中國教育現代化 2035》的政策導向下,人工智能與計算機技術的深度融合正深刻地重構著教育生態。隨著科技的飛速發展,全球范圍內的高中教育都面臨著培養具備數字化素養人才的緊迫需求,傳統的教學模式…

深度探索 C 語言:指針與內存管理的精妙藝術

C 語言作為一門歷史悠久且功能強大的編程語言,以其高效的性能和靈活的底層控制能力,在計算機科學領域占據著舉足輕重的地位。 指針和內存管理是 C 語言的核心特性,也是其最具挑戰性和魅力的部分。深入理解指針與內存管理,不僅能夠…

QQ郵箱授權碼如何獲取 QQ郵箱授權碼獲取方法介紹

QQ郵箱授權碼如何獲取 QQ郵箱授權碼獲取方法介紹 https://app.ali213.net/gl/857287.html

jupyter4.4安裝使用

一、chrome谷歌瀏覽器 1. 安裝 1.1 下載地址: 下載地址: https://www.google.cn/intl/zh-CN_ALL/chrome/fallback/ 2 插件markdown-viewer 2.1 下載地址: 下載地址:https://github.com/simov/markdown-viewer/releases 2.2…

STM32 HAL庫RTC實時時鐘超細詳解

一、引言 在嵌入式系統的應用中,實時時鐘(RTC)是一個非常重要的功能模塊。它能夠獨立于主系統提供精確的時間和日期信息,即使在系統斷電的情況下,也可以依靠備用電池繼續運行。STM32F407 是一款性能強大的微控制器&am…

vdso概念及原理,vdso_fault缺頁異常,vdso符號的獲取

一、背景 vdso的全稱是Virtual Dynamic Shared Object,它是一個特殊的共享庫,是在編譯內核時生成,并在內核鏡像里某一段地址段作為該共享庫的內容。vdso的前身是vsyscall,為了兼容一些舊的程序,x86上還是默認加載了vs…

Linux中的文件傳輸(附加詳細實驗案例)

一、實驗環境的設置 ①該實驗需要兩臺主機,虛擬機名稱為 L2 和 L3 ,在終端分別更改主機名為 node1 和 node2,在實驗過程能夠更好分辨。 然后再重新打開終端,主機名便都更改了相應的名稱。 ②用 ip a 的命令分別查看兩個主機的 …

【從0到1學Elasticsearch】Elasticsearch從入門到精通(上)

黑馬商城作為一個電商項目,商品的搜索肯定是訪問頻率最高的頁面之一。目前搜索功能是基于數據庫的模糊搜索來實現的,存在很多問題。 首先,查詢效率較低。 由于數據庫模糊查詢不走索引,在數據量較大的時候,查詢性能很差…

圖論基礎理論

在我看來,想要掌握圖的基礎應用,僅需要三步走。 什么是圖(基本概念)、圖的構造(打地基)、圖的遍歷方式(應用的基礎) 只要能OK的掌握這三步、就算圖論入門了!&#xff0…

詳細解讀react框架中的hooks

React Hooks 是 React 16.8 引入的一項革命性特性,它允許你在函數組件中使用狀態(state)和其他 React 特性,而無需編寫 class 組件。下面將詳細解讀 React Hooks 的核心概念、常用 Hooks 及其工作原理。 一、Hooks 的核心概念 1. 什么是 Hooks Hooks …

主機IP動態變化時如何通過固定host.docker.internal訪問本機服務

場景需求——主機IP動態變化時,通過固定的 http://host.docker.internal:11555 訪問本機服務,核心問題在于 host.docker.internal 的解析邏輯與動態IP的適配。以下是分步解決方案: 一、核心原理:host.docker.internal 的本質與局…

插值算法 - 最近鄰插值實現

目錄 1. 導入必要的庫 2. nearest_neighbor_interpolation 3. 測試代碼 數學原理 完整代碼 本文實現了基于最近鄰插值算法的圖像縮放功能。 它使用 Python 編寫,主要依賴于NumPy和PIL(Python Imaging Library)庫。 NumPy用于高效的數值計算,而PIL僅用于圖像的加載和…

windows中搭建Ubuntu子系統

windows中搭建虛擬環境 1.配置2.windows中搭建Ubuntu子系統2.1windows配置2.1.1 確認啟用私有化2.1.2 將wsl2設置為默認版本2.1.3 確認開啟相關配置2.1.4重啟windows以加載更改配置 2.2 搭建Ubuntu子系統2.2.1 下載Ubuntu2.2.2 遷移位置 3.Ubuntu子系統搭建docker環境3.1安裝do…

MySQL事務機制

目錄 原子性 持久性 隔離性 隔離級別(并發事務之間的關系) 讀未提交 讀已提交 可重復讀 串行化(最嚴格的隔離級別) 一致性 問題 不可重復讀性(已經提交的數據) 什么是臟讀問題(未提交的數據)? 幻讀 保存點 自動提交機制--autocommit 會話隔離級別與全局隔離級…

Cadence學習筆記之---直插元件的封裝制作

目錄 01 | 引 言 02 | 環境描述 03 | 操作步驟 04 | 結 語 01 | 引 言 在之前發布的Cadence小記中,已經講述了怎樣制作熱風焊盤,貼片(SMD)焊盤、通孔、過孔,以及貼片元件的封裝。 本篇關于Cadence的小記主要講如何制作直插元件的封裝。 …

【第四十周】文獻閱讀:用于檢索-增強大語言模型的查詢與重寫

目錄 摘要Abstract用于檢索-增強大語言模型的查詢與重寫研究背景方法論基于凍結LLM的重寫方案基于可訓練重寫器的方案重寫器預熱訓練(Rewriter Warm-up)強化學習(Reinforcement Learning) 創新性實驗結果局限性總結 摘要 這篇論文…

java學習總結(if switch for)

一.基本結構 1.單分支if int num 10; if (num > 5) {System.out.println("num 大于 5"); } 2.雙分支if-else int score 60; if (score > 60) {System.out.println("及格"); } else {System.out.println("不及格"); } 3.多分支 int…