基于深度學習的毒蘑菇檢測

文章目錄

    • 任務介紹
    • 數據概覽
    • 數據處理
        • 數據讀取與拼接
        • 字符數據轉化
        • 標簽數據映射
        • 數據集劃分
        • 數據標準化
    • 模型構建與訓練
        • 模型構建
        • 數據批處理
        • 模型訓練
    • 文件提交
    • 結果
    • 附錄

任務介紹

本次任務為毒蘑菇的二元分類,任務本身并不復雜,適合初學者,主要亮點在于對字符數據的處理,還有嘗試了加深神經網絡深度的效果,之后讀者也可自行改變觀察效果,比賽路徑將于附錄中給出。

數據概覽

本次任務的數據集比較簡單

  • train.csv 訓練文件
  • test.csv 測試文件
  • sample_submission.csv 提交示例文件

具體內容就是關于毒蘑菇的各種特征,可在附錄中獲取數據集。

數據處理

數據讀取與拼接

這段代碼提取了數據文件,并且對兩個不同來源的數據集進行了拼接,當我們的數據集較小時,就可采用這種方法,獲取其他的數據集并將兩個數據集合并起來。

import pandas as pd
from sklearn.preprocessing import LabelEncoder, StandardScaler
file = pd.read_csv("/kaggle/input/playground-series-s4e8/train.csv", index_col="id")
file2 = pd.read_csv("/kaggle/input/mushroom-classification-edible-or-poisonous/mushroom.csv")
file_all = pd.concat([file, file2])
字符數據轉化

這段代碼主要就是提取出字符數據,因為字符是無法直接被計算機處理,所以我們提取出來后,再將字符數據映射為數字數據。

char_features = ['cap-shape', 'cap-surface', 'cap-color', 'does-bruise-or-bleed', 'gill-attachment', 'gill-spacing', 'gill-color', 'stem-root', 'stem-surface', 'stem-color', 'veil-type', 'veil-color', 'has-ring', 'ring-type', 'spore-print-color', 'habitat', 'season']
for i in char_features:file_all[i] = LabelEncoder().fit_transform(file_all[i])
file_all = file_all.fillna(0)
train_col = ['cap-diameter', 'stem-height', 'stem-width', 'cap-shape', 'cap-surface', 'cap-color', 'does-bruise-or-bleed', 'gill-attachment', 'gill-spacing', 'gill-color', 'stem-root', 'stem-surface', 'stem-color', 'veil-type', 'veil-color', 'has-ring', 'ring-type', 'spore-print-color', 'habitat', 'season']
X = file_all[train_col]
y = file_all['class']
標簽數據映射

除了用上述方法進行字符轉化外,還可以使用map函數,以下是具體操作。

y.unique()
# 構建映射字典
applying = {'e': 0, 'p': 1}
y = y.map(applying)
數據集劃分

這段代碼使用sklearn庫將數據集劃分為訓練集和測試集。

from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=0.1)
x_train.shape, y_train.shape
數據標準化

這段代碼將我們的數據進行歸一化,減小數字大小方便計算,但是仍然保持他們之間的線性關系,不會對結果產生影響。

scaler = StandardScaler()
x_train = scaler.fit_transform(x_train)
x_test = scaler.fit_transform(x_test)

模型構建與訓練

這段代碼使用torch庫構建了深度學習模型,主要運用了線性層,還進行了正則化操作,防止模型過擬合。

模型構建
import torch
import torch.nn as nn
class Model(nn.Module):def __init__(self):super().__init__()self.linear = nn.Linear(20, 256)self.relu = nn.ReLU()self.dropout = nn.Dropout(p=0.2)self.linear1 = nn.Linear(256, 128)self.linear2 = nn.Linear(128, 64)self.linear3 = nn.Linear(64, 48)self.linear4 = nn.Linear(48, 32)self.linear5 = nn.Linear(32, 2)def forward(self, x):out = self.linear(x)out = self.relu(out)out = self.linear1(out)out = self.relu(out)out = self.dropout(out)out = self.linear2(out)out = self.relu(out)out = self.linear3(out)out = self.dropout(out)out = self.relu(out)out = self.linear4(out)out = self.relu(out)out = self.linear5(out)return out

對模型類進行實例化。

model = Model()
數據批處理

由于數據一條一條的處理起來很慢,因此我們可以將數據打包,一次給模型輸入多條數據,能有效節省時間。

import torch.nn.functional as F
class Dataset(torch.utils.data.Dataset):def __init__(self, x, y):self.x = xself.y = ydef __len__(self):return len(self.x)def __getitem__(self, i):x = torch.Tensor(self.x[i])y = torch.tensor(self.y.iloc[i])return x, y
train_data = Dataset(x_train, y_train)
test_data = Dataset(x_test, y_test)
loader = torch.utils.data.DataLoader(train_data, batch_size=64, drop_last=True, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=256, drop_last=True, shuffle=True)
模型訓練

