logoalt Hacker News

Hybrid Attention

33 pointsby JohannaAlmeidatoday at 1:06 PM8 commentsview on HN

TLDR: Forked pytorch and triton internals . Changed attention so its linear first layer , middle quadratic layer, last linear layer Inference got much faster with a low perplexity hit in tests .

Full attention O(n²): 17.96s / 5.6 tok/s

HybridAttention O(n·W + n·D): 0.35s / 286.6 tok/s

I have been building a small Rust focused language model from scratch in PyTorch. This is not a finetune. It is byte level, trained from random initialization on a Rust heavy corpus assembled here: https://codeberg.org/JohannaJuntos/Sisyphus

Model and training setup

The model has 25.6M parameters with a 512 context length. It uses a byte level vocabulary of 256, with 8 layers, 8 heads, and 512 dimensional embeddings. Positional embeddings are learned and the embedding and LM head weights are tied.

Training ran for 30k steps on a 173.5M byte Rust corpus using a single RTX 4060 Ti 8GB.

Final metrics were a train loss of 0.5834, validation loss of 0.8217, and perplexity of 2.15. The best validation loss occurred around step 18.5k, which suggests some late overfitting or plateau.

Architecture

The model is a GPT style decoder, but replaces standard full attention with a HybridAttention block in each layer. This combines local windowed causal attention with a GRU like recurrent state path, along with a learned gate that mixes the two.

The local path handles short range syntax, while the recurrent path carries compressed long range state. The gate bias is initialized to favor local attention early in training.

Inference uses Triton kernels and custom torch.library ops.

Corpus

The biggest gain came from corpus expansion.

The run started with about 31MB from Rust official sources and major projects such as rustc, cargo, rust analyzer, tokio, serde, ripgrep, clap, and axum. The corpus was expanded to 173.5M bytes by cloning the top 500 crates, with 461 successful clones.

This expansion had more impact than any architectural change.

Inference performance

Full attention runs at about 5.6 tokens per second, while HybridAttention with KV cache reaches 286.6 tokens per second. This is about a 51x speedup with no visible quality loss.

The KV cache uses a hot window of 64 tokens in VRAM, while older tokens are compressed to 8 bit magnitude and angle and can be selectively promoted back to full precision. This changes the effective complexity from quadratic to near linear for this setup.

Quality

Surface Rust syntax looks decent, and imports and function signatures are often plausible. Semantics are still weak, and repetition and recursive patterns are common. It looks like Rust, but does not reason well yet.

What seems interesting

This project combines byte level Rust only pretraining from scratch, a hybrid local attention and recurrent architecture, large scale corpus expansion across the Rust ecosystem, and a practical KV cache paging strategy that delivers large speedups on consumer GPUs.

Next steps

I plan to run ablations comparing hybrid attention against local only and recurrent only variants, evaluate checkpoints around 18.5k versus the final model, and add syntax level validation such as parsing and compiling generated code. I also want to explore scaling context length from 256 up to 2048 and test whether switching from byte level to BPE becomes worthwhile now that the corpus is larger.

Questions

For small code models, which evaluations have been most useful beyond perplexity?

Has anyone seen hybrid local plus recurrent attention work well for code generation?

Given this setup, would you prioritize more tokens, longer context, or clean ablations first?


Comments

bigbadfelinetoday at 5:38 PM

I've been interested in faster attention and smaller models for some time but haven't had the time to do serious research so I can't answer your questions.

However, everything you do sounds very interesting, useful and well thought out, please keep doing it, I'd encourage others to work in the same direction too.

I hope, more of us can find the time for more than best wishes in the near future.

hackerman70000today at 4:32 PM

For the evaluation question: for small code models, try-to-compile rate on generated functions is the simplest metric that actually correlates with usefulness. Perplexity tells you the model learned the distribution, compilation rate tells you it learned the structure. Beyond that, exact match on function body completion given a signature is more informative than open ended generation benchmarks

JohannaAlmeidatoday at 1:32 PM

Full attention O(n²): 17.96s / 5.6 tok/s

HybridAttention O(n·W + n·D): 0.35s / 286.6 tok/s

empath75today at 2:07 PM

Is this for just like auto complete, because you are not going to get anything very useful out of a code-only training set.

show 2 replies
woodsontoday at 2:51 PM

Look into RWKV.

show 1 reply
MarcelinoGMX3Ctoday at 3:41 PM

[dead]

Aegis_Labstoday at 2:40 PM

[dead]