Pytorch-07 完整訓練測試過程

要在PyTorch中使用GPU進行數據集的加載、模型的訓練和最后模型的測試,需要將數據集和模型都移動到GPU上,并確保在訓練和測試過程中都在GPU上進行計算。以下是一個完整的示例代碼,展示了如何在PyTorch中使用GPU進行端到端的訓練和測試:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset# 檢查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 準備訓練和測試數據,并將其移動到GPU
train_input = torch.randn(100, 10).to(device)
train_target = torch.randn(100, 1).to(device)
test_input = torch.randn(20, 10).to(device)
test_target = torch.randn(20, 1).to(device)# 創建數據集和數據加載器
train_dataset = TensorDataset(train_input, train_target)
train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True)# 定義一個簡單的神經網絡模型,并將其移動到GPU
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.fc1 = nn.Linear(10, 5)self.relu = nn.ReLU()self.fc2 = nn.Linear(5, 1)def forward(self, x):x = self.fc1(x)x = self.relu(x)x = self.fc2(x)return xmodel = SimpleModel().to(device)# 定義損失函數和優化器
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)# 訓練模型
model.train()
for epoch in range(100):for input_data, target_data in train_loader:optimizer.zero_grad()output = model(input_data)loss = criterion(output, target_data)loss.backward()optimizer.step()# 測試模型
model.eval()
with torch.no_grad():test_output = model(test_input)test_loss = criterion(test_output, test_target)print(f'Test Loss: {test_loss.item()}')

在這個示例中,我們首先檢查GPU是否可用,并將訓練和測試數據移動到GPU上。然后,我們創建了數據集和數據加載器,定義了神經網絡模型,并將模型移動到GPU。在訓練過程中,我們使用數據加載器加載數據進行訓練;在測試過程中,我們使用model.eval()將模型切換為評估模式,并使用torch.no_grad()上下文管理器關閉梯度計算,以避免在測試過程中更新模型參數。最后,我們計算了模型在測試集上的損失。整個訓練和測試過程都在GPU上進行,以加速計算和提高效率。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/17935.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/17935.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/17935.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

六月后考研如何備考看這一篇就夠了

以下是考研六月后可以參考的規劃: 6 月至 8 月(強化階段): 英語:繼續背單詞,開始刷歷年真題中的閱讀部分,仔細分析錯題原因,總結解題技巧。數學:完成基礎階段的復習后&am…

接口作為返回類型與類作為返回類型一樣嗎?

LinkedList<Integer> q new LinkedList<>();和Queue<Integer> q new LinkedList<>();一樣嗎&#xff1f; 我現在想創建一個隊列對象&#xff0c;正常情況下我會這樣寫&#xff1a;Queue<Integer> q new Queue<>(); 但是你仔細想想&am…

使用chatglm.cpp本地部署ChatGLM3-6B模型

ChatGLM3模型介紹 ChatGLM3-6B 是 ChatGLM 系列最新一代的開源模型&#xff0c;在保留了前兩代模型對話流暢、部署門檻低等眾多優秀特性的基礎上&#xff0c;ChatGLM3-6B 引入了如下特性&#xff1a; 更強大的基礎模型&#xff1a; ChatGLM3-6B 的基礎模型 ChatGLM3-6B-Base …

Yourpassword does not satisfy the current policyrequirements

mysql 新增數據庫用戶失敗 解決方法&#xff1a; 修改校驗密碼策略等級 set global validate_password.policyLOW;

dataguard 備庫關閉后啟動流程

startup mount&#xff1b; ---開啟adg alter database recover managed standby database using current logfile disconnect from session; -- alter database recover managed standby database cancel; alter database recover managed standby database disconnect…

C++課程設計實驗杭州電子科技大學ACM題目(上)

題目一&#xff1a;2013.蟠桃季 題目描述 Problem Description&#xff1a;喜歡西游記的同學肯定都知道悟空偷吃蟠桃的故事&#xff0c;你們一定都覺得這猴子太鬧騰了&#xff0c;其實你們是有所不知&#xff1a;悟空是在研究一個數學問題&#xff01;什么問題&#xff1f;他…

【面試】PWM(脈沖寬度調制)相關問題 ——長期更新

1、PWM調節原理 答&#xff1a;通過改變信號的高電平和低電平的持續時間比例來控制輸出信號的平均功率或電壓。 2、PWM占空比定義 答&#xff1a;在一個脈沖周期內&#xff0c;高電平的時間占整個周期時間的比例。 3、PWM波形的周期和調節精度由誰決定 答&#xff1a;當計數…

全同態加密生態項目盤點:FHE技術的崛起以及應用

撰文&#xff1a;Chris&#xff0c;Techub News 在當今數字化的時代&#xff0c;隱私保護已成為一個全球性的焦點話題&#xff0c;特別是在加密貨幣和區塊鏈技術快速發展的背景下。雖然當前的隱私技術在保護數據安全方面多有欠缺&#xff0c;引發了廣泛的關注和批評&#xff0c…

BUUCTF-WEB3

[極客大挑戰 2019]Knife1 1.打開附件鏈接 一句話木馬eval($_POST["Syc"]); 2.中國蟻劍 用中國蟻劍連接 在根目錄下找到一個名為flag的文件 3.得到flag [極客大挑戰 2019]Upload1 1.打開附件鏈接 是一個文件上傳 2.一句話木馬 經過多次嘗試都被繞過&#xff0c;更…

【MySQL】數據庫的開始

前言 數據庫是我們學習編程中一個非常重要的內容&#xff0c;像一些什么什么管理系統&#xff0c;如果想要存儲數據都是需要連接數據庫的。博主之前寫過一篇圖書管理系統的博客&#xff0c;那時的我還沒接觸過數據庫&#xff0c;所有的數據都是現成創建的&#xff0c;感興趣的…

JavaScript面試 題

1.延時加載JS有哪些方式 延時加載 :async defer 例如:<script defer type"type/javascript" srcscript.js></ script> defer:等html全部解析完成,才會執行js代碼,順次執行的 async: js和html解析是同步的,不是順次執行js腳本(誰先加載完先執行誰)2.JS數…

【C++】菱形繼承、菱形虛擬繼承、繼承與組合

目錄 01.概念 02.虛擬繼承 原理 03.繼承和組合 01.概念 單繼承&#xff1a; 一個子類只有一個父類時&#xff0c;稱這種繼承關系為單繼承。 多繼承&#xff1a; 一個子類同時有兩個及以上的父類時&#xff0c;稱這種繼承關系為多繼承。 菱形繼承&#xff1a; 菱形繼承是…

一文搞懂oracle事務提交以及臟數據落盤的原則

本文基于oracle 19c 做事務提交以及oracle臟數據落盤的相關解讀 第一章 相關進程及組件介紹&#xff1a; 1.LGWR&#xff1a; 重做日志條目在系統全局區域 &#xff08;SGA&#xff09; 的重做日志緩沖區中生成。LGWR 按順序將重做日志條目寫入重做日志文件。如果數據庫具有…

【MySQL精通之路】MySQL的使用(3)-命令行連接

本節介紹使用命令行選項來指定如何為MySQL或mysqldump等客戶端建立到MySQL服務器的連接。 有關使用類似URI的連接字符串或鍵值對建立連接的信息&#xff0c;對于MySQL Shell等客戶端&#xff0c;請參閱“使用類似URI字符串或鍵值配對連接到服務器”。 有關無法連接的其他信息&a…

期望薪資26K,北京瘋狂游戲golang一面

北京瘋狂游戲一面 1、自我介紹 2、財務業務中&#xff0c;你做了哪些設計來保證金額數據的準確性&#xff1f;&#xff08;例如&#xff0c;業務涉及多步驟&#xff0c;某一步出了問題怎么解決&#xff09; 3、如何解決單個業務直接報錯的數據準確性問題 4、分布式場景下&a…

理解Vue 3響應式系統原理

title: 理解Vue 3響應式系統原理 date: 2024/5/28 15:44:47 updated: 2024/5/28 15:44:47 categories: 前端開發 tags: Vue3.xTypeScriptSFC優化Composition-APIRef&Reactive性能提升響應式原理 第一章&#xff1a;Vue 3簡介 1.1 Vue 3概述 Vue 3的誕生背景&#xff1…

怎么把電腦上的文件傳到手機上?可保存文檔的云筆記

在職場中&#xff0c;我們經常需要將電腦上的重要文件、資料傳到手機上&#xff0c;以便隨時查閱和使用。比如&#xff0c;當你在公司完成了一份關鍵報告&#xff0c;但即將外出與客戶溝通&#xff0c;這時如果能將報告傳到手機上&#xff0c;就能在移動中隨時準備應對客戶的咨…

uniapp Androud 離線打包升級APK,覆蓋安裝不更新問題

Android 打包時在assets/data/dcloud_control.xml文件中&#xff0c;如果配置debug"true" syncDebug"true"&#xff0c;則consle打印有效&#xff0c;不然沒有打印數據 <hbuilder debug"true" syncDebug"true"> <apps> …

破解App渠道歸因難題,Xinstall助你實現精準數據追蹤!

在移動互聯網時代&#xff0c;App的推廣和運營面臨著諸多挑戰。其中&#xff0c;渠道歸因問題一直困擾著眾多推廣者。如何準確追蹤用戶來源&#xff0c;分析不同渠道的推廣效果&#xff0c;成為了擺在推廣者面前的一大難題。然而&#xff0c;有了Xinstall的出現&#xff0c;這一…

C++網絡編程——實現一個簡單的echo服務器

在前面講完了服務器從建立套接字、綁定、監聽和提取&#xff0c;以及客戶端的連接&#xff0c;我們已經可以動手實現一個簡單的鏡像服務器。 錯誤處理 在那之前&#xff0c;我們先封裝一個錯誤處理函數 errif 可以定義一個uitl.cpp放里面&#xff0c;需要的地方引用即可 ut…