PyTorch 中如何針對 GPU 和 TPU 使用不同的處理方式

一個簡單的矩陣乘法例子來演示在 PyTorch 中如何針對 GPU 和 TPU 使用不同的處理方式。

這個例子會展示核心的區別在于如何獲取和指定計算設備,以及(對于 TPU)可能需要額外的庫和同步操作。

示例代碼:

import torch
import time# --- GPU 示例 ---
print("--- GPU 示例 ---")
# 檢查是否有可用的 GPU (CUDA)
if torch.cuda.is_available():gpu_device = torch.device('cuda')print(f"檢測到 GPU。使用設備: {gpu_device}")# 創建張量并移動到 GPU# 在張量創建時直接指定 device='cuda' 或 .to('cuda')tensor_a_gpu = torch.randn(1000, 2000, device=gpu_device)tensor_b_gpu = torch.randn(2000, 1500, device=gpu_device)# 在 GPU 上執行矩陣乘法start_time = time.time()result_gpu = torch.mm(tensor_a_gpu, tensor_b_gpu)torch.cuda.synchronize() # 等待 GPU 計算完成end_time = time.time()print(f"在 GPU 上執行了矩陣乘法,結果張量大小: {result_gpu.shape}")print(f"GPU 計算耗時: {end_time - start_time:.4f} 秒")# print(result_gpu) # 可以打印結果,但對于大張量會很多else:print("未檢測到 GPU。無法運行 GPU 示例。")# --- TPU 示例 ---
print("\n--- TPU 示例 ---")
# 導入 PyTorch/XLA 庫
# 注意:這個庫需要在支持 TPU 的環境 (如 Google Colab TPU runtime 或 Cloud TPU VM) 中安裝和運行
try:import torch_xlaimport torch_xla.core.xla_model as xmimport torch_xla.distributed.parallel_loader as plimport torch_xla.distributed.xla_multiprocessing as xmp# 檢查是否在 XLA (TPU) 環境中if xm.xla_device() is not None:IS_TPU_AVAILABLE = Trueelse:IS_TPU_AVAILABLE = Falseexcept ImportError:print("未找到 torch_xla 庫。")IS_TPU_AVAILABLE = False
except Exception as e:print(f"初始化 torch_xla 失敗: {e}")IS_TPU_AVAILABLE = Falseif IS_TPU_AVAILABLE:# 獲取 TPU 設備tpu_device = xm.xla_device()print(f"檢測到 TPU。使用設備: {tpu_device}")# 創建張量并移動到 TPU (通過 XLA 設備)# 在張量創建時直接指定 device=tpu_device 或 .to(tpu_device)# 注意:TPU 操作通常是惰性的,數據和計算可能會在 xm.mark_step() 或其他同步點時才實際執行tensor_a_tpu = torch.randn(1000, 2000, device=tpu_device)tensor_b_tpu = torch.randn(2000, 1500, device=tpu_device)# 在 TPU 上執行矩陣乘法 (通過 XLA)start_time = time.time()result_tpu = torch.mm(tensor_a_tpu, tensor_b_tpu)# 觸發執行和同步 (TPU 操作通常是惰性的,需要顯式步驟來編譯和執行)# 在實際訓練循環中,通常在一個 minibatch 結束時調用 xm.mark_step()xm.mark_step()# 注意:TPU 的時間測量可能需要通過特定 XLA 函數,這里使用簡單的 time() 可能不精確反映 TPU 計算時間end_time = time.time()print(f"在 TPU 上執行了矩陣乘法,結果張量大小: {result_tpu.shape}")#print(f"TPU (包含編譯和同步) 耗時: {end_time - start_time:.4f} 秒") # 這里的計時僅供參考# print(result_tpu) # 可以打印結果else:print("無法運行 TPU 示例,因為未找到 torch_xla 庫 或 不在 TPU 環境中。")print("要在 Google Colab 中運行 TPU 示例,請在 'Runtime' -> 'Change runtime type' 中選擇 TPU。")

代碼解釋:

  1. 導入: 除了 torch,GPU 示例不需要額外的庫。但 TPU 示例需要導入 torch_xla 庫。
  2. 設備獲取:
    • GPU 使用 torch.device('cuda') 或更簡單的 'cuda' 字符串來指定設備。torch.cuda.is_available() 用于檢查 CUDA 是否可用。
    • TPU 使用 torch_xla.core.xla_model.xla_device() 來獲取 XLA 設備對象。通常需要檢查 torch_xla 是否成功導入以及 xm.xla_device() 是否返回一個非 None 的設備對象來確定 TPU 環境是否可用。
  3. 張量創建/移動:
    • 無論是 GPU 還是 TPU,都可以通過在創建張量時指定 device=... 或使用 .to(device) 方法將已有的張量移動到目標設備上。
  4. 計算: 執行矩陣乘法 torch.mm() 的代碼在兩個例子中看起來是相同的。這是 PyTorch 的一個優點,上層代碼在不同設備上可以保持相似。
  5. 同步:
    • GPU 操作在調用時通常是異步的,但 torch.cuda.synchronize() 會阻塞 CPU,直到所有 GPU 操作完成,這在計時時是必需的。
    • TPU 操作通過 XLA 編譯和執行,通常是惰性的 (lazy)。這意味著調用 torch.mm() 可能只是構建計算圖,實際計算可能不會立即發生。xm.mark_step() 是一個重要的同步點,它會觸發 XLA 編譯當前構建的計算圖并在 TPU 上執行,然后等待執行完成。在實際訓練循環中,這通常在每個 mini-batch 結束時調用。

核心區別在于設備層面的處理方式: 原生 PyTorch 直接通過 CUDA API 與 GPU 交互,而對 TPU 的支持則需要借助 torch_xla 庫作為中介,通過 XLA 編譯器來生成和管理 TPU 上的執行。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/80464.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/80464.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/80464.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

自主shell命令行解釋器

目標 能處理普通命令能處理內建命令 實現原理 用下面的時間軸來表示時間發生次序。時間從左向右。shell由標識為sh的方塊,它隨著時間從左向右移動。 shell從用戶讀入字符串“ls”。shell建立一個新的進程,然后等待進程中運行ls程序并等待進程結束。 …

如何在sheel中運行Spark

啟動hdfs集群,打開hadoop100:9870,在wcinput目錄下上傳一個包含很多個單詞的文本文件。 啟動之后在spark-shell中寫代碼。 // 讀取文件,得到RDD val rdd1 sc.textFile("hdfs://hadoop100:8020/wcinput/words.txt") // 將單詞進行切…

【入門】數字走向II

描述 輸入整數N&#xff0c;輸出相應方陣。 輸入描述 一個整數N。&#xff08; 0 < n < 10 ) 輸出描述 一個方陣&#xff0c;每個數字的場寬為3。 #include <bits/stdc.h> using namespace std; int main() {int n;cin>>n;for(int in;i>1;i--){for(…

Python自動化-python基礎(下)

六、帶參數的裝飾器 七、函數生成器 運行結果&#xff1a; 八、通過反射操作對象方法 1.添加和覆蓋對象方法 2.刪除對象方法 通過使用內建函數: delattr() # 刪除 x.a() print("通過反射刪除之后") delattr(x, "a") x.a()3 通過反射判斷對象是否有指定…

重新定義高性能:Hyperlane —— Rust生態中的極速HTTP服務器

重新定義高性能&#xff1a;Hyperlane —— Rust生態中的極速HTTP服務器 &#x1f680; 為什么選擇Hyperlane&#xff1f; 在追求極致性能的Web服務開發領域&#xff0c;Hyperlane 憑借其獨特的Rust基因和架構設計&#xff0c;在最新基準測試中展現出令人驚艷的表現&#xff…

通俗的理解MFC消息機制

1. 消息是什么&#xff1f; 想象你家的門鈴響了&#xff08;比如有人按門鈴、敲門、或者有快遞&#xff09;&#xff0c;這些都是“消息”。 在 MFC 中&#xff0c;消息就是系統或用戶觸發的各種事件&#xff0c;比如鼠標點擊&#xff08;WM_LBUTTONDOWN&#xff09;、鍵盤輸入…

騰訊開源SuperSonic:AI+BI如何重塑制造業數據分析?

目錄 一、四款主流ChatBI產品 二、ChatBI應用案例與實際落地情況 三、SuperSonic底層原理 3.1、Headless?BI 是什么 3.2、S2SQL?是什么 3.3、SuperSonic 平臺架構 四、ChatBI應用細節深挖 五、與現有系統的集成方案 六、部署和安全 七、開源生態、可擴展性與二次開…

AI生成視頻推薦

以下是一些好用的 AI 生成視頻工具&#xff1a; 國內工具 可靈 &#xff1a;支持文本生成視頻、圖片生成視頻&#xff0c;適用于廣告、電影剪輯和短視頻制作&#xff0c;能在 30 秒內生成 6 秒的高清視頻&#xff08;1440p&#xff09;&#xff0c;目前處于免費測試階段。 即…

OrangePi Zero 3學習筆記(Android篇)5 - usbutils編譯(更新lsusb)

目錄 1. Ubuntu中編譯 2. AOSP編譯 3. 去掉原來的配置 3. 打包 4. 驗證lsusb 在Ubuntu中&#xff0c;lsusb的源代碼源自usbutils。而OrangePi Zero 3中lsusb的位置可以看文件H618-Android12-Src/external/toybox/Android.bp&#xff0c; "toys/other/lsusb.c",…

bcm5482 phy 場景總結

1,BCM5482是一款雙端口10/100/1000BASE-T以太網PHY芯片,支持多種速率和雙工模式。其配置主要通過MDIO(Management Data Input/Output)接口進行,MDIO接口用于訪問PHY芯片內部的寄存器,從而配置網絡速率、雙工模式以及其他相關參數。 a,具體以下面兩種場景舉例 2. 寄存器和…

RedHat磁盤的添加和擴容

前情提要 &#x1f9f1; 磁盤結構流程概念圖&#xff1a; 物理磁盤 (/dev/sdX) └── 分區&#xff08;如 /dev/sdX1&#xff09;或整塊磁盤&#xff08;直接使用&#xff09; └── 物理卷 (PV, 用 pvcreate) └── 卷組 (VG, 用 vgcreate) …

Lua—元表(Metatable)

原表解析 在 Lua table 中我們可以訪問對應的 key 來得到 value 值&#xff0c;但是卻無法對兩個 table 進行操作(比如相加)。 因此 Lua 提供了元表(Metatable)&#xff0c;允許我們改變 table 的行為&#xff0c;每個行為關聯了對應的元方法。 setmetatable(table,metatable…

一種運動平臺掃描雷達超分辨成像視場選擇方法——論文閱讀

一種運動平臺掃描雷達超分辨成像視場選擇方法 1. 專利的研究目標與意義1.1 研究目標1.2 實際意義2. 專利的創新方法與技術細節2.1 核心思路與流程2.1.1 方法流程圖2.2 關鍵公式與模型2.2.1 回波卷積模型2.2.2 最大后驗概率(MAP)估計2.2.3 統計約束模型2.2.4 迭代優化公式2.3 …

Listremove數據時報錯:Caused by: java.lang.UnsupportedOperationException

看了二哥的foreach陷阱后&#xff0c;自己也遇見了需要循環刪除元素的情況&#xff0c;立馬想到了當時自己陰差陽錯的避開所有坑的解決方式&#xff1a;先倒序遍歷&#xff0c;再刪除。之前好使&#xff0c;但是這次不好使了&#xff0c;報錯Caused by: java.lang.UnsupportedO…

Ceph集群OSD運維手冊:基礎操作與節點擴縮容實戰

#作者&#xff1a;stackofumbrella 文章目錄 一、Ceph集群的OSD基礎操作查看osd的ID編號查看OSD的詳細信息查看OSD的狀態信息查看OSD的統計信息查看OSD在主機上的存儲信息查看OSD延遲的統計信息查看各個OSD使用率集群暫停接收數據集群取消暫停 OSD寫入權重操作查看默認OSD操作…

PHP框架在分布式系統中的應用!

隨著互聯網業務的快速發展&#xff0c;分布式系統因其高可用性、可擴展性和容錯性成為現代應用架構的主流選擇。而PHP作為一門成熟的Web開發語言&#xff0c;憑借其簡潔的語法、豐富的框架生態和持續的性能優化&#xff0c;逐漸在分布式系統中嶄露頭角。本文將深入探討PHP框架在…

MySQL 索引(一)

文章目錄 索引&#xff08;重點&#xff09;硬件理解磁盤盤片和扇區定位扇區磁盤的隨機訪問和連續訪問 軟件方面的理解建立共識索引的理解 索引&#xff08;重點&#xff09; 索引可以提高數據庫的性能&#xff0c;它的價值&#xff0c;在于提高一個海量數據的檢索速度。 案例…

環境搭建-復現ST-GCN輸出動作分類視頻(win10+openpose1.7.0+VS2019+CMake3.30.1+cuda11.1)

這次我們安裝github.com/yysijie/st-gcn這個作者源碼環境&#xff0c;安裝流程十分復雜這里介紹大體流程。 1.首先編譯openpose的python API接口這個編譯難度較大&#xff0c;具體參考博文&#xff1a;windows編譯openpose及在python中調用_python openpose-CSDN博客 這個博…

HTML屬性

HTML&#xff08;HyperText Markup Language&#xff09;是網頁開發的基石&#xff0c;而屬性&#xff08;Attribute&#xff09;則是HTML元素的重要組成部分。它們為標簽提供附加信息&#xff0c;控制元素的行為、樣式或功能。本文將從基礎到進階&#xff0c;全面解析HTML屬性…

2025年“深圳杯”數學建模挑戰賽C題國獎大佬萬字思路助攻

完整版1.5萬字論文思路和Python代碼下載&#xff1a;https://www.jdmm.cc/file/2712073/ 引言 本題目旨在分析分布式能源 (Distributed Generation, DG) 接入配電網系統后帶來的風險。核心風險評估公式為&#xff1a; R P_{loss} \times C_{loss} P_{over} \times C_{over}…