RNN的理解

對于RNN的理解

import torch
import torch.nn as nn
import torch.nn.functional as F# 手動實現一個簡單的RNN
class RNN(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(RNN, self).__init__()# 定義權重矩陣和偏置項self.hidden_size = hidden_sizeself.W_xh = nn.Parameter(torch.randn(input_size, hidden_size))  # 輸入到隱藏層的權重#

#注:
input_size = 4
hidden_size = 3
W_xh = torch.randn(input_size, hidden_size)
生成的 W_xh 會是一個形狀為 (4, 3) 的張量,可能是這樣的(數字是隨機生成的):
tensor([[ 0.2973, -1.1254, 0.7172],
[ 0.0983, 0.2856, -0.4586],
[-0.0105, 0.2317, 0.2716],
[ 1.0431, -1.3894, -0.1525]])
這個張量有 4 行 3 列。

        self.W_hh = nn.Parameter(torch.randn(hidden_size, hidden_size))  # 隱藏層到隱藏層的權重self.b_h = nn.Parameter(torch.zeros(hidden_size))  # 隱藏層偏置self.W_hy = nn.Parameter(torch.randn(hidden_size, output_size))  # 隱藏層到輸出層的權重self.b_y = nn.Parameter(torch.zeros(output_size))  # 輸出層偏置def forward(self, x):# 初始化隱藏狀態為0h_t = torch.zeros(x.size(0), self.hidden_size)  # 初始隱藏狀態 [[5]]

注:
x 是輸入數據,形狀是 (3, 5, 4),其中:
3 是批量大小(batch_size),即我們一次性輸入網絡的樣本數是 3。
5 是序列長度(seq_len),每個樣本有 5 個時間步。
4 是每個時間步的輸入特征數量。
self.hidden_size 假設是 6,表示隱藏層的維度是 6。
x.size(0) 獲取輸入張量 x 的第一個維度的大小,也就是批量大小 3。
torch.zeros(3, 6) 會創建一個形狀為 (3, 6) 的張量,表示有 3 個樣本,每個樣本有 6 個隱藏狀態神經元(即隱狀態的維度是 6)。所有的元素都初始化為 0。

    # 遍歷時間步,逐個處理輸入序列for t in range(x.size(1)):  # x.size(1) 是序列長度x_t = x[:, t, :]  # 獲取當前時間步的輸入 (batch_size, input_size)

`
注:x = torch.tensor([[[0.1, 0.2, 0.3, 0.4], # 第 0 時間步的輸入 (第一個樣本)
[0.5, 0.6, 0.7, 0.8], # 第 1 時間步的輸入 (第一個樣本)
[0.9, 1.0, 1.1, 1.2]], # 第 2 時間步的輸入 (第一個樣本)
[[1.3, 1.4, 1.5, 1.6], # 第 0 時間步的輸入 (第二個樣本)
[1.7, 1.8, 1.9, 2.0], # 第 1 時間步的輸入 (第二個樣本)
[2.1, 2.2, 2.3, 2.4]]]) # 第 2 時間步的輸入 (第二個樣本)

第一次循環 t=0:
x_t = x[:, 0, :]
x[:, 0, :] 會提取出所有樣本在第 0 時間步的輸入:

第一個樣本在第 0 時間步的輸入是 [0.1, 0.2, 0.3, 0.4]。

第二個樣本在第 0 時間步的輸入是 [1.3, 1.4, 1.5, 1.6]。

因此,x_t 的值是:

tensor([[0.1, 0.2, 0.3, 0.4],
[1.3, 1.4, 1.5, 1.6]])

        # 更新隱藏狀態:h_t = tanh(W_xh * x_t + W_hh * h_t + b_h)h_t = torch.tanh(x_t @ self.W_xh + h_t @ self.W_hh + self.b_h)  # [[4]]

`
注:1. 計算 x_t @ W_xh
x_t @ W_xh 是輸入 x_t 和權重矩陣 W_xh 的矩陣乘法。我們有 2 個樣本,每個樣本有 3 個輸入特征,權重矩陣 W_xh 的形狀是 (3, 4),所以乘法的結果是一個形狀為 (2, 4) 的張量,即每個樣本的隱藏狀態更新的部分。

對于第一個樣本:

[0.5, 0.6, 0.7] @ [[0.1, 0.2, -0.1, 0.4],
[0.3, 0.5, 0.2, -0.2],
[0.7, -0.1, 0.3, 0.5]]
我們可以計算它的結果:

= [0.5 * 0.1 + 0.6 * 0.3 + 0.7 * 0.7,
0.5 * 0.2 + 0.6 * 0.5 + 0.7 * (-0.1),
0.5 * -0.1 + 0.6 * 0.2 + 0.7 * 0.3,
0.5 * 0.4 + 0.6 * (-0.2) + 0.7 * 0.5]

= [0.05 + 0.18 + 0.49,
0.1 + 0.3 - 0.07,
-0.05 + 0.12 + 0.21,
0.2 - 0.12 + 0.35]

= [0.72, 0.33, 0.28, 0.43]
對于第二個樣本:

[1.0, 1.2, 1.3] @ [[0.1, 0.2, -0.1, 0.4],
[0.3, 0.5, 0.2, -0.2],
[0.7, -0.1, 0.3, 0.5]]
計算結果:

= [1.0 * 0.1 + 1.2 * 0.3 + 1.3 * 0.7,
1.0 * 0.2 + 1.2 * 0.5 + 1.3 * (-0.1),
1.0 * -0.1 + 1.2 * 0.2 + 1.3 * 0.3,
1.0 * 0.4 + 1.2 * (-0.2) + 1.3 * 0.5]

= [0.1 + 0.36 + 0.91,
0.2 + 0.6 - 0.13,
-0.1 + 0.24 + 0.39,
0.4 - 0.24 + 0.65]

= [1.37, 0.67, 0.53, 0.81]
因此,x_t @ W_xh 的結果是:

tensor([[0.72, 0.33, 0.28, 0.43],
[1.37, 0.67, 0.53, 0.81]])

x_t @ self.W_xh:
x_t 是當前時間步的輸入,形狀是 (batch_size, input_size)。
self.W_xh 是輸入到隱藏層的權重矩陣,形狀是 (input_size, hidden_size)。
h_t @ self.W_hh:
h_t 是前一時間步的隱藏狀態,形狀是 (batch_size, hidden_size)。
self.W_hh 是隱藏層到隱藏層的權重矩陣,形狀是 (hidden_size, hidden_size)。

    # 最后一個時間步的隱藏狀態通過全連接層得到輸出y_t = h_t @ self.W_hy + self.b_y  # 輸出層return y_t

超參數設置

input_size = 10 # 輸入特征維度
hidden_size = 20 # 隱藏層維度
output_size = 5 # 輸出類別數
seq_length = 5 # 序列長度
batch_size = 3 # 批量大小

實例化模型

model = RNN(input_size, hidden_size, output_size)

打印模型結構

print(model)

創建隨機輸入數據 (batch_size, seq_length, input_size)

x = torch.randn(batch_size, seq_length, input_size)

前向傳播

output = model(x)
print(“Output shape:”, output.shape) # 輸出形狀應為 (batch_size, output_size)

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/902269.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/902269.shtml
英文地址,請注明出處:http://en.pswp.cn/news/902269.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

二叉查找樹和B樹

二叉查找樹(Binary Search Tree, BST)和 B 樹(B-tree)都是用于組織和管理數據的數據結構,但它們在結構、應用場景和性能方面有顯著區別。 二叉查找樹(Binary Search Tree, BST) 特點&#xff1…

一段式端到端自動駕駛:VAD:Vectorized Scene Representation for Efficient Autonomous Driving

論文地址:https://github.com/hustvl/VAD 代碼地址:https://arxiv.org/pdf/2303.12077 1. 摘要 自動駕駛需要對周圍環境進行全面理解,以實現可靠的軌跡規劃。以往的方法依賴于密集的柵格化場景表示(如:占據圖、語義…

OpenCV訓練題

一、創建一個 PyQt 應用程序,該應用程序能夠: 使用 OpenCV 加載一張圖像。在 PyQt 的窗口中顯示這張圖像。提供四個按鈕(QPushButton): 一個用于將圖像轉換為灰度圖一個用于將圖像恢復為原始彩色圖一個用于將圖像進行…

opencv函數展示4

一、形態學操作函數 1.基本形態學操作 (1)cv2.getStructuringElement() (2)cv2.erode() (3)cv2.dilate() 2.高級形態學操作 (1)cv2.morphologyEx() 二、直方圖處理函數 1.直方圖…

iPhone 13P 換超容電池,一年實記的“電池循環次數-容量“柱狀圖

繼上一篇 iPhone 13P 更換"移植電芯"和"超容電池"🔋體驗,詳細記錄了如何更換這兩種電池,以及各自的優略勢對比。 一晃一年過去,時間真快,這次分享下記錄了使用超容電池的 “循環次數 - 容量(mAh)…

基于 pnpm + Monorepo + Turbo + 無界微前端 + Vite 的企業級前端工程實踐

基于 pnpm Monorepo Turbo 無界微前端 Vite 的企業級前端工程實踐 一、技術演進:為什么引入 Vite? 在微前端與 Monorepo 架構落地后,構建性能成為新的優化重點: Webpack 構建瓶頸:復雜配置導致開發啟動慢&#…

(五)機器學習---決策樹和隨機森林

在分類問題中還有一個常用算法:就是決策樹。本文將會對決策樹和隨機森林進行介紹。 目錄 一.決策樹的基本原理 (1)決策樹 (2)決策樹的構建過程 (3)決策樹特征選擇 (4&#xff0…

Vue3使用AntvG6寫拓撲圖,可添加修改刪除節點和邊

npm安裝antv/g6 npm install antv/g6 --save 上代碼 <template><div id"tpt1" ref"container" style"width: 100%;height: 100%;"></div> </template><script setup>import { Renderer as SVGRenderer } from …

Arduino編譯和燒錄STM32——基于J-link SWD模式

一、安裝Stm32 Arduino支持 在arduino中添加stm32的開發板地址&#xff1a;https://github.com/stm32duino/BoardManagerFiles/raw/main/package_stmicroelectronics_index.json 安裝stm32開發板支持 二、安裝STM32CubeProgrammer 從stm32網站中安裝&#xff1a;https://ww…

智慧城市氣象中臺架構:多源天氣API網關聚合方案

在開發與天氣相關的應用時&#xff0c;獲取準確的天氣信息是一個關鍵需求。萬維易源提供的“天氣預報查詢”API為開發者提供了一個高效、便捷的工具&#xff0c;可以通過簡單的接口調用查詢全國范圍內的天氣信息。本文將詳細介紹如何使用該API&#xff0c;以及其核心功能和調用…

Vue 組件化開發

引言 在當今的 Web 開發領域&#xff0c;構建一個功能豐富且用戶體驗良好的博客是許多開發者的目標。Vue.js 作為一款輕量級且高效的 JavaScript 框架&#xff0c;其組件化開發的特性為我們提供了一種優雅的解決方案。通過將博客拆分成多個獨立的組件&#xff0c;我們可以提高代…

Deno 統一 Node 和 npm,既是 JS 運行時,又是包管理器

Deno 是一個現代的、一體化的、零配置的 JavaScript 運行時、工具鏈&#xff0c;專為 JavaScript 和 TypeScript 開發設計。目前已有數十萬開發者在使用 Deno&#xff0c;其代碼倉庫是 GitHub 上 star 數第二高的 Rust 項目。 Stars 數102620Forks 數5553 主要特點 內置安全性…

應用篇02-鏡頭標定(上)

本節主要介紹相機的標定方法&#xff0c;包括其內、外參數的求解&#xff0c;以及如何使用HALCON標定助手實現標定。 計算機視覺——相機標定(Camera Calibration)_攝像機標定-CSDN博客 1. 原理 本節介紹與相機標定相關的理論知識&#xff0c;不一定全&#xff0c;可以參考相…

PG CTE 遞歸 SQL 翻譯為 達夢版本

文章目錄 PG SQLDM SQL總結 PG SQL with recursive result as (select res_id,phy_res_code,res_name from tbl_res where parent_res_id (select res_id from tbl_res where phy_res_code org96000#20211203155858) and res_type_id 1 union all select t1.res_id, t1.p…

C# Where 泛型約束

在C#中&#xff0c;Where關鍵字主要有兩種用途 1、在泛型約束中限制類型參數 2、在LINQ查詢中篩選數據 本文主要介紹where關鍵字在在泛型約束中的使用 泛型定義中的 where 子句指定對用作泛型類型、方法、委托或本地函數中類型參數的參數類型的約束。通過使用 where 關鍵字和…

《MySQL:MySQL表的約束-主鍵/復合主鍵/唯一鍵/外鍵》

表的約束&#xff1a;表中一定要有各種約束&#xff0c;通過約束&#xff0c;讓未來插入數據庫表中的數據是符合預期的。約束本質是通過技術手段&#xff0c;倒逼程序員插入正確的數據。即&#xff0c;站在mysql的視角&#xff0c;凡是插入進來的數據&#xff0c;都是符合數據約…

Qt 創建QWidget的界面庫(DLL)

【1】新建一個qt庫項目 【2】在項目目錄圖標上右擊&#xff0c;選擇Add New... 【3】選擇模版&#xff1a;Qt->Qt設計師界面類&#xff0c;選擇Widget&#xff0c;填寫界面類的名稱、.h .cpp .ui名稱 【4】創建C調用接口&#xff08;默認是創建C調用接口&#xff09; #ifnd…

汽車免拆診斷案例 | 2011款雪鐵龍世嘉車刮水器偶爾自動工作

故障現象 一輛2011款雪鐵龍世嘉車&#xff0c;搭載1.6 L 發動機&#xff0c;累計行駛里程約為19.8萬km。車主反映&#xff0c;該車刮水器偶爾會自動工作&#xff0c;且前照燈偶爾會自動點亮。 故障診斷 接車后試車發現&#xff0c;除了上述故障現象以外&#xff0c;當用遙控器…

【Linux】NAT、代理服務、內網穿透

NAT、代理服務、內網穿透 一. NAT1. NAT 技術2. NAT IP 轉換過程3. NAPT 技術4. NAT 技術的缺陷 二. 代理服務器1. 正向代理2. 反向代理3. NAT 和代理服務器 內網穿透內網打洞 一. NAT NAT&#xff08;Network Address Translation&#xff0c;網絡地址轉換&#xff09;技術&a…

MobaXterm連接Ubuntu(SSH)

1.查看Ubuntu ip 打開終端,使用指令 ifconfig 由圖可知ip地址 2.MobaXterm進行SSH連接 點擊session,然后點擊ssh,最后輸入ubuntu IP地址以及用戶名