Pytorch中加載自定義數據集 - VOC
其中需要pip install xmltodict
#voc_dataset.pyimport os
import torch
import xmltodict
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as transformsclass VOCDataset(Dataset): def __init__(self,img_dir,label_dir,transform,label_transform): #定義一些后面會用的參數 self.img_dir = img_dir #img地址 self.label_dir = label_dir #label文件地址 self.transform = transform #是否要做一些變換 self.label_transform = label_transform #是否要對label做一些變換self.img_names = os.listdir(self.img_dir) #os.listdir 獲取文件夾下的所有文件名稱,列表形式self.label_names = os.listdir(self.label_dir) #獲取label文件夾下的所有文件名稱 self.classes_list = ["no helmet","motor","number","with helmet"]#為了轉化標記為 : 0,1,2,3def __len__(self):return len(self.img_names) #返回照片文件的個數def __getitem__(self, index):img_name = self.img_names[index] #圖片列表[序號] 獲取文件名img_path = os.path.join(self.img_dir, img_name) #對地址進行拼接 獲取文件的路徑image = Image.open(img_path).convert('RGB') #通過文件地址打開文件,轉化為RGB三通道格式#new1.png -> new1.xml#new1.png -> [new1,png] -> new1 + ".xml"label_name = img_name.split('.')[0] + ".xml" #獲取標注的文件名label_path = os.path.join(self.label_dir, label_name) #拼接獲取標注的路徑with open(label_path, 'r',encoding="utf-8") as f: #打開標注文件label_content = f.read() #讀出標注文件所有的內容label_dict = xmltodict.parse(label_content) #因為內容是XML格式,xmltodict.parse 將內容轉化為 dict 格式target = [] #將要返回的數組,定義總體返回容器objects = label_dict["annotation"]["object"] #獲取dict里的標注對象for obj in objects: #獲取每個標注里面的信息obj_name = obj["name"]obj_class_id = self.classes_list.index(obj_name) #將標注的名字(no helmet)轉化為數字(0)obj_xmax = float(obj["bndbox"]["xmax"])obj_ymax = float(obj["bndbox"]["ymax"])obj_xmin = float(obj["bndbox"]["xmin"])obj_ymin = float(obj["bndbox"]["ymin"])target.extend([obj_class_id,obj_xmin,obj_ymin,obj_xmax,obj_ymax]) #將信息保存到總體返回容器target = torch.Tensor(target) #轉為tensor數據類型if self.transform is not None:image = self.transform(image) #對定義對象時寫的對image的操作return image,targetif __name__ == '__main__':train_dataset = VOCDataset(r"E:\HelmetDataset-VOC\train\images",r"E:\HelmetDataset-VOC\train\labels",transforms.Compose([transforms.ToTensor()]),None)print(len(train_dataset))print(train_dataset[11])
Pytorch中加載自定義數據集 - YOLO
如過VOC弄懂了的話,那這個代碼會非常簡單
#YOLO_dataset.pyimport os
import torchfrom PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as transformsclass YOLODataset(Dataset):def __init__(self,img_dir,label_dir,transform,label_transform): #定義一些后面會用的參數self.img_dir = img_dir #img地址self.label_dir = label_dir #label文件地址self.transform = transform #是否要做一些變換self.label_transform = label_transform #是否要對label做一些變換self.img_names = os.listdir(self.img_dir) #os.listdir 獲取文件夾下的所有文件名稱,列表形式self.label_names = os.listdir(self.label_dir) #獲取label文件夾下的所有文件名稱
# self.classes_list = ["no helmet","motor","number","with helmet"]#為了轉化標記為 : 0,1,2,3def __len__(self):return len(self.img_names) #返回照片文件的個數def __getitem__(self, index):img_name = self.img_names[index] #圖片列表[序號] 獲取文件名img_path = os.path.join(self.img_dir, img_name) #對地址進行拼接 獲取文件的路徑image = Image.open(img_path).convert('RGB') #通過文件地址打開文件,轉化為RGB三通道格式#new1.png -> new1.xml#new1.png -> [new1,png] -> new1 + ".txt"label_name = img_name.split('.')[0] + ".txt" #獲取標注的文件名label_path = os.path.join(self.label_dir, label_name) #拼接獲取標注的路徑with open(label_path, 'r',encoding="utf-8") as f: #打開標注文件label_content = f.read() #讀出標注文件所有的內容target = []object_infos = label_content.strip().split("\n")for object_info in object_infos:info_list = object_info.strip().split(" ")class_id = float(info_list[0])center_x = float(info_list[1])center_y = float(info_list[2])width = float(info_list[3])height = float(info_list[4])target.extend([class_id,center_x,center_y,width,height])# label_dict = xmltodict.parse(label_content) #因為內容是XML格式,xmltodict.parse 將內容轉化為 dict 格式# target = [] #將要返回的數組,定義總體返回容器# objects = label_dict["annotation"]["object"] #獲取dict里的標注對象# for obj in objects: #獲取每個標注里面的信息# obj_name = obj["name"]# obj_class_id = self.classes_list.index(obj_name) #將標注的名字(no helmet)轉化為數字(0)# obj_xmax = float(obj["bndbox"]["xmax"])# obj_ymax = float(obj["bndbox"]["ymax"])# obj_xmin = float(obj["bndbox"]["xmin"])# obj_ymin = float(obj["bndbox"]["ymin"])# target.extend([obj_class_id,obj_xmin,obj_ymin,obj_xmax,obj_ymax]) #將信息保存到總體返回容器target = torch.Tensor(target) #轉為tensor數據類型if self.transform is not None:image = self.transform(image) #對定義對象時寫的對image的操作return image,targetif __name__ == '__main__':train_dataset = YOLODataset(r"E:\HelmetDataset-YOLO\HelmetDataset-YOLO-Train\images", r"E:\HelmetDataset-YOLO\HelmetDataset-YOLO-Train\labels", transforms.Compose([transforms.ToTensor()]), None)print(len(train_dataset))print(train_dataset[11])
模型的 nn.model &模型的可視化
?#model.py
import torch
import torch.nn as nn
from torchvision import transformsfrom yolo_dataset import VOCDatasetclass TuduiModel(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(in_channels=3 , out_channels=20, kernel_size=5)self.conv2 = nn.Conv2d(in_channels=20, out_channels=20, kernel_size=5)def forward(self, x):x = torch.nn.functional.relu(self.conv1(x))return torch.nn.functional.relu(self.conv2(x))if __name__ == '__main__':model = TuduiModel()dataset = VOCDataset(r"E:\HelmetDataset-VOC\train\images",r"E:\HelmetDataset-VOC\train\labels",transforms.Compose([transforms.ToTensor(),transforms.Resize((512, 512)),]),None)img,target = dataset[0]output = model(img)# print(output)# print(model)torch.onnx.export(model,img,"tudui.onnx") #模型可視化
ONNX模型格式?
在環境中
pip install onnx
然后
torch.onnx.export(model,img,"tudui.onnx") #(模型,圖片,名字)再用瀏覽器打開 netron.app
把生成好的onnx文件拖進網頁
?