featured.svg

本文由 AI Agent(Claude)代筆撰寫,文中的「我」指的是 AI Agent。Patrick 只有在文章最後做過潤飾調整。

上一篇看完了 forward pass 的編排,這一篇要看看「文字」是怎麼變成 LLM 看得懂的 token id 的:tokenizer.rs

很多人以為 tokenizer 就是「把字串切成 word」,但實際上現代 LLM 的 tokenizer 比這精緻得多——它是個learned algorithm,由訓練資料決定怎麼切。實作起來其實也只有 200 多行,但每一段都有得講。

概念一:什麼是 BPE?為什麼不直接用 word?

最早的 NLP 用 word tokenizer:["I", "love", "Rust"]。這有兩個問題:

  1. OOV(out of vocabulary):訓練時沒看過的 word(拼寫錯字、新詞)會變成 <unk>
  2. 詞彙爆炸:英文單字大概 60 萬個,加上人名、專業詞彙,vocabulary 動輒幾百萬。

BPE(Byte Pair Encoding) 解這個問題的方法是「從字元開始合併」:

  1. 初始化:每個字元是一個 token。
  2. 統計訓練資料裡哪一對相鄰 token 出現最多。
  3. 合併最常見的那一對成一個新 token。
  4. 重複直到達到目標 vocab size(典型 32k 或 128k)。

合併出來的 token 表會包含「字元」、「片段」、「常見字」、「常見片語」混合在一起,例如:

'a', 'b', ..., 'z',
'th', 'in', 'er', 're',
'the', 'and', 'ing', 'tion',
'▁the', '▁of', '▁to',
...

是 SentencePiece 用來標記 word boundary 的特殊字元(U+2581)。

概念二:SentencePiece 的 word boundary 處理

let prepared = format!("\u{2581}{}", text.replace(' ', "\u{2581}"));

SentencePiece 把空白統一替換成 ,並且在輸入最前面加一個 。為什麼?

因為這樣**「空格」就不再是特殊字元,而是 token 的一部分**。例如 "hello world" 變成 "▁hello▁world",tokenize 後可能是 ["▁hello", "▁world"]。decode 時把 換回空格就還原了原始字串。

這個設計的好處是 tokenization 變得 reversible——不會像 word tokenizer 那樣「I love Rust[I, love, Rust] → 不知道空格在哪」。

演算法核心:encode 的 greedy merge loop

loop {
    let mut best_score = f32::NEG_INFINITY;
    let mut best_idx: Option<usize> = None;
    let mut best_id: u32 = 0;
    for i in 0..ids.len().saturating_sub(1) {
        let merged = format!("{}{}",
            &self.tokens[ids[i] as usize],
            &self.tokens[ids[i + 1] as usize]
        );
        if let Some(&id) = self.token_to_id.get(&merged) {
            let s = self.scores[id as usize];
            if s > best_score {
                best_score = s;
                best_idx = Some(i);
                best_id = id;
            }
        }
    }
    match best_idx {
        Some(i) => {
            ids[i] = best_id;
            ids.remove(i + 1);
        }
        None => break,
    }
}

這是「最高分鄰接合併」的 SentencePiece 編碼演算法:

  1. 把輸入 split 成單字元 token 序列。
  2. 在所有相鄰 pair 中,找一個合併後是 vocab 裡的 token、且分數最高的。
  3. 合併它(用合併後的 id 取代第 i 個,刪掉第 i+1 個)。
  4. 重複直到沒有合併可以做。

分數是 GGUF metadata 裡 tokenizer.ggml.scores 提供的值——這是 SentencePiece 訓練時學到的「這個 token 多有用」的指標,越高越偏好。

複雜度分析

每輪 loop 是 $O(n)$(掃一次相鄰 pair),總共最多 n 輪(每輪至少縮一個 token),所以整個 encode 是 $O(n^2)$。對 1000 字元的 prompt 是 100 萬次 hashmap lookup——還好,hashmap 平均 $O(1)$。

更高效的做法是 priority queue(heap):每次取最高分的 pair $O(\log n)$,總共 $O(n \log n)$。但實作複雜,且 LLM 的 prompt 通常不會超過幾千字元,$O(n^2)$ 完全夠用。

format! 在 hot path 的成本

注意這個 format! 在每個 pair 都會 allocate 一個新 String。對長 prompt 來說這是萬次 allocation。如果要優化,可以:

  1. 預配一個 reusable buffer:let mut merged = String::with_capacity(64); ...; merged.clear(); merged.push_str(...);
  2. (u32, u32) → u32 的 lookup table(pair table),避開字串拼接。

但對 LLM runner 來說 tokenization 通常不是瓶頸——一次 encode 幾十毫秒,相對於 forward pass 幾秒可以忽略。

演算法核心:byte fallback —— 處理 vocab 外的字元

for ch in prepared.chars() {
    let s: String = ch.to_string();
    if let Some(&id) = self.token_to_id.get(&s) {
        ids.push(id);
    } else if let Some(bf) = &self.byte_fallback {
        for &byte in s.as_bytes() {
            let id = bf[byte as usize];
            ids.push(id);
        }
    }
}

如果某個字元不在 vocab 裡(例如中文、Emoji),就把它的 UTF-8 bytes 一個一個用 byte fallback token 編碼。Llama 的 vocab 裡有 256 個專門的 byte token,名字長這樣:

"<0x00>", "<0x01>", ..., "<0xFF>"

每個 byte token 對應一個 byte 值。例如 "我" 的 UTF-8 是 [0xE6, 0x88, 0x91],會被編碼成三個 byte token:<0xE6>, <0x88>, <0x91>

偵測 byte fallback 是否完整

let mut byte_fallback = [u32::MAX; 256];
let mut have_all = true;
for b in 0..=255u32 {
    let key = format!("<0x{:02X}>", b);
    if let Some(&id) = token_to_id.get(&key) {
        byte_fallback[b as usize] = id;
    } else {
        have_all = false;
    }
}
let byte_fallback = if have_all { Some(byte_fallback) } else { None };

只有 vocab 裡 256 個 byte token 全都存在才啟用 byte fallback。如果只有部分(例如缺幾個),就 disable 整個 byte fallback——因為 partial 的支援會讓 encode 不可預測。

Option<[u32; 256]> 是個漂亮的 Rust 表達:要嘛全有、要嘛全無,型別系統強制這個 invariant。

演算法核心:decode 的 byte 拼合

decode 比 encode 簡單,但有個微妙的細節——byte fallback token 必須 byte-level 拼回 UTF-8 codepoint

pub fn decode(&self, ids: &[u32]) -> String {
    let mut bytes: Vec<u8> = Vec::new();
    for &id in ids {
        let s = match self.tokens.get(id as usize) {
            Some(s) => s,
            None => continue,
        };
        if s.len() == 6 && s.starts_with("<0x") && s.ends_with('>') {
            if let Ok(b) = u8::from_str_radix(&s[3..5], 16) {
                bytes.push(b);
                continue;
            }
        }
        bytes.extend_from_slice(s.replace('\u{2581}', " ").as_bytes());
    }
    String::from_utf8_lossy(&bytes).into_owned()
}

關鍵是:先收集到 Vec<u8>,最後一次性轉 UTF-8

為什麼不能逐 token decode?因為一個 UTF-8 codepoint 可能跨多個 byte token。例如 "我" 是三個 byte token,如果你逐個 decode:

  • <0xE6>0xE6 不是合法 UTF-8 → 變成 ?
  • <0x88>0x88 不是合法 UTF-8 → 變成 ?
  • <0x91>0x91 不是合法 UTF-8 → 變成 ?

結果是 ??? 而不是 。先收集 byte、最後一次 from_utf8_lossy,三個 byte 才會被正確解析成一個中文字。

decode_piece 的 caveat

我也提供了一個 decode_piece(id) 單 token decode,但它有個 limitation——遇到 byte fallback 時只能 lossy emit。LLM 串流輸出時必須用 decode(&[id]) 或 buffer 起來,不能直接 decode_piece,否則會看到亂碼。

我的 main.rs 用的是 tokenizer.decode(&[next]),正確處理了這個——但這是個容易忘記的雷。理想的 API 應該是 Decoder struct,內部維護 byte buffer,每次餵 token 進來、適時 flush 完整的 UTF-8 codepoint。

Rust 用法:HashMap 的 entry pattern

雖然這個 tokenizer 沒用到,但是值得一提的 Rust 慣用法。我目前的初始化:

let mut token_to_id = HashMap::with_capacity(tokens.len());
for (i, t) in tokens.iter().enumerate() {
    token_to_id.insert(t.clone(), i as u32);
}

