目錄
- 一、簡介:
- 二、圖片分類網絡
- 1.記載訓練數據(torch自帶的cifa10數據集)
- 2.數據增強
- 3.模型構建
- 4.模型訓練
- 三、完整源碼及文檔
一、簡介:
基于殘差連接的圖片分類網絡,本網絡使用ResNet18作為基礎模塊,根據cifa10的特點進行改進網絡,使用交叉熵損失函數和SGD優化器。本網絡在cifa10數據集上不使用預訓練參數,經過數據增強,訓練30輪達到了85%的分類準確率。
二、圖片分類網絡
1.記載訓練數據(torch自帶的cifa10數據集)
2.數據增強
數據增強防止過擬合,將圖像數據進行標準化、縮放
3.模型構建
改模型:原始的resnet18首層使用的7x7的卷積核,CIFAR10圖片太小不適合,要改成3x3的,步長和padding都要一并改成1。因為圖太小,最大池化層也同樣沒用,刪掉。最后一個全連接層輸出改成10。
先定義一個殘差類(繼承NN.module,后面重復使用殘差):
分類模型構建: