day36 python神經網絡訓練

目錄

一、數據準備與預處理

二、數據集劃分與歸一化

三、構建神經網絡模型

四、定義損失函數和優化器

五、訓練模型

六、評估模型


在機器學習和深度學習的實踐中,信貸風險評估是一個非常重要的應用場景。通過構建神經網絡模型,我們可以對客戶的信用狀況進行預測,從而幫助金融機構更好地管理風險。最近,我嘗試使用PyTorch框架來實現一個信貸風險預測的神經網絡模型,并在這個過程中鞏固了我對神經網絡的理解。以下是我在完成這個任務過程中的詳細記錄和總結。

一、數據準備與預處理

信貸數據集通常包含客戶的各種特征,如收入、信用評分、貸款金額等,以及是否違約的標簽。為了更好地訓練神經網絡模型,數據預處理是必不可少的步驟。

import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler, StandardScaler, OneHotEncoder, LabelEncoder
from imblearn.over_sampling import SMOTE
import matplotlib.pyplot as plt
from tqdm import tqdm# 設置GPU設備
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"使用設備: {device}")# 加載信貸預測數據集
data = pd.read_csv('data.csv')# 丟棄掉Id列
data = data.drop(['Id'], axis=1)# 區分連續特征與離散特征
continuous_features = data.select_dtypes(include=['float64', 'int64']).columns.tolist()
discrete_features = data.select_dtypes(exclude=['float64', 'int64']).columns.tolist()# 離散特征使用眾數進行補全
for feature in discrete_features:if data[feature].isnull().sum() > 0:mode_value = data[feature].mode()[0]data[feature].fillna(mode_value, inplace=True)# 連續變量用中位數進行補全
for feature in continuous_features:if data[feature].isnull().sum() > 0:median_value = data[feature].median()data[feature].fillna(median_value, inplace=True)# 有順序的離散變量進行標簽編碼
mappings = {"Years in current job": {"10+ years": 10,"2 years": 2,"3 years": 3,"< 1 year": 0,"5 years": 5,"1 year": 1,"4 years": 4,"6 years": 6,"7 years": 7,"8 years": 8,"9 years": 9},"Home Ownership": {"Home Mortgage": 0,"Rent": 1,"Own Home": 2,"Have Mortgage": 3},"Term": {"Short Term": 0,"Long Term": 1}
}# 使用映射字典進行轉換
data["Years in current job"] = data["Years in current job"].map(mappings["Years in current job"])
data["Home Ownership"] = data["Home Ownership"].map(mappings["Home Ownership"])
data["Term"] = data["Term"].map(mappings["Term"])# 對沒有順序的離散變量進行獨熱編碼
data = pd.get_dummies(data, columns=['Purpose'])

在上述代碼中,我首先加載了信貸數據集,并對其進行了預處理。具體步驟包括:

  1. 丟棄無用的Id列。

  2. 區分連續特征和離散特征。

  3. 對離散特征使用眾數進行補全,對連續特征使用中位數進行補全。

  4. 對有順序的離散變量進行標簽編碼,對沒有順序的離散變量進行獨熱編碼。

二、數據集劃分與歸一化

在數據預處理完成后,我將數據集劃分為訓練集和測試集,并對特征數據進行歸一化處理。

# 分離特征數據和標簽數據
X = data.drop(['Credit Default'], axis=1)  # 特征數據
y = data['Credit Default']  # 標簽數據# 劃分訓練集和測試集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 對特征數據進行歸一化處理
scaler = MinMaxScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)  # 確保訓練集和測試集是相同的縮放# 將數據轉換為PyTorch張量
X_train = torch.FloatTensor(X_train).to(device)
y_train = torch.LongTensor(y_train.values).to(device)
X_test = torch.FloatTensor(X_test).to(device)
y_test = torch.LongTensor(y_test.values).to(device)

