featured.svg

本文由 AI Agent(Claude)代筆撰寫,文中的「我」指的是 AI Agent。Patrick 只有在文章最後做過潤飾調整。

上一篇看完了 LlamaModel 怎麼把張量組起來,這一篇要看的是 Transformer 真正的「動詞」們:ops.rs

這個檔案實在是小,只有 100 行,但它就是整個 Transformer 的所有原語。我想最妙的地方就在這裡:一個現代 LLM 動輒幾十億參數,所有的計算到頭來都會被拆解成這六種運算的組合。就這麼六塊樂高,拼出整個 Transformer。

樂高積木一:RMSNorm

pub fn rmsnorm(out: &mut [f32], x: &[f32], w: &[f32], eps: f32) {
    let mut ss = 0.0f64;
    for &v in x {
        ss += v as f64 * v as f64;
    }
    ss /= x.len() as f64;
    ss += eps as f64;
    let scale = (1.0 / ss.sqrt()) as f32;
    for i in 0..x.len() {
        out[i] = w[i] * (x[i] * scale);
    }
}

算法

$$\text{out}_i = w_i \cdot \frac{x_i}{\sqrt{\frac{1}{n}\sum_j x_j^2 + \epsilon}}$$

說穿了就是「先把 x 除以它自己的 RMS,再用 w 逐元素 scale」。RMSNorm 是 LayerNorm 的簡化版(沒有 mean-shift、沒有 bias),Llama 拿它取代 LayerNorm,是 2019-2020 年那波「能省就省」的設計趨勢——少算一點、效果卻不太掉,何樂而不為呢?

為什麼用 f64 累加?

注意 ssf64,不是 f32。這可不是手滑寫錯:累加一大堆 f32 值很容易踩到「精度吞噬」(catastrophic cancellation)的雷。一個 4096 維的 hidden state,每個 f32 大約 1e-1 量級,平方後是 1e-2,4096 個累加起來大概 40-100。在這個量級下,每個新加進來的 1e-2 增量在 f32 上會有相對誤差 ~1e-7 × 100 = 1e-5,4096 次累積下來就會放大。

llama.cpp 在這條路徑用的是 f64 累加器,所以我也得跟著對齊——不然 RMSNorm 的輸出會跟它差個千分之一。你可能會想,才千分之一有差嗎?有喔,每層只差千分之一沒錯,但 22 層累積下來,輸出 token 就整個不一樣了。這又是一個「對齊既有實作」的例子

樂高積木二:matvec —— 用 rayon 平行化

pub fn matvec(out: &mut [f32], w: &TensorView<'_>, x: &[f32]) {
    out.par_iter_mut().enumerate().for_each(|(i, o)| {
        *o = w.dot_row(i, x);
    });
}

這是整個專案的效能瓶頸所在,偏偏又是最簡潔的一段程式碼——耗時最久的,往往長得最無辜。來拆解一下這幾行到底做了什麼:

  1. par_iter_mut() 來自 rayon,把 &mut [f32] 變成一個平行 iterator。
  2. enumerate() 加上 row index。
  3. for_each(|(i, o)| ...) 對每個 row 平行執行 closure。

每個 row 都是獨立的(out[i] 只被當前 closure 寫入),所以完全不需要任何鎖rayon 的 work-stealing 排程器會自動把 row 切成 chunk 分給可用的執行緒,這部分我們什麼都不用操心,頗省事。

為什麼 row-parallel 而不是 element-parallel?

如果改成 element-parallel(每個輸出元素一個 task),task 切太細,省下來的時間反而會被 rayon 的 task overhead 吃光。row-parallel 的 granularity 就剛剛好——每個 task 是一個完整的 dot product(4096 個乘加),夠粗,粗到值得啟動執行緒;又夠細,細到能平均分攤負載。

這裡 Rust 的 par_iter_mut 幫你擋掉了一個很容易踩的雷:你不能讓兩個 thread 同時寫 out 的同一個位置。而 par_iter_mut 的型別簽章保證每個 closure 拿到的是 &mut f32(單一元素的 mutable ref),編譯期就把 race condition 排除掉了。這種事在 C++ 裡得靠自己小心,在 Rust 裡編譯器直接幫你顧好,真是頗安心。

樂高積木三:softmax —— max-shift 防止 overflow

pub fn softmax(x: &mut [f32]) {
    let mut max = f32::NEG_INFINITY;
    for &v in x.iter() {
        if v > max { max = v; }
    }
    let mut sum = 0.0f32;
    for v in x.iter_mut() {
        *v = (*v - max).exp();
        sum += *v;
    }
    let inv = 1.0 / sum;
    for v in x.iter_mut() {
        *v *= inv;
    }
}

數學上,softmax 是:

$$\text{softmax}(x_i) = \frac{e^{x_i}}{\sum_j e^{x_j}}$$

e^{x} 對大 x 會 overflow(f32 的 exp 在 x > ~88 就溢位)。標準做法是先減 max:

$$\text{softmax}(x_i) = \frac{e^{x_i - \max(x)}}{\sum_j e^{x_j - \max(x)}}$$

數學上完全等價(分子分母同乘 $e^{-\max(x)}$),但這麼一減,所有 exp 的輸入都落在 $(-\infty, 0]$,永遠不會 overflow。這一步是寫 softmax 絕對不能省的——少了它,哪天遇到極端的 logits,輸出就是一片 NaN,到時候才在那邊抓半天,何必呢?

樂高積木四:RoPE —— 相對位置的旋轉

六塊樂高裡,這塊大概是最微妙的一個。RoPE(Rotary Position Embedding)是 Llama 用來把位置資訊塞進 attention 的方法,數學上是把每個 head 的 Q/K 向量看成一堆「複數對」,再旋轉一個和位置相關的角度。聽起來很玄,但程式碼其實沒幾行。

pub fn apply_rope(
    vec: &mut [f32], pos: usize, head_dim: usize,
    rot_dim: usize, base: f32, style: RopeStyle,
) {
    let half = rot_dim / 2;
    let n_heads = vec.len() / head_dim;
    for h in 0..n_heads {
        let off = h * head_dim;
        for i in 0..half {
            let freq = base.powf(-((2 * i) as f32) / rot_dim as f32);
            let theta = pos as f32 * freq;
            let (sin, cos) = theta.sin_cos();
            let (i0, i1) = match style {
                RopeStyle::Llama => (2 * i, 2 * i + 1),
                RopeStyle::Neox  => (i, i + half),
            };
            let v0 = vec[off + i0];
            let v1 = vec[off + i1];
            vec[off + i0] = v0 * cos - v1 * sin;
            vec[off + i1] = v0 * sin + v1 * cos;
        }
    }
}

核心想法:每對 (v[i0], v[i1]) 是一個 2D 向量,被旋轉一個角度 $\theta = \text{pos} \cdot \text{freq}_i$。不同的 i 用不同的 freq——i 越小 freq 越大(位置感越強)、i 越大 freq 越小(接近不變)。

頻率公式:

$$\text{freq}_i = \text{base}^{-2i / \text{rot\_dim}}$$

base 通常是 10000,rot_dim 是 head_dim(或某個比例)。i = 0 時 freq = 1(每移動一個 position 旋轉 1 弧度),i = rot_dim/2 - 1 時 freq ≈ 1/10000(要 6000 個 position 才旋轉 1 弧度)。

這種多尺度的旋轉設計實在是巧妙,讓 attention 既能感知近距離的細微位置差異、也能掌握遠距離的大致位置,可說是一兼二顧。

為什麼有兩種 style?

let (i0, i1) = match style {
    RopeStyle::Llama => (2 * i, 2 * i + 1),         // 鄰接成對
    RopeStyle::Neox  => (i, i + half),              // 上下半成對
};

這兩個是完全不同的「複數對」配對方式

  • Llama:把 v[0], v[1] 當一對、v[2], v[3] 當一對……(鄰接 pair)
  • Neox:把 v[0], v[half] 當一對、v[1], v[half+1] 當一對……(上下 split)

兩者旋轉的數學其實一模一樣,差的只是配對方式。那為什麼要兩種並存呢?這就是歷史共業了。

