Python 訓練營打卡 Day 34

GPU訓練及類的call方法

一、GPU訓練

與day33采用的CPU訓練不同,今天試著讓模型在GPU上訓練,引入import time比較兩者在運行時間上的差異

import torch
# 設置GPU設備
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"使用設備: {device}")# 仍然用4特征,3分類的鳶尾花數據集作為我們今天的數據集
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
import numpy as np
import time# 加載數據集
iris = load_iris()
x = iris.data
y = iris.target
# 劃分數據集
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=42)
# 歸一化數據,神經網絡對于輸入數據的尺寸敏感,歸一化是最常見的處理方式
from sklearn.preprocessing import MinMaxScaler
scaler = MinMaxScaler()
x_train = scaler.fit_transform(x_train)
x_test = scaler.transform(x_test) #確保訓練集和測試集使用相同的縮放因子
# 將數據轉換為PyTorch張量并移至GPU
# 分類問題交叉熵損失要求標簽為long類型
# 張量具有to(device)方法,可以將張量移動到指定的設備上
x_train = torch.FloatTensor(x_train).to(device)
y_train = torch.LongTensor(y_train).to(device)
x_test = torch.FloatTensor(x_test).to(device)
y_test = torch.LongTensor(y_test).to(device)import torch.nn as nn # 導入PyTorch的神經網絡模塊
import torch.optim as optim # 導入PyTorch的優化器模塊
class MLP(nn.Module): # 定義一個多層感知機(MLP)模型,繼承父類nn.Moduledef __init__(self): # 初始化函數super(MLP, self).__init__() # 調用父類的初始化函數# 定義的前三行是八股文,后面的是自定義的self.fc1 = nn.Linear(4, 10) # 定義第一個全連接層,輸入維度為4,輸出維度為10self.relu = nn.ReLU() # 定義激活函數ReLUself.fc2 = nn.Linear(10, 3) # 定義第二個全連接層,輸入維度為10,輸出維度為3
# 輸出層不需要激活函數,因為后面會用到交叉熵函數cross_entropy,交叉熵函數內部有softmax函數,會把輸出轉化為概率def forward(self, x):out = self.fc1(x) # 輸入x經過第一個全連接層out = self.relu(out) # 激活函數ReLUout = self.fc2(out) # 輸入x經過第二個全連接層return out 
# 實例化模型并移至GPU
# MLP繼承nn.Module類,所以也具有to(device)方法
model = MLP().to(device)
# 定義損失函數和優化器
# 分類問題使用交叉熵損失函數,適用于多分類問題,應用softmax函數將輸出映射到概率分布,然后計算交叉熵損失
criterion = nn.CrossEntropyLoss()
# 使用隨機梯度下降優化器(SGD),學習率為0.01
optimizer = optim.SGD(model.parameters(), lr=0.01)# 開始循環訓練
# 訓練模型
num_epochs = 20000 # 訓練的輪數# 用于存儲每個 epoch 的損失值
losses = []
start_time = time.time() # 記錄開始時間
for epoch in range(num_epochs): # 開始迭代訓練過程,range是從0開始,所以epoch是從0開始# 前向傳播:將數據輸入模型,計算模型預測輸出outputs = model.forward(x_train)   # 顯式調用forward函數# outputs = model(X_train)  # 常見寫法隱式調用forward函數,其實是用了model類的__call__方法loss = criterion(outputs, y_train) # output是模型預測值,y_train是真實標簽,計算兩者之間損失值# 反向傳播和優化optimizer.zero_grad() #梯度清零,因為PyTorch會累積梯度,所以每次迭代需要清零,梯度累計是那種小的bitchsize模擬大的bitchsizeloss.backward() # 反向傳播計算梯度,自動完成以下計算:# 1. 計算損失函數對輸出的梯度# 2. 從輸出層→隱藏層→輸入層反向傳播# 3. 計算各層權重/偏置的梯度optimizer.step() # 更新參數# 記錄損失值losses.append(loss.item())# 打印訓練信息if (epoch + 1) % 100 == 0: # range是從0開始,所以epoch+1是從當前epoch開始,每100個epoch打印一次print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')
time_all = time.time() - start_time
print(f'Training time: {time_all:.2f} seconds')
import matplotlib.pyplot as plt
# 可視化損失曲線
plt.plot(range(num_epochs), losses) 
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss over Epochs')
plt.show()

