目錄
1、序言
2、主要代碼
3、運行結果展示
(1)展示cifar10里面的20張圖片
(2)預測的圖片
(3)模型評估
- 本文為🔗365天深度學習訓練營?中的學習記錄博客
- 🍖 原作者:K同學啊 | 接輔導、項目定制
1、序言
這篇博客,主要是講的是彩色圖片分類,數據集為:cifar10。
和以往的流程一樣:
① 判斷是否在GPU上運行
② 導入cifar10數據集
③ 歸一化(將像素的值標準化至0-1的區間內)
④ 可視化數據集里的數據(展示了2行,共20張圖片)
⑤ 搭建神經網絡
⑥ 編譯
⑦ 訓練模型
⑧ 預測并顯示預測的圖片
⑨ 模型評估?
2、主要代碼
import tensorflow as tf
from keras import datasets, layers, models
import matplotlib.pyplot as plt
import numpy as np
# 1、設置GPU
gpus = tf.config.list_physical_devices('GPU')if gpus:gpu0 = gpus[0] # 如果有多個GPU,僅使用第0個GPUtf.config.experimental.set_memory_growth(gpu0,True) # 設置GPU顯存用量按需使用tf.config.set_visible_devices([gpu0], "GPU")# 2、導入數據
(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()# 3、歸一化
# 將像素的值標準化至0到1的區間內。
train_images, test_images = train_images / 255.0, test_images / 255.0
print(train_images.shape,test_images.shape,train_labels.shape,test_labels.shape)
# (50000, 32, 32, 3) (10000, 32, 32, 3) (50000, 1) (10000, 1)# 4、可視化
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer','dog', 'frog', 'horse', 'ship', 'truck']plt.figure(figsize=(20,10))
for i in range(20):plt.subplot(5,10,i+1)plt.xticks([])plt.yticks([])plt.grid(False)plt.imshow(train_images[i], cmap=plt.cm.binary)plt.xlabel(class_names[train_labels[i][0]])
plt.show()# 二、搭建神經網絡
model = models.Sequential([layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)), # 卷積層1,卷積核3*3layers.MaxPooling2D((2, 2)), # 池化層1,2*2采樣layers.Conv2D(64, (3, 3), activation='relu'), # 卷積層2,卷積核3*3layers.MaxPooling2D((2, 2)), # 池化層2,2*2采樣layers.Conv2D(64, (3, 3), activation='relu'), # 卷積層3,卷積核3*3layers.Flatten(), # Flatten層,連接卷積層與全連接層layers.Dense(64, activation='relu'), # 全連接層,特征進一步提取layers.Dense(10) # 輸出層,輸出預期結果
])# model.summary() # 打印網絡結構# 三、編譯
model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])# 四、訓練模型
history = model.fit(train_images, train_labels, epochs=10,validation_data=(test_images, test_labels))# 五、預測
plt.imshow(test_images[1])
plt.show()pre = model.predict(test_images)
print(class_names[np.argmax(pre[1])])# 六、模型評估
plt.plot(history.history['accuracy'], label='accuracy')
plt.plot(history.history['val_accuracy'], label='val_accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.ylim([0.5, 1])
plt.legend(loc='lower right')
plt.show()test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=2)print("test_loss:", test_loss, "---------test_acc:", test_acc)
3、運行結果展示
(1)展示cifar10里面的20張圖片
(2)預測的圖片
遺憾,我的結果預測錯誤,預測的是:automobile
(3)模型評估
上圖可以看到,測試精確度達到70%左右。