Python day37

@浙大疏錦行?python day37.

內容:

  • 保存模型只需要保存模型的參數即可,使用的時候直接構建模型再導入參數即可
# 保存模型參數
torch.save(model.state_dict(), "model_weights.pth")# 加載參數(需先定義模型結構)
model = MLP()  # 初始化與訓練時相同的模型結構
model.load_state_dict(torch.load("model_weights.pth"))
# model.eval()  # 切換至推理模式(可選)
  • 也可以同時保存模型 + 參數
# 保存整個模型
torch.save(model, "full_model.pth")# 加載模型(無需提前定義類,但需確保環境一致)
model = torch.load("full_model.pth")
model.eval()  # 切換至推理模式(可選)
  • 保存訓練狀態,用于在訓練過程中保存中間狀態
# # 保存訓練狀態
# checkpoint = {
#     "model_state_dict": model.state_dict(),
#     "optimizer_state_dict": optimizer.state_dict(),
#     "epoch": epoch,
#     "loss": best_loss,
# }
# torch.save(checkpoint, "checkpoint.pth")# # 加載并續訓
# model = MLP()
# optimizer = torch.optim.Adam(model.parameters())
# checkpoint = torch.load("checkpoint.pth")# model.load_state_dict(checkpoint["model_state_dict"])
# optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
# start_epoch = checkpoint["epoch"] + 1  # 從下一輪開始訓練
# best_loss = checkpoint["loss"]# # 繼續訓練循環
# for epoch in range(start_epoch, num_epochs):
#     train(model, optimizer, ...)
  • 對于跨框架保存時,需要保存為onnx文件
  • 早停策略:在訓練過程中,訓練集的損失不斷下降,但是驗證集或者測試集的損失反而上升,此時這種情況稱為過擬合;針對這種情況,我們可以引入早停策略,用于提前終止訓練;
  • 可以在訓練達到某一個epoch次數時進行一次驗證,將此次結果和歷史最佳結果進行比較,如果結果更好則保留,如果不好則會使臨時記錄變量自增,一旦連續自增到預設值,直接停止。這種設計是為了防止出現震蕩波動情況,避免偶然性
  • 可以結合上面的保存斷點 breakpoint
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
import time
import matplotlib.pyplot as plt
from tqdm import tqdm  # 導入tqdm庫用于進度條顯示
import warnings
warnings.filterwarnings("ignore")  # 忽略警告信息# 設置GPU設備
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"使用設備: {device}")# 加載鳶尾花數據集
iris = load_iris()
X = iris.data  # 特征數據
y = iris.target  # 標簽數據# 劃分訓練集和測試集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 歸一化數據
scaler = MinMaxScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)# 將數據轉換為PyTorch張量并移至GPU
X_train = torch.FloatTensor(X_train).to(device)
y_train = torch.LongTensor(y_train).to(device)
X_test = torch.FloatTensor(X_test).to(device)
y_test = torch.LongTensor(y_test).to(device)class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.fc1 = nn.Linear(4, 10)  # 輸入層到隱藏層self.relu = nn.ReLU()self.fc2 = nn.Linear(10, 3)  # 隱藏層到輸出層def forward(self, x):out = self.fc1(x)out = self.relu(out)out = self.fc2(out)return out# 實例化模型并移至GPU
model = MLP().to(device)# 分類問題使用交叉熵損失函數
criterion = nn.CrossEntropyLoss()# 使用隨機梯度下降優化器
optimizer = optim.SGD(model.parameters(), lr=0.01)# 訓練模型
num_epochs = 20000  # 訓練的輪數# 用于存儲每200個epoch的損失值和對應的epoch數
train_losses = []  # 存儲訓練集損失
test_losses = []   # 存儲測試集損失
epochs = []# ===== 新增早停相關參數 =====
best_test_loss = float('inf')  # 記錄最佳測試集損失
best_epoch = 0                 # 記錄最佳epoch
patience = 50                # 早停耐心值(連續多少輪測試集損失未改善時停止訓練)
counter = 0                    # 早停計數器
early_stopped = False          # 是否早停標志
# ==========================start_time = time.time()  # 記錄開始時間# 創建tqdm進度條
with tqdm(total=num_epochs, desc="訓練進度", unit="epoch") as pbar:# 訓練模型for epoch in range(num_epochs):# 前向傳播outputs = model(X_train)  # 隱式調用forward函數train_loss = criterion(outputs, y_train)# 反向傳播和優化optimizer.zero_grad()train_loss.backward()optimizer.step()# 記錄損失值并更新進度條if (epoch + 1) % 200 == 0:# 計算測試集損失model.eval()with torch.no_grad():test_outputs = model(X_test)test_loss = criterion(test_outputs, y_test)model.train()train_losses.append(train_loss.item())test_losses.append(test_loss.item())epochs.append(epoch + 1)# 更新進度條的描述信息pbar.set_postfix({'Train Loss': f'{train_loss.item():.4f}', 'Test Loss': f'{test_loss.item():.4f}'})# ===== 新增早停邏輯 =====if test_loss.item() < best_test_loss: # 如果當前測試集損失小于最佳損失best_test_loss = test_loss.item() # 更新最佳損失best_epoch = epoch + 1 # 更新最佳epochcounter = 0 # 重置計數器# 保存最佳模型torch.save(model.state_dict(), 'best_model.pth')else:counter += 1if counter >= patience:print(f"早停觸發!在第{epoch+1}輪,測試集損失已有{patience}輪未改善。")print(f"最佳測試集損失出現在第{best_epoch}輪,損失值為{best_test_loss:.4f}")early_stopped = Truebreak  # 終止訓練循環# ======================# 每1000個epoch更新一次進度條if (epoch + 1) % 1000 == 0:pbar.update(1000)  # 更新進度條# 確保進度條達到100%if pbar.n < num_epochs:pbar.update(num_epochs - pbar.n)  # 計算剩余的進度并更新time_all = time.time() - start_time  # 計算訓練時間
print(f'Training time: {time_all:.2f} seconds')# ===== 新增:加載最佳模型用于最終評估 =====
if early_stopped:print(f"加載第{best_epoch}輪的最佳模型進行最終評估...")model.load_state_dict(torch.load('best_model.pth'))
# ================================# 可視化損失曲線
plt.figure(figsize=(10, 6))
plt.plot(epochs, train_losses, label='Train Loss')
plt.plot(epochs, test_losses, label='Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Test Loss over Epochs')
plt.legend()
plt.grid(True)
plt.show()# 在測試集上評估模型
model.eval()
with torch.no_grad():outputs = model(X_test)_, predicted = torch.max(outputs, 1)correct = (predicted == y_test).sum().item()accuracy = correct / y_test.size(0)print(f'測試集準確率: {accuracy * 100:.2f}%')    

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/92155.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/92155.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/92155.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

ORACLE進階操作

1 事務 事務的任務便是使數據庫從一種狀態變換成為另一種狀態&#xff0c;這不同于文件系統&#xff0c;它是數據庫所特用的。 所有的數據庫中&#xff0c;事務只針對DML&#xff08;增刪改)&#xff0c;不針對select select只能查看其他事務提交或回滾的數據&#xff0c;不能查…

Modbus 的一些理解

疑問&#xff1a;&#xff08;使用的是Modbustcp&#xff09;我在 Modbus slave 上面設置了slave地址為1&#xff0c;位置為40001的位置的值為1&#xff0c;40001這個位置上面的值是怎么存儲的&#xff0c;存儲在哪里的&#xff1f;他們是怎么進行交互的&#xff1f;在Modbus協…

【運動控制框架】WPF運動控制框架源碼,可用于激光切割機,雕刻機,分板機,點膠機,插件機等設備,開箱即用

WPF運動控制框架源碼&#xff0c;可用于激光切割機&#xff0c;雕刻機&#xff0c;分板機&#xff0c;點膠機&#xff0c;插件機等設備&#xff0c;考慮到各運動控制硬件不同&#xff0c;視覺應用功能&#xff08;應用視覺軟件&#xff09;也不同&#xff0c;所以只開發各路徑編…

RabbitMQ-日常運維命令

作者介紹&#xff1a;簡歷上沒有一個精通的運維工程師。請點擊上方的藍色《運維小路》關注我&#xff0c;下面的思維導圖也是預計更新的內容和當前進度(不定時更新)。中間件&#xff0c;我給它的定義就是為了實現某系業務功能依賴的軟件&#xff0c;包括如下部分:Web服務器代理…

【Linux基礎知識系列】第九十篇 - 使用awk進行文本處理

在Linux系統中&#xff0c;文本處理是一個常見的任務&#xff0c;尤其是在處理日志文件、配置文件和數據文件時。awk是一個功能強大的文本處理工具&#xff0c;廣泛用于數據提取、分析和格式化。它不僅可以處理簡單的文本文件&#xff0c;還可以處理復雜的結構化數據&#xff0…

第二十七天(數據結構:圖)

圖&#xff1a;是一種非線性結構形式化的描述: G{V,R}V:圖中各個頂點元素(如果這個圖代表的是地圖&#xff0c;這個頂點就是各個點的地址)R:關系集合&#xff0c;圖中頂點與頂點之間的關系(如果是地圖&#xff0c;這個關系集合可能就代表的是各個地點之間的距離)在頂點與頂點…

數據賦能(386)——數據挖掘——迭代過程

概述重要性如下&#xff1a;提升挖掘效果&#xff1a;迭代過程能不斷優化數據挖掘模型&#xff0c;提高挖掘結果的準確性和有效性&#xff0c;從而更好地滿足業務需求。適應復雜數據&#xff1a;數據往往具有復雜性和多樣性&#xff0c;通過迭代可以逐步探索和適應數據的特點&a…

什么是鍵值緩存?讓 LLM 閃電般快速

一、為什么 LLMs 需要 KV 緩存&#xff1f;大語言模型&#xff08;LLMs&#xff09;的文本生成遵循 “自回歸” 模式 —— 每次僅輸出一個 token&#xff08;如詞語、字符或子詞&#xff09;&#xff0c;再將該 token 與歷史序列拼接&#xff0c;作為下一輪輸入&#xff0c;直到…

16.Home-懶加載指令優化

問題1&#xff1a;邏輯書寫位置不合理問題2&#xff1a;重復監聽問題已經加載完畢但是還在監聽

Day116 若依融合mqtt

MQTT 1.MQTT協議概述MQTT是一種基于發布/訂閱模式的輕量級消息傳輸協議&#xff0c;設計用于低帶寬、高延遲或不穩定的網絡環境&#xff0c;廣泛應用于物聯網領域1.1 MQTT協議的應用場景1.智能家居、車聯網、工業物聯網&#xff1a;MQTT可以用于連接各種家電設備和傳感器&#…

PyTorch + PaddlePaddle 語音識別

PyTorch PaddlePaddle 語音識別 目錄 概述環境配置基礎理論數據預處理模型架構設計完整實現案例模型訓練與評估推理與部署性能優化技巧總結 語音識別&#xff08;ASR, Automatic Speech Recognition&#xff09;是將音頻信號轉換為文本的技術。結合PyTorch和PaddlePaddle的…

施耐德 Easy Altivar ATV310 變頻器:高效電機控制的理想選擇(含快速調試步驟及常見故障代碼)

施耐德 Easy Altivar ATV310 變頻器&#xff1a;高效電機控制的理想選擇&#xff08;含快速調試步驟&#xff09;在工業自動化領域&#xff0c;變頻器作為電機控制的核心設備&#xff0c;其性能與可靠性直接影響整個生產系統的效率。施耐德電氣推出的 Easy Altivar ATV310 變頻…

搭建郵件服務器概述

一、電子郵件應用解析標準郵件服務器&#xff08;qq郵箱&#xff09;&#xff1a;1&#xff09;提供電子郵箱&#xff08;lvbuqq.com&#xff09;及存儲空間2&#xff09;為客戶端向外發送郵件給其他郵箱&#xff08;diaochan163.com&#xff09;3&#xff09;接收/投遞其他郵箱…

day28-NFS

1.每日復盤與今日內容1.1復盤Rsync:本地模式、遠程模式&#x1f35f;&#x1f35f;&#x1f35f;&#x1f35f;&#x1f35f;、遠程守護模式&#x1f35f;&#x1f35f;&#x1f35f;&#x1f35f;&#x1f35f;安裝、配置Rsync啟動、測試服務備份案例1.2今日內容NFS優缺點NFS服…

二叉搜索樹--通往高階數據結構的基石

目錄 前言&#xff1a; 1、二叉搜索樹的概念 2、二叉搜索樹性能分析 3、二叉搜索樹的實現 BinarySelectTree.h test.cpp 4、key 和 key / value&#xff08; map 和 set 的鋪墊 &#xff09; 前言&#xff1a; 又回到數據結構了&#xff0c;這次我們將要學習一些復雜的…

Profinet轉Ethernet IP網關接入五軸車床上下料機械手控制系統的配置實例

本案例為西門子1200PLC借助PROFINET轉EtherNet/IP網關與搬運機器人進行連接的配置案例。所需設備包括&#xff1a;西門子1200PLC、Profinet轉EtherNet/IP網關以及發那科&#xff08;Fanuc&#xff09;機器人。開啟在工業自動化控制領域廣泛應用、功能強大且專業的西門子博圖配置…

專題二_滑動窗口_長度最小的子數組

引入&#xff1a;滑動窗口首先&#xff0c;這是滑動窗口的第一道題&#xff0c;所以簡短的說一下滑動窗口的思路&#xff1a;當我們題目要求找一個滿足要求的區間的時候&#xff0c;且這個區間的left和right指針&#xff0c;都只需要同向移動的時候&#xff0c;就可以使用滑動窗…

解鎖高效開發:AWS 前端 Web 與移動應用解決方案詳解

告別繁雜的部署與運維&#xff0c;AWS 讓前端開發者的精力真正聚焦于創造卓越用戶體驗。在當今快速迭代的數字環境中&#xff0c;Web 與移動應用已成為企業與用戶交互的核心。然而&#xff0c;前端開發者常常面臨諸多挑戰&#xff1a;用戶認證的復雜性、后端 API 的集成難題、跨…

北京JAVA基礎面試30天打卡04

1. 單例模式的實現方式及線程安全 單例模式&#xff08;Singleton Pattern&#xff09;確保一個類只有一個實例&#xff0c;并提供一個全局訪問點。以下是常見的單例模式實現方式&#xff0c;以及如何保證線程安全&#xff1a; 單例模式的實現方式餓漢式&#xff08;Eager Init…

Redis 緩存三大核心問題:穿透、擊穿與雪崩的深度解析

引言在現代互聯網架構中&#xff0c;緩存是提升系統性能、降低數據庫壓力的核心手段之一。而 Redis 作為高性能的內存數據庫&#xff0c;憑借其豐富的數據結構、靈活的配置選項以及高效的網絡模型&#xff0c;已經成為緩存領域的首選工具。本文將從 Redis 的基本原理出發&#…