Pytorch入門實戰 P10-使用pytorch實現車牌識別

目錄

前言

一、MyDataset文件

二、完整代碼:

三、結果展示:

四、添加accuracy值


  • 🍨 本文為🔗365天深度學習訓練營?中的學習記錄博客
  • 🍖 原作者:K同學啊 | 接輔導、項目定制

本周的學習內容是,使用pytorch實現車牌識別。

前言

????????之前的案例里面,我們大多是使用的是datasets.ImageFolder函數,直接導入已經分類好的數據集形成Dataset,然后使用DataLoader加載Dataset,但是如果對無法分類的數據集,我們應該如何導入呢。

????????這篇文章主要就是介紹通過自定義的一個MyDataset加載車牌數據集并完成車牌識別。

一、MyDataset文件

數據文件是這樣的,沒有進行分類的。

# 加載數據文件
class MyDataset(data.Dataset):def __init__(self, all_labels, data_paths_str, transform):self.img_labels = all_labels  # 獲取標簽信息self.img_dir = data_paths_str  # 圖像目錄路徑self.transform = transform   # 目標轉換函數def __len__(self):return len(self.img_labels)   # 返回數據集的長度,即標簽的數量def __getitem__(self, index):image = Image.open(self.img_dir[index]).convert('RGB')   # 打開指定索引的圖像文件,并將其轉換為RGB模式label = self.img_labels[index]  # 獲取圖像對應的標簽if self.transform:image = self.transform(image)   # 如果設置了轉換函數,則對圖像進行轉換(如,裁剪、縮放、歸一化等)return image, label  # 返回圖像和標簽

二、完整代碼:

import pathlibimport matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image
from torch import nn
from torch.utils import data
from torchvision import transforms
import torch.nn.functional as F
from torch.autograd import Variable
import matplotlib as mpl
mpl.use('Agg')  # 在服務器上運行的時候,打開注釋device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)data_dir = './data'
data_dir = pathlib.Path(data_dir)data_paths = list(data_dir.glob('*'))
classNames = [str(path).split('/')[1].split('_')[1].split('.')[0] for path in data_paths]
# print(classNames)  # '滬G1CE81', '云G86LR6', '鄂U71R9F', '津G467JR'....data_paths_str = [str(path) for path in data_paths]# 數據可視化
plt.rcParams['axes.unicode_minus'] = False
plt.figure(figsize=(14,5))
plt.suptitle('data show', fontsize=15)
for i in range(18):plt.subplot(3, 6, i+1)# 顯示圖片images = plt.imread(data_paths_str[i])plt.imshow(images)
plt.show()# 3、標簽數字化
char_enum = ["京","滬","津","渝","冀","晉","蒙","遼","吉","黑","蘇","浙","皖","閩","贛","魯","豫","鄂","湘","粵","桂","瓊","川","貴","云","藏","陜","甘","青","寧","新","軍","使"]number = [str(i) for i in range(0, 10)]  # 0-9 的數字
alphabet = [chr(i) for i in range(65, 91)]  # A到Z的字母
char_set = char_enum + number + alphabet
char_set_len = len(char_set)
label_name_len = len(classNames[0])# 將字符串數字化
def text2vec(text):vector = np.zeros([label_name_len, char_set_len])for i, c in enumerate(text):idx = char_set.index(c)vector[i][idx] = 1.0return vectorall_labels = [text2vec(i) for i in classNames]# 加載數據文件
class MyDataset(data.Dataset):def __init__(self, all_labels, data_paths_str, transform):self.img_labels = all_labels  # 獲取標簽信息self.img_dir = data_paths_str  # 圖像目錄路徑self.transform = transform   # 目標轉換函數def __len__(self):return len(self.img_labels)def __getitem__(self, index):image = Image.open(self.img_dir[index]).convert('RGB')label = self.img_labels[index]  # 獲取圖像對應的標簽if self.transform:image = self.transform(image)return image, label  # 返回圖像和標簽total_datadir = './data/'
train_transforms = transforms.Compose([transforms.Resize([224, 224]),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
])total_data = MyDataset(all_labels, data_paths_str, train_transforms)
# 劃分數據
train_size = int(0.8*len(total_data))
test_size = len(total_data) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(total_data,[train_size, test_size])
print(train_size, test_size)  # 10940 2735# 數據加載
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16, shuffle=True)for X, y in test_loader:print('Shape of X [N,C,H,W]:',X.shape)   # ([16, 3, 224, 224])print('Shape of y:', y.shape, y.dtype)    # torch.Size([16, 7, 69]) torch.float64break# 搭建網絡模型
class Network_bn(nn.Module):def __init__(self):super(Network_bn, self).__init__()"""nn.Conv2d()函數:第一個參數(in_channels)是輸入的channel數量第二個參數(out_channels)是輸出的channel數量第三個參數(kernel_size)是卷積核大小第四個參數(stride)是步長,默認為1第五個參數(padding)是填充大小,默認為0"""self.conv1 = nn.Conv2d(in_channels=3, out_channels=12, kernel_size=5, stride=1, padding=0)self.bn1 = nn.BatchNorm2d(12)self.conv2 = nn.Conv2d(in_channels=12, out_channels=12, kernel_size=5, stride=1, padding=0)self.bn2 = nn.BatchNorm2d(12)self.pool = nn.MaxPool2d(2, 2)self.conv4 = nn.Conv2d(in_channels=12, out_channels=24, kernel_size=5, stride=1, padding=0)self.bn4 = nn.BatchNorm2d(24)self.conv5 = nn.Conv2d(in_channels=24, out_channels=24, kernel_size=5, stride=1, padding=0)self.bn5 = nn.BatchNorm2d(24)self.fc1 = nn.Linear(24 * 50 * 50, label_name_len * char_set_len)self.reshape = Reshape([label_name_len, char_set_len])def forward(self, x):x = F.relu(self.bn1(self.conv1(x)))x = F.relu(self.bn2(self.conv2(x)))x = self.pool(x)x = F.relu(self.bn4(self.conv4(x)))x = F.relu(self.bn5(self.conv5(x)))x = self.pool(x)x = x.view(-1, 24 * 50 * 50)x = self.fc1(x)# 最終reshapex = self.reshape(x)return xclass Reshape(nn.Module):def __init__(self,shape):super(Reshape, self).__init__()self.shape = shapedef forward(self, x):return x.view(x.size(0), *self.shape)model = Network_bn().to(device)
print(model)# 優化器與損失函數
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=0.0001)
loss_model = nn.CrossEntropyLoss()def test(model, test_loader, loss_model):size = len(test_loader.dataset)num_batches = len(test_loader)model.eval()test_loss, correct = 0, 0with torch.no_grad():for X, y in test_loader:X, y = X.to(device), y.to(device)pred = model(X)test_loss += loss_model(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchesprint(f'Avg loss: {test_loss:>8f}\n')return correct, test_lossdef train(model,train_loader, loss_model, optimizer):model = model.to(device)model.train()for i, (images, labels) in enumerate(train_loader, 0):   # 0 是標起始位置的值images = Variable(images.to(device))labels = Variable(labels.to(device))optimizer.zero_grad()outputs = model(images)loss = loss_model(outputs, labels)loss.backward()optimizer.step()if i % 100 == 0:print('[%5d] loss: %.3f' % (i, loss))# 模型的訓練
test_acc_list = []
test_loss_list = []
epochs = 30
for t in range(epochs):print(f"Epoch {t+1}\n-----------------------")train(model,train_loader, loss_model,optimizer)test_acc,test_loss = test(model, test_loader, loss_model)test_acc_list.append(test_acc)test_loss_list.append(test_loss)print('Done!!!')# 結果分析
x = [i for i in range(1,31)]
plt.plot(x, test_loss_list, label="Loss", alpha = 0.8)
plt.xlabel('Epoch')
plt.ylabel('Loss')plt.legend()
plt.show()
plt.savefig("/data/jupyter/deep_demo/p10_car_number/resultImg.jpg")  # 保存圖片在服務器的位置
plt.show()

三、結果展示:

總結:從剛開始損失為0.077 到,訓練30輪后,損失到了0.026。

四、添加accuracy值

需求:對在上面的代碼中,對loss進行了統計更新,請補充acc統計更新部分,即獲取每一次測試的ACC值。

添加accuracy的運行過程:

?

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/13168.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/13168.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/13168.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

SEO:搜索引擎蜘蛛名稱UA(user-agent)

最近網站在做統計功能,想著統計下蜘蛛爬行記錄,看看都有哪些搜索引擎蜘蛛經常關顧,故而好進行相應的對策改變。都知道搜索引擎對一個網站很重要,是很多網站重要的流量來源。熟悉各大搜索引擎的蜘蛛就顯得必要。 做SEO優化的通常會說蜘蛛爬得越…

國網698.45報文解析工具

本文分享一個698.45協議的報文解析工具,此報文解析工具功能強大,可以解析多種國網數據協議。 下載鏈接: https://pan.baidu.com/s/1ngbBG-yL8ucRWLDflqzEnQ 提取碼: y1de 主要界面如下: 本工具內置698.45數據協議, 即可調用word…

win編寫bat腳本啟動java服務

新建txt,編寫,前臺啟動,出現cmd黑窗口 echo off start java -jar zhoao1.jar start java -jar zhoao2.jar pause完成后,重命名.bat 1、后臺啟動,不出現cmd黑窗口,app是窗口名稱 echo off start "名…

美團小程序mtgsig1.2逆向

聲明 本文章中所有內容僅供學習交流使用,不用于其他任何目的,抓包內容、敏感網址、數據接口等均已做脫敏處理,嚴禁用于商業用途和非法用途,否則由此產生的一切后果均與作者無關!wx a15018601872 本文章未…

VMware虛擬機沒有網,無法設置網絡為橋接狀態

今天需要使用Ubuntu18但現有虛擬機是Ubuntu20,由于硬盤空間不夠大,所以刪除了原來的虛擬機并重新搭建Ubuntu18的環境,然后發現虛擬機沒有網絡,而我之前的虛擬機這一切都是正常的。 在網絡設置里勾選的是橋接模式但無法聯網&#x…

由讀寫arrow引發的對時間時區的思考

arrow是apache開發的一種高壓縮的數據結構,發現用來存儲K線還是很不錯的選擇。 測試用python讀寫很方便,關鍵是足夠小,A股1支票1分鐘的數據,1個月大約是140多K吧。 結果從數據庫取出來存入arrow中,再用C進行讀取&…

Cow Exhibition G的來龍去脈

[USACO03FALL] Cow Exhibition G - 洛谷 曲折經過 爆搜 一開始沒什么好的想法&#xff0c;就針對每頭奶牛去or不去進行了爆搜。 #include <cstdio> #include <algorithm> using namespace std;#define maxn 405 int iq[maxn], eq[maxn]; int ans; int n;void d…

留學資訊 | 2024英國學生簽證申請需要滿足哪些條件?

英國移民局于2020年9月10日發布了《移民規則變更聲明: HC 707》&#xff0c;對學生簽證制度進行了全面改革。該法案于2020年10月5日正式生效。根據此法案&#xff0c;新的學生簽證——The Student and Child Student Routes學生和兒童學生路線&#xff0c;將替代原先的Tier 4學…

短視頻賽道有哪些:成都鼎茂宏升文化傳媒公司

短視頻賽道有哪些&#xff1a;探索多元化的內容領域 隨著科技的飛速發展和人們生活節奏的加快&#xff0c;短視頻已成為現代人生活中不可或缺的一部分。它以其簡短、直觀、易于分享的特點&#xff0c;迅速占領了各個年齡層和社會群體的心智。然而&#xff0c;短視頻的賽道并非…

層次式體系結構概述

1.軟件體系結構 軟件體系結構可定義為&#xff1a;軟件體系結構為軟件系統提供了結構、行為和屬性的高級抽象&#xff0c;由構成系統的元素描述、這些元素的相互作用、指導元素集成的模式以及這些模式的約束組成。軟件體系結構不僅指定了系統的組織結構和拓撲結構&#xff0c;并…

小程序框架是智能融媒體平臺構建的最佳線路

過去5年&#xff0c;媒體行業一直都在進行著信息化建設向融媒體平臺建設的轉變。一些融媒體的建設演變總結如下&#xff1a; 新聞終端的端側內容矩陣建設&#xff0c;如App新聞端&#xff0c;社交平臺上的官方媒體等新聞本地生活雙旗艦客戶端&#xff0c;兼顧主流媒體核心宣傳…

TopOn 正式聚合Kwai 旗下程序化廣告平臺——Kwai Network

**我們非常高興的宣布&#xff0c;TopOn SDK 近日已正式聚合Kwai Network。**作為Kwai 旗下的程序化廣告平臺&#xff0c;Kwai Network 通過優質的變現能力及產品能力&#xff0c;為廣大開發者提供高效及時的服務。 TopOn 聚合平臺與Kwai Network 正式完成接入后&#xff0c;開…

實戰+代碼!Selenium + Phantom JS爬取天天基金數據

功能&#xff1a; 通過程序實現從基金列表頁&#xff0c;獲取指定頁數內所有基金的近一周收益率以及每支基金的詳情頁鏈接。再進入每支基金的詳情頁獲取其余的基金信息&#xff0c;將所有獲取到的基金詳細信息按近6月收益率倒序排列寫入一個Excel表格。 思路&#xff1a; 1.…

vm 虛擬機 Debian12 開啟 root、ssh 登錄功能

前言&#xff0c;安裝的時候語言就選中文就好了。選擇中文&#xff0c;在安裝的時候就可以選擇國內 163 的源。 開啟 ssh 功能 先提權&#xff0c;用 root 賬戶 su安裝 ssh 安裝 ssh-server apt install openssh-server啟動 ssh systemctl start ssh查看 ssh 狀態 systemctl st…

【Flutter 面試題】 如何讓圖片重復堆疊容器?

【Flutter 面試題】 如何讓圖片重復堆疊容器? 文章目錄 寫在前面口述回答補充說明寫在前面 ?? 關于我 ,小雨青年 ?? CSDN博客專家,GitChat專欄作者,阿里云社區專家博主,51CTO專家博主。2023博客之星TOP153。 ???? 正在學 Flutter 的同學,你好! ?? Flutter …

根據web訪問日志,封禁請求量異常的IP,如IP在半小 時后恢復正常則解除封禁

在網絡安全日益受到重視的今天&#xff0c;如何有效防范惡意流量和攻擊成為了每個網站管理員必須面對的問題。惡意流量不僅會影響網站的正常運行&#xff0c;還可能導致服務器崩潰&#xff0c;給網站帶來不可估量的損失。為了應對這一問題&#xff0c;我們特別推出了一款實用的…

u3d的ab文件注意事項

//----------------LoadAllAB.cs--------------------- using System.Collections;using UnityEngine;namespace System.IO{public class LoadAllAB : MonoBehaviour{ //讀取本地string path "Assets/Actors/lznh/ab/animation/t_bl/";// Use this for initiali…

SQL注入之數據庫基礎

數據庫基礎 創建數據庫 create 數據庫名稱;創建表 create table if not exists mobile(ID int(10) primary key auto_increment comment 手機編號 主鍵自增,Brand varchar(50) not null comment 手機品牌 非空約束,Model varchar(50) not null comment 手機型號 非空約束,Pr…

Keil手動安裝編譯器V5版本

V5編譯器下載&#xff1a;免積分下載 新版的keil不會自動幫你安裝V5版本的編譯器&#xff0c;但是很多教程很多比賽所用單片機都是V5的編譯器&#xff0c;所以用來開以前的或者開源的很多東西編譯直接一大堆報錯。 吐槽說完了接下來教你怎么解決 打開installer&#xff08;在…

vue使用postcss-pxtorem實現自適應

安裝&#xff1a; npm install postcss-pxtorem -Dvue.config.js文件設置&#xff1a; css: {loaderOptions: {scss: {additionalData: import "~element-ui/packages/theme-chalk/src/common/var.scss";import"/styles/variables.scss";},postcss: {po…