LeNet

概念

代碼

model

import torch.nn as nn
import torch.nn.functional as Fclass LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()  # super()繼承父類的構造函數self.conv1 = nn.Conv2d(3, 16, 5)self.pool1 = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(16, 32, 5)self.pool2 = nn.MaxPool2d(2, 2)self.fc1 = nn.Linear(32*5*5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x): x = F.relu(self.conv1(x))    # input(3, 32, 32) output(16, 28, 28)x = self.pool1(x)            # output(16, 14, 14)x = F.relu(self.conv2(x))    # output(32, 10, 10)x = self.pool2(x)            # output(32, 5, 5)x = x.view(-1, 32*5*5)       # output(32*5*5)x = F.relu(self.fc1(x))      # output(120)x = F.relu(self.fc2(x))      # output(84)x = self.fc3(x)              # output(10)return x

forward:定義正向傳播的過程。

ReLU:激活哈數

觀察網絡中的參數傳遞:發現傳遞的都是channel通道數,最后output在softmax函數里展開的也是展開的通道數。

train

import torch
import torchvision
import torch.nn as nn
from model import LeNet
import torch.optim as optim
import torchvision.transforms as transformsdef main():transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# 50000張訓練圖片# 第一次使用時要將download設置為True才會自動去下載數據集train_set = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)train_loader = torch.utils.data.DataLoader(train_set, batch_size=36,shuffle=True, num_workers=0)# 10000張驗證圖片# 第一次使用時要將download設置為True才會自動去下載數據集val_set = torchvision.datasets.CIFAR10(root='./data', train=False,download=False, transform=transform)val_loader = torch.utils.data.DataLoader(val_set, batch_size=5000,shuffle=False, num_workers=0)val_data_iter = iter(val_loader)val_image, val_label = next(val_data_iter)# classes = ('plane', 'car', 'bird', 'cat',#            'deer', 'dog', 'frog', 'horse', 'ship', 'truck')net = LeNet()loss_function = nn.CrossEntropyLoss()optimizer = optim.Adam(net.parameters(), lr=0.001)for epoch in range(5):  # loop over the dataset multiple timesrunning_loss = 0.0for step, data in enumerate(train_loader, start=0):# get the inputs; data is a list of [inputs, labels]inputs, labels = data# zero the parameter gradientsoptimizer.zero_grad()# forward + backward + optimizeoutputs = net(inputs)loss = loss_function(outputs, labels)loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()if step % 500 == 499:    # print every 500 mini-batcheswith torch.no_grad():outputs = net(val_image)  # [batch, 10]predict_y = torch.max(outputs, dim=1)[1]accuracy = torch.eq(predict_y, val_label).sum().item() / val_label.size(0)print('[%d, %5d] train_loss: %.3f  test_accuracy: %.3f' %(epoch + 1, step + 1, running_loss / 500, accuracy))running_loss = 0.0print('Finished Training')save_path = './Lenet.pth'torch.save(net.state_dict(), save_path)if __name__ == '__main__':main()

predict.py

import torch
import torchvision.transforms as transforms
from PIL import Imagefrom model import LeNetdef main():transform = transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')net = LeNet()net.load_state_dict(torch.load('Lenet.pth'))im = Image.open('1.jpg').convert('RGB')im = transform(im)  # [C, H, W]im = torch.unsqueeze(im, dim=0)  # [N, C, H, W]with torch.no_grad():outputs = net(im)predict = torch.max(outputs, dim=1)[1].numpy()# predict = torch.softmax(outputs,dim=1)# print(predict)# tensor([[9.9884e-01, 1.9386e-04, 3.8757e-04, 2.0671e-05, 2.5372e-04, 3.6199e-05,# 3.7643e-05, 1.7624e-04, 2.0138e-05, 3.4801e-05]])print(classes[int(predict)])if __name__ == '__main__':main()

知識點:

增加新的維度:?

im = torch.unsqueeze(im, dim=0) ?# [N, C, H, W]?

predict = torch.max(outputs, dim=1)[1].numpy():

這一行代碼使用torch.max()函數找到outputs張量在第一個維度上的最大值,并返回最大值和對應的索引。dim=1表示在第一個維度上進行最大值的計算,即對每個樣本的輸出進行比較。[1]表示返回最大值對應的索引。最后,.numpy()將結果轉換為NumPy數組。?

更換:

predict = torch.softmax(outputs,dim=1)

print:tensor([[9.9884e-01, 1.9386e-04, 3.8757e-04, 2.0671e-05, 2.5372e-04, 3.6199e-05,
? ? ? ? ?3.7643e-05, 1.7624e-04, 2.0138e-05, 3.4801e-05]])

Pytorch使用

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/211018.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/211018.shtml
英文地址,請注明出處:http://en.pswp.cn/news/211018.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Bash腳本處理ogg、flac格式到mp3格式的批量轉換

現在下載的許多音樂文件是flac和ogg格式的,QQ音樂上下載的就是這樣的,這些文件尺寸比較大,在某些場合使用不便,比如在車機上播放還是mp3格式合適,音質這些在車機上播放足夠了,要求不高。比如本人就喜歡下載…

軟件接口安全設計規范

《軟件項目接口安全設計規范》 1.token授權機制 2.https傳輸加密 3.接口調用防濫用 4.日志審計里監控 5.開發測試環境隔離,脫敏處理 6.數據庫運維監控審計

卷王開啟驗證碼后無法登陸問題解決

問題描述 使用 docker 部署,后臺設置開啟驗證,重啟服務器之后,docker重啟,再次訪問系統,驗證碼獲取失敗,導致無法進行驗證,也就無法登陸系統。 如果不了解卷王的,可以去官網看下。…

飛天使-linux操作的一些技巧與知識點3

http工作原理 http1.0 協議 使用的是短連接,建立一次tcp連接,發起一次http的請求,結束,tcp斷開 http1.1 協議使用的是長連接,建立一次tcp的連接,發起多次http的請求,結束,tcp斷開ngi…

ky10 server x86 設置網卡開機自啟

