打卡第44天:無人機數據集分類

重復以下內容

作業:
kaggle找到一個圖像數據集,用cnn網絡進行訓練并且用grad-cam做可視化

進階:
并拆分成多個文件

import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms, models
from PIL import Image  # 添加缺失的導入
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm# 設置隨機種子確保結果可復現
torch.manual_seed(42)
if torch.cuda.is_available():torch.cuda.manual_seed_all(42)class CustomDataset(Dataset):def __init__(self, image_dir, label_dir, transform=None):self.image_paths = sorted([os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))])self.label_paths = sorted([os.path.join(label_dir, f) for f in os.listdir(label_dir) if f.lower().endswith('.txt')])self.transform = transform# 確保圖像和標簽數量匹配assert len(self.image_paths) == len(self.label_paths), "圖像數量與標簽數量不匹配"def __len__(self):return len(self.image_paths)def __getitem__(self, idx):image = Image.open(self.image_paths[idx]).convert('RGB')with open(self.label_paths[idx], 'r') as f:label = int(f.read().strip())  # 假設標簽文件中是整數類別if self.transform:image = self.transform(image)return image, label
class CustomDataset(Dataset):def __init__(self, image_dir, label_dir, transform=None, debug=False):self.image_dir = image_dirself.label_dir = label_dirself.transform = transformself.debug = debug# 獲取圖像和標簽文件列表self.image_files = sorted([f for f in os.listdir(image_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))])self.label_files = sorted([f for f in os.listdir(label_dir) if f.lower().endswith('.txt')])# 打印調試信息if self.debug:print(f"圖像目錄: {image_dir}")print(f"標簽目錄: {label_dir}")print(f"找到 {len(self.image_files)} 個圖像文件")print(f"找到 {len(self.label_files)} 個標簽文件")# 打印前10個文件檢查排序是否匹配print("\n前10個圖像文件:")for f in self.image_files[:10]:print(f"  {f}")print("\n前10個標簽文件:")for f in self.label_files[:10]:print(f"  {f}")# 確保圖像和標簽數量匹配assert len(self.image_files) == len(self.label_files), \f"圖像數量({len(self.image_files)})與標簽數量({len(self.label_files)})不匹配"# 創建文件映射(假設文件名除去擴展名后相同)self.image_to_label = {}for img_file in self.image_files:# 提取圖像文件名(不含擴展名)img_base = os.path.splitext(img_file)[0]# 查找對應的標簽文件found = Falsefor lbl_file in self.label_files:lbl_base = os.path.splitext(lbl_file)[0]if img_base == lbl_base:self.image_to_label[img_file] = lbl_filefound = Truebreakif not found and self.debug:print(f"警告: 找不到圖像 '{img_file}' 對應的標簽文件")# 再次確認所有圖像都有對應的標簽assert len(self.image_to_label) == len(self.image_files), \f"只有 {len(self.image_to_label)} 個圖像找到了對應的標簽,總數應為 {len(self.image_files)}"if self.debug:print(f"成功建立 {len(self.image_to_label)} 個圖像-標簽映射")def __len__(self):return len(self.image_files)def __getitem__(self, idx):img_file = self.image_files[idx]lbl_file = self.image_to_label[img_file]image_path = os.path.join(self.image_dir, img_file)label_path = os.path.join(self.label_dir, lbl_file)image = Image.open(image_path).convert('RGB')with open(label_path, 'r') as f:label = int(f.read().strip())if self.transform:image = self.transform(image)return image, label
# 數據路徑配置
data_dir = r"C:\Users\許蘭\Desktop\打卡文件\mix20230204"  # 替換為你的數據集路徑
train_image_dir = os.path.join(data_dir, 'train/images')
train_label_dir = os.path.join(data_dir, 'train/labels')
val_image_dir = os.path.join(data_dir, 'validation/images')
val_label_dir = os.path.join(data_dir, 'validation/labels')
test_image_dir = os.path.join(data_dir, 'test/images')
test_label_dir = os.path.join(data_dir, 'test/labels')# 數據預處理和增強
data_transforms = {'train': transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.RandomRotation(10),transforms.ColorJitter(brightness=0.2, contrast=0.2),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),'val': transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),'test': transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
}# 使用自定義數據集類加載數據
image_datasets = {'train': CustomDataset(train_image_dir, train_label_dir, data_transforms['train']),'val': CustomDataset(val_image_dir, val_label_dir, data_transforms['val']),'test': CustomDataset(test_image_dir, test_label_dir, data_transforms['test'])
}# 創建數據加載器
batch_size = 32
dataloaders = {'train': DataLoader(image_datasets['train'], batch_size=batch_size, shuffle=True, num_workers=4),'val': DataLoader(image_datasets['val'], batch_size=batch_size, shuffle=False, num_workers=4),'test': DataLoader(image_datasets['test'], batch_size=batch_size, shuffle=False, num_workers=4)
}# 獲取類別數量(假設類別從0開始連續編號)
num_classes = len(set([label for _, label in image_datasets['train']]))
print(f"數據集包含 {num_classes} 個類別")# 檢查GPU是否可用
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"使用設備: {device}")# 定義CNN模型
class CNNModel(nn.Module):def __init__(self, num_classes):super(CNNModel, self).__init__()# 特征提取部分self.features = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(64, 128, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(128, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(256, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(256, 512, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),)# 分類部分self.classifier = nn.Sequential(nn.Dropout(0.5),nn.Linear(512 * 7 * 7, 4096),nn.ReLU(inplace=True),nn.Dropout(0.5),nn.Linear(4096, 4096),nn.ReLU(inplace=True),nn.Linear(4096, num_classes))# 初始化權重self._initialize_weights()def forward(self, x):x = self.features(x)x = torch.flatten(x, 1)x = self.classifier(x)return xdef _initialize_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.normal_(m.weight, 0, 0.01)nn.init.constant_(m.bias, 0)# 初始化模型
model = CNNModel(num_classes)
model = model.to(device)# 定義損失函數和優化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)# 訓練模型
def train_model(model, dataloaders, criterion, optimizer, scheduler, num_epochs=25):best_model_wts = model.state_dict()best_acc = 0.0# 記錄訓練過程history = {'train_loss': [], 'train_acc': [],'val_loss': [], 'val_acc': []}for epoch in range(num_epochs):print(f'Epoch {epoch+1}/{num_epochs}')print('-' * 10)# 每個epoch有訓練和驗證階段for phase in ['train', 'val']:if phase == 'train':model.train()  # 訓練模式else:model.eval()   # 評估模式running_loss = 0.0running_corrects = 0# 迭代數據for inputs, labels in tqdm(dataloaders[phase]):inputs = inputs.to(device)labels = labels.to(device)# 梯度歸零optimizer.zero_grad()# 前向傳播# 只有在訓練時才跟蹤歷史with torch.set_grad_enabled(phase == 'train'):outputs = model(inputs)_, preds = torch.max(outputs, 1)loss = criterion(outputs, labels)# 只有在訓練階段才反向傳播+優化if phase == 'train':loss.backward()optimizer.step()# 統計running_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)if phase == 'train' and scheduler is not None:scheduler.step()epoch_loss = running_loss / len(dataloaders[phase].dataset)epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')# 記錄歷史history[f'{phase}_loss'].append(epoch_loss)history[f'{phase}_acc'].append(epoch_acc.item())# 保存最佳模型if phase == 'val' and epoch_acc > best_acc:best_acc = epoch_accbest_model_wts = model.state_dict()print(f'保存最佳模型,準確率: {best_acc:.4f}')print()# 加載最佳模型權重model.load_state_dict(best_model_wts)return model, history# 訓練模型
num_epochs = 25
model, history = train_model(model, dataloaders, criterion, optimizer, scheduler, num_epochs)# 在測試集上評估模型
def evaluate_model(model, dataloader):model.eval()running_corrects = 0all_preds = []all_labels = []with torch.no_grad():for inputs, labels in dataloader:inputs = inputs.to(device)labels = labels.to(device)outputs = model(inputs)_, preds = torch.max(outputs, 1)running_corrects += torch.sum(preds == labels.data)all_preds.extend(preds.cpu().numpy())all_labels.extend(labels.cpu().numpy())accuracy = running_corrects.double() / len(dataloader.dataset)print(f'測試集準確率: {accuracy:.4f}')return accuracy.item(), all_preds, all_labels# 評估模型
test_accuracy, predictions, true_labels = evaluate_model(model, dataloaders['test'])# 可視化訓練過程
plt.figure(figsize=(12, 4))plt.subplot(1, 2, 1)
plt.plot(history['train_loss'], label='Training Loss')
plt.plot(history['val_loss'], label='Validation Loss')
plt.title('Loss Over Time')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()plt.subplot(1, 2, 2)
plt.plot(history['train_acc'], label='Training Accuracy')
plt.plot(history['val_acc'], label='Validation Accuracy')
plt.title('Accuracy Over Time')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()plt.tight_layout()
plt.savefig('training_history.png')
plt.show()# 保存模型
torch.save(model.state_dict(), 'cnn_image_classifier.pth')
print("模型已保存為 'cnn_image_classifier.pth'")# 可選:使用預訓練模型進行遷移學習
def use_pretrained_model(num_classes):# 加載預訓練的ResNet18model_ft = models.resnet18(pretrained=True)# 凍結部分層for param in list(model_ft.parameters())[:-4]:param.requires_grad = False# 修改最后的全連接層num_ftrs = model_ft.fc.in_featuresmodel_ft.fc = nn.Linear(num_ftrs, num_classes)return model_ft.to(device)    

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/83622.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/83622.shtml
英文地址,請注明出處:http://en.pswp.cn/web/83622.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

個人網站大更新,還是有個總站比較好

個人網站大更新,還是有個總站比較好 放棄了所有框架,用純htmlcssjs擼了個網站,這回可以想改啥改啥了。 選擇了黑紫作為主色調,暫時看著還算可以。 為什么不用那些框架了 幾個原因: 嘗試用vuepress、vitepress、not…

高精度算法詳解:從原理到加減乘除的完整實現

文章目錄 一、為什么需要高精度算法二、高精度算法的數據結構設計2.1 基礎工具函數2.2 高精度加法實現2.3 高精度減法實現2.4 高精度乘法實現2.5 高精度除法實現 三、完整測試程序四、總結 一、為什么需要高精度算法 在編程中,處理極大數值是常見需求,例…

排序--計數排序

一,引言 計數排序是一種針對整數數據的高效排序算法。其主要流程可分為三個步驟:首先計算整數數據的數值范圍;接著按大小順序統計各數值的出現次數;最后根據統計結果輸出排序后的數據序列。 二,求最值 遍歷現有數據,獲取最大值…

Kubernetes安全機制深度解析(四):動態準入控制和Webhook

#作者:程宏斌 文章目錄 動態準入控制什么是準入 Webhook? 嘗試準入Webhook先決條件編寫一個準入 Webhook 服務器部署準入 Webhook 服務即時配置準入 Webhook對 API 服務器進行身份認證 Webhook 請求與響應Webhook 配置匹配請求-規則匹配請求&#xff1a…

WDK 10.0.19041.685,可在32位win7 sp1系統下搭配vs2019使用,可以編譯出xp驅動。

(14)[驅動開發]配置環境 VS2019 WDK10 寫 xp驅動 (14)[驅動開發]配置環境 VS2019 WDK10 寫 xp驅動_microsoft visual 2019 wdk-CSDN博客文章瀏覽閱讀3k次,點贊8次,收藏17次。本文介紹了如何在VS2019環境下安裝和配置Windows Driver Kit(WDK)&#xff0…

論壇系統自動化測試

1、項目背景與測試目標 系統定位 論壇系統作為典型的高并發Web應用,需支持用戶注冊、登錄、發帖、評論、私信及個人中心管理等核心功能,是用戶公開交流與信息共享的核心平臺。其穩定性與響應效率直接影響用戶體驗及平臺活躍度。 測試必要性 功能可靠性&…

ChipWhisperer教程(一)

一、ChipWhisperer介紹 ChipWhisperer 是一個完整的開源工具鏈,用于學習嵌入式設備上的側信道攻擊并驗證這些設備的側信道抗性。ChipWhisperer主要用于功耗分析,利用設備功耗泄露的信息進行攻擊,也可用于故障攻擊(電壓和時鐘毛刺…

【持續更新】計算機網絡試題

問題1 請簡要說明TCP/IP協議棧的四層結構,并分別舉出每一層出現的典型協議或應用。 答案 應用層:ping,telnet,dns 傳輸層:tcp,udp 網絡層:ip,icmp 數據鏈路層:arp,rarp 問題2 下列協議或應用分別屬于TCP/IP協議…

短劇系統開發:打造高效、創新的短視頻娛樂平臺 - 從0到1的完整解決方案

一、短劇市場迎來爆發式增長 - 不容錯過的萬億級藍海 隨著5G技術的普及和移動互聯網的深度滲透,短劇市場正在經歷前所未有的爆發式增長。根據權威機構艾瑞咨詢最新發布的《2023年中國網絡短劇行業發展報告》顯示: 市場規模:2023年中國短劇市…

ChipWhisperer教程(三)

——CW305目標板的波形采集 一、目標板介紹 CW305 是一款獨立的 FPGA 目標板,搭載的FPGA芯片為Xilinx Artix-7系列。 它具有與 FPGA 通信的 USB 接口、為 FPGA 提供時鐘的外部 PLL、編程 VCC-INT 電源以及用于故障注入環境的二極管保護。 CW305 電路板有多種配置&…

django中如何解析content-type=application/json的請求

django中如何解析content-typeapplication/json的請求 本文由「大千AI助手」原創發布,專注用真話講AI,回歸技術本質。拒絕神話或妖魔化。搜索「大千AI助手」關注我,一起撕掉過度包裝,學習真實的AI技術! 往期文章回顧: …

Chainlink VRF 深度解析與實戰

背景 在區塊鏈的去中心化應用中,隨機性是一個常見但難以實現的需求。例如,區塊鏈游戲需要隨機決定戰斗結果,NFT 項目需要隨機分配稀有屬性,去中心化抽獎需要公平選擇獲獎者。然而,傳統的鏈上隨機數生成方法&#xff0…

7. TypeScript接口

TypeScript 中的接口(Interfaces)用于定義對象的結構。它們允許開發者指定一個對象應具有哪些屬性以及這些屬性的類型。接口有助于確保對象遵循特定的結構,從而在整個應用中提供一致性,并提升代碼的可維護性。 一、認識接口 Typ…

UE 新版渲染器輸出視頻

安裝包解壓到C盤 打開UE插件 Movie Render Queue 進入UE引擎在項目設置找到 libx264 aac mp4 影片渲染隊列調用出 命令行編碼器安裝包路徑,序列輸出路徑,定序器不能有中文

基于用戶的協同過濾推薦算法實現(Java電商平臺)

在電商平臺中,基于用戶的協同過濾推薦算法是一種常見的推薦系統方法。它通過分析用戶之間的相似性來推薦商品。以下是一個簡單的實現思路和示例代碼,使用Java語言。 實現思路 數據準備:收集用戶的評分數據,通常以用戶-商品評分矩…

LeetCode - 904. 水果成籃

題目 904. 水果成籃 - 力扣(LeetCode) 思路 題目本質 你有一個整數數組,每個元素代表一種水果。你只能用兩個籃子,每個籃子只能裝一種水果。你要在數組中找一個最長的連續子數組,這個子數組里最多只包含兩種不同的…

發現 Kotlin MultiPlatform 的一點小變化

最近發現 Kotlin 官方已經開始首推 Idea 的社區版的 KMP 插件了. 以前有網頁創建 KMP 的項目的文檔也消失了. 雖然有 Android Studio 的選項. 但是卻不是在默認的位置上了. 足以說明官方是有意想讓大家直接使用 Idea 社區版或者專業版 所以我直接在社區版上安裝 KMP 插件. 嘗試…

【Photoshop】金屬字體制作

新建一個空白項目,選擇橫排文字工具,輸入想要的文件建立文字圖層 選擇橫排文字工具選擇出文字內容,在通知欄出點擊’拾色器‘,設置好需要的文字顏色 圖層面板右下角點擊‘添加圖層樣式’,選擇斜面和浮雕 樣式設置為內斜…

centos 7.9 升級ssh版本 7.4p1 升級到 8.2p1

centos 7.9 升級ssh版本 7.4p1 升級到 8.2p1 1、安裝包下載2、安裝telnet3、安裝openssl-OpenSSL_1_1_1f.tar.gz4、安裝openssh-8.2p1.tar.gz5、修改ssh服務的相關配置文件6、確定可以ssh連接服務器后,卸載telnet,因為telnet不安全 本文是離線環境下升級…

stm32---dma串口發送+fifo隊列框架

之前分享了一個關于gd32的fifo框架,這次就用stm32仿照寫一個,其實幾乎一樣,這次說的更詳細點,我全文都寫上了注釋,大家直接cv模仿我的調用方式即可 uasrt.c #include "stm32f10x.h" // D…