day50 隨機函數與廣播機制

目錄

一、隨機張量的生成

1.1?torch.randn()?函數

1.2 其他隨機函數

1.3 輸出維度測試

二、廣播機制

2.1 廣播機制的規則

2.2 加法的廣播機制

二維張量與一維向量相加

三維張量與二維張量相加

二維張量與標量相加

高維張量與低維張量相加

2.3 乘法的廣播機制

批量矩陣與單個矩陣相乘

批量矩陣與批量矩陣相乘(部分廣播)

三維張量與二維張量相乘(高維廣播)


一、隨機張量的生成

在深度學習中,我們經常需要隨機生成張量,例如用于模型參數的初始化、生成測試數據或模擬輸入特征。PyTorch 提供了多種隨機張量生成函數,其中 torch.randn() 是最常用的一種。

1.1?torch.randn()?函數

torch.randn() 可以創建一個由標準正態分布(均值為 0,標準差為 1)隨機數填充的張量。它的參數如下:

  • size:必選參數,表示輸出張量的形狀。

  • dtype:可選參數,指定張量的數據類型。

  • device:可選參數,指定張量存儲的設備。

  • requires_grad:可選參數,是否需要計算梯度。

以下是生成不同維度張量的示例代碼:

import torch# 生成標量(0維張量)
scalar = torch.randn(())
print(f"標量: {scalar}, 形狀: {scalar.shape}")# 生成向量(1維張量)
vector = torch.randn(5)  # 長度為5的向量
print(f"向量: {vector}, 形狀: {vector.shape}")# 生成矩陣(2維張量)
matrix = torch.randn(3, 4)  # 3行4列的矩陣
print(f"矩陣:{matrix},矩陣形狀: {matrix.shape}")# 生成3維張量(常用于圖像數據的通道、高度、寬度)
tensor_3d = torch.randn(3, 224, 224)  # 3通道,高224,寬224
print(f"3維張量形狀: {tensor_3d.shape}")# 生成4維張量(常用于批量圖像數據:[batch, channel, height, width])
tensor_4d = torch.randn(2, 3, 224, 224)  # 批量大小為2,3通道,高224,寬224
print(f"4維張量形狀: {tensor_4d.shape}")

1.2 其他隨機函數

除了 torch.randn(),PyTorch 還提供了其他隨機函數,例如:

  • torch.rand():生成在 [0, 1) 范圍內均勻分布的隨機數。

  • torch.randint():生成指定范圍內的隨機整數。

  • torch.normal():生成指定均值和標準差的正態分布隨機數。

以下是示例代碼:

# 生成均勻分布隨機數
x = torch.rand(3, 2)  # 生成3x2的張量
print(f"均勻分布隨機數: {x}, 形狀: {x.shape}")# 生成隨機整數
x = torch.randint(low=0, high=10, size=(3,))  # 生成3個0到9之間的整數
print(f"隨機整數: {x}, 形狀: {x.shape}")# 生成正態分布隨機數
mean = torch.tensor([0.0, 0.0])
std = torch.tensor([1.0, 2.0])
x = torch.normal(mean, std)  # 生成兩個正態分布隨機數
print(f"正態分布隨機數: {x}, 形狀: {x.shape}")

1.3 輸出維度測試

在實際的深度學習任務中,我們通常需要計算輸入張量經過不同層后的輸出維度。以下是卷積層、池化層、線性層等的維度變化示例:

import torch
import torch.nn as nn# 生成輸入張量 (批量大小, 通道數, 高度, 寬度)
input_tensor = torch.randn(1, 3, 32, 32)  # 例如CIFAR-10圖像
print(f"輸入尺寸: {input_tensor.shape}")# 卷積層操作
conv1 = nn.Conv2d(in_channels=3,        # 輸入通道數out_channels=16,      # 輸出通道數(卷積核數量)kernel_size=3,        # 卷積核大小stride=1,             # 步長padding=1             # 填充
)
conv_output = conv1(input_tensor)  # 由于 padding=1 且 stride=1,空間尺寸保持不變
print(f"卷積后尺寸: {conv_output.shape}")# 池化層操作 (減小空間尺寸)
pool = nn.MaxPool2d(kernel_size=2, stride=2)  # 創建一個最大池化層
pool_output = pool(conv_output)
print(f"池化后尺寸: {pool_output.shape}")# 將多維張量展平為向量
flattened = pool_output.view(pool_output.size(0), -1)
print(f"展平后尺寸: {flattened.shape}")# 線性層操作
fc1 = nn.Linear(in_features=4096,     # 輸入特征數out_features=128      # 輸出特征數
)
fc_output = fc1(flattened)
print(f"線性層后尺寸: {fc_output.shape}")# 再經過一個線性層(例如分類器)
fc2 = nn.Linear(128, 10)  # 假設是10分類問題
final_output = fc2(fc_output)
print(f"最終輸出尺寸: {final_output.shape}")

二、廣播機制

PyTorch 的廣播機制(Broadcasting)是一種強大的張量運算特性,允許在不同形狀的張量之間進行算術運算,而無需顯式地擴展張量維度或復制數據。這種機制使得代碼更簡潔高效,尤其在處理多維數據時非常實用。

2.1 廣播機制的規則

當對兩個形狀不同的張量進行運算時,PyTorch 會按以下規則自動處理維度兼容性:

  1. 從右向左比較維度:從張量的最后一個維度(最右側)開始向前逐維比較。

  2. 維度擴展條件

    • 相等維度:若兩個張量在某一維度上大小相同,則繼續比較下一維度。

    • 一維擴展:若其中一個張量在某一維度上大小為 1,則該維度會被擴展為另一個張量對應維度的大小。

    • 不兼容錯誤:若某一維度大小既不相同也不為 1,則拋出 RuntimeError

  3. 維度補全規則:若一個張量的維度少于另一個,則在其左側補 1 直至維度數匹配。

2.2 加法的廣播機制

以下是幾個加法廣播的例子:

二維張量與一維向量相加
a = torch.tensor([[10], [20], [30]])  # 形狀: (3, 1)
b = torch.tensor([1, 2, 3])           # 形狀: (3,)
result = a + bprint("原始張量a:")
print(a)print("\n原始張量b:")
print(b)print("\n加法結果:")
print(result)
三維張量與二維張量相加
a = torch.tensor([[[1], [2]], [[3], [4]]])  # 形狀: (2, 2, 1)
b = torch.tensor([[10, 20]])               # 形狀: (1, 2)
result = a + bprint("原始張量a:")
print(a)print("\n原始張量b:")
print(b)print("\n加法結果:")
print(result)
二維張量與標量相加
a = torch.tensor([[1, 2], [3, 4]])  # 形狀: (2, 2)
b = 10                              # 標量,形狀視為 ()
result = a + bprint("原始張量a:")
print(a)print("\n標量b:")
print(b)print("\n加法結果:")
print(result)
高維張量與低維張量相加
a = torch.tensor([[[1, 2], [3, 4]]])  # 形狀: (1, 2, 2)
b = torch.tensor([[5, 6]])            # 形狀: (1, 2)
result = a + bprint("原始張量a:")
print(a)print("\n原始張量b:")
print(b)print("\n加法結果:")
print(result)

2.3 乘法的廣播機制

矩陣乘法(@)的廣播機制除了遵循通用廣播規則外,還需要滿足矩陣乘法的維度約束:

  • 最后兩個維度必須滿足:A.shape[-1] == B.shape[-2](即 A 的列數等于 B 的行數)。

  • 其他維度(批量維度):遵循通用廣播規則。

以下是幾個矩陣乘法廣播的例子:

批量矩陣與單個矩陣相乘
A = torch.randn(2, 3, 4)  # 形狀: (2, 3, 4)
B = torch.randn(4, 5)     # 形狀: (4, 5)
result = A @ B            # 結果形狀: (2, 3, 5)print("A形狀:", A.shape)
print("B形狀:", B.shape)
print("結果形狀:", result.shape)
批量矩陣與批量矩陣相乘(部分廣播)
A = torch.randn(3, 2, 4)  # 形狀: (3, 2, 4)
B = torch.randn(1, 4, 5)  # 形狀: (1, 4, 5)
result = A @ B            # 結果形狀: (3, 2, 5)print("A形狀:", A.shape)
print("B形狀:", B.shape)
print("結果形狀:", result.shape)
三維張量與二維張量相乘(高維廣播)
A = torch.randn(2, 3, 4, 5)  # 形狀: (2, 3, 4, 5)
B = torch.randn(5, 6)        # 形狀: (5, 6)
result = A @ B               # 結果形狀: (2, 3, 4, 6)print("A形狀:", A.shape)
print("B形狀:", B.shape)
print("結果形狀:", result.shape)

@浙大疏錦行

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/86267.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/86267.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/86267.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Java持久層技術對比:Hibernate、MyBatis與JPA的選擇與應用

目錄 簡介持久層技術概述Hibernate詳解MyBatis詳解JPA詳解技術選型對比最佳實踐與應用場景性能優化策略未來發展趨勢總結與建議 簡介 在Java企業級應用開發中,持久層(Persistence Layer)作為連接業務邏輯與數據存儲的橋梁,其技…

【2025CVPR】模型融合新范式:PLeaS算法詳解(基于排列與最小二乘的模型合并技術)

本文深入解析ICLR 2025頂會論文《PLeaS: Merging Models with Permutations and Least Squares》,揭示模型融合領域突破性進展. 一、問題背景:模型合并的核心挑戰 隨著開源模型的爆發式增長,如何高效合并多個專用模型成為關鍵挑戰。傳統方法存在三大痛點: ?初始化依賴?…

磁盤空間清道夫FolderSize 系列:可視化分析 + 重復文件識別,

各位電腦小能手們,今天來給大家嘮嘮Folder類軟件!這玩意兒主要是為了文件夾管理、監控、安全還有優化這些需求設計的,不同工具的功能各有側重。下面我就結合多個搜索結果,給大家分類介紹一下。 軟件下載地址安裝包 首先是文件夾空…

嵌入式全棧面試指南:TCP/IP、C 語言基礎、STM32 外設與 RT?Thread

作為嵌入式工程師,面試時往往不僅要展示基礎編程能力,還要兼具網絡協議、硬件驅動、實時操作系統(RTOS)等方面的知識深度。本文將從TCP/IP 協議、C 語言核心基礎、STM32 IO 與外設驅動、RT?Thread 及其多任務/IPC四大模塊進行全面…

Git 命令全流程總結

以下是從初始化到版本控制、查看記錄、撤回操作的 Git 命令全流程總結,按操作場景分類整理: 一、初始化與基礎操作 操作命令初始化倉庫git init添加所有文件到暫存區git add .提交到本地倉庫git commit -m "提交描述"首次提交需配置身份git c…

軟件功能測試報告都包含哪些內容?

軟件功能測試報告是軟件開發生命周期中的重要文檔,主要涵蓋以下關鍵內容:    1.測試概況:概述測試目標、范圍和方法,確保讀者對測試背景有清晰了解。 2.測試環境:詳細描述測試所用的硬件、軟件環境,確保…

OpenCV CUDA模塊圖像處理------雙邊濾波的GPU版本函數bilateralFilter()

操作系統:ubuntu22.04 OpenCV版本:OpenCV4.9 IDE:Visual Studio Code 編程語言:C11 算法描述 該函數在 GPU 上執行雙邊濾波操作,是一種非線性平滑濾波器,能夠在 保留邊緣的同時去除噪聲。 函數原型 void cv::cuda:…

Perplexity AI:重塑你的信息探索之旅

在信息爆炸的時代,如何快速、精準地獲取所需知識,并將其轉化為行動力?答案或許就藏在 Perplexity AI 這款強大的智能工具中。它不僅僅是一個搜索引擎,更是一個能理解你、與你對話、為你深度解析信息的智能伙伴。告別繁瑣的信息篩選…

Java高級反射實戰:15個場景化編程技巧與底層原理解析

引用 在Java的世界里,反射機制如同賦予開發者一把“萬能鑰匙”,它打破了靜態編程的邊界,讓代碼在運行時擁有動態獲取類信息、操作對象屬性和方法的能力。從Spring框架的依賴注入,到MyBatis的SQL映射生成;從JSON序列化…

構建 MCP 服務器:第 3 部分 — 添加提示

這是我們構建 MCP 服務器的四部分教程的第三部分。在第一部分中,我們使用基本資源創建了第一個MCP 服務器;在第二部分中,我們添加了資源模板并改進了代碼組織。現在,我們將進一步重構代碼并添加提示功能。 什么是 MCP 提示&#…

MySQL 索引優化(Explain執行計劃) 詳細講解

🤟致敬讀者 🟩感謝閱讀🟦笑口常開🟪生日快樂?早點睡覺 📘博主相關 🟧博主信息🟨博客首頁🟫專欄推薦🟥活動信息 文章目錄 MySQL 索引優化(Explain執行計劃…

使用 IntelliJ IDEA 安裝通義靈碼(TONGYI Lingma)插件,進行后端 Java Spring Boot 項目的用戶用例生成及常見問題處理

一、什么是通義靈碼(TONGYI Lingma)? 通義靈碼是阿里巴巴推出的智能代碼輔助工具,結合大模型技術,支持代碼生成、用例生成、代碼補全等功能,能極大提升開發效率。 二、在 IDEA 中安裝通義靈碼插件 打開 In…

AI編程在BOSS項目的實踐經驗分享

前言 在人工智能技術革新浪潮的推動下,智能編程助手正以前所未有的速度重塑開發領域。這些基于AI的代碼輔助工具通過智能提示生成、實時錯誤檢測和自動化重構等功能,顯著提升了軟件工程的全流程效率。無論是初入行業的開發者還是資深程序員,…

JVM 類加載器 詳解

類加載器 兩個類來源于同一個 Class文件,被同一個Java虛擬機加載,只要加載它們的類加載器不同,那這兩個類就必定不相等 這里所指的“相等”,包括代表類的Class對象的equals()方法、isAssignableFrom()方法、isInstance()方法的返…

Javascript 編程基礎(5)面向對象 | 5.1、構造函數實例化對象

文章目錄 一、構造函數實例化對象1、基本語法2、構造函數與原型的關系3、完整的原型鏈4、構造函數的特點5、prototype與__proto__屬性5.1、對象實例的__proto__屬性5.2、prototype屬性僅存在于函數對象5.3、實例與原型的關系5.4、獲取對象原型 6、注意事項 前言: 在…

自動駕駛科普(百度Apollo)學習筆記

1. 寫在前面 在過去的幾年里,自動駕駛技術取得飛速發展,人類社會正逐漸走向一個新時代,這個時代中,汽車不僅僅是一個交通工具,更是一個智能的、能夠感知環境、做出決策并自主導航的機器伙伴。現在正好也從事這塊的工作…

Windows應用-音視頻捕獲

下載“Windows應用-音視頻捕獲”項目 本應用可以同時捕獲4個視頻源和4個音頻源,可以監視視頻源圖像,監聽音頻源;可以將視頻源圖像寫入MP4文件,將音頻源寫入MP3或WAV文件;還可以錄制系統播放的聲音。本應用使用MFC對話框…

MATLAB生成大規模無線通信網絡拓撲(任意節點數量)

功能: 生成任意節點數量的網絡拓撲,符合現實世界節點空間分布和連接規律 效果: 30節點: 100節點: 500節點: 程序: %創建時間:2025年6月8日 %zhouzhichao %自然生長出n節點的網絡% …

TDengine 開發指南—— UDF函數

UDF 簡介 在某些應用場景中,應用邏輯需要的查詢功能無法直接使用內置函數來實現,TDengine 允許編寫用戶自定義函數(UDF),以便解決特殊應用場景中的使用需求。UDF 在集群中注冊成功后,可以像系統內置函數一…

C#提取CAN ASC文件時間戳:實現與性能優化

C#提取CAN ASC文件時間戳:實現與性能優化 在汽車電子和工業控制領域,CAN總線是最常用的通信協議之一。而ASC(ASCII)文件作為CAN總線數據的標準日志格式,廣泛應用于數據記錄和分析場景。本文將深入探討如何高效地從CAN…