Python學習Day43

學習來源:@浙大疏錦行

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import os
?
# 設置隨機種子確保結果可復現
torch.manual_seed(42)
np.random.seed(42)
?
# 將 'path/to/your_dataset' 替換為你的數據集所在的根目錄
data_dir = './data/10 Big Cats of the Wild - Image Classification'
train_dir = os.path.join(data_dir, 'train')
test_dir = os.path.join(data_dir, 'test')
?
if not os.path.isdir(data_dir):
? ? raise FileNotFoundError(
? ? ? ? f"Dataset directory not found at '{data_dir}'. "
? ? ? ? f"Please update the 'data_dir' variable to your dataset's path."
? ? )
?
transform = transforms.Compose([
? ? transforms.Resize((32, 32)),
? ? transforms.ToTensor(),
? ? transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
?
# 加載訓練集和測試集
trainset = torchvision.datasets.ImageFolder(root=train_dir, transform=transform)
testset = torchvision.datasets.ImageFolder(root=test_dir, transform=transform)
?
# 從訓練數據集中自動獲取類別名稱和數量
classes = trainset.classes
num_classes = len(classes)
print(f"從數據集中找到 {num_classes} 個類別: {classes}")
?
# 創建數據加載器
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)
?
?
# --- MODIFICATION 3: 動態調整CNN模型以適應你的數據集 ---
class SimpleCNN(nn.Module):
? ? def __init__(self, num_classes): # 將類別數量作為參數傳入
? ? ? ? super(SimpleCNN, self).__init__()
? ? ? ? self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
? ? ? ? self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
? ? ? ? self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
? ? ? ? self.pool = nn.MaxPool2d(2, 2)
? ? ? ? # 輸入特征數128 * 4 * 4取決于輸入圖像大小和網絡結構。
? ? ? ? # 由于我們將所有圖像調整為32x32,經過3次2x2的池化后,尺寸變為 32 -> 16 -> 8 -> 4。所以這里是4*4。
? ? ? ? self.fc1 = nn.Linear(128 * 4 * 4, 512)
? ? ? ? # **重要**: 輸出層的大小現在由num_classes決定
? ? ? ? self.fc2 = nn.Linear(512, num_classes)
? ? ? ??
? ? def forward(self, x):
? ? ? ? x = self.pool(F.relu(self.conv1(x))) ?
? ? ? ? x = self.pool(F.relu(self.conv2(x))) ?
? ? ? ? x = self.pool(F.relu(self.conv3(x))) ?
? ? ? ? x = x.view(-1, 128 * 4 * 4)
? ? ? ? x = F.relu(self.fc1(x))
? ? ? ? x = self.fc2(x)
? ? ? ? return x
?
# 初始化模型,傳入你的數據集的類別數量
model = SimpleCNN(num_classes=num_classes)
print("模型已創建")
?
# 如果有GPU則使用GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
?
# 訓練模型函數 (現在使用傳入的trainloader)
def train_model(model, trainloader, epochs=5): # 增加訓練周期以獲得更好效果
? ? criterion = nn.CrossEntropyLoss()
? ? optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
? ??
? ? print("開始訓練...")
? ? for epoch in range(epochs):
? ? ? ? running_loss = 0.0
? ? ? ? for i, data in enumerate(trainloader, 0):
? ? ? ? ? ? inputs, labels = data
? ? ? ? ? ? inputs, labels = inputs.to(device), labels.to(device)
? ? ? ? ? ??
? ? ? ? ? ? optimizer.zero_grad()
? ? ? ? ? ? outputs = model(inputs)
? ? ? ? ? ? loss = criterion(outputs, labels)
? ? ? ? ? ? loss.backward()
? ? ? ? ? ? optimizer.step()
? ? ? ? ? ??
? ? ? ? ? ? running_loss += loss.item()
? ? ? ? ? ? if i % 100 == 99:
? ? ? ? ? ? ? ? print(f'[{epoch + 1}, {i + 1:5d}] 損失: {running_loss / 100:.3f}')
? ? ? ? ? ? ? ? running_loss = 0.0
? ??
? ? print("訓練完成")
?
# 定義模型保存路徑
model_save_path = 'my_custom_cnn.pth'
?
# 嘗試加載預訓練模型
try:
? ? model.load_state_dict(torch.load(model_save_path))
? ? print(f"已從 '{model_save_path}' 加載預訓練模型")
except FileNotFoundError:
? ? print("無法加載預訓練模型,將開始訓練新模型。")
? ? train_model(model, trainloader, epochs=5) # 訓練新模型
? ? torch.save(model.state_dict(), model_save_path) # 保存訓練好的模型
? ? print(f"新模型已訓練并保存至 '{model_save_path}'")
?
# 設置模型為評估模式
model.eval()
?
# Grad-CAM實現 (這部分無需修改)
class GradCAM:
? ? def __init__(self, model, target_layer):
? ? ? ? self.model = model
? ? ? ? self.target_layer = target_layer
? ? ? ? self.gradients = None
? ? ? ? self.activations = None
? ? ? ? self.register_hooks()
? ? ? ??
? ? def register_hooks(self):
? ? ? ? def forward_hook(module, input, output):
? ? ? ? ? ? self.activations = output.detach()
? ? ? ? def backward_hook(module, grad_input, grad_output):
? ? ? ? ? ? self.gradients = grad_output[0].detach()
? ? ? ? self.target_layer.register_forward_hook(forward_hook)
? ? ? ? self.target_layer.register_backward_hook(backward_hook)
? ??
? ? def generate_cam(self, input_image, target_class=None):
? ? ? ? model_output = self.model(input_image)
? ? ? ? if target_class is None:
? ? ? ? ? ? target_class = torch.argmax(model_output, dim=1).item()
? ? ? ??
? ? ? ? self.model.zero_grad()
? ? ? ? one_hot = torch.zeros_like(model_output)
? ? ? ? one_hot[0, target_class] = 1
? ? ? ? model_output.backward(gradient=one_hot, retain_graph=True) # retain_graph=True可能需要
? ? ? ??
? ? ? ? gradients = self.gradients
? ? ? ? activations = self.activations
? ? ? ? weights = torch.mean(gradients, dim=(2, 3), keepdim=True)
? ? ? ? cam = torch.sum(weights * activations, dim=1, keepdim=True)
? ? ? ? cam = F.relu(cam)
? ? ? ??
? ? ? ? cam = F.interpolate(cam, size=(32, 32), mode='bilinear', align_corners=False)
? ? ? ? cam = cam - cam.min()
? ? ? ? cam = cam / cam.max() if cam.max() > 0 else cam
? ? ? ??
? ? ? ? return cam.cpu().squeeze().numpy(), target_class
?
grad_cam = GradCAM(model, model.conv3)
?
# 從測試集中獲取一張圖片
img, label = testset[0]
img_tensor = img.unsqueeze(0).to(device)
?
# 生成CAM
cam, predicted_class_idx = grad_cam.generate_cam(img_tensor)
?
# 可視化結果
def visualize_cam(img, cam, predicted_class, true_class):
? ? img = img.permute(1, 2, 0).numpy() # 轉換回 (H, W, C)
? ? # 反歸一化以便顯示
? ? img = img * 0.5 + 0.5
? ? img = np.clip(img, 0, 1)
?
? ? heatmap = plt.cm.jet(cam)
? ? heatmap = heatmap[:, :, :3] # 去掉alpha通道
?
? ? overlay = heatmap * 0.4 + img * 0.6
? ??
? ? plt.figure(figsize=(10, 5))
? ? plt.subplot(1, 3, 1)
? ? plt.imshow(img)
? ? plt.title(f'Original Image\nTrue: {true_class}')
? ? plt.axis('off')
?
? ? plt.subplot(1, 3, 2)
? ? plt.imshow(heatmap)
? ? plt.title('Grad-CAM Heatmap')
? ? plt.axis('off')
?
? ? plt.subplot(1, 3, 3)
? ? plt.imshow(overlay)
? ? plt.title(f'Overlay\nPredicted: {predicted_class}')
? ? plt.axis('off')
? ??
? ? plt.show()
?
# 顯示結果
predicted_class_name = classes[predicted_class_idx]
true_class_name = classes[label]
visualize_cam(img, cam, predicted_class_name, true_class_name)

?

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/86534.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/86534.shtml
英文地址,請注明出處:http://en.pswp.cn/web/86534.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

了解一下Unity AssetBundle 的幾種加載方式

Unity 的 AssetBundle 系統提供了多種加載方式,以滿足不同場景下的資源管理和性能需求。 同步加載(LoadFromFile) 同步加載使用 AssetBundle.LoadFromFile 方法從文件系統中直接加載 AssetBundle。這種方式會阻塞主線程,直到加載…

鴻蒙邊緣智能計算架構實戰:多線程圖像采集與高可靠緩沖設計

目錄 一、技術背景與挑戰二、鴻蒙邊緣計算架構的核心特性1. 分布式軟總線:打破設備孤島2. 輕量化多線程模型 三、多線程圖像采集的穩定性設計1. 分層緩沖隊列架構2. 線程優先級策略 四、邊緣側高可靠緩沖機制1. 基于分布式數據管理的容錯設計2. 動態帶寬調節 五、實…

excel中vba開發工具

1、支持單元格點擊出現彈框進行選擇 支持模多次模糊查詢 Private Sub CommandButton1_Click() Call vehicle_查詢 End SubPrivate Sub Worksheet_Activate()Call vehicle_取出車架號和公司名稱 取出不重復的車架號Sheet13.ComboBox1.Visible False 車架號顯示Sheet13.ComboB…

CatBoost:征服類別型特征的梯度提升王者

基于有序提升與對稱樹的下一代GBDT框架,重塑高維分類數據處理范式 一、CatBoost的誕生:解決類別特征的終極挑戰 2017年由俄羅斯Yandex團隊開源,CatBoost(Categorical Boosting)直指機器學習中的核心痛點:類…

使用 WSL 啟動ubuntu.tar文件

使用 WSL 啟動ubuntu.tar文件,可按以下步驟進行3: 檢查 WSL 版本:確保你的 WSL 版本為 2.4.8 或更高版本。可以在命令行中輸入wsl --update來更新 WSL 到最新版本。 設置默認 WSL 版本:如果還沒有將 WSL 2 設置為默認版本&#x…

vue-23(創建用于邏輯提取的可重用組合組件)

創建用于邏輯提取的可重用組合組件 可重用的組合式是 Vue 組合式 API 的基石,它使你能夠在多個組件中提取和重用有狀態邏輯。這有助于編寫更清晰的代碼,減少冗余,并提高可維護性。通過將特定功能封裝到組合式中,你可以輕松地共享…

數據透視表學習筆記

學習視頻:Excel數據透視表大全,3小時從小白到大神!_嗶哩嗶哩_bilibili 合并行標簽 初始數據透視表 不顯示分類匯總 以大綱形式顯示 在組的底部顯示所有分類匯總 以表格形式顯示 合并單元格-右鍵-數據透視表選項 選中-合并并劇中排列帶…

吃透 Golang 基礎:測試

文章目錄 go test測試函數隨機測試測試一個命令白盒測試外部測試包 測試覆蓋率基準測試剖析示例函數 go test go test命令是一個按照一定的約定和組織來測試代碼的程序。在包目錄內,所有以xxx_test.go為后綴名的源文件在執行go build時不會被構建為包的一部分&#…

酒店服務配置無門檻優惠券

1.查看酒店綁定的是那個倉庫; 凱里亞德酒店(深圳北站壹城中心店),綁定的是“龍華民治倉(睿嘀購” 2.“門店列表”選擇“龍華民治倉(睿嘀購””中的“綁定場所” 3.通過酒店名字查找綁定的商品模板; 凱里亞德酒店(深圳…

IoT創新應用場景,賦能海外市場拓展

在數字化浪潮席卷全球的當下,物聯網(Internet of Things, IoT)正以革命性的力量重塑產業生態。這項通過傳感器、通信技術及智能算法實現設備互聯的技術,不僅推動全球從“萬物互聯”邁向“萬物智聯”,更成為賦能企業開拓…

Idea中Docker打包流程記錄

1. maven項目,先打package 2.添加Dockerfile 3.執行打包命令 注意最后的路徑 . docker buildx build -t xxx-app:版本號 -f Dockerfile . 4.下載文件 docker save -o xxx-app-版本號.tar xxx-app:版本號 5.加載鏡像 docker load -i xxx-app-版本號.tar 6.編…

硬件工程師筆試面試高頻考點-電阻

目錄 1.1 電阻選型時一般從哪幾個方面進行考慮? 1.2上拉下拉電阻的作用 1.3 PTC熱敏電阻作為電源電路保險絲的工作原理 1.4 如果阻抗不匹配,有哪些后果 1.5 電阻、電容和電感0402、0603和0805封裝的含義 1.6 電阻、電容和電感的封裝大小與什么參數有關 1.7 …

小程序入門:小程序 API 的三大分類

在小程序開發中,API(Application Programming Interface)起著至關重要的作用,它為開發者提供了豐富的功能和能力,使我們能夠創建出功能強大、用戶體驗良好的小程序。小程序 API 大致可分為以下三大分類:事件…

算法第55天|冗余連接、冗余連接II

冗余連接 題目 思路與解法 #include <iostream> #include <vector> using namespace std; int n; // 節點數量 vector<int> father(1001, 0); // 按照節點大小范圍定義數組// 并查集初始化 void init() {for (int i 0; i < n; i) {father[i] i;} } //…

Docker單獨部署grafana

Docker單獨部署grafana 環境說明 操作前提&#xff1a; 先去搭建PC端的MySQL和虛擬機 自行找參考 Linux部署docker參考文章&#xff1a; 02-Docker安裝_docker安裝包下載-CSDN博客 本文參考文章&#xff1a; 運維小記 說明&#xff1a; 本文的操作均以搭建好的PC端的MySQL和虛…

【數據分析,相關性分析】Matlab代碼#數學建模#創新算法

【數據分析&#xff0c;相關性分析】118-matlab代碼 #數學建模#創新算法 相關性分析及繪圖 基于最大互信息系數的特征篩選 最大互信息系數 皮爾遜相關系數 spearman相關系數 kendall秩相關系數 請自帶預算時間與需求以便高效溝通&#xff0c;回復超快&#xff0c;可以加急…

淺談C++ 中泛型編程(模版編程)

C 是一種強大且靈活的編程語言&#xff0c;支持多種編程范式&#xff0c;使得開發者能夠選擇最適合特定問題的解決方案。在實際開發中&#xff0c;面向對象編程、泛型編程、函數式編程和元編程是最常用的幾種范式。 今天主要與大家一起來介紹和學習泛型編程&#xff08;即模版…

iOS開發中的KVO以及原理

KVO概述 KVO(Key-Value-Observing)是iOS開發中一種觀察者模式實現&#xff0c;允許對象監聽另一個對象屬性的變化。當被觀察屬性的值發生變化時&#xff0c;觀察者會收到通知。KVO基于NSKeyValueObserving協議實現&#xff0c;是Foundation框架的核心功能之一。 1.KVO的基本使…

雷卯針對靈眸科技EASY Orin-nano RK3516 開發板防雷防靜電方案

一、應用場景 1. 人臉檢測 2. 人臉識別 3. 安全帽檢測 4. 人員檢測 5. OCR文字識別 6. 人頭檢測 7. 表情神態識別 8. 人體骨骼點識別 9. 火焰檢測 10. 人臉姿態估計 11. 人手檢測 12. 車輛檢測 13. 二維碼識別 二、 功能概述 1 CPU&#xff1a;八核64位ARM v8處…

中國雙非高校經費TOP榜數據分析

當我們習慣性仰望985、211這些“國家隊”時&#xff0c;一批地方重點支持的高校正悄悄發力&#xff0c;手握重金&#xff0c;展現出不遜于名校的“鈔能力”。特別是“雙非”大學中的佼佼者&#xff0c;它們的年度經費預算&#xff0c;足以讓許多普通院校望塵莫及。 今天就帶大…