滾動hash
滾動哈希(rolling hash)也叫 Rabin-Karp 字符串哈希算法,它是將某個字符串看成某個進制下的整數,并將其對應的十進制整數作為hash值。
滾動hash算法的推導
假設有一個長度為n的數組a[0],a[1],a[2],…a[n-1],數組中的最大值為ma, 我們選取進制k滿足k>ma,將數組a看成是n位k進制整數,那么其對應的10進制整數為:
∑ i = 0 n ? 1 a [ i ] ? k n ? 1 ? i \sum_{i=0}^{n-1} a[i] * k^{n-1-i} i=0∑n?1?a[i]?kn?1?i
這樣一來,在子數組長度固定的前提下,給定進制 k,子數組與其十進制值滿足「一一對應」的關系,即不會有兩個不同的子數組,它們的十進制值相同。因此滾動哈希得到的哈希值是可以表示原子數組的。
滾動哈希的一大優勢在于,如果我們需要求出一個數組中長度為 len 的所有子數組的哈希值,需要的時間僅為線性,即如果我們已經計算出數組中以 j 開始的子數組的哈希值:
h a s h ( j ) = ∑ i = 0 l e n ? 1 a [ j + i ] ? k l e n ? 1 ? i hash(j) = \sum_{i=0}^{len-1} a[j+i] * k^{len-1-i} hash(j)=i=0∑len?1?a[j+i]?klen?1?i
那么要計算以 j+1 開始的子數組的哈希值,我們通過公式推導:
h a s h ( j + 1 ) = ∑ i = 0 l e n ? 1 a [ j + 1 + i ] ? k l e n ? 1 ? i = ∑ i = 1 l e n a [ j + i ] ? k l e n ? i = k ( ∑ i = 1 l e n a [ j + i ] ? k l e n ? 1 ? i ) = k ( h a s h ( j ) ? a [ j ] ? k l e n ? 1 + a [ j + l e n ] ? k ? 1 ) = k ? h a s h ( j ) ? a [ j ] ? k l e n + a [ j + l e n ] \begin{aligned} hash(j+1) &= \sum_{i=0}^{len-1} a[j+1+i] * k^{len-1-i} \\ &= \sum_{i=1}^{len} a[j+i]*k^{len-i} \\ &= k(\sum_{i=1}^{len} a[j+i]*k^{len-1-i}) \\ &= k(hash(j) - a[j]*k^{len-1} + a[j+len]*k^{-1}) \\ &= k*hash(j) - a[j]*k^{len} + a[j+len] \end{aligned} hash(j+1)?=i=0∑len?1?a[j+1+i]?klen?1?i=i=1∑len?a[j+i]?klen?i=k(i=1∑len?a[j+i]?klen?1?i)=k(hash(j)?a[j]?klen?1+a[j+len]?k?1)=k?hash(j)?a[j]?klen+a[j+len]?
就可以在 ? ( 1 ) \phi(1) ?(1)的時間內得到該值。
利用滾動hash算法計算最長公共子路徑的代碼示例如下:
上述代碼的執行效率較低,以下代碼通過二分法優化,可以有效降低代碼的時間復雜度:
def longest_common_subpath_2(n: int, paths: List[List[int]]) -> int:mod = (10 ** 9 + 7) * (10 ** 9 + 9)base = 10 ** 6 + 3# get min len of pathsmin_len = len(min(paths, key=lambda x: len(x)))def check(x: int) -> bool:k = pow(base, x, mod)hash_values = defaultdict(int)for path in paths:cnt = Counter()hash_value = 0for i in range(x):hash_value = (hash_value * base + path[i]) % modcnt[hash_value] += 1hash_values[hash_value] += 1for i in range(x, len(path)):hash_value = (hash_value * base + path[i] - path[i - x] * k) % modif hash_value not in cnt:cnt[hash_value] += 1hash_values[hash_value] += 1return max(hash_values.values(), default=0) == len(paths)l, r, ans = 1, min_len, 0while l <= r:mid = (l + r) >> 1if check(mid):ans = midl = mid + 1else:r = mid - 1return ans