?結果顯示GPU訓練情況下訓練時間為44.96s,而昨天的CPU訓練時間為14.13s。

本質是因為GPU在計算的時候,相較于CPU多了3個時間上的開銷

1.數據傳輸開銷 (CPU 內存 <-> GPU 顯存)

  • 在 GPU 進行任何計算之前,數據(輸入張量 X_train、y_train,模型參數)需要從計算機的主內存 (RAM) 復制到 GPU 專用的顯存 (VRAM) 中。
  • 當結果傳回 CPU 時(例如,使用 loss.item() 獲取損失值用于打印或記錄,或者獲取最終預測結果),數據也需要從 GPU 顯存復制回 CPU 內存。
  • 對于少量數據和非常快速的計算任務,這個傳輸時間可能比 GPU 通過并行計算節省下來的時間還要長。

在上述代碼中,循環里的 loss.item() 操作會在每個 epoch 都進行一次從 GPU 到 CPU 的數據同步和傳輸,以便獲取標量損失值。對于20000個epoch來說,這會累積不少的傳輸開銷。

2.核心啟動開銷 (GPU 核心啟動時間)

  • GPU 執行的每個操作(例如,一個線性層的前向傳播、一個激活函數)都涉及到在 GPU 上啟動一個“核心”(kernel)——一個在 GPU 眾多計算單元上運行的小程序。
  • 啟動每個核心都有一個小的、固定的開銷。
  • 如果核心內的實際計算量非常小(本項目的小型網絡和鳶尾花數據),這個啟動開銷在總時間中的占比就會比較大。相比之下,CPU 執行這些小操作的“調度”開銷通常更低。

3.性能浪費:計算量和數據批次

  • 這個數據量太少,gpu的很多計算單元都沒有被用到,即使用了全批次也沒有用到的全部計算單元。

綜上,數據傳輸和各種固定開銷的總和,超過了 GPU 在這點計算量上通過并行處理所能節省的時間,導致了 GPU 比 CPU 慢的現象。這些特性導致GPU在處理鳶尾花分類這種“玩具級別”的問題時,它的優勢無法體現,反而會因為上述開銷顯得“笨重”。

那么什么時候 GPU 會發揮巨大優勢?

  • 大型數據集: 例如,圖像數據集成千上萬張圖片,每張圖片維度很高。
  • 大型模型: 例如,深度卷積網絡 (CNNs like ResNet, VGG) 或 Transformer 模型,它們有數百萬甚至數十億的參數,計算量巨大。
  • 合適的批處理大小: 能夠充分利用 GPU 并行性的 batch size,不至于還有剩余的計算量沒有被 GPU 處理。
  • 復雜的、可并行的運算: 大量的矩陣乘法、卷積等。

針對上面反應的3個時間開銷問題,能夠優化的只有數據傳輸時間,針對性解決即可,很容易想到2個思路:

  1. 直接不打印訓練過程的loss了,但是這樣會沒辦法記錄最后的可視化圖片,只能肉眼觀察loss數值變化
  2. 每隔200個epoch保存一下loss,不需要20000個epoch每次都打印

經試驗,思路一去除圖片打印部分的情況下訓練時間為:36.77s,思路二改變保存間隔情況下的訓練時長為:44.6s

二、__call__方法

在 Python 中,__call__方法是一個特殊的魔術方法,它允許類的實例像函數一樣被調用。這種特性使得對象可以表現得像函數,同時保留對象的內部狀態。

