This design implicitly does something similar to something that I sometimes think conventional transformers should try: allowing later layers to query the KV data from earlier layers. As far as I can tell, with a conventional transformer, if a layer (and presumably higher-level-thinking) layer wants wants to take input from earlier tokens from something lower down, it needs to get it from the output and “remember” it by itself instead of just reading it directly.
But suppose an extra attention head were added that queried the KV data from lower layers. At the very least, I imagine this might cleanly solve the STRAWBERRY problem: whatever layer has figured out that the prompt wants to count instances of R could attend to lower layers that actually perceive those Rs.
> extra attention head were added that queried the KV data from lower layers
Isn't this sort of similar to latent looping? E.g. [1]. But actually as [2] argues, even that wasn't a good experiment because it used the very last hidden state, which is too close to the logits and loses most of the rich embedding structure. Perhaps you don't even need access to the state of anything except the penultimate hidden layer, since based on my vague reading of [3] the residual stream doesn't "lose information" as it passes deeper down the attention layers, so each block maybe manipulates a different subspace of the residual stream.
> Perhaps you don't even need access to the state of anything except the penultimate hidden layer, since based on my vague reading of [3] the residual stream doesn't "lose information" as it passes deeper down the attention layers, so each block maybe manipulates a different subspace of the residual stream.
I imagine that conventional transformers kind of force this. If you train a transformer such that it needs to learn the ability to do tasks like “Repeat the following words: apple banana cat” then the model is sort of forced to internally propagate the input far enough along to be able to perform the task. But maybe if you pre-trained from scratch with an architecture where later layers get direct access to earlier layers and/or the raw input, then the model wouldn’t need to propagate information.
Or maybe it would all fall apart and something would go wrong with the gradients.
I continue to be fascinated by these architectures that:
- Build in recurrence / inference scaling to transformers more natively.
- Don't use full recurrent gradient traces, and succeed not just despite, but because of that.
Interesting.
Instead of running the model once (flash) or multiple times (thinking/pro) in its entirety, this approach seems to apply the same principle within one run, looping back internally.
Instead of big models that “brute force” the right answer by knowing a lot of possible outcomes, this model seems to come to results with less knowledge but more wisdom.
Kind of like having a database of most possible frames in a video game and blending between them instead of rendering the scene.
Isn’t this in a sense an RNN built out of a slice of an LLM? Which if true means it might have the same drawbacks, namely slowness to train but also benefits such as an endless context window (in theory)
It's sort of an RNN, but it's also basically a transformer with shared layer weights. Each step is equivalent to one transformer layer, the computation for n steps is the same as the computation for a transformer with n layers.
The notion of context window applies to the sequence, it doesn't really affect that, each iteration sees and attends over the whole sequence.
I'm surprised more attention isn't paid to this research direction, that nobody has tried to generalize it for example by combining the recurrence concept with next token prediction.
That said despite the considerable gains this seems to just be some hyperparameter tweaking rather than a foundational improvement.
Not just hyper parameter tweaking. Not foundational research either. But rather engineering improvements that compound with each other (conswiglu layers, muon optimizer)
This design implicitly does something similar to something that I sometimes think conventional transformers should try: allowing later layers to query the KV data from earlier layers. As far as I can tell, with a conventional transformer, if a layer (and presumably higher-level-thinking) layer wants wants to take input from earlier tokens from something lower down, it needs to get it from the output and “remember” it by itself instead of just reading it directly.
But suppose an extra attention head were added that queried the KV data from lower layers. At the very least, I imagine this might cleanly solve the STRAWBERRY problem: whatever layer has figured out that the prompt wants to count instances of R could attend to lower layers that actually perceive those Rs.
> extra attention head were added that queried the KV data from lower layers
Isn't this sort of similar to latent looping? E.g. [1]. But actually as [2] argues, even that wasn't a good experiment because it used the very last hidden state, which is too close to the logits and loses most of the rich embedding structure. Perhaps you don't even need access to the state of anything except the penultimate hidden layer, since based on my vague reading of [3] the residual stream doesn't "lose information" as it passes deeper down the attention layers, so each block maybe manipulates a different subspace of the residual stream.
[1] https://arxiv.org/abs/2412.06769
[2] https://snimu.github.io/2025/03/30/multi-layer-language-head...
[3] https://news.ycombinator.com/item?id=45758093
> Perhaps you don't even need access to the state of anything except the penultimate hidden layer, since based on my vague reading of [3] the residual stream doesn't "lose information" as it passes deeper down the attention layers, so each block maybe manipulates a different subspace of the residual stream.
I imagine that conventional transformers kind of force this. If you train a transformer such that it needs to learn the ability to do tasks like “Repeat the following words: apple banana cat” then the model is sort of forced to internally propagate the input far enough along to be able to perform the task. But maybe if you pre-trained from scratch with an architecture where later layers get direct access to earlier layers and/or the raw input, then the model wouldn’t need to propagate information.
Or maybe it would all fall apart and something would go wrong with the gradients.
Sounds like a further improvement in the spirit of HRM & TRM models.
Decent comment via x: https://x.com/r0ck3t23/status/2002383378566303745
I continue to be fascinated by these architectures that: - Build in recurrence / inference scaling to transformers more natively. - Don't use full recurrent gradient traces, and succeed not just despite, but because of that.
Interesting. Instead of running the model once (flash) or multiple times (thinking/pro) in its entirety, this approach seems to apply the same principle within one run, looping back internally.
Instead of big models that “brute force” the right answer by knowing a lot of possible outcomes, this model seems to come to results with less knowledge but more wisdom.
Kind of like having a database of most possible frames in a video game and blending between them instead of rendering the scene.
Isn’t this in a sense an RNN built out of a slice of an LLM? Which if true means it might have the same drawbacks, namely slowness to train but also benefits such as an endless context window (in theory)
It's sort of an RNN, but it's also basically a transformer with shared layer weights. Each step is equivalent to one transformer layer, the computation for n steps is the same as the computation for a transformer with n layers.
The notion of context window applies to the sequence, it doesn't really affect that, each iteration sees and attends over the whole sequence.
> Instead of running the model once (flash) or multiple times (thinking/pro) in its entirety
I'm not sure what you mean here, but there isn't a difference in the number of times a model runs during inference.
I'm surprised more attention isn't paid to this research direction, that nobody has tried to generalize it for example by combining the recurrence concept with next token prediction. That said despite the considerable gains this seems to just be some hyperparameter tweaking rather than a foundational improvement.
> nobody has tried to generalize it for example by combining the recurrence concept with next token prediction
Here you go: https://arxiv.org/abs/2502.05171
Not just hyper parameter tweaking. Not foundational research either. But rather engineering improvements that compound with each other (conswiglu layers, muon optimizer)
It should be noted that this is NOT the official scores on the private evaluation set