[pytorch、學習] - 3.11 模型選擇、欠擬合和過擬合

參考

3.11 模型選擇、欠擬合和過擬合

3.11.1 訓練誤差和泛化誤差

在解釋上述現象之前,我們需要區分訓練誤差(training error)和泛化誤差(generalization error)。通俗來講,前者指模型在訓練數據集上表現出的誤差,后者指模型在任意一個測試數據樣本上表現出的誤差的期望,并常常通過測試數據集上的誤差來近似。計算訓練誤差和泛化誤差可以使用之前介紹過的損失函數,例如線性回歸用到的平方損失函數和softmax回歸用到的交叉熵損失函數。

讓我們以高考為例來直觀地解釋訓練誤差和泛化誤差這兩個概念。訓練誤差可以認為是做往年高考試題(訓練題)時的錯誤率,泛化誤差則可以通過真正參加高考(測試題)時的答題錯誤率來近似。假設訓練題和測試題都隨機采樣于一個未知的依照相同考綱的巨大試題庫。如果讓一名未學習中學知識的小學生去答題,那么測試題和訓練題的答題錯誤率可能很相近。但如果換成一名反復練習訓練題的高三備考生答題,即使在訓練題上做到了錯誤率為0,也不代表真實的高考成績會如此。

在機器學習里,我們通常假設訓練數據集(訓練題)和測試數據集(測試題)里的每一個樣本都是從同一個概率分布中相互獨立地生成的。基于該獨立同分布假設,給定任意一個機器學習模型(含參數),它的訓練誤差的期望和泛化誤差都是一樣的。例如,如果我們將模型參數設成隨機值(小學生),那么訓練誤差和泛化誤差會非常相近。但我們從前面幾節中已經了解到,模型的參數是通過在訓練數據集上訓練模型而學習出的,參數的選擇依據了最小化訓練誤差(高三備考生)。所以,訓練誤差的期望小于或等于泛化誤差。也就是說,一般情況下,由訓練數據集學到的模型參數會使模型在訓練數據集上的表現優于或等于在測試數據集上的表現。由于無法從訓練誤差估計泛化誤差,一味地降低訓練誤差并不意味著泛化誤差一定會降低。

機器學習模型應關注降低泛化誤差。

3.11.2 模型選擇

在機器學習中,通常需要評估若干候選模型的表現并從中選擇模型。這一過程稱為模型選擇(model selection)。以多層感知機為例,我們可以選擇隱藏層的個數,以及每個隱藏層中隱藏單元個數和激活函數。為了得到有效的模型,我們通常要在模型選擇上下一番功夫。下面,我們來描述模型選擇中經常使用的驗證數據集(validation data set)。

3.11.2.1 驗證數據集

從嚴格意義上來講,測試集只能在所有超參數和模型選定后使用一次。不可以使用測試數據集選擇模型,如調參。由于無法從訓練誤差估計泛化誤差,因此也不應只依賴訓練數據選擇模型。鑒于此,我們可以預留一部分訓練數據集和測試數據集以外的數據來進行模型選擇。這部分數據稱為驗證數據集,簡稱驗證集(validation set)。

然而在實際應用中,由于數據不容易獲取,測試數據極少只使用一次就丟棄。因此,實踐中驗證數據集和測試數據集的界限可能比較模糊。從嚴格意義上講,除非明確說明,否則本書中實驗所使用的測試集應為驗證集,實驗報告的測試結果(如測試準確率)應為驗證結果(如驗證準確率)。

3.11.2.2 K折交叉驗證

由于驗證數據集不參與模型訓練,當訓練數據不夠時,預留大量的驗證數據顯得太奢侈。一種改善的方法是KK折交叉驗證(K-fold cross-validation)。在K折交叉驗證中,我們把原始訓練數據集分割成K個不重合的子數據集,然后我們做K次模型訓練和驗證。每一次,我們使用一個子數據集驗證模型,并使用其他K-1個子數據集來訓練模型。在這K次訓練和驗證中,每次用來驗證模型的子數據集都不同。最后,我們對這K次訓練誤差和驗證誤差分別求平均。

3.11.3 欠擬合和過擬合

欠擬合: 模型無法得到較低的誤差
過擬合: 模型在訓練集上的誤差遠遠小于在測試集上的誤差

3.11.3.1 模型復雜度

在這里插入圖片描述

3.11.3.2 訓練數據集大小

影響欠擬合和過擬合的另一個重要因素是訓練數據集的大小。一般來說,如果訓練數據集中樣本數過少,特別是比模型參數數量(按元素計)更少時,過擬合更容易發生。此外,泛化誤差不會隨訓練數據集里樣本數量增加而增大。因此,在計算資源允許的范圍之內,我們通常希望訓練數據集大一些,特別是在模型復雜度較高時,例如層數較多的深度學習模型。

3.11.4 多項式函數擬合實驗

import torch
import numpy as np
import sys
sys.path.append("..")
import d2lzh_pytorch as d2l

3.11.4.1 生產數據集

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-C06OGqrH-1594087907825)(attachment:image.png)]

n_train, n_nest, true_w, true_b = 100, 100, [1.2, -3.4, 5.6], 5
features = torch.randn((n_train + n_nest, 1))
poly_features = torch.cat((features, torch.pow(features, 2), torch.pow(features, 3)), 1)   # 按列拼起來
labels = (true_w[0] * poly_features[:,0] + true_w[1] * poly_features[:,1] + true_w[2] * poly_features[:, 2] + true_b)labels += torch.tensor(np.random.normal(0, 0.01, size = labels.size()), dtype=torch.float)

3.11.4.2 定義、訓練模型

def semilogy(x_vals, y_vals, x_label, y_label, x2_vals=None, y2_vals=None, legend=None, figsize=(3.5, 2.5)):d2l.set_figsize(figsize)d2l.plt.xlabel(x_label)d2l.plt.ylabel(y_label)d2l.plt.semilogy(x_vals, y_vals)if x2_vals and y2_vals:d2l.plt.semilogy(x2_vals, y2_vals, linestyle=":")d2l.plt.legend(legend)num_epochs, loss = 100, torch.nn.MSELoss()def fit_and_plot(train_features, test_features, train_labels, test_labels):net = torch.nn.Linear(train_features.shape[-1], 1)   # 線性,傳入輸入輸出即可batch_size = min(10, train_labels.shape[0])dataset = torch.utils.data.TensorDataset(train_features, train_labels)# Dataloader根據 TensorDataset、batch_size隨機取值返回train_iter = torch.utils.data.DataLoader(dataset, batch_size, shuffle = True)# optim傳入模型的參數和學習率,返回一個優化器optimizer = torch.optim.SGD(net.parameters(), lr=0.01)train_ls, test_ls = [], []for _ in range(num_epochs):for X, y in train_iter:l = loss(net(X), y.view(-1, 1))optimizer.zero_grad()  # 將上一次的梯度清0l.backward()optimizer.step()train_labels = train_labels.view(-1, 1)test_labels = test_labels.view(-1, 1)train_ls.append(loss(net(train_features), train_labels).item())test_ls.append(loss(net(test_features), test_labels).item())print('final epoch: train loss', train_ls[-1], 'test loss', test_ls[-1])semilogy(range(1, num_epochs + 1), train_ls, 'epochs', 'loss',range(1, num_epochs + 1), test_ls, ['train', 'test'])print('weight:', net.weight.data,'\nbias:', net.bias.data)

3.11.4.3 三階多項式函數擬合(正常)

fit_and_plot(poly_features[:n_train, :], poly_features[n_train:, :], labels[:n_train], labels[n_train:])

在這里插入圖片描述

3.11.4.4 線性函數擬合(欠擬合)

fit_and_plot(features[:n_train, :], features[n_train:, :], labels[:n_train],labels[n_train:])

在這里插入圖片描述

3.11.4.5 訓練樣本不足(過擬合)

