--- license: mit language: - en pipeline_tag: text-generation tags: - bitnet - quantization - early-exit - layer-skipping - efficient-transformers datasets: - roneneldan/TinyStories --- # llama3-earlyexit Llama3-style baseline with full precision weights and activations ## Model Description This model implements a 24-layer transformer with early exit loss and quadratic layer dropout for efficient inference. It was trained on the TinyStories dataset with layer-wise auxiliary supervision to enable flexible speed-quality tradeoffs during inference. ## Architecture Details - **Layers**: 24 - **Hidden dimension**: 2048 - **Attention heads**: 32 (64-dimensional each) - **Key-Value heads**: 8 (Grouped Query Attention with 4:1 ratio) - **FFN intermediate size**: 4096 - **Position embeddings**: Rotary Position Embeddings (RoPE) - **Normalization**: RMSNorm - **Activation**: SwiGLU (for MLP) - **Parameters**: ~1.06B ### Quantization Scheme - **Weights**: Full precision (FP32) - **Activations**: Full precision (FP32) - **Hadamard**: No ## Training Details ### Dataset - **Source**: TinyStories (2.1M stories) - **Tokenizer**: GPT-2 BPE (vocab size: 50,257) - **Sequence length**: 512 tokens ### Training Techniques **Quadratic Layer Dropout:** - Progressive dropout: p_l = 0.5 × (l/L)² - Normalized so Σp_l = 1.0 - Never drops final layer - Makes earlier layers more accurate **Early Exit Loss:** - All layers share the same LM head - Loss = main_loss + 0.3 × early_exit_loss - Layer-proportional weighting: w_i = (i+1)/L - Enables flexible early exit at inference ### Hyperparameters - **Optimizer**: AdamW - **Learning rate**: 6e-4 - **Warmup steps**: 1000 - **Batch size**: 16 (effective: 64) - **Training steps**: 50000 - **Gradient clipping**: 1.0 ## Performance ### Perplexity (TinyStories validation) | Exit Layer | Perplexity | Speed (tok/s) | |------------|------------|---------------| | All layers | TBD | TBD | | Layer 18 | TBD | TBD | | Layer 12 | TBD | TBD | | Layer 6 | TBD | TBD | ### Training Stability - **Gradient norms**: TBD - **Final loss**: TBD ## Usage ### Installation ```bash pip install transformers torch ``` ### Basic Inference ```python from transformers import AutoTokenizer, AutoModelForCausalLM # Load model model = AutoModelForCausalLM.from_pretrained("your-username/llama3-earlyexit") tokenizer = AutoTokenizer.from_pretrained("your-username/llama3-earlyexit") # Generate text inputs = tokenizer("Once upon a time", return_tensors="pt") outputs = model.generate(**inputs, max_length=100) print(tokenizer.decode(outputs[0])) ``` ### Early Exit Inference ```python # Exit at layer 12 for faster inference model.set_exit_layer(12) outputs = model.generate(**inputs, max_length=100) # 1.5-2x faster with minimal quality loss ``` ### Benchmark Different Exit Layers ```python for exit_layer in [6, 12, 18, 24]: model.set_exit_layer(exit_layer) outputs = model.generate(**inputs, max_length=100) print(f"Layer {exit_layer}: {tokenizer.decode(outputs[0])}") ``` ## Limitations - **Inference speed**: Quantized models use fake quantization (QAT) without specialized kernels, resulting in slower inference than full-precision despite lower bit-width - **Training instability**: 4-bit models (v2) exhibit gradient explosion (norms 50-110) requiring careful hyperparameter tuning - **Dataset scope**: Trained only on TinyStories; may not generalize to other domains without fine-tuning ## Citation If you use this model, please cite: ```bibtex @article{bitnet, title={BitNet: Scaling 1-bit Transformers for Large Language Models}, author={Wang, Hongyu and Ma, Shuming and Dong, Li and others}, journal={arXiv preprint arXiv:2310.11453}, year={2023} } @article{layerskip, title={LayerSkip: Enabling Early Exit Inference and Self-Speculative Decoding}, author={Elhoushi, Mostafa and Shrivastava, Akshat and Liskovich, Diana and others}, journal={arXiv preprint arXiv:2404.16710}, year={2024} } ``` ## License MIT License ## Contact For questions or issues, please open an issue on the model repository.