4.權重衰減(weight decay)

4.1 手動實現權重衰減

import torch
from torch import nn
from torch.utils.data import TensorDataset,DataLoader
import matplotlib.pyplot as plt
def synthetic_data(w,b,num_inputs):X=torch.normal(0,1,size=(num_inputs,w.shape[0]))y=X@w+by+=torch.normal(0,0.1,size=y.shape)return X,y
def load_array(data,batch_size,is_train=True):dataset=TensorDataset(*data)return DataLoader(dataset,batch_size=batch_size,shuffle=is_train)
def init_params(num_inputs):w=torch.normal(0,1,size=(num_inputs,1),requires_grad=True)b=torch.zeros(1,requires_grad=True)return [w,b]
def l2_penalty(w):return 0.5*torch.sum(w.pow(2))def linear_reg(X,w,b):return torch.matmul(X,w)+b
def mse_loss(y_hat,y):return (y_hat-y)**2/2
def sgd(params,lr,batch_size):for params in params:params.data-=lr*params.grad/batch_sizeparams.grad.zero_()
def evaluate_loss(net, data_iter, loss):total_loss, total_samples = 0.0, 0for X, y in data_iter:l = loss(net(X), y)total_loss += l.sum().item()total_samples += y.numel()return total_loss / total_samples
n_train,n_test,num_inputs,batch_size=20,100,200,5
true_w,true_b=torch.ones((num_inputs,1))*0.01,0.05
train_data=synthetic_data(true_w,true_b,n_train)
test_data=synthetic_data(true_w,true_b,n_test)
train_iter=load_array(train_data,batch_size)
test_iter=load_array(test_data,batch_size,is_train=False)
w,b=init_params(num_inputs)
net=lambda X:linear_reg(X,w,b)
loss=mse_loss
num_epochs,lr,lambd=10,0.05,3
#animator=SimpleAnimator()
for epoch in range(num_epochs):for X,y in train_iter:l=loss(net(X),y)+lambd*l2_penalty(w)l.sum().backward()sgd([w,b],lr,batch_size)if (epoch+1)%5==0:train_loss=evaluate_loss(net,train_iter,loss)test_loss=evaluate_loss(net,test_iter,loss)#animator.add(epoch+1,train_loss,test_loss)print(f"Epoch {epoch+1}: Train Loss: {train_loss:.4f},test Loss: {test_loss:.4f}")
print('w的L2范數是:', torch.norm(w).item())
plt.show()

4.2 簡單實現權重衰減

import torch
from torch import nn
from torch.utils.data import TensorDataset,DataLoader
import matplotlib.pyplot as plt
def synthetic_data(w,b,num_inputs):X=torch.normal(0,1,size=(num_inputs,w.shape[0]))y=X@w+by+=torch.normal(0,0.1,size=y.shape)return X,y
def load_array(data,batch_size,is_train=True):dataset=TensorDataset(*data)return DataLoader(dataset,batch_size=batch_size,shuffle=is_train)
def init_params(num_inputs):w=torch.normal(0,1,size=(num_inputs,1),requires_grad=True)b=torch.zeros(1,requires_grad=True)return [w,b]
def l2_penalty(w):return 0.5*torch.sum(w.pow(2))
def linear_reg(X,w,b):return torch.matmul(X,w)+b
def mse_loss(y_hat,y):return ((y_hat-y)**2).sum()/2
def evaluate_loss(net, data_iter, loss):total_loss, total_samples = 0.0, 0for X, y in data_iter:l = loss(net(X), y)total_loss += l.item()*y.shape[0]total_samples += y.numel()return total_loss / total_samples
n_train,n_test,num_inputs,batch_size=20,100,200,5
true_w,true_b=torch.ones((num_inputs,1))*0.01,0.05
train_data=synthetic_data(true_w,true_b,n_train)
test_data=synthetic_data(true_w,true_b,n_test)
train_iter=load_array(train_data,batch_size)
test_iter=load_array(test_data,batch_size,is_train=False)
w,b=init_params(num_inputs)
net=lambda X:linear_reg(X,w,b)
loss=mse_loss
num_epochs,lr,lambd=100,0.001,3
optimizer=torch.optim.SGD([w,b],lr=lr,weight_decay=0.001)
#animator=SimpleAnimator()
for epoch in range(num_epochs):for X,y in train_iter:optimizer.zero_grad()l=loss(net(X),y)l.backward()#sgd([w,b],lr,batch_size)optimizer.step() if (epoch+1)%5==0:train_loss=evaluate_loss(net,train_iter,loss)test_loss=evaluate_loss(net,test_iter,loss)#animator.add(epoch+1,train_loss,test_loss)print(f"Epoch {epoch+1}: Train Loss: {train_loss:.4f},test Loss: {test_loss:.4f}")
print('w的L2范數是:', torch.norm(w).item())
plt.show()

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/88189.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/88189.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/88189.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

OpenCV開發-初始概念

第一章 OpenCV核心架構解析1.1 計算機視覺的基石OpenCV(Open Source Computer Vision Library)作為跨平臺計算機視覺庫,自1999年由Intel發起,已成為圖像處理領域的標準工具。其核心價值體現在:跨平臺性:支持…

LeetCode 930.和相同的二元子數組

給你一個二元數組 nums ,和一個整數 goal ,請你統計并返回有多少個和為 goal 的 非空 子數組。 子數組 是數組的一段連續部分。 示例 1: 輸入:nums [1,0,1,0,1], goal 2 輸出:4 解釋: 有 4 個滿足題目要求…

【論文解讀】Referring Camouflaged Object Detection

論文信息 論文題目:Referring Camouflaged Object Detection 論文鏈接:https://arxiv.org/pdf/2306.07532 代碼鏈接:https://github.com/zhangxuying1004/RefCOD 錄用期刊:TPAMI 2025 論文單位:南開大學 ps&#xff1a…

Spring中過濾器和攔截器的區別及具體實現

在 Spring 框架中,過濾器(Filter) 和 攔截器(Interceptor) 都是用于處理 HTTP 請求的中間件,但它們在作用范圍、實現方式和生命周期上有顯著區別。以下是詳細對比和實現方式:核心區別特性過濾器…

CANFD 數據記錄儀在新能源汽車售后維修中的應用

一、前言隨著新能源汽車市場如火如荼和新能源汽車電子系統的日益復雜,傳統維修手段在面對復雜和偶發故障時往往捉襟見肘,CANFD 數據記錄儀則憑借其獨特優勢,為售后維修帶來新的解決方案。二、 詳細介紹在新能源汽車領域,CANFD 數據…

某當CRM XlsFileUpload存在任意文件上傳(CNVD-2025-10982)

免責聲明 本文檔所述漏洞詳情及復現方法僅限用于合法授權的安全研究和學術教育用途。任何個人或組織不得利用本文內容從事未經許可的滲透測試、網絡攻擊或其他違法行為。使用者應確保其行為符合相關法律法規,并取得目標系統的明確授權。 前言: 我們建立了一個更多,更全的…

自然語言處理與實踐

文章目錄Lesson1:Introduction to NLP、NLP 基礎與文本預處理1.教材2.自然語言處理概述(1)NLP 的定義、發展歷程與應用場景(2)NLP 的主要任務:分詞、詞性標注、命名實體識別、句法分析等2.文本預處理3.文本表示方法:詞向量表示/詞表征Lesson2…

CSS揭秘:9.自適應的橢圓

前置知識:border-radius 用法前言 本篇目標是實現一個橢圓,半橢圓,四分之一橢圓。 一、圓形和橢圓 當我們想實現一個圓形時,通常只要指定 border-radius 為 width/height 的一半就可以了。 當我們指定的border-radius的值超過了 w…

善用關系網絡:開源AI大模型、AI智能名片與S2B2C商城小程序賦能下的成功新路徑

摘要:本文聚焦于關系在個人成功中的關鍵作用,指出關系即財富,善用關系、拓展人脈是成功的重要途徑。在此基礎上,引入開源AI大模型、AI智能名片以及S2B2C商城小程序等新興技術工具,探討它們如何助力個體在復雜的關系網絡…

2025年滲透測試面試題總結-2025年HW(護網面試) 34(題目+回答)

安全領域各種資源,學習文檔,以及工具分享、前沿信息分享、POC、EXP分享。不定期分享各種好玩的項目及好用的工具,歡迎關注。 目錄 2025年HW(護網面試) 34 一、網站信息收集 核心步驟與工具 二、CDN繞過與真實IP獲取 6大實戰方法 三、常…

螢石全新上線企業AI對話智能體,開啟IoT人機交互新體驗

一、什么是螢石AI對話智能體?如何讓設備聽得到、聽得懂?這次螢石發布的AI對話Agent,讓設備能進行自然、流暢、真人感的AI對話智能體,幫助開發者打造符合業務場景的AI對話智能體能力,實現全雙工、實時打斷、可擴展、對話…

智紳科技:以科技為翼,構建養老安全守護網

隨著我國老齡化進程加速,2025年60歲以上人口突破3.2億,養老安全問題成為社會關注的焦點。智紳科技作為智慧養老領域的領軍企業,以“科技賦能健康,智慧守護晚年”為核心理念,通過人工智能、物聯網、大數據等技術融合&am…

矩陣系統源碼部署實操指南:搭建全解析,支持OEM

矩陣系統源碼部署指南矩陣系統是一種高效的數據處理框架,適用于大規模分布式計算。以下為詳細部署步驟,包含OEM支持方案。環境準備確保服務器滿足以下要求:操作系統:Linux(推薦Ubuntu 18.04/CentOS 7)硬件配…

基于python的個人財務記賬系統

博主介紹:java高級開發,從事互聯網行業多年,熟悉各種主流語言,精通java、python、php、爬蟲、web開發,已經做了多年的畢業設計程序開發,開發過上千套畢業設計程序,沒有什么華麗的語言&#xff0…

從 CODING 停服到極狐 GitLab “接棒”,軟件研發工具市場風云再起

CODING DevOps 產品即將停服的消息,如同一顆重磅炸彈,在軟件研發工具市場炸開了鍋。從今年 9 月開始,CODING 將陸續下線其 DevOps 產品,直至 2028 年 9 月 30 日完全停服。這一變動讓眾多依賴 CODING 平臺的企業和個人開發者陷入了…

#滲透測試#批量漏洞挖掘#HSC Mailinspector 任意文件讀取漏洞(CVE-2024-34470)

免責聲明 本教程僅為合法的教學目的而準備,嚴禁用于任何形式的違法犯罪活動及其他商業行為,在使用本教程前,您應確保該行為符合當地的法律法規,繼續閱讀即表示您需自行承擔所有操作的后果,如有異議,請立即停…

深入解析C++驅動開發實戰:優化高效穩定的驅動應用

深入解析C驅動開發實戰:優化高效穩定的驅動應用 在現代計算機系統中,驅動程序(Driver)扮演著至關重要的角色,作為操作系統與硬件設備之間的橋梁,驅動程序負責管理和控制硬件資源,確保系統的穩定…

SNIProxy 輕量級匿名CDN代理架構與實現

🌐 SNIProxy 輕量級匿名CDN代理架構與實現 🏗? 1. 整體架構設計 🔹 1.1 系統架構概覽 #mermaid-svg-S4n74I2nPLGityDB {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-S4n74I2nP…

Qt的信號與槽(一)

Qt的信號與槽(一)1.信號和槽的基本認識2.connect3.關閉窗口的按鈕4.函數的根源5.形參和實參的類型🌟hello,各位讀者大大們你們好呀🌟🌟 🚀🚀系列專欄:【Qt的學習】 &…

springMVC02-視圖解析器、RESTful設計風格,靜態資源訪問配置

一、SpringMVC 的視圖在 SpringMVC 中,視圖的作用渲染數據,將模型 Model (將控制器(Controller))中的數據展示給用戶。在 Java 代碼中,視圖由接口 org.springframework.web.servlet.View 表示SpringMVC 視圖的種類很多…