神經網絡基礎-價格分類案例

文章目錄

    • 1. 需求分析
    • 2. 導入所需工具包
    • 3. 構建數據集
    • 4. 構建分類網絡模型
    • 5. 訓練模型
    • 6. 模型訓練
    • 7. 評估模型
    • 8. 模型優化

學習目標:

  1. 掌握構建分類模型流程
  2. 動手實踐整個過程

1. 需求分析

小明創辦了一家手機公司,他不知道如何估算手機產品的價格。為了解決這個問題,他收集了多家公司的手機銷售數據。該數據為二手手機的各個性能的數據,最后根據這些性能得到4個價格區間,作為這些二手手機售出的價格區間。主要包括:

battery_power電池一次可存儲的電量,單位:毫安/時
blue是否有藍牙
clock_speed微處理器執行指令的速度
dual_sim是否支持雙卡
fc前置攝像頭百萬像素
four_g是否有4G
int_memory內存(GB)
m_dep移動深度(cm)
mobile_wt手機重量
n_cores處理器內核數
pc主攝像頭百萬像素
px_height像素分辨率高度
px_width像素分辨率寬度
ram隨機存儲器(兆字節)
sc_h手機屏幕高度(cm)
sc_w手機屏幕寬度(cm)
talk_time一次充電持續時長
three_g是否有3G
touch_screen是否有觸屏控制
wifi是否能連wifi
price_range價格區間(0,1,2,3)

我們需要幫助小明找出手機的功能(例如:RAM等)與其售價之間的某種關系。我們可以使用機器學習的方法來解決這個問題,也可以構建一個全連接的網絡。

需要注意的是: 在這個問題中,我們不需要預測實際價格,而是一個價格范圍,它的范圍使用 0、1、2、3 來表示,所以該問題也是一個分類問題。接下來我們還是按照四個步驟來完成這個任務:

  • 準備訓練集數據

  • 構建要使用的模型

  • 模型訓練

  • 模型預測評估

2. 導入所需工具包

# 導入相關模塊
import torch
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import time

3. 構建數據集

數據共有 2000 條, 其中 1600 條數據作為訓練集, 400 條數據用作測試集。 我們使用 sklearn 的數據集劃分工作來完成。并使用 PyTorch 的 TensorDataset 來將數據集構建為 Dataset 對象,方便構造數據集加載對象。

#1. 導入相關模塊
import torch
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import time# 構建數據集
def load_dataset():# 使用pandas 讀取數據data = pd.read_csv('data/手機價格預測.csv')# 特征值和目標值x,y = data.iloc[:,:-1],data.iloc[:,-1]# 類型轉換:特征值,目標值x = x.astype(np.float32)y = y.astype(np.int64)# 劃分訓練集和測試集x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=88)# 構建數據集,轉換為pytorch格式train_dataset = TensorDataset(torch.from_numpy(x_train.values), torch.from_numpy(y_train.values))test_dataset = TensorDataset(torch.from_numpy(x_test.values), torch.from_numpy(y_test.values))#返回結果return train_dataset, test_dataset,x_train.shape[1],len(np.unique(y))if __name__ == '__main__':train_dataset, test_dataset,input_dim,class_num = load_dataset()print("輸入特征數:",input_dim)print("分類個數:",class_num)

輸出結果為:

輸入特征數: 20
分類個數: 4

4. 構建分類網絡模型

構建全連接神經網絡來進行手機價格分類,該網絡主要由三個線性層來構建,使用relu激活函數。

網絡共有 3 個全連接層, 具體信息如下:

  1. 第一層: 輸入為維度為 20, 輸出維度為: 128
  2. 第二層: 輸入為維度為 128, 輸出維度為: 256
  3. 第三層: 輸入為維度為 256, 輸出維度為: 4
# 構建網絡模型
class PhonePriceModel(nn.Module):def __init__(self,input_dim,output_dim):super(PhonePriceModel, self).__init__()# 1. 第一層: 輸入為維度為 20, 輸出維度為: 128self.linear1 = nn.Linear(input_dim, 128)# 2. 第二層: 輸入為維度為 128, 輸出維度為: 256self.linear2 = nn.Linear(128, 256)# 3. 第三層: 輸入為維度為 256, 輸出維度為: 4self.linear3 = nn.Linear(256, output_dim)def forward(self, x):# 前向傳播過程x = torch.relu(self.linear1(x))x = torch.relu(self.linear2(x))output = self.linear3(x)# 獲取數據結果return outputif __name__ == '__main__':train_dataset, test_dataset,input_dim,class_num = load_dataset()print("輸入特征數:",input_dim)print("分類個數:",class_num)# 模型實例化model = PhonePriceModel(input_dim,class_num)

5. 訓練模型

網絡編寫完成之后,我們需要編寫訓練函數。所謂的訓練函數,指的是輸入數據讀取、送入網絡、計算損失、更新參數的流程,該流程較為固定。我們使用的是多分類交叉生損失函數、使用 SGD 優化方法。最終,將訓練好的模型持久化到磁盤中。

# 模型訓練過程
def train(train_dataset,input_dim,class_num):# 固定隨機數種子torch.manual_seed(0)# 初始化模型model = PhonePriceModel(input_dim,class_num)# 損失函數criterion = nn.CrossEntropyLoss()# 優化方法optimizer = optim.SGD(model.parameters(), lr=1e-3)# 訓練輪數num_epochs = 50# 遍歷輪數for epoch_idx in range(num_epochs):# 初始化數據加載器dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)# 訓練時間start = time.time()# 計算損失total_loss = 0.0total_num = 1# 遍歷每個batch數據進行處理for x,y in dataloader:# 將數據送入網絡中進行預測output = model(x)# 計算損失loss = criterion(output, y)#梯度清零optimizer.zero_grad()# 方向傳播loss.backward()# 參數更新optimizer.step()# 損失計算total_num += 1total_loss += loss.item()# 打印損失變換結果print('epoch: %4s loss: %.2f, time: %.2fs' % (epoch_idx + 1, total_loss / total_num, time.time() - start))# 保存模型torch.save(model.state_dict(), 'model/phone.ptn')

6. 模型訓練

if __name__ == '__main__':train_dataset, test_dataset,input_dim,class_num = load_dataset()print("輸入特征數:",input_dim)print("分類個數:",class_num)# 模型訓練過程train(train_dataset,input_dim,class_num)

輸出結果:

epoch:    1 loss: 13.31, time: 0.25s
epoch:    2 loss: 0.96, time: 0.24s
epoch:    3 loss: 0.90, time: 0.24s
epoch:    4 loss: 0.89, time: 0.25s
epoch:    5 loss: 0.86, time: 0.26s
...
epoch:   46 loss: 0.68, time: 0.25s
epoch:   47 loss: 0.69, time: 0.26s
epoch:   48 loss: 0.68, time: 0.28s
epoch:   49 loss: 0.69, time: 0.24s
epoch:   50 loss: 0.69, time: 0.24s

7. 評估模型

使用訓練好的模型,對未知的樣本的進行預測的過程。我們這里使用前面單獨劃分出來的驗證集來進行評估。

# 4 評估模型
def test(test_dataset,input_dim,class_num):# 加載模型和訓練好的網絡參數model = PhonePriceModel(input_dim,class_num)model.load_state_dict(torch.load('model/phone.ptn',weights_only=False))# 構建加載器dataloader = DataLoader(test_dataset, batch_size=8, shuffle=True)# 評估測試集correct = 0# 遍歷測試集中的數據for x,y in dataloader:# 將其送入網絡中output = model(x)# 獲取類別結果y_pred = torch.argmax(output, dim=1)# 獲取預測正確的個數correct += (y_pred == y).sum().sum()# 求預測精度print('Acc: %.5f' % (correct.item() / len(test_dataset)))
if __name__ == '__main__':train_dataset, test_dataset,input_dim,class_num = load_dataset()print("輸入特征數:",input_dim)print("分類個數:",class_num)# 評估模型test(test_dataset,input_dim,class_num)

輸出結果:

Acc: 0.62500

8. 模型優化

我們前面的網絡模型在測試集的準確率為: 0.54750, 我們可以通過以下方面進行調優:

  1. 優化方法由 SGD 調整為 Adam
  2. 學習率由 1e-3 調整為 1e-4
  3. 對數據數據進行標準化
  4. Dropout 正則化
  5. 調整訓練輪次
# 使用Adam方法優化網絡
#1. 導入相關模塊
import torch
from tensorboard import summary
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import time#2. 構建數據集
def load_dataset():# 使用pandas 讀取數據data = pd.read_csv('data/手機價格預測.csv')# 特征值和目標值x,y = data.iloc[:,:-1],data.iloc[:,-1]# 類型轉換:特征值,目標值x = x.astype(np.float32)y = y.astype(np.int64)# 劃分訓練集和測試集x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=88)# 數據標準化scaler = StandardScaler()x_train = scaler.fit_transform(x_train)x_test = scaler.fit_transform(x_test)x_train = torch.tensor(x_train,dtype=torch.float32)x_test = torch.tensor(x_test,dtype=torch.float32)# 構建數據集,轉換為pytorch格式train_dataset = TensorDataset(x_train, torch.from_numpy(y_train.values))test_dataset = TensorDataset(x_test, torch.from_numpy(y_test.values))#返回結果return train_dataset, test_dataset,x_train.shape[1],len(np.unique(y))
#2. 構建網絡模型
class PhonePriceModel(nn.Module):def __init__(self,input_dim,output_dim,p_dropout=0.4):super(PhonePriceModel, self).__init__()# 第一層: 輸入為維度為 20, 輸出維度為: 128self.linear1 = nn.Linear(input_dim, 128)# Dropout優化self.dropout = nn.Dropout(p_dropout)# 第二層: 輸入為維度為 128, 輸出維度為: 256self.linear2 = nn.Linear(128, 256)# Dropout優化self.dropout = nn.Dropout(p_dropout)# 第三層: 輸入為維度為 256, 輸出維度為: 4self.linear3 = nn.Linear(256, output_dim)def forward(self, x):# 前向傳播過程x = torch.relu(self.linear1(x))x = torch.relu(self.linear2(x))output = self.linear3(x)# 獲取數據結果return output# 3. 模型訓練過程
def train(train_dataset,input_dim,class_num):# 固定隨機數種子torch.manual_seed(0)# 初始化模型model = PhonePriceModel(input_dim,class_num)# 損失函數criterion = nn.CrossEntropyLoss()# 優化方法# optimizer = optim.SGD(model.parameters(), lr=1e-3)# Adam優化方法 調整學習率為 lr=1e-4optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, betas=(0.9, 0.99))# 訓練輪數 100 - 0.9075 50 - 0.9125num_epochs = 50# 遍歷輪數for epoch_idx in range(num_epochs):# 初始化數據加載器dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)# 訓練時間start = time.time()# 計算損失total_loss = 0.0total_num = 1# 遍歷每個batch數據進行處理for x,y in dataloader:# 將數據送入網絡中進行預測output = model(x)# 計算損失loss = criterion(output, y)#梯度清零optimizer.zero_grad()# 方向傳播loss.backward()# 參數更新optimizer.step()# 損失計算total_num += 1total_loss += loss.item()# 打印損失變換結果print('epoch: %4s loss: %.2f, time: %.2fs' % (epoch_idx + 1, total_loss / total_num, time.time() - start))# 保存模型torch.save(model.state_dict(), 'model/phone2.ptn')
# 4 評估模型
def test(test_dataset,input_dim,class_num):# 加載模型和訓練好的網絡參數model = PhonePriceModel(input_dim,class_num)model.load_state_dict(torch.load('model/phone2.ptn',weights_only=False))# 構建加載器dataloader = DataLoader(test_dataset, batch_size=8, shuffle=True)# 評估測試集correct = 0# 遍歷測試集中的數據for x,y in dataloader:# 將其送入網絡中output = model(x)# 獲取類別結果y_pred = torch.argmax(output, dim=1)# 獲取預測正確的個數correct += (y_pred == y).sum().sum()# 求預測精度print('Acc: %.5f' % (correct.item() / len(test_dataset)))
if __name__ == '__main__':train_dataset, test_dataset,input_dim,class_num = load_dataset()print("輸入特征數:",input_dim)print("分類個數:",class_num)# 模型訓練train(train_dataset,input_dim,class_num)test(test_dataset,input_dim,class_num)

這里我們調整 Adam方法優化梯度下降,學習率調整為1e-4,樣本數據采用標準化處理。采用Dropout正則化。最后輸出結果:

Acc: 0.91250

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/66224.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/66224.shtml
英文地址,請注明出處:http://en.pswp.cn/web/66224.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

SAP 固定資產常用的數據表有哪些,他們是怎么記錄數據的?

在SAP系統中,固定資產管理(FI-AA)涉及多個核心數據表,用于記錄資產主數據、折舊、交易等。以下是常用的數據表及其記錄數據的邏輯: 1. ANKT - 資產主數據表 功能:存儲資產主數據的文本描述。 字段&#x…

光伏儲能電解水制氫仿真模型Matlab/Simulink

今天更新的內容為光伏儲能制氫技術,這個方向我之前在21年就系統研究并發表過相關文章,經過這幾年的發展,綠色制氫技術也受到更多高校的注意,本篇博客也是在原先文章的基礎上進行更新。 首先讓大家熟悉一下綠氫制取技術這個概念&a…

Redis 3.2.1在Win10系統上的安裝教程

諸神緘默不語-個人CSDN博文目錄 這個文件可以跟我要,也可以從官網下載:https://github.com/MicrosoftArchive/redis/releases 這個是微軟以前維護的Windows版Redis安裝包,如果想要比較新的版本可以從別人維護的項目里下(https://…

基于springboot+vue.js+uniapp技術開發的一套大型企業MES生產管理系統源碼,支持多端管理

企業級智能制造MES系統源碼,技術架構:springboot vue-element-plus-admin 企業級云MES全套源碼,支持app、小程序、H5、臺后管理端 MES指的是制造企業生產過程執行系統,是一套面向制造企業車間執行層的生產信息化管理系統。MES系…

【Redis】Redis事務和Lua腳本的區別

Redis事務 概念 事務:Redis事務是一組命令的集合,這些命令會被序列化地執行,中間不會被其他命令插入。 MULTI/EXEC:Redis事務通過MULTI命令開始,通過EXEC命令執行所有已入隊的命令。 特點 原子性: 事務…

frameworks 之 AMS與ActivityThread交互

frameworks 之 AMS與ActivityThread交互 1. 類關系2. 流程2.1 AMS流程2.1 ActivityThread流程 3. 堆棧 講解AMS 如何和 ActivityThread 生命周期調用流程 涉及到的類如下 frameworks/base/core/java/android/app/servertransaction/ResumeActivityItem.javaframeworks/base/cor…

Jmeter 簡單使用、生成測試報告(一)

一、下載Jmter 去官網下載,我下載的是apache-jmeter-5.6.3.zip,解壓后就能用。 二、安裝java環境 JMeter是基于Java開發的,運行JMeter需要Java環境。 1.下載JDK、安裝Jdk 2.配置java環境變量 3.驗證安裝是否成功(java -versio…

如何使用淘寶URL采集商品詳情數據及銷量

一、通過淘寶開放平臺(如果有資質) 注冊成為淘寶開發者 訪問淘寶開放平臺官方網站,按照要求填寫開發者信息,包括企業或個人身份驗證等步驟。這一步是為了獲取合法的 API 使用權限。 了解商品詳情 API 淘寶開放平臺提供了一系列…

Unity3D中的Lua、ILRuntime與HybridCLR/huatuo熱更對比分析詳解

前言 在游戲開發中,熱更新技術是一項重要的功能,它允許開發者在不重新發布游戲客戶端的情況下,更新游戲內容。Unity3D作為廣泛使用的游戲引擎,支持多種熱更新方案,包括Lua、ILRuntime和HybridCLR/huatuo。本文將詳細介…

QT加載Ui文件信息方法(python)

在 PyQt 或 PySide 中,加載 Qt Designer 生成的 .ui 文件有兩種常見方法: 使用 pyuic 將 .ui 文件轉換為 Python 代碼。動態加載 .ui 文件。 以下是兩種方法的詳細說明和示例代碼。 方法 1:使用 pyuic 將 .ui 文件轉換為 Python 代碼 步驟…

javascript基礎從小白到高手系列一十二:JSON

本章內容 ? 理解JSON 語法 ? 解析JSON ? JSON 序列化 正如上一章所說,XML 曾經一度成為互聯網上傳輸數據的事實標準。第一代Web 服務很大程度上 是以XML 為基礎的,以服務器間通信為主要特征。可是,XML 也并非沒有批評者。有的人認為XML 過…

網絡編程 - - TCP套接字通信及編程實現

概述 TCP(Transmission Control Protocol,傳輸控制協議)是一種面向連接的、可靠的傳輸層協議。在網絡編程中,TCP常用于實現客戶端和服務器之間的可靠數據傳輸。本文將基于C語言實現TCP服務端和客戶端建立通信的過程。 三次握手 在…

2023-2024 學年 廣東省職業院校技能大賽(高職組)“信息安全管理與評估”賽題一

2023-2024 學年 廣東省職業院校技能大賽(高職組“信息安全管理與評估”賽題一) 模塊一:網絡平臺搭建與設備安全防護第一階段任務書任務 1:網絡平臺搭建任務 2:網絡安全設備配置與防護DCRS:DCFW:DCWS:DCBC:WAF: 模塊二:網絡安全事件…

thinkphp6 + redis實現大數據導出excel超時或內存溢出問題解決方案

redis下載安裝(window版本) 參考地址:https://blog.csdn.net/Ci1693840306/article/details/144214215 php安裝redis擴展 參考鏈接:https://blog.csdn.net/jianchenn/article/details/106144313 解決思路:&#xff0…

PT8M2302 觸控 A/D 型 8-Bit MCU

1. 產品概述 PT8M2302 是一款可多次編程( MTP ) A/D 型 8 位 MCU ,其包括 2K*16bit MTP ROM 、 256*8bit SRAM、 ADC 、 PWM 、 Touch 等功能,具有高性能精簡指令集、低工作電壓、低功耗特性且完全集 成觸控按鍵功能。為…

如何使用策略模式并讓spring管理

1、策略模式公共接口類 BankFileStrategy public interface BankFileStrategy {String getBankFile(String bankType) throws Exception; } 2、策略模式業務實現類 Slf4j Component public class ConcreteStrategy implements BankFileStrategy {Overridepublic String ge…

前端開發:盒子模型、塊元素

1.border邊框 *{box-sizing:border-box; } //使所有邊框不再撐大盒子模型 粗細 : border-width 樣式 : border-style, 默認沒邊框 . solid 實線邊框 dashed 虛線邊框 dotted 點線邊框 顏色 : border-color div { width : 200px ; height : 200px ; border : …

Nvidia Blackwell架構深度剖析:深入了解RTX 50系列GPU的升級

在CES 2025上,英偉達推出了基于Blackwell架構的GeForce RTX 50系列顯卡,包括RTX 5090、RTX 5080、RTX 5070 Ti和RTX 5070。一段時間以來,我們已經知曉了該架構的各種細節,其中許多此前還只是傳聞。不過,英偉達近日在20…

計算機網絡 (45)動態主機配置協議DHCP

前言 計算機網絡中的動態主機配置協議(DHCP,Dynamic Host Configuration Protocol)是一種網絡管理協議,主要用于自動分配IP地址和其他網絡配置參數給連接到網絡的設備。 一、基本概念 定義:DHCP是一種網絡協議&#xf…

“扣子”開發之四:與千帆AppBuilder比較

上一個專題——“扣子”開發——未能落地,開始抱著極大的熱情進入,但迅速被稚嫩的架構模型折磨打擊,硬著頭皮堅持了兩周,終究還是感覺不實用不趁手放棄了。今天詢問了下豆包,看看還有哪些比較好的AI開發平臺&#xff0…