Python打卡第51天

@浙大疏錦行

作業:

day43的時候我們安排大家對自己找的數據集用簡單cnn訓練,現在可以嘗試下借助這幾天的知識來實現精度的進一步提高

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
import torch.nn.functional as F
from PIL import Image
import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import cv2
import random# 設置隨機種子確保結果可復現
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)# 設置中文字體支持
plt.rcParams["font.family"] = ["SimHei"]
plt.rcParams['axes.unicode_minus'] = False  # 解決負號顯示問題# 數據集路徑
data_dir = r"D:\archive (1)\MY_data"# 數據預處理和增強
train_transform = transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.RandomRotation(10),transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])test_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 加載數據集
train_dataset = datasets.ImageFolder(os.path.join(data_dir, 'train'), transform=train_transform)
test_dataset = datasets.ImageFolder(os.path.join(data_dir, 'test'), transform=test_transform)# 創建數據加載器
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)# 獲取類別名稱
classes = train_dataset.classes
print(f"類別: {classes}")# CBAM注意力機制實現
class ChannelAttention(nn.Module):def __init__(self, in_channels, reduction_ratio=16):super(ChannelAttention, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.max_pool = nn.AdaptiveMaxPool2d(1)self.fc = nn.Sequential(nn.Conv2d(in_channels, in_channels // reduction_ratio, 1, bias=False),nn.ReLU(),nn.Conv2d(in_channels // reduction_ratio, in_channels, 1, bias=False))self.sigmoid = nn.Sigmoid()def forward(self, x):avg_out = self.fc(self.avg_pool(x))max_out = self.fc(self.max_pool(x))out = avg_out + max_outreturn self.sigmoid(out)class SpatialAttention(nn.Module):def __init__(self, kernel_size=7):super(SpatialAttention, self).__init__()self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)self.sigmoid = nn.Sigmoid()def forward(self, x):avg_out = torch.mean(x, dim=1, keepdim=True)max_out, _ = torch.max(x, dim=1, keepdim=True)x_cat = torch.cat([avg_out, max_out], dim=1)out = self.conv(x_cat)return self.sigmoid(out)class CBAM(nn.Module):def __init__(self, in_channels, reduction_ratio=16, kernel_size=7):super(CBAM, self).__init__()self.channel_attention = ChannelAttention(in_channels, reduction_ratio)self.spatial_attention = SpatialAttention(kernel_size)def forward(self, x):x = x * self.channel_attention(x)x = x * self.spatial_attention(x)return x# 定義改進的CNN模型(支持多種預訓練模型和CBAM注意力機制)
class EnhancedFruitClassifier(nn.Module):def __init__(self, num_classes=10, model_name='resnet18', use_cbam=True):super(EnhancedFruitClassifier, self).__init__()self.use_cbam = use_cbam# 根據選擇加載不同的預訓練模型if model_name == 'resnet18':self.model = models.resnet18(pretrained=True)in_features = self.model.fc.in_features# 保存原始層以便后續使用self.features = nn.Sequential(*list(self.model.children())[:-2])self.avgpool = self.model.avgpoolelif model_name == 'resnet50':self.model = models.resnet50(pretrained=True)in_features = self.model.fc.in_featuresself.features = nn.Sequential(*list(self.model.children())[:-2])self.avgpool = self.model.avgpoolelif model_name == 'efficientnet_b0':self.model = models.efficientnet_b0(pretrained=True)in_features = self.model.classifier[1].in_featuresself.features = nn.Sequential(*list(self.model.children())[:-1])self.avgpool = nn.AdaptiveAvgPool2d(1)else:raise ValueError(f"不支持的模型: {model_name}")# 凍結大部分預訓練層for param in list(self.model.parameters())[:-5]:param.requires_grad = False# 添加CBAM注意力機制if use_cbam:self.cbam = CBAM(in_features)# 修改最后一層以適應我們的分類任務self.fc = nn.Linear(in_features, num_classes)def forward(self, x):# 特征提取x = self.features(x)# 應用CBAM注意力機制if self.use_cbam:x = self.cbam(x)# 全局池化x = self.avgpool(x)x = torch.flatten(x, 1)# 分類x = self.fc(x)return x# 初始化模型 - 可以選擇不同的預訓練模型和是否使用CBAM
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = EnhancedFruitClassifier(num_classes=len(classes),model_name='resnet18',  # 可選: 'resnet18', 'resnet50', 'efficientnet_b0'use_cbam=True
).to(device)# 定義損失函數和優化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)# 訓練模型
def train_model(model, train_loader, criterion, optimizer, scheduler, device, epochs=10):model.train()for epoch in range(epochs):running_loss = 0.0correct = 0total = 0progress_bar = tqdm(enumerate(train_loader), total=len(train_loader))for i, (inputs, labels) in progress_bar:inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels).sum().item()progress_bar.set_description(f"Epoch {epoch+1}/{epochs}, "f"Loss: {running_loss/(i+1):.4f}, "f"Acc: {100.*correct/total:.2f}%")scheduler.step()print(f"Epoch {epoch+1}/{epochs}, "f"Train Loss: {running_loss/len(train_loader):.4f}, "f"Train Acc: {100.*correct/total:.2f}%")return model# 評估模型
def evaluate_model(model, test_loader, device):model.eval()correct = 0total = 0class_correct = list(0. for i in range(len(classes)))class_total = list(0. for i in range(len(classes)))with torch.no_grad():for inputs, labels in test_loader:inputs, labels = inputs.to(device), labels.to(device)outputs = model(inputs)_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels).sum().item()# 計算每個類別的準確率for i in range(len(labels)):label = labels[i]class_correct[label] += (predicted[i] == label).item()class_total[label] += 1print(f"測試集整體準確率: {100.*correct/total:.2f}%")# 打印每個類別的準確率for i in range(len(classes)):if class_total[i] > 0:print(f"{classes[i]} 類別的準確率: {100.*class_correct[i]/class_total[i]:.2f}%")else:print(f"{classes[i]} 類別的樣本數為0")return 100.*correct/total# Grad-CAM實現
class GradCAM:def __init__(self, model, target_layer):self.model = modelself.target_layer = target_layerself.gradients = Noneself.activations = None# 注冊鉤子self.hook_handles = []# 保存梯度的鉤子def backward_hook(module, grad_in, grad_out):self.gradients = grad_out[0]return None# 保存激活值的鉤子def forward_hook(module, input, output):self.activations = outputreturn Noneself.hook_handles.append(target_layer.register_forward_hook(forward_hook))self.hook_handles.append(target_layer.register_backward_hook(backward_hook))def __call__(self, x, class_idx=None):# 前向傳播model_output = self.model(x)if class_idx is None:class_idx = torch.argmax(model_output, dim=1)# 構建one-hot向量one_hot = torch.zeros_like(model_output)one_hot[0, class_idx] = 1# 反向傳播self.model.zero_grad()model_output.backward(gradient=one_hot, retain_graph=True)# 計算權重(全局平均池化梯度)weights = torch.mean(self.gradients, dim=(2, 3), keepdim=True)# 加權組合激活映射cam = torch.sum(weights * self.activations, dim=1).squeeze()# ReLU激活,因為我們只關心對類別有正貢獻的區域cam = F.relu(cam)# 歸一化if torch.max(cam) > 0:cam = cam / torch.max(cam)# 調整大小到輸入圖像尺寸cam = F.interpolate(cam.unsqueeze(0).unsqueeze(0), size=(x.size(2), x.size(3)), mode='bilinear', align_corners=False).squeeze()return cam.detach().cpu().numpy(), class_idx.item()def remove_hooks(self):for handle in self.hook_handles:handle.remove()# 可視化Grad-CAM結果
def visualize_gradcam(img_path, model, target_layer, classes, device):# 加載并預處理圖像img = Image.open(img_path).convert('RGB')img_tensor = test_transform(img).unsqueeze(0).to(device)# 初始化Grad-CAMgrad_cam = GradCAM(model, target_layer)# 獲取Grad-CAM熱力圖cam, pred_class = grad_cam(img_tensor)# 反歸一化圖像以便顯示img_np = img_tensor.squeeze().cpu().numpy().transpose((1, 2, 0))img_np = img_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])img_np = np.clip(img_np, 0, 1)# 調整熱力圖大小heatmap = cv2.resize(cam, (img_np.shape[1], img_np.shape[0]))# 創建彩色熱力圖heatmap = np.uint8(255 * heatmap)heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)heatmap = np.float32(heatmap) / 255# 疊加原始圖像和熱力圖superimposed_img = heatmap * 0.4 + img_npsuperimposed_img = np.clip(superimposed_img, 0, 1)# 顯示結果plt.figure(figsize=(15, 5))plt.subplot(131)plt.imshow(img_np)plt.title('原始圖像')plt.axis('off')plt.subplot(132)plt.imshow(cam, cmap='jet')plt.title('Grad-CAM熱力圖')plt.axis('off')plt.subplot(133)plt.imshow(superimposed_img)plt.title(f'疊加圖像\n預測類別: {classes[pred_class]}')plt.axis('off')plt.tight_layout()plt.show()# 預測函數
def predict_image(img_path, model, classes, device):# 加載并預處理圖像img = Image.open(img_path).convert('RGB')img_tensor = test_transform(img).unsqueeze(0).to(device)# 預測model.eval()with torch.no_grad():outputs = model(img_tensor)probs = F.softmax(outputs, dim=1)top_probs, top_classes = probs.topk(5, dim=1)# 打印預測結果print(f"圖像: {os.path.basename(img_path)}")print("預測結果:")for i in range(top_probs.size(1)):print(f"{classes[top_classes[0, i]]}: {top_probs[0, i].item() * 100:.2f}%")return top_classes[0, 0].item()# 主函數
def main():# 訓練模型print("開始訓練模型...")trained_model = train_model(model, train_loader, criterion, optimizer, scheduler, device, epochs=5)# 評估模型print("\n評估模型...")evaluate_model(trained_model, test_loader, device)# 保存模型model_path = "fruit_classifier.pth"torch.save(trained_model.state_dict(), model_path)print(f"\n模型已保存至: {model_path}")# 可視化Grad-CAM結果print("\n可視化Grad-CAM結果...")# 從測試集中隨機選擇幾張圖像進行可視化predict_dir = os.path.join(data_dir, 'predict')if os.path.exists(predict_dir):# 使用predict目錄中的圖像image_files = [os.path.join(predict_dir, f) for f in os.listdir(predict_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]if len(image_files) > 0:# 隨機選擇2張圖像sample_images = random.sample(image_files, min(2, len(image_files)))for img_path in sample_images:print(f"\n處理圖像: {img_path}")# 預測圖像類別pred_class = predict_image(img_path, trained_model, classes, device)# 可視化Grad-CAMif hasattr(trained_model, 'model') and hasattr(trained_model.model, 'layer4'):# 對于ResNet系列模型visualize_gradcam(img_path, trained_model, trained_model.model.layer4[-1].conv2, classes, device)else:# 對于其他模型,使用最后一個特征層visualize_gradcam(img_path, trained_model, list(trained_model.features.children())[-1], classes, device)else:print(f"predict目錄為空,無法進行可視化")else:print(f"predict目錄不存在,無法進行可視化")if __name__ == "__main__":main()
類別: ['Apple', 'Banana', 'avocado', 'cherry', 'kiwi', 'mango', 'orange', 'pinenapple', 'strawberries', 'watermelon']
開始訓練模型...
Epoch 1/5, Loss: 0.8748, Acc: 74.23%: 100%|██████████| 72/72 [00:08<00:00,  8.66it/s]
Epoch 1/5, Train Loss: 0.8748, Train Acc: 74.23%
Epoch 2/5, Loss: 0.4802, Acc: 83.83%: 100%|██████████| 72/72 [00:07<00:00, 10.02it/s]
Epoch 2/5, Train Loss: 0.4802, Train Acc: 83.83%
Epoch 3/5, Loss: 0.4239, Acc: 86.35%: 100%|██████████| 72/72 [00:07<00:00,  9.69it/s]
Epoch 3/5, Train Loss: 0.4239, Train Acc: 86.35%
Epoch 4/5, Loss: 0.4179, Acc: 85.96%: 100%|██████████| 72/72 [00:07<00:00,  9.64it/s]
Epoch 4/5, Train Loss: 0.4179, Train Acc: 85.96%
Epoch 5/5, Loss: 0.3747, Acc: 87.44%: 100%|██████████| 72/72 [00:07<00:00,  9.68it/s]
Epoch 5/5, Train Loss: 0.3747, Train Acc: 87.44%評估模型...
測試集整體準確率: 66.83%
Apple 類別的準確率: 80.90%
Banana 類別的準確率: 0.00%
avocado 類別的準確率: 1.89%
cherry 類別的準確率: 93.33%
kiwi 類別的準確率: 93.33%
mango 類別的準確率: 48.57%
orange 類別的準確率: 97.94%
pinenapple 類別的準確率: 96.19%
strawberries 類別的準確率: 90.29%
watermelon 類別的準確率: 71.43%模型已保存至: fruit_classifier.pth可視化Grad-CAM結果...處理圖像: D:\archive (1)\MY_data\predict\img_341.jpeg
圖像: img_341.jpeg
預測結果:
mango: 90.52%
orange: 3.99%
kiwi: 2.45%
avocado: 1.86%
Apple: 0.98%

處理圖像: D:\archive (1)\MY_data\predict\1.jpeg
圖像: 1.jpeg
預測結果:
Apple: 95.86%
cherry: 2.94%
Banana: 0.71%
avocado: 0.24%
strawberries: 0.19%

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/84538.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/84538.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/84538.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Notepad++ 官方下載

https://notepad-plus-plus.org/downloads/ 下載官網 1、https://github.com/notepad-plus-plus/notepad-plus-plus/releases 2、https://notepad-plus-plus.org/news/v881-we-are-with-ukraine/

運維之十個問題--2

目錄 1. 如果有ip惡意刷流量怎么辦 2. 標準端口范圍 3.內存16G&#xff0c;交換分區多大 4.請簡述非對稱加密算法&#xff0c;ping命令通過什么協議實現&#xff0c;icmp是什么協議 5.客戶訪問網站速度慢原因 6. 進程和線程的區別 7.zabbix監控是你搭建的嗎&#xff0c;平…

vue前端面試題——記錄一次面試當中遇到的題(1)

1.v-if和v-show的區別 v-if和v-show都是Vue中用于條件渲染的指令&#xff0c;但它們的實現機制和適用場景有所不同&#xff1a; v-if是真正的條件渲染&#xff0c;在條件切換時會銷毀和重建DOM元素&#xff0c;適合運行時條件變化不頻繁的場景&#xff1b; v-show只是通過CS…

【QT面試題】(三)

文章目錄 Qt信號槽的優點及缺點Qt中的文件流和數據流區別&#xff1f;Qt中show和exec區別QT多線程使用的方法 (4種)QString與基本數據類型如何轉換&#xff1f;QT保證多線程安全事件與信號的區別connect函數的連接方式&#xff1f;信號與槽的多種用法Qt的事件過濾器有哪些同步和…

Vscode下Go語言環境配置

前言 本文介紹了vscode下Go語言開發環境的快速配置&#xff0c;為新手小白快速上手Go語言提供幫助。 1.下載官方Vscode 這步比較基礎&#xff0c;已經安裝好的同學可以直接快進到第二步 官方安裝包地址&#xff1a;https://code.visualstudio.com/ 雙擊一直點擊下一步即可,記…

HTML 文本省略號

目錄 HTML 文本省略號超行省略號如何實現1. 單行文本溢出顯示省略號2. 多行文本溢出顯示省略號方法一&#xff1a;使用 -webkit-line-clamp&#xff08;推薦&#xff09;方法二&#xff1a;使用偽元素&#xff08;兼容性好&#xff09;方法三&#xff1a;使用 JavaScript 動態監…

Spring Boot 實現流式響應(兼容 2.7.x)

在實際開發中&#xff0c;我們可能會遇到一些流式數據處理的場景&#xff0c;比如接收來自上游接口的 Server-Sent Events&#xff08;SSE&#xff09; 或 流式 JSON 內容&#xff0c;并將其原樣中轉給前端頁面或客戶端。這種情況下&#xff0c;傳統的 RestTemplate 緩存機制會…

ffmpeg 新版本轉碼設置幀率上限

ffmpeg 新版本轉碼設置幀率上限 ffmpeg 在老版本比如 4.3的時候&#xff0c;轉碼設置幀率上限是通過vsync控制 # 設置動態控制最大幀率60 "-vsync 2 -r 60" 新版本這個參數沒辦法動態判斷控制幀率了 替換為使用filter中的fps進行設置 # 設置動態幀率最大60幀 -…

Qt繪制電池圖標源碼分享

一、效果展示 二、源碼分享 cell.h #ifndef CELL_WIDGET_H #define CELL_WIDGET_H #include <QWidget> #include <QPainter> #include <QPaintEngine> #include <QPaintEvent>/* 電池控件類 */ class CellWidget : public QWidget {Q_OBJECTQ_PROPERTY…

安卓基礎(生成APK)

??生成調試版&#xff08;Debug&#xff09;?? Build → Build Bundle(s)/APK(s) → Build APK輸出路徑&#xff1a;app/build/outputs/apk/debug/app-debug.apk ??生成發布版&#xff08;Release&#xff09;?? Build → Generate Signed Bundle/APK → 選擇 ??APK?…

如何在 TypeScript 中使用類型保護

前言 類型保護是一種 TypeScript 技術&#xff0c;用于獲取變量類型的信息&#xff0c;通常用于條件塊中。類型保護是返回布爾值的常規函數??&#xff0c;它接受一個類型并告知 TypeScript 是否可以將其縮小到更具體的值。類型保護具有獨特的屬性&#xff0c;可以根據返回的…

山東大學軟件學院項目實訓-基于大模型的模擬面試系統-面試對話標題自動總結

面試對話標題自動總結 主要實現思路&#xff1a;每當AI回復用戶之后&#xff0c;調用方法查看當前對話是否大于三條&#xff0c;如果大于則將用戶的兩條和AI回復的一條對話傳給DeepSeek讓其進行總結&#xff08;后端&#xff09;&#xff0c;總結后調用updateChatTopic進行更新…

Spring Cloud與Alibaba微服務架構全解析

Spring Cloud與Spring Cloud Alibaba微服務架構解析 1. Spring Boot概念 Spring Boot并不是新技術&#xff0c;而是基于Spring框架下“約定優于配置”理念的產物。它幫助開發者更容易、更快速地創建獨立運行和產品級別的基于Spring框架的應用。Spring Boot中并沒有引入新技術…

AI 賦能 Java 開發:從通宵達旦到高效交付的蛻變之路

作為一名深耕 Java 開發領域多年的從業者&#xff0c;相信很多同行都與我有過相似的經歷&#xff1a;在 “996” 甚至 “007” 的高壓模式下&#xff0c;被反復修改的需求、復雜的架構設計、無休止的代碼編寫&#xff0c;以及部署時層出不窮的問題折磨得疲憊不堪。長期以來&…

06. C#入門系列【自定義類型】:從青銅到王者的進階之路

C#入門系列【自定義類型】&#xff1a;從青銅到王者的進階之路 一、引言&#xff1a;為什么需要自定義類型&#xff1f; 在C#的世界里&#xff0c;系統自帶的類型&#xff08;如int、string、bool&#xff09;就像是基礎武器&#xff0c;能解決一些簡單問題。但當你面對復雜的…

使用 PyTorch 和 TensorBoard 實時可視化模型訓練

在這個教程中&#xff0c;我們將使用 PyTorch 訓練一個簡單的多層感知機&#xff08;MLP&#xff09;模型來解決 MNIST 手寫數字分類問題&#xff0c;并且使用 TensorBoard 來可視化訓練過程中的不同信息&#xff0c;如損失、準確度、圖像、參數分布和學習率變化。 步驟 1&…

第十五章 15.OSPF(CCNA)

第十五章 15.OSPF(CCNA) 介紹了大家都能用的OSPF動態路由協議 注釋&#xff1a; 學習資源是B站的CCNA by Sean_Ning CCNA 最新CCNA 200-301 視頻教程(含免費實驗環境&#xff09; PS&#xff1a;喜歡的可以去買下他的課程&#xff0c;不貴&#xff0c;講的很細 To be cont…

手機連接windows遇到的問題及解決方法

文章目錄 寫在前面一、手機與windows 連接以后 無法在win端打開手機屏幕,提示801方法零、檢查連接方法一、系統修復方法二、斷開重連方法三、軟件更新方法四、關閉防火墻 寫在前面 本文主要記錄所遇到的問題以及解決方案&#xff0c;以備后用。 所用機型&#xff1a;win11 專業…

Spring Boot + MyBatis Plus 項目中,entity和 XML 映射文件的查找機制

在 Spring Boot MyBatis - Plus 項目中&#xff0c;entity&#xff08;實體類&#xff09;和 XML 映射文件的查找機制有其默認規則&#xff0c;也可通過配置調整&#xff0c;以下詳細說明&#xff1a; 一、實體類&#xff08;entity&#xff09;的查找 MyBatis - Plus 能找到…

itvbox綠豆影視tvbox手機版影視APP源碼分享搭建教程

我們先來看看今天的主題&#xff0c;tvbox手機版&#xff0c;然后再看看如何搭建&#xff1a; 很多愛好者都希望搭建自己的影視平臺&#xff0c;那該如何搭建呢&#xff1f; 后端開發環境&#xff1a; 1.易如意后臺管理優化版源碼&#xff1b; 2.寶塔面板&#xff1b; 3.ph…