Rabin-Karp 是啥呢 ? 它主要可以幫助我們進行以下的事情 :
它可以判斷 t 字串是否為 s 字串的子字串
子字串例的範例如下 :
s: "abcabee"
t: "cab"
t 為 s 的子字串,但注意如果 t 為 cbe,則不是喔。
那這裡問個問題。
用兩次 for loop 就能找到子字串了,為啥需要使用 Rabin-Karp 呢 ?
class Solution {
public:
/**
* @param source:
* @param target:
* @return: 回傳是否 target 是否為 source 的子字串
*/
bool strStr(string &source, string &target) {
for (int i=0; i <= (source.size() - target.size()); i++){
for (int j=0; j < target.size(); j++){
if(source[i+j] != target[j]) break;
if(j == (target.size()-1)){
return true;
}
}
}
return false;
}
};
因為兩次 for loop 的平均時間複雜度為 O(NxM),而 Rabin-Karp 為 O(N+M)
但這裡要注意 Rabin-Karp 嚴格來說不是最佳的『字串搜尋演算法』,更好的還有 KMP 之類的,不過這裡在本篇討論範圍,會拉 Rabin Karp 出來看是因為比較簡單……
其它字串搜尋演算法可以至以下連結來查查。
接下來將開始的那理解 Rabin-Karp 這個演算法,本篇文章分以為下幾個章節 :
- Rabin-Karp 基本原理
- Rabin-Karp 的 hash function
- LeetCode 的題目 Implement strStr()
Rabin-Karp 基本原理
Rabin-Karp 的核心為 :
用 hash 來進行字串判斷,並且只要 O(1) 的時間就可以計算出 next hash
使得 Rabin-Karp 為時間複雜度為 O(N+M) 優於爆力法 O(NxM)
這裡先簡單說一下 hash 是啥。簡單的說就是將『 某個東東 A 』透過 hash function,轉成『 某個東東 B 』。例如下範例,會根據字串的每個字元 ascii 加總起來,產生一個 hash 值。
"abc" => hash function => 294(97+98+99)
function hash(string s){
int res = 0;
for (int i=0; i < s.size(); i++){
// (int)字串,這裡會取得 s[i] 這字串的 ASCII 碼
res += (int)s[i];
}
return res;
}
這裡注意一下 hash function 不一定要長這個樣子,上述只是範例。
那為啥 rabin karp 要進行 hash 轉換呢 ? 等回兒會知道。接下來我們來看演算法的步驟。然後我們假設範例如下
source: "abcabee"
target: "ca"
source 長度: n
target 長度: m
Step 1. 將 Target 與 Source 前 M 個字串進行 Hash 轉換
假設咱們的 hash function 為如下 :
a => 97
b => 98
z => 122
hash(s) = ASCII(s[0]) + ASCII(s[1]) + ... + ASCII(S[n])
所以我們的 target 透過 hash 轉換會得到 :
source: "abcabee"
target: "ca"
hash("ab") = 97 + 98 = 195
hash("ca") = 99 + 97 = 196
注意 ! 應該有人有注意到這個 hash function 會有問題,這在下一章節會說說,這裡只是範例,只要有個 hash 方法就好。
Step 2. 計算每一個 Window 字串的 Hash 值
然後接下來會用移動視窗的模式來計算每 m 的 hash 值。由於我們要找的 target 為長度 2 的字串,因此我們的 m 就為 2。所以就代表要計算出每個連續 2 個字元的字串的 hash 值。
以下圖來看,它會照這以下的順序來計算每個字串的 hash 值,其中每個紅區塊就是一個 m 為為 2 的window。
那接下來要如何計算之後的字串的 hash 值呢 ? 不就是一樣 step 1 的方法嗎 ?
NoNo ~ 當然不是,你仔細想一下,這樣不就和爆力法一樣,每個字串都還是要 for loop 一次,那這樣對就代表時間複雜度還是 O(NxM),下面為 O(NXM) 的完整寫法,你看會需要兩個 for loop。
// source: "abcabee"
// 每一次 i 就是計算一個 window 的字串 hash
// i=0 => ab
// i=1 => bc
// i=2 => ca
int n = source.size();
int m = target.size();
for (int i=0; i < n; i++){
hash = 0;
for (int j=0; j < m; j++){
hash += ASCII(s[i+j]);
}
}
這就是為什麼要用 hash 的關係。
因為有了 hash 公式,我們可以用 O(1) 反推下一個 hash 值,這過程叫『 rehash 』
以我們範例的 hash 公式來看,反推的方法為 :
source: "abcabee"
a => 97
b => 98
c => 99
hash("ab") => 97 + 98 = 195
hash("bc") => 195 - 97(前 1 個字串的第一個字元) + 99(bc 的最後一個字元) = 197
這樣就可以使用 O(1) 的時間複雜度來進行運算了。
在設計 hash function 時,要記得設計出可以反推的 hash 公式,才能用在 Rabin-Karp。
Step 3. 判斷 target hash 是否等於 current hash
最後就只要判斷 target 的 hash 值,與某個 window 所計算出的 hash 值相等,那基本上它就是答案。
source: "abcabee"
target: "ca" => 99 + 97 = 196
window 1 : ab => 97 + 98 = 195
window 2 : bc => 195 - 97 + 99 = 197
window 3 : ca => 197 - 98 + 97 = 196
window 4 : ab => 196 - 99 + 98 = 195
window 5 : be => 195 - 97 + 100 = 198
window 6 : ee => 198 - 98 + 100 = 200
Ans: 196 = 196 = ca 所以 target 為 source 的子字串
Rabin-Karp 的 Hash Function
Rabin Karp 的 hash function 有兩個重點 :
要可以『 用 O(1) 進行 rehash 』 且『 不要 overflow 』
咱們接下來先從剛剛範例版本的來看看。
Version 1 : 範例用版本
我們上面範例使用的 hash 方法如下,就是簡單將每個字元的 ascii 碼加總起,來產生一個值。
a => 97
b => 98
z => 122
hash(s) = ASCII(s[0]) + ASCII(s[1]) + ... + ASCII(S[n])
但是這個版本的 hash 值有問題。
那就是 bc(98+99) 與 cb(99+98) 或 ad(97+100) 都會計算為 197
這就所謂的『 hash 碰撞 ( collision ) 』。在傳統資料結構中,有兩種方法可以解決這個問題 :
- 拉鏈法 : 也就是在這個 hash 對應下,會有一個 linklist 來儲放所有的值,就有點像是 197 : bc->cb->ad,但這個問題很明顯,如果 linklist 太長,就會由原本找 hash 對應值操作的時間複雜度 o(1) 變成 o(n( link list 的長度 ))。
- 開放地址法 : 假設第一次操作時 hash 值 197(bc),位置 1 插入 bc 沒碰撞,成功。第二次操作時 hash 值 197(cb),插入位置 1 發生碰撞失敗,再嘗試插入位置 2 成功。
這個大概就淺淺的說這兩種解決方法。詳細的內容建立去查查 hash 碰撞處理,應該就會知道更多了。
Version 2 : Hash 計算多考慮位置
事實上最理想的情況是儘可能的『 減少 hash 碰撞 』,因為不管是用那種方法,如果太常碰撞的話,都會讓搜尋或更新的時間複雜度大大幅的上升,所以接下來的 version 2 後就是開始往這方面來思考。
要解決上述的 bc、cb、ad 這種換個位置就碰撞的方法,最簡單的就是 :
Hash 值一起考慮每個字元的位置
所以事實上可以改成以下的 hash 計算。為了簡單我們就不用 ascii 來當每個字元的編號,而使用 1 ~ 26 來代表 a 至 z。
hash(s) = s[0]*10^(n-1) + s[1]*10^(n-2) + s[n]*10^0
{ a:1, b:2, c:3 }
Ex. hash("abc") = 1*10^2 + 2*10^1 + 3*10^0 = 123
為什麼要乘一個 10 ?
先說一下,不一定要乘 10,你可以乘其它數。而乘它的原因在於 :
減少 hash 碰撞
假設我們沒有乘某個數變成如下 :
hash(s) = s[0]^2 + s[1]^1
那這時如果 s[0] = 2 而 s[0] = 0 這樣 hash 值會為 4 值。那如果這時有那一個編碼為 s[0]^2 且只有一個字元,那就碰撞了。如下範例。
hash(s) = s[0]^2 + s[1]^1
"1" => 1
"2" => 2
hash("20") = 2^2 + 0 = 4
hash("4") = 4^1 = 4 ( ! 碰撞 )
Version 3 : 處理 overflow 問題
Version 2 事實上有個問題,如果 10^n 的 n 次方太大怎麼辦呢 ?
會炸掉,以 c++ 來說 int(16) 就大約 15 就炸了。
所以應該會有人想說,那就使用 mod 改成如下,mod 一個質數,越大越好,因為越小越容易發生 hash 衝突,至於為什麼是要選質數,請參考此篇,我覺得寫的很清楚了 :
hash(s) = (s[0]*10^(n-1) + s[1]*10^(n-2) + s[n]*10^0) mod prime
{ a:1, b:2, c:3 }
prime = 11
Ex. hash("abc") = (1*10^2 + 2*10^1 + 3*10^0) mod 11
= 123 mod 11
= 2
可是如果是 10^n 太大怎麼辦呢 ?
這裡就需要就上公式進行一下轉換。
先假設只有 2 位
(1) hash = (s[0]*10^1 + s[1]*10^0) mod p
可以轉換成
(2) hash = (((s[0] mod p) *10) mod p + s[1]) mod p
然後公式 2 寫成程式碼就是如下,以 c++ 為例 :
int hash = 0;
for (int i=0; i < s.size(); i++){
hash = ((hash * 10) % p + (int)s[i])) % p;
}
上面的公式(2) 就是 wiki rabin karp 上面所說的 hash function。
至於 (1) 為什麼可以轉成 (2) 呢 ? 呵,說實話我還真推不太出來。目前只知道和以下兩個數學原理有關 :
這個公式轉換就先這樣,唉……只能說我當初離散數學完全沒在唸…… 這裡只要先知道公式 (1) 與公式(2),兩個算出來是相等的,然後公式(2)可以讓我們避免發生 overflow。
Rehash 運算
好接下來是 rehash,咱們先不考慮 mod 的情況下,咱們先來看看如何將以下的 hash function 進行 rehash。
hash(s) = s[0]^2 + s[1]^1
就直接參考以下的圖會比較好理解,先說明一下這裡的範例 1234567 它是一個文字,然後假設你要取得 hash(“123”) 它的 function 如下 :
hash("123") = 1*10^2 + 2*10^1 + 3 = 123
然後接下來我們要進行 rehash 以 o(1) 的時間複雜度來取得,下一個字串 234 的方法就如下圖,這圖應該是很好理解囉。
然後接下來考慮 mod 的情況,下列為前提條件。
Ex. "12345"
now : 123
next : 234
now_hash : (1*10^2 + 2*10^1 + 3*10^0) mod 11 = 2
現在要求的為 next 的 hash。
pow = ((10 mod 11)*10) mod 11 = 1
備註: pow 等於於 (10^2) mod 11,會用這種方式是防止 10^n 炸掉
然後下列為 rehash 考慮 mod 的流程。
(1) next_hash = (2 - 1*pow) = 1
(2) next_hash = (next_hash mod 11 + 11) mod 11 = 1
(3) next_hash = (next_hash*10 + 4) mod 11 = 3
驗證 : (2*10^2 + 3*10^1 + 4*10^0) mod 11 = 3
其中需要注意上述流程(2),為什麼它要多加個 11 然後在 mod 一次呢 ?
因為流程(1)可能會負數
這裡我只能說負數的取餘數,不同的語言有可能會有不同的答案,而這個方法是比較好在程式碼比較通常的解法,詳細的負數取餘建立看看下面這篇文章,先理解一下問題在那。
LeetCode 的題目 Implement strStr()
Implement strStr().
Return the index of the first occurrence of needle in haystack, or -1 if needle is not part of haystack.
Input: source = "hello", target = "ll"
Output: 2
Input: source = "aaaaa", target = "bba"
Output: -1
Rabin Karp 解法
這一段程式碼就是標準的 rabin karp 來判斷 target 字串是否為 source 字串的寫法。
#include <math.h>
class Solution {
private:
int BASE = 10;
int MODE = 19260817;
// def: hash(s) = s[0]*BASE^n-1 + s[1]*BASE^n-2 ... s[n-1]*BASE^0
int hash(string s){
int hash = 0;
for (int i=0; i < s.size(); i++){
hash = ((hash * BASE) % MODE + (int)s[i]) % MODE;
}
return hash;
};
public:
int strStr(string source, string target) {
int n = source.size();
int t = target.size();
if((n == 0 && t == 0) || t == 0){
return 0;
}
if(n < t){
return -1;
}
long long target_hash = hash(target);
long long cur_hash = hash(source.substr(0, t));
if(target_hash == cur_hash){
return 0;
}
long long pow = 1;
for (int j=0; j < t-1; j++){
pow = (pow * BASE) % MODE;
}
// rehash
// Ex. source = "abcde", target = "cd"
// cur_hash = hash("ab")
// target_hash = hash("cd")
// loop
// "bc" => "cd" => "de"
for (int i=1; i <= n - t; i++){
cur_hash = (cur_hash - (long long)source[i-1]*pow);
// 注意下面這一段 !!!! 為了解決上一段可能會負數的問題。
cur_hash = (cur_hash % MODE + MODE) % MODE;
cur_hash = (cur_hash*BASE + (int)source[i+t-1]) % MODE;
if(cur_hash == target_hash){
return i;
}
}
return -1;
}
};