[高光譜]PyTorch使用CNN對高光譜圖像進行分類

項目原地址:

Hyperspectral-Classificationicon-default.png?t=N6B9https://github.com/eecn/Hyperspectral-ClassificationDataLoader講解:

[高光譜]使用PyTorch的dataloader加載高光譜數據icon-default.png?t=N6B9https://blog.csdn.net/weixin_37878740/article/details/130929358

一、模型加載

? ? ? ? 在原始項目中,提供了14種模型可供選擇,從最簡單的SVM到3D-CNN,這里以2D-CNN為例,在原項目中需要將model屬性設置為:sharma。

? ? ? ? ?模型通過一個get_model(.)函數獲得,該函數一共四個返回(model, optimizer, loss, hyperparams;分別為:模型,迭代器,損失函數,超參數),輸入為模型類別。

? ? ? ? 進入函數內部,找到對應的函數體如下:

elif name == 'sharma':kwargs.setdefault('batch_size', 60)        #batch_szieepoch = kwargs.setdefault('epoch', 30)     #迭代數lr = kwargs.setdefault('lr', 0.05)         #學習率center_pixel = True                        #是否開啟中心像素模型# We assume patch_size = 64kwargs.setdefault('patch_size', 64)        #patch_szie,即圖像塊大小model = SharmaEtAl(n_bands, n_classes, patch_size=kwargs['patch_size'])  #模型本體optimizer = optim.SGD(model.parameters(), lr=lr, weight_decay=0.0005)    #迭代器criterion = nn.CrossEntropyLoss(weight=kwargs['weights'])            #交叉熵損失函數kwargs.setdefault('scheduler', optim.lr_scheduler.MultiStepLR(optimizer, milestones=[epoch // 2, (5 * epoch) // 6], gamma=0.1))

? ? ? ? 這里設置了一部分超參數,同時設置了patch_size為64(此概念可以參見dataloader篇),采用的損失函數為常見的交叉熵損失函數,而模型本體則是使用SharmaEtAl(.)進行加載。

二、模型本體

? ? ? ? 跳轉至SharmaEtAl(nn.Module),其繼承自nn.model,輸入參數3個,分別為:輸入通道數、分類數、圖塊尺寸。

def __init__(self, input_channels, n_classes, patch_size=64):

? 該網絡的結構如圖,模型中里面包含3個卷積、2個bn、2個池化和2個全連接,如下:

# 卷積層1
self.conv1 = nn.Conv3d(1, 96, (input_channels, 6, 6), stride=(1,2,2))
self.conv1_bn = nn.BatchNorm3d(96)
self.pool1 = nn.MaxPool3d((1, 2, 2))
# 卷積層2
self.conv2 = nn.Conv3d(1, 256, (96, 3, 3), stride=(1,2,2))
self.conv2_bn = nn.BatchNorm3d(256)
self.pool2 = nn.MaxPool3d((1, 2, 2))
# 卷積層3
self.conv3 = nn.Conv3d(1, 512, (256, 3, 3), stride=(1,1,1))# 展平函數
self.features_size = self._get_final_flattened_size()# 由兩個全連接組成的分類器
self.fc1 = nn.Linear(self.features_size, 1024)
self.dropout = nn.Dropout(p=0.5)
self.fc2 = nn.Linear(1024, n_classes)

? ? ? ? 其中的展平函數_get_final_flattened_size(.),并不實際參與前向傳遞,僅計算轉換后的通道數。

    def _get_final_flattened_size(self):with torch.no_grad():x = torch.zeros((1, 1, self.input_channels,self.patch_size, self.patch_size))x = F.relu(self.conv1_bn(self.conv1(x)))x = self.pool1(x)print(x.size())b, t, c, w, h = x.size()x = x.view(b, 1, t*c, w, h) x = F.relu(self.conv2_bn(self.conv2(x)))x = self.pool2(x)print(x.size())b, t, c, w, h = x.size()x = x.view(b, 1, t*c, w, h) x = F.relu(self.conv3(x))print(x.size())_, t, c, w, h = x.size()return t * c * w * h

? ? ? ? 實際的前向傳遞如下:

    def forward(self, x):# 卷積塊1x = F.relu(self.conv1_bn(self.conv1(x)))x = self.pool1(x)# 獲取tensor尺寸b, t, c, w, h = x.size()# 調整tensor尺寸x = x.view(b, 1, t*c, w, h) # 卷積塊2x = F.relu(self.conv2_bn(self.conv2(x)))x = self.pool2(x)# 獲取tensor尺寸b, t, c, w, h = x.size()# 調整tensor尺寸x = x.view(b, 1, t*c, w, h) # 卷積塊3x = F.relu(self.conv3(x))# 調整tensor尺寸x = x.view(-1, self.features_size)# 分類器x = self.fc1(x)x = self.dropout(x)x = self.fc2(x)return x

三、訓練與測試

? ? ? ? 主函數中,訓練和測試結構如下:

        try:train(model, optimizer, loss, train_loader, hyperparams['epoch'],scheduler=hyperparams['scheduler'], device=hyperparams['device'],supervision=hyperparams['supervision'], val_loader=val_loader,display=viz)except KeyboardInterrupt:# Allow the user to stop the trainingpassprobabilities = test(model, img, hyperparams)prediction = np.argmax(probabilities, axis=-1)

? ? ? ? 訓練被封裝在train(.)函數中,測試封裝在test(.)函數中,下面逐一來看。

? ? ? ? 首先是train函數,這里省去外圍部分,僅看核心的循環控制段。

# 外循環控制,用于控制輪次(epoch)
for e in tqdm(range(1, epoch + 1), desc="Training the network"):# 進入訓練模式net.train()avg_loss = 0.# 從dataloader中取出圖像(data)和標簽(target)for batch_idx, (data, target) in tqdm(enumerate(data_loader), total=len(data_loader)):# 如果是GPU模式則需要轉換為cuda格式data, target = data.to(device), target.to(device)#---實際的訓練部分---## 凍結梯度optimizer.zero_grad()# 訓練模式(監督訓練/半監督訓練)if supervision == 'full':# 前向傳遞output = net(data)#target = target - 1# 交叉熵損失函數loss = criterion(output, target)elif supervision == 'semi':outs = net(data)output, rec = outs#target = target - 1loss = criterion[0](output, target) + net.aux_loss_weight * criterion[1](rec, data)#---實際的訓練部分---## 損失函數反向傳遞loss.backward()# 迭代器步進optimizer.step()# 記錄損失函數avg_loss += loss.item()losses[iter_] = loss.item()mean_losses[iter_] = np.mean(losses[max(0, iter_ - 100):iter_ + 1])iter_ += 1del(data, target, loss, output)

? ? ? ? 接下來是test函數,與train不同的是,其參數為:model, img, hyperparams。其中img,是一整張高光譜圖像,而不是由DataSet塊采樣后的圖像塊。故其結構也與train大不相同。

? ? ? ? 在進行測試的時候,需要一個滑動窗口(sliding_window)函數將其進行切塊以滿足圖像輸入的要求。同時還需要一個grouper函數將其組裝為batch送入神經網絡中。所以我們可以看到循環控制的最外層實際上就是上面兩個函數來組成的。

    # 圖像切塊iterations = count_sliding_window(img, **kwargs) // batch_sizefor batch in tqdm(grouper(batch_size, sliding_window(img, **kwargs)),total=(iterations),desc="Inference on the image"):#  鎖定梯度with torch.no_grad():#  逐像素模式if patch_size == 1:data = [b[0][0, 0] for b in batch]data = np.copy(data)data = torch.from_numpy(data)# 其他模式else:data = [b[0] for b in batch]data = np.copy(data)data = data.transpose(0, 3, 1, 2)data = torch.from_numpy(data)data = data.unsqueeze(1)indices = [b[1:] for b in batch]# 類型轉換data = data.to(device)# 前向傳遞output = net(data)if isinstance(output, tuple):output = output[0]output = output.to('cpu')if patch_size == 1 or center_pixel:output = output.numpy()else:output = np.transpose(output.numpy(), (0, 2, 3, 1))for (x, y, w, h), out in zip(indices, output):# 將得到的像素平裝回原尺寸if center_pixel:probs[x + w // 2, y + h // 2] += outelse:probs[x:x + w, y:y + h] += outreturn probs

? ? ? ? 這個函數會使用上述的兩個函數,將圖像切割成可以放入神經網絡的尺寸并逐個進行前向傳遞,最后將得到的所有像素的結果按照原來的尺寸組成一個結果矩陣返回。

? ? ? ? 最后,這個結果由一個argmax函數得到其概率最大的預測結果:

prediction = np.argmax(probabilities, axis=-1)

四、結果計算

? ? ? ? 在完成上述步驟后,由metrics(.)函數計算最終的模型結果:

run_results = metrics(prediction, test_gt, ignored_labels=hyperparams['ignored_labels'], n_classes=N_CLASSES)

? ? ? ? 其函數體如下:

def metrics(prediction, target, ignored_labels=[], n_classes=None):"""Compute and print metrics (accuracy, confusion matrix and F1 scores).Args:prediction: list of predicted labelstarget: list of target labelsignored_labels (optional): list of labels to ignore, e.g. 0 for undefn_classes (optional): number of classes, max(target) by defaultReturns:accuracy, F1 score by class, confusion matrix"""ignored_mask = np.zeros(target.shape[:2], dtype=np.bool)for l in ignored_labels:ignored_mask[target == l] = Trueignored_mask = ~ignored_mask#target = target[ignored_mask] -1target = target[ignored_mask]prediction = prediction[ignored_mask]results = {}n_classes = np.max(target) + 1 if n_classes is None else n_classescm = confusion_matrix(target,prediction,labels=range(n_classes))results["Confusion matrix"] = cm# Compute global accuracytotal = np.sum(cm)accuracy = sum([cm[x][x] for x in range(len(cm))])accuracy *= 100 / float(total)results["Accuracy"] = accuracy# Compute F1 scoreF1scores = np.zeros(len(cm))for i in range(len(cm)):try:F1 = 2. * cm[i, i] / (np.sum(cm[i, :]) + np.sum(cm[:, i]))except ZeroDivisionError:F1 = 0.F1scores[i] = F1results["F1 scores"] = F1scores# Compute kappa coefficientpa = np.trace(cm) / float(total)pe = np.sum(np.sum(cm, axis=0) * np.sum(cm, axis=1)) / \float(total * total)kappa = (pa - pe) / (1 - pe)results["Kappa"] = kappareturn results

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/41572.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/41572.shtml
英文地址,請注明出處:http://en.pswp.cn/news/41572.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

使用JMeter創建數據庫測試

好吧!我一直覺得我不聰明,所以,我用最詳細,最明了的方式來書寫這個文章。我相信,我能明白的,你們一定能明白。 我的環境:MySQL:mysql-essential-5.1.51-win32 jdbc驅動:…

mysql 03.查詢(重點)

先準備測試數據,代碼如下: -- 創建數據庫 DROP DATABASE IF EXISTS mydb; CREATE DATABASE mydb; USE mydb;-- 創建student表 CREATE TABLE student (sid CHAR(6),sname VARCHAR(50),age INT,gender VARCHAR(50) DEFAULT male );-- 向student表插入數據…

PHP 公交公司充電樁管理系統mysql數據庫web結構apache計算機軟件工程網頁wamp

一、源碼特點 PHP 公交公司充電樁管理系統是一套完善的web設計系統,對理解php編程開發語言有幫助,系統具有完整的源代碼和數據庫,系統主要采用B/S模式開發。 源碼下載 https://download.csdn.net/download/qq_41221322/88220946 論文下…

【面試問題】當前系統查詢接口需要去另外2個系統庫中實時查詢返回結果拼接優化思路

文章目錄 場景描述優化思路分享資源 場景描述 接口需要從系統1查詢數據,查出的每條數據需要從另一個系統2中再去查詢某些字段, 比如:從系統1中查出100條數據,每條數據需要去系統2中再去查詢出行數據,可能系統1一條數…

socks5 保障網絡安全與爬蟲需求的完美融合

Socks5代理:跨足網絡安全和爬蟲領域的全能選手 Socks5代理作為一種通用的網絡協議,為多種應用場景提供了強大的代理能力。它不僅支持TCP和UDP的數據傳輸,還具備更高級的安全特性,如用戶身份驗證和加密通信。在網絡安全中&#xf…

蘋果手機批量刪除聯系人的2個方法,請查收!

【想要清理通訊錄里的“僵尸號”,但是突然發現手機不能批量刪除。一個一個刪除太麻煩了,有什么辦法可以一次性多刪幾個人嗎?】 小編想問問果粉們平時都是怎么刪除聯系人的?特別是要刪除多個聯系人的時候,大家還是選擇…

matlab保存圖片

僅作為記錄,大佬請跳過。 文章目錄 用界面中的“另存為”用saveas 用界面中的“另存為” 即可。 參考 感謝大佬博主文章:傳送門 用saveas 必須在編輯器中的plot之后用saveas(也就是不能在命令行中單獨使用——比如在編輯器中plot&#xf…

神經網絡基礎-神經網絡補充概念-46-指數加權平均的偏差修正

由來 指數加權平均(Exponential Moving Average,EMA)在初始時可能會受到偏差的影響,特別是在數據量較小時,EMA的值可能會與實際數據有較大的偏差。為了修正這種偏差,可以使用偏差修正方法,通常…

基于平臺的城市排水泵站管理系統設計

安科瑞 耿敏花 近年來我國城市內澇災害頻發,造成人員傷亡以及經濟損失嚴重,嚴重威脅著城市的安全。數據顯示,2015-2018年我國平均每年受淹或發生內澇城市的數量約占我國城市數量的1/5;人民生命財產也損失嚴重,據統計&a…

基于YOLOv5n/s/m不同參數量級模型開發構建茶葉嫩芽檢測識別模型,使用pruning剪枝技術來對模型進行輕量化處理,探索不同剪枝水平下模型性能影響【續】

這里主要是前一篇博文的后續內容,簡單回顧一下:本文選取了n/s/m三款不同量級的模型來依次構建訓練模型,所有的參數保持同樣的設置,之后探索在不同剪枝處理操作下的性能影響。 在上一篇博文中保持30的剪枝程度得到的效果還是比較理…

C++ 學習系列3 -- 函數壓棧與出棧

在C中,函數壓棧(函數調用)和出棧(函數返回)是函數調用過程中的兩個關鍵步驟。下面將逐步解釋這兩個過程: 一 函數壓棧與出棧過程簡介 函數壓棧(函數調用)的過程如下: …

2020年3月全國計算機等級考試真題(C語言二級)

2020年3月全國計算機等級考試真題(C語言二級) 第1題 有以下程序 void fun1 (char*p) { char*q; qp; while(*q!\0) { (*Q); q; } } main() { char a[]{"Program"},*p; p&a[3]; fun1(p); print…

【C語言學習】本地變量

本地變量 1.函數每次運行,就會產生一個獨立的變量空間,在這個空間中的變量,是函數的這次運行所獨有的,稱之為本地變量。 2.定義在函數內部的變量就是本地變量。 3.參數也是本地變量 變量的生存期和作用域 1.生存期:變量…

新能源電動車充電樁控制主板安全特點

新能源電動車充電樁控制主板安全特點 你是否曾經擔心過充電樁的安全問題?充電樁主板又是什么樣的呢?今天我們就來聊聊這個話題。 充電樁主板采用雙重安全防護系統,包括防水、防護、防塵等,確保充電樁安全、可靠。不僅如此,充電樁主板采用先…

簡單的洗牌算法

目錄 前言 問題 代碼展現及分析 poker類 game類 Text類 前言 洗牌算法為ArrayList具體使用的典例,可以很好的讓我們快速熟系ArrayList的用法。如果你對ArrayList還不太了解除,推薦先看本博主的ArrayList的詳解。 ArrayList的詳解_WHabcwu的博客-CSD…

mysql mysql 容器 忽略大小寫配置

首先能夠連接上mysql,然后輸入下面這個命令查看mysql是否忽略大小寫 show global variables like %lower_case%; lower_case_table_names 0:不忽略大小寫 lower_case_table_names 1:忽略大小寫 mysql安裝分為兩種(根據自己的my…

sql server Varchar轉換為Datetime

將Varchar轉換為Datetime是一個常見的需求,在處理日期和時間數據時特別有用。在SQL Server中,可以使用CONVERT函數或CAST函數將Varchar轉換為Datetime。 使用CONVERT函數 CONVERT函數可以將一個值從一個類型轉換為另一個類型。以下是使用CONVERT函數將…

FPGA芯片IO口上下拉電阻的使用

FPGA芯片IO口上下拉電阻的使用 為什么要設置上下拉電阻一、如何設置下拉電阻二、如何設置上拉電阻為什么要設置上下拉電阻 這里以高云FPGA的GW1N-UV2QN48C6/I5來舉例,這個芯片的上電默認初始化階段,引腳是弱上來模式,且模式固定不能通過軟件的配置來改變。如下圖所示: 上…

centos 7.x 單用戶模式

最近碰到 centos 7.9 一些參數設置錯誤無法啟動系統的情況,研究后可以使用單用戶模式進入系統進行恢復操作。 進入啟動界面,按 e ro 替換為 rw init/sysroot/bin/sh 替換前 替換后 Ctrl-x 進行重啟進入單用戶模式 執行 chroot /sysroot 可以查看日…

【ARM64 常見匯編指令學習 19 -- ARM64 BEQ與B.EQ的區別】

文章目錄 ARM BEQ和B.EQ 上篇文章:ARM64 常見匯編指令學習 18 – ARM64 TST 指令與 條件標志位 Z ARM BEQ和B.EQ 在ARMv8匯編中,BEQ和B.EQ實際上是同一條指令的兩種不同表示方式,它們都表示條件分支指令,當某個條件滿足時&#x…