PyTorch 與 TensorFlow 中基于自定義層的 DNN 實現對比

深度學習雙雄對決:PyTorch vs TensorFlow 自定義層大比拼



一、TensorFlow 實現 DNN

1. 核心邏輯

  • 直接繼承 tf.keras.layers.Layer:無需中間類,直接在 build 中定義多層結構。
  • 動態參數管理:通過 add_weight 注冊每一層的權重和偏置。
import tensorflow as tfclass CustomDNNLayer(tf.keras.layers.Layer):def __init__(self, hidden_units, output_dim, **kwargs):super(CustomDNNLayer, self).__init__(**kwargs)self.hidden_units = hidden_unitsself.output_dim = output_dimdef build(self, input_shape):# 輸入層到第一個隱藏層self.w1 = self.add_weight(name='w1', shape=(input_shape[-1], self.hidden_units[0]),initializer='random_normal',trainable=True)self.b1 = self.add_weight(name='b1',shape=(self.hidden_units[0],),initializer='zeros',trainable=True)# 隱藏層之間self.ws = []self.bs = []for i in range(len(self.hidden_units) - 1):self.ws.append(self.add_weight(name=f'w{i+2}', shape=(self.hidden_units[i], self.hidden_units[i+1]),initializer='random_normal',trainable=True))self.bs.append(self.add_weight(name=f'b{i+2}',shape=(self.hidden_units[i+1],),initializer='zeros',trainable=True))# 輸出層self.wo = self.add_weight(name='wo',shape=(self.hidden_units[-1], self.output_dim),initializer='random_normal',trainable=True)self.bo = self.add_weight(name='bo',shape=(self.output_dim,),initializer='zeros',trainable=True)def call(self, inputs):x = tf.matmul(inputs, self.w1) + self.b1x = tf.nn.relu(x)for i in range(len(self.hidden_units) - 1):x = tf.matmul(x, self.ws[i]) + self.bs[i]x = tf.nn.relu(x)x = tf.matmul(x, self.wo) + self.boreturn x

二、PyTorch 實現自定義層

1. 核心邏輯

  • 繼承 nn.Module:自定義層本質是模塊的組合。
  • 使用 nn.ModuleList:動態管理多個 nn.Linear 層。
import torch
import torch.nn as nnclass CustomPyTorchDNN(nn.Module):def __init__(self, input_size, hidden_sizes, output_size):super(CustomPyTorchDNN, self).__init__()self.hidden_layers = nn.ModuleList()prev_size = input_size# 動態添加隱藏層for hidden_size in hidden_sizes:self.hidden_layers.append(nn.Linear(prev_size, hidden_size))prev_size = hidden_size# 輸出層self.output_layer = nn.Linear(prev_size, output_size)def forward(self, x):for layer in self.hidden_layers:x = torch.relu(layer(x))x = self.output_layer(x)return x

三、關鍵差異對比

維度TensorFlow 實現PyTorch 實現
類繼承方式直接繼承 tf.keras.layers.Layer,無中間類。繼承 nn.Module,通過 nn.ModuleList 管理子模塊。
參數管理build 中顯式注冊每層權重(add_weight)。自動注冊所有 nn.Linear 參數(無需手動操作)。
前向傳播定義通過 call 方法逐層計算,需手動處理每層的權重和激活函數。通過 forward 方法逐層調用 nn.Linear,激活函數手動插入。
靈活性更底層,適合完全自定義邏輯(如非線性變換、特殊參數初始化)。更簡潔,適合快速構建標準網絡結構。
訓練流程需手動實現訓練循環(反向傳播 + 優化器)。需手動實現訓練循環(與 TensorFlow 類似)。

四、總結

  • TensorFlow:通過直接繼承 tf.keras.layers.Layer,可實現完全自定義的 DNN,但需手動管理多層權重和激活邏輯,適合對模型細節有嚴格控制需求的場景。
  • PyTorch:通過直接繼承 nn.Module,可實現完全自定義的 DNN;利用 nn.ModuleListnn.Linear 的組合,能高效構建標準 DNN 結構,代碼簡潔且易于擴展,適合快速原型開發和研究場景。

