【PyTorch】多對象分割項目

對象分割任務的目標是找到圖像中目標對象的邊界。實際應用例如自動駕駛汽車和醫學成像分析。這里將使用PyTorch開發一個深度學習模型來完成多對象分割任務。多對象分割的主要目標是自動勾勒出圖像中多個目標對象的邊界。

對象的邊界通常由與圖像大小相同的分割掩碼定義,在分割掩碼中屬于目標對象的所有像素基于預定義的標記被標記為相同。

目錄

創建數據集

創建數據加載器

創建模型

部署模型

定義損失函數和優化器

訓練和驗證模型


創建數據集

from torchvision.datasets import VOCSegmentation
from PIL import Image   
from torchvision.transforms.functional import to_tensor, to_pil_imageclass myVOCSegmentation(VOCSegmentation):def __getitem__(self, index):img = Image.open(self.images[index]).convert('RGB')target = Image.open(self.masks[index])if self.transforms is not None:augmented= self.transforms(image=np.array(img), mask=np.array(target))img = augmented['image']target = augmented['mask']                  target[target>20]=0img= to_tensor(img)            target= torch.from_numpy(target).type(torch.long)return img, targetfrom albumentations import (HorizontalFlip,Compose,Resize,Normalize)mean = [0.485, 0.456, 0.406] 
std = [0.229, 0.224, 0.225]
h,w=520,520transform_train = Compose([ Resize(h,w),HorizontalFlip(p=0.5), Normalize(mean=mean,std=std)])transform_val = Compose( [ Resize(h,w),Normalize(mean=mean,std=std)])            path2data="./data/"    
train_ds=myVOCSegmentation(path2data, year='2012', image_set='train', download=False, transforms=transform_train) 
print(len(train_ds)) val_ds=myVOCSegmentation(path2data, year='2012', image_set='val', download=False, transforms=transform_val)
print(len(val_ds)) 
import torch
import numpy as np
from skimage.segmentation import mark_boundaries
import matplotlib.pylab as plt
%matplotlib inline
np.random.seed(0)
num_classes=21
COLORS = np.random.randint(0, 2, size=(num_classes+1, 3),dtype="uint8")def show_img_target(img, target):if torch.is_tensor(img):img=to_pil_image(img)target=target.numpy()for ll in range(num_classes):mask=(target==ll)img=mark_boundaries(np.array(img) , mask,outline_color=COLORS[ll],color=COLORS[ll])plt.imshow(img)def re_normalize (x, mean = mean, std= std):x_r= x.clone()for c, (mean_c, std_c) in enumerate(zip(mean, std)):x_r [c] *= std_cx_r [c] += mean_creturn x_r

?展示訓練數據集示例圖像

img, mask = train_ds[10]
print(img.shape, img.type(),torch.max(img))
print(mask.shape, mask.type(),torch.max(mask))plt.figure(figsize=(20,20))img_r= re_normalize(img)
plt.subplot(1, 3, 1) 
plt.imshow(to_pil_image(img_r))plt.subplot(1, 3, 2) 
plt.imshow(mask)plt.subplot(1, 3, 3) 
show_img_target(img_r, mask)

展示驗證數據集示例圖像

img, mask = val_ds[10]
print(img.shape, img.type(),torch.max(img))
print(mask.shape, mask.type(),torch.max(mask))plt.figure(figsize=(20,20))img_r= re_normalize(img)
plt.subplot(1, 3, 1) 
plt.imshow(to_pil_image(img_r))plt.subplot(1, 3, 2) 
plt.imshow(mask)plt.subplot(1, 3, 3) 
show_img_target(img_r, mask)

創建數據加載器

?通過torch.utils.data針對訓練和驗證集分別創建Dataloader,打印示例觀察效果

from torch.utils.data import DataLoader
train_dl = DataLoader(train_ds, batch_size=4, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=8, shuffle=False) for img_b, mask_b in train_dl:print(img_b.shape,img_b.dtype)print(mask_b.shape, mask_b.dtype)breakfor img_b, mask_b in val_dl:print(img_b.shape,img_b.dtype)print(mask_b.shape, mask_b.dtype)break

創建模型

創建并打印deeplab_resnet模型結構,使用預訓練權重

from torchvision.models.segmentation import deeplabv3_resnet101
import torchmodel=deeplabv3_resnet101(pretrained=True, num_classes=21)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model=model.to(device)
print(model)

部署模型

在驗證數據集的數據批次上部署模型觀察效果?

from torch import nnmodel.eval()
with torch.no_grad():for xb, yb in val_dl:yb_pred = model(xb.to(device))yb_pred = yb_pred["out"].cpu()print(yb_pred.shape)    yb_pred = torch.argmax(yb_pred,axis=1)break
print(yb_pred.shape)plt.figure(figsize=(20,20))n=2
img, mask= xb[n], yb_pred[n]
img_r= re_normalize(img)
plt.subplot(1, 3, 1) 
plt.imshow(to_pil_image(img_r))plt.subplot(1, 3, 2) 
plt.imshow(mask)plt.subplot(1, 3, 3) 
show_img_target(img_r, mask)

可見勾勒對象方面效果很好?

定義損失函數和優化器

from torch import nn
criterion = nn.CrossEntropyLoss(reduction="sum")
from torch import optim
opt = optim.Adam(model.parameters(), lr=1e-6)def loss_batch(loss_func, output, target, opt=None):   loss = loss_func(output, target)if opt is not None:opt.zero_grad()loss.backward()opt.step()return loss.item(), Nonefrom torch.optim.lr_scheduler import ReduceLROnPlateau
lr_scheduler = ReduceLROnPlateau(opt, mode='min',factor=0.5, patience=20,verbose=1)def get_lr(opt):for param_group in opt.param_groups:return param_group['lr']current_lr=get_lr(opt)
print('current lr={}'.format(current_lr))

訓練和驗證模型

def loss_epoch(model,loss_func,dataset_dl,sanity_check=False,opt=None):running_loss=0.0len_data=len(dataset_dl.dataset)for xb, yb in dataset_dl:xb=xb.to(device)yb=yb.to(device)output=model(xb)["out"]loss_b, _ = loss_batch(loss_func, output, yb, opt)running_loss += loss_bif sanity_check is True:breakloss=running_loss/float(len_data)return loss, Noneimport copy
def train_val(model, params):num_epochs=params["num_epochs"]loss_func=params["loss_func"]opt=params["optimizer"]train_dl=params["train_dl"]val_dl=params["val_dl"]sanity_check=params["sanity_check"]lr_scheduler=params["lr_scheduler"]path2weights=params["path2weights"]loss_history={"train": [],"val": []}metric_history={"train": [],"val": []}    best_model_wts = copy.deepcopy(model.state_dict())best_loss=float('inf')    for epoch in range(num_epochs):current_lr=get_lr(opt)print('Epoch {}/{}, current lr={}'.format(epoch, num_epochs - 1, current_lr))   model.train()train_loss, train_metric=loss_epoch(model,loss_func,train_dl,sanity_check,opt)loss_history["train"].append(train_loss)metric_history["train"].append(train_metric)model.eval()with torch.no_grad():val_loss, val_metric=loss_epoch(model,loss_func,val_dl,sanity_check)loss_history["val"].append(val_loss)metric_history["val"].append(val_metric)   if val_loss < best_loss:best_loss = val_lossbest_model_wts = copy.deepcopy(model.state_dict())torch.save(model.state_dict(), path2weights)print("Copied best model weights!")lr_scheduler.step(val_loss)if current_lr != get_lr(opt):print("Loading best model weights!")model.load_state_dict(best_model_wts) print("train loss: %.6f" %(train_loss))print("val loss: %.6f" %(val_loss))print("-"*10) model.load_state_dict(best_model_wts)return model, loss_history, metric_history        
import os
opt = optim.Adam(model.parameters(), lr=1e-6)
lr_scheduler = ReduceLROnPlateau(opt, mode='min',factor=0.5, patience=20,verbose=1)path2models= "./models/"
if not os.path.exists(path2models):os.mkdir(path2models)params_train={"num_epochs": 10,"optimizer": opt,"loss_func": criterion,"train_dl": train_dl,"val_dl": val_dl,"sanity_check": True,"lr_scheduler": lr_scheduler,"path2weights": path2models+"sanity_weights.pt",
}model, loss_hist, _ = train_val(model, params_train)

繪制了訓練和驗證損失曲線?

num_epochs=params_train["num_epochs"]plt.title("Train-Val Loss")
plt.plot(range(1,num_epochs+1),loss_hist["train"],label="train")
plt.plot(range(1,num_epochs+1),loss_hist["val"],label="val")
plt.ylabel("Loss")
plt.xlabel("Training Epochs")
plt.legend()
plt.show()

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/93687.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/93687.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/93687.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

SSH 使用密鑰登錄服務器

用這種方法遠程登陸服務器的時候無需手動輸入密碼 具體步驟 客戶端通過 ssh-keygen 生成公鑰和私鑰 ssh-keygen -t rsa 生成的時候會有一系列問題&#xff0c;根據自己的需要選擇就行。生成的結果為兩個文件&#xff1a; 上傳公鑰至服務器&#xff0c;上述兩個文件一般在客戶…

MySQL 8.4 企業版啟用TDE功能和表加密

一、系統環境操作系統&#xff1a;Ubuntu 24.04 數據庫:8.4.4-commercial for Linux on x86_64 (MySQL Enterprise Server - Commercial)二、安裝TDE組件前提&#xff1a;檢查組件文件是否存在ls /usr/lib/mysql/plugin/component_keyring_encrypted_file.so1.配置全局清單文件…

【Altium designer】導出的原理圖PDF亂碼異常的解決方法

一、有些電源名字無法顯示或器件丟失 解決辦法 (1)首先AD18以及以上的新版本AD不存在該問題。 (2)其次AD17以及更舊版本的AD很可能遇到該問題,參考如下博客筆記進行操作即可: 大致的操作如下:DXP → Preferences → Schematic → Options里面“Render Text with GDI+”…

4.Ansible自動化之-部署文件到主機

4 - 部署文件到受管主機 實驗環境 先通過以下命令搭建基礎環境&#xff08;創建工作目錄、配置 Ansible 環境和主機清單&#xff09;&#xff1a; # 在控制節點&#xff08;controller&#xff09;上創建web目錄并進入&#xff0c;作為工作目錄 [bqcontroller ~]$ mkdir web &a…

Vuex的使用

Vuex 超詳細使用教程&#xff08;從入門到精通&#xff09;一、Vuex 是什么&#xff1f;Vuex 是專門為 Vue.js 設計的狀態管理庫&#xff0c;它采用集中式存儲管理應用的所有組件的狀態。簡單來說&#xff0c;Vuex 就是一個"全局變量倉庫"&#xff0c;所有組件都可以…

pytorch 數據預處理,加載,訓練,可視化流程

流程定義自定義數據集類定義訓練和驗證的數據增強定義模型、損失函數和優化器訓練循環&#xff0c;包括驗證訓練可視化整個流程模型評估高級功能擴展混合精度訓練?分布式訓練?{:width“50%” height“50%”} 定義自定義數據集類 # #1. 自定義數據集類 # class CustomImageD…

Prompt工程:OCR+LLM文檔處理的精準制導系統

在PDF OCR與大模型結合的實際應用中&#xff0c;很多團隊會發現一個現象&#xff1a;同樣的OCR文本&#xff0c;不同的Prompt設計會產生截然不同的提取效果。有時候準確率能達到95%&#xff0c;有時候卻只有60%。這背后的關鍵就在于Prompt工程的精細化程度。 &#x1f3af; 為什…

RecSys:粗排模型和精排特征體系

粗排 在推薦系統鏈路中&#xff0c;排序階段至關重要&#xff0c;通常分為召回、粗排和精排三個環節。粗排作為精排前的預處理階段&#xff0c;需要在效果和性能之間取得平衡。 雙塔模型 后期融合&#xff1a;把用戶、物品特征分別輸入不同的神經網絡&#xff0c;不對用戶、…

spring聲明式事務,finally 中return對事務回滾的影響

finally 塊中使用 return 是一個常見的編程錯誤&#xff0c;它會&#xff1a; 跳過正常的事務提交流程。吞掉異常&#xff0c;使錯誤處理失效 導致不可預測的事務行為Java 中 finally 和 return 的執行機制&#xff1a;1. finally 塊的基本特性 在 Java 中&#xff0c;finally …

WPF 打印報告圖片大小的自適應(含完整示例與詳解)

目標&#xff1a;在 FlowDocument 報告里&#xff0c;根據 1~6 張圖片的數量&#xff0c; 自動選擇 2 行 3 列 的最佳布局&#xff1b;在只有 1、2、4 張時保持“占滿感”&#xff0c;打印清晰且不變形。規則一覽&#xff1a;1 張 → 占滿 23&#xff08;大圖居中&#xff09;…

【AI大模型前沿】百度飛槳PaddleOCR 3.0開源發布,支持多語言、手寫體識別,賦能智能文檔處理

系列篇章&#x1f4a5; No.文章1【AI大模型前沿】深度剖析瑞智病理大模型 RuiPath&#xff1a;如何革新癌癥病理診斷技術2【AI大模型前沿】清華大學 CLAMP-3&#xff1a;多模態技術引領音樂檢索新潮流3【AI大模型前沿】浙大攜手阿里推出HealthGPT&#xff1a;醫學視覺語言大模…

迅為RK3588開發板Android12 制作使用系統簽名

在 Android 源碼 build/make/target/product/security/下存放著簽名文件&#xff0c;如下所示&#xff1a;將北京迅為提供的 keytool 工具拷貝到 ubuntu 中&#xff0c;然后將 Android11 或 Android12 源碼build/make/target/product/security/下的 platform.pk8 platform.x509…

Day08 Go語言學習

1.安裝Go和Goland 2.新建demo項目實踐語法并使用git實踐版本控制操作 2.1 Goland配置 路徑**&#xff1a;** GOPATH workspace GOROOT golang 文件夾&#xff1a; bin 編譯后的可執行文件 pkg 編譯后的包文件 src 源文件 遇到問題1&#xff1a;運行 ‘go build awesomeProject…

Linux-文件創建拷貝刪除剪切

文章目錄Linux文件相關命令ls通配符含義touch 創建文件命令示例cp 拷貝文件rm 刪除文件mv剪切文件Linux文件相關命令 ls ls是英文單詞list的簡寫&#xff0c;其功能為列出目錄的內容&#xff0c;是用戶最常用的命令之一&#xff0c;它類似于DOS下的dir命令。 Linux文件或者目…

RabbitMQ:交換機(Exchange)

目錄一、概述二、Direct Exchange &#xff08;直連型交換機&#xff09;三、Fanout Exchange&#xff08;扇型交換機&#xff09;四、Topic Exchange&#xff08;主題交換機&#xff09;五、Header Exchange&#xff08;頭交換機&#xff09;六、Default Exchange&#xff08;…

【實時Linux實戰系列】基于實時Linux的物聯網系統設計

隨著物聯網&#xff08;IoT&#xff09;技術的飛速發展&#xff0c;越來越多的設備被連接到互聯網&#xff0c;形成了一個龐大而復雜的網絡。這些設備從簡單的傳感器到復雜的工業控制系統&#xff0c;都在實時地產生和交換數據。實時Linux作為一種強大的操作系統&#xff0c;為…

第五天~提取Arxml中描述信息New_CanCluster--Expert

?? ARXML描述信息提取:挖掘汽車電子設計的"知識寶藏" 在AUTOSAR工程中,描述信息如同埋藏在ARXML文件中的金礦,而New_CanCluster--Expert正是打開這座寶藏的密鑰。本文將帶您深度探索ARXML描述信息的提取藝術,解鎖汽車電子設計的核心知識資產! ?? 為什么描述…

開源 C++ QT Widget 開發(一)工程文件結構

文章的目的為了記錄使用C 進行QT Widget 開發學習的經歷。臨時學習&#xff0c;完成app的開發。開發流程和要點有些記憶模糊&#xff0c;趕緊記錄&#xff0c;防止忘記。 相關鏈接&#xff1a; 開源 C QT Widget 開發&#xff08;一&#xff09;工程文件結構-CSDN博客 開源 C…

手寫C++ string類實現詳解

類定義cppnamespace ym {class string {private:char* _str; // 字符串數據size_t _size; // 當前字符串長度size_t _capacity; // 當前分配的內存容量static const size_t npos -1; // 特殊值&#xff0c;表示最大可能位置public:// 構造函數和析構函數string(…

C++信息學奧賽一本通-第一部分-基礎一-第3章-第2節

C信息學奧賽一本通-第一部分-基礎一-第3章-第2節 2057 星期幾 #include <iostream>using namespace std;int main() {int day; cin >> day;switch (day) {case 1:cout << "Monday";break;case 2:cout << "Tuesday";break;case 3:c…