這段代碼就是模型的訓練過程,包括創建優化器,定義損失函數等,還在訓練過程中測試準確率與損失函數值,動態的觀察訓練過程。

from tqdm import tqdm
import matplotlib.pyplot as plt
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, weight_decay=1e-5)
from sklearn.metrics import matthews_corrcoef
flag = 0
for i in range(10):for x, label in tqdm(loader):out = model(x)loss = criterion(out, label)loss.backward()optimizer.step()optimizer.zero_grad()flag+=1if flag%500 == 0:test = next(iter(test_loader))t_out = model(test[0]).argmax(dim=1)print("loss=", loss.item())acc = (t_out == test[1]).sum().item()/len(test[1])mcc = matthews_corrcoef(t_out, test[1])print("acc=", acc)print("mcc=", mcc)

文件提交

這段代碼主要就是使用訓練好的模型在測試集上預測,并且將其整合成提交文件。

test_file = pd.read_csv("/kaggle/input/playground-series-s4e8/test.csv")
for i in char_features:test_file[i] = LabelEncoder().fit_transform(test_file[i])
test_file.fillna(0)
test_x = torch.Tensor(test_file[train_col].values)
test_x = torch.Tensor(scaler.fit_transform(test_x))
out = model(test_x)
out = pd.Series(out.argmax(dim=1))
map2 = {0: 'e', 1: 'p'}
result = out.map(map2)
answer = pd.DataFrame({'id': test_file['id'], "class": result})
answer.to_csv('submission.csv', index=False)

結果

將文件提交后,得到了0.97的成績,已經非常接近1了,證明模型的效果非常不錯。
在這里插入圖片描述

附錄

比賽鏈接:https://www.kaggle.com/competitions/playground-series-s4e8
額外數據集地址:https://www.kaggle.com/datasets/vishalpnaik/mushroom-classification-edible-or-poisonous

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/903931.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/903931.shtml
英文地址,請注明出處:http://en.pswp.cn/news/903931.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

時間給了我們什么?

時間給了我們什么? ?春秋易逝,青春難留,轉瞬之間已過半百。 ?過往中,有得有失,這就是人生。 ?一日三餐四季,日起日落里,成就了昨天、今天和明天,在歷史長河中,皆是…

軟件工程國考

軟件工程-同等學力計算機綜合真題及答案 (2004-2014、2017-2024) 2004 年軟工 第三部分 軟件工程 (共 30 分) 一、單項選擇題(每小題 1 分,共 5 分) 軟件可用性是指( &#xff09…

數據結構*棧

棧 什么是棧 這里的棧與我們之前常說的棧是不同的。之前我們說的棧是內存棧,它是JVM內存的一部分,用于存儲局部變量、方法調用信息等。每個線程都有自己獨立的棧空間,當線程啟動時,棧就會被創建;線程結束&#xff0c…

IntelliJ IDEA 保姆級使用教程

文章目錄 一、創建項目二、創建模塊三、創建包四、創建類五、編寫代碼六、運行代碼注意 七、IDEA 常見設置1、主題2、字體3、背景色 八、IDEA 常用快捷鍵九、IDEA 常見操作9.1、類操作9.1.1、刪除類文件9.1.2、修改類名稱注意 9.2、模塊操作9.2.1、修改模塊名快速查看 9.2.2、導…

HTTP 快速解析

一、HTTP請求結構 HTTP請求和響應報文由以下部分組成(以請求報文為例): 請求報文結構: 請求行:包含HTTP方法(如GET/POST)、請求URL和協議版本(如HTTP/1.1,HTTP/2.0&…

【AI學習】李宏毅新課《DeepSeek-R1 這類大語言模型是如何進行「深度思考」(Reasoning)的?》的部分紀要

針對推理模型,主要講了四種方法,兩種不需要訓練模型,兩種需要。 對于reason和inference,這兩個詞有不同的含義! 推理時計算不是新鮮事,AlphaGo就是如此。 這張圖片說明了將訓練和推理時計算綜合考慮的關系&…

Kotlin Flow流

一 Kotlin Flow 中的 stateIn 和 shareIn 一、簡單比喻理解 想象一個水龍頭(數據源)和幾個水杯(數據接收者): 普通 Flow(冷流):每個水杯來接水時,都要重新打開水龍頭從…

【嵌入式Linux】基于ARM-Linux的zero2平臺的智慧樓宇管理系統項目

目錄 1. 需求及項目準備(此項目對于虛擬機和香橙派的配置基于上一個垃圾分類項目,如初次開發,兩個平臺的環境變量,阿里云接入,攝像頭配置可參考垃圾分類項目)1.1 系統框圖1.2 硬件接線1.3 語音模塊配置1.4 …

Linux運維中常用的磁盤監控方式

在Linux運維中,磁盤監控是一項關鍵任務,因為它能幫助我們預防磁盤空間不足或性能問題導致的服務中斷或數據丟失。讓我們來看看有哪些常用的磁盤監控方法吧! 1. 查看磁盤使用情況(df命令) df命令用于顯示文件系統的…

OpenCV第6課 圖像處理之幾何變換(縮放)

1.簡述 圖像幾何變換又稱為圖像空間變換,它將一幅圖像中的坐標位置映射到另一幅圖像中的新坐標位置。幾何變換并不改變圖像的像素值,只是在圖像平面上進行像素的重新安排。 根據OpenCV函數的不同,本節課將映射關系劃分為縮放、翻轉、仿射變換、透視等。 2.縮放 2.1 函數…

(35)VTK C++開發示例 ---將圖片映射到平面2

文章目錄 1. 概述2. CMake鏈接VTK3. main.cpp文件4. 演示效果 更多精彩內容👉內容導航 👈👉VTK開發 👈 1. 概述 與上一個示例不同的是,使用vtkImageReader2Factory根據文件擴展名或內容自動創建對應的圖像文件讀取器&a…

【模型量化】量化基礎

目錄 一、認識量化 二、量化基礎原理 2.1 對稱量化和非對稱量化 2.1.1 對稱量化 2.1.2 非對稱量化 2.1.3 量化后的矩陣乘 2.2 神經網絡量化 2.2.1 動態量化 2.2.2 靜態量化 2.3 量化感知訓練 一、認識量化 量化的主要目的是節約顯存、提高計算效率以及加快通信 dee…

【零基礎入門】一篇掌握Python中的字典(創建、訪問、修改、字典方法)【詳細版】

?? 個人主頁:十二月的貓-CSDN博客 ?? 系列專欄: ??《PyTorch科研加速指南:即插即用式模塊開發》-CSDN博客 ???? 十二月的寒冬阻擋不了春天的腳步,十二點的黑夜遮蔽不住黎明的曙光 目錄 1. 前言 2. 字典 2.1 字典的創建 2.1.1 大括號+直接賦值 2.1.2 大括號…

PHP-session

PHP中,session(會話)是一種在服務器上存儲用戶數據的方法,這些數據可以在多個頁面請求或訪問之間保持。Session提供了一種方式來跟蹤用戶狀態,比如登錄信息、購物車內容等。當用戶首次訪問網站時,服務器會創…

第 5 篇:紅黑樹:工程實踐中的平衡大師

上一篇我們探討了為何有序表需要“平衡”機制來保證 O(log N) 的穩定性能。現在,我們要認識一位在實際工程中應用最廣泛、久經考驗的“平衡大師”——紅黑樹 (Red-Black Tree)。 如果你用過 Java 的 TreeMap? 或 TreeSet?,或者 C STL 中的 map? 或 s…

第十六屆藍橋杯 2025 C/C++組 客流量上限

目錄 題目: 題目描述: 題目鏈接: 思路: 打表找規律: 核心思路: 思路詳解: 得到答案的方式: 按計算器: 暴力求解代碼: 快速冪代碼: 位運…

一天學完JDBC!!(萬字總結)

文章目錄 JDBC是什么 1、環境搭建 && 入門案例2、核心API理解①、注冊驅動(Driver類)②、Connection③、statement(sql注入)④、PreparedStatement⑤、ResultSet 3、jdbc擴展(ORM、批量操作)①、實體類和ORM②、批量操作 4. 連接池①、常用連接池②、Durid連接池③、Hi…

從原理到實戰講解回歸算法!!!

哈嘍,大家好,我是我不是小upper, 今天系統梳理了線性回歸的核心知識,從模型的基本原理、參數估計方法,到模型評估指標與實際應用場景,幫助大家深入理解這一經典的機器學習算法,助力數據分析與預測工作。 …

【dify—10】工作流實戰——文生圖工具

目錄 一、創建工作流 應用 二、安裝硅基流動 三、配置硅基流動 四、API測試 (1)進入API文檔 (2)復制curl代碼 (3)Postman測試API 五、 建立文生圖工作流 (1)建立http請求 &…

Rust將結構導出到json如何處理小數點問題

簡述 標準的 serde_json 序列化器不支持直接對浮點數進行格式化限制。如果將浮點數轉換成字符串,又太low逼。這里重點推薦rust_decimal。 #[derive(Serialize)] pub struct StockTickRow {datetime: NaiveDateTime,code: String,name: String,#[serde(serialize_w…