兩種實現均滿足用戶對“直接繼承核心類 + 使用基礎組件”的要求,可根據具體任務選擇框架。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/79688.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/79688.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/79688.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

1ms城市算網穩步啟航,引領數字領域的“1小時經濟圈”效應

文 | 智能相對論 作者 | 陳選濱 為什么近年來國產動畫、國產3A大作迎來了井噴式爆發?拋開制作水平以及市場需求的升級不談,還有一個重要原因往往被大多數人所忽視,那就是新型信息的完善與成熟。 譬如,現階段驚艷用戶的云游戲以及…

【計算機視覺】語義分割:Segment Anything (SAM):通用圖像分割的范式革命

Segment Anything:通用圖像分割的范式革命 技術突破與架構創新核心設計理念關鍵技術組件 環境配置與快速開始硬件要求安裝步驟基礎使用示例 深度功能解析1. 多模態提示融合2. 全圖分割生成3. 高分辨率處理 模型微調與定制1. 自定義數據集準備2. 微調訓練配置 常見問…

機器學習例題——預測facebook簽到位置(K近鄰算法)和葡萄酒質量預測(線性回歸)

一、預測facebook簽到位置 代碼展示: import pandas as pd from sklearn.preprocessing import StandardScaler from sklearn.model_selection import train_test_split from sklearn.neighbors import KNeighborsClassifier from sklearn.model_selection import…

對ubuntu的簡單介紹

目錄 1. 簡介 2. 核心特點 3. 系統架構與技術亮點 4. 適用場景 5. 優缺點分析 6. 安裝與配置建議 7. 未來發展方向 總結 1. 簡介 Ubuntu 是基于 Debian 的開源 Linux 操作系統,由 Canonical 公司(創始人 Mark Shuttleworth)提供商業支…

多商戶電商系統整套源碼開源,支持二次開發,構建多店鋪高效聯動運營方案

在數字化浪潮席卷全球的今天,電商行業競爭愈發激烈,多商戶電商平臺憑借其獨特的生態優勢,成為眾多企業和創業者的熱門選擇。一套優質的多商戶電商系統不僅能為商家提供穩定的銷售渠道,還能為平臺運營者創造巨大的商業價值。分享一…

Qwen3與Deepseek R1對比(截止20250506)

Qwen3和DeepSeek R1都是在AI領域內備受關注的大規模語言模型。根據最近的評測和報道,以下是Qwen3與DeepSeek R1的一些對比要點: 全面性能: Qwen3被描述為在數學、推理、代碼等核心能力上全面超越了DeepSeek R1。特別是在編程能力方面&#x…

Linux56 YUM源配置

epel未啟動 顯示系統未通過注冊 配置YUM倉庫 本地YUM倉庫 1.備份 tar -zcf repo.tar.gz *.repo 2.掛載 mount -o ro /dev/sr0 /mnt 3.開機自啟 chmod x /etc/rc.local echo ‘mount -o ro /dec/sr0 /mnt’ /etc/rc.local 4.編寫本地YUM倉庫 local.repo [local] namelocal yum …

二叉樹—中序遍歷—非遞歸

初始狀態 假設當前從根節點 b 開始,此時棧為空 。 第一步:處理根節點 b 的左子樹 調用 goAlongLeftBranch 函數,從節點 b 開始,因為 b 有左子樹(節點 a ),將 b 入棧,此時棧&#…

R 語言科研繪圖第 45 期 --- 桑基圖-和弦

在發表科研論文的過程中,科研繪圖是必不可少的,一張好看的圖形會是文章很大的加分項。 為了便于使用,本系列文章介紹的所有繪圖都已收錄到了 sciRplot 項目中,獲取方式: R 語言科研繪圖模板 --- sciRplothttps://mp.weixin.qq.c…

ARM 流控制指令

計算機按照嚴格的順序執行指令。流控制改變了默認的順序執行方式。前面已 經介紹了強制跳轉到程序中某個非順序位置的無條件分支。以及依據測試結果 進行跳轉的條件分支。這里將介紹子程序調用和返回指令,它們會跳轉到一個 指令塊、執行這些指令,然后返回…

PDF內容搜索--支持跨文件夾多文件、組合詞搜索

平時我們接觸到的PDF文檔特別多,需要對PDF文檔做一些處理,那么今天給大家帶來的這兩個軟件非常的棒,可以幫你提升處理文檔的效率。 PDF內容搜索 快速檢索 我用夸克網盤分享了「PDF搜索PDF 轉長圖.zip」,點擊鏈接即可保存。打開「…

個人Unity自用面經(未完)

目錄標題 1.在 2D 平臺跳躍游戲項目中,你使用了對象池來生成和回收怪物包含陣亡的動畫預制件。在對象池回收對象時,如何確保動畫狀態被正確重置,避免下次使用時出現異常?2.在僵尸吃腦子模擬項目中,你創建了繼承于IAspe…

【計網】ICMP、IP、Ethernet PDU之間的封裝關系

TCP/IP體系結構 應用層RIP、OSPF、FTP運輸層TCP、UDP網際層IP、ARP、ICMP網絡接口層底層協議(Ethernet) 數據鏈路層 Ethernet報文格式 6Byte6Byte2Byte46~1500Byte4Byte目的MAC地址源MAC地址類型/長度數據FCS 其中,類型 / 長度值小于 1536…

前端取經路——入門取經:初出師門的九個CSS修行

大家好,我是老十三,一名前端開發工程師。CSS就像前端修行路上的第一道關卡,看似簡單,實則暗藏玄機。在今天的文章中,我將帶你一起應對九大CSS難題,從Flexbox布局到響應式設計,從選擇器優先級到B…

n8n工作流自動化平臺的實操:Cannot find module ‘iconv-lite‘

解決問題: 1.在可視化界面,執行const iconv require(iconv-lite);,報Cannot find module iconv-lite [line 2]錯誤; 查看module的路徑 進入docker容器 #docker exec -it n8n /bin/sh 構建一個test.js,并寫入如何代碼 vi tes…

如何在 PowerEdge 服務器上設置 NIC 分組

以下文章提供了有關 Windows、VMware 和 Linux 中的 NIC 分組的信息。 什么是網絡適配器分組?設置 NIC 分組 Windows設置 NIC 分組 VMware設置 NIC 分組 Linux 什么是網絡適配器分組(綁定)? 網絡適配器分組是一個術語&#xff0…

【Java ee初階】多線程(5)

一、wait 和 notify wait notify 是兩個用來協調線程執行順序的關鍵字,用來避免“線程餓死”的情況。 wait 和 notify 其實都是 Object 這個類的方法,而 Object這個類是所有類的“祖宗類”,也就是說明,任何一個類,都…

基于k8s的Jenkins CI/CD平臺部署實踐(二):流水線構建與自動部署全流程

基于k8s的Jenkins CI/CD平臺部署實踐(二):流水線構建與自動部署全流程 文章目錄 基于k8s的Jenkins CI/CD平臺部署實踐(二):流水線構建與自動部署全流程一、Jenkins簡介二、系統架構與環境說明1. 系統架構2.…

《Windows 環境下 Qt C++ 項目升級 GCC 版本的完整指南》

Windows 環境下 Qt C++ 項目升級 GCC 版本的完整指南 在 Windows 系統中升級 Qt C++ 項目的 GCC 版本需要同時考慮 Qt 工具鏈、MinGW 環境以及項目配置的調整。以下是詳細的升級步驟和注意事項: 一、升級前的準備工作 1. 確認當前環境 檢查 Qt 版本(建議使用 Qt 5.15+ 以獲…

【coze】故事卡片(圖片、音頻、文字)

【coze】故事卡片(圖片、音頻、文字) 1、創建智能體2、添加人設與回復邏輯3、添加工作流(1)創建工作流(2)添加大模型節點(3)添加提示詞優化節點(4)添加豆包圖…