python打卡第37天

知識點回顧:

  1. 過擬合的判斷:測試集和訓練集同步打印指標
  2. 模型的保存和加載
    1. 僅保存權重
    2. 保存權重和模型
    3. 保存全部信息checkpoint,還包含訓練狀態
  3. 早停策略

作業:對信貸數據集訓練后保存權重,加載權重后繼續訓練50輪,并采取早停策略

import pandas as pd
import numpy as np
import random
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import classification_report, roc_auc_score
import matplotlib.pyplot as pltdef set_seed(seed=42):random.seed(seed)np.random.seed(seed)torch.manual_seed(seed)torch.cuda.manual_seed_all(seed)torch.backends.cudnn.deterministic = Trueset_seed(42)# 讀取數據
data = pd.read_csv('data.csv')
target_col = 'Credit Default'
data = data.fillna(data.median(numeric_only=True))
data = data.fillna('Unknown')categorical_features = ['Home Ownership', 'Purpose', 'Term', 'Years in current job']
numerical_features = [col for col in data.columns if col not in categorical_features + [target_col]]for col in categorical_features:le = LabelEncoder()data[col] = le.fit_transform(data[col])X = data[categorical_features + numerical_features]
y = data[target_col]scaler = StandardScaler()
X[numerical_features] = scaler.fit_transform(X[numerical_features])X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, stratify=y, random_state=42)class CreditDataset(Dataset):def __init__(self, X, y):self.X = torch.tensor(X.values, dtype=torch.float32)self.y = torch.tensor(y.values, dtype=torch.float32)def __len__(self):return len(self.X)def __getitem__(self, idx):return self.X[idx], self.y[idx]train_dataset = CreditDataset(X_train, y_train)
test_dataset = CreditDataset(X_test, y_test)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)class CreditNet(nn.Module):def __init__(self, input_dim):super(CreditNet, self).__init__()self.model = nn.Sequential(nn.Linear(input_dim, 64),nn.BatchNorm1d(64),nn.ReLU(),nn.Dropout(0.3),nn.Linear(64, 32),nn.BatchNorm1d(32),nn.ReLU(),nn.Dropout(0.2),nn.Linear(32, 1))def forward(self, x):return self.model(x).squeeze(1)device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = CreditNet(X_train.shape[1]).to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)def train(model, loader, criterion, optimizer):model.train()total_loss = 0for X_batch, y_batch in loader:X_batch, y_batch = X_batch.to(device), y_batch.to(device)optimizer.zero_grad()outputs = model(X_batch)loss = criterion(outputs, y_batch)loss.backward()optimizer.step()total_loss += loss.item() * X_batch.size(0)return total_loss / len(loader.dataset)def evaluate(model, loader):model.eval()preds, targets = [], []with torch.no_grad():for X_batch, y_batch in loader:X_batch = X_batch.to(device)outputs = torch.sigmoid(model(X_batch)).cpu().numpy()preds.extend(outputs)targets.extend(y_batch.numpy())preds = np.array(preds)targets = np.array(targets)preds_label = (preds > 0.5).astype(int)auc = roc_auc_score(targets, preds)report = classification_report(targets, preds_label, digits=4)return auc, report# 訓練主循環
epochs = 20
train_losses = []
test_aucs = []for epoch in range(epochs):train_loss = train(model, train_loader, criterion, optimizer)auc, _ = evaluate(model, test_loader)train_losses.append(train_loss)test_aucs.append(auc)print(f"Epoch {epoch+1}/{epochs} - Train Loss: {train_loss:.4f} - Test AUC: {auc:.4f}")# 可視化訓練損失和AUC曲線
plt.figure(figsize=(12,5))
plt.subplot(1,2,1)
plt.plot(range(1, epochs+1), train_losses, marker='o')
plt.xlabel('Epoch')
plt.ylabel('Train Loss')
plt.title('Training Loss Curve')
plt.grid(True)
plt.subplot(1,2,2)
plt.plot(range(1, epochs+1), test_aucs, marker='o', color='orange')
plt.xlabel('Epoch')
plt.ylabel('Test AUC')
plt.title('Test AUC Curve')
plt.grid(True)
plt.tight_layout()
plt.show()# 保存模型權重
torch.save(model.state_dict(), "credit_model.pth")
# 定義早停類
class EarlyStopping:def __init__(self, patience=5, delta=1e-4):self.patience = patienceself.delta = deltaself.best_score = Noneself.counter = 0self.early_stop = Falsedef __call__(self, score):if self.best_score is None or score > self.best_score + self.delta:self.best_score = scoreself.counter = 0else:self.counter += 1if self.counter >= self.patience:self.early_stop = True# 加載權重并繼續訓練
model.load_state_dict(torch.load("credit_model.pth"))
epochs_continue = 50
early_stopping = EarlyStopping(patience=5, delta=1e-4)
train_losses2 = []
test_aucs2 = []for epoch in range(epochs_continue):train_loss = train(model, train_loader, criterion, optimizer)auc, _ = evaluate(model, test_loader)train_losses2.append(train_loss)test_aucs2.append(auc)print(f"[Continue] Epoch {epoch+1}/{epochs_continue} - Train Loss: {train_loss:.4f} - Test AUC: {auc:.4f}")early_stopping(auc)if early_stopping.early_stop:print("Early stopping triggered!")break# 可視化繼續訓練的曲線
plt.figure(figsize=(12,5))
plt.subplot(1,2,1)
plt.plot(range(1, len(train_losses2)+1), train_losses2, marker='o')
plt.xlabel('Epoch')
plt.ylabel('Train Loss')
plt.title('Continue Training Loss Curve')
plt.grid(True)
plt.subplot(1,2,2)
plt.plot(range(1, len(test_aucs2)+1), test_aucs2, marker='o', color='orange')
plt.xlabel('Epoch')
plt.ylabel('Test AUC')
plt.title('Continue Test AUC Curve')
plt.grid(True)
plt.tight_layout()
plt.show()# 最終評估
auc, report = evaluate(model, test_loader)
print(f"\nFinal Test AUC: {auc:.4f}")
print("Classification Report:\n", report)

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/81483.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/81483.shtml
英文地址,請注明出處:http://en.pswp.cn/web/81483.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

