DAY 37 早停策略和模型權重的保存

  1. 早停策略

import torch.nn as nn
import torch.optim as optim
import time
import matplotlib.pyplot as plt
from tqdm import tqdm# Define the MLP model
class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.fc1 = nn.Linear(X_train.shape[1], 10)self.relu = nn.ReLU()self.fc2 = nn.Linear(10, 2)  # Binary classificationdef forward(self, x):out = self.fc1(x)out = self.relu(out)out = self.fc2(out)return out# Instantiate the model
model = MLP().to(device)# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# Training settings
num_epochs = 20000
early_stop_patience = 50  # Epochs to wait for improvement
best_loss = float('inf')
patience_counter = 0
best_epoch = 0
early_stopped = False# Track losses
train_losses = []
test_losses = []
epochs = []# Start training
start_time = time.time()
with tqdm(total=num_epochs, desc="Training Progress", unit="epoch") as pbar:for epoch in range(num_epochs):model.train()optimizer.zero_grad()outputs = model(X_train)train_loss = criterion(outputs, y_train)train_loss.backward()optimizer.step()# Evaluate on the test setmodel.eval()with torch.no_grad():outputs_test = model(X_test)test_loss = criterion(outputs_test, y_test)if (epoch + 1) % 200 == 0:train_losses.append(train_loss.item())test_losses.append(test_loss.item())epochs.append(epoch + 1)# Early stopping checkif test_loss.item() < best_loss:  # If current test loss is better than the bestbest_loss = test_loss.item()  # Update best lossbest_epoch = epoch + 1  # Update best epochpatience_counter = 0  # Reset counter# Save the best modeltorch.save(model.state_dict(), 'best_model.pth')else:patience_counter += 1if patience_counter >= early_stop_patience:print(f"Early stopping triggered! No improvement for {early_stop_patience} epochs.")print(f"Best test loss was at epoch {best_epoch} with a loss of {best_loss:.4f}")early_stopped = Truebreak  # Stop the training loop# Update the progress barpbar.set_postfix({'Train Loss': f'{train_loss.item():.4f}', 'Test Loss': f'{test_loss.item():.4f}'})# Update progress bar every 1000 epochsif (epoch + 1) % 1000 == 0:pbar.update(1000)# Ensure progress bar reaches 100%
if pbar.n < num_epochs:pbar.update(num_epochs - pbar.n)time_all = time.time() - start_time  # Calculate total training time
print(f'Training time: {time_all:.2f} seconds')# If early stopping occurred, load the best model
if early_stopped:print(f"Loading best model from epoch {best_epoch} for final evaluation...")model.load_state_dict(torch.load('best_model.pth'))# Continue training for 50 more epochs after loading the best model
num_extra_epochs = 50
for epoch in range(num_extra_epochs):model.train()optimizer.zero_grad()outputs = model(X_train)train_loss = criterion(outputs, y_train)train_loss.backward()optimizer.step()# Evaluate on the test setmodel.eval()with torch.no_grad():outputs_test = model(X_test)test_loss = criterion(outputs_test, y_test)train_losses.append(train_loss.item())test_losses.append(test_loss.item())epochs.append(num_epochs + epoch + 1)# Print progress for the extra epochsprint(f"Epoch {num_epochs + epoch + 1}: Train Loss = {train_loss.item():.4f}, Test Loss = {test_loss.item():.4f}")# Plot the loss curves
plt.figure(figsize=(10, 6))
plt.plot(epochs, train_losses, label='Train Loss')
plt.plot(epochs, test_losses, label='Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Test Loss over Epochs')
plt.legend()
plt.grid(True)
plt.show()# Evaluate final accuracy on the test set
model.eval()
with torch.no_grad():outputs = model(X_test)_, predicted = torch.max(outputs, 1)correct = (predicted == y_test).sum().item()accuracy = correct / y_test.size(0)print(f'Test Accuracy: {accuracy * 100:.2f}%')

@浙大疏錦行

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/86169.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/86169.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/86169.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

零基礎搭建Spring AI本地開發環境指南

Spring AI 是一個 Spring 官方團隊主導的開源項目&#xff0c;旨在將生成式人工智能&#xff08;Generative AI&#xff09;能力無縫集成到 Spring 應用程序中。它提供了一個統一的、Spring 風格的抽象層&#xff0c;簡化了與各種大型語言模型&#xff08;LLMs&#xff09;、嵌…

windows登錄系統配置雙因子認證的解決方案

在數字化浪潮席卷全球的今天&#xff0c;安全如同氧氣般不可或缺。Verizon《2023年數據泄露調查報告》指出&#xff0c;80%的黑客攻擊與登錄憑證失竊直接相關。當傳統密碼防護變得千瘡百孔&#xff0c;企業如何在身份驗證的戰場上贏得主動權&#xff1f;答案就藏在"雙保險…

Java數據結構——線性表Ⅱ

一、鏈式存儲結構概述 1. 基本概念&#xff08;邏輯分析&#xff09; 核心思想&#xff1a;用指針將離散的存儲單元串聯成邏輯上連續的線性表 設計動機&#xff1a;解決順序表 "預先分配空間" 與 "動態擴展" 的矛盾 關鍵特性&#xff1a; 結點空間動態…

技術基石:SpreadJS 引擎賦能極致體驗

在能源行業數字化轉型的浪潮中&#xff0c;青島國瑞信息技術有限公司始終以技術創新為核心驅動力&#xff0c;不斷探索前沿技術在能源領域的深度應用。其推出的 RCV 行列視生產數據應用系統之所以能夠在行業內脫穎而出&#xff0c;離不開背后強大的技術基石 ——SpreadJS 引擎。…

Typora - Typora 打字機模式

Typora 打字機模式 1、基本介紹 Typora 打字機模式&#xff08;Typewriter Mode&#xff09;是一種專注于當前寫作行的功能 打字機模式會自動將正在編輯的行保持在屏幕中央&#xff0c;讓用戶更集中注意力&#xff0c;類似于傳統打字機的體驗 2、開啟方式 點擊 【視圖】 -…

3.0 compose學習:MVVM框架+Hilt注解調用登錄接口

文章目錄 前言&#xff1a;1、添加依賴1.1 在settings.gradle.kts中添加1.2 在應用級的build.gradle.kts添加插件依賴1.3 在module級的build.gradle.kts添加依賴 2、實體類2.1 request2.2 reponse 3、網絡請求3.1 ApiService3.2 NetworkModule3.3 攔截器 添加token3.4 Hilt 的 …

git學習資源

動畫演示&#xff1a;Learn Git Branching 終極目標&#xff08;能看懂即入門&#xff09;&#xff1a;git 簡明指南 Git 教程 | 菜鳥教程

C++ 第二階段:模板編程 - 第一節:函數模板與類模板

目錄 一、模板編程的核心概念 1.1 什么是模板編程&#xff1f; 二、函數模板詳解 2.1 函數模板的定義與使用 2.1.1 基本語法 2.1.2 示例&#xff1a;通用交換函數 2.1.3 類型推導規則 2.2 函數模板的注意事項 2.2.1 普通函數與函數模板的調用規則 2.2.2 隱式類型轉換…

Docker 報錯“x509: certificate signed by unknown authority”的排查與解決實錄

目錄 &#x1f527;Docker 報錯“x509: certificate signed by unknown authority”的排查與解決實錄 &#x1f4cc; 問題背景 &#x1f9ea; 排查過程 步驟 1&#xff1a;確認加速器地址是否可訪問 步驟 2&#xff1a;檢查 Docker 是否真的使用了鏡像加速器 步驟 3&…

達夢以及其他圖形化安裝沒反應或者報錯No more handles [gtk_init_check() failed]

本人安裝問題和解決步驟如下&#xff0c;僅供參考 執行 DMInstall.bin 報錯 按照網上大部分解決方案 export DISPLAY:0.0 xhost 重新執行 DMInstall.bin&#xff0c;無報錯也無反應 安裝xclock測試也是同樣效果&#xff0c;無報錯也無反應 最開始猜測可能是連接工具問題&a…

項目節奏不一致時,如何保持全局平衡

項目節奏不一致時&#xff0c;如何保持全局平衡的關鍵在于&#xff1a;構建跨項目協調機制、合理配置資源、建立共享節奏看板、優先明確戰略驅動、引入緩沖與預警機制。其中&#xff0c;構建跨項目協調機制尤為關鍵&#xff0c;它能將各項目的排期、優先級和風險實時聯動&#…

macOS - 安裝微軟雅黑字體

文章目錄 1、下載資源2、安裝3、查看字體 app4、卸載字體 macOS 中打開 Windows 傳輸過來的文件的時候&#xff0c;經常會提示 xxx 字體缺失。下面以安裝 微軟雅黑字體為例。 1、下載資源 https://github.com/BronyaCat/Win-Fonts-For-Mac 2、安裝 雙擊 Fonts 文件夾下的 msy…

ArkUI-X資源分類與訪問

應用開發過程中&#xff0c;經常需要用到顏色、字體、間距、圖片等資源&#xff0c;在不同的設備或配置中&#xff0c;這些資源的值可能不同。 應用資源&#xff1a;借助資源文件能力&#xff0c;開發者在應用中自定義資源&#xff0c;自行管理這些資源在不同的設備或配置中的…

11-StarRocks故障診斷FAQ

StarRocks故障診斷FAQ 概述 本文檔整理了StarRocks故障診斷過程中常見的問題和解決方案,涵蓋了故障排查、日志分析、性能診斷、問題定位等各個方面,幫助用戶快速定位和解決StarRocks相關問題。 故障排查FAQ Q1: 如何排查連接故障? A: 連接故障排查方法: 1. 網絡連通性…

敏捷項目管理怎么做?4大主流方法論對比及工具適配方案

在傳統瀑布式項目管理中&#xff0c;需求定義、設計、開發、測試等環節如同工業流水線般嚴格線性推進&#xff0c;展現出強大的流程控制能力。不過今天的軟件迭代周期已壓縮至周級乃至日級&#xff0c;瀑布式管理難以應對需求的快速變化&#xff0c;敏捷式項目管理則以“小步快…

解決YOLO模型從Python遷移到C++時目標漏檢問題——跨語言部署中的關鍵陷阱與解決方案

問題背景 當我們將Python訓練的YOLO模型部署到C環境時&#xff0c;常遇到部分目標漏檢問題。這通常源于預處理/后處理差異、數據類型隱式轉換或模型轉換誤差。本文通過完整案例解析核心問題并提供可落地的解決方案。 一、常見原因分析 預處理不一致 Python常用OpenCV&#xff…

【2025CCF中國開源大會】開放注冊與會議通知(第二輪)

點擊藍字 關注我們 CCF Opensource Development Committee 2025 CCF中國開源大會 由中國計算機學會主辦的 2025 CCF中國開源大會&#xff08;CCF ChinaOSC&#xff09;擬于 2025年8月2日-3日 在上海召開。本屆大會以“蓄勢引領、眾行致遠”為主題&#xff0c;由上海交通大學校長…

本地聊天室

測試版還沒測試過&#xff0c;后面更新不會繼續開源&#xff0c;有問題自行修復 開發環境: PHP版本7.2 Swoole擴展 本地服務器環境&#xff08;如XAMPP、MAMP&#xff09; 功能說明: 注冊/登錄系統&#xff0c;支持本地用戶數據存儲 ? 發送文本、圖片和語音消息 ? 實…

golang學習隨便記x-調試與雜類(待續)

編譯與調試 調試時從終端鍵盤輸入 調試帶有需要用戶鍵盤輸入的程序時&#xff0c;VSCode報錯&#xff1a;Unable to process evaluate: debuggee is running&#xff0c;因為調試器不知道具體是哪個終端輸入。需要配置啟動文件 .vscode/launch.json 類似如下&#xff08;注意…

MultipartFile、File 和 Mat

1. MultipartFile (來自 Spring Web) 用途&#xff1a; 代表通過 multipart 形式提交&#xff08;通常是 HTTP POST 請求&#xff09;接收到的文件。 它是 Spring Web 中用于處理 Web 客戶端文件上傳的核心接口。 關鍵特性&#xff1a; 抽象&#xff1a; 這是一個接口&#xf…