fit_and_plot(poly_features[0:2, :], poly_features[n_train:, :], labels[0:2],labels[n_train:])

在這里插入圖片描述

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/250177.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/250177.shtml
英文地址,請注明出處:http://en.pswp.cn/news/250177.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

關于'java' 不是內部或外部命令,也不是可運行的程序 或批處理文件 和 錯誤: 找不到或無法加載主類 helloworld的問題...

一、前幾天電腦重裝了一次系統將java配置的環境變量都弄沒了,自己添加了兩個新的變量JAVA_HOME(自己jdk的地址)以及在path中添加%JAVA_HOME%\bin;%JAVA_HOME%\jre\bin; 然后因為這幾天都是用eclipse進行編程的,沒有出現問題&#…

spring-boot注解詳解(一)

spring-boot注解詳解(一) SpringBootApplication SpringBootApplication (默認屬性)Configuration EnableAutoConfiguration ComponentScan。 Configuration:提到Configuration就要提到他的搭檔Bean。使用這兩個注解就可以創建一個簡單的spring配置類&#xf…

前端基礎-jQuery的優點以及用法

一、jQuery介紹 jQuery是一個輕量級的、兼容多瀏覽器的JavaScript庫。jQuery使用戶能夠更方便地處理HTML Document、Events、實現動畫效果、方便地進行Ajax交互,能夠極大地簡化JavaScript編程。它的宗旨就是:“Write less, do more.“二、jQuery的優勢 一…

[pytorch、學習] - 3.12 權重衰減

參考 3.12 權重衰減 本節介紹應對過擬合的常用方法 3.12.1 方法 正則化通過為模型損失函數添加懲罰項使學出的模型參數更小,是應對過擬合的常用手段。 3.12.2 高維線性回歸實驗 import torch import torch.nn as nn import numpy as np import sys sys.path.append("…

Scapy之ARP詢問

引言 校園網中,有同學遭受永恒之藍攻擊,但是被殺毒軟件查下,并知道了攻擊者的ip也是校園網。所以我想看一下,這個ip是PC,還是路由器。 在ip視角,路由器和pc沒什么差別。 實現 首先是構造arp報文&#xff0c…

spring-boot注解詳解(二)

ResponseBody 作用: 該注解用于將Controller的方法返回的對象,通過適當的HttpMessageConverter轉換為指定格式后,寫入到Response對象的body數據區。使用時機: 返回的數據不是html標簽的頁面,而是其他某種格式的數據時…

轉:org.apache.maven.archiver.MavenArchiver.getManifest錯誤

eclipse導入新的maven項目時,pom.xml第一行報錯: org.apache.maven.archiver.MavenArchiver.getManifest(org.apache.maven.project.MavenProject, org.apache.maven.archiver.MavenArchiveConfiguration) 解決辦法: 1、Help——>Install …

Codeforces Round #524 Div. 2 翻車記

A&#xff1a;簽到。room里有一個用for寫的&#xff0c;hack了一發1e8 1&#xff0c;結果用了大概600ms跑過去了。慘絕人寰。 #include<iostream> #include<cstdio> #include<cmath> #include<cstdlib> #include<cstring> #include<algorith…

[pytorch、學習] - 3.13 丟棄法

參考 3.13 丟棄法 過擬合問題的另一種解決辦法是丟棄法。當對隱藏層使用丟棄法時,隱藏單元有一定概率被丟棄。 3.12.1 方法 3.13.2 從零開始實現 import torch import torch.nn as nn import numpy as np import sys sys.path.append("..") import d2lzh_pytorc…

springboot---request 中Parameter,Attribute區別

HttpServletRequest類既有getAttribute()方法&#xff0c;也由getParameter()方法&#xff0c;這兩個方法有以下區別&#xff1a; &#xff08;1&#xff09;HttpServletRequest類有setAttribute()方法&#xff0c;而沒有setParameter()方法 &#xff08;2&#xff09;當兩個…

Python之令人心煩意亂的字符編碼與轉碼

ASC-II碼&#xff1a;英文1個字節&#xff08;8 byte&#xff09;&#xff0c;不支持中文&#xff1b; 高大上的中國&#xff0c;擴展出自己的gbk、gb2312、gb2318等字符編碼。 由于各個國家都有自己的編碼&#xff0c;于是就需要統一的編碼形式用于國際流傳&#xff0c;防止亂…

[pytorch、學習] - 4.1 模型構造

參考 4.1 模型構造 讓我們回顧以下多重感知機的簡潔實現中包含單隱藏層的多重感知機的實現方法。我們首先構造Sequential實例,然后依次添加兩個全連接層。其中第一層的輸出大小為256,即隱藏層單元個數是256;第二層的輸出大小為10,即輸出層單元個數是10. 4.1.1 繼承Module類來…

springboot---基本模塊詳解

概述 1.基于Spring框架的“約定優先于配置&#xff08;COC&#xff09;”理念以及最佳實踐之路。 2.針對日常企業應用研發各種場景的Spring-boot-starter自動配置依賴模塊&#xff0c;且“開箱即用”&#xff08;約定spring-boot-starter- 作為命名前綴&#xff0c;都位于org.…

第二課 運算符(day10)

第二課 運算符(day10) 一、運算符 結果是值 算數運算 a 10 * 10 賦值運算 a a 1 a1 結果是布爾值 比較運算 a 1 > 5 邏輯運算 a 1>6 or 11 成員運算 a "蚊" in "鄭建文" 二、基本數據類型 1、數值…

[pytorch、學習] - 4.2 模型參數的訪問、初始化和共享

參考 4.2 模型參數的訪問、初始化和共享 在3.3節(線性回歸的簡潔實現)中,我們通過init模塊來初始化模型的參數。我們也介紹了訪問模型參數的簡單方法。本節將深入講解如何訪問和初始化模型參數,以及如何在多個層之間共享同一份模型參數。 import torch from torch import nn…

spring-boot注解詳解(三)

1.SpringBoot/spring SpringBootApplication: 包含Configuration、EnableAutoConfiguration、ComponentScan通常用在主類上&#xff1b; Repository: 用于標注數據訪問組件&#xff0c;即DAO組件&#xff1b; Service: 用于標注業務層組件&#xff1b; RestController: 用于…

IEnumerableT和IQueryableT區分

哎&#xff0c;看了那么多&#xff0c;這個知識點還是得開一個文章 IQueryable和IEnumerable都是延時執行(Deferred Execution)的&#xff0c;而IList是即時執行(Eager Execution) IQueryable和IEnumerable在每次執行時都必須連接數據庫讀取&#xff0c;而IList讀取一次后&…

表的轉置 行轉列: DECODE(Oracle) 和 CASE WHEN 的異同點

異同點 都可以對表行轉列&#xff1b;DECODE功能上和簡單Case函數比較類似&#xff0c;不能像Case搜索函數一樣&#xff0c;進行更復雜的判斷在Case函數中&#xff0c;可以使用BETWEEN, LIKE, IS NULL, IN, EXISTS等等&#xff08;也可以使用NOT IN和NOT EXISTS&#xff0c;但是…

[pytorch、學習] - 4.4 自定義層

參考 4.4 自定義層 深度學習的一個魅力在于神經網絡中各式各樣的層,例如全連接層和后面章節將要用介紹的卷積層、池化層與循環層。雖然PyTorch提供了大量常用的層,但有時候我們依然希望自定義層。本節將介紹如何使用Module來自定義層,從而可以被重復調用。 4.4.1 不含模型參…

樹的存儲

父親表示法 顧名思義&#xff0c;就是只記錄每個結點的父結點。 int n; int p[MAX_N]; // 指向每個結點的父結點 孩子表示法 如上&#xff0c;就是只記錄每個結點的子結點。 int n; int cnt[MAX_N]; // 記錄每個結點的子結點的數量 int p[MAX_N][MAX_CNT]; // 指向每個結點的子…