本文由 AI Agent(Claude)代筆撰寫,文中的「我」指的是 AI Agent。Patrick 只有在文章最後做過潤飾調整。
上一篇我們把所有 Transformer 原語都拆過一遍了,這一篇終於要把它們編排成一次完整的 forward pass,主角就是核心檔案 runner.rs。
老實說,整個專案我最推薦讀的就是這個檔案——大約 170 行的 Rust,把 Llama 的整個推論流程攤平在你眼前,每一步在做什麼都看得清清楚楚,實在是很過癮。
概念一:autoregressive 推論的本質
LLM 推論不是「一次處理整段文字」,而是「一次處理一個 token」:
prompt: "The capital of France"
→ token ids [464, 3139, 286, 4881]
prefill (處理 prompt):
forward(464) → logits (我們不用,但要建 KV cache)
forward(3139) → logits
forward(286) → logits
forward(4881) → logits ← 用這個 logits 抽下一個 token
decode (生成):
sampler(logits) → " is"
forward(" is") → logits
sampler(logits) → " Paris"
forward(" Paris") → logits
... 一直到 EOS每次 forward(token) 就是一個 step,step 之間靠 KV cache 一路把上下文累積起來。一個字一個字慢慢吐,model 其實沒有你想像中那麼「整段一起讀」。
概念二:KV Cache —— 為什麼 attention 要 cache 過去?
Attention 的數學是這樣的:
$$\text{attn}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right) V$$在 autoregressive 推論時,第 t 個 token 的 Q 只需要跟前 t 個 token 的 K/V 做運算(因為未來的 token 根本還沒出現嘛)。如果每次 forward 都把所有過去的 K/V 重新算一遍,prefill 階段就是 $O(n^2)$,decode 階段每次又是 $O(n)$ 的重做——這也太浪費了吧。
KV cache 的想法其實很單純:某個 position 的 K/V 算過一次就存起來,下次直接拿來用。這樣每次 forward 只要算當前這個 token 的 Q/K/V 就好,K 和 V 寫進 cache、Q 拿去跟整個 cache 點積。
概念三:Context 和 KV Cache 的關係
很多人一聽到「context length 2048」,直覺就想成「模型有個地方存著最近 2048 個 token」。其實不是這樣喔——模型存的是這 2048 個 token 被各層 attention 算出來的 K 和 V,token 本身早就丟掉了。所以 context 的容量上限,講白了就是 KV cache 的容量上限。
關係可以這樣理解:
context window = 「我能看見多遠的過去」的能力上限
KV cache = 這個能力實際的儲存形式
n_ctx = KV cache 在 position 維度上的大小當你看到 n_ctx = 2048,背後真正配出來的記憶體就是:
kcache: [n_layer][n_ctx × kv_dim] f32
vcache: [n_layer][n_ctx × kv_dim] f32對 TinyLlama (n_layer=22, kv_dim=256) 來說,這是 22 × 2 × 2048 × 256 × 4 = 92 MB。對 Llama-2 7B (n_layer=32, kv_dim=4096) 來說,是 32 × 2 × 2048 × 4096 × 4 = 2 GB——KV cache 居然比模型本身的量化權重還大,是不是有點出乎意料?
還有一點要記住:token 一旦進了 cache,就再也撈不回原貌了。Context window 滿掉之後,要嘛截斷舊 token、要嘛丟掉新 token,沒有第三條路——因為原始 token id 早就在被 attention 投影成 K/V 的那一刻丟掉了。
概念四:為什麼不存 token 本身?為什麼不存 embedding?為什麼是 K/V?
我覺得把這個問題想清楚,就等於把 KV cache 的設計從頭到尾理解了一遍。我們從「最樸素」的方案開始,一步步往下推:
方案 A:只存 token id
「我把過去的 token id 都存起來,下次推論時整段重跑就好。」
問題在於:這樣每次 forward 都得從 token id 開始,重算 embedding、重跑 22 層 transformer。對第 t 個 token 來說是 $O(t \times \text{layers} \times \text{matmul})$ 的工作,整個 prompt 的 prefill 是 $O(n^2)$,decode 累積下來也是 $O(n^2)$。說穿了,這根本沒在 cache 嘛。
方案 B:存 embedding(每個 token 對應一個 n_embd 向量)
「那我存 token embedding,這樣可以跳過 embedding lookup。」
問題是:embedding 只是 forward pass 的入口而已。從 embedding 走到「attention 真正吃的東西」,中間還隔著一整層的 RMSNorm + Q/K/V 投影。每層的 attention 看到的,是「自己這層的輸入經過 W_k 和 W_v 投影後的 K/V」,不是 token embedding。所以如果只存 token embedding:
- 每次 forward 都要重新跑 22 層的 RMSNorm、重新算 K/V 投影。
- 但這些工作的結果和 token 位置 t 無關——每次重算都會得到同一個 K_t 和 V_t。
- 純粹的浪費。
更麻煩的是:第 1 層的輸入是 token embedding,可是第 2 層的輸入是「第 1 層的輸出」。所以你沒辦法拿「token embedding」這一個概念去代表「每一層 attention 需要的東西」。每層之間都隔著 attention + FFN 的非線性轉換,「embedding」這個詞其實只在最開頭那一層才講得通。
方案 C:存每層的 hidden state(每層都有自己的 [n_layer][n_ctx][n_embd])
「那我存每層的輸出 hidden state?」
問題是:你存了 hidden state,但 attention 真正要的是 K = W_k(hidden_state) 和 V = W_v(hidden_state)。每次 forward 還是得重算 K/V 投影(一個 [n_embd, kv_dim] 的 matvec)。等於白存,因為投影那一步根本沒省到。
而且這個方案吃的記憶體還比 K/V cache 大:hidden state 是 n_embd 寬,K/V 在 GQA 下只有 kv_dim = head_dim × n_head_kv 寬(TinyLlama 是 256,比 n_embd=2048 整整小了 8 倍)。划不來。
方案 D:直接存 K 和 V(最終答案)
繞了一圈,答案其實一直都在公式裡:K 和 V 才是 attention 真正消費的東西。$\text{attn}(Q, K, V)$ 裡頭沒有 hidden state、沒有 embedding、沒有 token id——就只有 Q、K、V。把過去的 K 和 V 存起來,attention 就齊了。
而且 K/V 還有兩個我覺得頗迷人的性質:
- 它們和 Q 是解耦的:K_t 和 V_t 一旦算出來就和「未來會用什麼 Q 來查它」無關。所以可以放心 cache。
- 它們不需要被未來修改:因為 LLM 是 causal——位置 t 的 K/V 只被 position ≥ t 的 Q 查詢,但這些查詢不會回頭改 K_t/V_t。一旦寫進 cache 就是 immutable 的。
這就是為什麼我會說 KV cache 是 transformer 推論的「最佳化終點」——它剛剛好存下 attention 需要的最少資訊,再少一點就破壞語義,再多一點就是浪費。剛剛好,一點都不多。
從「資訊保留鏈」看這件事
換個角度想想:模型做 forward pass,其實是一條把「token id」逐步轉成「下一個 token 機率分佈」的資訊流。中途會冒出一大堆中間表示:
token id → embedding → layer 1 hidden → layer 1 K/V → ...
→ layer 2 hidden → layer 2 K/V → ...
...
→ final hidden → logits每一個中間表示都裝著這個 token 的某種資訊。但對 attention 來說,這條鏈上唯一一個「跨 token 互動」的點,就是它把 K/V 拿來算內積那一步。其他每一步(RMSNorm、Q/K/V 投影、FFN…)都是 pure-position 的——只管當前這個 token 自己的資料,跟別人無關。
所以:
- 那些「pure-position」的步驟不需要 cache,因為每個 token 的計算和其他 token 無關。
- 那個「跨 token 互動」的步驟必須 cache,因為它要看到所有過去 token 的資訊。
而 attention 跨 token 看到的東西,就是 K 和 V。所以該被 cache 的,當然就是它們囉。
一個簡單的比喻
把 LLM 比喻成一個圖書館:
- Token id 是書的索書號。
- Embedding 是書的封面(提供入口)。
- 每層 hidden state 是書頁的內容(給讀者看)。
- K 和 V 是給其他書「互相引用」用的索引卡。
Attention 就是書與書之間的對話。書本身(hidden state)很厚,但對話只透過索引卡(K/V)來進行。把用過的卡片留在桌上,下次新書一進來就能直接跟它們對話——不必把整本書重新搬出來翻一遍。我自己覺得這個比喻還滿傳神的。
概念五:為什麼 KV cache 必須是 per-layer,而不只是 per-token?
注意 Runner 的 cache 是 Vec<Vec<f32>>,第一個維度是 layer,第二個才是 token position:
kcache: Vec<Vec<f32>>, // [n_layer][n_ctx × kv_dim]
vcache: Vec<Vec<f32>>,對 22 層的 TinyLlama 來說,這就是 22 份獨立的 cache。為什麼不能大家共用一份?為什麼一個 token 不是對應一組 K/V,而是硬要對應 n_layer 組?
因為「同一個 token」在每一層的 K/V 是完全不同的東西
關鍵就藏在 forward pass 的結構裡。我們回頭看看每一層 attention 到底在算什麼:
layer 1: x¹ = embedding(token)
k¹ = W_k¹(rmsnorm(x¹)) ← 第 1 層的 K
v¹ = W_v¹(rmsnorm(x¹)) ← 第 1 層的 V
x² = x¹ + attn(...) + ffn(...)
layer 2: k² = W_k²(rmsnorm(x²)) ← 第 2 層的 K(input 不同、權重不同)
v² = W_v²(rmsnorm(x²))
x³ = x² + attn(...) + ffn(...)
layer 3: k³ = W_k³(rmsnorm(x³)) ← 第 3 層的 K
...每一層的 K/V 同時受兩件事決定:
- 這一層的 input hidden state —— 上一層 attention + FFN 完成後傳下來的東西,每層都不同。
- 這一層自己的權重
W_k^l/W_v^l—— 每層獨立訓練的不同矩陣。
換句話說,layer 1 的 K 是「token 剛被 embed 時的樣子,被 W_k¹ 看出來的特徵」;layer 5 的 K 則是「token 經過 4 層 attention + FFN 整合上下文之後的樣子,被 W_k⁵ 看出來的特徵」。這兩個 K 既不是同一個東西,也沒辦法互相代打。
換個角度:每一層都有自己的 attention
Transformer 的設計本身就把每層 attention 當成獨立的計算單元。每層都在問「我這層的 query 跟我這層 cache 裡的 key 像不像」,答案決定了我這層怎麼把 V 加總起來。底下幾層通常學「文法、近距離 token 關係」,中層學「短語結構」,高層學「語意、長距離依賴」——這些功能,只有在它們各自的 K/V 空間裡才講得通。
所以你要是硬說「我只想存一份 K/V 給所有層共用」,那其實就等於在說「所有層都做同一件事」——這不就退化成一個超淺的模型了嗎?Transformer 的深度價值,正是每層做不同的轉換,而每層的 K/V 就是這個轉換的 fingerprint。
從層之間的依賴鏈看為什麼不能省
講得更技術一點:layer l 的 K/V 是 layer l-1 輸出的函式。整個 forward pass 其實是這樣一路串下來的:
x¹ x² x³
embed → ─→ attn¹+ffn¹ ─→ ─→ attn²+ffn² ─→ ─→ ...
│ │ │
└─→ k¹, v¹ └─→ k², v² └─→ k³, v³每一層的 K/V 都「只能在自己這層用」——下一層的 attention 才不會去看 k¹,因為它要看的是經過這一層 transform 過的世界。所以這 22 份 K/V,就是 22 塊獨立的記憶體,沒有什麼一份共用這回事。
一個 token 在 cache 裡到底佔多少?
把上面這些事全部串起來,一個 token 進到 cache 之後實際佔的記憶體就是:
single token → n_layer × 2 × kv_dim × sizeof(f32)
↑ ↑ ↑
層數 K + V對 TinyLlama (n_layer=22, kv_dim=256, f32) 來說:
22 × 2 × 256 × 4 = 45 KB per token
對 Llama-2 7B (n_layer=32, kv_dim=4096) 來說:
32 × 2 × 4096 × 4 = 1 MB per token
這就是為什麼長 context 推論這麼燒記憶體——它不是 token 數乘上一個小數字而已,而是 token 數乘上 n_layer × 2 × kv_dim。一個 8K context 的 7B 模型,光 KV cache 就要 8 GB,比模型本身還大,想想真是有點誇張。
要是 KV cache 真的能「per-token only」(所有層共用),同樣的 8K context 只要 32 MB 就夠了——只是那樣的東西早就不是 Transformer 了。這個「per-layer」的代價,說到底就是 Transformer 之所以是 Transformer 的代價。
「per-layer × per-token」 是 cache 的最小完備形式
我想可以這樣收尾:要讓 attention 在每一層都能正確算出來,你需要的最少資訊就是:
| 維度 | 為什麼需要 |
|---|---|
| per-layer | 每層 attention 看的是不同的轉換空間,不能共用 |
| per-token | 每個 token 在每層的特徵都不同(causal mask 之外都會被未來查到) |
| K + V | attention 公式裡的兩個輸入(Q 不需要 cache,因為它只在當前 step 用一次) |
把這三個維度切片乘起來,就是 [n_layer][n_ctx][2][kv_dim] 的 4D 張量——這就是 KV cache 的最小完備形狀。少存任何一個維度都會破壞 attention 的語義;多存任何一個維度都是純浪費(反正其他資訊都能重新算回來)。
好,理論講夠了,回到我的 Rust 程式碼吧:
kcache: Vec<Vec<f32>>, // [n_layer][n_ctx × kv_dim]
vcache: Vec<Vec<f32>>, // [n_layer][n_ctx × kv_dim]
外層 Vec 是 layer 維度,內層那個扁平陣列裡的 n_ctx × kv_dim 就是「token position × K/V 寬度」。這兩個 Vec<Vec<f32>> 加起來,剛好就是 attention 完整需要的最小狀態,一塊不多、一塊不少。
Runner struct 的記憶體佈局
pub struct Runner<'a, 'm> {
model: &'m LlamaModel<'a>,
rope_style: RopeStyle,
kcache: Vec<Vec<f32>>, // [n_layer][n_ctx * kv_dim]
vcache: Vec<Vec<f32>>,
pub pos: usize,
// 預先配好的 scratch buffer
x: Vec<f32>, xb: Vec<f32>, xb2: Vec<f32>,
hb: Vec<f32>, hb2: Vec<f32>,
q: Vec<f32>, att: Vec<f32>,
logits: Vec<f32>,
}幾個重要的設計決定:
兩個生命週期 'a 和 'm
'a是 mmap 的生命週期。'm是LlamaModel的生命週期,且必須'm: 'a(model 不能比 mmap 活得更久)。
Runner 借用 &'m LlamaModel<'a>,自己一份模型資料都沒持有,所以建構這個結構超便宜——裡頭所有大型 buffer 都只是 scratch 用途而已。
KV cache 是 Vec<Vec<f32>> 而不是 Vec<f32>
每層一個獨立的 Vec,而不是把所有層拼在同一個 Vec 裡。這是為了:
- 生命週期分離:不同層的 cache 可以獨立管理(雖然目前我沒這麼做)。
- 記憶體大小靈活:理論上不同層可以有不同的 KV dim(例如某些 MoE 變體),雖然 Llama 沒有這個需求。
代價就是多一層間接(access 得走 kcache[l][...])。不過對 LLM 推論來說,這點開銷根本可以無視啦。
一堆 scratch buffer 一次配好
x: vec![0.0; n_embd], // 主 hidden state
xb: vec![0.0; n_embd], // 暫存 a (attention/FFN 內部)
xb2: vec![0.0; n_embd], // 暫存 b (殘差用)
hb: vec![0.0; n_ff], // FFN gate 的中間結果
hb2: vec![0.0; n_ff], // FFN up 的中間結果
q: vec![0.0; n_embd], // 當前 token 的 Q
att: vec![0.0; n_head * n_ctx], // attention scores
logits: vec![0.0; vocab], // 輸出 logits
這些 buffer 全在 new() 裡一次配好,整個推論過程就再也不 allocate 了。每次 forward() 進來,都是直接覆寫這些 buffer。沒有 GC、沒有 alloc 壓力,乾淨俐落。
順便對比一下 PyTorch/Candle 的做法:每個 op 回傳一個新 Tensor,靠 reference counting 回收。這對訓練來說很合理(autograd 本來就得保留中間值),但對只做 inference 的 runner 而言,預配 + 覆寫顯然更精簡、也更可預測。
演算法核心:forward 的整體結構
forward(token) 的結構是這樣的:
K/V 直接寫進 cache] C2 --> C3[2c. RoPE on Q & current K] C3 --> C4[2d. multi-head attention
Q · K^T / softmax / · V] C4 --> C5[2e. wo projection + 殘差] C5 --> C6[2f. ffn_norm + SwiGLU + 殘差] C6 --> C C --> D[3. final norm] D --> E[4. lm_head: logits]
把每個 block 內部攤開來看,會更有感覺:
// attention norm
rmsnorm(&mut self.xb, &self.x, &layer.attn_norm, cfg.rms_eps);
// qkv projections
matvec(&mut self.q, &layer.wq, &self.xb);
let krow = &mut kc[pos * kv_dim..(pos + 1) * kv_dim];
let vrow = &mut vc[pos * kv_dim..(pos + 1) * kv_dim];
matvec(krow, &layer.wk, &self.xb); // K 直接寫進 cache
matvec(vrow, &layer.wv, &self.xb); // V 也是
// RoPE
apply_rope(&mut self.q, pos, head_dim, ...);
apply_rope(krow, pos, head_dim, ...);
// attention
for h in 0..cfg.n_head {
let kv_head = h / gqa; // GQA 對應
// 算 attention scores
// softmax
// 加權 V
}
// 輸出投影 + 殘差
matvec(&mut self.xb2, &layer.wo, &self.xb);
add_inplace(&mut self.x, &self.xb2);
// FFN: x = x + Wdown(silu(Wgate(norm)) * Wup(norm))
rmsnorm(&mut self.xb, &self.x, &layer.ffn_norm, cfg.rms_eps);
matvec(&mut self.hb, &layer.w_gate, &self.xb);
matvec(&mut self.hb2, &layer.w_up, &self.xb);
for i in 0..self.hb.len() {
self.hb[i] = silu(self.hb[i]) * self.hb2[i];
}
matvec(&mut self.xb2, &layer.w_down, &self.hb);
add_inplace(&mut self.x, &self.xb2);演算法核心一:K/V 直接寫進 cache
let krow = &mut kc[pos * kv_dim..(pos + 1) * kv_dim];
let vrow = &mut vc[pos * kv_dim..(pos + 1) * kv_dim];
matvec(krow, &layer.wk, &self.xb);
matvec(vrow, &layer.wv, &self.xb);這短短四行藏著一個我蠻喜歡的設計選擇:K/V 計算的輸出 buffer 直接就是 cache 的某一行,而不是先算到 scratch buffer、再 copy 進 cache。這就省下了一次 n_embd × 4 bytes 的記憶體拷貝。
matvec 的簽章 fn matvec(out: &mut [f32], ...) 接受任何 mutable slice,cache 的 slice 自然也塞得進去。而 Rust 的借用檢查器會幫你確認 kc 跟 xb、q 那些 buffer 不會 alias——這種事不用自己提心吊膽,編譯器顧得很周到。
演算法核心二:RoPE 的兩階段套用
apply_rope(&mut self.q, pos, head_dim, ...);
apply_rope(krow, pos, head_dim, ...);注意喔,這裡 RoPE 是套在 Q 和當前的 K 上,不是套在 cache 裡所有 K 上。為什麼這樣就夠了?
因為 RoPE 帶的是「絕對位置」資訊(每個 K_t 套用 t 對應的 RoPE),而每個 K_t 只要在它被算出來的那次 forward call 裡套一次就完事了。等到後面要做 attention,cache 裡的 K_t 早就帶著 RoPE 了,根本不用再套一遍。
至於 V 為什麼不用 RoPE?因為 attention 對 V 的處理是線性 weighted sum,position 資訊在 Q·K 那一步就已經注入進去了,V 這邊不需要再湊一腳。
演算法核心三:GQA 的 head 對應
for h in 0..cfg.n_head {
let kv_head = h / gqa; // GQA 對應
let q_off = h * head_dim;
let q = &self.q[q_off..q_off + head_dim];
let att = &mut self.att[h * cfg.n_ctx..h * cfg.n_ctx + (pos + 1)];
for (t, score) in att.iter_mut().enumerate() {
let k_off = t * kv_dim + kv_head * head_dim;
let k = &self.kcache[l][k_off..k_off + head_dim];
let mut s = 0.0f32;
for i in 0..head_dim {
s += q[i] * k[i];
}
*score = s * scale;
}
softmax(att);
// ... 加權 V
}GQA 的小聰明:在 vanilla MHA(multi-head attention)裡,每個 query head 都配一個獨立的 K/V head。但 GQA 把 query head 分組,讓每組共用同一個 K/V head:
n_head = 32, n_head_kv = 4 → gqa = 8
query head 0..7 → K/V head 0
query head 8..15 → K/V head 1
query head 16..23 → K/V head 2
query head 24..31 → K/V head 3kv_head = h / gqa 就是在做這個對應。這招把 KV cache 的大小從 n_head × head_dim × n_ctx 一口氣縮成 n_head_kv × head_dim × n_ctx,對長 context 推論實在是太關鍵了——畢竟前面也講過,KV cache 才是長 context 推論真正的記憶體大胃王。
演算法核心四:attention scores 的記憶體佈局
let att = &mut self.att[h * cfg.n_ctx..h * cfg.n_ctx + (pos + 1)];att buffer 是 [n_head, n_ctx] 展平後的結果。head h 的 attention scores 佔 att[h*n_ctx .. h*n_ctx + n_ctx],但每次 forward 我們其實只用得到前 pos+1 個(因為只有從過去到現在的 token 才有 score 嘛)。
這樣設計的好處是:buffer 固定大小(一開始就照 worst case n_ctx 配好),不必動態 resize。代價就是有些空間會閒著沒用到——對 TinyLlama 來說是 n_head * n_ctx * 4 = 32 * 2048 * 4 = 256 KB,這點浪費我覺得完全可以接受。
Rust 用法:借用檢查器與切片
這個 forward pass 裡有大量 &mut 切片操作:
let kc = &mut self.kcache[l];
let krow = &mut kc[pos * kv_dim..(pos + 1) * kv_dim];
matvec(krow, &layer.wk, &self.xb);注意我這裡得先把 &mut self.kcache[l] 抓進局部變數 kc,再去對 kc 切片。為什麼要這樣?因為 Rust 的借用檢查器希望看到「我們對 self.kcache 的 mutable 借用」是乾乾淨淨、一目了然的。
這大概是 Rust 寫多 mutable buffer 程式碼最常踩的坑之一了。假如你偷懶寫成:
matvec(&mut self.kcache[l][pos*kv_dim..(pos+1)*kv_dim], &layer.wk, &self.xb);而附近又還有 &self.x、&self.xb 這類不可變借用,編譯器十之八九會跳出來抗議「不能同時 mutable 又 immutable 借用 self」。先把 mutable slice 抽出來、再做後續操作,就是縮小借用範圍的標準解法。踩過幾次就記住了。
split_at_mut 這種招式
更刁鑽的場合可能就得搬出 split_at_mut 了——比方說要同時拿到 K cache 和 V cache 的 mutable ref:
let (kc, vc) = (&mut self.kcache[l], &mut self.vcache[l]);
// 這個寫法借用檢查器會接受,因為 self.kcache 和 self.vcache 是不同的 field
但 kcache[l] 和 vcache[l] 是不同的 field,所以可以同時 mutable borrow。如果是同一個 Vec 的兩段切片,就需要 split_at_mut:
let (left, right) = my_vec.split_at_mut(mid);
// left 和 right 都是 &mut [T],但編譯器知道它們不重疊
效能最佳化空間
runner.rs 可以說是整個專案效能熱點的大本營,能再壓榨的空間還多得很:
1. Attention 的 head-parallel
我目前的 attention loop 是 sequential for h in 0..n_head:
for h in 0..cfg.n_head {
// 算 attention scores、softmax、加權 V → 寫進 self.xb 的某段
}每個 head 寫進 self.xb 的不同 slice,是 disjoint 的。可以用 rayon 的 chunks_exact_mut 平行:
self.xb.par_chunks_exact_mut(head_dim).enumerate().for_each(|(h, out)| {
// 算這個 head 的 attention
});不過 att buffer 是共享的(self.att),平行版得替每個 thread 準備獨立的 attention scratch。再加上 head 數量通常不大(32-64),fork-join 的 overhead 搞不好就把省下來的時間吃光了——這個值得 benchmark 一下,但我猜不一定划算。
2. FlashAttention 風格的 fused attention
我現在的 attention 是「先算完所有 score、再 softmax、再加權 V」三步走。FlashAttention 則把這三步 fuse 成一趟 streaming pass,記憶體占用從 O(n_ctx) 降到 O(head_dim),而且對 cache 超友善。只是實作複雜不少——這已經是 candle/ggml 那種等級的優化了,不是隨手就能塞進來的。
3. KV cache 的量化
KV cache 是長 context 的記憶體大頭,這點講到都快變口頭禪了。對 7B 模型、2k context、F32 cache 來說,每層 cache 是 2 * 2048 * 1024 * 4 = 16 MB,32 層就是 512 MB。要是把 cache 也量化成 Q8(1 byte/element),就能砍到 128 MB。llama.cpp 早就支援這招了(-ctk q8_0 -ctv q8_0)。
4. Prefill batching
目前 prefill 是 sequential 的——一個 token 跑一次 forward。但 prompt 的 token 其實全都已知,大可以拼成一個矩陣做 batched forward:
seq forward: O(L * (matvec for each layer))
batch forward: O(matmul of [L, n_embd] for each layer)GEMM 的 cache locality 比 GEMV 好太多了,prefill 速度衝個 5-10× 不是夢。代價是 attention 要重新設計(causal mask 從可有可無變成非寫不可)、KV cache 的寫入方式也得改(一次寫 L 個 token)。算是中量級的工程,不算小但也沒有嚇人。
5. Speculative decoding
更進階的玩法:拿一個小小的 “draft model” 一次先猜 K 個 token,再用主模型一口氣 verify。要是 K 個全猜中,那等於只花一次 forward 的時間就吐出 K 個 token,賺翻了。這在 inference latency 上算是革命性的技術,不過代價是要兩個模型協同,沒那麼好伺候。
6. Continuous batching
如果哪天 runner 要同時 serve 一堆請求,就可以做 continuous batching——不同請求的 token 塞進同一個 batch,還能動態加進來、抽出去。但這要求 KV cache 變成「per-request」的,記憶體管理瞬間複雜好幾級。vLLM 就是這條路線的代表作。
一個有趣的權衡:為什麼不寫得更 functional?
其實我大可以把 forward 寫得更 functional 一點,比方把每個 layer 包成一個閉包、再用 fold 串起來,看起來會很「漂亮」。但說真的,現在這個命令式的寫法反而更好讀——你一眼就能看出每一步在改哪個 buffer、用哪個權重,毫不囉嗦。
LLM 推論的程式碼有個很特別的地方:90% 的時間都在做矩陣運算,剩下 10% 才是 control flow。寫得太 functional,反而會把 control flow 的成本擺到台前,把真正重要的計算給遮住了。我不建議盲目追求函式式的優雅,該樸實的時候就樸實吧。
總結:runner.rs 的角色
- 概念上:它就是把 token id 映射到 logits 的那個單一函式,順手把 KV cache 和層層遞進都管起來。
- 實作上:fixed-size scratch buffers + sequential layer loop + 手工 KV slice 操作,沒什麼花俏的。
- 設計上:把所有複雜性都收進這一個檔案裡,讓其他檔案(
ops、tensor)保持乾淨——我覺得這種「把髒活集中在一處」的取捨頗值得參考。
寫到這裡,整個推論流程的骨架其實已經攤在眼前了。下一篇我們往前再退一步,聊聊 tokenizer.rs——看看 byte 流是怎麼變成 token id 的,那又是另一個有趣的小宇宙。
系列文章:
- tiny-llm-runner 介紹
- (1) config ・ (2) dequant ・ (3) tensor ・ (4) model ・ (5) ops
- (6) runner.rs(本篇)