21.過擬合和欠擬合示例

1. 背景介紹

在機器學習和深度學習中,過擬合和欠擬合是兩個非常重要的概念。過擬合指的是模型在訓練數據上表現很好,但在新的測試數據上效果變差的情況。欠擬合則是指模型無法很好地擬合訓練數據的情況。這兩種情況都會導致模型無法很好地泛化,影響最終的預測和應用效果。

為了幫助大家更好地理解過擬合和欠擬合的概念及其應對方法,我將通過一個基于PyTorch的代碼示例來演示這兩種情況的具體表現。我們將生成一個拋物線數據集,并定義三種不同復雜度的模型,分別對應欠擬合、正常擬合和過擬合的情況。通過可視化訓練和測試誤差的曲線圖,以及預測結果的散點圖,我們可以直觀地觀察到這三種情況下模型的擬合效果。

2. 核心概念與聯系

過擬合和欠擬合是機器學習和深度學習中兩個相互對應的概念:

1. 過擬合(Overfitting): 模型在訓練數據上表現很好,但在新的測試數據上效果變差的情況。這通常是由于模型過于復雜,過度擬合了訓練數據中的噪聲和細節,導致無法很好地推廣到未知數據。

2. 欠擬合(Underfitting): 模型無法很好地擬合訓練數據的情況。這通常是由于模型過于簡單,無法捕捉訓練數據中的復雜模式和關系。

這兩種情況都會導致模型在實際應用中無法很好地泛化,因此需要采取相應的措施來防止和緩解過擬合和欠擬合。常見的應對方法包括:

- 增加訓練樣本數量
- 減少模型復雜度(比如調整網絡層數、神經元個數等)
- 使用正則化技術(如L1/L2正則化、Dropout等)
- 調整超參數(如學習率、批量大小等)
- 特征工程(如特征選擇、降維等)

通過合理的模型設計和超參數調優,我們可以尋找到一個恰當的模型復雜度,使其既能很好地擬合訓練數據,又能在新數據上保持良好的泛化性能。這就是機器學習中的**bias-variance tradeoff**,也是我們在實際應用中需要權衡的一個關鍵點。

?3. 核心算法原理和具體操作步驟

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split# 生成數據
np.random.seed(42)
X = np.random.uniform(-5, 5, 500)
y = X**2 + 1 + np.random.normal(0, 1, 500)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)# 定義三種不同復雜度的模型
class UnderFitModel(nn.Module):def __init__(self):super(UnderFitModel, self).__init__()self.fc = nn.Linear(1, 1)def forward(self, x):return self.fc(x)class NormalFitModel(nn.Module):def __init__(self):super(NormalFitModel, self).__init__()self.fc1 = nn.Linear(1, 8)self.fc2 = nn.Linear(8, 1)self.activation = nn.ReLU()def forward(self, x):x = self.fc1(x)x = self.activation(x)x = self.fc2(x)return xclass OverFitModel(nn.Module):def __init__(self):super(OverFitModel, self).__init__()self.fc1 = nn.Linear(1, 32)self.fc2 = nn.Linear(32, 32)self.fc3 = nn.Linear(32, 1)self.activation = nn.ReLU()def forward(self, x):x = self.fc1(x)x = self.activation(x)x = self.fc2(x)x = self.activation(x)x = self.fc3(x)return x# 訓練模型并記錄誤差
def train_and_evaluate(model, train_loader, test_loader):optimizer = torch.optim.SGD(model.parameters(), lr=0.005)criterion = nn.MSELoss()train_losses = []test_losses = []for epoch in range(100):model.train()train_loss = 0.0for inputs, targets in train_loader:optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, targets)loss.backward()optimizer.step()train_loss += loss.item()train_loss /= len(train_loader)train_losses.append(train_loss)model.eval()test_loss = 0.0with torch.no_grad():for inputs, targets in test_loader:outputs = model(inputs)loss = criterion(outputs, targets)test_loss += loss.item()test_loss /= len(test_loader)test_losses.append(test_loss)return train_losses, test_losses# 訓練三種模型并可視化
under_fit_model = UnderFitModel()
normal_fit_model = NormalFitModel()
over_fit_model = OverFitModel()under_fit_train_losses, under_fit_test_losses = train_and_evaluate(under_fit_model, train_loader, test_loader)
normal_fit_train_losses, normal_fit_test_losses = train_and_evaluate(normal_fit_model, train_loader, test_loader)
over_fit_train_losses, over_fit_test_losses = train_and_evaluate(over_fit_model, train_loader, test_loader)plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.plot(under_fit_train_losses, label='Under-fit Train Loss')
plt.plot(under_fit_test_losses, label='Under-fit Test Loss')
plt.plot(normal_fit_train_losses, label='Normal-fit Train Loss')
plt.plot(normal_fit_test_losses, label='Normal-fit Test Loss')
plt.plot(over_fit_train_losses, label='Over-fit Train Loss')
plt.plot(over_fit_test_losses, label='Over-fit Test Loss')
plt.xlabel('Epoch')
plt.ylabel('MSE Loss')
plt.title('Training and Test Loss Curves')
plt.legend()plt.subplot(1, 2, 2)
plt.scatter(X_test, y_test, label='True')
plt.scatter(X_test, under_fit_model(X_test).detach().numpy(), label='Under-fit Prediction')
plt.scatter(X_test, normal_fit_model(X_test).detach().numpy(), label='Normal-fit Prediction')
plt.scatter(X_test, over_fit_model(X_test).detach().numpy(), label='Over-fit Prediction')
plt.xlabel('x')
plt.ylabel('y')
plt.title('Test Set Predictions')
plt.legend()plt.show()

這個代碼示例涵蓋了我們之前討論的各個步驟:

數據生成: 我們生成了一個拋物線形狀的數據集,并使用train_test_split函數將其劃分為訓練集和測試集。
模型定義: 我們定義了三種不同復雜度的PyTorch模型,分別對應欠擬合、正常擬合和過擬合的情況。
訓練與評估: 我們實現了一個train_and_evaluate函數,該函數負責訓練模型并記錄訓練集和測試集上的損失。
可視化: 最后,我們使用matplotlib繪制了訓練損失和測試損失的曲線圖,以及在測試集上的預測結果。

欠擬合模型:訓練誤差和測試誤差都較大,說明模型無法很好地擬合數據。在測試集上的預測結果也存在較大偏差。
正常擬合模型:訓練誤差和測試誤差較為接近,說明模型的擬合效果較好。在測試集上的預測也比較準確。
過擬合模型:訓練誤差很小,但測試誤差較大,說明模型在訓練集上表現很好,但在新數據上泛化能力較差。在測試集上的預測結果存在一定偏差。
通過這個實例,我們可以直觀地觀察到不同復雜度模型在訓練和泛化性能上的差異。欠擬合模型在訓練集和測試集上的損失都較大,說明模型無法很好地擬合數據。正常擬合模型在訓練集和測試集上的損失較為接近,說明模型具有較好的泛化能力。而過擬合模型在訓練集上的損失很小,但在測試集上的損失較大,說明模型過于復雜,在新數據上泛化性能較差。

通過這種觀察訓練誤差和測試誤差的方法,我們可以及時發現模型存在的問題,并針對性地調整模型結構、添加正則化等手段來優化模型性能。這是機器學習和深度學習中非常基礎和重要的實踐技能。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/22297.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/22297.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/22297.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

視頻號小店,常見的違規條例!98%的商家必犯的違規細節!

哈嘍~我是電商月月 做電商,不管哪個平臺都有屬于自己的規則條例,這些違規細節,一定要提前了解 所以今天,月月就給大家分享一下,做視頻號小店的話,有哪些常見的違規細節 這里我們分三點講解 一&#xff…

【分享】兩種方法禁止修改Word文檔

對于比較重要的Word文件,不想被隨意編輯修改,可以試試以下兩個方法,不清楚的小伙伴,一起來看看吧! 方法1:設置“只讀方式” 我們可以給Word文檔設置以“只讀方式”打開,這樣就算編輯修改了文檔…

如何通過SD-WAN提升企業溝通效率

在數字化飛速發展的今天,企業對大數據和實時商業數據傳輸的需求日益增長。傳統的專線連接技術已無法滿足企業對快速部署商業應用和高效網絡連接的需求。在這種背景下,SD-WAN成為提升企業網絡溝通效率的關鍵技術。 SD-WAN的靈活部署模式 SD-WAN提供了高度…

6月軟考新通知:24下集成大概率是中級蕞簡單的一門

2024下半年軟考6月新通知: 一、24下軟考考試時間安排: 24下半年軟考報名時間:8月19日-9月15日 24下半年軟考考試時間:11月9-12日 24下半年軟考成績查詢:12月中(預計) 二、考情分析 24上軟考…

09_JavaWeb會話

1.會話 HTTP是一種無狀態協議; HTTP協議對于發送過請求或者響應都不做持久化處理具體來說就是客戶端發送請求,服務器接收請求,但是服務器自身不會記錄每一條請求都是由哪一個客戶端發出的; 會話管理是通過Cookie和Session配合解…

【排序】插入排序,希爾排序

前面我們講述了冒泡排序和選擇排序,我們本章講的排序方法是插入排序,插入排序是希爾排序實現的基礎函數,大家一定要好好理解插入排序的邏輯,這樣才能在后面學習希爾排序的時候,更容易的去理解,我們直接開始…

關于無法通過腳本啟動Kafka集群的解決辦法

啟動Kafka集群時,需要在每臺個節點上啟動啟動服務,比較麻煩,通過寫了以下腳本來進行啟停;發現能正常使用停止功能,不能正常啟動Kafka; Kafka啟停腳本: ## 以防不能通過shell腳本啟動Kafka服務…

富格林:揭露黑幕平臺保障安全

富格林指出,很多黑幕平臺都會將自己包裝得光鮮亮麗后,再出來誘惑投資者,使得投資者資金安全得不到保障,有苦說不出。富格林表示,黑幕平臺的套路其實是非常常見的,只要投資者熟知并能夠分辨出,就…

C盤擴容——只能刪除C盤右邊的磁盤對C盤進行擴展

winR彈出命令框 輸入:compmgmt.msc 進入磁盤管理頁面 注意:被刪除盤如果有重要數據信息,請備份。 或者刪除之前轉移至其他盤,否則刪除之后,則無法找回。 尤其是安裝的軟件。 規范安裝目錄十分重要。 將C盤右邊的磁盤&a…

最全 Inno Setup 教程-[FILE] Flag參數

【1】此參數是一個附加選項的集合。可以使用空格將多個選項分隔開。 【2】支持以下選項: 32位 當在“Source”和“DestDir”參數中使用{sys}常量時,將該常量映射到32位系統目錄。將“regserver”和“regtypelib”標志設置為將文件視為32位,…

安防綜合管理系統EasyCVR視頻匯聚平臺GA/T 1400協議中的關鍵消息交互示例

在當今的信息化時代,公共安全防范日益成為保障社會和諧穩定的關鍵。視頻監控系統作為現代安全防范的重要手段,正不斷在公安、交通、城市管理等領域發揮著越來越重要的作用。而GA/T 1400協議視圖庫,作為公安視頻圖像信息應用系統的標準&#x…

Vue3 子組件訪問父組件的方法 - 父組件訪問子組件的屬性或方法 - 子組件修改父組件的值

一。子組件訪問父組件的方法 //父組件 <DialogEditing close-dialog"handleClose" /> const handleClose () > {};//子組件 const emit defineEmits(["closeDialog"]); const close () > {emit("closeDialog"); // 使用 };二。父…

健身日記之倒立俯臥撐學習——起始日2024.6.4

文章目錄 前言 自我介紹 昔日計劃 新目標計劃 瓶頸突破嘗試 參考視頻及文章 前言 有輕微健身基礎&#xff0c;正式接觸街健五大神技&#xff0c;立志在兩年內解鎖全部&#xff0c;將有機會的進行日常訓練和目標肌群鍛煉&#xff0c;這里向大家展示我的計劃和安排&#xf…

opencv-python(五)

opencv的顏色通道中順序是B&#xff0c;G&#xff0c;R。 圖像屬性 import cv2img cv2.imread(jk.jpg) print(fshape{img.shape}) print(fsize{img.size}) print(fdtype{img.dtype}) shape&#xff1a;圖像像素的行&#xff0c;列&#xff0c;通道 size&#xff1a;行數 X …

YonSuite收款通,助力企業618更快收款

隨著電商節日“618”的臨近&#xff0c;各大企業紛紛摩拳擦掌&#xff0c;準備在這場年中大促中大展身手。然而&#xff0c;隨著銷售額的激增&#xff0c;收款管理問題也愈發凸顯&#xff0c;成為制約企業快速發展的重要瓶頸。在這個關鍵時刻&#xff0c;YonSuite收款通憑借其卓…

Python實現登錄到遠程主機,然后在遠程主機上繼續連接遠程主機

實現功能 登錄到遠程主機&#xff0c;然后在遠程主機上繼續連接遠程主機&#xff0c;執行命令。 import paramiko import time# 第二個遠程主機的連接信息&#xff08;在第一個遠程主機上執行SSH連接時使用&#xff09; second_remote_host 192.168.xx.xxx # 創建SSH客…

通過命令行將tar壓縮文件解壓縮到指定目錄|Linux

要將all.tar文件解壓縮到指定目錄下&#xff0c;你可以使用Linux命令行中的tar命令。以下是具體步驟&#xff1a; 打開終端&#xff08;Terminal&#xff09;。 使用cd命令切換到你想要解壓縮文件的目標目錄。例如&#xff1a; cd /path/to/your/directory將/path/to/your/dir…

echarts圖例formatter配置添加百分比

echarts圖例如何添加百分比 const pieChart async () > {const myChart echarts.init(piepic.value)const piedata await getPieData(); // 等待數據返回myChart.setOption({title: {},grid: {},tooltip: {trigger: item,},legend: {top: middle,align:left,icon: circl…

都可以寫好后端接口

在后端工程師的日常開發中&#xff0c;我們都曾想過 怎么設計一個良好的接口呢&#xff1f;需要考慮的點有哪些。來 給您。 1、請求參數校驗 這個是大家都能想到的&#xff0c;也是一個良好的接口必備的前提條件&#xff0c;通過入參的校驗我們可以過濾掉許多無效的請求&…

零基礎學Java第二十七天之前端-HTML5詳解

前端-HTML5詳解 一、概述 HTML5是HTML的第五個版本&#xff0c;它對HTML進行了許多改進和擴展&#xff0c;使得網頁開發更加豐富和便利。HTML5是Web標準的重要組成部分&#xff0c;旨在提高瀏覽器兼容性&#xff0c;統一網頁開發標準。HTML5不僅包括了HTML的基本元素和標簽&am…