基于深度學習的胸部 X 光圖像肺炎分類系統(三)

目錄

二分類胸片判斷:

1. 數據加載時指定了兩類標簽

2. 損失函數用了二分類專用的

3. 輸出層只有 1 個神經元,用了sigmoid激活函數

4. 預測時用 0.5 作為分類閾值


二分類胸片判斷:

import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, roc_auc_score, roc_curve, confusion_matrix, classification_report
from imblearn.over_sampling import RandomOverSampler
import tensorflow as tf
from keras import layers
from keras import models
# 或者更常用的是直接導入Sequential類
from keras.models import Sequential
from keras.preprocessing.image import ImageDataGenerator
import os
import zipfile
import requests
from tensorflow.python.keras.callbacks import EarlyStopping
#  這個代碼執行 請切換環境到tf_env
plt.rcParams['font.sans-serif'] = ['SimHei']  # 使用 SimHei 字體
plt.rcParams['axes.unicode_minus'] = False    # 解決負號顯示問題
plt.rcParams['font.size'] = 10  # 設置全局字體大小# 數據加載和預處理
def load_data(train_dir, test_dir, val_dir, img_size=(150, 150), batch_size=32):# 數據增強器 - 僅用于訓練集train_datagen = ImageDataGenerator(rescale=1. / 255,rotation_range=10,width_shift_range=0.1,height_shift_range=0.1,shear_range=0.1,zoom_range=0.1,horizontal_flip=True)# 驗證集和測試集只需要重新縮放val_test_datagen = ImageDataGenerator(rescale=1. / 255)# 加載訓練數據train_generator = train_datagen.flow_from_directory(train_dir,target_size=img_size,batch_size=batch_size,class_mode='binary',classes=['NORMAL', 'PNEUMONIA'],shuffle=True)# 加載驗證數據val_generator = val_test_datagen.flow_from_directory(val_dir,target_size=img_size,batch_size=batch_size,class_mode='binary',classes=['NORMAL', 'PNEUMONIA'],shuffle=False)# 加載測試數據test_generator = val_test_datagen.flow_from_directory(test_dir,target_size=img_size,batch_size=batch_size,class_mode='binary',classes=['NORMAL', 'PNEUMONIA'],shuffle=False)return train_generator, val_generator, test_generator# 處理樣本不均衡(過采樣)
def handle_imbalance(generator):# 提取特征和標簽X, y = [], []num_batches = len(generator)# 重置生成器以確保從開始獲取數據generator.reset()for i in range(num_batches):batch_x, batch_y = generator.next()X.append(batch_x)y.append(batch_y)X = np.concatenate(X)y = np.concatenate(y)# 打印原始分布print(f"原始樣本分布: 正常={np.sum(y == 0)}, 肺炎={np.sum(y == 1)}")# 展平特征用于過采樣X_flat = X.reshape(X.shape[0], -1)# 過采樣少數類ros = RandomOverSampler(random_state=42)X_resampled, y_resampled = ros.fit_resample(X_flat, y)# 恢復圖像形狀X_resampled = X_resampled.reshape(-1, *X.shape[1:])print(f"過采樣后分布: 正常={np.sum(y_resampled == 0)}, 肺炎={np.sum(y_resampled == 1)}")return X_resampled, y_resampled, y# 構建改進的CNN模型
def build_model(input_shape):model = models.Sequential([# 第一個卷積塊layers.Conv2D(32, (3, 3), activation='relu', input_shape=input_shape),layers.BatchNormalization(),layers.MaxPooling2D((2, 2)),layers.Dropout(0.2),# 第二個卷積塊layers.Conv2D(64, (3, 3), activation='relu'),layers.BatchNormalization(),layers.MaxPooling2D((2, 2)),layers.Dropout(0.3),# 第三個卷積塊layers.Conv2D(128, (3, 3), activation='relu'),layers.BatchNormalization(),layers.MaxPooling2D((2, 2)),layers.Dropout(0.4),# 第四個卷積塊layers.Conv2D(256, (3, 3), activation='relu'),layers.BatchNormalization(),layers.MaxPooling2D((2, 2)),layers.Dropout(0.5),# 分類器layers.Flatten(),layers.Dense(512, activation='relu'),layers.BatchNormalization(),layers.Dropout(0.5),layers.Dense(1, activation='sigmoid')])# 使用更穩定的優化器optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001)model.compile(optimizer=optimizer,loss='binary_crossentropy',metrics=['accuracy',tf.keras.metrics.Precision(name='precision'),tf.keras.metrics.Recall(name='recall'),tf.keras.metrics.AUC(name='auc')])return model# 主函數
def main():# 假設數據集已經手動下載并解壓train_dir = "chest_xray/train"test_dir = "chest_xray/test"val_dir = "chest_xray/val"# 加載數據img_size = (150, 150)batch_size = 32train_generator, val_generator, test_generator = load_data(train_dir, test_dir, val_dir, img_size, batch_size)# 處理樣本不均衡X_train, y_train_resampled, y_train_original = handle_imbalance(train_generator)# 計算類別權重(基于原始分布)n_normal = np.sum(y_train_original == 0)n_pneumonia = np.sum(y_train_original == 1)total = n_normal + n_pneumoniaweight_for_normal = (1 / n_normal) * (total / 2.0)weight_for_pneumonia = (1 / n_pneumonia) * (total / 2.0)class_weights = {0: weight_for_normal, 1: weight_for_pneumonia}print(f"類別權重: 正常={weight_for_normal:.2f}, 肺炎={weight_for_pneumonia:.2f}")# 構建模型model = build_model((*img_size, 3))model.summary()# 提前停止回調early_stopping = EarlyStopping(monitor='val_loss',patience=5,restore_best_weights=True,verbose=1)# 訓練模型history = model.fit(X_train, y_train_resampled,epochs=30,batch_size=32,validation_data=val_generator,class_weight=class_weights,callbacks=[early_stopping],verbose=1)# 評估模型 - 使用完整測試集test_generator.reset()test_steps = len(test_generator)test_results = model.evaluate(test_generator, steps=test_steps, verbose=1)print("\n測試集評估結果:")print(f"準確率: {test_results[1]:.4f}")print(f"精確率: {test_results[2]:.4f}")print(f"召回率: {test_results[3]:.4f}")print(f"AUC: {test_results[4]:.4f}")# 獲取測試集所有預測結果test_generator.reset()y_true = []y_pred_prob = []for i in range(test_steps):batch_x, batch_y = test_generator.next()y_true.extend(batch_y)batch_pred = model.predict(batch_x, verbose=0).ravel()y_pred_prob.extend(batch_pred)y_true = np.array(y_true)y_pred_prob = np.array(y_pred_prob)y_pred = (y_pred_prob > 0.5).astype(int)# 計算額外指標f1 = f1_score(y_true, y_pred)auc = roc_auc_score(y_true, y_pred_prob)print(f"\nF1-score: {f1:.4f}")print(f"AUC-ROC: {auc:.4f}")# 分類報告print("\n分類報告:")print(classification_report(y_true, y_pred, target_names=['NORMAL', 'PNEUMONIA']))# 混淆矩陣cm = confusion_matrix(y_true, y_pred)print("混淆矩陣:")print(cm)# 繪制ROC曲線fpr, tpr, _ = roc_curve(y_true, y_pred_prob)plt.figure(figsize=(10, 6))plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC曲線 (AUC = {auc:.4f})')plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')plt.xlim([0.0, 1.0])plt.ylim([0.0, 1.05])plt.xlabel('False Positive Rate')plt.ylabel('True Positive Rate')plt.title('接收者操作特征曲線(ROC)')plt.legend(loc="lower right")plt.savefig('roc_curve.png', dpi=300)plt.show()# 繪制訓練歷史plt.figure(figsize=(12, 8))plt.subplot(2, 2, 1)plt.plot(history.history['accuracy'], label='訓練準確率')plt.plot(history.history['val_accuracy'], label='驗證準確率')plt.title('準確率')plt.legend()plt.subplot(2, 2, 2)plt.plot(history.history['loss'], label='訓練損失')plt.plot(history.history['val_loss'], label='驗證損失')plt.title('損失')plt.legend()plt.subplot(2, 2, 3)plt.plot(history.history['precision'], label='訓練精確率')plt.plot(history.history['val_precision'], label='驗證精確率')plt.title('精確率')plt.legend()plt.subplot(2, 2, 4)plt.plot(history.history['recall'], label='訓練召回率')plt.plot(history.history['val_recall'], label='驗證召回率')plt.title('召回率')plt.legend()plt.tight_layout()plt.savefig('training_history.png', dpi=300)plt.show()if __name__ == "__main__":main()

這段代碼里有很多地方明確體現了這是一個二分類任務(判斷 “正常胸片” 和 “肺炎胸片” 兩類),最關鍵的有這幾個地方:

1. 數據加載時指定了兩類標簽

load_data 函數中,加載數據時明確指定了類別為兩類:

train_generator = train_datagen.flow_from_directory(

??? train_dir,

??? ...

??? class_mode='binary',? # 這里指定是二分類模式

??? classes=['NORMAL', 'PNEUMONIA'],? # 明確兩類:正常(NORMAL)和肺炎(PNEUMONIA

??? ...

)

  1. class_mode='binary':直接告訴程序 “這是二分類任務”,標簽會被處理成 0 和 1(0 代表正常,1 代表肺炎)。
  2. classes=['NORMAL', 'PNEUMONIA']:手動指定只有這兩個類別,沒有第三種情況。

2. 損失函數用了二分類專用的

在模型編譯時,損失函數用的是 binary_crossentropy(二分類交叉熵):

model.compile(

??? ...

??? loss='binary_crossentropy',? # 專門用于二分類的損失函數

??? ...

)

這個損失函數的作用是:計算 “模型判斷為 0 或 1 的概率” 與 “實際標簽(0 或 1)” 之間的差距,指導模型優化。如果是多分類任務,會用其他損失函數(比如 categorical_crossentropy)。

3. 輸出層只有 1 個神經元,用了sigmoid激活函數

模型的最后一層是:

layers.Dense(1, activation='sigmoid')? # 輸出層

  1. Dense(1):只輸出 1 個數值,這個數值經過 sigmoid 激活后,會被壓縮到 0~1 之間。
  2. 實際含義:
    1. 數值越接近 0 → 模型認為 “更可能是正常胸片(0 類)”;
    2. 數值越接近 1 → 模型認為 “更可能是肺炎胸片(1 類)”。

這是二分類任務的典型輸出方式(多分類會有多個神經元,對應多個類別)。

4. 預測時用 0.5 作為分類閾值

在生成最終判斷結果時:

y_pred = (y_pred_prob > 0.5).astype(int)? # 大于0.51類(肺炎),否則算0類(正常)

直接用 0.5 作為 “兩類的分界線”,把輸出概率分成 “0” 和 “1” 兩類,進一步說明這是二分類。

從 “數據標簽定義”“損失函數選擇”“輸出層設計” 到 “最終預測規則”,全流程都圍繞 “只能分成兩類” 展開,沒有任何支持多類別的設計。所以這段代碼是典型的二分類任務,目標就是區分 “正常胸片” 和 “肺炎胸片”。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/92799.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/92799.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/92799.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

深入理解 BIO、NIO、AIO

目錄 一、同步與非同步 二、阻塞與非阻塞 三、BIO(Blocking I/O,阻塞I/O) 四、NIO(Non-blocking I/O,非阻塞I/O) 五、AIO(Asynchronous I/O,異步I/O) 同步阻塞&…

電腦無法識別固態硬盤怎么辦?

隨著固態硬盤(SSD)越來越普及,不少用戶在給電腦更換、加裝SSD時會遇到一個讓人頭大的問題——電腦識別不了固態硬盤。可能是開不了機,或者在“此電腦”中找不到硬盤,甚至連系統安裝界面都提示“找不到驅動器”。這時候…

Kingbasepostgis 安裝實踐

文章目錄前言一、安裝準備1.1 部署方案規劃1.2 SELINUX、防火墻狀態檢查1.3 操作系統時間檢查1.4 創建用戶及密碼1.5 目錄創建1.6 操作系統參數配置1.6.1 配置limits.conf文件二、安裝2.1 上傳安裝包以及license授權文件2.2 拷貝安裝文件2.3 命令行方式安裝2.3.1簡介2.3.2 許可…

移動端設備能部署的llm

mlc-llm 內置RedPajama hf示例模型 TheBloke/Mistral-7B-Instruct-v0.2-GGUF https://github.com/mlc-ai/mlc-llm/tree/main llama.cpp https://github.com/ggml-org/llama.cpp reference --- MLC-LLM:大模型如何部署到瀏覽器 / 手機?完整流程復現…

Ubuntu硬盤掛載

一、在 Ubuntu 中,你可以用以下命令快速查看 所有已連接但尚未掛載的硬盤和分區:lsblk -o NAME,SIZE,FSTYPE,MOUNTPOINT,UUID輸出中 MOUNTPOINT 為空的行,就是 未掛載的分區。sda ├─sda1 500M ext4 /boot ├─sda2 1.8T ntfs └─sda3 …

JavaScript -Socket5代理使用

axios 安裝兩個包 socks-proxy-agent,axios const { SocksProxyAgent } require(socks-proxy-agent); const axios require(axios);const socks5Axios axios.create();const socks5 () > {const socks5Agent new SocksProxyAgent("socks5://112.194.8…

[特殊字符] 從數據庫無法訪問到成功修復崩潰表:一次 MySQL 故障排查實錄

一次典型的 MySQL 故障排查與修復全過程,涵蓋登錄失敗、表崩潰、innodb_force_recovery 救援、壞表剔除與數據恢復等關鍵操作。一、問題背景某業務系統運行多年,數據庫使用的是 MySQL 8.0.18,近期在一次服務器重啟后,發現無法正常…

【Agent】API Reference Manual(API 參考手冊)

https://github.com/Intelligent-Internet/CommonGround/blob/main/docs/framework/03-api-reference.md 以下是這份 API Reference Manual(API 參考手冊) 的完整中文翻譯: API 參考手冊 版本:0.1 目錄 概覽 1.1 API 目的 1.2 通信協議與核心概念 HTTP API 2.1 POST /se…

LeetCode Hot 100 全排列

給定一個不含重復數字的數組 nums ,返回其 所有可能的全排列 。你可以 按任意順序 返回答案。示例 1:輸入:nums [1,2,3] 輸出:[[1,2,3],[1,3,2],[2,1,3],[2,3,1],[3,1,2],[3,2,1]]示例 2:輸入:nums [0,1]…

AI大模型如何有效識別和糾正數據中的偏見?

當下,人工智能大模型已成為推動各行業發展的關鍵力量,廣泛應用于自然語言處理、圖像識別、醫療診斷、金融風控等領域,為人們的生活和工作帶來了諸多便利。然而,隨著其應用的不斷深入,數據偏見問題逐漸浮出水面&#xf…

如何通過內網穿透,訪問公司內部服務器?

“凌晨2點,銷售總監王姐在機場候機時突然接到客戶電話——對方要求立即查看產品庫存數據。她慌忙翻出筆記本電腦,卻發現公司內網數據庫沒有公網IP,VPN連接又卡在驗證環節……這樣的場景,是否讓你想起某個手忙腳亂的時刻&#xff1…

12. isaacsim4.2教程-ROS 導航

1. Teleport 示例 ROS 服務的作用: 提供了一種同步、請求-響應的通信方式,用于執行那些需要即時獲取結果或狀態反饋的一次性操作或查詢。 Teleport 服務在 ROS 仿真(尤其是 Gazebo)和某些簡單機器人控制中扮演著瞬移機器人或對象…

DeepSpeed-FastGen:通過 MII 和 DeepSpeed-Inference 實現大語言模型的高吞吐文本生成

溫馨提示: 本篇文章已同步至"AI專題精講" DeepSpeed-FastGen:通過 MII 和 DeepSpeed-Inference 實現大語言模型的高吞吐文本生成 摘要 隨著大語言模型(LLM)被廣泛應用,其部署與擴展變得至關重要&#xff0…

操作系統:操作系統的結構(Structures of Operating System)

目錄 簡單結構(Simple Structure) 整體式結構(Monolithic Structure) 什么是 Kernel(內核)? 層次結構(Layered Structure) 微內核結構(Microkernel&#x…

Python柱狀圖

1.各國GDP柱狀圖2.各國GDP時間線柱狀圖

FastGPT:企業級智能問答系統,讓知識庫觸手可及

在信息爆炸的時代,企業如何高效管理和利用海量知識?傳統搜索和文檔庫已難以滿足需求。FastGPT正成為企業構建智能知識核心的首選。一、FastGPT:不止于問答的智能知識引擎FastGPT 顛覆了傳統知識庫的局限,其核心優勢在于&#xff1…

探索 MyBatis-Plus

引言在當今的 Java 開發領域,數據庫操作是一個至關重要的環節。MyBatis 作為一款優秀的持久層框架,已經被廣泛應用。而 MyBatis-Plus 則是在 MyBatis 基礎上進行增強的工具,它簡化了開發流程,提高了開發效率。本文將詳細介紹 MyBa…

Hive【安裝 01】hive-3.1.2版本安裝配置(含 mysql-connector-java-5.1.47.jar 網盤資源)

我使用的安裝文件是 apache-hive-3.1.2-bin.tar.gz ,以下內容均以此版本進行說明。 以下環境測試安裝成功: openEuler 22.03 (LTS-SP1)系統 MySQL-8.0.40 1.前置條件 MySQL數據庫 我安裝的是 mysql-5.7.28 版本的,安裝方法可參考《Linux環境…

璞致 PZSDR-P101:ZYNQ7100+AD9361 架構軟件無線電平臺,重塑寬頻信號處理范式

璞致電子 PZSDR-P101 軟件無線電平臺以 "異構計算 寬頻射頻 工業級可靠性" 為核心設計理念,基于 Xilinx ZYNQ7100 處理器與 ADI AD9361 射頻芯片構建,為工程師提供從 70MHz 到 6GHz 的全頻段信號處理解決方案。無論是頻譜監測、無線通信原型…