Simply Patrick

tiny-llm-runner 深入解讀 (9):main.rs —— CLI、Prefill、Decode 與整體效能

featured.svg

本文由 AI Agent(Claude)代筆撰寫,文中的「我」指的是 AI Agent。Patrick 只有在文章最後做過潤飾調整。

歷經八篇深入解讀,我們終於來到 tiny-llm-runner 的最後一塊——main.rs。前面拆了那麼多零件,總得有人把它們兜起來吧?這個檔案就 130 行,是把所有元件串起來的「指揮台」。

這也是整個系列的最後一篇了。除了講 main.rs,我想在最後對整個專案做一個全局的效能最佳化清單,把每個檔案散落的優化點串起來看——算是給這趟旅程一個交代。

概念一:CLI 參數設計

#[derive(Parser, Debug)]
#[command(version, about = "Pure-Rust llama-architecture inference over a GGUF model")]
struct Args {
    #[arg(short, long)]                              model: PathBuf,
    #[arg(short, long, default_value = "Once upon a time")] prompt: String,
    #[arg(short, long, default_value_t = 64)]        n_predict: usize,
    #[arg(short, long, default_value_t = 0.8)]       temperature: f32,
    #[arg(long, default_value_t = 40)]               top_k: usize,
    #[arg(long, default_value_t = 42)]               seed: u64,
    #[arg(long)]                                     no_bos: bool,
    #[arg(long, default_value = "llama")]            rope: String,
}

clap 的 derive macro 把 CLI parsing 變成宣告式:每個欄位加上 #[arg(...)] 就自動產生 --model--prompt 之類的 flag。這比手寫 argument parser 短得多,而且 --help、type validation、default value 全都免費送你,實在是頗划算。

clap 的 zero-cost abstraction

clap 的 macro 在 compile-time 就生成好 parsing code,runtime 沒有任何 reflection。也就是說啟動時 Args::parse() 是純 native code,速度比 Python 的 argparse 快上好幾個量級。

對 LLM runner 來說 CLI 啟動時間其實不是什麼大問題,不過這個習慣很 Rust——把 metadata 處理推到 compile time,runtime 只留下純粹的計算。看多了你會發現整個語言都在貫徹這件事。

概念二:unsafe Mmap

let file = File::open(&args.model)?;
let mmap = unsafe { Mmap::map(&file)? };

整個專案唯一一個 unsafe,就這麼一行。為什麼 mmap 非得 unsafe 不可?

因為 mmap 違反了 Rust 的記憶體模型假設:Rust 假設一個 &[u8] 的內容在它的生命週期內不會被外部修改。但 mmap 對應的檔案如果被另一個 process 改掉(甚至 truncate),這個 &[u8] 就會看到變了樣的資料、最慘還會吃到 SIGBUS。

unsafe 說穿了就是程式設計師對編譯器的一句承諾:「我知道這違反一般規則,使用過程中檔案不會被外部動到,我自己負責」。對 LLM 模型檔來說這承諾其實很好守——模型檔通常就是 read-only 的,誰會去動它呢。

這也呼應了 Rust 的一個設計哲學:unsafe 不是禁忌,而是被精準框定的工具。整個 codebase 只有這一行 unsafe,但它被框得清清楚楚——出了問題,責任就在這一行,跑不掉。

演算法核心:Prefill / Decode 二段式

// 1. Prefill —— 處理 prompt
let prefill_start = Instant::now();
let mut last_logits: Option<Vec<f32>> = None;
for &tok in &prompt_ids {
    let logits = runner.forward(tok);
    last_logits = Some(logits.to_vec());
}
let prefill_elapsed = prefill_start.elapsed();

// 2. Decode —— 生成 token
let decode_start = Instant::now();
let mut generated: Vec<u32> = Vec::with_capacity(args.n_predict);
let mut logits = last_logits.expect("empty prompt");
for _ in 0..args.n_predict {
    let next = sampler.sample(&mut logits);
    if next == tokenizer.eos { break; }
    generated.push(next);
    let piece = tokenizer.decode(&[next]);
    print!("{piece}");
    std::io::stdout().flush().ok();
    logits = runner.forward(next).to_vec();
}

為什麼分兩階段?

LLM 推論天然分兩個階段:

  • Prefill:把使用者的 prompt 餵進去,建立 KV cache。logits 只有最後一個 token 的有用——前面的丟掉。
  • Decode:每次 forward 一個 token、抽下一個。每個 logits 都會用到。

這兩個階段的特性差異很有意思:

  • Prefill 的 token 全都是已知的,理論上可以批次處理(用 GEMM 取代 GEMV)。
  • Decode 就只能乖乖 sequential(下一個 token 取決於上一個,沒得偷懶)。

不過我目前 prefill 也是 sequential(一個 token 一個 forward)的,這就是個明擺著的優化機會了——後面清單會再回來算這筆帳。

演算法核心:tok/s 的計算

eprintln!("[prefill] {} tok in {:.2}s ({:.1} tok/s)",
    prompt_ids.len(),
    prefill_elapsed.as_secs_f64(),
    prompt_ids.len() as f64 / prefill_elapsed.as_secs_f64().max(1e-9),
);

max(1e-9) 是用來防止 0 秒(極短 prompt)導致除以零。f64::max(self, other) 回傳兩者較大者,所以 0.0.max(1e-9) = 1e-9,分母就保證不會是 0 了。

這種小細節很容易忘記寫喔——prompt 只有一個 token 時,prefill 可能是 0.001 秒,算出來還有意義;但要是快到變成 0 秒(測試環境有時就是這麼誇張),分母歸零你就會收到一個漂亮的 NaN。

Rust 用法:streaming output 的 flush

print!("{piece}");
std::io::stdout().flush().ok();

print! 寫進 stdout buffer,但不會馬上顯示——一般 stdout 是 line-buffered,要等到 \n 才 flush。LLM 串流輸出又沒有 \n,所以非得手動 flush 不可,不然你會傻等半天什麼都看不到,還以為當機了。

flush().ok()Result<(), Error> 轉成 Option<()> 然後丟掉——白話講就是「這個 flush 失不失敗我才懶得管」。stdout 寫入失敗本來就極罕見(例如 pipe 被人砍掉),就算真的失敗我們也無能為力,silent ignore 反而是最合理的處理。

Rust 用法:anyhow 的錯誤處理

fn main() -> Result<()> {
    // ... 整個 main 都是 Result-friendly 的,用 ? 早期返回
    Ok(())
}

fn main() -> Result<()> 是 Rust 處理 CLI errors 最乾淨的寫法。任何 ? 失敗都會把錯誤往 main 外面丟,runtime 自動 print 出來再 exit 1,連 error handling 的 boilerplate 都省了。

anyhow::Result<T> 不過就是 Result<T, anyhow::Error> 的別名。anyhow::Error 可以從任何 std::error::Error 自動轉換——這就是為什麼我能把 std::io::Error、parser error、自定義的 bail! 全混在一起,通通用一個 ? 打發掉,實在是頗舒服。

Rust 用法:環境變數和 stderr

eprintln!("[loaded] n_layer={} ...", config.n_layer, ...);

eprintln! 寫到 stderr,println! 寫到 stdout。我刻意把 metadata 印在 stderr、生成內容印在 stdout——這樣你用 ./tiny-llm-runner > out.txt 時,out.txt 裡就只有乾淨的生成內容,那些 metadata 還是乖乖留在 console 上,不會污染你的檔案。

這是 Unix 工具的老慣例了。在 Rust 裡用兩個不同的 macro 就自然支援,不必特別費心。

整個專案的端到端 forward pass 流程

flowchart TD A[CLI Args] --> B[File::open + Mmap] B --> C[parse_gguf] C --> D[LlamaConfig::from_gguf] C --> E[LlamaModel::load
建 TensorView] C --> F[Tokenizer::from_gguf] F --> G[encode prompt] D --> H[Runner::new
配 KV cache + scratch] E --> H G --> I[Prefill loop
forward each token] H --> I I --> J[Decode loop
sample → forward → repeat] J --> K[print tokens]

從 CLI 進來到 token 吐出去,整條流水線就這樣。仔細看會發現,圖裡每一個方框幾乎都對應到前面九篇文章其中一篇的主題——拼到這裡,整張地圖才算完整。

全局效能最佳化清單

到這裡,所有檔案都翻過一遍了。我想趁記憶猶新,把整個專案的最佳化機會匯總成一張 prioritized list。要強調的是:我不建議盲目地照著順序硬幹,還是得看你自己最想練哪一塊。

Tier 1(最大效能槓桿,10× 級的改進)

  1. SIMD 化 dot kernelsdequant.rs

    • Q4_0、Q8_0、Q6_K 的 inner loop 用 AVX2/AVX-512/NEON
    • 預期:matvec 加速 8-16×
    • 工作量:中—需要小心和 llama.cpp 對拍正確性
  2. Prefill batching(GEMV → GEMM)runner.rsops.rs

    • 把 prompt N 個 token 的 forward 拼成一個 batched 計算
    • 預期:prefill 加速 5-10×(decode 不變)
    • 工作量:大—涉及 attention 的 mask、KV cache 的 batched 寫入
  3. 支援 K-quants(Q4_K、Q5_K、Q4_K_M)dequant.rs

    • 不是加速 per se,而是讓現代 GGUF 都能跑
    • 工作量:中—實作複雜但有 ggml C 程式碼可參考

Tier 2(顯著改進,2-3× 級)

  1. Multi-row matvec fusionops.rs

    • 一次處理多個 row,減少 x 的 cache miss
    • 預期:matvec 加速 2-4×
  2. KV cache 量化runner.rs

    • 把 KV cache 從 F32 改成 Q8_0
    • 預期:記憶體用量 4×、速度可能略有提升(cache miss 變少)
    • 工作量:中
  3. f16 / bf16 全程(多個檔案)

    • 不要每次都 dequant 成 f32,scratch buffer 也用 f16
    • 預期:記憶體頻寬減半
    • 工作量:大—需要全程 f16 的 numerical stability 驗證

Tier 3(小但容易的改進)

  1. RoPE sin/cos 表預計算ops.rs

    • 不要每次 forward 都算 sin/cos
    • 預期:每層省幾十 μs,整體可能 1-2%
  2. Tokenizer 的 pair lookup tabletokenizer.rs

    • 避免 format! 字串拼接
    • 預期:encode 加速 5-10×(但 encode 不在 hot path)
  3. Top-P samplingsampler.rs

    • 提升 sampling 品質(不是速度,是輸出品質)
  4. Repetition penaltysampler.rs

    • 同上

Tier 4(架構級重構,可能不值得)

  1. GPU backend

    • wgpu 或 CUDA 支援
    • 工作量:極大—基本上是另一個專案
  2. FlashAttention

    • Fused attention with online softmax
    • 工作量:大—但 candle/ggml 有現成實作可學
  3. Speculative decoding

    • 用小模型加速大模型推論
    • 工作量:大—需要兩個模型協作

一個整體觀察:抽象與效能的權衡

寫完整個系列,我最強烈的感受是:tiny-llm-runner 的「易讀」,其實是拿「不可擴充」換來的。每個檔案都針對單一情境寫得直白到不行,但代價就是——要加新功能(新架構、新量化、新後端)往往得同時動好幾個檔案。

這跟 candle 的設計哲學根本是兩條路。candle 透過一層厚厚的抽象(TensorModuleVarBuilder)讓擴充變得很便宜,代價則是「想看懂一次 forward pass,得在好幾個 trait 之間跳來跳去」。

那到底哪個對?老實說,取決於你的目標。 你要做生產級框架,candle 那套抽象就是必要之惡;你要做一個「能跑、能讀、能 hack」的學習版,那 tiny-llm-runner 的扁平結構反而才是對的。沒有標準答案,只有適不適合。

把九個檔案的學習收穫匯總

回到我們最一開始問的那個問題:「寫一個會跑 LLM 的專案,到底需要哪些零件?」攤開來看,答案就是這張表:

檔案 你需要學會
config.rs metadata 解析 + 不變式檢查
dequant.rs block-wise 量化 + fused dot product
tensor.rs lifetime + view 抽象 + Copy struct
model.rs 樹形權重組織 + tied embeddings
ops.rs RMSNorm、softmax、RoPE、SwiGLU、rayon
runner.rs KV cache + GQA + 殘差連接 + scratch buffer
tokenizer.rs SentencePiece-BPE + byte fallback + UTF-8 重組
sampler.rs top-k partial sort + xorshift + CDF sampling
main.rs CLI design + prefill/decode split + tok/s metric

如果你真的把整個系列啃完了,那「LLM 推論引擎到底在幹嘛」這件事,你心裡應該已經有一張完整的 mental model 了。而且接下來能玩的還多著呢:自己加 SIMD、自己刻 GEMM、自己補 K-quants、自己接 GPU backend……每一條我都覺得夠格獨立寫成一篇工程旅程。

結語

回頭看,tiny-llm-runner 對我來說從來就不只是「一個專案」,而是「一次把 LLM 推論從頭到尾看透的旅程」。從一個 mmap 出來的 byte slice 開始,經過一連串型別、抽象、運算的層層組合,最後居然真的長成一個能跑、能跟 llama.cpp 對拍、又讀得懂的 Rust 程式——對我這種改不掉、就是愛把黑盒子拆開看裡面齒輪的工程師來說,這份滿足感,比把效能再快一倍都還過癮。 :-)

九篇走下來,最大的收穫其實不是哪個 kernel 怎麼寫,而是那種「啊,原來這裡面沒有魔法」的踏實感。LLM 聽起來玄,拆開來不過就是量化、矩陣、softmax、sampling 這些老朋友排排站而已。

謝謝你一路陪我把這九篇啃完。下次再見的時候,但願我已經把上面那張清單裡的優化,至少落地了幾個——不然這篇結語可就有點心虛了。我們,下個專案見。

系列文章:


tiny-llm-runner 深入解讀 (8):sampler.rs —— Greedy、Top-K 與 xorshift PRNG

featured.svg

本文由 AI Agent(Claude)代筆撰寫,文中的「我」指的是 AI Agent。Patrick 只有在文章最後做過潤飾調整。

上一篇看完了 tokenizer,這一篇要看的是模型吐出結果之後最後一道工序:怎麼把一堆數字變成「下一個 token」。sampler.rs 不到 100 行,看起來很不起眼,但它實在是決定了 LLM 的「個性」——同一個模型換不同 sampler,生出來的內容可以差很多喔

概念一:什麼是 logits?怎麼從 logits 拿 token?

forward pass 的最終輸出是 [vocab_size] 個浮點數,叫做 logits。每個位置代表「下一個 token 是 i 的 unnormalized log-probability」,講白話就是「模型覺得這個 token 有多順眼」的分數。

最直接的選法就是 greedy(argmax)

fn argmax(x: &[f32]) -> usize {
    let mut best = 0usize;
    let mut best_v = f32::NEG_INFINITY;
    for (i, &v) in x.iter().enumerate() {
        if v > best_v { best_v = v; best = i; }
    }
    best
}

就是無腦挑分數最高的那個。Greedy 的問題是輸出太 deterministic,模型每次都選同一個最高分 token,內容會變得很單調,還很容易卡在迴圈裡跳不出來。

概念二:Temperature —— 給機率分佈加溫

Temperature 是這樣運作的:

let inv_t = 1.0 / self.temperature;
for v in logits.iter_mut() {
    *v *= inv_t;
}

把所有 logits 除以 temperature。然後過 softmax 拿到機率分佈:

$$P(i) = \frac{e^{\ell_i / T}}{\sum_j e^{\ell_j / T}}$$

T 的影響:

  • T → 0:分佈會變成尖峰(最大值佔 1,其他 0)→ 等同於 greedy。
  • T = 1:原始分佈。
  • T → ∞:分佈會變平(每個 token 機率接近 1/vocab)→ 完全隨機。

實務上 T = 0.7 ~ 1.0 大概就是 LLM 寫作的甜蜜點吧——夠隨機讓內容有點變化,又不至於整個脫線講起夢話。

概念三:Top-K —— 只從前 K 個候選裡抽

純 temperature 抽樣有個惱人的地方:它不會幫你排除「明顯不該選」的 token。比方說分佈裡有個位置是「機率 0.001、但其實是個亂碼字元」,溫度一高,它還是有那 0.001 的機會被抽到——然後一個 token 就毀掉一整段文字,前功盡棄。

Top-K 的解法很直接:只保留機率最高的 K 個,其他全部歸零

if self.top_k > 0 && self.top_k < logits.len() {
    let mut indexed: Vec<(usize, f32)> = logits.iter().copied().enumerate().collect();
    indexed.select_nth_unstable_by(self.top_k, |a, b| {
        b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
    });
    let cutoff = indexed[..self.top_k]
        .iter()
        .map(|(_, v)| *v)
        .fold(f32::INFINITY, f32::min);
    for v in logits.iter_mut() {
        if *v < cutoff {
            *v = f32::NEG_INFINITY;
        }
    }
}

select_nth_unstable_by:partial sort 的妙用

我這裡沒用 sort,而是用 select_nth_unstable_by。為什麼呢?

sort 是 $O(n \log n)$。但說穿了,我們只關心「前 K 個是哪些」、根本不在乎這 K 個之間誰排前誰排後。select_nth_unstable_by 就是 partial sort:把第 K 個元素放到對的位置,前 K 個都在它前面(內部順序不保證),後面的都在它後面。複雜度平均是 $O(n)$、worst case 才 $O(n \log n)$,比乖乖整個排序快得多

拿 vocab_size = 32k、top_k = 40 來算,sort 要做 32k × log(32k) ≈ 480k 次比較;partial sort 平均只要 ~32k 次。差了 15 倍耶,這種白吃的午餐不拿白不拿。

找 cutoff 的小技巧

let cutoff = indexed[..self.top_k]
    .iter()
    .map(|(_, v)| *v)
    .fold(f32::INFINITY, f32::min);

partial sort 之後 indexed[..K] 就是「最大的 K 個」,只是內部順序未知。我用 fold(f32::INFINITY, f32::min) 撈出它們之中最小的那個——這就是 cutoff。任何 logit 小於 cutoff 的,就請它出局。

順帶一提,我這裡用的是 f32::min(function)而不是 method 版的 min,當作 fold 的參數。f32::min 碰到 NaN 的處理方式是「忽略它」,比較不會出事。

NaN 處理

b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)

f32::partial_cmp 因為 NaN 沒辦法比大小,回傳的是 Option<Ordering>。我用 unwrap_or(Ordering::Equal) 把 NaN 當成「相等」來處理——不算漂亮啦,但至少不會 crash。理論上 logits 是不該冒出 NaN 的,不過量化模型碰到某些極端輸入還真有可能生出 NaN 來,所以這個 defensive 的小動作我覺得很值得。

演算法核心:累積分佈抽樣

Softmax + uniform random + cumulative sum:

softmax(logits);
let r = self.next_f32();   // [0, 1)
let mut acc = 0.0f32;
for (i, &p) in logits.iter().enumerate() {
    acc += p;
    if acc >= r { return i as u32; }
}
(logits.len() - 1) as u32

這是經典的 inverse CDF sampling

  1. 把機率分佈累加成 CDF:[p0, p0+p1, p0+p1+p2, ..., 1.0]
  2. 隨機抽一個 r ∈ [0, 1)
  3. 找第一個 CDF 大於 r 的位置就是抽中的 token。

這裡有個容易被忽略的小細節:理論上最後一個 cumulative sum 應該剛好是 1.0,但 floating point error 可能讓它差那麼一點點、略小於 1.0。萬一 r 不偏不倚就落在那個誤差縫隙裡,迴圈跑完卻沒人 return,那就尷尬了。所以最後補了一行 (logits.len() - 1) as u32 兜底,當作保險。

演算法核心:xorshift64 PRNG

fn next_u64(&mut self) -> u64 {
    let mut s = self.rng_state;
    s ^= s << 13;
    s ^= s >> 7;
    s ^= s << 17;
    self.rng_state = s;
    s
}

這是 xorshift64——一個簡單到有點不可思議的偽隨機數產生器。三條 shift + xor 指令,就足以通過大部分隨機性測試了。週期是 $2^{64} - 1$,對 LLM 推論來說綽綽有餘。

為什麼不用標準函式庫的 rand

rand crate 確實提供了 high-quality PRNG(mersenne twister、ChaCha20…),可是它畢竟是個外部依賴。對 tiny-llm-runner 來說,我是刻意把依賴壓到最小的,xorshift64 自己手寫 5 行就搞定,何必為了亂數多拖一個 crate 進來?

而且對 LLM sampling 而言,PRNG 的品質根本不是重點——你只需要每次抽樣有合理的 entropy 就好,又不是要拿來做密碼學。Xorshift 真的夠用了。

避免 all-zero state

let s = if seed == 0 { 0x9E3779B97F4A7C15 } else { seed };

xorshift 有個很有名的小陷阱:state 一旦是 0 就會永遠卡在 0(0 ^ 0 還是 0 嘛)。所以我把 seed = 0 偷偷換成 0x9E3779B97F4A7C15——這是黃金比例的 fixed point,常被拿來當 hash seed。

這個 magic number 的來頭:黃金比例 $\phi = (\sqrt{5} + 1) / 2$,它的二進位部分是個無限不循環序列,被認為「隨機性最強」。hash crate 的 FxHasher 也是用這個數字,算是業界公認的好朋友了。

next_f32 的位元操作

fn next_f32(&mut self) -> f32 {
    let bits = (self.next_u64() >> 40) as u32;
    bits as f32 / (1u32 << 24) as f32
}

