DAY43打卡

@浙大疏錦行

kaggle找到一個圖像數據cnn網絡進行訓練并且grad-cam可視化

進階并拆分成多個文件

fruit_cnn_project/
├─ data/                # 存放數據集(需手動創建,后續放入圖片)
│  ├─ train/            # 訓練集圖像
│  └─ val/              # 驗證集圖像
├─ models/              # 模型定義
│  └─ cnn_model.py      # CNN網絡結構
├─ utils/               # 工具函數
│  ├─ dataset_utils.py  # 數據加載與預處理
│  ├─ grad_cam.py       # Grad-CAM可視化
│  └─ train_utils.py    # 訓練與評估
├─ main.py              # 主程序
└─ requirements.txt     # 依賴列表(可選)
# 第一部分:導入庫
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline# 第二部分:數據加載與預處理
def load_data():data_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])train_dataset = datasets.ImageFolder(root='data/train', transform=data_transform)train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)test_dataset = datasets.ImageFolder(root='data/test', transform=data_transform)test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)return train_loader, test_loader# 第三部分:模型定義
class SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)self.relu = nn.ReLU()self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)self.fc1 = nn.Linear(32 * 56 * 56, 128)self.fc2 = nn.Linear(128, 2)def forward(self, x):x = self.pool(self.relu(self.conv1(x)))x = self.pool(self.relu(self.conv2(x)))x = x.view(-1, 32 * 56 * 56)x = self.relu(self.fc1(x))x = self.fc2(x)return x# 第四部分:模型訓練
train_loader, _ = load_data()
model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
num_epochs = 10for epoch in range(num_epochs):running_loss = 0.0for i, data in enumerate(train_loader, 0):inputs, labels = dataoptimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()print(f'Epoch {epoch + 1}, Loss: {running_loss / len(train_loader)}')torch.save(model.state_dict(), 'trained_model.pth')# 第五部分:模型測試
_, test_loader = load_data()
model = SimpleCNN()
model.load_state_dict(torch.load('trained_model.pth'))
model.eval()
correct = 0
total = 0with torch.no_grad():for data in test_loader:images, labels = dataoutputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f'Accuracy of the network on the test images: {100 * correct / total}%')# 第六部分:Grad-CAM可視化(修復版)
def get_activation():activation = {}def hook(model, input, output):activation['target_layer'] = output.detach()return hook, activationdef grad_cam(model, image, target_class_index):hook, activation = get_activation()target_layer = model.conv2target_layer.register_forward_hook(hook)model.eval()image = image.unsqueeze(0)image.requires_grad_(True)output = model(image)one_hot = torch.zeros(1, output.size()[-1]).to(image.device)one_hot[0][target_class_index] = 1output.backward(gradient=one_hot, retain_graph=True)gradients = image.grad[0].cpu().numpy()# 從activation字典中獲取激活圖activation_map = activation['target_layer'].cpu().numpy()[0]weights = np.mean(gradients, axis=(1, 2))cam = np.zeros(activation_map.shape[1:], dtype=np.float32)for i, w in enumerate(weights):cam += w * activation_map[i]cam = np.maximum(cam, 0)cam = F.interpolate(torch.from_numpy(cam).unsqueeze(0).unsqueeze(0), size=(224, 224), mode='bilinear', align_corners=False)[0][0].numpy()cam = (cam - cam.min()) / (cam.max() - cam.min())return cam# 可視化前幾張測試圖片
dataiter = iter(test_loader)
images, labels = dataiter.next()for i in range(5):  # 可視化前5張圖片image = images[i]label = labels[i].item()cam = grad_cam(model, image, label)plt.figure(figsize=(10, 5))plt.subplot(1, 2, 1)plt.imshow(image.permute(1, 2, 0).numpy())plt.title(f'Original Image (Class: {label})')plt.axis('off')plt.subplot(1, 2, 2)plt.imshow(image.permute(1, 2, 0).numpy())plt.imshow(cam, cmap='jet', alpha=0.5)plt.title('Grad-CAM Visualization')plt.axis('off')plt.tight_layout()plt.show()

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/83327.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/83327.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/83327.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

[藍橋杯C++ 2024 國 B ] 立定跳遠(二分)

題目描述 在運動會上,小明從數軸的原點開始向正方向立定跳遠。項目設置了 n n n 個檢查點 a 1 , a 2 , ? , a n a_1, a_2, \cdots , a_n a1?,a2?,?,an? 且 a i ≥ a i ? 1 > 0 a_i \ge a_{i?1} > 0 ai?≥ai?1?>0。小明必須先后跳躍到每個檢查…

LINUX530 rsync定時同步 環境配置

rsync定時代碼同步 環境配置 關閉防火墻 selinux systemctl stop firewalld systemctl disable firewalld setenforce 0 vim /etc/selinux/config SELINUXdisable設置主機名 hostnamectl set-hostname code hostnamectl set-hostname backup設置靜態地址 cd /etc/sysconfi…

鴻蒙OSUniApp結合機器學習打造智能圖像分類應用:HarmonyOS實踐指南#三方框架 #Uniapp

UniApp結合機器學習打造智能圖像分類應用:HarmonyOS實踐指南 引言 在移動應用開發領域,圖像分類是一個既經典又充滿挑戰的任務。隨著機器學習技術的發展,我們現在可以在移動端實現高效的圖像分類功能。本文將詳細介紹如何使用UniApp結合Ten…

【Redis】大key問題詳解

目錄 1、什么是大key2、大key的危害【1】阻塞風險【2】網絡阻塞【3】內存不均【4】持久化問題 3、如何發現大key【1】使用內置命令【2】使用memory命令(Redis 4.0)【3】使用scan命令【4】監控工具 4、解決方案【1】拆分大key【2】使用合適的數據結構【3】…

redis核心知識點

Redis是一種基于內存的數據庫,對數據的讀寫操作都是在內存中完成,因此讀寫速度非常快,常用于緩存,消息隊列、分布式鎖等場景。 Redis 提供了多種數據類型來支持不同的業務場景,比如 String(字符串)、Hash(哈希)、 Lis…

vscode不滿足先決條件問題的解決——vscode的老版本安裝與禁止更新(附安裝包)

目錄 起因 vscode更新設置的關閉 安裝包 結語 起因 由于主包用的系統是centos的,且版本有點老了,再加上vscode現在不支持老版本的,這對主包來說更是雪上加霜啊 但是主包看了網上很多教程,眼花繚亂,好多配置要改&…

如何成為一名優秀的產品經理(自動駕駛)

一、 夯實核心基礎 深入理解智能駕駛技術棧: 感知: 攝像頭、雷達(毫米波、激光雷達)、超聲波傳感器的工作原理、優缺點、融合策略。了解目標檢測、跟蹤、SLAM等基礎算法概念。 定位: GNSS、IMU、高精地圖、輪速計等定…

【ISAQB大綱解讀】信息隱藏指的是什么

在軟件架構中,信息隱藏(Information Hiding) 是核心設計原則之一,由 David Parnas 在 1972 年提出。它強調通過限制對模塊內部實現細節的訪問,來降低系統復雜度、提高可維護性和可擴展性。在 ISAQB 的學習目標&#xf…

網頁前端開發(基礎進階2--JS)

前面學習了html與css,接下來學習JS(JavaScript與Java無關)。 web標準(網頁標準)分為3個部分: 1.html主要負責網頁的結構(頁面的元素和內容) 2.css主要負責網頁的表現(…

完全移除內聯腳本

說明 日期&#xff1a;2025年5月9日。 內聯腳本給跨站腳本攻擊&#xff08;XSS&#xff09;留了條路。 示例 日期&#xff1a;2025年5月9日。 如下網頁文件a.html&#xff1a; <!-- 內聯腳本塊 --> <script> function handleClick{ alert("Hello")…

[藍橋杯]約瑟夫環

約瑟夫環 題目描述 nn 個人的編號是 1 ~ nn&#xff0c;如果他們依編號按順時針排成一個圓圈&#xff0c;從編號是 1 的人開始順時針報數。 &#xff08;報數是從 1 報起&#xff09;當報到 kk 的時候&#xff0c;這個人就退出游戲圈。下一個人重新從 1 開始報數。 求最后剩…

電子電氣架構 --- 如何應對未來區域式電子電氣(E/E)架構的挑戰?

我是穿拖鞋的漢子,魔都中堅持長期主義的汽車電子工程師。 老規矩,分享一段喜歡的文字,避免自己成為高知識低文化的工程師: 做到欲望極簡,了解自己的真實欲望,不受外在潮流的影響,不盲從,不跟風。把自己的精力全部用在自己。一是去掉多余,凡事找規律,基礎是誠信;二是…

isp中的 ISO代表什么意思

isp中的 ISO代表什么意思 在攝影和圖像信號處理&#xff08;ISP&#xff0c;Image Signal Processor&#xff09;領域&#xff0c;ISO是一個用于衡量相機圖像傳感器對光線敏感度的標準參數。它最初源于膠片攝影時代的 “國際標準化組織&#xff08;International Organization …

第十二節:第五部分:集合框架:Set集合的特點、底層原理、哈希表、去重復原理

Set系列集合特點 哈希值 HashSet集合的底層原理 HashSet集合去重復 代碼 代碼一&#xff1a;整體了解一下Set系列集合的特點 package com.itheima.day20_Collection_set;import java.util.HashSet; import java.util.LinkedHashSet; import java.util.Set; import java.util.…

邁向分布式智能:解析MCP到A2A的通信范式遷移

智能體與外部世界的橋梁之言&#xff1a; 在深入探討智能體之間的協作機制之前&#xff0c;我們有必要先厘清一個更基礎的問題&#xff1a;**單個智能體如何與外部世界建立連接&#xff1f;** 這就引出了我們此前介紹過的 **MCP&#xff08;Model Context Protocol&…

Android Studio 配置之gitignore

1.創建或編輯.gitignore文件 在項目根目錄下檢查是否已有.gitignore文件。如果沒有&#xff0c;創建一個新文件&#xff0c;命名為.gitignore&#xff08;注意文件名前有個點&#xff09;。 添加忽略規則&#xff1a;在.gitignore中添加以下內容&#xff1a; 忽略整個 .idea …

算法:二分查找

1.二分查找 704. 二分查找 - 力扣&#xff08;LeetCode&#xff09; 二分查找算法要確定“二段性”&#xff0c;時間復雜度為O(lonN)。為了防止數據溢出&#xff0c;所以求mid時要用防溢出的方式。 class Solution { public:int search(vector<int>& nums, int tar…

day62—DFS—太平洋大西洋水流問題(LeetCode-417)

題目描述 有一個 m n 的矩形島嶼&#xff0c;與 太平洋 和 大西洋 相鄰。 “太平洋” 處于大陸的左邊界和上邊界&#xff0c;而 “大西洋” 處于大陸的右邊界和下邊界。 這個島被分割成一個由若干方形單元格組成的網格。給定一個 m x n 的整數矩陣 heights &#xff0c; hei…

Langchaine4j 流式輸出 (6)

Langchaine4j 流式輸出 大模型的流式輸出是指大模型在生成文本或其他類型的數據時&#xff0c;不是等到整個生成過程完成后再一次性 返回所有內容&#xff0c;而是生成一部分就立即發送一部分給用戶或下游系統&#xff0c;以逐步、逐塊的方式返回結果。 這樣&#xff0c;用戶…

自動駕駛與智能交通:構建未來出行的智能引擎

隨著人工智能、物聯網、5G和大數據等前沿技術的發展&#xff0c;自動駕駛汽車和智能交通系統正以前所未有的速度改變人類的出行方式。這一變革不僅是技術的融合創新&#xff0c;更是推動城市可持續發展的關鍵支撐。 一、自動駕駛與智能交通的定義 1. 自動駕駛&#xff08;Auto…