【PyTorch】torch.fmod使用截斷正態分布truncated normal distribution初始化神經網絡的權重

這個代碼片段展示了如何用 PyTorch 初始化神經網絡的權重,具體使用的是截斷正態分布(truncated normal distribution)。截斷正態分布意味著生成的值會在一定范圍內截斷,以防止出現極端值。這里使用 torch.fmod 作為一種變通方法實現這一效果。

詳細解釋

1. 截斷正態分布

截斷正態分布是對正態分布的一種修改,確保生成的值在一定范圍內。具體來說,torch.fmod 函數返回輸入張量除以 2 的余數(即使得生成的值在 -2 到 2 之間)。

2. 權重初始化

代碼中,四個權重張量按不同的標準差(init_sd_first, init_sd_middle, init_sd_last)從截斷正態分布中生成。具體的維度分別是:

  • 第一層的權重張量形狀為 (x_dim, width + n_double)
  • 中間層的兩個權重張量形狀為 (width, width + n_double)
  • 最后一層的權重張量形狀為 (width, 1)

這些權重張量的生成方式如下:

initial_weights = [torch.fmod(torch.normal(0, init_sd_first, size=(x_dim, width + n_double)), 2),torch.fmod(torch.normal(0, init_sd_middle, size=(width, width + n_double)), 2),torch.fmod(torch.normal(0, init_sd_middle, size=(width, width + n_double)), 2),torch.fmod(torch.normal(0, init_sd_last, size=(width, 1)), 2)
]

示例代碼

下面是一個完整的示例,展示如何使用上述權重初始化方式初始化一個簡單的神經網絡:

import torch
import torch.nn as nnclass CustomModel(nn.Module):def __init__(self, x_dim, width, n_double, init_sd_first, init_sd_middle, init_sd_last):super(CustomModel, self).__init__()self.linear1 = nn.Linear(x_dim, width + n_double)self.linear2 = nn.Linear(width + n_double, width + n_double)self.linear3 = nn.Linear(width + n_double, width + n_double)self.linear4 = nn.Linear(width + n_double, 1)self.init_weights(init_sd_first, init_sd_middle, init_sd_last)def init_weights(self, init_sd_first, init_sd_middle, init_sd_last):self.linear1.weight.data = torch.fmod(torch.normal(0, init_sd_first, size=self.linear1.weight.size()), 2)self.linear2.weight.data = torch.fmod(torch.normal(0, init_sd_middle, size=self.linear2.weight.size()), 2)self.linear3.weight.data = torch.fmod(torch.normal(0, init_sd_middle, size=self.linear3.weight.size()), 2)self.linear4.weight.data = torch.fmod(torch.normal(0, init_sd_last, size=self.linear4.weight.size()), 2)def forward(self, x):x = torch.relu(self.linear1(x))x = torch.relu(self.linear2(x))x = torch.relu(self.linear3(x))x = self.linear4(x)return x# 定義超參數
x_dim = 10
width = 20
n_double = 5
init_sd_first = 0.1
init_sd_middle = 0.1
init_sd_last = 0.1# 初始化模型
model = CustomModel(x_dim, width, n_double, init_sd_first, init_sd_middle, init_sd_last)# 打印權重以驗證初始化
for name, param in model.named_parameters():if 'weight' in name:print(f"{name} initialized with values: \n{param.data}\n")

在這個示例中,我們定義了一個簡單的神經網絡 CustomModel,并在 init_weights 方法中使用截斷正態分布初始化權重。通過打印權重,我們可以驗證它們是否按預期初始化。

說明

  1. 定義網絡CustomModel 包含四個線性層。第一層輸入尺寸為 x_dim,輸出尺寸為 width + n_double。接下來的兩層也是同樣的輸出尺寸,最后一層輸出尺寸為 1。
  2. 初始化權重:在 init_weights 方法中,我們使用截斷正態分布(通過 torch.fmod)初始化每一層的權重。我們對生成的正態分布取模 2,使得權重在 -2 和 2 之間。
  3. 打印參數:我們通過 model.named_parameters() 方法遍歷模型的參數,并打印每層參數的名稱、尺寸和前兩個值。

進一步說明

  • 截斷正態分布:使用 torch.normal 生成正態分布的隨機數,然后使用 torch.fmod 將這些隨機數的范圍限制在 -2 到 2 之間。
  • 超參數x_dim 是輸入的特征維度,width 是每層的寬度(即神經元數量),n_double 是一個附加參數,用于增加每層的輸出維度。init_sd_firstinit_sd_middleinit_sd_last 是每層權重初始化的標準差。

這個示例展示了如何使用截斷正態分布初始化神經網絡的權重,并打印每層的參數。如果您有更多問題或需要進一步的幫助,請告訴我!

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/41826.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/41826.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/41826.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

配置linux net.ipv4.ip_forward數據包轉發

前言 出于系統安全考慮,在默認情況下,Linux系統是禁止數據包轉發的。數據包轉發指的是當主機擁有多個網卡時,通過一個網卡接收到的數據包,根據目的IP地址來轉發數據包到其他網卡。這個功能通常用于路由器。 如果在Linux系統中需要…

CVPR 2024最佳論文分享:通過解釋方法比較Transformers和CNNs的決策機制

CVPR(Conference on Computer Vision and Pattern Recognition)是計算機視覺領域最有影響力的會議之一,主要方向包括圖像和視頻處理、目標檢測與識別、三維視覺等。近期,CVPR 2024 公布了最佳論文。共有10篇論文獲獎,其…

計算組的妙用!!頁面權限控制

需求描述: 某些特殊的場景下,針對某頁看板,需要進行數據權限卡控,但是又不能對全部的數據進行RLS處理,這種情況下可以利用計算組來解決這個需求。 實際場景 事實表包含產品維度和銷售維度 兩個維度屬于同一公司下面的…

限幅濾波法

限幅濾波法 限幅濾波法:根據經驗判斷,確定兩次采樣允許的最大偏差值(設為A),每次檢測到新值時判斷:如果本次值與上次值之差<=A,則本次值有效,如果本次值與上次值之差>A,則本次值無效,放棄本次值,用上次值代替本次值。 優點: 能有效克服因偶然因素引起的脈沖…

【Python】已解決:FileNotFoundError: [Errno 2] No such file or directory: ‘./1.xml’

文章目錄 一、分析問題背景二、可能出錯的原因三、錯誤代碼示例四、正確代碼示例五、注意事項 已解決&#xff1a;FileNotFoundError: [Errno 2] No such file or directory: ‘./1.xml’ 一、分析問題背景 在Python編程中&#xff0c;FileNotFoundError是一個常見的異常&…

ChatGPT對話:Python程序自動模擬操作網頁,無法彈出下拉列表框

【編者按】需要編寫Python程序自動模擬操作網頁。編者有編程經驗&#xff0c;但沒有前端編程經驗&#xff0c;完全不知道如何編寫這種程序。通過與ChatGPT討論&#xff0c;1天完成了任務。因為沒有這類程序的編程經驗&#xff0c;需要邊學習&#xff0c;邊編程&#xff0c;遇到…

貝爾曼方程(Bellman Equation)

貝爾曼方程(Bellman Equation) 貝爾曼方程(Bellman Equation)是動態規劃和強化學習中的核心概念,用于描述最優決策問題中的價值函數的遞歸關系。它為狀態值函數和動作值函數提供了一個重要的遞推公式,幫助我們計算每個狀態或狀態-動作對的預期回報。 貝爾曼方程的原理 …

Python 自動化測試必會技能板塊—unittest框架

說到 Python 的單元測試框架&#xff0c;想必接觸過 Python 的朋友腦袋里第一個想到的就是 unittest。 的確&#xff0c;作為 Python 的標準庫&#xff0c;它很優秀&#xff0c;并被廣泛應用于各個項目。但其實在 Python 眾多項目中&#xff0c;主流的單元測試框架遠不止這一個…

西門子PLC1200--與電腦S7通訊

硬件構成 PLC為西門子1211DCDCDC 電腦上位機用PYTHON編寫 二者通訊用網線&#xff0c;通訊協議用S7 PLC上的數據 PLC上的數據是2個uint&#xff0c;在DB1&#xff0c;地址偏移分別是0和2 需要注意的是DB塊要關閉優化的塊訪問&#xff0c;否則是沒有偏移地址的 PLC中的數據內…

elementui中日期/時間的禁用處理,使用傳值的方式

項目中,經常會用到 在一個學年或者一個學期或者某一個時間段需要做的某件事情,則我們需要在創建這個事件的時候,需要設置一定的時間周期,那這個時間周期就需要給一定的限制處理,避免用戶的誤操作,優化用戶體驗 如下:需求為,在選擇學年后,學期的設置需要在學年中,且結束時間大…

Spring Cloud Gateway如何匹配某路徑并進行路由轉發

本案例&#xff0c;將/helloworld-app/**的請求轉發到helloworld微服務的/**路徑&#xff08;既如lb://helloworld/**&#xff09; 配置如下&#xff08;見spring.cloud.gateway.routes配置&#xff09;&#xff1a; spring:application:name: SpringCloudGatewayDemocloud:n…

軟件架構之計算機組成與體系結構

1.1計算機系統組成 計算機系統是一個硬件和軟件的綜合體&#xff0c;可以把它看成按功能劃分的多級層次結構。 1.1.1 計算機硬件的組成 硬件通常是指一切看得見&#xff0c;摸得到的設備實體。原始的馮?諾依曼&#xff08;VonNeumann&#xff09;計算機在結構上是以運算器為…

2024年中國十大杰出起名大師排行榜,最厲害的易經姓名學改名字專家

在2024年揭曉的中國十大杰出易學泰斗評選中&#xff0c;一系列對姓名學與國學易經有深入研究的專家榮登榜單。其中&#xff0c;中國十大權威姓名學專家泰斗頂級杰出代表人物的師傅顏廷利大師以其在國際舞臺上的卓越貢獻和深邃學識&#xff0c;被公認為姓名學及易經起名領域的權…

C#程序調用Sql Server存儲過程異常處理:調用存儲過程后不返回、不拋異常的解決方案

目錄 一、代碼解析&#xff1a; 二、解決方案 1、增加日志記錄 2、異步操作 注意事項 3、增加超時機制 4、使用線程池 5、使用信號量或事件 6、監控數據庫連接狀態 在C#程序操作Sql Server數據庫的實際應用中&#xff0c;若異常就會拋出異常&#xff0c;我們還能找到異…

Leetcode 完美數

1.題目要求: 對于一個 正整數&#xff0c;如果它和除了它自身以外的所有 正因子 之和相等&#xff0c;我們稱它為 「完美數」。給定一個 整數 n&#xff0c; 如果是完美數&#xff0c;返回 true&#xff1b;否則返回 false。示例 1&#xff1a;輸入&#xff1a;num 28 輸出&a…

2024年6月份找工作和面試總結

轉眼間6月份已經過完了&#xff0c;2024年已經過了一半&#xff0c;希望大家都找到了合適的工作。 本人前段時間寫了5月份找工作的情況&#xff0c;請查看2024年5月份面試總結-CSDN博客 但是后續寫的總結被和諧了&#xff0c;不知道這篇文章能不能發出來。 1、6月份面試機會依…

網絡爬蟲基礎

網絡爬蟲基礎 網絡爬蟲&#xff0c;也被稱為網絡蜘蛛或爬蟲&#xff0c;是一種用于自動瀏覽互聯網并從網頁中提取信息的軟件程序。它們能夠訪問網站&#xff0c;解析頁面內容&#xff0c;并收集所需數據。Python語言因其簡潔的語法和強大的庫支持&#xff0c;成為實現網絡爬蟲…

verilog讀寫文件注意事項

想要的16進制數是文本格式提供的文件&#xff0c;想將16進制數提取到變量內&#xff0c; 可以使用 f s c a n f ( f d 1 , " 也可以使用 fscanf(fd1,"%h",rd_byte);實現 也可以使用 fscanf(fd1,"也可以使用readmemh(“./FILE/1.txt”,mem);//fe放在mem[0…

運用Redis作為設備注冊中心,解決20w+設備高并發讀寫,高性能讀寫異步把數據同步到mysql持久化。

使用 Redis 作為設備注冊中心&#xff0c;并通過高并發讀寫將數據異步同步到 MySQL 數據庫&#xff0c;可以采用以下策略&#xff1a; 1. **設備注冊與發現**&#xff1a; - 使用 Redis 的字符串或哈希表存儲設備信息&#xff0c;其中鍵可以是設備的唯一標識符。 2. **高并…

基于Android Studio零食工坊

目錄 項目介紹 圖片展示 運行環境 獲取方式 項目介紹 用戶 可以瀏覽商品 &#xff0c; 查詢商品 &#xff0c; 加入購物車 &#xff0c; 結算商品 &#xff0c; 查看瀏覽記錄 &#xff0c; 修改密碼 &#xff0c; 修改個人信息 &#xff0c; 查詢訂單 管理員 能夠實現商品的…