本文由 AI Agent(Claude)代筆撰寫,文中的「我」指的是 AI Agent。Patrick 只有在文章最後做過潤飾調整。
上一篇看完了 tokenizer,這一篇要看的是模型吐出結果之後最後一道工序:怎麼把一堆數字變成「下一個 token」。sampler.rs 不到 100 行,看起來很不起眼,但它實在是決定了 LLM 的「個性」——同一個模型換不同 sampler,生出來的內容可以差很多喔。
概念一:什麼是 logits?怎麼從 logits 拿 token?
forward pass 的最終輸出是 [vocab_size] 個浮點數,叫做 logits。每個位置代表「下一個 token 是 i 的 unnormalized log-probability」,講白話就是「模型覺得這個 token 有多順眼」的分數。
最直接的選法就是 greedy(argmax):
fn argmax(x: &[f32]) -> usize {
let mut best = 0usize;
let mut best_v = f32::NEG_INFINITY;
for (i, &v) in x.iter().enumerate() {
if v > best_v { best_v = v; best = i; }
}
best
}就是無腦挑分數最高的那個。Greedy 的問題是輸出太 deterministic,模型每次都選同一個最高分 token,內容會變得很單調,還很容易卡在迴圈裡跳不出來。
概念二:Temperature —— 給機率分佈加溫
Temperature 是這樣運作的:
let inv_t = 1.0 / self.temperature;
for v in logits.iter_mut() {
*v *= inv_t;
}把所有 logits 除以 temperature。然後過 softmax 拿到機率分佈:
$$P(i) = \frac{e^{\ell_i / T}}{\sum_j e^{\ell_j / T}}$$T 的影響:
- T → 0:分佈會變成尖峰(最大值佔 1,其他 0)→ 等同於 greedy。
- T = 1:原始分佈。
- T → ∞:分佈會變平(每個 token 機率接近 1/vocab)→ 完全隨機。
實務上 T = 0.7 ~ 1.0 大概就是 LLM 寫作的甜蜜點吧——夠隨機讓內容有點變化,又不至於整個脫線講起夢話。
概念三:Top-K —— 只從前 K 個候選裡抽
純 temperature 抽樣有個惱人的地方:它不會幫你排除「明顯不該選」的 token。比方說分佈裡有個位置是「機率 0.001、但其實是個亂碼字元」,溫度一高,它還是有那 0.001 的機會被抽到——然後一個 token 就毀掉一整段文字,前功盡棄。
Top-K 的解法很直接:只保留機率最高的 K 個,其他全部歸零:
if self.top_k > 0 && self.top_k < logits.len() {
let mut indexed: Vec<(usize, f32)> = logits.iter().copied().enumerate().collect();
indexed.select_nth_unstable_by(self.top_k, |a, b| {
b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
});
let cutoff = indexed[..self.top_k]
.iter()
.map(|(_, v)| *v)
.fold(f32::INFINITY, f32::min);
for v in logits.iter_mut() {
if *v < cutoff {
*v = f32::NEG_INFINITY;
}
}
}select_nth_unstable_by:partial sort 的妙用
我這裡沒用 sort,而是用 select_nth_unstable_by。為什麼呢?
sort 是 $O(n \log n)$。但說穿了,我們只關心「前 K 個是哪些」、根本不在乎這 K 個之間誰排前誰排後。select_nth_unstable_by 就是 partial sort:把第 K 個元素放到對的位置,前 K 個都在它前面(內部順序不保證),後面的都在它後面。複雜度平均是 $O(n)$、worst case 才 $O(n \log n)$,比乖乖整個排序快得多。
拿 vocab_size = 32k、top_k = 40 來算,sort 要做 32k × log(32k) ≈ 480k 次比較;partial sort 平均只要 ~32k 次。差了 15 倍耶,這種白吃的午餐不拿白不拿。
找 cutoff 的小技巧
let cutoff = indexed[..self.top_k]
.iter()
.map(|(_, v)| *v)
.fold(f32::INFINITY, f32::min);partial sort 之後 indexed[..K] 就是「最大的 K 個」,只是內部順序未知。我用 fold(f32::INFINITY, f32::min) 撈出它們之中最小的那個——這就是 cutoff。任何 logit 小於 cutoff 的,就請它出局。
順帶一提,我這裡用的是 f32::min(function)而不是 method 版的 min,當作 fold 的參數。f32::min 碰到 NaN 的處理方式是「忽略它」,比較不會出事。
NaN 處理
b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)f32::partial_cmp 因為 NaN 沒辦法比大小,回傳的是 Option<Ordering>。我用 unwrap_or(Ordering::Equal) 把 NaN 當成「相等」來處理——不算漂亮啦,但至少不會 crash。理論上 logits 是不該冒出 NaN 的,不過量化模型碰到某些極端輸入還真有可能生出 NaN 來,所以這個 defensive 的小動作我覺得很值得。
演算法核心:累積分佈抽樣
Softmax + uniform random + cumulative sum:
softmax(logits);
let r = self.next_f32(); // [0, 1)
let mut acc = 0.0f32;
for (i, &p) in logits.iter().enumerate() {
acc += p;
if acc >= r { return i as u32; }
}
(logits.len() - 1) as u32這是經典的 inverse CDF sampling:
- 把機率分佈累加成 CDF:
[p0, p0+p1, p0+p1+p2, ..., 1.0]。 - 隨機抽一個
r ∈ [0, 1)。 - 找第一個 CDF 大於
r的位置就是抽中的 token。
這裡有個容易被忽略的小細節:理論上最後一個 cumulative sum 應該剛好是 1.0,但 floating point error 可能讓它差那麼一點點、略小於 1.0。萬一 r 不偏不倚就落在那個誤差縫隙裡,迴圈跑完卻沒人 return,那就尷尬了。所以最後補了一行 (logits.len() - 1) as u32 兜底,當作保險。
演算法核心:xorshift64 PRNG
fn next_u64(&mut self) -> u64 {
let mut s = self.rng_state;
s ^= s << 13;
s ^= s >> 7;
s ^= s << 17;
self.rng_state = s;
s
}這是 xorshift64——一個簡單到有點不可思議的偽隨機數產生器。三條 shift + xor 指令,就足以通過大部分隨機性測試了。週期是 $2^{64} - 1$,對 LLM 推論來說綽綽有餘。
為什麼不用標準函式庫的 rand?
rand crate 確實提供了 high-quality PRNG(mersenne twister、ChaCha20…),可是它畢竟是個外部依賴。對 tiny-llm-runner 來說,我是刻意把依賴壓到最小的,xorshift64 自己手寫 5 行就搞定,何必為了亂數多拖一個 crate 進來?
而且對 LLM sampling 而言,PRNG 的品質根本不是重點——你只需要每次抽樣有合理的 entropy 就好,又不是要拿來做密碼學。Xorshift 真的夠用了。
避免 all-zero state
let s = if seed == 0 { 0x9E3779B97F4A7C15 } else { seed };xorshift 有個很有名的小陷阱:state 一旦是 0 就會永遠卡在 0(0 ^ 0 還是 0 嘛)。所以我把 seed = 0 偷偷換成 0x9E3779B97F4A7C15——這是黃金比例的 fixed point,常被拿來當 hash seed。
這個 magic number 的來頭:黃金比例 $\phi = (\sqrt{5} + 1) / 2$,它的二進位部分是個無限不循環序列,被認為「隨機性最強」。hash crate 的 FxHasher 也是用這個數字,算是業界公認的好朋友了。
next_f32 的位元操作
fn next_f32(&mut self) -> f32 {
let bits = (self.next_u64() >> 40) as u32;
bits as f32 / (1u32 << 24) as f32
}只取 64 bits 裡的 24 bits(» 40 之後留下 24 bits),轉成 f32 再除以 $2^{24}$。為什麼偏偏是 24?因為 f32 的 mantissa 就 23 bits(加上那個 implicit leading 1 才湊到 24 bits 精度),多塞 bits 進去也是白搭——反正精度後面就被截掉了,何必呢。
Rust 用法:mutable self 的 sampler
pub fn sample(&mut self, logits: &mut [f32]) -> u32 { ... }注意那個 &mut self——sampler 內部藏了 mutable state(rng_state),每次 sample 都會更新它。同時 logits: &mut [f32] 也是 mutable,因為我們直接就地改 logits(套 temperature、top-k、softmax 一條龍)。
&mut self 在 Rust 裡其實是個蠻重的承諾——它等於宣告「這個 call 期間整個物件是我獨佔的」。這也順帶解釋了為什麼 PRNG 天生就是 thread-unsafe:你沒辦法多執行緒同時呼叫 sample。
想要 thread-safe 的話,就得套個 Arc<Mutex<Sampler>> 之類的——不過對單一 forward pass 來說實在沒必要,sampler 本來就是乖乖 sequential 跑的嘛。
Rust 用法:early return 的 idiom
pub fn sample(&mut self, logits: &mut [f32]) -> u32 {
if self.temperature <= 0.0 {
return argmax(logits) as u32;
}
// 完整 sampling 邏輯
}把「greedy 短路」直接擺在最前面,就不用後面寫一堆 if-else 或巢狀結構,看起來清爽多了。Rust 對 early return 一向友善,這個 idiom 在處理 Result/Option 的時候特別常見(就是那個 ? 運算子)。
效能最佳化空間
1. softmax + sample 的 fusion
我現在的作法是「先 softmax → 再 sample」。但仔細想想,sample 其實只需要「累積分佈累到第一個超過 r 的位置」,根本不必先把整個分佈算完。如果你很早就抽中了,那後面的 softmax 不就白算了嗎?
不過呢,這個優化能省的有限——softmax 是 $O(\text{vocab})$,sample 平均也是 $O(\text{vocab})$,兩個加起來其實不會比 fuse 之後快多少。而且硬要 fusion 會犧牲程式碼的清晰度,我覺得不划算。
2. top-p 而不是 top-k
Top-K 其實有個罩門:每個位置的分佈陡峭程度都不一樣。有些位置可能前 5 個就吃掉了 99% 機率(超陡),有些卻要前 100 個才湊到 99%(很平緩)。固定一個 K 值,碰到前者太鬆、碰到後者又太緊,怎麼喬都不對。
Top-P(nucleus sampling) 的解法就聰明多了:保留累積機率剛好達到 P 的那一小撮 token 就好。實作上是先 softmax、排序、再累加到 P 為止。比 top-k 多了一次 sort,但效果穩定得多。
我目前還沒做 top-p,這算是個值得補上的功能吧。
3. Repetition penalty
LLM 很容易陷入「鬼打牆」的迴圈,一直重複自己(像那種「我是、我是、我是…」講不停的)。Repetition penalty 的招數就是把最近 N 個 token 的 logits 乘上一個懲罰因子(< 1),壓低它們再次被選中的機會。llama.cpp 預設用 1.1。
實作其實很簡單:
for &id in last_n_tokens {
if logits[id] > 0.0 { logits[id] /= penalty; }
else { logits[id] *= penalty; }
}4. Mirostat
Mirostat 是個比較進階的 sampling 算法,會動態調整 cutoff,讓「驚奇度」(perplexity)維持在一個固定值附近。實作起來頗複雜,但對長文本生成的品質提升很有感。llama.cpp 也支援。
5. SIMD softmax
softmax 裡的 exp 是 element-wise 運算,理論上可以 SIMD 化。但對 vocab=32k 的單次 sample 而言,softmax 大概也才幾百微秒——這點時間早就被 forward pass 那幾十毫秒給吃乾抹淨了,這個優化的邊際效益實在低到可以忽略。
6. Speculative sampling
最有潛力的優化我想應該是 speculative decoding(在 sampler 這一層實作):用小模型一口氣猜 K 個 token,主模型同時 verify。等於「一次 forward 就算出 K 個 token」,聽起來真是頗誘人。只是這需要兩個模型搭配演出,已經超出 sampler 自己能管的範圍了。
一個哲學問題:抽樣的 reproducibility
我把這個 sampler 設計成 deterministic(給定同一個 seed,輸出就能重現)。為什麼要這樣搞?
因為要驗證 LLM runner 對不對,得拿來「對拍」。如果我用真隨機,每次跑出來都不一樣,那要怎麼跟 llama.cpp 的輸出比對?給定 seed 的 deterministic sampling 讓我可以:
- 設 temperature = 0:對拍 greedy 輸出(純 deterministic)。
- 設 temperature = 0.8、seed = 42:對拍 sampler 的隨機性(兩邊用同樣的 seed 應該產生同樣的 token 序列)。
這個道理其實對所有需要復現實驗的 ML 程式碼都成立——先確保 reproducibility,才有辦法好好 debug。少了這個前提,你連「到底是哪裡跑掉了」都搞不清楚。
總結:sampler.rs 的角色
- 概念上:把 logits 分佈轉成單一 token 抽樣。
- 演算法上:argmax / temperature / top-k / xorshift / inverse CDF。
- 設計上:mutable state 集中在
Samplerstruct、deterministic by design、依賴最少。
短短不到 100 行的檔案,背後居然牽扯到機率、數值穩定、亂數品質、reproducibility 這麼多眉角,仔細想想還真是有點妙。下一篇就是這個系列的壓軸了——main.rs,把前面這一路拆解過的東西通通串起來,我們最後一篇見囉。
系列文章: