數據集下載
MINST_PNG_Training在github的項目目錄中的datasets
中有MNIST的png格式數據集的壓縮包
用于訓練的神經網絡模型
自定義數據集訓練
在前文【Pytorch】13.搭建完整的CIFAR10模型我們已經知道了基本搭建神經網絡的框架了,但是其中的數據集使用的torchvision
中的CIFAR10
官方數據集進行訓練的
train_dataset = torchvision.datasets.CIFAR10('../datasets', train=True, download=True,transform=torchvision.transforms.ToTensor())
test_dataset = torchvision.datasets.CIFAR10('../datasets', train=False, download=True,transform=torchvision.transforms.ToTensor())
本文將用圖片格式的數據集進行訓練
我們通過
# Dataset CIFAR10
# Number of datapoints: 60000
# Root location: ../datasets
# Split: Train
# StandardTransform
# Transform: ToTensor()
print(train_dataset)
可以看到我們下載的數據集是這種格式的,所以我們的主要問題就是如何將自定義的數據集獲取,并且轉化為這種形式,剩下的步驟就和上文相同了
數據類型進行轉化
我們的首要目的是,根據數據集的地址,分別將數據轉化為train_dataset
與test_dataset
我們需要調用ImageFolder
方法來進行操作
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from torchvision.datasets import ImageFolder
from model import *# 訓練集地址
train_root = "../datasets/mnist_png/training"
# 測試集地址
test_root = '../datasets/mnist_png/testing'# 進行數據的處理,定義數據轉換
data_transform = transforms.Compose([transforms.Resize((28, 28)),transforms.Grayscale(),transforms.ToTensor()])# 加載數據集
train_dataset = ImageFolder(train_root, transform=data_transform)
test_dataset = ImageFolder(test_root, transform=data_transform)
首先我們需要將數據進行處理,通過transforms.Compose
獲取對象data_transform
其中進行了三步操作
- 將圖片大小變為
28*28像素
便于輸入網絡模型 - 將圖片轉化為灰度格式,因為手寫數字識別不需要
三通道
的圖片,只需要灰度圖像就可以識別,而png格式
的圖片是四通道
的 - 將圖片轉化為
tensor
數據類型
然后通過ImageFolder
給出圖片的地址與轉化類型,就可以實現與我們在官方下載數據集相同的格式
# Dataset ImageFolder
# Number of datapoints: 60000
# Root location: ../datasets/mnist_png/training
# StandardTransform
# Transform: Compose(
# Resize(size=(28, 28), interpolation=bilinear, max_size=None, antialias=True)
# ToTensor()
# )
print(train_dataset)
其他與前文【Pytorch】13.搭建完整的CIFAR10模型基本相同
完整代碼
網絡模型
import torch
from torch import nnclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)self.relu1 = nn.ReLU()self.pool1 = nn.MaxPool2d(2, stride=2)self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)self.relu2 = nn.ReLU()self.pool2 = nn.MaxPool2d(2, stride=2)self.flatten = nn.Flatten()self.fc1 = nn.Linear(3136, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = self.conv1(x)x = self.relu1(x)x = self.pool1(x)x = self.conv2(x)x = self.relu2(x)x = self.pool2(x)x = self.flatten(x)x = self.fc1(x)x = self.fc2(x)return xif __name__ == "__main__":model = Net()input = torch.ones((1, 1, 28, 28))output = model(input)print(output.shape)
訓練過程
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from torchvision.datasets import ImageFolder
from model import *# 訓練集地址
train_root = "../datasets/mnist_png/training"
# 測試集地址
test_root = '../datasets/mnist_png/testing'# 進行數據的處理,定義數據轉換
data_transform = transforms.Compose([transforms.Resize((28, 28)),transforms.Grayscale(),transforms.ToTensor()])# 加載數據集
train_dataset = ImageFolder(train_root, transform=data_transform)
test_dataset = ImageFolder(test_root, transform=data_transform)# Dataset ImageFolder
# Number of datapoints: 60000
# Root location: ../datasets/mnist_png/training
# StandardTransform
# Transform: Compose(
# Resize(size=(28, 28), interpolation=bilinear, max_size=None, antialias=True)
# ToTensor()
# )
# print(train_dataset)# print(train_dataset[0])train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=True)device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")model = Net().to(device)
loss_fn = nn.CrossEntropyLoss().to(device)
learning_rate = 0.01
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)epoch = 10writer = SummaryWriter('../logs')
total_step = 0for i in range(epoch):model.train()pre_step = 0pre_loss = 0for data in train_loader:images, labels = dataimages = images.to(device)labels = labels.to(device)optimizer.zero_grad()outputs = model(images)loss = loss_fn(outputs, labels)loss.backward()optimizer.step()pre_loss = pre_loss + loss.item()pre_step += 1total_step += 1if pre_step % 100 == 0:print(f"Epoch: {i+1} ,pre_loss = {pre_loss/pre_step}")writer.add_scalar('train_loss', pre_loss / pre_step, total_step)model.eval()pre_accuracy = 0with torch.no_grad():for data in test_loader:images, labels = dataimages = images.to(device)labels = labels.to(device)outputs = model(images)pre_accuracy += outputs.argmax(1).eq(labels).sum().item()print(f"Test_accuracy: {pre_accuracy/len(test_dataset)}")writer.add_scalar('test_accuracy', pre_accuracy / len(test_dataset), i)torch.save(model, f'../models/model{i}.pth')writer.close()
參考文章
【CNN】搭建AlexNet網絡——并處理自定義的數據集(貓狗分類)
How to download MNIST images as PNGs