--- language: - en license: apache-2.0 library_name: transformers pipeline_tag: text-generation tags: - mixture-of-attentions - distance-attention - metric-attention - mqa - hyperffn - router-gating datasets: - nvidia/Nemotron-Math-HumanReasoning - WeMake/Intelligent-Content-Understanding --- # MoAMetricLM‑100M — Mixture of Attentions (MoA) **A geometry‑aware Transformer that mixes several attention mechanisms and routes them with a metric‑based router.** - **Parameters:** ~185 M (≈ 100 M effective due to the mixture) - **Task:** Causal language modeling (decoder‑only) - **Library:** 🤗 Transformers - **KV cache:** Not yet implemented (generation recomputes the full context at every step) --- ## Model card | **Model ID** | `reaperdoesntknow/MoA-100M` | |--------------|-------------------------------------| | **Architecture** | `moa_metric` (custom) | | **Tokenizer** | GPT‑2 (`gpt2`) – `pad_token` set to `eos_token` | | **Context length** | 2048 tokens | | **Training data** | 2 × ≈ 256 k tokens from the datasets listed above | | **Training compute** | CPU‑only (Intel), FP32 | | **Training hyper‑parameters** | LR = 5e‑4 (AdamW), batch = 4, seq ≤ 512, 500 k total tokens | | **Final loss** | ≈ 0.30 (train) | | **License** | Apache‑2.0 | | **Safety** | No alignment or safety fine‑tuning – outputs may be biased or inaccurate. | | **Intended use** | Research on geometry‑aware attention, structured sparsity, and mixture‑of‑attention models. | | **Limitations** | • No KV‑cache → slower generation.
• Small token budget → not a general‑purpose LM.
• No safety/alignment training. | | **Out‑of‑scope** | High‑stakes applications (medical, legal, etc.) without further evaluation. | --- ## Overview MoA replaces the classic dot‑product attention with **metric‑based attention** and blends **four** distinct heads per Transformer block: | Head type | Description | |-----------|-------------| | **LocalConvHead** | Depthwise‑separable 1‑D convolution → captures short‑range context. | | **Metric Multi‑Head Attention (MetricMHAttention)** | Soft‑min over **L2 / cosine / diagonal‑Mahalanobis** distances:
\(\displaystyle \text{attn}_{h}(i,j) \propto \exp\!\big(-\alpha_h\|q_i-k_j\|^2\big)\) | | **Metric MQA** | Multi‑Query attention (shared K/V) in the same metric space – cheaper than full MHA. | | **ChannelMixHead** | Per‑token MLP that mixes channel dimensions (no positional mixing). | A **token‑wise router** decides, for each token, which head(s) to use and applies **feature‑gates** (FiLM‑style) and **router‑bias gates** for up/down‑scaling. The **FFN** is a **HyperFFN** – three parallel branches (SwiGLU MLP, separable‑conv, low‑rank) combined by a **branch router**. LayerScale and optional DropPath keep training stable. ### Regularisation (optional) * **Triangle‑inequality (TI) penalty** on sampled triples to encourage true‑metric behaviour. * **Ball pruning** – each head learns an **origin** \(o_h\) and **radius** \(r_h\); keys outside the ball are masked, giving structured sparsity. --- ## Architecture diagram (high‑level) ``` Input → Embedding → (PreNorm) → Block₁ → … → Blockₙ → LM‑Head → Output │ ├─ LocalConvHead ├─ MetricMHAttention ├─ MetricMQA └─ ChannelMixHead (router decides per‑token) Each Block also contains: → HyperFFN (SwiGLU | Conv | Low‑rank) ← branch router → LayerScale + DropPath ``` --- ## Configuration (example) ```json { "model_type": "moa_metric", "vocab_size": 50257, "dim": 768, "num_layers": 12, "attn_heads": 8, "mqa_q_heads": 8, "mixer_hidden": 3072, "ffn_hidden": 3072, "metric": "l2", // "l2" | "cosine" | "maha_diag" "alpha_init": 1.0, "learn_alpha": true, "use_balls": true, "radius_init": 3.0, "learn_radius": true, "origin_init_scale": 0.0, "maha_init": 1.0, "ti_reg_weight": 0.0, "ti_reg_samples": 0, "router_hidden": 128, "router_dropout": 0.1, "router_temperature": 1.0, "attn_drop": 0.1, "proj_drop": 0.1, "drop_path": 0.0, "max_position_embeddings": 2048, "pad_token_id": 50256, "bos_token_id": 50256, "eos_token_id": 50256 } ``` > **Tip:** If you use the GPT‑2 tokenizer, set `pad_token = eos_token` and make sure `vocab_size` matches the tokenizer (50257). --- ## Quick‑start (inference) ```python >>> from transformers import AutoTokenizer, AutoModelForCausalLM >>> model_id = "reaperdoesntknow/MoA-100M" >>> tokenizer = AutoTokenizer.from_pretrained(model_id) >>> tokenizer.pad_token = tokenizer.eos_token # needed for the GPT‑2 tokenizer >>> model = AutoModelForCausalLM.from_pretrained(model_id) >>> prompt = "Explain metric‑based attention in simple terms:" >>> inputs = tokenizer(prompt, return_tensors="pt") >>> output_ids = model.generate( ... **inputs, ... max_new_tokens=128, ... do_sample=False, # deterministic; set temperature>0 for sampling ... pad_token_id=tokenizer.pad_token_id, ... ) >>> print(tokenizer.decode(output_ids[0], skip_special_tokens=True)) ``` *Note:* Because KV‑cache is not implemented, generation time grows linearly with the total context length. --- ## Training (custom loop sketch) ```python from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForLanguageModeling from torch.utils.data import DataLoader import torch, torch.nn.functional as F tokenizer = AutoTokenizer.from_pretrained("gpt2") tokenizer.pad_token = tokenizer.eos_token def collate_fn(examples): batch = tokenizer( [ex["text"] for ex in examples], padding="max_length", truncation=True, max_length=512, return_tensors="pt", ) labels = batch["input_ids"].clone() labels[batch["attention_mask"] == 0] = -100 batch["labels"] = labels return batch # dataset = load_dataset(..., split="train") # must contain a 'text' field # loader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=collate_fn) model = AutoModelForCausalLM.from_pretrained("reaperdoesntknow/MoA-100M") optimizer = torch.optim.AdamW( model.parameters(), lr=5e-4, betas=(0.9, 0.95), weight_decay=0.01, ) for batch in loader: out = model(**batch) out.loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.2) optimizer.step() optimizer.zero_grad() ``` --- ## Evaluation checklist * **Perplexity** on a held‑out split of the two training datasets. * **Ablation studies** (keep total token budget constant): * L2 vs. cosine vs. diagonal‑Mahalanobis distance. * With / without ball pruning. * With / without HyperFFN branch router. * With / without TI regulariser. * **Speed / memory** comparison against a vanilla GPT‑2‑size model (same `dim`/`layers`). --- ## Efficiency notes | Feature | What it does | |---------|--------------| | **Ball pruning** | Masks keys that lie outside a learned radius → reduces the quadratic attention cost. | | **Metric MQA** | Shares K/V across heads → fewer projection matrices, lower FLOPs. | | **HyperFFN branch router** | Token‑wise top‑k routing means only the most useful branch is evaluated per token. | | **CPU tips** | Set `OMP_NUM_THREADS` / `MKL_NUM_THREADS` to the number of physical cores; use `torch.set_num_threads()` if needed. | Future roadmap: metric‑aware KV‑cache, kernelised distance approximations (e.g., Random Fourier Features), quantisation & mixed‑precision inference. --- ## Safety, Bias & Risks * The model **has not been fine‑tuned for safety or alignment**. * Outputs may contain **biases, profanity, or factual errors**. * Do **not** deploy in high‑stakes contexts without additional evaluation, moderation, and possibly further fine‑tuning. --- ## License Apache‑2.0 – see the `LICENSE` file in the repository. --- ## Citation ```bibtex @misc{moametriclm185m, title = {reaperdoesntknow/MoA-100M: A Geometry-Aware Mixture-of-Attentions Language Model}, author = {Colca, Roy Shawn and collaborators}, year = {2025}, url = {https://huggingface.co/reaperdoesntknow/MoA-100M} } ``` --- ## Changelog | Version | Date | Notes | |---------|------|-------| | **v0.2** | 2025‑09‑20 | 500 k‑token CPU run, GPT‑2 tokenizer, LR = 5e‑4, final loss ≈ 0.30. | | **v0.1** | 2025‑09‑20 | Initial public release: metric heads, MQA, ball pruning, HyperFFN, router & gates; HF‑compatible; no KV cache. | --- ## Maintainers * **Author:** reaper (Convergent Intelligence LLC) * **Contact:** *Email* (convergentintelligencenyc@gmail.com)* --- ## Special Remarks - This models still in an extremely experimental state. As are most of them, but im working on stabilizing this one for general inference. - I design create and train all of my models using my mathematical research and pure disgust for the dot product! - For those of you who actually read this and use my models, you make my day everytime I see another download, so thank you for being awesome!