[pytorch、學習] - 3.9 多重感知機的從零開始實現

參考

3.9 多重感知機的從零開始實現

import torch
import numpy as np
import sys
sys.path.append("..")
import d2lzh_pytorch as d2l

3.9.1. 獲取和讀取數據

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

3.9.2. 定義模型參數

num_inputs, num_outputs, num_hiddens  = 784, 10, 256W1 = torch.tensor(np.random.normal(0, 0.01, (num_inputs, num_hiddens)), dtype=torch.float)
b1 = torch.zeros(num_hiddens, dtype=torch.float)
W2 = torch.tensor(np.random.normal(0, 0.01, (num_hiddens, num_outputs)), dtype=torch.float)
b2 = torch.zeros(num_outputs, dtype=torch.float)params = [W1, b1, W2, b2]
for param in params:param.requires_grad_(requires_grad=True)

3.9.3. 定義激活函數

def relu(X):return torch.max(input=X, other=torch.tensor(0.0))

3.9.4. 定義模型

def net(X):X = X.view((-1, num_inputs))H = relu(torch.matmul(X, W1) + b1)return torch.matmul(H, W2) + b2

3.9.5. 定義損失函數

loss = torch.nn.CrossEntropyLoss()

3.9.6. 訓練模型

num_epochs, lr = 5, 100.0
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, params, lr)

在這里插入圖片描述

3.9.7. 預測

X, y  = iter(test_iter).next()true_labels = d2l.get_fashion_mnist_labels(y.numpy())
pred_labels = d2l.get_fashion_mnist_labels(net(X).argmax(dim=1).numpy())
titles = [true + '\n' + pred for true, pred in zip(true_labels, pred_labels)]d2l.show_fashion_mnist(X[0:9], titles[0:9])

在這里插入圖片描述

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/250185.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/250185.shtml
英文地址,請注明出處:http://en.pswp.cn/news/250185.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

C語言逗號運算符和逗號表達式基礎總結

逗號運算符的作用: 1,起分隔符的作用: 定義變量用于分隔變量:int a,b輸入或輸出時用于分隔輸出表列 printf("%d%d",a,b) 2,用于逗號表達式的順序運算符 語法:表達式1,表達式2,...,表達…

java基礎-泛型舉例詳解

泛型 泛型是JDK5.0增加的新特性,泛型的本質是參數化類型,即所操作的數據類型被指定為一個參數。這種類型參數可以在類、接口、和方法的創建中,分別被稱為泛型類、泛型接口、泛型方法。 一、認識泛型 在沒有泛型之前,通過對類型Object的引用來…

MySQL數據庫視圖(view),視圖定義、創建視圖、修改視圖

原文鏈接:https://blog.csdn.net/moxigandashu/article/details/63254901轉載于:https://www.cnblogs.com/chrdai/p/9131881.html

[pytorch、學習] - 3.10 多重感知機的簡潔實現

參考 3.10. 多重感知機的簡潔實現 import torch from torch import nn from torch.nn import init import numpy as np import sys sys.path.append("..") import d2lzh_pytorch as d2l3.10.1. 定義模型 num_inputs, num_outputs, num_hiddens 784, 10, 256 # 參…

【匯編語言】——第三章課后總結

第三章 的書本上主要有以下幾個內容: 1.內存中字的存儲 字單元:即存放一個字型數據(16位)的內存單元,由兩個地址連續的內存單元組成。 小端法:高地址內存單元中存放字型數據的高位字節,低地址內…

如何從 Android 手機免費恢復已刪除的通話記錄/歷史記錄?

有一個有合作意向的人給我打電話,但我沒有接聽。更糟糕的是,我錯誤地將其刪除,認為這是一個騷擾電話。那么有沒有辦法從 Android 手機恢復已刪除的通話記錄呢?” 塞繆爾問道。如何在 Android 上恢復已刪除的通話記錄?如…

springBoot 登錄攔截器

1、首選創建一個繼承HandlerInterceptor的攔截器 import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse;import org.springframework.web.servlet.HandlerInterceptor; import org.springframework.web.servlet.ModelAndView; /*** 攔…

[pytorch、學習] - 3.11 模型選擇、欠擬合和過擬合

參考 3.11 模型選擇、欠擬合和過擬合 3.11.1 訓練誤差和泛化誤差 在解釋上述現象之前,我們需要區分訓練誤差(training error)和泛化誤差(generalization error)。通俗來講,前者指模型在訓練數據集上表現…

關于'java' 不是內部或外部命令,也不是可運行的程序 或批處理文件 和 錯誤: 找不到或無法加載主類 helloworld的問題...

一、前幾天電腦重裝了一次系統將java配置的環境變量都弄沒了,自己添加了兩個新的變量JAVA_HOME(自己jdk的地址)以及在path中添加%JAVA_HOME%\bin;%JAVA_HOME%\jre\bin; 然后因為這幾天都是用eclipse進行編程的,沒有出現問題&#…

spring-boot注解詳解(一)

spring-boot注解詳解(一) SpringBootApplication SpringBootApplication (默認屬性)Configuration EnableAutoConfiguration ComponentScan。 Configuration:提到Configuration就要提到他的搭檔Bean。使用這兩個注解就可以創建一個簡單的spring配置類&#xf…

前端基礎-jQuery的優點以及用法

一、jQuery介紹 jQuery是一個輕量級的、兼容多瀏覽器的JavaScript庫。jQuery使用戶能夠更方便地處理HTML Document、Events、實現動畫效果、方便地進行Ajax交互,能夠極大地簡化JavaScript編程。它的宗旨就是:“Write less, do more.“二、jQuery的優勢 一…

[pytorch、學習] - 3.12 權重衰減

參考 3.12 權重衰減 本節介紹應對過擬合的常用方法 3.12.1 方法 正則化通過為模型損失函數添加懲罰項使學出的模型參數更小,是應對過擬合的常用手段。 3.12.2 高維線性回歸實驗 import torch import torch.nn as nn import numpy as np import sys sys.path.append("…

Scapy之ARP詢問

引言 校園網中,有同學遭受永恒之藍攻擊,但是被殺毒軟件查下,并知道了攻擊者的ip也是校園網。所以我想看一下,這個ip是PC,還是路由器。 在ip視角,路由器和pc沒什么差別。 實現 首先是構造arp報文&#xff0c…

spring-boot注解詳解(二)

ResponseBody 作用: 該注解用于將Controller的方法返回的對象,通過適當的HttpMessageConverter轉換為指定格式后,寫入到Response對象的body數據區。使用時機: 返回的數據不是html標簽的頁面,而是其他某種格式的數據時…

轉:org.apache.maven.archiver.MavenArchiver.getManifest錯誤

eclipse導入新的maven項目時,pom.xml第一行報錯: org.apache.maven.archiver.MavenArchiver.getManifest(org.apache.maven.project.MavenProject, org.apache.maven.archiver.MavenArchiveConfiguration) 解決辦法: 1、Help——>Install …

Codeforces Round #524 Div. 2 翻車記

A&#xff1a;簽到。room里有一個用for寫的&#xff0c;hack了一發1e8 1&#xff0c;結果用了大概600ms跑過去了。慘絕人寰。 #include<iostream> #include<cstdio> #include<cmath> #include<cstdlib> #include<cstring> #include<algorith…

[pytorch、學習] - 3.13 丟棄法

參考 3.13 丟棄法 過擬合問題的另一種解決辦法是丟棄法。當對隱藏層使用丟棄法時,隱藏單元有一定概率被丟棄。 3.12.1 方法 3.13.2 從零開始實現 import torch import torch.nn as nn import numpy as np import sys sys.path.append("..") import d2lzh_pytorc…

springboot---request 中Parameter,Attribute區別

HttpServletRequest類既有getAttribute()方法&#xff0c;也由getParameter()方法&#xff0c;這兩個方法有以下區別&#xff1a; &#xff08;1&#xff09;HttpServletRequest類有setAttribute()方法&#xff0c;而沒有setParameter()方法 &#xff08;2&#xff09;當兩個…

Python之令人心煩意亂的字符編碼與轉碼

ASC-II碼&#xff1a;英文1個字節&#xff08;8 byte&#xff09;&#xff0c;不支持中文&#xff1b; 高大上的中國&#xff0c;擴展出自己的gbk、gb2312、gb2318等字符編碼。 由于各個國家都有自己的編碼&#xff0c;于是就需要統一的編碼形式用于國際流傳&#xff0c;防止亂…

[pytorch、學習] - 4.1 模型構造

參考 4.1 模型構造 讓我們回顧以下多重感知機的簡潔實現中包含單隱藏層的多重感知機的實現方法。我們首先構造Sequential實例,然后依次添加兩個全連接層。其中第一層的輸出大小為256,即隱藏層單元個數是256;第二層的輸出大小為10,即輸出層單元個數是10. 4.1.1 繼承Module類來…