Pytorch從零開始實戰12

Pytorch從零開始實戰——DenseNet算法實戰

本系列來源于365天深度學習訓練營

原作者K同學

文章目錄

  • Pytorch從零開始實戰——DenseNet算法實戰
    • 環境準備
    • 數據集
    • 模型選擇
    • 開始訓練
    • 可視化
    • 總結

環境準備

本文基于Jupyter notebook,使用Python3.8,Pytorch2.0.1+cu118,torchvision0.15.2,需讀者自行配置好環境且有一些深度學習理論基礎。本次實驗的目的是理解并使用DenseNet模型,本次實驗由于參數較大,建議使用GPU進行計算。
第一步,導入常用包

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.nn.functional as F
import random
from time import time
import numpy as np
import pandas as pd
import datetime
import gc
import os
import copy
import warnings
os.environ['KMP_DUPLICATE_LIB_OK']='True'  # 用于避免jupyter環境突然關閉
torch.backends.cudnn.benchmark=True  # 用于加速GPU運算的代碼

設置隨機數種子

torch.manual_seed(428)
torch.cuda.manual_seed(428)
torch.cuda.manual_seed_all(428)
random.seed(428)
np.random.seed(428)

檢查設備對象

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device, torch.cuda.device_count() # # (device(type='cuda'), 2)

數據集

本次實驗將探索在醫學領域使用深度學習,準確識別和分類乳腺癌亞型是一項重要的臨床任務,利用深度學習方法識別可以有效節省時間并減少錯誤。數據集是由多張以 40 倍掃描的乳腺癌 (BCa) 標本的完整載玻片圖像組成。
使用pathlib查看類別,本次類別只有0,1兩種類別分別代表不患癌和患癌

import pathlib
data_dir = './data/ill/'
data_dir = pathlib.Path(data_dir) # 轉成pathlib.Path對象
data_paths = list(data_dir.glob('*')) 
classNames = [str(path).split("/")[2] for path in data_paths]
classNames # ['0', '1']

使用transforms對數據集進行統一處理,并且根據文件夾名映射對應標簽

all_transforms = transforms.Compose([transforms.Resize([224, 224]),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 標準化
])total_data = datasets.ImageFolder("./data/ill/", transform=all_transforms)
total_data.class_to_idx # {'0': 0, '1': 1}

隨機查看5張圖片

def plotsample(data):fig, axs = plt.subplots(1, 5, figsize=(10, 10)) #建立子圖for i in range(5):num = random.randint(0, len(data) - 1) #首先選取隨機數,隨機選取五次#抽取數據中對應的圖像對象,make_grid函數可將任意格式的圖像的通道數升為3,而不改變圖像原始的數據#而展示圖像用的imshow函數最常見的輸入格式也是3通道npimg = torchvision.utils.make_grid(data[num][0]).numpy()nplabel = data[num][1] #提取標簽 #將圖像由(3, weight, height)轉化為(weight, height, 3),并放入imshow函數中讀取axs[i].imshow(np.transpose(npimg, (1, 2, 0))) axs[i].set_title(nplabel) #給每個子圖加上標簽axs[i].axis("off") #消除每個子圖的坐標軸plotsample(total_data)

在這里插入圖片描述
根據8比2劃分數據集和測試集,并且利用DataLoader劃分批次和隨機打亂

train_size = int(0.8 * len(total_data))
test_size  = len(total_data) - train_size
train_ds, test_ds = torch.utils.data.random_split(total_data, [train_size, test_size])batch_size = 32
train_dl = torch.utils.data.DataLoader(train_ds,batch_size=batch_size,shuffle=True,)
test_dl = torch.utils.data.DataLoader(test_ds,batch_size=batch_size,shuffle=True,)len(train_dl.dataset), len(test_dl.dataset) # (10722, 2681)

模型選擇

本次實驗使用DenseNet模型,DenseNet的設計核心思想是通過密集連接來增強神經網絡的信息流動,促進梯度的傳播,以及提高參數的共享和重復使用。采用跨通道concat的形式來連接,會連接前面所有層作為輸入,這里的連接不是ResNet那樣的相加,而在channel維度的疊加。
核心公式為:
在這里插入圖片描述
DenseNet中的基本組成單元是DenseBlock,它由多個密集連接的DenseLayer組成。每個DenseLayer都接收所有前面的DenseLayer特征作為輸入,將其連接到自己的輸出上,并傳遞給后續的層。如圖所示,這是一個基本的DenseBlock模塊。
在這里插入圖片描述
整體網絡架構圖如下所示,借用K同學的圖片
在這里插入圖片描述

為了控制模型的復雜度并減少特征圖的大小,DenseNet引入了Transition Block。過渡塊包括批歸一化、ReLU激活和 1x1 卷積,以減小特征圖的通道數,并通過池化操作降低空間維度。
在這里插入圖片描述
首先對DenseLayer類定義,本次實驗使用add_module函數,默認是用于向類中添加一個子模塊,第一個參數為模塊名,第二個參數為模塊實例,其實相當于加到父類的nn.Sequential里面,所以調用的時候使用super().forward(x),這段的核心是將輸入 x 與新特征 t 進行通道維度上的連接,完成密集連接。

class DenseLayer(nn.Sequential):def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):super().__init__()self.add_module("norm1", nn.BatchNorm2d(num_input_features))self.add_module("relu1", nn.ReLU(inplace=True))self.add_module("conv1", nn.Conv2d(num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False))self.add_module("norm2", nn.BatchNorm2d(bn_size * growth_rate))self.add_module("relu2", nn.ReLU(inplace=True))self.add_module("conv2", nn.Conv2d(bn_size*growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False))self.drop_rate = drop_ratedef forward(self, x):t = super().forward(x)if self.drop_rate > 0:t = F.dropout(t, p=self.drop_rate, training=self.training)return torch.cat([x, t], 1)

下面是DenseBlock的實現,通過循環創建了多個DenseLayer。其中的 num_input_features + i * growth_rate 用于指定輸入通道的數量,確保每個DenseLayer的輸入通道數逐漸增加。將新創建的DenseLayer添加為 DenseBlock 的子模塊。循環結束后,DenseBlock 就包含了多個DenseLayer,每個DenseLayer都具有逐漸增加的輸入通道數量。

class DenseBlock(nn.Sequential):def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate):super().__init__()for i in range(num_layers):layer = DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate)self.add_module("denselayer%d" % (i + 1), layer)

下面是Transition,實現過渡的功能,是在塊之間降低通道數量和空間維度。

class Transition(nn.Sequential):def __init__(self, num_input_feature, num_output_features):super().__init__()self.add_module("norm", nn.BatchNorm2d(num_input_feature))self.add_module("relu", nn.ReLU(inplace=True))self.add_module("conv", nn.Conv2d(num_input_feature, num_output_features, kernel_size=1, stride=1, bias=False))self.add_module("pool", nn.AvgPool2d(2, stride=2))

模型實現,self.features 是一個包含多個層的序列,包括初始卷積塊、多個DenseBlock和Transition,以及最后的全局平均池化和分類器。遍歷 block_config 中的配置,創建DenseBlock和Transition。參數初始化部分使用了 Kaiming 初始化和常數初始化。
其中,OrderedDict是Python中的一種有序字典數據結構,它保留了元素添加的順序。在神經網絡中,我們可以使用OrderedDict來指定模型的層次結構。

