【PyTorch】單對象分割項目

對象分割是在圖像中找到目標對象的邊界的過程。單目標分割的重點是自動勾勒出圖像中一個目標對象的邊界。對象邊界通常由二進制掩碼定義。

通過二進制掩碼,可以在圖像上覆蓋輪廓以勾勒出對象邊界。例如以下圖片描繪了胎兒的超聲圖像、胎兒頭部的二進制掩碼以及覆蓋在超聲圖像上的胎兒頭部的圖像分割:

目錄

準備數據集

創建自定義數據集

劃分數據集

創建數據加載器

搭建模型

定義損失函數

定義優化器

訓練和評估模型


準備數據集

使用胎兒頭圍數據集Automated measurement of fetal head circumference,在懷孕期間,超聲成像用于測量胎兒頭圍,監測胎兒的生長。數據集包含標準平面的二維(2D)超聲圖像。Automated measurement of fetal head circumferenceFor more information about this dataset go to:?https://hc18.grand-challenge.org/https://zenodo.org/record/1322001#.XcX1jk9KhhE

import os
path2train="./data/training_set/"imgsList=[pp for pp in os.listdir(path2train) if "Annotation" not in pp]
anntsList=[pp for pp in os.listdir(path2train) if "Annotation" in pp]
print("number of images:", len(imgsList))
print("number of annotations:", len(anntsList))import numpy as np
np.random.seed(2024)
rndImgs=np.random.choice(imgsList,4)
rndImgsimport matplotlib.pylab as plt
from PIL import Image
from scipy import ndimage as ndi
from skimage.segmentation import mark_boundaries
from torchvision.transforms.functional import to_tensor, to_pil_image
import torchdef show_img_mask(img, mask):if torch.is_tensor(img):img=to_pil_image(img)mask=to_pil_image(mask)img_mask=mark_boundaries(np.array(img), np.array(mask),outline_color=(0,1,0),color=(0,1,0))plt.imshow(img_mask)
for fn in rndImgs:path2img = os.path.join(path2train, fn)path2annt= path2img.replace(".png", "_Annotation.png")img = Image.open(path2img)annt_edges = Image.open(path2annt)mask = ndi.binary_fill_holes(annt_edges)        plt.figure()plt.subplot(1, 3, 1) plt.imshow(img, cmap="gray")plt.subplot(1, 3, 2) plt.imshow(mask, cmap="gray")plt.subplot(1, 3, 3) show_img_mask(img, mask)

plt.figure()
plt.subplot(1, 3, 1) 
plt.imshow(img, cmap="gray")
plt.axis('off')plt.subplot(1, 3, 2) 
plt.imshow(mask, cmap="gray")
plt.axis('off')    plt.subplot(1, 3, 3) 
show_img_mask(img, mask)
plt.axis('off')

# conda install conda-forge/label/cf202003::albumentations
from albumentations import (HorizontalFlip,VerticalFlip,    Compose,Resize,
)h,w=128,192
transform_train = Compose([ Resize(h,w), HorizontalFlip(p=0.5), VerticalFlip(p=0.5), ])transform_val = Resize(h,w)

創建自定義數據集

from torch.utils.data import Dataset
from PIL import Image
from torchvision.transforms.functional import to_tensor, to_pil_imageclass fetal_dataset(Dataset):def __init__(self, path2data, transform=None):      imgsList=[pp for pp in os.listdir(path2data) if "Annotation" not in pp]anntsList=[pp for pp in os.listdir(path2train) if "Annotation" in pp]self.path2imgs = [os.path.join(path2data, fn) for fn in imgsList] self.path2annts= [p2i.replace(".png", "_Annotation.png") for p2i in self.path2imgs]self.transform = transformdef __len__(self):return len(self.path2imgs)def __getitem__(self, idx):path2img = self.path2imgs[idx]image = Image.open(path2img)path2annt = self.path2annts[idx]annt_edges = Image.open(path2annt)mask = ndi.binary_fill_holes(annt_edges)        image= np.array(image)mask=mask.astype("uint8")        if self.transform:augmented = self.transform(image=image, mask=mask)image = augmented['image']mask = augmented['mask']            image= to_tensor(image)            mask=255*to_tensor(mask)            return image, mask
fetal_ds1=fetal_dataset(path2train, transform=transform_train)
fetal_ds2=fetal_dataset(path2train, transform=transform_val)
img,mask=fetal_ds1[0]
print(img.shape, img.type(),torch.max(img))
print(mask.shape, mask.type(),torch.max(mask))show_img_mask(img, mask)

