圖像分類應用

先留一段圖像分類代碼,空閑時間再做分析:

創建神經網絡:

import torch
from torch import nn
import torch.nn.functional as F
class MyAlexNet(nn.Module):def __init__(self):super(MyAlexNet, self).__init__()self.c1=nn.Conv2d(in_channels=3,out_channels=48,kernel_size=11,stride=4,padding=2)self.ReLU=nn.ReLU()self.c2=nn.Conv2d(in_channels=48,out_channels=128,kernel_size=5,stride=1,padding=2)self.s2=nn.MaxPool2d(2)self.c3=nn.Conv2d(in_channels=128,out_channels=192,kernel_size=3,stride=1,padding=1)self.s3=nn.MaxPool2d(2)self.c4=nn.Conv2d(in_channels=192,out_channels=192,kernel_size=3,stride=1,padding=1)self.c5=nn.Conv2d(in_channels=192,out_channels=128,kernel_size=3,stride=1,padding=1)self.s5=nn.MaxPool2d(kernel_size=3,stride=2)self.flatten=nn.Flatten()self.f6=nn.Linear(128*6*6,2048)self.f7=nn.Linear(2048,2048)self.f8=nn.Linear(2048,1000)self.f9=nn.Linear(1000,2)def forward(self,x):x=self.ReLU(self.c1(x))x=self.ReLU(self.c2(x))x=self.s2(x)x=self.ReLU(self.c3(x))x=self.s3(x)x=self.ReLU(self.c4(x))x=self.ReLU(self.c5(x))x=self.s5(x)x=self.flatten(x)x=self.f6(x)x=F.dropout(x,p=0.5)x=self.f7(x)x=F.dropout(x,p=0.5)x=self.f8(x)x=F.dropout(x,p=0.5)x=self.f9(x)return x
if __name__ == '__main__':x=torch.rand([1,3,224,224])model=MyAlexNet()y=model(x)

數據預處理:

import os
from shutil import copy
import random
def mkdir(file):if not os.path.exists(file):os.makedirs(file)
#獲取data文件夾下所有文件夾名(即需要分類的類名)
file_path='E:/BaiduNetdiskDownload/Kaggle貓狗大戰/train'
flower_class= [cla for cla in os.listdir(file_path)]
#創建訓練集train文件夾,并由類名在其目錄下創建子目錄
mkdir('data/train')
mkdir('data/train/cat')
mkdir('data/train/dog')
mkdir('data/val')
mkdir('data/val/cat')
mkdir('data/val/dog')
split_rate=0.1
for cla in flower_class:cla_path=file_path+'/'+cla#"E:\BaiduNetdiskDownload\Kaggle貓狗大戰\train\train\cat.0.jpg"images=os.listdir(cla_path)print(cla_path)num=len(images)eval_index=random.sample(images,k=int(num*split_rate))for index,image in enumerate(images):if image in eval_index:image_path = cla_path+'/'+imageif "cat" in image_path:new_path = 'data/val/cat/'else:new_path = 'data/val/dog/'copy(image_path,new_path)else:image_path=cla_path+'/'+imageif "cat" in image_path:new_path='data/train/cat/'else:new_path='data/train/dog/'copy(image_path,new_path)print("\r[{}]processing[{}/{}]".format(cla,index+1,num),end="")print()
print("processing done!")

訓練集用于 訓練權重:

import torch
from torch import nn
from net import MyAlexNet
import numpy as np
from torch.optim import lr_scheduler
import os
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus']=False
ROOT_TRAIN = 'C:/Users/86156/PycharmProjects/pythonProject1/cat-dog/data/train'
ROOT_TEST='C:/Users/86156/PycharmProjects/pythonProject1/cat-dog/data/val'
normalize=transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
train_transform=transforms.Compose([transforms.Resize((224,224)),transforms.RandomVerticalFlip(),transforms.ToTensor(),normalize
])
val_transform=transforms.Compose([transforms.Resize((224,224)),transforms.ToTensor(),normalize
])
train_dataset=ImageFolder(ROOT_TRAIN,transform=train_transform)
val_dataset=ImageFolder(ROOT_TEST,transform=val_transform)
train_dataloader=DataLoader(train_dataset,batch_size=32,shuffle=True)
val_dataloader=DataLoader(val_dataset,batch_size=32,shuffle=True)
model=MyAlexNet()
loss_fn=nn.CrossEntropyLoss()
optimizer=torch.optim.SGD(model.parameters(),lr=0.01,momentum=0.9)
lr_scheduler=lr_scheduler.StepLR(optimizer,step_size=10,gamma=0.5)
def train(dataloader,model,loss_fn,optimizer):loss,current,n =0.0,0.0,0.0for batch,(x,y) in enumerate(dataloader):image,y =x,youtput=model(image)cur_loss=loss_fn(output,y)_,pred=torch.max(output,axis=1)cur_acc=torch.sum(y==pred)/output.shape[0]optimizer.zero_grad()cur_loss.backward()optimizer.step()loss+=cur_loss.item()current+=cur_acc.item()n+=1train_loss=loss/ntrain_acc=current/nprint('train_loss'+str(train_loss))print('train_acc'+str(train_acc))return train_loss,train_acc
def val(dataloader,model,loss_fn):model.eval()loss,current,n=0.0,0.0,0.0with torch.no_grad():for batch,(x,y) in enumerate(dataloader):image,y =x,youtput=model(image)cur_loss=loss_fn(output,y)_,pred=torch.max(output,axis=1)cur_acc=torch.sum(y==pred)/output.shape[0]loss+=cur_loss.item()current+=cur_acc.item()n+=1val_loss=loss/nval_acc=current/nprint('val_loss'+str(val_loss))print('val_acc'+str(val_acc))return val_loss,val_acc
def matplot_loss(train_loss,val_loss):plt.plot(train_loss,label='train_loss')plt.plot(val_loss,label='val_loss')plt.legend(loc='best')plt.ylabel('loss')plt.xlabel('epoch')plt.title("訓練集和驗證集loss值對比圖")plt.show()
def matplot_acc(train_loss,val_loss):plt.plot(train_acc,label='train_acc')plt.plot(val_acc,label='val_acc')plt.legend(loc='best')plt.ylabel('acc')plt.xlabel('epoch')plt.title("訓練集和驗證集acc值對比圖")plt.show()
loss_train=[]
acc_train=[]
loss_val=[]
acc_val=[]
epoch=20
min_acc=0
for t in range(epoch):lr_scheduler.step()print(f"epoch{t+1}\n----------------")train_loss,train_acc=train(train_dataloader,model,loss_fn,optimizer)val_loss,val_acc=val(val_dataloader,model,loss_fn)loss_train.append(train_loss)acc_train.append(train_acc)loss_val.append(val_loss)acc_val.append(val_acc)if val_acc>min_acc:folder='save_model'if not os.path.exists(folder):os.mkdir('save_model')min_acc=val_accprint(f"save best model,第{t+1}輪")torch.save(model.state_dict(),'save_model/best.model.pth')if t==epoch-1:torch.save(model.state_dict(),'save_model/last_model.pth')
print('Done')

測試集用于測試模型:

import torch
from net import MyAlexNet
from torch.autograd import variable
from torchvision import datasets,transforms
from torchvision.transforms import ToPILImage
from torchvision.datasets import ImageFolder
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
ROOT_TRAIN = 'C:/Users/86156/PycharmProjects/pythonProject1/cat-dog/data/train'
ROOT_TEST='C:/Users/86156/PycharmProjects/pythonProject1/cat-dog/data/val'
normalize=transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
train_transform=transforms.Compose([transforms.Resize((224,224)),transforms.RandomVerticalFlip(),transforms.ToTensor(),normalize
])
val_transform=transforms.Compose([transforms.Resize((224,224)),transforms.ToTensor(),normalize
])
train_dataset=ImageFolder(ROOT_TRAIN,transform=train_transform)
val_dataset=ImageFolder(ROOT_TEST,transform=val_transform)
train_dataloader=DataLoader(train_dataset,batch_size=32,shuffle=True)
val_dataloader=DataLoader(val_dataset,batch_size=32,shuffle=True)
model=MyAlexNet()
model.load_state_dict(torch.load("C:/Users/86156/PycharmProjects/pythonProject1/cat-dog/save_model/best.model.pth"))
classes=["cat","dog",
]
show=ToPILImage()
model.eval()
for i in range(50):x,y = val_dataset[i][0],val_dataset[i][1]show(x).show()x=torch.tensor(torch.unsqueeze(x,dim=0).float(),requires_grad=True)x=torch.tensor(x)with torch.no_grad():pred=model(x)print(pred)predicted,actual=classes[torch.argmax(pred[0])],classes[y]print(f'predicted:"{predicted}",Actual:"{actual}"')

沒有顯卡慢的跟狗屎一樣。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/719530.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/719530.shtml
英文地址,請注明出處:http://en.pswp.cn/news/719530.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

二刷代碼隨想錄算法訓練營第十天 | 232.用棧實現隊列、 225. 用隊列實現棧

目錄 一、232. 用棧實現隊列 二、225. 用隊列實現棧 一、232. 用棧實現隊列 題目鏈接:力扣 文章講解:代碼隨想錄 視頻講解: 棧的基本操作! | LeetCode:232.用棧實現隊列 題目: 請你僅使用兩個棧實現先…

Vision Pro開發者學習路線

官方給到的Vision Pro開發者學習路線: 1. 學習基礎知識: - 學習 Xcode、Swift 和 SwiftUI 的基礎知識,包括語法、UI 設計等。 - 掌握 ARKit 和 SwiftUI 的使用,了解如何創建沉浸式增強現實體驗。 2. 學習 3D 建模&#xf…

『Linux從入門到精通』第 ? 期 - System V 共享內存

文章目錄 💐專欄導讀💐文章導讀🐧共享內存原理🐧共享內存相關函數🐦key 與 shmid 區別 🐧代碼實例 💐專欄導讀 🌸作者簡介:花想云 ,在讀本科生一枚&#xff0…

CentOS7安裝DockerCompose和Docker鏡像倉庫的配置

CentOS7安裝DockerCompose 1.下載 Linux下需要通過命令下載: # 安裝 curl -L https://github.com/docker/compose/releases/download/1.23.1/docker-compose-uname -s-uname -m > /usr/local/bin/docker-compose2.修改文件權限 修改文件權限: # …

YOLOv9獨家原創改進|加入幽靈卷積Ghost Convolution模塊,輕量化!

專欄介紹:YOLOv9改進系列 | 包含深度學習最新創新,主力高效漲點!!! 一、論文摘要 由于內存和計算資源有限,在嵌入式設備上部署卷積神經網絡是困難的。特征圖中的冗余是那些成功的細胞神經網絡的一個重要特征…

【網站項目】158企業人事管理系統

🙊作者簡介:擁有多年開發工作經驗,分享技術代碼幫助學生學習,獨立完成自己的項目或者畢業設計。 代碼可以私聊博主獲取。🌹贈送計算機畢業設計600個選題excel文件,幫助大學選題。贈送開題報告模板&#xff…

突破編程_C++_字符串算法(判斷字符串是否包含)

1 算法題 :判斷一個字符串是否包含另一個字符串的所有字符(不一定連續) 1.1 題目含義 判斷一個字符串(稱為“主字符串”或“大字符串”)是否包含另一個字符串(稱為“子字符串”或“小字符串”&#xff09…

代碼隨想錄算法訓練營第31天—貪心算法05 | ● 435. 無重疊區間 ● *763.劃分字母區間 ● *56. 合并區間

435. 無重疊區間 https://programmercarl.com/0435.%E6%97%A0%E9%87%8D%E5%8F%A0%E5%8C%BA%E9%97%B4.html 考點 貪心算法重疊區間 我的思路 先按照區間左坐標進行排序,方便后續處理進行for循環,循環范圍是0到倒數第二個元素如果當前區間和下一區間重疊…

在Linux以命令行方式(靜默方式/非圖形化方式)安裝MATLAB(正版)

1.根據教程,下載windows版本matlab,打開圖形化界面,選擇linux版本的只下載不安裝 2.獲取安裝文件夾 3.獲取許可證 4.安裝 (1)跳過引用文章的2.2章節 (2)本文的安裝文件夾代替引用文章的解壓IS…

Java進階(鎖)——鎖的升級,synchronized與lock鎖區別

目錄 引出Java中鎖升級synchronized與lock鎖區別 緩存三兄弟:緩存擊穿、穿透、雪崩緩存擊穿緩存穿透緩存雪崩 總結 引出 Java進階(鎖)——鎖的升級,synchronized與lock鎖區別 Java中鎖升級 看一段代碼: public class…

Fastwhisper + Pyannote 實現 ASR + 說話者識別

文章目錄 前言一、faster-whisper簡單介紹二、pyannote.audio介紹三、faster-whisper pyannote.audio 實現語者識別四、多說幾句 前言 最近在研究ASR相關的業務,也是調研了不少模型,踩了不少坑,ASR這塊,目前中文普通話效果最好的…

Scrapy與分布式開發(1.1):課程導學

Scrapy與分布式開發:從入門到精通,打造高效爬蟲系統 課程大綱 在這個專欄中,我們將一起探索Scrapy框架的魅力,以及如何通過Scrapy-Redis實現分布式爬蟲的開發。在本課程導學中,我們將為您簡要介紹課程的學習目標、內容…

Verilog Coding Styles For Improved Simulation Efficiency論文學習記錄

原文基于Verilog-XL仿真器,測試了以下幾種方式對仿真效率的影響。 1. 使用 Case 語句而不是 if / else if 語句 八選一多路選擇器 case 實現效率比 if / else if 提升 6% 。 2. 如果可以盡量不使用 begin end 語句 使用 begin end 的 ff 觸發器比不使用 begin end …

初學者學習51還是STM32

初學者學習51還是STM32 在嵌入式系統領域,51和STM32是兩種常見的單片機架構。對于初學者來說,選擇學習哪種架構可能會成為一個難題。本文將對初學者學習51和STM32進行比較,以幫助讀者做出明智的選擇。 1. 51架構 51架構是指Intel 8051系列…

深度相機xyz點云文件三維坐標和jpg圖像文件二維坐標的相互變換函數

深度相機同時拍攝xyz點云文件和jpg圖像文件。xyz文件里面包含三維坐標[x,y,z]和jpg圖像文件包含二維坐標[x,y],但是不能直接進行變換,需要一定的步驟來推演。 下面函數是通過box二維框[xmin, ymin, xmax, ymax, _, _ ]去截取xyz文件中對應box里面的點云…

MyCAT學習——在openEuler22.03中安裝MyCAT2(網盤下載版)

準備工作 因為MyCAT 2基于JDK 1.8開發。也需要在虛擬機中安裝JDK(JDK官網就能下載,我這提供一個捷徑) jdk-8u401-linux-x64.rpmhttps://pan.baidu.com/s/1ywcDsxYOmfZONpmH9oDjfw?pwdrhel下載對應的tar安裝包,以及對應的jar包 安裝程序包…

九州金榜|孩子厭學要怎么辦?

孩子從小學到初中再到高中,孩子出現厭學情緒很正常,但是孩子出現厭學情緒后,就必然會影響到孩子學習成績,孩子產生厭學情緒的原因有哪些呢?只有找準孩子厭學原因才能去幫助孩子怎樣去克服孩子厭學情緒,下面…

ajax請求servlet成功但接收不到返回數據問題

ajax請求servlet成功但接收不到返回數據問題 javaweb初學者,最近老師布置的課設,所有功能都完成了,唯獨ajax與servlet交互出現問題,無論怎么調試都收不到數據 查詢兩天無果,剛才無意間看到 Crabime前輩的文章才恍然大…

深入解析YOLO:實時目標檢測技術的革命者

深入解析YOLO:實時目標檢測技術的革命者 目標檢測作為計算機視覺領域的一個核心任務,一直以來都是研究的熱點。而YOLO(You Only Look Once)技術作為其中的杰出代表,以其獨特的處理方式和卓越的性能,成為了…

day34貪心算法 part03

1005. K 次取反后最大化的數組和 簡單 給你一個整數數組 nums 和一個整數 k ,按以下方法修改該數組: 選擇某個下標 i 并將 nums[i] 替換為 -nums[i] 。 重復這個過程恰好 k 次。可以多次選擇同一個下標 i 。 以這種方式修改數組后,返回數…