--- license: apache-2.0 pipeline_tag: text-generation tags: - model_hub_mixin - pytorch_model_hub_mixin - RxNN - SparseQueryAttention - SQA - GroupedQueryAttention - MultiQueryAttention language: - en datasets: - roneneldan/TinyStories library_name: RxNN --- # SQAT-m: symmetric Sparse Query Attention Transformer Micro-MoE Research model for [**Sparse Query Attention (SQA)**](https://github.com/RxAI-dev/RxNN/blob/main/docs/research/sparse_query_attention.md) research - extension to **Grouped Query Attention (GQA)**, that's also reducing the number of used query heads, instead of further reducing key/value heads count, up to **Multi Query Attention (MQA)**. That approach results in huge computational complexity reduction and much faster training, while the performance stays between **GQA** and **MQA** level. > Symmetric **SQA** variant, is using exactly 50% of both query and kv heads. It has performance on reference GQA level, but > training is noticeable faster. That's the best configuration for full-sequence processing cases like encoders [Check other variants](#compared-models) ##### Research paper - arxiv.org/abs/2510.01817 ### Architecture details: - trainable params: ~8.62M - dim: 128 - layers: 6 - self-attention: symmetric Sparse Query Attention (sSQA) - heads: 8 (for dimension split) - query groups: 4 - key/value groups: 4 - Mixture-of-Experts Feed Forward - experts: 12 - active experts: 2 - SwiGLU feed forward with 256 dim - RoPE - RMS Norm - vocab: 5k (english only) - context length: 256 - Library: RxNN ### Training details: This microscale model was trained on 5 epochs on simple synthetic dataset, and is able to generate simple stories. The main training goal is to compare it with reference GQA/MQA models and other SQA variants - dataset: [roneneldan/TinyStories](https://huggingface.co/datasets/roneneldan/TinyStories) - 5 epochs - 2.3B processed tokens - learning rate: 2e-3, cosine annealing scheduler without warmup ### Compared models - [GQA-Ref-Micro](https://huggingface.co/ReactiveAI/GQA-Ref-Micro): 8 query heads, 2/8 kv heads - [MQA-Ref-Micro](https://huggingface.co/ReactiveAI/MQA-Ref-Micro): 8 query heads, 1/8 kv heads - [SQAT-mm](https://huggingface.co/ReactiveAI/SQAT-mm): 4/8 query heads, 2/8 kv heads - [sSQAT-mm](https://huggingface.co/ReactiveAI/sSQAT-mm): 4/8 query heads, 4/8 kv heads - [xSQAT-mm](https://huggingface.co/ReactiveAI/xSQAT-mm): 2/8 query heads, 2/8 kv heads ### Results Validation mean loss/accuracy: - GQA: 1.139 / ~70.66% - MQA: 1.158 / ~70.33% - **SQA: 1.159 / ~70.32%** - **sSQA: 1.142 / ~70.63%** <- - **xSQA: 1.169 / ~70.12%** Total training time: - GQA: ~398 min - MQA: ~399 min - **SQA: ~387 min** - **sSQA: ~390 min** <- - **xSQA: ~383 min** That results suggest that even with very short sequences (256) the computational benefits are noticeable (\~3%), while the performance differences are very small (\~1%). sSQA configuration has only \~0.3% worse loss, while it's \~2% faster. However, in bigger models with 1024 context size, the computational differences were greater (\~10%), while most SQA variants were closer to GQA than MQA in performance Even _the extreme version_ of **SQA** with only 2/8 used query heads (and also 2/8 key/value heads), seems to have similar performance as a reference MQA model, with even shorter training times. However, further reduction below this level (~25% of heads used), doesn't reduce training time/cost and noticeable decreasing performance, so there is some limitation. It suggests that **SQA** could be a viable alternative to spatially sparse attention. More info in [ReactiveAI/xSQAT-mm](https://huggingface.co/ReactiveAI/xSQAT-mm). ### Model size difference SQA has reduced dimensions of query heads linear projection and output projection, which results in a little smaller model size: - GQA: 8.67M Params - MQA: 8.64M Params - **SQA: 8.57M Params** - **sSQA: 8.62M Params** <- - **xSQA: 8.52M Params** > In these models, size difference is small because of MoE. In dense models the difference is more noticeable, check [ReactiveAI/SQAT-m](https://huggingface.co/ReactiveAI/SQAT-m) ### Usage Model requires our [RxLM framework](https://github.com/RxAI-dev/rxlm) for training/inference. It's integrated with HuggingFace Hub and libraries. Components connected to SQA and classic transformers are free even for commercial usage, while Reactive Transformer components are free only for non-commercial usage (Reactive AI Framework License v1.0) #### Inference: - Install RxNN, PyTorch and dependencies: `pip install rxnn torch transformers tokenizers` ```python import torch from rxlm.experimental.models import ExperimentalAttentionTransformer from rxlm.transformers.sampler import Sampler, SampleDecoder from rxlm.training.tokenizer import load_tokenizer_from_hf_hub model = ExperimentalAttentionTransformer.from_pretrained('ReactiveAI/sSQAT-mm') tokenizer = load_tokenizer_from_hf_hub('ReactiveAI/sSQAT-mm') sampler = Sampler(model, torch.device('cuda' if torch.cuda.is_available() else 'cpu'), end_token_id=3) sample = SampleDecoder(sampler, tokenizer) # 0.1 and 0.9 are default values for temperature and top_p generated = sample('Example model input for text generation...', temperature=0.1, top_p=0.9, max_seq_len=1024) sample('Example model input for text generation - print streamed response...', temperature=0.1, top_p=0.9, max_seq_len=1024, print_stream=True) ``` #### Train: - Install RxNN, PyTorch and dependencies: `pip install rxnn torch transformers tokenizers tensorboard` (`tensorboard` is optional) ```python import torch from rxlm.experimental.models import ExperimentalAttentionTransformer from rxlm.training.tokenizer import load_tokenizer_from_hf_hub from rxlm.llm_training.dataset import AutoregressiveLMDataset from rxlm.llm_training.supervised import AutoregressiveTrainer from rxlm.training.callbacks import PrintLossCallback, PrintAccuracyCallback, TokenCounterCallback, ModelSaveCallback from rxlm.training.scheduler import get_transformer_lr_scheduler model = ExperimentalAttentionTransformer.from_pretrained('ReactiveAI/sSQAT-mm') tokenizer = load_tokenizer_from_hf_hub('ReactiveAI/sSQAT-mm') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') batch_size = 256 epochs = 5 gradient_acc_steps = 1 seq_len = 1024 vocab_size = 10_000 peak_lr = 2e-3 * gradient_acc_steps train_dataset = AutoregressiveLMDataset.from_hf_hub('hf-dataset-id', 'subset', tokenizer=tokenizer, max_seq_len=seq_len) # split is 'train' by default valid_dataset = AutoregressiveLMDataset.from_hf_hub('hf-dataset-id', split='validation', tokenizer=tokenizer, max_seq_len=seq_len) dataset_len = len(train_dataset) steps_per_epoch = int(dataset_len / batch_size - 1) total_steps = int((epochs * steps_per_epoch) / gradient_acc_steps) warmup_steps = 0 logs_dir = './tensorboard_logs' # require tensorboard `pip install tensorboard` print_cb = PrintLossCallback(batches_per_epoch=steps_per_epoch) count_cb = TokenCounterCallback() acc_cb = PrintAccuracyCallback() save_cb = ModelSaveCallback('./path/to/save', push_to_hub=True, hub_model_id='your-model-id', private_repo=True, push_checkpoint_weights=True, final_commit_message='Final commit message', hf_token=YOUR_HF_TOKEN) trainer = AutoregressiveTrainer(model, device, dataset=train_dataset, validation_dataset=valid_dataset, vocab_size=vocab_size, callbacks=[print_cb, acc_cb, count_cb, save_cb], use_amp=True, dtype=torch.bfloat16, log_dir=logs_dir, gradient_accumulation_steps=gradient_acc_steps, use_moe_aux_loss=True, moe_aux_loss_scale=0.01) optimizer = torch.optim.AdamW(model.parameters(), lr=peak_lr, weight_decay=0.01) scheduler = get_transformer_lr_scheduler( optimizer, warmup_steps=warmup_steps, num_training_steps=total_steps ) trainer(epochs=epochs, batch_size=batch_size, optimizer=optimizer, scheduler=scheduler) ``` ## Summary According to experiment results, **Sparse Query Attention** seems to be the most cost-effective variant of **Grouped Query Attention**, leading to noticeable training time reduction (even for very small context) and is a promising research direction. It should be tested on very long context models, but this was out of scope of the current research. We will surely continue exploring SQA, but now we are mostly concentrated on out reactive architectures. Currently, for our **Reactive Tranformer** architectures that were initially designed with GQA for self-attention and MQA for memory-attention, we consider using SQA variants instead, for all attention layer types. More info will be released soon.