在《提高模型性能,你可以嘗試這幾招...》一文中,我們給出了幾種提高模型性能的方法,但這篇文章是在訓練數據集不變的前提下提出的優化方案。其實對于深度學習而言,數據量的多寡通常對模型性能的影響更大,所以擴充數據規模一般情況是一個非常有效的方法。
對于Google、Facebook來說,收集幾百萬張圖片,訓練超大規模的深度學習模型,自然不在話下。但是對于個人或者小型企業而言,收集現實世界的數據,特別是帶標簽的數據,將是一件非常費時費力的事。本文探討一種技術,在現有數據集的基礎上,進行數據增強(data augmentation),增加參與模型訓練的數據量,從而提升模型的性能。
什么是數據增強
所謂數據增強,就是采用在原有數據上隨機增加抖動和擾動,從而生成新的訓練樣本,新樣本的標簽和原始數據相同。這個也很好理解,對于一張標簽為“狗”的圖片,做一定的模糊、裁剪、變形等處理,并不會改變這張圖片的類別。數據增強也不僅局限于圖片分類應用,比如有如下圖所示的數據,數據滿足正態分布:
我們在數據集的基礎上,增加一些擾動處理,數據分布如下:
數據就在原來的基礎上增加了幾倍,但整體上仍然滿足正態分布。有人可能會說,這樣的出來的模型不是沒有原來精確了嗎?考慮到現實世界的復雜性,我們采集到的數據很難完全滿足正態分布,所以這樣增加數據擾動,不僅不會降低模型的精確度,然而增強了泛化能力。
對于圖片數據而言,能夠做的數據增強的方法有很多,通常的方法是:
- 平移
- 旋轉
- 縮放
- 裁剪
- 切變(shearing)
- 水平/垂直翻轉
- ...
上面幾種方法,可能切變(shearing)比較難以理解,看一張圖就明白了:
我們要親自編寫這些數據增強算法嗎?通常不需要,比如keras就提供了批量處理圖片變形的方法。
keras中的數據增強方法
keras中提供了ImageDataGenerator類,其構造方法如下:
ImageDataGenerator(featurewise_center=False,samplewise_center=False,featurewise_std_normalization = False,samplewise_std_normalization = False,zca_whitening = False,rotation_range = 0.,width_shift_range = 0.,height_shift_range = 0.,shear_range = 0.,zoom_range = 0.,channel_shift_range = 0.,fill_mode = 'nearest',cval = 0.0,horizontal_flip = False,vertical_flip = False,rescale = None,preprocessing_function = None,data_format = K.image_data_format(),
)
復制代碼
參數很多,常用的參數有:
- rotation_range: 控制隨機的度數范圍旋轉。
- width_shift_range和height_shift_range: 分別用于水平和垂直移位。
- zoom_range: 根據[1 - zoom_range,1 + zoom_range]范圍均勻將圖像“放大”或“縮小”。
- horizontal_flip:控制是否水平翻轉。
完整的參數說明請參考keras文檔。
下面一段代碼將1張給定的圖片擴充為10張,當然你還可以擴充更多:
image = load_img(args["image"])
image = img_to_array(image)
image = np.expand_dims(image, axis=0)aug = ImageDataGenerator(rotation_range=30, width_shift_range=0.1, height_shift_range=0.1,shear_range=0.2, zoom_range=0.2, horizontal_flip=True, fill_mode="nearest")aug.fit(image)imageGen = aug.flow(image, batch_size=1, save_to_dir=args["output"], save_prefix=args["prefix"],save_format="jpeg")total = 0
for image in imageGen:# increment out countertotal += 1if total == 10:break
復制代碼
需要指出的是,上述代碼的最后一個迭代是必須的,否在不會在output目錄下生成圖片,另外output目錄必須存在,否則會出現一下錯誤:
Traceback (most recent call last):File "augmentation_demo.py", line 35, in <module>for image in imageGen:File "/data/ai/anaconda3/envs/keras/lib/python3.6/site-packages/keras_preprocessing/image.py", line 1526, in __next__return self.next(*args, **kwargs)File "/data/ai/anaconda3/envs/keras/lib/python3.6/site-packages/keras_preprocessing/image.py", line 1704, in nextreturn self._get_batches_of_transformed_samples(index_array)File "/data/ai/anaconda3/envs/keras/lib/python3.6/site-packages/keras_preprocessing/image.py", line 1681, in _get_batches_of_transformed_samplesimg.save(os.path.join(self.save_to_dir, fname))File "/data/ai/anaconda3/envs/keras/lib/python3.6/site-packages/PIL/Image.py", line 1947, in savefp = builtins.open(filename, "w+b")
FileNotFoundError: [Errno 2] No such file or directory: 'output/image_0_1091.jpeg'
復制代碼
如下一張狗狗的圖片:
經過數據增強技術處理之后,可以得到如下10張形態稍微不同的狗狗的圖片,這相當于在原有數據集上增加了10倍的數據,其實我們還可以擴充得最多:
數據增強之后的比較
我們以MiniVGGNet模型為例,說明在其在17flowers數據集上進行訓練的效果。17flowers是一個非常小的數據集,包含17中品類的花卉圖案,每個品類包含80張圖片,這對于深度學習而言,數據量實在是太小了。一般而言,要讓深度學習模型有一定的精確度,每個類別的圖片至少需要1000~5000張。這樣的數據集可以很好的說明數據增強技術的必要性。
從網站上下載的17flowers數據,所有的圖片都放在一個目錄下,而我們通常訓練時的目錄結構為:
{類別名}/{圖片文件}
復制代碼
為此我寫了一個organize_flowers17.py腳本。
在沒有使用數據增強的情況下,在訓練數據集和驗證數據集上精度、損失隨著訓練輪次的變化曲線圖:
可以看到,大約經過十幾輪的訓練,在訓練數據集上的準確率很快就達到了接近100%,然而在驗證數據集上的準確率卻無法再上升,只能達到60%左右。這個圖可以明顯的看出模型出現了非常嚴重的過擬合。
如果采用數據增強技術呢?曲線圖如下:
從圖中可以看到,雖然在訓練數據集上的準確率有所下降,但在驗證數據集上的準確率有比較明顯的提升,說明模型的泛化能力有所增強。
也許在我們看來,準確率從60%多增加到70%,只有10%的提升,并不是什么了不得的成績。但要考慮到我們采用的數據集樣本數量實在是太少,能夠達到這樣的提升已經是非常難得,在實際項目中,有時為了提升1%的準確率,都會花費不少的功夫。
總結
數據增強技術在一定程度上能夠提高模型的泛化能力,減少過擬合,但在實際中,我們如果能夠收集到更多真實的數據,還是要盡量使用真實數據。另外,數據增強只需應用于訓練數據集,驗證集上則不需要,畢竟我們希望在驗證集上測試真實數據的準確。
以上實例均有完整的代碼,點擊閱讀原文,跳轉到我在github上建的示例代碼。
另外,我在閱讀《Deep Learning for Computer Vision with Python》這本書,在微信公眾號后臺回復“計算機視覺”關鍵字,可以免費下載這本書的電子版。
參考閱讀
提高模型性能,你可以嘗試這幾招...
計算機視覺與深度學習,看這本書就夠了