【PyTorch】多對象分割

對象分割任務的目標是找到圖像中目標對象的邊界。實際應用例如自動駕駛汽車和醫學成像分析。這里將使用PyTorch開發一個深度學習模型來完成多對象分割任務。多對象分割的主要目標是自動勾勒出圖像中多個目標對象的邊界。

對象的邊界通常由與圖像大小相同的分割掩碼定義,在分割掩碼中屬于目標對象的所有像素基于預定義的標記被標記為相同。

目錄

創建數據集

創建數據加載器

創建模型

部署模型

定義損失函數和優化器

訓練和驗證模型


創建數據集

from torchvision.datasets import VOCSegmentation
from PIL import Image   
from torchvision.transforms.functional import to_tensor, to_pil_imageclass myVOCSegmentation(VOCSegmentation):def __getitem__(self, index):img = Image.open(self.images[index]).convert('RGB')target = Image.open(self.masks[index])if self.transforms is not None:augmented= self.transforms(image=np.array(img), mask=np.array(target))img = augmented['image']target = augmented['mask']                  target[target>20]=0img= to_tensor(img)            target= torch.from_numpy(target).type(torch.long)return img, targetfrom albumentations import (HorizontalFlip,Compose,Resize,Normalize)mean = [0.485, 0.456, 0.406] 
std = [0.229, 0.224, 0.225]
h,w=520,520transform_train = Compose([ Resize(h,w),HorizontalFlip(p=0.5), Normalize(mean=mean,std=std)])transform_val = Compose( [ Resize(h,w),Normalize(mean=mean,std=std)])            path2data="./data/"    
train_ds=myVOCSegmentation(path2data, year='2012', image_set='train', download=False, transforms=transform_train) 
print(len(train_ds)) val_ds=myVOCSegmentation(path2data, year='2012', image_set='val', download=False, transforms=transform_val)
print(len(val_ds)) 
import torch
import numpy as np
from skimage.segmentation import mark_boundaries
import matplotlib.pylab as plt
%matplotlib inline
np.random.seed(0)
num_classes=21
COLORS = np.random.randint(0, 2, size=(num_classes+1, 3),dtype="uint8")def show_img_target(img, target):if torch.is_tensor(img):img=to_pil_image(img)target=target.numpy()for ll in range(num_classes):mask=(target==ll)img=mark_boundaries(np.array(img) , mask,outline_color=COLORS[ll],color=COLORS[ll])plt.imshow(img)def re_normalize (x, mean = mean, std= std):x_r= x.clone()for c, (mean_c, std_c) in enumerate(zip(mean, std)):x_r [c] *= std_cx_r [c] += mean_creturn x_r

?展示訓練數據集示例圖像

img, mask = train_ds[10]
print(img.shape, img.type(),torch.max(img))
print(mask.shape, mask.type(),torch.max(mask))plt.figure(figsize=(20,20))img_r= re_normalize(img)
plt.subplot(1, 3, 1) 
plt.imshow(to_pil_image(img_r))plt.subplot(1, 3, 2) 
plt.imshow(mask)plt.subplot(1, 3, 3) 
show_img_target(img_r, mask)

展示驗證數據集示例圖像

img, mask = val_ds[10]
print(img.shape, img.type(),torch.max(img))
print(mask.shape, mask.type(),torch.max(mask))plt.figure(figsize=(20,20))img_r= re_normalize(img)
plt.subplot(1, 3, 1) 
plt.imshow(to_pil_image(img_r))plt.subplot(1, 3, 2) 
plt.imshow(mask)plt.subplot(1, 3, 3) 
show_img_target(img_r, mask)

創建數據加載器

?通過torch.utils.data針對訓練和驗證集分別創建Dataloader,打印示例觀察效果

from torch.utils.data import DataLoader
train_dl = DataLoader(train_ds, batch_size=4, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=8, shuffle=False) for img_b, mask_b in train_dl:print(img_b.shape,img_b.dtype)print(mask_b.shape, mask_b.dtype)breakfor img_b, mask_b in val_dl:print(img_b.shape,img_b.dtype)print(mask_b.shape, mask_b.dtype)break

創建模型

創建并打印deeplab_resnet模型結構,使用預訓練權重

from torchvision.models.segmentation import deeplabv3_resnet101
import torchmodel=deeplabv3_resnet101(pretrained=True, num_classes=21)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model=model.to(device)
print(model)

部署模型

在驗證數據集的數據批次上部署模型觀察效果?

from torch import nnmodel.eval()
with torch.no_grad():for xb, yb in val_dl:yb_pred = model(xb.to(device))yb_pred = yb_pred["out"].cpu()print(yb_pred.shape)    yb_pred = torch.argmax(yb_pred,axis=1)break
print(yb_pred.shape)plt.figure(figsize=(20,20))n=2
img, mask= xb[n], yb_pred[n]
img_r= re_normalize(img)
plt.subplot(1, 3, 1) 
plt.imshow(to_pil_image(img_r))plt.subplot(1, 3, 2) 
plt.imshow(mask)plt.subplot(1, 3, 3) 
show_img_target(img_r, mask)

可見勾勒對象方面效果很好?

定義損失函數和優化器

from torch import nn
criterion = nn.CrossEntropyLoss(reduction="sum")
from torch import optim
opt = optim.Adam(model.parameters(), lr=1e-6)def loss_batch(loss_func, output, target, opt=None):   loss = loss_func(output, target)if opt is not None:opt.zero_grad()loss.backward()opt.step()return loss.item(), Nonefrom torch.optim.lr_scheduler import ReduceLROnPlateau
lr_scheduler = ReduceLROnPlateau(opt, mode='min',factor=0.5, patience=20,verbose=1)def get_lr(opt):for param_group in opt.param_groups:return param_group['lr']current_lr=get_lr(opt)
print('current lr={}'.format(current_lr))

訓練和驗證模型

def loss_epoch(model,loss_func,dataset_dl,sanity_check=False,opt=None):running_loss=0.0len_data=len(dataset_dl.dataset)for xb, yb in dataset_dl:xb=xb.to(device)yb=yb.to(device)output=model(xb)["out"]loss_b, _ = loss_batch(loss_func, output, yb, opt)running_loss += loss_bif sanity_check is True:breakloss=running_loss/float(len_data)return loss, Noneimport copy
def train_val(model, params):num_epochs=params["num_epochs"]loss_func=params["loss_func"]opt=params["optimizer"]train_dl=params["train_dl"]val_dl=params["val_dl"]sanity_check=params["sanity_check"]lr_scheduler=params["lr_scheduler"]path2weights=params["path2weights"]loss_history={"train": [],"val": []}metric_history={"train": [],"val": []}    best_model_wts = copy.deepcopy(model.state_dict())best_loss=float('inf')    for epoch in range(num_epochs):current_lr=get_lr(opt)print('Epoch {}/{}, current lr={}'.format(epoch, num_epochs - 1, current_lr))   model.train()train_loss, train_metric=loss_epoch(model,loss_func,train_dl,sanity_check,opt)loss_history["train"].append(train_loss)metric_history["train"].append(train_metric)model.eval()with torch.no_grad():val_loss, val_metric=loss_epoch(model,loss_func,val_dl,sanity_check)loss_history["val"].append(val_loss)metric_history["val"].append(val_metric)   if val_loss < best_loss:best_loss = val_lossbest_model_wts = copy.deepcopy(model.state_dict())torch.save(model.state_dict(), path2weights)print("Copied best model weights!")lr_scheduler.step(val_loss)if current_lr != get_lr(opt):print("Loading best model weights!")model.load_state_dict(best_model_wts) print("train loss: %.6f" %(train_loss))print("val loss: %.6f" %(val_loss))print("-"*10) model.load_state_dict(best_model_wts)return model, loss_history, metric_history        
import os
opt = optim.Adam(model.parameters(), lr=1e-6)
lr_scheduler = ReduceLROnPlateau(opt, mode='min',factor=0.5, patience=20,verbose=1)path2models= "./models/"
if not os.path.exists(path2models):os.mkdir(path2models)params_train={"num_epochs": 10,"optimizer": opt,"loss_func": criterion,"train_dl": train_dl,"val_dl": val_dl,"sanity_check": True,"lr_scheduler": lr_scheduler,"path2weights": path2models+"sanity_weights.pt",
}model, loss_hist, _ = train_val(model, params_train)

繪制了訓練和驗證損失曲線?

num_epochs=params_train["num_epochs"]plt.title("Train-Val Loss")
plt.plot(range(1,num_epochs+1),loss_hist["train"],label="train")
plt.plot(range(1,num_epochs+1),loss_hist["val"],label="val")
plt.ylabel("Loss")
plt.xlabel("Training Epochs")
plt.legend()
plt.show()

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/96770.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/96770.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/96770.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