【洛谷P9303題解】AC- [CCC 2023 J5] CCC Word Hunt

在CCC單詞搜索游戲中,單詞隱藏在一個字母網格中。目標是確定給定單詞在網格中隱藏的次數。單詞可以以直線或直角的方式排列。以下是詳細的解題思路及代碼實現: 傳送門: https://www.luogu.com.cn/problem/P9303 解題思路 輸入讀取與初始化&…

LangGraph + LLM + stream_mode

文章目錄 LLM 代碼valuesmessagesupdatesmessages updatesmessages updates 2 LLM 代碼 from dataclasses import dataclassfrom langchain.chat_models import init_chat_model from langgraph.graph import StateGraph, STARTfrom langchain_openai import ChatOpenAI # 初…

Pydantic 學習與使用

Pydantic 學習與使用 在 Fastapi 的 Web 開發中的數據驗證通常都是在使用 Pydantic 來進行數據的校驗,本文將對 Pydantic 的使用方法做記錄與學習。 **簡介:**Pydantic 是一個在 Python 中用于數據驗證和解析的第三方庫,它現在是 Python 使…

批量文件重命名工具

分享一個自己使用 python 開發的小軟件,批量文件重命名工具,主要功能有批量中文轉拼音,簡繁體轉換,大小寫轉換,替換文件名,刪除指定字符,批量添加編號,添加前綴/后綴。同時還有文件時…

多語言視角下的 DOM 操作:從 JavaScript 到 Python、Java 與 C#

多語言視角下的 DOM 操作:從 JavaScript 到 Python、Java 與 C# 在 Web 開發中,文檔對象模型(DOM)是構建動態網頁的核心技術。它將 HTML/XML 文檔解析為樹形結構,允許開發者通過編程方式訪問和修改頁面內容、結構和樣…

【C/C++】紅黑樹學習筆記

文章目錄 紅黑樹1 基本概念1.1 定義1.2 基本特性推理1.3 對比1.4 延伸1.4.1 簡單判別是否是紅黑樹1.4.2 應用 2 插入2.1 插入結點默認紅色2.2 插入結點2.2.1 插入結點是根結點2.2.2 插入結點的叔叔是紅色2.2.3 插入結點的叔叔是黑色場景分析LL型RR型LR型RL型 3 構建4 示例代碼 …

網絡通信的基石:深入理解幀與報文

在這個萬物互聯的時代,我們每天都在享受著網絡帶來的便利——從早晨查看天氣預報,到工作中的視頻會議,再到晚上刷著短視頻放松。然而,在這些看似簡單的網絡交互背后,隱藏著精密而復雜的數據傳輸機制。今天,…

STM32 SPI通信(硬件)

一、SPI外設簡介 STM32內部集成了硬件SPI收發電路,可以由硬件自動執行時鐘生成、數據收發等功能,減輕CPU的負擔 可配置8位/16位數據幀、高位先行/低位先行 時鐘頻率: fPCLK / (2, 4, 8, 16, 32, 64, 128, 256) 支持多主機模型、主或從操作 可…

尚硅谷redis7-11-redis10大類型之總體概述

前提:我們說的數據類型一般是value的數據類型,key的類型都是字符串。 redis字符串【String】 string類型是二進制安全的,意思是redis的string可以包含任何數據,比如jpg圖片或者序列化的對象。 string類型是Redis最基本的數據類型,一個redis中字符串va…

