深度學習4.4筆記

《動手學深度學習》-4.4-筆記

驗證數據集:通常是從訓練集中劃分出來的一部分數據,不要和訓練數據混在一起,評估模型好壞的數據集

測試數據集:只用一次的數據集

k-折交叉驗證(k-Fold Cross-Validation)是一種統計方法,用于評估和比較機器學習模型的性能。它通過將數據集分成k個子集(或“折”)來實現,每個子集都作為一次測試集,而剩余的k-1個子集則作為訓練集。這個過程會重復k次,每次選擇不同的子集作為測試集,最終將k次測試結果的平均值作為模型的性能評估。常用k=5/10,在沒有足夠多數據使用時。

總結:

訓練數據集:訓練模型參數

驗證數據集:選擇模型超參數

非大型數據集上通常使用k-折交叉驗證

欠擬合(Underfitting)

欠擬合是指模型對訓練數據的擬合程度不夠,無法捕捉到數據中的規律和模式。換句話說,模型過于簡單,無法很好地描述數據的特征。

過擬合(Overfitting)

過擬合是指模型對訓練數據擬合得過于完美,以至于模型在訓練數據上表現很好,但在新的、未見過的數據上表現很差。換句話說,模型過度學習了訓練數據中的噪聲和細節,而無法泛化到新的數據。

模型容量的定義

表示容量:模型的最大擬合能力,即通過調節參數,模型能夠表示的函數族

  1. 模型參數數量:參數越多,模型容量通常越高。

  2. 模型結構復雜度:例如,神經網絡的層數和每層的神經元數量。

  3. 數據復雜度:數據的復雜度(如樣本數量、特征數量)也會影響模型容量的選擇

模型容量與過擬合、欠擬合的關系

  • 容量不足:模型無法很好地擬合訓練數據,導致欠擬合。

  • 容量過高:模型可能會過度擬合訓練數據中的噪聲,導致過擬合

總結;

模型容量需要匹配數據復雜度,否則可能過擬合或欠擬合

統計機器學習提供數學工具來衡量模型復雜度?

代碼部分:
?

import math
import numpy as np
import torch
from torch import nn
from d2l import torch as d2l

引入需要的庫

max_degree = 20  # 多項式的最大階數
n_train, n_test = 100, 100  # 訓練和測試數據集大小
true_w = np.zeros(max_degree)  # 分配大量的空間
true_w[0:4] = np.array([5, 1.2, -3.4, 5.6])features = np.random.normal(size=(n_train + n_test, 1))#隨機生成200個樣本點(服從標準正態分布的x值)。
np.random.shuffle(features)#打亂樣本順序。
poly_features = np.power(features, np.arange(max_degree).reshape(1, -1))
for i in range(max_degree):poly_features[:, i] /= math.gamma(i + 1)  # gamma(n)=(n-1)!
# labels的維度:(n_train+n_test,)
labels = np.dot(poly_features, true_w)
labels += np.random.normal(scale=0.1, size=labels.shape)#加上噪音

分析:

生成一個多項式回歸的訓練/測試數據集。也就是說,我們在模擬一個“隱藏函數”,然后加一點噪聲,生成一些數據,來用于模型訓練。

多項式階數:我們打算生成最多20階的多項式數據(比如 1, x, x2, ..., x1?)。

true_w = np.zeros(max_degree)  # 創建一個長度為20的權重數組,初始值全是0
true_w[0:4] = np.array([5, 1.2, -3.4, 5.6])

這一步設置了我們想要“模擬”的真實多項式模型的參數。它實際上模擬了一個三階多項式:

y = 5 + 1.2x - 3.4x2 + 5.6x3

其余的高階項(x? ~ x1?)的系數為0。

poly_features = np.power(features, np.arange(max_degree).reshape(1, -1))


這一步是關鍵!構造一個 多項式特征矩陣。

假設 features = [[x1], [x2], ..., [x200]],
我們把它轉化為:
[[1, x1, x12, x13, ..., x1^19],
?[1, x2, x22, x23, ..., x2^19],
?...
]

for i in range(max_degree):poly_features[:, i] /= math.gamma(i + 1)  # gamma(n)=(n-1)!

這一步是做多項式特征的縮放處理,用的是數學中的Gamma函數

舉例:

  • gamma(1) = 0! = 1

  • gamma(2) = 1! = 1

  • gamma(3) = 2! = 2

  • gamma(4) = 3! = 6 ...

所以這是在做歸一化的處理,讓高階項不會變得太大

labels = np.dot(poly_features, true_w)

這一步是最核心的:根據我們設定的權重 true_w 計算標簽 y 值