輸入命令查看網卡名稱 ip a 輸入命令編輯網卡信息 vi /etc/sysconfig/network-scripts/*33改成yes 按ESC鍵,輸入:wq,保存

Aloha 機械臂的學習記錄2——AWE:AWE + ACT

繼續下一個階段: Train policy python act/imitate_episodes.py \ --task_name [TASK] \ --ckpt_dir data/outputs/act_ckpt/[TASK]_waypoint \ --policy_class ACT --kl_weight 10 --chunk_size 50 --hidden_dim 512 --batch_size 8 --dim_feedforward 3200 \ --n…

F : A DS二分查找_尋找比目標字母大的最小字母

Description 給你一個字符串str,字符串中的字母都已按照升序排序,且只包含小寫字母。另外給出一個目標字母target,請你尋找在這一有序字符串里比目標字母大的最小字母。 在比較時,字母是依序循環出現的。例如,str“ab…

Python中鎖的常見用法

在 Python 中,可以使用線程鎖來控制多個線程對共享資源的訪問。以下是一些常見的 Python 中鎖的用法: 創建線程鎖 在 Python 中,可以使用 threading 模塊中的 Lock 類來創建線程鎖。例如: import threading# 創建線程鎖 lock …

Python網絡爬蟲環境的安裝指南

網絡爬蟲是一種自動化的網頁數據抓取技術,廣泛用于數據挖掘、信息搜集和互聯網研究等領域。Python作為一種強大的編程語言,擁有豐富的庫支持網絡爬蟲的開發。本文將為你詳細介紹如何在你的計算機上安裝Python網絡爬蟲環境。 一、安裝python開發環境 進…

什么是電壓紋波,造成不良,如何測量、如何抑制設計

1 引言 電源給電子產品提供能量同時也附帶了一些不好的影響成分,如紋波、噪聲等,這些對本振、、濾波、放大器、混頻器、檢波、A/D 轉換等電路都會產生影響,會直接影響電子產品正常工作,所以項目設計要合理、要有實測數據、要盡量減小系統電壓的紋波。 1.1 電壓紋波(volta…

bc-linux-歐拉重制root密碼

最近需要重新安裝虛擬機的系統 安裝之后發現對方提供的root密碼不對,無法進入系統。 上網搜了下發現可以進入單用戶模式進行密碼修改從而重置root用戶密碼。 在這個界面下按e鍵 找到圖中部分,把標紅的部分刪除掉,然后寫上rw init/bin/…

strftime(“%-m/%-d/%Y“) 報錯 ValueError: Invalid format string

問題 運行測試用例時,出現ValueError: Invalid format string的錯誤,代碼大致如下: from datetime import date .... current date.today() return current.strftime("%-m/%-d/%Y")原因 開發此代碼的時候是在mac上開發的&#…

24、文件上傳漏洞——Apache文件解析漏洞

文章目錄 一、環境簡介一、Apache與php三種結合方法二、Apache解析文件的方法三、Apache解析php的方法四、漏洞原理五、修復方法 一、環境簡介 Apache文件解析漏洞與用戶配置有密切關系。嚴格來說,屬于用戶配置問題,這里使用ubantu的docker來復現漏洞&am…

IOday7作業

1> 使用無名管道完成父子進程間的通信 #include<myhead.h>int main(int argc, const char *argv[]) {//創建存放兩個文件描述符的數組int fd[2];int pid -1;//打開無名管道if(pipe(fd) -1){perror("pipe");return -1;}//創建子進程pid fork();if(pid &g…

wordpress小記

1.插件市場搜索redis&#xff0c;并按照 Redis Object cache插件 2.開啟php的redis擴展 執行php -m|grep redis&#xff0c;沒有顯示就執行 yum -y install php-redis3.再次修改wp配置文件&#xff0c;增加redis的配置 define( WP_REDIS_HOST, 114.80.36.124 );define( WP_…

非標設計之電磁閥

電磁閥&#xff1a; 分類&#xff1a; 動畫演示兩位三通電磁閥&#xff1a; 兩位三通電磁閥動畫演示&#xff1a; 111&#xff1a; 氣缸回路的介紹&#xff1a; 失電狀態&#xff1a; 電磁閥得電狀態&#xff1a; 兩位五通電磁閥的回路&#xff1a;&#xff08;常用&#xf…

算數運算符和算數表達式

基本算數運算符 算數運算符&#xff1a; &#xff08;加法運算符或正值運算符&#xff09;、-&#xff08;減法運算符或負值運算符&#xff09;、*&#xff08;乘&#xff09;、/&#xff08;除&#xff09;、%&#xff08;求余數&#xff09; 雙目運算符&#xff1a; 雙目…

四則運算 .

輸入一個表達式&#xff08;用字符串表示&#xff09;&#xff0c;求這個表達式的值。 保證字符串中的有效字符包括[‘0’-‘9’],‘’,‘-’, ‘*’,‘/’ ,‘(’&#xff0c; ‘)’,‘[’, ‘]’,‘{’ ,‘}’。且表達式一定合法。字符串長度滿足1≤n≤1000 輸入描述&#x…

CGAL的2D符合規定的三角剖分和網格

1、符合規定的三角剖分 1.1、定義 如果三角形的任何面的外接圓在其內部不包含頂點&#xff0c;則該三角形是 Delaunay 三角形。 約束 Delaunay 三角形是一種盡可能接近 Delaunay 的約束三角形。 約束 Delaunay 三角形的任何面的外接圓在其內部不包含從該面可見的數據點。 如果…

陀螺儀LSM6DSV16X與AI集成(3)----讀取融合算法輸出的四元數

陀螺儀LSM6DSV16X與AI集成.2--姿態解算 概述視頻教學樣品申請完整代碼下載使用demo板生成STM32CUBEMX串口配置IIC配置CS和SA0設置串口重定向參考程序初始化SFLP步驟初始化SFLP讀取四元數數據演示 概述 LSM6DSV16X 特性涉及到的是一種低功耗的傳感器融合算法&#xff08;Sensor…