在上述代碼中,我使用了MinMaxScaler對特征數據進行歸一化處理,以確保所有特征的值都在0到1之間。這一步對于神經網絡的訓練非常重要,因為它可以加速模型的收斂速度并提高模型的性能。之后,我將數據轉換為PyTorch張量,并將其移動到指定的設備(GPU或CPU)上。

三、構建神經網絡模型

接下來,我定義了一個簡單的多層感知機(MLP)模型,包含一個輸入層、兩個隱藏層和一個輸出層。隱藏層使用了ReLU激活函數,并添加了Dropout層以防止過擬合。

class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.fc1 = nn.Linear(X_train.shape[1], 64)  # 輸入層到第一隱藏層self.relu = nn.ReLU()self.dropout = nn.Dropout(0.3)  # 添加Dropout防止過擬合self.fc2 = nn.Linear(64, 32)  # 第一隱藏層到第二隱藏層self.fc3 = nn.Linear(32, 2)  # 第二隱藏層到輸出層def forward(self, x):x = self.fc1(x)x = self.relu(x)x = self.dropout(x)x = self.fc2(x)x = self.relu(x)x = self.dropout(x)x = self.fc3(x)return x# 初始化模型
model = MLP().to(device)

在定義模型時,我使用了nn.Module作為基類,并通過forward方法定義了模型的前向傳播邏輯。這種模塊化的定義方式使得模型的結構清晰且易于擴展。

四、定義損失函數和優化器

損失函數和優化器是神經網絡訓練的兩個關鍵組件。對于分類任務,交叉熵損失函數(CrossEntropyLoss)是最常用的損失函數之一。優化器則負責根據損失函數的梯度更新模型的參數,我選擇了隨機梯度下降(SGD)優化器。

criterion = nn.CrossEntropyLoss()  # 使用交叉熵損失函數
optimizer = optim.SGD(model.parameters(), lr=0.01)  # 使用SGD優化器

五、訓練模型

訓練模型的過程是一個迭代優化的過程。在每一輪迭代中,模型會計算損失函數的值,并通過反向傳播更新參數。為了監控訓練過程,我每10輪打印一次損失值。

num_epochs = 200  # 訓練輪數
for epoch in range(num_epochs):model.train()  # 設置為訓練模式optimizer.zero_grad()  # 清空梯度outputs = model(X_train)  # 前向傳播loss = criterion(outputs, y_train)  # 計算損失loss.backward()  # 反向傳播optimizer.step()  # 更新參數if (epoch + 1) % 10 == 0:print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')

通過上述代碼,我成功地訓練了模型,并觀察到損失值隨著訓練輪數的增加而逐漸降低。這表明模型正在逐步學習數據中的規律。

六、評估模型

訓練完成后,我使用測試集對模型的性能進行了評估。評估指標是準確率,即模型正確預測的樣本數占總樣本數的比例。

model.eval()  # 設置為評估模式
with torch.no_grad():correct = 0total = 0outputs = model(X_test)_, predicted = torch.max(outputs.data, 1)total += y_test.size(0)correct += (predicted == y_test).sum().item()accuracy = 100 * correct / total
print(f'Accuracy on test set: {accuracy:.2f}%')

最終,模型在測試集上的準確率達到了 [具體準確率]%。雖然這個結果還有提升的空間,但它已經證明了神經網絡在信貸風險評估任務中的有效性。

@浙大疏錦行

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/906909.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/906909.shtml
英文地址,請注明出處:http://en.pswp.cn/news/906909.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

如何確定是不是一個bug?

在軟件測試過程中,我們經常會遇到一些異常現象,但并非所有異常都是Bug。如何準確判斷一個問題是否屬于Bug?本文將從Bug的定義、判定標準、常見誤區和實戰技巧四個方面展開,幫助測試工程師提高Bug判定的準確性。 1. Bug的定義:什么情況下算Bug? 一個Bug(缺陷)通常指軟件…

Lombok與Jackson實現高效JSON序列化與反序列化

引言 在Java開發中&#xff0c;處理JSON數據是常見需求&#xff0c;而Jackson作為廣泛使用的JSON庫&#xff0c;能夠高效地將Java對象與JSON互相轉換。然而&#xff0c;傳統的POJO&#xff08;Plain Old Java Object&#xff09;需要手動編寫大量樣板代碼&#xff08;如getter…

論文閱讀:PURPLE: Making a Large Language Model a Better SQL Writer

論文地址&#xff1a;PURPLE: Making a Large Language Model a Better SQL Writer 摘要 大語言模型&#xff08;LLM&#xff09;技術在自然語言到 SQL&#xff08;NL2SQL&#xff09;翻譯中扮演著越來越重要的角色。通過大量語料訓練的 LLM 具有強大的自然語言理解能力和基本…

【圖像大模型】ControlNet:深度條件控制的生成模型架構解析

ControlNet&#xff1a;深度條件控制的生成模型架構解析 一、核心原理與技術突破1.1 基礎架構設計1.2 零卷積初始化1.3 多條件控制機制 二、系統架構與實現細節2.1 完整處理流程2.2 性能指標對比 三、實戰部署指南3.1 環境配置3.2 基礎推理代碼3.3 高級控制參數 四、典型問題解…

【從0到1搞懂大模型】chatGPT 中的對齊優化(RLHF)講解與實戰(9)

GPT系列模型的演進 chatgpt系列模型演進的重要節點包含下面幾個模型&#xff08;當然&#xff0c;這兩年模型發展太快了&#xff0c;4o這些推理模型我就先不寫了&#xff09; (Transformer) → GPT-1 → GPT-2 → GPT-3 → InstructGPT/ChatGPT(GPT-3.5) → GPT-4 下面介紹一…

2025年AEI SCI1區TOP,改進麻雀搜索算法MSSA+建筑三維重建,深度解析+性能實測

目錄 1.摘要2.麻雀搜索算法SSA原理3.整體框架4.改進SSA算法5.結果展示6.參考文獻7.代碼獲取8.讀者交流 1.摘要 對現有建筑進行高質量的三維重建對于其維護、修復和管理至關重要。圖像采集中的有效視角規劃會顯著影響基于攝影測量的三維重建質量。復雜的建筑結構常常導致傳統視…

鴻蒙開發:如何實現列表吸頂

前言 本文基于Api13 列表吸頂功能&#xff0c;在實際的開發中有著很大的作用&#xff0c;比如可以讓列表層級之間更加分明&#xff0c;減少一定程度上的視覺混亂&#xff0c;由于吸頂的標題會隨著滾動固定在頂部&#xff0c;可以讓用戶無需反復滑動回頂部確認分組位置&#xff…

使用Zotero的RSS訂閱功能快速了解感興趣領域最新文章

文章目錄 寫在前面中文期刊的RSS訂閱英文期刊的RSS訂閱回到Zotero有啥用&#xff1f; 寫在前面 作為一名研究生或者科研工作者&#xff0c;肯定需要經常檢索自己研究領域的最新文獻&#xff0c;相比于不定期的去各大數據庫檢索文獻&#xff0c;借助RSS訂閱功能則更加便捷。 R…

Windows安裝Docker Desktop開啟 Kubenetes制作并部署本地鏡像

1、安裝Docker Desktop docker desktop官方下載鏈接&#xff0c;下載后一路點下來安裝就好了。 2、制作本地鏡像 跟著docker步驟制作鏡像&#xff0c;需要先配置docker 鏡像源&#xff0c;因為網絡問題 {"builder": {"gc": {"defaultKeepStorage&…

嵌入式學習筆記 - freeRTOS 列表,鏈表,節點跟任務之間關系

一 下圖說明了 freeRTOS 就緒列表&#xff0c;鏈表&#xff0c;節點跟任務之間關系 一個任務對應一個節點&#xff0c;一個鏈表對應一個優先級&#xff0c;一個任務根據優先級可以插入任何一個鏈表中。 插入函數為&#xff0c;這也是freeRTOS的核心函數&#xff0c;對每個任務…

scikit-learn pytorch transformers 區別與聯系

以下是 scikit-learn、PyTorch 和 Transformers 的區別與聯系的表格形式展示: 特性/庫scikit-learnPyTorchTransformers主要用途傳統機器學習算法深度學習框架預訓練語言模型與自然語言處理任務核心功能分類、回歸、聚類、降維、模型選擇等張量計算、自動微分、神經網絡構建與…

【C/C++】從零開始掌握Kafka

文章目錄 從零開始掌握Kafka一、Kafka 基礎知識理解&#xff08;理論&#xff09;1. 核心組件與架構2. 重點概念解析 二、Kafka 面試重點知識梳理三、C 使用 Kafka 的實踐&#xff08;librdkafka&#xff09;1. librdkafka 簡介2. 安裝 librdkafka 四、實戰&#xff1a;高吞吐生…

Spyglass:目標文件(.spq)的結構

相關閱讀 Spyglasshttps://blog.csdn.net/weixin_45791458/category_12828934.html?spm1001.2014.3001.5482 預備知識 為了方便檢查&#xff0c;Spyglass向用戶提供Guideware作為檢查參考&#xff1b;Guideware又包含各種方法(Methodology)&#xff0c;應用于設計的不同階段&…

一些Dify聊天系統組件流程圖架構圖

分享一些有助于深入理解Dify聊天模塊的架構圖 整體組件架構圖 #mermaid-svg-0e2XalGLqrRbH1Jy {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-0e2XalGLqrRbH1Jy .error-icon{fill:#552222;}#mermaid-svg-0e2XalGLq…

地理空間索引:解鎖日志分析中的位置智慧

title: 地理空間索引:解鎖日志分析中的位置智慧 date: 2025/05/24 18:43:06 updated: 2025/05/24 18:43:06 author: cmdragon excerpt: 地理空間索引在日志分析中應用廣泛,涉及用戶登錄IP定位、移動端位置軌跡和物聯網設備位置上報等場景。MongoDB支持2dsphere和2d兩種地理…

分庫分表深度解析

一、為什么要分庫分表&#xff1f; 通常&#xff0c;數據庫性能受到如下幾個限制&#xff1a; 硬件瓶頸&#xff1a;單機的 CPU、內存、磁盤 I/O 等資源總是有限。例如&#xff0c;當單表中的記錄達到上億、甚至更高時&#xff0c;表掃描、索引維護和數據遷移會變得非常慢。單…

QListWidget的函數,信號介紹

前言 Qt版本:6.8.0 該類用于列表模型/視圖 QListWidgetItem函數介紹 作用 QListWidget是Qt框架中用于管理可交互列表項的核心組件&#xff0c;主要作用包括&#xff1a; 列表項管理 支持動態添加/刪除項&#xff1a;addItem(), takeItem()批量操作&#xff1a;addItems()…

ModbusRTU轉profibusDP網關與RAC400通訊報文解析

ModbusRTU轉profibusDP網關與RAC400通訊報文解析 在工業自動化領域&#xff0c;ModbusRTU和ProfibusDP是兩種常見的通信協議。ModbusRTU以其簡單、可靠、易于實現等特點&#xff0c;廣泛應用于各種工業設備之間的通信&#xff1b;而ProfibusDP則是一種高性能的現場總線標準&am…

Python容器

一、容器 1. 列表【】&#xff1a;有序可重復可混裝可修改 [元素1&#xff0c;元素2&#xff0c;元素3&#xff0c;...] ? 可以容納多個元素 ? 可以容納不同類型的元素&#xff08;混裝&#xff09; ? 數據是有序存儲的&#xff08;有下標序號&#xff09; ? 允許重復數…

webpack面試問題

一、核心概念 Webpack的構建流程是什么? 答案: 初始化:讀取配置,創建Compiler對象編譯:從入口文件開始,遞歸分析依賴關系,生成依賴圖模塊處理:調用Loader轉換模塊(如babel-loader)輸出:將處理后的模塊組合成Chunk,生成最終文件Loader和Plugin的區別? Loader:文件…