可以理解為:
對每一行的多項式特征向量和權重向量做內積(點乘),
也就是:

  • 所以最終的標簽是:
    真實標簽 + 小范圍擾動

# NumPy ndarray轉換為tensor
true_w, features, poly_features, labels = [torch.tensor(x, dtype=torch.float32) for x in [true_w, features, poly_features, labels]]
#這句用列表推導式,把之前的 NumPy 數組全部 轉換成 PyTorch 的 tensor(張量)格式,這樣就可以用 PyTorch 來訓練模型啦!
features[:2], poly_features[:2, :], labels[:2]#這個不是賦值語句,而是查看前兩個樣本的輸入特征、多項式特征和標簽的值,

已經把 NumPy 的數組轉成了 PyTorch 的張量

def evaluate_loss(net, data_iter, loss):  #@save"""評估給定數據集上模型的損失"""metric = d2l.Accumulator(2)  # 損失的總和,樣本數量for X, y in data_iter:out = net(X)#前向傳播 + 計算損失 讓模型對輸入 X 做預測,得到輸出 outy = y.reshape(out.shape)l = loss(out, y)#計算預測結果和真實值之間的損失metric.add(l.sum(), l.numel())#計算預測結果和真實值之間的損失return metric[0] / metric[1]#計算預測結果和真實值之間的損失

評估模型在某個數據集(data_iter)上的平均損失

分析:

  • net: 模型(PyTorch 中定義的神經網絡)

  • data_iter: 數據迭代器(通常是訓練集或測試集的 DataLoader

  • loss: 損失函數(比如 nn.MSELoss()

def train(train_features, test_features, train_labels, test_labels,num_epochs=400):#定義了一個訓練函數loss = nn.MSELoss(reduction='none')#均方誤差損失函數(MSE),但不求平均,保留每個樣本的損失值。input_shape = train_features.shape[-1]# 不設置偏置,因為我們已經在多項式中實現了它net = nn.Sequential(nn.Linear(input_shape, 1, bias=False))batch_size = min(10, train_labels.shape[0])train_iter = d2l.load_array((train_features, train_labels.reshape(-1,1)),batch_size)#把訓練和測試數據打包成 DataLoader,方便模型一批一批訓練test_iter = d2l.load_array((test_features, test_labels.reshape(-1,1)),batch_size, is_train=False)trainer = torch.optim.SGD(net.parameters(), lr=0.01)#使用隨機梯度下降(SGD)優化模型參數,學習率為 0.01animator = d2l.Animator(xlabel='epoch', ylabel='loss', yscale='log',xlim=[1, num_epochs], ylim=[1e-3, 1e2],legend=['train', 'test'])#用 D2L 里的 Animator 動態繪圖類,記錄訓練過程的 loss 曲線for epoch in range(num_epochs):#訓練一個 epoch,用的是 D2L 中封裝好的 train_epoch_ch3(每輪完整訓練一遍所有 batch)d2l.train_epoch_ch3(net, train_iter, loss, trainer)if epoch == 0 or (epoch + 1) % 20 == 0:#每隔20輪(或第1輪),就評估一下訓練集和測試集上的平均損失,然后加到圖上animator.add(epoch + 1, (evaluate_loss(net, train_iter, loss),evaluate_loss(net, test_iter, loss)))print('weight:', net[0].weight.data.numpy())  #打印最終訓練得到的權重

PyTorch + 多項式特征訓練一個線性模型,并可視化訓練過程的

  • 用線性模型擬合你設計的多項式數據(多階特征)

  • 使用 MSELoss + SGD 訓練

  • 可視化訓練和測試集上的損失變化

  • 打印最終訓練好的模型參數,看看學得準不準

按書中的報錯然后,你調用了 l.backward() 來反向傳播,但這個 l 是一個 不需要梯度 的張量(requires_grad=False),所以無法反向傳播!

還是之前的做法:

loss = nn.MSELoss(reduction='none')返回的是一個 每個樣本的損失 的張量,而不是所有樣本損失的平均或總和。

看看 train_epoch_ch3 的定義),它里面可能是直接用了 l = loss(y_hat, y),然后 l.backward()

修改后

正常:

欠擬合

如果用不同復雜度的模型來擬合這個函數,表現會怎樣?

# 只用 1 和 x 兩項(線性模型)
train(poly_features[:n_train, :2], poly_features[n_train:, :2],labels[:n_train], labels[n_train:])

這意味著你在訓練一個線性模型

這個模型完全忽略了二階項 x2 和三階項 x3,所以它根本學不出原來的復雜模式

結果就是:

  • 訓練損失很高

  • 測試損失也高

  • 模型欠擬合:學得太簡單,跟不上真實的非線性函數

# 使用與真實模型相同的特征階數
train(poly_features[:n_train, :4], poly_features[n_train:, :4],labels[:n_train], labels[n_train:])

?這次你用了前4項:

注意:你訓練的時候也會擬合這幾個特征,也就是:

而我們真實函數 y = 5 + 1.2x - 3.4x^2 + 5.6x^3,剛好就是3階多項式

  • 訓練損失下降得更快

  • 最終損失更低

  • 模型可以很好地擬合數據,不欠擬合也不過擬合

?

poly_features[:, :] 表示使用 所有20階的多項式特征,也就是:

  • 訓練了一個 20維輸入的線性模型

  • 訓練次數設為 1500 輪(比前面更多)

  • 但現在用一個 包含20階的模型 去擬合這些數據,雖然原函數只有3階,后面17個高階項都是“多余的”。

  • 訓練集表現很好(損失很低),但在測試集上 泛化能力變差

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/73808.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/73808.shtml
英文地址,請注明出處:http://en.pswp.cn/web/73808.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

vue 兩種路由模式

一、兩種模式比較 在vue.js中,路由模式分為兩種:hash 模式和 history 模式。這兩種模式決定了URL的結構和瀏覽器歷史記錄的管理方式。 1. hash 模式帶 #,#后面的地址變化不會引起頁面的刷新。換句話說,hash模式不會將#后面的地址…

Android生態大變革,谷歌調整開源政策,核心開發不再公開

“開源”這個詞曾經是Android的護城河,如今卻成了谷歌的燙手山芋。最近谷歌宣布調整Android的開源政策,核心開發將全面轉向私有分支。翻譯成人話就是:以后Android的核心更新,不再公開共享了。 這操作不就是開源變節嗎,…

JavaScript中集合常用操作方法詳解

JavaScript中集合常用操作方法詳解 JavaScript中的集合主要包括數組(Array)、集合(Set)和映射(Map)。下面我將詳細介紹這些集合類型的常用操作方法。 數組(Array) 數組是JavaScript中最常用的集合類型,提供了豐富的操作方法。 創建數組 // 字面量創建 const ar…

【HC-05】藍牙串口通信模塊調試與應用(1)

一、HC-05 基礎學習視頻 HC-05藍牙串口通信模塊調試與應用1 二、HC-05學習視頻課件

【學Rust寫CAD】18 定點數2D仿射變換矩陣結構體(MatrixFixedPoint結構別名)

源碼 // matrix/fixed.rs use crate::fixed::Fixed; use super::generic::Matrix;/// 定點數矩陣類型別名 pub type MatrixFixedPoint Matrix<Fixed, Fixed, Fixed, Fixed, Fixed, Fixed>;代碼解析 這段代碼定義了一個定點數矩陣的類型別名 MatrixFixedPoint&#xff…

axios文件下載使用后端傳遞的名稱

java后端通過HttpServletResponse 返回文件流 在Content-Disposition中插入文件名 一定要設置Access-Control-Expose-Headers&#xff0c;代表跨域該Content-Disposition返回Header可讀&#xff0c;如果沒有&#xff0c;前端是取不到Content-Disposition的&#xff0c;可以在統…

HarmonyOS之深入解析如何根據url下載pdf文件并且在本地顯示和預覽

一、文件下載 ① 網絡請求配置 下載在線文件&#xff0c;需要訪問網絡&#xff0c;因此需要在 config.json 中添加網絡權限&#xff1a; {"module": {"requestPermissions": [{"name": "ohos.permission.INTERNET","reason&qu…

鴻蒙前后端項目源碼-點餐v3.0-原創!原創!原創!

鴻蒙前后端點餐項目源碼含文檔ArkTS語言. 原創作品.我半個月寫的原創作品&#xff0c;請尊重原創。 原創作品&#xff0c;盜版必究&#xff01;&#xff01;&#xff01;&#xff01; 原創作品&#xff0c;盜版必究&#xff01;&#xff01;&#xff01;&#xff01; 原創作…

VUE3+TypeScript項目,使用html2Canvas+jspdf生成PDF并實現--分頁--頁眉--頁尾

使用html2CanvasJsPDF生成pdf&#xff0c;并實現分頁添加頁眉頁尾 1.封裝方法htmlToPdfPage.ts /**path: src/utils/htmlToPdfPage.tsname: 導出頁面為PDF格式 并添加頁眉頁尾 **/ /*** 封裝思路* 1.將頁面根據A4大小分隔邊距&#xff0c;避免內容被中間截斷* 所有元素層級不要…

5.Excel:從網上獲取數據

一 用 Excel 數據選項卡獲取數據的方法 連接。 二 要求獲取實時數據 每1分鐘自動更新數據。 A股市場_同花順行情中心_同花順財經網 用上面方法將數據加載進工作表中。 在表格內任意區域右鍵&#xff0c;刷新。 自動刷新&#xff1a; 三 缺點 Excel 只能爬取網頁上表格類型的…

《深度剖析SQL之WHERE子句:數據過濾的藝術》

在當今數據驅動的時代&#xff0c;數據處理和分析能力已成為職場中至關重要的技能。SQL作為一種強大的結構化查詢語言&#xff0c;在數據管理和分析領域占據著核心地位。而WHERE子句&#xff0c;作為SQL中用于數據過濾的關鍵組件&#xff0c;就像是一把精準的手術刀&#xff0c…

華為eNSP-配置靜態路由與靜態路由備份

一、靜態路由介紹 靜態路由是指用戶或網絡管理員手工配置的路由信息。當網絡拓撲結構或者鏈路狀態發生改變時&#xff0c;需要網絡管理人員手工修改靜態路由信息。相比于動態路由協議&#xff0c;靜態路由無需頻繁地交換各自的路由表&#xff0c;配置簡單&#xff0c;比較適合…

Docker 快速入門指南

Docker 快速入門指南 1. Docker 常用指令 Docker 是一個輕量級的容器化平臺&#xff0c;可以幫助開發者快速構建、測試和部署應用程序。以下是一些常用的 Docker 命令。 1.1 鏡像管理 # 搜索鏡像 docker search <image_name># 拉取鏡像 docker pull <image_name>…

基礎認證-單選題(一)

單選題 1、下列關于request方法和requestlnStream方法說法錯誤的是(C) A 都支持取消訂閱響應事件 B 都支持訂閱HTTP響應頭事件 C 都支持HttpResponse返回值類型 D 都支持傳入URL地址和相關配置項 2、如需修改Text組件文本的透明度可通過以下哪個屬性方法進行修改 (C) A dec…

Logback使用和常用配置

Logback 是 Spring Boot 默認集成的日志框架&#xff0c;相比 Log4j&#xff0c;它性能更高、配置更靈活&#xff0c;并且天然支持 Spring Profile 多環境配置。以下是詳細配置步驟及常用配置示例。 一、添加依賴&#xff08;非 Spring Boot 項目&#xff09; 若項目未使用 Sp…

MySQL基礎語法DDLDML

目錄 #1.創建和刪除數據庫 ?#2.如果有lyt就刪除,沒有則創建一個新的lyt #3.切換到lyt數據庫下 #4.創建數據表并設置列及其屬性,name是關鍵詞要用name包圍 ?編輯 #5.刪除數據表 #5.查看創建的student表 #6.向student表中添加數據,數據要與列名一一對應 #7.查詢studen…

在windows下安裝windows+Ubuntu16.04雙系統(下)

這篇文章的內容主要來源于這篇文章&#xff0c;為正式安裝windowsUbuntu16.04雙系統部分。在正式安裝前&#xff0c;若還沒有進行前期準備工作&#xff08;1.分區2.制作啟動u盤&#xff09;&#xff0c;見《在windows下安裝windowsUbuntu16.04雙系統(上)》 二、正式安裝Ubuntu …

Ubuntu24.04 離線安裝 MySQL8.0.41

一、環境準備 1.1 官方下載MySQL8.0.41 完整包 1.2 上傳包 & 解壓 上傳包名稱是&#xff1a;mysql-server_8.0.41-1ubuntu24.04_amd64.deb-bundle.tar # 切換到上傳目錄 cd /home/MySQL8 # 解壓&#xff1a; tar -xvf mysql-server_8.0.41-1ubuntu24.04_amd64.deb-bundl…

記錄一次Dell服務器更換內存條報錯解決過程No memory found

文章目錄 問題問題分析解決流程總結 問題 今天給服務器添加了幾個內存條&#xff0c;開啟后報錯 No memory found No useable DlMMs found. Verify the DlMMsare properly seated and that they are installed in the correct sockets. 問題分析 這個錯誤說明服務器在啟動時沒…

Apache HttpClient使用

一、Apache HttpClient 基礎版 HttpClients 是 Apache HttpClient 庫中的一個工具類&#xff0c;用于創建和管理 HTTP 客戶端實例。Apache HttpClient 是一個強大的 Java HTTP 客戶端庫&#xff0c;用于發送 HTTP 請求并處理 HTTP 響應。HttpClients 提供了多種方法來創建和配…