前饋神經網絡正則化例子

直接看代碼:

import torch  
import numpy as np  
import random  
from IPython import display  
from matplotlib import pyplot as plt  
import torchvision  
import torchvision.transforms as transforms   mnist_train = torchvision.datasets.MNIST(root='/MNIST', train=True, download=True, transform=transforms.ToTensor())  
mnist_test = torchvision.datasets.MNIST(root='./MNIST', train=False,download=True, transform=transforms.ToTensor())  batch_size = 256 train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True,num_workers=0)  
test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False,num_workers=0)  num_inputs,num_hiddens,num_outputs =784, 256,10def init_param():W1 = torch.tensor(np.random.normal(0, 0.01, (num_hiddens,num_inputs)), dtype=torch.float32)  b1 = torch.zeros(1, dtype=torch.float32)  W2 = torch.tensor(np.random.normal(0, 0.01, (num_outputs,num_hiddens)), dtype=torch.float32)  b2 = torch.zeros(1, dtype=torch.float32)  params =[W1,b1,W2,b2]for param in params:param.requires_grad_(requires_grad=True)  return W1,b1,W2,b2def relu(x):x = torch.max(input=x,other=torch.tensor(0.0))  return xdef net(X):  X = X.view((-1,num_inputs))  H = relu(torch.matmul(X,W1.t())+b1)  #myrelu =((matmal x,w1)+b1),return  matmal(myrelu,w2 )+ b2return relu(torch.matmul(H,W2.t())+b2 )return torch.matmul(H,W2.t())+b2def SGD(paras,lr):  for param in params:  param.data -= lr * param.grad  def l2_penalty(w):return (w**2).sum()/2def train(net,train_iter,test_iter,loss,num_epochs,batch_size,lr=None,optimizer=None,mylambda=0):  train_ls, test_ls = [], []for epoch in range(num_epochs):ls, count = 0, 0for X,y in train_iter :X = X.reshape(-1,num_inputs)l=loss(net(X),y)+ mylambda*l2_penalty(W1) + mylambda*l2_penalty(W2)optimizer.zero_grad()l.backward()optimizer.step()ls += l.item()count += y.shape[0]train_ls.append(ls)ls, count = 0, 0for X,y in test_iter:X = X.reshape(-1,num_inputs)l=loss(net(X),y) + mylambda*l2_penalty(W1) + mylambda*l2_penalty(W2)ls += l.item()count += y.shape[0]test_ls.append(ls)if(epoch)%2==0:print('epoch: %d, train loss: %f, test loss: %f'%(epoch+1,train_ls[-1],test_ls[-1]))return train_ls,test_lslr = 0.01num_epochs = 20Lamda = [0,0.1,0.2,0.3,0.4,0.5]Train_ls, Test_ls = [], []for lamda in Lamda:print("current lambda is %f"%lamda)W1,b1,W2,b2 = init_param()loss = torch.nn.CrossEntropyLoss()optimizer = torch.optim.SGD([W1,b1,W2,b2],lr = 0.001)train_ls, test_ls = train(net,train_iter,test_iter,loss,num_epochs,batch_size,lr,optimizer,lamda)   Train_ls.append(train_ls)Test_ls.append(test_ls)x = np.linspace(0,len(Train_ls[1]),len(Train_ls[1]))plt.figure(figsize=(10,8))for i in range(0,len(Lamda)):plt.plot(x,Train_ls[i],label= f'L2_Regularization:{Lamda [i]}',linewidth=1.5)plt.xlabel('different epoch')plt.ylabel('loss')plt.legend(loc=2, bbox_to_anchor=(1.1,1.0),borderAxesPad = 0.)plt.title('train loss with L2_penalty')plt.show()

運行結果:

在這里插入圖片描述

疑問和心得:

  1. 畫圖的實現和細節還是有些模糊。
  2. 正則化系數一般是一個可以根據算法有一定變動的常數。
  3. 前饋神經網絡中,二分類最后使用logistic函數返回,多分類一般返回softmax值,若是一般的回歸任務,一般是直接relu返回。
  4. 前饋神經網絡的實現,從物理層上應該是全連接的,但是網上的代碼一般都是兩層單個神經元,這個容易產生誤解。個人感覺,還是要使用nn封裝的函數比較正宗。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/42759.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/42759.shtml
英文地址,請注明出處:http://en.pswp.cn/news/42759.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Docker:Windows container和Linux container

點擊"Switch to Windows containers"菜單時: 提示 然后 實際上是運行:com.docker.admin.exe start-service

成集云 | 樂享問題邀請同步企微提醒 | 解決方案

源系統成集云目標系統 方案介紹 騰訊樂享是騰訊公司開發的一款企業社區化知識管理平臺,它提供了包括知識庫、問答、課堂、考試、活動、投票和論壇等核心應用。這個平臺凝聚了騰訊10年的管理經驗,可以滿足政府、企業和學校在知識管理、學習培訓、文化建…

【gitkraken】gitkraken自動更新問題

GitKraken 會自動升級&#xff01;一旦自動升級&#xff0c;你的 GitKraken 自然就不再是最后一個免費版 6.5.1 了。 在安裝 GitKraken 之后&#xff0c;在你的安裝目錄&#xff08;C:\Users\<用戶名>\AppData\Local\gitkraken&#xff09;下會有一個名為 Update.exe 的…

Linux環境變量

環境變量 一.基本概念二.常見的環境變量1.PATH&#xff1a;指令搜索路徑2.HOME&#xff1a; 指定用戶的主工作目錄3.SHELL&#xff1a;當前Shell,它的值通常是/bin/bash 三.查看環境變量的方法四.命令行參數五.環境變量增加和刪除六.本地變量 一個問題&#xff1a;我們在寫一段…

Kotlin~Bridge橋接模式

概念 抽象和現實之間搭建橋梁&#xff0c;分離實現和抽象。 抽象&#xff08;What&#xff09;實現&#xff08;How&#xff09;用戶可見系統正常工作的底層代碼產品付款方式定義數據類型的類。處理數據存儲和檢索的類 角色介紹 Abstraction&#xff1a;抽象 定義抽象接口&…

《Go 語言第一課》課程學習筆記(五)

入口函數與包初始化&#xff1a;搞清 Go 程序的執行次序 main.main 函數&#xff1a;Go 應用的入口函數 Go 語言中有一個特殊的函數&#xff1a;main 包中的 main 函數&#xff0c;也就是 main.main&#xff0c;它是所有 Go 可執行程序的用戶層執行邏輯的入口函數。 Go 程序在…

一起創建Vue腳手架吧

目錄 一、安裝Vue CLI1.1 配置 npm 淘寶鏡像1.2 全局安裝1.3 驗證是否成功 二、創建vue_test項目2.1 cmd進入桌面2.2 創建項目2.3 運行項目2.4 查看效果 三、腳手架結構分析3.1 文件目錄結構分析3.2 vscode終端打開項目 一、安裝Vue CLI CLI&#xff1a;command-line interface…

日常BUG——微信小程序提交代碼報錯

&#x1f61c;作 者&#xff1a;是江迪呀??本文關鍵詞&#xff1a;日常BUG、BUG、問題分析??每日 一言 &#xff1a;存在錯誤說明你在進步&#xff01; 一、問題描述 在使用微信小程序開發工具進行提交代碼時&#xff0c;報出如下錯誤&#xff1a; Invalid a…

Git提交規范指南

在開發過程中&#xff0c;Git每次提交代碼&#xff0c;都需要寫Commit message&#xff08;提交說明&#xff09;&#xff0c;規范的Commit message有很多好處&#xff1a; 方便快速瀏覽查找&#xff0c;回溯之前的工作內容可以直接從commit 生成Change log(發布時用于說明版本…

5、flink任務中可以使用哪些轉換算子(Transformation)

1、什么是Flink中的轉換算子 在使用 Flink DataStream API 開發流式計算任務時&#xff0c;可以將一個或多個 DataStream 轉換成新的 DataStream&#xff0c;在應用程序中可以將多個數據轉換算子合并成一個復雜的數據流拓撲圖。 2、常用的轉換算子 Flink提供了功能各異的轉換算…

[論文筆記]ON LAYER NORMALIZATION IN THE TRANSFORMER ARCHITECTURE

引言 這是論文ON LAYER NORMALIZATION IN THE TRANSFORMER ARCHITECTURE的閱讀筆記。本篇論文提出了通過Pre-LN的方式可以省掉Warm-up環節,并且可以加快Transformer的訓練速度。 通常訓練Transformer需要一個仔細設計的學習率warm-up(預熱)階段:在訓練開始階段學習率需要設…

JDK 1.6與JDK 1.8的區別

ArrayList使用默認的構造方式實例 jdk1.6默認初始值為10jdk1.8為0,第一次放入值才初始化&#xff0c;屬于懶加載 Hashmap底層 jdk1.6與jdk1.8都是數組鏈表 jdk1.8是鏈表超過8時&#xff0c;自動轉為紅黑樹 靜態方式不同 jdk1.6是先初始化static后執行main方法。 jdk1.8是懶加…

設置PHP的fpm的系統性能參數pm.max_children

1 介紹 PHP從Apache module換成了Fpm&#xff0c;跑了幾天突然發現網站打不開了。 頁面顯示超時&#xff0c;檢查MySQL、Redis一眾服務都正常。 進入Fpm容器查看日志&#xff0c;發現了如下的錯誤信息&#xff1a; server reached pm.max_children setting (5), consider r…

python中的svm:介紹和基本使用方法

python中的svm&#xff1a;介紹和基本使用方法 支持向量機&#xff08;Support Vector Machine&#xff0c;簡稱SVM&#xff09;是一種常用的分類算法&#xff0c;可以用于解決分類和回歸問題。SVM通過構建一個超平面&#xff0c;將不同類別的數據分隔開&#xff0c;使得正負樣…

2023全國大學生數學建模競賽A題B題C題D題E題思路+模型+代碼+論文

目錄 一. 2023國賽數學建模思路&#xff1a; 賽題發布后會第一時間發布選題建議&#xff0c;思路&#xff0c;模型代碼等 詳細思路獲取見文末名片&#xff0c;9.7號第一時間更新 二.國賽常用的模型算法&#xff1a; 三、算法簡介 四.超重要&#xff01;&#xff01;&…

msvcp140.dll丟失的解決方法,如何預防msvcp140.dll丟失

在電腦操作系統中經常會彈出類似msvcp140.dll丟失的錯誤提示窗口&#xff0c;導致軟件無法正常運行。為什么會出現msvcp140.dll丟失的情況呢&#xff1f;出現這種情況應該如何解決呢&#xff1f;小編有三種解決方法分享給大家。 一.msvcp140.dll丟失的原因 1.安裝過程中受損:在…

前端框架學習-ES6新特性(尚硅谷web筆記)

ECMASript是由 Ecma 國際通過 ECMA-262 標準化的腳本程序設計語言。javaScript也是該規范的一種實現。 新特性目錄 筆記出處&#xff1a;b站ES6let 關鍵字const關鍵字變量的解構賦值模板字符串簡化對象寫法箭頭函數rest參數spread擴展運算符Promise模塊化 ES8async 和 await E…

云原生周刊:Kubernetes v1.28 新特性一覽 | 2023.8.14

推薦一個 GitHub 倉庫&#xff1a;Fast-Kubernetes。 Fast-Kubernetes 是一個涵蓋了 Kubernetes 的實驗室&#xff08;LABs&#xff09;的倉庫。它提供了關于 Kubernetes 的各種主題和組件的詳細內容&#xff0c;包括 Kubectl、Pod、Deployment、Service、ConfigMap、Volume、…

CF1013B And 題解

題目傳送門 題目意思&#xff1a; 給你一個長度為 n n n 的序列 a i a_i ai?&#xff0c;再給一個數 x x x。每一步你可以將序列中的一個數與上 x x x。請問最少要多少步才可以使得序列中出現兩個相同的數&#xff0c;如果無解輸出 ? 1 -1 ?1。 思路&#xff1a; 首…

Vue頁面刷新常用的4種方法

Vue項目里,有時候我們需要刷新頁面,重新加載頁面數據,常用方法如下: 方法一:location.reload() 方法全局刷新 使用 location.reload() 方法可以簡單地實現當前頁面的刷新,這個方法會重新加載當前頁面,類似于用戶點擊瀏覽器的刷新按鈕。 在 Vue 中,可以將該方法綁定到…