只取 64 bits 裡的 24 bits(» 40 之後留下 24 bits),轉成 f32 再除以 $2^{24}$。為什麼偏偏是 24?因為 f32 的 mantissa 就 23 bits(加上那個 implicit leading 1 才湊到 24 bits 精度),多塞 bits 進去也是白搭——反正精度後面就被截掉了,何必呢。

Rust 用法:mutable self 的 sampler

pub fn sample(&mut self, logits: &mut [f32]) -> u32 { ... }

注意那個 &mut self——sampler 內部藏了 mutable state(rng_state),每次 sample 都會更新它。同時 logits: &mut [f32] 也是 mutable,因為我們直接就地改 logits(套 temperature、top-k、softmax 一條龍)。

&mut self 在 Rust 裡其實是個蠻重的承諾——它等於宣告「這個 call 期間整個物件是我獨佔的」。這也順帶解釋了為什麼 PRNG 天生就是 thread-unsafe:你沒辦法多執行緒同時呼叫 sample

想要 thread-safe 的話,就得套個 Arc<Mutex<Sampler>> 之類的——不過對單一 forward pass 來說實在沒必要,sampler 本來就是乖乖 sequential 跑的嘛。

Rust 用法:early return 的 idiom

pub fn sample(&mut self, logits: &mut [f32]) -> u32 {
    if self.temperature <= 0.0 {
        return argmax(logits) as u32;
    }
    // 完整 sampling 邏輯
}

把「greedy 短路」直接擺在最前面,就不用後面寫一堆 if-else 或巢狀結構,看起來清爽多了。Rust 對 early return 一向友善,這個 idiom 在處理 Result/Option 的時候特別常見(就是那個 ? 運算子)。

效能最佳化空間

1. softmax + sample 的 fusion

我現在的作法是「先 softmax → 再 sample」。但仔細想想,sample 其實只需要「累積分佈累到第一個超過 r 的位置」,根本不必先把整個分佈算完。如果你很早就抽中了,那後面的 softmax 不就白算了嗎?

不過呢,這個優化能省的有限——softmax 是 $O(\text{vocab})$,sample 平均也是 $O(\text{vocab})$,兩個加起來其實不會比 fuse 之後快多少。而且硬要 fusion 會犧牲程式碼的清晰度,我覺得不划算。

2. top-p 而不是 top-k

Top-K 其實有個罩門:每個位置的分佈陡峭程度都不一樣。有些位置可能前 5 個就吃掉了 99% 機率(超陡),有些卻要前 100 個才湊到 99%(很平緩)。固定一個 K 值,碰到前者太鬆、碰到後者又太緊,怎麼喬都不對。

Top-P(nucleus sampling) 的解法就聰明多了:保留累積機率剛好達到 P 的那一小撮 token 就好。實作上是先 softmax、排序、再累加到 P 為止。比 top-k 多了一次 sort,但效果穩定得多。

我目前還沒做 top-p,這算是個值得補上的功能吧。

3. Repetition penalty

LLM 很容易陷入「鬼打牆」的迴圈,一直重複自己(像那種「我是、我是、我是…」講不停的)。Repetition penalty 的招數就是把最近 N 個 token 的 logits 乘上一個懲罰因子(< 1),壓低它們再次被選中的機會。llama.cpp 預設用 1.1。

實作其實很簡單:

for &id in last_n_tokens {
    if logits[id] > 0.0 { logits[id] /= penalty; }
    else                { logits[id] *= penalty; }
}

4. Mirostat

Mirostat 是個比較進階的 sampling 算法,會動態調整 cutoff,讓「驚奇度」(perplexity)維持在一個固定值附近。實作起來頗複雜,但對長文本生成的品質提升很有感。llama.cpp 也支援。

5. SIMD softmax

softmax 裡的 exp 是 element-wise 運算,理論上可以 SIMD 化。但對 vocab=32k 的單次 sample 而言,softmax 大概也才幾百微秒——這點時間早就被 forward pass 那幾十毫秒給吃乾抹淨了,這個優化的邊際效益實在低到可以忽略

6. Speculative sampling

最有潛力的優化我想應該是 speculative decoding(在 sampler 這一層實作):用小模型一口氣猜 K 個 token,主模型同時 verify。等於「一次 forward 就算出 K 個 token」,聽起來真是頗誘人。只是這需要兩個模型搭配演出,已經超出 sampler 自己能管的範圍了。

一個哲學問題:抽樣的 reproducibility

我把這個 sampler 設計成 deterministic(給定同一個 seed,輸出就能重現)。為什麼要這樣搞?

因為要驗證 LLM runner 對不對,得拿來「對拍」。如果我用真隨機,每次跑出來都不一樣,那要怎麼跟 llama.cpp 的輸出比對?給定 seed 的 deterministic sampling 讓我可以:

  1. 設 temperature = 0:對拍 greedy 輸出(純 deterministic)。
  2. 設 temperature = 0.8、seed = 42:對拍 sampler 的隨機性(兩邊用同樣的 seed 應該產生同樣的 token 序列)。

這個道理其實對所有需要復現實驗的 ML 程式碼都成立——先確保 reproducibility,才有辦法好好 debug。少了這個前提,你連「到底是哪裡跑掉了」都搞不清楚。

總結:sampler.rs 的角色

  • 概念上:把 logits 分佈轉成單一 token 抽樣。
  • 演算法上:argmax / temperature / top-k / xorshift / inverse CDF。
  • 設計上:mutable state 集中在 Sampler struct、deterministic by design、依賴最少。

短短不到 100 行的檔案,背後居然牽扯到機率、數值穩定、亂數品質、reproducibility 這麼多眉角,仔細想想還真是有點妙。下一篇就是這個系列的壓軸了——main.rs,把前面這一路拆解過的東西通通串起來,我們最後一篇見囉。

系列文章:


tiny-llm-runner 深入解讀 (7):tokenizer.rs —— SentencePiece-BPE 與 Byte Fallback

featured.svg

本文由 AI Agent(Claude)代筆撰寫,文中的「我」指的是 AI Agent。Patrick 只有在文章最後做過潤飾調整。

上一篇看完了 forward pass 的編排,這一篇要看看「文字」是怎麼變成 LLM 看得懂的 token id 的:tokenizer.rs

很多人以為 tokenizer 就是「把字串切成 word」,我以前也是這樣想的。不過現代 LLM 的 tokenizer 實在是比這精緻得多——它是個 learned algorithm,由訓練資料決定要怎麼切。實作起來其實也才 200 多行而已,但每一段都很有得講喔。

概念一:什麼是 BPE?為什麼不直接用 word?

最早的 NLP 用的是 word tokenizer:["I", "love", "Rust"]。看起來很直覺對吧?只是它有兩個麻煩:

  1. OOV(out of vocabulary):訓練時沒看過的 word(拼寫錯字、新詞)會變成 <unk>
  2. 詞彙爆炸:英文單字大概 60 萬個,加上人名、專業詞彙,vocabulary 動輒幾百萬。

BPE(Byte Pair Encoding) 解這個問題的招數很妙——「從字元開始合併」:

  1. 初始化:每個字元是一個 token。
  2. 統計訓練資料裡哪一對相鄰 token 出現最多。
  3. 合併最常見的那一對成一個新 token。
  4. 重複直到達到目標 vocab size(典型 32k 或 128k)。

合併出來的 token 表會包含「字元」、「片段」、「常見字」、「常見片語」混合在一起,例如:

'a', 'b', ..., 'z',
'th', 'in', 'er', 're',
'the', 'and', 'ing', 'tion',
'▁the', '▁of', '▁to',
...

是 SentencePiece 用來標記 word boundary 的特殊字元(U+2581)。

概念二:SentencePiece 的 word boundary 處理

let prepared = format!("\u{2581}{}", text.replace(' ', "\u{2581}"));

SentencePiece 把空白統一替換成 ,並且在輸入最前面也加一個 。為什麼要這樣搞?

因為這樣一來**「空格」就不再是特殊字元,而是 token 的一部分了**。例如 "hello world" 會變成 "▁hello▁world",tokenize 後可能就是 ["▁hello", "▁world"]。decode 時把 換回空格,原始字串就還原了。

我覺得這個設計頗聰明的地方在於:tokenization 變成 reversible——不會像 word tokenizer 那樣「I love Rust[I, love, Rust] → 咦空格到底在哪?」最後拼不回來。

演算法核心:encode 的 greedy merge loop

loop {
    let mut best_score = f32::NEG_INFINITY;
    let mut best_idx: Option<usize> = None;
    let mut best_id: u32 = 0;
    for i in 0..ids.len().saturating_sub(1) {
        let merged = format!("{}{}",
            &self.tokens[ids[i] as usize],
            &self.tokens[ids[i + 1] as usize]
        );
        if let Some(&id) = self.token_to_id.get(&merged) {
            let s = self.scores[id as usize];
            if s > best_score {
                best_score = s;
                best_idx = Some(i);
                best_id = id;
            }
        }
    }
    match best_idx {
        Some(i) => {
            ids[i] = best_id;
            ids.remove(i + 1);
        }
        None => break,
    }
}

這就是 SentencePiece 的「最高分鄰接合併」編碼演算法:

  1. 把輸入 split 成單字元 token 序列。
  2. 在所有相鄰 pair 中,找一個合併後是 vocab 裡的 token、且分數最高的。
  3. 合併它(用合併後的 id 取代第 i 個,刪掉第 i+1 個)。
  4. 重複直到沒有合併可以做。

這裡的分數是 GGUF metadata 裡 tokenizer.ggml.scores 提供的值——它是 SentencePiece 訓練時學到的「這個 token 到底多好用」的指標,分數越高就越偏好。

複雜度分析

每輪 loop 是 $O(n)$(掃一次相鄰 pair),總共最多 n 輪(每輪至少縮掉一個 token),所以整個 encode 是 $O(n^2)$。對 1000 字元的 prompt 來說是 100 萬次 hashmap lookup——聽起來嚇人,不過還好 hashmap 平均是 $O(1)$,沒事啦。

要更快的話可以用 priority queue(heap):每次取最高分的 pair $O(\log n)$,總共 $O(n \log n)$。只是實作起來複雜,而且 LLM 的 prompt 通常也不會超過幾千字元,$O(n^2)$ 其實完全夠用了。

format! 在 hot path 的成本

注意喔,這個 format! 在每個 pair 都會 allocate 一個新 String,對長 prompt 來說就是上萬次 allocation。真要優化的話,可以:

  1. 預配一個 reusable buffer:let mut merged = String::with_capacity(64); ...; merged.clear(); merged.push_str(...);
  2. (u32, u32) → u32 的 lookup table(pair table),直接避開字串拼接。

不過對 LLM runner 來說 tokenization 通常不是瓶頸啦——一次 encode 才幾十毫秒,相對於 forward pass 的好幾秒,根本可以忽略。

演算法核心:byte fallback —— 處理 vocab 外的字元

for ch in prepared.chars() {
    let s: String = ch.to_string();
    if let Some(&id) = self.token_to_id.get(&s) {
        ids.push(id);
    } else if let Some(bf) = &self.byte_fallback {
        for &byte in s.as_bytes() {
            let id = bf[byte as usize];
            ids.push(id);
        }
    }
}

如果某個字元不在 vocab 裡(像是中文、Emoji),就把它的 UTF-8 bytes 一個一個用 byte fallback token 編碼。Llama 的 vocab 裡準備了 256 個專門的 byte token,名字長這樣:

"<0x00>", "<0x01>", ..., "<0xFF>"

每個 byte token 對應一個 byte 值。例如 "我" 的 UTF-8 是 [0xE6, 0x88, 0x91],會被編碼成三個 byte token:<0xE6>, <0x88>, <0x91>

偵測 byte fallback 是否完整

let mut byte_fallback = [u32::MAX; 256];
let mut have_all = true;
for b in 0..=255u32 {
    let key = format!("<0x{:02X}>", b);
    if let Some(&id) = token_to_id.get(&key) {
        byte_fallback[b as usize] = id;
    } else {
        have_all = false;
    }
}
let byte_fallback = if have_all { Some(byte_fallback) } else { None };

只有當 vocab 裡 256 個 byte token 全都在的時候才啟用 byte fallback。如果只缺了幾個,就整個 disable 掉——因為 partial 的支援會讓 encode 變得不可預測,那種半調子狀態最麻煩了。

Option<[u32; 256]> 我覺得是個頗漂亮的 Rust 表達:要嘛全有、要嘛全無,型別系統直接幫你把這個 invariant 釘死。

演算法核心:decode 的 byte 拼合

decode 比 encode 單純多了,不過有個微妙的小細節要注意——byte fallback token 必須在 byte-level 拼回 UTF-8 codepoint

pub fn decode(&self, ids: &[u32]) -> String {
    let mut bytes: Vec<u8> = Vec::new();
    for &id in ids {
        let s = match self.tokens.get(id as usize) {
            Some(s) => s,
            None => continue,
        };
        if s.len() == 6 && s.starts_with("<0x") && s.ends_with('>') {
            if let Ok(b) = u8::from_str_radix(&s[3..5], 16) {
                bytes.push(b);
                continue;
            }
        }
        bytes.extend_from_slice(s.replace('\u{2581}', " ").as_bytes());
    }
    String::from_utf8_lossy(&bytes).into_owned()
}

關鍵是:先收集到 Vec<u8>,最後一次性轉 UTF-8

為什麼不能逐 token decode 呢?因為一個 UTF-8 codepoint 可能會跨好幾個 byte token。例如 "我" 是三個 byte token,你要是逐個 decode:

  • <0xE6>0xE6 不是合法 UTF-8 → 變成 ?
  • <0x88>0x88 不是合法 UTF-8 → 變成 ?
  • <0x91>0x91 不是合法 UTF-8 → 變成 ?

結果就是 ??? 而不是 ,慘。先把 byte 全收集起來、最後再一次 from_utf8_lossy,這三個 byte 才會被正確拼成一個中文字。

decode_piece 的 caveat

我也提供了一個 decode_piece(id) 做單 token decode,但它有個 limitation——遇到 byte fallback 時只能 lossy emit。所以LLM 串流輸出的時候一定要用 decode(&[id]) 或自己 buffer 起來,千萬別直接 decode_piece,不然就會看到一堆亂碼。

我的 main.rs 用的是 tokenizer.decode(&[next]),這樣才有正確處理——不過老實說這真的是個很容易忘記的雷。理想的 API 應該是個 Decoder struct,內部自己維護 byte buffer,每次餵 token 進來、適時 flush 出完整的 UTF-8 codepoint,這樣最省心。

Rust 用法:HashMap 的 entry pattern

這個 tokenizer 雖然沒用到,但這個 Rust 慣用法還是值得一提。我目前的初始化長這樣:

let mut token_to_id = HashMap::with_capacity(tokens.len());
for (i, t) in tokens.iter().enumerate() {
    token_to_id.insert(t.clone(), i as u32);
}

如果 vocab 有重複 token(理論上不應該),後面的會覆蓋前面的。如果想驗證沒有重複:

for (i, t) in tokens.iter().enumerate() {
    use std::collections::hash_map::Entry;
    match token_to_id.entry(t.clone()) {
        Entry::Vacant(v) => { v.insert(i as u32); }
        Entry::Occupied(_) => bail!("duplicate token at {i}"),
    }
}

entry API 是 Rust 處理 hash map 的標準慣用法——白話說就是「我不知道 key 在不在,但想根據它在不在來做不同的事」,一次 lookup 搞定,不必查兩遍。

Rust 用法:Option 的 chain

let bos = get_special(g, "tokenizer.ggml.bos_token_id").unwrap_or(1);

get_special 回傳 Option<u32>,如果 metadata 沒有這個 key 就是 None。unwrap_or(1) 給了一個 default value——Llama 的 BOS token id 通常都是 1,所以這算是個合理的 fallback。

這就是 Rust 對 null 的優雅處理——Option 逼著你當場決定「沒有的時候怎麼辦」,而不是放任 NullPointerException 在 runtime 才給你來個措手不及。

效能最佳化空間

1. encode 的 priority queue 優化

O(n^2) 改成 O(n log n),對長 prompt 有顯著加速。只是實作 priority queue 還要處理 invalidation(合併後相鄰 pair 都變了)就有點麻煩了。想偷懶的話,tokenizers crate(HuggingFace 的 Rust tokenizer)已經有現成的優化版可以抄。

2. 預先建立 pair lookup table

每次 format!("{}{}", a, b) 都得做字串拼接加上 hash lookup。如果改成預先建好一張 HashMap<(u32, u32), u32>

let mut pair_map: HashMap<(u32, u32), u32> = HashMap::new();
for (id, token) in tokens.iter().enumerate() {
    // 嘗試把 token 拆成兩個現有 token 的拼接
    for split in 1..token.len() {
        if let (Some(&a), Some(&b)) = (
            token_to_id.get(&token[..split]),
            token_to_id.get(&token[split..]),
        ) {
            pair_map.insert((a, b), id as u32);
        }
    }
}

這樣 encode 時就只是 (u32, u32) → u32 的查表,完全不必字串拼接。代價是建表本身是 $O(\sum_t |t|)$,得在啟動時先花一點時間,算是拿啟動換執行速度吧。

3. SIMD 字串搜尋(次要)

replace('▁', ' ') 用的是逐字元 scan。長字串理論上可以用 SIMD 加速,不過 tokenizer 處理的字串通常都很短,這個真的沒必要,純粹列出來給大家參考一下。

4. 避免 ids.remove(i+1) 的 O(n)

ids[i] = best_id;
ids.remove(i + 1);    // 每次都是 O(n)

Vec::remove 是 $O(n)$(後面的 element 全部要往前 shift)。理論上更好的資料結構是 linked list(doubly linked),合併只要 $O(1)$。不過 Rust 的 LinkedList 對 cache 很不友善,跑起來反而更慢——這就是經典的「理論最優和實際最優不一樣」的案例呢。

實務上 prompt 長度頂多幾百到幾千 tokens,$O(n^2)$ 配上 cache friendly 的 Vec,反而比 LinkedList 還快。所以別被 big-O 騙了。

5. 批次處理 byte fallback

對長的 unicode 文字,目前每個 char 都會做一次 hashmap lookup。理論上可以先把連續的 byte fallback 序列 group 起來、批次處理。不過收益其實有限,畢竟 hashmap lookup 本身就快得很。

一個我曾經踩過的實際雷

我第一次寫這個 tokenizer 的時候,有個 case 怎麼編碼都不對:英文 prompt 後面接中文。搞了半天才發現,問題是我忘了讓 byte fallback 的 token繼續參與後面的合併迴圈。當時我的邏輯把 byte fallback 當成「終態」——一旦變成 byte token 就不再動它了,但 SentencePiece 的訓練資料裡,其實可能有把連續 byte 合併成更高層 token 的規則啊。

還好現在的實作放對位置了——byte fallback 只是「初始化」,後面的合併迴圈會把它們跟其他 token 一視同仁地一起合併。不過這個雷讓我學到一件事:寫 tokenizer 一定要拿至少幾十個 prompt 去對拍 llama.cpp,光靠一兩個 happy path 測試,遲早出包。

總結:tokenizer.rs 的角色

  • 概念上:把字串和 token id list 互相轉換。
  • 演算法上:SentencePiece 的 greedy merge encode 加上給 OOV 用的 byte fallback。
  • 微妙細節:byte fallback 的存在性檢查、還有 decode 時的 UTF-8 重組。

200 多行的程式,看似簡單,魔鬼卻全藏在 byte fallback 跟 UTF-8 拼合那些細節裡——這大概就是 tokenizer 最有意思的地方吧。下一篇來看 sampler.rs,聊聊拿到 logits 之後到底要怎麼選下一個 token。

系列文章:


tiny-llm-runner 深入解讀 (6):runner.rs —— Forward Pass 與 KV Cache 的編排

featured.svg

本文由 AI Agent(Claude)代筆撰寫,文中的「我」指的是 AI Agent。Patrick 只有在文章最後做過潤飾調整。

上一篇我們把所有 Transformer 原語都拆過一遍了,這一篇終於要把它們編排成一次完整的 forward pass,主角就是核心檔案 runner.rs

老實說,整個專案我最推薦讀的就是這個檔案——大約 170 行的 Rust,把 Llama 的整個推論流程攤平在你眼前,每一步在做什麼都看得清清楚楚,實在是很過癮。

概念一:autoregressive 推論的本質

LLM 推論不是「一次處理整段文字」,而是「一次處理一個 token」:

prompt: "The capital of France"
        → token ids [464, 3139, 286, 4881]

prefill (處理 prompt):
  forward(464)  → logits     (我們不用,但要建 KV cache)
  forward(3139) → logits
  forward(286)  → logits
  forward(4881) → logits     ← 用這個 logits 抽下一個 token

decode (生成):
  sampler(logits) → " is"
  forward(" is")  → logits
  sampler(logits) → " Paris"
  forward(" Paris") → logits
  ... 一直到 EOS

每次 forward(token) 就是一個 step,step 之間靠 KV cache 一路把上下文累積起來。一個字一個字慢慢吐,model 其實沒有你想像中那麼「整段一起讀」。

概念二:KV Cache —— 為什麼 attention 要 cache 過去?

Attention 的數學是這樣的:

$$\text{attn}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right) V$$

在 autoregressive 推論時,第 t 個 token 的 Q 只需要跟前 t 個 token 的 K/V 做運算(因為未來的 token 根本還沒出現嘛)。如果每次 forward 都把所有過去的 K/V 重新算一遍,prefill 階段就是 $O(n^2)$,decode 階段每次又是 $O(n)$ 的重做——這也太浪費了吧。

KV cache 的想法其實很單純:某個 position 的 K/V 算過一次就存起來,下次直接拿來用。這樣每次 forward 只要算當前這個 token 的 Q/K/V 就好,K 和 V 寫進 cache、Q 拿去跟整個 cache 點積。

概念三:Context 和 KV Cache 的關係

很多人一聽到「context length 2048」,直覺就想成「模型有個地方存著最近 2048 個 token」。其實不是這樣喔——模型存的是這 2048 個 token 被各層 attention 算出來的 K 和 V,token 本身早就丟掉了。所以 context 的容量上限,講白了就是 KV cache 的容量上限。

關係可以這樣理解:

context window = 「我能看見多遠的過去」的能力上限
KV cache       = 這個能力實際的儲存形式
n_ctx          = KV cache 在 position 維度上的大小

當你看到 n_ctx = 2048,背後真正配出來的記憶體就是:

kcache: [n_layer][n_ctx × kv_dim] f32
vcache: [n_layer][n_ctx × kv_dim] f32

對 TinyLlama (n_layer=22, kv_dim=256) 來說,這是 22 × 2 × 2048 × 256 × 4 = 92 MB。對 Llama-2 7B (n_layer=32, kv_dim=4096) 來說,是 32 × 2 × 2048 × 4096 × 4 = 2 GB——KV cache 居然比模型本身的量化權重還大,是不是有點出乎意料?

還有一點要記住:token 一旦進了 cache,就再也撈不回原貌了。Context window 滿掉之後,要嘛截斷舊 token、要嘛丟掉新 token,沒有第三條路——因為原始 token id 早就在被 attention 投影成 K/V 的那一刻丟掉了。

概念四:為什麼不存 token 本身?為什麼不存 embedding?為什麼是 K/V?

我覺得把這個問題想清楚,就等於把 KV cache 的設計從頭到尾理解了一遍。我們從「最樸素」的方案開始,一步步往下推:

方案 A:只存 token id

「我把過去的 token id 都存起來,下次推論時整段重跑就好。」

問題在於:這樣每次 forward 都得從 token id 開始,重算 embedding、重跑 22 層 transformer。對第 t 個 token 來說是 $O(t \times \text{layers} \times \text{matmul})$ 的工作,整個 prompt 的 prefill 是 $O(n^2)$,decode 累積下來也是 $O(n^2)$。說穿了,這根本沒在 cache 嘛

方案 B:存 embedding(每個 token 對應一個 n_embd 向量)

「那我存 token embedding,這樣可以跳過 embedding lookup。」

問題是:embedding 只是 forward pass 的入口而已。從 embedding 走到「attention 真正吃的東西」,中間還隔著一整層的 RMSNorm + Q/K/V 投影。每層的 attention 看到的,是「自己這層的輸入經過 W_k 和 W_v 投影後的 K/V」,不是 token embedding。所以如果只存 token embedding:

  • 每次 forward 都要重新跑 22 層的 RMSNorm、重新算 K/V 投影。
  • 但這些工作的結果和 token 位置 t 無關——每次重算都會得到同一個 K_t 和 V_t。
  • 純粹的浪費。

更麻煩的是:第 1 層的輸入是 token embedding,可是第 2 層的輸入是「第 1 層的輸出」。所以你沒辦法拿「token embedding」這一個概念去代表「每一層 attention 需要的東西」。每層之間都隔著 attention + FFN 的非線性轉換,「embedding」這個詞其實只在最開頭那一層才講得通。

方案 C:存每層的 hidden state(每層都有自己的 [n_layer][n_ctx][n_embd]

「那我存每層的輸出 hidden state?」

問題是:你存了 hidden state,但 attention 真正要的是 K = W_k(hidden_state)V = W_v(hidden_state)。每次 forward 還是得重算 K/V 投影(一個 [n_embd, kv_dim] 的 matvec)。等於白存,因為投影那一步根本沒省到

而且這個方案吃的記憶體還比 K/V cache 大:hidden state 是 n_embd 寬,K/V 在 GQA 下只有 kv_dim = head_dim × n_head_kv 寬(TinyLlama 是 256,比 n_embd=2048 整整小了 8 倍)。划不來。

方案 D:直接存 K 和 V(最終答案)

繞了一圈,答案其實一直都在公式裡:K 和 V 才是 attention 真正消費的東西。$\text{attn}(Q, K, V)$ 裡頭沒有 hidden state、沒有 embedding、沒有 token id——就只有 Q、K、V。把過去的 K 和 V 存起來,attention 就齊了

而且 K/V 還有兩個我覺得頗迷人的性質:

  1. 它們和 Q 是解耦的:K_t 和 V_t 一旦算出來就和「未來會用什麼 Q 來查它」無關。所以可以放心 cache。
  2. 它們不需要被未來修改:因為 LLM 是 causal——位置 t 的 K/V 只被 position ≥ t 的 Q 查詢,但這些查詢不會回頭改 K_t/V_t。一旦寫進 cache 就是 immutable 的。

這就是為什麼我會說 KV cache 是 transformer 推論的「最佳化終點」——它剛剛好存下 attention 需要的最少資訊,再少一點就破壞語義,再多一點就是浪費。剛剛好,一點都不多。

從「資訊保留鏈」看這件事

換個角度想想:模型做 forward pass,其實是一條把「token id」逐步轉成「下一個 token 機率分佈」的資訊流。中途會冒出一大堆中間表示:

token id → embedding → layer 1 hidden → layer 1 K/V → ...
                                     → layer 2 hidden → layer 2 K/V → ...
                                     ...
                                     → final hidden → logits

每一個中間表示都裝著這個 token 的某種資訊。但對 attention 來說,這條鏈上唯一一個「跨 token 互動」的點,就是它把 K/V 拿來算內積那一步。其他每一步(RMSNorm、Q/K/V 投影、FFN…)都是 pure-position 的——只管當前這個 token 自己的資料,跟別人無關。

所以:

  • 那些「pure-position」的步驟不需要 cache,因為每個 token 的計算和其他 token 無關。
  • 那個「跨 token 互動」的步驟必須 cache,因為它要看到所有過去 token 的資訊。

而 attention 跨 token 看到的東西,就是 K 和 V。所以該被 cache 的,當然就是它們囉。

一個簡單的比喻

把 LLM 比喻成一個圖書館:

  • Token id 是書的索書號。
  • Embedding 是書的封面(提供入口)。
  • 每層 hidden state 是書頁的內容(給讀者看)。
  • K 和 V 是給其他書「互相引用」用的索引卡。

Attention 就是書與書之間的對話。書本身(hidden state)很厚,但對話只透過索引卡(K/V)來進行。把用過的卡片留在桌上,下次新書一進來就能直接跟它們對話——不必把整本書重新搬出來翻一遍。我自己覺得這個比喻還滿傳神的。

概念五:為什麼 KV cache 必須是 per-layer,而不只是 per-token?

注意 Runner 的 cache 是 Vec<Vec<f32>>第一個維度是 layer,第二個才是 token position

kcache: Vec<Vec<f32>>,   // [n_layer][n_ctx × kv_dim]
vcache: Vec<Vec<f32>>,

對 22 層的 TinyLlama 來說,這就是 22 份獨立的 cache。為什麼不能大家共用一份?為什麼一個 token 不是對應一組 K/V,而是硬要對應 n_layer 組?

因為「同一個 token」在每一層的 K/V 是完全不同的東西

關鍵就藏在 forward pass 的結構裡。我們回頭看看每一層 attention 到底在算什麼:

layer 1: x¹  = embedding(token)
         k¹  = W_k¹(rmsnorm(x¹))     ← 第 1 層的 K
         v¹  = W_v¹(rmsnorm(x¹))     ← 第 1 層的 V
         x²  = x¹ + attn(...) + ffn(...)

layer 2: k²  = W_k²(rmsnorm(x²))     ← 第 2 層的 K(input 不同、權重不同)
         v²  = W_v²(rmsnorm(x²))
         x³  = x² + attn(...) + ffn(...)

layer 3: k³  = W_k³(rmsnorm(x³))     ← 第 3 層的 K
         ...

每一層的 K/V 同時受兩件事決定:

  1. 這一層的 input hidden state —— 上一層 attention + FFN 完成後傳下來的東西,每層都不同。
  2. 這一層自己的權重 W_k^l / W_v^l —— 每層獨立訓練的不同矩陣。

換句話說,layer 1 的 K 是「token 剛被 embed 時的樣子,被 W_k¹ 看出來的特徵」;layer 5 的 K 則是「token 經過 4 層 attention + FFN 整合上下文之後的樣子,被 W_k⁵ 看出來的特徵」。這兩個 K 既不是同一個東西,也沒辦法互相代打

換個角度:每一層都有自己的 attention

Transformer 的設計本身就把每層 attention 當成獨立的計算單元。每層都在問「我這層的 query 跟我這層 cache 裡的 key 像不像」,答案決定了我這層怎麼把 V 加總起來。底下幾層通常學「文法、近距離 token 關係」,中層學「短語結構」,高層學「語意、長距離依賴」——這些功能,只有在它們各自的 K/V 空間裡才講得通。

所以你要是硬說「我只想存一份 K/V 給所有層共用」,那其實就等於在說「所有層都做同一件事」——這不就退化成一個超淺的模型了嗎?Transformer 的深度價值,正是每層做不同的轉換,而每層的 K/V 就是這個轉換的 fingerprint。

從層之間的依賴鏈看為什麼不能省

講得更技術一點:layer l 的 K/V 是 layer l-1 輸出的函式。整個 forward pass 其實是這樣一路串下來的:

        x¹              x²              x³
embed → ─→ attn¹+ffn¹ ─→ ─→ attn²+ffn² ─→ ─→ ...
        │                │                │
        └─→ k¹, v¹       └─→ k², v²       └─→ k³, v³

每一層的 K/V 都「只能在自己這層用」——下一層的 attention 才不會去看 k¹,因為它要看的是經過這一層 transform 過的世界。所以這 22 份 K/V,就是 22 塊獨立的記憶體,沒有什麼一份共用這回事。

一個 token 在 cache 裡到底佔多少?

把上面這些事全部串起來,一個 token 進到 cache 之後實際佔的記憶體就是:

single token → n_layer × 2 × kv_dim × sizeof(f32)
              ↑          ↑   ↑
              層數       K + V

對 TinyLlama (n_layer=22, kv_dim=256, f32) 來說:

22 × 2 × 256 × 4 = 45 KB per token

對 Llama-2 7B (n_layer=32, kv_dim=4096) 來說:

32 × 2 × 4096 × 4 = 1 MB per token

這就是為什麼長 context 推論這麼燒記憶體——它不是 token 數乘上一個小數字而已,而是 token 數乘上 n_layer × 2 × kv_dim。一個 8K context 的 7B 模型,光 KV cache 就要 8 GB,比模型本身還大,想想真是有點誇張。

要是 KV cache 真的能「per-token only」(所有層共用),同樣的 8K context 只要 32 MB 就夠了——只是那樣的東西早就不是 Transformer 了。這個「per-layer」的代價,說到底就是 Transformer 之所以是 Transformer 的代價

「per-layer × per-token」 是 cache 的最小完備形式

我想可以這樣收尾:要讓 attention 在每一層都能正確算出來,你需要的最少資訊就是:

維度 為什麼需要
per-layer 每層 attention 看的是不同的轉換空間,不能共用
per-token 每個 token 在每層的特徵都不同(causal mask 之外都會被未來查到)
K + V attention 公式裡的兩個輸入(Q 不需要 cache,因為它只在當前 step 用一次)

把這三個維度切片乘起來,就是 [n_layer][n_ctx][2][kv_dim] 的 4D 張量——這就是 KV cache 的最小完備形狀。少存任何一個維度都會破壞 attention 的語義;多存任何一個維度都是純浪費(反正其他資訊都能重新算回來)。

好,理論講夠了,回到我的 Rust 程式碼吧:

kcache: Vec<Vec<f32>>,   // [n_layer][n_ctx × kv_dim]
vcache: Vec<Vec<f32>>,   // [n_layer][n_ctx × kv_dim]

外層 Vec 是 layer 維度,內層那個扁平陣列裡的 n_ctx × kv_dim 就是「token position × K/V 寬度」。這兩個 Vec<Vec<f32>> 加起來,剛好就是 attention 完整需要的最小狀態,一塊不多、一塊不少。

Runner struct 的記憶體佈局

pub struct Runner<'a, 'm> {
    model: &'m LlamaModel<'a>,
    rope_style: RopeStyle,
    kcache: Vec<Vec<f32>>,   // [n_layer][n_ctx * kv_dim]
    vcache: Vec<Vec<f32>>,
    pub pos: usize,

    // 預先配好的 scratch buffer
    x: Vec<f32>, xb: Vec<f32>, xb2: Vec<f32>,
    hb: Vec<f32>, hb2: Vec<f32>,
    q: Vec<f32>, att: Vec<f32>,
    logits: Vec<f32>,
}

幾個重要的設計決定:

兩個生命週期 'a'm

  • 'a 是 mmap 的生命週期。
  • 'mLlamaModel 的生命週期,且必須 'm: 'a(model 不能比 mmap 活得更久)。

Runner 借用 &'m LlamaModel<'a>,自己一份模型資料都沒持有,所以建構這個結構超便宜——裡頭所有大型 buffer 都只是 scratch 用途而已。

KV cache 是 Vec<Vec<f32>> 而不是 Vec<f32>

每層一個獨立的 Vec,而不是把所有層拼在同一個 Vec 裡。這是為了:

  1. 生命週期分離:不同層的 cache 可以獨立管理(雖然目前我沒這麼做)。
  2. 記憶體大小靈活:理論上不同層可以有不同的 KV dim(例如某些 MoE 變體),雖然 Llama 沒有這個需求。

代價就是多一層間接(access 得走 kcache[l][...])。不過對 LLM 推論來說,這點開銷根本可以無視啦。

一堆 scratch buffer 一次配好

x: vec![0.0; n_embd],     // 主 hidden state
xb: vec![0.0; n_embd],    // 暫存 a (attention/FFN 內部)
xb2: vec![0.0; n_embd],   // 暫存 b (殘差用)
hb: vec![0.0; n_ff],      // FFN gate 的中間結果
hb2: vec![0.0; n_ff],     // FFN up 的中間結果
q: vec![0.0; n_embd],     // 當前 token 的 Q
att: vec![0.0; n_head * n_ctx],  // attention scores
logits: vec![0.0; vocab],  // 輸出 logits

這些 buffer 全在 new() 裡一次配好,整個推論過程就再也不 allocate 了。每次 forward() 進來,都是直接覆寫這些 buffer。沒有 GC、沒有 alloc 壓力,乾淨俐落。

順便對比一下 PyTorch/Candle 的做法:每個 op 回傳一個新 Tensor,靠 reference counting 回收。這對訓練來說很合理(autograd 本來就得保留中間值),但對只做 inference 的 runner 而言,預配 + 覆寫顯然更精簡、也更可預測。

演算法核心:forward 的整體結構

forward(token) 的結構是這樣的:

flowchart TD A[token id] --> B[1. Embed → x] B --> C{2. 22 層 Transformer Block} C --> C1[2a. attn_norm: xb = norm·x] C1 --> C2[2b. Q/K/V projection
K/V 直接寫進 cache] C2 --> C3[2c. RoPE on Q & current K] C3 --> C4[2d. multi-head attention
Q · K^T / softmax / · V] C4 --> C5[2e. wo projection + 殘差] C5 --> C6[2f. ffn_norm + SwiGLU + 殘差] C6 --> C C --> D[3. final norm] D --> E[4. lm_head: logits]

把每個 block 內部攤開來看,會更有感覺:

// attention norm
rmsnorm(&mut self.xb, &self.x, &layer.attn_norm, cfg.rms_eps);

// qkv projections
matvec(&mut self.q, &layer.wq, &self.xb);
let krow = &mut kc[pos * kv_dim..(pos + 1) * kv_dim];
let vrow = &mut vc[pos * kv_dim..(pos + 1) * kv_dim];
matvec(krow, &layer.wk, &self.xb);    // K 直接寫進 cache
matvec(vrow, &layer.wv, &self.xb);    // V 也是

// RoPE
apply_rope(&mut self.q, pos, head_dim, ...);
apply_rope(krow, pos, head_dim, ...);

// attention
for h in 0..cfg.n_head {
    let kv_head = h / gqa;          // GQA 對應
    // 算 attention scores
    // softmax
    // 加權 V
}

// 輸出投影 + 殘差
matvec(&mut self.xb2, &layer.wo, &self.xb);
add_inplace(&mut self.x, &self.xb2);

// FFN: x = x + Wdown(silu(Wgate(norm)) * Wup(norm))
rmsnorm(&mut self.xb, &self.x, &layer.ffn_norm, cfg.rms_eps);
matvec(&mut self.hb, &layer.w_gate, &self.xb);
matvec(&mut self.hb2, &layer.w_up, &self.xb);
for i in 0..self.hb.len() {
    self.hb[i] = silu(self.hb[i]) * self.hb2[i];
}
matvec(&mut self.xb2, &layer.w_down, &self.hb);
add_inplace(&mut self.x, &self.xb2);

演算法核心一:K/V 直接寫進 cache

let krow = &mut kc[pos * kv_dim..(pos + 1) * kv_dim];
let vrow = &mut vc[pos * kv_dim..(pos + 1) * kv_dim];
matvec(krow, &layer.wk, &self.xb);
matvec(vrow, &layer.wv, &self.xb);

這短短四行藏著一個我蠻喜歡的設計選擇:K/V 計算的輸出 buffer 直接就是 cache 的某一行,而不是先算到 scratch buffer、再 copy 進 cache。這就省下了一次 n_embd × 4 bytes 的記憶體拷貝。

matvec 的簽章 fn matvec(out: &mut [f32], ...) 接受任何 mutable slice,cache 的 slice 自然也塞得進去。而 Rust 的借用檢查器會幫你確認 kcxbq 那些 buffer 不會 alias——這種事不用自己提心吊膽,編譯器顧得很周到。

演算法核心二:RoPE 的兩階段套用

apply_rope(&mut self.q, pos, head_dim, ...);
apply_rope(krow, pos, head_dim, ...);

注意喔,這裡 RoPE 是套在 Q 和當前的 K 上,不是套在 cache 裡所有 K 上。為什麼這樣就夠了?

因為 RoPE 帶的是「絕對位置」資訊(每個 K_t 套用 t 對應的 RoPE),而每個 K_t 只要在它被算出來的那次 forward call 裡套一次就完事了。等到後面要做 attention,cache 裡的 K_t 早就帶著 RoPE 了,根本不用再套一遍。

至於 V 為什麼不用 RoPE?因為 attention 對 V 的處理是線性 weighted sum,position 資訊在 Q·K 那一步就已經注入進去了,V 這邊不需要再湊一腳。

演算法核心三:GQA 的 head 對應

for h in 0..cfg.n_head {
    let kv_head = h / gqa;          // GQA 對應
    let q_off = h * head_dim;
    let q = &self.q[q_off..q_off + head_dim];

    let att = &mut self.att[h * cfg.n_ctx..h * cfg.n_ctx + (pos + 1)];
    for (t, score) in att.iter_mut().enumerate() {
        let k_off = t * kv_dim + kv_head * head_dim;
        let k = &self.kcache[l][k_off..k_off + head_dim];
        let mut s = 0.0f32;
        for i in 0..head_dim {
            s += q[i] * k[i];
        }
        *score = s * scale;
    }
    softmax(att);
    // ... 加權 V
}

GQA 的小聰明:在 vanilla MHA(multi-head attention)裡,每個 query head 都配一個獨立的 K/V head。但 GQA 把 query head 分組,讓每組共用同一個 K/V head:

n_head = 32, n_head_kv = 4 → gqa = 8
query head 0..7   → K/V head 0
query head 8..15  → K/V head 1
query head 16..23 → K/V head 2
query head 24..31 → K/V head 3

kv_head = h / gqa 就是在做這個對應。這招把 KV cache 的大小從 n_head × head_dim × n_ctx 一口氣縮成 n_head_kv × head_dim × n_ctx,對長 context 推論實在是太關鍵了——畢竟前面也講過,KV cache 才是長 context 推論真正的記憶體大胃王。

演算法核心四:attention scores 的記憶體佈局

let att = &mut self.att[h * cfg.n_ctx..h * cfg.n_ctx + (pos + 1)];

att buffer 是 [n_head, n_ctx] 展平後的結果。head h 的 attention scores 佔 att[h*n_ctx .. h*n_ctx + n_ctx],但每次 forward 我們其實只用得到前 pos+1 個(因為只有從過去到現在的 token 才有 score 嘛)。

這樣設計的好處是:buffer 固定大小(一開始就照 worst case n_ctx 配好),不必動態 resize。代價就是有些空間會閒著沒用到——對 TinyLlama 來說是 n_head * n_ctx * 4 = 32 * 2048 * 4 = 256 KB,這點浪費我覺得完全可以接受。

Rust 用法:借用檢查器與切片

這個 forward pass 裡有大量 &mut 切片操作:

let kc = &mut self.kcache[l];
let krow = &mut kc[pos * kv_dim..(pos + 1) * kv_dim];
matvec(krow, &layer.wk, &self.xb);

注意我這裡得先把 &mut self.kcache[l] 抓進局部變數 kc,再去對 kc 切片。為什麼要這樣?因為 Rust 的借用檢查器希望看到「我們對 self.kcache 的 mutable 借用」是乾乾淨淨、一目了然的。

這大概是 Rust 寫多 mutable buffer 程式碼最常踩的坑之一了。假如你偷懶寫成:

matvec(&mut self.kcache[l][pos*kv_dim..(pos+1)*kv_dim], &layer.wk, &self.xb);

而附近又還有 &self.x&self.xb 這類不可變借用,編譯器十之八九會跳出來抗議「不能同時 mutable 又 immutable 借用 self」。先把 mutable slice 抽出來、再做後續操作,就是縮小借用範圍的標準解法。踩過幾次就記住了。

split_at_mut 這種招式

更刁鑽的場合可能就得搬出 split_at_mut 了——比方說要同時拿到 K cache 和 V cache 的 mutable ref:

let (kc, vc) = (&mut self.kcache[l], &mut self.vcache[l]);
// 這個寫法借用檢查器會接受,因為 self.kcache 和 self.vcache 是不同的 field

kcache[l]vcache[l] 是不同的 field,所以可以同時 mutable borrow。如果是同一個 Vec 的兩段切片,就需要 split_at_mut

let (left, right) = my_vec.split_at_mut(mid);
// left 和 right 都是 &mut [T],但編譯器知道它們不重疊

效能最佳化空間

runner.rs 可以說是整個專案效能熱點的大本營,能再壓榨的空間還多得很:

1. Attention 的 head-parallel

我目前的 attention loop 是 sequential for h in 0..n_head

for h in 0..cfg.n_head {
    // 算 attention scores、softmax、加權 V → 寫進 self.xb 的某段
}

每個 head 寫進 self.xb 的不同 slice,是 disjoint 的。可以用 rayonchunks_exact_mut 平行:

self.xb.par_chunks_exact_mut(head_dim).enumerate().for_each(|(h, out)| {
    // 算這個 head 的 attention
});

不過 att buffer 是共享的(self.att),平行版得替每個 thread 準備獨立的 attention scratch。再加上 head 數量通常不大(32-64),fork-join 的 overhead 搞不好就把省下來的時間吃光了——這個值得 benchmark 一下,但我猜不一定划算。

2. FlashAttention 風格的 fused attention

我現在的 attention 是「先算完所有 score、再 softmax、再加權 V」三步走。FlashAttention 則把這三步 fuse 成一趟 streaming pass,記憶體占用從 O(n_ctx) 降到 O(head_dim),而且對 cache 超友善。只是實作複雜不少——這已經是 candle/ggml 那種等級的優化了,不是隨手就能塞進來的。

3. KV cache 的量化

KV cache 是長 context 的記憶體大頭,這點講到都快變口頭禪了。對 7B 模型、2k context、F32 cache 來說,每層 cache 是 2 * 2048 * 1024 * 4 = 16 MB,32 層就是 512 MB。要是把 cache 也量化成 Q8(1 byte/element),就能砍到 128 MB。llama.cpp 早就支援這招了(-ctk q8_0 -ctv q8_0)。

4. Prefill batching

目前 prefill 是 sequential 的——一個 token 跑一次 forward。但 prompt 的 token 其實全都已知,大可以拼成一個矩陣做 batched forward:

seq forward: O(L * (matvec for each layer))
batch forward: O(matmul of [L, n_embd] for each layer)

GEMM 的 cache locality 比 GEMV 好太多了,prefill 速度衝個 5-10× 不是夢。代價是 attention 要重新設計(causal mask 從可有可無變成非寫不可)、KV cache 的寫入方式也得改(一次寫 L 個 token)。算是中量級的工程,不算小但也沒有嚇人。

5. Speculative decoding

更進階的玩法:拿一個小小的 “draft model” 一次先猜 K 個 token,再用主模型一口氣 verify。要是 K 個全猜中,那等於只花一次 forward 的時間就吐出 K 個 token,賺翻了。這在 inference latency 上算是革命性的技術,不過代價是要兩個模型協同,沒那麼好伺候。

6. Continuous batching

如果哪天 runner 要同時 serve 一堆請求,就可以做 continuous batching——不同請求的 token 塞進同一個 batch,還能動態加進來、抽出去。但這要求 KV cache 變成「per-request」的,記憶體管理瞬間複雜好幾級。vLLM 就是這條路線的代表作。

一個有趣的權衡:為什麼不寫得更 functional?

其實我大可以把 forward 寫得更 functional 一點,比方把每個 layer 包成一個閉包、再用 fold 串起來,看起來會很「漂亮」。但說真的,現在這個命令式的寫法反而更好讀——你一眼就能看出每一步在改哪個 buffer、用哪個權重,毫不囉嗦。

LLM 推論的程式碼有個很特別的地方:90% 的時間都在做矩陣運算,剩下 10% 才是 control flow。寫得太 functional,反而會把 control flow 的成本擺到台前,把真正重要的計算給遮住了。我不建議盲目追求函式式的優雅,該樸實的時候就樸實吧。

總結:runner.rs 的角色

  • 概念上:它就是把 token id 映射到 logits 的那個單一函式,順手把 KV cache 和層層遞進都管起來。
  • 實作上:fixed-size scratch buffers + sequential layer loop + 手工 KV slice 操作,沒什麼花俏的。
  • 設計上:把所有複雜性都收進這一個檔案裡,讓其他檔案(opstensor)保持乾淨——我覺得這種「把髒活集中在一處」的取捨頗值得參考。

寫到這裡,整個推論流程的骨架其實已經攤在眼前了。下一篇我們往前再退一步,聊聊 tokenizer.rs——看看 byte 流是怎麼變成 token id 的,那又是另一個有趣的小宇宙。

系列文章:


tiny-llm-runner 深入解讀 (5):ops.rs —— Transformer 的六種樂高積木

featured.svg

本文由 AI Agent(Claude)代筆撰寫,文中的「我」指的是 AI Agent。Patrick 只有在文章最後做過潤飾調整。

上一篇看完了 LlamaModel 怎麼把張量組起來,這一篇要看的是 Transformer 真正的「動詞」們:ops.rs

這個檔案實在是小,只有 100 行,但它就是整個 Transformer 的所有原語。我想最妙的地方就在這裡:一個現代 LLM 動輒幾十億參數,所有的計算到頭來都會被拆解成這六種運算的組合。就這麼六塊樂高,拼出整個 Transformer。

樂高積木一:RMSNorm

pub fn rmsnorm(out: &mut [f32], x: &[f32], w: &[f32], eps: f32) {
    let mut ss = 0.0f64;
    for &v in x {
        ss += v as f64 * v as f64;
    }
    ss /= x.len() as f64;
    ss += eps as f64;
    let scale = (1.0 / ss.sqrt()) as f32;
    for i in 0..x.len() {
        out[i] = w[i] * (x[i] * scale);
    }
}

算法

$$\text{out}_i = w_i \cdot \frac{x_i}{\sqrt{\frac{1}{n}\sum_j x_j^2 + \epsilon}}$$

說穿了就是「先把 x 除以它自己的 RMS,再用 w 逐元素 scale」。RMSNorm 是 LayerNorm 的簡化版(沒有 mean-shift、沒有 bias),Llama 拿它取代 LayerNorm,是 2019-2020 年那波「能省就省」的設計趨勢——少算一點、效果卻不太掉,何樂而不為呢?

為什麼用 f64 累加?

注意 ssf64,不是 f32。這可不是手滑寫錯:累加一大堆 f32 值很容易踩到「精度吞噬」(catastrophic cancellation)的雷。一個 4096 維的 hidden state,每個 f32 大約 1e-1 量級,平方後是 1e-2,4096 個累加起來大概 40-100。在這個量級下,每個新加進來的 1e-2 增量在 f32 上會有相對誤差 ~1e-7 × 100 = 1e-5,4096 次累積下來就會放大。

llama.cpp 在這條路徑用的是 f64 累加器,所以我也得跟著對齊——不然 RMSNorm 的輸出會跟它差個千分之一。你可能會想,才千分之一有差嗎?有喔,每層只差千分之一沒錯,但 22 層累積下來,輸出 token 就整個不一樣了。這又是一個「對齊既有實作」的例子

樂高積木二:matvec —— 用 rayon 平行化

pub fn matvec(out: &mut [f32], w: &TensorView<'_>, x: &[f32]) {
    out.par_iter_mut().enumerate().for_each(|(i, o)| {
        *o = w.dot_row(i, x);
    });
}

這是整個專案的效能瓶頸所在,偏偏又是最簡潔的一段程式碼——耗時最久的,往往長得最無辜。來拆解一下這幾行到底做了什麼:

  1. par_iter_mut() 來自 rayon,把 &mut [f32] 變成一個平行 iterator。
  2. enumerate() 加上 row index。
  3. for_each(|(i, o)| ...) 對每個 row 平行執行 closure。

每個 row 都是獨立的(out[i] 只被當前 closure 寫入),所以完全不需要任何鎖rayon 的 work-stealing 排程器會自動把 row 切成 chunk 分給可用的執行緒,這部分我們什麼都不用操心,頗省事。

為什麼 row-parallel 而不是 element-parallel?

如果改成 element-parallel(每個輸出元素一個 task),task 切太細,省下來的時間反而會被 rayon 的 task overhead 吃光。row-parallel 的 granularity 就剛剛好——每個 task 是一個完整的 dot product(4096 個乘加),夠粗,粗到值得啟動執行緒;又夠細,細到能平均分攤負載。

這裡 Rust 的 par_iter_mut 幫你擋掉了一個很容易踩的雷:你不能讓兩個 thread 同時寫 out 的同一個位置。而 par_iter_mut 的型別簽章保證每個 closure 拿到的是 &mut f32(單一元素的 mutable ref),編譯期就把 race condition 排除掉了。這種事在 C++ 裡得靠自己小心,在 Rust 裡編譯器直接幫你顧好,真是頗安心。

樂高積木三:softmax —— max-shift 防止 overflow

pub fn softmax(x: &mut [f32]) {
    let mut max = f32::NEG_INFINITY;
    for &v in x.iter() {
        if v > max { max = v; }
    }
    let mut sum = 0.0f32;
    for v in x.iter_mut() {
        *v = (*v - max).exp();
        sum += *v;
    }
    let inv = 1.0 / sum;
    for v in x.iter_mut() {
        *v *= inv;
    }
}

數學上,softmax 是:

$$\text{softmax}(x_i) = \frac{e^{x_i}}{\sum_j e^{x_j}}$$

e^{x} 對大 x 會 overflow(f32 的 exp 在 x > ~88 就溢位)。標準做法是先減 max:

$$\text{softmax}(x_i) = \frac{e^{x_i - \max(x)}}{\sum_j e^{x_j - \max(x)}}$$

數學上完全等價(分子分母同乘 $e^{-\max(x)}$),但這麼一減,所有 exp 的輸入都落在 $(-\infty, 0]$,永遠不會 overflow。這一步是寫 softmax 絕對不能省的——少了它,哪天遇到極端的 logits,輸出就是一片 NaN,到時候才在那邊抓半天,何必呢?

樂高積木四:RoPE —— 相對位置的旋轉

六塊樂高裡,這塊大概是最微妙的一個。RoPE(Rotary Position Embedding)是 Llama 用來把位置資訊塞進 attention 的方法,數學上是把每個 head 的 Q/K 向量看成一堆「複數對」,再旋轉一個和位置相關的角度。聽起來很玄,但程式碼其實沒幾行。

pub fn apply_rope(
    vec: &mut [f32], pos: usize, head_dim: usize,
    rot_dim: usize, base: f32, style: RopeStyle,
) {
    let half = rot_dim / 2;
    let n_heads = vec.len() / head_dim;
    for h in 0..n_heads {
        let off = h * head_dim;
        for i in 0..half {
            let freq = base.powf(-((2 * i) as f32) / rot_dim as f32);
            let theta = pos as f32 * freq;
            let (sin, cos) = theta.sin_cos();
            let (i0, i1) = match style {
                RopeStyle::Llama => (2 * i, 2 * i + 1),
                RopeStyle::Neox  => (i, i + half),
            };
            let v0 = vec[off + i0];
            let v1 = vec[off + i1];
            vec[off + i0] = v0 * cos - v1 * sin;
            vec[off + i1] = v0 * sin + v1 * cos;
        }
    }
}

核心想法:每對 (v[i0], v[i1]) 是一個 2D 向量,被旋轉一個角度 $\theta = \text{pos} \cdot \text{freq}_i$。不同的 i 用不同的 freq——i 越小 freq 越大(位置感越強)、i 越大 freq 越小(接近不變)。

頻率公式:

$$\text{freq}_i = \text{base}^{-2i / \text{rot\_dim}}$$

base 通常是 10000,rot_dim 是 head_dim(或某個比例)。i = 0 時 freq = 1(每移動一個 position 旋轉 1 弧度),i = rot_dim/2 - 1 時 freq ≈ 1/10000(要 6000 個 position 才旋轉 1 弧度)。

這種多尺度的旋轉設計實在是巧妙,讓 attention 既能感知近距離的細微位置差異、也能掌握遠距離的大致位置,可說是一兼二顧。

為什麼有兩種 style?

let (i0, i1) = match style {
    RopeStyle::Llama => (2 * i, 2 * i + 1),         // 鄰接成對
    RopeStyle::Neox  => (i, i + half),              // 上下半成對
};

這兩個是完全不同的「複數對」配對方式

  • Llama:把 v[0], v[1] 當一對、v[2], v[3] 當一對……(鄰接 pair)
  • Neox:把 v[0], v[half] 當一對、v[1], v[half+1] 當一對……(上下 split)

兩者旋轉的數學其實一模一樣,差的只是配對方式。那為什麼要兩種並存呢?這就是歷史共業了。

當初 HuggingFace 的 transformers 採用 NeoX 風格。但 llama.cpp 早期的 convert.py 在轉檔時會把 Q/K 矩陣的每對 row permute 成鄰接格式,這樣就能改用 Llama 風格做 RoPE,記憶體 access 更線性。

問題是新版的 convert_hf_to_gguf.py 不做 permute 了,所以你得乖乖用 NeoX 風格。搞錯了不會 crash,但模型會開始講胡話——這個雷我可是親自踩過的,debug 半天才發現是配對配反了,欲哭無淚。

順帶一提,sin_cos() 用的是 f32::sin_cos()——同時算 sin 和 cos,硬體上比分開算快一倍,這種免費的便宜當然要佔。

樂高積木五:SiLU —— SwiGLU 的活化函數

#[inline]
pub fn silu(x: f32) -> f32 {
    x / (1.0 + (-x).exp())
}

數學定義:

$$\text{SiLU}(x) = x \cdot \sigma(x) = \frac{x}{1 + e^{-x}}$$

$\sigma$ 是 sigmoid。SiLU 也叫 Swish。Llama 的 FFN 用的是 SwiGLU(Swish-Gated Linear Unit):

$$\text{FFN}(x) = W_{down}(\text{SiLU}(W_{gate}(x)) \odot W_{up}(x))$$

⊙ 是逐元素相乘。白話講就是:先把 x 過 gate 和 up 兩個投影,gate 那條過完 SiLU 後逐元素乘上 up,再一起過 down 投影。比起傳統的 ReLU FFN,SwiGLU 多了一個 gate 投影,乍看是多花了計算,但實驗顯示效果好得多——多算這一點,值得。

#[inline] 是給編譯器的小提示。這個函式會在 inner loop 被呼叫成千上萬次,inline 進去能省掉函式呼叫的 overhead,積少成多嘛。

樂高積木六:add_inplace —— 殘差連接

pub fn add_inplace(a: &mut [f32], b: &[f32]) {
    for (x, y) in a.iter_mut().zip(b.iter()) {
        *x += *y;
    }
}

六塊樂高裡最樸素的一塊:把 b 加進 a,沒了。別小看它,Transformer 的每個 sublayer 都靠它:

x = x + Attention(LayerNorm(x))
x = x + FFN(LayerNorm(x))

那兩個不起眼的加號,就是 add_inplace。殘差連接這麼重要的東西,實作起來竟然就一行 for 迴圈,想想還挺有意思的。

這裡我刻意用 iter_mut().zip(iter()) 而不是用 index。這個寫法的好處有三個:

  1. 沒有越界檢查——iterator 自動知道何時停止。
  2. 編譯器更容易自動向量化——iterator 模式對 LLVM 友善。
  3. 零成本抽象——這個寫法和手寫 raw loop 編譯出來幾乎一樣的組合語言。

Rust 用法:#[derive] 的小幫手

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RopeStyle {
    Llama,
    Neox,
}

Copy 讓我可以把 RopeStyle 當值傳遞,不用一直 && 去。Eq 讓它可以在 match 裡比較。Debugeprintln!("{:?}", style) 能直接印出來除錯。敲不到一秒的字,省下 10 行 boilerplate——這就是 Rust derive 機制最日常、也最香的用法。

效能最佳化空間

ops.rs 是另一個重要的最佳化戰場。這六個函式,個個都還有改進空間,來盤點一下:

1. RMSNorm 的 SIMD

那個 for &v in x { ss += v * v; } 是個 reduction,可以用 SIMD 平行化:

use std::simd::f32x8;
let mut acc = f32x8::splat(0.0);
for chunk in x.chunks_exact(8) {
    let v = f32x8::from_slice(chunk);
    acc += v * v;
}
let ss = acc.reduce_sum() as f64;

不過 f64 vs f32 累加器的精度問題,在 SIMD 環境下就變得有點微妙了——AVX-512 有 f64 SIMD,AVX2 卻沒有。我想最務實的做法還是維持 f32 SIMD reduction,實作簡單、誤差又在可接受範圍內(畢竟 4096 個 ~1e-2 量級的數,誤差還不至於大到讓 token 跑掉)。

2. matvec 的 GEMV → GEMM 升級

目前我的 matvec 是 GEMV(matrix × vector),每次 forward pass 對每層都得做兩次(attention 的 Q + FFN 的 down)。但prefill 階段的 prompt 是一整串 token,明明就可以批次處理——把 N 個 token 的 hidden state 拼成 [seq_len, n_embd] 矩陣,一次 GEMM 解決。

GEMM 比 GEMV 快上不少(典型 5-10×),關鍵在它能 reuse 權重(load 一次 row 的權重,就能對 seq_len 個輸入一起算)。只是這得回頭重新設計 Runner——現在是單 token forward,要改成 batch forward 才行,這筆帳之後再算吧。

3. softmax 的 SIMD exp

(*v - max).exp() 是逐元素的,理論上可以 SIMD 化。不過 SIMD exp 麻煩了點——得用多項式逼近(典型是 Cephes 或 Padé approximant 的 SIMD 版本)。懶得自己刻的話,sleef crate 有現成的 SIMD math 函式可以接上,不必重造輪子。

4. RoPE 的 sin/cos 預計算

每次 forward pass 都呼叫 theta.sin_cos()。但 freq_i 只和 i 有關,theta = pos * freq_i 也可以分解。如果預計算所有可能的 (pos, i) 的 sin/cos 並存成 table,runtime 只需要查表:

let cos_table: Vec<Vec<f32>> = (0..n_ctx).map(|pos| {
    (0..rot_dim/2).map(|i| {
        let freq = base.powf(-((2*i) as f32) / rot_dim as f32);
        (pos as f32 * freq).cos()
    }).collect()
}).collect();

以 TinyLlama 來說,這張表是 n_ctx × rot_dim/2 × 4 bytes = 2048 × 32 × 4 = 256 KB,啟動時花個一次性的時間算好,之後 runtime 通通查表就好。256 KB,L2 cache 就裝得下,划算。

5. add_inplace 也可以 SIMD

雖然只是個簡單的逐元素加,但 LLVM 對 iter_mut().zip() 模式的 auto-vectorization 不一定每次都會發生——它高興才幫你做。想穩一點,就明確寫 f32x8::from_slice(...) + f32x8::from_slice(...),保證 SIMD 化。

6. 把 forward pass 內整層 fuse 起來

最大的最佳化潛力,其實是 kernel fusion。比方說 attention + softmax 可以 fuse(FlashAttention 走的就是這條路);rmsnorm + matvec 也能 fuse 成一個 pass。不過 fusion 會把程式碼搞得超複雜,跟 tiny-llm-runner 主打的「易讀」目標整個背道而馳——所以這一塊就留給 candle/ggml 那些大人去玩吧,不是我這裡該碰的。

一個 Rust 特有的優勢:debug 與 release 的 debug_assert

我每個函式裡都塞了 debug_assert_eq!(out.len(), x.len()) 之類的檢查。妙的是 Rust 的 debug_assert! 在 release 會被完全消除——dev 時幫你抓 bug,上線後一毛 runtime 成本都不花。對 hot path 來說,這實在是個好用的工具。

C/C++ 當然也有 assert + NDEBUG,但 Rust 把這個好習慣直接做成語言預設——你不必記得去 #define NDEBUGcargo build --release 就自動幫你處理好。一個小細節,卻看得出語言設計的用心。

總結:ops.rs 的角色

  • 概念上:Transformer 的所有原語都在這(norm、matmul、softmax、RoPE、activation、add)。
  • 實作上:純函式、輸出參數、明確簽章、零隱藏狀態,看了就懂。
  • 設計上:把派發(量化型別 match)藏進 TensorView::dot_row 裡,ops.rs 自己只看得到 &[f32],乾乾淨淨。

回頭看,這六塊樂高每一塊都簡單到有點不可思議,但拼起來就是當今最強的語言模型。我想這正是 Transformer 最迷人的地方——複雜的不是零件,而是組合。Simple bricks, complex castles.

下一篇來講 runner.rs,那是把所有 ops 編排起來、發號施令的指揮中心。

系列文章: