深度學習實戰基礎案例——卷積神經網絡(CNN)基于SqueezeNet的眼疾識別|第1例

文章目錄

  • 前言
  • 一、數據準備
      • 1.1 數據集介紹
      • 1.2 數據集文件結構
  • 二、項目實戰
      • 2.1 數據標簽劃分
      • 2.2 數據預處理
      • 2.3 構建模型
      • 2.4 開始訓練
      • 2.5 結果可視化
  • 三、數據集個體預測

前言

SqueezeNet是一種輕量且高效的CNN模型,它參數比AlexNet少50倍,但模型性能(accuracy)與AlexNet接近。顧名思義,Squeeze的中文意思是壓縮和擠壓的意思,所以我們通過算法的名字就可以猜想到,該算法一定是通過壓縮模型來降低模型參數量的。當然任何算法的改進都是在原先的基礎上提升精度或者降低模型參數,因此該算法的主要目的就是在于降低模型參數量的同時保持模型精度。


我的環境:

  • 基礎環境:python3.7
  • 編譯器:pycharm
  • 深度學習框架:pytorch
  • 數據集代碼獲取:鏈接(提取碼:2357 )

一、數據準備

本案例使用的數據集是眼疾識別數據集iChallenge-PM。

1.1 數據集介紹

iChallenge-PM是百度大腦和中山大學中山眼科中心聯合舉辦的iChallenge比賽中,提供的關于病理性近視(Pathologic Myopia,PM)的醫療類數據集,包含1200個受試者的眼底視網膜圖片,訓練、驗證和測試數據集各400張。

  • training.zip:包含訓練中的圖片和標簽
  • validation.zip:包含驗證集的圖片
  • valid_gt.zip:包含驗證集的標簽

該數據集是從AI Studio平臺中下載的,具體信息如下:
在這里插入圖片描述

1.2 數據集文件結構

數據集中共有三個壓縮文件,分別是:

  • training.zip
├── PALM-Training400
│   ├── PALM-Training400.zip
│   │   ├── H0002.jpg
│   │   └── ...
│   ├── PALM-Training400-Annotation-D&F.zip
│   │   └── ...
│   └── PALM-Training400-Annotation-Lession.zip└── ...
  • valid_gt.zip:標記的位置 里面的PM_Lable_and_Fovea_Location.xlsx就是標記文件
├── PALM-Validation-GT
│   ├── Lession_Masks
│   │   └── ...
│   ├── Disc_Masks
│   │   └── ...
│   └── PM_Lable_and_Fovea_Location.xlsx
  • validation.zip:測試數據集
├── PALM-Validation
│   ├── V0001.jpg
│   ├── V0002.jpg
│   └── ...

二、項目實戰

項目結構如下:
在這里插入圖片描述

2.1 數據標簽劃分

該眼疾數據集格式有點復雜,這里我對數據集進行了自己的處理,將訓練集和驗證集寫入txt文本里面,分別對應它的圖片路徑和標簽。

import os
import pandas as pd
# 將訓練集劃分標簽
train_dataset = r"F:\SqueezeNet\data\PALM-Training400\PALM-Training400"
train_list = []
label_list = []train_filenames = os.listdir(train_dataset)for name in train_filenames:filepath = os.path.join(train_dataset, name)train_list.append(filepath)if name[0] == 'N' or name[0] == 'H':label = 0label_list.append(label)elif name[0] == 'P':label = 1label_list.append(label)else:raise('Error dataset!')with open('F:/SqueezeNet/train.txt', 'w', encoding='UTF-8') as f:i = 0for train_img in train_list:f.write(str(train_img) + ' ' +str(label_list[i]))i += 1f.write('\n')
# 將驗證集劃分標簽
valid_dataset = r"F:\SqueezeNet\data\PALM-Validation400"
valid_filenames = os.listdir(valid_dataset)
valid_label = r"F:\SqueezeNet\data\PALM-Validation-GT\PM_Label_and_Fovea_Location.xlsx"
data = pd.read_excel(valid_label)
valid_data = data[['imgName', 'Label']].values.tolist()with open('F:/SqueezeNet/valid.txt', 'w', encoding='UTF-8') as f:for valid_img in valid_data:f.write(str(valid_dataset) + '/' + valid_img[0] + ' ' + str(valid_img[1]))f.write('\n')

2.2 數據預處理

這里采用到的數據預處理,主要有調整圖像大小、隨機翻轉、歸一化等。

import os.path
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import transformstransform_BZ = transforms.Normalize(mean=[0.5, 0.5, 0.5],std=[0.5, 0.5, 0.5]
)class LoadData(Dataset):def __init__(self, txt_path, train_flag=True):self.imgs_info = self.get_images(txt_path)self.train_flag = train_flagself.train_tf = transforms.Compose([transforms.Resize(224),  # 調整圖像大小為224x224transforms.RandomHorizontalFlip(),  # 隨機左右翻轉圖像transforms.RandomVerticalFlip(),  # 隨機上下翻轉圖像transforms.ToTensor(),  # 將PIL Image或numpy.ndarray轉換為tensor,并歸一化到[0,1]之間transform_BZ  # 執行某些復雜變換操作])self.val_tf = transforms.Compose([transforms.Resize(224),  # 調整圖像大小為224x224transforms.ToTensor(),  # 將PIL Image或numpy.ndarray轉換為tensor,并歸一化到[0,1]之間transform_BZ  # 執行某些復雜變換操作])def get_images(self, txt_path):with open(txt_path, 'r', encoding='utf-8') as f:imgs_info = f.readlines()imgs_info = list(map(lambda x: x.strip().split(' '), imgs_info))return imgs_infodef padding_black(self, img):w, h = img.sizescale = 224. / max(w, h)img_fg = img.resize([int(x) for x in [w * scale, h * scale]])size_fg = img_fg.sizesize_bg = 224img_bg = Image.new("RGB", (size_bg, size_bg))img_bg.paste(img_fg, ((size_bg - size_fg[0]) // 2,(size_bg - size_fg[1]) // 2))img = img_bgreturn imgdef __getitem__(self, index):img_path, label = self.imgs_info[index]img_path = os.path.join('', img_path)img = Image.open(img_path)img = img.convert("RGB")img = self.padding_black(img)if self.train_flag:img = self.train_tf(img)else:img = self.val_tf(img)label = int(label)return img, labeldef __len__(self):return len(self.imgs_info)

2.3 構建模型

import torch
import torch.nn as nn
import torch.nn.init as initclass Fire(nn.Module):def __init__(self, inplanes, squeeze_planes,expand1x1_planes, expand3x3_planes):super(Fire, self).__init__()self.inplanes = inplanesself.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1)self.squeeze_activation = nn.ReLU(inplace=True)self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes,kernel_size=1)self.expand1x1_activation = nn.ReLU(inplace=True)self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes,kernel_size=3, padding=1)self.expand3x3_activation = nn.ReLU(inplace=True)def forward(self, x):x = self.squeeze_activation(self.squeeze(x))return torch.cat([self.expand1x1_activation(self.expand1x1(x)),self.expand3x3_activation(self.expand3x3(x))], 1)class SqueezeNet(nn.Module):def __init__(self, version='1_0', num_classes=1000):super(SqueezeNet, self).__init__()self.num_classes = num_classesif version == '1_0':self.features = nn.Sequential(nn.Conv2d(3, 96, kernel_size=7, stride=2),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),Fire(96, 16, 64, 64),Fire(128, 16, 64, 64),Fire(128, 32, 128, 128),nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),Fire(256, 32, 128, 128),Fire(256, 48, 192, 192),Fire(384, 48, 192, 192),Fire(384, 64, 256, 256),nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),Fire(512, 64, 256, 256),)elif version == '1_1':self.features = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, stride=2),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),Fire(64, 16, 64, 64),Fire(128, 16, 64, 64),nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),Fire(128, 32, 128, 128),Fire(256, 32, 128, 128),nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),Fire(256, 48, 192, 192),Fire(384, 48, 192, 192),Fire(384, 64, 256, 256),Fire(512, 64, 256, 256),)else:# FIXME: Is this needed? SqueezeNet should only be called from the# FIXME: squeezenet1_x() functions# FIXME: This checking is not done for the other modelsraise ValueError("Unsupported SqueezeNet version {version}:""1_0 or 1_1 expected".format(version=version))# Final convolution is initialized differently from the restfinal_conv = nn.Conv2d(512, self.num_classes, kernel_size=1)self.classifier = nn.Sequential(nn.Dropout(p=0.5),final_conv,nn.ReLU(inplace=True),nn.AdaptiveAvgPool2d((1, 1)))for m in self.modules():if isinstance(m, nn.Conv2d):if m is final_conv:init.normal_(m.weight, mean=0.0, std=0.01)else:init.kaiming_uniform_(m.weight)if m.bias is not None:init.constant_(m.bias, 0)def forward(self, x):x = self.features(x)x = self.classifier(x)return torch.flatten(x, 1)

