pytorch神經網絡因素預測_實戰:使用PyTorch構建神經網絡進行房價預測

微信公號:ilulaoshi / 個人網站:lulaoshi.info

本文將學習一下如何使用PyTorch創建一個前饋神經網絡(或者叫做多層感知機,Multiple-Layer Perceptron,MLP),文中會使用PyTorch提供的自動求導功能,訓練一個神經網絡。

本文的數據集來自Kaggle競賽:房價預測(https://www.kaggle.com/c/house-prices-advanced-regression-techniques/)。這份數據分為訓練數據集和測試數據集。兩個數據集都包括每棟房子的特征,如建造年份、地下室狀況等特征值。這些特征中,有連續的數值型(Numerical)特征,有離散的分類(Categorical)特征。這些特征中,有些特征值是缺失值“na”。訓練數據集包括了每棟房子的價格,也就是需要預測的目標值(Label)。我們應該用訓練數據集訓練一個模型,并對測試數據集進行預測,然后將結果提交到Kaggle。

數據探索和預處理

首先,我們下載并加載數據集:

train_data_path ='./dataset/train.csv'

train = pd.read_csv(train_data_path)

num_of_train_data = train.shape[0]

test_data_path ='./dataset/test.csv'

test = pd.read_csv(test_data_path)

訓練數據集共1460個樣本,81個維度,其中,Id是每個樣本的唯一編號,SalePrice是房價,也是我們要擬合的目標值。其他維度(列)有數值類特征,也有非數值列,或者叫分類特征。

先查看訓練數據集的維度:

train.shape

輸出為:

(1460, 81)

或者通過train.describe()來查看整個數據集各個特征的一些統計情況。

接著,我們要把訓練數據集和測試數據集合并。將訓練數據集和測試數據集合并主要是為了統一特征處理的流程,或者說對訓練數據集和測試數據集使用同樣的方法,進行同樣的特征工程處理。

# 房價,要擬合的目標值

target = train.SalePrice

# 輸入特征,可以將SalePrice列扔掉

train.drop(['SalePrice'],axis = 1 , inplace = True)

# 將train和test合并到一起,一塊進行特征工程,方便預測test的房價

combined = train.append(test)

combined.reset_index(inplace=True)

combined.drop(['index', 'Id'], inplace=True, axis=1)

接著就要開始進行特征工程了。本文沒有進行任何復雜的特征工程,只做了兩件事:1、過濾掉了含有缺失值的列;2、對分類特征進行了One-Hot編碼。缺失值會在一定程度上影響算法的預測效果,一般可以使用一些默認值或者一些臨近值來填充缺失值。對于MLP模型,分類特征必須經過編碼,轉換成數值才能進行模型訓練,One-Hot編碼是一種最常見的分類特征處理的方法。

我們用下面的函數過濾非空列:

# 選出非空列

def get_cols_with_no_nans(df,col_type):

'''

Arguments :

df : The dataframe to process

col_type :

num : to only get numerical columns with no nans

no_num : to only get nun-numerical columns with no nans

all : to get any columns with no nans

'''

if (col_type == 'num'):

predictors = df.select_dtypes(exclude=['object'])

elif (col_type == 'no_num'):

predictors = df.select_dtypes(include=['object'])

elif (col_type == 'all'):

predictors = df

else :

print('Error : choose a type (num, no_num, all)')

return 0

cols_with_no_nans = []

for col in predictors.columns:

if not df[col].isnull().any():

cols_with_no_nans.append(col)

return cols_with_no_nans

分別對數值特征和分類特征進行處理:

num_cols = get_cols_with_no_nans(combined, 'num')

cat_cols = get_cols_with_no_nans(combined, 'no_num')

# 過濾掉含有缺失值的特征

combined = combined[num_cols + cat_cols]

print(num_cols[:5])

print ('Number of numerical columns with no nan values: ',len(num_cols))

print(cat_cols[:5])

print ('Number of non-numerical columns with no nan values: ',len(cat_cols))

經過過濾,數值特征共有25列,分類特征共有20列,共45列。

# 對分類特征進行One-Hot編碼

def oneHotEncode(df,colNames):

for col in colNames:

if( df[col].dtype == np.dtype('object')):

# pandas.get_dummies 可以對分類特征進行One-Hot編碼

dummies = pd.get_dummies(df[col],prefix=col)

df = pd.concat([df,dummies],axis=1)

# drop the encoded column

df.drop([col],axis = 1 , inplace=True)

return df

對于分類特征,還需要進行One-Hot編碼,pandas.get_dummies可以幫我們自動完成One-Hot編碼過程。經過One-Hot編碼后,數據增加了很多列,共有149列。

至此,我們完成了一次非常簡單的特征工程,將這些數據轉化為PyTorch模型所能接受的Tensor形式:

# 訓練數據集特征

train_features = torch.tensor(combined[:num_of_train_data].values, dtype=torch.float)

# 訓練數據集目標

train_labels = torch.tensor(target.values, dtype=torch.float).view(-1, 1)

# 測試數據集特征

test_features = torch.tensor(combined[num_of_train_data:].values, dtype=torch.float)

print("train data size: ", train_features.shape)

print("label data size: ", train_labels.shape)

print("test data size: ", test_features.shape)

構建神經網絡

接著,我們開始構建神經網絡。

在PyTorch中構建神經網絡有兩種方式。比較簡單的前饋網絡,可以使用nn.Sequential。nn.Sequential是一個存放神經網絡的容器,直接在nn.Sequential里面添加我們需要的層即可。整個模型的輸入為特征數,輸出為一個標量。模型的隱藏層使用了ReLU激活函數,最后一層是一個線性層,得到的是一個預測的房價值。

model_sequential = nn.Sequential(

nn.Linear(train_features.shape[1], 128),

nn.ReLU(),

nn.Linear(128, 256),

nn.ReLU(),

nn.Linear(256, 256),

nn.ReLU(),

nn.Linear(256, 256),

nn.ReLU(),

nn.Linear(256, 1)

)

另一種構建神經網絡的方式是繼承nn.Module類,我們將子類起名為Net類。__init__()方法為Net類的構造函數,用來初始化神經網絡各層的參數;forward()也是我們必須實現的方法,主要用來實現神經網絡的前向傳播過程。

class Net(nn.Module):

def __init__(self, features):

super(Net, self).__init__()

self.linear_relu1 = nn.Linear(features, 128)

self.linear_relu2 = nn.Linear(128, 256)

self.linear_relu3 = nn.Linear(256, 256)

self.linear_relu4 = nn.Linear(256, 256)

self.linear5 = nn.Linear(256, 1)

def forward(self, x):

y_pred = self.linear_relu1(x)

y_pred = nn.functional.relu(y_pred)

y_pred = self.linear_relu2(y_pred)

y_pred = nn.functional.relu(y_pred)

y_pred = self.linear_relu3(y_pred)

y_pred = nn.functional.relu(y_pred)

y_pred = self.linear_relu4(y_pred)

y_pred = nn.functional.relu(y_pred)

y_pred = self.linear5(y_pred)

return y_pred

我們已經定義好了一個神經網絡的Net類,還要初始化一個Net類的對象實例model,表示某個具體的模型。然后定義損失函數,這里使用MSELoss,MSELoss使用了均方誤差(Mean Square Error)來衡量損失函數。對于模型model的訓練過程,這里使用Adam算法。Adam是優化算法中的一種,在很多場景中效率要優于SGD。

model = Net(features=train_features.shape[1])

# 使用均方誤差作為損失函數

criterion = nn.MSELoss(reduction='mean')

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

訓練模型

接著,我們使用Adam算法進行多輪的迭代,更新模型model中的參數。這里對模型進行500輪的迭代。

losses = []

# 訓練500輪

for t in range(500):

y_pred = model(train_features)

loss = criterion(y_pred, train_labels)

# print(t, loss.item())

losses.append(loss.item())

if torch.isnan(loss):

break

# 將模型中各參數的梯度清零。

# PyTorch的backward()方法計算梯度會默認將本次計算的梯度與緩存中已有的梯度加和。

# 必須在反向傳播前先清零。

optimizer.zero_grad()

# 反向傳播,計算各參數對于損失loss的梯度

loss.backward()

# 根據剛剛反向傳播得到的梯度更新模型參數

optimizer.step()

每次迭代使用訓練數據集中的所有樣本train_features。model(train_features)實際是執行的model.forward(train_features),即forward()方法中定義的前向傳播邏輯,輸入數據在神經網絡模型中前向傳播,得到預測值y_pred。criterion(y_pred, train_labels)方法計算了預測值y_pred和目標值train_labels之間的損失。

每次迭代時,我們要先對模型中各參數的梯度清零:optimizer.zero_grad()。PyTorch中的backward()默認是把本次計算的梯度和緩存中已有的梯度加和,因此必須在反向傳播前先將梯度清零。接著執行backward()方法,完成反向傳播過程,PyTorch會幫我們計算各參數對于損失函數的梯度。optimizer.step()會根據剛剛反向傳播得到的梯度,更新模型參數。

至此,一個簡單的預測房價的模型就訓練好了。

測試模型

我們可以使用模型對測試數據集進行預測,將得到的預測值保存成文件,提交到Kaggle上。

predictions = model(test_features).detach().numpy()

my_submission = pd.DataFrame({'Id':pd.read_csv('./dataset/test.csv').Id,'SalePrice': predictions[:, 0]})

my_submission.to_csv('{}.csv'.format('./dataset/submission'), index=False)

參考資料

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/457549.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/457549.shtml
英文地址,請注明出處:http://en.pswp.cn/news/457549.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

SQL基本操作

SQL 操作 檢索數據 SELECT 檢索數據 -- 檢索單個列 SELECT 列名 FROM table_name;-- 檢索多個列 SELECT 列1, 列2 FROM table_name;-- 檢索所有列 SELECT * FROM table_name;-- 檢索不同的值 SELECT DISTINCT 列名 FROM table_name;限制檢索結果 -- SQL Server / Access SE…

git 忽略 部分文件夾_git提交忽略某些文件或文件夾

記得第一次用 github 提交代碼,node_modules 目錄死活傳不上去,哈哈哈,后來才知道在 .gitignore 文件里設置了忽略 node_modules 目錄上傳。是的, .gitignore 文件就是設置那些你不想用 git 一起上傳的文件和文件夾。比如剛接觸到…

Ajax實現原理詳解

Ajax:Asynchronous javascript and xml,實現了客戶端與服務器進行數據交流過程。使用技術的好處是:不用頁面刷新,并且在等待頁面傳輸數據的同時可以進行其他操作。 這就是異步調用的很好體現。首先得了解什么是異步和同步的概念。…

SpringJDBC解析3-回調函數(update為例)

PreparedStatementCallback作為一個接口,其中只有一個函數doInPrepatedStatement,這個函數是用于調用通用方法execute的時候無法處理的一些個性化處理方法,在update中的函數實現: protected int update(final PreparedStatementCr…

python上下文管理器

DAY 23. python上下文管理器 Python 的 with 語句支持通過上下文管理器所定義的運行時上下文這一概念。 此對象的實現使用了一對專門方法,允許用戶自定義類來定義運行時上下文,在語句體被執行前進入該上下文,并在語句執行完畢時退出該上下文&…

勾股定理python思路_趣叮咚編程數學揭秘:為什么勾股定理a+b=c?

我們都知道:三角形3個外角之和360度可是誰知道為什么等于360度呢?其實利用編程制作動圖演繹了解啦:那勾股定理abc又是為什么呢?還有很多有趣的數學公式都可以演繹:圓的面積公式、圓周長...通過動圖演繹原來晦澀難懂的定…

System.InvalidOperationException : 不應有 Response xmlns=''。

xml如下&#xff1a; <?xml version"1.0" encoding"UTF-8"?> <Response version"2"><datacash_reference>4700203048783633</datacash_reference><information>Failed to identify the card scheme of the supp…

Navicat Premium連接SQL Server

Navicat Premium連接SQL Server 步驟&#xff1a; 激活SQL Server 服務配置SQL Server網絡配置連接SQL Server 激活SQLServer服務 直接搜索 計算機管理 點 服務和應用程序&#xff0c; 點 SQL Server配置管理器&#xff0c; 雙擊第一個SQL Server服務 不出意外的話&#xf…

mysql 單標遞歸_MySql8 WITH RECURSIVE遞歸查詢父子集的方法

背景開發過程中遇到類似評論的功能是&#xff0c;需要時用查詢所有評論的子集。不同數據庫中實現方式也不同&#xff0c;本文使用Mysql數據庫&#xff0c;版本為8.0Oracle數據庫中可使用START [Param] CONNECT BY PRIORMysql 中需要使用 WITH RECURSIVE需求找到name為張三的孩子…

processon完全裝逼指南

一、引言 作為一名IT從業者&#xff0c;不僅要有扎實的知識儲備&#xff0c;出色的業務能力&#xff0c;還需要具備一定的軟實力。軟實力體現在具體事務的處理能力&#xff0c;包括溝通&#xff0c;協作&#xff0c;團隊領導&#xff0c;問題的解決方案等&#xff0c;這些能力在…

mysql在空閑8小時之后會斷開連接(默認情況)

調試程序的過程發現&#xff0c;在mysql連接空閑一定時間&#xff08;默認8小時&#xff09;之后會斷開連接&#xff0c;需要重新連接&#xff0c;也引發我對重連機制的思考。轉載于:https://www.cnblogs.com/ppzbty/p/5707576.html

selector多路復用_多路復用器Selector

Unix系統有五種IO模型分別是阻塞IO(blocking IO)&#xff0c;非阻塞IO( non-blocking IO)&#xff0c;IO多路復用(IO multiplexing)&#xff0c;信號驅動(SIGIO/Signal IO)和異步IO(Asynchronous IO)。而IO多路復用通常有select&#xff0c;poll&#xff0c;epoll&#xff0c;k…

解決svn log顯示no author,no date的方法之一

只要把svnserve.conf中的anon-access read 的read 改為none&#xff0c;也不需要重啟svnserve就行 sh-4.1# grep "none" /var/www/html/svn/pro/conf/svnserve.conf ### and "none". The sample settings below are the defaults. anon-access none轉載…

REST framework 權限管理源碼分析

REST framework 權限管理源碼分析 同認證一樣&#xff0c;dispatch()作為入口&#xff0c;從self.initial(request, *args, **kwargs)進入initial() def initial(self, request, *args, **kwargs):# .......# 用戶認證self.perform_authentication(request)# 權限控制self.che…

解決larave-dompdf中文字體顯示問題

0、使用MPDF dompdf個人感覺沒有那么好用&#xff0c;最終的生產環境使用的是MPDF&#xff0c;github上有文檔說明。如果你堅持使用&#xff0c;下面是解決辦法。可以明確的說&#xff0c;中文亂碼是可以解決的。 1、安裝laravel-dompdf依賴。 Packagist&#xff1a;https://pa…

mfc程序轉化為qt_小峰的QT學習筆記

我的專業是輸電線路&#xff0c;上個學期&#xff0c;我們開了一門架空線路設計基礎的課&#xff0c;當時有一個大作業是計算線路的比載&#xff0c;臨界檔距&#xff0c;弧垂最低點和安裝曲線。恰逢一門結課考試結束&#xff0c;大作業ddl快到&#xff0c;我和另外兩個同專業的…

MS SQL的存儲過程

-- -- Author: -- Create date: 2016-07-01 -- Description: 注冊信息 -- ALTER PROCEDURE [dbo].[sp_MebUser_Register]( UserType INT, MobileNumber VARCHAR(11), MobileCode VARCHAR(50), LoginPwd VARCHAR(50), PayPwd VARCHAR(50), PlateNumber VARCHAR(20), UserTr…

mysql 中 all any some 用法

-- 建表語句 CREATE TABLE score(id INT PRIMARY KEY AUTO_INCREMENT,NAME VARCHAR(20),SUBJECT VARCHAR(20),score INT);-- 添加數據 INSERT INTO score VALUES (NULL,張三,語文,81), (NULL,張三,數學,75), (NULL,李四,語文,76), (NULL,李四,數學,90), (NULL,王五,語文,81), (…

REST framework 用戶認證源碼

REST 用戶認證源碼 在Django中&#xff0c;從URL調度器中過來的HTTPRequest會傳遞給disatch(),使用REST后也一樣 # REST的dispatch def dispatch(self, request, *args, **kwargs):""".dispatch() is pretty much the same as Djangos regular dispatch,but w…

scrapyd部署_如何通過 Scrapyd + ScrapydWeb 簡單高效地部署和監控分布式爬蟲項目

來自 Scrapy 官方賬號的推薦需求分析初級用戶&#xff1a;只有一臺開發主機能夠通過 Scrapyd-client 打包和部署 Scrapy 爬蟲項目&#xff0c;以及通過 Scrapyd JSON API 來控制爬蟲&#xff0c;感覺 命令行操作太麻煩 &#xff0c;希望能夠通過瀏覽器直接部署和運行項目專業用…