廣播機制
Numpy的Universal functions中要求輸入的數組shape是一致的,當數組的shape不相等
時,則會使用廣播機制。不過,調整數組使得shape一樣,需要滿足一定的規則,否則將
出錯。這些規則可歸納為以下4條。
1)讓所有輸入數組都向其中shape最長的數組看齊,不足的部分則通過在前面加1補
齊,如:
a:2×3×2
b:3×2
則b向a看齊,在b的前面加1,變為:1×3×2
2)輸出數組的shape是輸入數組shape的各個軸上的最大值;
3)如果輸入數組的某個軸和輸出數組的對應軸的長度相同或者某個軸的長度為1時,
這個數組能被用來計算,否則出錯;
4)當輸入數組的某個軸的長度為1時,沿著此軸運算時都用(或復制)此軸上的第一
組值。
廣播在整個Numpy中用于決定如何處理形狀迥異的數組,涉及的算術運算包括
(+,-,*,/…)。這些規則說得很嚴謹,但不直觀,下面我們結合圖形與代碼來進一步
說明。
目的:A+B,其中A為4×1矩陣,B為一維向量(3,)。
要相加,需要做如下處理:
·根據規則1,B需要向看齊,把B變為(1,3)
·根據規則2,輸出的結果為各個軸上的最大值,即輸出結果應該為(4,3)矩陣,那
么A如何由(4,1)變為(4,3)矩陣?B又如何由(1,3)變為(4,3)矩陣?
·根據規則4,用此軸上的第一組值(要主要區分是哪個軸),進行復制(但在實際處
理中不是真正復制,否則太耗內存,而是采用其他對象如ogrid對象,進行網格處理)即
可,詳細處理過程如圖1-4所示。
import numpy as np
A=np.arange(0,40,10).reshape(4,1)
B=np.arange(0,3)
print("A矩陣的形狀:{},B矩陣的形狀:{}".format(A.shape,B.shape))
C=A+B
print("C矩陣的形狀:{}".format(C.shape))
print(C)
運行結果
矩陣的形狀:(4, 1),B矩陣的形狀:(3,)
C矩陣的形狀:(4, 3)
[[ 0 1 2][10 11 12][20 21 22][30 31 32]]