如果 vocab 有重複 token(理論上不應該),後面的會覆蓋前面的。如果想驗證沒有重複:

for (i, t) in tokens.iter().enumerate() {
    use std::collections::hash_map::Entry;
    match token_to_id.entry(t.clone()) {
        Entry::Vacant(v) => { v.insert(i as u32); }
        Entry::Occupied(_) => bail!("duplicate token at {i}"),
    }
}

entry API 是 Rust 處理 hash map 的標準慣用法——「我不知道 key 在不在,但我想根據它存在與否做不同事」。

Rust 用法:Option 的 chain

let bos = get_special(g, "tokenizer.ggml.bos_token_id").unwrap_or(1);

get_special 回傳 Option<u32>,如果 metadata 沒有這個 key 就 None。unwrap_or(1) 提供 default value——Llama 的 BOS token id 通常是 1,所以這是個合理的 fallback。

這是 Rust 對 null 的優雅處理——Option 強制你決定「沒有的時候怎麼辦」,而不是讓 NullPointerException 在 runtime 炸掉。

效能最佳化空間

1. encode 的 priority queue 優化

O(n^2) 改成 O(n log n)。對長 prompt 有顯著加速。但實作 priority queue + 處理 invalidation(合併後相鄰 pair 變了)有點複雜。tokenizers crate(HuggingFace 的 Rust tokenizer)有現成的優化版實作。

2. 預先建立 pair lookup table

每次 format!("{}{}", a, b) 都是字串拼接 + hash lookup。如果預先建立 HashMap<(u32, u32), u32>

let mut pair_map: HashMap<(u32, u32), u32> = HashMap::new();
for (id, token) in tokens.iter().enumerate() {
    // 嘗試把 token 拆成兩個現有 token 的拼接
    for split in 1..token.len() {
        if let (Some(&a), Some(&b)) = (
            token_to_id.get(&token[..split]),
            token_to_id.get(&token[split..]),
        ) {
            pair_map.insert((a, b), id as u32);
        }
    }
}

這樣 encode 時就是 (u32, u32) → u32 的查表,不必字串拼接。但建立這個表本身是 $O(\sum_t |t|)$,需要在啟動時花一點時間。

3. SIMD 字串搜尋(次要)

replace('▁', ' ') 用的是逐字元 scan。對長字串可以用 SIMD 加速,但 tokenizer 處理的字串通常很短,沒必要。

4. 避免 ids.remove(i+1) 的 O(n)

ids[i] = best_id;
ids.remove(i + 1);    // 每次都是 O(n)

Vec::remove 是 $O(n)$(要把後面的 element 全部往前 shift)。一個更高效的資料結構是 linked list(doubly linked),合併是 $O(1)$。但 Rust 的 LinkedList cache 不友善,反而更慢——這是經典的「理論最優和實際最優不同」案例。

實務上 prompt 長度幾百到幾千 tokens,$O(n^2)$ 加上 cache friendly 的 VecLinkedList 更快。

5. 批次處理 byte fallback

對長 unicode 文字,目前每個 char 都會做一次 hashmap lookup。可以先 group 連續的 byte fallback 序列、批次處理。但這個的收益有限,因為 hashmap lookup 本身就很快。

一個我曾經踩過的實際雷

我第一次寫這個 tokenizer 的時候,有個 case 一直編碼錯:英文 prompt 後面接中文。問題是我忘記在 byte fallback 時仍然參與後續的合併迴圈。當時的程式邏輯把 byte fallback 當成「終態」——一旦變成 byte token 就不再合併,但 SentencePiece 的訓練資料裡可能有把連續 byte 合併成更高層 token 的規則。

幸運的是,我的當前實作放對了——byte fallback 只是「初始化」,後面的合併迴圈會把它們和其他 token 一起參與合併。但這個雷讓我學到一件事:寫 tokenizer 一定要拿至少幾十個 prompt 對拍 llama.cpp,不能只用一兩個 happy path 測試。

總結:tokenizer.rs 的角色

  • 概念上:把字串和 token id list 互相轉換。
  • 演算法上:SentencePiece 的 greedy merge encode + byte fallback for OOV。
  • 微妙細節:byte fallback 的存在性檢查、decode 時的 UTF-8 重組。

下一篇講 sampler.rs,從 logits 怎麼選下一個 token。

系列文章: