pytorch自定義新層demo_從頭學pytorch(十一):自定義層

自定義layer

不含模型參數的layer

含模型參數的layer

核心都一樣,自定義一個繼承自nn.Module的類,在類的forward函數里實現該layer的計算,不同的是,帶參數的layer需要用到nn.Parameter

不含模型參數的layer

直接繼承nn.Module

import torch

from torch import nn

class CenteredLayer(nn.Module):

def __init__(self, **kwargs):

super(CenteredLayer, self).__init__(**kwargs)

def forward(self, x):

return x - x.mean()

layer = CenteredLayer()

layer(torch.tensor([1, 2, 3, 4, 5], dtype=torch.float))

net = nn.Sequential(nn.Linear(8, 128), CenteredLayer())

y = net(torch.rand(4, 8))

y.mean().item()

含模型參數的layer

Parameter

ParameterList

ParameterDict

Parameter類其實是Tensor的子類,如果一個Tensor是Parameter,那么它會自動被添加到模型的參數列表里。所以在自定義含模型參數的層時,我們應該將參數定義成Parameter,除了直接定義成Parameter類外,還可以使用ParameterList和ParameterDict分別定義參數的列表和字典。

ParameterList用法和list類似

class MyDense(nn.Module):

def __init__(self):

super(MyDense,self).__init__()

self.params = nn.ParameterList([nn.Parameter(torch.randn(4,4)) for i in range(4)])

self.params.append(nn.Parameter(torch.randn(4,1)))

def forward(self,x):

for i in range(len(self.params)):

x = torch.mm(x,self.params[i])

return x

net = MyDense()

print(net)

輸出

MyDense(

(params): ParameterList(

(0): Parameter containing: [torch.FloatTensor of size 4x4]

(1): Parameter containing: [torch.FloatTensor of size 4x4]

(2): Parameter containing: [torch.FloatTensor of size 4x4]

(3): Parameter containing: [torch.FloatTensor of size 4x4]

(4): Parameter containing: [torch.FloatTensor of size 4x1]

)

)

ParameterDict用法和python dict類似.也可以用.keys(),.items()

class MyDictDense(nn.Module):

def __init__(self):

super(MyDictDense, self).__init__()

self.params = nn.ParameterDict({

'linear1': nn.Parameter(torch.randn(4, 4)),

'linear2': nn.Parameter(torch.randn(4, 1))

})

self.params.update({'linear3': nn.Parameter(torch.randn(4, 2))}) # 新增

def forward(self, x, choice='linear1'):

return torch.mm(x, self.params[choice])

net = MyDictDense()

print(net)

print(net.params.keys(),net.params.items())

x = torch.ones(1, 4)

net(x, 'linear1')

輸出

MyDictDense(

(params): ParameterDict(

(linear1): Parameter containing: [torch.FloatTensor of size 4x4]

(linear2): Parameter containing: [torch.FloatTensor of size 4x1]

(linear3): Parameter containing: [torch.FloatTensor of size 4x2]

)

)

odict_keys(['linear1', 'linear2', 'linear3']) odict_items([('linear1', Parameter containing:

tensor([[-0.2275, -1.0434, -1.6733, -1.8101],

[ 1.7530, 0.0729, -0.2314, -1.9430],

[-0.1399, 0.7093, -0.4628, -0.2244],

[-1.6363, 1.2004, 1.4415, -0.1364]], requires_grad=True)), ('linear2', Parameter containing:

tensor([[ 0.5035],

[-0.0171],

[-0.8580],

[-1.1064]], requires_grad=True)), ('linear3', Parameter containing:

tensor([[-1.2078, 0.4364],

[-0.8203, 1.7443],

[-1.7759, 2.1744],

[-0.8799, -0.1479]], requires_grad=True))])

使用自定義的layer構造模型

layer1 = MyDense()

layer2 = MyDictDense()

net = nn.Sequential(layer2,layer1)

print(net)

print(net(x))

輸出

Sequential(

(0): MyDictDense(

(params): ParameterDict(

(linear1): Parameter containing: [torch.FloatTensor of size 4x4]

(linear2): Parameter containing: [torch.FloatTensor of size 4x1]

(linear3): Parameter containing: [torch.FloatTensor of size 4x2]

)

)

(1): MyDense(

(params): ParameterList(

(0): Parameter containing: [torch.FloatTensor of size 4x4]

(1): Parameter containing: [torch.FloatTensor of size 4x4]

(2): Parameter containing: [torch.FloatTensor of size 4x4]

(3): Parameter containing: [torch.FloatTensor of size 4x4]

(4): Parameter containing: [torch.FloatTensor of size 4x1]

)

)

)

tensor([[-4.7566]], grad_fn=)

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/541568.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/541568.shtml
英文地址,請注明出處:http://en.pswp.cn/news/541568.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

java日歷類add方法_Java日歷computeTime()方法及示例

java日歷類add方法日歷類computeTime()方法 (Calendar Class computeTime() method) computeTime() method is available in java.util package. java.util包中提供了computeTime()方法 。 computeTime() method is for conversion of current field values to the ms(millisec…

C++——智能指針和RAII

該文章代碼均在gitee中開源 C智能指針hpphttps://gitee.com/Ehundred/cpp-knowledge-points/tree/master/%E6%99%BA%E8%83%BD%E6%8C%87%E9%92%88??????? 智能指針 傳統指針的問題 在C自定義類型中,我們為了避免內存泄漏,會采用析構函數的方法釋…

移除元素所有事件監聽_DOM 事件模型或 DOM 事件機制

DOM 事件模型DOM 的事件操作(監聽和觸發),都定義在EventTarget接口。所有節點對象都部署了這個接口,其他一些需要事件通信的瀏覽器內置對象(比如,XMLHttpRequest、AudioNode、AudioContext)也部…

gettimezone_Java日歷getTimeZone()方法與示例

gettimezone日歷類的getTimeZone()方法 (Calendar Class getTimeZone() method) getTimeZone() method is available in java.util package. getTimeZone()方法在java.util包中可用。 getTimeZone() method is used to return this Calendar time zone. getTimeZone()方法用于返…

cass展點不在原位置_cass展點之步驟及方法

cass展點之步驟及方法cass展點是根據手工或坐標正反算軟件自動計算的結果,利用cass軟件將點號、坐標及其高程自動展示到圖紙上的一種方法。其基本步驟和方法如下:一、將井下測點的點號、以及計算好的Y坐標、X坐標、及高程由sheet1復制并粘貼到sheet2上面…

Java BufferedWriter close()方法與示例

BufferedWriter類close()方法 (BufferedWriter Class close() method) close() method is available in java.io package. close()方法在java.io包中可用。 close() method is used to flushes the characters from the stream and later will close it by using close() metho…

ISCC2014-reverse

這是我做reverse的題解。在咱逆向之路上的mark一下,,水平有限,大牛見笑。題目及題解鏈接:http://pan.baidu.com/s/1gd3k2RL 宗女齊姜 果然是僅僅有50分的難度,OD直接找到了flag. 找到殺手 這題用OD做非常麻煩。我改用I…

python 獲取當前時間再往前幾個月_Python 中的時間和日期操作

Python中,對日期和時間的操作,主要使用這3個內置模塊: datetime 、 time 和 calendar 獲取當前時間對應的數字 開發程序時,經常需要獲取兩個代碼位置在執行時的時間差,比如,我們想知道某個函數執行大概耗費了多少時間,就可以使用time.time()來做。 import time before =…

Java BigDecimal restder()方法與示例

BigDecimal類的restder()方法 (BigDecimal Class remainder() method) Syntax: 句法: public BigDecimal remainder(BigDecimal divsr);public BigDecimal remainder(BigDecimal divsr, MathContext ma_co);remainder() method is available in java.math package.…

python程序需要編譯么_python需要編譯么

一個經常聽見的問題,那就是:Python是解釋型的語言嗎?它會被編譯嗎?這個問題沒有想象中那么好回答。和很多人認識世界一樣,習慣以一個簡單的模型去評判一些事物。而事實上,里面包含了很多很多的細節。通常的…

DevOps平臺中的自動化部署框架設計

本文目錄: 一、背景 二、我們的需求是什么? 三、概念澄清 四、概念模型 五、總體設計 六、關鍵點設計 七、總結 一、背景 說到自動化部署,大家肯定都會想到一些配置管理工具,像ansible,chef,puppet, saltstack等等。雖然這些工具給…

插入排序算法 ,遞歸實現_C程序實現遞歸插入排序

插入排序算法 ,遞歸實現The only difference between Insertion sort and Recursive Insertion Sort is that in the Recursive method, we start from placing the last element in its correct position in the sorted array instead of starting from the first. 插入排序和…

python虛擬機直接加載字節碼運行程序_第二章 python如何運行程序

一.python解釋器介紹Python解釋器是一種讓程序運行起來的程序。實際上,解釋器是代碼與機器的計算機硬件之間的軟件邏輯層。當Python包安裝在機器上后,它包含了一些最小化的組件:一個解釋器和支持的庫。二.python的視角當Python運行腳本時&…

Java LocalDate類| 帶示例的format()方法

LocalDate類format()方法 (LocalDate Class format() method) format() method is available in java.time package. format()方法在java.time包中可用。 format() method is used to format this LocalDate object by using the given DateTimeFormatter object. format()方法…

胃癌2019csco指南_2019 CSCO胃癌診療指南精華來了!

一文輕松get 2019 CSCO胃癌診療指南更新要點!文丨青青子衿 中山大學腫瘤防治中心來源丨醫學界腫瘤頻道近日,2019年CSCO指南發布會于南京召開。今天為大家推送的是2019 CSCO胃癌診療指南的最新更新,在發布專場中,來自華中科技大學同…

001_docker-compose構建elk環境

由于打算給同事分享elk相關的東西,搭建配置elk環境太麻煩了,于是想到了docker。docker官方提供了docker-compose編排工具,elk集群一鍵就可以搞定,真是興奮。好了下面咱們開始吧。 一、 https://github.com/deviantony/docker-elk $ cd /006_xxxallproject/005_docker/001_e…

Java即時類| toString()方法與示例

即時類toString()方法 (Instant Class toString() method) toString() method is available in java.time package. toString()方法在java.time包中可用。 toString() method is used to represent this Instant as a String by using the standards ISO-8601 format. toString…

learn opengl 中文_LearnOpenGL CN

歡迎來到OpenGL的世界歡迎來到OpenGL的世界。這個工程只是我(Joey de Vries)的一次小小的嘗試,希望能夠建立起一個完善的OpenGL教學平臺。無論你學習OpenGL是為了學業,找工作,或僅僅是因為興趣,這個網站都將能夠教會你現代(Core-p…

MYSQL5.7 日志管理

2019獨角獸企業重金招聘Python工程師標準>>> 慢查詢日志slow-query-log1 slow-query-log-filefile_name long_query_time1 #SQL執行多長時間以上會記錄到慢查詢日志,0~10s log_slow_admin_statementsOFF #在寫入慢查詢日志的語句中包含緩慢的管理語句。 …

duration java_Java Duration類| ofHours()方法與示例

duration javaDuration Class of Hours()方法 (Duration Class ofHours() method) ofHours() method is available in java.time package. ofHours()方法在java.time包中可用。 ofHours() method is used to represent the given hours in this Duration. ofHours()方法用于表示…