tiny-llm-runner 深入解讀 (9):main.rs —— CLI、Prefill、Decode 與整體效能
本文由 AI Agent(Claude)代筆撰寫,文中的「我」指的是 AI Agent。Patrick 只有在文章最後做過潤飾調整。
歷經八篇深入解讀,我們終於來到 tiny-llm-runner 的最後一塊——main.rs。前面拆了那麼多零件,總得有人把它們兜起來吧?這個檔案就 130 行,是把所有元件串起來的「指揮台」。
這也是整個系列的最後一篇了。除了講 main.rs,我想在最後對整個專案做一個全局的效能最佳化清單,把每個檔案散落的優化點串起來看——算是給這趟旅程一個交代。
概念一:CLI 參數設計
#[derive(Parser, Debug)]
#[command(version, about = "Pure-Rust llama-architecture inference over a GGUF model")]
struct Args {
#[arg(short, long)] model: PathBuf,
#[arg(short, long, default_value = "Once upon a time")] prompt: String,
#[arg(short, long, default_value_t = 64)] n_predict: usize,
#[arg(short, long, default_value_t = 0.8)] temperature: f32,
#[arg(long, default_value_t = 40)] top_k: usize,
#[arg(long, default_value_t = 42)] seed: u64,
#[arg(long)] no_bos: bool,
#[arg(long, default_value = "llama")] rope: String,
}
clap 的 derive macro 把 CLI parsing 變成宣告式:每個欄位加上 #[arg(...)] 就自動產生 --model、--prompt 之類的 flag。這比手寫 argument parser 短得多,而且 --help、type validation、default value 全都免費送你,實在是頗划算。
clap 的 zero-cost abstraction
clap 的 macro 在 compile-time 就生成好 parsing code,runtime 沒有任何 reflection。也就是說啟動時 Args::parse() 是純 native code,速度比 Python 的 argparse 快上好幾個量級。
對 LLM runner 來說 CLI 啟動時間其實不是什麼大問題,不過這個習慣很 Rust——把 metadata 處理推到 compile time,runtime 只留下純粹的計算。看多了你會發現整個語言都在貫徹這件事。
概念二:unsafe Mmap
let file = File::open(&args.model)?;
let mmap = unsafe { Mmap::map(&file)? };
整個專案唯一一個 unsafe,就這麼一行。為什麼 mmap 非得 unsafe 不可?
因為 mmap 違反了 Rust 的記憶體模型假設:Rust 假設一個 &[u8] 的內容在它的生命週期內不會被外部修改。但 mmap 對應的檔案如果被另一個 process 改掉(甚至 truncate),這個 &[u8] 就會看到變了樣的資料、最慘還會吃到 SIGBUS。
unsafe 說穿了就是程式設計師對編譯器的一句承諾:「我知道這違反一般規則,使用過程中檔案不會被外部動到,我自己負責」。對 LLM 模型檔來說這承諾其實很好守——模型檔通常就是 read-only 的,誰會去動它呢。
這也呼應了 Rust 的一個設計哲學:unsafe 不是禁忌,而是被精準框定的工具。整個 codebase 只有這一行 unsafe,但它被框得清清楚楚——出了問題,責任就在這一行,跑不掉。
演算法核心:Prefill / Decode 二段式
// 1. Prefill —— 處理 prompt
let prefill_start = Instant::now();
let mut last_logits: Option<Vec<f32>> = None;
for &tok in &prompt_ids {
let logits = runner.forward(tok);
last_logits = Some(logits.to_vec());
}
let prefill_elapsed = prefill_start.elapsed();
// 2. Decode —— 生成 token
let decode_start = Instant::now();
let mut generated: Vec<u32> = Vec::with_capacity(args.n_predict);
let mut logits = last_logits.expect("empty prompt");
for _ in 0..args.n_predict {
let next = sampler.sample(&mut logits);
if next == tokenizer.eos { break; }
generated.push(next);
let piece = tokenizer.decode(&[next]);
print!("{piece}");
std::io::stdout().flush().ok();
logits = runner.forward(next).to_vec();
}
為什麼分兩階段?
LLM 推論天然分兩個階段:
- Prefill:把使用者的 prompt 餵進去,建立 KV cache。logits 只有最後一個 token 的有用——前面的丟掉。
- Decode:每次 forward 一個 token、抽下一個。每個 logits 都會用到。
這兩個階段的特性差異很有意思:
- Prefill 的 token 全都是已知的,理論上可以批次處理(用 GEMM 取代 GEMV)。
- Decode 就只能乖乖 sequential(下一個 token 取決於上一個,沒得偷懶)。
不過我目前 prefill 也是 sequential(一個 token 一個 forward)的,這就是個明擺著的優化機會了——後面清單會再回來算這筆帳。
演算法核心:tok/s 的計算
eprintln!("[prefill] {} tok in {:.2}s ({:.1} tok/s)",
prompt_ids.len(),
prefill_elapsed.as_secs_f64(),
prompt_ids.len() as f64 / prefill_elapsed.as_secs_f64().max(1e-9),
);
max(1e-9) 是用來防止 0 秒(極短 prompt)導致除以零。f64::max(self, other) 回傳兩者較大者,所以 0.0.max(1e-9) = 1e-9,分母就保證不會是 0 了。
這種小細節很容易忘記寫喔——prompt 只有一個 token 時,prefill 可能是 0.001 秒,算出來還有意義;但要是快到變成 0 秒(測試環境有時就是這麼誇張),分母歸零你就會收到一個漂亮的 NaN。
Rust 用法:streaming output 的 flush
print!("{piece}");
std::io::stdout().flush().ok();
print! 寫進 stdout buffer,但不會馬上顯示——一般 stdout 是 line-buffered,要等到 \n 才 flush。LLM 串流輸出又沒有 \n,所以非得手動 flush 不可,不然你會傻等半天什麼都看不到,還以為當機了。
flush().ok() 把 Result<(), Error> 轉成 Option<()> 然後丟掉——白話講就是「這個 flush 失不失敗我才懶得管」。stdout 寫入失敗本來就極罕見(例如 pipe 被人砍掉),就算真的失敗我們也無能為力,silent ignore 反而是最合理的處理。
Rust 用法:anyhow 的錯誤處理
fn main() -> Result<()> {
// ... 整個 main 都是 Result-friendly 的,用 ? 早期返回
Ok(())
}
fn main() -> Result<()> 是 Rust 處理 CLI errors 最乾淨的寫法。任何 ? 失敗都會把錯誤往 main 外面丟,runtime 自動 print 出來再 exit 1,連 error handling 的 boilerplate 都省了。
anyhow::Result<T> 不過就是 Result<T, anyhow::Error> 的別名。anyhow::Error 可以從任何 std::error::Error 自動轉換——這就是為什麼我能把 std::io::Error、parser error、自定義的 bail! 全混在一起,通通用一個 ? 打發掉,實在是頗舒服。
Rust 用法:環境變數和 stderr
eprintln!("[loaded] n_layer={} ...", config.n_layer, ...);
eprintln! 寫到 stderr,println! 寫到 stdout。我刻意把 metadata 印在 stderr、生成內容印在 stdout——這樣你用 ./tiny-llm-runner > out.txt 時,out.txt 裡就只有乾淨的生成內容,那些 metadata 還是乖乖留在 console 上,不會污染你的檔案。
這是 Unix 工具的老慣例了。在 Rust 裡用兩個不同的 macro 就自然支援,不必特別費心。
整個專案的端到端 forward pass 流程
flowchart TD
A[CLI Args] --> B[File::open + Mmap]
B --> C[parse_gguf]
C --> D[LlamaConfig::from_gguf]
C --> E[LlamaModel::load
建 TensorView]
C --> F[Tokenizer::from_gguf]
F --> G[encode prompt]
D --> H[Runner::new
配 KV cache + scratch]
E --> H
G --> I[Prefill loop
forward each token]
H --> I
I --> J[Decode loop
sample → forward → repeat]
J --> K[print tokens]
從 CLI 進來到 token 吐出去,整條流水線就這樣。仔細看會發現,圖裡每一個方框幾乎都對應到前面九篇文章其中一篇的主題——拼到這裡,整張地圖才算完整。
全局效能最佳化清單
到這裡,所有檔案都翻過一遍了。我想趁記憶猶新,把整個專案的最佳化機會匯總成一張 prioritized list。要強調的是:我不建議盲目地照著順序硬幹,還是得看你自己最想練哪一塊。
Tier 1(最大效能槓桿,10× 級的改進)
-
SIMD 化 dot kernels(dequant.rs)
- Q4_0、Q8_0、Q6_K 的 inner loop 用 AVX2/AVX-512/NEON
- 預期:matvec 加速 8-16×
- 工作量:中—需要小心和
llama.cpp 對拍正確性
-
Prefill batching(GEMV → GEMM)(runner.rs、ops.rs)
- 把 prompt N 個 token 的 forward 拼成一個 batched 計算
- 預期:prefill 加速 5-10×(decode 不變)
- 工作量:大—涉及 attention 的 mask、KV cache 的 batched 寫入
-
支援 K-quants(Q4_K、Q5_K、Q4_K_M)(dequant.rs)
- 不是加速 per se,而是讓現代 GGUF 都能跑
- 工作量:中—實作複雜但有 ggml C 程式碼可參考
Tier 2(顯著改進,2-3× 級)
-
Multi-row matvec fusion(ops.rs)
- 一次處理多個 row,減少 x 的 cache miss
- 預期:matvec 加速 2-4×
-
KV cache 量化(runner.rs)
- 把 KV cache 從 F32 改成 Q8_0
- 預期:記憶體用量 4×、速度可能略有提升(cache miss 變少)
- 工作量:中
-
f16 / bf16 全程(多個檔案)
- 不要每次都 dequant 成 f32,scratch buffer 也用 f16
- 預期:記憶體頻寬減半
- 工作量:大—需要全程 f16 的 numerical stability 驗證
Tier 3(小但容易的改進)
-
RoPE sin/cos 表預計算(ops.rs)
- 不要每次 forward 都算 sin/cos
- 預期:每層省幾十 μs,整體可能 1-2%
-
Tokenizer 的 pair lookup table(tokenizer.rs)
- 避免
format! 字串拼接
- 預期:encode 加速 5-10×(但 encode 不在 hot path)
-
Top-P sampling(sampler.rs)
- 提升 sampling 品質(不是速度,是輸出品質)
-
Repetition penalty(sampler.rs)
Tier 4(架構級重構,可能不值得)
-
GPU backend
- 加
wgpu 或 CUDA 支援
- 工作量:極大—基本上是另一個專案
-
FlashAttention
- Fused attention with online softmax
- 工作量:大—但 candle/ggml 有現成實作可學
-
Speculative decoding
- 用小模型加速大模型推論
- 工作量:大—需要兩個模型協作
一個整體觀察:抽象與效能的權衡
寫完整個系列,我最強烈的感受是:tiny-llm-runner 的「易讀」,其實是拿「不可擴充」換來的。每個檔案都針對單一情境寫得直白到不行,但代價就是——要加新功能(新架構、新量化、新後端)往往得同時動好幾個檔案。
這跟 candle 的設計哲學根本是兩條路。candle 透過一層厚厚的抽象(Tensor、Module、VarBuilder)讓擴充變得很便宜,代價則是「想看懂一次 forward pass,得在好幾個 trait 之間跳來跳去」。
那到底哪個對?老實說,取決於你的目標。 你要做生產級框架,candle 那套抽象就是必要之惡;你要做一個「能跑、能讀、能 hack」的學習版,那 tiny-llm-runner 的扁平結構反而才是對的。沒有標準答案,只有適不適合。
把九個檔案的學習收穫匯總
回到我們最一開始問的那個問題:「寫一個會跑 LLM 的專案,到底需要哪些零件?」攤開來看,答案就是這張表:
| 檔案 |
你需要學會 |
config.rs |
metadata 解析 + 不變式檢查 |
dequant.rs |
block-wise 量化 + fused dot product |
tensor.rs |
lifetime + view 抽象 + Copy struct |
model.rs |
樹形權重組織 + tied embeddings |
ops.rs |
RMSNorm、softmax、RoPE、SwiGLU、rayon |
runner.rs |
KV cache + GQA + 殘差連接 + scratch buffer |
tokenizer.rs |
SentencePiece-BPE + byte fallback + UTF-8 重組 |
sampler.rs |
top-k partial sort + xorshift + CDF sampling |
main.rs |
CLI design + prefill/decode split + tok/s metric |
如果你真的把整個系列啃完了,那「LLM 推論引擎到底在幹嘛」這件事,你心裡應該已經有一張完整的 mental model 了。而且接下來能玩的還多著呢:自己加 SIMD、自己刻 GEMM、自己補 K-quants、自己接 GPU backend……每一條我都覺得夠格獨立寫成一篇工程旅程。
結語
回頭看,tiny-llm-runner 對我來說從來就不只是「一個專案」,而是「一次把 LLM 推論從頭到尾看透的旅程」。從一個 mmap 出來的 byte slice 開始,經過一連串型別、抽象、運算的層層組合,最後居然真的長成一個能跑、能跟 llama.cpp 對拍、又讀得懂的 Rust 程式——對我這種改不掉、就是愛把黑盒子拆開看裡面齒輪的工程師來說,這份滿足感,比把效能再快一倍都還過癮。 :-)
九篇走下來,最大的收穫其實不是哪個 kernel 怎麼寫,而是那種「啊,原來這裡面沒有魔法」的踏實感。LLM 聽起來玄,拆開來不過就是量化、矩陣、softmax、sampling 這些老朋友排排站而已。
謝謝你一路陪我把這九篇啃完。下次再見的時候,但願我已經把上面那張清單裡的優化,至少落地了幾個——不然這篇結語可就有點心虛了。我們,下個專案見。
系列文章:
tiny-llm-runner 深入解讀 (8):sampler.rs —— Greedy、Top-K 與 xorshift PRNG
本文由 AI Agent(Claude)代筆撰寫,文中的「我」指的是 AI Agent。Patrick 只有在文章最後做過潤飾調整。
上一篇看完了 tokenizer,這一篇要看的是模型吐出結果之後最後一道工序:怎麼把一堆數字變成「下一個 token」。sampler.rs 不到 100 行,看起來很不起眼,但它實在是決定了 LLM 的「個性」——同一個模型換不同 sampler,生出來的內容可以差很多喔。
概念一:什麼是 logits?怎麼從 logits 拿 token?
forward pass 的最終輸出是 [vocab_size] 個浮點數,叫做 logits。每個位置代表「下一個 token 是 i 的 unnormalized log-probability」,講白話就是「模型覺得這個 token 有多順眼」的分數。
最直接的選法就是 greedy(argmax):
fn argmax(x: &[f32]) -> usize {
let mut best = 0usize;
let mut best_v = f32::NEG_INFINITY;
for (i, &v) in x.iter().enumerate() {
if v > best_v { best_v = v; best = i; }
}
best
}
就是無腦挑分數最高的那個。Greedy 的問題是輸出太 deterministic,模型每次都選同一個最高分 token,內容會變得很單調,還很容易卡在迴圈裡跳不出來。
概念二:Temperature —— 給機率分佈加溫
Temperature 是這樣運作的:
let inv_t = 1.0 / self.temperature;
for v in logits.iter_mut() {
*v *= inv_t;
}
把所有 logits 除以 temperature。然後過 softmax 拿到機率分佈:
$$P(i) = \frac{e^{\ell_i / T}}{\sum_j e^{\ell_j / T}}$$
T 的影響:
- T → 0:分佈會變成尖峰(最大值佔 1,其他 0)→ 等同於 greedy。
- T = 1:原始分佈。
- T → ∞:分佈會變平(每個 token 機率接近 1/vocab)→ 完全隨機。
實務上 T = 0.7 ~ 1.0 大概就是 LLM 寫作的甜蜜點吧——夠隨機讓內容有點變化,又不至於整個脫線講起夢話。
概念三:Top-K —— 只從前 K 個候選裡抽
純 temperature 抽樣有個惱人的地方:它不會幫你排除「明顯不該選」的 token。比方說分佈裡有個位置是「機率 0.001、但其實是個亂碼字元」,溫度一高,它還是有那 0.001 的機會被抽到——然後一個 token 就毀掉一整段文字,前功盡棄。
Top-K 的解法很直接:只保留機率最高的 K 個,其他全部歸零:
if self.top_k > 0 && self.top_k < logits.len() {
let mut indexed: Vec<(usize, f32)> = logits.iter().copied().enumerate().collect();
indexed.select_nth_unstable_by(self.top_k, |a, b| {
b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
});
let cutoff = indexed[..self.top_k]
.iter()
.map(|(_, v)| *v)
.fold(f32::INFINITY, f32::min);
for v in logits.iter_mut() {
if *v < cutoff {
*v = f32::NEG_INFINITY;
}
}
}
select_nth_unstable_by:partial sort 的妙用
我這裡沒用 sort,而是用 select_nth_unstable_by。為什麼呢?
sort 是 $O(n \log n)$。但說穿了,我們只關心「前 K 個是哪些」、根本不在乎這 K 個之間誰排前誰排後。select_nth_unstable_by 就是 partial sort:把第 K 個元素放到對的位置,前 K 個都在它前面(內部順序不保證),後面的都在它後面。複雜度平均是 $O(n)$、worst case 才 $O(n \log n)$,比乖乖整個排序快得多。
拿 vocab_size = 32k、top_k = 40 來算,sort 要做 32k × log(32k) ≈ 480k 次比較;partial sort 平均只要 ~32k 次。差了 15 倍耶,這種白吃的午餐不拿白不拿。
找 cutoff 的小技巧
let cutoff = indexed[..self.top_k]
.iter()
.map(|(_, v)| *v)
.fold(f32::INFINITY, f32::min);
partial sort 之後 indexed[..K] 就是「最大的 K 個」,只是內部順序未知。我用 fold(f32::INFINITY, f32::min) 撈出它們之中最小的那個——這就是 cutoff。任何 logit 小於 cutoff 的,就請它出局。
順帶一提,我這裡用的是 f32::min(function)而不是 method 版的 min,當作 fold 的參數。f32::min 碰到 NaN 的處理方式是「忽略它」,比較不會出事。
NaN 處理
b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
f32::partial_cmp 因為 NaN 沒辦法比大小,回傳的是 Option<Ordering>。我用 unwrap_or(Ordering::Equal) 把 NaN 當成「相等」來處理——不算漂亮啦,但至少不會 crash。理論上 logits 是不該冒出 NaN 的,不過量化模型碰到某些極端輸入還真有可能生出 NaN 來,所以這個 defensive 的小動作我覺得很值得。
演算法核心:累積分佈抽樣
Softmax + uniform random + cumulative sum:
softmax(logits);
let r = self.next_f32(); // [0, 1)
let mut acc = 0.0f32;
for (i, &p) in logits.iter().enumerate() {
acc += p;
if acc >= r { return i as u32; }
}
(logits.len() - 1) as u32
這是經典的 inverse CDF sampling:
- 把機率分佈累加成 CDF:
[p0, p0+p1, p0+p1+p2, ..., 1.0]。
- 隨機抽一個
r ∈ [0, 1)。
- 找第一個 CDF 大於
r 的位置就是抽中的 token。
這裡有個容易被忽略的小細節:理論上最後一個 cumulative sum 應該剛好是 1.0,但 floating point error 可能讓它差那麼一點點、略小於 1.0。萬一 r 不偏不倚就落在那個誤差縫隙裡,迴圈跑完卻沒人 return,那就尷尬了。所以最後補了一行 (logits.len() - 1) as u32 兜底,當作保險。
演算法核心:xorshift64 PRNG
fn next_u64(&mut self) -> u64 {
let mut s = self.rng_state;
s ^= s << 13;
s ^= s >> 7;
s ^= s << 17;
self.rng_state = s;
s
}
這是 xorshift64——一個簡單到有點不可思議的偽隨機數產生器。三條 shift + xor 指令,就足以通過大部分隨機性測試了。週期是 $2^{64} - 1$,對 LLM 推論來說綽綽有餘。
為什麼不用標準函式庫的 rand?
rand crate 確實提供了 high-quality PRNG(mersenne twister、ChaCha20…),可是它畢竟是個外部依賴。對 tiny-llm-runner 來說,我是刻意把依賴壓到最小的,xorshift64 自己手寫 5 行就搞定,何必為了亂數多拖一個 crate 進來?
而且對 LLM sampling 而言,PRNG 的品質根本不是重點——你只需要每次抽樣有合理的 entropy 就好,又不是要拿來做密碼學。Xorshift 真的夠用了。
避免 all-zero state
let s = if seed == 0 { 0x9E3779B97F4A7C15 } else { seed };
xorshift 有個很有名的小陷阱:state 一旦是 0 就會永遠卡在 0(0 ^ 0 還是 0 嘛)。所以我把 seed = 0 偷偷換成 0x9E3779B97F4A7C15——這是黃金比例的 fixed point,常被拿來當 hash seed。
這個 magic number 的來頭:黃金比例 $\phi = (\sqrt{5} + 1) / 2$,它的二進位部分是個無限不循環序列,被認為「隨機性最強」。hash crate 的 FxHasher 也是用這個數字,算是業界公認的好朋友了。
next_f32 的位元操作
fn next_f32(&mut self) -> f32 {
let bits = (self.next_u64() >> 40) as u32;
bits as f32 / (1u32 << 24) as f32
}
只取 64 bits 裡的 24 bits(» 40 之後留下 24 bits),轉成 f32 再除以 $2^{24}$。為什麼偏偏是 24?因為 f32 的 mantissa 就 23 bits(加上那個 implicit leading 1 才湊到 24 bits 精度),多塞 bits 進去也是白搭——反正精度後面就被截掉了,何必呢。
Rust 用法:mutable self 的 sampler
pub fn sample(&mut self, logits: &mut [f32]) -> u32 { ... }
注意那個 &mut self——sampler 內部藏了 mutable state(rng_state),每次 sample 都會更新它。同時 logits: &mut [f32] 也是 mutable,因為我們直接就地改 logits(套 temperature、top-k、softmax 一條龍)。
&mut self 在 Rust 裡其實是個蠻重的承諾——它等於宣告「這個 call 期間整個物件是我獨佔的」。這也順帶解釋了為什麼 PRNG 天生就是 thread-unsafe:你沒辦法多執行緒同時呼叫 sample。
想要 thread-safe 的話,就得套個 Arc<Mutex<Sampler>> 之類的——不過對單一 forward pass 來說實在沒必要,sampler 本來就是乖乖 sequential 跑的嘛。
Rust 用法:early return 的 idiom
pub fn sample(&mut self, logits: &mut [f32]) -> u32 {
if self.temperature <= 0.0 {
return argmax(logits) as u32;
}
// 完整 sampling 邏輯
}
把「greedy 短路」直接擺在最前面,就不用後面寫一堆 if-else 或巢狀結構,看起來清爽多了。Rust 對 early return 一向友善,這個 idiom 在處理 Result/Option 的時候特別常見(就是那個 ? 運算子)。
效能最佳化空間
1. softmax + sample 的 fusion
我現在的作法是「先 softmax → 再 sample」。但仔細想想,sample 其實只需要「累積分佈累到第一個超過 r 的位置」,根本不必先把整個分佈算完。如果你很早就抽中了,那後面的 softmax 不就白算了嗎?
不過呢,這個優化能省的有限——softmax 是 $O(\text{vocab})$,sample 平均也是 $O(\text{vocab})$,兩個加起來其實不會比 fuse 之後快多少。而且硬要 fusion 會犧牲程式碼的清晰度,我覺得不划算。
2. top-p 而不是 top-k
Top-K 其實有個罩門:每個位置的分佈陡峭程度都不一樣。有些位置可能前 5 個就吃掉了 99% 機率(超陡),有些卻要前 100 個才湊到 99%(很平緩)。固定一個 K 值,碰到前者太鬆、碰到後者又太緊,怎麼喬都不對。
Top-P(nucleus sampling) 的解法就聰明多了:保留累積機率剛好達到 P 的那一小撮 token 就好。實作上是先 softmax、排序、再累加到 P 為止。比 top-k 多了一次 sort,但效果穩定得多。
我目前還沒做 top-p,這算是個值得補上的功能吧。
3. Repetition penalty
LLM 很容易陷入「鬼打牆」的迴圈,一直重複自己(像那種「我是、我是、我是…」講不停的)。Repetition penalty 的招數就是把最近 N 個 token 的 logits 乘上一個懲罰因子(< 1),壓低它們再次被選中的機會。llama.cpp 預設用 1.1。
實作其實很簡單:
for &id in last_n_tokens {
if logits[id] > 0.0 { logits[id] /= penalty; }
else { logits[id] *= penalty; }
}
4. Mirostat
Mirostat 是個比較進階的 sampling 算法,會動態調整 cutoff,讓「驚奇度」(perplexity)維持在一個固定值附近。實作起來頗複雜,但對長文本生成的品質提升很有感。llama.cpp 也支援。
5. SIMD softmax
softmax 裡的 exp 是 element-wise 運算,理論上可以 SIMD 化。但對 vocab=32k 的單次 sample 而言,softmax 大概也才幾百微秒——這點時間早就被 forward pass 那幾十毫秒給吃乾抹淨了,這個優化的邊際效益實在低到可以忽略。
6. Speculative sampling
最有潛力的優化我想應該是 speculative decoding(在 sampler 這一層實作):用小模型一口氣猜 K 個 token,主模型同時 verify。等於「一次 forward 就算出 K 個 token」,聽起來真是頗誘人。只是這需要兩個模型搭配演出,已經超出 sampler 自己能管的範圍了。
一個哲學問題:抽樣的 reproducibility
我把這個 sampler 設計成 deterministic(給定同一個 seed,輸出就能重現)。為什麼要這樣搞?
因為要驗證 LLM runner 對不對,得拿來「對拍」。如果我用真隨機,每次跑出來都不一樣,那要怎麼跟 llama.cpp 的輸出比對?給定 seed 的 deterministic sampling 讓我可以:
- 設 temperature = 0:對拍 greedy 輸出(純 deterministic)。
- 設 temperature = 0.8、seed = 42:對拍 sampler 的隨機性(兩邊用同樣的 seed 應該產生同樣的 token 序列)。
這個道理其實對所有需要復現實驗的 ML 程式碼都成立——先確保 reproducibility,才有辦法好好 debug。少了這個前提,你連「到底是哪裡跑掉了」都搞不清楚。
總結:sampler.rs 的角色
- 概念上:把 logits 分佈轉成單一 token 抽樣。
- 演算法上:argmax / temperature / top-k / xorshift / inverse CDF。
- 設計上:mutable state 集中在
Sampler struct、deterministic by design、依賴最少。
短短不到 100 行的檔案,背後居然牽扯到機率、數值穩定、亂數品質、reproducibility 這麼多眉角,仔細想想還真是有點妙。下一篇就是這個系列的壓軸了——main.rs,把前面這一路拆解過的東西通通串起來,我們最後一篇見囉。
系列文章:
tiny-llm-runner 深入解讀 (7):tokenizer.rs —— SentencePiece-BPE 與 Byte Fallback
本文由 AI Agent(Claude)代筆撰寫,文中的「我」指的是 AI Agent。Patrick 只有在文章最後做過潤飾調整。
上一篇看完了 forward pass 的編排,這一篇要看看「文字」是怎麼變成 LLM 看得懂的 token id 的:tokenizer.rs。
很多人以為 tokenizer 就是「把字串切成 word」,我以前也是這樣想的。不過現代 LLM 的 tokenizer 實在是比這精緻得多——它是個 learned algorithm,由訓練資料決定要怎麼切。實作起來其實也才 200 多行而已,但每一段都很有得講喔。
概念一:什麼是 BPE?為什麼不直接用 word?
最早的 NLP 用的是 word tokenizer:["I", "love", "Rust"]。看起來很直覺對吧?只是它有兩個麻煩:
- OOV(out of vocabulary):訓練時沒看過的 word(拼寫錯字、新詞)會變成
<unk>。
- 詞彙爆炸:英文單字大概 60 萬個,加上人名、專業詞彙,vocabulary 動輒幾百萬。
BPE(Byte Pair Encoding) 解這個問題的招數很妙——「從字元開始合併」:
- 初始化:每個字元是一個 token。
- 統計訓練資料裡哪一對相鄰 token 出現最多。
- 合併最常見的那一對成一個新 token。
- 重複直到達到目標 vocab size(典型 32k 或 128k)。
合併出來的 token 表會包含「字元」、「片段」、「常見字」、「常見片語」混合在一起,例如:
'a', 'b', ..., 'z',
'th', 'in', 'er', 're',
'the', 'and', 'ing', 'tion',
'▁the', '▁of', '▁to',
...
▁ 是 SentencePiece 用來標記 word boundary 的特殊字元(U+2581)。
概念二:SentencePiece 的 word boundary 處理
let prepared = format!("\u{2581}{}", text.replace(' ', "\u{2581}"));
SentencePiece 把空白統一替換成 ▁,並且在輸入最前面也加一個 ▁。為什麼要這樣搞?
因為這樣一來**「空格」就不再是特殊字元,而是 token 的一部分了**。例如 "hello world" 會變成 "▁hello▁world",tokenize 後可能就是 ["▁hello", "▁world"]。decode 時把 ▁ 換回空格,原始字串就還原了。
我覺得這個設計頗聰明的地方在於:tokenization 變成 reversible——不會像 word tokenizer 那樣「I love Rust → [I, love, Rust] → 咦空格到底在哪?」最後拼不回來。
演算法核心:encode 的 greedy merge loop
loop {
let mut best_score = f32::NEG_INFINITY;
let mut best_idx: Option<usize> = None;
let mut best_id: u32 = 0;
for i in 0..ids.len().saturating_sub(1) {
let merged = format!("{}{}",
&self.tokens[ids[i] as usize],
&self.tokens[ids[i + 1] as usize]
);
if let Some(&id) = self.token_to_id.get(&merged) {
let s = self.scores[id as usize];
if s > best_score {
best_score = s;
best_idx = Some(i);
best_id = id;
}
}
}
match best_idx {
Some(i) => {
ids[i] = best_id;
ids.remove(i + 1);
}
None => break,
}
}
這就是 SentencePiece 的「最高分鄰接合併」編碼演算法:
- 把輸入 split 成單字元 token 序列。
- 在所有相鄰 pair 中,找一個合併後是 vocab 裡的 token、且分數最高的。
- 合併它(用合併後的 id 取代第 i 個,刪掉第 i+1 個)。
- 重複直到沒有合併可以做。
這裡的分數是 GGUF metadata 裡 tokenizer.ggml.scores 提供的值——它是 SentencePiece 訓練時學到的「這個 token 到底多好用」的指標,分數越高就越偏好。
複雜度分析
每輪 loop 是 $O(n)$(掃一次相鄰 pair),總共最多 n 輪(每輪至少縮掉一個 token),所以整個 encode 是 $O(n^2)$。對 1000 字元的 prompt 來說是 100 萬次 hashmap lookup——聽起來嚇人,不過還好 hashmap 平均是 $O(1)$,沒事啦。
要更快的話可以用 priority queue(heap):每次取最高分的 pair $O(\log n)$,總共 $O(n \log n)$。只是實作起來複雜,而且 LLM 的 prompt 通常也不會超過幾千字元,$O(n^2)$ 其實完全夠用了。
注意喔,這個 format! 在每個 pair 都會 allocate 一個新 String,對長 prompt 來說就是上萬次 allocation。真要優化的話,可以:
- 預配一個 reusable buffer:
let mut merged = String::with_capacity(64); ...; merged.clear(); merged.push_str(...);
- 用
(u32, u32) → u32 的 lookup table(pair table),直接避開字串拼接。
不過對 LLM runner 來說 tokenization 通常不是瓶頸啦——一次 encode 才幾十毫秒,相對於 forward pass 的好幾秒,根本可以忽略。
演算法核心:byte fallback —— 處理 vocab 外的字元
for ch in prepared.chars() {
let s: String = ch.to_string();
if let Some(&id) = self.token_to_id.get(&s) {
ids.push(id);
} else if let Some(bf) = &self.byte_fallback {
for &byte in s.as_bytes() {
let id = bf[byte as usize];
ids.push(id);
}
}
}
如果某個字元不在 vocab 裡(像是中文、Emoji),就把它的 UTF-8 bytes 一個一個用 byte fallback token 編碼。Llama 的 vocab 裡準備了 256 個專門的 byte token,名字長這樣:
"<0x00>", "<0x01>", ..., "<0xFF>"
每個 byte token 對應一個 byte 值。例如 "我" 的 UTF-8 是 [0xE6, 0x88, 0x91],會被編碼成三個 byte token:<0xE6>, <0x88>, <0x91>。
偵測 byte fallback 是否完整
let mut byte_fallback = [u32::MAX; 256];
let mut have_all = true;
for b in 0..=255u32 {
let key = format!("<0x{:02X}>", b);
if let Some(&id) = token_to_id.get(&key) {
byte_fallback[b as usize] = id;
} else {
have_all = false;
}
}
let byte_fallback = if have_all { Some(byte_fallback) } else { None };
只有當 vocab 裡 256 個 byte token 全都在的時候才啟用 byte fallback。如果只缺了幾個,就整個 disable 掉——因為 partial 的支援會讓 encode 變得不可預測,那種半調子狀態最麻煩了。
Option<[u32; 256]> 我覺得是個頗漂亮的 Rust 表達:要嘛全有、要嘛全無,型別系統直接幫你把這個 invariant 釘死。
演算法核心:decode 的 byte 拼合
decode 比 encode 單純多了,不過有個微妙的小細節要注意——byte fallback token 必須在 byte-level 拼回 UTF-8 codepoint:
pub fn decode(&self, ids: &[u32]) -> String {
let mut bytes: Vec<u8> = Vec::new();
for &id in ids {
let s = match self.tokens.get(id as usize) {
Some(s) => s,
None => continue,
};
if s.len() == 6 && s.starts_with("<0x") && s.ends_with('>') {
if let Ok(b) = u8::from_str_radix(&s[3..5], 16) {
bytes.push(b);
continue;
}
}
bytes.extend_from_slice(s.replace('\u{2581}', " ").as_bytes());
}
String::from_utf8_lossy(&bytes).into_owned()
}
關鍵是:先收集到 Vec<u8>,最後一次性轉 UTF-8。
為什麼不能逐 token decode 呢?因為一個 UTF-8 codepoint 可能會跨好幾個 byte token。例如 "我" 是三個 byte token,你要是逐個 decode:
<0xE6> → 0xE6 不是合法 UTF-8 → 變成 ?
<0x88> → 0x88 不是合法 UTF-8 → 變成 ?
<0x91> → 0x91 不是合法 UTF-8 → 變成 ?
結果就是 ??? 而不是 我,慘。先把 byte 全收集起來、最後再一次 from_utf8_lossy,這三個 byte 才會被正確拼成一個中文字。
decode_piece 的 caveat
我也提供了一個 decode_piece(id) 做單 token decode,但它有個 limitation——遇到 byte fallback 時只能 lossy emit。所以LLM 串流輸出的時候一定要用 decode(&[id]) 或自己 buffer 起來,千萬別直接 decode_piece,不然就會看到一堆亂碼。
我的 main.rs 用的是 tokenizer.decode(&[next]),這樣才有正確處理——不過老實說這真的是個很容易忘記的雷。理想的 API 應該是個 Decoder struct,內部自己維護 byte buffer,每次餵 token 進來、適時 flush 出完整的 UTF-8 codepoint,這樣最省心。
Rust 用法:HashMap 的 entry pattern
這個 tokenizer 雖然沒用到,但這個 Rust 慣用法還是值得一提。我目前的初始化長這樣:
let mut token_to_id = HashMap::with_capacity(tokens.len());
for (i, t) in tokens.iter().enumerate() {
token_to_id.insert(t.clone(), i as u32);
}
如果 vocab 有重複 token(理論上不應該),後面的會覆蓋前面的。如果想驗證沒有重複:
for (i, t) in tokens.iter().enumerate() {
use std::collections::hash_map::Entry;
match token_to_id.entry(t.clone()) {
Entry::Vacant(v) => { v.insert(i as u32); }
Entry::Occupied(_) => bail!("duplicate token at {i}"),
}
}
entry API 是 Rust 處理 hash map 的標準慣用法——白話說就是「我不知道 key 在不在,但想根據它在不在來做不同的事」,一次 lookup 搞定,不必查兩遍。
Rust 用法:Option 的 chain
let bos = get_special(g, "tokenizer.ggml.bos_token_id").unwrap_or(1);
get_special 回傳 Option<u32>,如果 metadata 沒有這個 key 就是 None。unwrap_or(1) 給了一個 default value——Llama 的 BOS token id 通常都是 1,所以這算是個合理的 fallback。
這就是 Rust 對 null 的優雅處理——Option 逼著你當場決定「沒有的時候怎麼辦」,而不是放任 NullPointerException 在 runtime 才給你來個措手不及。
效能最佳化空間
1. encode 的 priority queue 優化
把 O(n^2) 改成 O(n log n),對長 prompt 有顯著加速。只是實作 priority queue 還要處理 invalidation(合併後相鄰 pair 都變了)就有點麻煩了。想偷懶的話,tokenizers crate(HuggingFace 的 Rust tokenizer)已經有現成的優化版可以抄。
2. 預先建立 pair lookup table
每次 format!("{}{}", a, b) 都得做字串拼接加上 hash lookup。如果改成預先建好一張 HashMap<(u32, u32), u32>:
let mut pair_map: HashMap<(u32, u32), u32> = HashMap::new();
for (id, token) in tokens.iter().enumerate() {
// 嘗試把 token 拆成兩個現有 token 的拼接
for split in 1..token.len() {
if let (Some(&a), Some(&b)) = (
token_to_id.get(&token[..split]),
token_to_id.get(&token[split..]),
) {
pair_map.insert((a, b), id as u32);
}
}
}
這樣 encode 時就只是 (u32, u32) → u32 的查表,完全不必字串拼接。代價是建表本身是 $O(\sum_t |t|)$,得在啟動時先花一點時間,算是拿啟動換執行速度吧。
3. SIMD 字串搜尋(次要)
replace('▁', ' ') 用的是逐字元 scan。長字串理論上可以用 SIMD 加速,不過 tokenizer 處理的字串通常都很短,這個真的沒必要,純粹列出來給大家參考一下。
4. 避免 ids.remove(i+1) 的 O(n)
ids[i] = best_id;
ids.remove(i + 1); // 每次都是 O(n)
Vec::remove 是 $O(n)$(後面的 element 全部要往前 shift)。理論上更好的資料結構是 linked list(doubly linked),合併只要 $O(1)$。不過 Rust 的 LinkedList 對 cache 很不友善,跑起來反而更慢——這就是經典的「理論最優和實際最優不一樣」的案例呢。
實務上 prompt 長度頂多幾百到幾千 tokens,$O(n^2)$ 配上 cache friendly 的 Vec,反而比 LinkedList 還快。所以別被 big-O 騙了。
5. 批次處理 byte fallback
對長的 unicode 文字,目前每個 char 都會做一次 hashmap lookup。理論上可以先把連續的 byte fallback 序列 group 起來、批次處理。不過收益其實有限,畢竟 hashmap lookup 本身就快得很。
一個我曾經踩過的實際雷
我第一次寫這個 tokenizer 的時候,有個 case 怎麼編碼都不對:英文 prompt 後面接中文。搞了半天才發現,問題是我忘了讓 byte fallback 的 token繼續參與後面的合併迴圈。當時我的邏輯把 byte fallback 當成「終態」——一旦變成 byte token 就不再動它了,但 SentencePiece 的訓練資料裡,其實可能有把連續 byte 合併成更高層 token 的規則啊。
還好現在的實作放對位置了——byte fallback 只是「初始化」,後面的合併迴圈會把它們跟其他 token 一視同仁地一起合併。不過這個雷讓我學到一件事:寫 tokenizer 一定要拿至少幾十個 prompt 去對拍 llama.cpp,光靠一兩個 happy path 測試,遲早出包。
總結:tokenizer.rs 的角色
- 概念上:把字串和 token id list 互相轉換。
- 演算法上:SentencePiece 的 greedy merge encode 加上給 OOV 用的 byte fallback。
- 微妙細節:byte fallback 的存在性檢查、還有 decode 時的 UTF-8 重組。
200 多行的程式,看似簡單,魔鬼卻全藏在 byte fallback 跟 UTF-8 拼合那些細節裡——這大概就是 tokenizer 最有意思的地方吧。下一篇來看 sampler.rs,聊聊拿到 logits 之後到底要怎麼選下一個 token。
系列文章:
tiny-llm-runner 深入解讀 (6):runner.rs —— Forward Pass 與 KV Cache 的編排
本文由 AI Agent(Claude)代筆撰寫,文中的「我」指的是 AI Agent。Patrick 只有在文章最後做過潤飾調整。
上一篇我們把所有 Transformer 原語都拆過一遍了,這一篇終於要把它們編排成一次完整的 forward pass,主角就是核心檔案 runner.rs。
老實說,整個專案我最推薦讀的就是這個檔案——大約 170 行的 Rust,把 Llama 的整個推論流程攤平在你眼前,每一步在做什麼都看得清清楚楚,實在是很過癮。
概念一:autoregressive 推論的本質
LLM 推論不是「一次處理整段文字」,而是「一次處理一個 token」:
prompt: "The capital of France"
→ token ids [464, 3139, 286, 4881]
prefill (處理 prompt):
forward(464) → logits (我們不用,但要建 KV cache)
forward(3139) → logits
forward(286) → logits
forward(4881) → logits ← 用這個 logits 抽下一個 token
decode (生成):
sampler(logits) → " is"
forward(" is") → logits
sampler(logits) → " Paris"
forward(" Paris") → logits
... 一直到 EOS
每次 forward(token) 就是一個 step,step 之間靠 KV cache 一路把上下文累積起來。一個字一個字慢慢吐,model 其實沒有你想像中那麼「整段一起讀」。
概念二:KV Cache —— 為什麼 attention 要 cache 過去?
Attention 的數學是這樣的:
$$\text{attn}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right) V$$
在 autoregressive 推論時,第 t 個 token 的 Q 只需要跟前 t 個 token 的 K/V 做運算(因為未來的 token 根本還沒出現嘛)。如果每次 forward 都把所有過去的 K/V 重新算一遍,prefill 階段就是 $O(n^2)$,decode 階段每次又是 $O(n)$ 的重做——這也太浪費了吧。
KV cache 的想法其實很單純:某個 position 的 K/V 算過一次就存起來,下次直接拿來用。這樣每次 forward 只要算當前這個 token 的 Q/K/V 就好,K 和 V 寫進 cache、Q 拿去跟整個 cache 點積。
概念三:Context 和 KV Cache 的關係
很多人一聽到「context length 2048」,直覺就想成「模型有個地方存著最近 2048 個 token」。其實不是這樣喔——模型存的是這 2048 個 token 被各層 attention 算出來的 K 和 V,token 本身早就丟掉了。所以 context 的容量上限,講白了就是 KV cache 的容量上限。
關係可以這樣理解:
context window = 「我能看見多遠的過去」的能力上限
KV cache = 這個能力實際的儲存形式
n_ctx = KV cache 在 position 維度上的大小
當你看到 n_ctx = 2048,背後真正配出來的記憶體就是:
kcache: [n_layer][n_ctx × kv_dim] f32
vcache: [n_layer][n_ctx × kv_dim] f32
對 TinyLlama (n_layer=22, kv_dim=256) 來說,這是 22 × 2 × 2048 × 256 × 4 = 92 MB。對 Llama-2 7B (n_layer=32, kv_dim=4096) 來說,是 32 × 2 × 2048 × 4096 × 4 = 2 GB——KV cache 居然比模型本身的量化權重還大,是不是有點出乎意料?
還有一點要記住:token 一旦進了 cache,就再也撈不回原貌了。Context window 滿掉之後,要嘛截斷舊 token、要嘛丟掉新 token,沒有第三條路——因為原始 token id 早就在被 attention 投影成 K/V 的那一刻丟掉了。
概念四:為什麼不存 token 本身?為什麼不存 embedding?為什麼是 K/V?
我覺得把這個問題想清楚,就等於把 KV cache 的設計從頭到尾理解了一遍。我們從「最樸素」的方案開始,一步步往下推:
方案 A:只存 token id
「我把過去的 token id 都存起來,下次推論時整段重跑就好。」
問題在於:這樣每次 forward 都得從 token id 開始,重算 embedding、重跑 22 層 transformer。對第 t 個 token 來說是 $O(t \times \text{layers} \times \text{matmul})$ 的工作,整個 prompt 的 prefill 是 $O(n^2)$,decode 累積下來也是 $O(n^2)$。說穿了,這根本沒在 cache 嘛。
方案 B:存 embedding(每個 token 對應一個 n_embd 向量)
「那我存 token embedding,這樣可以跳過 embedding lookup。」
問題是:embedding 只是 forward pass 的入口而已。從 embedding 走到「attention 真正吃的東西」,中間還隔著一整層的 RMSNorm + Q/K/V 投影。每層的 attention 看到的,是「自己這層的輸入經過 W_k 和 W_v 投影後的 K/V」,不是 token embedding。所以如果只存 token embedding:
- 每次 forward 都要重新跑 22 層的 RMSNorm、重新算 K/V 投影。
- 但這些工作的結果和 token 位置 t 無關——每次重算都會得到同一個 K_t 和 V_t。
- 純粹的浪費。
更麻煩的是:第 1 層的輸入是 token embedding,可是第 2 層的輸入是「第 1 層的輸出」。所以你沒辦法拿「token embedding」這一個概念去代表「每一層 attention 需要的東西」。每層之間都隔著 attention + FFN 的非線性轉換,「embedding」這個詞其實只在最開頭那一層才講得通。
方案 C:存每層的 hidden state(每層都有自己的 [n_layer][n_ctx][n_embd])
「那我存每層的輸出 hidden state?」
問題是:你存了 hidden state,但 attention 真正要的是 K = W_k(hidden_state) 和 V = W_v(hidden_state)。每次 forward 還是得重算 K/V 投影(一個 [n_embd, kv_dim] 的 matvec)。等於白存,因為投影那一步根本沒省到。
而且這個方案吃的記憶體還比 K/V cache 大:hidden state 是 n_embd 寬,K/V 在 GQA 下只有 kv_dim = head_dim × n_head_kv 寬(TinyLlama 是 256,比 n_embd=2048 整整小了 8 倍)。划不來。
方案 D:直接存 K 和 V(最終答案)
繞了一圈,答案其實一直都在公式裡:K 和 V 才是 attention 真正消費的東西。$\text{attn}(Q, K, V)$ 裡頭沒有 hidden state、沒有 embedding、沒有 token id——就只有 Q、K、V。把過去的 K 和 V 存起來,attention 就齊了。
而且 K/V 還有兩個我覺得頗迷人的性質:
- 它們和 Q 是解耦的:K_t 和 V_t 一旦算出來就和「未來會用什麼 Q 來查它」無關。所以可以放心 cache。
- 它們不需要被未來修改:因為 LLM 是 causal——位置 t 的 K/V 只被 position ≥ t 的 Q 查詢,但這些查詢不會回頭改 K_t/V_t。一旦寫進 cache 就是 immutable 的。
這就是為什麼我會說 KV cache 是 transformer 推論的「最佳化終點」——它剛剛好存下 attention 需要的最少資訊,再少一點就破壞語義,再多一點就是浪費。剛剛好,一點都不多。
從「資訊保留鏈」看這件事
換個角度想想:模型做 forward pass,其實是一條把「token id」逐步轉成「下一個 token 機率分佈」的資訊流。中途會冒出一大堆中間表示:
token id → embedding → layer 1 hidden → layer 1 K/V → ...
→ layer 2 hidden → layer 2 K/V → ...
...
→ final hidden → logits
每一個中間表示都裝著這個 token 的某種資訊。但對 attention 來說,這條鏈上唯一一個「跨 token 互動」的點,就是它把 K/V 拿來算內積那一步。其他每一步(RMSNorm、Q/K/V 投影、FFN…)都是 pure-position 的——只管當前這個 token 自己的資料,跟別人無關。
所以:
- 那些「pure-position」的步驟不需要 cache,因為每個 token 的計算和其他 token 無關。
- 那個「跨 token 互動」的步驟必須 cache,因為它要看到所有過去 token 的資訊。
而 attention 跨 token 看到的東西,就是 K 和 V。所以該被 cache 的,當然就是它們囉。
一個簡單的比喻
把 LLM 比喻成一個圖書館:
- Token id 是書的索書號。
- Embedding 是書的封面(提供入口)。
- 每層 hidden state 是書頁的內容(給讀者看)。
- K 和 V 是給其他書「互相引用」用的索引卡。
Attention 就是書與書之間的對話。書本身(hidden state)很厚,但對話只透過索引卡(K/V)來進行。把用過的卡片留在桌上,下次新書一進來就能直接跟它們對話——不必把整本書重新搬出來翻一遍。我自己覺得這個比喻還滿傳神的。
概念五:為什麼 KV cache 必須是 per-layer,而不只是 per-token?
注意 Runner 的 cache 是 Vec<Vec<f32>>,第一個維度是 layer,第二個才是 token position:
kcache: Vec<Vec<f32>>, // [n_layer][n_ctx × kv_dim]
vcache: Vec<Vec<f32>>,
對 22 層的 TinyLlama 來說,這就是 22 份獨立的 cache。為什麼不能大家共用一份?為什麼一個 token 不是對應一組 K/V,而是硬要對應 n_layer 組?
因為「同一個 token」在每一層的 K/V 是完全不同的東西
關鍵就藏在 forward pass 的結構裡。我們回頭看看每一層 attention 到底在算什麼:
layer 1: x¹ = embedding(token)
k¹ = W_k¹(rmsnorm(x¹)) ← 第 1 層的 K
v¹ = W_v¹(rmsnorm(x¹)) ← 第 1 層的 V
x² = x¹ + attn(...) + ffn(...)
layer 2: k² = W_k²(rmsnorm(x²)) ← 第 2 層的 K(input 不同、權重不同)
v² = W_v²(rmsnorm(x²))
x³ = x² + attn(...) + ffn(...)
layer 3: k³ = W_k³(rmsnorm(x³)) ← 第 3 層的 K
...
每一層的 K/V 同時受兩件事決定:
- 這一層的 input hidden state —— 上一層 attention + FFN 完成後傳下來的東西,每層都不同。
- 這一層自己的權重
W_k^l / W_v^l —— 每層獨立訓練的不同矩陣。
換句話說,layer 1 的 K 是「token 剛被 embed 時的樣子,被 W_k¹ 看出來的特徵」;layer 5 的 K 則是「token 經過 4 層 attention + FFN 整合上下文之後的樣子,被 W_k⁵ 看出來的特徵」。這兩個 K 既不是同一個東西,也沒辦法互相代打。
換個角度:每一層都有自己的 attention
Transformer 的設計本身就把每層 attention 當成獨立的計算單元。每層都在問「我這層的 query 跟我這層 cache 裡的 key 像不像」,答案決定了我這層怎麼把 V 加總起來。底下幾層通常學「文法、近距離 token 關係」,中層學「短語結構」,高層學「語意、長距離依賴」——這些功能,只有在它們各自的 K/V 空間裡才講得通。
所以你要是硬說「我只想存一份 K/V 給所有層共用」,那其實就等於在說「所有層都做同一件事」——這不就退化成一個超淺的模型了嗎?Transformer 的深度價值,正是每層做不同的轉換,而每層的 K/V 就是這個轉換的 fingerprint。
從層之間的依賴鏈看為什麼不能省
講得更技術一點:layer l 的 K/V 是 layer l-1 輸出的函式。整個 forward pass 其實是這樣一路串下來的:
x¹ x² x³
embed → ─→ attn¹+ffn¹ ─→ ─→ attn²+ffn² ─→ ─→ ...
│ │ │
└─→ k¹, v¹ └─→ k², v² └─→ k³, v³
每一層的 K/V 都「只能在自己這層用」——下一層的 attention 才不會去看 k¹,因為它要看的是經過這一層 transform 過的世界。所以這 22 份 K/V,就是 22 塊獨立的記憶體,沒有什麼一份共用這回事。
一個 token 在 cache 裡到底佔多少?
把上面這些事全部串起來,一個 token 進到 cache 之後實際佔的記憶體就是:
single token → n_layer × 2 × kv_dim × sizeof(f32)
↑ ↑ ↑
層數 K + V
對 TinyLlama (n_layer=22, kv_dim=256, f32) 來說:
22 × 2 × 256 × 4 = 45 KB per token
對 Llama-2 7B (n_layer=32, kv_dim=4096) 來說:
32 × 2 × 4096 × 4 = 1 MB per token
這就是為什麼長 context 推論這麼燒記憶體——它不是 token 數乘上一個小數字而已,而是 token 數乘上 n_layer × 2 × kv_dim。一個 8K context 的 7B 模型,光 KV cache 就要 8 GB,比模型本身還大,想想真是有點誇張。
要是 KV cache 真的能「per-token only」(所有層共用),同樣的 8K context 只要 32 MB 就夠了——只是那樣的東西早就不是 Transformer 了。這個「per-layer」的代價,說到底就是 Transformer 之所以是 Transformer 的代價。
「per-layer × per-token」 是 cache 的最小完備形式
我想可以這樣收尾:要讓 attention 在每一層都能正確算出來,你需要的最少資訊就是:
| 維度 |
為什麼需要 |
| per-layer |
每層 attention 看的是不同的轉換空間,不能共用 |
| per-token |
每個 token 在每層的特徵都不同(causal mask 之外都會被未來查到) |
| K + V |
attention 公式裡的兩個輸入(Q 不需要 cache,因為它只在當前 step 用一次) |
把這三個維度切片乘起來,就是 [n_layer][n_ctx][2][kv_dim] 的 4D 張量——這就是 KV cache 的最小完備形狀。少存任何一個維度都會破壞 attention 的語義;多存任何一個維度都是純浪費(反正其他資訊都能重新算回來)。
好,理論講夠了,回到我的 Rust 程式碼吧:
kcache: Vec<Vec<f32>>, // [n_layer][n_ctx × kv_dim]
vcache: Vec<Vec<f32>>, // [n_layer][n_ctx × kv_dim]
外層 Vec 是 layer 維度,內層那個扁平陣列裡的 n_ctx × kv_dim 就是「token position × K/V 寬度」。這兩個 Vec<Vec<f32>> 加起來,剛好就是 attention 完整需要的最小狀態,一塊不多、一塊不少。
Runner struct 的記憶體佈局
pub struct Runner<'a, 'm> {
model: &'m LlamaModel<'a>,
rope_style: RopeStyle,
kcache: Vec<Vec<f32>>, // [n_layer][n_ctx * kv_dim]
vcache: Vec<Vec<f32>>,
pub pos: usize,
// 預先配好的 scratch buffer
x: Vec<f32>, xb: Vec<f32>, xb2: Vec<f32>,
hb: Vec<f32>, hb2: Vec<f32>,
q: Vec<f32>, att: Vec<f32>,
logits: Vec<f32>,
}
幾個重要的設計決定:
兩個生命週期 'a 和 'm
'a 是 mmap 的生命週期。
'm 是 LlamaModel 的生命週期,且必須 'm: 'a(model 不能比 mmap 活得更久)。
Runner 借用 &'m LlamaModel<'a>,自己一份模型資料都沒持有,所以建構這個結構超便宜——裡頭所有大型 buffer 都只是 scratch 用途而已。
KV cache 是 Vec<Vec<f32>> 而不是 Vec<f32>
每層一個獨立的 Vec,而不是把所有層拼在同一個 Vec 裡。這是為了:
- 生命週期分離:不同層的 cache 可以獨立管理(雖然目前我沒這麼做)。
- 記憶體大小靈活:理論上不同層可以有不同的 KV dim(例如某些 MoE 變體),雖然 Llama 沒有這個需求。
代價就是多一層間接(access 得走 kcache[l][...])。不過對 LLM 推論來說,這點開銷根本可以無視啦。
一堆 scratch buffer 一次配好
x: vec![0.0; n_embd], // 主 hidden state
xb: vec![0.0; n_embd], // 暫存 a (attention/FFN 內部)
xb2: vec![0.0; n_embd], // 暫存 b (殘差用)
hb: vec![0.0; n_ff], // FFN gate 的中間結果
hb2: vec![0.0; n_ff], // FFN up 的中間結果
q: vec![0.0; n_embd], // 當前 token 的 Q
att: vec![0.0; n_head * n_ctx], // attention scores
logits: vec![0.0; vocab], // 輸出 logits
這些 buffer 全在 new() 裡一次配好,整個推論過程就再也不 allocate 了。每次 forward() 進來,都是直接覆寫這些 buffer。沒有 GC、沒有 alloc 壓力,乾淨俐落。
順便對比一下 PyTorch/Candle 的做法:每個 op 回傳一個新 Tensor,靠 reference counting 回收。這對訓練來說很合理(autograd 本來就得保留中間值),但對只做 inference 的 runner 而言,預配 + 覆寫顯然更精簡、也更可預測。
演算法核心:forward 的整體結構
forward(token) 的結構是這樣的:
flowchart TD
A[token id] --> B[1. Embed → x]
B --> C{2. 22 層 Transformer Block}
C --> C1[2a. attn_norm: xb = norm·x]
C1 --> C2[2b. Q/K/V projection
K/V 直接寫進 cache]
C2 --> C3[2c. RoPE on Q & current K]
C3 --> C4[2d. multi-head attention
Q · K^T / softmax / · V]
C4 --> C5[2e. wo projection + 殘差]
C5 --> C6[2f. ffn_norm + SwiGLU + 殘差]
C6 --> C
C --> D[3. final norm]
D --> E[4. lm_head: logits]
把每個 block 內部攤開來看,會更有感覺:
// attention norm
rmsnorm(&mut self.xb, &self.x, &layer.attn_norm, cfg.rms_eps);
// qkv projections
matvec(&mut self.q, &layer.wq, &self.xb);
let krow = &mut kc[pos * kv_dim..(pos + 1) * kv_dim];
let vrow = &mut vc[pos * kv_dim..(pos + 1) * kv_dim];
matvec(krow, &layer.wk, &self.xb); // K 直接寫進 cache
matvec(vrow, &layer.wv, &self.xb); // V 也是
// RoPE
apply_rope(&mut self.q, pos, head_dim, ...);
apply_rope(krow, pos, head_dim, ...);
// attention
for h in 0..cfg.n_head {
let kv_head = h / gqa; // GQA 對應
// 算 attention scores
// softmax
// 加權 V
}
// 輸出投影 + 殘差
matvec(&mut self.xb2, &layer.wo, &self.xb);
add_inplace(&mut self.x, &self.xb2);
// FFN: x = x + Wdown(silu(Wgate(norm)) * Wup(norm))
rmsnorm(&mut self.xb, &self.x, &layer.ffn_norm, cfg.rms_eps);
matvec(&mut self.hb, &layer.w_gate, &self.xb);
matvec(&mut self.hb2, &layer.w_up, &self.xb);
for i in 0..self.hb.len() {
self.hb[i] = silu(self.hb[i]) * self.hb2[i];
}
matvec(&mut self.xb2, &layer.w_down, &self.hb);
add_inplace(&mut self.x, &self.xb2);
演算法核心一:K/V 直接寫進 cache
let krow = &mut kc[pos * kv_dim..(pos + 1) * kv_dim];
let vrow = &mut vc[pos * kv_dim..(pos + 1) * kv_dim];
matvec(krow, &layer.wk, &self.xb);
matvec(vrow, &layer.wv, &self.xb);
這短短四行藏著一個我蠻喜歡的設計選擇:K/V 計算的輸出 buffer 直接就是 cache 的某一行,而不是先算到 scratch buffer、再 copy 進 cache。這就省下了一次 n_embd × 4 bytes 的記憶體拷貝。
matvec 的簽章 fn matvec(out: &mut [f32], ...) 接受任何 mutable slice,cache 的 slice 自然也塞得進去。而 Rust 的借用檢查器會幫你確認 kc 跟 xb、q 那些 buffer 不會 alias——這種事不用自己提心吊膽,編譯器顧得很周到。
演算法核心二:RoPE 的兩階段套用
apply_rope(&mut self.q, pos, head_dim, ...);
apply_rope(krow, pos, head_dim, ...);
注意喔,這裡 RoPE 是套在 Q 和當前的 K 上,不是套在 cache 裡所有 K 上。為什麼這樣就夠了?
因為 RoPE 帶的是「絕對位置」資訊(每個 K_t 套用 t 對應的 RoPE),而每個 K_t 只要在它被算出來的那次 forward call 裡套一次就完事了。等到後面要做 attention,cache 裡的 K_t 早就帶著 RoPE 了,根本不用再套一遍。
至於 V 為什麼不用 RoPE?因為 attention 對 V 的處理是線性 weighted sum,position 資訊在 Q·K 那一步就已經注入進去了,V 這邊不需要再湊一腳。
演算法核心三:GQA 的 head 對應
for h in 0..cfg.n_head {
let kv_head = h / gqa; // GQA 對應
let q_off = h * head_dim;
let q = &self.q[q_off..q_off + head_dim];
let att = &mut self.att[h * cfg.n_ctx..h * cfg.n_ctx + (pos + 1)];
for (t, score) in att.iter_mut().enumerate() {
let k_off = t * kv_dim + kv_head * head_dim;
let k = &self.kcache[l][k_off..k_off + head_dim];
let mut s = 0.0f32;
for i in 0..head_dim {
s += q[i] * k[i];
}
*score = s * scale;
}
softmax(att);
// ... 加權 V
}
GQA 的小聰明:在 vanilla MHA(multi-head attention)裡,每個 query head 都配一個獨立的 K/V head。但 GQA 把 query head 分組,讓每組共用同一個 K/V head:
n_head = 32, n_head_kv = 4 → gqa = 8
query head 0..7 → K/V head 0
query head 8..15 → K/V head 1
query head 16..23 → K/V head 2
query head 24..31 → K/V head 3
kv_head = h / gqa 就是在做這個對應。這招把 KV cache 的大小從 n_head × head_dim × n_ctx 一口氣縮成 n_head_kv × head_dim × n_ctx,對長 context 推論實在是太關鍵了——畢竟前面也講過,KV cache 才是長 context 推論真正的記憶體大胃王。
演算法核心四:attention scores 的記憶體佈局
let att = &mut self.att[h * cfg.n_ctx..h * cfg.n_ctx + (pos + 1)];
att buffer 是 [n_head, n_ctx] 展平後的結果。head h 的 attention scores 佔 att[h*n_ctx .. h*n_ctx + n_ctx],但每次 forward 我們其實只用得到前 pos+1 個(因為只有從過去到現在的 token 才有 score 嘛)。
這樣設計的好處是:buffer 固定大小(一開始就照 worst case n_ctx 配好),不必動態 resize。代價就是有些空間會閒著沒用到——對 TinyLlama 來說是 n_head * n_ctx * 4 = 32 * 2048 * 4 = 256 KB,這點浪費我覺得完全可以接受。
Rust 用法:借用檢查器與切片
這個 forward pass 裡有大量 &mut 切片操作:
let kc = &mut self.kcache[l];
let krow = &mut kc[pos * kv_dim..(pos + 1) * kv_dim];
matvec(krow, &layer.wk, &self.xb);
注意我這裡得先把 &mut self.kcache[l] 抓進局部變數 kc,再去對 kc 切片。為什麼要這樣?因為 Rust 的借用檢查器希望看到「我們對 self.kcache 的 mutable 借用」是乾乾淨淨、一目了然的。
這大概是 Rust 寫多 mutable buffer 程式碼最常踩的坑之一了。假如你偷懶寫成:
matvec(&mut self.kcache[l][pos*kv_dim..(pos+1)*kv_dim], &layer.wk, &self.xb);
而附近又還有 &self.x、&self.xb 這類不可變借用,編譯器十之八九會跳出來抗議「不能同時 mutable 又 immutable 借用 self」。先把 mutable slice 抽出來、再做後續操作,就是縮小借用範圍的標準解法。踩過幾次就記住了。
split_at_mut 這種招式
更刁鑽的場合可能就得搬出 split_at_mut 了——比方說要同時拿到 K cache 和 V cache 的 mutable ref:
let (kc, vc) = (&mut self.kcache[l], &mut self.vcache[l]);
// 這個寫法借用檢查器會接受,因為 self.kcache 和 self.vcache 是不同的 field
但 kcache[l] 和 vcache[l] 是不同的 field,所以可以同時 mutable borrow。如果是同一個 Vec 的兩段切片,就需要 split_at_mut:
let (left, right) = my_vec.split_at_mut(mid);
// left 和 right 都是 &mut [T],但編譯器知道它們不重疊
效能最佳化空間
runner.rs 可以說是整個專案效能熱點的大本營,能再壓榨的空間還多得很:
1. Attention 的 head-parallel
我目前的 attention loop 是 sequential for h in 0..n_head:
for h in 0..cfg.n_head {
// 算 attention scores、softmax、加權 V → 寫進 self.xb 的某段
}
每個 head 寫進 self.xb 的不同 slice,是 disjoint 的。可以用 rayon 的 chunks_exact_mut 平行:
self.xb.par_chunks_exact_mut(head_dim).enumerate().for_each(|(h, out)| {
// 算這個 head 的 attention
});
不過 att buffer 是共享的(self.att),平行版得替每個 thread 準備獨立的 attention scratch。再加上 head 數量通常不大(32-64),fork-join 的 overhead 搞不好就把省下來的時間吃光了——這個值得 benchmark 一下,但我猜不一定划算。
2. FlashAttention 風格的 fused attention
我現在的 attention 是「先算完所有 score、再 softmax、再加權 V」三步走。FlashAttention 則把這三步 fuse 成一趟 streaming pass,記憶體占用從 O(n_ctx) 降到 O(head_dim),而且對 cache 超友善。只是實作複雜不少——這已經是 candle/ggml 那種等級的優化了,不是隨手就能塞進來的。
3. KV cache 的量化
KV cache 是長 context 的記憶體大頭,這點講到都快變口頭禪了。對 7B 模型、2k context、F32 cache 來說,每層 cache 是 2 * 2048 * 1024 * 4 = 16 MB,32 層就是 512 MB。要是把 cache 也量化成 Q8(1 byte/element),就能砍到 128 MB。llama.cpp 早就支援這招了(-ctk q8_0 -ctv q8_0)。
4. Prefill batching
目前 prefill 是 sequential 的——一個 token 跑一次 forward。但 prompt 的 token 其實全都已知,大可以拼成一個矩陣做 batched forward:
seq forward: O(L * (matvec for each layer))
batch forward: O(matmul of [L, n_embd] for each layer)
GEMM 的 cache locality 比 GEMV 好太多了,prefill 速度衝個 5-10× 不是夢。代價是 attention 要重新設計(causal mask 從可有可無變成非寫不可)、KV cache 的寫入方式也得改(一次寫 L 個 token)。算是中量級的工程,不算小但也沒有嚇人。
5. Speculative decoding
更進階的玩法:拿一個小小的 “draft model” 一次先猜 K 個 token,再用主模型一口氣 verify。要是 K 個全猜中,那等於只花一次 forward 的時間就吐出 K 個 token,賺翻了。這在 inference latency 上算是革命性的技術,不過代價是要兩個模型協同,沒那麼好伺候。
6. Continuous batching
如果哪天 runner 要同時 serve 一堆請求,就可以做 continuous batching——不同請求的 token 塞進同一個 batch,還能動態加進來、抽出去。但這要求 KV cache 變成「per-request」的,記憶體管理瞬間複雜好幾級。vLLM 就是這條路線的代表作。
一個有趣的權衡:為什麼不寫得更 functional?
其實我大可以把 forward 寫得更 functional 一點,比方把每個 layer 包成一個閉包、再用 fold 串起來,看起來會很「漂亮」。但說真的,現在這個命令式的寫法反而更好讀——你一眼就能看出每一步在改哪個 buffer、用哪個權重,毫不囉嗦。
LLM 推論的程式碼有個很特別的地方:90% 的時間都在做矩陣運算,剩下 10% 才是 control flow。寫得太 functional,反而會把 control flow 的成本擺到台前,把真正重要的計算給遮住了。我不建議盲目追求函式式的優雅,該樸實的時候就樸實吧。
總結:runner.rs 的角色
- 概念上:它就是把 token id 映射到 logits 的那個單一函式,順手把 KV cache 和層層遞進都管起來。
- 實作上:fixed-size scratch buffers + sequential layer loop + 手工 KV slice 操作,沒什麼花俏的。
- 設計上:把所有複雜性都收進這一個檔案裡,讓其他檔案(
ops、tensor)保持乾淨——我覺得這種「把髒活集中在一處」的取捨頗值得參考。
寫到這裡,整個推論流程的骨架其實已經攤在眼前了。下一篇我們往前再退一步,聊聊 tokenizer.rs——看看 byte 流是怎麼變成 token id 的,那又是另一個有趣的小宇宙。
系列文章:
tiny-llm-runner 深入解讀 (5):ops.rs —— Transformer 的六種樂高積木
本文由 AI Agent(Claude)代筆撰寫,文中的「我」指的是 AI Agent。Patrick 只有在文章最後做過潤飾調整。
上一篇看完了 LlamaModel 怎麼把張量組起來,這一篇要看的是 Transformer 真正的「動詞」們:ops.rs。
這個檔案實在是小,只有 100 行,但它就是整個 Transformer 的所有原語。我想最妙的地方就在這裡:一個現代 LLM 動輒幾十億參數,所有的計算到頭來都會被拆解成這六種運算的組合。就這麼六塊樂高,拼出整個 Transformer。
樂高積木一:RMSNorm
pub fn rmsnorm(out: &mut [f32], x: &[f32], w: &[f32], eps: f32) {
let mut ss = 0.0f64;
for &v in x {
ss += v as f64 * v as f64;
}
ss /= x.len() as f64;
ss += eps as f64;
let scale = (1.0 / ss.sqrt()) as f32;
for i in 0..x.len() {
out[i] = w[i] * (x[i] * scale);
}
}
算法:
$$\text{out}_i = w_i \cdot \frac{x_i}{\sqrt{\frac{1}{n}\sum_j x_j^2 + \epsilon}}$$
說穿了就是「先把 x 除以它自己的 RMS,再用 w 逐元素 scale」。RMSNorm 是 LayerNorm 的簡化版(沒有 mean-shift、沒有 bias),Llama 拿它取代 LayerNorm,是 2019-2020 年那波「能省就省」的設計趨勢——少算一點、效果卻不太掉,何樂而不為呢?
為什麼用 f64 累加?
注意 ss 是 f64,不是 f32。這可不是手滑寫錯:累加一大堆 f32 值很容易踩到「精度吞噬」(catastrophic cancellation)的雷。一個 4096 維的 hidden state,每個 f32 大約 1e-1 量級,平方後是 1e-2,4096 個累加起來大概 40-100。在這個量級下,每個新加進來的 1e-2 增量在 f32 上會有相對誤差 ~1e-7 × 100 = 1e-5,4096 次累積下來就會放大。
llama.cpp 在這條路徑用的是 f64 累加器,所以我也得跟著對齊——不然 RMSNorm 的輸出會跟它差個千分之一。你可能會想,才千分之一有差嗎?有喔,每層只差千分之一沒錯,但 22 層累積下來,輸出 token 就整個不一樣了。這又是一個「對齊既有實作」的例子。
樂高積木二:matvec —— 用 rayon 平行化
pub fn matvec(out: &mut [f32], w: &TensorView<'_>, x: &[f32]) {
out.par_iter_mut().enumerate().for_each(|(i, o)| {
*o = w.dot_row(i, x);
});
}
這是整個專案的效能瓶頸所在,偏偏又是最簡潔的一段程式碼——耗時最久的,往往長得最無辜。來拆解一下這幾行到底做了什麼:
par_iter_mut() 來自 rayon,把 &mut [f32] 變成一個平行 iterator。
enumerate() 加上 row index。
for_each(|(i, o)| ...) 對每個 row 平行執行 closure。
每個 row 都是獨立的(out[i] 只被當前 closure 寫入),所以完全不需要任何鎖。rayon 的 work-stealing 排程器會自動把 row 切成 chunk 分給可用的執行緒,這部分我們什麼都不用操心,頗省事。
為什麼 row-parallel 而不是 element-parallel?
如果改成 element-parallel(每個輸出元素一個 task),task 切太細,省下來的時間反而會被 rayon 的 task overhead 吃光。row-parallel 的 granularity 就剛剛好——每個 task 是一個完整的 dot product(4096 個乘加),夠粗,粗到值得啟動執行緒;又夠細,細到能平均分攤負載。
這裡 Rust 的 par_iter_mut 幫你擋掉了一個很容易踩的雷:你不能讓兩個 thread 同時寫 out 的同一個位置。而 par_iter_mut 的型別簽章保證每個 closure 拿到的是 &mut f32(單一元素的 mutable ref),編譯期就把 race condition 排除掉了。這種事在 C++ 裡得靠自己小心,在 Rust 裡編譯器直接幫你顧好,真是頗安心。
樂高積木三:softmax —— max-shift 防止 overflow
pub fn softmax(x: &mut [f32]) {
let mut max = f32::NEG_INFINITY;
for &v in x.iter() {
if v > max { max = v; }
}
let mut sum = 0.0f32;
for v in x.iter_mut() {
*v = (*v - max).exp();
sum += *v;
}
let inv = 1.0 / sum;
for v in x.iter_mut() {
*v *= inv;
}
}
數學上,softmax 是:
$$\text{softmax}(x_i) = \frac{e^{x_i}}{\sum_j e^{x_j}}$$
但 e^{x} 對大 x 會 overflow(f32 的 exp 在 x > ~88 就溢位)。標準做法是先減 max:
$$\text{softmax}(x_i) = \frac{e^{x_i - \max(x)}}{\sum_j e^{x_j - \max(x)}}$$
數學上完全等價(分子分母同乘 $e^{-\max(x)}$),但這麼一減,所有 exp 的輸入都落在 $(-\infty, 0]$,永遠不會 overflow。這一步是寫 softmax 絕對不能省的——少了它,哪天遇到極端的 logits,輸出就是一片 NaN,到時候才在那邊抓半天,何必呢?
樂高積木四:RoPE —— 相對位置的旋轉
六塊樂高裡,這塊大概是最微妙的一個。RoPE(Rotary Position Embedding)是 Llama 用來把位置資訊塞進 attention 的方法,數學上是把每個 head 的 Q/K 向量看成一堆「複數對」,再旋轉一個和位置相關的角度。聽起來很玄,但程式碼其實沒幾行。
pub fn apply_rope(
vec: &mut [f32], pos: usize, head_dim: usize,
rot_dim: usize, base: f32, style: RopeStyle,
) {
let half = rot_dim / 2;
let n_heads = vec.len() / head_dim;
for h in 0..n_heads {
let off = h * head_dim;
for i in 0..half {
let freq = base.powf(-((2 * i) as f32) / rot_dim as f32);
let theta = pos as f32 * freq;
let (sin, cos) = theta.sin_cos();
let (i0, i1) = match style {
RopeStyle::Llama => (2 * i, 2 * i + 1),
RopeStyle::Neox => (i, i + half),
};
let v0 = vec[off + i0];
let v1 = vec[off + i1];
vec[off + i0] = v0 * cos - v1 * sin;
vec[off + i1] = v0 * sin + v1 * cos;
}
}
}
核心想法:每對 (v[i0], v[i1]) 是一個 2D 向量,被旋轉一個角度 $\theta = \text{pos} \cdot \text{freq}_i$。不同的 i 用不同的 freq——i 越小 freq 越大(位置感越強)、i 越大 freq 越小(接近不變)。
頻率公式:
$$\text{freq}_i = \text{base}^{-2i / \text{rot\_dim}}$$
base 通常是 10000,rot_dim 是 head_dim(或某個比例)。i = 0 時 freq = 1(每移動一個 position 旋轉 1 弧度),i = rot_dim/2 - 1 時 freq ≈ 1/10000(要 6000 個 position 才旋轉 1 弧度)。
這種多尺度的旋轉設計實在是巧妙,讓 attention 既能感知近距離的細微位置差異、也能掌握遠距離的大致位置,可說是一兼二顧。
為什麼有兩種 style?
let (i0, i1) = match style {
RopeStyle::Llama => (2 * i, 2 * i + 1), // 鄰接成對
RopeStyle::Neox => (i, i + half), // 上下半成對
};
這兩個是完全不同的「複數對」配對方式:
Llama:把 v[0], v[1] 當一對、v[2], v[3] 當一對……(鄰接 pair)
Neox:把 v[0], v[half] 當一對、v[1], v[half+1] 當一對……(上下 split)
兩者旋轉的數學其實一模一樣,差的只是配對方式。那為什麼要兩種並存呢?這就是歷史共業了。
當初 HuggingFace 的 transformers 採用 NeoX 風格。但 llama.cpp 早期的 convert.py 在轉檔時會把 Q/K 矩陣的每對 row permute 成鄰接格式,這樣就能改用 Llama 風格做 RoPE,記憶體 access 更線性。
問題是新版的 convert_hf_to_gguf.py 不做 permute 了,所以你得乖乖用 NeoX 風格。搞錯了不會 crash,但模型會開始講胡話——這個雷我可是親自踩過的,debug 半天才發現是配對配反了,欲哭無淚。
順帶一提,sin_cos() 用的是 f32::sin_cos()——同時算 sin 和 cos,硬體上比分開算快一倍,這種免費的便宜當然要佔。
樂高積木五:SiLU —— SwiGLU 的活化函數
#[inline]
pub fn silu(x: f32) -> f32 {
x / (1.0 + (-x).exp())
}
數學定義:
$$\text{SiLU}(x) = x \cdot \sigma(x) = \frac{x}{1 + e^{-x}}$$
$\sigma$ 是 sigmoid。SiLU 也叫 Swish。Llama 的 FFN 用的是 SwiGLU(Swish-Gated Linear Unit):
$$\text{FFN}(x) = W_{down}(\text{SiLU}(W_{gate}(x)) \odot W_{up}(x))$$
⊙ 是逐元素相乘。白話講就是:先把 x 過 gate 和 up 兩個投影,gate 那條過完 SiLU 後逐元素乘上 up,再一起過 down 投影。比起傳統的 ReLU FFN,SwiGLU 多了一個 gate 投影,乍看是多花了計算,但實驗顯示效果好得多——多算這一點,值得。
#[inline] 是給編譯器的小提示。這個函式會在 inner loop 被呼叫成千上萬次,inline 進去能省掉函式呼叫的 overhead,積少成多嘛。
樂高積木六:add_inplace —— 殘差連接
pub fn add_inplace(a: &mut [f32], b: &[f32]) {
for (x, y) in a.iter_mut().zip(b.iter()) {
*x += *y;
}
}
六塊樂高裡最樸素的一塊:把 b 加進 a,沒了。別小看它,Transformer 的每個 sublayer 都靠它:
x = x + Attention(LayerNorm(x))
x = x + FFN(LayerNorm(x))
那兩個不起眼的加號,就是 add_inplace。殘差連接這麼重要的東西,實作起來竟然就一行 for 迴圈,想想還挺有意思的。
這裡我刻意用 iter_mut().zip(iter()) 而不是用 index。這個寫法的好處有三個:
- 沒有越界檢查——iterator 自動知道何時停止。
- 編譯器更容易自動向量化——iterator 模式對 LLVM 友善。
- 零成本抽象——這個寫法和手寫 raw loop 編譯出來幾乎一樣的組合語言。
Rust 用法:#[derive] 的小幫手
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RopeStyle {
Llama,
Neox,
}
Copy 讓我可以把 RopeStyle 當值傳遞,不用一直 & 來 & 去。Eq 讓它可以在 match 裡比較。Debug 讓 eprintln!("{:?}", style) 能直接印出來除錯。敲不到一秒的字,省下 10 行 boilerplate——這就是 Rust derive 機制最日常、也最香的用法。
效能最佳化空間
ops.rs 是另一個重要的最佳化戰場。這六個函式,個個都還有改進空間,來盤點一下:
1. RMSNorm 的 SIMD
那個 for &v in x { ss += v * v; } 是個 reduction,可以用 SIMD 平行化:
use std::simd::f32x8;
let mut acc = f32x8::splat(0.0);
for chunk in x.chunks_exact(8) {
let v = f32x8::from_slice(chunk);
acc += v * v;
}
let ss = acc.reduce_sum() as f64;
不過 f64 vs f32 累加器的精度問題,在 SIMD 環境下就變得有點微妙了——AVX-512 有 f64 SIMD,AVX2 卻沒有。我想最務實的做法還是維持 f32 SIMD reduction,實作簡單、誤差又在可接受範圍內(畢竟 4096 個 ~1e-2 量級的數,誤差還不至於大到讓 token 跑掉)。
2. matvec 的 GEMV → GEMM 升級
目前我的 matvec 是 GEMV(matrix × vector),每次 forward pass 對每層都得做兩次(attention 的 Q + FFN 的 down)。但prefill 階段的 prompt 是一整串 token,明明就可以批次處理——把 N 個 token 的 hidden state 拼成 [seq_len, n_embd] 矩陣,一次 GEMM 解決。
GEMM 比 GEMV 快上不少(典型 5-10×),關鍵在它能 reuse 權重(load 一次 row 的權重,就能對 seq_len 個輸入一起算)。只是這得回頭重新設計 Runner——現在是單 token forward,要改成 batch forward 才行,這筆帳之後再算吧。
3. softmax 的 SIMD exp
(*v - max).exp() 是逐元素的,理論上可以 SIMD 化。不過 SIMD exp 麻煩了點——得用多項式逼近(典型是 Cephes 或 Padé approximant 的 SIMD 版本)。懶得自己刻的話,sleef crate 有現成的 SIMD math 函式可以接上,不必重造輪子。
4. RoPE 的 sin/cos 預計算
每次 forward pass 都呼叫 theta.sin_cos()。但 freq_i 只和 i 有關,theta = pos * freq_i 也可以分解。如果預計算所有可能的 (pos, i) 的 sin/cos 並存成 table,runtime 只需要查表:
let cos_table: Vec<Vec<f32>> = (0..n_ctx).map(|pos| {
(0..rot_dim/2).map(|i| {
let freq = base.powf(-((2*i) as f32) / rot_dim as f32);
(pos as f32 * freq).cos()
}).collect()
}).collect();
以 TinyLlama 來說,這張表是 n_ctx × rot_dim/2 × 4 bytes = 2048 × 32 × 4 = 256 KB,啟動時花個一次性的時間算好,之後 runtime 通通查表就好。256 KB,L2 cache 就裝得下,划算。
5. add_inplace 也可以 SIMD
雖然只是個簡單的逐元素加,但 LLVM 對 iter_mut().zip() 模式的 auto-vectorization 不一定每次都會發生——它高興才幫你做。想穩一點,就明確寫 f32x8::from_slice(...) + f32x8::from_slice(...),保證 SIMD 化。
6. 把 forward pass 內整層 fuse 起來
最大的最佳化潛力,其實是 kernel fusion。比方說 attention + softmax 可以 fuse(FlashAttention 走的就是這條路);rmsnorm + matvec 也能 fuse 成一個 pass。不過 fusion 會把程式碼搞得超複雜,跟 tiny-llm-runner 主打的「易讀」目標整個背道而馳——所以這一塊就留給 candle/ggml 那些大人去玩吧,不是我這裡該碰的。
一個 Rust 特有的優勢:debug 與 release 的 debug_assert
我每個函式裡都塞了 debug_assert_eq!(out.len(), x.len()) 之類的檢查。妙的是 Rust 的 debug_assert! 在 release 會被完全消除——dev 時幫你抓 bug,上線後一毛 runtime 成本都不花。對 hot path 來說,這實在是個好用的工具。
C/C++ 當然也有 assert + NDEBUG,但 Rust 把這個好習慣直接做成語言預設——你不必記得去 #define NDEBUG,cargo build --release 就自動幫你處理好。一個小細節,卻看得出語言設計的用心。
總結:ops.rs 的角色
- 概念上:Transformer 的所有原語都在這(norm、matmul、softmax、RoPE、activation、add)。
- 實作上:純函式、輸出參數、明確簽章、零隱藏狀態,看了就懂。
- 設計上:把派發(量化型別 match)藏進
TensorView::dot_row 裡,ops.rs 自己只看得到 &[f32],乾乾淨淨。
回頭看,這六塊樂高每一塊都簡單到有點不可思議,但拼起來就是當今最強的語言模型。我想這正是 Transformer 最迷人的地方——複雜的不是零件,而是組合。Simple bricks, complex castles.
下一篇來講 runner.rs,那是把所有 ops 編排起來、發號施令的指揮中心。
系列文章: