Pytorch項目實戰-2:花卉分類

一、前言

在深度學習項目中,數據集的處理和模型的訓練、測試、預測是關鍵環節。本文將為小白詳細介紹從數據集搜集、清洗、劃分到模型訓練、測試、預測以及模型結構查看的全流程,附帶代碼和操作說明,讓你輕松上手!

二、數據集

二、數據集獲取

2.1 自建數據集 vs 公開數據集

  • 自建數據集:適合本科畢設、大作業等小規模場景,可通過自己拍攝爬蟲爬取(如百度圖片)構建。
  • 公開數據集:適合專業研究,例如醫學圖像分割可從ISIC Archive獲取。

2.2 百度圖片爬蟲實戰(附代碼)

代碼文件:data_get.py

# -*- coding: utf-8 -*-
import requests
import re
import osheaders = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) Chrome/84.0.4147.125 Safari/537.36'}
name = input('請輸入要爬取的圖片類別:')
num = 0
num_1 = 0
num_2 = 0
x = input('請輸入要爬取的圖片數量?(1=60張,2=120張):')
list_1 = []for i in range(int(x)):name_1 = os.getcwd()name_2 = os.path.join(name_1, 'data/' + name)url = f'https://image.baidu.com/search/flip?tn=baiduimage&ie=utf-8&word={name}&pn={i*30}'res = requests.get(url, headers=headers)htlm_1 = res.content.decode()a = re.findall('"objURL":"(.*?)",', htlm_1)if not os.path.exists(name_2):os.makedirs(name_2)for b in a:try:b_2 = re.findall('https:(.*?)&', b)[0]  # 提取圖片URLif b_2 not in list_1:num += 1img = requests.get(b)save_path = os.path.join(name_1, 'data/' + name, f'{name}{num}.jpg')with open(save_path, 'ab') as f:f.write(img.content)print(f'---------正在下載第{num}張圖片----------')list_1.append(b_2)else:num_1 += 1  # 統計重復圖片except:print(f'---------第{num}張圖片無法下載----------')num_2 += 1  # 統計失敗圖片print(f'下載完成!總共下載{num+num_1+num_2}張,成功{num}張,重復{num_1}張,失敗{num_2}張')

使用步驟

  1. 保存代碼為data_get.py
  2. 運行后輸入圖片類別(如 “向日葵”)和數量(1 或 2)
  3. 圖片會自動保存到data/類別名目錄下

?三、數據集清洗(解決中文路徑和壞圖問題)

3.1 為什么需要清洗?

  • OpenCV 對中文路徑支持差,會導致讀取錯誤
  • 爬取的圖片可能包含損壞文件(無法讀取的壞圖)

3.2 清洗代碼(data_clean.py)

import shutil
import cv2
import os
import numpy as np
from tqdm import tqdmdef cv_imread(file_path, type=-1):"""支持中文路徑讀取圖片"""img = cv2.imdecode(np.fromfile(file_path, dtype=np.uint8), -1)return cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) if type==0 else imgdef cv_imwrite(file_path, cv_img, is_gray=True):"""支持中文路徑保存圖片"""if len(cv_img.shape)==3 and is_gray:cv_img = cv_img[:, :, 0]cv2.imencode(os.path.splitext(file_path)[1], cv_img)[1].tofile(file_path)def data_clean(src_folder, english_name):clean_folder = f{src_folder}_cleanedif os.path.isdir(clean_folder):shutil.rmtree(clean_folder)  # 刪除已存在目錄os.mkdir(clean_folder)image_names = os.listdir(src_folder)with tqdm(total=len(image_names)) as pabr:for i, name in enumerate(image_names):path = os.path.join(src_folder, name)try:img = cv_imread(path)# 保存為英文名稱的JPG圖片save_name = f{english_name}_{i}.jpgsave_path = os.path.join(clean_folder, save_name)cv_imwrite(save_path, img, is_gray=False)except:print(f{name}是壞圖)pabr.update(1)if __name__ == __main__:data_clean(src_folder=D:/數據集/向日葵, english_name=sunflowers)  # 替換為你的路徑

運行結果

  • 生成原目錄_cleaned文件夾,存放清洗后的圖片
  • 自動跳過壞圖,重命名為英文(如sunflowers_0.jpg

四、數據集劃分(6:2:2 比例)

4.1 適用場景

  • 當數據集未區分訓練集、驗證集、測試集時使用
  • 要求:圖片按類別存放在子目錄下(如data/向日葵,?data/玫瑰

4.2 劃分代碼(data_split.py)

import os
import shutil
import random
from tqdm import tqdmdef split_data(src_dir, save_dir, ratios=[0.6, 0.2, 0.2]):os.makedirs(save_dir, exist_ok=True)categories = os.listdir(src_dir)for cate in categories:cate_path = os.path.join(src_dir, cate)imgs = os.listdir(cate_path)random.shuffle(imgs)total = len(imgs)# 計算劃分索引train_idx = int(total * ratios[0])val_idx = train_idx + int(total * ratios[1])# 劃分數據集for phase, start, end in zip(['train', 'val', 'test'], [0, train_idx, val_idx]):phase_dir = os.path.join(save_dir, phase, cate)os.makedirs(phase_dir, exist_ok=True)for img in tqdm(imgs[start:end], desc=fProcessing {phase} {cate}):src_img = os.path.join(cate_path, img)dest_img = os.path.join(phase_dir, img)shutil.copyfile(src_img, dest_img)if __name__ == __main__:src_dir = D:/數據集_cleaned  # 清洗后的數據集路徑save_dir = D:/數據集_split  # 劃分結果保存路徑split_data(src_dir, save_dir)

關鍵操作

  1. 修改src_dir為清洗后的數據集路徑
  2. 運行后生成save_dir/split目錄,包含train/val/test子目錄
  3. 比例可在ratios參數中調整(總和需為 1)

五、模型訓練(以 ResNet50 為例)

5.1 準備工作

一開始執行之前會有一個會需要下載預訓練模型到指定目錄,由于眾所周知的原因,大家需要提前先把模型下載下來放置到這個目錄,這個大家自行探索。

image20221201114528829

右鍵直接運行train.py就可以開始訓練模型,代碼首先會輸出模型的基本信息(模型有幾個卷積層、池化層、全連接層構成)和運行的記錄。

  • 下載預訓練模型(如 ResNet50),放入指定目錄(代碼中標記TODO
  • 確保數據集劃分正確(訓練集路徑需對應)

5.2 開始訓練

from torchutils import *
from torchvision import datasets, models, transforms
import os.path as osp
import os
if torch.cuda.is_available():device = torch.device('cuda:0')
else:device = torch.device('cpu')
print(f'Using device: {device}')
# 固定隨機種子,保證實驗結果是可以復現的
seed = 42
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True
data_path = r"G:\code\2023_pytorch110_classification_42-master\flowers_5_split" # todo 數據集路徑
# 注: 執行之前請先劃分數據集
# 超參數設置
params = {# 'model': 'vit_tiny_patch16_224',  # 選擇預訓練模型# 'model': 'resnet50d',  # 選擇預訓練模型'model': 'efficientnet_b3a',  # 選擇預訓練模型"img_size": 224,  # 圖片輸入大小"train_dir": osp.join(data_path, "train"),  # todo 訓練集路徑"val_dir": osp.join(data_path, "val"),  # todo 驗證集路徑'device': device,  # 設備'lr': 1e-3,  # 學習率'batch_size': 4,  # 批次大小'num_workers': 0,  # 進程'epochs': 10,  # 輪數"save_dir": "../checkpoints/",  # todo 保存路徑"pretrained": True,"num_classes": len(os.listdir(osp.join(data_path, "train"))),  # 類別數目, 自適應獲取類別數目'weight_decay': 1e-5  # 學習率衰減
}# 定義模型
class SELFMODEL(nn.Module):def __init__(self, model_name=params['model'], out_features=params['num_classes'],pretrained=True):super().__init__()self.model = timm.create_model(model_name, pretrained=pretrained)  # 從預訓練的庫中加載模型# self.model = timm.create_model(model_name, pretrained=pretrained, checkpoint_path="pretrained/resnet50d_ra2-464e36ba.pth")  # 從預訓練的庫中加載模型# classifierif model_name[:3] == "res":n_features = self.model.fc.in_features  # 修改全連接層數目self.model.fc = nn.Linear(n_features, out_features)  # 修改為本任務對應的類別數目elif model_name[:3] == "vit":n_features = self.model.head.in_features  # 修改全連接層數目self.model.head = nn.Linear(n_features, out_features)  # 修改為本任務對應的類別數目else:n_features = self.model.classifier.in_featuresself.model.classifier = nn.Linear(n_features, out_features)# resnet修改最后的全鏈接層print(self.model)  # 返回模型def forward(self, x):  # 前向傳播x = self.model(x)return x# 定義訓練流程
def train(train_loader, model, criterion, optimizer, epoch, params):metric_monitor = MetricMonitor()  # 設置指標監視器model.train()  # 模型設置為訓練模型nBatch = len(train_loader)stream = tqdm(train_loader)for i, (images, target) in enumerate(stream, start=1):  # 開始訓練images = images.to(params['device'], non_blocking=True)  # 加載數據target = target.to(params['device'], non_blocking=True)  # 加載模型output = model(images)  # 數據送入模型進行前向傳播loss = criterion(output, target.long())  # 計算損失f1_macro = calculate_f1_macro(output, target)  # 計算f1分數recall_macro = calculate_recall_macro(output, target)  # 計算recall分數acc = accuracy(output, target)  # 計算準確率分數metric_monitor.update('Loss', loss.item())  # 更新損失metric_monitor.update('F1', f1_macro)  # 更新f1metric_monitor.update('Recall', recall_macro)  # 更新recallmetric_monitor.update('Accuracy', acc)  # 更新準確率optimizer.zero_grad()  # 清空學習率loss.backward()  # 損失反向傳播optimizer.step()  # 更新優化器lr = adjust_learning_rate(optimizer, epoch, params, i, nBatch)  # 調整學習率stream.set_description(  # 更新進度條"Epoch: {epoch}. Train.      {metric_monitor}".format(epoch=epoch,metric_monitor=metric_monitor))return metric_monitor.metrics['Accuracy']["avg"], metric_monitor.metrics['Loss']["avg"]  # 返回結果# 定義驗證流程
def validate(val_loader, model, criterion, epoch, params):metric_monitor = MetricMonitor()  # 驗證流程model.eval()  # 模型設置為驗證格式stream = tqdm(val_loader)  # 設置進度條with torch.no_grad():  # 開始推理for i, (images, target) in enumerate(stream, start=1):images = images.to(params['device'], non_blocking=True)  # 讀取圖片target = target.to(params['device'], non_blocking=True)  # 讀取標簽output = model(images)  # 前向傳播loss = criterion(output, target.long())  # 計算損失f1_macro = calculate_f1_macro(output, target)  # 計算f1分數recall_macro = calculate_recall_macro(output, target)  # 計算recall分數acc = accuracy(output, target)  # 計算accmetric_monitor.update('Loss', loss.item())  # 后面基本都是更新進度條的操作metric_monitor.update('F1', f1_macro)metric_monitor.update("Recall", recall_macro)metric_monitor.update('Accuracy', acc)stream.set_description("Epoch: {epoch}. Validation. {metric_monitor}".format(epoch=epoch,metric_monitor=metric_monitor))return metric_monitor.metrics['Accuracy']["avg"], metric_monitor.metrics['Loss']["avg"]# 展示訓練過程的曲線
def show_loss_acc(acc, loss, val_acc, val_loss, sava_dir):# 從history中提取模型訓練集和驗證集準確率信息和誤差信息# 按照上下結構將圖畫輸出plt.figure(figsize=(8, 8))plt.subplot(2, 1, 1)plt.plot(acc, label='Training Accuracy')plt.plot(val_acc, label='Validation Accuracy')plt.legend(loc='lower right')plt.ylabel('Accuracy')plt.ylim([min(plt.ylim()), 1])plt.title('Training and Validation Accuracy')plt.subplot(2, 1, 2)plt.plot(loss, label='Training Loss')plt.plot(val_loss, label='Validation Loss')plt.legend(loc='upper right')plt.ylabel('Cross Entropy')plt.title('Training and Validation Loss')plt.xlabel('epoch')# 保存在savedir目錄下。save_path = osp.join(save_dir, "results.png")plt.savefig(save_path, dpi=100)if __name__ == '__main__':accs = []losss = []val_accs = []val_losss = []data_transforms = get_torch_transforms(img_size=params["img_size"])  # 獲取圖像預處理方式train_transforms = data_transforms['train']  # 訓練集數據處理方式valid_transforms = data_transforms['val']  # 驗證集數據集處理方式train_dataset = datasets.ImageFolder(params["train_dir"], train_transforms)  # 加載訓練集valid_dataset = datasets.ImageFolder(params["val_dir"], valid_transforms)  # 加載驗證集if params['pretrained'] == True:save_dir = osp.join(params['save_dir'], params['model']+"_pretrained_" + str(params["img_size"]))  # 設置模型保存路徑else:save_dir = osp.join(params['save_dir'], params['model'] + "_nopretrained_" + str(params["img_size"]))  # 設置模型保存路徑if not osp.isdir(save_dir):  # 如果保存路徑不存在的話就創建os.makedirs(save_dir)  #print("save dir {} created".format(save_dir))train_loader = DataLoader(  # 按照批次加載訓練集train_dataset, batch_size=params['batch_size'], shuffle=True,num_workers=params['num_workers'], pin_memory=True,)val_loader = DataLoader(  # 按照批次加載驗證集valid_dataset, batch_size=params['batch_size'], shuffle=False,num_workers=params['num_workers'], pin_memory=True,)print(train_dataset.classes)model = SELFMODEL(model_name=params['model'], out_features=params['num_classes'],pretrained=params['pretrained']) # 加載模型# model = nn.DataParallel(model)  # 模型并行化,提高模型的速度# resnet50d_1epochs_accuracy0.50424_weights.pthmodel = model.to(params['device'])  # 模型部署到設備上criterion = nn.CrossEntropyLoss().to(params['device'])  # 設置損失函數optimizer = torch.optim.AdamW(model.parameters(), lr=params['lr'], weight_decay=params['weight_decay'])  # 設置優化器# 損失函數和優化器可以自行設置修改。# criterion = nn.CrossEntropyLoss().to(params['device'])  # 設置損失函數# optimizer = torch.optim.AdamW(model.parameters(), lr=params['lr'], weight_decay=params['weight_decay'])  # 設置優化器best_acc = 0.0  # 記錄最好的準確率# 只保存最好的那個模型。for epoch in range(1, params['epochs'] + 1):  # 開始訓練acc, loss = train(train_loader, model, criterion, optimizer, epoch, params)val_acc, val_loss = validate(val_loader, model, criterion, epoch, params)accs.append(acc)losss.append(loss)val_accs.append(val_acc)val_losss.append(val_loss)if val_acc >= best_acc:# 保存的時候設置一個保存的間隔,或者就按照目前的情況,如果前面的比后面的效果好,就保存一下。# 按照間隔保存的話得不到最好的模型。save_path = osp.join(save_dir, f"{params['model']}_{epoch}epochs_accuracy{acc:.5f}_weights.pth")torch.save(model.state_dict(), save_path)best_acc = val_accshow_loss_acc(accs, losss, val_accs, val_losss, save_dir)print("訓練已完成,模型和訓練日志保存在: {}".format(save_dir))

運行結果:

  • 輸出模型結構(卷積層 / 池化層 / 全連接層)
  • 保存訓練曲線(acc.pngloss.png
  • 自動保存最優模型到指定目錄

六、模型測試與預測

6.1 測試代碼(test.py)

python

from torchutils import *
from torchvision import datasets, models, transforms
import os.path as osp
import os
from train import SELFMODELif torch.cuda.is_available():device = torch.device('cuda:0')
else:device = torch.device('cpu')
print(f'Using device: {device}')
# 固定隨機種子,保證實驗結果是可以復現的
seed = 42
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = Truedata_path = "../flowers_data_split"  # todo 修改為數據集根目錄
model_path = "../checkpoints/resnet50d_pretrained_224/resnet50d_10epochs_accuracy0.99501_weights.pth"  # todo 模型地址
model_name = 'resnet50d'  # todo 模型名稱
img_size = 224  # todo 數據集訓練時輸入模型的大小
# 注: 執行之前請先劃分數據集
# 超參數設置
params = {# 'model': 'vit_tiny_patch16_224',  # 選擇預訓練模型# 'model': 'efficientnet_b3a',  # 選擇預訓練模型'model': model_name,  # 選擇預訓練模型"img_size": img_size,  # 圖片輸入大小"test_dir": osp.join(data_path, "test"),  # todo 測試集子目錄'device': device,  # 設備'batch_size': 4,  # 批次大小'num_workers': 0,  # 進程"num_classes": len(os.listdir(osp.join(data_path, "train"))),  # 類別數目, 自適應獲取類別數目
}def test(val_loader, model, params, class_names):metric_monitor = MetricMonitor()  # 驗證流程model.eval()  # 模型設置為驗證格式stream = tqdm(val_loader)  # 設置進度條# 對模型分開進行推理test_real_labels = []test_pre_labels = []with torch.no_grad():  # 開始推理for i, (images, target) in enumerate(stream, start=1):images = images.to(params['device'], non_blocking=True)  # 讀取圖片target = target.to(params['device'], non_blocking=True)  # 讀取標簽output = model(images)  # 前向傳播# loss = criterion(output, target.long())  # 計算損失# print(output)target_numpy = target.cpu().numpy()y_pred = torch.softmax(output, dim=1)y_pred = torch.argmax(y_pred, dim=1).cpu().numpy()test_real_labels.extend(target_numpy)test_pre_labels.extend(y_pred)# print(target_numpy)# print(y_pred)f1_macro = calculate_f1_macro(output, target)  # 計算f1分數recall_macro = calculate_recall_macro(output, target)  # 計算recall分數acc = accuracy(output, target)  # 計算acc# metric_monitor.update('Loss', loss.item())  # 后面基本都是更新進度條的操作metric_monitor.update('F1', f1_macro)metric_monitor.update("Recall", recall_macro)metric_monitor.update('Accuracy', acc)stream.set_description("mode: {epoch}.  {metric_monitor}".format(epoch="test",metric_monitor=metric_monitor))class_names_length = len(class_names)heat_maps = np.zeros((class_names_length, class_names_length))for test_real_label, test_pre_label in zip(test_real_labels, test_pre_labels):heat_maps[test_real_label][test_pre_label] = heat_maps[test_real_label][test_pre_label] + 1# print(heat_maps)heat_maps_sum = np.sum(heat_maps, axis=1).reshape(-1, 1)# print(heat_maps_sum)# print()heat_maps_float = heat_maps / heat_maps_sum# print(heat_maps_float)# title, x_labels, y_labels, harvestshow_heatmaps(title="heatmap", x_labels=class_names, y_labels=class_names, harvest=heat_maps_float,save_name="record/heatmap_{}.png".format(model_name))# 加上模型名稱return metric_monitor.metrics['Accuracy']["avg"], metric_monitor.metrics['F1']["avg"], \metric_monitor.metrics['Recall']["avg"]def show_heatmaps(title, x_labels, y_labels, harvest, save_name):# 這里是創建一個畫布fig, ax = plt.subplots()# cmap https://blog.csdn.net/ztf312/article/details/102474190im = ax.imshow(harvest, cmap="OrRd")# 這里是修改標簽# We want to show all ticks...ax.set_xticks(np.arange(len(y_labels)))ax.set_yticks(np.arange(len(x_labels)))# ... and label them with the respective list entriesax.set_xticklabels(y_labels)ax.set_yticklabels(x_labels)# 因為x軸的標簽太長了,需要旋轉一下,更加好看# Rotate the tick labels and set their alignment.plt.setp(ax.get_xticklabels(), rotation=45, ha="right",rotation_mode="anchor")# 添加每個熱力塊的具體數值# Loop over data dimensions and create text annotations.for i in range(len(x_labels)):for j in range(len(y_labels)):text = ax.text(j, i, round(harvest[i, j], 2),ha="center", va="center", color="black")ax.set_xlabel("Predict label")ax.set_ylabel("Actual label")ax.set_title(title)fig.tight_layout()plt.colorbar(im)plt.savefig(save_name, dpi=100)# plt.show()if __name__ == '__main__':data_transforms = get_torch_transforms(img_size=params["img_size"])  # 獲取圖像預處理方式# train_transforms = data_transforms['train']  # 訓練集數據處理方式valid_transforms = data_transforms['val']  # 驗證集數據集處理方式# valid_dataset = datasets.ImageFolder(params["val_dir"], valid_transforms)  # 加載驗證集# print(valid_dataset)test_dataset = datasets.ImageFolder(params["test_dir"], valid_transforms)class_names = test_dataset.classesprint(class_names)# valid_dataset = datasets.ImageFolder(params["val_dir"], valid_transforms)  # 加載驗證集test_loader = DataLoader(  # 按照批次加載訓練集test_dataset, batch_size=params['batch_size'], shuffle=True,num_workers=params['num_workers'], pin_memory=True,)# 加載模型model = SELFMODEL(model_name=params['model'], out_features=params['num_classes'],pretrained=False)  # 加載模型結構,加載模型結構過程中pretrained設置為False即可。weights = torch.load(model_path)model.load_state_dict(weights)model.eval()model.to(device)# 指標上的測試結果包含三個方面,分別是acc f1 和 recall, 除此之外,應該還有相應的熱力圖輸出,整體會比較好看一些。acc, f1, recall = test(test_loader, model, params, class_names)print("測試結果:")print(f"acc: {acc}, F1: {f1}, recall: {recall}")print("測試完成,heatmap保存在{}下".format("record"))

6.2 圖片預測(predict.py)

import torch
# from train_resnet import SelfNet
from train import SELFMODEL
import os
import os.path as osp
import shutil
import torch.nn as nn
from PIL import Image
from torchutils import get_torch_transformsif torch.cuda.is_available():device = torch.device('cuda')
else:device = torch.device('cpu')model_path = "../checkpoints/resnet50d_pretrained_224/resnet50d_10epochs_accuracy0.99501_weights.pth"  # todo  模型路徑
classes_names = ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']  # todo 類名
img_size = 224  # todo 圖片大小
model_name = "resnet50d"  # todo 模型名稱
num_classes = len(classes_names)  # todo 類別數目def predict_batch(model_path, target_dir, save_dir):data_transforms = get_torch_transforms(img_size=img_size)valid_transforms = data_transforms['val']# 加載網絡model = SELFMODEL(model_name=model_name, out_features=num_classes, pretrained=False)# model = nn.DataParallel(model)weights = torch.load(model_path)model.load_state_dict(weights)model.eval()model.to(device)# 讀取圖片image_names = os.listdir(target_dir)for i, image_name in enumerate(image_names):image_path = osp.join(target_dir, image_name)img = Image.open(image_path)img = valid_transforms(img)img = img.unsqueeze(0)img = img.to(device)output = model(img)label_id = torch.argmax(output).item()predict_name = classes_names[label_id]save_path = osp.join(save_dir, predict_name)if not osp.isdir(save_path):os.makedirs(save_path)shutil.copy(image_path, save_path)print(f"{i + 1}: {image_name} result {predict_name}")def predict_single(model_path, image_path):data_transforms = get_torch_transforms(img_size=img_size)# train_transforms = data_transforms['train']valid_transforms = data_transforms['val']# 加載網絡model = SELFMODEL(model_name=model_name, out_features=num_classes, pretrained=False)# model = nn.DataParallel(model)weights = torch.load(model_path)model.load_state_dict(weights)model.eval()model.to(device)# 讀取圖片img = Image.open(image_path)img = valid_transforms(img)img = img.unsqueeze(0)img = img.to(device)output = model(img)label_id = torch.argmax(output).item()predict_name = classes_names[label_id]print(f"{image_path}'s result is {predict_name}")if __name__ == '__main__':# 批量預測函數predict_batch(model_path=model_path,target_dir="D:/upppppppppp/cls/cls_torch_tem/images/test_imgs/mini",save_dir="D:/upppppppppp/cls/cls_torch_tem/images/test_imgs/mini_result")# 單張圖片預測函數# predict_single(model_path=model_path, image_path="images/test_imgs/506659320_6fac46551e.jpg")

七、模型結構與參數量查看

7.1 查看模型結構(Netron 工具)

  1. 將模型轉換為 ONNX 格式(代碼utils/export_onnx.py):
import numpy as np
import onnxruntime
from PIL import Imageclass_names = {'0': '雛菊', '1': '蒲公英', '2': '玫瑰', '3': '向日葵', '4': '郁金香'}# transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])#均值,標準差
# 預測圖片
session = onnxruntime.InferenceSession(r"C:\Users\nongc\Desktop\ImageClassifier.onnx")def process_image(image_path):# 讀取測試數據img = Image.open(image_path)# Resize,thumbnail方法只能進行縮小,所以進行了判斷if img.size[0] > img.size[1]:img.thumbnail((10000, 256))else:img.thumbnail((256, 10000))# Crop操作left_margin = (img.width - 224) / 2bottom_margin = (img.height - 224) / 2right_margin = left_margin + 224top_margin = bottom_margin + 224img = img.crop((left_margin, bottom_margin, right_margin,top_margin))# img.save('thumb.jpg')# 相同的預處理方法img = np.array(img) / 255mean = np.array([0.485, 0.456, 0.406])  # provided meanstd = np.array([0.229, 0.224, 0.225])  # provided stdimg = (img - mean) / std# 注意顏色通道應該放在第一個位置img = img.transpose((2, 0, 1))return imgimage_path = r"C:\Users\nongc\Desktop\百度云下載\2023_pytorch110_classification_42-master\2023_pytorch110_classification_42-master\flowers_5\roses\99383371_37a5ac12a3_n.jpg"  # '1':
img = process_image(image_path)
img = np.expand_dims(img, 0)outputs = session.run([], {"modelInput": img.astype('float32')})
result_index = int(np.argmax(np.squeeze(outputs)))
result = class_names['%d' % result_index]  # 獲得對應的名稱print(np.squeeze(outputs), '\n', img.shape)
print(f"預測種類為: {result} 對應索引為:{np.argmax(np.squeeze(outputs))}")
# print(np.min(outputs),np.argmin(np.squeeze(outputs)),np.max(outputs))

?打開Netron 官網,拖入resnet50.onnx即可可視化模型結構。

7.2 計算參數量(get_flops.py)

import torch
from torchstat import stat
from train import SELFMODELif torch.cuda.is_available():device = torch.device('cuda')
else:device = torch.device('cpu')
model_name = "resnet50d" # todo 模型名稱
num_classes = 5 # todo 類別數目
model_path = "../../checkpoints/resnet50d_pretrained_224/resnet50d_10epochs_accuracy0.99501_weights.pth" # todo 模型地址
model = SELFMODEL(model_name=model_name, out_features=num_classes, pretrained=False)
weights = torch.load(model_path)
model.load_state_dict(weights)
model.eval()
stat(model, (3, 224, 224)) # 后面的224表示模型的輸入大小

八、總結

本文覆蓋了深度學習項目的核心流程:數據獲取→清洗→劃分→訓練→測試→預測,并提供了可直接運行的代碼和詳細操作說明。對于小白來說,建議先從簡單數據集(如花卉分類)入手,逐步熟悉每個環節,遇到問題可參考代碼中的TODO注釋和報錯信息排查。

?

?

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/84074.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/84074.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/84074.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

React Flow 邊事件處理實戰:鼠標事件、鍵盤操作及連接規則設置(附完整代碼)

本文為《React Agent:從零開始構建 AI 智能體》專欄系列文章。 專欄地址:https://blog.csdn.net/suiyingy/category_12933485.html。項目地址:https://gitee.com/fgai/react-agent(含完整代碼示?例與實戰源)。完整介紹…

java小結(一)

java(上) 模塊一 1.JDK,JRE,JVM 知識點 核心內容 易混淆點 JDK定義 Java Development Kit(Java開發工具包),包含開發所需全部工具 JDK包含JRE的關系容易混淆 JRE定義 Java Runtime Environment(Jav…

ddns-go安裝介紹-強大的ipv6動態域名解析神器-家庭云計算專家

ddns-go 是一款輕量級開源動態域名解析工具,專注于解決動態IP環境下的域名綁定問題,尤其適配IPv6網絡環境。其核心功能包括: 1.IPv6動態解析:自動檢測本地IPv6地址變化(支持網卡、接口或命令獲取)&#xf…

Docker-mongodb

拉取 MongoDB 鏡像: docker pull mongo 創建容器并設置用戶: 要掛載本地數據目錄,請替換此路徑: /Users/Allen/Env/AllenDocker/mongodb/data/db docker run -d --name local-mongodb \-e MONGO_INITDB_ROOT_USERNAMEadmin \-e MONGO_INITDB_ROOT_PA…

WooCommerce緩存教程 – 如何防止緩存破壞你的WooCommerce網站?

我們在以前的文章中探討過如何加快你的WordPress網站的速度,并研究過各種形式的緩存。 然而,像那些使用WooCommerce的動態電子商務網站,在讓緩存正常工作方面往往會面臨重大挑戰。 在本指南中,我們將告訴你如何為WooCommerce設置…

貪心算法 Part04

總結下重疊區間問題 LC 452. 用最少數量的箭引爆氣球 和 LC 435. 無重疊區間 本質上是一樣的。 LC 452. 用最少數量的箭引爆氣球 是求n個區間當中 , 區間的種類數量 k。此處可以理解為,重疊在一起的區間屬于同一品種,沒有重疊的區間當然…

云原生CD工具-Argocd+ArgoRollout入門到精通

第一章 Argo CD簡介 課時1.1 Argo產品介紹 ARGO官網地址:https://argoproj.github.io/ 旗下產品有: Argo Workflows、ArgoCD 、Argo Rollouts 、Argo Events 課時1.2 什么是Argo CD Argo CD 是一個開源的持續交付工具, 是 Kubernetes 的聲明式 GitOps 持續交付工具。專…

數據分析與應用---數據可視化基礎

目錄 Matplotlib基礎繪圖 (一)、pyplot繪圖基礎語法與常用參數 1、pyplot基礎語法 (1) 創建畫布與創建子圖 (2) 添加畫布內容 (3) 保存與顯示圖形 案例代碼 2. 設置pyplot的動態rc參數 (二)、使用Matplotlib繪制進階圖形 1. 繪制散點圖----scatter 2. 繪制折線…

PP-YOLOE-SOD學習筆記1

項目:基于PP-YOLOE-SOD的無人機航拍圖像檢測案例全流程實操 - 飛槳AI Studio星河社區 一、安裝環境 先準備新環境py>3.9 1.先cd到源代碼的根目錄下 2.pip install -r requirements.txt 3.python setup.py install 這一步需要看自己的GPU情況,去飛漿…

力扣HOT100之二叉樹:114. 二叉樹展開為鏈表

這道題自己嘗試著做了一下,感覺還是得用遞歸來做比較簡單,但是一直想的是用前序遍歷來構造鏈表,導致怎么做都不對,去看了下靈神的題解,然后問了下GPT,現在終于弄明白了。雖然構造出來的鏈表的排列順序是按照…

Spring Boot 注解 @ConditionalOnMissingBean是什么

一句話總結: ConditionalOnMissingBean 是 Spring Boot 提供的一個 條件注解(Conditional Annotation),意思是: 只有當 Spring 容器中 不存在 某個 Bean 時,當前的 Bean 或配置才會被加載。 這是一種典型的…

PyInstaller 如何在mac電腦上生成在window上可執行的exe文件

PyInstaller跨平臺打包限制 PyInstaller 無法直接從macOS生成Windows可執行文件,因為它需要訪問目標平臺的系統庫和Python環境來構建可執行文件。要在macOS上為Windows打包Python應用,需要通過以下方法之一: 方法一:使用虛擬機或…

零基礎設計模式——創建型模式 - 抽象工廠模式

第二部分:創建型模式 - 抽象工廠模式 (Abstract Factory Pattern) 我們已經學習了單例模式(保證唯一實例)和工廠方法模式(延遲創建到子類)。現在,我們來探討創建型模式中更為復雜和強大的一個——抽象工廠…

【通用智能體】Serper API 詳解:搜索引擎數據獲取的核心工具

Serper API 詳解:搜索引擎數據獲取的核心工具 一、Serper API 的定義與核心功能二、技術架構與核心優勢2.1 技術實現原理2.2 對比傳統方案的突破性優勢 三、典型應用場景與代碼示例3.1 SEO 監控系統3.2 競品廣告分析 四、使用成本與配額策略五、開發者注意事項六、替…

Flask-SQLAlchemy核心概念:模型類與數據庫表、類屬性與表字段、外鍵與關系映射

前置閱讀,關于Flask-SQLAlchemy支持哪些數據庫及基本配置,鏈接:Flask-SQLAlchemy_數據庫配置 摘要 本文以一段典型的 SQLAlchemy 代碼示例為引入,闡述以下核心概念: 模型類(Model Class) ? 數…

野火魯班貓(arrch64架構debian)從零實現用MobileFaceNet算法進行實時人臉識別(四)安裝RKNN Toolkit2

RKNN Toolkit2是用來將onnx模型轉成rknn專用模型,并可通過RKNN Toolkit Lite2或者RKNPU調用NPU進行加速計算的工具。 一開始我安裝很多次都無法成功安裝。后來跟售后技術對接,必須是PC平臺的Linux環境才可以。我的電腦是windows,所以我需要用…

基于深度學習的工件檢測系統設計與實現

在工業自動化領域,工件檢測一直是提高生產效率和產品質量的關鍵環節。傳統的人工檢測方法不僅效率低下,而且容易受到主觀因素的影響,導致誤判率較高。隨著深度學習技術的飛速發展,基于圖像識別的自動檢測系統逐漸成為研究熱點。今…

CyberSecAsia專訪CertiK首席安全官:區塊鏈行業亟需“安全優先”開發范式

近日,權威網絡安全媒體CyberSecAsia發布了對CertiK首席安全官Wang Tielei博士的專訪,雙方圍繞企業在進軍區塊鏈領域時所面臨的關鍵安全風險與防御策略展開深入探討。 Wang博士在采訪中指出,跨鏈橋攻擊、智能合約漏洞以及私鑰管理不當&#x…

Google C++ Style Guide 谷歌 C++編碼風格指南,深入理解華為與谷歌的編程規范——C和C++實踐指南

Google C 編程風格指南 Release Apr 07, 2017 0. ?享 ?? 4.45 ??? Benjy Weinberger, Craig Silverstein, Gregory Eitzmann, Mark Mentovai, Tashana Landray ?? YuleFox, Yang.Y, acgtyrant, lilinsanity 亯??享 ? Google Style Guide ? Google 開源…

當科技邂逅浪漫:在Codigger的世界里,遇見“愛”

520,一個充滿愛意的日子,人們用各種方式表達對彼此的深情。而在科技的世界里,我們也正經歷著一場特別的邂逅——Codigger,一個分布式操作系統的誕生,正在以它獨特的方式,重新定義我們與技術的關系。 Codigg…