機器學習 - 不同分類模型的比較

一、模型訓練

本案例中,我們將通過四種不同的模型來預測泰坦尼克號乘客的生存情況。
一下是訓練的具體步驟。

加載數據

從seaborn庫中加載目標數據。該數據集包括多個特征,如 PassengerId, Pclass, Name, Sex, Age, SibSp, Parch, Ticket, Fare, Cabin, 和 Embarked。我們訓練使用特征 Pclass, Age, Fare, 和 Sex,標簽列為 Survived

import pandas as pd
import seaborn as sns# Load the Titanic dataset
data = sns.load_dataset('titanic')
print(data.head())
ResultPassengerId  Survived  Pclass  \
0            1         0       3   
1            2         1       1   
2            3         1       3   
3            4         1       1   
4            5         0       3   Name     Sex   Age  SibSp  \
0                            Braund, Mr. Owen Harris    male  22.0      1   
1  Cumings, Mrs. John Bradley (Florence Briggs Th...  female  38.0      1   
2                             Heikkinen, Miss. Laina  female  26.0      0   
3       Futrelle, Mrs. Jacques Heath (Lily May Peel)  female  35.0      1   
4                           Allen, Mr. William Henry    male  35.0      0   Parch            Ticket     Fare Cabin Embarked  
0      0         A/5 21171   7.2500   NaN        S  
1      0          PC 17599  71.2833   C85        C  
2      0  STON/O2. 3101282   7.9250   NaN        S  
3      0            113803  53.1000  C123        S  
4      0            373450   8.0500   NaN        S  

數據預處理

在本案例中,我們的目標是預測泰坦尼克號乘客的生存情況。首先,我將詳細介紹使用的數據預處理方法,這是確保模型表現良好的重要步驟。

1. 缺失值處理

在泰坦尼克號數據集中,Age 是存在缺失值的重要特征。處理缺失值是確保模型準確性的關鍵步驟之一。

# Handle missing values for 'Age'
imputer = SimpleImputer(strategy='mean')
features['Age'] = imputer.fit_transform(features[['Age']])
  • SimpleImputer(strategy='mean'): 這行代碼創建了一個填充器對象,指定使用均值(mean)來填充缺失值。
  • imputer.fit_transform(features[['Age']]): 這里應用填充器,計算所有已知年齡的均值,并填充到缺失的位置。這種方法假設年齡數據的缺失是隨機的,使用均值是合理的首選。
2. 類別特征編碼

Sex 列是分類數據,包含文本值(male/female),需要轉換為模型可處理的數值形式。

# Convert 'Sex' from categorical to numerical
encoder = LabelEncoder()
features['Sex'] = encoder.fit_transform(features['Sex'])
  • LabelEncoder(): 這行代碼創建了一個標簽編碼器,用于將文本標簽轉換為唯一的整數。
  • encoder.fit_transform(features['Sex']): 應用編碼器,將 malefemale 分別轉換為數值(例如 0 和 1)。這是必須的步驟,因為大多數機器學習算法在訓練過程中不能直接處理文本數據。
3. 特征縮放

由于KNN和許多其他機器學習算法對數據的尺度敏感,所以對特征進行標準化是很重要的。

# Standardize the features
scaler = StandardScaler()
features_scaled = scaler.fit_transform(features)
  • StandardScaler(): 創建一個標準化器,用于將特征縮放到具有零均值和單位方差的范圍內。
  • scaler.fit_transform(features): 應用標準化處理,確保所有特征都處于相同的尺度,有助于改善模型的性能和收斂速度。
4. 數據集劃分

最后,數據被劃分為訓練集和測試集,用于訓練模型和評估其性能。

# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(features_scaled, target, test_size=0.2, random_state=42)
  • train_test_split(): 這個函數將數據隨機分為訓練集和測試集,test_size=0.2 表示 20% 的數據用于測試,剩下的 80% 用于訓練。random_state=42 確保每次數據分割的方式相同,這對于可復現性是很重要的。

這些預處理步驟確保了數據的一致性和適用性,是后續模型訓練和驗證的基礎。

5.數據預處理代碼
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.impute import SimpleImputer
from sklearn.metrics import classification_report# Select the required features and the target
features = data[['Pclass', 'Age', 'Fare', 'Sex']]
target = data['Survived']# Handle missing values for 'Age'
imputer = SimpleImputer(strategy='mean')
features['Age'] = imputer.fit_transform(features[['Age']])# Convert 'Sex' from categorical to numerical
encoder = LabelEncoder()
features['Sex'] = encoder.fit_transform(features['Sex'])# Standardize the features
scaler = StandardScaler()
features_scaled = scaler.fit_transform(features)# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(features_scaled, target, test_size=0.2, random_state=42)

KNN模型訓練

from sklearn.neighbors import KNeighborsClassifier# Train the KNN model
knn = KNeighborsClassifier(n_neighbors=5)
knn.fit(X_train, y_train)# Predictions and evaluation
knn_predictions = knn.predict(X_test)
knn_report = classification_report(y_test, knn_predictions)knn_report

KNN模型評估報告:

              precision    recall  f1-score   support0       0.82      0.89      0.85       1051       0.82      0.72      0.76        74accuracy                           0.82       179macro avg       0.82      0.80      0.81       179
weighted avg       0.82      0.82      0.81       179

這個報告顯示了模型在測試集上的表現,包括精確度(precision)、召回率(recall)、F1分數和總體準確度(accuracy)。

線性回歸模型訓練

from sklearn.linear_model import LinearRegression# Train the Linear Regression model
# Note: Linear regression is not typically used for classification tasks, but we'll demonstrate it here for learning purposes.
linear_reg = LinearRegression()
linear_reg.fit(X_train, y_train)# Predictions
linear_reg_predictions = linear_reg.predict(X_test)# Convert predictions to binary to evaluate (0 if < 0.5 else 1)
linear_reg_predictions_binary = [1 if x >= 0.5 else 0 for x in linear_reg_predictions]# Evaluation
linear_reg_report = classification_report(y_test, linear_reg_predictions_binary)linear_reg_report

線性回歸評估報告:

              precision    recall  f1-score   support0       0.81      0.84      0.82       1051       0.76      0.72      0.74        74accuracy                           0.79       179macro avg       0.78      0.78      0.78       179
weighted avg       0.79      0.79      0.79       179

盡管線性回歸通常不用于分類任務,這里我們通過將輸出閾值設為0.5來將其用于二分類問題。

邏輯回歸模型訓練

from sklearn.linear_model import LogisticRegression# Train the Logistic Regression model
logistic_reg = LogisticRegression(random_state=42)
logistic_reg.fit(X_train, y_train)# Predictions
logistic_reg_predictions = logistic_reg.predict(X_test)# Evaluation
logistic_reg_report = classification_report(y_test, logistic_reg_predictions)logistic_reg_report

邏輯回歸評估報告:

              precision    recall  f1-score   support0       0.82      0.86      0.84       1051       0.78      0.73      0.76        74accuracy                           0.80       179macro avg       0.80      0.79      0.80       179
weighted avg       0.80      0.80      0.80       179

決策樹模型訓練

from sklearn.tree import DecisionTreeClassifier# Train the Decision Tree model
decision_tree = DecisionTreeClassifier(random_state=42)
decision_tree.fit(X_train, y_train)# Predictions
decision_tree_predictions = decision_tree.predict(X_test)# Evaluation
decision_tree_report = classification_report(y_test, decision_tree_predictions)decision_tree_report

決策樹評估報告:

              precision    recall  f1-score   support0       0.79      0.77      0.78       1051       0.69      0.72      0.70        74accuracy                           0.75       179macro avg       0.74      0.74      0.74       179
weighted avg       0.75      0.75      0.75       179

模型比較

在這項分析中,我們使用了四種不同的機器學習模型來處理同一數據集,下面是每個模型的性能總結和對比:

KNN (K-Nearest Neighbors)

  • 精確度 (Precision): 0.82 (平均)
  • 召回率 (Recall): 0.80 (平均)
  • F1 分數: 0.81 (平均)
  • 總體準確率 (Accuracy): 82%
  • 優點: 相對直觀易懂,不需要假設數據分布。
  • 缺點: 對異常值敏感,計算量較大,需要調整超參數(如K值)。

線性回歸 (Linear Regression)

  • 精確度: 0.78 (平均)
  • 召回率: 0.78 (平均)
  • F1 分數: 0.78 (平均)
  • 總體準確率: 79%
  • 優點: 實現簡單,解釋性強。
  • 缺點: 不適合用于分類任務,需要轉換為分類輸出,容易受到異常值的影響。

邏輯回歸 (Logistic Regression)

  • 精確度: 0.80 (平均)
  • 召回率: 0.79 (平均)
  • F1 分數: 0.80 (平均)
  • 總體準確率: 80%
  • 優點: 輸出可解釋性強,輸出值具有概率意義。
  • 缺點: 非線性問題表現一般。

決策樹 (Decision Tree)

  • 精確度: 0.74 (平均)
  • 召回率: 0.74 (平均)
  • F1 分數: 0.74 (平均)
  • 總體準確率: 75%
  • 優點: 不需要數據預處理,對非線性關系處理得好,易于理解和解釋。
  • 缺點: 容易過擬合,對于數據變化較敏感。

總結

  • 性能最佳: KNN 和邏輯回歸在本次任務中表現最佳,具有較高的準確率和平衡的精確度與召回率。
  • 適用性: 邏輯回歸提供了概率輸出,更適合需要概率解釋的場景。決策樹在解釋性和處理非線性數據方面有優勢。
  • 資源消耗: KNN在大數據集上運行較慢,因為需要計算每個實例之間的距離。決策樹和邏輯回歸相對資源消耗較低。

很好,我們可以進一步探討一些可能的改進方法或者根據模型性能的具體分析來提供額外的見解。

模型優化和選擇

改進模型性能的策略

  1. KNN:

    • 參數調整: KNN的K值對模型的性能有重大影響。通過交叉驗證來找到最佳的K值可以改進模型的精確度和召回率。
    • 距離度量: 選擇不同的距離度量(如歐氏距離、曼哈頓距離)可能會對結果產生影響,特別是在特征差異性較大的數據集中。
  2. 線性回歸和邏輯回歸:

    • 特征工程: 引入多項式特征或交互特征可以幫助模型捕捉更復雜的關系,尤其是在邏輯回歸中處理非線性邊界時。
    • 正則化: 對邏輯回歸使用L1或L2正則化可以幫助避免過擬合,同時選擇合適的正則化強度是關鍵。
  3. 決策樹:

    • 剪枝: 對決策樹進行剪枝(限制樹的深度、葉節點的最小樣本數等)可以減少過擬合,提高模型的泛化能力。
    • 集成方法: 使用隨機森林或梯度提升決策樹(Gradient Boosting Decision Trees, GBDT)等集成方法可以顯著提升決策樹的性能和穩定性。

模型選擇的考慮因素

  • 數據大小和特征數: 對于大規模數據集,計算密集型的模型(如KNN)可能不是最佳選擇。相反,決策樹和邏輯回歸在大數據集上的表現通常更優。
  • 預測時間要求: 如果應用場景對預測速度有嚴格要求,需要考慮模型的預測效率。例如,決策樹的預測速度通常非常快。
  • 模型的解釋性: 在需要解釋模型決策的應用中(如醫療、金融領域),決策樹和邏輯回歸的可解釋性優勢可能更為重要。

更多問題咨詢

Cos機器人

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:
http://www.pswp.cn/web/11696.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/11696.shtml
英文地址,請注明出處:http://en.pswp.cn/web/11696.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

科技查新中的工法查新點如何確立與提煉?案例講解!

按《工程建設工法管理辦法》( 建 質&#xff3b;2014&#xff3d;103 號) &#xff0c;工法&#xff0c;是指以工程為對象&#xff0c;以工藝為核心&#xff0c;運用系 統工程原理&#xff0c;把先進技術和科學管理結合起來&#xff0c;經過一定工程實踐形成的綜合配套的施工方…

探索美國動態IP池:技術賦能下的網絡安全新篇章

在數字化飛速發展的今天&#xff0c;網絡安全成為了各行各業關注的焦點。特別是在跨國業務中&#xff0c;如何保障數據的安全傳輸和合規性成為了企業面臨的重要挑戰。美國動態IP池作為一種新興的網絡技術&#xff0c;正逐漸走進人們的視野&#xff0c;為網絡安全提供新的解決方…

黑馬甄選離線數倉項目day02(數據采集)

datax介紹 官網&#xff1a; https://github.com/alibaba/DataX/blob/master/introduction.md DataX 是阿里云 DataWorks數據集成 的開源版本&#xff0c;在阿里巴巴集團內被廣泛使用的離線數據同步工具/平臺。 DataX 實現了包括 MySQL、Oracle、OceanBase、SqlServer、Postgre…

Java中List接口中方法的使用(初學者指南)

Java中List接口中方法的使用&#xff08;初學者指南&#xff09; 在Java中&#xff0c;List接口是Collection接口的子接口&#xff0c;它表示一個有序的集合&#xff0c;其中的元素都可以重復。List接口提供了許多額外的方法&#xff0c;用于對元素進行插入、刪除、查詢等操作…

計算機Java項目|Springboot學生讀書筆記共享

作者主頁&#xff1a;編程指南針 作者簡介&#xff1a;Java領域優質創作者、CSDN博客專家 、CSDN內容合伙人、掘金特邀作者、阿里云博客專家、51CTO特邀作者、多年架構師設計經驗、騰訊課堂常駐講師 主要內容&#xff1a;Java項目、Python項目、前端項目、人工智能與大數據、簡…

C++通過json文件配置參數

一、安裝nlohmann json nlohmann json&#xff1a;安裝_nlohmann安裝-CSDN博客 依次執行下面指令&#xff1a; git clone https://gitee.com/cuihongxi/mov_from_github.gitcd json-developmkdir buildcd buildcmake ..makesudo make install 二、安裝完成后使用 #include…

華為設備display查看命令

display version //查看版本信息 display current-configuration //查看配置詳情 display this //查看當前視圖有效配置 display ip routing-table //查看路由表 display ip routing-table 192.168.3.1 //查看去往3.1的路由 display ip interface brief //查看接口下ip信息 dis…

想跨境出海?云手機提供了一種可能性

全球化時代&#xff0c;越來越多的中國電商開始將目光投向了海外市場。這并不是偶然&#xff0c;而是他們在長期的市場運營中&#xff0c;看到了出海的必要性和潛在的機會。 中國的電商市場無疑是全球最大也最發達的之一。然而&#xff0c;隨著市場的不斷發展和競爭的日益加劇…

visual studio2022 JNI極簡開發流程

文章目錄 1 創建java類2 生成JNI頭文件3 使用visual studio2022創建DLL項目3.1 選擇模板中&#xff08;Windows桌面向導&#xff09;3.2 為項目命名3.3 選擇應用程序類型為動態鏈接庫3.4 項目概覽 4 導入需要的頭文件4.1 導入需要的頭文件4.2 修改頭文件 5 編寫C實現6 生成dll文…

服務器3389端口,服務器3389端口風險提示的應對措施

3389端口是Windows操作系統中遠程桌面協議&#xff08;RDP&#xff09;的默認端口。一旦該端口被惡意攻擊者利用&#xff0c;可能會導致未經授權的遠程訪問和數據泄露等嚴重安全問題。 針對此風險&#xff0c;強烈建議您采取以下措施&#xff1a; 1. 修改默認端口&#xff1a;…

Java面試之抽象類和接口

Java的一個重要特性就是抽象&#xff0c;抽象是指將具體的事物抽象成更一般化、更抽象化的概念或模型。在Java中&#xff0c;抽象可以通過抽象類和接口來實現&#xff0c;它們讓你能夠定義一些方法但不提供具體實現&#xff0c;從而讓子類去實現具體細節。 一、抽象類&#xf…

springboot3 集成spring-authorization-server (一 基礎篇)

官方文檔 Spring Authorization Server 環境介紹 java&#xff1a;17 SpringBoot&#xff1a;3.2.0 SpringCloud&#xff1a;2023.0.0 引入maven配置 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter…

識別AI論文生成內容,降低論文高AI率

AI寫作工具能幫我們在短時間內高效生成一篇畢業論文、開通報告、文獻綜述、任務書、調研報告、期刊論文、課程論文等等&#xff0c;導致許多人開始使用AI寫作工具作為撰寫學術論文的輔助手段。而學術界為了杜絕此行為&#xff0c;開始使用AIGC檢測系統來判斷文章是由AI生成還是…

解鎖商業AI,賦能新質生產力發展——思愛普中國峰會探展全紀錄

ITValue 鈦媒體獨家探秘思愛普中國峰會&#xff0c;帶你深刻感受SAP助力企業利用以商業AI為代表的數字化技術&#xff0c;實現質的飛躍&#xff0c;通過全數據、全球化、全綠色賦能新型中國企業發展新質生產力。 首發&#xff5c;鈦媒體APP ITValue 5月10日&#xff0c;一年一度…

基于NTP服務器獲取網絡時間的實現

文章目錄 1 NTP1.1 簡介1.2 包結構1.3 UNIX 時間戳和NTP時間戳 2 代碼實現2.1 實現步驟2.2 完整代碼 3 結果 在某些場景下&#xff0c;單片機需要通過網絡獲取準確的時間進行數據同步&#xff0c;例如日志記錄、定時任務等。然而&#xff0c;單片機本身無法直接獲得準確的標準時…

Vue的學習 —— <vue指令>

目錄 前言 正文 內容渲染指令 內容渲染指令的使用方法 v-text v-html 屬性綁定指令 雙向數據綁定指令 事件綁定指令 條件渲染指令 循環列表渲染指令 偵聽器 前言 在完成Vue開發環境的搭建后&#xff0c;若想將Vue應用于實際項目&#xff0c;首要任務是學習Vue的基…

ORA-00932: inconsistent datatypes: expected - got CLOB的分析解決方案

最近在項目中遇到查詢數據時報ORA-00932: inconsistent datatypes: expected - got CLOB錯誤&#xff0c;這個錯誤很明顯是由于查詢時類型的不匹配造成的。 問題分析&#xff1a; 一、檢查你的查詢的實體的類型是否于數據庫的保持一致&#xff0c;如果不一致&#xff0c;那么需…

333_C++_編寫一個go函數每次從文件中讀取固定大小數據,且go作為回調,傳遞給其他函數中,多次調用,完成逐塊傳輸數據

(core工程文件) tick_transfer_all_t類是一個用于異步傳輸數據的輔助類,它在某個異步操作完成后將_tick的值設置為0,并返回傳輸的結果 namespace hl {namespace http{namespace __detail{class tick_transfer_all_t{boost::shared_ptr<unsigned long long> _tick

MySQL 查詢庫 和 表 占用空間大小的 語句

查看mysql 數據庫的大小 SELECT table_schema AS 數據庫名稱, ROUND(SUM(data_length index_length) / 1024 / 1024, 2) AS 數據庫大小(MB) FROM information_schema.tables GROUP BY table_schema;查詢數據庫中表的 數據量&#xff08;這個方法 有緩存延遲&#xff0c;只能用…

[力扣題解] 96. 不同的二叉搜索樹

題目&#xff1a;96. 不同的二叉搜索樹 思路 動態規劃 f[i]&#xff1a;有i個結點有多少種二叉搜索樹 狀態轉移方程&#xff1a; 以n3為例&#xff1a; 以1為頭節點&#xff0c;左子樹有0個結點&#xff0c;右子樹有2個結點&#xff1b; 以2為頭節點&#xff0c;左子樹有1個…