Ubuntu系統:22.04
python版本:3.9
安裝依賴庫:
pip install tensorflow==2.13 matplotlib numpy -i https://mirrors.aliyun.com/pypi/simple
代碼實現:
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
import numpy as np
import matplotlib.pyplot as plt# 加載MNIST數據集
mnist = tf.keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()# 數據預處理
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32') / 255
test_images = test_images.reshape(test_images.shape[0], 28, 28, 1).astype('float32') / 255# 構建CNN模型
model = Sequential()
model.add(Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=(28, 28, 1)))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(64, kernel_size=(3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dense(10, activation='softmax'))# 編譯模型
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])# 訓練模型
history = model.fit(train_images, train_labels,batch_size=128,epochs=5,verbose=1,validation_data=(test_images, test_labels))# 評估模型
test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=0)
print(f"\n測試準確率: {test_acc:.4f}")# 保存模型
model.save('mnist_cnn_model.keras')
print("模型已保存為 mnist_cnn_model.keras")# 可視化訓練過程
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='訓練準確率')
plt.plot(history.history['val_accuracy'], label='驗證準確率')
plt.title('模型準確率')
plt.ylabel('準確率')
plt.xlabel('訓練輪次')
plt.legend()plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='訓練損失')
plt.plot(history.history['val_loss'], label='驗證損失')
plt.title('模型損失')
plt.ylabel('損失')
plt.xlabel('訓練輪次')
plt.legend()plt.tight_layout()
plt.savefig('training_history.png')
print("訓練過程圖表已保存為 training_history.png")# 測試預測
sample_idx = np.random.randint(0, len(test_images))
sample_image = test_images[sample_idx].reshape(1, 28, 28, 1)
prediction = model.predict(sample_image, verbose=0)plt.figure(figsize=(5, 3))
plt.imshow(test_images[sample_idx].reshape(28, 28), cmap='gray')
plt.title(f"真實標簽: {test_labels[sample_idx]}\n預測結果: {np.argmax(prediction)}")
plt.axis('off')
plt.savefig('sample_prediction.png')
print(f"樣本預測圖已保存為 sample_prediction.png\n真實標簽: {test_labels[sample_idx]},預測結果: {np.argmax(prediction)}")