機器學習 - Kaggle項目實踐(3)Digit Recognizer 手寫數字識別

Digit Recognizer | Kaggle 題面

Digit Recognizer-CNN | Kaggle 下面代碼的kaggle版本

使用CNN進行手寫數字識別

學習到了網絡搭建手法+學習率退火+數據增廣 提高訓練效果。

使用混淆矩陣 以及對分類出錯概率最大的例子單獨拎出來分析。

最終以99.546%正確率 排在 86/1035

1. 數據準備

拆成X和Y 輸出每個數字 train中的數據量

import pandas as pd
import seaborn as sns
train = pd.read_csv("/kaggle/input/digit-recognizer/train.csv")
test = pd.read_csv("/kaggle/input/digit-recognizer/test.csv")
Y_train = train["label"]
X_train = train.drop(labels = ["label"],axis = 1) 
del train g = sns.countplot(x=Y_train)
Y_train.value_counts()

1.確認數據干凈 → 沒有缺失值

print(X_train.isnull().any().describe()) # 下面代表所有列 都是False
print(test.isnull().any().describe())

2.灰度歸一化(除以255) → 提升訓練速度和穩定性。

reshape → 轉換成 CNN 所需的 (height, width, channel) 格式。 ?

原來是(樣本數, 784) -> (樣本數, 高度, 寬度, 通道數)? 樣本數, 28, 28, 1

X_train, test = X_train / 255.0, test / 255.0
X_train = X_train.values.reshape(-1, 28, 28, 1)
test = test.values.reshape(-1, 28, 28, 1)

3.One-Hot 編碼 只有一個位置為1的標簽 → 后續 softmax 輸出10個類別的概率。

from tensorflow.keras.utils import to_categorical
Y_train = to_categorical(Y_train, num_classes=10)

4.劃分驗證集 → 用于調參、防止過擬合。

from sklearn.model_selection import train_test_split
X_train, X_val, Y_train, Y_val = train_test_split(X_train, Y_train, test_size=0.1, random_state=2)

5.可視化一個樣本 → 檢查數據是否正確加載。

import matplotlib.pyplot as plt
plt.imshow(X_train[0][:,:,0], cmap='gray')
plt.title("Example Image")
plt.show()

2. CNN模型建構

卷積塊:2個Conv(5 * 5)+ BN + 池化? ? ? ~? ? ?2個Conv(3 * 3) + BN + 池化

Dropout舍棄一些神經元防止過擬合。

import numpy as np
import tensorflow as tffrom tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D as MaxPool2D
from tensorflow.keras.layers import Dense, Dropout, Flatten, BatchNormalization
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ReduceLROnPlateau
from tensorflow.keras.preprocessing.image import ImageDataGeneratormodel = Sequential()# 卷積塊 1
model.add(Conv2D(32, (5,5), padding='same', activation='relu', input_shape=(28,28,1)))
model.add(BatchNormalization())   # 批歸一化
model.add(Conv2D(32, (5,5), padding='same', activation='relu'))
model.add(BatchNormalization())
model.add(MaxPool2D(pool_size=(2,2)))
model.add(Dropout(0.25))# 卷積塊 2
model.add(Conv2D(64, (3,3), padding='same', activation='relu'))
model.add(BatchNormalization())
model.add(Conv2D(64, (3,3), padding='same', activation='relu'))
model.add(BatchNormalization())
model.add(MaxPool2D(pool_size=(2,2), strides=(2,2)))
model.add(Dropout(0.25))

輸出層:Flatten + Dense + Softmax

model.add(Flatten())
model.add(Dense(256, activation='relu'))
model.add(BatchNormalization())
model.add(Dropout(0.5))
model.add(Dense(10, activation='softmax'))

model.compile? ? ?Adam優化器 + 交叉熵損失 + accuracy評估

optimizer = Adam(learning_rate=0.001)model.compile(optimizer=optimizer,loss='categorical_crossentropy',metrics=['accuracy'])

學習率退火 若 3 個 epoch 內沒提升,就把 lr 減半

learning_rate_reduction = ReduceLROnPlateau(monitor='val_accuracy', patience=3, verbose=1,factor=0.5, min_lr=1e-5
)

數據增廣(ImageDataGenerator) 輕微旋轉、縮放、平移等擴充數據

datagen = ImageDataGenerator(rotation_range=10,zoom_range=0.1,width_shift_range=0.1,height_shift_range=0.1
)

訓練 設置 epoch? batch_size? steps

