logoalt Hacker News

libraryofbabeltoday at 4:42 AM5 repliesview on HN

Speculative decoding is an amazingly clever invention, almost seems-too-good-to-be-true (faster interference with zero degradation from the quality of the main model). The core idea is: if you can find a way to generate a small run of draft next tokens with a smaller model that have a reasonable likelihood of being correct, it's fast to check that they are actually correct with the main model because you can run the checks in parallel. And if you think about it, a lot of next tokens are pretty obvious in certain situations (e.g. it doesn't take a frontier model to guess the likely next token in "United States of...", and a lot of code is boilerplate and easy to predict from previous code sections).

I always encourage folks who are interested in LLM internals to read up on speculative decoding (both the basic version and the more advanced MTP), and if you have time, try and implement your own version of it (writing the core without a coding agent, to begin with!)


Replies

zmmmmmtoday at 8:06 AM

> it's fast to check that they are actually correct with the main model because you can run the checks in parallel.

Can you give an intuition as to why it's faster? I would have thought regardless how many you run in parallel, the successful check has to execute the full model to generate the full sequence so you will have exactly the same time needed? Or is it by process of elimination so it terminates early once it eliminates the non-viable choices? (in which case, how do you guarantee the correct output was speculatively generated at all to be the last survivor?)

show 3 replies
m12ktoday at 7:18 AM

So we've basically taken the concept of branch prediction from CPUs and applied it to LLMs?

show 3 replies
mungoman2today at 5:36 AM

Naively it seems odd that running multiple checks in parallel is faster than just running the autoregressive model multiple times in series. It’s the same amount of compute right?

But I think the key is that in the standard autoregressive case we get memory bandwidth bound, so there are tons of idle compute resources. And so checking multiple tokens is cheap because we can batch and thus reuse the read weights for multiple tokens.

The verification step is similar to a prefill with a small batch size. The difference is what we do with the generated logits.

show 2 replies