基于PyTorch搭建CNN實現視頻動作分類任務代碼詳解

在這里插入圖片描述
數據及具體講解來源:
基于PyTorch搭建CNN實現視頻動作分類任務

import torch
import torch.nn as nn
import torchvision.transforms as T
import scipy.io
from torch.utils.data import DataLoader,Dataset
import os
from PIL import Image
from torch.autograd import Variable
import numpy as np"""
加載數據
"""
#獲得標簽
label_mat = scipy.io.loadmat('./datasets/q3_2_data.mat')
#獲得訓練集標簽
label_train = label_mat['trLb']
print(len(label_train))
#獲得驗證集標簽
label_val = label_mat['valLb']
print(len(label_val))"""
通過Dataset類進行數據預處理
"""
class ActionDataset(Dataset):def __init__(self,root_dir,labels = [],transform=None):"""Args::param root_dir: 數據路徑:param labels: 圖片標簽:param transform: 數據處理函數"""self.root_dir = root_dirself.transform = transformself.length = len(os.listdir(self.root_dir))self.labels = labelsdef __len__(self):  #返回數據數量return self.length*3    #一個視頻片段包含3幀(3個圖片)def __getitem__(self, idx): #圖片處理及返回數據folder = idx//3+1   #判斷該幀屬于第幾個視頻中imidx = idx%3 + 1   #判斷該幀在該視頻中是第幾幀folder = format(folder,'05d')   #將folder格式化,05d代表五位數,若不到五位用0填充imgname = str(imidx) + '.jpg'img_path = os.path.join(self.root_dir,folder,imgname)image = Image.open(img_path)"""當輸入標簽有值時,說明是訓練集和驗證集,輸出的樣本也是有標簽的,若沒有值,說明是測試集,輸出的樣本是沒有標簽的"""if len(self.labels)!=0:Label = self.labels[idx//3][0]-1#如果有對數據的處理先對數據進行處理if self.transform:image = self.transform(image)if len(self.labels)!=0:sample = {'image':image,'img_path':img_path,'Label':Label}else:sample = {'image':image,'img_path':img_path}return sampleimage_datast = ActionDataset(root_dir='./datasets/trainClips/',labels=label_train,transform=T.ToTensor())
# torchvision.transforms中定義了非常多對圖像的預處理方法,這里使用的ToTensor方法為將0~255的RGB值映射到0~1的Tensor類型。
# #測試一下
# for i in range(3):
#     sample = image_datast[i]
#     print(sample['image'].shape)
#     print(sample['Label'])
#     print(sample['img_path'])"""
Dataloader類進行封裝
注意:Windows不要用num_works
"""
#image_dataloader = DataLoader(image_datast,batch_size=4,shuffle=True)
# for i , sample in enumerate(image_dataloader):
#     #enumerate(iteration, start):返回一個枚舉的對象
#     sample['image'] = sample['image']
#     print(sample[i,sample['image'].shape,sample['img_path'],'Label'])
#     if i == 6:
#         break
#訓練集
image_dataset_train=ActionDataset(root_dir='./datasets/trainClips/',labels=label_train,transform=T.ToTensor())
image_dataloader_train = DataLoader(image_dataset_train, batch_size=32,shuffle=True)
#驗證集
image_dataset_val=ActionDataset(root_dir='./datasets/valClips/',labels=label_val,transform=T.ToTensor())
image_dataloader_val = DataLoader(image_dataset_val, batch_size=32,shuffle=False)
#測試集:沒有給定labels
image_dataset_test=ActionDataset(root_dir='./datasets/testClips/',labels=[],transform=T.ToTensor())
image_dataloader_test = DataLoader(image_dataset_test, batch_size=32,shuffle=False)"""
搭建模型
"""dtype = torch.FloatTensor # 這是pytorch所支持的cpu數據類型中的浮點數類型。print_every = 100   # 這個參數用于控制loss的打印頻率,因為我們需要在訓練過程中不斷的對loss進行檢測。def reset(m):   # 這是模型參數的初始化if hasattr(m, 'reset_parameters'):m.reset_parameters()#數據解釋和處理
class Flatten(nn.Module):def forward(self, x):N, C, H, W = x.size() # 讀取各個維度。return x.view(N, -1)  # -1代表除了特殊聲明過的以外的全部維度。fixed_model_base = nn.Sequential(nn.Conv2d(3,8,kernel_size=7,stride=1),   ##3*64*64 -> 8*58*58nn.ReLU(inplace=True),nn.MaxPool2d(2, stride = 2),    # 8*58*58 -> 8*29*29nn.Conv2d(8, 16, kernel_size=7, stride=1), # 8*29*29 -> 16*23*23nn.ReLU(inplace=True),nn.MaxPool2d(2, stride = 2), # 16*23*23 -> 16*11*11Flatten(),nn.ReLU(inplace=True),nn.Linear(1936, 10)     # 1936 = 16*11*11
)
fixed_model = fixed_model_base.type(dtype)  #將模型數據轉換成pytorch所支持的cpu數據類型中的浮點數類型。
# #測試:
# x = torch.randn(32, 3, 64, 64).type(dtype)
# x_var = Variable(x.type(dtype)) # 需要將其封裝為Variable類型。
# ans = fixed_model(x_var)
# print(np.array(ans.size())) # 檢查模型輸出。
# np.array_equal(np.array(ans.size()), np.array([32, 10]))"""
訓練步驟及模塊
"""
optimizer = torch.optim.RMSprop(fixed_model_base.parameters(), lr = 0.0001)
loss_fn = nn.CrossEntropyLoss()def train(model,loss_fn,optimizer,dataloader,num_epoch = 1):for epoch in range(num_epoch):check_accuracy(fixed_model,image_dataloader_val)    #在驗證集驗證模型效果model.train()   #模型的.train()方法讓模型進入訓練模式,參數保留梯度,dropout層等部分正常工作for t,sample in enumerate(dataloader):x_var = Variable(sample['image'])y_var = Variable(sample['Label'].long())scores = model(x_var)   #得到輸出loss = loss_fn(scores,y_var)if (t+1)%print_every ==0:print('t = %d, loss = %.4f' % (t + 1, loss.item()))#三步更新optimizer.zero_grad()loss.backward()optimizer.step()def check_accuracy(model,loader):num_correct = 0num_samples = 0model.eval()    # 模型的.eval()方法切換進入評測模式,對應的dropout等部分將停止工作。for t,sample in enumerate(loader):x_var = Variable(sample['image'])y_var = Variable(sample['Label'])scores = model(x_var)_,preds = scores.data.max(1)    # 找到可能最高的標簽作為輸出。num_correct += (preds.numpy() == y_var.numpy()).sum()num_samples += preds.size(0)acc = float(num_correct)/num_samplesprint('Got %d / %d correct (%.2f)' % (num_correct, num_samples, 100 * acc))"""
訓練并驗證
"""
torch.random.manual_seed(54321)
fixed_model.cpu()
fixed_model.apply(reset)
#pytorch中的model.apply(fn)會遞歸地將函數fn應用到父模塊的每個子模塊submodule,也包括model這個父模塊自身
fixed_model.train()
train(fixed_model, loss_fn, optimizer,image_dataloader_train, num_epoch=5)
check_accuracy(fixed_model, image_dataloader_val)"""
測試
"""def predict_on_test(model, loader):model.eval()results = open('results.csv', 'w')  # 模型預測結果會被放在這里。count = 0results.write('Id' + ',' + 'Class' + '\n')for t, sample in enumerate(loader):x_var = Variable(sample['image'])scores = model(x_var)_, preds = scores.data.max(1)for i in range(len(preds)):results.write(str(count) + ',' + str(preds[i]) + '\n')count += 1results.close()return countcount = predict_on_test(fixed_model, image_dataloader_test)  # 放入你想要測試的訓練集,然后打開文件去看一看結果吧。
print(count)

在這里插入圖片描述

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/389283.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/389283.shtml
英文地址,請注明出處:http://en.pswp.cn/news/389283.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

missforest_missforest最佳丟失數據插補算法

missforestMissing data often plagues real-world datasets, and hence there is tremendous value in imputing, or filling in, the missing values. Unfortunately, standard ‘lazy’ imputation methods like simply using the column median or average don’t work wel…

華碩猛禽1080ti_F-22猛禽動力回路的視頻分析

華碩猛禽1080tiThe F-22 Raptor has vectored thrust. This means that the engines don’t just push towards the front of the aircraft. Instead, the thrust can be directed upward or downward (from the rear of the jet). With this vectored thrust, the Raptor can …

聊天常用js代碼

<script languagejavascript>//轉意義字符與替換圖象以及字體HtmlEncode(text)function HtmlEncode(text){return text.replace(//"/g, &quot;).replace(/</g, <).replace(/>/g, >).replace(/#br#/g,<br>).replace(/IMGSTART/g,<IMG style…

溫故而知新:柯里化 與 bind() 的認知

什么是柯里化?科里化是把一個多參數函數轉化為一個嵌套的一元函數的過程。&#xff08;簡單的說就是將函數的參數&#xff0c;變為多次入參&#xff09; const curry (fn, ...args) > fn.length < args.length ? fn(...args) : curry.bind(null, fn, ...args); // 想要…

OPENVAS運行

https://www.jianshu.com/p/382546aaaab5轉載于:https://www.cnblogs.com/diyunpeng/p/9258163.html

Memory-Associated Differential Learning論文及代碼解讀

Memory-Associated Differential Learning論文及代碼解讀 論文來源&#xff1a; 論文PDF&#xff1a; Memory-Associated Differential Learning論文 論文代碼&#xff1a; Memory-Associated Differential Learning代碼 論文解讀&#xff1a; 1.Abstract Conventional…

大數據技術 學習之旅_如何開始您的數據科學之旅?

大數據技術 學習之旅Machine Learning seems to be fascinating to a lot of beginners but they often get lost into the pool of information available across different resources. This is true that we have a lot of different algorithms and steps to learn but star…

純API函數實現串口讀寫。

以最后決定用純API函數實現串口讀寫。 先從網上搜索相關代碼&#xff08;關鍵字&#xff1a;C# API 串口&#xff09;&#xff0c;發現網上相關的資料大約來源于一個版本&#xff0c;那就是所謂的msdn提供的樣例代碼&#xff08;msdn的具體出處&#xff0c;我沒有考證&#xff…

數據可視化工具_數據可視化

數據可視化工具Visualizations are a great way to show the story that data wants to tell. However, not all visualizations are built the same. My rule of thumb is stick to simple, easy to understand, and well labeled graphs. Line graphs, bar charts, and histo…

Android Studio調試時遇見Install Repository and sync project的問題

我們可以看到&#xff0c;報的錯是“Failed to resolve: com.android.support:appcompat-v7:16.”&#xff0c;也就是我們在build.gradle中最后一段中的compile項內容。 AS自動生成的“com.android.support:appcompat-v7:16.”實際上是根據我們的最低版本16來選擇16.x.x及以上編…

Apache Ignite 學習筆記(二): Ignite Java Thin Client

前一篇文章&#xff0c;我們介紹了如何安裝部署Ignite集群&#xff0c;并且嘗試了用REST和SQL客戶端連接集群進行了緩存和數據庫的操作。現在我們就來寫點代碼&#xff0c;用Ignite的Java thin client來連接集群。 在開始介紹具體代碼之前&#xff0c;讓我們先簡單的了解一下Ig…

VGAE(Variational graph auto-encoders)論文及代碼解讀

一&#xff0c;論文來源 論文pdf Variational graph auto-encoders 論文代碼 github代碼 二&#xff0c;論文解讀 理論部分參考&#xff1a; Variational Graph Auto-Encoders&#xff08;VGAE&#xff09;理論參考和源碼解析 VGAE&#xff08;Variational graph auto-en…

IIS7設置

IIS 7.0和IIS 6.0相比改變很大誰都知道&#xff0c;而且在IIS 7.0中用VS2005來調試Web項目也不是什么新鮮的話題&#xff0c;但是我還是第一次運用這個東東&#xff0c;所以在此記下我的一些過程&#xff0c;希望能給更多的后來者帶了一點參考。其實我寫這篇文章時也參考了其他…

tableau大屏bi_Excel,Tableau,Power BI ...您應該使用什么?

tableau大屏biAfter publishing my previous article on data visualization with Power BI, I received quite a few questions about the abilities of Power BI as opposed to those of Tableau or Excel. Data, when used correctly, can turn into digital gold. So what …

python 可視化工具_最佳的python可視化工具

python 可視化工具Disclaimer: I work for Datapane免責聲明&#xff1a;我為Datapane工作 動機 (Motivation) There are amazing articles on data visualization on Medium every day. Although this comes at the cost of information overload, it shouldn’t prevent you …

網絡編程 socket介紹

Socket介紹 Socket是應用層與TCP/IP協議族通信的中間軟件抽象層&#xff0c;它是一組接口。在設計模式中&#xff0c;Socket其實就是一個門面模式&#xff0c;它把復雜的TCP/IP協議族隱藏在Socket接口后面&#xff0c;對用戶來說&#xff0c;一組簡單的接口就是全部。 Socket通…

猿課python 第三天

字典 字典是python中唯一的映射類型,字典對象是可變的&#xff0c;但是字典的鍵是不可變對象&#xff0c;字典中可以使用不同的鍵值字典功能> dict.clear()          -->清空字典 dict.keys()          -->獲取所有key dict.values()      …

在C#中使用代理的方式觸發事件

事件&#xff08;event&#xff09;是一個非常重要的概念&#xff0c;我們的程序時刻都在觸發和接收著各種事件&#xff1a;鼠標點擊事件&#xff0c;鍵盤事件&#xff0c;以及處理操作系統的各種事件。所謂事件就是由某個對象發出的消息。比如用戶按下了某個按鈕&#xff0c;某…

BP神經網絡反向傳播手動推導

BP神經網絡過程&#xff1a; 基本思想 BP算法是一個迭代算法&#xff0c;它的基本思想如下&#xff1a; 將訓練集數據輸入到神經網絡的輸入層&#xff0c;經過隱藏層&#xff0c;最后達到輸出層并輸出結果&#xff0c;這就是前向傳播過程。由于神經網絡的輸出結果與實際結果…

使用python和pandas進行同類群組分析

背景故事 (Backstory) I stumbled upon an interesting task while doing a data exercise for a company. It was about cohort analysis based on user activity data, I got really interested so thought of writing this post.在為公司進行數據練習時&#xff0c;我偶然發…