【深度學習】神經網絡過擬合與欠擬合-part5

八、過擬合與欠擬合

訓練深層神經網絡時,由于模型參數較多,數據不足的時候容易過擬合,正則化技術就是防止過擬合,提升模型的泛化能力和魯棒性 (對新數據表現良好 對異常數據表現良好)

1、概念

1.1過擬合

在訓練模型數據擬合能力很強,表現很好,在測試數據上表現很差

原因

數據不足;模型太復雜;正則化強度不足

1.2 欠擬合

模型學習能力不足,無法捕捉數據中的關系

?1.3 如何判斷

過擬合:訓練時候誤差低,驗證時候誤差高 說明過度擬合了訓練數據中的噪聲或特定模式

欠擬合:訓練和測試的誤差都高,說明模型太簡單,無法捕捉到復雜模式

2.解決欠擬合

增加模型復雜度,增加特征,減少正則化強度,訓練更長時間

3.解決過擬合

考慮損失函數,損失函數的目的是使預測值與真實值無限接近,如果在原來的損失函數上添加一個非0的變量

其中f(w)是關于權重w的函數,f(w)>0

要使L1變小,就要使L變小的同時,也要使f(w)變小。從而控制權重w在較小的范圍內。

3.1 L2正則化

L2在損失函數中添加權重參數的平方和來實現,目標是懲罰過大的參數

3.1.1 數字表示

損失函數L(tθ),其中θ表示權重參數,加入L2正則化后

其中:

  • L(θ) 是原始損失函數(比如均方誤差、交叉熵等)。

  • λ是正則化強度,控制正則化的力度。

  • θi 是模型的第 i 個權重參數。

  • 是所有權重參數的平方和,稱為 L2 正則化項。

3.1.2 梯度更新

L2正則下,梯度更新時,不僅考慮原始損失函數梯度,還要考慮正則化的影響

其中:

η 是學習率。

是損失函數關于參數 \theta_t 的梯度。

是 L2 正則化項的梯度,對應的是參數值本身的衰減。

很明顯,參數越大懲罰力度就越大,從而讓參數逐漸趨向于較小值,避免出現過大的參數。

3.1.4 代碼
import torch 
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt# 設置支持中文的字體
plt.rcParams['font.sans-serif'] = ['SimHei']  # 指定字體為黑體
plt.rcParams['axes.unicode_minus'] = False   # 解決負號顯示問題
# 設置種子
torch.manual_seed(42)# 隨機數據
n_samples = 100
n_features = 20
x = torch.randn(n_samples, n_features)
y = torch.randn(n_samples,1)# 定義全鏈接神經網絡
class SimpleNet(nn.Module):def __init__(self):super(SimpleNet,self).__init__()self.fc1 = nn.Linear(n_features,50)self.fc2 = nn.Linear(50,1)def forward(self,x):x = torch.relu(self.fc1(x))return self.fc2(x)
# 訓練函數
def train_model(use_l2 = False,weight_decay = 0.01,n_epoches = 100):model = SimpleNet()criterion = nn.MSELoss()# 選擇優化器if use_l2:optimizer = optim.SGD(model.parameters(),lr = 0.01,weight_decay = weight_decay)else:optimizer = optim.SGD(model.parameters(),lr = 0.01,)train_losses = []for epoch in range(n_epoches):optimizer.zero_grad()outputs = model(x)loss = criterion(outputs,y)loss.backward()optimizer.step()train_losses.append(loss.item())if(epoch+1) % 10 == 0:print(F":Epoches[{epoch+1}/{n_epoches},Loss:{loss.item():.4f}]")return train_losses# 訓練比較兩種模型
train_losses_no_l2 = train_model(use_l2 = False)
train_losses_with_l2 = train_model(use_l2 = True,weight_decay=0.01)# 繪制損失曲線
plt.plot(train_losses_no_l2,label = "沒有L2正則化")
plt.plot(train_losses_with_l2,label = "有L2正則化")
plt.xlabel("Epoch")
plt.ylabel("損失")
plt.title('L2正則化vs無正則化')
plt.legend()
plt.show()
?

3.2 L1正則化

通過在損失函數中添加權重參數的絕對值來之和來約束模型復雜度

3.2.1 數學表示

設模型的原始損失函數為 L(θ),其中θ表示模型權重參數,則加入 L1 正則化后的損失函數表示為:

其中:

  • L(θ) 是原始損失函數。

  • λ是正則化強度,控制正則化的力度。

  • |θi| 是模型第i 個參數的絕對值。

  • 是所有權重參數的絕對值之和,這個項即為 L1 正則化項。

3.2.2 梯度更新

L1的正則化下 梯度更新公式

其中:

  • η是學習率。

  • 是損失函數關于參數 \theta_t 的梯度。

  • 是參數 \theta_t 的符號函數,表示當 \theta_t 為正時取值為 1,為負時取值為 -1,等于 0 時為 0。

L1正則化依賴參數的絕對值,梯度更新不說簡單的線性縮小,而是通過符號函數來調整參數的方向,這就是為什么L1正則化促使參數變為0

3.2.3 作用
  1. 稀疏性:L1 正則化的一個顯著特性是它會促使許多權重參數變為 。這是因為 L1 正則化傾向于將權重絕對值縮小到零,使得模型只保留對結果最重要的特征,而將其他不相關的特征權重設為零,從而實現 特征選擇 的功能。

  2. 防止過擬合:通過限制權重的絕對值,L1 正則化減少了模型的復雜度,使其不容易過擬合訓練數據。相比于 L2 正則化,L1 正則化更傾向于將某些權重完全移除,而不是減小它們的值。

  3. 簡化模型:由于 L1 正則化會將一些權重變為零,因此模型最終會變得更加簡單,僅依賴于少數重要特征。這對于高維度數據特別有用,尤其是在特征數量遠多于樣本數量的情況下。

  4. 特征選擇:因為 L1 正則化會將部分權重置零,因此它天然具有特征選擇的能力,有助于自動篩選出對模型預測最重要的特征。

3.2.4 與L2對比

L1 正則化 更適合用于產生稀疏模型,會讓部分權重完全為零,適合做特征選擇。

L2 正則化 更適合平滑模型的參數,避免過大參數,但不會使權重變為零,適合處理高維特征較為密集的場景。

3.3 Dropout

每次訓練迭代中,一部分神經元被丟棄(p為丟棄概率)

被選中的神經元不參與傳播

在測試階段,所有的神經元都參與計算,但對權重進行縮放(1-p),以保持輸出的期望值一致

Dropout是一種訓練過程中隨機丟棄部分神經元的計算,減少神經元之間的依賴防止模型過于復雜,避免過擬合

3.3.1 實現
import torch
import torch.nn as nndropout = nn.Dropout(p=0.5)
x = torch.randint(0, 10, (5, 6),dtype=torch.float)
print(x)print(dropout(x))

Dropout過程:

按照指定概率把部分神經元值設為0

為避免該操作帶來的影響,需要對非0的元素使用縮放因子1/(1-p)進行強化

假設某個神經元的輸出為 x,Dropout 的操作可以表示為:

  • 在訓練階段:

  • 在測試階段:
    y=x

為什么要使用縮放因子1/(1-p)?

在訓練階段,Dropout 會以概率 p隨機將某些神經元的輸出設置為 0,而以概率 1?p 保留這些神經元。

假設某個神經元的原始輸出是 x,那么在訓練階段,它的期望輸出值為:

通過這種縮放,訓練階段的期望輸出值仍然是 x,與沒有 Dropout 時一致。

3.3.2 權重影響
import torch
import torch.nn as nn
from PIL import Image
from torchvision import transforms
import os
from matplotlib import pyplot as plt
torch.manual_seed(42)def load_img(path,resize = (224,224)):pil_img = Image.open(path).convert('RGB')print("Original Image Size: ", pil_img.size)transform = transforms.Compose([transforms.Resize(resize),transforms.ToTensor()])return transform(pil_img)# dirpath = os.path.dirname(__file__)
# path = os.path.join(dirpath,"img","100.jpg")
path = "./100.jpg"trans_img = load_img(path)
trans_img = trans_img.unsqueeze(0)
dropout = nn.Dropout2d(p=0.2)
drop_img = dropout(trans_img)
trans_img = trans_img.squeeze(0).permute(1,2,0).numpy()
drop_img = drop_img.squeeze(0).permute(1,2,0).numpy()drop_img = drop_img.clip(0,1)
fig = plt.figure(figsize = (10,5))ax1 = fig.add_subplot(1,2,1)
ax1.imshow(trans_img)ax2 = fig.add_subplot(1,2,2)
ax2.imshow(drop_img)plt.show()

?

nn.Dropout2d(p):Dropout2d 是針對二維數據設計的 Dropout 層,它在訓練過程中隨機將輸入張量的某些通道(二維平面)置為零。

參數要求格式示例形狀說明
輸入(N, C, H, W)(16, 64, 32, 32)批大小×通道×高×寬
輸出(N, C, H, W)(16, 64, 32, 32)與輸入同形,部分通道歸零

3.5 數據增強

?樣本不足時過擬合的常見原因之一

  • 當訓練數據過少時,模型容易“記住”有限的樣本(包括噪聲和無關細節),而非學習通用的規律。

  • 簡單模型更可能捕捉真實規律,但數據不足時,復雜模型會傾向于擬合訓練集中的偶然性模式(噪聲)。

  • 樣本不足時,訓練集的分布可能與真實分布偏差較大,導致模型學到錯誤的規律。

  • 小數據集中,個別樣本的噪聲(如標注錯誤、異常值)會被放大,模型可能將噪聲誤認為規律。

數據增強的好處:

大幅度降低數據采集和標注成本;

降低過擬合風險,提高模型泛化能力

transforms:

常用變換類

transforms.Compose:將多個變換操作組合成一個流水線。

transforms.ToTensor:將 PIL 圖像或 NumPy 數組轉換為 PyTorch 張量,將圖像數據從 uint8 類型 (0-255) 轉換為 float32 類型 (0.0-1.0)。

transforms.Normalize:對張量進行標準化。

transforms.Resize:調整圖像大小。

transforms.CenterCrop:從圖像中心裁剪指定大小的區域。

transforms.RandomCrop:隨機裁剪圖像。

transforms.RandomHorizontalFlip:隨機水平翻轉圖像。

transforms.RandomVerticalFlip:隨機垂直翻轉圖像。

transforms.RandomRotation:隨機旋轉圖像。

transforms.ColorJitter:隨機調整圖像的亮度、對比度、飽和度和色調。

transforms.RandomGrayscale:隨機將圖像轉換為灰度圖像。

transforms.RandomResizedCrop:隨機裁剪圖像并調整大小。

3.5.1 圖片縮放
from PIL import Image
img1 = plt.imread('./img/100.jpg')
print(img1.shape)
plt.imshow(img1)
plt.show()img = Image.open('./img/100.jpg')
transorm = transforms.Compose([transforms.Resize((224,224)),transforms.ToTensor()])
r_img = transorm(img)
print(r_img.shape)r_img = r_img.permute(1,2,0)plt.imshow(r_img)
plt.show()

?

?

3.5.2 隨機裁剪
# 裁剪
img = Image.open('./img/100.jpg')
transform = transforms.Compose([transforms.RandomCrop(size=(224, 224)),transforms.ToTensor()])
r_img = transform(img)
print(r_img.shape)r_img = r_img.permute(1, 2, 0)plt.imshow(r_img)
plt.show()

?

3.5.3 隨機水平翻轉
img = Image.open("./img/100.jpg")
transform = transforms.Compose([transforms.RandomHorizontalFlip(p=1),transforms.ToTensor()])
r_img = transform(img)
print(r_img.shape)r_img = r_img.permute(1, 2, 0)plt.imshow(r_img)
plt.show()

?

3.5.4 調整圖片顏色

transforms.ColorJitter(brightness=0, contrast=0, saturation=0, hue=0)

brightness表示亮度:

float或者(min,max)

    • 如果是 float(如 brightness=0.2),則亮度在 [max(0, 1 - 0.2), 1 + 0.2] = [0.8, 1.2] 范圍內隨機縮放。

    • 如果是 (min, max)(如 brightness=(0.5, 1.5)),則亮度在 [0.5, 1.5] 范圍內隨機縮放。

contrast:

  • 對比度調整的范圍。

  • 格式與 brightness 相同。

saturation:

  • 飽和度調整的范圍。

  • 格式與 brightness 相同。

hue:

  • 色調調整的范圍。

  • 可以是一個浮點數(表示相對范圍)或一個元組 (min, max)。

  • 取值范圍必須為 [-0.5, 0.5](因為色相在 HSV 色彩空間中是循環的,超出范圍會導致顏色異常)。

  • 例如,hue=0.1 表示色調在 [-0.1, 0.1] 之間隨機調整。

img = Image.open("./img/100.jpg")transform = transforms.Compose([transforms.ColorJitter(brightness=0.2,contrast= 0.2,saturation= 0.2,hue= 0.2),transforms.ToTensor()])
r_img = transform(img)
print(r_img.shape)
r_img = r_img.permute(1,2,0)plt.imshow(r_img)
plt.show()

3.5.5 隨機旋轉

RandomRotation用于對圖像進行隨機旋轉。

transforms.RandomRotation(degrees, interpolation=InterpolationMode.NEAREST, expand=False, center=None, fill=0
)

degrees:

  • 旋轉角度的范圍,可以是一個浮點數或元組 (min_degree, max_degree)。

  • 例如,degrees=30 表示旋轉角度在 [-30, 30] 之間隨機選擇。

  • 例如,degrees=(30, 60) 表示旋轉角度在 [30, 60] 之間隨機選擇。

interpolation:

  • 插值方法,用于旋轉圖像。

  • 默認是 InterpolationMode.NEAREST(最近鄰插值)。

  • 其他選項包括 InterpolationMode.BILINEAR(雙線性插值)、InterpolationMode.BICUBIC(雙三次插值)等。

expand:

  • 是否擴展圖像大小以適應旋轉后的圖像。如:當需要保留完整旋轉后的圖像時(如醫學影像、文檔掃描)

  • 如果為 True,旋轉后的圖像可能會比原始圖像大。

  • 如果為 False,旋轉后的圖像大小與原始圖像相同。

center:

  • 旋轉中心點的坐標,默認為圖像中心。

  • 可以是一個元組 (x, y),表示旋轉中心的坐標。

fill:

  • 旋轉后圖像邊緣的填充值。

  • 可以是一個浮點數(用于灰度圖像)或一個元組(用于 RGB 圖像)。默認填充0(黑色)

image = Image.open("./img/100.jpg")
transform = transforms.RandomRotation(degrees=90)rotated_image = transform(image)plt.imshow(rotated_image)
plt.axis('off')
plt.show()

?

3.5.6 圖片轉Tensor
import torch
from PIL import Image
from torchvision import transforms
import osimg = Image.open('./img/100.jpg')
transform = transforms.ToTensor()
img_tensor = transform(img)
print(img_tensor)

?

3.5.7 Tensor轉圖片
# img_tensor = torch.rand(3, 224, 224)
img = Image.open('./img/100.jpg')
transform1 = transforms.ToTensor()
img_tensor = transform1(img)transform2 = transforms.ToPILImage()
img = transform2(img_tensor)plt.imshow(img)
plt.show()

?

3.5.8 歸一化

標準化:將圖像的像素值從原始范圍([0,255]或[0,1],轉化為均值為0,標準差為1的分布。

加速訓練:標準化后的數據分布更均勻,有利于訓練

提高模型性能

img = Image.open('./img/100.jpg')
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
r_img= transform(img)
print(r_img.shape)r_img = r_img.permute(1,2,0)plt.imshow(r_img)
plt.show()

?

3.5.9 數據增強整合

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/914863.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/914863.shtml
英文地址,請注明出處:http://en.pswp.cn/news/914863.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

JavaScript的“硬件窺探術”:瀏覽器如何讀取你的設備信息?

JavaScript的“硬件窺探術”:瀏覽器如何讀取你的設備信息? 在Web開發的世界里,JavaScript一直扮演著“幕后魔術師”的角色。從簡單的頁面跳轉到復雜的實時數據處理,它似乎總能用最輕巧的方式解決最棘手的問題。但你是否想過&#…

論安全架構設計(層次)

安全架構設計(層次) 摘要 2021年4月,我有幸參與了某保險公司的“優車險”項目的建設開發工作,該系統以車險報價、車險投保和報案理賠為核心功能,同時實現了年檢代辦、道路救援、一鍵挪車等增值服務功能。在本項目中&a…

滾珠導軌常見的故障有哪些?

在自動化生產設備、精密機床等領域,滾珠導軌就像是設備平穩運行的 “軌道”,為機械部件的直線運動提供穩準導向。但導軌使用時間長了,難免會出現這樣那樣的故障。滾珠脫落:可能由安裝不當、導軌損壞、超負荷運行、維護不當或惡劣環…

機器視覺的包裝盒絲印應用

在包裝盒絲網印刷領域,隨著消費市場對產品外觀精細化要求的持續提升,傳統印刷工藝面臨多重挑戰:多色套印偏差、曲面基材定位困難、異形結構印刷失真等問題。雙翌光電科技研發的WiseAlign視覺系統,通過高精度視覺對位技術與智能化操…

Redis學習-03重要文件及作用、Redis 命令行客戶端

Redis 重要文件及作用 啟動/停止命令或腳本 /usr/bin/redis-check-aof -> /usr/bin/redis-server /usr/bin/redis-check-rdb -> /usr/bin/redis-server /usr/bin/redis-cli /usr/bin/redis-sentinel -> /usr/bin/redis-server /usr/bin/redis-server /usr/libexec/red…

SVN客戶端(TortoiseSVN)和SVN-VS2022插件(visualsvn)官網下載

SVN服務端官網下載地址:https://sourceforge.net/projects/win32svn/ SVN客戶端工具(TortoiseSVN):https://plan.io/tortoise-svn/ SVN-VS2022插件(visualsvn)官網下載地址:https://www.visualsvn.com/downloads/

990. 等式方程的可滿足性

題目&#xff1a;第一次思考&#xff1a; 經典并查集 實現&#xff1a;class UnionSet{public:vector<int> parent;public:UnionSet(int n) {parent.resize(n);}void init(int n) {for (int i 0; i < n; i) {parent[i] i;}}int find(int x) {if (parent[x] ! x) {pa…

HTML--教程

<!DOCTYPE html> <html> <head> <meta charset"utf-8"> <title>菜鳥教程(runoob.com)</title> </head> <body><h1>我的第一個標題</h1><p>我的第一個段落。</p> </body> </html&g…

Leetcode刷題營第二十七題:二叉樹的最大深度

104. 二叉樹的最大深度 給定一個二叉樹 root &#xff0c;返回其最大深度。 二叉樹的 最大深度 是指從根節點到最遠葉子節點的最長路徑上的節點數。 示例 1&#xff1a; 輸入&#xff1a;root [3,9,20,null,null,15,7] 輸出&#xff1a;3示例 2&#xff1a; 輸入&#xff…

微信小程序翻書效果

微信小程序翻書效果 wxml <viewwx:for"{{imgList}}" hidden"{{pagenum > imgList.length - index - 1}}"wx:key"index"class"list-pape" style"{{index imgList.length - pagenum - 1 ? clipPath1 : }}"bindtouchst…

個人IP的塑造方向有哪些?

在內容創業和自媒體發展的浪潮下&#xff0c;個人IP的價值越來越受到重視。個人IP不僅是個人品牌的延伸&#xff0c;更是吸引流量來實現商業變現的重要工具。想要塑造個人IP&#xff0c;需要我們有明確的內容方向和策略&#xff0c;下面就讓我們來簡單了解下。一、展現自我形象…

Spring之【BeanDefinition】

目錄 BeanDefinition接口 代碼片段 作用 BeanDefinitionRegistry接口 代碼片段 作用 RootBeanDefinition實現類 GenericBeanDefinition實現類 BeanDefinition接口 代碼片段 public interface BeanDefinition {// ...void setScope(Nullable String scope);NullableSt…

GD32VW553-IOT LED呼吸燈項目

GD32VW553-IOT LED呼吸燈項目項目簡介這是一個基于GD32VW553-IOT開發板的LED呼吸燈演示項目。通過PWM技術控制LED亮度&#xff0c;實現多種呼吸燈效果&#xff0c;展示RISC-V MCU的PWM功能和實時控制能力。功能特性1. 多種呼吸燈效果正弦波呼吸&#xff1a;自然平滑的呼吸效果線…

Linux(Ubuntu)硬盤使用情況解析(已房子舉例)

文章目錄前言輸出字段詳解1.核心字段說明2.生活化的方式解釋&#xff08;已房間為例&#xff09;3.重點理解①主臥室 (/)??②??臨時房 (tmpfs)??總結前言 “df -h” 是在 Linux ??檢查磁盤空間狀態的最基本、最常用的命令之一??。當發現系統變慢、程序報錯說“磁盤空…

vue中的this.$set

在 Vue 2 中&#xff0c;this.$set 是一個用于響應式地添加新屬性到已有對象的全局 API。它的主要作用是解決 Vue 無法檢測到對象屬性添加或刪除的限制&#xff08;由于 Vue 2 的響應式系統基于 Object.defineProperty 實現&#xff09;。1. 為什么需要 this.$set&#xff1f; …

python爬蟲技術——基礎知識、實戰

參考文獻&#xff1a; Python爬蟲入門(一)&#xff08;適合初學者&#xff09;-CSDN博客 一、常用爬蟲工具包 Scrapy 語言: Python特點: 高效、靈活的爬蟲框架&#xff0c;適合大型爬蟲項目。 BeautifulSoup 語言: Python特點: 用于解析HTML和XML&#xff0c;簡單易用。 Sel…

QT 交叉編譯環境下,嵌入式設備顯示字體大小和QT Creator 桌面顯示不一致問題解決

第一步&#xff1a; 發送fc-list 命令 &#xff0c;查找嵌入式環境下支持的字庫第二步 為每個控件指定字庫文件&#xff0c;以label控件為例&#xff1a;int fontId QFontDatabase::addApplicationFont("/usr/share/fonts/source-han-sans-cn/SourceHanSansCN-Normal.otf…

php生成二維碼

<?php // 包含qrlib庫 require_once(qrlib.php);// 二維碼內容 $data https://www.example.com;// 生成二維碼圖片的文件名 $filename qrcode.png;// 二維碼參數 $errorCorrectionLevel L; // 錯誤糾正級別 $matrixPointSize 5; // 生成圖片大小// 生成二維碼 QR…

#systemverilog# 關鍵字之 變量聲明周期與靜態方法關系探討

我們先看來年下面的代碼: class test; task static bar(); …… endtask class test; static task bar(); …… endtask 在 SystemVerilog 中,這兩種聲明方式有本質區別,涉及方法的靜態/非靜態屬性以及局部變量的生命周期。 1. task static bar(); ... endt…