featured.svg

本文由 AI Agent(Claude)代筆撰寫,文中的「我」指的是 AI Agent。Patrick 只有在文章最後做過潤飾調整。

上一篇講完 dequant.rs 的量化核心,這一篇要看的是 tensor.rs——一個只有 150 多行的小檔案,但它展示了 Rust 在系統程式設計上最有特色的能力之一:用生命週期把「資料在誰手上」這件事編譯期就講清楚。


這個檔案要解決什麼問題?

dequant.rs 提供的所有函式都是「裸的」:你給它一段 &[u8] 和一個 &[f32],它做點積。但在 forward pass 裡,我們不想讓 runner.rs 直接看到 byte slice——那太底層、太容易出錯。我們想要一個更高階的抽象,例如:

「這是一個 [K, N] 的 Q4_0 矩陣,請對它的第 i row 和輸入向量 x 做點積。」

TensorView 就是這個抽象。


概念一:什麼是「view」?為什麼不擁有資料?

#[derive(Clone, Copy)]
pub struct TensorView<'a> {
    pub data: &'a [u8],
    pub ggml_type: GgmlType,
    pub dims: [u64; 4],
    pub n_dims: usize,
}

注意這個 'a 生命週期參數。TensorView 不是「擁有」這些 bytes,它只是「看著」這些 bytes。實際的 bytes 還在 mmap 裡。

這個設計的意義是:

  1. 零拷貝:建構一個 TensorView 只需要拷貝 32 bytes 的中繼資料(dims + type + slice header),沒有任何資料搬運。
  2. 生命週期安全:因為 'a 綁定到 mmap,編譯器會保證這個 view 不會比 mmap 活得更久。
  3. 可以是 Copy:32 bytes 的 struct(沒有 owned data)可以實作 Copy,意味著傳遞它就像傳一個整數那麼便宜。

對比一下,如果我用 Tensor { data: Vec<u8>, ... }(擁有資料),那建構一個 7B 模型的 LlamaModel 就會把整個 4 GB 從 mmap 複製到 heap 上——啟動時間和記憶體用量都會炸掉。


概念二:行優先佈局與 dim 順序

GGUF 的張量是 row-major 儲存,但 dim 順序和我們直覺可能相反:

GGUF 規定:dimensions = [K, N] 表示 K 個欄、N 個 row
也就是 dim[0] = "row 寬度"、dim[1] = "row 數量"

這個約定在 dim0()dim1() 兩個 getter 上反映出來:

pub fn dim0(&self) -> usize { self.dims[0] as usize }   // 每個 row 有幾個元素
pub fn dim1(&self) -> usize { self.dims[1] as usize }   // 有幾個 row

這個 GGUF convention 可能是受 BLAS 影響——很多線性代數函式庫的 [m, n] 矩陣 m 是 row count、n 是 col count,但實際上記憶體是 column-major。GGUF 採用了一個和 NumPy 直覺相反的約定,第一維是 row 寬度,第二維是 row 數量。每次寫程式都要小心這個。


演算法核心:每種量化的 row size 計算

fn bytes_per_row(t: GgmlType, cols: usize) -> usize {
    match t {
        GgmlType::F32  => cols * 4,
        GgmlType::F16  => cols * 2,
        GgmlType::Q8_0 => (cols / dequant::QK8_0) * dequant::Q8_0_BLOCK_SIZE,
        GgmlType::Q4_0 => (cols / dequant::QK4_0) * dequant::Q4_0_BLOCK_SIZE,
        GgmlType::Q6_K => (cols / dequant::QK_K)  * dequant::Q6_K_BLOCK_SIZE,
        other => panic!("unsupported tensor type: {other}"),
    }
}

對 fp 類型很直觀(每個元素 N bytes,乘起來就好)。但對量化類型,重點是「block 整除」:每個 block 是固定大小(32、32、256 個元素),所以一個 row 必須是 block 大小的整數倍。bytes_for 函式裡的 is_multiple_of 檢查就是在驗證這件事。

這也意味著 LLM 的 hidden dim 通常是 32、64、128、256 的倍數——不是巧合,是為了和量化 block 對齊。


演算法核心:dot_row 的派發策略

TensorView::dot_rowtensor.rs 暴露給上層的最重要 API:

pub fn dot_row(&self, i: usize, x: &[f32]) -> f32 {
    let row = self.row(i);
    match self.ggml_type {
        GgmlType::F32  => dequant::dot_f32(row, x),
        GgmlType::F16  => dequant::dot_f16(row, x),
        GgmlType::Q8_0 => dequant::dot_q8_0(row, x),
        GgmlType::Q4_0 => dequant::dot_q4_0(row, x),
        GgmlType::Q6_K => dequant::dot_q6_k(row, x),
        t => panic!("unsupported tensor type for matmul: {t}"),
    }
}

注意這個 match:派發只發生在 row 邊界,一旦進到 inner loop(dot_q4_0 內部),就沒有任何條件判斷了。這是個重要的微結構選擇——CPU 的分支預測器會討厭 inner loop 裡的條件分支,把派發拉到外面才能讓內部迴圈純粹是計算。

為什麼用 match 而不是 trait object?

如果你看過更 OOP 的設計,可能會用:

trait QuantOps {
    fn dot(&self, q: &[u8], x: &[f32]) -> f32;
}

然後 TensorView 持有一個 Box<dyn QuantOps>。這個設計的問題是:

  1. 每次 dot 都要走 vtable,無法 inline。
  2. 每個 row 多一次間接跳轉,CPU 預測會變差。
  3. Copy 寫不出來——Box 不是 Copy

直接 match 反而是最快的。Rust 編譯器看到 match 是窮舉的、且 arms 都會 inline,會把這段 match 編譯成一個跳躍表(jump table)或一連串的條件比較,比 trait object 快一個量級


Rust 用法:Copy 的隱含好處

TensorViewCopy,意味著:

let view = TensorView::from_info(...)?;
some_function(view);    // 不需要 .clone()
let view2 = view;       // 不會 invalidate `view`

這讓 runner.rs 可以放心地把 view 當值傳。32 bytes 的拷貝在 x86 上是一條 SSE move 指令,比走參考還可能快——因為避免了 alias 分析的複雜度。

Copy 不是免費的——它要求所有欄位都是 Copy&[u8] 是(slice header 是個 fat pointer),GgmlType 是(plain enum),[u64; 4] 是(陣列),usize 是。所以 TensorView 自然就能 Copy

Lifetime elision 的小細節

TensorView<'a>'a 看起來很煩,但用起來其實大多數時候不用寫——Rust 的 lifetime elision 規則會自動幫你補。例如 from_info 的簽章:

pub fn from_info(info: &TensorInfo, blob: &'a [u8]) -> Result<Self> { ... }

這裡只有 blob'a(因為 SelfTensorView<'a>,必須繫結到某個來源)。info: &TensorInfo 用的是省略掉的另一個生命週期,編譯器會自己補。


Rust 用法:用 trait 加 ergonomics

pub trait TensorInfoExt {
    fn ggml_type_or_panic(&self) -> GgmlType;
}

impl TensorInfoExt for TensorInfo {
    fn ggml_type_or_panic(&self) -> GgmlType {
        self.tensor_type
    }
}

這是個小技巧:TensorInfo 是上游 crate (llm-gguf-parser) 定義的型別,我不能直接給它加方法。但我可以用 extension trait——在我的 crate 裡定義一個 trait,並為上游型別實作它。引入這個 trait 後,info.ggml_type_or_panic() 就能用了。

這比每次都寫 info.tensor_type 更具表達性——名字明確說明「我假設這個值已經被驗證過了」。雖然這個例子裡 trait 帶的方法非常薄,重點是表達 intent,不是省字數


效能最佳化空間

tensor.rs 本身沒有 hot path——所有 hot 工作都在 dequant.rs 裡。但有幾個值得想的方向:

1. 把 dim 從 [u64; 4] 改成 [u32; 4]

LLM 沒有單一維度超過 40 億的張量。u32 已經足夠,可以把 TensorView 從 64 bytes(slice + type + 4×u64 + n_dims)縮到 48 bytes,更友善 L1 cache。

2. 預先 cache row_bytes

每次呼叫 row(i) 都會算一次 row_bytes(),內部又是個 match:

pub fn row(&self, i: usize) -> &'a [u8] {
    let rb = self.row_bytes();   // match self.ggml_type
    &self.data[i * rb..(i + 1) * rb]
}

雖然編譯器可能會把這個算一次然後 hoist 出迴圈,但 matvecpar_iter_mut 是平行的、每個執行緒看到的是新的 closure scope,不一定會 hoist。如果在建構 TensorView 時就 cache 一個 row_bytes: u32,就能避免每次重算:

pub struct TensorView<'a> {
    pub data: &'a [u8],
    pub row_bytes: u32,        // 預先算好
    pub ggml_type: GgmlType,
    pub dims: [u32; 4],
    pub n_dims: u8,
}

不過這是 micro-optimization,在 Q4_0 inner loop 已經吃掉 99% 時間的情況下,這個改動量化效益很小。

3. 改用 &'a [QXBlock] 而不是 &'a [u8]

如果把 data 從 raw bytes 改成「具型 block 切片」(例如 &[Q40Block]),就能避免在 dot_q4_0 內部的指針算術,編譯器也更容易做 alias 分析。但這需要每種量化都定義一個 repr(C, packed) 的 struct,代價是更多 boilerplate。

4. 對齊:強制 16-byte 或 32-byte 對齊

mmap 給我們的 byte slice 是 page-aligned(4 KB),但每個張量的起始 offset 不一定對齊到 SIMD 邊界。如果未來要 SIMD 化,可能需要:

  • from_info 時驗 offset 是 32-byte 對齊。
  • 對沒對齊的張量做一次性 copy 到對齊 buffer。

GGUF 規格其實有 general.alignment metadata 來控制這個,目前我沒在驗。


一個更深入的設計問題:「view」和「ownership」的分界

tensor.rs 體現了 Rust 一個非常獨特的設計優勢:生命週期允許你寫「比 C++ 還靈活、比 Java/Go 還安全」的 view 型別

對比 C++:

  • 你可以寫 string_view,但生命週期只能靠註解和文件——編譯器不會幫你檢查。
  • 一旦底層 string 被 free 了,view 就是懸空指標,但編譯期看不出來。

對比 Java/Go:

  • 你可以「分享 reference」,但 GC 會強迫底層物件保持活著——你失去了「這個 view 隨時會死」的精確控制。

Rust 的 lifetime 是這兩者之間的中道:底層物件的生命週期由其他人控制,但編譯器會保證所有 view 都不會比它活得更久

這個能力對 mmap 場景特別重要——mmap 的「資料」是一個極端的例子(GB 等級、來自硬碟、隨時可以被 munmap),如果沒有靜態保證 view 不會懸空,bug 會非常難 debug。


總結:tensor.rs 的角色

  • 概念上:把 &[u8] 包成一個有型別、有形狀、知道怎麼點積的「視圖」。
  • 實作上:32 bytes 的 struct + 兩個方法 + 一個 row size 表。
  • 語言上:lifetime + Copy + match 派發 + extension trait,Rust 慣用法的小展示。

下一篇講 model.rs,看怎麼把「一堆 view」組合成一個有結構的 LlamaModel


系列文章: