【深度學習】——利用pytorch搭建一個完整的深度學習項目(構建模型、加載數據集、參數配置、訓練、模型保存、預測)

目錄

?一、深度學習項目的基本構成

二、實戰(貓狗分類)

1、數據集下載

2、dataset.py文件

3、model.py

4、config.py

5、predict.py


?一、深度學習項目的基本構成

一個深度學習模型一般包含以下幾個文件:

datasets文件夾:存放需要訓練和測試的數據集

dataset.py:加載數據集,將數據集轉換為固定的格式,返回圖像集和標簽集

model.py:根據自己的需求搭建一個深度學習模型,具體搭建方法參考

【深度學習】——pytorch搭建模型及相關模型icon-default.png?t=L892https://blog.csdn.net/qq_45769063/article/details/120246601config.py:將需要配置的參數均放在這個文件中,比如batchsize,transform,epochs,lr等超參數

train.py:加載數據集,訓練

predict.py:加載訓練好的模型,對圖像進行預測

requirements.txt:一些需要的庫,通過pip install -r requirements.txt可以進行安裝

readme:記錄一些log

log文件:存放訓練好的模型

loss文件夾:存放訓練記錄的loss圖像

二、實戰(貓狗分類)

1、數據集下載

下載數據

  • 訓練數據

鏈接: https://pan.baidu.com/s/1UOJUi-Wm6w0D7JGQduq7Ow 提取碼: 485q

  • 測試數據

鏈接: https://pan.baidu.com/s/1sSgLFkv9K3ciRVLAryWUKg 提取碼: gyvs

下載好之后解壓,可以發現訓練數據以cat或dog開頭,測試數據都以數字命名。

這里我重命名了,cats以0開始,dogs以1開始

創建dataset文件夾

一般習慣這樣構造目錄,直接人為劃分三個數據集,當然也可以用程序進行劃分

?

2、dataset.py文件

主要是繼承dataset類,然后在__getitem__方法中編寫代碼,得到一個可以通過字典key來取值的實例化對象

## 導入模塊
from torch.utils.data import DataLoader,Dataset
from skimage import io,transform
import matplotlib.pyplot as plt
import os
import torch
from torchvision import transforms, utils
from PIL import Image
import pandas as pd
import numpy as np
#過濾警告信息
import warnings
warnings.filterwarnings("ignore")class MyDataset(Dataset):  # 繼承Datasetdef __init__(self, path_dir, transform=None,train=True,test=True,val=True):  # 初始化一些屬性self.path_dir = path_dir  # 文件路徑,如'.\data\cat-dog'self.transform = transform  # 對圖形進行處理,如標準化、截取、轉換等self.images = os.listdir(self.path_dir)  # 把路徑下的所有文件放在一個列表中self.train = trainself.test = testself.val = valif self.test:self.images = os.listdir(self.path_dir + r"\cats")self.images.extend(os.listdir(self.path_dir+r"\dogs"))if self.train:self.images = os.listdir(self.path_dir + r"\cats")self.images.extend(os.listdir(self.path_dir+r"\dogs"))if self.val:self.images = os.listdir(self.path_dir + r"\cats")self.images.extend(os.listdir(self.path_dir+r"\dogs"))def __len__(self):  # 返回整個數據集的大小return len(self.images)def __getitem__(self, index):  # 根據索引index返回圖像及標簽image_index = self.images[index]  # 根據索引獲取圖像文件名稱if image_index[0] == "0":img_path = os.path.join(self.path_dir,"cats", image_index)  # 獲取圖像的路徑或目錄else:img_path = os.path.join(self.path_dir,"dogs", image_index)  # 獲取圖像的路徑或目錄img = Image.open(img_path).convert('RGB')  # 讀取圖像# 根據目錄名稱獲取圖像標簽(cat或dog)# 把字符轉換為數字cat-0,dog-1label = 0 if image_index[0] == "0" else 1if self.transform is not None:img = self.transform(img)# print(type(img))# print(img.size)return img, label

3、model.py

模型是在VGG16的基礎上進行修改的,主要是增加了一層卷積層和兩層全連接層,將輸入的圖像resize成448,448大小

from torch import nnclass VGG19(nn.Module):def __init__(self, num_classes=2):super(VGG19, self).__init__()  # 繼承父類屬性和方法# 根據前向傳播的順序,搭建各個子網絡模塊## 十四個卷積層,每個卷積模塊都有卷積層、激活層和池化層,用nn.Sequential()這個容器將各個模塊存放起來# [1,3,448,448]self.conv0 = nn.Sequential(nn.Conv2d(3, 32, (3, 3), (1, 1), (1, 1)),nn.ReLU(inplace=True),  # inplace = True表示是否進行覆蓋計算nn.MaxPool2d((2, 2), (2, 2)))# [1,32,224,224]self.conv1 = nn.Sequential(nn.Conv2d(32, 64, (3, 3), (1, 1), (1, 1)),nn.ReLU(inplace=True),  # inplace = True表示是否進行覆蓋計算)# [1,64,224,224]self.conv2 = nn.Sequential(nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)),nn.ReLU(inplace=True),  # inplace = True表示是否進行覆蓋計算nn.MaxPool2d((2, 2), (2, 2)))# [1,64,112,112]self.conv3 = nn.Sequential(nn.Conv2d(64, 128, (3, 3), (1, 1), (1, 1)),nn.ReLU(inplace=True),  # inplace = True表示是否進行覆蓋計算)# [1,128,112,112]self.conv4 = nn.Sequential(nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)),nn.ReLU(inplace=True),  # inplace = True表示是否進行覆蓋計算nn.MaxPool2d((2, 2), (2, 2)))# [1,128,56,56]self.conv5 = nn.Sequential(nn.Conv2d(128, 256, (3, 3), (1, 1), (1, 1)),nn.ReLU(inplace=True),  # inplace = True表示是否進行覆蓋計算)# [1,256,56,56]self.conv6 = nn.Sequential(nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),nn.ReLU(inplace=True),  # inplace = True表示是否進行覆蓋計算)# [1,256,56,56]self.conv7 = nn.Sequential(nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),nn.ReLU(inplace=True),  # inplace = True表示是否進行覆蓋計算nn.MaxPool2d((2, 2), (2, 2)))# [1,256,28,28]self.conv8 = nn.Sequential(nn.Conv2d(256, 512, (3, 3), (1, 1), (1, 1)),nn.ReLU(inplace=True))# [1,512,28,28]self.conv9 = nn.Sequential(nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1)),nn.ReLU(inplace=True))# [1,512,28,28]self.conv10 = nn.Sequential(nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1)),nn.ReLU(inplace=True),nn.MaxPool2d((2, 2), (2, 2)))# [1,512,14,14]self.conv11 = nn.Sequential(nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1)),nn.ReLU(inplace=True),)# [1,512,14,14]self.conv12 = nn.Sequential(nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1)),nn.ReLU(inplace=True),)# [1,512,14,14]-->[1,512,7,7]self.conv13 = nn.Sequential(nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1)),nn.ReLU(inplace=True),nn.MaxPool2d((2, 2), (2, 2)))# 五個全連接層,每個全連接層之間存在激活層和dropout層self.classfier = nn.Sequential(# [1*512*7*7]nn.Linear(1 * 512 * 7 * 7, 4096),nn.ReLU(True),nn.Dropout(),# 4096nn.Linear(4096, 4096),nn.ReLU(True),nn.Dropout(),# 4096-->1000nn.Linear(4096, 1000),nn.ReLU(True),nn.Dropout(),# 1000-->100nn.Linear(1000, 100),nn.ReLU(True),nn.Dropout(),nn.Linear(100, num_classes),nn.Softmax(dim=1))# 前向傳播函數def forward(self, x):# 十四個卷積層x = self.conv0(x)x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = self.conv4(x)x = self.conv5(x)x = self.conv6(x)x = self.conv7(x)x = self.conv8(x)x = self.conv9(x)x = self.conv10(x)x = self.conv11(x)x = self.conv12(x)x = self.conv13(x)# 將圖像扁平化為一維向量,[1,512,7,7]-->1*512*7*7x = x.view(x.size(0), -1)# 三個全連接層output = self.classfier(x)return outputif __name__ == '__main__':import torchnet = VGG19()print(net)input = torch.randn([1,3,448,448])output = net(input)print(output)

4、config.py

from torchvision import transforms as T# 數據集準備
trainFlag = True
valFlag = True
testFlag = Falsetrainpath = r".\datasets\train"
testpath = r".\datasets\test"
valpath = r".\datasets\val"transform_ = T.Compose([T.Resize(448),  # 縮放圖片(Image),保持長寬比不變,最短邊為224像素T.CenterCrop(448),  # 從圖片中間切出224*224的圖片T.ToTensor(),  # 將圖片(Image)轉成Tensor,歸一化至[0, 1]T.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])  # 標準化至[-1, 1],規定均值和標準差
])# 訓練相關參數
batchsize = 2
lr = 0.001
epochs = 100

5、predict.py

?加載訓練好的模型,對圖像進行預測

from pytorch.Cats_Dogs.model import VGG19
from PIL import Image
import torch
from pytorch.Cats_Dogs.configs import transform_def predict_(model, img):# 將輸入的圖像從array格式轉為imageimg = Image.fromarray(img)# 自己定義的pytorch transform方法img = transform_(img)# .view()用來增加一個維度# 我的圖像的shape為(1, 64, 64)# channel為1,H為64, W為64# 因為訓練的時候輸入的照片的維度為(batch_size, channel, H, W) 所以需要再新增一個維度# 增加的維度是batch size,這里輸入的一張圖片,所以為1img = img.view(1, 1, 64, 64)output = model(img)_, prediction = torch.max(output, 1)# 將預測結果從tensor轉為array,并抽取結果prediction = prediction.numpy()[0]return predictionif __name__ == '__main__':img_path = r"*.jpg"img = Image.open(img_path).convert('RGB')  # 讀取圖像device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')# device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')model = VGG19()# save_path,和模型的保存那里的save_path一樣# .eval() 預測結果前必須要做的步驟,其作用為將模型轉為evaluation模式# Sets the module in evaluation mode.model.load_state_dict(torch.load("*.pth"))model.eval()pred = predict_(model,img)print(pred)

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/255764.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/255764.shtml
英文地址,請注明出處:http://en.pswp.cn/news/255764.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

二叉樹的序遍歷

時間限制: 1 s空間限制: 32000 KB題目等級 : 白銀 Silver題目描述 Description求一棵二叉樹的前序遍歷,中序遍歷和后序遍歷 輸入描述 Input Description第一行一個整數n,表示這棵樹的節點個數。 接下來n行每行2個整數L和R。第i行的兩個整數Li和Ri代表編號…

GUI登錄界面

在這次的作業中,我先使用單選按鈕,輸入框,復選框設計了一個簡單地登錄界面。接著我使用了MouseListener將登陸按鈕與下一個“查詢界面”連接起來。最后我使用了我們本周所學的JFrame框架與事件處理機制設計了一個簡單地界面。我所設計的登錄界…

淺談ROS操作系統及其應用趨勢

ROS操作系統是最先由斯坦福開發的開源機器人操作系統,目前由willowgarage公司開發和維護,相關的開發社區也很成熟( http://www.ros.org , http://answers.ros.org, http://www.willowgarage.com), 經過幾年的發展API也逐漸穩定&a…

Raft學習傳送門

Raft官網 官方可視化動畫1 官方可視化動畫2 論文中文翻譯 論文英文地址 Paxos Made Simple論文翻譯 Raft理解 技術分享 《分布式一致性raft算法實現原理》 狀態機 MIT: raft實現 分布式系統學習2-Raft算法分析與實現 分布式系統MIT 6.824學習資源 知乎大神的&#…

【Python生成器與迭代器的區別】

目錄 一、迭代 二、迭代器 1)創建迭代器——兩種方法 iter()方法 利用()和range結合使用 2)具體案例 3、生成器 4、二者的異同 1)、共同點 2)、不同點 a、語法上 b、用法上 一、迭代 首先理解一…

BZOJ4426 : [Nwerc2015]Better Productivity最大生產率

如果一個區間包含另一個區間,那么這兩個區間是否在一起的生產率是一樣的。 將所有這種包含了其他區間的區間放入數組$b$,其余的放入數組$c$,有多個相同的時候則從$b$移一個到$c$。 那么$c$里所有區間左端點遞增,右端點也遞增&…

[codevs1105][COJ0183][NOIP2005]過河

[codevs1105][COJ0183][NOIP2005]過河 試題描述 在河上有一座獨木橋,一只青蛙想沿著獨木橋從河的一側跳到另一側。在橋上有一些石子,青蛙很討厭踩在這些石子上。由于橋的長度和青蛙一次跳過的距離都是正整數,我們可以把獨木橋上青蛙可能到達的…

ABB機器人套接口通信 機器人部分

ABB機器人套接口通信 機器人部分 文章機器人部分,描述如何運行機器人從機程序,使機器人根據上位機節點發送的命令,執行具體運動。 ABB機器人執行3個任務。這些任務都配置為SEMISTATIC(背景程序)的任務,第三任務是正常動作任務。下文描述如…

CRM項目總結

CRM項目總結 一:開發背景 在公司日益擴大的過程中,不可避免的會伴隨著更多問題出現。 對外 : 如何更好的管理客戶與公司的關系?如何更及時的了解客戶日益發展的需求變化?公司的產品是否真的符合客戶需求?以…

【面經——《速騰聚創科技有限公司——深度學習算法工程師》】

自我介紹 實習項目 1)項目主要應用的領域? 2)難點在哪?——機械臂吸盤大小和目標大小之間坐標的協調 3)難點不在于算法,在于數據的處理和均衡性?對于數據均衡方面有什么理解&#xf…

js變量和數據類型

轉載于:https://www.cnblogs.com/songyinan/p/6181421.html

offline .net3.5

1.加載虛擬光驅 2.dism.exe /online /enable-feature /featurename:netfx3 /Source:D:\sources\sxs轉載于:https://www.cnblogs.com/BillLei/p/5294082.html

(九)模板方法模式詳解(包含與類加載器不得不說的故事)

作者:zuoxiaolong8810(左瀟龍),轉載請注明出處,特別說明:本博文來自博主原博客,為保證新博客中博文的完整性,特復制到此留存,如需轉載請注明新博客地址即可。 模板方法模…

阿里云openapi接口使用,PHP,視頻直播

1.下載sdk放入項目文件夾中 核心就是aliyun-php-sdk-core,它的配置文件會自動加載相應的類 2.引入文件 include_once LIB_PATH . ORG/aliyun-openapi/aliyun-php-sdk-core/Config.php; 3.配置客戶端對象,需要Access Key ID,Access Key Secret $iClientProfile Defa…

【面經——《廣州敏視數碼科技有限公司》——圖像處理算法工程師-深度學習方向】

目錄 筆試 HR面 專業面——60多分鐘 主管面 反問: 筆試 8道題——簡答題 1道編程 蘋果、香蕉、梨、菠蘿,彩色圖像如何進行分類?一輛帶車牌的汽車,圖像亮度整體呈現偏亮狀態,如何…

Android之網絡編程利用PHP操作MySql插入數據(四)

因為最近在更新我的項目,就想著把自己在項目中用到的一些的簡單的與網絡交互的方法總結一下,所以最近Android網絡編程方面的博文會比較多一些,我盡量以最簡單的方法給大家分享,讓大家明白易懂。如果有什么不對的地方,還…

RAPID 信號的互鎖和同步 WaitTestAndSet 和 TestAndSet

RAPID 信號的互鎖和同步 WaitTestAndSet 指令等待指定的持久型 BOOL 變量變成 FALSE.當變量值變為 FALSE, 該指令將設置變量為 TRUE 并繼續執行. 該持久型變量可被作為同步或者互斥時的一個 BOOL 信號量。 這個指令與 TestAndSet 有著同樣的基本功能。但是 WaitTestAnd…

【常用網址】——opencv等

opencv官網Releases - OpenCVhttps://opencv.org/releases/

(五):C++分布式實時應用框架——微服務架構的演進

C分布式實時應用框架——微服務架構的演進 技術交流合作QQ群:436466587 歡迎討論交流 上一篇:(四):C分布式實時應用框架——狀態中心模塊 版權聲明:本文版權及所用技術歸屬smartguys團隊所有,對于抄襲,非經同意轉載等…