機器學習中,針對不同的問題選用不同的損失函數非常重要,而均方誤差就是最基本,也是在解決回歸問題時最常用的損失函數。本文就keras模塊均方誤差的計算梳理了一些細節。
首先看一下均方誤差的數學定義 :
均方誤差是預測向量與真實向量差值的平方然后求平均,其中n為兩個向量所包含的元素的個數。
MSE=1n∑i=1n(Yi?Y^i)2MSE = \frac{1}{n}\sum_{i=1}^n(Y_i - \hat Y_i)^2MSE=n1?i=1∑n?(Yi??Y^i?)2
需要注意的是,我們發現在上述定義中,真實值與預測值都是兩個向量。下面以一個小例子來說明keras中的均方誤差計算。
以keras源代碼中的注釋為例,在選擇默認reduction type 為 auto/sum_over_batch_size時,計算結果為(1+1)/ 4 = 0.5
y_true = [[0., 1.], [0., 0.]]
y_pred = [[1., 1.], [1., 0.]]
# Using 'auto'/'sum_over_batch_size' reduction type.
mse = tf.keras.losses.MeanSquaredError()
mse(y_true, y_pred).numpy()0.5
例如,在一類線性回歸問題中,輸出為一個包含a個元素的向量,對于每一個batch,我們計算一次均方誤差。此時輸入和輸出的均為 (batch_size,a)的矩陣,此時的均方誤差計算實際上是針對flatten平鋪后的兩個向量進行運算的,即n = batch_size * a,計算出的均方誤差仍然是一個一維的數值,表示了每個元素的平均誤差。