用邏輯回歸(Logistic Regression)處理鳶尾花(iris)數據集

# 導入必要的庫
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import (accuracy_score, confusion_matrix,classification_report, ConfusionMatrixDisplay)
from sklearn.preprocessing import StandardScaler# 1. 加載鳶尾花數據集
iris = load_iris()
# 轉換為DataFrame方便查看(特征+標簽)
iris_df = pd.DataFrame(data=iris.data, columns=iris.feature_names)
iris_df['species'] = [iris.target_names[i] for i in iris.target]  # 添加花名標簽# 2. 數據基本信息查看
print("數據集形狀:", iris.data.shape)  # 150個樣本,4個特征
print("\n特征名稱:", iris.feature_names)  # 花萼長度、寬度,花瓣長度、寬度
print("\n類別名稱:", iris.target_names)  # 山鳶尾、變色鳶尾、維吉尼亞鳶尾# 3. 數據劃分(特征X和標簽y)
X = iris.data  # 特征:4個植物學測量值
y = iris.target  # 標簽:0,1,2分別對應三種鳶尾花# 劃分訓練集(80%)和測試集(20%),隨機種子確保結果可復現
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y  # stratify=y保持類別比例
)# 4. 特征標準化(邏輯回歸對特征尺度敏感,標準化可提升性能)
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)  # 訓練集擬合并標準化
X_test_scaled = scaler.transform(X_test)  # 測試集使用相同的標準化參數# 5. 訓練邏輯回歸模型(多分類任務)
model = LogisticRegression(max_iter=200, random_state=42)  # 增加迭代次數確保收斂
model.fit(X_train_scaled, y_train)# 6. 模型預測
y_pred = model.predict(X_test_scaled)  # 測試集預測標簽
y_pred_proba = model.predict_proba(X_test_scaled)  # 預測每個類別的概率# 7. 模型評估
print("\n===== 模型評估結果 =====")
print(f"訓練集準確率:{model.score(X_train_scaled, y_train):.4f}")
print(f"測試集準確率:{accuracy_score(y_test, y_pred):.4f}")print("\n混淆矩陣:")
cm = confusion_matrix(y_test, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=iris.target_names)
disp.plot(cmap=plt.cm.Blues)
plt.title("混淆矩陣(測試集)")
plt.show()print("\n分類報告:")
print(classification_report(y_test, y_pred, target_names=iris.target_names))# 8. 特征重要性分析(邏輯回歸系數)
feature_importance = pd.DataFrame({'特征': iris.feature_names,'系數絕對值': np.abs(model.coef_).mean(axis=0)  # 多分類取各系數的絕對值均值
}).sort_values(by='系數絕對值', ascending=False)print("\n特征重要性(系數絕對值):")
print(feature_importance)# 可視化特征重要性
plt.figure(figsize=(8, 4))
sns.barplot(x='系數絕對值', y='特征', data=feature_importance, palette='coolwarm')
plt.title("特征對分類的重要性")
plt.show()# 9. 新樣本預測示例
# 假設一個新的鳶尾花測量數據(花萼長、花萼寬、花瓣長、花瓣寬)
new_sample = np.array([[5.8, 3.0, 4.9, 1.6]])  # 接近變色鳶尾的特征
new_sample_scaled = scaler.transform(new_sample)  # 標準化# 預測結果
predicted_class = model.predict(new_sample_scaled)
predicted_prob = model.predict_proba(new_sample_scaled)print("\n===== 新樣本預測 =====")
print(f"預測類別:{iris.target_names[predicted_class[0]]}")
print("各類別概率:")
for i, prob in enumerate(predicted_prob[0]):print(f"{iris.target_names[i]}: {prob:.4f}")

這段代碼使用邏輯回歸算法對經典的鳶尾花數據集進行分類,是一個完整的機器學習項目流程。

1. 導入必要的庫

import numpy as np

import pandas as pd

import matplotlib.pyplot as plt

import seaborn as sns

from sklearn.datasets import load_iris

from sklearn.model_selection import train_test_split

from sklearn.linear_model import LogisticRegression

from sklearn.metrics import (accuracy_score, confusion_matrix,

???????????????????????????? classification_report, ConfusionMatrixDisplay)

from sklearn.preprocessing import StandardScaler

  1. numpy/pandas:用于數據處理(如矩陣運算、表格操作)。
  2. matplotlib/seaborn:用于繪制圖表(如混淆矩陣、特征重要性)。
  3. sklearn:機器學習庫,提供數據集、模型、評估工具。

2. 加載和查看數據

iris = load_iris()? # 加載內置鳶尾花數據集

iris_df = pd.DataFrame(iris.data, columns=iris.feature_names)

iris_df['species'] = [iris.target_names[i] for i in iris.target]

print("數據集形狀:", iris.data.shape)? # (150, 4) → 150個樣本,4個特征

print("特征名稱:", iris.feature_names)? # 花瓣/花萼的長度、寬度

print("類別名稱:", iris.target_names)? # ['setosa' 'versicolor' 'virginica']

  1. 鳶尾花數據集:包含 150 朵花的數據,分為 3 個品種(每個品種 50 朵)。
  2. 4 個特征:花瓣長度、花瓣寬度、花萼長度、花萼寬度(都是厘米)。
  3. 目標:根據這 4 個特征預測花的品種。

3. 數據劃分(訓練集和測試集)

X = iris.data? # 特征(花瓣/花萼的測量值)

y = iris.target? # 標簽(0/1/2對應3個品種)

X_train, X_test, y_train, y_test = train_test_split(

??? X, y, test_size=0.2, random_state=42, stratify=y

)

  1. train_test_split:將數據分為 80% 訓練集和 20% 測試集。
    1. stratify=y:確保訓練集和測試集中 3 個品種的比例相同(避免數據偏斜)。
    2. random_state=42:固定隨機種子,確保結果可復現(每次運行劃分結果相同)。

4. 特征標準化

scaler = StandardScaler()

X_train_scaled = scaler.fit_transform(X_train)? # 訓練集標準化

X_test_scaled = scaler.transform(X_test)? # 測試集用相同參數標準化

  1. 為什么標準化?:邏輯回歸對特征尺度敏感(例如,如果某個特征的數值范圍很大,會影響模型收斂)。
  2. StandardScaler:將特征轉換為均值為 0、標準差為 1 的標準正態分布。
    1. fit_transform:計算訓練集的均值 / 標準差,并應用轉換。
    2. transform:用訓練集的統計參數(均值 / 標準差)轉換測試集(不能重新計算)。

5. 訓練邏輯回歸模型

model = LogisticRegression(max_iter=200, random_state=42)

model.fit(X_train_scaled, y_train)

  1. LogisticRegression:邏輯回歸是分類算法(盡管名字帶 “回歸”)。
    1. max_iter=200:增加最大迭代次數,確保模型收斂(默認 100 可能不夠)。
  2. fit:用訓練數據學習模型參數(找到最佳分類邊界)。

6. 模型預測

y_pred = model.predict(X_test_scaled)? # 預測類別(0/1/2

y_pred_proba = model.predict_proba(X_test_scaled)? # 預測每個類別的概率

  1. predict:直接輸出預測的類別(例如 1 代表 versicolor)。
  2. predict_proba:輸出樣本屬于每個類別的概率(例如 [0.01, 0.95, 0.04] 表示 95% 概率是第二類)。

7. 模型評估

print(f"訓練集準確率:{model.score(X_train_scaled, y_train):.4f}")

print(f"測試集準確率:{accuracy_score(y_test, y_pred):.4f}")

  1. 準確率(Accuracy:預測正確的樣本比例。
    1. 訓練集準確率:約 0.99(模型對訓練數據的擬合程度)。
    2. 測試集準確率:約 0.97(模型對新數據的泛化能力)。
混淆矩陣(Confusion Matrix)

cm = confusion_matrix(y_test, y_pred)

disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=iris.target_names)

disp.plot()

  1. 混淆矩陣:可視化分類結果,對角線表示預測正確的樣本數。
    1. 例如:預測 setosa(0)的樣本全部分類正確;有 1 個 versicolor(1)被誤分類為 virginica(2)。
分類報告(Classification Report)

print(classification_report(y_test, y_pred, target_names=iris.target_names))

  1. 精確率(Precision:預測為某類的樣本中,實際屬于該類的比例。
  2. 召回率(Recall:實際屬于某類的樣本中,被正確預測的比例。
  3. F1 分數(F1-score:精確率和召回率的調和平均。

8. 特征重要性分析

feature_importance = pd.DataFrame({

??? '特征': iris.feature_names,

??? '系數絕對值': np.abs(model.coef_).mean(axis=0)

}).sort_values('系數絕對值', ascending=False)

  1. 邏輯回歸系數:系數絕對值越大,說明該特征對分類的影響越大。
    1. 通常petal width(花瓣寬度)和petal length(花瓣長度)對分類最重要。

9. 新樣本預測示例

new_sample = np.array([[5.8, 3.0, 4.9, 1.6]])? # 手動構造一個樣本

new_sample_scaled = scaler.transform(new_sample)? # 標準化

predicted_class = model.predict(new_sample_scaled)? # 預測類別

predicted_prob = model.predict_proba(new_sample_scaled)? # 預測概率

  1. 預測結果:輸出新樣本的預測類別和概率(例如 95% 概率是 versicolor)。

總結

這個代碼展示了一個完整的機器學習流程:

  1. 數據準備:加載數據、劃分訓練集 / 測試集。
  2. 特征工程:標準化特征,避免量綱影響。
  3. 模型訓練:用邏輯回歸學習分類規則。
  4. 模型評估:用準確率、混淆矩陣等指標衡量性能。
  5. 預測應用:對新樣本進行分類。

鳶尾花數據集是機器學習的 “Hello World”,適合入門。邏輯回歸是簡單但強大的分類算法,尤其適合特征與類別之間存在線性關系的場景。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/91839.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/91839.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/91839.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

華大北斗TAU1201-1216A00高精度雙頻GNSS定位模塊 自動駕駛專用

在萬物互聯的時代,您還在為定位不準、信號丟失而煩惱嗎?TAU1201-1216A00華大北斗高精度定位模塊TAU1201是一款高性能的雙頻GNSS定位模塊,搭載了華大北斗的CYNOSURE III GNSS SoC 芯片,該模塊支持新一代北斗三號信號體制&#xff0…

堅持繼續布局32位MCU,進一步完善產品陣容,96Mhz主頻CW32L012新品發布!

在全球MCU市場競爭加劇、國產替代加速的背景下,嵌入式設備對核心控制芯片的性能、功耗、可靠性及性價比提出了前所未有的嚴苛需求。為適應市場競爭,2025年7月16日,武漢芯源半導體正式推出基于CW32L01x系列低功耗微控制器家族的全新成員&#…

用線性代數推導碼分多址(CDMA)

什么是碼分多址 碼分多址:CDMA允許多個用戶同時、在同一頻率上傳輸數據。它通過給每個用戶分配唯一的、相互正交的二進制序列來實現區分。用戶的數據比特被這個碼片序列擴展成一個高速率的信號,然后在接收端通過相同的碼片序列進行相關運算來回復原數據 …

mac 配置svn

1.查看brew的版本:brew install subversion2.安裝brew命令:bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"3.把路徑添加到path環境變量:echo export PATH"/opt/homebrew/b…

使用 .NET Core 的原始 WebSocket

在 Web 開發中,后端存在一些值得注意的通信協議,用于將更改通知給已連接的客戶端。所有這些協議都用于處理同一件事。但鮮為人知的協議很少,鮮為人知的協議也很少。今天,將討論 WebSocket,它在開發中使用最少&#xff…

編程實現Word自動排版:從理論到實踐的全面指南

在現代辦公環境中,文檔排版是一項常見但耗時的工作。特別是對于需要處理大量文檔的專業人士來說,手動排版不僅費時費力,還容易出現不一致的問題。本文將深入探討如何通過編程方式實現Word文檔的自動排版,從理論基礎到實際應用&…

力扣經典算法篇-25-刪除鏈表的倒數第 N 個結點(計算鏈表的長度,利用棧先進后出特性,雙指針法)

1、題干 給你一個鏈表,刪除鏈表的倒數第 n 個結點,并且返回鏈表的頭結點。 示例 1:輸入:head [1,2,3,4,5], n 2 輸出:[1,2,3,5] 示例 2: 輸入:head [1], n 1 輸出:[] 示例 3&…

VIT速覽

當我們取到一張圖片,我們會把它劃分為一個個patch,如上圖把一張圖片劃分為了9個patch,然后通過一個embedding把他們轉換成一個個token,每個patch對應一個token,然后在輸入到transformer encoder之前還要經過一個class …

【服務器與部署 14】消息隊列部署:RabbitMQ、Kafka生產環境搭建指南

【服務器與部署 14】消息隊列部署:RabbitMQ、Kafka生產環境搭建指南 關鍵詞:消息隊列、RabbitMQ集群、Kafka集群、消息中間件、異步通信、微服務架構、高可用部署、消息持久化、生產環境配置、分布式系統 摘要:本文從實際業務場景出發&#x…

LeetCode中等題--167.兩數之和II-輸入有序數組

1. 題目 給你一個下標從 1 開始的整數數組 numbers &#xff0c;該數組已按 非遞減順序排列 &#xff0c;請你從數組中找出滿足相加之和等于目標數 target 的兩個數。如果設這兩個數分別是 numbers[index1] 和 numbers[index2] &#xff0c;則 1 < index1 < index2 <…

【C# in .NET】19. 探秘抽象類:具體實現與抽象契約的橋梁

探秘抽象類:具體實現與抽象契約的橋梁 在.NET類型系統中,抽象類是連接具體實現與抽象契約的關鍵橋梁,它既具備普通類的狀態承載能力,又擁有類似接口的行為約束特性。本文將從 IL 代碼結構、CLR 類型加載機制、方法調度邏輯三個維度,全面揭示抽象類的底層工作原理,通過與…

Apache RocketMQ + “太乙” = 開源貢獻新體驗

Apache RocketMQ 是 Apache 基金會托管的頂級項目&#xff0c;自 2012 年誕生于阿里巴巴&#xff0c;服務于淘寶等核心交易系統&#xff0c;歷經多次雙十一萬億級數據洪峰穩定性驗證&#xff0c;至今已有十余年發展歷程。RocketMQ 致力于構建低延遲、高并發、高可用、高可靠的分…

永磁同步電機控制算法--弱磁控制(變交軸CCR-VQV)

一、原理介紹CCR-FQV弱磁控制不能較好的利用逆變器的直流側電壓&#xff0c;造成電機的調速范圍窄、效率低和帶載能力差。為了解決CCR-FQV弱磁控制存在的缺陷&#xff0c;可以在電機運行過程中根據工況的不同實時的改變交軸電壓給定uq?的值&#xff0c;實施 CCR-VQV弱磁控制。…

達夢數據守護集群搭建(1主1實時備庫1同步備庫1異步備庫)

目錄 1 環境信息 1.1 目錄信息 1.2 其他環境信息 2 環境準備 2.1 新建dmdba用戶 2.2 關閉防火墻 2.3 關閉Selinux 2.4 關閉numa和透明大頁 2.5 修改文件打開最大數 2.6 修改磁盤調度 2.7 修改cpufreq模式 2.8 信號量修改 2.9 修改sysctl.conf 2.10 修改 /etc/sy…

電感與電容充、放電極性判斷和電感選型

目錄 一、電感 二、電容 三、電感選型 一、電感 充電&#xff1a;左右-為例 放電&#xff1a;極性相反&#xff0c;左-右 二、電容 充電&#xff1a;左右-為例 放電&#xff1a;左右-&#xff08;與充電極性一致&#xff09; 三、電感選型 主要考慮額定電流和飽和電流。…

新建模范式Mamba——“Selectivity is All You Need?”

目錄 一、快速走進和理解Mamba建模架構 &#xff08;一&#xff09;從Transformer的統治地位談起 &#xff08;二&#xff09;另一條道路&#xff1a;結構化狀態空間模型&#xff08;SSM&#xff09; &#xff08;三&#xff09;Mamba 的核心創新&#xff1a;Selective SSM…

Python實現Word文檔中圖片的自動提取與加載:從理論到實踐

在現代辦公和文檔處理中&#xff0c;Word文檔已經成為最常用的文件格式之一。這些文檔不僅包含文本內容&#xff0c;還經常嵌入各種圖片、圖表和其他媒體元素。在許多場景下&#xff0c;我們需要從Word文檔中提取這些圖片&#xff0c;例如進行內容分析、創建圖像數據庫、或者在…

Kafka、RabbitMQ 與 RocketMQ 高可靠消息保障方案對比分析

Kafka、RabbitMQ 與 RocketMQ 高可靠消息保障方案對比分析 在分布式系統中&#xff0c;消息隊列承擔著異步解耦、流量削峰、削峰填谷等重要職責。為了保證應用的數據一致性和業務可靠性&#xff0c;各大消息中間件都提供了多種高可靠消息保障機制。本文以Kafka、RabbitMQ和Rock…

四足機器人遠程視頻與互動控制的全鏈路方案

隨著機器人行業的快速發展&#xff0c;特別是四足仿生機器人在巡檢、探測、安防、救援等復雜環境中的廣泛部署&#xff0c;如何實現高質量、低延遲的遠程視頻監控與人機互動控制&#xff0c;已經成為制約其應用落地與規模化推廣的關鍵技術難題。 四足機器人常常面臨以下挑戰&a…

把leetcode官方題解自己簡單解釋一下

自用自用&#xff01;&#xff01;&#xff01;leetcode hot 100