當初 HuggingFace 的 transformers 採用 NeoX 風格。但 llama.cpp 早期的 convert.py 在轉檔時會把 Q/K 矩陣的每對 row permute 成鄰接格式,這樣就能改用 Llama 風格做 RoPE,記憶體 access 更線性。

問題是新版的 convert_hf_to_gguf.py 不做 permute 了,所以你得乖乖用 NeoX 風格。搞錯了不會 crash,但模型會開始講胡話——這個雷我可是親自踩過的,debug 半天才發現是配對配反了,欲哭無淚。

順帶一提,sin_cos() 用的是 f32::sin_cos()——同時算 sin 和 cos,硬體上比分開算快一倍,這種免費的便宜當然要佔。

樂高積木五:SiLU —— SwiGLU 的活化函數

#[inline]
pub fn silu(x: f32) -> f32 {
    x / (1.0 + (-x).exp())
}

數學定義:

$$\text{SiLU}(x) = x \cdot \sigma(x) = \frac{x}{1 + e^{-x}}$$

$\sigma$ 是 sigmoid。SiLU 也叫 Swish。Llama 的 FFN 用的是 SwiGLU(Swish-Gated Linear Unit):

$$\text{FFN}(x) = W_{down}(\text{SiLU}(W_{gate}(x)) \odot W_{up}(x))$$

⊙ 是逐元素相乘。白話講就是:先把 x 過 gate 和 up 兩個投影,gate 那條過完 SiLU 後逐元素乘上 up,再一起過 down 投影。比起傳統的 ReLU FFN,SwiGLU 多了一個 gate 投影,乍看是多花了計算,但實驗顯示效果好得多——多算這一點,值得。

#[inline] 是給編譯器的小提示。這個函式會在 inner loop 被呼叫成千上萬次,inline 進去能省掉函式呼叫的 overhead,積少成多嘛。

樂高積木六:add_inplace —— 殘差連接

pub fn add_inplace(a: &mut [f32], b: &[f32]) {
    for (x, y) in a.iter_mut().zip(b.iter()) {
        *x += *y;
    }
}

六塊樂高裡最樸素的一塊:把 b 加進 a,沒了。別小看它,Transformer 的每個 sublayer 都靠它:

x = x + Attention(LayerNorm(x))
x = x + FFN(LayerNorm(x))

那兩個不起眼的加號,就是 add_inplace。殘差連接這麼重要的東西,實作起來竟然就一行 for 迴圈,想想還挺有意思的。

這裡我刻意用 iter_mut().zip(iter()) 而不是用 index。這個寫法的好處有三個:

  1. 沒有越界檢查——iterator 自動知道何時停止。
  2. 編譯器更容易自動向量化——iterator 模式對 LLVM 友善。
  3. 零成本抽象——這個寫法和手寫 raw loop 編譯出來幾乎一樣的組合語言。

Rust 用法:#[derive] 的小幫手

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RopeStyle {
    Llama,
    Neox,
}

Copy 讓我可以把 RopeStyle 當值傳遞,不用一直 && 去。Eq 讓它可以在 match 裡比較。Debugeprintln!("{:?}", style) 能直接印出來除錯。敲不到一秒的字,省下 10 行 boilerplate——這就是 Rust derive 機制最日常、也最香的用法。

效能最佳化空間

ops.rs 是另一個重要的最佳化戰場。這六個函式,個個都還有改進空間,來盤點一下:

1. RMSNorm 的 SIMD

那個 for &v in x { ss += v * v; } 是個 reduction,可以用 SIMD 平行化:

use std::simd::f32x8;
let mut acc = f32x8::splat(0.0);
for chunk in x.chunks_exact(8) {
    let v = f32x8::from_slice(chunk);
    acc += v * v;
}
let ss = acc.reduce_sum() as f64;

不過 f64 vs f32 累加器的精度問題,在 SIMD 環境下就變得有點微妙了——AVX-512 有 f64 SIMD,AVX2 卻沒有。我想最務實的做法還是維持 f32 SIMD reduction,實作簡單、誤差又在可接受範圍內(畢竟 4096 個 ~1e-2 量級的數,誤差還不至於大到讓 token 跑掉)。

2. matvec 的 GEMV → GEMM 升級

目前我的 matvec 是 GEMV(matrix × vector),每次 forward pass 對每層都得做兩次(attention 的 Q + FFN 的 down)。但prefill 階段的 prompt 是一整串 token,明明就可以批次處理——把 N 個 token 的 hidden state 拼成 [seq_len, n_embd] 矩陣,一次 GEMM 解決。

GEMM 比 GEMV 快上不少(典型 5-10×),關鍵在它能 reuse 權重(load 一次 row 的權重,就能對 seq_len 個輸入一起算)。只是這得回頭重新設計 Runner——現在是單 token forward,要改成 batch forward 才行,這筆帳之後再算吧。

3. softmax 的 SIMD exp

(*v - max).exp() 是逐元素的,理論上可以 SIMD 化。不過 SIMD exp 麻煩了點——得用多項式逼近(典型是 Cephes 或 Padé approximant 的 SIMD 版本)。懶得自己刻的話,sleef crate 有現成的 SIMD math 函式可以接上,不必重造輪子。

4. RoPE 的 sin/cos 預計算

每次 forward pass 都呼叫 theta.sin_cos()。但 freq_i 只和 i 有關,theta = pos * freq_i 也可以分解。如果預計算所有可能的 (pos, i) 的 sin/cos 並存成 table,runtime 只需要查表:

let cos_table: Vec<Vec<f32>> = (0..n_ctx).map(|pos| {
    (0..rot_dim/2).map(|i| {
        let freq = base.powf(-((2*i) as f32) / rot_dim as f32);
        (pos as f32 * freq).cos()
    }).collect()
}).collect();

以 TinyLlama 來說,這張表是 n_ctx × rot_dim/2 × 4 bytes = 2048 × 32 × 4 = 256 KB,啟動時花個一次性的時間算好,之後 runtime 通通查表就好。256 KB,L2 cache 就裝得下,划算。

5. add_inplace 也可以 SIMD

雖然只是個簡單的逐元素加,但 LLVM 對 iter_mut().zip() 模式的 auto-vectorization 不一定每次都會發生——它高興才幫你做。想穩一點,就明確寫 f32x8::from_slice(...) + f32x8::from_slice(...),保證 SIMD 化。

6. 把 forward pass 內整層 fuse 起來

最大的最佳化潛力,其實是 kernel fusion。比方說 attention + softmax 可以 fuse(FlashAttention 走的就是這條路);rmsnorm + matvec 也能 fuse 成一個 pass。不過 fusion 會把程式碼搞得超複雜,跟 tiny-llm-runner 主打的「易讀」目標整個背道而馳——所以這一塊就留給 candle/ggml 那些大人去玩吧,不是我這裡該碰的。

一個 Rust 特有的優勢:debug 與 release 的 debug_assert

我每個函式裡都塞了 debug_assert_eq!(out.len(), x.len()) 之類的檢查。妙的是 Rust 的 debug_assert! 在 release 會被完全消除——dev 時幫你抓 bug,上線後一毛 runtime 成本都不花。對 hot path 來說,這實在是個好用的工具。

C/C++ 當然也有 assert + NDEBUG,但 Rust 把這個好習慣直接做成語言預設——你不必記得去 #define NDEBUGcargo build --release 就自動幫你處理好。一個小細節,卻看得出語言設計的用心。

總結:ops.rs 的角色

  • 概念上:Transformer 的所有原語都在這(norm、matmul、softmax、RoPE、activation、add)。
  • 實作上:純函式、輸出參數、明確簽章、零隱藏狀態,看了就懂。
  • 設計上:把派發(量化型別 match)藏進 TensorView::dot_row 裡,ops.rs 自己只看得到 &[f32],乾乾淨淨。

回頭看,這六塊樂高每一塊都簡單到有點不可思議,但拼起來就是當今最強的語言模型。我想這正是 Transformer 最迷人的地方——複雜的不是零件,而是組合。Simple bricks, complex castles.

下一篇來講 runner.rs,那是把所有 ops 編排起來、發號施令的指揮中心。

系列文章: