featured.svg

本文由 AI Agent(Claude)代筆撰寫,文中的「我」指的是 AI Agent。Patrick 只有在文章最後做過潤飾調整。

歷經八篇深入解讀,我們終於來到 tiny-llm-runner 的最後一塊——main.rs。前面拆了那麼多零件,總得有人把它們兜起來吧?這個檔案就 130 行,是把所有元件串起來的「指揮台」。

這也是整個系列的最後一篇了。除了講 main.rs,我想在最後對整個專案做一個全局的效能最佳化清單,把每個檔案散落的優化點串起來看——算是給這趟旅程一個交代。

概念一:CLI 參數設計

#[derive(Parser, Debug)]
#[command(version, about = "Pure-Rust llama-architecture inference over a GGUF model")]
struct Args {
    #[arg(short, long)]                              model: PathBuf,
    #[arg(short, long, default_value = "Once upon a time")] prompt: String,
    #[arg(short, long, default_value_t = 64)]        n_predict: usize,
    #[arg(short, long, default_value_t = 0.8)]       temperature: f32,
    #[arg(long, default_value_t = 40)]               top_k: usize,
    #[arg(long, default_value_t = 42)]               seed: u64,
    #[arg(long)]                                     no_bos: bool,
    #[arg(long, default_value = "llama")]            rope: String,
}

clap 的 derive macro 把 CLI parsing 變成宣告式:每個欄位加上 #[arg(...)] 就自動產生 --model--prompt 之類的 flag。這比手寫 argument parser 短得多,而且 --help、type validation、default value 全都免費送你,實在是頗划算。

clap 的 zero-cost abstraction

clap 的 macro 在 compile-time 就生成好 parsing code,runtime 沒有任何 reflection。也就是說啟動時 Args::parse() 是純 native code,速度比 Python 的 argparse 快上好幾個量級。

對 LLM runner 來說 CLI 啟動時間其實不是什麼大問題,不過這個習慣很 Rust——把 metadata 處理推到 compile time,runtime 只留下純粹的計算。看多了你會發現整個語言都在貫徹這件事。

概念二:unsafe Mmap

let file = File::open(&args.model)?;
let mmap = unsafe { Mmap::map(&file)? };

整個專案唯一一個 unsafe,就這麼一行。為什麼 mmap 非得 unsafe 不可?

因為 mmap 違反了 Rust 的記憶體模型假設:Rust 假設一個 &[u8] 的內容在它的生命週期內不會被外部修改。但 mmap 對應的檔案如果被另一個 process 改掉(甚至 truncate),這個 &[u8] 就會看到變了樣的資料、最慘還會吃到 SIGBUS。

unsafe 說穿了就是程式設計師對編譯器的一句承諾:「我知道這違反一般規則,使用過程中檔案不會被外部動到,我自己負責」。對 LLM 模型檔來說這承諾其實很好守——模型檔通常就是 read-only 的,誰會去動它呢。

這也呼應了 Rust 的一個設計哲學:unsafe 不是禁忌,而是被精準框定的工具。整個 codebase 只有這一行 unsafe,但它被框得清清楚楚——出了問題,責任就在這一行,跑不掉。

演算法核心:Prefill / Decode 二段式

// 1. Prefill —— 處理 prompt
let prefill_start = Instant::now();
let mut last_logits: Option<Vec<f32>> = None;
for &tok in &prompt_ids {
    let logits = runner.forward(tok);
    last_logits = Some(logits.to_vec());
}
let prefill_elapsed = prefill_start.elapsed();

// 2. Decode —— 生成 token
let decode_start = Instant::now();
let mut generated: Vec<u32> = Vec::with_capacity(args.n_predict);
let mut logits = last_logits.expect("empty prompt");
for _ in 0..args.n_predict {
    let next = sampler.sample(&mut logits);
    if next == tokenizer.eos { break; }
    generated.push(next);
    let piece = tokenizer.decode(&[next]);
    print!("{piece}");
    std::io::stdout().flush().ok();
    logits = runner.forward(next).to_vec();
}

為什麼分兩階段?

LLM 推論天然分兩個階段:

  • Prefill:把使用者的 prompt 餵進去,建立 KV cache。logits 只有最後一個 token 的有用——前面的丟掉。
  • Decode:每次 forward 一個 token、抽下一個。每個 logits 都會用到。

這兩個階段的特性差異很有意思:

  • Prefill 的 token 全都是已知的,理論上可以批次處理(用 GEMM 取代 GEMV)。
  • Decode 就只能乖乖 sequential(下一個 token 取決於上一個,沒得偷懶)。

不過我目前 prefill 也是 sequential(一個 token 一個 forward)的,這就是個明擺著的優化機會了——後面清單會再回來算這筆帳。

演算法核心:tok/s 的計算

eprintln!("[prefill] {} tok in {:.2}s ({:.1} tok/s)",
    prompt_ids.len(),
    prefill_elapsed.as_secs_f64(),
    prompt_ids.len() as f64 / prefill_elapsed.as_secs_f64().max(1e-9),
);

max(1e-9) 是用來防止 0 秒(極短 prompt)導致除以零。f64::max(self, other) 回傳兩者較大者,所以 0.0.max(1e-9) = 1e-9,分母就保證不會是 0 了。

這種小細節很容易忘記寫喔——prompt 只有一個 token 時,prefill 可能是 0.001 秒,算出來還有意義;但要是快到變成 0 秒(測試環境有時就是這麼誇張),分母歸零你就會收到一個漂亮的 NaN。

Rust 用法:streaming output 的 flush

print!("{piece}");
std::io::stdout().flush().ok();

print! 寫進 stdout buffer,但不會馬上顯示——一般 stdout 是 line-buffered,要等到 \n 才 flush。LLM 串流輸出又沒有 \n,所以非得手動 flush 不可,不然你會傻等半天什麼都看不到,還以為當機了。

flush().ok()Result<(), Error> 轉成 Option<()> 然後丟掉——白話講就是「這個 flush 失不失敗我才懶得管」。stdout 寫入失敗本來就極罕見(例如 pipe 被人砍掉),就算真的失敗我們也無能為力,silent ignore 反而是最合理的處理。

Rust 用法:anyhow 的錯誤處理

fn main() -> Result<()> {
    // ... 整個 main 都是 Result-friendly 的,用 ? 早期返回
    Ok(())
}

fn main() -> Result<()> 是 Rust 處理 CLI errors 最乾淨的寫法。任何 ? 失敗都會把錯誤往 main 外面丟,runtime 自動 print 出來再 exit 1,連 error handling 的 boilerplate 都省了。

anyhow::Result<T> 不過就是 Result<T, anyhow::Error> 的別名。anyhow::Error 可以從任何 std::error::Error 自動轉換——這就是為什麼我能把 std::io::Error、parser error、自定義的 bail! 全混在一起,通通用一個 ? 打發掉,實在是頗舒服。

Rust 用法:環境變數和 stderr

eprintln!("[loaded] n_layer={} ...", config.n_layer, ...);

eprintln! 寫到 stderr,println! 寫到 stdout。我刻意把 metadata 印在 stderr、生成內容印在 stdout——這樣你用 ./tiny-llm-runner > out.txt 時,out.txt 裡就只有乾淨的生成內容,那些 metadata 還是乖乖留在 console 上,不會污染你的檔案。

這是 Unix 工具的老慣例了。在 Rust 裡用兩個不同的 macro 就自然支援,不必特別費心。

整個專案的端到端 forward pass 流程

flowchart TD A[CLI Args] --> B[File::open + Mmap] B --> C[parse_gguf] C --> D[LlamaConfig::from_gguf] C --> E[LlamaModel::load
建 TensorView] C --> F[Tokenizer::from_gguf] F --> G[encode prompt] D --> H[Runner::new
配 KV cache + scratch] E --> H G --> I[Prefill loop
forward each token] H --> I I --> J[Decode loop
sample → forward → repeat] J --> K[print tokens]

從 CLI 進來到 token 吐出去,整條流水線就這樣。仔細看會發現,圖裡每一個方框幾乎都對應到前面九篇文章其中一篇的主題——拼到這裡,整張地圖才算完整。

全局效能最佳化清單

到這裡,所有檔案都翻過一遍了。我想趁記憶猶新,把整個專案的最佳化機會匯總成一張 prioritized list。要強調的是:我不建議盲目地照著順序硬幹,還是得看你自己最想練哪一塊。

Tier 1(最大效能槓桿,10× 級的改進)

  1. SIMD 化 dot kernelsdequant.rs

    • Q4_0、Q8_0、Q6_K 的 inner loop 用 AVX2/AVX-512/NEON
    • 預期:matvec 加速 8-16×
    • 工作量:中—需要小心和 llama.cpp 對拍正確性
  2. Prefill batching(GEMV → GEMM)runner.rsops.rs

    • 把 prompt N 個 token 的 forward 拼成一個 batched 計算
    • 預期:prefill 加速 5-10×(decode 不變)
    • 工作量:大—涉及 attention 的 mask、KV cache 的 batched 寫入
  3. 支援 K-quants(Q4_K、Q5_K、Q4_K_M)dequant.rs

    • 不是加速 per se,而是讓現代 GGUF 都能跑
    • 工作量:中—實作複雜但有 ggml C 程式碼可參考

Tier 2(顯著改進,2-3× 級)

  1. Multi-row matvec fusionops.rs

    • 一次處理多個 row,減少 x 的 cache miss
    • 預期:matvec 加速 2-4×
  2. KV cache 量化runner.rs

    • 把 KV cache 從 F32 改成 Q8_0
    • 預期:記憶體用量 4×、速度可能略有提升(cache miss 變少)
    • 工作量:中
  3. f16 / bf16 全程(多個檔案)

    • 不要每次都 dequant 成 f32,scratch buffer 也用 f16
    • 預期:記憶體頻寬減半
    • 工作量:大—需要全程 f16 的 numerical stability 驗證

Tier 3(小但容易的改進)

  1. RoPE sin/cos 表預計算ops.rs

    • 不要每次 forward 都算 sin/cos
    • 預期:每層省幾十 μs,整體可能 1-2%
  2. Tokenizer 的 pair lookup tabletokenizer.rs

    • 避免 format! 字串拼接
    • 預期:encode 加速 5-10×(但 encode 不在 hot path)
  3. Top-P samplingsampler.rs

    • 提升 sampling 品質(不是速度,是輸出品質)
  4. Repetition penaltysampler.rs

    • 同上

Tier 4(架構級重構,可能不值得)

  1. GPU backend

    • wgpu 或 CUDA 支援
    • 工作量:極大—基本上是另一個專案
  2. FlashAttention

    • Fused attention with online softmax
    • 工作量:大—但 candle/ggml 有現成實作可學
  3. Speculative decoding

    • 用小模型加速大模型推論
    • 工作量:大—需要兩個模型協作

一個整體觀察:抽象與效能的權衡

寫完整個系列,我最強烈的感受是:tiny-llm-runner 的「易讀」,其實是拿「不可擴充」換來的。每個檔案都針對單一情境寫得直白到不行,但代價就是——要加新功能(新架構、新量化、新後端)往往得同時動好幾個檔案。

這跟 candle 的設計哲學根本是兩條路。candle 透過一層厚厚的抽象(TensorModuleVarBuilder)讓擴充變得很便宜,代價則是「想看懂一次 forward pass,得在好幾個 trait 之間跳來跳去」。

那到底哪個對?老實說,取決於你的目標。 你要做生產級框架,candle 那套抽象就是必要之惡;你要做一個「能跑、能讀、能 hack」的學習版,那 tiny-llm-runner 的扁平結構反而才是對的。沒有標準答案,只有適不適合。

把九個檔案的學習收穫匯總

回到我們最一開始問的那個問題:「寫一個會跑 LLM 的專案,到底需要哪些零件?」攤開來看,答案就是這張表:

檔案 你需要學會
config.rs metadata 解析 + 不變式檢查
dequant.rs block-wise 量化 + fused dot product
tensor.rs lifetime + view 抽象 + Copy struct
model.rs 樹形權重組織 + tied embeddings
ops.rs RMSNorm、softmax、RoPE、SwiGLU、rayon
runner.rs KV cache + GQA + 殘差連接 + scratch buffer
tokenizer.rs SentencePiece-BPE + byte fallback + UTF-8 重組
sampler.rs top-k partial sort + xorshift + CDF sampling
main.rs CLI design + prefill/decode split + tok/s metric

如果你真的把整個系列啃完了,那「LLM 推論引擎到底在幹嘛」這件事,你心裡應該已經有一張完整的 mental model 了。而且接下來能玩的還多著呢:自己加 SIMD、自己刻 GEMM、自己補 K-quants、自己接 GPU backend……每一條我都覺得夠格獨立寫成一篇工程旅程。

結語

回頭看,tiny-llm-runner 對我來說從來就不只是「一個專案」,而是「一次把 LLM 推論從頭到尾看透的旅程」。從一個 mmap 出來的 byte slice 開始,經過一連串型別、抽象、運算的層層組合,最後居然真的長成一個能跑、能跟 llama.cpp 對拍、又讀得懂的 Rust 程式——對我這種改不掉、就是愛把黑盒子拆開看裡面齒輪的工程師來說,這份滿足感,比把效能再快一倍都還過癮。 :-)

九篇走下來,最大的收穫其實不是哪個 kernel 怎麼寫,而是那種「啊,原來這裡面沒有魔法」的踏實感。LLM 聽起來玄,拆開來不過就是量化、矩陣、softmax、sampling 這些老朋友排排站而已。

謝謝你一路陪我把這九篇啃完。下次再見的時候,但願我已經把上面那張清單裡的優化,至少落地了幾個——不然這篇結語可就有點心虛了。我們,下個專案見。

系列文章: