基于UNet算法的農業遙感圖像語義分割——補充版

前言

本案例希望建立一個UNET網絡模型,來實現對農業遙感圖像語義分割的任務。本篇博客主要包括對上一篇博客中的相關遺留問題進行解決,并對網絡結構進行優化調整以適應個人的硬件設施——NVIDIA GeForce RTX 3050。

本案例的前兩篇博客直達鏈接基于UNet算法的農業遙感圖像語義分割(下)和基于UNet算法的農業遙感圖像語義分割(上)

1.模型簡化

1.1 二分類語義分割效果解答

上一篇博客最終的預測結果為二分類的語義分割,即經過彩色映射后,結果只有黑和藍兩種顏色。原因是因為模型雖然參數更新了1400多次,但其實從遍歷數據集的角度考慮也就65個epoch.在這里插入圖片描述
同時網絡模型參數量約7.7M,模型并未充分學習到訓練集上的信息。之所以會出現二分類的預測結果,是與模型初始化權重有關。

1.2網絡模型調整

因此針對上述情況,我將模型改成了單層的編碼器-解碼器架構,同時將Block模塊中做進一步特征融合的卷積層移除,具體結構如下所示:

class Block(nn.Module):def __init__(self, in_channels, out_channels):super(Block, self).__init__()self.relu = nn.ReLU(inplace=False)self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)# self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)def forward(self, x):x = self.conv1(x)x = self.relu(x)# x = self.conv2(x)# x = self.relu(x)return xclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.relu = nn.ReLU(inplace=False)self.pool = nn.MaxPool2d(kernel_size=2, stride=2)# 編碼器部分self.conv1 = Block(3, 32)# 解碼器部分self.up2 = nn.ConvTranspose2d(32, 32, kernel_size=2, stride=2)self.conv2 = Block(32, 32)self.conv3 = nn.Conv2d(32, 4, kernel_size=1)def forward(self, x):# 編碼器conv1 = self.conv1(x)  # 32, 512, 512pool1 = self.pool(conv1)  # 32, 256, 256# 解碼器up2 = self.up2(pool1)  # 32, 512, 512conv2 = self.conv2(up2)  # 32, 512, 512conv3 = self.conv3(conv2)  # 4, 512, 512return conv3

此時查看模型的信息如下所示:
在這里插入圖片描述
模型的參數量已經減少至14.4k,可以預見結果并不會很好。因為輸入的圖像尺寸就已經512×512×3,相比而言,該模型顯然不能充分擬合該任務。

2.訓練策略調整

2.1訓練損失波動解答

因為統計的損失是按照每個iter進行統計的,每次的迭代過程在該批次下的參數更新朝著當前批次損失變小的方向進行,但對其他批次可能損失會升高,因此損失波動劇烈,但整體呈下降趨勢。
這里的解決方案如下:

  1. 將參數更新過程中記錄的iter次數進行減少,如將iter%10==0調整成iter%200==0
  2. 將參數更新過程中的記錄的結果轉換成累積量,即將10個iter中損失進行累加或者將一個epoch中的所有損失進行累加(本案例后續改進采用該方式)。
  3. 將參數更新過程中的記錄的結果轉換成平均量,即將10個iter中損失進行平均或者將一個epoch中的所有損失進行平均。

2.2訓練過程調整

因為本案例的數據集本身就很小,所以這里采用的是將一個epoch中的所有損失進行累加統計進行輸出可視化。同時為了避免模型參數保存冗余問題,將模型保存策略進行調整,只保存在驗證集上損失最小的模型,同時使用覆蓋原則將之前的保存模型進行覆蓋,以節省空間開銷,具體代碼調整如下:

    # 創建一個 SummaryWriter 對象,用于將數據寫入 TensorBoardwriter = SummaryWriter("dataset/logs")epoch = 0best_val_loss = float('inf')# best_val_loss = 7.899# model.load_state_dict(torch.load('./models/secweights_40.pth'))while epoch < 500:epoch += 1print("---------第{}輪訓練開始---------".format(epoch))train_loss = 0for i, (img, label) in tqdm(enumerate(dataloader_train)):img = img.to(device).float()label = label.long().to(device)model.train()output = model(img)# output = torch.argmax(output, dim=1).double()# iter_num += 1loss = getLoss(output, label)train_loss += loss.item()loss.backward()optimizer.step()optimizer.zero_grad()# print("---------第{}輪訓練結束---------".format(epoch))print("第{}輪訓練的損失為:{}".format(epoch, train_loss))writer.add_scalar('Training Loss3', train_loss, epoch)if epoch % 10 == 0:# torch.save(model.state_dict(), './models/thirdweights_{}.pth'.format(epoch))val_loss = 0with torch.no_grad():model.eval()for i, (img, label) in tqdm(enumerate(dataloader_val)):img = img.to(device).float()label = label.long().to(device)output = model(img)loss = getLoss(output, label)val_loss += loss.item()print("第{}輪驗證的損失為:{}".format(epoch, val_loss))if val_loss < best_val_loss:best_val_loss = val_losstorch.save(model.state_dict(), './models/best_model2.pth')print("Saved new best model")writer.add_scalar('Validation Loss3', val_loss, epoch)writer.close()

3.結果分析

3.1訓練過程損失

在訓練過程中的損失記錄如下:

在這里插入圖片描述
通過結果可以看出上述修改方式確實取得了不錯的效果,模型訓練集的抖動已大幅度減小。
從曲線角度考慮,訓練集損失已經趨向于平穩,同時驗證集上損失也趨向于平穩,由此判斷模型已經基本收斂,但訓練集的損失仍停留在較高水平,大概率是因為模型過于簡單,難以擬合該任務的需求。

3.2模型預測結果

這里將模型最終保存的結果加載進來,對未知圖片進行預測,代碼如下:

import matplotlib.pyplot as plt
import torch
import cv2
import numpy as np
from torch.utils.tensorboard import SummaryWriterfrom Net2 import Net# I=cv2.imread('dataset/0.9/image/16213.png')#dataset/test.png
I=cv2.imread('dataset/test.png')
I=np.transpose(I, (2, 0, 1))
I=I/255.0
I=I.reshape(1,3,512,512)
I=torch.tensor(I)
model=Net().double()
model.load_state_dict(torch.load('models/best_model2.pth'))
output=model(I)
# print(output.shape)
# print(output[0,:,:5,:5])
predicted_classes = torch.argmax(output, dim=1).squeeze(0).numpy()color_map = {0: [0, 0, 0],  # 黑色1: [255, 0, 0],  # 紅色2: [0, 255, 0],  # 綠色3: [0, 0, 255]  # 藍色
}height, width = predicted_classes.shape
colored_image = np.zeros((height, width, 3), dtype=np.uint8)
for i in range(height):for j in range(width):class_id = predicted_classes[i, j]colored_image[i, j] = color_map[class_id]plt.imshow(colored_image)
plt.axis('off')
plt.show()
print(colored_image.shape)
colored_image=np.transpose(colored_image, (2, 0, 1))
writer=SummaryWriter('dataset/logs')writer.add_image('test3',colored_image)
writer.close()

預測結果如下:
在這里插入圖片描述
從結果角度考慮,確實實現了四分類的語義分割效果,但預測的效果并不是很好,因此需要進一步修改網絡結構。

4.網絡模型優化

具體修改主要包括引入批量規范化BatchNormalization的處理和增加了Dropout的機制以及對網絡結構調整為三層的編碼器-解碼器架構。

4.1 BatchNormalization

批量規范化的核心思想是對每一層的輸入進行歸一化處理,使得每一層的輸入分布在訓練過程中保持相對穩定。具體來說,它將輸入數據的每個特征維度都歸一化到均值為 0、方差為 1 的標準正態分布。這樣可以減少內部協變量偏移的影響,加快訓練速度。