datagen.fit(X_train)# ===== 訓練 =====
epochs = 30
batch_size = 86
steps = int(np.ceil(len(X_train) / batch_size)) # 一次訓練覆蓋所有樣本history = model.fit(datagen.flow(X_train, Y_train, batch_size=batch_size), # 數據增強epochs=epochs,steps_per_epoch=steps, validation_data=(X_val, Y_val), # 驗證數據callbacks=[learning_rate_reduction], # 回調學習率verbose=2 # 每個 epoch 輸出一次
)
440/440 - 144s - 327ms/step - accuracy: 0.9950 - loss: 0.0170 - val_accuracy: 0.9957 - val_loss: 0.0134 - learning_rate: 6.2500e-05
Epoch 30/30Epoch 30: ReduceLROnPlateau reducing learning rate to 3.125000148429535e-05.
440/440 - 145s - 329ms/step - accuracy: 0.9952 - loss: 0.0153 - val_accuracy: 0.9955 - val_loss: 0.0141 - learning_rate: 6.2500e-05

這是訓練日志的最后一部分輸出 準確率達到 99.5%

3. 模型評估 evaluation

1. history報告中 訓練集和驗證集的 Loss 和 Accuracy 可視化。

val為驗證,驗證不比訓練差 說明沒有太過擬合。

fig, ax = plt.subplots(2,1)# 繪制訓練集和驗證集的 Loss
ax[0].plot(history.history['loss'], color='b', label="Training loss")
ax[0].plot(history.history['val_loss'], color='r', label="validation loss")
ax[0].legend(loc='best', shadow=True)# 繪制訓練集和驗證集的 Accuracy
ax[1].plot(history.history['accuracy'], color='b', label="Training accuracy")
ax[1].plot(history.history['val_accuracy'], color='r', label="Validation accuracy")
ax[1].legend(loc='best', shadow=True)

2. 混淆矩陣 confusion_matrix

from sklearn.metrics import confusion_matrix
import itertoolsdef plot_confusion_matrix(cm, classes,normalize=False,title='Confusion matrix',cmap=plt.cm.Blues):plt.imshow(cm, interpolation='nearest', cmap=cmap)plt.title(title)plt.colorbar()tick_marks = np.arange(len(classes))plt.xticks(tick_marks, classes, rotation=45)plt.yticks(tick_marks, classes)if normalize:cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]thresh = cm.max() / 2.for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):plt.text(j, i, cm[i, j],horizontalalignment="center",color="white" if cm[i, j] > thresh else "black")plt.tight_layout()plt.ylabel('True label')plt.xlabel('Predicted label')Y_pred = model.predict(X_val) # 預測值
Y_pred_classes = np.argmax(Y_pred, axis = 1) # 預測對應類別
Y_true = np.argmax(Y_val,axis = 1) # 真實類別confusion_mtx = confusion_matrix(Y_true, Y_pred_classes) 
plot_confusion_matrix(confusion_mtx, classes = range(10)) 

3. 對于分類錯誤的位置 用預測(錯誤數字的概率 - 正確標簽的概率)

輸出預測錯誤差別最大的圖片 是什么樣的。

# 找出預測錯誤的位置 布爾數組
errors = (Y_pred_classes != Y_true )# 提取錯誤預測的詳細數據
Y_pred_classes_errors = Y_pred_classes[errors]
Y_pred_errors = Y_pred[errors]
Y_true_errors = Y_true[errors]
X_val_errors = X_val[errors]# 錯誤預測的最大概率
Y_pred_errors_prob = np.max(Y_pred_errors,axis = 1)# 正確標簽對應的預測概率
true_prob_errors = np.diagonal(np.take(Y_pred_errors, Y_true_errors, axis=1))# 差值:預測的最大概率 - 正確標簽概率
delta_pred_true_errors = Y_pred_errors_prob - true_prob_errors# 找出概率差最大的錯誤
most_important_errors = np.argsort(delta_pred_true_errors)[-6:]def display_errors(errors_index,img_errors,pred_errors, obs_errors):""" This function shows 6 images with their predicted and real labels"""n = 0nrows = 2ncols = 3fig, ax = plt.subplots(nrows,ncols,sharex=True,sharey=True)for row in range(nrows):for col in range(ncols):error = errors_index[n]ax[row,col].imshow((img_errors[error]).reshape((28,28)))ax[row,col].set_title("Predicted label :{}\nTrue label :{}".format(pred_errors[error],obs_errors[error]))n += 1display_errors(most_important_errors, X_val_errors, Y_pred_classes_errors, Y_true_errors)

發現預測錯誤的幾張圖片 本身就很容易誤解

4. 預測&提交結果

results = model.predict(test)
results = np.argmax(results,axis = 1)  # 概率轉類別
results = pd.Series(results,name="Label")submission = pd.concat([pd.Series(range(1,28001),name = "ImageId"),results],axis = 1)
submission.to_csv("cnn_mnist_datagen.csv",index=False)

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/93237.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/93237.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/93237.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

新手如何高效運營亞馬遜跨境電商:從傳統SP廣告到DeepBI智能策略

"為什么我的廣告點擊量很高但訂單轉化率卻很低?""如何避免新品期廣告預算被大詞消耗殆盡?""為什么手動調整關鍵詞和出價總是慢市場半拍?""競品ASIN投放到底該怎么做才有效?""有沒有…

【論文閱讀 | CVPR 2024 | UniRGB-IR:通過適配器調優實現可見光-紅外語義任務的統一框架】

論文閱讀 | CVPR 2024 | UniRGB-IR:通過適配器調優實現可見光-紅外語義任務的統一框架?1&&2. 摘要&&引言3.方法3.1 整體架構3.2 多模態特征池3.3 補充特征注入器3.4 適配器調優范式4 實驗4.1 RGB-IR 目標檢測4.2 RGB-IR 語義分割4.3 RGB-IR 顯著目…

Hyperf 百度翻譯接口實現方案

保留 HTML/XML 標簽結構,僅翻譯文本內容,避免破壞富文本格式。采用「HTML 解析 → 文本提取 → 批量翻譯 → 回填」的流程。百度翻譯集成方案:富文本內容翻譯系統 HTML 解析 百度翻譯 API 集成 文件結構 app/ ├── Controller/ │ └──…

字節跳動 VeOmni 框架開源:統一多模態訓練效率飛躍!

資料來源:火山引擎-開發者社區 多模態時代的訓練痛點,終于有了“特效藥” 當大模型從單一語言向文本 圖像 視頻的多模態進化時,算法工程師們的訓練流程卻陷入了 “碎片化困境”: 當業務要同時迭代 DiT、LLM 與 VLM時&#xff0…

配置docker pull走http代理

之前寫了一篇自建Docker鏡像加速器服務的博客,需要用到境外服務器作為代理,但是一般可能沒有境外服務器,只有http代理,所以如果本地使用想走代理可以用以下方式 臨時生效(只對當前終端有效) 設置環境變量…

OpenAI 開源模型 gpt-oss 本地部署詳細教程

OpenAI 最近發布了其首個開源的開放權重模型gpt-oss,這在AI圈引起了巨大的轟動。對于廣大開發者和AI愛好者來說,這意味著我們終于可以在自己的機器上,完全本地化地運行和探索這款強大的模型了。 本教程將一步一步指導你如何在Windows和Linux…

力扣-5.最長回文子串

題目鏈接 5.最長回文子串 class Solution {public String longestPalindrome(String s) {boolean[][] dp new boolean[s.length()][s.length()];int maxLen 0;String str s.substring(0, 1);for (int i 0; i < s.length(); i) {dp[i][i] true;}for (int len 2; len …

Apache Ignite超時管理核心組件解析

這是一個非常關鍵且設計精巧的 定時任務與超時管理組件 —— GridTimeoutProcessor&#xff0c;它是 Apache Ignite 內核中負責 統一調度和處理所有異步超時事件的核心模塊。&#x1f3af; 一、核心職責統一管理所有需要“在某個時間點觸發”的任務或超時邏輯。它相當于 Ignite…

DAY 42 Grad-CAM與Hook函數

知識點回顧回調函數lambda函數hook函數的模塊鉤子和張量鉤子Grad-CAM的示例# 定義一個存儲梯度的列表 conv_gradients []# 定義反向鉤子函數 def backward_hook(module, grad_input, grad_output):# 模塊&#xff1a;當前應用鉤子的模塊# grad_input&#xff1a;模塊輸入的梯度…

基于 NVIDIA 生態的 Dynamo 風格分布式 LLM 推理架構

網羅開發&#xff08;小紅書、快手、視頻號同名&#xff09;大家好&#xff0c;我是 展菲&#xff0c;目前在上市企業從事人工智能項目研發管理工作&#xff0c;平時熱衷于分享各種編程領域的軟硬技能知識以及前沿技術&#xff0c;包括iOS、前端、Harmony OS、Java、Python等方…

《吃透 C++ 類和對象(中):拷貝構造函數與賦值運算符重載深度解析》

&#x1f525;個人主頁&#xff1a;草莓熊Lotso &#x1f3ac;作者簡介&#xff1a;C研發方向學習者 &#x1f4d6;個人專欄&#xff1a; 《C語言》 《數據結構與算法》《C語言刷題集》《Leetcode刷題指南》 ??人生格言&#xff1a;生活是默默的堅持&#xff0c;毅力是永久的…

Python 環境隔離實戰:venv、virtualenv 與 conda 的差異與最佳實踐

那天把項目部署到測試環境&#xff0c;結果依賴沖突把服務拉崩了——本地能跑&#xff0c;線上不能跑。折騰半天才發現&#xff1a;我和同事用的不是同一套 site-packages&#xff0c;版本差異導致運行時異常。那一刻我徹底明白&#xff1a;虛擬環境不是可選項&#xff0c;它是…

[ 數據結構 ] 時間和空間復雜度

1.算法效率算法效率分析分為兩種 : ①時間效率, ②空間效率 時間效率即為 時間復雜度 , 時間復雜度主要衡量一個算法的運行速度空間效率即為 空間復雜度 , 空間復雜度主要衡量一個算法所需要的額外空間2.時間復雜度2.1 時間復雜度的概念定義 : 再計算機科學中 , 算法的時間復雜…

一,設計模式-單例模式

目的設計單例模式的目的是為了解決兩個問題&#xff1a;保證一個類只有一個實例這種需求是需要控制某些資源的共享權限&#xff0c;比如文件資源、數據庫資源。為該實例提供一個全局訪問節點相較于通過全局變量保存重要的共享對象&#xff0c;通過一個封裝的類對象&#xff0c;…

AIStarter修復macOS 15兼容問題:跨平臺AI項目管理新體驗

AIStarter是全網唯一支持Windows、Mac和Linux的AI管理平臺&#xff0c;為開發者提供便捷的AI項目管理體驗。近期&#xff0c;熊哥在視頻中分享了針對macOS 15系統無法打開AIStarter的修復方案&#xff0c;最新版已完美兼容。本文基于視頻內容&#xff0c;詳解修復細節與使用技巧…

LabVIEW 紡織檢測數據傳遞

基于 LabVIEW 實現紡織檢測系統中上位機&#xff08;PC 機&#xff09;與下位機&#xff08;單片機&#xff09;的串口數據傳遞&#xff0c;成功應用于煮繭機溫度測量系統。通過采用特定硬件架構與軟件設計&#xff0c;實現了溫度數據的高效采集、傳輸與分析&#xff0c;操作簡…

ECCV-2018《Variational Wasserstein Clustering》

核心思想 該論文提出了一個基于最優傳輸(optimal transportation) 理論的新型聚類方法&#xff0c;稱為變分Wasserstein聚類(Variational Wasserstein Clustering, VWC)。其核心思想有三點&#xff1a;建立最優傳輸與k-means聚類的聯系&#xff1a;作者指出k-means聚類問題本質…

部署 Docker 應用詳解(MySQL + Tomcat + Nginx + Redis)

文章目錄一、MySQL二、Tomcat三、Nginx四、Redis一、MySQL 搜索 MySQL 鏡像下載 MySQL 鏡像創建 MySQL 容器 docker run -i -t/d -p 3307:3306 --namec_mysql -v $PWD/conf:/etc/mysql/conf.d -v $PWD/logs:/logs -v $PWD/data:/var/lib/mysql -e MYSQL_ROOT_PASSWORD123456 m…

VR全景導覽在大型活動中的應用實踐:優化觀眾體驗與現場管理

大型演出賽事往往吸引海量觀眾&#xff0c;但復雜的場館環境常帶來諸多困擾&#xff1a;如何快速找到座位看臺區域&#xff1f;停車位如何規劃&#xff1f;附近公交地鐵站在哪&#xff1f;這些痛點直接影響觀眾體驗與現場秩序。VR全景技術為解決這些問題提供了有效方案。通過在…

OpenJDK 17 JIT編譯器堆棧分析

##堆棧(gdb) bt #0 PhaseOutput::safepoint_poll_table (this0x7fffd0bfb950) at /home/yym/openjdk17/jdk17-master/src/hotspot/share/opto/output.hpp:173 #1 0x00007ffff689634e in PhaseOutput::fill_buffer (this0x7fffd0bfb950, cb0x7fffd0bfb970, blk_starts0x7fffb0…