Simply Patrick

tiny-llm-runner 深入解讀 (7):tokenizer.rs —— SentencePiece-BPE 與 Byte Fallback

featured.svg

本文由 AI Agent(Claude)代筆撰寫,文中的「我」指的是 AI Agent。Patrick 只有在文章最後做過潤飾調整。

上一篇看完了 forward pass 的編排,這一篇要看看「文字」是怎麼變成 LLM 看得懂的 token id 的:tokenizer.rs

很多人以為 tokenizer 就是「把字串切成 word」,但實際上現代 LLM 的 tokenizer 比這精緻得多——它是個learned algorithm,由訓練資料決定怎麼切。實作起來其實也只有 200 多行,但每一段都有得講。

概念一:什麼是 BPE?為什麼不直接用 word?

最早的 NLP 用 word tokenizer:["I", "love", "Rust"]。這有兩個問題:

  1. OOV(out of vocabulary):訓練時沒看過的 word(拼寫錯字、新詞)會變成 <unk>
  2. 詞彙爆炸:英文單字大概 60 萬個,加上人名、專業詞彙,vocabulary 動輒幾百萬。

BPE(Byte Pair Encoding) 解這個問題的方法是「從字元開始合併」:

  1. 初始化:每個字元是一個 token。
  2. 統計訓練資料裡哪一對相鄰 token 出現最多。
  3. 合併最常見的那一對成一個新 token。
  4. 重複直到達到目標 vocab size(典型 32k 或 128k)。

合併出來的 token 表會包含「字元」、「片段」、「常見字」、「常見片語」混合在一起,例如:

'a', 'b', ..., 'z',
'th', 'in', 'er', 're',
'the', 'and', 'ing', 'tion',
'▁the', '▁of', '▁to',
...

是 SentencePiece 用來標記 word boundary 的特殊字元(U+2581)。

概念二:SentencePiece 的 word boundary 處理

let prepared = format!("\u{2581}{}", text.replace(' ', "\u{2581}"));

SentencePiece 把空白統一替換成 ,並且在輸入最前面加一個 。為什麼?

因為這樣**「空格」就不再是特殊字元,而是 token 的一部分**。例如 "hello world" 變成 "▁hello▁world",tokenize 後可能是 ["▁hello", "▁world"]。decode 時把 換回空格就還原了原始字串。

這個設計的好處是 tokenization 變得 reversible——不會像 word tokenizer 那樣「I love Rust[I, love, Rust] → 不知道空格在哪」。

演算法核心:encode 的 greedy merge loop

loop {
    let mut best_score = f32::NEG_INFINITY;
    let mut best_idx: Option<usize> = None;
    let mut best_id: u32 = 0;
    for i in 0..ids.len().saturating_sub(1) {
        let merged = format!("{}{}",
            &self.tokens[ids[i] as usize],
            &self.tokens[ids[i + 1] as usize]
        );
        if let Some(&id) = self.token_to_id.get(&merged) {
            let s = self.scores[id as usize];
            if s > best_score {
                best_score = s;
                best_idx = Some(i);
                best_id = id;
            }
        }
    }
    match best_idx {
        Some(i) => {
            ids[i] = best_id;
            ids.remove(i + 1);
        }
        None => break,
    }
}

這是「最高分鄰接合併」的 SentencePiece 編碼演算法:

  1. 把輸入 split 成單字元 token 序列。
  2. 在所有相鄰 pair 中,找一個合併後是 vocab 裡的 token、且分數最高的。
  3. 合併它(用合併後的 id 取代第 i 個,刪掉第 i+1 個)。
  4. 重複直到沒有合併可以做。

分數是 GGUF metadata 裡 tokenizer.ggml.scores 提供的值——這是 SentencePiece 訓練時學到的「這個 token 多有用」的指標,越高越偏好。

複雜度分析

每輪 loop 是 $O(n)$(掃一次相鄰 pair),總共最多 n 輪(每輪至少縮一個 token),所以整個 encode 是 $O(n^2)$。對 1000 字元的 prompt 是 100 萬次 hashmap lookup——還好,hashmap 平均 $O(1)$。

更高效的做法是 priority queue(heap):每次取最高分的 pair $O(\log n)$,總共 $O(n \log n)$。但實作複雜,且 LLM 的 prompt 通常不會超過幾千字元,$O(n^2)$ 完全夠用。

format! 在 hot path 的成本

注意這個 format! 在每個 pair 都會 allocate 一個新 String。對長 prompt 來說這是萬次 allocation。如果要優化,可以:

  1. 預配一個 reusable buffer:let mut merged = String::with_capacity(64); ...; merged.clear(); merged.push_str(...);
  2. (u32, u32) → u32 的 lookup table(pair table),避開字串拼接。

但對 LLM runner 來說 tokenization 通常不是瓶頸——一次 encode 幾十毫秒,相對於 forward pass 幾秒可以忽略。

演算法核心:byte fallback —— 處理 vocab 外的字元

for ch in prepared.chars() {
    let s: String = ch.to_string();
    if let Some(&id) = self.token_to_id.get(&s) {
        ids.push(id);
    } else if let Some(bf) = &self.byte_fallback {
        for &byte in s.as_bytes() {
            let id = bf[byte as usize];
            ids.push(id);
        }
    }
}

如果某個字元不在 vocab 裡(例如中文、Emoji),就把它的 UTF-8 bytes 一個一個用 byte fallback token 編碼。Llama 的 vocab 裡有 256 個專門的 byte token,名字長這樣:

"<0x00>", "<0x01>", ..., "<0xFF>"

每個 byte token 對應一個 byte 值。例如 "我" 的 UTF-8 是 [0xE6, 0x88, 0x91],會被編碼成三個 byte token:<0xE6>, <0x88>, <0x91>

偵測 byte fallback 是否完整

let mut byte_fallback = [u32::MAX; 256];
let mut have_all = true;
for b in 0..=255u32 {
    let key = format!("<0x{:02X}>", b);
    if let Some(&id) = token_to_id.get(&key) {
        byte_fallback[b as usize] = id;
    } else {
        have_all = false;
    }
}
let byte_fallback = if have_all { Some(byte_fallback) } else { None };

只有 vocab 裡 256 個 byte token 全都存在才啟用 byte fallback。如果只有部分(例如缺幾個),就 disable 整個 byte fallback——因為 partial 的支援會讓 encode 不可預測。

Option<[u32; 256]> 是個漂亮的 Rust 表達:要嘛全有、要嘛全無,型別系統強制這個 invariant。

演算法核心:decode 的 byte 拼合

decode 比 encode 簡單,但有個微妙的細節——byte fallback token 必須 byte-level 拼回 UTF-8 codepoint

pub fn decode(&self, ids: &[u32]) -> String {
    let mut bytes: Vec<u8> = Vec::new();
    for &id in ids {
        let s = match self.tokens.get(id as usize) {
            Some(s) => s,
            None => continue,
        };
        if s.len() == 6 && s.starts_with("<0x") && s.ends_with('>') {
            if let Ok(b) = u8::from_str_radix(&s[3..5], 16) {
                bytes.push(b);
                continue;
            }
        }
        bytes.extend_from_slice(s.replace('\u{2581}', " ").as_bytes());
    }
    String::from_utf8_lossy(&bytes).into_owned()
}

關鍵是:先收集到 Vec<u8>,最後一次性轉 UTF-8

為什麼不能逐 token decode?因為一個 UTF-8 codepoint 可能跨多個 byte token。例如 "我" 是三個 byte token,如果你逐個 decode:

  • <0xE6>0xE6 不是合法 UTF-8 → 變成 ?
  • <0x88>0x88 不是合法 UTF-8 → 變成 ?
  • <0x91>0x91 不是合法 UTF-8 → 變成 ?

結果是 ??? 而不是 。先收集 byte、最後一次 from_utf8_lossy,三個 byte 才會被正確解析成一個中文字。

decode_piece 的 caveat

我也提供了一個 decode_piece(id) 單 token decode,但它有個 limitation——遇到 byte fallback 時只能 lossy emit。LLM 串流輸出時必須用 decode(&[id]) 或 buffer 起來,不能直接 decode_piece,否則會看到亂碼。

我的 main.rs 用的是 tokenizer.decode(&[next]),正確處理了這個——但這是個容易忘記的雷。理想的 API 應該是 Decoder struct,內部維護 byte buffer,每次餵 token 進來、適時 flush 完整的 UTF-8 codepoint。

Rust 用法:HashMap 的 entry pattern

雖然這個 tokenizer 沒用到,但是值得一提的 Rust 慣用法。我目前的初始化:

let mut token_to_id = HashMap::with_capacity(tokens.len());
for (i, t) in tokens.iter().enumerate() {
    token_to_id.insert(t.clone(), i as u32);
}

如果 vocab 有重複 token(理論上不應該),後面的會覆蓋前面的。如果想驗證沒有重複:

for (i, t) in tokens.iter().enumerate() {
    use std::collections::hash_map::Entry;
    match token_to_id.entry(t.clone()) {
        Entry::Vacant(v) => { v.insert(i as u32); }
        Entry::Occupied(_) => bail!("duplicate token at {i}"),
    }
}

entry API 是 Rust 處理 hash map 的標準慣用法——「我不知道 key 在不在,但我想根據它存在與否做不同事」。

Rust 用法:Option 的 chain

let bos = get_special(g, "tokenizer.ggml.bos_token_id").unwrap_or(1);

get_special 回傳 Option<u32>,如果 metadata 沒有這個 key 就 None。unwrap_or(1) 提供 default value——Llama 的 BOS token id 通常是 1,所以這是個合理的 fallback。

這是 Rust 對 null 的優雅處理——Option 強制你決定「沒有的時候怎麼辦」,而不是讓 NullPointerException 在 runtime 炸掉。

效能最佳化空間

1. encode 的 priority queue 優化

O(n^2) 改成 O(n log n)。對長 prompt 有顯著加速。但實作 priority queue + 處理 invalidation(合併後相鄰 pair 變了)有點複雜。tokenizers crate(HuggingFace 的 Rust tokenizer)有現成的優化版實作。

2. 預先建立 pair lookup table

每次 format!("{}{}", a, b) 都是字串拼接 + hash lookup。如果預先建立 HashMap<(u32, u32), u32>

let mut pair_map: HashMap<(u32, u32), u32> = HashMap::new();
for (id, token) in tokens.iter().enumerate() {
    // 嘗試把 token 拆成兩個現有 token 的拼接
    for split in 1..token.len() {
        if let (Some(&a), Some(&b)) = (
            token_to_id.get(&token[..split]),
            token_to_id.get(&token[split..]),
        ) {
            pair_map.insert((a, b), id as u32);
        }
    }
}

這樣 encode 時就是 (u32, u32) → u32 的查表,不必字串拼接。但建立這個表本身是 $O(\sum_t |t|)$,需要在啟動時花一點時間。

3. SIMD 字串搜尋(次要)

replace('▁', ' ') 用的是逐字元 scan。對長字串可以用 SIMD 加速,但 tokenizer 處理的字串通常很短,沒必要。

4. 避免 ids.remove(i+1) 的 O(n)

ids[i] = best_id;
ids.remove(i + 1);    // 每次都是 O(n)

Vec::remove 是 $O(n)$(要把後面的 element 全部往前 shift)。一個更高效的資料結構是 linked list(doubly linked),合併是 $O(1)$。但 Rust 的 LinkedList cache 不友善,反而更慢——這是經典的「理論最優和實際最優不同」案例。

實務上 prompt 長度幾百到幾千 tokens,$O(n^2)$ 加上 cache friendly 的 VecLinkedList 更快。

5. 批次處理 byte fallback

對長 unicode 文字,目前每個 char 都會做一次 hashmap lookup。可以先 group 連續的 byte fallback 序列、批次處理。但這個的收益有限,因為 hashmap lookup 本身就很快。

一個我曾經踩過的實際雷

我第一次寫這個 tokenizer 的時候,有個 case 一直編碼錯:英文 prompt 後面接中文。問題是我忘記在 byte fallback 時仍然參與後續的合併迴圈。當時的程式邏輯把 byte fallback 當成「終態」——一旦變成 byte token 就不再合併,但 SentencePiece 的訓練資料裡可能有把連續 byte 合併成更高層 token 的規則。

幸運的是,我的當前實作放對了——byte fallback 只是「初始化」,後面的合併迴圈會把它們和其他 token 一起參與合併。但這個雷讓我學到一件事:寫 tokenizer 一定要拿至少幾十個 prompt 對拍 llama.cpp,不能只用一兩個 happy path 測試。

總結:tokenizer.rs 的角色

  • 概念上:把字串和 token id list 互相轉換。
  • 演算法上:SentencePiece 的 greedy merge encode + byte fallback for OOV。
  • 微妙細節:byte fallback 的存在性檢查、decode 時的 UTF-8 重組。

下一篇講 sampler.rs,從 logits 怎麼選下一個 token。

系列文章:


tiny-llm-runner 深入解讀 (6):runner.rs —— Forward Pass 與 KV Cache 的編排

featured.svg

本文由 AI Agent(Claude)代筆撰寫,文中的「我」指的是 AI Agent。Patrick 只有在文章最後做過潤飾調整。

上一篇看完了所有 Transformer 原語,這一篇要看的是把它們編排成一次完整 forward pass 的核心檔案:runner.rs

整個專案最值得讀的就是這個檔案——大約 170 行的 Rust,把 Llama 的整個推論流程鋪展開來,每一步都看得清清楚楚。

概念一:autoregressive 推論的本質

LLM 推論不是「一次處理整段文字」,而是「一次處理一個 token」:

prompt: "The capital of France"
        → token ids [464, 3139, 286, 4881]

prefill (處理 prompt):
  forward(464)  → logits     (我們不用,但要建 KV cache)
  forward(3139) → logits
  forward(286)  → logits
  forward(4881) → logits     ← 用這個 logits 抽下一個 token

decode (生成):
  sampler(logits) → " is"
  forward(" is")  → logits
  sampler(logits) → " Paris"
  forward(" Paris") → logits
  ... 一直到 EOS

每次 forward(token) 是一個 step。step 之間透過 KV cache 累積上下文。

概念二:KV Cache —— 為什麼 attention 要 cache 過去?

Attention 的數學是這樣的:

$$\text{attn}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right) V$$

在 autoregressive 推論時,第 t 個 token 的 Q 只需要和前 t 個 token 的 K/V 做運算(因為未來的 token 還沒出現)。如果每次 forward 都重新計算所有過去的 K/V,那 prefill 階段是 $O(n^2)$,decode 階段每次都是 $O(n)$ 的重做——太浪費了。

KV cache 的作法:每算過一次某個 position 的 K/V 就存起來,下次直接用。這樣每次 forward 只需要算當前這個 token 的 Q/K/V,K 和 V 寫進 cache、Q 拿來和整個 cache 點積。

概念三:Context 和 KV Cache 的關係

很多人聽到「context length 2048」會直覺想成「模型有個地方存著最近 2048 個 token」。這不正確——模型實際上只存了 2048 個 token 被各層 attention 算出來的 K 和 V。Context 的這個容量限制就是 KV cache 的容量限制。

關係可以這樣理解:

context window = 「我能看見多遠的過去」的能力上限
KV cache       = 這個能力實際的儲存形式
n_ctx          = KV cache 在 position 維度上的大小

當你看到 n_ctx = 2048,背後真正配出來的記憶體就是:

kcache: [n_layer][n_ctx × kv_dim] f32
vcache: [n_layer][n_ctx × kv_dim] f32

對 TinyLlama (n_layer=22, kv_dim=256) 來說,這是 22 × 2 × 2048 × 256 × 4 = 92 MB。對 Llama-2 7B (n_layer=32, kv_dim=4096) 來說,是 32 × 2 × 2048 × 4096 × 4 = 2 GB——KV cache 比模型本身的量化權重還大。

Token 一旦進入 cache,就再也找不回來了。Context window 滿了之後,要嘛截斷舊 token、要嘛丟掉新 token,沒有第三個選擇——因為原始 token id 早就被 attention 投影成 K/V 之後丟掉了。

概念四:為什麼不存 token 本身?為什麼不存 embedding?為什麼是 K/V?

這個問題的答案會徹底解釋 KV cache 的設計。讓我們從「最樸素」的方案開始往下思考:

方案 A:只存 token id

「我把過去的 token id 都存起來,下次推論時整段重跑就好。」

問題:這個方案下,每次 forward 都要從 token id 開始,重算 embedding、重算 22 層 transformer。對第 t 個 token 來說是 $O(t \times \text{layers} \times \text{matmul})$ 的工作,整個 prompt 的 prefill 是 $O(n^2)$,decode 累積下來也是 $O(n^2)$。完全沒有 cache 的意義

方案 B:存 embedding(每個 token 對應一個 n_embd 向量)

「那我存 token embedding,這樣可以跳過 embedding lookup。」

問題:embedding 只是 forward pass 的輸入。從 embedding 到「attention 真正用的東西」,還要走完整層的 RMSNorm + Q/K/V 投影。也就是說每層的 attention 看到的是「自己這一層的輸入經過 W_k 和 W_v 投影後的 K/V」,不是 token embedding。如果只存 token embedding:

  • 每次 forward 都要重新跑 22 層的 RMSNorm、重新算 K/V 投影。
  • 但這些工作的結果和 token 位置 t 無關——每次重算都會得到同一個 K_t 和 V_t。
  • 純粹的浪費。

更糟的是:第 1 層的輸入是 token embedding,但第 2 層的輸入是「第 1 層的輸出」。所以你不能用「token embedding」這個單一概念來代表「每一層 attention 需要的東西」。每層之間隔著 attention + FFN 的非線性轉換,「embedding」這個詞只在最開始那一層有意義。

方案 C:存每層的 hidden state(每層都有自己的 [n_layer][n_ctx][n_embd]

「那我存每層的輸出 hidden state?」

問題:你存了 hidden state,但 attention 真正需要的是 K = W_k(hidden_state)V = W_v(hidden_state)。每次 forward 還是要重算 K/V 投影(一個 [n_embd, kv_dim] 的 matvec)。白存了,因為投影沒省到

而且這個方案佔的記憶體比 K/V cache 大:hidden state 是 n_embd 寬,K/V 在 GQA 下只有 kv_dim = head_dim × n_head_kv 寬(TinyLlama 是 256,比 n_embd=2048 小 8 倍)。

方案 D:直接存 K 和 V(最終答案)

K 和 V 才是 attention 真正消費的東西。$\text{attn}(Q, K, V)$ 的公式裡沒有 hidden state、沒有 embedding、沒有 token id——只有 Q、K、V。把過去的 K 和 V 存起來,attention 就完整了

而且 K/V 有兩個迷人的性質:

  1. 它們和 Q 是解耦的:K_t 和 V_t 一旦算出來就和「未來會用什麼 Q 來查它」無關。所以可以放心 cache。
  2. 它們不需要被未來修改:因為 LLM 是 causal——位置 t 的 K/V 只被 position ≥ t 的 Q 查詢,但這些查詢不會回頭改 K_t/V_t。一旦寫進 cache 就是 immutable 的。

這就是為什麼 KV cache 是 transformer 推論的「最佳化終點」——它剛好存了 attention 需要的最少資訊,再少就會破壞語義,再多就是浪費。

從「資訊保留鏈」看這件事

換個角度:模型做 forward pass 是一條把「token id」逐步轉成「下一個 token 機率分佈」的資訊流。中途有很多中間表示:

token id → embedding → layer 1 hidden → layer 1 K/V → ...
                                     → layer 2 hidden → layer 2 K/V → ...
                                     ...
                                     → final hidden → logits

每一個中間表示都包含了關於這個 token 的某種資訊。對 attention 來說,這條鏈上唯一一個「跨 token 互動」的點就是它把 K/V 拿來算內積。其他每一步(RMSNorm、Q/K/V 投影、FFN…)都是 pure-position 的——只處理當前這個 token 的資料。

所以:

  • 那些「pure-position」的步驟不需要 cache,因為每個 token 的計算和其他 token 無關。
  • 那個「跨 token 互動」的步驟必須 cache,因為它要看到所有過去 token 的資訊。

而 attention 跨 token 看到的東西就是 K 和 V。所以這就是該被 cache 的東西。

一個簡單的比喻

把 LLM 比喻成一個圖書館:

  • Token id 是書的索書號。
  • Embedding 是書的封面(提供入口)。
  • 每層 hidden state 是書頁的內容(給讀者看)。
  • K 和 V 是給其他書「互相引用」用的索引卡。

Attention 是書與書之間的對話。書本身(hidden state)很厚,但對話只透過索引卡(K/V)進行。把對話用過的卡片留在桌上,下次新書進來就可以直接和它們對話——不需要把整本書重新搬出來看。

概念五:為什麼 KV cache 必須是 per-layer,而不只是 per-token?

注意 Runner 的 cache 是 Vec<Vec<f32>>第一個維度是 layer,第二個才是 token position

kcache: Vec<Vec<f32>>,   // [n_layer][n_ctx × kv_dim]
vcache: Vec<Vec<f32>>,

對 22 層的 TinyLlama 來說,這是 22 份獨立的 cache。為什麼不能共用同一份?為什麼一個 token 不只對應一組 K/V,而是對應 n_layer 組?

因為「同一個 token」在每一層的 K/V 是完全不同的東西

關鍵在 forward pass 的結構。回頭看看每一層 attention 在算什麼:

layer 1: x¹  = embedding(token)
         k¹  = W_k¹(rmsnorm(x¹))     ← 第 1 層的 K
         v¹  = W_v¹(rmsnorm(x¹))     ← 第 1 層的 V
         x²  = x¹ + attn(...) + ffn(...)

layer 2: k²  = W_k²(rmsnorm(x²))     ← 第 2 層的 K(input 不同、權重不同)
         v²  = W_v²(rmsnorm(x²))
         x³  = x² + attn(...) + ffn(...)

layer 3: k³  = W_k³(rmsnorm(x³))     ← 第 3 層的 K
         ...

每一層的 K/V 同時受兩件事決定:

  1. 這一層的 input hidden state —— 上一層 attention + FFN 完成後傳下來的東西,每層都不同。
  2. 這一層自己的權重 W_k^l / W_v^l —— 每層獨立訓練的不同矩陣。

也就是說 layer 1 的 K 是「token 在剛 embed 時的樣子被 W_k¹ 看出來的特徵」,layer 5 的 K 是「token 經過 4 層 attention + FFN 整合上下文後的樣子被 W_k⁵ 看出來的特徵」。這兩個 K 既不是同一個東西、也不能互相替代

換個角度:每一層都有自己的 attention

Transformer 的設計本身就讓每層 attention 是獨立的計算單元。每層問「我這層的 query 和我這層 cache 裡的 key 像不像」,答案決定我這層怎麼把 V 加總起來。底下幾層通常學「文法、近距離 token 關係」,中層學「短語結構」,高層學「語意、長距離依賴」——這些功能只有在它們各自的 K/V 空間裡才有意義。

如果你硬要說「我只想存一份 K/V 給所有層用」,那等於是在說「所有層都做同一件事」——這就退化成一個非常淺的模型了。Transformer 的深度價值就在於每層做不同的轉換,而每層的 K/V 是這個轉換的 fingerprint。

從層之間的依賴鏈看為什麼不能省

更技術一點:layer l 的 K/V 是 layer l-1 的輸出的函式。整個 forward pass 是這樣的串聯:

        x¹              x²              x³
embed → ─→ attn¹+ffn¹ ─→ ─→ attn²+ffn² ─→ ─→ ...
        │                │                │
        └─→ k¹, v¹       └─→ k², v²       └─→ k³, v³

每一層的 K/V 都「只能在這一層用」——下一層的 attention 不會去看 k¹,因為它要看的是經過這一層 transform 過的世界。所以這 22 份 K/V 是 22 個獨立的記憶體,不是一份共用的。

一個 token 在 cache 裡到底佔多少?

把上面所有事情串起來,一個 token 進入 cache 後實際佔的記憶體是:

single token → n_layer × 2 × kv_dim × sizeof(f32)
              ↑          ↑   ↑
              層數       K + V

對 TinyLlama (n_layer=22, kv_dim=256, f32) 來說:

22 × 2 × 256 × 4 = 45 KB per token

對 Llama-2 7B (n_layer=32, kv_dim=4096) 來說:

32 × 2 × 4096 × 4 = 1 MB per token

這就是為什麼長 context 推論這麼吃記憶體——不只是 token 數量乘上一個小的數字,而是 token 數量乘上 n_layer × 2 × kv_dim。一個 8K context 的 7B 模型光 KV cache 就要 8 GB,比模型本身還大。

如果 KV cache 可以「per-token only」(共用所有層),同樣的 8K context 只需要 32 MB——但那樣的模型已經不是 Transformer 了。這個「per-layer」的代價,就是 Transformer 之所以是 Transformer 的代價

「per-layer × per-token」 是 cache 的最小完備形式

可以這樣總結:要讓 attention 在每一層都能正確算出來,你需要的最少資訊是:

維度 為什麼需要
per-layer 每層 attention 看的是不同的轉換空間,不能共用
per-token 每個 token 在每層的特徵都不同(causal mask 之外都會被未來查到)
K + V attention 公式裡的兩個輸入(Q 不需要 cache,因為它只在當前 step 用一次)

把這三個維度切片乘起來,就是 [n_layer][n_ctx][2][kv_dim] 的 4D 張量——這就是 KV cache 的最小完備形狀。再省任何一個維度都會破壞 attention 的語義;多存任何一個維度都是浪費(因為其他資訊都可以重新計算)。

回到我的 Rust 程式碼:

kcache: Vec<Vec<f32>>,   // [n_layer][n_ctx × kv_dim]
vcache: Vec<Vec<f32>>,   // [n_layer][n_ctx × kv_dim]

外層 Vec 是 layer 維度、內層的扁平陣列裡 n_ctx × kv_dim 是「token position × K/V 寬度」。這兩個 Vec<Vec<f32>> 加起來就是 attention 完整需要的最小狀態,一塊不多、一塊不少。

Runner struct 的記憶體佈局

pub struct Runner<'a, 'm> {
    model: &'m LlamaModel<'a>,
    rope_style: RopeStyle,
    kcache: Vec<Vec<f32>>,   // [n_layer][n_ctx * kv_dim]
    vcache: Vec<Vec<f32>>,
    pub pos: usize,

    // 預先配好的 scratch buffer
    x: Vec<f32>, xb: Vec<f32>, xb2: Vec<f32>,
    hb: Vec<f32>, hb2: Vec<f32>,
    q: Vec<f32>, att: Vec<f32>,
    logits: Vec<f32>,
}

幾個重要的設計決定:

兩個生命週期 'a'm

  • 'a 是 mmap 的生命週期。
  • 'mLlamaModel 的生命週期,且必須 'm: 'a(model 不能比 mmap 活得更久)。

Runner 借用 &'m LlamaModel<'a>,自己沒有持有任何模型資料,所以這個結構是 cheap to construct——所有大型 buffer 都是 scratch 用途。

KV cache 是 Vec<Vec<f32>> 而不是 Vec<f32>

每層一個獨立的 Vec,而不是把所有層拼在同一個 Vec 裡。這是為了:

  1. 生命週期分離:不同層的 cache 可以獨立管理(雖然目前我沒這麼做)。
  2. 記憶體大小靈活:理論上不同層可以有不同的 KV dim(例如某些 MoE 變體),雖然 Llama 沒有這個需求。

代價是多一層間接(access 要走 kcache[l][...])。對 LLM 推論來說 negligible。

一堆 scratch buffer 一次配好

x: vec![0.0; n_embd],     // 主 hidden state
xb: vec![0.0; n_embd],    // 暫存 a (attention/FFN 內部)
xb2: vec![0.0; n_embd],   // 暫存 b (殘差用)
hb: vec![0.0; n_ff],      // FFN gate 的中間結果
hb2: vec![0.0; n_ff],     // FFN up 的中間結果
q: vec![0.0; n_embd],     // 當前 token 的 Q
att: vec![0.0; n_head * n_ctx],  // attention scores
logits: vec![0.0; vocab],  // 輸出 logits

這些 buffer 在 new() 一次配好,整個推論過程不再 allocate。每次 forward() 進來都直接覆寫這些 buffer。沒有 GC、沒有 alloc 壓力。

對比一下 PyTorch/Candle 的做法:每個 op 回傳新 Tensor,靠 reference counting 回收。對訓練很合理(autograd 需要保留中間值),但對只做 inference 的 runner 來說,預配 + 覆寫是更精簡、更可預測的做法。

演算法核心:forward 的整體結構

forward(token) 的結構是這樣的:

flowchart TD A[token id] --> B[1. Embed → x] B --> C{2. 22 層 Transformer Block} C --> C1[2a. attn_norm: xb = norm·x] C1 --> C2[2b. Q/K/V projection
K/V 直接寫進 cache] C2 --> C3[2c. RoPE on Q & current K] C3 --> C4[2d. multi-head attention
Q · K^T / softmax / · V] C4 --> C5[2e. wo projection + 殘差] C5 --> C6[2f. ffn_norm + SwiGLU + 殘差] C6 --> C C --> D[3. final norm] D --> E[4. lm_head: logits]

每個 block 的內部精細化看起來會更清楚:

// attention norm
rmsnorm(&mut self.xb, &self.x, &layer.attn_norm, cfg.rms_eps);

// qkv projections
matvec(&mut self.q, &layer.wq, &self.xb);
let krow = &mut kc[pos * kv_dim..(pos + 1) * kv_dim];
let vrow = &mut vc[pos * kv_dim..(pos + 1) * kv_dim];
matvec(krow, &layer.wk, &self.xb);    // K 直接寫進 cache
matvec(vrow, &layer.wv, &self.xb);    // V 也是

// RoPE
apply_rope(&mut self.q, pos, head_dim, ...);
apply_rope(krow, pos, head_dim, ...);

// attention
for h in 0..cfg.n_head {
    let kv_head = h / gqa;          // GQA 對應
    // 算 attention scores
    // softmax
    // 加權 V
}

// 輸出投影 + 殘差
matvec(&mut self.xb2, &layer.wo, &self.xb);
add_inplace(&mut self.x, &self.xb2);

// FFN: x = x + Wdown(silu(Wgate(norm)) * Wup(norm))
rmsnorm(&mut self.xb, &self.x, &layer.ffn_norm, cfg.rms_eps);
matvec(&mut self.hb, &layer.w_gate, &self.xb);
matvec(&mut self.hb2, &layer.w_up, &self.xb);
for i in 0..self.hb.len() {
    self.hb[i] = silu(self.hb[i]) * self.hb2[i];
}
matvec(&mut self.xb2, &layer.w_down, &self.hb);
add_inplace(&mut self.x, &self.xb2);

演算法核心一:K/V 直接寫進 cache

let krow = &mut kc[pos * kv_dim..(pos + 1) * kv_dim];
let vrow = &mut vc[pos * kv_dim..(pos + 1) * kv_dim];
matvec(krow, &layer.wk, &self.xb);
matvec(vrow, &layer.wv, &self.xb);

這四行有個非常重要的設計選擇:K/V 計算的輸出 buffer 直接是 cache 的某一行,而不是先算到 scratch buffer 再 copy 進 cache。這省下了一次 n_embd × 4 bytes 的記憶體拷貝。

matvec 的簽章 fn matvec(out: &mut [f32], ...) 接受任何 mutable slice,cache 的 slice 完全可以塞進去。Rust 的借用檢查器會驗證 kcxbq 等其他 buffer 不會 alias。

演算法核心二:RoPE 的兩階段套用

apply_rope(&mut self.q, pos, head_dim, ...);
apply_rope(krow, pos, head_dim, ...);

注意這裡 RoPE 是套在 Q 和當前的 K 上,不是套在 cache 裡所有 K 上。為什麼?

因為 RoPE 是「絕對位置」資訊(每個 K_t 套用 t 對應的 RoPE),而每個 K_t 在它被計算的那個 forward call 裡套用一次就夠了。等到後面要做 attention 時,cache 裡的 K_t 已經帶著 RoPE,不必再做。

V 不需要 RoPE,因為 attention 對 V 的處理是線性 weighted sum,position 資訊在 Q·K 那裡就已經注入。

演算法核心三:GQA 的 head 對應

for h in 0..cfg.n_head {
    let kv_head = h / gqa;          // GQA 對應
    let q_off = h * head_dim;
    let q = &self.q[q_off..q_off + head_dim];

    let att = &mut self.att[h * cfg.n_ctx..h * cfg.n_ctx + (pos + 1)];
    for (t, score) in att.iter_mut().enumerate() {
        let k_off = t * kv_dim + kv_head * head_dim;
        let k = &self.kcache[l][k_off..k_off + head_dim];
        let mut s = 0.0f32;
        for i in 0..head_dim {
            s += q[i] * k[i];
        }
        *score = s * scale;
    }
    softmax(att);
    // ... 加權 V
}

GQA 的 trick:在 vanilla MHA(multi-head attention)裡,每個 query head 對應一個獨立的 K/V head。但 GQA 把 query head 分組,每組共用同一個 K/V head:

n_head = 32, n_head_kv = 4 → gqa = 8
query head 0..7   → K/V head 0
query head 8..15  → K/V head 1
query head 16..23 → K/V head 2
query head 24..31 → K/V head 3

kv_head = h / gqa 就是這個對應。這個技巧把 KV cache 的大小從 n_head × head_dim × n_ctx 縮成 n_head_kv × head_dim × n_ctx,對長 context 推論很重要——KV cache 是長 context 推論的記憶體大頭。

演算法核心四:attention scores 的記憶體佈局

let att = &mut self.att[h * cfg.n_ctx..h * cfg.n_ctx + (pos + 1)];

att buffer 是 [n_head, n_ctx] 的展平。對 head h 的 attention scores 佔 att[h*n_ctx .. h*n_ctx + n_ctx]。但每次 forward 我們只用前 pos+1 個(因為只有過去到現在的 token 有 score)。

這個設計的好處是:buffer 是固定大小(為 worst case n_ctx 配好),不需要 dynamic resize。代價是有些空間沒用到——對 TinyLlama 來說 n_head * n_ctx * 4 = 32 * 2048 * 4 = 256 KB,可以接受。

Rust 用法:借用檢查器與切片

這個 forward pass 裡有大量 &mut 切片操作:

let kc = &mut self.kcache[l];
let krow = &mut kc[pos * kv_dim..(pos + 1) * kv_dim];
matvec(krow, &layer.wk, &self.xb);

注意我必須先取 &mut self.kcache[l] 存進局部變數 kc,再對 kc 切片。為什麼?因為 Rust 的借用檢查器需要看到「我們對 self.kcache 的 mutable 借用」是清晰的。

這是 Rust 寫多 mutable buffer 程式碼最常踩的坑之一。如果你寫:

matvec(&mut self.kcache[l][pos*kv_dim..(pos+1)*kv_dim], &layer.wk, &self.xb);

而其他地方還有 &self.x&self.xb 之類的不可變借用,編譯器可能會抗議「不能同時 mutable 和 immutable 借用 self」。先把 mutable slice 提取出來,再做後續操作,是讓借用範圍縮小的標準技巧。

split_at_mut 的那種招數

更進階的場合可能要 split_at_mut——例如同時拿 K cache 和 V cache 的 mutable ref:

let (kc, vc) = (&mut self.kcache[l], &mut self.vcache[l]);
// 這個寫法借用檢查器會接受,因為 self.kcache 和 self.vcache 是不同的 field

kcache[l]vcache[l] 是不同的 field,所以可以同時 mutable borrow。如果是同一個 Vec 的兩段切片,就需要 split_at_mut

let (left, right) = my_vec.split_at_mut(mid);
// left 和 right 都是 &mut [T],但編譯器知道它們不重疊

效能最佳化空間

runner.rs 是整個專案的效能熱點集中地,有大量改進空間:

1. Attention 的 head-parallel

我目前的 attention loop 是 sequential for h in 0..n_head

for h in 0..cfg.n_head {
    // 算 attention scores、softmax、加權 V → 寫進 self.xb 的某段
}

每個 head 寫進 self.xb 的不同 slice,是 disjoint 的。可以用 rayonchunks_exact_mut 平行:

self.xb.par_chunks_exact_mut(head_dim).enumerate().for_each(|(h, out)| {
    // 算這個 head 的 attention
});

att buffer 是共享的(self.att),平行版需要每個 thread 獨立的 attention scratch。head 數量通常很小(32-64),fork-join 的 overhead 可能吃掉收益——這個值得 benchmark 但不一定划算。

2. FlashAttention 風格的 fused attention

我目前的 attention 是「先算所有 score、再 softmax、再加權 V」三步。FlashAttention 把它們 fuse 成一個 streaming pass,記憶體占用從 O(n_ctx) 降到 O(head_dim),且 cache 友善。但實作複雜得多——這是 candle/ggml 級別的優化。

3. KV cache 的量化

KV cache 是長 context 的記憶體大頭。對 7B 模型、2k context、F32 cache,每層 cache 是 2 * 2048 * 1024 * 4 = 16 MB,32 層 = 512 MB。如果把 cache 也量化成 Q8(1 byte/element),可以降到 128 MB。llama.cpp 已經支援這個(-ctk q8_0 -ctv q8_0)。

4. Prefill batching

目前 prefill 是 sequential——一個 token 一個 forward。但 prompt 的所有 token 都已知,可以拼成一個矩陣做 batched forward:

seq forward: O(L * (matvec for each layer))
batch forward: O(matmul of [L, n_embd] for each layer)

GEMM 的 cache locality 比 GEMV 好得多,prefill 速度可以提升 5-10×。但要重新設計 attention(causal mask 變成必須的)、要修改 KV cache 寫入方式(一次寫 L 個 token)。這是中量級的改動。

5. Speculative decoding

更進階的優化:用一個小的 “draft model” 一次猜 K 個 token,用主模型 verify。如果猜對 K 個就只花了一次 forward 的時間做 K 個 token。這是 inference latency 的革命性技術,但需要兩個模型協同。

6. Continuous batching

如果 runner 要 serve 多個請求,可以做 continuous batching——不同請求的 token 進入同一個 batch,動態加減。但這要求 KV cache 是「per-request」的,記憶體管理更複雜。vLLM 是這條路徑的代表。

一個有趣的權衡:為什麼不寫得更 functional?

我可以把 forward 寫得更 functional,例如把每個 layer 寫成一個閉包、用 fold 串起來。但目前這個 indicative、命令式的寫法反而更容易讀——你能一眼看出每一步在改哪個 buffer、用哪個權重。

LLM 推論的程式碼有個特殊性:90% 的時間在做矩陣運算,10% 在做 control flow。寫得太 functional 會把 control flow 的成本顯露在前景,反而蓋住了真正重要的計算。

總結:runner.rs 的角色

  • 概念上:把 token id 映射到 logits 的單一函式,內部管理 KV cache 和層級遞進。
  • 實作上:fixed-size scratch buffers + sequential layer loop + manual KV slice manipulation。
  • 設計上:把所有複雜性集中在這個檔案裡,其他檔案(opstensor)保持簡單。

下一篇講 tokenizer.rs,把 byte 流變成 token id 的細節。

系列文章:


tiny-llm-runner 深入解讀 (5):ops.rs —— Transformer 的六種樂高積木

featured.svg

本文由 AI Agent(Claude)代筆撰寫,文中的「我」指的是 AI Agent。Patrick 只有在文章最後做過潤飾調整。

上一篇看完了 LlamaModel 怎麼把張量組起來,這一篇要看的是 Transformer 真正的「動詞」們:ops.rs

這個檔案只有 100 行,但它就是整個 Transformer 的所有原語。一個現代 LLM 的所有計算,最終都會被分解成這六種運算的組合。

樂高積木一:RMSNorm

pub fn rmsnorm(out: &mut [f32], x: &[f32], w: &[f32], eps: f32) {
    let mut ss = 0.0f64;
    for &v in x {
        ss += v as f64 * v as f64;
    }
    ss /= x.len() as f64;
    ss += eps as f64;
    let scale = (1.0 / ss.sqrt()) as f32;
    for i in 0..x.len() {
        out[i] = w[i] * (x[i] * scale);
    }
}

算法

$$\text{out}_i = w_i \cdot \frac{x_i}{\sqrt{\frac{1}{n}\sum_j x_j^2 + \epsilon}}$$

也就是「先把 x 除以它自己的 RMS,再用 w 逐元素 scale」。RMSNorm 是 LayerNorm 的簡化版(沒有 mean-shift、沒有 bias),Llama 用它取代 LayerNorm 是 2019-2020 年的設計趨勢。

為什麼用 f64 累加?

注意 ssf64,不是 f32。這不是隨便寫的:累加大量 f32 值容易產生「精度吞噬」(catastrophic cancellation)。一個 4096 維的 hidden state,每個 f32 大約 1e-1 量級,平方後是 1e-2,4096 個累加起來大概 40-100。在這個量級下,每個新加進來的 1e-2 增量在 f32 上會有相對誤差 ~1e-7 × 100 = 1e-5,4096 次累積下來會放大。

llama.cpp 在這條路徑用了 f64 累加器,我必須跟它對齊——否則 RMSNorm 的輸出會和它差千分之一,雖然每層只差千分之一,但 22 層累積下來輸出 token 就會不一樣。這是「對齊既有實作」的另一個例子

樂高積木二:matvec —— 用 rayon 平行化

pub fn matvec(out: &mut [f32], w: &TensorView<'_>, x: &[f32]) {
    out.par_iter_mut().enumerate().for_each(|(i, o)| {
        *o = w.dot_row(i, x);
    });
}

這是整個專案的效能瓶頸所在,也是最簡潔的一段程式碼。讓我們拆解這幾行做了什麼:

  1. par_iter_mut() 來自 rayon,把 &mut [f32] 變成一個平行 iterator。
  2. enumerate() 加上 row index。
  3. for_each(|(i, o)| ...) 對每個 row 平行執行 closure。

每個 row 是獨立的(out[i] 只被當前 closure 寫入),所以不需要任何鎖rayon 的 work-stealing 排程器會自動把 row 切成 chunk 分給可用的執行緒。

為什麼 row-parallel 而不是 element-parallel?

如果改成 element-parallel(每個輸出元素一個 task),task 太細會被 rayon 的 task overhead 吃掉。row-parallel 的 granularity 剛好——每個 task 是一個完整的 dot product(4096 個乘加),夠粗到值得啟動執行緒,又夠細到能平均負載。

Rust 的 par_iter_mut 設計上避免了一個很容易踩的雷:你不能讓兩個 thread 同時寫 out 的同一個位置。par_iter_mut 的型別簽章保證每個 closure 拿到的是 &mut f32(單一元素的 mutable ref),編譯期就排除了 race condition。

樂高積木三:softmax —— max-shift 防止 overflow

pub fn softmax(x: &mut [f32]) {
    let mut max = f32::NEG_INFINITY;
    for &v in x.iter() {
        if v > max { max = v; }
    }
    let mut sum = 0.0f32;
    for v in x.iter_mut() {
        *v = (*v - max).exp();
        sum += *v;
    }
    let inv = 1.0 / sum;
    for v in x.iter_mut() {
        *v *= inv;
    }
}

數學上,softmax 是:

$$\text{softmax}(x_i) = \frac{e^{x_i}}{\sum_j e^{x_j}}$$

e^{x} 對大 x 會 overflow(f32 的 exp 在 x > ~88 就溢位)。標準做法是先減 max:

$$\text{softmax}(x_i) = \frac{e^{x_i - \max(x)}}{\sum_j e^{x_j - \max(x)}}$$

數學上等價(分子分母同乘 $e^{-\max(x)}$),但所有 exp 的輸入都在 $(-\infty, 0]$,永遠不會 overflow。這是寫 softmax 的不可省的步驟——少了它,遇到極端 logits 就會 NaN。

樂高積木四:RoPE —— 相對位置的旋轉

這是最微妙的一個運算。RoPE(Rotary Position Embedding)是 Llama 用來把位置資訊塞進 attention 的方法,數學上是把每個 head 的 Q/K 向量看成「複數對」並旋轉一個和位置相關的角度。

pub fn apply_rope(
    vec: &mut [f32], pos: usize, head_dim: usize,
    rot_dim: usize, base: f32, style: RopeStyle,
) {
    let half = rot_dim / 2;
    let n_heads = vec.len() / head_dim;
    for h in 0..n_heads {
        let off = h * head_dim;
        for i in 0..half {
            let freq = base.powf(-((2 * i) as f32) / rot_dim as f32);
            let theta = pos as f32 * freq;
            let (sin, cos) = theta.sin_cos();
            let (i0, i1) = match style {
                RopeStyle::Llama => (2 * i, 2 * i + 1),
                RopeStyle::Neox  => (i, i + half),
            };
            let v0 = vec[off + i0];
            let v1 = vec[off + i1];
            vec[off + i0] = v0 * cos - v1 * sin;
            vec[off + i1] = v0 * sin + v1 * cos;
        }
    }
}

核心想法:每對 (v[i0], v[i1]) 是一個 2D 向量,被旋轉一個角度 $\theta = \text{pos} \cdot \text{freq}_i$。不同的 i 用不同的 freq——i 越小 freq 越大(位置感越強)、i 越大 freq 越小(接近不變)。

頻率公式:

$$\text{freq}_i = \text{base}^{-2i / \text{rot\_dim}}$$

base 通常是 10000,rot_dim 是 head_dim(或某個比例)。i = 0 時 freq = 1(每移動一個 position 旋轉 1 弧度),i = rot_dim/2 - 1 時 freq ≈ 1/10000(要 6000 個 position 才旋轉 1 弧度)。

這個多尺度的旋轉讓 attention 既能感知近距離的細微位置差異、也能感知遠距離的大致位置。

為什麼有兩種 style?

let (i0, i1) = match style {
    RopeStyle::Llama => (2 * i, 2 * i + 1),         // 鄰接成對
    RopeStyle::Neox  => (i, i + half),              // 上下半成對
};

這兩個是完全不同的「複數對」配對方式

  • Llama:把 v[0], v[1] 當一對、v[2], v[3] 當一對……(鄰接 pair)
  • Neox:把 v[0], v[half] 當一對、v[1], v[half+1] 當一對……(上下 split)

兩者旋轉的數學是一樣的,但配對方式不同。為什麼兩者並存?

歷史上,HuggingFace 的 transformers 採用 NeoX 風格。但 llama.cpp 早期的 convert.py 在轉檔時會把 Q/K 矩陣的每對 row permute 成鄰接格式,這樣就可以用 Llama 風格做 RoPE,記憶體 access 更線性。

新版的 convert_hf_to_gguf.py 不做 permute,所以你必須用 NeoX 風格。搞錯了不會 crash 但模型會講胡話——這是我親自踩過的雷。

sin_cos()f32::sin_cos()——同時算 sin 和 cos,硬體上比分別算快一倍。

樂高積木五:SiLU —— SwiGLU 的活化函數

#[inline]
pub fn silu(x: f32) -> f32 {
    x / (1.0 + (-x).exp())
}

數學定義:

$$\text{SiLU}(x) = x \cdot \sigma(x) = \frac{x}{1 + e^{-x}}$$

$\sigma$ 是 sigmoid。SiLU 也叫 Swish。Llama 的 FFN 用的是 SwiGLU(Swish-Gated Linear Unit):

$$\text{FFN}(x) = W_{down}(\text{SiLU}(W_{gate}(x)) \odot W_{up}(x))$$

⊙ 是逐元素相乘。也就是先把 x 過 gate 和 up 兩個投影,gate 過 SiLU 之後逐元素乘上 up,再過 down 投影。比起傳統 ReLU FFN,SwiGLU 多了一個 gate 投影,但實驗顯示效果好得多。

#[inline] 是給編譯器的提示,因為這個函式會在 inner loop 被呼叫成千上萬次,inline 進去能避免函式呼叫的 overhead。

樂高積木六:add_inplace —— 殘差連接

pub fn add_inplace(a: &mut [f32], b: &[f32]) {
    for (x, y) in a.iter_mut().zip(b.iter()) {
        *x += *y;
    }
}

最樸素的一個原語:把 b 加進 a。Transformer 的每個 sublayer 都會做:

x = x + Attention(LayerNorm(x))
x = x + FFN(LayerNorm(x))

那兩個加號就是 add_inplace

注意我用 iter_mut().zip(iter()) 而不是用 index。這個寫法的好處是:

  1. 沒有越界檢查——iterator 自動知道何時停止。
  2. 編譯器更容易自動向量化——iterator 模式對 LLVM 友善。
  3. 零成本抽象——這個寫法和手寫 raw loop 編譯出來幾乎一樣的組合語言。

Rust 用法:#[derive] 的小幫手

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RopeStyle {
    Llama,
    Neox,
}

Copy 讓我可以把 RopeStyle 當值傳遞,不需要 &Eq 讓我可以在 match 比較。Debugeprintln!("{:?}", style) 可以印出來除錯。不到一秒的工作,省下 10 行 boilerplate——這是 Rust derive 機制的常規用法。

效能最佳化空間

ops.rs 是另一個重要的最佳化戰場。每個函式都有改進空間:

1. RMSNorm 的 SIMD

那個 for &v in x { ss += v * v; } 是個 reduction,可以用 SIMD 平行化:

use std::simd::f32x8;
let mut acc = f32x8::splat(0.0);
for chunk in x.chunks_exact(8) {
    let v = f32x8::from_slice(chunk);
    acc += v * v;
}
let ss = acc.reduce_sum() as f64;

f64 vs f32 累加器的精度問題在 SIMD 環境下變得有趣——AVX-512 有 f64 SIMD,AVX2 沒有。最務實的做法是維持 f32 SIMD reduction,因為實作簡單且誤差可接受(畢竟 4096 個 ~1e-2 量級的數,誤差不會大到讓 token 不一致)。

2. matvec 的 GEMV → GEMM 升級

我目前的 matvec 是 GEMV(matrix × vector),每次 forward pass 對每層做兩次(attention 的 Q + FFN 的 down)。但prefill 階段的 prompt 是序列 token,可以批次處理——把 N 個 token 的 hidden state 拼成 [seq_len, n_embd] 矩陣,做一次 GEMM。

GEMM 比 GEMV 快很多(典型 5-10×),因為它能 reuse 權重(每次 load 一個 row 的權重,可以對 seq_len 個輸入做計算)。但這需要重新設計 Runner——目前是單 token forward,要改成 batch forward。

3. softmax 的 SIMD exp

(*v - max).exp() 是逐元素的,可以 SIMD 化。但 SIMD exp 比較麻煩——需要用多項式逼近(典型用 Cephes 或 Padé approximant 的 SIMD 版本)。sleef crate 有現成的 SIMD math 函式,可以接上。

4. RoPE 的 sin/cos 預計算

每次 forward pass 都呼叫 theta.sin_cos()。但 freq_i 只和 i 有關,theta = pos * freq_i 也可以分解。如果預計算所有可能的 (pos, i) 的 sin/cos 並存成 table,runtime 只需要查表:

let cos_table: Vec<Vec<f32>> = (0..n_ctx).map(|pos| {
    (0..rot_dim/2).map(|i| {
        let freq = base.powf(-((2*i) as f32) / rot_dim as f32);
        (pos as f32 * freq).cos()
    }).collect()
}).collect();

對 TinyLlama 是 n_ctx × rot_dim/2 × 4 bytes = 2048 × 32 × 4 = 256 KB 的表,啟動時花一次性的時間算好,runtime 都是查表。L2 cache 就能裝下。

5. add_inplace 也可以 SIMD

雖然這是個簡單的逐元素加,但 LLVM 對 iter_mut().zip() 模式的 auto-vectorization 不一定總是會發生。明確的 f32x8::from_slice(...) + f32x8::from_slice(...) 能保證 SIMD 化。

6. 把 forward pass 內整層 fuse 起來

最大的最佳化潛力是 kernel fusion。例如 attention + softmax 可以 fuse(FlashAttention 就是這個思路);rmsnorm + matvec 也可以 fuse 成一個 pass。但 fusion 會大幅複雜化程式碼,違背 tiny-llm-runner 的「易讀」目標——所以這是 candle/ggml 的世界,不是這裡。

一個 Rust 特有的優勢:debug 與 release 的 debug_assert

我每個函式裡都有 debug_assert_eq!(out.len(), x.len()) 之類的。Rust 的 debug_assert! 在 release 會被完全消除——既能在 dev 抓 bug,又不付出 runtime 成本。對 hot path 來說這是個重要的工具。

C/C++ 也有 assert + NDEBUG,但 Rust 把這個習慣做成了語言預設——你不需要記得 #define NDEBUGcargo build --release 自動處理。

總結:ops.rs 的角色

  • 概念上:Transformer 的所有原語(norm、matmul、softmax、RoPE、activation、add)。
  • 實作上:純函式、輸出參數、明確簽章、零隱藏狀態。
  • 設計上:把派發(量化型別 match)藏在 TensorView::dot_row 裡,ops.rs 自己只看到 &[f32],乾乾淨淨。

下一篇講 runner.rs,那是把所有 ops 編排起來的指揮中心。

系列文章:


tiny-llm-runner 深入解讀 (4):model.rs —— 把扁平張量組成 Llama 結構

featured.svg

本文由 AI Agent(Claude)代筆撰寫,文中的「我」指的是 AI Agent。Patrick 只有在文章最後做過潤飾調整。

上一篇看完了 TensorView 那層薄殼,這一篇要看的是 model.rs:把這些 view 組合成一個有結構、有層級LlamaModel

GGUF 把所有張量存成一個扁平 list,每個張量靠字串名字識別。但 Llama 是一個層級結構:每一層有 9 個張量、整個模型有 N 層加上幾個全域張量。model.rs 的工作就是這個 flat list → tree 的對應。

概念一:Llama 的張量命名規則

打開任何一個 Llama 架構的 GGUF,你會看到這些張量名字:

token_embd.weight              # token embedding 表
output_norm.weight             # 最後的 RMSNorm
output.weight                  # lm_head(有時不存在 = tied embeddings)

blk.0.attn_norm.weight         # 第 0 層的 attention 前 RMSNorm
blk.0.attn_q.weight            # Q 投影
blk.0.attn_k.weight            # K 投影
blk.0.attn_v.weight            # V 投影
blk.0.attn_output.weight       # output 投影
blk.0.ffn_norm.weight          # FFN 前 RMSNorm
blk.0.ffn_gate.weight          # SwiGLU 的 gate
blk.0.ffn_up.weight            # SwiGLU 的 up
blk.0.ffn_down.weight          # SwiGLU 的 down
blk.1.attn_norm.weight
... (重複 N 次)

這個命名約定是 llama.cpp 訂的。每一層有 9 個張量,前綴是 blk.{layer_id}.,後綴對應 attention 或 FFN 的某個元件。model.rs 的工作就是按這個規則去 GGUF 裡找。

概念二:n_ff、n_layer 與 Transformer block 的關係

從上面的張量命名可以看到一個很整齊的結構:每一個 blk.{l} 是一個完整的 Transformer block,整個模型就是 n_layer 個這種 block 疊起來。但每個 block 內部到底是什麼?以及 n_ff 在這裡扮演什麼角色?

一個 Transformer block 的兩個子層

每個 block 內部有兩個 sublayer,依序執行:

  1. Attention sublayerattn_norm → Q/K/V 投影 → RoPE → attention → attn_output 投影 → 殘差。
  2. FFN sublayer (SwiGLU)ffn_normffn_gateffn_up 兩個並列投影 → silu(gate) * upffn_down 投影 → 殘差。

attention 負責「跨 token 的資訊互換」,FFN 負責「在每個 token 內做非線性變換」。實證上這兩件事都很重要——你不能只有 attention(會少了 token 內部的資料豐富度),也不能只有 FFN(每個 token 各做各的、沒有上下文)。

n_ff 是什麼?

n_ff(也寫作 feed_forward_lengthintermediate_size)是FFN 的中間維度。具體來說:

hidden state x        : 形狀 [n_embd]            = TinyLlama 的 2048

ffn_gate(x)  : matmul → 形狀 [n_ff]              = 5632   ↑
ffn_up(x)    : matmul → 形狀 [n_ff]              = 5632   │ FFN 內部「膨脹」的維度
silu(gate) * up       : 形狀 [n_ff]              = 5632   ↓

ffn_down(...) : matmul → 形狀 [n_embd]           = 2048   (回到 hidden 大小)

也就是說 FFN 把資料先膨脹再壓回去

n_embd → n_ff → n_embd
 2048    5632    2048   (TinyLlama)
 4096    14336   4096   (Llama-3 8B)

膨脹後的 n_ff 維度才是 FFN 真正「思考」的空間。在這個更高維的空間裡套 SiLU 非線性,再投影回 n_embd

為什麼 n_ff 比 n_embd 大?

理論上 FFN 沒有規定中間維度要多大。但實證上幾乎所有 Transformer 都選 n_ff ≈ 4 × n_embd(或對 SwiGLU 來說是 ≈ 8/3 × n_embd ≈ 2.7 ×,因為 SwiGLU 比傳統 FFN 多一個投影,作者要讓總參數量持平)。直覺是:

  1. 更高維度提供更多「特徵組合」的空間。一層 FFN 等於是「在更寬的空間裡做一個 lookup table」,太窄就學不到複雜的模式。
  2. 大部分模型容量集中在 FFN,不是 attention。Attention 的權重矩陣是 [n_embd × n_embd](每個 Q/K/V/O 一個),FFN 是 [n_embd × n_ff](gate/up/down 三個)。當 n_ff = 4 × n_embd 時,FFN 一個 block 大概佔了 12 × n_embd²,attention 佔 4 × n_embd²——FFN 是 attention 的 3 倍

n_layern_ff 是兩個獨立的 scaling 維度

這兩個參數常常被一起調,但它們的角色完全不同:

維度 作用 像什麼
n_layer 深度」——疊幾層 block 思考的「步驟數」
n_ff 寬度」——每個 FFN 內部多大 每一步「能展開多少特徵」

兩者的權重總量都會直接影響模型大小:

總參數量 ≈ n_layer × (4·n_embd² + 3·n_embd·n_ff)
                    ↑                ↑
                 attention         FFN (SwiGLU 三個矩陣)

對 TinyLlama 1.1B:

22 × (4·2048² + 3·2048·5632)
= 22 × (16.8M + 34.6M)
= 22 × 51.4M
≈ 1.13B

換成 Llama-3 8B:

32 × (4·4096² + 3·4096·14336) ≈ 32 × (67M + 176M) ≈ 7.8B

數字對得上實際模型大小(差距是 token embedding 和 norm 那些佔的份量)。

在 GGUF 命名裡的對應

n_ff 直接決定了 FFN 三個權重矩陣的形狀:

blk.{l}.ffn_gate.weight  : [n_embd, n_ff]    → 投影到中間維度
blk.{l}.ffn_up.weight    : [n_embd, n_ff]    → 投影到中間維度
blk.{l}.ffn_down.weight  : [n_ff, n_embd]    → 投影回 hidden 維度

n_layer 則決定了這組張量會出現幾次——blk.0.*blk.{n_layer-1}.*,每層獨立一份權重。所以這兩個數字的乘積(加上 attention 那部分)決定了整個模型 9 × n_layer 個 block 級權重的數量。

理解了這個關係,你再看 LlamaConfign_layer = 22n_ff = 5632,就能在腦子裡馬上算出「FFN 中間有 5632 維、總共有 22 個這樣的 block 串起來、每個 block 的 FFN 佔 ~34.6M 參數、整個 FFN 部分大概是 760M 參數」——這就是 Llama 模型 size 的來源。

演算法核心:建立 name → TensorInfo 的索引

第一步是把扁平 list 轉成可查的 HashMap:

let by_name: HashMap<&str, &TensorInfo> =
    tensors.iter().map(|t| (t.name.as_str(), t)).collect();

這個 one-liner 用了 Rust collection 的一個漂亮特性:collect() 會根據目標型別自動選擇收集方式HashMap<K, V>Iterator<Item = (K, V)> 收集就會建一個 hash table。

注意這裡的型別是 HashMap<&str, &TensorInfo>——key 和 value 都是借用,沒有任何 String/Vec 拷貝。整個索引大概只佔幾 KB(每個 entry 是兩個指標 + hash 雜湊),相對於 4 GB 的權重來說可以忽略。

演算法核心:tied embeddings 的容錯

let output = match by_name.get("output.weight") {
    Some(info) => TensorView::from_info(info, blob)?,
    None => token_embd,
};

這段在處理一個 LLM 的細節:有些模型會把 lm_head 和 token embedding 共用同一份權重(叫做 “tied embeddings”)。這樣可以省下 vocab_size × n_embd 的權重——TinyLlama 1.1B 來說大概是 32000 × 2048 × 4 bytes = 262 MB,省下來是非常實在的。

GGUF 在這種情況下會省略 output.weight,runner 必須自己知道要 fallback 到 token_embd.weight。注意我直接 = token_embd——因為 TensorViewCopy,這只是一個 32-byte 拷貝,沒有生命週期問題。如果是 Box<TensorView>Rc<TensorView>,這行就會麻煩很多。

演算法核心:分批載入 norm 權重

fn load_f32_vec(by_name: &HashMap<&str, &TensorInfo>, name: &str, blob: &[u8])
    -> Result<Vec<f32>>
{
    let info = by_name.get(name)?;
    if info.tensor_type != GgmlType::F32 {
        bail!("expected {name} to be F32, got {:?}", info.tensor_type);
    }
    let elems: u64 = info.dimensions.iter().product();
    let mut out = vec![0.0f32; elems as usize];
    dequant::dequant_row_f32(&blob[start..end], &mut out);
    Ok(out)
}

注意這裡的策略:RMSNorm 的權重直接 dequant 成 Vec<f32>,但 Q/K/V/FFN 矩陣保持成 TensorView

為什麼?因為它們的大小差距很大:

  • RMSNorm 權重:n_embd 個 f32,TinyLlama 是 2048 × 4 bytes = 8 KB,整個模型 N 層 × 2 個 norm + 1 個 final = 約 360 KB。
  • 注意力矩陣:[n_embd, n_embd] 個 Q4_0 元素,TinyLlama 是 2048 × 2048 × 0.5 bytes = 2 MB,整個模型大概 4 GB。

對 norm 來說,多花 360 KB 換來「forward pass 不必每次重新解碼」是極划算的。分清楚什麼東西該 eager dequant、什麼該 lazy dequant,是這個檔案的關鍵設計決定

而且 norm 權重必須是 F32(這個 bail 不只是規範性檢查,而是事實——llama.cpp 從不量化 norm,因為 norm 的數值範圍小到不值得量化)。

Rust 用法:lifetime 在 struct 上的擴散

pub struct LayerWeights<'a> {
    pub attn_norm: Vec<f32>,
    pub wq: TensorView<'a>,
    pub wk: TensorView<'a>,
    pub wv: TensorView<'a>,
    pub wo: TensorView<'a>,
    pub ffn_norm: Vec<f32>,
    pub w_gate: TensorView<'a>,
    pub w_up: TensorView<'a>,
    pub w_down: TensorView<'a>,
}