【遞歸、搜索與回溯算法】專題一 遞歸

文章目錄 0.理解遞歸、搜索與回溯1.面試題 08.06.漢諾塔問題1.1 題目1.2 思路1.3 代碼 2. 合并兩個有序鏈表2.1 題目2.2 思路2.3 代碼 3.反轉鏈表3.1 題目3.2 思路3.3 代碼 4.兩兩交換鏈表中的節點4.1 題目4.2 思路4.3 代碼 5. Pow(x, n) - 快速冪5.1 題目5.2 思路5.3 代碼 0.理…

C#實現List導出CSV:深入解析完整方案

C#實現List導出CSV:深入解析完整方案 在數據交互場景中,CSV文件憑借其跨平臺兼容性和簡潔性,成為數據交換的重要載體。本文將基于C#反射機制實現的通用CSV導出方案,結合實際開發中的痛點,從基礎實現、深度優化到生產級…

字符串day7

344 反轉字符串 字符串理論上也是一個數組&#xff0c;因此只需要用雙指針即可 class Solution { public:void reverseString(vector<char>& s) {for(int i0,js.size()-1;i<j;i,j--){swap(s[i],s[j]);}} };541 反轉字符串 自己實現一個反轉從start到end的字符串…

Grafana XSSOpenRedirectSSRF漏洞復現(CVE-2025-4123)

免責申明: 本文所描述的漏洞及其復現步驟僅供網絡安全研究與教育目的使用。任何人不得將本文提供的信息用于非法目的或未經授權的系統測試。作者不對任何由于使用本文信息而導致的直接或間接損害承擔責任。如涉及侵權,請及時與我們聯系,我們將盡快處理并刪除相關內容。 前…

私服 nexus 之間遷移 npm 倉庫

本文介紹如何將一個 Nexus 特定倉庫中的 npm 包內容遷移到另一個 Nexus 特定倉庫。此過程適用于需要重構倉庫結構或合并倉庫的場景。 遷移腳本 以下是完整的遷移腳本&#xff0c;它會自動完成以下操作&#xff1a; 從源倉庫獲取所有 npm 包列表下載每個包的 .tgz 文件解壓并…

Django ToDoWeb 服務

我們的任務是使用 Django 創建一個簡單的 ToDo 應用程序,允許用戶添加、查看和刪除筆記。我們將通過設置 Django 項目、創建 Todo 模型、設計表單和視圖來處理用戶輸入以及創建模板來顯示任務來構建它。我們將逐步實現核心功能以有效地管理 todo 項。 Django ToDoWeb 服務 …

阿里云服務器遭遇DDoS攻擊?低成本第三方高防解決方案全解析

阿里云服務器因高性能和穩定性備受青睞&#xff0c;但其DDoS高防服務的價格常讓中小企業望而卻步。面對動輒每月數萬元的防護成本&#xff0c;許多用戶不禁疑問&#xff1a;能否通過第三方高防服務保護阿里云服務器&#xff1f;如何實現低成本高效防御&#xff1f; 本文將結合技…

2025山東CCPC補題

2025山東CCPC補題 目錄 2025山東CCPC補題K - UNO&#xff01; &#xff08;雙端隊列的簡單應用&#xff09;M - 第九屆河北省大學生程序設計競賽 &#xff08;二進制枚舉模擬&#xff09;J - Generate 01 String 感覺這場比賽的題目挺不錯的&#xff1b;沒有說那些為了算法而算…

體繪制學習

一、基本概念 體繪制是對一個三維物體數據進行采樣與擬合的過程。 在體繪制中用vtkVolume渲染數據 渲染數據類數據轉換類幾何渲染vtkActorvtkPolyDataMapper體渲染vtkVolumevtkVolumeRayCastMapper 體繪制常用算法如下。 光線投射法。 優點是可視化結果質量好。缺點是計算…

告別“盤絲洞”車間:4-20mA無線傳輸如何重構工廠神經網?

4-20ma無線傳輸是利用無線模塊將傳統的溫度、壓力、液位等4-20mA電流信號轉換為無線信號進行傳輸。這一技術突破了有線傳輸的限制&#xff0c;使得信號可以在更廣泛的范圍內進行靈活、快速的傳遞&#xff0c;無線傳輸距離可達到50KM。達泰4-20ma無線傳輸模塊在實現工業現場應用…

VB.NET與SQL連接問題解決方案

1.基本連接步驟 使用SqlConnection、SqlCommand和SqlDataReader進行基礎操作&#xff1a; vb.net Imports System.Data.SqlClient Public Sub ConnectToDatabase() Dim connectionString As String "ServermyServerAddress;DatabasemyDataBase;Integrated Security…