Introspective Diffusion Language Models (introspective-diffusion.github.io)

by zagwdt 55 comments 281 points
Read article View on HN

55 comments

[−] thepasch 31d ago
If I’m reading this right, this is pretty wild. They turned a Qwen autoregressor into a diffuser by using a bunch of really clever techniques, and they vastly outperform any “native diffuser,” actually being competitive with the base model they were trained from. The obvious upside here is the massive speedup in generation.

And then through a LoRA adapter, you can ground the diffuser on the base model’s distribution (essentially have it “compare” its proposals against what the base model would’ve generated), which effectively means: exact same byte-for-byte output for the same seed, just roughly twice as fast (which should improve even more for batched tasks).

I’m not an expert, more of a “practicing enthusiast,” so I might be missing something, but at first glance, this reads super exciting to me.

[−] oliver236 31d ago
I think your excitement is justified. The paper is claiming a serious bridge between AR quality and parallel decoding, and the lossless LoRA-assisted mode is the wildest part.
[−] awestroke 31d ago
I don't understand how you can compare against the base model output without generating with the base model, in which case what's the point?
[−] radarsat1 31d ago
Because the nature of transformers is that running a bunch of pregenerated tokens through them is a parallel operation, not autoregressive. That's how it works at training time, but speculative decoding uses it at inference time. So if you just want to check whether a set of known tokens is "likely" given the base model, you can run them all through and get probability distributions, no need to sample.

It's the same reason there's a difference in speed between "prompt processing" and "generation". The former is just taking the pre-generated prompt and building the KV cache, which is parallel, not autoregressive and therefore way faster.

[−] qeternity 31d ago
I haven't read TFA yet but a common technique is speculative decoding where a fast draft model will generate X tokens, which are then verified by the larger target model. The target model may accept some Y <= X tokens but the speedup comes from the fact that this can be done in parallel as a prefill operation due to the nature of transformers.

So let's say a draft model generates 5 tokens, all 5 of these can be verified in parallel with a single forward pass of the target model. The target model may only accept the first 4 tokens (or whatever) but as long as the 5 forward passes of the draft model + 1 prefill of the target model is faster than 4 forward passes of the target, you will have a speedup while maintaining the exact output distribution as the target.

[−] nodja 30d ago
Same reason why prompt processing is faster than text generation.

When you already know the tokens ahead of time you can calculate the probabilities of all tokens batched together, incurring significant bandwidth savings. This won't work if you're already compute bound so people with macs/etc. won't get as much benefits from this.

[−] Majromax 30d ago
Are Macs/etc compute bound with their 'it fits in unified memory' language models? Certainly by the time you're streaming weights from SSD you must be back in a bandwidth-bound regime.
[−] dd8601fn 30d ago
From what I understood, if we’re talking a single user on a mac (not batching) you’re rarely compute bound in the first place. More rows per pass is nearly free that way when cores were sitting idle anyway.

If that’s wrong I would certainly appreciate being corrected, though. But if it’s right, a 2.9x speed-up after rejected tokens, nearly for free, sounds amazing.

[−] nodja 29d ago
That will depend on the model, but they'll hit compute limits before a typical GPU in almost all cases. Macs will still benefit a speedup from this, just not one as big as the one reported.
[−] Balinares 31d ago
Isn't that exactly how draft models speed up inference, though? Validating a batch of tokens is significantly faster than generating them.
[−] anentropic 31d ago
presumably that happens at training time?

then once successfully trained you get faster inference from just the diffusion model

[−] a1j9o94 31d ago
You would only use the base model during training. This is a distillation technique
[−] porridgeraisin 30d ago
Eh. There is nothing diffusion about this. Nothing to do with denoising. This setup is still purely causal, making it quite a dishonest framing IMO. There is no more introspection here than what happens in MTP + SD setups.

Let me explain what is going on here. This is basically a form of multi-token prediction. And speculative decoding in inference. See my earlier post[1] to understand what that is. TL;DR, in multi-token prediction you train separate LM heads to predict the next as well as next to next token as well as... Upto chosen next kth token. Training multiple LM heads is expensive and can be unnecessary, so what people typically do is have a common base for all the k heads, explained further in [1]. These guys do another variant.

Here is what they do mechanically, given a sequence p consisting of five tokens PE([p1, p2, p3, p4, p5]). Where PE(.) adds relative position info to each token.

1. Create an augmented sequence PE([p1 MASK MASK MASK MASK]). Do a training pass on that, with the ground truth sequence p1..5. Here it is trained to, for example, to predict p3 given p1+pos=-2 MASK+pos=-1 MASK+pos=0, loosely notating.

2. Then separately[2], train it as usual on PE([p1 p2 p3 p4 p5]).

Step (1) teaches it to do multi-token prediction, essentially the single LM head will (very very loosely speaking) condition on the position k of the special MASK token and "route" it to the "implicit" k'th LM head.

Step (2) teaches it to be a usual LLM and predict the next token. No MASK tokens involved.

So far, you have trained a multi-token predictor.

Now during inference

You use this for speculative decoding. You generate 5 tokens ahead at once with MASK tokens. And then you run that sequence through the LLM again. This has the same benefits as usual speculative decoding, namely that you can do matrix-matrix multiplication as opposed to matrix-vector. The former is more memory-bandwidth efficient due to higher arithmetic intensity.

here is an example,

query = ["what", "is", "2+2"]) prompt = PE([...query, MASK*5]) you run output = LLM(prompt). Say output is ["what", "is", "2+2", "it", "is", "4"]. Note that the NN is trained to predict the kth next token when faced with positionally encoded MASK tokens. So you get all 5 in one go. To be precise, it learns to predict "4" given ["what", "is", "2+2", MASK, MASK]. Since it does not need the "it" and "is" explicitly, you can do it in parallel with generating the "it" and the "is". "is" is predicted given ["what", "is", "2+2", MASK], for example, and that also doesn't depend on the explicit "it" being there, and thus can also be done in parallel with generating "it", which is just normal generating the next token given the query. And then you use this as a draft in your speculative decoding setup.

Their claim is that using a multi-token predictor this way as a draft model works really well. To be clear, this is still causal, the reason diffusion models have hype is because they are capable of global refinement. This is not. In the same thread as [1], I explain how increasing the number of MASK tokens, i.e increasing k, i.e the number of tokens you predict at once in your multi-token prediction setup quickly leads to poor quality. This paper agrees with that. They try out k=2,3,4,8. They see a drop in quality at 8 itself. So finally, this is 4-token-prediction with self-speculative decoding(sans LayerSkip or such), removing seemingly no existing limitation of such setups. It is definitely an interesting way to train MTP though.

[1] https://news.ycombinator.com/item?id=45221692

[2] Note that it is computationally a single forward pass. Attention masks help you fuse steps 1 and 2 into a single operation. However, you still have 2 separate loss values.

[−] Reubend 29d ago
After trying to understand their method, I think you're right. Doesn't seem like anything that I would personally call "diffusion". Much closer to MTP + speculative decoding.

Then again, their results with it are great. It would be interesting to benchmark it against standard SD on a model that already uses MTP.

[−] porridgeraisin 29d ago
Yeah, I think it's a super neat way to do MTP. Conceptually much more pleasing and simple than existing methods. Especially since this way scaling k as models get better will be easier. Wish it had been presented as such.
[−] radarsat1 30d ago
This reminds me a lot of the tricks to turn BERT into a generative model. I guess the causal masking that keeps it to essentially be autoregressive is an important difference though. Kind of best of both worlds.
[−] krackers 25d ago
Masked language modeling has been compared loosely to text diffusion [1], so the paper's title claim may be loosely true in some sense even if it's misleading.

[1] https://nathan.rs/posts/roberta-diffusion/

[−] andsoitis 31d ago
Is anyone here experimenting seriously with Diffusion for text generation? I’d love to learn about your experiences!
[−] recsv-heredoc 31d ago
https://www.inceptionlabs.ai/

This startup seems to have been at it a while.

From our look into it - amazing speed, but challenges remain around time-to-first-token user experience and overall answer quality.

Can absolutely see this working if we can get the speed and accuracy up to that “good enough” position for cheaper models - or non-user facing async work.

One other question I’ve had is wondering if it’s possible to actually set a huge amount of text to diffuse as the output - using a larger body to mechanically force greater levels of reasoning. I’m sure there’s some incredibly interesting research taking place in the big labs on this.

[−] IanCal 31d ago
The overall speed rather than TTFT might start to be more relevant as the caller moves from being a human to another model.

However quality is really important. I tried that site and clicked one of their examples, "create a javascript animation". Fast response, but while it starts like this

`` Below is a self‑contained HTML + CSS + JavaScript example that creates a simple, smooth animation: a colorful ball bounces around the browser window while leaving a fading trail behind it.

JavaScript Bounce Animation