PyTorch 張量與自動微分操作

筆記

1 張量索引操作

import torch
?
# 下標從左到右從0開始(0->第一個值), 從右到左從-1開始
# data[行下標, 列下標]
# data[0軸下標, 1軸下標, 2軸下標]
?
def dm01():# 創建張量torch.manual_seed(0)data = torch.randint(low=0, high=10, size=(4, 5))print('data->', data)# 根據下標值獲取對應位置的元素# 行數據 第一行print('data[0] ->', data[0])# 列數據 第一列print('data[:, 0]->', data[:, 0])# 根據下標列表取值# 第二行第三列的值和第四行第五列值print('data[[1, 3], [2, 4]]->', data[[1, 3], [2, 4]])# [[1], [3]: 第二行第三列 第二行第五列值 ? 第四行第三列 第四行第五列值print('data[[[1], [3]], [2, 4]]->', data[[[1], [3]], [2, 4]])# 根據布爾值取值# 第二列大于6的所有行數據print(data[:, 1] > 6)print('data[data[:, 1] > 6]->', data[data[:, 1] > 6])# 第三行大于6的所有列數據print('data[:, data[2]>6]->', data[:, data[2] > 6])# 根據范圍取值  切片  [起始下標:結束下標:步長]# 第一行第三行以及第二列第四列張量print('data[::2, 1::2]->', data[::2, 1::2])
?# 創建三維張量data2 = torch.randint(0, 10, (3, 4, 5))print("data2->", data2)# 0軸第一個值print(data2[0, :, :])# 1軸第一個值print(data2[:, 0, :])# 2軸第一個值print(data2[:, :, 0])
?
?
if __name__ == '__main__':dm01()

2 張量形狀操作

2.1 reshape

import torch
?
?
# reshape(shape=(行,列)): 修改連續或非連續張量的形狀, 不改數據
# -1: 表示自動計算行或列 ? 例如:  (5, 6) -> (-1, 3) -1*3=5*6 -1=10  (10, 3)
def dm01():torch.manual_seed(0)t1 = torch.randint(0, 10, (5, 6))print('t1->', t1)print('t1的形狀->', t1.shape)# 形狀修改為 (2, 15)t2 = t1.reshape(shape=(2, 15))t3 = t1.reshape(shape=(2, -1))print('t2->', t2)print('t2的形狀->', t2.shape)print('t3->', t3)print('t3的形狀->', t3.shape)
?
?
?
if __name__ == '__main__':dm01()
?

2.2 squeeze和unsqueeze

# squeeze(dim=): 刪除值為1的維度, dim->指定維度, 維度值不為1不生效  不設置dim,刪除所有值為1的維度
# 例如: (3,1,2,1) -> squeeze()->(3,2)  squeeze(dim=1)->(3,2,1)
# unqueeze(dim=): 在指定維度上增加值為1的維度  dim=-1:最后維度
def dm02():torch.manual_seed(0)# 四維t1 = torch.randint(0, 10, (3, 1, 2, 1))print('t1->', t1)print('t1的形狀->', t1.shape)# squeeze: 降維t2 = torch.squeeze(t1)print('t2->', t2)print('t2的形狀->', t2.shape)# dim: 指定維度t3 = torch.squeeze(t1, dim=1)print('t3->', t3)print('t3的形狀->', t3.shape)# unsqueeze: 升維# (3, 2)->(1, 3, 2)# t4 = t2.unsqueeze(dim=0)# 最后維度 (3, 2)->(3, 2, 1)t4 = t2.unsqueeze(dim=-1)print('t4->', t4)print('t4的形狀->', t4.shape)
?
?
if __name__ == '__main__':dm02()

2.3 transpose和permute

# 調換維度
# torch.permute(input=,dims=): 改變張量任意維度順序
# input: 張量對象
# dims: 改變后的維度順序, 傳入軸下標值 (1,2,3)->(3,1,2)
# torch.transpose(input=,dim0=,dim1=): 改變張量兩個維度順序
# dim0: 軸下標值, 第一個維度
# dim1: 軸下標值, 第二個維度
# (1,2,3)->(2,1,3) 一次只能交換兩個維度
def dm03():torch.manual_seed(0)t1 = torch.randint(low=0, high=10, size=(3, 4, 5))print('t1->', t1)print('t1形狀->', t1.shape)# 交換0維和1維數據# t2 = t1.transpose(dim0=1, dim1=0)t2 = t1.permute(dims=(1, 0, 2))print('t2->', t2)print('t2形狀->', t2.shape)# t1形狀修改為 (5, 3, 4)t3 = t1.permute(dims=(2, 0, 1))print('t3->', t3)print('t3形狀->', t3.shape)
?
?
if __name__ == '__main__':dm03()

2.4 view和contiguous

# tensor.view(shape=): 修改連續張量的形狀, 操作等同于reshape()
# tensor.is_contiugous(): 判斷張量是否連續, 返回True/False  張量經過transpose/permute處理變成不連續
# tensor.contiugous(): 將張量轉為連續張量
def dm04():torch.manual_seed(0)t1 = torch.randint(low=0, high=10, size=(3, 4))print('t1->', t1)print('t1形狀->', t1.shape)print('t1是否連續->', t1.is_contiguous())# 修改張量形狀t2 = t1.view((4, 3))print('t2->', t2)print('t2形狀->', t2.shape)print('t2是否連續->', t2.is_contiguous())# 張量經過transpose操作t3 = t1.transpose(dim0=1, dim1=0)print('t3->', t3)print('t3形狀->', t3.shape)print('t3是否連續->', t3.is_contiguous())# 修改張量形狀# view# contiugous(): 轉換成連續張量t4 = t3.contiguous().view((3, 4))print('t4->', t4)t5 = t3.reshape(shape=(3, 4))print('t5->', t5)print('t5是否連續->', t5.is_contiguous())
?
?
if __name__ == '__main__':dm04()

3 張量拼接操作

3.1 cat/concat

import torch
?
?
# torch.cat()/concat(tensors=, dim=): 在指定維度上進行拼接, 其他維度值必須相同, 不改變新張量的維度, 指定維度值相加
# tensors: 多個張量列表
# dim: 拼接維度
def dm01():torch.manual_seed(0)t1 = torch.randint(low=0, high=10, size=(2, 3))t2 = torch.randint(low=0, high=10, size=(2, 3))t3 = torch.cat(tensors=[t1, t2], dim=0)print('t3->', t3)print('t3形狀->', t3.shape)t4 = torch.concat(tensors=[t1, t2], dim=1)print('t4->', t4)print('t4形狀->', t4.shape)?
if __name__ == '__main__':# dm01()

3.2 stack

# torch.stack(tensors=, dim=): 根據指定維度進行堆疊, 在指定維度上新增一個維度(維度值張量個數), 新張量維度發生改變
# tensors: 多個張量列表
# dim: 拼接維度
def dm02():torch.manual_seed(0)t1 = torch.randint(low=0, high=10, size=(2, 3))t2 = torch.randint(low=0, high=10, size=(2, 3))t3 = torch.stack(tensors=[t1, t2], dim=0)# t3 = torch.stack(tensors=[t1, t2], dim=1)print('t3->', t3)print('t3形狀->', t3.shape)
?
?
if __name__ == '__main__':dm02()

4 自動微分模塊

4.1 梯度計算

"""
梯度: 求導,求微分 上山下山最快的方向
梯度下降法: W1=W0-lr*梯度 ? lr是可調整已知參數  W0:初始模型的權重,已知  計算出W0的梯度后更新到W1權重
pytorch中如何自動計算梯度 自動微分模塊
注意點: ①loss標量和w向量進行微分  ②梯度默認累加,計算當前的梯度, 梯度值是上次和當前次求和  ③梯度存儲.grad屬性中
"""
import torch
?
?
def dm01():# 創建標量張量 w權重# requires_grad: 是否自動微分,默認False# dtype: 自動微分的張量元素類型必須是浮點類型# w = torch.tensor(data=10, requires_grad=True, dtype=torch.float32)# 創建向量張量 w權重w = torch.tensor(data=[10, 20], requires_grad=True, dtype=torch.float32)# 定義損失函數, 計算損失值loss = 2 * w ** 2print('loss->', loss)print('loss.sum()->', loss.sum())# 計算梯度 反向傳播  loss必須是標量張量,否則無法計算梯度loss.sum().backward()# 獲取w權重的梯度值print('w.grad->', w.grad)w.data = w.data - 0.01 * w.gradprint('w->', w)
?
?
if __name__ == '__main__':dm01()

4.2 梯度下降法求最優解

"""
① 創建自動微分w權重張量
② 自定義損失函數 loss=w**2+20  后續無需自定義,導入不同問題損失函數模塊
③ 前向傳播 -> 先根據上一版模型計算預測y值, 根據損失函數計算出損失值
④ 反向傳播 -> 計算梯度
⑤ 梯度更新 -> 梯度下降法更新w權重
"""
import torch
?
?
def dm01():# ① 創建自動微分w權重張量w = torch.tensor(data=10, requires_grad=True, dtype=torch.float32)print('w->', w)# ② 自定義損失函數 后續無需自定義, 導入不同問題損失函數模塊loss = w ** 2 + 20print('loss->', loss)# 0.01 -> 學習率print('開始 權重x初始值:%.6f (0.01 * w.grad):無 loss:%.6f' % (w, loss))for i in range(1, 1001):# ③ 前向傳播 -> 先根據上一版模型計算預測y值, 根據損失函數計算出損失值loss = w ** 2 + 20# 梯度清零 -> 梯度累加, 沒有梯度默認Noneif w.grad is not None:w.grad.zero_()# ④ 反向傳播 -> 計算梯度loss.sum().backward()# ⑤ 梯度更新 -> 梯度下降法更新w權重# W = W - lr * W.grad# w.data -> 更新w張量對象的數據, 不能直接使用w(將結果重新保存到一個新的變量中)w.data = w.data - 0.01 * w.gradprint('w.grad->', w.grad)print('次數:%d 權重w: %.6f, (0.01 * w.grad):%.6f loss:%.6f' % (i, w, 0.01 * w.grad, loss))
?print('w->', w, w.grad, 'loss最小值', loss)
?
?
if __name__ == '__main__':dm01()

4.3 梯度計算注意點

# 自動微分的張量不能轉換成numpy數組, 可以借助detach()方法生成新的不自動微分張量
import torch
?
?
def dm01():x1 = torch.tensor(data=10, requires_grad=True, dtype=torch.float32)print('x1->', x1)# 判斷張量是否自動微分 返回True/Falseprint(x1.requires_grad)# 調用detach()方法對x1進行剝離, 得到新的張量,不能自動微分,數據和原張量共享x2 = x1.detach()print(x2.requires_grad)print(x1.data)print(x2.data)print(id(x1.data))print(id(x2.data))# 自動微分張量轉換成numpy數組n1 = x2.numpy()print('n1->', n1)
?
?
if __name__ == '__main__':dm01()

4.4 自動微分模塊應用

import torch
import torch.nn as nn ?# 損失函數,優化器函數,模型函數
?
?
def dm01():# todo:1-定義樣本的x和yx = torch.ones(size=(2, 5))y = torch.zeros(size=(2, 3))print('x->', x)print('y->', y)# todo:2-初始模型權重 w b 自動微分張量w = torch.randn(size=(5, 3), requires_grad=True)b = torch.randn(size=(3,), requires_grad=True)print('w->', w)print('b->', b)# todo:3-初始模型,計算預測y值y_pred = torch.matmul(x, w) + bprint('y_pred->', y_pred)# todo:4-根據MSE損失函數計算損失值# 創建MSE對象, 類創建對象criterion = nn.MSELoss()loss = criterion(y_pred, y)print('loss->', loss)# todo:5-反向傳播,計算w和b梯度loss.sum().backward()print('w.grad->', w.grad)print('b.grad->', b.grad)
?
?
if __name__ == '__main__':dm01()

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/79486.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/79486.shtml
英文地址,請注明出處:http://en.pswp.cn/web/79486.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

接口的基礎定義與屬性約束

在 TypeScript 中,接口(Interface)是一個非常強大且常用的特性。接口定義了對象的結構,包括對象的屬性和方法,可以為對象提供類型檢查和約束。通過接口,我們可以清晰地描述一個對象應該具備哪些屬性和方法。…

高效全能PDF工具,支持OCR識別

軟件介紹 PDF XChange Editor是一款功能強大的PDF編輯工具,支持多種操作功能,不僅可編輯PDF內容與圖片,還具備OCR識別表單信息的能力,滿足多種場景下的需求。 軟件特點 這款PDF編輯器完全免費,用戶下載后直接…

OpenCV 中用于背景分割的一個類cv::bgsegm::BackgroundSubtractorGMG

操作系統:ubuntu22.04 OpenCV版本:OpenCV4.9 IDE:Visual Studio Code 編程語言:C11 算法描述 cv::bgsegm::BackgroundSubtractorGMG 是 OpenCV 中用于背景分割的一個類,它實現了基于貝葉斯推理的背景建模算法(Bayesi…

MongoDB知識框架

簡介:MongoDB 是一個基于分布式文件存儲的數據庫,屬于 NoSQL 數據庫產品,以下是其知識框架總結: 一、數據模型 文檔:MongoDB 中的數據以 BSON(二進制形式的 JSON)格式存儲在集合中,…

WEBSTORM前端 —— 第2章:CSS —— 第8節:網頁制作2(小兔鮮兒)

目錄 1.項目目錄 2.SEO 三大標簽 3.Favicon 圖標 4.版心 5.快捷導航(shortcut) 6.頭部(header) 7.底部(footer) 8.banner 9.banner – 圓點 10.新鮮好物(goods) 11.熱門品牌(brand) 12.生鮮(fresh) 13.最新專題(topic) 1.項目目錄 【xtx-pc】 ima…

1、RocketMQ 核心架構拆解

1. 為什么要使用消息隊列? 消息隊列(MQ)是分布式系統中不可或缺的中間件,主要解決系統間的解耦、異步和削峰填谷問題。 解耦:生產者和消費者通過消息隊列通信,彼此無需直接依賴,極大提升系統靈…

[Linux網絡_71] NAT技術 | 正反代理 | 網絡協議總結 | 五種IO模型

目錄 1.NAT技術 NAPT 2.NAT和代理服務器 3.網線通信各層協議總結 補充說明 4.五種 IO 模型 1.什么是IO?什么是高效的IO? 2.有那些IO的方式?這么多的方式,有那些是高效的? 異步 IO 🎣 關鍵缺陷類比…

Unity基礎學習(八)時間相關內容Time

眾所周知,每一個游戲都會有自己的時間。這個時間可以是內部,從游戲開始的時間,也可以是外部真實的物理時間,時間相關內容 主要用于游戲中 參與位移計時 時間暫停等。那么我們今天就來看看Unity中和時間相關的內容。 Unity時間功能…

Java游戲服務器開發流水賬(1)游戲服務器的架構淺析

新項目立項停滯,頭大。近期讀老項目代碼看到Java,筆記記錄一下。 為什么要做服務器的架構 游戲服務器架構設計具有多方面的重要意義,它直接關系到游戲的性能、可擴展性、穩定性以及用戶體驗等關鍵因素 確保游戲的流暢運行 優化數據處理&a…

計算機視覺與深度學習 | 基于Transformer的低照度圖像增強技術

基于Transformer的低照度圖像增強技術通過結合Transformer的全局建模能力和傳統圖像增強理論(如Retinex),在保留顏色信息、抑制噪聲和平衡亮度方面展現出顯著優勢。以下是其核心原理、關鍵公式及典型代碼實現: 一、原理分析 1. 全局依賴建模與局部特征融合 Transformer的核…

Linux 文件目錄管理常用命令

pwd 顯示當前絕對路徑 cd 切換目錄 指令備注cd -回退cd …返回上一層cd ~切換到用戶主目錄 ls 列出目錄的內容 指令備注ls -a顯示當前目錄中的所有文件和目錄,包括隱藏文件ls -l以長格式顯示當前目錄中的文件和目錄ls -hl以人類可讀的方式顯示當前目錄中的文…

【Linux 系統調試】性能分析工具perf使用與調試方法

目錄 一、perf基本概念 1?. 事件類型? 2?. 低開銷高精度 3?. 工具定位? 二、安裝與基礎配置 1. 安裝方法 2. 啟用符號調試 三、perf工作原理 1. 數據采集機制 2. 硬件事件轉化流程 四、perf應用場景 1. 系統瓶頸定位 2. 鎖競爭優化 3. 緩存優化 五、perf高級…

嵌入式中屏幕的通信方式

LCD屏通信方式詳解 LCD屏(液晶顯示屏)的通信方式直接影響其數據傳輸效率、顯示刷新速度及硬件設計復雜度。根據應用場景和需求,LCD屏的通信方式主要分為以下三類,每種方式在協議類型、數據速率、硬件成本及適用場景上存在顯著差異…

【el-admin】el-admin關聯數據字典

數據字典使用 一、新增數據字典1、新增【圖書狀態】和【圖書類型】數據字典2、編輯字典值 二、代碼生成配置1、表單設置2、關聯字典3、驗證關聯數據字典 三、查詢操作1、模糊查詢2、按類別查詢(下拉框) 四、數據校驗 一、新增數據字典 1、新增【圖書狀態…

【Spring】Spring MVC筆記

文章目錄 一、SpringMVC簡介1、什么是MVC2、什么是SpringMVC3、SpringMVC的特點 二、HelloWorld1、開發環境2、創建maven工程a>添加web模塊b>打包方式:warc>引入依賴 3、配置web.xmla>默認配置方式b>擴展配置方式 4、創建請求控制器5、創建springMVC…

如何在大型項目中解決 VsCode 語言服務器崩潰的問題

在大型C/C項目中,VS Code的語言服務器(如C/C擴展)可能因內存不足或配置不當頻繁崩潰。本文結合系統資源分析與實戰技巧,提供一套完整的解決方案。 一、問題根源診斷 1.1 內存瓶頸分析 通過top命令查看系統資源使用情況&#xff…

LeetCode百題刷002摩爾投票法

遇到的問題都有解決的方案,希望我的博客可以為你提供一些幫助 圖片源自leetcode 題目:169. 多數元素 - 力扣(LeetCode) 一、排序法 題目要求需要找到多數值(元素個數>n/2)并返回這個值。一般會想到先…

Android Studio Gradle 中 只顯示 Tasks 中沒有 build 選項解決辦法

一、問題描述 想把項目中某一個模塊的代碼單獨打包成 aar ,之前是點擊 AndroidStudio 右側的 Gradle 選項,然后再點擊需要打包的模塊找到 build 進行打包,但是卻發現沒有 build 選項。 二、解決辦法 1、設置中勾選 Configure all Gradle tasks… 選項 …

深入淺出之STL源碼分析2_stl與標準庫,編譯器的關系

引言 在第一篇博客中,深入淺出之STL源碼分析1_vector基本操作-CSDN博客 我們將引出下面的幾個問題 1.剛才我提到了我的編譯器版本是g 11.4.0,而我們要講解的是STL(標準模板庫),那么二者之間的關系是什么?…

(十二)深入了解AVFoundation-采集:人臉識別與元數據處理

(一)深入了解AVFoundation:框架概述與核心模塊解析-CSDN博客 (二) 深入了解AVFoundation - 播放:AVFoundation 播放基礎入門-CSDN博客 (三)深入了解AVFoundation-播放&#xff1…