Rabin-Karp 演算法的詳細解析
algorithm
Lastmod: 2020-06-26

Rabin-Karp 是啥呢 ? 它主要可以幫助我們進行以下的事情 :

它可以判斷 t 字串是否為 s 字串的子字串

子字串例的範例如下 :

s: "abcabee"
t: "cab"

t 為 s 的子字串,但注意如果 t 為 cbe,則不是喔。

那這裡問個問題。

用兩次 for loop 就能找到子字串了,為啥需要使用 Rabin-Karp 呢 ?

class Solution {
public:
    /**
     * @param source: 
     * @param target: 
     * @return: 回傳是否 target 是否為 source 的子字串
     */
    bool strStr(string &source, string &target) {
        for (int i=0; i <= (source.size() - target.size()); i++){
            for (int j=0; j < target.size(); j++){
                if(source[i+j] != target[j]) break;
                
                if(j == (target.size()-1)){
                    return true;
                }
            }
        }
        return false;
    }
};

因為兩次 for loop 的平均時間複雜度為 O(NxM),而 Rabin-Karp 為 O(N+M)

但這裡要注意 Rabin-Karp 嚴格來說不是最佳的『字串搜尋演算法』,更好的還有 KMP 之類的,不過這裡在本篇討論範圍,會拉 Rabin Karp 出來看是因為比較簡單……

其它字串搜尋演算法可以至以下連結來查查。

wiki-字串搜尋演算法

接下來將開始的那理解 Rabin-Karp 這個演算法,本篇文章分以為下幾個章節 :

  • Rabin-Karp 基本原理
  • Rabin-Karp 的 hash function
  • LeetCode 的題目 Implement strStr()

Rabin-Karp 基本原理

Rabin-Karp 的核心為 :

用 hash 來進行字串判斷,並且只要 O(1) 的時間就可以計算出 next hash

使得 Rabin-Karp 為時間複雜度為 O(N+M) 優於爆力法 O(NxM)

這裡先簡單說一下 hash 是啥。簡單的說就是將『 某個東東 A 』透過 hash function,轉成『 某個東東 B 』。例如下範例,會根據字串的每個字元 ascii 加總起來,產生一個 hash 值。

"abc" => hash function => 294(97+98+99)

function hash(string s){
   int res = 0;
   for (int i=0; i < s.size(); i++){
      // (int)字串,這裡會取得 s[i] 這字串的 ASCII 碼
      res += (int)s[i];
   }
   return res;
}

這裡注意一下 hash function 不一定要長這個樣子,上述只是範例。

那為啥 rabin karp 要進行 hash 轉換呢 ? 等回兒會知道。接下來我們來看演算法的步驟。然後我們假設範例如下

source: "abcabee"
target: "ca"

source 長度: n
target 長度: m

Step 1. 將 Target 與 Source 前 M 個字串進行 Hash 轉換

假設咱們的 hash function 為如下 :

a => 97
b => 98
z => 122

hash(s) = ASCII(s[0]) + ASCII(s[1]) + ... + ASCII(S[n])

所以我們的 target 透過 hash 轉換會得到 :

source: "abcabee"
target: "ca"

hash("ab") = 97 + 98 = 195
hash("ca") = 99 + 97 = 196

注意 ! 應該有人有注意到這個 hash function 會有問題,這在下一章節會說說,這裡只是範例,只要有個 hash 方法就好。

Step 2. 計算每一個 Window 字串的 Hash 值

然後接下來會用移動視窗的模式來計算每 m 的 hash 值。由於我們要找的 target 為長度 2 的字串,因此我們的 m 就為 2。所以就代表要計算出每個連續 2 個字元的字串的 hash 值。

以下圖來看,它會照這以下的順序來計算每個字串的 hash 值,其中每個紅區塊就是一個 m 為為 2 的window。

那接下來要如何計算之後的字串的 hash 值呢 ? 不就是一樣 step 1 的方法嗎 ?

NoNo ~ 當然不是,你仔細想一下,這樣不就和爆力法一樣,每個字串都還是要 for loop 一次,那這樣對就代表時間複雜度還是 O(NxM),下面為 O(NXM) 的完整寫法,你看會需要兩個 for loop。

// source: "abcabee"
// 每一次 i 就是計算一個 window 的字串 hash
// i=0 => ab
// i=1 => bc
// i=2 => ca
int n = source.size();
int m = target.size();
for (int i=0; i < n; i++){
  hash = 0;
  for (int j=0; j < m; j++){
     hash += ASCII(s[i+j]);
  }
}

這就是為什麼要用 hash 的關係。

因為有了 hash 公式,我們可以用 O(1) 反推下一個 hash 值,這過程叫『 rehash 』

以我們範例的 hash 公式來看,反推的方法為 :

source: "abcabee"
a => 97
b => 98
c => 99
hash("ab") => 97 + 98 = 195
hash("bc") => 195 - 97(前 1 個字串的第一個字元) + 99(bc 的最後一個字元) = 197

這樣就可以使用 O(1) 的時間複雜度來進行運算了。

在設計 hash function 時,要記得設計出可以反推的 hash 公式,才能用在 Rabin-Karp。

Step 3. 判斷 target hash 是否等於 current hash

最後就只要判斷 target 的 hash 值,與某個 window 所計算出的 hash 值相等,那基本上它就是答案。

source: "abcabee"
target: "ca" => 99 + 97 = 196

window 1 : ab => 97 + 98 = 195
window 2 : bc => 195 - 97 + 99 = 197
window 3 : ca => 197 - 98 + 97 = 196
window 4 : ab => 196 - 99 + 98 = 195
window 5 : be => 195 - 97 + 100 = 198
window 6 : ee => 198 - 98 + 100 = 200

Ans: 196 = 196 = ca 所以 target 為 source 的子字串

Rabin-Karp 的 Hash Function

Rabin Karp 的 hash function 有兩個重點 :

要可以『 用 O(1) 進行 rehash 』 且『 不要 overflow 』

咱們接下來先從剛剛範例版本的來看看。

Version 1 : 範例用版本

我們上面範例使用的 hash 方法如下,就是簡單將每個字元的 ascii 碼加總起,來產生一個值。

a => 97
b => 98
z => 122

hash(s) = ASCII(s[0]) + ASCII(s[1]) + ... + ASCII(S[n])

但是這個版本的 hash 值有問題。

那就是 bc(98+99) 與 cb(99+98) 或 ad(97+100) 都會計算為 197

這就所謂的『 hash 碰撞 ( collision ) 』。在傳統資料結構中,有兩種方法可以解決這個問題 :

  • 拉鏈法 : 也就是在這個 hash 對應下,會有一個 linklist 來儲放所有的值,就有點像是 197 : bc->cb->ad,但這個問題很明顯,如果 linklist 太長,就會由原本找 hash 對應值操作的時間複雜度 o(1) 變成 o(n( link list 的長度 ))。
  • 開放地址法 : 假設第一次操作時 hash 值 197(bc),位置 1 插入 bc 沒碰撞,成功。第二次操作時 hash 值 197(cb),插入位置 1 發生碰撞失敗,再嘗試插入位置 2 成功。

這個大概就淺淺的說這兩種解決方法。詳細的內容建立去查查 hash 碰撞處理,應該就會知道更多了。

Version 2 : Hash 計算多考慮位置

事實上最理想的情況是儘可能的『 減少 hash 碰撞 』,因為不管是用那種方法,如果太常碰撞的話,都會讓搜尋或更新的時間複雜度大大幅的上升,所以接下來的 version 2 後就是開始往這方面來思考。

要解決上述的 bc、cb、ad 這種換個位置就碰撞的方法,最簡單的就是 :

Hash 值一起考慮每個字元的位置

所以事實上可以改成以下的 hash 計算。為了簡單我們就不用 ascii 來當每個字元的編號,而使用 1 ~ 26 來代表 a 至 z。

hash(s) = s[0]*10^(n-1) + s[1]*10^(n-2) + s[n]*10^0

{ a:1, b:2, c:3 }
Ex. hash("abc") = 1*10^2 + 2*10^1 + 3*10^0 = 123

為什麼要乘一個 10 ?

先說一下,不一定要乘 10,你可以乘其它數。而乘它的原因在於 :

減少 hash 碰撞

假設我們沒有乘某個數變成如下 :

hash(s) = s[0]^2 + s[1]^1 

那這時如果 s[0] = 2 而 s[0] = 0 這樣 hash 值會為 4 值。那如果這時有那一個編碼為 s[0]^2 且只有一個字元,那就碰撞了。如下範例。

hash(s) = s[0]^2 + s[1]^1
"1" => 1
"2" => 2

hash("20") = 2^2 + 0 = 4
hash("4") = 4^1 = 4 ( ! 碰撞 )

Version 3 : 處理 overflow 問題

Version 2 事實上有個問題,如果 10^n 的 n 次方太大怎麼辦呢 ?

會炸掉,以 c++ 來說 int(16) 就大約 15 就炸了。

所以應該會有人想說,那就使用 mod 改成如下,mod 一個質數,越大越好,因為越小越容易發生 hash 衝突,至於為什麼是要選質數,請參考此篇,我覺得寫的很清楚了 :

hash(s) = (s[0]*10^(n-1) + s[1]*10^(n-2) + s[n]*10^0) mod prime

{ a:1, b:2, c:3 }
prime = 11
Ex. hash("abc") = (1*10^2 + 2*10^1 + 3*10^0) mod 11 
= 123 mod 11
= 2

可是如果是 10^n 太大怎麼辦呢 ?

這裡就需要就上公式進行一下轉換。

先假設只有 2 位
(1) hash = (s[0]*10^1 + s[1]*10^0) mod p

可以轉換成
(2) hash = (((s[0] mod p) *10) mod p + s[1]) mod p

然後公式 2 寫成程式碼就是如下,以 c++ 為例 :

int hash = 0;
for (int i=0; i < s.size(); i++){
    hash = ((hash * 10) % p + (int)s[i])) % p;
}

上面的公式(2) 就是 wiki rabin karp 上面所說的 hash function。

至於 (1) 為什麼可以轉成 (2) 呢 ? 呵,說實話我還真推不太出來。目前只知道和以下兩個數學原理有關 :

這個公式轉換就先這樣,唉……只能說我當初離散數學完全沒在唸…… 這裡只要先知道公式 (1) 與公式(2),兩個算出來是相等的,然後公式(2)可以讓我們避免發生 overflow。

Rehash 運算

好接下來是 rehash,咱們先不考慮 mod 的情況下,咱們先來看看如何將以下的 hash function 進行 rehash。

hash(s) = s[0]^2 + s[1]^1 

就直接參考以下的圖會比較好理解,先說明一下這裡的範例 1234567 它是一個文字,然後假設你要取得 hash(“123”) 它的 function 如下 :

hash("123") = 1*10^2 + 2*10^1 + 3 = 123

然後接下來我們要進行 rehash 以 o(1) 的時間複雜度來取得,下一個字串 234 的方法就如下圖,這圖應該是很好理解囉。

然後接下來考慮 mod 的情況,下列為前提條件。

Ex. "12345"
now : 123 
next : 234
now_hash : (1*10^2 + 2*10^1 + 3*10^0) mod 11 = 2

現在要求的為 next 的 hash。
pow = ((10 mod 11)*10) mod 11 = 1
備註: pow 等於於 (10^2) mod 11,會用這種方式是防止 10^n 炸掉

然後下列為 rehash 考慮 mod 的流程。

(1) next_hash = (2 - 1*pow) = 1
(2) next_hash = (next_hash mod 11 + 11) mod 11 = 1
(3) next_hash = (next_hash*10 + 4) mod 11 = 3

驗證 : (2*10^2 + 3*10^1 + 4*10^0) mod 11 = 3

其中需要注意上述流程(2),為什麼它要多加個 11 然後在 mod 一次呢 ?

因為流程(1)可能會負數

這裡我只能說負數的取餘數,不同的語言有可能會有不同的答案,而這個方法是比較好在程式碼比較通常的解法,詳細的負數取餘建立看看下面這篇文章,先理解一下問題在那。

程式語言中負數取餘的問題

LeetCode 的題目 Implement strStr()

Implement strStr().

Return the index of the first occurrence of needle in haystack, or -1 if needle is not part of haystack.

Input: source = "hello", target = "ll"
Output: 2

Input: source = "aaaaa", target = "bba"
Output: -1

Rabin Karp 解法

這一段程式碼就是標準的 rabin karp 來判斷 target 字串是否為 source 字串的寫法。

#include <math.h>  
class Solution {
private:
    int BASE = 10;
    int MODE = 19260817;
    // def: hash(s) = s[0]*BASE^n-1 + s[1]*BASE^n-2 ... s[n-1]*BASE^0 
    int hash(string s){
        int hash = 0;
        for (int i=0; i < s.size(); i++){
            hash = ((hash * BASE) % MODE + (int)s[i]) % MODE;
        }
        return hash;
    };
public:
    int strStr(string source, string target) {
        int n = source.size();
        int t = target.size();
        if((n == 0 && t == 0) || t == 0){
            return 0;
        }
        
        if(n < t){
            return -1;
        }
        
        long long target_hash = hash(target);
        long long cur_hash = hash(source.substr(0, t)); 
        if(target_hash == cur_hash){
            return 0;
        }
        
        long long pow = 1;
        for (int j=0; j < t-1; j++){
            pow = (pow * BASE) % MODE;
        }
        
        // rehash
        // Ex. source = "abcde", target = "cd"
        // cur_hash = hash("ab")
        // target_hash = hash("cd")
        // loop 
        // "bc" => "cd" => "de"
        for (int i=1; i <= n - t; i++){
            cur_hash = (cur_hash - (long long)source[i-1]*pow);
            // 注意下面這一段 !!!! 為了解決上一段可能會負數的問題。
            cur_hash = (cur_hash % MODE + MODE) % MODE;
            cur_hash = (cur_hash*BASE + (int)source[i+t-1]) % MODE;
            if(cur_hash == target_hash){
                return i;
            }
        }
        return -1;
    }
};

參考資料

comments powered by Disqus