[pytorch、學習] - 3.6 softmax回歸的從零開始實現

參考

3.6 softmax回歸的從零開始實現

import torch
import torchvision
import numpy as np
import sys
sys.path.append("..")
import d2lzh_pytorch as d2l

3.6.1. 獲取和讀取數據

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

3.6.2. 初始化模型參數

num_inputs = 784
num_outputs = 10W = torch.tensor(np.random.normal(0, 0.01, (num_inputs, num_outputs)), dtype=torch.float)  # torch.Size([784, 10])
b = torch.zeros(num_outputs, dtype=torch.float)   # torch.Size([10])# 同之前一樣,我們需要模型參數梯度。
W.requires_grad_(requires_grad=True)
b.requires_grad_(requires_grad=True)

在這里插入圖片描述

3.6.3. 實現softmax運算

def softmax(X):X_exp = X.exp()partition = X_exp.sum(dim=1, keepdim=True)return X_exp / partition

3.6.4. 定義模型

# 傳入特征,給出預測值
def net(X):return softmax(torch.mm(X.view((-1, num_inputs)), W) + b)

3.6.5. 定義損失函數

def cross_entropy(y_hat, y):return -torch.log(y_hat.gather(1, y.view(-1, 1)))

3.6.6. 計算分類準確率

def accuracy(y_hat, y):return (y_hat.argmax(dim=1) ==y).float().mean().item()def evaluate_accuracy(data_iter, net):acc_sum, n = 0.0, 0for X, y in data_iter:acc_sum += (net(X).argmax(dim=1) == y).float().sum().item()n += y.shape[0]return acc_sum /n

3.6.7. 訓練模型

  • d2lzh
num_epochs, lr = 5, 0.1# 本函數已保存在d2lzh包中方便以后使用
def train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size,params=None, lr=None, optimizer=None):for epoch in range(num_epochs):train_l_sum, train_acc_sum, n = 0.0, 0.0, 0for X, y in train_iter:y_hat = net(X)l = loss(y_hat, y).sum()# 梯度清零if optimizer is not None:optimizer.zero_grad()elif params is not None and params[0].grad is not None:for param in params:param.grad.data.zero_()l.backward()if optimizer is None:d2l.sgd(params, lr, batch_size)else:optimizer.step()  # “softmax回歸的簡潔實現”一節將用到train_l_sum += l.item()train_acc_sum += (y_hat.argmax(dim=1) == y).sum().item()n += y.shape[0]test_acc = evaluate_accuracy(test_iter, net)print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f'% (epoch + 1, train_l_sum / n, train_acc_sum / n, test_acc))train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs, batch_size, [W, b], lr)

在這里插入圖片描述

3.6.8. 預測

X, y = iter(test_iter).next()true_labels = d2l.get_fashion_mnist_labels(y.numpy())
pred_labels = d2l.get_fashion_mnist_labels(net(X).argmax(dim=1).numpy())
titles = [true + '\n' + pred for true, pred in zip(true_labels, pred_labels)]d2l.show_fashion_mnist(X[0:9], titles[0:9])

在這里插入圖片描述

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/250191.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/250191.shtml
英文地址,請注明出處:http://en.pswp.cn/news/250191.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Django基礎必備三件套: HttpResponse render redirect

1. HttpResponse : 它的作用是內部傳入一個字符串參數, 然后發給瀏覽器 def index(request):return HttpResponse(ok) 2. render : 可以接收三個參數, 一是request參數, 二是待渲染的 html 模板文件, 三是保存具體數據的字典參數 def index(request):return render(request, …

React 簡單實例 (React-router + webpack + Antd )

React Demo Github 地址 經過React Native 的洗禮之后,寫了這個 demo ;React 是為了使前端的V層更具組件化,能更好的復用,同時可以讓你從操作dom中解脫出來,只需要操作數據就會改變相應的dom; 而React Nat…

[pytorch、學習] - 3.7 softmax回歸的簡潔實現

參考 3.7. softmax回歸的簡潔實現 使用pytorch實現softmax import torch from torch import nn from torch.nn import init import numpy as np import sys sys.path.append("..") import d2lzh_pytorch as d2l3.7.1. 獲取和讀取數據 batch_size 256 train_iter…

【模板】NTT

NTT模板 #include<bits/stdc.h> using namespace std; #define LL long long const int MAXL22; const int MAXN1<<MAXL; const int Mod998244353; int rev[MAXN],A[MAXN],B[MAXN],C[MAXN]; int fast_pow(int a,int b){int ans1;while(b){if(b&1)ans1ll*ans*a%…

centos 7 php7 yum源

rpm -Uvh https://dl.fedoraproject.org/pub/epel/epel-release-latest-7.noarch.rpmrpm -Uvh https://mirror.webtatic.com/yum/el7/webtatic-release.rpm 轉載于:https://www.cnblogs.com/myJuly/p/10008252.html

[pytorch、學習] - 3.9 多重感知機的從零開始實現

參考 3.9 多重感知機的從零開始實現 import torch import numpy as np import sys sys.path.append("..") import d2lzh_pytorch as d2l3.9.1. 獲取和讀取數據 batch_size 256 train_iter, test_iter d2l.load_data_fashion_mnist(batch_size)3.9.2. 定義模型參…

C語言逗號運算符和逗號表達式基礎總結

逗號運算符的作用&#xff1a; 1&#xff0c;起分隔符的作用&#xff1a; 定義變量用于分隔變量&#xff1a;int a,b輸入或輸出時用于分隔輸出表列 printf("%d%d",a,b) 2,用于逗號表達式的順序運算符 語法&#xff1a;表達式1&#xff0c;表達式2&#xff0c;...,表達…

java基礎-泛型舉例詳解

泛型 泛型是JDK5.0增加的新特性&#xff0c;泛型的本質是參數化類型&#xff0c;即所操作的數據類型被指定為一個參數。這種類型參數可以在類、接口、和方法的創建中&#xff0c;分別被稱為泛型類、泛型接口、泛型方法。 一、認識泛型 在沒有泛型之前,通過對類型Object的引用來…

MySQL數據庫視圖(view),視圖定義、創建視圖、修改視圖

原文鏈接&#xff1a;https://blog.csdn.net/moxigandashu/article/details/63254901轉載于:https://www.cnblogs.com/chrdai/p/9131881.html

[pytorch、學習] - 3.10 多重感知機的簡潔實現

參考 3.10. 多重感知機的簡潔實現 import torch from torch import nn from torch.nn import init import numpy as np import sys sys.path.append("..") import d2lzh_pytorch as d2l3.10.1. 定義模型 num_inputs, num_outputs, num_hiddens 784, 10, 256 # 參…

【匯編語言】——第三章課后總結

第三章 的書本上主要有以下幾個內容&#xff1a; 1.內存中字的存儲 字單元&#xff1a;即存放一個字型數據&#xff08;16位&#xff09;的內存單元&#xff0c;由兩個地址連續的內存單元組成。 小端法&#xff1a;高地址內存單元中存放字型數據的高位字節&#xff0c;低地址內…

如何從 Android 手機免費恢復已刪除的通話記錄/歷史記錄?

有一個有合作意向的人給我打電話&#xff0c;但我沒有接聽。更糟糕的是&#xff0c;我錯誤地將其刪除&#xff0c;認為這是一個騷擾電話。那么有沒有辦法從 Android 手機恢復已刪除的通話記錄呢&#xff1f;” 塞繆爾問道。如何在 Android 上恢復已刪除的通話記錄&#xff1f;如…

springBoot 登錄攔截器

1、首選創建一個繼承HandlerInterceptor的攔截器 import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse;import org.springframework.web.servlet.HandlerInterceptor; import org.springframework.web.servlet.ModelAndView; /*** 攔…

[pytorch、學習] - 3.11 模型選擇、欠擬合和過擬合

參考 3.11 模型選擇、欠擬合和過擬合 3.11.1 訓練誤差和泛化誤差 在解釋上述現象之前&#xff0c;我們需要區分訓練誤差&#xff08;training error&#xff09;和泛化誤差&#xff08;generalization error&#xff09;。通俗來講&#xff0c;前者指模型在訓練數據集上表現…

關于'java' 不是內部或外部命令,也不是可運行的程序 或批處理文件 和 錯誤: 找不到或無法加載主類 helloworld的問題...

一、前幾天電腦重裝了一次系統將java配置的環境變量都弄沒了&#xff0c;自己添加了兩個新的變量JAVA_HOME&#xff08;自己jdk的地址&#xff09;以及在path中添加%JAVA_HOME%\bin;%JAVA_HOME%\jre\bin; 然后因為這幾天都是用eclipse進行編程的&#xff0c;沒有出現問題&#…

spring-boot注解詳解(一)

spring-boot注解詳解(一) SpringBootApplication SpringBootApplication (默認屬性)Configuration EnableAutoConfiguration ComponentScan。 Configuration&#xff1a;提到Configuration就要提到他的搭檔Bean。使用這兩個注解就可以創建一個簡單的spring配置類&#xf…

前端基礎-jQuery的優點以及用法

一、jQuery介紹 jQuery是一個輕量級的、兼容多瀏覽器的JavaScript庫。jQuery使用戶能夠更方便地處理HTML Document、Events、實現動畫效果、方便地進行Ajax交互&#xff0c;能夠極大地簡化JavaScript編程。它的宗旨就是&#xff1a;“Write less, do more.“二、jQuery的優勢 一…

[pytorch、學習] - 3.12 權重衰減

參考 3.12 權重衰減 本節介紹應對過擬合的常用方法 3.12.1 方法 正則化通過為模型損失函數添加懲罰項使學出的模型參數更小,是應對過擬合的常用手段。 3.12.2 高維線性回歸實驗 import torch import torch.nn as nn import numpy as np import sys sys.path.append("…

Scapy之ARP詢問

引言 校園網中&#xff0c;有同學遭受永恒之藍攻擊&#xff0c;但是被殺毒軟件查下&#xff0c;并知道了攻擊者的ip也是校園網。所以我想看一下&#xff0c;這個ip是PC&#xff0c;還是路由器。 在ip視角&#xff0c;路由器和pc沒什么差別。 實現 首先是構造arp報文&#xff0c…

spring-boot注解詳解(二)

ResponseBody 作用&#xff1a; 該注解用于將Controller的方法返回的對象&#xff0c;通過適當的HttpMessageConverter轉換為指定格式后&#xff0c;寫入到Response對象的body數據區。使用時機&#xff1a; 返回的數據不是html標簽的頁面&#xff0c;而是其他某種格式的數據時…