深度學習Dropout實現

深度學習中的 Dropout 技術在代碼層面上的實現通常非常直接。其核心思想是在訓練過程中,對于網絡中的每個神經元(或者更精確地說,是每個神經元的輸出),以一定的概率 p 隨機將其輸出置為 0。在反向傳播時,這些被“drop out”的神經元也不會參與梯度更新。

以下是 Dropout 在代碼層面上的一個基本實現邏輯,以 Python 和 NumPy 為例進行說明,然后再展示在常見的深度學習框架(如 TensorFlow 和 PyTorch)中的實現方式。

1. NumPy 實現(概念演示)

假設我們有一個神經網絡的某一層輸出 activation,它是一個形狀為 (batch_size, num_neurons) 的 NumPy 數組。我們可以通過以下步驟實現 Dropout:

Python

import numpy as npdef dropout_numpy(activation, keep_prob):"""使用 NumPy 實現 Dropout。Args:activation: 神經網絡層的激活輸出 (NumPy array).keep_prob: 保留神經元的概率 (float, 0 到 1 之間).Returns:經過 Dropout 處理的激活輸出 (NumPy array).mask: 用于記錄哪些神經元被 drop out 的掩碼 (NumPy array)."""if keep_prob < 0. or keep_prob > 1.:raise ValueError("keep_prob must be between 0 and 1")# 生成一個和 activation 形狀相同的隨機掩碼,元素值為 True 或 Falsemask = (np.random.rand(*activation.shape) < keep_prob)# 將掩碼應用于激活輸出,被 drop out 的神經元輸出置為 0output = activation * mask# 在訓練階段,為了保證下一層的期望輸入不變,需要對保留下來的神經元輸出進行縮放output /= keep_probreturn output, mask# 示例
batch_size = 64
num_neurons = 128
activation = np.random.randn(batch_size, num_neurons)
keep_prob = 0.8dropout_output, dropout_mask = dropout_numpy(activation, keep_prob)print("原始激活輸出的形狀:", activation.shape)
print("Dropout 后的激活輸出的形狀:", dropout_output.shape)
print("Dropout 掩碼的形狀:", dropout_mask.shape)
print("被 drop out 的神經元比例:", np.sum(dropout_mask == False) / dropout_mask.size)

代碼解釋:

  • keep_prob: 這是保留神經元的概率。Dropout 的概率通常設置為 1 - keep_prob
  • 生成掩碼 (mask): 我們使用 np.random.rand() 生成一個和輸入 activation 形狀相同的隨機數數組,其元素值在 0 到 1 之間。然后,我們將這個數組與 keep_prob 進行比較,得到一個布爾類型的掩碼。True 表示對應的神經元被保留,False 表示被 drop out。
  • 應用掩碼 (output = activation * mask): 我們將掩碼和原始的激活輸出進行逐元素相乘。由于布爾類型的 TrueFalse 在數值運算中會被轉換為 1 和 0,所以掩碼中為 False 的位置對應的激活輸出會被置為 0。
  • 縮放 (output /= keep_prob): 這是一個非常重要的步驟。在訓練階段,由于一部分神經元被隨機置為 0,為了保證下一層神經元接收到的期望輸入與沒有 Dropout 時大致相同,我們需要對保留下來的神經元的輸出進行放大。放大的倍數是 1 / keep_prob

需要注意的是,在模型的評估(或推理)階段,通常不會使用 Dropout。這意味著 keep_prob 會被設置為 1,或者 Dropout 層會被禁用。這是因為 Dropout 是一種在訓練時使用的正則化技術,用于減少過擬合。在評估時,我們希望模型的所有神經元都參與計算,以獲得最準確的預測。

2. TensorFlow 實現

在 TensorFlow 中,Dropout 是一個內置的層:

Python

import tensorflow as tf# 在 Sequential 模型中添加 Dropout 層
model = tf.keras.models.Sequential([tf.keras.layers.Dense(128, activation='relu', input_shape=(784,)),tf.keras.layers.Dropout(0.2), # Dropout 概率為 0.2 (即 keep_prob 為 0.8)tf.keras.layers.Dense(10, activation='softmax')
])# 或者在函數式 API 中使用
inputs = tf.keras.Input(shape=(784,))
x = tf.keras.layers.Dense(128, activation='relu')(inputs)
x = tf.keras.layers.Dropout(0.2)(x)
outputs = tf.keras.layers.Dense(10, activation='softmax')(x)
model_functional = tf.keras.Model(inputs=inputs, outputs=outputs)# 訓練模型
# model.compile(...)
# model.fit(...)

在 TensorFlow 的 tf.keras.layers.Dropout(rate) 層中,rate 參數指定的是神經元被 drop out 的概率。在訓練時,這個層會隨機將一部分神經元的輸出置為 0,并對剩下的神經元進行縮放。在推理時,這個層不會有任何作用。TensorFlow 內部會自動處理訓練和推理階段的行為。

3. PyTorch 實現

在 PyTorch 中,Dropout 也是一個內置的模塊:

代碼段

import torch
import torch.nn as nnclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.fc1 = nn.Linear(784, 128)self.relu = nn.ReLU()self.dropout = nn.Dropout(p=0.2) # Dropout 概率為 0.2self.fc2 = nn.Linear(128, 10)self.softmax = nn.Softmax(dim=1)def forward(self, x):x = self.fc1(x)x = self.relu(x)x = self.dropout(x)x = self.fc2(x)x = self.softmax(x)return xmodel = Net()# 設置模型為訓練模式 (啟用 Dropout)
model.train()# 設置模型為評估模式 (禁用 Dropout)
model.eval()# 在前向傳播中使用 Dropout
# output = model(input_tensor)

在 PyTorch 的 nn.Dropout(p) 模塊中,p 參數指定的是神經元被 drop out 的概率。與 TensorFlow 類似,PyTorch 的 Dropout 在訓練模式 (model.train()) 下會啟用,隨機將神經元置零并縮放輸出。在評估模式 (model.eval()) 下,Dropout 層會失效,相當于一個恒等變換。

總結

在代碼層面上,Dropout 的實現主要涉及以下幾個步驟:

  1. 生成一個隨機的二值掩碼,其形狀與神經元的輸出相同,掩碼中每個元素以一定的概率(Dropout 概率)為 0,以另一概率(保留概率)為 1。
  2. 將這個掩碼與神經元的輸出逐元素相乘,從而將一部分神經元的輸出置為 0。
  3. 在訓練階段,對保留下來的神經元的輸出進行縮放,通常除以保留概率。
  4. 在評估階段,禁用 Dropout,即不進行掩碼操作和縮放。

現代深度學習框架已經將 Dropout 的實現封裝在專門的層或模塊中,用戶只需要指定 Dropout 的概率即可,框架會自動處理訓練和評估階段的不同行為。這大大簡化了在模型中應用 Dropout 的過程。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/905949.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/905949.shtml
英文地址,請注明出處:http://en.pswp.cn/news/905949.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

AtCoder AT_abc406_c [ABC406C] ~

前言 除了 A 題&#xff0c;唯一一道一遍過的題。 題目大意 我們定義滿足以下所有條件的一個長度為 N N N 的序列 A ( A 1 , A 2 , … , A N ) A(A_1,A_2,\dots,A_N) A(A1?,A2?,…,AN?) 為波浪序列&#xff1a; N ≥ 4 N\ge4 N≥4&#xff08;其實滿足后面就必須滿足這…

Java Web 應用安全響應頭配置全解析:從單體到微服務網關的實踐

背景&#xff1a;為什么安全響應頭至關重要&#xff1f; 在 Web 安全領域&#xff0c;響應頭&#xff08;Response Headers&#xff09;是防御 XSS、點擊劫持、跨域數據泄露等攻擊的第一道防線。通過合理配置響應頭&#xff0c;可強制瀏覽器遵循安全策略&#xff0c;限制惡意行…

如何停止終端呢?ctrl+c不管用,其他有什么方法呢?

如果你在終端中運行了一個程序&#xff08;比如 Python GUI tkinter 應用&#xff09;&#xff0c;按下 Ctrl C 沒有作用&#xff0c;一般是因為該程序&#xff1a; 運行了主事件循環&#xff08;例如 tkinter.mainloop()&#xff09; 或 在子線程中運行&#xff0c;而 Ctrl …

深入解析 React 的 useEffect:從入門到實戰

文章目錄 前言一、為什么需要 useEffect&#xff1f;核心作用&#xff1a; 二、useEffect 的基礎用法1. 基本語法2. 依賴項數組的作用 三、依賴項數組演示1. 空數組 []&#xff1a;2.無依賴項&#xff08;空&#xff09;3.有依賴項 四、清理副作用函數實戰案例演示1. 清除定時器…

Ubuntu 更改 Nginx 版本

將 1.25 降為 1.18 先卸載干凈 # 1. 完全卸載當前Nginx sudo apt purge nginx nginx-common nginx-core# 2. 清理殘留配置 sudo apt autoremove sudo rm -rf /etc/apt/sources.list.d/nginx*.list修改倉庫地址 # 添加倉庫&#xff08;通用穩定版倉庫&#xff09; codename$(…

如何在 Windows 10 或 11 中安裝 PowerShellGet 模塊?

PowerShell 是微軟在其 Windows 操作系統上提供的強大腳本語言,可用于通過命令行界面自動化各種任務,適用于 Windows 桌面或服務器環境。而 PowerShellGet 是 PowerShell 中的一個模塊,提供了用于從各種來源發現、安裝、更新和發布模塊的 cmdlet。 本文將介紹如何在 PowerS…

NBA足球賽事直播源碼體育直播M33模板賽事源碼

源碼名稱&#xff1a;體育直播賽事扁平自適應M33直播模板源碼 開發環境&#xff1a;帝國cms7.5 空間支持&#xff1a;phpmysql 帶軟件采集&#xff0c;可以掛著自動采集發布&#xff0c;無需人工操作&#xff01; 演示地址&#xff1a;NBA足球賽事直播源碼體育直播M33模板賽事…

【Python】魔法方法是真的魔法! (第二期)

還不清楚魔術方法&#xff1f; 可以看看本系列開篇&#xff1a;【Python】小子&#xff01;是魔術方法&#xff01;-CSDN博客 【Python】魔法方法是真的魔法&#xff01; &#xff08;第一期&#xff09;-CSDN博客 在 Python 中&#xff0c;如何自定義數據結構的比較邏輯&…

Qt 強大的窗口停靠浮動

1、左邊&#xff1a; 示例代碼&#xff1a; CDockManager::setConfigFlags(CDockManager::DefaultOpaqueConfig); CDockManager::setConfigFlag(CDockManager::FocusHighlighting, true); dockManager new CDockManager(this); // Disabling the Internal Style S…

Linux進程異常退出排查指南

在 Linux 中&#xff0c;如果進程無法正常終止&#xff08;如 kill 命令無效&#xff09;或異常退出&#xff0c;可以按照以下步驟排查和解決&#xff1a; 1. 常規終止進程 嘗試普通終止&#xff08;SIGTERM&#xff09; kill PID # 發送 SIGTERM 信號&#xff08;…

使用tensorRT10部署低光照補償模型

1.低光照補償模型的簡單介紹 作者介紹一種Zero-Reference Deep Curve Estimation (Zero-DCE)的方法用于在沒有參考圖像的情況下增強低光照圖像的效果。 具體來說&#xff0c;它將低光照圖像增強問題轉化為通過深度網絡進行圖像特定曲線估計的任務。訓練了一個輕量級的深度網絡…

SLAM定位常用地圖對比示例

序號 地圖類型 概述 1 格柵地圖 將現實環境柵格化,每一個柵格用 0 和 1 分別表示空閑和占據狀態,初始化為未知狀態 0.5 2 特征地圖 以點、線、面等幾何特征來描繪周圍環境,將采集的信息進行篩選和提取得到關鍵幾何特征 3 拓撲地圖 將重要部分抽象為地圖,使用簡單的圖形表示…

【圖像生成1】Latent Diffusion Models 論文學習筆記

一、背景 本文主要記錄一下使用 LDMs 之前&#xff0c;學習 LDMs 的過程。 二、論文解讀 Paper&#xff1a;[2112.10752] High-Resolution Image Synthesis with Latent Diffusion Models 1. 總體描述 LDMs 將傳統 DMs 在高維圖像像素空間&#xff08;Pixel Space&#x…

通信安全堡壘:profinet轉ethernet ip主網關提升冶煉安全與連接

作為鋼鐵冶煉生產線的安全檢查員&#xff0c;我在此提交關于使用profinet轉ethernetip網關前后對生產線連接及安全影響的檢查報告。 使用profinet轉ethernetip網關前的情況&#xff1a; 在未使用profinet轉ethernetip網關之前&#xff0c;我們的EtherNet/IP測溫儀和流量計與PR…

TIFS2024 | CRFA | 基于關鍵區域特征攻擊提升對抗樣本遷移性

Improving Transferability of Adversarial Samples via Critical Region-Oriented Feature-Level Attack 摘要-Abstract引言-Introduction相關工作-Related Work提出的方法-Proposed Method問題分析-Problem Analysis擾動注意力感知加權-Perturbation Attention-Aware Weighti…

day 20 奇異值SVD分解

一、什么是奇異值 二、核心思想&#xff1a; 三、奇異值的主要應用 1、降維&#xff1a; 2、數據壓縮&#xff1a; 原理&#xff1a;圖像可以表示為一個矩陣&#xff0c;矩陣的元素對應圖像的像素值。對這個圖像矩陣進行 SVD 分解后&#xff0c;小的奇異值對圖像的主要結構貢…

符合Python風格的對象(對象表示形式)

對象表示形式 每門面向對象的語言至少都有一種獲取對象的字符串表示形式的標準方 式。Python 提供了兩種方式。 repr()   以便于開發者理解的方式返回對象的字符串表示形式。str()   以便于用戶理解的方式返回對象的字符串表示形式。 正如你所知&#xff0c;我們要實現_…

springboot配置tomcat端口的方法

在Spring Boot中配置Tomcat端口可通過以下方法實現&#xff1a; 配置文件方式 properties格式 在application.properties中添加&#xff1a;server.port8081YAML格式 在application.yml中添加&#xff1a;server:port: 8082多環境配置 創建不同環境的配置文件&#xff08;如app…

DeepSeek指令微調與強化學習對齊:從SFT到RLHF

后訓練微調的重要性 預訓練使大模型獲得豐富的語言和知識表達能力,但其輸出往往與用戶意圖和安全性需求不完全匹配。業內普遍采用三階段訓練流程:預訓練 → 監督微調(SFT)→ 人類偏好對齊(RLHF)。預訓練階段模型在大規模語料上學習語言規律;監督微調利用人工標注的數據…

Maven 插件擴展點與自定義生命周期

&#x1f9d1; 博主簡介&#xff1a;CSDN博客專家&#xff0c;歷代文學網&#xff08;PC端可以訪問&#xff1a;https://literature.sinhy.com/#/?__c1000&#xff0c;移動端可微信小程序搜索“歷代文學”&#xff09;總架構師&#xff0c;15年工作經驗&#xff0c;精通Java編…