之前寫過加載數據集的一些小筆記,這里詳細內容就不再敘述了
詳細學習可以參考該博文二、PyTorch加載數據
一、分析
因為U-net網絡架構是輸入1通道,大小為(572,572)的灰度圖,圖片大小無所謂,我的思路是將三通道的圖像使用OpenCV進行相關的處理,轉換為單通道突圖片,之后再送入網絡模型中
二、準備數據集
數據集的采集和制作可以參考該篇博文
四、采集和制作數據集
三、完整加載數據集代碼
test_dataset
import torch
import cv2
import os
import glob
from torch.utils.data import Dataset
import randomclass Beyond_loader(Dataset):def __init__(self, data_path):# 初始化函數,讀取所有data_path下的圖片self.data_path = data_pathself.imgs_path = glob.glob(os.path.join(data_path, 'image/*.png'))def augment(self, image, flipCode):# 使用cv2.flip進行數據增強,filpCode為1水平翻轉,0垂直翻轉,-1水平+垂直翻轉flip = cv2.flip(image, flipCode)return flipdef __getitem__(self, index):# 根據index讀取圖片image_path = self.imgs_path[index]# 根據image_path生成label_pathlabel_path = image_path.replace('image', 'label')# 讀取訓練圖片和標簽圖片image = cv2.imread(image_path)label = cv2.imread(label_path)# 將數據轉為單通道的圖片image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)label = cv2.cvtColor(label, cv2.COLOR_BGR2GRAY)image = image.reshape(1, image.shape[0], image.shape[1])label = label.reshape(1, label.shape[0], label.shape[1])# 處理標簽,將像素值為255的改為1if label.max() > 1:label = label / 255# 隨機進行數據增強,為2時不做處理flipCode = random.choice([-1, 0, 1, 2])if flipCode != 2:image = self.augment(image, flipCode)label = self.augment(label, flipCode)return image, labeldef __len__(self):# 返回訓練集大小return len(self.imgs_path)if __name__ == "__main__":beyond_loader = Beyond_loader("./dataset/train")print("數據個數:", len(beyond_loader))train_loader = torch.utils.data.DataLoader(dataset=beyond_loader,batch_size=1,shuffle=True)for image, label in train_loader:print(image.shape)
一共有6張圖像,batch_size設為1,故train_loader有6組
數據個數: 6
torch.Size([1, 1, 320, 320])
torch.Size([1, 1, 320, 320])
torch.Size([1, 1, 320, 320])
torch.Size([1, 1, 320, 320])
torch.Size([1, 1, 320, 320])
torch.Size([1, 1, 320, 320])