Python訓練營打卡Day43

kaggle找到一個圖像數據集,用cnn網絡進行訓練并且用grad-cam做可視化

進階:并拆分成多個文件

config.py

import os# 基礎配置類
class Config:def __init__(self):# Kaggle配置self.kaggle_username = ""  # Kaggle用戶名self.kaggle_key = ""  # Kaggle API密鑰# 數據集配置self.dataset_name = "chest-xray-pneumonia"  # 默認使用胸部X光數據集self.data_dir = "data"self.train_dir = os.path.join(self.data_dir, "train")self.val_dir = os.path.join(self.data_dir, "val")self.test_dir = os.path.join(self.data_dir, "test")# 模型配置self.model_save_path = "models/cnn_model.h5"self.img_width, self.img_height = 224, 224self.batch_size = 32self.epochs = 10self.learning_rate = 0.001# Grad-CAM配置self.gradcam_output_dir = "gradcam_output"self.target_layer = "block5_conv3"  # VGG16最后一個卷積層,根據模型調整    

data_loader.py

import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from config import Configclass DataLoader:def __init__(self, config: Config):self.config = configself.train_generator = Noneself.val_generator = Noneself.test_generator = Noneself.class_indices = Nonedef setup_data_generators(self):# 數據增強配置train_datagen = ImageDataGenerator(rescale=1./255,rotation_range=20,width_shift_range=0.2,height_shift_range=0.2,shear_range=0.2,zoom_range=0.2,horizontal_flip=True,fill_mode='nearest')test_datagen = ImageDataGenerator(rescale=1./255)# 創建數據生成器self.train_generator = train_datagen.flow_from_directory(self.config.train_dir,target_size=(self.config.img_width, self.config.img_height),batch_size=self.config.batch_size,class_mode='categorical')self.val_generator = test_datagen.flow_from_directory(self.config.val_dir,target_size=(self.config.img_width, self.config.img_height),batch_size=self.config.batch_size,class_mode='categorical')self.test_generator = test_datagen.flow_from_directory(self.config.test_dir,target_size=(self.config.img_width, self.config.img_height),batch_size=self.config.batch_size,class_mode='categorical',shuffle=False)self.class_indices = self.train_generator.class_indicesreturn self.train_generator, self.val_generator, self.test_generatordef get_class_names(self):if self.class_indices is None:self.setup_data_generators()return list(self.class_indices.keys())    

grad_cam.py

import os
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import cv2
from tensorflow.keras.models import Model
from config import Configclass GradCAM:def __init__(self, model, class_names, config: Config):self.model = modelself.class_names = class_namesself.config = configos.makedirs(self.config.gradcam_output_dir, exist_ok=True)def generate_heatmap(self, img_array, layer_name=None):if layer_name is None:layer_name = self.config.target_layer# 創建一個用于獲取輸出的模型grad_model = Model(inputs=[self.model.inputs],outputs=[self.model.get_layer(layer_name).output, self.model.output])# 計算梯度with tf.GradientTape() as tape:conv_outputs, predictions = grad_model(img_array)class_idx = np.argmax(predictions[0])class_name = self.class_names[class_idx]loss = predictions[:, class_idx]# 獲取梯度grads = tape.gradient(loss, conv_outputs)# 平均梯度pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))# 權重激活映射conv_outputs = conv_outputs[0]heatmap = tf.reduce_mean(tf.multiply(pooled_grads, conv_outputs), axis=-1)# 歸一化熱圖heatmap = np.maximum(heatmap, 0) / np.max(heatmap)return heatmap, class_name, predictions[0][class_idx]def overlay_heatmap(self, heatmap, img_path, alpha=0.4):# 加載原始圖像img = cv2.imread(img_path)img = cv2.resize(img, (self.config.img_width, self.config.img_height))# 調整熱圖大小heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0]))# 將熱圖轉換為RGBheatmap = np.uint8(255 * heatmap)heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)# 將熱圖疊加到原圖superimposed_img = heatmap * alpha + imgsuperimposed_img = np.uint8(superimposed_img)return img, heatmap, superimposed_imgdef process_image(self, img_path, layer_name=None):# 加載和預處理圖像img = tf.keras.preprocessing.image.load_img(img_path, target_size=(self.config.img_width, self.config.img_height))img_array = tf.keras.preprocessing.image.img_to_array(img)img_array = np.expand_dims(img_array, axis=0)img_array = img_array / 255.0# 生成熱圖heatmap, class_name, confidence = self.generate_heatmap(img_array, layer_name)# 疊加熱圖original_img, heatmap_img, superimposed_img = self.overlay_heatmap(heatmap, img_path)# 保存結果filename = os.path.basename(img_path)output_path = os.path.join(self.config.gradcam_output_dir, f"gradcam_{filename}")# 創建可視化fig, axes = plt.subplots(1, 3, figsize=(15, 5))axes[0].imshow(cv2.cvtColor(original_img, cv2.COLOR_BGR2RGB))axes[0].set_title('原始圖像')axes[0].axis('off')axes[1].imshow(heatmap)axes[1].set_title('Grad-CAM熱圖')axes[1].axis('off')axes[2].imshow(cv2.cvtColor(superimposed_img, cv2.COLOR_BGR2RGB))axes[2].set_title(f'疊加圖像 - {class_name} ({confidence:.2%})')axes[2].axis('off')plt.tight_layout()plt.savefig(output_path)plt.close()return output_path, class_name, confidence    

kaggle_downloader.py

import os
import json
import kaggle
from kaggle.api.kaggle_api_extended import KaggleApi
from config import Config
import zipfileclass KaggleDownloader:def __init__(self, config: Config):self.config = configself.api = Nonedef authenticate(self):# 設置Kaggle API憑證os.environ['KAGGLE_USERNAME'] = self.config.kaggle_usernameos.environ['KAGGLE_KEY'] = self.config.kaggle_key# 初始化API客戶端self.api = KaggleApi()self.api.authenticate()def download_dataset(self):if not self.api:self.authenticate()# 創建數據目錄os.makedirs(self.config.data_dir, exist_ok=True)# 下載數據集print(f"正在下載數據集: {self.config.dataset_name}")self.api.dataset_download_files(self.config.dataset_name, path=self.config.data_dir, unzip=True)print(f"數據集下載完成,保存路徑: {self.config.data_dir}")# 解壓文件(如果需要)for file in os.listdir(self.config.data_dir):if file.endswith('.zip'):zip_path = os.path.join(self.config.data_dir, file)with zipfile.ZipFile(zip_path, 'r') as zip_ref:zip_ref.extractall(self.config.data_dir)os.remove(zip_path)    

main.py

import argparse
from config import Config
from kaggle_downloader import KaggleDownloader
from data_loader import DataLoader
from model_builder import ModelBuilder
from trainer import Trainer
from grad_cam import GradCAM
import tensorflow as tf
import osdef main():# 解析命令行參數parser = argparse.ArgumentParser(description='Kaggle圖像數據CNN訓練與Grad-CAM可視化')parser.add_argument('--download', action='store_true', help='下載Kaggle數據集')parser.add_argument('--train', action='store_true', help='訓練模型')parser.add_argument('--evaluate', action='store_true', help='評估模型')parser.add_argument('--visualize', action='store_true', help='運行Grad-CAM可視化')parser.add_argument('--dataset', type=str, help='Kaggle數據集名稱')parser.add_argument('--model_type', type=str, default='vgg16', choices=['simple', 'vgg16'], help='模型類型')parser.add_argument('--img_path', type=str, help='用于Grad-CAM可視化的圖像路徑')args = parser.parse_args()# 配置config = Config()# 更新配置if args.dataset:config.dataset_name = args.dataset# 1. 下載Kaggle數據集if args.download:downloader = KaggleDownloader(config)downloader.download_dataset()# 2. 加載數據data_loader = DataLoader(config)train_generator, val_generator, test_generator = data_loader.setup_data_generators()class_names = data_loader.get_class_names()print(f"分類類別: {class_names}")# 3. 構建模型model_builder = ModelBuilder(config, len(class_names))if args.model_type == 'simple':model = model_builder.build_simple_cnn()else:model = model_builder.build_vgg16_model()# 4. 訓練模型if args.train:trainer = Trainer(config)history = trainer.train(model, train_generator, val_generator)print("模型訓練完成")# 5. 評估模型if args.evaluate:if os.path.exists(config.model_save_path):model = tf.keras.models.load_model(config.model_save_path)print("加載已保存的模型")test_loss, test_acc = model.evaluate(test_generator)print(f"測試集準確率: {test_acc:.2%}")# 6. Grad-CAM可視化if args.visualize:if os.path.exists(config.model_save_path):model = tf.keras.models.load_model(config.model_save_path)print("加載已保存的模型用于可視化")if args.img_path and os.path.exists(args.img_path):grad_cam = GradCAM(model, class_names, config)output_path, class_name, confidence = grad_cam.process_image(args.img_path)print(f"可視化完成,結果保存在: {output_path}")print(f"預測類別: {class_name}, 置信度: {confidence:.2%}")else:print("請提供有效的圖像路徑")if __name__ == "__main__":main()    

model_builder.py

from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
from tensorflow.keras.applications import VGG16
from tensorflow.keras.optimizers import Adam
from config import Configclass ModelBuilder:def __init__(self, config: Config, num_classes: int):self.config = configself.num_classes = num_classesdef build_simple_cnn(self):# 構建簡單的CNN模型model = Sequential([Conv2D(32, (3, 3), activation='relu', input_shape=(self.config.img_width, self.config.img_height, 3)),MaxPooling2D((2, 2)),Conv2D(64, (3, 3), activation='relu'),MaxPooling2D((2, 2)),Conv2D(128, (3, 3), activation='relu'),MaxPooling2D((2, 2)),Flatten(),Dense(128, activation='relu'),Dropout(0.5),Dense(self.num_classes, activation='softmax')])model.compile(optimizer=Adam(learning_rate=self.config.learning_rate),loss='categorical_crossentropy',metrics=['accuracy'])return modeldef build_vgg16_model(self, fine_tune=False):# 構建基于VGG16的預訓練模型base_model = VGG16(weights='imagenet',include_top=False,input_shape=(self.config.img_width, self.config.img_height, 3))# 是否微調預訓練模型if not fine_tune:for layer in base_model.layers:layer.trainable = False# 添加自定義層x = base_model.outputx = Flatten()(x)x = Dense(256, activation='relu')(x)x = Dropout(0.5)(x)predictions = Dense(self.num_classes, activation='softmax')(x)model = Model(inputs=base_model.input, outputs=predictions)model.compile(optimizer=Adam(learning_rate=self.config.learning_rate),loss='categorical_crossentropy',metrics=['accuracy'])return model    

trainer.py

import os
import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from config import Configclass Trainer:def __init__(self, config: Config):self.config = configdef train(self, model, train_generator, val_generator):# 創建模型保存目錄os.makedirs(os.path.dirname(self.config.model_save_path), exist_ok=True)# 定義回調函數callbacks = [ModelCheckpoint(self.config.model_save_path, monitor='val_accuracy', save_best_only=True, mode='max',verbose=1),EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True,verbose=1),ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=3, min_lr=0.00001,verbose=1)]# 訓練模型history = model.fit(train_generator,steps_per_epoch=train_generator.samples // self.config.batch_size,validation_data=val_generator,validation_steps=val_generator.samples // self.config.batch_size,epochs=self.config.epochs,callbacks=callbacks)return history    

@浙大疏錦行

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/82888.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/82888.shtml
英文地址,請注明出處:http://en.pswp.cn/web/82888.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

hive 3集成Iceberg 1.7中的Java版本問題

hive 3.1.3 集成iceberg 1.7.2創建Iceberg表報錯如下: Exception in thread "main" java.lang.UnsupportedClassVersionError: org/apache/iceberg/mr/hive/HiveIcebergStorageHandler has been compiled by a more recent version of the Java Runtime …

文本切塊技術(Splitter)

為什么要分塊? 將長文本分解成適當大小的片段,以便于嵌入、索引和存儲,并提高檢索的精確度。 用ChunkViz工具可視化分塊 在線使用 ChunkViz github https://github.com/gkamradt/ChunkViz 如何確定大模型所能接受的最長上下文 可以從…

C++:用 libcurl 發送一封帶有附件的郵件

編寫mingw C 程序&#xff0c;用 libcurl 發送一封帶有附件的郵件 下面是一個使用 MinGW 編譯的 C 程序&#xff0c;使用 libcurl 發送帶附件的郵件。這個程序完全通過代碼實現 SMTP 郵件發送&#xff0c;不依賴外部郵件客戶端&#xff1a; // send_email.cpp #include <i…

tensorflow image_dataset_from_directory 訓練數據集構建

以數據集 https://www.kaggle.com/datasets/vipoooool/new-plant-diseases-dataset 為例 目錄結構 訓練圖像數據集要求&#xff1a; 主目錄下包含多個子目錄&#xff0c;每個子目錄代表一個類別。每個子目錄中存儲屬于該類別的圖像文件。 例如 main_directory/ ...cat/ ...…

遨游Spring AI:第一盤菜Hello World

Spring AI的正式版已經發布了&#xff0c;很顯然&#xff0c;接下來我們要做的事情就是寫一個Hello World。 總體思路就是在本地搭建一個簡單的大模型&#xff0c;然后編寫Spring AI代碼與模型進行交互。 分五步&#xff1a; 1. 安裝Ollama&#xff1b; 2. 安裝DeepSeek&…

華為云Flexus+DeepSeek征文|基于華為云Flexus X和DeepSeek-R1打造個人知識庫問答系統

目錄 前言 1 快速部署&#xff1a;一鍵搭建Dify平臺 1.1 部署流程詳解 1.2 初始配置與登錄 2 構建專屬知識庫 2.1 進入知識庫模塊并創建新庫 2.2 選擇數據源導入內容 2.3 上傳并識別多種文檔格式 2.4 文本處理與索引構建 2.5 保存并完成知識庫創建 3接入ModelArts S…

Java優化:雙重for循環

在工作中&#xff0c;經常性的會出現在兩張表中查找相同ID的數據&#xff0c;許多開發者會使用兩層for循環嵌套&#xff0c;雖然實現功能沒有問題&#xff0c;但是效率極低&#xff0c;一下是一個簡單的優化過程&#xff0c;代碼耗時湊從26856ms優化到了748ms。 功能場景 有兩…

Prompt Tuning:生成的模型文件有什么構成

一、為什么Prompt Tuning會生成模型文件? 1. Prompt Tuning的本質:優化可訓練的「提示參數」 核心邏輯:Prompt Tuning(提示調優)是一種輕量級的微調技術,僅優化模型輸入層的提示向量(Prompt Embedding)或少量額外參數,而非更新整個預訓練模型的權重。生成模型文件的原…

ARM SMMUv3簡介(一)

1.概述 SMMU&#xff08;System Memory Management Unit&#xff0c;系統內存管理單元&#xff09;是ARM架構中用于管理設備訪問系統內存的硬件模塊。SMMU和MMU的功能類似&#xff0c;都是將虛擬地址轉換成物理地址&#xff0c;不同的是MMU轉換的虛擬地址來自CPU&#xff0c;S…

在 Windows 系統上運行 Docker 容器中的 Ubuntu 鏡像并顯示 GUI

在 Windows 上安裝一個 X Server&#xff08;如 VcXsrv 或 X410&#xff09;&#xff0c;Ubuntu 容器通過網絡將圖形界面轉發到 Windows。 步驟&#xff1a; 安裝 X Server&#xff1a; 推薦使用VcXsrv&#xff0c;免費開源。 安裝后運行 XLaunch&#xff0c;選擇&#xff1…

Vue3學習(4)- computed的使用

1. 簡述與使用 作用&#xff1a;computed 用于基于響應式數據派生出新值&#xff0c;其值會自動緩存并在依賴變化時更新。 ?緩存機制?&#xff1a;依賴未變化時直接返回緩存值&#xff0c;避免重復計算&#xff08;通過 _dirty 標志位實現&#xff09;。?響應式更新?&…

【HarmonyOS 5】出行導航開發實踐介紹以及詳細案例

以下是 ?HarmonyOS 5? 出行導航的核心能力詳解&#xff08;無代碼版&#xff09;&#xff0c;聚焦智能交互、多端協同與場景化創新&#xff1a; 一、交互革新&#xff1a;從被動響應到主動服務 ?意圖驅動導航? ?自然語義理解?&#xff1a;用戶通過語音指令&#xff08;如…

csrf攻擊學習

原理 csrf又稱跨站偽造請求攻擊&#xff0c;現代網站利用Cookie、Session 或 Token 等機制識別用戶身份&#xff0c;一旦用戶訪問某個網站&#xff0c;瀏覽器在之后請求會自動帶上這些信息來識別用戶身份。用戶在網站進行請求或者操作時服務器會給出對應的內容&#xff0c;比如…

深入剖析MySQL鎖機制,多事務并發場景鎖競爭

一、隱藏字段對 InnoDB 的行鎖&#xff08;Record Lock&#xff09;與間隙鎖&#xff08;Gap Lock&#xff09;的影響 1. 隱藏字段與鎖的三大核心影響 類型影響維度描述DB_TRX_IDMVCC 可見性控制決定是否讀取當前版本&#xff0c;或在加鎖時避開不可見版本&#xff08;影響加鎖…

以SMMUv2為例,使用Trace32可視化操作SMMU的常用命令詳解

Trace32支持一系列的SMMU命令&#xff0c;可以幫助用戶更好地配置、查看和分析SMMU。換句話說&#xff0c;就是讓SMMU的配置變得可視化。 在添加SMMU實例之前&#xff0c;需要選擇一個CPU來激活該SMMU實例的相關命令。Trace32讓SMMU的配置可視化的本質是&#xff0c;操縱CPU讀取…

將數據庫表導出為C#實體對象

數據庫方式 use 數據庫;declare TableName sysname 表名 declare Result varchar(max) /// <summary> /// TableName /// </summary> public class TableName {select Result Result /// <summary>/// CONVERT(NVARCHAR(500), ISNULL(ColN…

CSS 預處理器與工具

目錄 CSS 預處理器與工具1. Less主要特性 2. Sass/SCSS主要特性 3. Tailwind CSS主要特性 4. 其他工具PostCSSCSS Modules 5. 選擇建議 CSS 預處理器與工具 1. Less Less 是一個 CSS 預處理器&#xff0c;它擴展了 CSS 語言&#xff0c;添加了變量、嵌套規則、混合&#xff0…

this.$set() 的用法詳解(Vue響應式系統相關)

1. 什么是 this.$set()&#xff1f; this.$set(target, key, value) 是 Vue 2 中提供的一個方法&#xff0c;用于向響應式對象中動態添加屬性&#xff0c;確保新加的屬性同樣是響應式的。 2. 為什么需要它&#xff1f; Vue 2 的響應式系統基于 Object.defineProperty&#…

【HarmonyOS Next之旅】DevEco Studio使用指南(三十)

目錄 1 -> 部署云側工程 2 -> 通過CloudDev面板獲取云開發資源支持 3 -> 通用云開發模板 3.1 -> 適用范圍 3.2 -> 效果圖 4 -> 總結 1 -> 部署云側工程 可以選擇在云函數和云數據庫全部開發完成后&#xff0c;將整個云工程資源統一部署到AGC云端。…

如何配置nginx解決前端跨域請求問題

我們以一個簡單的例子模擬不同情況下產生的跨域問題以及解決方案。假設在http://127.0.0.1:8000的頁面調用接口 fetch(http://127.0.0.1:8003/api/data)常看到的錯誤“Access to fetch at ‘http://127.0.0.1:8003/api/data’ from origin ‘http://localhost:8000’ has been…