pytorch -- GPU優化寫法套路

1. GPU優化的點

網絡模型
數據(輸入、標注)
損失函數

  1. .cuda方式
    代碼:
import torch.optim
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter# 1. 準備數據集
train_data = torchvision.datasets.CIFAR10('data',train=True,transform=torchvision.transforms.ToTensor(),download=True)
test_data = torchvision.datasets.CIFAR10('data',train=False,transform=torchvision.transforms.ToTensor(),download=True)
# 數據集大小
train_data_size = len(train_data)
test_data_size = len(test_data)
print('訓練數據集的長度為{}'.format(train_data_size))
print('測試數據集的長度為{}'.format(test_data_size))# 2 利用DataLoader加載數據集
train_dataloader = DataLoader(train_data,batch_size=64)
test_dataloader = DataLoader(test_data,batch_size=64)# 3 搭建神經網絡
# 3 搭建神經網絡
class Tudui(nn.Module):def __init__(self):super().__init__()self.model = nn.Sequential(nn.Conv2d(3,32,5,1,2),nn.MaxPool2d(2),nn.Conv2d(32, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 64, 5, 1, 2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(1024,64),nn.Linear(64, 10))def forward(self,x):x = self.model(x)return x# 4 創建網絡模型
tudui = Tudui()
# --------------------------
if torch.cuda.is_available():tudui = tudui.cuda()# 5 損失函數
loss_fn = nn.CrossEntropyLoss()
# ---------------------------
if torch.cuda.is_available():loss_fn = loss_fn.cuda()# 6 優化器 1e-2=1x10^(-2)
learning_rate = 0.01
optimizer = torch.optim.SGD(tudui.parameters(),lr=learning_rate)# 7 設置訓練網絡的一些參數
total_train_step = 0 # 記錄訓練次數
total_test_step = 0 # 記錄測試次數
epoch = 10 #訓練輪數
# 添加tensorboard
writer = SummaryWriter('logs_model')
for i in range(epoch):print('-----------第{}輪訓練開始-----------'.format(i+1))# 訓練開始# 訓練步驟開始 dropout batchNorm僅對某些層次有作用tudui.train()for data in train_dataloader:imgs, targets = data# ---------------------------if torch.cuda.is_available():imgs = imgs.cuda()targets = targets.cuda()output = tudui(imgs) #訓練模型的預測輸出loss = loss_fn(output,targets)# 優化器優化模型optimizer.zero_grad()loss.backward()optimizer.step()total_train_step += 1if total_train_step % 100 == 0:print('訓練次數是{}時,loss是{}'.format(total_train_step,loss.item()))# 加了item() tensor變成了數字writer.add_scalar('train_loss',loss.item(),total_train_step)# 訓練完一輪,看是否訓練好,有沒有達到想要的需求,測試數據集中跑一篇看準確率或者損失# 測試步驟開始tudui.eval()total_test_loss = 0total_accuracy = 0# 測試不需要對梯度進行調整with torch.no_grad():for data in test_dataloader:imgs,targets = data# ---------------------------if torch.cuda.is_available():imgs = imgs.cuda()targets = targets.cuda()outputs = tudui(imgs)loss = loss_fn(outputs,targets)total_test_loss += loss.item()accuracy = (outputs.argmax(1) == targets).sum()total_accuracy += accuracyprint('整體測試集上的loss是{}'.format(total_test_loss))print('整體測試集上的正確率是{}'.format(total_accuracy/test_data_size))writer.add_scalar('test_loss',total_test_loss,total_test_step)writer.add_scalar('test_accuracy', total_accuracy, total_test_step)total_test_step+=1torch.save(tudui,'tudui_{}.pth'.format(i))print('模型已保存')writer.close()

3…to(device)方式

device = torch.device("cpu")
# 第一張顯卡
torch.device("cuda")
torch.device("cuda:0")
# 第二張
torch.device("cuda:1")

代碼:

import torch.optim
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
# 定義訓練的設備
device = torch.device('cuda')# 1. 準備數據集
train_data = torchvision.datasets.CIFAR10('data',train=True,transform=torchvision.transforms.ToTensor(),download=True)
test_data = torchvision.datasets.CIFAR10('data',train=False,transform=torchvision.transforms.ToTensor(),download=True)
# 數據集大小
train_data_size = len(train_data)
test_data_size = len(test_data)
print('訓練數據集的長度為{}'.format(train_data_size))
print('測試數據集的長度為{}'.format(test_data_size))# 2 利用DataLoader加載數據集
train_dataloader = DataLoader(train_data,batch_size=64)
test_dataloader = DataLoader(test_data,batch_size=64)# 3 搭建神經網絡
# 3 搭建神經網絡
class Tudui(nn.Module):def __init__(self):super().__init__()self.model = nn.Sequential(nn.Conv2d(3,32,5,1,2),nn.MaxPool2d(2),nn.Conv2d(32, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 64, 5, 1, 2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(1024,64),nn.Linear(64, 10))def forward(self,x):x = self.model(x)return x# 4 創建網絡模型
tudui = Tudui()
# --------------------------
# if torch.cuda.is_available():
#     tudui = tudui.cuda()
tudui = tudui.to(device)# 5 損失函數
loss_fn = nn.CrossEntropyLoss()
# ---------------------------
# if torch.cuda.is_available():
#     loss_fn = loss_fn.cuda()
loss_fn = loss_fn.to(device)# 6 優化器 1e-2=1x10^(-2)
learning_rate = 0.01
optimizer = torch.optim.SGD(tudui.parameters(),lr=learning_rate)# 7 設置訓練網絡的一些參數
total_train_step = 0 # 記錄訓練次數
total_test_step = 0 # 記錄測試次數
epoch = 10 #訓練輪數
# 添加tensorboard
writer = SummaryWriter('logs_model')
for i in range(epoch):print('-----------第{}輪訓練開始-----------'.format(i+1))# 訓練開始# 訓練步驟開始 dropout batchNorm僅對某些層次有作用tudui.train()for data in train_dataloader:imgs, targets = data# ---------------------------# if torch.cuda.is_available():#     imgs = imgs.cuda()#     targets = targets.cuda()imgs = imgs.to(device)targets = targets.to(device)output = tudui(imgs) #訓練模型的預測輸出loss = loss_fn(output,targets)# 優化器優化模型optimizer.zero_grad()loss.backward()optimizer.step()total_train_step += 1if total_train_step % 100 == 0:print('訓練次數是{}時,loss是{}'.format(total_train_step,loss.item()))# 加了item() tensor變成了數字writer.add_scalar('train_loss',loss.item(),total_train_step)# 訓練完一輪,看是否訓練好,有沒有達到想要的需求,測試數據集中跑一篇看準確率或者損失# 測試步驟開始tudui.eval()total_test_loss = 0total_accuracy = 0# 測試不需要對梯度進行調整with torch.no_grad():for data in test_dataloader:imgs,targets = data# ---------------------------# if torch.cuda.is_available():#     imgs = imgs.cuda()#     targets = targets.cuda()imgs = imgs.to(device)targets = targets.to(device)outputs = tudui(imgs)loss = loss_fn(outputs,targets)total_test_loss += loss.item()accuracy = (outputs.argmax(1) == targets).sum()total_accuracy += accuracyprint('整體測試集上的loss是{}'.format(total_test_loss))print('整體測試集上的正確率是{}'.format(total_accuracy/test_data_size))writer.add_scalar('test_loss',total_test_loss,total_test_step)writer.add_scalar('test_accuracy', total_accuracy, total_test_step)total_test_step+=1torch.save(tudui,'tudui_{}.pth'.format(i))print('模型已保存')writer.close()

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/710979.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/710979.shtml
英文地址,請注明出處:http://en.pswp.cn/news/710979.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

C++實現XOR加解器

#include <Windows.h> #include <iostream> #include <fstream> #include <string>// 加解密函數&#xff0c;使用XOR運算 void XORCrypt(char* data, int size, const std::string& key) {int keyLength key.length();for (int i 0; i < siz…

日志系統項目實現

日志系統的功能也就是將一條消息格式化后寫入到指定位置&#xff0c;這個指定位置一般是文件&#xff0c;顯示器&#xff0c;支持拓展到數據庫和服務器&#xff0c;后面我們就知道如何實現拓展的了&#xff0c;支持不同的寫入方式(同步異步)&#xff0c;同步:業務線程自己寫到文…

萬卡集群:字節搭建12288塊GPU的單一集群

文章目錄 論文Reference 論文 MegaScale: Scaling Large Language Model Training to More Than 10,000 GPUs 論文鏈接&#xff1a;https://arxiv.org/abs/2402.15627 從結構上講&#xff0c;網絡是基于Clos的“胖樹”結構。其中一個改進是在頂層交換機上把上行與下行鏈路分開&…

三、《任務列表案例》前端程序搭建和運行

本章概要 整合案例介紹和接口分析 案例功能預覽接口分析 前端工程導入 前端環境搭建導入前端程序 啟動測試 3.1 整合案例介紹和接口分析 3.1.1 案例功能預覽 3.1.2 接口分析 學習計劃分頁查詢 /* 需求說明查詢全部數據頁數據 請求urischedule/{pageSize}/{currentPage} 請…

stm32觸發硬件錯誤位置定位

1.背景 1. 項目中&#xff0c;調試過程或者測試中都會出現程序跑飛問題&#xff0c;這個時候問題特別難查找。 2. 觸發硬件錯誤往往是因為內存錯誤。這種問題特別難查找&#xff0c;尤其是產品到了測試階段&#xff0c;而這個異常復現又比較難的情況下&#xff0c;簡直頭疼。…

初學JavaScript總結

0 JavaScript html完成了架子&#xff0c;css做了美化&#xff0c;但是網頁是死的&#xff0c;需要給他注入靈魂&#xff0c;所以接下來需要學習JavaScript&#xff0c;這門語言會讓頁面能夠和用戶進行交互。JavaScript又稱為腳本語言&#xff0c;可以通過腳本實現用戶和頁面的…

每日shell腳本之打印99乘法表

每日shell腳本之打印99乘法表 #!/bin/bash for i in $(seq 1 9); dofor j in $(seq 1 9); doecho -n "$i * $j $(($i * $j)) "doneecho done

Programming Abstractions in C閱讀筆記:p306-p307

《Programming Abstractions in C》學習第75天&#xff0c;p306-p307總結&#xff0c;總計2頁。 一、技術總結 1.Quicksort algorithm(快速排序) 由法國計算機科學家C.A.R(Charles Antony Richard) Hoare&#xff08;東尼.霍爾&#xff09;在1959年開發(develop), 1961年發表…

Mac 制作可引導安裝器

Mac 使用U盤或移動固態硬盤制作可引導安裝器&#xff08;以 Monterey 為例&#xff09; 本教程參考 Apple 官網相關教程 創建可引導 Mac OS 安裝器 重新安裝 Mac OS 相關名詞解釋 磁盤分區會將其劃分為多個單獨的部分&#xff0c;稱為分區。分區也稱為容器&#xff0c;不同容器…

VR虛擬現實技術應用到豬抗原體檢測的好處

利用VR虛擬仿真技術開展豬瘟檢測實驗教學確保生豬產業健康發展 為了有效提高豬場豬瘟防控意識和檢測技術&#xff0c;避免生豬養殖業遭受豬瘟危害&#xff0c;基于VR虛擬仿真技術開展豬瘟檢測實驗教學數據能大大推動基層畜牧養殖業持續穩步發展保駕護航。 一、提高實驗效率 VR虛…

鯤鵬arm64架構下安裝KubeSphere

鯤鵬arm64架構下安裝KubeSphere 官方參考文檔: https://kubesphere.io/zh/docs/quick-start/minimal-kubesphere-on-k8s/ 在Kubernetes基礎上最小化安裝 KubeSphere 前提條件 官方參考文檔: https://kubesphere.io/zh/docs/installing-on-kubernetes/introduction/prerequi…

基于大模型思維鏈(Chain-of-Thought)技術的定制化思維鏈提示和定向刺激提示的心理咨詢場景定向ai智能應用

本篇為個人筆記 記錄基于大模型思維鏈&#xff08;Chain-of-Thought&#xff09;技術的定制化思維鏈提示和定向刺激提示的心理咨詢場景定向ai智能應用 人工智能為個人興趣領域 業余研究 如有錯漏歡迎指出&#xff01;&#xff01;&#xff01; 目錄 本篇為個人筆記 記錄基…

價格腰斬,騰訊云2024優惠活動云服務器62元一年,多配置報價

騰訊云服務器多少錢一年&#xff1f;62元一年起&#xff0c;2核2G3M配置&#xff0c;騰訊云2核4G5M輕量應用服務器218元一年、756元3年&#xff0c;4核16G12M服務器32元1個月、312元一年&#xff0c;8核32G22M服務器115元1個月、345元3個月&#xff0c;騰訊云服務器網txyfwq.co…

Node.js中的并發和多線程處理

在Node.js中&#xff0c;處理并發和多線程是一個非常重要的話題。由于Node.js是單線程的&#xff0c;這意味著它在任何給定時間內只能執行一個任務。然而&#xff0c;Node.js的事件驅動和非阻塞I/O模型使得處理并發和多線程變得更加高效和簡單。在本文中&#xff0c;我們將探討…

【排坑】搭建 Karmada 環境

git clone 報錯 問題&#xff1a;Failed to connect to github.com port 443:connection timed out 解決&#xff1a; git config --global --unset http.proxy【hack/local-up-karmada.sh】 1. karmada ca-certificates (no such package) 問題&#xff1a;fetching http…

老化的電動車與高層電梯樓的結合,將是巨大的安全隱患

中國是全球最大的電動汽車市場&#xff0c;其實中國還是全球最大的電動兩輪車市場&#xff0c;而電動兩輪車的老化比電動汽車更快&#xff0c;電動汽車的電池壽命可以達到10年&#xff0c;而電動兩輪車的電池壽命只有3-5年&#xff0c;而首批電動兩輪車至今已老化得相當嚴重&am…

【Pytorch深度學習開發實踐學習】【AlexNet】經典算法復現-Pytorch實現AlexNet神經網絡(1)model.py

算法簡介 AlexNet是人工智能深度學習在CV領域的開山之作&#xff0c;是最先把深度卷積神經網絡應用于圖像分類領域的研究成果&#xff0c;對后面的諸多研究起到了巨大的引領作用&#xff0c;因此有必要學習這個算法并能夠實現它。 主要的創新點在于&#xff1a; 首次使用GPU…

AI語音識別的技術解析

從語音識別算法的發展來看&#xff0c;語音識別技術主要分為三大類&#xff0c;第一類是模型匹配法&#xff0c;包括矢量量化(VQ) 、動態時間規整(DTW)等&#xff1b;第二類是概率統計方法&#xff0c;包括高斯混合模型(GMM) 、隱馬爾科夫模型(HMM)等&#xff1b;第三類是辨別器…

golang gin單獨部署vue3.0前后端分離應用

概述 因為公司最近的項目前端使用vue 3.0&#xff0c;后端api使用golang gin框架。測試通過后&#xff0c;博文記錄&#xff0c;用于備忘。 步驟 npm run build&#xff0c;構建出前端項目的dist目錄&#xff0c;dist目錄的結構具體如下圖 將dist目錄復制到后端程序同級目錄…

嵌入式軟件bug從哪里來,到哪里去

摘要&#xff1a;軟件從來不是一次就能完美的&#xff0c;需要以包容的眼光看待它的殘缺。那問題究竟為何產生&#xff0c;如何去除呢&#xff1f; 1、軟件問題從哪來 軟件缺陷問題千千萬萬&#xff0c;主要是需求、實現、和運行環境三方面。 1.1 需求描述偏差 客戶角度的描…