文章目錄
- 前言
- 一、數據準備
- 二、項目實戰
- 2.1 設置GPU
- 2.2 數據加載
- 2.3 數據預處理
- 2.4 數據劃分
- 2.5 搭建網絡模型
- 2.6 構建densenet121
- 2.7 訓練模型
- 2.8 結果可視化
- 三、UI設計
- 四、結果展示
- 總結
前言
在當今社會,眼科疾病尤其是白內障對人們的視力健康構成了嚴重威脅。白內障是全球范圍內導致失明的主要原因之一,早期準確的診斷對于疾病的治療和患者的預后至關重要。傳統的白內障檢測方法主要依賴于眼科醫生的專業判斷,這不僅需要大量的人力和時間,而且診斷結果可能會受到醫生經驗和主觀因素的影響。
隨著深度學習技術的飛速發展,其在醫療圖像分析領域展現出了巨大的潛力。卷積神經網絡(CNN)作為深度學習中的重要模型,已經在多種醫療圖像識別任務中取得了顯著的成果,如腫瘤檢測、疾病分類等。利用 CNN 對眼科圖像進行分析,可以輔助醫生更快速、準確地進行疾病診斷。
本文將詳細介紹如何使用基于 DenseNet 的卷積神經網絡進行白內障疾病檢測。通過這個實戰案例,不僅可以幫助讀者了解 DenseNet 的原理和應用,還能掌握利用深度學習進行醫療圖像分析的基本流程和方法,為進一步開展相關研究和實踐提供參考。
一、數據準備
本案例使用的數據集是retina_dataset|眼科疾病數據集。
數據集下載地址:點擊這里
Retina Dataset的構建基于眼底圖像的分類需求,涵蓋了四種主要的眼科疾病類別:正常、白內障、青光眼和視網膜疾病。數據集通過收集和整理不同患者的視網膜圖像,確保每類疾病均有代表性樣本。圖像數據經過標準化處理,以保證在不同設備和條件下獲取的圖像具有一致性,從而為后續的分類和分析提供了堅實的基礎。
二、項目實戰
我的環境:
- 基礎環境:Python3.9
- 編譯器:PyCharm 2024
- 深度學習框架:Pytorch2.0
2.1 設置GPU
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision
from torchvision import transforms, datasets
import os,PIL,pathlib,warningswarnings.filterwarnings("ignore") #忽略警告信息device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
2.2 數據加載
import os,PIL,random,pathlibdata_dir = '數據路徑'
data_dir = pathlib.Path(data_dir)data_paths = list(data_dir.glob('*'))
classeNames = [str(path).split("\\")[1] for path in data_paths]
2.3 數據預處理
train_transforms = transforms.Compose([transforms.Resize([224, 224]), # 將輸入圖片resize成統一尺寸# transforms.RandomHorizontalFlip(), # 隨機水平翻轉transforms.ToTensor(), # 將PIL Image或numpy.ndarray轉換為tensor,并歸一化到[0,1]之間transforms.Normalize( # 標準化處理-->轉換為標準正太分布(高斯分布),使模型更容易收斂mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 其中 mean=[0.485,0.456,0.406]與std=[0.229,0.224,0.225] 從數據集中隨機抽樣計算得到的。
])test_transform = transforms.Compose([transforms.Resize([224, 224]), # 將輸入圖片resize成統一尺寸transforms.ToTensor(), # 將PIL Image或numpy.ndarray轉換為tensor,并歸一化到[0,1]之間transforms.Normalize( # 標準化處理-->轉換為標準正太分布(高斯分布),使模型更容易收斂mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 其中 mean=[0.485,0.456,0.406]與std=[0.229,0.224,0.225] 從數據集中隨機抽樣計算得到的。
])
total_data = datasets.ImageFolder(data_dir,transform=train_transforms)
2.4 數據劃分
train_size = int(0.8 * len(total_data))
test_size = len(total_data) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(total_data, [train_size, test_size])
batch_size = 32train_dl = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True)
test_dl = torch.utils.data.DataLoader(test_dataset,batch_size=batch_size,shuffle=True)
2.5 搭建網絡模型
import torch.nn as nn
import torch
from torch import mean, maxclass _DenseLayer(nn.Module):def __init__(self, num_input_features, growth_rate, bn_size, drop_rate=0):super(_DenseLayer, self).__init__()self.drop_rate = drop_rateself.dense_layer = nn.Sequential(nn.BatchNorm2d(num_input_features),nn.ReLU(),nn.Conv2d(in_channels=num_input_features, out_channels=bn_size * growth_rate, kernel_size=1, stride=1,padding=0),Inceptionnext(bn_size * growth_rate, bn_size * growth_rate, kernel_size=3),CBAMBlock("FC", 5, channels=bn_size * growth_rate, ratio=9),nn.Conv2d(in_channels=bn_size * growth_rate, out_channels=growth_rate, kernel_size=1, stride=1, padding=0))self.dropout = nn.Dropout(p=self.drop_rate)def forward(self, x):y = self.dense_layer(x)if self.drop_rate > 0:y = self.dropout(y)return torch.concat([x, y], dim=1)class _DenseBlock(nn.Module):def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate=0):super(_DenseBlock, self).__init__()layers = []for i in range(num_layers):layers.append(_DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate))self.layers = nn.Sequential(*layers)def forward(self, x):return self.layers(x)class _TransitionLayer(nn.Module):def __init__(self, num_input_features, num_output_features):super(_TransitionLayer, self).__init__()self.transition_layer = nn.Sequential(nn.BatchNorm2d(num_input_features),nn.ReLU(),nn.Conv2d(in_channels=num_input_features, out_channels=num_output_features, kernel_size=1, stride=1,padding=0),nn.AvgPool2d(kernel_size=2, stride=2))def forward(self, x):return self.transition_layer(x)class DenseNet(nn.Module):def __init__(self, num_init_features=64, growth_rate=32, blocks=(6, 12, 24, 16), bn_size=4, drop_rate=0,num_classes=1000):super(DenseNet, self).__init__()self.features = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=num_init_features, kernel_size=7, stride=2, padding=3),nn.BatchNorm2d(num_init_features),nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))num_features = num_init_featuresself.layer1 = _DenseBlock(num_layers=blocks[0], num_input_features=num_features, growth_rate=growth_rate,bn_size=bn_size, drop_rate=drop_rate)num_features = num_features + blocks[0] * growth_rateself.transtion1 = _TransitionLayer(num_input_features=num_features, num_output_features=num_features // 2)num_features = num_features // 2self.layer2 = _DenseBlock(num_layers=blocks[1], num_input_features=num_features, growth_rate=growth_rate,bn_size=bn_size, drop_rate=drop_rate)num_features = num_features + blocks[1] * growth_rateself.transtion2 = _TransitionLayer(num_input_features=num_features, num_output_features=num_features // 2)num_features = num_features // 2self.layer3 = _DenseBlock(num_layers=blocks[2], num_input_features=num_features, growth_rate=growth_rate,bn_size=bn_size, drop_rate=drop_rate)num_features = num_features + blocks[2] * growth_rateself.transtion3 = _TransitionLayer(num_input_features=num_features, num_output_features=num_features // 2)num_features = num_features // 2self.layer4 = _DenseBlock(num_layers=blocks[3], num_input_features=num_features, growth_rate=growth_rate,bn_size=bn_size, drop_rate=drop_rate)num_features = num_features + blocks[3] * growth_rateself.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(num_features, num_classes)def forward(self, x):x = self.features(x)x = self.layer1(x)x = self.transtion1(x)x = self.layer2(x)x = self.transtion2(x)x = self.layer3(x)x = self.transtion3(x)x = self.layer4(x)x = self.avgpool(x)y = torch.flatten(x, start_dim=1)x = self.fc(y)return x
2.6 構建densenet121
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))densenet121 = DenseNet(blocks=(6,12,24,16),num_classes=len(classeNames)) model = densenet121.to(device)
2.7 訓練模型
# 訓練循環
def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset) # 訓練集的大小num_batches = len(dataloader) # 批次數目, (size/batch_size,向上取整)train_loss, train_acc = 0, 0 # 初始化訓練損失和正確率for X, y in dataloader: # 獲取圖片及其標簽X, y = X.to(device), y.to(device)# 計算預測誤差pred = model(X) # 網絡輸出loss = loss_fn(pred, y) # 計算網絡輸出和真實值之間的差距,targets為真實值,計算二者差值即為損失# 反向傳播optimizer.zero_grad() # grad屬性歸零loss.backward() # 反向傳播optimizer.step() # 每一步自動更新# 記錄acc與losstrain_acc += (pred.argmax(1) == y).type(torch.float).sum().item()train_loss += loss.item()train_acc /= sizetrain_loss /= num_batchesreturn train_acc, train_lossdef test (dataloader, model, loss_fn):size = len(dataloader.dataset) # 測試集的大小num_batches = len(dataloader) # 批次數目, (size/batch_size,向上取整)test_loss, test_acc = 0, 0# 當不進行訓練時,停止梯度更新,節省計算內存消耗with torch.no_grad():for imgs, target in dataloader:imgs, target = imgs.to(device), target.to(device)# 計算losstarget_pred = model(imgs)loss = loss_fn(target_pred, target)test_loss += loss.item()test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()test_acc /= sizetest_loss /= num_batchesreturn test_acc, test_lossimport copyoptimizer = torch.optim.Adam(model.parameters(), lr= 1e-4)
loss_fn = nn.CrossEntropyLoss() # 創建損失函數epochs = 20train_loss = []
train_acc = []
test_loss = []
test_acc = []best_acc = 0 # 設置一個最佳準確率,作為最佳模型的判別指標for epoch in range(epochs):model.train()epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, optimizer)model.eval()epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)# 保存最佳模型到 best_modelif epoch_test_acc > best_acc:best_acc = epoch_test_accbest_model = copy.deepcopy(model)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)# 獲取當前的學習率lr = optimizer.state_dict()['param_groups'][0]['lr']template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}, Lr:{:.2E}')print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss, lr))# 保存最佳模型到文件中
PATH = './best_model.pth' # 保存的參數文件名
torch.save(best_model.state_dict(), PATH)print('Done')
2.8 結果可視化
import matplotlib.pyplot as plt
#隱藏警告
import warnings
warnings.filterwarnings("ignore") #忽略警告信息epochs_range = range(epochs)plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Test Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Test Loss')
plt.show()
三、UI設計
這里使用QT Designder設計了一個簡易的UI界面,可以很方便的進行使用。
UI.py文件如下
import test
from PyQt5 import QtCore, QtGui, QtWidgets
from PyQt5.QtGui import QIcon
from PyQt5.QtWidgets import QFileDialog, QMessageBoximgName=Noneclass Ui_Form(object):def setupUi(self, Form):Form.setObjectName("Form")Form.resize(649, 559)self.label = QtWidgets.QLabel(Form)self.label.setGeometry(QtCore.QRect(120, 20, 331, 41))self.label.setStyleSheet("\n"
"font: 22pt \"華文彩云\";")self.label.setObjectName("label")self.label_2 = QtWidgets.QLabel(Form)self.label_2.setGeometry(QtCore.QRect(120, 130, 311, 251))self.label_2.setStyleSheet("border-image: url(:/新前綴/img.png);")self.label_2.setText("")self.label_2.setObjectName("label_2")self.label_4 = QtWidgets.QLabel(Form)self.label_4.setGeometry(QtCore.QRect(60, 440, 72, 15))self.label_4.setObjectName("label_4")self.textEdit = QtWidgets.QTextEdit(Form)self.textEdit.setGeometry(QtCore.QRect(140, 440, 211, 91))self.textEdit.setObjectName("textEdit")self.layoutWidget = QtWidgets.QWidget(Form)self.layoutWidget.setGeometry(QtCore.QRect(60, 400, 221, 31))self.layoutWidget.setObjectName("layoutWidget")self.horizontalLayout_2 = QtWidgets.QHBoxLayout(self.layoutWidget)self.horizontalLayout_2.setContentsMargins(0, 0, 0, 0)self.horizontalLayout_2.setObjectName("horizontalLayout_2")self.label_3 = QtWidgets.QLabel(self.layoutWidget)self.label_3.setObjectName("label_3")self.horizontalLayout_2.addWidget(self.label_3)self.lineEdit_2 = QtWidgets.QLineEdit(self.layoutWidget)self.lineEdit_2.setObjectName("lineEdit_2")self.horizontalLayout_2.addWidget(self.lineEdit_2)self.layoutWidget1 = QtWidgets.QWidget(Form)self.layoutWidget1.setGeometry(QtCore.QRect(30, 70, 591, 41))self.layoutWidget1.setObjectName("layoutWidget1")self.horizontalLayout_3 = QtWidgets.QHBoxLayout(self.layoutWidget1)self.horizontalLayout_3.setContentsMargins(0, 0, 0, 0)self.horizontalLayout_3.setObjectName("horizontalLayout_3")self.horizontalLayout = QtWidgets.QHBoxLayout()self.horizontalLayout.setObjectName("horizontalLayout")self.lineEdit = QtWidgets.QLineEdit(self.layoutWidget1)self.lineEdit.setObjectName("lineEdit")self.horizontalLayout.addWidget(self.lineEdit)self.pushButton = QtWidgets.QPushButton(self.layoutWidget1)self.pushButton.setObjectName("pushButton")self.horizontalLayout.addWidget(self.pushButton)self.horizontalLayout_3.addLayout(self.horizontalLayout)self.pushButton_2 = QtWidgets.QPushButton(self.layoutWidget1)self.pushButton_2.setObjectName("pushButton_2")self.horizontalLayout_3.addWidget(self.pushButton_2)self.pushButton.clicked.connect(self.openImage)self.pushButton_2.clicked.connect(self.inferImage)self.retranslateUi(Form)QtCore.QMetaObject.connectSlotsByName(Form)def retranslateUi(self, Form):_translate = QtCore.QCoreApplication.translateForm.setWindowTitle(_translate("Form", "Form"))self.label.setText(_translate("Form", "白內障檢測系統"))self.label_4.setText(_translate("Form", "診斷建議:"))self.label_3.setText(_translate("Form", "識別結果:"))self.pushButton.setText(_translate("Form", "打開文件"))self.pushButton_2.setText(_translate("Form", "開始識別"))def openImage(self): # 選擇本地圖片上傳global imgName # 這里為了方便別的地方引用圖片路徑,我們把它設置為全局變量imgName, imgType = QFileDialog.getOpenFileName(self, "打開圖片", "","*.jpg;*.png;;All Files(*)") # 彈出一個文件選擇框,第一個返回值imgName記錄選中的文件路徑+文件名,第二個返回值imgType記錄文件的類型jpg = QtGui.QPixmap(imgName).scaled(self.label_2.width(),self.label_2.height()) # 通過文件路徑獲取圖片文件,并設置圖片長寬為label控件的長寬self.label_2.setPixmap(jpg) # 在label控件上顯示選擇的圖片self.lineEdit.setText(imgName) # 顯示所選圖片的本地路徑def inferImage(self):global imgNameif imgName is None or imgName == '':QMessageBox.information(self, "Error!", "請先選擇圖片!", QMessageBox.Ok)returna1, a2 = test.infer(imgName)self.lineEdit_2.setText(a1)self.textEdit.setText(a2)
import asd_rc
四、結果展示
總結
通過本次案例,我們可以對深度學習設計程序的流程有一個簡單清楚的認知,以便我們將來構建其它深度學習系統可以更加得心應手。