2.4 開始訓練

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from model import SqueezeNet
import torchsummary
from dataloader import LoadData
import copydevice = "cuda:0" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))model = SqueezeNet(num_classes=2).to(device)
# print(model)
#print(torchsummary.summary(model, (3, 224, 224), 1))# 加載訓練集和驗證集
train_data = LoadData(r"F:\SqueezeNet\train.txt", True)
train_dl = torch.utils.data.DataLoader(train_data, batch_size=16, pin_memory=True,shuffle=True, num_workers=0)
test_data = LoadData(r"F:\SqueezeNet\valid.txt", True)
test_dl = torch.utils.data.DataLoader(test_data, batch_size=16, pin_memory=True,shuffle=True, num_workers=0)# 編寫訓練函數
def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)  # 訓練集的大小num_batches = len(dataloader)  # 批次數目, (size/batch_size,向上取整)print('num_batches:', num_batches)train_loss, train_acc = 0, 0  # 初始化訓練損失和正確率for X, y in dataloader:  # 獲取圖片及其標簽X, y = X.to(device), y.to(device)# 計算預測誤差pred = model(X)  # 網絡輸出loss = loss_fn(pred, y)  # 計算網絡輸出和真實值之間的差距,targets為真實值,計算二者差值即為損失# 反向傳播optimizer.zero_grad()  # grad屬性歸零loss.backward()  # 反向傳播optimizer.step()  # 每一步自動更新# 記錄acc與losstrain_acc += (pred.argmax(1) == y).type(torch.float).sum().item()train_loss += loss.item()train_acc /= sizetrain_loss /= num_batchesreturn train_acc, train_loss# 編寫驗證函數
def test(dataloader, model, loss_fn):size = len(dataloader.dataset)  # 測試集的大小num_batches = len(dataloader)  # 批次數目, (size/batch_size,向上取整)test_loss, test_acc = 0, 0# 當不進行訓練時,停止梯度更新,節省計算內存消耗with torch.no_grad():for imgs, target in dataloader:imgs, target = imgs.to(device), target.to(device)# 計算losstarget_pred = model(imgs)loss = loss_fn(target_pred, target)test_loss += loss.item()test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()test_acc /= sizetest_loss /= num_batchesreturn test_acc, test_loss# 開始訓練epochs = 20train_loss = []
train_acc = []
test_loss = []
test_acc = []best_acc = 0  # 設置一個最佳準確率,作為最佳模型的判別指標loss_function = nn.CrossEntropyLoss()  # 定義損失函數
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # 定義Adam優化器for epoch in range(epochs):model.train()epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_function, optimizer)model.eval()epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_function)# 保存最佳模型到 best_modelif epoch_test_acc > best_acc:best_acc = epoch_test_accbest_model = copy.deepcopy(model)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)# 獲取當前的學習率lr = optimizer.state_dict()['param_groups'][0]['lr']template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}, Lr:{:.2E}')print(template.format(epoch + 1, epoch_train_acc * 100, epoch_train_loss,epoch_test_acc * 100, epoch_test_loss, lr))# 保存最佳模型到文件中
PATH = './best_model.pth'  # 保存的參數文件名
torch.save(best_model.state_dict(), PATH)print('Done')

在這里插入圖片描述

2.5 結果可視化

import matplotlib.pyplot as plt
#隱藏警告
import warnings
warnings.filterwarnings("ignore")               #忽略警告信息
plt.rcParams['font.sans-serif']    = ['SimHei'] # 用來正常顯示中文標簽
plt.rcParams['axes.unicode_minus'] = False      # 用來正常顯示負號
plt.rcParams['figure.dpi']         = 100        #分辨率epochs_range = range(epochs)plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Test Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Test Loss')
plt.show()

可視化結果如下:
在這里插入圖片描述
可以自行調整學習率以及batch_size,這里我的超參數并沒有調整。

三、數據集個體預測

import matplotlib.pyplot as plt
from PIL import Image
from torchvision.transforms import transforms
from model import SqueezeNet
import torchdata_transform = transforms.Compose([transforms.ToTensor(),transforms.Resize((224, 224)),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])img = Image.open("F:\SqueezeNet\data\PALM-Validation400\V0008.jpg")
plt.imshow(img)
img = data_transform(img)
img = torch.unsqueeze(img, dim=0)
name = ['非病理性近視', '病理性近視']
model_weight_path = r"F:\SqueezeNet\best_model.pth"
model = SqueezeNet(num_classes=2)
model.load_state_dict(torch.load(model_weight_path))
model.eval()
with torch.no_grad():output = torch.squeeze(model(img))predict = torch.softmax(output, dim=0)# 獲得最大可能性索引predict_cla = torch.argmax(predict).numpy()print('索引為', predict_cla)
print('預測結果為:{},置信度為: {}'.format(name[predict_cla], predict[predict_cla].item()))
plt.show()
索引為 1
預測結果為:病理性近視,置信度為: 0.9768268465995789

在這里插入圖片描述

更詳細的請看paddle版本的實現:深度學習實戰基礎案例——卷積神經網絡(CNN)基于SqueezeNet的眼疾識別

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/39288.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/39288.shtml
英文地址,請注明出處:http://en.pswp.cn/news/39288.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Linkedin為什么要退出中國市場?

在迅速發展的時代,職場也在不斷變換,只有不斷地提升專業技能和進行培訓,才能在職場中獲得成功。Linkedin作為一家專注于職業發展的平臺,專業的學習體驗以及熱門技能贏得了人們青睞。然而遺憾的是這個曾經讓人備受青睞的平臺,如今卻在中國市場中黯然落幕,究竟是何種原因讓曾經風…

大數據Flink(六十一):Flink流處理程序流程和項目準備

文章目錄 Flink流處理程序流程和項目準備 一、Flink流處理程序的一般流程

Springboot 設置統一的請求返回格式

現在開發過程中主要采用前后端分離的方式進行開發測試,也就是前端封裝請求,后端提供標準的API接口服務。一般現在json 格式受到開發者們的青睞,學習過程中我們可以設置接口的返回類型,那么怎么做到設置統一的返回格式呢&#xff1…

數據在內存中的存儲(二進制形式存儲)

計算機要處理的信息是多種多樣的,如數字、文字、符號、圖形、音頻、視頻等,這些信息在人們的眼里是不同的。但對于計算機來說,它們在內存中都是一樣的,都是以二進制的形式來表示。 要想學習編程,就必須了解二進制&…

Spark SQL優化:NOT IN子查詢優化解決

背景 有如下的數據查詢場景。 SELECT a,b,c,d,e,f FROM xxx.BBBB WHERE dt ${zdt.addDay(0).format(yyyy-MM-dd)} AND predict_type not IN ( SELECT distinct a FROM xxx.AAAAAWHERE dt ${zdt.addDay(0).format(yyyy-MM-dd)} ) 分析 通過查看SQL語句的執行計劃基本…

Dubbo基礎學習(筆記一)

目錄 第一章、概念介紹1.1)什么是RPC框架1.2)什么是分布式系統1.3)Dubbo概述1.3)Dubbo基本架構 第二章、服務提供者2.1)目錄結構和依賴2.2)model層2.3)service層2.4)resources配置文…

ARTS 挑戰打卡的第8天 ---volatile 關鍵字在MCU中的作用,四個實例講解(Tips)

前言 (1)volatile 關鍵字作為嵌入式面試的常考點,很多人都不是很了解,或者說一知半解。 (2)可能有些人會說了,volatile 關鍵字不就是防止編譯器優化的嗎?有啥好詳細講解的&#xff1…

HashMap底層相關內容

HashMap的底層結構: 1.7之前 數組加鏈表,當兩個值進行插入的時候 采用頭插法進行插入,可能會造成死循環 1.8之后 數組加鏈表/紅黑樹,當兩個值進行插入的時候,采用尾插法進行插入,不會造成死循環 HashMap底…

xml轉map工具類

背景&#xff1a;最近遇到接口返回是xml&#xff0c;所以需要整一個轉換的工具類&#xff0c;方便后續其他xml處理。 依賴引入&#xff1a; <dependency><groupId>dom4j</groupId><artifactId>dom4j</artifactId><version>1.1</versi…

澎峰科技|邀您關注2023 RISC-V中國峰會!

峰會概覽 2023 RISC-V中國峰會&#xff08;RISC-V Summit China 2023&#xff09;將于8月23日至25日在北京香格里拉飯店舉行。本屆峰會將以“RISC-V生態共建”為主題&#xff0c;結合當下全球新形勢&#xff0c;把握全球新時機&#xff0c;呈現RISC-V全球新觀點、新趨勢。 本…

linux下nginx配置https和反向代理本地端口

1 修改配置文件/etc/nginx/sites-enabled/default 在配置文件中增加一個server用來做https端口監聽&#xff0c; ssl_certificate和ssl_certificate_key修改為自己申請的https認證文件 server{listen 443 ssl;server_name www.dogrich.net;#root /var/www/html;# 上面配置的…

《3D 數學基礎》12 幾何圖元

目錄 1 表達圖元的方法 1.1 隱式表示法 1.2 參數表示 1.3 直接表示 2. 直線和射線 2.1 射線的不同表示法 2.1.1 兩點表示 2.1.2 參數表示 2.1.3 相互轉換 2.2 直線的不同表示法 2.2.1 隱式表示法 2.2.2 斜截式 2.2.3 相互轉換 3. 球 3.1 隱式表示 1 表達圖元的方…

C語言的使用技巧--在IO操作中的移位和快速配置

在WB32F103&#xff08;ARM cortex m3內核&#xff0c;96Mhz&#xff09;的gpio初始化中有一段代碼&#xff0c;充分的結合了硬件特征并使用C語言的技巧來快速的配置對應的GPIO的功能&#xff0c;堪稱經典和楷模&#xff0c;代碼異常簡潔&#xff0c;執行速度快&#xff0c;配置…

【深度學習所有損失函數】在 NumPy、TensorFlow 和 PyTorch 中實現(2/2)

一、說明 在本文中&#xff0c;討論了深度學習中使用的所有常見損失函數&#xff0c;并在NumPy&#xff0c;PyTorch和TensorFlow中實現了它們。 (二-五)見 六、稀疏分類交叉熵損失 稀疏分類交叉熵損失類似于分類交叉熵損失&#xff0c;但在真實標簽作為整數而不是獨熱編碼提…

Python pycparser(c文件解析)模塊使用教程

文章目錄 安裝 pycparser 模塊模塊開發者網址獲取抽象語法樹1. 需要導入的模塊2. 獲取 不關注預處理相關 c語言文件的抽象語法樹ast3. 獲取 預處理后的c語言文件的抽象語法樹ast 語法樹組成1. 數據類型定義 Typedef2. 類型聲明 TypeDecl3. 標識符類型 IdentifierType4. 變量聲明…

語聚AI公測發布,大語言模型時代下新的生產力工具

語聚AI 公測發布 距離語聚AI內測上線已經過去近1個月。 這期間&#xff0c;我們共邀請了近百位資深用戶與行業專家加入語聚AI產品體驗。通過大家的熱情參與積極反饋&#xff0c;我們不斷優化并完善了語聚AI的功能與使用體驗。 經過研發團隊不懈的努力&#xff0c;今天語聚AI終…

【Leetcode】88.合并兩個有序數組

一、題目 1、題目描述 給你兩個按 非遞減順序 排列的整數數組 nums1 和 nums2,另有兩個整數 m 和 n ,分別表示 nums1 和 nums2 中的元素數目。 請你 合并 nums2 到 nums1 中,使合并后的數組同樣按 非遞減順序 排列。 注意:最終,合并后數組不應由函數返回,而是存儲在數…

梅賽德斯-奔馳將成為首家集成ChatGPT的汽車制造商

ChatGPT的受歡迎程度毋庸置疑。OpenAI這個基于人工智能的工具&#xff0c;每天能夠吸引無數用戶使用&#xff0c;已成為當下很受歡迎的技術熱點。因此&#xff0c;有許多公司都在想方設法利用ChatGPT來提高產品吸引力&#xff0c;賣點以及性能。在汽車領域&#xff0c;梅賽德斯…

代碼隨想錄算法訓練營第59天|動態規劃part16|583. 兩個字符串的刪除操作、72. 編輯距離、編輯距離總結篇

代碼隨想錄算法訓練營第59天&#xff5c;動態規劃part16&#xff5c;583. 兩個字符串的刪除操作、72. 編輯距離、編輯距離總結篇 583. 兩個字符串的刪除操作 583. 兩個字符串的刪除操作 思路&#xff1a; 思路見代碼 代碼&#xff1a; python class Solution(object):de…

[國產MCU]-BL602開發實例-I2C與總線設備地址掃描

I2C與總線設備掃描 文章目錄 I2C與總線設備掃描1、I2C介紹2、I2C驅動API介紹3、I2C使用實例I2C (Inter-Intergrated Circuit)是一種串行通訊總線,使用多主從架構,用來連接低速外圍裝置。 每個器件都有一個唯一的地址識別,并且都可以作為一個發送器或接收器。每個連接到總線的…