pub struct LlamaModel<'a> {
    pub config: LlamaConfig,
    pub token_embd: TensorView<'a>,
    pub output_norm: Vec<f32>,
    pub output: TensorView<'a>,
    pub layers: Vec<LayerWeights<'a>>,
}

每個 struct 都帶 'a,因為它們都間接持有 TensorView<'a>。這個 'a 一路傳到 LlamaModel 的層級,意思是:整個 LlamaModel 不能比 mmap 活得久

這在實務上會長這樣:

fn main() -> Result<()> {
    let mmap = unsafe { Mmap::map(&file)? };       // mmap: 'mmap
    let model = LlamaModel::load(..., &mmap[...])?; // model: LlamaModel<'mmap>
    let mut runner = Runner::new(&model, ...);     // runner 借用 &model

    // 現在 runner、model、mmap 全部都借用鏈到 mmap 上
    // 編譯器保證它們的生命週期是 mmap ⊃ model ⊃ runner
}

如果你不小心想把 modelmain 回傳出去:

fn load_model() -> LlamaModel<'???> { ... }   // 編譯失敗

編譯器會直接拒絕——因為 'a 必須繫結到一個外部生命週期。這是 Rust 在「結構體借用資料」這個模式上的招牌。

Rust 用法:用 format! 拼字串 vs 用 const generics

for l in 0..config.n_layer {
    layers.push(LayerWeights {
        attn_norm: load_f32_vec(&by_name, &format!("blk.{l}.attn_norm.weight"), blob)?,
        wq:        view(&by_name, &format!("blk.{l}.attn_q.weight"), blob)?,
        ...
    });
}

這段每呼叫一次 format! 都會 allocate 一個 String。對 22 層的 TinyLlama 來說是 22 × 9 = 198 次 allocation,整個 load 過程大概幾百 KB——對 startup 來說微不足道。

但有個更省的寫法:用 write! macro 寫進一個複用的 buffer:

let mut buf = String::with_capacity(64);
for l in 0..config.n_layer {
    buf.clear();
    write!(buf, "blk.{l}.attn_norm.weight").unwrap();
    let attn_norm = load_f32_vec(&by_name, &buf, blob)?;
    ...
}

不過這個改動會犧牲可讀性換 0.1 ms 啟動時間,這就是過度最佳化的典型例子。我選擇可讀性。

演算法核心:embed 函式的 row-major 觀念

pub fn embed(&self, token: u32, out: &mut [f32]) {
    self.token_embd.dequant_row(token as usize, out);
}

這只有一行,但有個微妙細節:token embedding 矩陣的 row 是 n_embd、token 數量是 dim1。也就是說:

token_embd shape (GGUF dim 順序): [n_embd, vocab_size]
                                    ^^^^^^  ^^^^^^^^^^
                                    row 寬度  row 數量

我們對 token id t 的 lookup 就是「拿第 t 個 row」。所以 dequant_row(token, out) 直接把 token 的 embedding vector 寫進 out

這個 embedding lookup 是整個 forward pass 中唯一會做 dequant 而不是 dot 的地方。原因很合理:embedding 是「直接拿值」而不是「拿值算內積」,所以 dequant 是必要的。

效能最佳化空間

model.rs 不在 hot path 上——它只在啟動時跑一次。但有幾個方向值得想:

1. 平行載入 norm 權重

目前 load_f32_vec 是序列的,22 層 × 2 個 norm + 1 個 = 45 次小 dequant。可以用 rayon 把它們平行掉:

use rayon::prelude::*;
let norms: Result<Vec<_>> = (0..n_layer).into_par_iter()
    .map(|l| load_f32_vec(&by_name, &format!("blk.{l}.attn_norm.weight"), blob))
    .collect();

但 norm 權重很小(每個 8 KB 左右),平行化的 overhead 可能比實際工作還大——不見得會更快。

2. 把 token_embd 也 eager dequant 成 f32

目前 embed 每次都會去 dequant 一個 row。對 7B 模型的 prefill 來說,假設 prompt 是 100 token、每 token 都 dequant 一次 row,這是 100 個 small dequant call。

但如果你 eager 把整個 token embedding 解壓成 f32,會佔用 vocab_size × n_embd × 4 bytes——TinyLlama 是 256 MB,但 Llama-3 8B 是 128k × 4096 × 4 = 2 GB這個交換在小 vocab 模型划算,在大 vocab 模型不划算。一個折衷方案是 LRU cache:只 cache 最近用過的 K 個 token 的 embedding。

3. 自訂 hashmap 改善 locality

HashMap<&str, &TensorInfo> 的 default hasher 是 SipHash,安全但慢。對啟動時的 lookup 來說,可以換成 ahashfxhash

use std::collections::HashMap;
type FastMap<K, V> = HashMap<K, V, ahash::RandomState>;

不過這同樣是「啟動時間 0.5 → 0.3 秒」的微優化,相對於 mmap warmup 的時間(取決於檔案大小,可能是幾秒)算不上痛點。

4. 預先驗證所有張量都存在

目前如果某層的 attn_q.weight 缺失,會在 load 那一層時才 bail。可以在 LlamaModel::load 一開始就用一個 list 驗證所有預期張量都在:

let expected: Vec<String> = (0..n_layer).flat_map(|l| {
    [
        format!("blk.{l}.attn_norm.weight"),
        format!("blk.{l}.attn_q.weight"),
        ...
    ]
}).collect();
for name in &expected {
    if !by_name.contains_key(name.as_str()) {
        bail!("missing {name}");
    }
}

這對使用者體驗有改善——一次回報所有缺失,而不是一個一個踩雷。

一個被隱藏的設計選擇:為什麼 norm 是 Vec<f32> 而不是 &[f32]

理論上 norm 也可以是個 view:

pub attn_norm: &'a [f32],   // 直接借用 mmap 的 bytes(如果 alignment 對齊)

這樣連 dequant 都不必做,只要把 mmap 的 bytes 重新解釋成 &[f32] 就行。但有兩個問題:

  1. endianness:GGUF 規定 little-endian,現代電腦也都是 LE,但 Rust 的 align_to 不保證安全做這個轉換。
  2. alignment:mmap 是 page-aligned 的,但 norm 張量的 offset 不一定 4-byte 對齊。

更安全的做法是 dequant_row_f32f32::from_le_bytes,每個元素都明確做 byte 轉換。代價是一次性的解碼成本(一個 7B 模型大概 1 MB f32 norm,在啟動時花幾毫秒解掉)。

如果未來想做 zero-copy norm,可以驗 alignment 然後用 bytemuck::cast_slice,這是一個成熟的「safe transmute」套件。

總結:model.rs 的角色

  • 概念上:扁平的 GGUF 張量 list 對應到結構化的 Llama 模型樹。
  • 實作上:HashMap 索引 + 字串拼接 + 條件性 fallback(tied embeddings)。
  • 設計上:分清楚什麼 eager dequant(norm,小且常用)、什麼 lazy dequant(大矩陣)。

下一篇講 ops.rs,那是 Transformer 的「樂高積木」。

系列文章:


tiny-llm-runner 深入解讀 (3):tensor.rs —— 用生命週期蓋一層薄殼

featured.svg

本文由 AI Agent(Claude)代筆撰寫,文中的「我」指的是 AI Agent。Patrick 只有在文章最後做過潤飾調整。

上一篇講完 dequant.rs 的量化核心,這一篇要看的是 tensor.rs——一個只有 150 多行的小檔案,但它展示了 Rust 在系統程式設計上最有特色的能力之一:用生命週期把「資料在誰手上」這件事編譯期就講清楚。

這個檔案要解決什麼問題?

dequant.rs 提供的所有函式都是「裸的」:你給它一段 &[u8] 和一個 &[f32],它做點積。但在 forward pass 裡,我們不想讓 runner.rs 直接看到 byte slice——那太底層、太容易出錯。我們想要一個更高階的抽象,例如:

「這是一個 [K, N] 的 Q4_0 矩陣,請對它的第 i row 和輸入向量 x 做點積。」

TensorView 就是這個抽象。

概念一:什麼是「view」?為什麼不擁有資料?

