理解深度學習pytorch框架中的線性層

在神經網絡或機器學習的線性層(Linear Layer / Fully Connected Layer)中,經常會見到兩種形式的公式:

  • 數學文獻或傳統線性代數寫法: y = W x + b \displaystyle y = W\,x + b y=Wx+b
  • 一些深度學習代碼中寫法: y = x W T + b \displaystyle y = x\,W^T + b y=xWT+b

初次接觸時,很多人會覺得兩者“方向”不太一樣,不知該如何對照理解;再加上矩陣維度 ( in_features , out_features ) (\text{in\_features},\, \text{out\_features}) (in_features,out_features) ( out_features , in_features ) (\text{out\_features},\, \text{in\_features}) (out_features,in_features) 的各種寫法常常讓人疑惑不已。本文將從數學角度和編程實現角度剖析它們的關系,并結合實際示例指出一些常見的坑與需要特別留意的下標對應問題。

1. 數學角度: y = W x + b \displaystyle y = W\,x + b y=Wx+b

在線性代數中,如果我們假設輸入 x x x 是一個列向量,通常會寫作 x ∈ R ( in_features ) \displaystyle x\in\mathbb{R}^{(\text{in\_features})} xR(in_features)(或者在更嚴格的矩陣形狀記法下寫作 ( in_features , 1 ) (\text{in\_features},\,1) (in_features,1))。那么一個最常見的全連接層可以表示為:

y = W x + b , y = W\,x + b, y=Wx+b,

其中:

  • W W W 是一個大小為 ( out_features , in_features ) \bigl(\text{out\_features},\,\text{in\_features}\bigr) (out_features,in_features) 的矩陣;
  • b b b 是一個 out_features \text{out\_features} out_features-維的偏置向量(形狀 ( out_features , 1 ) (\text{out\_features},\,1) (out_features,1));
  • y y y 則是輸出向量,大小為 out_features \text{out\_features} out_features

示例

假設 in_features = 3 \text{in\_features}=3 in_features=3 out_features = 2 \text{out\_features}=2 out_features=2。那么:
W ∈ R 2 × 3 , x ∈ R 3 × 1 , b ∈ R 2 × 1 . W \in \mathbb{R}^{2\times 3},\quad x \in \mathbb{R}^{3\times 1},\quad b \in \mathbb{R}^{2\times 1}. WR2×3,xR3×1,bR2×1.

矩陣寫開來就是:

W = [ w 11 w 12 w 13 w 21 w 22 w 23 ] , x = [ x 1 x 2 x 3 ] , b = [ b 1 b 2 ] . W = \begin{bmatrix} w_{11} & w_{12} & w_{13} \\[5pt] w_{21} & w_{22} & w_{23} \end{bmatrix},\quad x = \begin{bmatrix} x_{1}\\ x_{2}\\ x_{3} \end{bmatrix},\quad b = \begin{bmatrix} b_{1}\\ b_{2} \end{bmatrix}. W=[w11?w21??w12?w22??w13?w23??],x= ?x1?x2?x3?? ?,b=[b1?b2??].

那么線性變換結果 W x + b Wx + b Wx+b 可以展開為:

W x + b = [ w 11 x 1 + w 12 x 2 + w 13 x 3 w 21 x 1 + w 22 x 2 + w 23 x 3 ] + [ b 1 b 2 ] = [ w 11 x 1 + w 12 x 2 + w 13 x 3 + b 1 w 21 x 1 + w 22 x 2 + w 23 x 3 + b 2 ] . \begin{aligned} Wx + b &= \begin{bmatrix} w_{11}x_1 + w_{12}x_2 + w_{13}x_3 \\ w_{21}x_1 + w_{22}x_2 + w_{23}x_3 \end{bmatrix} + \begin{bmatrix} b_1 \\ b_2 \end{bmatrix} \\ &= \begin{bmatrix} w_{11}x_1 + w_{12}x_2 + w_{13}x_3 + b_1 \\ w_{21}x_1 + w_{22}x_2 + w_{23}x_3 + b_2 \end{bmatrix}. \end{aligned} Wx+b?=[w11?x1?+w12?x2?+w13?x3?w21?x1?+w22?x2?+w23?x3??]+[b1?b2??]=[w11?x1?+w12?x2?+w13?x3?+b1?w21?x1?+w22?x2?+w23?x3?+b2??].?

這就是最傳統、在數學文獻或線性代數課程中最常見的表示方法。


2. 編程實現角度: y = x W T + b \displaystyle y = x\,W^T + b y=xWT+b

在實際的深度學習代碼(例如 PyTorch、TensorFlow)中,經常看到的卻是下面這種寫法:

y = x @ W.T + b

注意這里 W.shape 通常被定義為 ( out_features , in_features ) (\text{out\_features},\, \text{in\_features}) (out_features,in_features),而 x.shape 在批量處理時則是 ( batch_size , in_features ) (\text{batch\_size},\, \text{in\_features}) (batch_size,in_features)。于是 (x @ W.T) 的結果是 ( batch_size , out_features ) (\text{batch\_size},\, \text{out\_features}) (batch_size,out_features)

為什么會出現轉置?
因為在數學里我們通常把 x x x 當作“列向量”放在右邊,于是公式變成 y = W x + b y = Wx + b y=Wx+b
但在編程里,尤其是處理批量輸入時,x 常寫成“行向量”的形式 ( batch_size , in_features ) (\text{batch\_size},\, \text{in\_features}) (batch_size,in_features),這就造成了在進行矩陣乘法時,需要將 W(大小 ( out_features , in_features ) (\text{out\_features},\, \text{in\_features}) (out_features,in_features))轉置成 ( in_features , out_features ) (\text{in\_features},\, \text{out\_features}) (in_features,out_features),才能滿足「行×列」的匹配關系。

從結果上來看,

( batch_size , in_features ) × ( in_features , out_features ) = ( batch_size , out_features ) . (\text{batch\_size}, \text{in\_features}) \times (\text{in\_features}, \text{out\_features}) = (\text{batch\_size}, \text{out\_features}). (batch_size,in_features)×(in_features,out_features)=(batch_size,out_features).

所以,在代碼里就寫成 x @ W.T,再加上偏置 b(通常會廣播到 batch_size \text{batch\_size} batch_size 那個維度)。

本質上這和數學公式里 y = W x + b y = W\,x + b y=Wx+b 并無沖突,只是一個“列向量”和“行向量”的轉置關系。只要搞清楚最終你想讓輸出 y y y 的 shape 是多少,就能明白在代碼里為什么要寫 .T


3. 常見錯誤與易混點解析

有些教程或文檔,會不小心寫成:“如果我們有一個形狀為 ( in_features , out_features ) (\text{in\_features},\text{out\_features}) (in_features,out_features) 的權重矩陣 W W W……”——然后又要做 W x Wx Wx,想得到一個 out_features \text{out\_features} out_features-維的結果。但按照線性代數的常規寫法,行數必須和輸出維度匹配、列數必須和輸入維度匹配。所以 正確 的說法應該是

W ∈ R ( out_features ) × ( in_features ) . W\in\mathbb{R}^{(\text{out\_features}) \times (\text{in\_features})}. WR(out_features)×(in_features).

否則從矩陣乘法次序來看就對不上。
但這又可能讓人迷惑:為什么深度學習框架 torch.nn.Linear(in_features, out_features) 卻給出 weight.shape == (out_features, in_features) 其實正是同一個道理,它和上面“數學文獻里”用到的 W W W 形狀完全一致。


4. 小結

  1. 從數學角度
    最傳統的記號是
    y = W x + b , W ∈ R ( out_features ) × ( in_features ) , x ∈ R ( in_features ) , y ∈ R ( out_features ) . y = W\,x + b, \quad W \in \mathbb{R}^{(\text{out\_features})\times(\text{in\_features})},\, x \in \mathbb{R}^{(\text{in\_features})},\, y \in \mathbb{R}^{(\text{out\_features})}. y=Wx+b,WR(out_features)×(in_features),xR(in_features),yR(out_features).

  2. 從深度學習代碼角度

    • 由于批量數據常被視為行向量,每一行代表一個樣本特征,因此形狀通常是 ( batch_size , in_features ) (\text{batch\_size},\, \text{in\_features}) (batch_size,in_features)
    • 對應的權重 W 定義為 ( out_features , in_features ) (\text{out\_features},\, \text{in\_features}) (out_features,in_features)。為了完成行乘以列的矩陣運算,需要對 W 做轉置:
      y = x @ W.T + b
      
    • 得到的 y.shape ( batch_size , out_features ) (\text{batch\_size},\, \text{out\_features}) (batch_size,out_features)
  3. 避免踩坑

    • 寫公式時,仔細確認 in_features \text{in\_features} in_features out_features \text{out\_features} out_features 的位置以及矩陣行列順序。
    • 編程實踐中理解“為什么要 .T”非常重要:那只是為了匹配「行×列」的矩陣乘法規則,本質上還是和 y = W x + b y = Wx + b y=Wx+b 相同。

通過理解并區分“列向量”與“行向量”的不同慣例,避免因為矩陣維度或轉置不當而導致莫名其妙的錯誤或 bug。


參考鏈接

  • PyTorch 文檔:torch.nn.Linear
  • 深度學習中的矩陣運算初步 —— batch_size 與矩陣乘法
  • 常見線性代數符號:行向量與列向量

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/66895.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/66895.shtml
英文地址,請注明出處:http://en.pswp.cn/web/66895.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

C#Object類型的索引,序列化和反序列化

前言 最近在編寫一篇關于標準Mes接口框架的文章。其中有一個非常需要考究的內容時如果實現數據靈活和可使用性強。因為考慮數據靈活性,所以我一開始選取了Object類型作為數據類型,Object作為數據Value字段,String作為數據Key字段&#xff0c…

大模型應用與部署 技術方案

大模型應用與部署 技術方案 一、引言 人工智能蓬勃發展,Qwen 大模型在自然語言處理領域地位關鍵,其架構優勢盡顯,能處理文本創作等多類復雜任務,提供優質交互。Milvus 向量數據庫則是向量數據存儲檢索利器,有高效索引算法(如 IVF_FLAT、HNSWLIB 等)助力大規模數據集相似…

【Prometheus】Prometheus如何監控Haproxy

?? 歡迎大家來到景天科技苑?? 🎈🎈 養成好習慣,先贊后看哦~🎈🎈 🏆 作者簡介:景天科技苑 🏆《頭銜》:大廠架構師,華為云開發者社區專家博主,…

C# 控制打印機:從入門到實踐

在開發一些涉及打印功能的應用程序時,使用 C# 控制打印機是一項很實用的技能。這篇文章就來詳細介紹下如何在 C# 中實現對打印機的控制。 一、準備工作 安裝相關庫:在 C# 中操作打印機,我們可以借助System.Drawing.Printing命名空間&#x…

Go語言中的值類型和引用類型特點

一、值類型 值類型的數據直接包含值,當它們被賦值給一個新的變量或者作為參數傳遞給函數時,實際上是創建了原值的一個副本。這意味著對新變量的修改不會影響原始變量的值。 Go中的值類型包括: 基礎類型:int,float64…

GPT 結束語設計 以nanogpt為例

GPT 結束語設計 以nanogpt為例 目錄 GPT 結束語設計 以nanogpt為例 1、簡述 2、分詞設計 3、結束語斷點 1、簡述 在手搓gpt的時候,可能會遇到一些性能問題,即關于是否需要全部輸出或者怎么節約資源。 在輸出語句被max_new_tokens 限制&#xff0c…

《探秘:人工智能如何為鴻蒙Next元宇宙網絡傳輸與延遲問題破局》

在元宇宙的宏大愿景中,流暢的網絡傳輸和低延遲是保障用戶沉浸式體驗的關鍵。鴻蒙Next結合人工智能技術,為解決這些問題提供了一系列創新思路和方法。 智能網絡監測與預測 人工智能可以實時監測鴻蒙Next元宇宙中的網絡狀況,包括帶寬、延遲、…

深入MapReduce——計算模型設計

引入 通過引入篇,我們可以總結,MapReduce針對海量數據計算核心痛點的解法如下: 統一編程模型,降低用戶使用門檻分而治之,利用了并行處理提高計算效率移動計算,減少硬件瓶頸的限制 優秀的設計&#xff0c…

macOS安裝Gradle環境

文章目錄 說明安裝JDK安裝Gradle 說明 gradle8.5最高支持jdk21,如果使用jdk22建議使用gradle8.8以上版本 安裝JDK mac系統安裝最新(截止2024.9.13)Oracle JDK操作記錄 安裝Gradle 下載Gradle,解壓將其存放到資源java/env目錄…

五國十五校聯合巨獻!仿人機器人運動與操控:控制、規劃與學習的最新突破與挑戰

作者: Zhaoyuan Gu, Junheng Li, Wenlan Shen, Wenhao Yu, Zhaoming Xie, Stephen McCrory, Xianyi Cheng, Abdulaziz Shamsah, Robert Griffin, C. Karen Liu, Abderrahmane Kheddar, Xue Bin Peng, Yuke Zhu, Guanya Shi, Quan Nguyen, Gordon Cheng, Huijun Gao,…

CVPR 2024 無人機/遙感/衛星圖像方向總匯(航空圖像和交叉視角定位)

1、UAV、Remote Sensing、Satellite Image(無人機/遙感/衛星圖像) Unleashing Unlabeled Data: A Paradigm for Cross-View Geo-Localization ?codeRethinking Transformers Pre-training for Multi-Spectral Satellite Imagery ?codeAerial Lifting: Neural Urban Semantic …

【BQ3568HM開發板】如何在OpenHarmony上通過校園網的上網認證

引言 前面已經對BQ3568HM開發板進行了初步測試,后面我要實現MQTT的工作,但是遇到一個問題,就是開發板無法通過校園網的認證操作。未認證的話會,學校使用的深瀾軟件系統會屏蔽所有除了認證用的流量。好在我們學校使用的認證系統和…

(Java版本)基于JAVA的網絡通訊系統設計與實現-畢業設計

源碼 論文 下載地址: ????c??????c基于JAVA的網絡通訊系統設計與實現(源碼系統論文)https://download.csdn.net/download/weixin_39682092/90299782https://download.csdn.net/download/weixin_39682092/90299782 第1章 緒論 1.1 課題選擇的…

kafka學習筆記4-TLS加密 —— 筑夢之路

1. 準備證書文件 mkdir /opt/kafka/pkicd !$# 生成CA證書 openssl req -x509 -nodes -days 3650 -newkey rsa:4096 -keyout ca.key -out ca.crt -subj "/CNKafka-CA"# 生成私鑰 openssl genrsa -out kafka.key 4096# 生成證書簽名請求 (CSR) openssl req -new -key …

Node.js NativeAddon 構建工具:node-gyp 安裝與配置完全指南

Node.js NativeAddon 構建工具:node-gyp 安裝與配置完全指南 node-gyp Node.js native addon build tool [這里是圖片001] 項目地址: https://gitcode.com/gh_mirrors/no/node-gyp 項目基礎介紹及主要編程語言 Node.js NativeAddon 構建工具(node-gyp…

SpringCloud微服務Gateway網關簡單集成Sentinel

Sentinel是阿里巴巴開源的一款面向分布式服務架構的輕量級流量控制、熔斷降級組件。Sentinel以流量為切入點,從流量控制、熔斷降級、系統負載保護等多個維度來幫助保護服務的穩定性。 官方文檔:https://sentinelguard.io/zh-cn/docs/introduction.html …

vscode環境中用倉頡語言開發時調出覆蓋率的方法

在vscode中倉頡語言想得到在idea中利用junit和jacoco的覆蓋率,需要如下幾個步驟: 1.在vscode中搭建倉頡語言開發環境; 2.在源代碼中右鍵運行[cangjie]coverage. 思路1:編寫了測試代碼的情況(包管理工具) …

pikachu靶場-敏感信息泄露概述

敏感信息泄露概述 由于后臺人員的疏忽或者不當的設計,導致不應該被前端用戶看到的數據被輕易的訪問到。 比如: ---通過訪問url下的目錄,可以直接列出目錄下的文件列表; ---輸入錯誤的url參數后報錯信息里面包含操作系統、中間件、開發語言的版…

安卓動態設置Unity圖形API

命令行方式 Unity圖像api設置為自動,安卓動態設置Vulkan、OpenGLES Unity設置 安卓設置 創建自定義活動并將其設置為應用程序入口點。 在自定義活動中,覆蓋字符串UnityPlayerActivity。updateunitycommandlineararguments (String cmdLine)方法。 在該方法中,將cmdLine…

CICD集合(五):Jenkins+Git+Allure實戰(自動化測試)

CICD集合(五):Jenkins+Git+Allure實戰(自動化測試) 前提: 已安裝好Jenkins安裝好git,maven,allure報告插件配置好Git,Maven,allure參考:CICD集合(一至四) https://blog.csdn.net/fen_fen/article/details/131476093 https://blog.csdn.net/fen_fen/article/details/1213…