引言
HyperLogLog算法經常在數據庫中被用來統計某一字段的Distinct Value(下文簡稱DV),比如Redis的HyperLogLog結構,出于好奇探索了一下這個算法的原理,無奈中文資料很少,只能直接去閱讀論文以及一些英文資料,總結成此文。
介紹
HyperLogLog算法來源于論文《HyperLogLog the analysis of a near-optimal cardinality estimation algorithm》(下載地址見文末的參考文獻),可以使用固定大小的字節計算任意大小的DV,本文先介紹該算法的原理,然后通過剖析stream-lib(一個Java實現的實時計算庫)對此算法的實現來進一步理解該算法。本文追求直觀理解,所以不會太過于糾結一些數學細節,如果關心數學細節的話可以直接去看論文,論文里會有具體的證明。
基數
基數就是指一個集合中不同值的數目,比如[a,b,c,d]的基數就是4,[a,b,c,d,a]的基數還是4,因為a重復了一個,不算。基數也可以稱之為Distinct Value,簡稱DV。下文中可能有時候稱呼為基數,有時候稱之為DV,但都是同一個意思。HyperLogLog算法就是用來計算基數的。
生活中的啟發-以拋硬幣為例
拋硬幣
HyperLogLog本質上來源于生活中一個小的發現,假設你拋了很多次硬幣,你告訴在這次拋硬幣的過程中最多只有兩次扔出連續的反面,讓我猜你總共拋了多少次硬幣,我敢打賭你拋硬幣的總次數不會太多,相反,如果你和我說最多出現了100次連續的反面,那么我敢肯定扔硬盤的總次數非常的多,甚至我還可以給出一個估計,這個估計要怎么給呢?其實是一個很簡單的概率問題,假設1代表拋出正面,0代表反面:
以序列1110100110為例
上圖中以拋硬幣序列"1110100110"為例,其中最長的反面序列是"00",我們順手把后面那個1也給帶上,也就是"001",因為它包括了序列中最長的一串0,所以在序列中肯定只出現過一次,而它在任意序列出現出現且僅出現一次的概率顯然是上圖所示的三個二分之一相乘,也就是八分之一,所以我可以給出一個估計值,你大概總共拋了8次硬幣。
很顯然,上面這種做法雖然能夠估計拋硬幣的總數,但是顯然誤差是比較大的,很容易受到突發事件(比如突然連續拋出好多0)的影響,HyperLogLog算法研究的就是如何減小這個誤差。
之前說過,HyperLogLog算法是用來計算基數的,這個拋硬幣的序列和基數有什么關系呢?比如在數據庫中,我只要在每次插入一條新的記錄時,計算這條記錄的hash,并且轉換成二進制,就可以將其看成一個硬幣序列了,如下(0b前綴表示二進制數):
計算hash
最簡單的想法
根據上面拋硬幣的啟發我可以想到如下的估計基數的算法(這里先給出偽代碼,后面會有Java實現):
輸入:一個集合
輸出:集合的基數
算法:
max = 0
對于集合中的每個元素:
hashCode = hash(元素)
num = hashCode二進制表示中最前面連續的0的數量
if num > max:
max = num
最后的結果是2的(max + 1)次冪
舉個例子,對于集合{ele1, ele2},先求hash(ele1)=0b00110111,它最前面的連續的0的數量為2(又稱為前導0),然后求hash(ele2)=0b10010000111,它的前導0數量為0,我們始終只保存前導零數量的最大值,所以最后max是2,我們估計的基數就是2的(2+1)次冪,即8。
為什么最后的max要加1呢?這是一個數學細節,具體要看論文,簡單的理解的話,可以像之前拋硬幣的例子那樣理解,把最長的一串零的后面的一個1或者前面的一個1"順手"帶上進行概率估計。
顯然這個算法是非常不準確的,但是這個想法還是很有啟發性的,從這個簡單的想法跟隨下文一步一步優化即可得到最終的比較高精度的HyperLogLog算法。
分桶
最簡單的一種優化方法顯然就是把數據分成m個均等的部分,分別估計其總數求平均后再乘以m,稱之為分桶。對應到前面拋硬幣的例子,其實就是把硬幣序列分成m個均等的部分,分別用之前提到的那個方法估計總數求平均后再乘以m,這樣就能一定程度上避免單一突發事件造成的誤差。
具體要怎么分桶呢?我們可以將每個元素的hash值的二進制表示的前幾位用來指示數據屬于哪個桶,然后把剩下的部分再按照之前最簡單的想法處理。
還是以剛剛的那個集合{ele1,ele2}為例,假設我要分2個桶,那么我只要去ele1的hash值的第一位來確定其分桶即可,之后用剩下的部分進行前導零的計算,如下圖:
假設ele1和ele2的hash值二進制表示如下:
hash(ele1) = 00110111
hash(ele2) = 10010001
分桶算法
到這里,你大概已經理解了LogLog算法的基本思想,LogLog算法是在HyperLogLog算法之前提出的一個基數估計算法,HyperLogLog算法其實就是LogLog算法的一個改進版。
LogLog算法完整的基數計算公式如下:
LogLog算法
其中m代表分桶數,R頭上一道橫杠的記號就代表每個桶的結果(其實就是桶中數據的最長前導零+1)的均值,相比我之前舉的簡單的例子,LogLog算法還乘了一個常數constant進行修正,這個constant具體是多少等我講到Java實現的時候再說。
調和平均數
前面的LogLog算法中我們是使用的是平均數來將每個桶的結果匯總起來,但是平均數有一個廣為人知的缺點,就是容易受到大的數值的影響,一個常見的例子是,假如我的工資是1000元一個月,我老板的工資是100000元一個月,那么我和老板的平均工資就是(100000 + 1000)/2,即50500元,顯然這離我的工資相差甚遠,我肯定不服這個平均工資。
用調和平均數就可以解決這一問題,調和平均數的結果會傾向于集合中比較小的數,x1到xn的調和平均數的公式如下:
調和平均數
再用這個公式算一下我和老板的平均工資:
使用調和平均數計算平均工資
最后的結果是1980元,這和我的工資水平還比較接近,這樣的平均工資水平我才比較信服。
再回到前面的LogLog算法,從前面的舉的例子可以看出,
影響LogLog算法精度的一個重要因素就是,hash值的前導零的數量顯然是有很大的偶然性的,經常會出現一兩數據前導零的數目比較多的情況,所以HyperLogLog算法相比LogLog算法一個重要的改進就是使用調和平均數而不是平均數來聚合每個桶中的結果,HyperLogLog算法的公式如下:
HyperLogLog算法
其中constant常數和m的含義和之前的LogLog算法公式中的含義一致,Rj代表(第j個桶中的數據的最大前導零數目+1),為了方便理解,我將公式再拆解一下:
HyperLogLog公式的理解
其實從算術平均數改成調和平均數這個優化是很容易想到的,但是為什么LogLog算法沒有直接使用調和平均數嗎?網上看到一篇英文文章里說大概是因為使用算術平均數的話證明比較容易一些,畢竟科學家們出論文每一步都是要證明的,不像我們這里簡單理解一下,猜一猜就可以了。
細節微調
關于HyperLogLog算法的大體思想到這里你就已經全部理解了。
不過算法中還有一些細微的校正,在數據總量比較小的時候,很容易就預測偏大,所以我們做如下校正:
(DV代表估計的基數值,m代表桶的數量,V代表結果為0的桶的數目,log表示自然對數)
if DV < (5 / 2) * m:
DV = m * log(m/V)
我再詳細解釋一下V的含義,假設我分配了64個桶(即m=64),當數據量很小時(比方說只有兩三個),那肯定有大量桶中沒有數據,也就說他們的估計值是0,V就代表這樣的桶的數目。
事實證明,這個校正的效果是非常好,在數據量小的時,估計得非常準確,有興趣可以去玩一下外國大佬制作的一個HyperLogLog算法的仿真:
http://content.research.neustar.biz/blog/hll.html
constant常數的選擇
constant常數的選擇與分桶的數目有關,具體的數學證明請看論文,這里就直接給出結論:
假設:m為分桶數,p是m的以2為底的對數
p
則按如下的規則計算constant
switch (p) {
case 4:
constant = 0.673 * m * m;
case 5:
constant = 0.697 * m * m;
case 6:
constant = 0.709 * m * m;
default:
constant = (0.7213 / (1 + 1.079 / m)) * m * m;
}
分桶數m的選擇
如果理解了之前的分桶算法,那么很顯然分桶數只能是2的整數次冪。
如果分桶越多,那么估計的精度就會越高,統計學上用來衡量估計精度的一個指標是“相對標準誤差”(relative standard deviation,簡稱RSD),RSD的計算公式這里就不給出了,百科上一搜就可以知道,從直觀上理解,RSD的值其實就是((每次估計的值)在(估計均值)上下的波動)占(估計均值)的比例(這句話加那么多括號是為了方便大家斷句)。RSD的值與分桶數m存在如下的計算關系:
RSD
有了這個公式,你可以先確定你想要達到的RSD的值,然后再推出分桶的數目m。
合并
假設有兩個數據流,分別構建了兩個HyperLogLog結構,稱為a和b,他們的桶數是一樣的,為n,現在要計算兩個數據流總體的基數。
數據流a:"a" "b" "c" "d" 基數:4
數據流b:"b" "c" "d" "e" 基數:4
兩個數據流的總體基數:5
從前文我們可以知道,HyperLogLog算法在內存中的結構其實就是一個桶數組,需要先用下面的算法從a和我b的桶數組中構建出新的桶數組c,其實就是從a,b的對應位置取最大的:
輸入:桶數組a,b。它們的長度都是n
輸出:新的桶數組c
算法:
c = c[n];
for (i=0; i
c[i]=max(a[i], b[i]);
}
return c;
之后用桶數組c代入前面的算法即可得到合并的總體基數。
Redis中的實現
Redis中和HyperLogLog相關的命令有三個:
PFADD hll ele:將ele添加進hll的基數計算中。流程:
先對ele求hash(使用的是一種叫做MurMurHash的算法)
將hash的低14位(因為總共有2的14次方個桶)作為桶的編號,選桶,記桶中當前的值為count
從的hash的第15位開始數0,假設從第15位開始有n個連續的0(即前導0)
如果n大于count,則把選中的桶的值置為n,否則不變
PFCOUNT hll:計算hll的基數。就是使用上面給出的DV公式根據桶中的數值,計算基數
PFMERGE hll3 hll1 hll2:將hll1和hll2合并成hll3。用的就是上面說的合并算法。
Redis的所有HyperLogLog結構都是固定的16384個桶(2的14次方),并且有兩種存儲格式:
稀疏格式:HyperLogLog算法在剛開始的時候,大多數桶其實都是0,稀疏格式通過存儲連續的0的數目,而不是每個0存一遍,大大減小了HyperLogLog剛開始時需要占用的內存
緊湊格式:用6個bit表示一個桶,需要占用12KB內存
如果還想更詳細地了解Redis中的實現細節的話,可以閱讀我的另一篇博客Redis源碼走馬觀花(5)HyperLogLog
HyperLogLog索引
之前在螞蟻實習的時候,用的一個自研數據庫號稱支持HyperLogLog索引.(目前還不知道有什么開源的數據庫支持這玩意,如果你知道,歡迎在評論里告訴我)。
所謂HyperLogLog索引,比如你在user列上建立了一個hyperLogLog索引,那么當你使用如下的查詢時:
SELECT COUNT(DISTINCT user) FROM users WHERE age >= 10 and city = "shanghai";
在計算COUNT(DISTINCT)時,會自動使用之前構建好的HyperLogLog索引來加速,據說能夠獲得數量級上的查詢速度提升。
如果仔細看了之前的算法,到這里可能會產生困惑,通過HyperLogLog似乎只能得到user的基數是多少,那又怎么能知道含有一定含有一定篩選條件(WHERE age > 10 and city = "shanghai")的user基數是多少呢?
其實再仔細想想,也很簡單,通過前面介紹過的“合并”就可以完成,對每個不同的city都構建了一個關于user的HyperLogLog結構,因為age的基數相對大一些,數據庫可以根據范圍在每個范圍構建了一個HyperLogLog結構,比如分別是0~10,10~20,20~30,這樣只需要將上面查詢涉及到的三個HyperLogLog結構合并即可(三個分別是指city為"guangzhou",age為10~20和age為20~30)。
這個只是我的個人猜測,也可能不是這樣。
Java實現分析
這個實現類中還包含很多與算法無關的序列化之類的代碼,所以不建議你直接去看,我把它的算法主干抽取了出來,變成了如下的三個類,你把這三個類的代碼復制下來放到項目的同一個包下即可,HyperLogLog類中還包含一個main函數,你可以運行一下看看代碼是否正確,代碼如下:
HyperLogLog.java
public class HyperLogLog {
private final RegisterSet registerSet;
private final int log2m; //log(m)
private final double alphaMM;
/**
*
* rsd = 1.04/sqrt(m)
* @param rsd 相對標準偏差
*/
public HyperLogLog(double rsd) {
this(log2m(rsd));
}
/**
* rsd = 1.04/sqrt(m)
* m = (1.04 / rsd)^2
* @param rsd 相對標準偏差
* @return
*/
private static int log2m(double rsd) {
return (int) (Math.log((1.106 / rsd) * (1.106 / rsd)) / Math.log(2));
}
private static double rsd(int log2m) {
return 1.106 / Math.sqrt(Math.exp(log2m * Math.log(2)));
}
/**
* accuracy = 1.04/sqrt(2^log2m)
*
* @param log2m
*/
public HyperLogLog(int log2m) {
this(log2m, new RegisterSet(1 << log2m));
}
/**
*
* @param registerSet
*/
public HyperLogLog(int log2m, RegisterSet registerSet) {
this.registerSet = registerSet;
this.log2m = log2m;
int m = 1 << this.log2m; //從log2m中算出m
alphaMM = getAlphaMM(log2m, m);
}
public boolean offerHashed(int hashedValue) {
// j 代表第幾個桶,取hashedValue的前log2m位即可
// j 介于 0 到 m
final int j = hashedValue >>> (Integer.SIZE - log2m);
// r代表 除去前log2m位剩下部分的前導零 + 1
final int r = Integer.numberOfLeadingZeros((hashedValue << this.log2m) | (1 << (this.log2m - 1)) + 1) + 1;
return registerSet.updateIfGreater(j, r);
}
/**
* 添加元素
* @param o 要被添加的元素
* @return
*/
public boolean offer(Object o) {
final int x = MurmurHash.hash(o);
return offerHashed(x);
}
public long cardinality() {
double registerSum = 0;
int count = registerSet.count;
double zeros = 0.0;
//count是桶的數量
for (int j = 0; j < registerSet.count; j++) {
int val = registerSet.get(j);
registerSum += 1.0 / (1 << val);
if (val == 0) {
zeros++;
}
}
double estimate = alphaMM * (1 / registerSum);
if (estimate <= (5.0 / 2.0) * count) { //小數據量修正
return Math.round(linearCounting(count, zeros));
} else {
return Math.round(estimate);
}
}
/**
* 計算constant常數的取值
* @param p log2m
* @param m m
* @return
*/
protected static double getAlphaMM(final int p, final int m) {
// See the paper.
switch (p) {
case 4:
return 0.673 * m * m;
case 5:
return 0.697 * m * m;
case 6:
return 0.709 * m * m;
default:
return (0.7213 / (1 + 1.079 / m)) * m * m;
}
}
/**
*
* @param m 桶的數目
* @param V 桶中0的數目
* @return
*/
protected static double linearCounting(int m, double V) {
return m * Math.log(m / V);
}
public static void main(String[] args) {
HyperLogLog hyperLogLog = new HyperLogLog(0.1325);//64個桶
//集合中只有下面這些元素
hyperLogLog.offer("hhh");
hyperLogLog.offer("mmm");
hyperLogLog.offer("ccc");
//估算基數
System.out.println(hyperLogLog.cardinality());
}
}
MurmurHash.java
/**
* 一種快速的非加密hash
* 適用于對保密性要求不高以及不在意hash碰撞攻擊的場合
*/
public class MurmurHash {
public static int hash(Object o) {
if (o == null) {
return 0;
}
if (o instanceof Long) {
return hashLong((Long) o);
}
if (o instanceof Integer) {
return hashLong((Integer) o);
}
if (o instanceof Double) {
return hashLong(Double.doubleToRawLongBits((Double) o));
}
if (o instanceof Float) {
return hashLong(Float.floatToRawIntBits((Float) o));
}
if (o instanceof String) {
return hash(((String) o).getBytes());
}
if (o instanceof byte[]) {
return hash((byte[]) o);
}
return hash(o.toString());
}
public static int hash(byte[] data) {
return hash(data, data.length, -1);
}
public static int hash(byte[] data, int seed) {
return hash(data, data.length, seed);
}
public static int hash(byte[] data, int length, int seed) {
int m = 0x5bd1e995;
int r = 24;
int h = seed ^ length;
int len_4 = length >> 2;
for (int i = 0; i < len_4; i++) {
int i_4 = i << 2;
int k = data[i_4 + 3];
k = k << 8;
k = k | (data[i_4 + 2] & 0xff);
k = k << 8;
k = k | (data[i_4 + 1] & 0xff);
k = k << 8;
k = k | (data[i_4 + 0] & 0xff);
k *= m;
k ^= k >>> r;
k *= m;
h *= m;
h ^= k;
}
// avoid calculating modulo
int len_m = len_4 << 2;
int left = length - len_m;
if (left != 0) {
if (left >= 3) {
h ^= (int) data[length - 3] << 16;
}
if (left >= 2) {
h ^= (int) data[length - 2] << 8;
}
if (left >= 1) {
h ^= (int) data[length - 1];
}
h *= m;
}
h ^= h >>> 13;
h *= m;
h ^= h >>> 15;
return h;
}
public static int hashLong(long data) {
int m = 0x5bd1e995;
int r = 24;
int h = 0;
int k = (int) data * m;
k ^= k >>> r;
h ^= k * m;
k = (int) (data >> 32) * m;
k ^= k >>> r;
h *= m;
h ^= k * m;
h ^= h >>> 13;
h *= m;
h ^= h >>> 15;
return h;
}
public static long hash64(Object o) {
if (o == null) {
return 0l;
} else if (o instanceof String) {
final byte[] bytes = ((String) o).getBytes();
return hash64(bytes, bytes.length);
} else if (o instanceof byte[]) {
final byte[] bytes = (byte[]) o;
return hash64(bytes, bytes.length);
}
return hash64(o.toString());
}
// 64 bit implementation copied from here: https://github.com/tnm/murmurhash-java
/**
* Generates 64 bit hash from byte array with default seed value.
*
* @param data byte array to hash
* @param length length of the array to hash
* @return 64 bit hash of the given string
*/
public static long hash64(final byte[] data, int length) {
return hash64(data, length, 0xe17a1465);
}
/**
* Generates 64 bit hash from byte array of the given length and seed.
*
* @param data byte array to hash
* @param length length of the array to hash
* @param seed initial seed value
* @return 64 bit hash of the given array
*/
public static long hash64(final byte[] data, int length, int seed) {
final long m = 0xc6a4a7935bd1e995L;
final int r = 47;
long h = (seed & 0xffffffffl) ^ (length * m);
int length8 = length / 8;
for (int i = 0; i < length8; i++) {
final int i8 = i * 8;
long k = ((long) data[i8 + 0] & 0xff) + (((long) data[i8 + 1] & 0xff) << 8)
+ (((long) data[i8 + 2] & 0xff) << 16) + (((long) data[i8 + 3] & 0xff) << 24)
+ (((long) data[i8 + 4] & 0xff) << 32) + (((long) data[i8 + 5] & 0xff) << 40)
+ (((long) data[i8 + 6] & 0xff) << 48) + (((long) data[i8 + 7] & 0xff) << 56);
k *= m;
k ^= k >>> r;
k *= m;
h ^= k;
h *= m;
}
switch (length % 8) {
case 7:
h ^= (long) (data[(length & ~7) + 6] & 0xff) << 48;
case 6:
h ^= (long) (data[(length & ~7) + 5] & 0xff) << 40;
case 5:
h ^= (long) (data[(length & ~7) + 4] & 0xff) << 32;
case 4:
h ^= (long) (data[(length & ~7) + 3] & 0xff) << 24;
case 3:
h ^= (long) (data[(length & ~7) + 2] & 0xff) << 16;
case 2:
h ^= (long) (data[(length & ~7) + 1] & 0xff) << 8;
case 1:
h ^= (long) (data[length & ~7] & 0xff);
h *= m;
}
;
h ^= h >>> r;
h *= m;
h ^= h >>> r;
return h;
}
}
RegisterSet.java
public class RegisterSet {
public final static int LOG2_BITS_PER_WORD = 6; //2的6次方是64
public final static int REGISTER_SIZE = 5; //每個register占5位,代碼里有一些細節涉及到這個5位,所以僅僅改這個參數是會報錯的
public final int count;
public final int size;
private final int[] M;
//傳入m
public RegisterSet(int count) {
this(count, null);
}
public RegisterSet(int count, int[] initialValues) {
this.count = count;
if (initialValues == null) {
/**
* 分配(m / 6)個int給M
*
* 因為一個register占五位,所以每個int(32位)有6個register
*/
this.M = new int[getSizeForCount(count)];
} else {
this.M = initialValues;
}
//size代表RegisterSet所占字的大小
this.size = this.M.length;
}
public static int getBits(int count) {
return count / LOG2_BITS_PER_WORD;
}
public static int getSizeForCount(int count) {
int bits = getBits(count);
if (bits == 0) {
return 1;
} else if (bits % Integer.SIZE == 0) {
return bits;
} else {
return bits + 1;
}
}
public void set(int position, int value) {
int bucketPos = position / LOG2_BITS_PER_WORD;
int shift = REGISTER_SIZE * (position - (bucketPos * LOG2_BITS_PER_WORD));
this.M[bucketPos] = (this.M[bucketPos] & ~(0x1f << shift)) | (value << shift);
}
public int get(int position) {
int bucketPos = position / LOG2_BITS_PER_WORD;
int shift = REGISTER_SIZE * (position - (bucketPos * LOG2_BITS_PER_WORD));
return (this.M[bucketPos] & (0x1f << shift)) >>> shift;
}
public boolean updateIfGreater(int position, int value) {
int bucket = position / LOG2_BITS_PER_WORD; //M下標
int shift = REGISTER_SIZE * (position - (bucket * LOG2_BITS_PER_WORD)); //M偏移
int mask = 0x1f << shift; //register大小為5位
// 這里使用long是為了避免int的符號位的干擾
long curVal = this.M[bucket] & mask;
long newVal = value << shift;
if (curVal < newVal) {
//將M的相應位置為新的值
this.M[bucket] = (int) ((this.M[bucket] & ~mask) | newVal);
return true;
} else {
return false;
}
}
public void merge(RegisterSet that) {
for (int bucket = 0; bucket < M.length; bucket++) {
int word = 0;
for (int j = 0; j < LOG2_BITS_PER_WORD; j++) {
int mask = 0x1f << (REGISTER_SIZE * j);
int thisVal = (this.M[bucket] & mask);
int thatVal = (that.M[bucket] & mask);
word |= (thisVal < thatVal) ? thatVal : thisVal;
}
this.M[bucket] = word;
}
}
int[] readOnlyBits() {
return M;
}
public int[] bits() {
int[] copy = new int[size];
System.arraycopy(M, 0, copy, 0, M.length);
return copy;
}
}
這里hash算法使用的是MurmurHash算法,可能很多人沒聽說過,其實在開源項目中使用的非常廣泛,這個算法在只追求速度和hash的隨機性,而不在意安全性和保密性的時候非常有效,我們不去深究這個算法的原理了,這個類的代碼也不必仔細看,就把它看成一個hash函數就好了。
還有需要稍微注意一下這里的RegisterSet類,我們把存放一個桶的結果的地方叫做一個register,類中M數組就是存放這些register內容的地方,在這里我們設置一個register占5位,所以每個int(32位)總共可以存放6個register。
重點去閱讀HyperLogLog類,我添加了相關注釋方便你閱讀,希望能夠幫助你了解更多細節。
參考資料
1.論文《HyperLogLog: the analysis of a near-optimal cardinality estimation algorithm》
http://algo.inria.fr/flajolet/Publications/FlFuGaMe07.pdf
如果有興趣去了解算法的數學證明的大佬可以去看一下