劃分數據集

按照8:2的比例劃分訓練數據集和驗證數據集

from sklearn.model_selection import ShuffleSplitsss = ShuffleSplit(n_splits=1, test_size=0.2, random_state=0)
indices=range(len(fetal_ds1))
for train_index, val_index in sss.split(indices):print(len(train_index))print("-"*10)print(len(val_index))

from torch.utils.data import Subsettrain_ds=Subset(fetal_ds1,train_index)
print(len(train_ds))
val_ds=Subset(fetal_ds2,val_index)
print(len(val_ds))

展示訓練數據集示例圖像?

plt.figure(figsize=(5,5))
for img,mask in train_ds:show_img_mask(img,mask)break

展示驗證數據集示例圖像?

plt.figure(figsize=(5,5))
for img,mask in val_ds:show_img_mask(img,mask)break

創建數據加載器

from torch.utils.data import DataLoader
train_dl = DataLoader(train_ds, batch_size=8, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=16, shuffle=False) for img_b, mask_b in train_dl:print(img_b.shape,img_b.dtype)print(mask_b.shape, mask_b.dtype)breakfor img_b, mask_b in val_dl:print(img_b.shape,img_b.dtype)print(mask_b.shape, mask_b.dtype)breaktorch.max(img_b)

搭建模型

基于編碼器-解碼器模型encoder–decoder model搭建分割任務模型

import torch.nn as nn
import torch.nn.functional as Fclass SegNet(nn.Module):def __init__(self, params):super(SegNet, self).__init__()C_in, H_in, W_in=params["input_shape"]init_f=params["initial_filters"] num_outputs=params["num_outputs"] self.conv1 = nn.Conv2d(C_in, init_f, kernel_size=3,stride=1,padding=1)self.conv2 = nn.Conv2d(init_f, 2*init_f, kernel_size=3,stride=1,padding=1)self.conv3 = nn.Conv2d(2*init_f, 4*init_f, kernel_size=3,padding=1)self.conv4 = nn.Conv2d(4*init_f, 8*init_f, kernel_size=3,padding=1)self.conv5 = nn.Conv2d(8*init_f, 16*init_f, kernel_size=3,padding=1)self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)self.conv_up1 = nn.Conv2d(16*init_f, 8*init_f, kernel_size=3,padding=1)self.conv_up2 = nn.Conv2d(8*init_f, 4*init_f, kernel_size=3,padding=1)self.conv_up3 = nn.Conv2d(4*init_f, 2*init_f, kernel_size=3,padding=1)self.conv_up4 = nn.Conv2d(2*init_f, init_f, kernel_size=3,padding=1)self.conv_out = nn.Conv2d(init_f, num_outputs , kernel_size=3,padding=1)    def forward(self, x):x = F.relu(self.conv1(x))x = F.max_pool2d(x, 2, 2)x = F.relu(self.conv2(x))x = F.max_pool2d(x, 2, 2)x = F.relu(self.conv3(x))x = F.max_pool2d(x, 2, 2)x = F.relu(self.conv4(x))x = F.max_pool2d(x, 2, 2)x = F.relu(self.conv5(x))x=self.upsample(x)x = F.relu(self.conv_up1(x))x=self.upsample(x)x = F.relu(self.conv_up2(x))x=self.upsample(x)x = F.relu(self.conv_up3(x))x=self.upsample(x)x = F.relu(self.conv_up4(x))x = self.conv_out(x)return x params_model={"input_shape": (1,h,w),"initial_filters": 16, "num_outputs": 1,}model = SegNet(params_model)import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model=model.to(device)

打印模型結構

print(model)

獲取模型摘要?

from torchsummary import summary
summary(model, input_size=(1, h, w))

定義損失函數

def dice_loss(pred, target, smooth = 1e-5):intersection = (pred * target).sum(dim=(2,3))union= pred.sum(dim=(2,3)) + target.sum(dim=(2,3)) dice= 2.0 * (intersection + smooth) / (union+ smooth)    loss = 1.0 - dicereturn loss.sum(), dice.sum()import torch.nn.functional as Fdef loss_func(pred, target):bce = F.binary_cross_entropy_with_logits(pred, target,  reduction='sum')pred= torch.sigmoid(pred)dlv, _ = dice_loss(pred, target)loss = bce  + dlvreturn lossfor img_v,mask_v in val_dl:mask_v= mask_v[8:]breakfor img_t,mask_t in train_dl:breakprint(dice_loss(mask_v,mask_v))
loss_func(mask_v,torch.zeros_like(mask_v))

import torchvisiondef metrics_batch(pred, target):pred= torch.sigmoid(pred)_, metric=dice_loss(pred, target)return metricdef loss_batch(loss_func, output, target, opt=None):   loss = loss_func(output, target)with torch.no_grad():pred= torch.sigmoid(output)_, metric_b=dice_loss(pred, target)if opt is not None:opt.zero_grad()loss.backward()opt.step()return loss.item(), metric_b

定義優化器

from torch import optim
opt = optim.Adam(model.parameters(), lr=3e-4)from torch.optim.lr_scheduler import ReduceLROnPlateau
lr_scheduler = ReduceLROnPlateau(opt, mode='min',factor=0.5, patience=20,verbose=1)def get_lr(opt):for param_group in opt.param_groups:return param_group['lr']current_lr=get_lr(opt)
print('current lr={}'.format(current_lr))

訓練和評估模型

def loss_epoch(model,loss_func,dataset_dl,sanity_check=False,opt=None):running_loss=0.0running_metric=0.0len_data=len(dataset_dl.dataset)for xb, yb in dataset_dl:xb=xb.to(device)yb=yb.to(device)output=model(xb)loss_b, metric_b=loss_batch(loss_func, output, yb, opt)running_loss += loss_bif metric_b is not None:running_metric+=metric_bif sanity_check is True:breakloss=running_loss/float(len_data)metric=running_metric/float(len_data)return loss, metric
import copy
def train_val(model, params):num_epochs=params["num_epochs"]loss_func=params["loss_func"]opt=params["optimizer"]train_dl=params["train_dl"]val_dl=params["val_dl"]sanity_check=params["sanity_check"]lr_scheduler=params["lr_scheduler"]path2weights=params["path2weights"]loss_history={"train": [],"val": []}metric_history={"train": [],"val": []}    best_model_wts = copy.deepcopy(model.state_dict())best_loss=float('inf')    for epoch in range(num_epochs):current_lr=get_lr(opt)print('Epoch {}/{}, current lr={}'.format(epoch, num_epochs - 1, current_lr))   model.train()train_loss, train_metric=loss_epoch(model,loss_func,train_dl,sanity_check,opt)loss_history["train"].append(train_loss)metric_history["train"].append(train_metric)model.eval()with torch.no_grad():val_loss, val_metric=loss_epoch(model,loss_func,val_dl,sanity_check)loss_history["val"].append(val_loss)metric_history["val"].append(val_metric)   if val_loss < best_loss:best_loss = val_lossbest_model_wts = copy.deepcopy(model.state_dict())torch.save(model.state_dict(), path2weights)print("Copied best model weights!")lr_scheduler.step(val_loss)if current_lr != get_lr(opt):print("Loading best model weights!")model.load_state_dict(best_model_wts) print("train loss: %.6f, dice: %.2f" %(train_loss,100*train_metric))print("val loss: %.6f, dice: %.2f" %(val_loss,100*val_metric))print("-"*10) model.load_state_dict(best_model_wts)return model, loss_history, metric_history        
opt = optim.Adam(model.parameters(), lr=3e-4)# 定義學習率調度器,當驗證集上的損失不再下降時,將學習率降低為原來的0.5倍,等待20個epoch后再次降低學習率
lr_scheduler = ReduceLROnPlateau(opt, mode='min',factor=0.5, patience=20,verbose=1)path2models= "./models/"# 判斷path2models路徑是否存在,如果不存在則創建該路徑
if not os.path.exists(path2models):os.mkdir(path2models)params_train={"num_epochs": 100,"optimizer": opt,"loss_func": loss_func,"train_dl": train_dl,"val_dl": val_dl,"sanity_check": False,"lr_scheduler": lr_scheduler,"path2weights": path2models+"weights.pt",
}model,loss_hist,metric_hist=train_val(model,params_train)

打印訓練驗證損失

num_epochs=params_train["num_epochs"]plt.title("Train-Val Loss")
plt.plot(range(1,num_epochs+1),loss_hist["train"],label="train")
plt.plot(range(1,num_epochs+1),loss_hist["val"],label="val")
plt.ylabel("Loss")
plt.xlabel("Training Epochs")
plt.legend()
plt.show()

?

打印訓練驗證精度

# plot accuracy progress
plt.title("Train-Val Accuracy")
plt.plot(range(1,num_epochs+1),metric_hist["train"],label="train")
plt.plot(range(1,num_epochs+1),metric_hist["val"],label="val")
plt.ylabel("Accuracy")
plt.xlabel("Training Epochs")
plt.legend()
plt.show()

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/919847.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/919847.shtml
英文地址,請注明出處:http://en.pswp.cn/news/919847.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

esp dl

放下了好多年 又回到了dl 該忘的也忘的差不多了 其實沒啥復雜的 只是不習慣 熟悉而已 好吧 現代的人工智能體 還是存在著很大的問題 眼睛 耳朵 思考 雖然功能是正常的 但距離&#xff02;真正&#xff02;(&#xff09;意思上的獨立意識個體 還是差別很大 再等個幾十年 看看…

基于django/python的服裝銷售系統平臺/服裝購物系統/基于django/python的服裝商城

基于django/python的服裝銷售系統平臺/服裝購物系統/基于django/python的服裝商城

詳解ThreadLocal<HttpServletRequest> requestThreadLocal

public static ThreadLocal<HttpServletRequest> requestThreadLocal ThreadLocal.withInitial(() -> null);一、代碼逐部分詳解 1. public static public&#xff1a;表示這個變量是公開的&#xff0c;其他類可以訪問。static&#xff1a;表示這是類變量&#xff0c…

Vue2 響應式系統設計原理與實現

文章目錄Vue2 響應式系統設計原理與實現Vue2 響應式系統設計原理與實現 Vue2 的響應式原理主要基于以下幾點&#xff1a; 使用 Object.defineProperty () 方法對數據對象的屬性進行劫持 當數據發生變化時&#xff0c;通知依賴該數據的視圖進行更新 實現一個發布 - 訂閱模式&a…

探索 JUC:Java 并發編程的神奇世界

探索 JUC&#xff1a;Java 并發編程的神奇世界 在 Java 編程領域&#xff0c;隨著多核處理器的普及和應用場景復雜度的提升&#xff0c;并發編程變得愈發重要。Java 并發包&#xff08;JUC&#xff0c;Java.util.concurrent&#xff09;就像是一座寶藏庫&#xff0c;為開發者提…

selenium采集數據怎么應對反爬機制?

selenium是一個非常強大的瀏覽器自動化工具&#xff0c;通過操作瀏覽器來抓取動態網頁內容&#xff0c;可以很好的處理JavaScript和AJAX加載的網頁。 它能支持像點擊按鈕、懸停元素、填寫表單等各種自動化操作&#xff0c;所以很適合自動化測試和數據采集。 selenium與各種主流…

指定文件夾上的壓縮圖像格式tiff轉換為 jpg 批量腳本

文章大綱 背景簡介 代碼 背景簡介 隨著數字成像技術在科研、醫學影像和遙感等領域的廣泛應用,多頁TIFF(Tag Image File Format)文件因其支持多維數據存儲和高位深特性,成為存儲序列圖像、顯微鏡切片或衛星遙感數據的首選格式。然而在實際應用中,這類文件存在以下顯著痛點…

Docker 部署 MySQL 8.0 完整指南:從拉取鏡像到配置遠程訪問

目錄前言一、拉取鏡像二、查看鏡像三、運行容器命令參數說明&#xff1a;四、查看運行容器五、進入容器內部六、修改 MySQL 配置1. 創建配置文件2. 配置內容七、重啟 MySQL 服務八、設置 Docker 啟動時自動啟動 MySQL九、再次重啟 MySQL十、授權遠程訪問1. 進入容器內部2. 登錄…

IntelliJ IDEA 常用快捷鍵筆記(Windows)

前言&#xff1a;特別標注的快捷鍵&#xff08;Windows&#xff09;快捷鍵功能說明Ctrl Alt M將選中代碼提取成方法Ctrl Alt T包裹選中代碼塊&#xff08;try/catch、if、for 等&#xff09;Ctrl H查看類的繼承層次Alt 7打開項目結構面板Ctrl F12打開當前文件結構視圖Ct…

疏老師-python訓練營-Day54Inception網絡及其思考

浙大疏錦行 DAY54 一、 inception網絡介紹 今天我們介紹inception&#xff0c;也就是GoogleNet 傳統計算機視覺的發展史 從上面的鏈接&#xff0c;可以看到其實inceptionnet是在resnet之前的&#xff0c;那為什么我今天才說呢&#xff1f;因為他要引出我們后面的特征融合和…

LeetCode第3304題 - 找出第 K 個字符 I

題目 解答 class Solution {public char kthCharacter(int k) {int n 0;int v 1;while (v < k) {v << 1;n;}String target kthCharacterString(n);return target.charAt(k - 1);}public String kthCharacterString(int n) {if (n 0) {return "a";}Str…

Codeforces Round 1043 (Div. 3) D-F 題解

D. From 1 to Infinity 題意 有一個無限長的序列&#xff0c;是把所有正整數按次序拼接&#xff1a;123456789101112131415...\texttt{123456789101112131415...}123456789101112131415...。求這個序列前 k(k≤1015)k(k\le 10^{15})k(k≤1015) 位的數位和。 思路 二分出第 …

【C語言16天強化訓練】從基礎入門到進階:Day 7

&#x1f525;個人主頁&#xff1a;艾莉絲努力練劍 ?專欄傳送門&#xff1a;《C語言》、《數據結構與算法》、C語言刷題12天IO強訓、LeetCode代碼強化刷題、洛谷刷題、C/C基礎知識知識強化補充、C/C干貨分享&學習過程記錄 &#x1f349;學習方向&#xff1a;C/C方向學習者…

【AI基礎:神經網絡】16、神經網絡的生理學根基:從人腦結構到AI架構,揭秘道法自然的智能密碼

“道法自然,久藏玄冥”——人工神經網絡(ANN)的崛起并非偶然,而是對自然界最精妙的智能系統——人腦——的深度模仿與抽象。從單個神經元的信號處理到大腦皮層的層級組織,從突觸可塑性的學習機制到全腦并行計算的高效能效,生物大腦的“玄冥”智慧為AI提供了源源不斷的靈感…

容器安全實踐(一):概念篇 - 從“想當然”到“真相”

在容器化技術日益普及的今天&#xff0c;許多開發者和運維人員都將應用部署在 Docker 或 Kubernetes 中。然而&#xff0c;一個普遍存在的誤解是&#xff1a;“容器是完全隔離的&#xff0c;所以它是安全的。” 如果你也有同樣的想法&#xff0c;那么你需要重新審視容器安全了。…

騰訊開源WeKnora:新一代文檔理解與檢索框架

引言&#xff1a;文檔智能處理的新范式 在數字化時代&#xff0c;企業和個人每天都面臨著海量文檔的處理需求&#xff0c;從產品手冊到學術論文&#xff0c;從合同條款到醫療報告&#xff0c;非結構化文檔的高效處理一直是技術痛點。2025年8月&#xff0c;騰訊正式開源了基于大…

C++之list類的代碼及其邏輯詳解 (中)

接下來我會依照前面所說的一些接口以及list的結構來進行講解。1. list_node的結構1.1 list_node結構體list由于其結構為雙向循環鏈表&#xff0c;所以我們在這里要這么初始化_next&#xff1a;指向鏈表中下一個節點的指針_prev&#xff1a;指向鏈表中上一個節點的指針_val&…

新能源汽車熱管理仿真:蒙特卡洛助力神經網絡訓練

研究背景在新能源汽車的熱管理仿真研究中&#xff0c;神經網絡訓練技術常被應用于系統降階建模。通過這一方法&#xff0c;可以構建出高效準確的代理模型&#xff0c;進而用于控制策略的優化、系統性能的預測與評估&#xff0c;以及實時仿真等任務&#xff0c;有效提升開發效率…

第十九講:C++11第一部分

目錄 1、C11簡介 2、列表初始化 2.1、{}初始化 2.2、initializer_list 2.2.1、成員函數 2.2.2、應用 3、變量類型推導 3.1、auto 3.2、decltype 3.3、nullptr 4、范圍for 5、智能指針 6、STL的一些變化 7、右值引用和移動語義 7.1、右值引用 7.2、右值與左值引…

書寫本體論視域下的文字學理論重構

在符號學與哲學的交叉領域&#xff0c;文字學&#xff08;Grammatologie&#xff09;作為一門顛覆性學科始終處于理論風暴的中心。自德里達1967年發表《論文字學》以來&#xff0c;傳統語言學中"語音中心主義"的霸權地位遭遇根本性動搖&#xff0c;文字不再被視為語言…