這里還有其他的逐層歸一化方式,這里不做詳細介紹。因為BatchNormalization聚焦于小批量層面,更適用于該任務,或者說更適用視覺圖像處理方面

在這里插入圖片描述
圖片來源:本校《深度學習》課程的PPT

4.2 Dropout的機制

Dropout的機制能有效防止過擬合,在訓練神經網絡時,它通過以一定的概率隨機將神經元的輸出設置為0,即暫時“丟棄”這些神經元及其連接,每次迭代訓練時在訓練一個不同的子網絡,通過多個子網絡的綜合效果來提高模型的泛化能力。類似于基學習器集成學習的思想。

4.3網絡模型代碼

上述的兩種方式是針對Block模塊的,這里為了更好的擬合語義分割的任務,需要進一步加深網絡結構,考慮到硬件資源有限,于是使用的是三層編碼器-解碼器架構,修改后的網絡模型完整代碼如下:

class Block(nn.Module):def __init__(self, in_channels, out_channels, dropout_rate=0.1):super(Block, self).__init__()self.relu = nn.ReLU(inplace=False)self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)self.bn1 = nn.BatchNorm2d(out_channels)self.dropout1 = nn.Dropout2d(p=dropout_rate)self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)self.bn2 = nn.BatchNorm2d(out_channels)self.dropout2 = nn.Dropout2d(p=dropout_rate)def forward(self, x):x = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.dropout1(x)x = self.conv2(x)x = self.bn2(x)x = self.relu(x)x = self.dropout2(x)return xclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.relu = nn.ReLU(inplace=False)self.pool = nn.MaxPool2d(kernel_size=2, stride=2)# 編碼器部分self.conv1 = Block(3, 32)self.conv2 = Block(32, 64)self.conv3 = Block(64, 128)# 解碼器部分self.up4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)self.conv4 = Block(128, 64)self.up5 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)self.conv5 = Block(64, 32)self.conv6 = nn.Conv2d(32, 4, kernel_size=1)def forward(self, x):# 編碼器conv1 = self.conv1(x)  # 32, 512, 512pool1 = self.pool(conv1)  # 32, 256, 256conv2 = self.conv2(pool1)  # 64, 256, 256pool2 = self.pool(conv2)  # 64, 128, 128conv3 = self.conv3(pool2)  # 128, 128, 128# 解碼器up4 = self.up4(conv3)  # 64, 256, 256conv4 = torch.cat([up4, conv2], dim=1)  # 128, 256, 256conv4 = self.conv4(conv4)  # 64, 256, 256up5 = self.up5(conv4)  # 32, 512, 512conv5 = torch.cat([up5, conv1], dim=1)  # 64, 512, 512conv5 = self.conv5(conv5)  # 32, 512, 512conv6 = self.conv6(conv5)  # 4, 512, 512return conv6

5.改進模型結果分析

訓練策略和之前保持不變,這里就不重復解釋,只對結果進行說明。

5.1訓練過程損失

訓練過程損失記錄如下:
在這里插入圖片描述
通過結果看到,訓練集和驗證集損失也基本趨于平穩,因此判斷模型基本收斂。

5.2模型預測結果

將之前訓練好的模型參數加載進來,對未知圖片進行預測,結果如下:
在這里插入圖片描述
通過結果可以看出,預測結果相對于之前有了很大的改善,基本實現了語義分割的效果,只是在微小內容上,識別的并不準確。可能是因為模型還是不夠復雜,不足以擬合該任務。

6.結語

至此,基于UNET算法的農業遙感圖像語義分割任務到此結束,期望能夠對你有所幫助。同時該項目也是我接觸的第一個語義分割項目,解釋的如有不足還請批評指出!!!

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/903624.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/903624.shtml
英文地址,請注明出處:http://en.pswp.cn/news/903624.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Compose筆記(二十一)--AnimationVisibility

這一節主要了解一下Compose的AnimationVisibility,AnimatedVisibility 是 Jetpack Compose 里用于實現組件可見性動畫效果的組件&#xff0c;借助它能讓組件在顯示和隱藏時帶有平滑的過渡動畫&#xff0c;從而提升用戶體驗。現總結如下: API 1. visible 含義&#xff1a;這是一…

基于 HT 構建 2D 智慧倉儲可視化系統的技術解析

在當今數字化時代&#xff0c;倉儲管理對于企業的運營效率和成本控制愈發關鍵。圖撲軟件&#xff08;Hightopo&#xff09;憑借其強大的 HT for Web 產品&#xff0c;打造出 2D 智慧倉儲可視化平臺&#xff0c;為倉儲管理帶來了全新的技術解決方案。 HT 是一款基于 WebGL、can…

HTML ASCII 編碼詳解

HTML ASCII 編碼詳解 引言 HTML&#xff08;HyperText Markup Language&#xff09;是一種用于創建網頁的標準標記語言。在HTML中&#xff0c;字符的表示方式非常重要&#xff0c;因為它直接影響到網頁內容的顯示效果。ASCII編碼作為一種基本的字符編碼方式&#xff0c;在HTM…

pinia-plugin-persistedstate的使用

pinia持久化存儲的使用 安裝 npm install pinia-plugin-persistedstate 注冊 import { createPinia } from pinia import piniaPluginPersistedstate from pinia-plugin-persistedstateconst pinia createPinia() pinia.use(piniaPluginPersistedstate)export default pinia …

Vue:el-table-tree懶加載數據

目錄 一、出現場景二、具體使用三、修改時重新加載樹節點四、新增、刪除重新加載樹節點 一、出現場景 在項目的開發過程中&#xff0c;我們經常會使用到表格樹的格式&#xff0c;但是猶豫數據較多&#xff0c;使用分頁又不符合項目需求時&#xff0c;就需要對樹進行懶加載的操…

ChipCN IDE KF32 導入工程后,無法編譯的問題

使用ChipON IDE for KungFu32 導入已有的工程是時&#xff0c;發現能夠編譯&#xff0c;但是點擊&#xff0c;同時選擇硬件調試時 沒有任何響應。查看工程調試配置時&#xff0c;發現如下問題&#xff1a; 沒有看到添加有啟動配置&#xff0c;說明就是這里的問題了(應該是IDE的…

前端筆記-Element-Plus

結束了vue的基礎學習&#xff0c;現在進一步學習組件 Element-Plus部分學習目標&#xff1a; Element Plus1、查閱官方文檔指南2、學習常用組件的使用方法3、Table、Pagination、Form4、Input、Input Number、Switch、Select、Date Picker、Button5、Message、MessageBox、N…

C++入門小館: 模板

嘿&#xff0c;各位技術潮人&#xff01;好久不見甚是想念。生活就像一場奇妙冒險&#xff0c;而編程就是那把超酷的萬能鑰匙。此刻&#xff0c;陽光灑在鍵盤上&#xff0c;靈感在指尖跳躍&#xff0c;讓我們拋開一切束縛&#xff0c;給平淡日子加點料&#xff0c;注入滿滿的pa…

強化學習之基于無模型的算法之基于值函數的深度強化學習算法

3、基于值函數的深度強化學習算法 1&#xff09;深度Q網絡&#xff08;DQN&#xff09; 核心思想 DQN是一種將Q學習與深度神經網絡結合的方法&#xff0c;用于解決高維狀態空間的問題。 它以環境的狀態作為輸入&#xff0c;通過神經網絡輸出每個動作的 Q 值&#xff0c;智能體…

網絡規劃和設計

1.結構化綜合布線系統包括建筑物綜合布線系統PDS&#xff0c;智能大夏布線系統IBS和工業布線系統IDS 2.GB 50311-2016綜合布線系統工程設計規范 GB/T 50312-2016綜合布線系統工程驗收規范 3.結構化布線系統分為6個子系統&#xff1a; 工作區子系統&#xff1b;水平布線子系…

軟件設計師-錯題筆記-計算機硬件和體系

1. 解析&#xff1a;循環冗余校驗碼也叫CRC校驗碼&#xff0c;其中運算包括了模2&#xff08;異或&#xff09;來構造校驗位。別的三種沒有用到模2的方法。 2. 解析&#xff1a;如果是正數&#xff0c;則是首位為0&#xff0c;其余位全為1&#xff0c;這時最大數(2^(n-1))-1…

OpenCV 4.7企業級開發實戰:從圖像處理到目標檢測的全方位指南

簡介 OpenCV作為工業級計算機視覺開發的核心工具庫,其4.7版本在圖像處理、視頻分析和深度學習模型推理方面實現了顯著優化。 本文將從零開始,系統講解OpenCV 4.7的核心特性和功能更新,同時結合企業級應用場景,提供詳細代碼示例和實戰項目,幫助讀者掌握從基礎圖像處理到復…

LeetCode算法題 (除自身以外數組的乘積)Day14!!!C/C++

https://leetcode.cn/problems/product-of-array-except-self/description/ 一、題目分析 給你一個整數數組 nums&#xff0c;返回 數組 answer &#xff0c;其中 answer[i] 等于 nums 中除 nums[i] 之外其余各元素的乘積 。 題目數據 保證 數組 nums之中任意元素的全部前綴…

如何寫好Verilog狀態機

還記得之前軟件的同事說過的一句話。怎么凸顯自己的工作量&#xff0c;就是自己給自己寫BUG。 看過夏宇聞老師書的都知道&#xff0c;verilog的FSM有moore和mealy,然后有一段&#xff0c;二段&#xff0c;三段式。記得我還是學生的時候&#xff0c;看到這里的時候&#xff0c;感…

晶振頻率/穩定度/精度/溫度特性的深度解析與測量技巧

在電子設備的精密世界里&#xff0c;晶振如同跳動的心臟&#xff0c;為各類系統提供穩定的時鐘信號。晶振的頻率、穩定度、精度以及溫度特性&#xff0c;這些關鍵參數不僅決定了設備的性能&#xff0c;更在不同的應用場景中發揮著至關重要的作用。 一、頻率選擇的本質&#xff…

Kafka-可視化工具-Offset Explorer

安裝&#xff1a; 下載地址&#xff1a;Offset Explorer 安裝好后如圖&#xff1a; 1、下載安裝完畢&#xff0c;進行新增連接&#xff0c;啟動offsetexplorer.exe&#xff0c;在Add Cluster窗口Properties 選項下填寫Cluster name 和 kafka Cluster Version Cluster name (集…

LabVIEW模板之溫度監測應用

這是一個溫度監測應用程序&#xff0c;基于 Continuous Measurement and Logging 示例項目構建&#xff0c;用于讀取模擬溫度值&#xff0c;當溫度超出給定范圍時發出警報 。 這個。 詳細說明 運行操作&#xff1a;直接運行該 VI 程序。點擊 “Start” 按鈕&#xff0c;即可開…

后端[特殊字符][特殊字符]看前端之Row與Col

是的&#xff0c;在 Ant Design 的柵格布局系統中&#xff0c;每個 <Row> 組件確實對應頁面上的一個獨立行。以下是更詳細的解釋&#xff1a; 核心概念 組件作用類比現實場景<Row>橫向容器&#xff0c;定義一行內容類似 Excel 表格中的一行<Col>縱向分割&am…

[特殊字符] SpringCloud項目中使用OpenFeign進行微服務遠程調用詳解(含連接池與日志配置)

&#x1f4da; 目錄 為什么要用OpenFeign&#xff1f; 在cart-service中整合OpenFeign 2.1 引入依賴 2.2 啟用OpenFeign 2.3 編寫Feign客戶端 2.4 調用Feign接口 開啟連接池&#xff0c;優化Feign性能 3.1 引入OkHttp 3.2 配置啟用OkHttp連接池 3.3 驗證連接池生效 Feign最佳…