#[derive(Clone, Copy)]
pub struct TensorView<'a> {
    pub data: &'a [u8],
    pub ggml_type: GgmlType,
    pub dims: [u64; 4],
    pub n_dims: usize,
}

注意這個 'a 生命週期參數。TensorView 不是「擁有」這些 bytes,它只是「看著」這些 bytes。實際的 bytes 還在 mmap 裡。

這個設計的意義是:

  1. 零拷貝:建構一個 TensorView 只需要拷貝 32 bytes 的中繼資料(dims + type + slice header),沒有任何資料搬運。
  2. 生命週期安全:因為 'a 綁定到 mmap,編譯器會保證這個 view 不會比 mmap 活得更久。
  3. 可以是 Copy:32 bytes 的 struct(沒有 owned data)可以實作 Copy,意味著傳遞它就像傳一個整數那麼便宜。

對比一下,如果我用 Tensor { data: Vec<u8>, ... }(擁有資料),那建構一個 7B 模型的 LlamaModel 就會把整個 4 GB 從 mmap 複製到 heap 上——啟動時間和記憶體用量都會炸掉。

概念二:行優先佈局與 dim 順序

GGUF 的張量是 row-major 儲存,但 dim 順序和我們直覺可能相反:

GGUF 規定:dimensions = [K, N] 表示 K 個欄、N 個 row
也就是 dim[0] = "row 寬度"、dim[1] = "row 數量"

這個約定在 dim0()dim1() 兩個 getter 上反映出來:

pub fn dim0(&self) -> usize { self.dims[0] as usize }   // 每個 row 有幾個元素
pub fn dim1(&self) -> usize { self.dims[1] as usize }   // 有幾個 row

這個 GGUF convention 可能是受 BLAS 影響——很多線性代數函式庫的 [m, n] 矩陣 m 是 row count、n 是 col count,但實際上記憶體是 column-major。GGUF 採用了一個和 NumPy 直覺相反的約定,第一維是 row 寬度,第二維是 row 數量。每次寫程式都要小心這個。

演算法核心:每種量化的 row size 計算

fn bytes_per_row(t: GgmlType, cols: usize) -> usize {
    match t {
        GgmlType::F32  => cols * 4,
        GgmlType::F16  => cols * 2,
        GgmlType::Q8_0 => (cols / dequant::QK8_0) * dequant::Q8_0_BLOCK_SIZE,
        GgmlType::Q4_0 => (cols / dequant::QK4_0) * dequant::Q4_0_BLOCK_SIZE,
        GgmlType::Q6_K => (cols / dequant::QK_K)  * dequant::Q6_K_BLOCK_SIZE,
        other => panic!("unsupported tensor type: {other}"),
    }
}

對 fp 類型很直觀(每個元素 N bytes,乘起來就好)。但對量化類型,重點是「block 整除」:每個 block 是固定大小(32、32、256 個元素),所以一個 row 必須是 block 大小的整數倍。bytes_for 函式裡的 is_multiple_of 檢查就是在驗證這件事。

這也意味著 LLM 的 hidden dim 通常是 32、64、128、256 的倍數——不是巧合,是為了和量化 block 對齊。

演算法核心:dot_row 的派發策略

TensorView::dot_rowtensor.rs 暴露給上層的最重要 API:

pub fn dot_row(&self, i: usize, x: &[f32]) -> f32 {
    let row = self.row(i);
    match self.ggml_type {
        GgmlType::F32  => dequant::dot_f32(row, x),
        GgmlType::F16  => dequant::dot_f16(row, x),
        GgmlType::Q8_0 => dequant::dot_q8_0(row, x),
        GgmlType::Q4_0 => dequant::dot_q4_0(row, x),
        GgmlType::Q6_K => dequant::dot_q6_k(row, x),
        t => panic!("unsupported tensor type for matmul: {t}"),
    }
}

注意這個 match:派發只發生在 row 邊界,一旦進到 inner loop(dot_q4_0 內部),就沒有任何條件判斷了。這是個重要的微結構選擇——CPU 的分支預測器會討厭 inner loop 裡的條件分支,把派發拉到外面才能讓內部迴圈純粹是計算。

為什麼用 match 而不是 trait object?

如果你看過更 OOP 的設計,可能會用:

trait QuantOps {
    fn dot(&self, q: &[u8], x: &[f32]) -> f32;
}

然後 TensorView 持有一個 Box<dyn QuantOps>。這個設計的問題是:

  1. 每次 dot 都要走 vtable,無法 inline。
  2. 每個 row 多一次間接跳轉,CPU 預測會變差。
  3. Copy 寫不出來——Box 不是 Copy

直接 match 反而是最快的。Rust 編譯器看到 match 是窮舉的、且 arms 都會 inline,會把這段 match 編譯成一個跳躍表(jump table)或一連串的條件比較,比 trait object 快一個量級

Rust 用法:Copy 的隱含好處

TensorViewCopy,意味著:

let view = TensorView::from_info(...)?;
some_function(view);    // 不需要 .clone()
let view2 = view;       // 不會 invalidate `view`

這讓 runner.rs 可以放心地把 view 當值傳。32 bytes 的拷貝在 x86 上是一條 SSE move 指令,比走參考還可能快——因為避免了 alias 分析的複雜度。

Copy 不是免費的——它要求所有欄位都是 Copy&[u8] 是(slice header 是個 fat pointer),GgmlType 是(plain enum),[u64; 4] 是(陣列),usize 是。所以 TensorView 自然就能 Copy

Lifetime elision 的小細節

TensorView<'a>'a 看起來很煩,但用起來其實大多數時候不用寫——Rust 的 lifetime elision 規則會自動幫你補。例如 from_info 的簽章:

pub fn from_info(info: &TensorInfo, blob: &'a [u8]) -> Result<Self> { ... }

這裡只有 blob'a(因為 SelfTensorView<'a>,必須繫結到某個來源)。info: &TensorInfo 用的是省略掉的另一個生命週期,編譯器會自己補。

Rust 用法:用 trait 加 ergonomics

pub trait TensorInfoExt {
    fn ggml_type_or_panic(&self) -> GgmlType;
}

impl TensorInfoExt for TensorInfo {
    fn ggml_type_or_panic(&self) -> GgmlType {
        self.tensor_type
    }
}

這是個小技巧:TensorInfo 是上游 crate (llm-gguf-parser) 定義的型別,我不能直接給它加方法。但我可以用 extension trait——在我的 crate 裡定義一個 trait,並為上游型別實作它。引入這個 trait 後,info.ggml_type_or_panic() 就能用了。

這比每次都寫 info.tensor_type 更具表達性——名字明確說明「我假設這個值已經被驗證過了」。雖然這個例子裡 trait 帶的方法非常薄,重點是表達 intent,不是省字數

效能最佳化空間

tensor.rs 本身沒有 hot path——所有 hot 工作都在 dequant.rs 裡。但有幾個值得想的方向:

1. 把 dim 從 [u64; 4] 改成 [u32; 4]

LLM 沒有單一維度超過 40 億的張量。u32 已經足夠,可以把 TensorView 從 64 bytes(slice + type + 4×u64 + n_dims)縮到 48 bytes,更友善 L1 cache。

2. 預先 cache row_bytes

每次呼叫 row(i) 都會算一次 row_bytes(),內部又是個 match:

pub fn row(&self, i: usize) -> &'a [u8] {
    let rb = self.row_bytes();   // match self.ggml_type
    &self.data[i * rb..(i + 1) * rb]
}

雖然編譯器可能會把這個算一次然後 hoist 出迴圈,但 matvecpar_iter_mut 是平行的、每個執行緒看到的是新的 closure scope,不一定會 hoist。如果在建構 TensorView 時就 cache 一個 row_bytes: u32,就能避免每次重算:

pub struct TensorView<'a> {
    pub data: &'a [u8],
    pub row_bytes: u32,        // 預先算好
    pub ggml_type: GgmlType,
    pub dims: [u32; 4],
    pub n_dims: u8,
}

不過這是 micro-optimization,在 Q4_0 inner loop 已經吃掉 99% 時間的情況下,這個改動量化效益很小。

3. 改用 &'a [QXBlock] 而不是 &'a [u8]

如果把 data 從 raw bytes 改成「具型 block 切片」(例如 &[Q40Block]),就能避免在 dot_q4_0 內部的指針算術,編譯器也更容易做 alias 分析。但這需要每種量化都定義一個 repr(C, packed) 的 struct,代價是更多 boilerplate。

4. 對齊:強制 16-byte 或 32-byte 對齊

mmap 給我們的 byte slice 是 page-aligned(4 KB),但每個張量的起始 offset 不一定對齊到 SIMD 邊界。如果未來要 SIMD 化,可能需要:

  • from_info 時驗 offset 是 32-byte 對齊。
  • 對沒對齊的張量做一次性 copy 到對齊 buffer。

GGUF 規格其實有 general.alignment metadata 來控制這個,目前我沒在驗。

一個更深入的設計問題:「view」和「ownership」的分界

tensor.rs 體現了 Rust 一個非常獨特的設計優勢:生命週期允許你寫「比 C++ 還靈活、比 Java/Go 還安全」的 view 型別

對比 C++:

  • 你可以寫 string_view,但生命週期只能靠註解和文件——編譯器不會幫你檢查。
  • 一旦底層 string 被 free 了,view 就是懸空指標,但編譯期看不出來。

對比 Java/Go:

  • 你可以「分享 reference」,但 GC 會強迫底層物件保持活著——你失去了「這個 view 隨時會死」的精確控制。

Rust 的 lifetime 是這兩者之間的中道:底層物件的生命週期由其他人控制,但編譯器會保證所有 view 都不會比它活得更久

這個能力對 mmap 場景特別重要——mmap 的「資料」是一個極端的例子(GB 等級、來自硬碟、隨時可以被 munmap),如果沒有靜態保證 view 不會懸空,bug 會非常難 debug。

總結:tensor.rs 的角色

  • 概念上:把 &[u8] 包成一個有型別、有形狀、知道怎麼點積的「視圖」。
  • 實作上:32 bytes 的 struct + 兩個方法 + 一個 row size 表。
  • 語言上:lifetime + Copy + match 派發 + extension trait,Rust 慣用法的小展示。

下一篇講 model.rs,看怎麼把「一堆 view」組合成一個有結構的 LlamaModel

系列文章: