礦物分類系統開發筆記(二):模型訓練[刪除空缺行]

目錄

一、階段銜接與開發目標

二、數據準備

三、模型選擇與訓練

1. 邏輯回歸(LR)

2. 隨機森林(RF)

3. 高斯樸素貝葉斯(GNB)

4. 支持向量機(SVM)

5. AdaBoost

6. XGBoost

四、模型評估與結果分析

評估指標

評估結果

結果分析

五、開發總結

六、后續計劃


一、階段銜接與開發目標

在《礦物分類系統開發筆記(一)》中,我們完成了礦物數據集的收集、清洗與預處理工作,重點對數據中的空缺值進行了分析,并采用 “刪除空缺行” 的方式生成了可供模型訓練的標準化數據集。本階段作為開發流程的延續,主要基于預處理后的數據完成以下目標:

  • 選取 6 種經典機器學習算法進行礦物分類模型訓練
  • 通過網格搜索優化模型參數,提升分類性能
  • 構建統一的評估體系,對比各模型在測試集上的表現
  • 記錄并分析實驗結果,為后續系統選型提供依據

二、數據準備

數據來源:使用預處理階段生成的訓練集(訓練數據集 [刪除空缺行].xlsx)和測試集(測試數據集 [刪除空缺行].xlsx)

數據劃分:

  • 特征集(X):所有樣本的屬性數據(除最后一列標簽外的所有列)
  • 標簽集(y):
    • 訓練集標簽:包含 0、1、3 三類(訓練集中標簽為 2 的樣本均存在數據空缺,已在預處理階段隨空缺行一同刪除)
    • 測試集標簽:包含 0、1、2、3 四類(保留了數據完整的標簽 2 樣本,用于驗證模型對未見過類別的泛化能力)

特殊處理:
針對 XGBoost 模型特性,構建標簽映射關系:{0:0, 1:1, 3:2},將原始標簽轉換為連續整數編碼;預測后通過反向映射{0:0, 1:1, 2:3}還原原始標簽,對測試集特有的標簽 2 單獨處理(預測結果中若出現未映射編碼則判定為 2)

import pandas as pd
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.svm import SVC
from xgboost import XGBClassifier
from sklearn.model_selection import GridSearchCV
from sklearn import metrics
import json# 數據讀取
train_data = pd.read_excel('..//temp_data//訓練數據集[刪除空缺行].xlsx')
test_data = pd.read_excel('..//temp_data//測試數據集[刪除空缺行].xlsx')# 特征與標簽分割
train_X = train_data.iloc[:, :-1]
train_y = train_data.iloc[:, -1]  # 訓練標簽:0、1、3
test_X = test_data.iloc[:, :-1]
test_y = test_data.iloc[:, -1]    # 測試標簽:0、1、2、3# XGBoost標簽映射處理
label_mapping = {0: 0, 1: 1, 3: 2}
reverse_mapping = {v: k for k, v in label_mapping.items()}
train_y_xgb = train_y.map(label_mapping)  # 轉換為連續編碼
test_y_xgb = test_y.map(label_mapping)# 結果存儲容器
result_data = {}

三、模型選擇與訓練

選取 6 種經典分類算法進行對比實驗,均采用網格搜索(GridSearchCV)進行參數優化,5 折交叉驗證確定最佳參數:

1. 邏輯回歸(LR)

核心參數:C=0.001, max_iter=100, multi_class='ovr', penalty='l1', solver='liblinear'
特點:采用 L1 正則化(Lasso),適合高維數據特征選擇,使用 ovr 策略處理多分類

# 網格搜索優化(實際運行時啟用)
# logreg = LogisticRegression()
# param_grid = [
#     {'penalty': ['l1'], 'solver': ['liblinear'], 'C': [0.001, 0.01, 0.1], 'multi_class': ['ovr']},
#     {'penalty': ['l2'], 'solver': ['lbfgs'], 'C': [0.001, 0.01, 0.1], 'multi_class': ['multinomial']}
# ]
# grid_search = GridSearchCV(logreg, param_grid, cv=5)
# grid_search.fit(train_X, train_y)
# print("LR最佳參數:", grid_search.best_params_)# 最佳模型訓練
LR_result = {}
lr = LogisticRegression(C=0.001, max_iter=100, multi_class='ovr', penalty='l1', solver='liblinear')
lr.fit(train_X, train_y)# 評估
train_pred = lr.predict(train_X)
test_pred = lr.predict(test_X)
print("LR訓練集評估:\n", metrics.classification_report(train_y, train_pred))
print("LR測試集評估:\n", metrics.classification_report(test_y, test_pred))# 結果提取
report = metrics.classification_report(test_y, test_pred, digits=6).split()
LR_result['recall_0'] = float(report[6])
LR_result['recall_1'] = float(report[11])
LR_result['recall_2'] = float(report[16])
LR_result['recall_3'] = float(report[21])
LR_result['acc'] = float(report[25])
result_data['LR'] = LR_result

2. 隨機森林(RF)

核心參數:bootstrap=True, criterion='gini', max_depth=None, min_samples_leaf=1, min_samples_split=2, n_estimators=200
特點:集成多棵決策樹降低過擬合風險,Gini 系數作為不純度度量,保留完整決策樹深度

# 網格搜索優化(實際運行時啟用)
# rf = RandomForestClassifier(random_state=42)
# param_grid = {
#     'n_estimators': [100, 200],
#     'max_depth': [None, 20],
#     'min_samples_split': [2, 5],
#     'bootstrap': [True]
# }
# grid_search = GridSearchCV(rf, param_grid, cv=5, n_jobs=-1)
# grid_search.fit(train_X, train_y)
# print("RF最佳參數:", grid_search.best_params_)# 最佳模型訓練
RF_result = {}
rf = RandomForestClassifier(bootstrap=True, criterion='gini', max_depth=None,min_samples_leaf=1, min_samples_split=2, n_estimators=200,random_state=42
)
rf.fit(train_X, train_y)# 評估
train_pred = rf.predict(train_X)
test_pred = rf.predict(test_X)
print("RF訓練集評估:\n", metrics.classification_report(train_y, train_pred))
print("RF測試集評估:\n", metrics.classification_report(test_y, test_pred))# 結果提取
report = metrics.classification_report(test_y, test_pred, digits=6).split()
RF_result['recall_0'] = float(report[6])
RF_result['recall_1'] = float(report[11])
RF_result['recall_2'] = float(report[16])
RF_result['recall_3'] = float(report[21])
RF_result['acc'] = float(report[25])
result_data['RF'] = RF_result

3. 高斯樸素貝葉斯(GNB)

核心參數:var_smoothing=1e-06
特點:基于貝葉斯定理的概率模型,通過 var_smoothing 參數提高數值穩定性

# 網格搜索優化(實際運行時啟用)
# gnb = GaussianNB()
# param_grid = {'var_smoothing': [1e-9, 1e-6, 1e-3]}
# grid_search = GridSearchCV(gnb, param_grid, cv=5)
# grid_search.fit(train_X, train_y)
# print("GNB最佳參數:", grid_search.best_params_)# 最佳模型訓練
GNB_result = {}
gnb = GaussianNB(var_smoothing=1e-06)
gnb.fit(train_X, train_y)# 評估
train_pred = gnb.predict(train_X)
test_pred = gnb.predict(test_X)
print("GNB訓練集評估:\n", metrics.classification_report(train_y, train_pred))
print("GNB測試集評估:\n", metrics.classification_report(test_y, test_pred))# 結果提取
report = metrics.classification_report(test_y, test_pred, digits=6).split()
GNB_result['recall_0'] = float(report[6])
GNB_result['recall_1'] = float(report[11])
GNB_result['recall_2'] = float(report[16])
GNB_result['recall_3'] = float(report[21])
GNB_result['acc'] = float(report[25])
result_data['GNB'] = GNB_result

4. 支持向量機(SVM)

核心參數:C=10, gamma=1, kernel='rbf', max_iter=1000
特點:采用 RBF 核函數處理非線性關系,較大的 C 值表示對誤分類懲罰更嚴格

# 網格搜索優化(實際運行時啟用)
# svm = SVC(random_state=42)
# param_grid = {
#     'kernel': ['rbf'],
#     'C': [1, 10],
#     'gamma': [0.1, 1],
#     'max_iter': [1000]
# }
# grid_search = GridSearchCV(svm, param_grid, cv=5, n_jobs=-1)
# grid_search.fit(train_X, train_y)
# print("SVM最佳參數:", grid_search.best_params_)# 最佳模型訓練
SVM_result = {}
svm = SVC(C=10, gamma=1, kernel='rbf', max_iter=1000, random_state=42)
svm.fit(train_X, train_y)# 評估
train_pred = svm.predict(train_X)
test_pred = svm.predict(test_X)
print("SVM訓練集評估:\n", metrics.classification_report(train_y, train_pred))
print("SVM測試集評估:\n", metrics.classification_report(test_y, test_pred))# 結果提取
report = metrics.classification_report(test_y, test_pred, digits=6).split()
SVM_result['recall_0'] = float(report[6])
SVM_result['recall_1'] = float(report[11])
SVM_result['recall_2'] = float(report[16])
SVM_result['recall_3'] = float(report[21])
SVM_result['acc'] = float(report[25])
result_data['SVM'] = SVM_result

5. AdaBoost

核心參數:algorithm='SAMME', learning_rate=0.5, n_estimators=50
特點:通過 SAMME 算法集成弱分類器,學習率 0.5 控制迭代步長,50 個基分類器

# 網格搜索優化(實際運行時啟用)
# ada = AdaBoostClassifier(random_state=42)
# param_grid = {
#     'n_estimators': [50, 100],
#     'learning_rate': [0.5, 1.0],
#     'algorithm': ['SAMME']
# }
# grid_search = GridSearchCV(ada, param_grid, cv=5, n_jobs=-1)
# grid_search.fit(train_X, train_y)
# print("AdaBoost最佳參數:", grid_search.best_params_)# 最佳模型訓練
Ada_result = {}
ada = AdaBoostClassifier(algorithm='SAMME', learning_rate=0.5, n_estimators=50, random_state=42
)
ada.fit(train_X, train_y)# 評估
train_pred = ada.predict(train_X)
test_pred = ada.predict(test_X)
print("AdaBoost訓練集評估:\n", metrics.classification_report(train_y, train_pred))
print("AdaBoost測試集評估:\n", metrics.classification_report(test_y, test_pred))# 結果提取
report = metrics.classification_report(test_y, test_pred, digits=6).split()
Ada_result['recall_0'] = float(report[6])
Ada_result['recall_1'] = float(report[11])
Ada_result['recall_2'] = float(report[16])
Ada_result['recall_3'] = float(report[21])
Ada_result['acc'] = float(report[25])
result_data['AdaBoost'] = Ada_result

6. XGBoost

核心參數:colsample_bytree=0.8, gamma=0, learning_rate=0.1, max_depth=3, n_estimators=200
特點:基于樹的集成模型,通過列采樣(80%)防止過擬合,深度 3 的樹結構控制復雜度

# 網格搜索優化(實際運行時啟用)
# xgb = XGBClassifier(random_state=42, use_label_encoder=False, eval_metric='mlogloss', num_class=3)
# param_grid = {
#     'n_estimators': [100, 200],
#     'max_depth': [3, 5],
#     'learning_rate': [0.1],
#     'colsample_bytree': [0.8]
# }
# grid_search = GridSearchCV(xgb, param_grid, cv=5, n_jobs=-1)
# grid_search.fit(train_X, train_y_xgb)
# print("XGBoost最佳參數:", grid_search.best_params_)# 最佳模型訓練
XGB_result = {}
xgb_best = XGBClassifier(colsample_bytree=0.8, gamma=0, learning_rate=0.1, max_depth=3,n_estimators=200, reg_alpha=0, reg_lambda=0, subsample=0.8,random_state=42, use_label_encoder=False, eval_metric='mlogloss', num_class=3
)
xgb_best.fit(train_X, train_y_xgb)# 評估(含標簽映射還原)
train_pred_encoded = xgb_best.predict(train_X)
train_pred = [reverse_mapping[code] for code in train_pred_encoded]
test_pred_encoded = xgb_best.predict(test_X)
test_pred = [reverse_mapping[code] if code in reverse_mapping else 2 for code in test_pred_encoded]print("XGBoost訓練集評估:\n", metrics.classification_report(train_y, train_pred))
print("XGBoost測試集評估:\n", metrics.classification_report(test_y, test_pred))# 結果提取
report = metrics.classification_report(test_y, test_pred, digits=6).split()
XGB_result['recall_0'] = float(report[6])
XGB_result['recall_1'] = float(report[11])
XGB_result['recall_2'] = float(report[16])
XGB_result['recall_3'] = float(report[21])
XGB_result['acc'] = float(report[25])
result_data['XGBoost'] = XGB_result# 保存所有結果
with open('..//temp_data//結果數據[刪除空缺行].json', 'w', encoding='utf-8') as f:json.dump(result_data, f, ensure_ascii=False, indent=4)

四、模型評估與結果分析

評估指標

  • 各類別召回率(recall):分別記錄 0、1、2、3 四類的召回率
  • 整體準確率(acc):模型整體分類正確率

評估結果

模型召回率_0召回率_1召回率_2召回率_3準確率(acc)
LR1.00.00.00.00.6
RF0.9333330.3333330.01.00.68
GNB0.7333330.3333330.01.00.56
SVM0.8666670.00.01.00.56
AdaBoost0.80.6666670.01.00.68
XGBoost0.9333330.1666670.01.00.64

結果分析

  • 整體表現:隨機森林(RF)和 AdaBoost 模型表現最優,準確率均達到 0.68;邏輯回歸(LR)和支持向量機(SVM)對標簽 1 的識別能力較弱(召回率為 0)
  • 類別特異性:
    • 標簽 0:邏輯回歸(LR)識別效果最佳(召回率 1.0),XGBoost 和 RF 次之(0.933333)
    • 標簽 1:AdaBoost 表現最優(召回率 0.666667),顯著高于其他模型
    • 標簽 2:所有模型召回率均為 0,主要原因是訓練集中無該類別樣本(因數據空缺已刪除),模型無法學習該類特征
    • 標簽 3:RF、GNB、SVM、AdaBoost、XGBoost 均能 100% 識別,說明該類別特征與其他類別區分度較高
  • 泛化能力:由于訓練集缺失標簽 2 樣本,所有模型對該類別均無識別能力,反映了訓練數據完整性對模型泛化能力的關鍵影響

五、開發總結

  • 完成了 6 種分類模型的訓練與優化,驗證了不同算法在礦物分類任務上的適用性
  • 通過標簽映射機制解決了 XGBoost 對非連續標簽的處理問題,保證了模型間評估標準的一致性
  • 明確了訓練數據空缺對模型性能的影響:標簽 2 因訓練樣本缺失導致所有模型識別失敗
  • 確定了 RF 和 AdaBoost 為當前階段表現最優的模型,為后續系統開發提供了選型依據

后續將進行其他5種預處理生成的數據集(平均值填充、中位數填充、眾數填充、線性回歸填充、隨機森林填充)進行模型訓練。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/93853.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/93853.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/93853.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

通信方式:命名管道

一、命名管道 1. 命名管道的原理 有了匿名管道,理解命名管道就非常簡單了。 對于普通文件而言,兩個進程打開同一個文件,OS是不會將文件加載兩次的,這兩個進程都會指向同一個文件,那么,也就享有同一份 in…

如何將數據庫快速接入大模型實現智能問數,實現chatbi、dataagent,只需短短幾步,不需要配置工作流!

智能問數系統初始化操作流程 一、系統初始化與管理員賬號創建登錄與初始化提示:首次訪問系統登錄頁,若系統未初始化,會彈出 “系統未完成初始化,請初始化管理員賬號” 提示,點擊【去創建】。填寫管理員信息&#xff1a…

告別手寫文檔!Spring Boot API 文檔終極解決方案:SpringDoc OpenAPI

在前后端分離和微服務盛行的今天,API 文檔是團隊協作的“通用語言”。一份清晰、準確、實時同步的文檔,能極大提升開發和聯調效率。然而,手動編寫和維護 API 文檔(如 Word、Markdown 或 Postman)是一場永無止境的噩夢—…

N4200EX是一款全智能超聲波檢測儀產品簡析

N4200EX是一款全智能超聲波檢測儀,適用于石油、石化、天然氣、氣體生產等行業的壓力管路、閥門、設備的各種防爆場合氣體泄漏、真空泄漏、閥門內漏檢測。●本安防爆設計,防爆、防塵、防水、抗摔。●適應惡劣環境,可在-25℃超低溫環境檢測&…

NestJS @Inject 裝飾器入門教程

一、核心概念解析 1.1 依賴注入(DI)的本質 依賴注入是一種設計模式,通過 IoC(控制反轉)容器管理對象生命周期。在 NestJS 中,Injectable() 標記的類會被容器管理,而 Inject() 用于顯式指定依賴項…

網絡地址詳解

子網劃分詳解:從 IP 地址結構到實際應用 在計算機網絡中,子網劃分是一項關鍵的技術,它能幫助我們更高效地管理 IP 地址資源,優化網絡性能。要深入理解子網劃分,首先需要從 IP 地址的基本結構說起。 一、IPv4 地址的基…

吾日三省吾身 | 周反思 8.19

上周一覽總體來說,上個周是一個被項目驅使而險些喪失自主思考能力的危險階段。相比任何有機械化工作經驗的讀者都有類似的體驗,在手上打螺絲的無盡循環中,自己的腦子就會逐漸喪失對自身的感知以及自主思考的能力。而這個負循環一旦開始&#…

08.19總結

連通性 在無向圖中,若任意兩點間均存在路徑相連,則該圖稱為連通圖。 若刪除圖中任意一個頂點后,剩余圖仍保持連通性,則該圖為點雙連通圖。 若刪除圖中任意一條邊后,圖仍保持連通性,則該圖為邊雙連通圖。 在…

車e估牽頭正式啟動乘用車金融價值評估師編制

8月13日,汽車金融行業職業能力評價規范編制啟動工作會議在廣州圓滿落幕。本次會議由中國機械工業聯合會機械工業人才評價中心主辦,廣州穗圣信息科技有限公司(車e估)承辦。會議匯聚了眾多行業精英,包括中國機械工業聯合…

清空 github 倉庫的歷史提交記錄(創建新分支)

想在 現有倉庫中創建一個新分支 master,刪除原來的 main,然后把 master 重命名為 main,并且清空歷史。可以用下面一條完整的命令序列操作: # 1. 創建一個沒有歷史的新分支 master git checkout --orphan master# 2. 添加當前所有文…

使用B210在Linux下實時處理ETC專用短程通信數據(2)-CPU單核高速數據處理

在上一篇文章中,使用Octave初步驗證了ETC車聯數據的格式。然而,Octave無法實時處理20M的采樣帶寬。我們本節通過C語言,重寫 Octave程序,實現實時處理,涉及下面三個關鍵特點。 文章目錄1. 全靜態內存2. 使用環狀緩存3 無…

Spark 運行流程核心組件(二)任務調度

1、調度策略參數默認值說明spark.scheduler.modeFIFO調度策略(FIFO/FAIR)spark.locality.wait3s本地性降級等待時間spark.locality.wait.processspark.locality.waitPROCESS_LOCAL 等待時間spark.locality.wait.nodespark.locality.waitNODE_LOCAL 等待時…

Orbbec---setBoolProperty 快捷配置設備行為

在奧比中光(Orbbec)SDK(通常稱為ob庫)中,setBoolProperty函數是用于設置設備或傳感器的布爾類型屬性的核心接口。它主要用于開啟/關閉設備的某些功能或模式,是配置設備行為的重要方法。 函數原型與參數解析…

[OWASP]智能體應用安全保障指南

1.關鍵組件定義 KC1 生成式語言模型(Generative Language Models) KC1.1 大語言模型(LLMs):作為代理的“大腦”,基于預訓練基礎模型(如 GPT 系列、Claude、Llama、Gemini)&#xff…

【Vivado TCL 教程】從零開始掌握 Xilinx Vivado TCL 腳本編程(三)

【Vivado TCL 教程】從零開始掌握 Xilinx Vivado TCL 腳本編程(三) 系列文章目錄 1、VMware Workstation Pro安裝指南:詳細步驟與配置選項說明 2、VMware 下 Ubuntu 操作系統下載與安裝指南 3、基于 Ubuntu 的 Linux 系統中 Vivado 2020.1 下…

AI與大數據驅動下的食堂采購系統源碼:供應鏈管理平臺的未來發展

在數字化浪潮不斷加速的今天,很多企業和機構都在追求一個目標:如何把“效率”與“成本”做到最佳平衡。對于學校、企事業單位的食堂來說,采購環節就是重中之重。往小了說,它關系到食堂員工的工作體驗;往大了說&#xf…

HarmonyOS 實戰:學會在鴻蒙中使用第三方 JavaScript 庫(附完整 Demo)

摘要 在鴻蒙(HarmonyOS NEXT / ArkTS)開發中,我們大部分業務邏輯和 UI 都是用 ArkTS 寫的。不過在做一些數據處理、網絡請求、工具函數或者復雜算法時,完全沒必要“重復造輪子”。這時候就可以直接引入 JavaScript 的第三方庫。鴻…

C++實現教務管理系統,文件操作賬戶密碼登錄(附源碼)

教務管理系統項目介紹 項目概述 這是一個基于C開發的教務管理系統,提供了學生、教師和系統管理員三種角色的功能模塊,實現了教務信息的錄入、查詢、修改和刪除等基本操作。系統采用文件存儲方式保存數據,具有簡單易用、功能完備的特點。 項…

《C++進階之STL》【二叉搜索樹】

【二叉搜索樹】目錄前言:------------概念介紹------------1. 什么是二叉搜索樹?2. 二叉搜索樹的性能怎么樣?------------基本操作------------一、查找操作思想步驟簡述二、插入操作目標步驟簡述三、刪除操作目標步驟簡述------------代碼實現--------…

Orange的運維學習日記--47.Ansible進階之異步處理

Orange的運維學習日記–47.Ansible進階之異步處理 文章目錄Orange的運維學習日記--47.Ansible進階之異步處理Playbook 執行順序原理可選執行策略調整并發連接數:forks 參數查看與修改 forks性能調優建議分批執行全局任務:serial 關鍵字serial 用法示例應…