python學習打卡day43

DAY 43 復習日

作業:
kaggle找到一個圖像數據集,用cnn網絡進行訓練并且用grad-cam做可視化

@浙大疏錦行

數據集使用貓狗數據集,訓練集中包含貓圖像4000張、狗圖像4005張。測試集包含貓圖像1012張,狗圖像1013張。以下是數據集的下載地址。

貓和狗 --- Cat and Dog

1.數據集加載與數據預處理

我這里對數據集文件路徑做了改變

C:\Users\vijay\Desktop\1\

├── train\

│? ? ? ├── cats\?

│? ? ? └── dogs\

└── test\

? ? ? ? ├── cats\?

? ? ? ? └── dags\?

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as F# 設置隨機種子確保結果可復現
torch.manual_seed(42)
np.random.seed(42)# 設置中文字體支持
plt.rcParams["font.family"] = ["SimHei"]
plt.rcParams['axes.unicode_minus'] = False  # 解決負號顯示問題# 檢查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用設備: {device}")# 1. 數據預處理
# 訓練集:使用多種數據增強方法提高模型泛化能力
train_transform = transforms.Compose([# 新增:調整圖像大小為統一尺寸transforms.Resize((32, 32)),  # 確保所有圖像都是32x32像素transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),transforms.RandomRotation(15),transforms.ToTensor(),transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])# 測試集:僅進行必要的標準化,保持數據原始特性
test_transform = transforms.Compose([# 新增:調整圖像大小為統一尺寸transforms.Resize((32, 32)),  # 確保所有圖像都是32x32像素transforms.ToTensor(),transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])# 定義數據集根目錄
root = r'C:\Users\vijay\Desktop\1'train_dataset = datasets.ImageFolder(root=root + '/train',  # 指向 train 子文件夾transform=train_transform
)
test_dataset = datasets.ImageFolder(root=root + '/test',  # 指向 test 子文件夾transform=test_transform
)# 打印類別信息,確認數據加載正確
print(f"訓練集類別: {train_dataset.classes}")
print(f"測試集類別: {test_dataset.classes}")# 3. 創建數據加載器
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

2.模型訓練與評估?

# 定義一個簡單的CNN模型
class SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()# 第一個卷積層,輸入通道為3(彩色圖像),輸出通道為32,卷積核大小為3x3,填充為1以保持圖像尺寸不變self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)# 第二個卷積層,輸入通道為32,輸出通道為64,卷積核大小為3x3,填充為1self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)# 第三個卷積層,輸入通道為64,輸出通道為128,卷積核大小為3x3,填充為1self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)# 最大池化層,池化核大小為2x2,步長為2,用于下采樣,減少數據量并提取主要特征self.pool = nn.MaxPool2d(2, 2)# 第一個全連接層,輸入特征數為128 * 4 * 4(經過前面卷積和池化后的特征維度),輸出為512self.fc1 = nn.Linear(128 * 4 * 4, 512)# 第二個全連接層,輸入為512,輸出為2(對應貓和非貓兩個類別)self.fc2 = nn.Linear(512, 2)def forward(self, x):# 第一個卷積層后接ReLU激活函數和最大池化操作,經過池化后圖像尺寸變為原來的一半,這里輸出尺寸變為16x16x = self.pool(F.relu(self.conv1(x)))# 第二個卷積層后接ReLU激活函數和最大池化操作,輸出尺寸變為8x8x = self.pool(F.relu(self.conv2(x)))# 第三個卷積層后接ReLU激活函數和最大池化操作,輸出尺寸變為4x4x = self.pool(F.relu(self.conv3(x)))# 將特征圖展平為一維向量,以便輸入到全連接層x = x.view(-1, 128 * 4 * 4)# 第一個全連接層后接ReLU激活函數x = F.relu(self.fc1(x))# 第二個全連接層輸出分類結果x = self.fc2(x)return x# 初始化模型
model = SimpleCNN()
print("模型已創建")# 如果有GPU則使用GPU,將模型轉移到對應的設備上
model = model.to(device)# 訓練模型
def train_model(model, train_loader, test_loader, epochs=10):# 定義損失函數為交叉熵損失,用于分類任務criterion = nn.CrossEntropyLoss()# 定義優化器為Adam,用于更新模型參數,學習率設置為0.001optimizer = torch.optim.Adam(model.parameters(), lr=0.001)for epoch in range(epochs):# 訓練階段model.train()running_loss = 0.0correct = 0total = 0for i, data in enumerate(train_loader, 0):# 從數據加載器中獲取圖像和標簽inputs, labels = data# 將圖像和標簽轉移到對應的設備(GPU或CPU)上inputs, labels = inputs.to(device), labels.to(device)# 清空梯度,避免梯度累加optimizer.zero_grad()# 模型前向傳播得到輸出outputs = model(inputs)# 計算損失loss = criterion(outputs, labels)# 反向傳播計算梯度loss.backward()# 更新模型參數optimizer.step()running_loss += loss.item()_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels).sum().item()if i % 100 == 99:# 每100個批次打印一次平均損失和準確率print(f'[{epoch + 1}, {i + 1}] 損失: {running_loss / 100:.3f} | 準確率: {100.*correct/total:.2f}%')running_loss = 0.0# 測試階段model.eval()test_loss = 0correct = 0total = 0with torch.no_grad():for data in test_loader:images, labels = dataimages, labels = images.to(device), labels.to(device)outputs = model(images)test_loss += criterion(outputs, labels).item()_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels).sum().item()print(f'測試集 [{epoch + 1}] 損失: {test_loss/len(test_loader):.3f} | 準確率: {100.*correct/total:.2f}%')print("訓練完成")return model# 訓練模型
try:# 嘗試加載預訓練模型(如果存在)model.load_state_dict(torch.load('cat_classifier.pth'))print("已加載預訓練模型")
except:print("無法加載預訓練模型,訓練新模型")model = train_model(model, train_loader, test_loader, epochs=10)# 保存訓練后的模型參數torch.save(model.state_dict(), 'cat_classifier.pth')# 設置模型為評估模式
model.eval()

3.?Grad-CAM實現

# Grad-CAM實現
class GradCAM:def __init__(self, model, target_layer):self.model = modelself.target_layer = target_layerself.gradients = Noneself.activations = None# 注冊鉤子,用于獲取目標層的前向傳播輸出和反向傳播梯度self.register_hooks()def register_hooks(self):# 前向鉤子函數,在目標層前向傳播后被調用,保存目標層的輸出(激活值)def forward_hook(module, input, output):self.activations = output.detach()# 反向鉤子函數,在目標層反向傳播后被調用,保存目標層的梯度def backward_hook(module, grad_input, grad_output):self.gradients = grad_output[0].detach()# 在目標層注冊前向鉤子和反向鉤子self.target_layer.register_forward_hook(forward_hook)self.target_layer.register_backward_hook(backward_hook)def generate_cam(self, input_image, target_class=None):# 前向傳播,得到模型輸出model_output = self.model(input_image)if target_class is None:# 如果未指定目標類別,則取模型預測概率最大的類別作為目標類別target_class = torch.argmax(model_output, dim=1).item()# 清除模型梯度,避免之前的梯度影響self.model.zero_grad()# 反向傳播,構造one-hot向量,使得目標類別對應的梯度為1,其余為0,然后進行反向傳播計算梯度one_hot = torch.zeros_like(model_output)one_hot[0, target_class] = 1model_output.backward(gradient=one_hot)# 獲取之前保存的目標層的梯度和激活值gradients = self.gradientsactivations = self.activations# 對梯度進行全局平均池化,得到每個通道的權重,用于衡量每個通道的重要性weights = torch.mean(gradients, dim=(2, 3), keepdim=True)# 加權激活映射,將權重與激活值相乘并求和,得到類激活映射的初步結果cam = torch.sum(weights * activations, dim=1, keepdim=True)# ReLU激活,只保留對目標類別有正貢獻的區域,去除負貢獻的影響cam = F.relu(cam)# 調整大小并歸一化,將類激活映射調整為與輸入圖像相同的尺寸(32x32),并歸一化到[0, 1]范圍cam = F.interpolate(cam, size=(32, 32), mode='bilinear', align_corners=False)cam = cam - cam.min()cam = cam / cam.max() if cam.max() > 0 else camreturn cam.cpu().squeeze().numpy(), target_class# 可視化Grad-CAM結果的函數
import warnings
warnings.filterwarnings("ignore")
import matplotlib.pyplot as plt
# 設置中文字體支持
plt.rcParams["font.family"] = ["SimHei"]
plt.rcParams['axes.unicode_minus'] = False  # 解決負號顯示問題
# 選擇一個隨機圖像
# idx = np.random.randint(len(test_dataset))
idx = 102  # 選擇測試集中的第101張圖片 (索引從0開始)
image, label = test_dataset[idx]
print(f"選擇的圖像類別: {test_dataset.classes[label]}")# 轉換圖像以便可視化
def tensor_to_np(tensor):img = tensor.cpu().numpy().transpose(1, 2, 0)mean = np.array([0.485, 0.456, 0.406])std = np.array([0.229, 0.224, 0.225])img = std * img + meanimg = np.clip(img, 0, 1)return img# 添加批次維度并移動到設備
input_tensor = image.unsqueeze(0).to(device)# 初始化Grad-CAM(選擇最后一個卷積層)
grad_cam = GradCAM(model, model.conv3)# 生成熱力圖
heatmap, pred_class = grad_cam.generate_cam(input_tensor)# 可視化
plt.figure(figsize=(12, 4))# 原始圖像
plt.subplot(1, 3, 1)
plt.imshow(tensor_to_np(image))
plt.title(f"原始圖像: {test_dataset.classes[label]}")
plt.axis('off')# 熱力圖
plt.subplot(1, 3, 2)
plt.imshow(heatmap, cmap='jet')
plt.title(f"Grad-CAM熱力圖: {test_dataset.classes[pred_class]}")
plt.axis('off')# 疊加的圖像
plt.subplot(1, 3, 3)
img = tensor_to_np(image)
heatmap_resized = np.uint8(255 * heatmap)
heatmap_colored = plt.cm.jet(heatmap_resized)[:, :, :3]
superimposed_img = heatmap_colored * 0.4 + img * 0.6
plt.imshow(superimposed_img)
plt.title("疊加熱力圖")
plt.axis('off')plt.tight_layout()
plt.savefig('grad_cam_result.png')
plt.show()print("Grad-CAM可視化完成。已保存為grad_cam_result.png")

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/908058.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/908058.shtml
英文地址,請注明出處:http://en.pswp.cn/news/908058.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

大數據與數據分析【數據分析全棧攻略:爬蟲+處理+可視化+報告】

- 第 100 篇 - Date: 2025 - 05 - 25 Author: 鄭龍浩/仟墨 大數據與數據分析 文章目錄 大數據與數據分析一 大數據是什么?1 定義2 大數據的來源3 大數據4個方面的典型特征(4V)4 大數據的應用領域5 數據分析工具6 數據是五種生產要素之一 二 …

uniapp 開發企業微信小程序,如何區別生產環境和測試環境?來處理不同的服務請求

在 uniapp 開發企業微信小程序時,區分生產環境和測試環境是常見需求。以下是幾種可靠的方法,幫助你根據環境處理不同的服務請求: 一、通過條件編譯區分(推薦) 使用 uniapp 的 條件編譯 語法,在代碼中標記…

青少年編程與數學 02-020 C#程序設計基礎 15課題、異常處理

青少年編程與數學 02-020 C#程序設計基礎 15課題、異常處理 一、異常1. 異常的分類2. 異常的作用小結 二、異常處理1. 異常處理的定義2. 異常處理的主要組成部分3. 異常處理的作用小結 三、C#異常處理1. 異常的基本概念2. 異常處理的關鍵字3. 異常處理的流程4. 自定義異常5. 異…

云原生時代 Kafka 深度實踐:05性能調優與場景實戰

5.1 性能調優全攻略 Producer調優 批量發送與延遲發送 通過調整batch.size和linger.ms參數提升吞吐量: props.put(ProducerConfig.BATCH_SIZE_CONFIG, 16384); // 默認16KB props.put(ProducerConfig.LINGER_MS_CONFIG, 10); // 等待10ms以積累更多消息ba…

在 Dify 項目中的 Celery:異步任務的實現與集成

Celery 是一個強大而靈活的分布式任務隊列系統,旨在幫助應用程序在后臺異步運行耗時的任務,提高系統的響應速度和性能。在 Dify 項目中,Celery 被廣泛用于處理異步任務和定時任務,并與其他工具(如 Sentry、OpenTelemet…

Pytorch Geometric官方例程pytorch_geometric/examples/link_pred.py環境安裝教程及圖數據集制作

最近需要訓練圖卷積神經網絡(Graph Convolution Neural Network, GCNN),在配置GCNN環境上總結了一些經驗。 我覺得對于初學者而言,圖神經網絡的訓練會有2個難點: ①環境配置 ②數據集制作 一、環境配置 我最初光想…

2025年微信小程序開發:AR/VR與電商的最新案例

引言 微信小程序自2017年推出以來,已成為中國移動互聯網生態的核心組成部分。根據最新數據,截至2025年,微信小程序的日活躍用戶超過4.5億,總數超過430萬,覆蓋電商、社交、線下服務等多個領域(WeChat Mini …

互聯網向左,區塊鏈向右

2008年,中本聰首次提出了比特幣的設想,這打開了去中心化的大門。 比特幣白皮書清晰的描述了去中心化支付的解決方案,并分別從以下幾個方面闡述了他的理念: 一、由轉賬雙方點對點的通訊,而不通過中心化的第三方&#xf…

PV操作的C++代碼示例講解

文章目錄 一、PV操作基本概念(一)信號量(二)P操作(三)V操作 二、PV操作的意義三、C中實現PV操作的方法(一)使用信號量實現PV操作代碼解釋: (二)使…

《對象創建的秘密:Java 內存布局、逃逸分析與 TLAB 優化詳解》

大家好呀!今天我們來聊聊Java世界里那些"看不見摸不著"但又超級重要的東西——對象在內存里是怎么"住"的,以及JVM這個"超級管家"是怎么幫我們優化管理的。放心,我會用最接地氣的方式講解,保證連小學…

簡單實現Ajax基礎應用

Ajax不是一種技術,而是一個編程概念。HTML 和 CSS 可以組合使用來標記和設置信息樣式。JavaScript 可以修改網頁以動態顯示,并允許用戶與新信息進行交互。內置的 XMLHttpRequest 對象用于在網頁上執行 Ajax,允許網站將內容加載到屏幕上而無需…

詳解開漏輸出和推挽輸出

開漏輸出和推挽輸出 以上是 GPIO 配置為輸出時的內部示意圖,我們要關注的其實就是這兩個 MOS 管的開關狀態,可以組合出四種狀態: 兩個 MOS 管都關閉時,輸出處于一個浮空狀態,此時他對其他點的電阻是無窮大的&#xff…

Matlab實現LSTM-SVM回歸預測,作者:機器學習之心

Matlab實現LSTM-SVM回歸預測,作者:機器學習之心 目錄 Matlab實現LSTM-SVM回歸預測,作者:機器學習之心效果一覽基本介紹程序設計參考資料 效果一覽 基本介紹 代碼主要功能 該代碼實現了一個LSTM-SVM回歸預測模型,核心流…

Leetcode - 周賽 452

目錄 一,3566. 等積子集的劃分方案二,3567. 子矩陣的最小絕對差三,3568. 清理教室的最少移動四,3569. 分割數組后不同質數的最大數目 一,3566. 等積子集的劃分方案 題目列表 本題有兩種做法,dfs 選或不選…

【FAQ】HarmonyOS SDK 閉源開放能力 —Account Kit(5)

1.問題描述: 集成華為一鍵登錄的LoginWithHuaweiIDButton, 但是Button默認名字叫 “華為賬號一鍵登錄”,太長無法顯示,能否簡寫成“一鍵登錄”與其他端一致? 解決方案: 問題分兩個場景: 一、…

Asp.Net Core SignalR的分布式部署

文章目錄 前言一、核心二、解決方案架構三、實現方案1.使用 Azure SignalR Service2.Redis Backplane(Redis 背板方案)3.負載均衡配置粘性會話要求無粘性會話方案(僅WebSockets)完整部署示例(Redis Docker)性能優化技…

L2-054 三點共線 - java

L2-054 三點共線 語言時間限制內存限制代碼長度限制棧限制Java (javac)2600 ms512 MB16KB8192 KBPython (python3)2000 ms256 MB16KB8192 KB其他編譯器2000 ms64 MB16KB8192 KB 題目描述: 給定平面上 n n n 個點的坐標 ( x _ i , y _ i ) ( i 1 , ? , n ) (x\_i…

【 java 基礎知識 第一篇 】

目錄 1.概念 1.1.java的特定有哪些? 1.2.java有哪些優勢哪些劣勢? 1.3.java為什么可以跨平臺? 1.4JVM,JDK,JRE它們有什么區別? 1.5.編譯型語言與解釋型語言的區別? 2.數據類型 2.1.long與int類型可以互轉嗎&…

高效背誦英語四級范文

以下是結合認知科學和實戰驗證的 ??高效背誦英語作文五步法??,助你在30分鐘內牢固記憶一篇作文,特別適配考前沖刺場景: 📝 ??一、解構作文(5分鐘)?? ??拆解邏輯框架?? 用熒光筆標出&#xff…

RHEL7安裝教程

RHEL7安裝教程 下載RHEL7鏡像 通過網盤分享的文件:RHEL 7.zip 鏈接: https://pan.baidu.com/s/1ExLhdJigj-tcrHJxIca5XA?pwdjrrj 提取碼: jrrj --來自百度網盤超級會員v6的分享安裝 1.打開VMware,新建虛擬機,選擇自定義然后下一步 2.點擊…