RabbitMQ---面試題

總結我們所學內容&#xff0c;這里推薦博客進行復習 RabbitMQ---面試題_rabbitmq常問面試題-CSDN博客

MasterGo自動布局(Auto Layout)

自動布局是用來表示 子元素與子元素之間互相影響的一種排版方式,是一種響應式布局技術。一般是將所有元素設計完成后再使用自動布局進行設置。 自動布局就是響應式布局,就是在不同尺寸的手機上寬度不同都應該怎么展示。 一般頁面的一級元素使用約束進行相對定位,二級元素及里…

還在重啟應用改 Topic?Spring Boot 動態 Kafka 消費的“終極形態”

場景描述&#xff1a; 你的一個微服務正在穩定地消費 Kafka 的 order_topic。現在&#xff0c;上游系統為了做業務隔離&#xff0c;新增加了一個 order_topic_vip&#xff0c;并開始向其中投遞 VIP 用戶的訂單。你需要在不重啟、不發布新版本的情況下&#xff0c;讓你現有的消費…

使用vllm部署neo4j的text2cypher-gemma-2-9b-it-finetuned-2024v1模型

使用vllm部署neo4j的text2cypher-gemma-2-9b-it-finetuned-2024v1模型 系統環境準備 由于使用的基于 nvcr.io/nvidia/cuda:12.1.1-cudnn8-runtime-ubuntu22.04 的 workbench,需要進行以下準備(其他系統環境可忽略) ldconfig -p | grep libcudnn 找到 libcudnn 的so庫,然…

Coze源碼分析-資源庫-創建知識庫-前端源碼-核心組件

概述 本文深入分析Coze Studio中用戶創建知識庫功能的前端實現。該功能允許用戶在資源庫中創建、編輯和管理知識庫資源&#xff0c;為開發者提供了強大的知識管理和數據處理能力。通過對源碼的詳細解析&#xff0c;我們將了解從資源庫入口到知識庫配置彈窗的完整架構設計、組件…

基于時空數據的網約車訂單需求預測與調度優化

一、引言隨著共享出行行業的蓬勃發展&#xff0c;網約車已成為城市交通的重要組成部分。如何精準預測訂單需求并優化車輛調度&#xff0c;是提升平臺運營效率、改善用戶體驗的關鍵。本文提出一種基于時空數據的網約車訂單需求預測與調度優化方案&#xff0c;通過網格化城市空間…

數據結構 Java對象的比較

在Java中&#xff0c;凡是涉及到比較的&#xff0c;可以分為兩類情況&#xff1a;一類是基本數據類型的比較&#xff0c;另一類是引用數據類型的比較。對于基本數據類型的比較&#xff0c;我們通過關系運算符&#xff08;、>、<、!、>、<&#xff09;進行它們之間的…

企智匯建筑施工項目管理系統:全周期數字化管控,賦能工程企業降本增效!?建筑工程項目管理軟件!建筑工程項目管理系統!建筑項目管理軟件企智匯軟件

在建筑施工行業&#xff0c;項目進度滯后、成本超支、質量安全隱患頻發、多方協同不暢等問題&#xff0c;一直是制約企業發展的痛點。傳統依賴人工記錄、Excel 統計的管理模式&#xff0c;不僅效率低下&#xff0c;更易因信息斷層導致決策失誤。企智匯建筑施工項目管理系統憑借…

k8s-臨時容器學習

臨時容器學習1. 什么是臨時容器2. 實驗1. 什么是臨時容器 在官網&#xff1a;https://kubernetes.io/zh-cn/docs/concepts/workloads/pods/ephemeral-containers/ 中有介紹 臨時容器是用于調試Pod中崩潰的容器或者不具備調試工具&#xff0c;比如在一個運行著業務的容器中&am…

Python 2025:低代碼開發與自動化運維的新紀元

從智能運維到無代碼應用&#xff0c;Python正在重新定義企業級應用開發范式在2025年的企業技術棧中&#xff0c;Python已經從一個"開發工具"演變為業務自動化的核心平臺。根據Gartner 2025年度報告&#xff0c;68%的企業在自動化項目中使用Python作為主要開發語言&am…

Netty 在 API 網關中的應用篇(請求轉發、限流、路由、負載均衡)

Netty 在 API 網關中的應用篇&#xff08;請求轉發、限流、路由、負載均衡&#xff09;隨著微服務架構的普及&#xff0c;API 網關成為服務之間通信和安全控制的核心組件。在構建高性能網關時&#xff0c;Netty 因其高吞吐、低延遲和異步非阻塞 IO 的特性&#xff0c;成為不少開…

基于STM32設計的青少年學習監控系統(華為云IOT)_282

文章目錄 一、前言 1.1 項目介紹 【1】項目開發背景 【2】設計實現的功能 【3】項目硬件模塊組成 【4】設計意義 【5】國內外研究現狀 【6】摘要 1.2 設計思路 1.3 系統功能總結 1.4 開發工具的選擇 【1】設備端開發 【2】上位機開發 1.5 參考文獻 1.6 系統框架圖 1.7 系統原理…

手寫Spring底層機制的實現【初始化IOC容器+依賴注入+BeanPostProcesson機制+AOP】

摘要&#xff1a;建議先看“JAVA----Spring的AOP和動態代理”這個文章&#xff0c;解釋都在代碼中&#xff01;一&#xff1a;提出問題依賴注入1.單例beans.xml<?xml version"1.0" encoding"UTF-8"?> <beans xmlns"http://www.springframe…

5G NR-NTN協議學習系列:NR-NTN介紹(2)

NTN網絡作為依賴衛星的通信方式&#xff0c;需要面對的通信距離&#xff0c;通信雙方的移動速度都和之前TN網絡存在巨大差異。在距離方面相比蜂窩地面網絡Terrestrial Network通信距離從最小幾百米到最大幾十km的情況&#xff0c;NTN非地面網絡的通信距離即使是近地軌道的LEO衛…

線掃相機采集圖像起始位置不正確原因總結

1、幀觸發開始時間問題 問題描述: 由于幀觸發決定了線掃相機的開始采集圖像位置,比如正確的位置是A點開始采集,結果你從B點開始觸發幀信號,這樣出來的圖像起始位置就不對 解決手段: 軟件需要記錄幀觸發時軸的位置 1)控制卡控制軸 一般使用位置比較觸發,我們可以通過監…

校園管理系統練習項目源碼-前后端分離-【node版】

今天給大家分享一個校園管理系統&#xff0c;前后端分離項目。這是最近在練習前端編程&#xff0c;結合 node 寫的一個完整的項目。 使用的技術&#xff1a; Node.js&#xff1a;版本要求16.20以上。 后端框架&#xff1a;Express框架。 數據庫&#xff1a; MySQL 8.0。 Vue2&a…

【項目】 :C++ - 仿mudou庫one thread one loop式并發服務器實現(模塊劃分)

【項目】 &#xff1a;C - 仿mudou庫one thread one loop式并發服務器實現一、HTTP 服務器與 Reactor 模型1.1、HTTP 服務器概念實現步驟難點1.2、Reactor 模型概念分類1. 單 Reactor 單線程2. 單 Reactor 多線程3. 多 Reactor 多線程目標定位總結二、功能模塊劃分2.1、SERVER …

浴室柜市占率第一,九牧重構數智衛浴新生態

作者 | 曾響鈴文 | 響鈴說2025年上半年&#xff0c;家居市場在政策的推動下展現出獨特的發展態勢。國家出臺的一系列鼓勵家居消費的政策&#xff0c;如“以舊換新”國補政策帶動超6000萬件廚衛產品煥新&#xff0c;以及我國超2.7億套房齡超20年的住宅進入改造周期&#xff0c;都…

源碼分析之Leaflet中TileLayer

概述 TileLayer 是 Layer 的子類&#xff0c;繼承自GridLayer基類&#xff0c;用于加載和顯示瓦片地圖。它提供了加載和顯示瓦片地圖的功能&#xff0c;支持自定義瓦片的 URL 格式和參數。 源碼分析 源碼實現 TileLayer的源碼實現如下&#xff1a; export var TileLayer GridL…

php學習(第二天)

一.網站基本概念-服務器 1.什么是服務器? 1.1定義 服務器&#xff08;server&#xff09;,也稱伺服器&#xff0c;是提供計算服務的設備。 供計算服務的設備” 這里的“設備”不僅指物理機器&#xff08;如一臺配有 CPU、內存、硬盤的計算機&#xff09;&#xff0c;也可以指…