# 不帶參數的call方法
class Counter:def __init__(self):self.count = 0def __call__(self):self.count += 1return self.count# 使用示例
counter = Counter() # 實例化對象,通過__init__方法初始化count為0
print(counter())  # 調用__call__,輸出: 1
print(counter())  # counter=1再代入call方法,輸出: 2
print(counter.count)  # 輸出: 2# 帶參數的call方法
class Adder:def __call__(self, a, b):print("唱跳籃球rap")return a + badder = Adder()
print(adder(3, 5))  
# 輸出:
# 唱跳籃球rap 
# 8

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/82302.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/82302.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/82302.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Ubuntu22.04 系統安裝Docker教程

1.更新系統軟件包 #確保您的系統軟件包是最新的。這有助于避免安裝過程中可能遇到的問題 sudo apt update sudo apt upgrade -y 2.安裝必要的依賴 sudo apt install apt-transport-https ca-certificates curl software-properties-common -y 3.替換軟件源 原來/etc/apt/s…

深入解析前端 JSBridge:現代混合開發的通信基石與架構藝術

引言&#xff1a;被低估的通信革命 在移動互聯網爆發式增長的十年間&#xff0c;Hybrid App&#xff08;混合應用&#xff09;始終占據著不可替代的地位。作為連接 Web 與 Native 的神經中樞&#xff0c;JSBridge 的設計質量直接決定了應用的性能上限與開發效率。本文將突破傳…

ES 面試題系列「三」

1、在設計 Elasticsearch 索引時&#xff0c;如何考慮數據的建模和映射&#xff1f; 需要根據業務需求和數據特點來確定索引的結構。首先要分析數據的類型&#xff0c;對于結構化數據&#xff0c;如數字、日期等&#xff0c;要明確其數據格式和范圍&#xff0c;選擇合適的字段…

HTML5快速入門-常用標簽及其屬性(三)

HTML5快速入門-常用標簽及其屬性(三) 文章目錄 HTML5快速入門-常用標簽及其屬性(三)音視頻標簽&#x1f3a7; <audio> 標簽 — 插入音頻使用 <source> 提供多格式備選&#xff08;提高兼容性&#xff09;&#x1f3a5; <video> 標簽 — 插入視頻&#x1f3b5…

Qt文件:XML文件

XML文件 1. XML文件結構1.1 基本結構1.2 XML 格式規則1.3 XML vs HTML 2. XML文件操作2.1 DOM 方式&#xff08;QDomDocument&#xff09;讀取 XML寫入XML 2.2 SAX 方式&#xff08;QXmlStreamReader/QXmlStreamWriter&#xff09;讀取XML寫入XML 2.3 對比分析 3. 使用場景3.1 …

day24Node-node的Web框架Express

1. Express 基礎 1.1 什么是Express node的web框架有Express 和 Koa。常用Express 。 Express 是一個基于 Node.js 的快速、極簡的 Web 應用框架,用于構建 服務器端應用(如網站后端、RESTful API 等)。它是 Node.js 生態中最流行的框架之一,以輕量、靈活和易用著稱。 …

uniapp實現的簡約美觀的票據、車票、飛機票模板

采用 uniapp 實現的一款簡約美觀的票據模板&#xff0c;純CSS、HTML實現&#xff0c;用戶完全可根據自身需求進行更改、擴展&#xff1b;支持web、H5、微信小程序&#xff08;其他小程序請自行測試&#xff09;&#xff0c; 可到插件市場下載嘗試&#xff1a; https://ext.dclo…

esp32+IDF V5.1.1版本編譯freertos報錯

error: portTICK_RATE_MS undeclared (first use in this function); did you mean portTICK_PERIOD_MS 解決方法: 使用命令 idf.py menuconfig 打開配置界面配置freeRtos 使能configENABLE_BACKWARD_COMPATIBLITY

vue 水印組件

Watermark.vue <script setup lang"ts"> import { ref, onMounted, onUnmounted, watch } from vue;interface Props {text?: string;fontSize?: number;color?: string;rotate?: number;zIndex?: number;gap?: number; }const props withDefaults(def…

hbuilder中h5轉為小程序提交發布審核

【注意】 [HBuilder] 11:59:15.179 此應用 DCloud appid 為 __UNI__9F9CC77 &#xff0c;您不是這個應用的項目成員。1、聯系這個應用的所有者&#xff0c;請求加入項目成員&#xff08;https://dev.dcloud.net.cn "成員管理"-"添加項目成員"&#xff09;…

QT之INI、JSON、XML處理

文章目錄 INI文件處理寫配置文件讀配置文件 JSON 文件處理寫入JSON讀取JSON XML文件處理寫XML文件讀XML文件 INI文件處理 首先得引入QSettings QSettings 是用來存儲和讀取應用程序設置的一個類 #include "wrinifile.h"#include <QSettings> #include <QtD…

道德經總結

道德經 《道德經》是中國古代偉大哲學家老子所著&#xff0c;全書約五千字&#xff0c;共81章&#xff0c;分為“道經”&#xff08;1–37章&#xff09;和“德經”&#xff08;38–81章&#xff09;兩部分。 《道德經》是一部融合哲學、政治、人生智慧于一體的經典著作。它提…

行為型:迭代器模式

目錄 1、核心思想 2、實現方式 2.1 模式結構 2.2 實現案例 3、優缺點分析 4、適用場景 1、核心思想 目的&#xff1a;將遍歷邏輯與數據存儲結構解耦 概念&#xff1a;提供一種機制來按順序訪問集合中的各元素&#xff0c;而不需要知道集合內部的構造 舉例&#xff1a;…

人臉識別技術合規備案最新政策詳解

《人臉識別技術應用安全管理辦法》將于2025年6月1日正式實施&#xff0c;該辦法從技術應用、個人信息保護、技術替代、監管體系四方面構建了人臉識別技術的治理框架&#xff0c;旨在平衡技術發展與安全風險。 一、明確技術應用的邊界 公共場所使用限制&#xff1a;僅在“維護公…

如何把vue項目部署在nginx上

1&#xff1a;在vscode中把vue項目打包會出現dist文件夾 按照圖示內容即可把vue項目部署在nginx上

奇好 PDF安全加密 + 自由拆分合并批量處理 OCR 識別

各位辦公小能手們&#xff0c;你們好呀&#xff01;今天我要給大家介紹一款超厲害的軟件——奇好PDF。它就像是一個PDF文檔處理的超級大管家&#xff0c;啥功能都有&#xff0c;格式轉換、編輯、提取、安全保護這些統統不在話下&#xff0c;不管是辦公、學習&#xff0c;還是設…

Docker-Harbor 私有鏡像倉庫使用指南

1.用戶管理 為項目創建專用用戶&#xff0c;并配置權限&#xff0c;確保該用戶能夠順利推送鏡像到 Harbor 倉庫&#xff0c;確保鏡像推送操作的安全性和便捷性。 創建完成后可以根據需要選擇是否設置為管理員 角色 權限描述 適用場景 系統管理員 擁有系統的完全控制權限 運維…

HomeAssistant開源的智能家居docker快速部署實踐筆記(CentOS7)

1. SGCC_Electricity 應用介紹 SGCC_Electricity 是一個用于將國家電網&#xff08;State Grid Corporation of China&#xff0c;簡稱 SGCC&#xff09;的電費和用電量數據接入 Home Assistant 的自定義集成組件。通過該應用&#xff0c;用戶可以實時追蹤家庭用電量情況&…

maven 3.0多線程編譯提高編譯速度

mvn package 默認只使用 單線程 來執行構建生命周期&#xff08;即順序地構建每一個模塊&#xff09;。 如果你使用的是多模塊項目&#xff0c;Maven 從 3.0 開始提供了**并行構建&#xff08;parallel build&#xff09;**的能力&#xff0c;但它不是默認開啟的。 如何啟用多…

python模塊管理環境變量

概要 在 Python 應用中&#xff0c;為了將配置信息與代碼分離、增強安全性并支持多環境&#xff08;開發、測試、生產&#xff09;運行&#xff0c;使用專門的模塊來管理環境變量是最佳實踐。常見工具包括&#xff1a; 標準庫 os.environ&#xff1a;直接讀取操作系統環境變量…