If I’m reading this right, this is pretty wild. They turned a Qwen autoregressor into a diffuser by using a bunch of really clever techniques, and they vastly outperform any “native diffuser,” actually being competitive with the base model they were trained from. The obvious upside here is the massive speedup in generation.
And then through a LoRA adapter, you can ground the diffuser on the base model’s distribution (essentially have it “compare” its proposals against what the base model would’ve generated), which effectively means: exact same byte-for-byte output for the same seed, just roughly twice as fast (which should improve even more for batched tasks).
I’m not an expert, more of a “practicing enthusiast,” so I might be missing something, but at first glance, this reads super exciting to me.
I don't understand how you can compare against the base model output without generating with the base model, in which case what's the point?
Eh. There is nothing diffusion about this. Nothing to do with denoising. This setup is still purely causal, making it quite a dishonest framing IMO. There is no more introspection here than what happens in MTP + SD setups.
Let me explain what is going on here. This is basically a form of multi-token prediction. And speculative decoding in inference. See my earlier post[1] to understand what that is. TL;DR, in multi-token prediction you train separate LM heads to predict the next as well as next to next token as well as... Upto chosen next kth token. Training multiple LM heads is expensive and can be unnecessary, so what people typically do is have a common base for all the k heads, explained further in [1]. These guys do another variant.
Here is what they do mechanically, given a sequence p consisting of five tokens PE([p1, p2, p3, p4, p5]). Where PE(.) adds relative position info to each token.
1. Create an augmented sequence PE([p1 MASK MASK MASK MASK]). Do a training pass on that, with the ground truth sequence p1..5. Here it is trained to, for example, to predict p3 given p1+pos=-2 MASK+pos=-1 MASK+pos=0, loosely notating.
2. Then separately[2], train it as usual on PE([p1 p2 p3 p4 p5]).
Step (1) teaches it to do multi-token prediction, essentially the single LM head will (very very loosely speaking) condition on the position `k` of the special MASK token and "route" it to the "implicit" k'th LM head.
Step (2) teaches it to be a usual LLM and predict the next token. No MASK tokens involved.
So far, you have trained a multi-token predictor.
Now during inference
You use this for speculative decoding. You generate 5 tokens ahead at once with MASK tokens. And then you run that sequence through the LLM again. This has the same benefits as usual speculative decoding, namely that you can do matrix-matrix multiplication as opposed to matrix-vector. The former is more memory-bandwidth efficient due to higher arithmetic intensity.
here is an example,
query = ["what", "is", "2+2"]) prompt = PE([...query, MASK*5]) you run output = LLM(prompt). Say output is ["what", "is", "2+2", "it", "is", "4"]. Note that the NN is trained to predict the kth next token when faced with positionally encoded MASK tokens. So you get all 5 in one go. To be precise, it learns to predict "4" given ["what", "is", "2+2", MASK, MASK]. Since it does not need the "it" and "is" explicitly, you can do it in parallel with generating the "it" and the "is". "is" is predicted given ["what", "is", "2+2", MASK], for example, and that also doesn't depend on the explicit "it" being there, and thus can also be done in parallel with generating "it", which is just normal generating the next token given the query. And then you use this as a draft in your speculative decoding setup.
Their claim is that using a multi-token predictor this way as a draft model works really well. To be clear, this is still causal, the reason diffusion models have hype is because they are capable of global refinement. This is not. In the same thread as [1], I explain how increasing the number of MASK tokens, i.e increasing `k`, i.e the number of tokens you predict at once in your multi-token prediction setup quickly leads to poor quality. This paper agrees with that. They try out k=2,3,4,8. They see a drop in quality at 8 itself. So finally, this is 4-token-prediction with self-speculative decoding(sans LayerSkip or such), removing seemingly no existing limitation of such setups. It is definitely an interesting way to train MTP though.
[1] https://news.ycombinator.com/item?id=45221692
[2] Note that it is computationally a single forward pass. Attention masks help you fuse steps 1 and 2 into a single operation. However, you still have 2 separate loss values.
I think your excitement is justified. The paper is claiming a serious bridge between AR quality and parallel decoding, and the lossless LoRA-assisted mode is the wildest part.