Abstract
The Llama-3.1-8B-Instruct-w16a8-mxtw model is a Turkish legal instruction-tuned variant of Llama-3.1-8B-Instruct, trained using a custom Float8 configuration designed to push the limits of FP8 efficiency under FSDP2. Unlike the default tensorwise or rowwise FP8 recipes, this model uses a manually constructed cast configuration where all input, weight, and gradient-output tensors are quantized using TENSORWISE dynamic scaling with FP8-E4M3 and FP8-E5M2 formats.
During training, model weights were maintained in BF16 for update stability, while forward and backward activations were cast to FP8-E4M3 (inputs & weights) and FP8-E5M2 (grad outputs). This configuration eliminates row-level scaling overhead and aggressively compresses compute tensors, enabling significantly higher throughput on H100 GPUs.
Trained on the newmindai/EuroHPC-Legal dataset (multi-domain Q/A in Turkish law), this custom FP8 variant delivered the largest performance gain among all recipes, reaching a ~116.82% speed improvement over the BF16 baseline while maintaining stable loss behavior comparable to other FP8 runs. The result highlights how customized Float8 strategies can unlock extreme efficiency—beyond standard tensorwise or rowwise casting—without degrading convergence.
Experiment Context
This model was trained with Float8 Custom cast configurations . We experimented it by setting the scaling_granularity to TENSORWISE for each of the input, weight, and gradient output cast configurations. The dtype was set to float8_e4m3 for the input and weight cast configurations, and to float8_e5m2 for the gradient output. Essentially, E4M3 indicates that the 8 bits are distributed as a normalized floating-point format with 4 exponent bits and 3 mantissa bits, while E4M3FN represents a similar normalized floating-point layout, also with 4 exponent bits and 3 mantissa bits, but using a finite number (FN) variant optimized for numerical stability. In both formats, the first bit serves as the sign bit.
from torchao.float8 import (
convert_to_float8_training,
Float8LinearConfig)
cast_config_input = CastConfig(
scaling_granularity=ScalingGranularity.TENSORWISE,
target_dtype=torch.float8_e4m3fn)
cast_config_weight = CastConfig(
scaling_granularity=ScalingGranularity.TENSORWISE,
target_dtype=torch.float8_e4m3fn)
cast_config_grad_output = CastConfig(
scaling_granularity=ScalingGranularity.TENSORWISE,
target_dtype=torch.float8_e5m2)
config = Float8LinearConfig(
cast_config_input=cast_config_input,
cast_config_weight=cast_config_weight,
cast_config_grad_output=cast_config_grad_output,
pad_inner_dim=True,
enable_fsdp_float8_all_gather=True,
round_scales_to_power_of_2=use_axiswise)
model = convert_to_float8_training(model, config=config)
Base Model Technical Specifications
- Parameters: 8 Billion
- Architecture Family: Llama 3.1
- Maximum Position Embeddings: 131,072
- Attention Heads: 32 (
num_attention_heads)
- Key-Value Heads: 8 (
num_key_value_heads)
- Hidden Layers: 32 (
num_hidden_layers)
- Hidden Size: 4,096 (
hidden_size)
- Intermediate Size: 14,336
- Vocabulary Size: 128,256
- Precision: bfloat16
- RoPE Scaling: type
llama3, factor = 8.0
- RMS Norm Epsilon: 1e-05
- Activation: SiLU
Training Methodology
Training Configuration
- Model:
meta-llama/Llama-3.1-8B-Instruct
- Sequence Length: 4,096 (
seq_len)
- Epochs: 1
- Per-Device Micro Batch Size: 2
- Gradient Accumulation: 4
- GPUs: 4 (via
CUDA_VISIBLE_DEVICES=0,1,2,3)
- dtype:
bf16 && fp8=false
- Weights: bfloat16
- Activations: bfloat16
- Optimizer: AdamW
- Learning Rate: 2e-5
- Weight Decay: 0.01
- Betas: (0.9, 0.95)
- Epsilon: 1e-8
- LR Scheduler: Cosine; warmup = 10% (
warmup_ratio=0.1) | also warmup_steps=100
- Max Grad Norm: 1.0
- Gradient Checkpointing: Enabled
- Evaluation: every 5 steps (
eval_steps=5, eval_samples=1000)
- Checkpointing: every 10 steps; keep last 5; select best by
eval_loss
- Logging: every step to file; Weights & Biases in offline mode
- Seed: 100
- Distributed Training:
torch.distributed.run (single node, multi-GPU)
- FSDP2 (Optimized Fully Sharded Data Parallel)
Setups
- Precision: Used Half-precision bfloat16 as data type and for computation.
- Hardware: HPC (EuroHPC/BSC-class) node with 4 × NVIDIA H100 GPUs.
- Framework: PyTorch with
torchrun for distributed training.
Dependencies
| package |
Version |
| Transformers |
4.57.1 |
| torch |
2.9.0+cu128 |
| accelerate |
0.14.1 |
| datasets |
4.3.0 |
| huggingface-hub |
0.36.0 |
| tensorboard |
2.20.0 |
| tensorboard-data-server |
0.7.2 |
| wandb |
0.22.1 |
Job Details
| model |
Job ID |
Runtime (mins) |
Nodes |
GPUs |
Node-hour |
GPU-hour |
micro-batch |
batch-size |
gradient_accumulation |
total_batch_size |
| Llama-3.1-8B-Instruct_w16a8_rw |
31768103 |
115.75 |
1 |
4 |
1.929 |
7.716 |
2 |
2 |
4 |
32 |
| Llama-3.1-8B-Instruct_w16a8_rw_with_gw_hp |
31837629 |
109.00 |
1 |
4 |
1.816 |
7.266 |
2 |
2 |
4 |
32 |
| Llama-3.1-8B-Instruct-w16a8-mxtw |
31768031 |
64.00 |
4 |
4 |
1.066 |
4.266 |
2 |
2 |
4 |
32 |
| Llama-3.1-8B-Instruct-w16a16-tw |
31768074 |
138.75 |
1 |
4 |
0.858 |
3.433 |
2 |
2 |
4 |
32 |
| Llama-3.1-8B-Instruct-w16a8-1node-bs8 |
31768093 |
123.75 |
1 |
4 |
0.788 |
3.151 |
2 |
2 |
4 |
32 |
| Llama-3.1-8B-Instruct-w16a16-4nodes-bs32 |
31478433 |
31.75 |
4 |
4 |
2.117 |
8.467 |
4 |
4 |
8 |
512 |
| Llama-3.1-8B-Instruct-w16a8-4nodes-bs32 |
31478468 |
39.75 |
4 |
4 |
2.650 |
10.600 |
4 |
4 |
8 |
512 |
| Llama-3.1-8B-Instruct-w16a16-8nodes-bs32 |
31476914 |
22.00 |
8 |
4 |
2.933 |
11.733 |
4 |
4 |
8 |
1024 |
| Llama-3.1-8B-Instruct-w16a8-8nodes-bs32 |
31476844 |
23.50 |
8 |
4 |
3.133 |
12.533 |
4 |
4 |
8 |
1024 |
| Llama-3.1-8B-Instruct-w16a16-8nodes-bs64 |
31476914 |
22.00 |
8 |
4 |
2.933 |
11.733 |
4 |
8 |
8 |
1024 |
| Llama-3.1-8B-Instruct-w16a8-8nodes-bs64 |
31476844 |
23.50 |
8 |
4 |
3.133 |
12.533 |
4 |
8 |
8 |
1024 |
Training Time Analysision
| Model |
Training Time (mins) |
Memory Allocated (avg %) |
GPU Utilization (avg %) |
Speed vs bf16 |
| Llama-3.1-8B-Instruct_w16a16 |
138.75267 |
74.4189 |
56.6059% |
_ |
| Llama-3.1-8B-Instruct-w16a8-tw |
123.75267 |
68.8982 |
97.5364% |
12.11% |
| Llama-3.1-8B-Instruct_w16a8_rw |
115.75364 |
69.6132 |
97.7689% |
19.87% |
| Llama-3.1-8B-Instruct_w16a8_rw_with_gw_hp |
109.00364 |
69.4806 |
97.3312% |
27.33% |
| Llama-3.1-8B-Instruct-w16a8-mxtw |
64.00328 |
68.8982 |
95.5661% |
116.82% |
Performance Evaluation
2-models trained on 1Node with fp8 recipes
| Loss metric results for w16a16 tensorwise & w16a8 custom recipe |
Memory allocation for w16a16 tensorwise & w16a8 custom recipe |
Utilization for w16a16 tensorwise & w16a8 custom recipe |
 |
 |
 |
| Loss metric results for w16a8 recipes |
Memory allocation for w16a8 recipes |
Utilization for w16a8 recipes |
 |
 |
 |
| Model |
Max Loss (train) |
Min Loss (train) |
Avg Loss (train) |
Final Loss (train) |
± Std (train) |
Max Loss (val) |
Min Loss (val) |
Avg Loss (val) |
Final Loss (val) |
± Std (val) |
| Llama-3.1-8B-Instruct-w16a8-rw |
8 |
3.1682 |
0.5740 |
0.8118 |
0.6431 |
0.2746 |
1.0613 |
0.8394 |
0.8937 |
0.8394 |
| Llama-3.1-8B-Instruct_w16a8_rw_with_gw_hp |
8 |
3.1837 |
0.5763 |
0.8116 |
0.6420 |
0.2751 |
1.0599 |
0.8391 |
0.8933 |
0.8391 |
| Llama-3.1-8B-Instruct-w16a8-mxtw |
8 |
3.1983 |
0.5747 |
0.8115 |
0.6446 |
0.2758 |
1.0562 |
0.8384 |
0.8923 |
0.8384 |
| Llama-3.1-8B-Instruct-w16a16-tw |
8 |
3.1235 |
0.7203 |
0.9750 |
0.3344 |
0.7612 |
1.9113 |
0.8907 |
0.9831 |
0.1897 |
| Llama-3.1-8B-Instruct-w16a8-1node-bs8 |
8 |
3.1661 |
0.7261 |
0.9804 |
0.3374 |
0.7672 |
1.9230 |
0.8948 |
0.9867 |
0.1906 |
| Llama-3.1-8B-Instruct-w16a16-4nodes-bs32 |
32 |
3.2452 |
0.7414 |
0.9665 |
0.4844 |
0.7504 |
1.0538 |
0.8382 |
0.8844 |
0.0725 |
| Llama-3.1-8B-Instruct-w16a8-4nodes-bs32 |
32 |
3.2840 |
0.7478 |
0.9748 |
0.4905 |
0.7581 |
1.0701 |
0.8430 |
0.8922 |
0.0764 |
| Llama-3.1-8B-Instruct-w16a16-8nodes-bs32 |
32 |
3.2311 |
0.8448 |
1.1856 |
0.6434 |
0.8448 |
1.0257 |
0.8977 |
0.9460 |
0.0568 |
| Llama-3.1-8B-Instruct-w16a8-8nodes-bs32 |
32 |
3.3003 |
0.8473 |
1.1866 |
0.6481 |
0.8473 |
1.0203 |
0.8992 |
0.9445 |
0.0539 |
| Llama-3.1-8B-Instruct-w16a16-8nodes-bs64 |
64 |
3.2311 |
0.8448 |
1.1856 |
0.6434 |
0.8448 |
1.0257 |
0.8977 |
0.9460 |
0.0568 |
| Llama-3.1-8B-Instruct-w16a8-8nodes-bs64 |
64 |
3.3003 |
0.8473 |
1.1866 |
0.6481 |
0.8473 |
1.0203 |
0.8992 |
0.9445 |
0.0539 |
Implementation
Gpu && Memory usage Profiling
The training progress has been profiled using pytorch-profiler tool.
Usage
Note: the final model has been saved in bfloat16 format. For inference, load the model in bfloat16 or float16 as shown below:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model_name = "newmindai/Llama-3.1-8B-Instruct-w16a8-mxtw"
dtype = torch.bfloat16
tok = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=dtype,
device_map="auto"
)
prompt = "Soru: Kişisel Verilerin Korunması Kanunu uyarınca hangi durumlarda açık rıza aranmaz? Cevap:"
inputs = tok(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
out = model.generate(
**inputs,
max_new_tokens=256,
do_sample=False
)
print(tok.decode(out[0], skip_special_tokens=True))
Ethical Considerations and Disclaimers
- Research & development purposes only; not a substitute for professional legal counsel.
- Users must ensure compliance with data protection and sector regulations.
- Potential biases may exist in domain data and model outputs.
Model & Data Card Metadata
- Total Parameters: 8,030,261,248
- Serialized Size (approx.): 16,060,522,496 bytes
- Config precision: bfloat16
- RoPE: llama3 scaling, factor 8.0
References and Citations
Base Model
@misc{meta_llama31_8b_instruct,
title={Llama 3.1 8B Instruct},
author={Meta AI},
year={2024},
howpublished={\url{https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct}}
}
Training Dataset
@misc{euro_hpc_legal,
title={EuroHPC-Legal},
author={newmindai},
year={2025},
howpublished={\url{https://huggingface.co/datasets/newmindai/EuroHPC-Legal}}
}