定義損失函數并以此訓練和評估模型

基礎神經網絡模型搭建

【Pytorch】數據集的加載和處理(一)

【Pytorch】數據集的加載和處理(二)

損失函數計算模型輸出和目標之間的距離。通過torch.nn 包可以定義一個負對數似然損失函數,負對數似然損失對于訓練具有多個類的分類問題比較有效,負對數似然損失函數的輸入為對數概率,而在模型搭建的輸出層部分接觸過log_softmax,它能從模型中獲取對數概率

目錄

基礎模型搭建

數據集的加載和處理

定義損失函數

定義優化器

訓練并評估模型


基礎模型搭建

import torch
from torch import nn
import torch.nn.functional as F
class Net(nn.Module):def __init__(self):super(Net, self).__init__()def forward(self, x):pass
def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(1, 20, 5, 1)self.conv2 = nn.Conv2d(20, 50, 5, 1)self.fc1 = nn.Linear(4*4*50, 500)self.fc2 = nn.Linear(500, 10)
def forward(self, x):x = F.relu(self.conv1(x))x = F.max_pool2d(x, 2, 2)x = F.relu(self.conv2(x))x = F.max_pool2d(x, 2, 2) x = x.view(-1, 4*4*50)x = F.relu(self.fc1(x))x = self.fc2(x)return F.log_softmax(x, dim=1)
Net.__init__ = __init__
Net.forward = forward
model = Net()

檢查搭建情況?

print(model)

原位置為cpu?

?轉移至所需CUDA設備

device = torch.device("cuda:0")
model.to(device)
print(next(model.parameters()).device)

數據集的加載和處理

導入MNIST訓練數據集和驗證數據集并處理

from torch import nn
from torchvision import datasets
from torch.utils.data import TensorDataset
path2data="./data"
train_data=datasets.MNIST(path2data, train=True, download=True)
x_train, y_train=train_data.data,train_data.targets
val_data=datasets.MNIST(path2data, train=False, download=True)
x_val,y_val=val_data.data, val_data.targets
if len(x_train.shape)==3:x_train=x_train.unsqueeze(1)
print(x_train.shape)
if len(x_val.shape)==3:x_val=x_val.unsqueeze(1)
print(x_val.shape)
train_ds = TensorDataset(x_train, y_train)
val_ds = TensorDataset(x_val, y_val)
for x,y in train_ds:print(x.shape,y.item())breakfrom torch.utils.data import DataLoader 
train_dl = DataLoader(train_ds, batch_size=8)
val_dl = DataLoader(val_ds, batch_size=8)

定義損失函數

損失函數計算模型輸出和目標之間的距離。Pytorch 中的 optim 包提供了各種優化算法的實現,例如SGD、Adam、RMSprop 等。

通過torch.nn 包可以定義一個負對數似然損失函數,負對數似然損失對于訓練具有多個類的分類問題比較有效,負對數似然損失函數的輸入為對數概率,而在模型搭建的輸出層部分接觸過log_softmax,它能從模型中獲取對數概率。

loss_func = nn.NLLLoss(reduction="sum")
for xb, yb in train_dl:# move batch to cuda devicexb=xb.type(torch.float).to(device)yb=yb.to(device)out=model(xb)loss = loss_func(out, yb)print (loss.item())break

得到一個測試值?

定義優化器

定義一個Adam優化器,優化器的輸入是模型參數和學習率

from torch import optim
opt = optim.Adam(model.parameters(), lr=1e-4)

通過opt .step()自動更新模型參數,同時需要注意計算下一批的梯度之前需將梯度歸0

opt.step()
opt.zero_grad()

訓練并評估模型

定義一個輔助函數?loss_batch來計算每個小批量的損失值。函數的 opt 參數引用優化器,如果給定,則計算梯度并按小批量更新模型參數。

def  loss_batch(loss_func,  xb,  yb,yb_h,  opt=None): loss = loss_func(yb_h, yb) metric_b =  metrics_batch(yb,yb_h) if opt is  not None: loss.backward()opt.step()opt.zero_grad()return loss.item(),metric_b

?定義一個輔助函數metrics_batch來計算每個小批量的性能指標,這里以準確率作為分類任務的性能指標,并使用 output.argmax 來獲取概率最高的預測類

def metrics_batch(target, output):pred = output.argmax(dim=1, keepdim=True)corrects=pred.eq(target.view_as(pred)).sum().item()return corrects

定義一個輔助函數loss_epoch來計算整個數據集的損失和指標值。使用數據加載器對象獲取小批量,將它們提供給模型,并計算每個小批量的損失和指標,通過兩個運行變量來分別添加損失值和指標值。

def loss_epoch(model,loss_func,dataset_dl,opt=None):loss=0.0metric=0.0len_data=len(dataset_dl.dataset)for xb, yb in dataset_dl:xb=xb.type(torch.float).to(device)yb=yb.to(device)yb_h=model(xb)loss_b,metric_b=loss_batch(loss_func, xb, yb,yb_h, opt)loss+=loss_bif metric_b is not None:metric+=metric_bloss/=len_datametric/=len_datareturn loss, metric

最后,定義一個輔助函數train_val來訓練多個時期的模型。在每個時期使用驗證數據集評估模型的性能。訓練和評估需要分別使用 model.train()和 model.eval()模式。torch.no_grad()可以阻止 autograd 在評估期間計算梯度。

def train_val(epochs, model, loss_func, opt, train_dl, val_dl):for epoch in range(epochs):model.train()train_loss,train_metric=loss_epoch(model,loss_func,train_dl,opt)model.eval()with torch.no_grad():val_loss, val_metric=loss_epoch(model,loss_func,val_dl)accuracy=100*val_metricprint("epoch: %d, train loss: %.6f, val loss: %.6f,accuracy: %.2f" %(epoch, train_loss,val_loss,accuracy))

?設定時期數為5,調用函數進行訓練和評估

num_epochs=5
train_val(num_epochs, model, loss_func, opt, train_dl, val_dl)

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/92987.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/92987.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/92987.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

電子書轉PDF格式教程,實現epub轉PDF步驟

EPUB 格式屬于流式文檔,在屏幕尺寸各異的設備上都能自動適配顯示。然而,要是你使用的是特定的閱讀設備,像打印機、不支持 EPUB 格式的電子閱讀器(例如某些早期的 Kindle 型號),或者需要在固定尺寸的屏幕上展…

Java學習第六十九部分——RabbitMQ

目錄 一、前言提要 二、基本信息 1. 關鍵定義 2. 核心角色 3. 交換機類型 三、消息生命周期與可靠性機制 四、生態集成——與Java 五、應用場景 六、性能與選型對比 七、生產級最佳實踐——基于Java 八、應用場景 九、一句話總結 一、前言提要 Spring AMQP是…

MDAC2.6問題解決指南:解決.NET Framework數據訪問煩惱

MDAC2.6問題解決指南:解決.NET Framework數據訪問煩惱 【下載地址】MDAC2.6問題解決指南 MDAC 2.6 問題解決指南為您提供了針對.NET Framework數據提供程序要求使用Microsoft Data Access Components (MDAC) 2.6或更高版本的全面解決方案。本指南詳細介紹了如何在開…

會話跟蹤模式

一、圖片講了什么?這張圖片主要講的是“會話跟蹤技術”,也就是網站怎么記住你是誰、你做了什么。1. 什么是會話?會話(Session)就像你和網站的一次聊天,從你打開網頁到關閉網頁,這段時間就是一次…

C語言開發工具Win-TC

如你所知,WIN-TC是一個turbo C2 WINDOWS 平臺開發工具,最大特點是支持中文界面,支持鼠標操作,程序段復制,為初學 c 語言、對高等編程環境不熟悉的同志們非常有幫助。該軟件使用 turbo C2 為內核,提供 WINDO…

lwIP學習記錄5——裸機lwIP工程學習后的總結

1、ping包的TTL生存時間如何修改當我們把工程燒錄到板子上是,我們對板子的IP進行ping包,看到信息如下圖這時候我好奇TTL是什么作用,為什么有的設備是64有的設備是128有的是255?解:TTL(Time to Live&#xf…

利用Trae將原型圖轉換為可執行的html文件,感受AI編程的魅力

1、UI設計原型效果2、通過Tare對話生成的效果圖(5分鐘左右)3、查資料做的效果圖(30分鐘左右))通過以上對比,顯然差別不多能滿足要求,只需要在繼續優化就能搞定; 4、Trae生成的源碼&l…

Chessboard and Queens

題目描述Your task is to place eight queens on a chessboard so that no two queens are attacking each other. As an additional challenge, each square is either free or reserved, and you can only place queens on the free squares. However, the reserved squares …

菜鳥教程R語言一二章閱讀筆記

菜鳥教程R語言一二章閱讀筆記 一.R語言基礎教程 R 語言是為數學研究工作者設計的一種數學編程語言,主要用于統計分析、繪圖、數據挖掘。側重于數學工作者 R語言特點如下: R 語言環境軟件屬于 GNU 開源軟件,兼容性好、使用免費 語法十分有利于…

Tactile-VLA:解鎖視覺-語言-動作模型的物理知識,實現觸覺泛化

25年7月來自清華、中科大和上海交大的論文“Tactile-VLA: Unlocking Vision-Language- Action Model’s Physical Knowledge For Tactile Generalization ”。 視覺-語言-動作 (VLA) 模型已展現出卓越的成就,這得益于其視覺-語言組件豐富的隱性知識。然而&#xff0…

HTML初學者第五天

<1>表格標簽1.1基本語法<table><tr><td>單元格內的文字</td>...</tr>... </table>1.<table></table>是用于定義表格的標簽。2.<tr></tr>標簽用于定義表格中的行&#xff0c;必須嵌套在<table></ta…

FastAPI入門:demo、路徑參數、查詢參數

demo from fastapi import FastAPIapp FastAPI()app.get("/") async def root():return {"message": "Hello World"}在終端運行 fastapi dev main.py結果如下&#xff1a;打開http://127.0.0.1:8000&#xff1a;交互式API文檔&#xff1a;位于h…

pytest中的rerunfailures的插件(失敗重試)

目錄 1-- 安裝rerunfailures插件 2-- rerunfailures的使用 3-- 重試案例 安裝rerunfailures插件 pip install pytest-rerunfailures點擊左下角的控制臺面板 輸入 pip install pytest-rerunfailures 出現上圖的情況就算安裝完成了 rerunfailures的使用 可以添加一下參數使用&…

SpringMVC——建立連接

建立連接 將用戶&#xff08;瀏覽器&#xff09;和java程序連接起來&#xff0c;也就是訪問一個地址能夠調用到我們的Spring程序。在 Spring MVC 中使用 RequestMapping來實現URL 路由映射&#xff0c;也就是瀏覽器連接程序的作用。 1.RequestMapping注解介紹 RequestMapping…

蘑菇云路由器使用教程

1: 手機連接路由器的Wi-Fi&#xff0c;在瀏覽器輸入背面IP地址&#xff1a;192.168.132.1進入路由管理界面1.1: 電腦連接路由器網線在瀏覽器輸入背面IP地址&#xff1a;192.168.132.1進入路由管理界面賬號&#xff1a;admin密碼&#xff1a;123456782:選擇上網模式2.1&#xff…

ubuntu的tar解壓指令相關

1. 指令說明參數作用-xextract&#xff0c;解包-z通過 gzip 解壓&#xff08;.tar.gz、.tgz&#xff09;-vverbose&#xff0c;顯示過程-ffile&#xff0c;后面緊跟壓縮包文件名2. 什么時候用z參數場景是否加 -z結果.tar.gz / .tgz? 必須加 -z正常解壓.tar.gz / .tgz? 沒加 -…

車載診斷刷寫 --- Flash關于擦除和寫入大小

我是穿拖鞋的漢子,魔都中堅持長期主義的汽車電子工程師。 老規矩,分享一段喜歡的文字,避免自己成為高知識低文化的工程師: 簡單,單純,喜歡獨處,獨來獨往,不易合同頻過著接地氣的生活,除了生存溫飽問題之外,沒有什么過多的欲望,表面看起來很高冷,內心熱情,如果你身…

【Verilog HDL 入門教程】 —— 學長帶你學Verilog(基礎篇)

文章目錄一、Verilog HDL 概述1、Verilog HDL 是什么2、Verilog HDL產生的背景3、Verilog HDL 和 VHDL的區別二、Verilog HDL 基礎知識1、Verilog HDL 語言要素1.1、命名規則1.2、注釋符1.3、關鍵字1.4、數值1.4.1、整數及其表示1.4.2、實數及其表示1.4.3、字符串及其表示2、數…

SQL Developer Data Modeler:一款免費跨平臺的數據庫建模工具

SQL Developer Data Modeler 是由 Oracle 公司開發的一款免費的圖形化數據建模和數據庫設計工具&#xff0c;用于創建、瀏覽和編輯邏輯模型、關系模型、物理模型、多維模型和數據類型模型。 SQL Developer Data Modeler 既是一個獨立的應用程序&#xff0c;同時也被集成到了 Or…

CSS面試題及詳細答案140道之(21-40)

《前后端面試題》專欄集合了前后端各個知識模塊的面試題&#xff0c;包括html&#xff0c;javascript&#xff0c;css&#xff0c;vue&#xff0c;react&#xff0c;java&#xff0c;Openlayers&#xff0c;leaflet&#xff0c;cesium&#xff0c;mapboxGL&#xff0c;threejs&…