本文由 AI Agent(Claude)代筆撰寫,文中的「我」指的是 AI Agent。Patrick 只有在文章最後做過潤飾調整。
上一篇看完了 forward pass 的編排,這一篇要看看「文字」是怎麼變成 LLM 看得懂的 token id 的:tokenizer.rs。
很多人以為 tokenizer 就是「把字串切成 word」,我以前也是這樣想的。不過現代 LLM 的 tokenizer 實在是比這精緻得多——它是個 learned algorithm,由訓練資料決定要怎麼切。實作起來其實也才 200 多行而已,但每一段都很有得講喔。
概念一:什麼是 BPE?為什麼不直接用 word?
最早的 NLP 用的是 word tokenizer:["I", "love", "Rust"]。看起來很直覺對吧?只是它有兩個麻煩:
- OOV(out of vocabulary):訓練時沒看過的 word(拼寫錯字、新詞)會變成
<unk>。 - 詞彙爆炸:英文單字大概 60 萬個,加上人名、專業詞彙,vocabulary 動輒幾百萬。
BPE(Byte Pair Encoding) 解這個問題的招數很妙——「從字元開始合併」:
- 初始化:每個字元是一個 token。
- 統計訓練資料裡哪一對相鄰 token 出現最多。
- 合併最常見的那一對成一個新 token。
- 重複直到達到目標 vocab size(典型 32k 或 128k)。
合併出來的 token 表會包含「字元」、「片段」、「常見字」、「常見片語」混合在一起,例如:
'a', 'b', ..., 'z',
'th', 'in', 'er', 're',
'the', 'and', 'ing', 'tion',
'▁the', '▁of', '▁to',
...▁ 是 SentencePiece 用來標記 word boundary 的特殊字元(U+2581)。
概念二:SentencePiece 的 word boundary 處理
let prepared = format!("\u{2581}{}", text.replace(' ', "\u{2581}"));SentencePiece 把空白統一替換成 ▁,並且在輸入最前面也加一個 ▁。為什麼要這樣搞?
因為這樣一來**「空格」就不再是特殊字元,而是 token 的一部分了**。例如 "hello world" 會變成 "▁hello▁world",tokenize 後可能就是 ["▁hello", "▁world"]。decode 時把 ▁ 換回空格,原始字串就還原了。
我覺得這個設計頗聰明的地方在於:tokenization 變成 reversible——不會像 word tokenizer 那樣「I love Rust → [I, love, Rust] → 咦空格到底在哪?」最後拼不回來。
演算法核心:encode 的 greedy merge loop
loop {
let mut best_score = f32::NEG_INFINITY;
let mut best_idx: Option<usize> = None;
let mut best_id: u32 = 0;
for i in 0..ids.len().saturating_sub(1) {
let merged = format!("{}{}",
&self.tokens[ids[i] as usize],
&self.tokens[ids[i + 1] as usize]
);
if let Some(&id) = self.token_to_id.get(&merged) {
let s = self.scores[id as usize];
if s > best_score {
best_score = s;
best_idx = Some(i);
best_id = id;
}
}
}
match best_idx {
Some(i) => {
ids[i] = best_id;
ids.remove(i + 1);
}
None => break,
}
}這就是 SentencePiece 的「最高分鄰接合併」編碼演算法:
- 把輸入 split 成單字元 token 序列。
- 在所有相鄰 pair 中,找一個合併後是 vocab 裡的 token、且分數最高的。
- 合併它(用合併後的 id 取代第 i 個,刪掉第 i+1 個)。
- 重複直到沒有合併可以做。
這裡的分數是 GGUF metadata 裡 tokenizer.ggml.scores 提供的值——它是 SentencePiece 訓練時學到的「這個 token 到底多好用」的指標,分數越高就越偏好。
複雜度分析
每輪 loop 是 $O(n)$(掃一次相鄰 pair),總共最多 n 輪(每輪至少縮掉一個 token),所以整個 encode 是 $O(n^2)$。對 1000 字元的 prompt 來說是 100 萬次 hashmap lookup——聽起來嚇人,不過還好 hashmap 平均是 $O(1)$,沒事啦。
要更快的話可以用 priority queue(heap):每次取最高分的 pair $O(\log n)$,總共 $O(n \log n)$。只是實作起來複雜,而且 LLM 的 prompt 通常也不會超過幾千字元,$O(n^2)$ 其實完全夠用了。
format! 在 hot path 的成本
注意喔,這個 format! 在每個 pair 都會 allocate 一個新 String,對長 prompt 來說就是上萬次 allocation。真要優化的話,可以:
- 預配一個 reusable buffer:
let mut merged = String::with_capacity(64); ...; merged.clear(); merged.push_str(...); - 用
(u32, u32) → u32的 lookup table(pair table),直接避開字串拼接。
不過對 LLM runner 來說 tokenization 通常不是瓶頸啦——一次 encode 才幾十毫秒,相對於 forward pass 的好幾秒,根本可以忽略。
演算法核心:byte fallback —— 處理 vocab 外的字元
for ch in prepared.chars() {
let s: String = ch.to_string();
if let Some(&id) = self.token_to_id.get(&s) {
ids.push(id);
} else if let Some(bf) = &self.byte_fallback {
for &byte in s.as_bytes() {
let id = bf[byte as usize];
ids.push(id);
}
}
}如果某個字元不在 vocab 裡(像是中文、Emoji),就把它的 UTF-8 bytes 一個一個用 byte fallback token 編碼。Llama 的 vocab 裡準備了 256 個專門的 byte token,名字長這樣:
"<0x00>", "<0x01>", ..., "<0xFF>"每個 byte token 對應一個 byte 值。例如 "我" 的 UTF-8 是 [0xE6, 0x88, 0x91],會被編碼成三個 byte token:<0xE6>, <0x88>, <0x91>。
偵測 byte fallback 是否完整
let mut byte_fallback = [u32::MAX; 256];
let mut have_all = true;
for b in 0..=255u32 {
let key = format!("<0x{:02X}>", b);
if let Some(&id) = token_to_id.get(&key) {
byte_fallback[b as usize] = id;
} else {
have_all = false;
}
}
let byte_fallback = if have_all { Some(byte_fallback) } else { None };只有當 vocab 裡 256 個 byte token 全都在的時候才啟用 byte fallback。如果只缺了幾個,就整個 disable 掉——因為 partial 的支援會讓 encode 變得不可預測,那種半調子狀態最麻煩了。
Option<[u32; 256]> 我覺得是個頗漂亮的 Rust 表達:要嘛全有、要嘛全無,型別系統直接幫你把這個 invariant 釘死。
演算法核心:decode 的 byte 拼合
decode 比 encode 單純多了,不過有個微妙的小細節要注意——byte fallback token 必須在 byte-level 拼回 UTF-8 codepoint:
pub fn decode(&self, ids: &[u32]) -> String {
let mut bytes: Vec<u8> = Vec::new();
for &id in ids {
let s = match self.tokens.get(id as usize) {
Some(s) => s,
None => continue,
};
if s.len() == 6 && s.starts_with("<0x") && s.ends_with('>') {
if let Ok(b) = u8::from_str_radix(&s[3..5], 16) {
bytes.push(b);
continue;
}
}
bytes.extend_from_slice(s.replace('\u{2581}', " ").as_bytes());
}
String::from_utf8_lossy(&bytes).into_owned()
}關鍵是:先收集到 Vec<u8>,最後一次性轉 UTF-8。
為什麼不能逐 token decode 呢?因為一個 UTF-8 codepoint 可能會跨好幾個 byte token。例如 "我" 是三個 byte token,你要是逐個 decode:
<0xE6>→0xE6不是合法 UTF-8 → 變成?<0x88>→0x88不是合法 UTF-8 → 變成?<0x91>→0x91不是合法 UTF-8 → 變成?
結果就是 ??? 而不是 我,慘。先把 byte 全收集起來、最後再一次 from_utf8_lossy,這三個 byte 才會被正確拼成一個中文字。
decode_piece 的 caveat
我也提供了一個 decode_piece(id) 做單 token decode,但它有個 limitation——遇到 byte fallback 時只能 lossy emit。所以LLM 串流輸出的時候一定要用 decode(&[id]) 或自己 buffer 起來,千萬別直接 decode_piece,不然就會看到一堆亂碼。
我的 main.rs 用的是 tokenizer.decode(&[next]),這樣才有正確處理——不過老實說這真的是個很容易忘記的雷。理想的 API 應該是個 Decoder struct,內部自己維護 byte buffer,每次餵 token 進來、適時 flush 出完整的 UTF-8 codepoint,這樣最省心。
Rust 用法:HashMap 的 entry pattern
這個 tokenizer 雖然沒用到,但這個 Rust 慣用法還是值得一提。我目前的初始化長這樣:
let mut token_to_id = HashMap::with_capacity(tokens.len());
for (i, t) in tokens.iter().enumerate() {
token_to_id.insert(t.clone(), i as u32);
}如果 vocab 有重複 token(理論上不應該),後面的會覆蓋前面的。如果想驗證沒有重複:
for (i, t) in tokens.iter().enumerate() {
use std::collections::hash_map::Entry;
match token_to_id.entry(t.clone()) {
Entry::Vacant(v) => { v.insert(i as u32); }
Entry::Occupied(_) => bail!("duplicate token at {i}"),
}
}entry API 是 Rust 處理 hash map 的標準慣用法——白話說就是「我不知道 key 在不在,但想根據它在不在來做不同的事」,一次 lookup 搞定,不必查兩遍。
Rust 用法:Option 的 chain
let bos = get_special(g, "tokenizer.ggml.bos_token_id").unwrap_or(1);get_special 回傳 Option<u32>,如果 metadata 沒有這個 key 就是 None。unwrap_or(1) 給了一個 default value——Llama 的 BOS token id 通常都是 1,所以這算是個合理的 fallback。
這就是 Rust 對 null 的優雅處理——Option 逼著你當場決定「沒有的時候怎麼辦」,而不是放任 NullPointerException 在 runtime 才給你來個措手不及。
效能最佳化空間
1. encode 的 priority queue 優化
把 O(n^2) 改成 O(n log n),對長 prompt 有顯著加速。只是實作 priority queue 還要處理 invalidation(合併後相鄰 pair 都變了)就有點麻煩了。想偷懶的話,tokenizers crate(HuggingFace 的 Rust tokenizer)已經有現成的優化版可以抄。
2. 預先建立 pair lookup table
每次 format!("{}{}", a, b) 都得做字串拼接加上 hash lookup。如果改成預先建好一張 HashMap<(u32, u32), u32>:
let mut pair_map: HashMap<(u32, u32), u32> = HashMap::new();
for (id, token) in tokens.iter().enumerate() {
// 嘗試把 token 拆成兩個現有 token 的拼接
for split in 1..token.len() {
if let (Some(&a), Some(&b)) = (
token_to_id.get(&token[..split]),
token_to_id.get(&token[split..]),
) {
pair_map.insert((a, b), id as u32);
}
}
}這樣 encode 時就只是 (u32, u32) → u32 的查表,完全不必字串拼接。代價是建表本身是 $O(\sum_t |t|)$,得在啟動時先花一點時間,算是拿啟動換執行速度吧。
3. SIMD 字串搜尋(次要)
replace('▁', ' ') 用的是逐字元 scan。長字串理論上可以用 SIMD 加速,不過 tokenizer 處理的字串通常都很短,這個真的沒必要,純粹列出來給大家參考一下。
4. 避免 ids.remove(i+1) 的 O(n)
ids[i] = best_id;
ids.remove(i + 1); // 每次都是 O(n)
Vec::remove 是 $O(n)$(後面的 element 全部要往前 shift)。理論上更好的資料結構是 linked list(doubly linked),合併只要 $O(1)$。不過 Rust 的 LinkedList 對 cache 很不友善,跑起來反而更慢——這就是經典的「理論最優和實際最優不一樣」的案例呢。
實務上 prompt 長度頂多幾百到幾千 tokens,$O(n^2)$ 配上 cache friendly 的 Vec,反而比 LinkedList 還快。所以別被 big-O 騙了。
5. 批次處理 byte fallback
對長的 unicode 文字,目前每個 char 都會做一次 hashmap lookup。理論上可以先把連續的 byte fallback 序列 group 起來、批次處理。不過收益其實有限,畢竟 hashmap lookup 本身就快得很。
一個我曾經踩過的實際雷
我第一次寫這個 tokenizer 的時候,有個 case 怎麼編碼都不對:英文 prompt 後面接中文。搞了半天才發現,問題是我忘了讓 byte fallback 的 token繼續參與後面的合併迴圈。當時我的邏輯把 byte fallback 當成「終態」——一旦變成 byte token 就不再動它了,但 SentencePiece 的訓練資料裡,其實可能有把連續 byte 合併成更高層 token 的規則啊。
還好現在的實作放對位置了——byte fallback 只是「初始化」,後面的合併迴圈會把它們跟其他 token 一視同仁地一起合併。不過這個雷讓我學到一件事:寫 tokenizer 一定要拿至少幾十個 prompt 去對拍 llama.cpp,光靠一兩個 happy path 測試,遲早出包。
總結:tokenizer.rs 的角色
- 概念上:把字串和 token id list 互相轉換。
- 演算法上:SentencePiece 的 greedy merge encode 加上給 OOV 用的 byte fallback。
- 微妙細節:byte fallback 的存在性檢查、還有 decode 時的 UTF-8 重組。
200 多行的程式,看似簡單,魔鬼卻全藏在 byte fallback 跟 UTF-8 拼合那些細節裡——這大概就是 tokenizer 最有意思的地方吧。下一篇來看 sampler.rs,聊聊拿到 logits 之後到底要怎麼選下一個 token。
系列文章:
- tiny-llm-runner 介紹
- (1) config ・ (2) dequant ・ (3) tensor ・ (4) model ・ (5) ops ・ (6) runner
- (7) tokenizer.rs(本篇)