from collections import OrderedDict
class DenseNet(nn.Module):def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), num_init_features=64,bn_size=4, compression_rate=0.5, drop_rate=0, num_classes=1000):super().__init__()self.features = nn.Sequential(OrderedDict([("conv0", nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),("norm0", nn.BatchNorm2d(num_init_features)),("relu0", nn.ReLU(inplace=True)),("pool0", nn.MaxPool2d(3, stride=2, padding=1))]))num_features = num_init_featuresfor i, num_layers in enumerate(block_config):block = DenseBlock(num_layers, num_features, bn_size, growth_rate, drop_rate)self.features.add_module("denseblock%d" % (i + 1), block)num_features += num_layers * growth_rateif i != len(block_config) - 1:transition = Transition(num_features, int(num_features * compression_rate))self.features.add_module("transition%d" % (i + 1), transition)num_features = int(num_features * compression_rate)self.features.add_module("norm5", nn.BatchNorm2d(num_features))self.features.add_module("relu5", nn.ReLU(inplace=True))self.classifier = nn.Linear(num_features, num_classes)for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight)elif isinstance(m, nn.BatchNorm2d):nn.init.constant_(m.bias, 0)nn.init.constant_(m.weight, 1)elif isinstance(m, nn.Linear):nn.init.constant_(m.bias, 0)def forward(self, x):features = self.features(x)out = F.avg_pool2d(features, 7, stride=1).view(features.size(0), -1)out = self.classifier(out)return out

使用summary查看網絡
在這里插入圖片描述

開始訓練

定義訓練函數

def train(dataloader, model, loss_fn, opt):size = len(dataloader.dataset)num_batches = len(dataloader)train_acc, train_loss = 0, 0for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)loss = loss_fn(pred, y)opt.zero_grad()loss.backward()opt.step()train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()train_loss += loss.item()train_acc /= sizetrain_loss /= num_batchesreturn train_acc, train_loss

定義測試函數

def test(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)test_acc, test_loss = 0, 0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)loss = loss_fn(pred, y)test_acc += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss += loss.item()test_acc /= sizetest_loss /= num_batchesreturn test_acc, test_loss

定義學習率、損失函數、優化算法

loss_fn = nn.CrossEntropyLoss()
learn_rate = 0.0001
opt = torch.optim.Adam(model.parameters(), lr=learn_rate)

開始訓練,epoch設置為20

import time
epochs = 20
train_loss = []
train_acc = []
test_loss = []
test_acc = []T1 = time.time()best_acc = 0
best_model = 0for epoch in range(epochs):model.train()epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt)model.eval() # 確保模型不會進行訓練操作epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)if epoch_test_acc > best_acc:best_acc = epoch_test_accbest_model = copy.deepcopy(model)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)print("epoch:%d, train_acc:%.1f%%, train_loss:%.3f, test_acc:%.1f%%, test_loss:%.3f"% (epoch + 1, epoch_train_acc * 100, epoch_train_loss, epoch_test_acc * 100, epoch_test_loss))T2 = time.time()
print('程序運行時間:%s秒' % (T2 - T1))PATH = './best_model.pth'  # 保存的參數文件名
if best_model is not None:torch.save(best_model.state_dict(), PATH)print('保存最佳模型')
print("Done")

效果還是不錯的
在這里插入圖片描述

可視化

可視化訓練過程與測試過程

import warnings
warnings.filterwarnings("ignore")               #忽略警告信息
plt.rcParams['font.sans-serif']    = ['SimHei'] # 用來正常顯示中文標簽
plt.rcParams['axes.unicode_minus'] = False      # 用來正常顯示負號
plt.rcParams['figure.dpi']         = 100        #分辨率epochs_range = range(epochs)plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

在這里插入圖片描述

總結

本次實驗學習到了一個更激進的密集連接機制,每個層都會包含前面層所有的輸入,而且與ResNet不同,層與層之間使用疊加的方式進行連接,來增強神經網絡的信息流動,促進梯度的傳播,以及提高參數的共享和重復使用,使得模型表現出不錯的效果。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/207240.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/207240.shtml
英文地址,請注明出處:http://en.pswp.cn/news/207240.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Elasticsearch、Logstash、Kibana(ELK)環境搭建

下面是 Elasticsearch、Logstash、Kibana(ELK)環境搭建的具體操作步驟: 安裝 Java ELK 是基于 Java 編寫的,因此需要先安裝 Java。建議安裝 Java 8 或以上版本。 下載并安裝 Elasticsearch Elasticsearch 是一個基于 Lucene 的…

DevEco Studio 運行項目有時會自動出現.js和.map文件

運行的時候報錯了,發現多了.js和.map,而且還不是一個,很多個。 通過查詢,好像是之前已知問題了,給的建議是手動刪除(一個一個刪),而且有的評論還說,一周出現了3次,太可怕了。 搜的過…

【網絡編程】-- 02 端口、通信協議

網絡編程 3 端口 端口表示計算機上的一個程序的進程 不同的進程有不同的端口號!用來區分不同的軟件進程 被規定總共0~65535 TCP,UDP:65535 * 2 在同一協議下,端口號不可以沖突占用 端口分類: 公有端口:0~1023 HT…

【android開發-23】android中WebView的用法詳解

1,WabView的用法 在Android中,WebView是一個非常重要的組件,它允許我們在Android應用中嵌入網頁,展示HTML內容。WebView是Android SDK中提供的標準組件,使用它我們可以很方便地將web頁面直接嵌入到Android應用中。Web…

亞信安慧AntDB數據庫中級培訓ACP上線,中國移動總部首批客戶認證通過

近日,亞信安慧AntDB數據庫ACP(AntDB Certified Professional)中級培訓課程于官網上線。在中國移動總部客戶運維團隊、現場項目部伙伴和AntDB數據庫成員的協同組織下,首批中級認證學員順利完成相關課程的培訓,并獲得Ant…

自然語言處理22-基于本地知識庫的快速問答系統,利用大模型的中文訓練集為知識庫

大家好,我是微學AI,今天給大家介紹一下自然語言處理22-基于本地知識庫的快速問答系統,利用大模型的中文訓練集為知識庫。我們的快速問答系統是基于本地知識庫和大模型的最新技術,它利用了經過訓練的中文大模型,該模型使用了包括alpaca_gpt4_data的開源數據集。 一、本地…

C //例10.3 從鍵盤讀入若干個字符串,對它們按字母大小的順序排序,然后把排好序的字符串送到磁盤文件中保存。

C程序設計 (第四版) 譚浩強 例10.3 例10.3 從鍵盤讀入若干個字符串,對它們按字母大小的順序排序,然后把排好序的字符串送到磁盤文件中保存。 IDE工具:VS2010 Note: 使用不同的IDE工具可能有部分差異。 代碼塊 方法…

2023_Spark_實驗二十五:SparkStreaming讀取Kafka數據源:使用Direct方式

SparkStreaming讀取Kafka數據源:使用Direct方式 一、前提工作 安裝了zookeeper 安裝了Kafka 實驗環境:kafka zookeeper spark 實驗流程 二、實驗內容 實驗要求:實現的從kafka讀取實現wordcount程序 啟動zookeeper zk.sh start# zk.sh…

生成元(Digit Generator, ACM/ICPC Seoul 2005, UVa1583)

如果x加上x的各個數字之和得到y,就說x是y的生成元。 給出n(1≤n≤100000),求最小生成元。 無解輸出0。 例如,n216,121,2005時的解分別為198,0,1979。 我的思路很簡單&am…

element-UI中el-scrollbar的使用

在elment-ui中有這么一個滾動條&#xff0c;當鼠標over到內容部分才會顯示&#xff0c;移開鼠標之后滾動條就會隱藏起來&#xff0c;相較于原生的滾動條比較美觀。 <el-scrollbar> //將滾動條的內部的內容放在里面即可 </el-scrollbar> 在使用過程中&#xff…

SNMP陷阱監控工具

SNMP&#xff08;簡單網絡管理協議&#xff09;是網絡管理的一個重要方面&#xff0c;其中網絡設備&#xff08;包括路由器、交換機和服務器&#xff09;在滿足預定義條件時將SNMP陷阱作為異步通知發送到中央管理系統。簡而言之&#xff0c;每當發生關鍵服務器不可用或硬件高溫…

microblaze仿真

verdivcs (1) vlogan/vcs增加編譯選項 -debug_accessall -kdb -lca (2) 在 simulation 選項中加入下面三個選項 -guiverdi UVM_VERDI_TRACE"UVM_AWARERALHIERCOMPWAVE" UVM_TR_RECORD 這里 -guiverdi是啟動verdi 和vcs聯合仿真。UVM_VERDI_TRACE 這里是記錄 U…

第四十二篇,MATLAB on Linux

最近在Ubuntu上安裝了一把MATLAB&#xff0c;以下操作親測有效。 一、版本 Linux&#xff1a;Ubuntu 18.04 MATLAB&#xff1a;R2021a Linux版&#xff0c;910 MATLAB下載鏈接&#xff1a;提取碼MUYU&#xff0c;感謝大佬無私奉獻&#xff01; 二、安裝 詳細的安裝步驟不…

linux高級篇基礎理論七(Tomcat)

??作者&#xff1a;小劉在C站 ??個人主頁&#xff1a; 小劉主頁 ??不能因為人生的道路坎坷,就使自己的身軀變得彎曲;不能因為生活的歷程漫長,就使求索的 腳步遲緩。 ??學習兩年總結出的運維經驗&#xff0c;以及思科模擬器全套網絡實驗教程。專欄&#xff1a;云計算技…

算法題,文本左右對齊

/*** 給定一個單詞數組 words 和一個長度 maxWidth &#xff0c;重新排版單詞&#xff0c;使其成為每行恰好有 maxWidth 個字符&#xff0c;且左右兩端對齊的文本。** 你應該使用 “貪心算法” 來放置給定的單詞&#xff1b;也就是說&#xff0c;盡可能多地往每行中放置單詞。必…

ubuntu22.04系統更改完resolv.conf后 重啟網絡服務后resolv.conf被重置

vi /etc/systemd/resolved.conf&#xff0c; [Resolve] DNS8.8.8.8 114.114.114.114 192.168.4.2 2.重啟域名解析服務 systemctl restart systemd-resolved systemctl enable systemd-resolved 3.備份當前的/etc/resolve.conf&#xff0c;并重新設置/run/systemd/resolve/res…

Docker 安裝 Centos和寶塔

1. 安裝centos docker pull centos:centos7 2. 創建docker容器&#xff1a;newbt 代表容器名 docker run -i -t -d --name newbt -p 2000:20 -p 2100:21 -p 8000:80 -p 4430:443 -p 8880:888 -p 8888:8888 -p 38444:38444 -p 2200:22 -p 2300:23 -p 2500:25 -p 3306:3306 -p 6…

c++ 解析zip文件,實現對流式文件pptx內容的修改

libzip 官網地址&#xff1a;示例代碼 #include <iostream> #include <cstdlib> #include <cstring> #include <ctime> #include <zip.h>//解析原始zip內容&#xff0c;保存為新的zip文件 int ziptest(const char* inputPath, const char* out…

vue pc官網頂部導航欄組件

官網頂部導航分為一級導航和二級導航 導航的樣子 文件的層級 router 文件層級 header 組件代碼 <h1 class"logo-wrap"><router-link to"/"><img class"logo" :src"$config.company.logo" alt"" /><i…

直面雙碳目標,優維科技攜手奧意建筑打造綠色低碳建筑數智云平臺

優維“雙碳”戰略合作建筑 為落實創新驅動發展戰略&#xff0c;增強深圳工程建設領域科技創新能力&#xff0c;促進技術進步、科技成果轉化和推廣應用&#xff0c;根據《深圳市工程建設領域科技計劃項目管理辦法》《深圳市住房和建設局關于組織申報2022年深圳市工程建設領域科…