Contributing Custom Models for Training

This guide explains how to add custom model implementations to the optimum/neuron/models/training/ directory. Custom models are needed to support distributed training features like tensor parallelism, pipeline parallelism, and sequence parallelism on AWS Trainium devices.

Architecture Components

1. NeuronModelMixin

The NeuronModelMixin class provides core functionality:

2. Weight Transformation Specs

Transformation specs handle converting weights between:

Key transformation spec types:

For complete API documentation of all transformation specs and utility functions, see the Model Weight Transformation Specs API Reference.

3. Parallel Layers

Use these parallel layers from neuronx_distributed:

Implementation Steps

Step 1: Create Model Structure

Create a new directory: optimum/neuron/models/training/your_model/

__init__.py

from .modeling_your_model import YourModelForCausalLM, YourModel

__all__ = ["YourModelForCausalLM", "YourModel"]

Step 2: Implement the Model Building Blocks

modeling_your_model.py

Imports and Dependencies

import torch
from torch import nn
from neuronx_distributed.parallel_layers.layers import (
    ColumnParallelLinear,
    RowParallelLinear,
    ParallelEmbedding,
)
from neuronx_distributed.modules.qkv_linear import GQAQKVColumnParallelLinear
from transformers import PreTrainedModel
from transformers.models.your_model import YourModelConfig

from ..config import TrainingNeuronConfig
from ..modeling_utils import NeuronModelMixin
from ..transformations_utils import (
    CustomModule,
    FusedLinearsSpec,
    GQAQKVColumnParallelLinearSpec,
    ModelWeightTransformationSpecs,
)

Embedding Layer

class YourModelEmbeddings(nn.Module):
    def __init__(self, config, trn_config):
        super().__init__()
        self.embed_tokens = ParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
            dtype=config.torch_dtype,
            sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
        )

MLP Layer with Fused Linears

Important: Any module that has transformation specs must inherit from CustomModule to ensure proper handling of weight transformations, and the transformation specs must be defined in the self.specs attribute.

class YourModelMLP(nn.Module, CustomModule):
    def __init__(self, config, trn_config):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        
        # Fused gate and up projections
        self.gate_up_proj = ColumnParallelLinear(
            self.hidden_size,
            2 * self.intermediate_size,
            stride=2,  # Important for proper sharding
            bias=False,
            gather_output=False,
            sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
            dtype=config.torch_dtype,
        )
        
        self.down_proj = RowParallelLinear(
            self.intermediate_size,
            self.hidden_size,
            bias=False,
            input_is_parallel=True,
            sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
            dtype=config.torch_dtype,
        )
        
        # Define transformation specs
        self.specs = ModelWeightTransformationSpecs()
        self.specs.add_spec(
            FusedLinearsSpec(
                fused_linear_name="gate_up_proj",
                linear_names=["gate_proj", "up_proj"],
                bias=False,
                fuse_axis="column",  # Fuse along output dimension
                original_dims=[self.intermediate_size, self.intermediate_size],
            )
        )

Attention Layer

The attention layer implementation depends on the model’s architecture and tensor parallel configuration. There are three main variants:

1. Separate Q, K, V Projections (Default)

class YourModelAttention(nn.Module, CustomModule):
    def __init__(self, config, trn_config, layer_idx):
        super().__init__()
        self.config = config
        self.num_heads = config.num_attention_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.head_dim = config.hidden_size // self.num_heads
        
        # Separate projections for Q, K, V
        self.q_proj = ColumnParallelLinear(
            config.hidden_size,
            self.num_heads * self.head_dim,
            bias=False,
            gather_output=False,
            sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
            dtype=config.torch_dtype,
        )
        self.k_proj = ColumnParallelLinear(
            config.hidden_size,
            self.num_key_value_heads * self.head_dim,
            bias=False,
            gather_output=False,
            sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
            dtype=config.torch_dtype,
        )
        self.v_proj = ColumnParallelLinear(
            config.hidden_size,
            self.num_key_value_heads * self.head_dim,
            bias=False,
            gather_output=False,
            sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
            dtype=config.torch_dtype,
        )
        
        self.o_proj = RowParallelLinear(
            self.num_heads * self.head_dim,
            config.hidden_size,
            bias=False,
            input_is_parallel=True,
            sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
            dtype=config.torch_dtype,
        )
        
        # No transformation specs needed - regular parallel layers
        self.specs = ModelWeightTransformationSpecs()

2. Fused QKV Projection (Multi-Head Attention)

class YourModelAttention(nn.Module, CustomModule):
    def __init__(self, config, trn_config, layer_idx):
        super().__init__()
        # ... (same setup as above)
        
        tp_size = get_tensor_model_parallel_size()
        
        # Only use fused QKV when num_heads == num_key_value_heads (no GQA)
        if trn_config.fuse_qkv and self.num_heads == self.num_key_value_heads:
            self.qkv_proj = ColumnParallelLinear(
                config.hidden_size,
                3 * self.num_heads * self.head_dim,  # Q + K + V
                stride=3,  # Important for proper sharding
                bias=False,
                gather_output=False,
                sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
                dtype=config.torch_dtype,
            )
            
            # Define transformation specs for fused QKV
            self.specs = ModelWeightTransformationSpecs()
            self.specs.add_spec(
                FusedLinearsSpec(
                    fused_linear_name="qkv_proj",
                    linear_names=["q_proj", "k_proj", "v_proj"],
                    bias=False,
                    fuse_axis="column",
                    original_dims=[self.num_heads * self.head_dim] * 3,
                )
            )
            self.split_size = self.num_heads * self.head_dim // tp_size

3. GQA QKV Projection (Required for Challenging TP Configurations)

class YourModelAttention(nn.Module, CustomModule):
    def __init__(self, config, trn_config, layer_idx):
        super().__init__()
        # ... (same setup as above)
        
        tp_size = get_tensor_model_parallel_size()
        
        # Use GQA QKV when KV heads can't be evenly distributed across TP ranks
        # This happens when: num_key_value_heads < tp_size or num_key_value_heads % tp_size != 0
        self.qkv_linear = (self.num_key_value_heads < tp_size) or (self.num_key_value_heads % tp_size != 0)
        
        if self.qkv_linear:
            # Calculate KV size multiplier to ensure even distribution
            if trn_config.kv_size_multiplier is None:
                self.kv_size_multiplier = trn_config.auto_kv_size_multiplier(self.num_key_value_heads)
            else:
                self.kv_size_multiplier = trn_config.kv_size_multiplier
                
            self.qkv_proj = GQAQKVColumnParallelLinear(
                config.hidden_size,
                [self.num_heads * self.head_dim, self.num_key_value_heads * self.head_dim],
                bias=False,
                gather_output=False,
                sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
                kv_size_multiplier=self.kv_size_multiplier,
                fuse_qkv=trn_config.fuse_qkv,
                dtype=config.torch_dtype,
            )
            
            # Define transformation specs for GQA QKV
            self.specs = ModelWeightTransformationSpecs()
            self.specs.add_spec(
                GQAQKVColumnParallelLinearSpec(
                    gqa_qkv_projection_name="qkv_proj",
                    query_projection_name="q_proj",
                    key_projection_name="k_proj", 
                    value_projection_name="v_proj",
                    output_projection_name="o_proj",
                    num_attention_heads=self.num_heads,
                    num_key_value_heads=self.num_key_value_heads,
                    kv_size_multiplier=self.kv_size_multiplier,
                    q_output_size_per_partition=self.qkv_proj.q_output_size_per_partition,
                    kv_output_size_per_partition=self.qkv_proj.kv_output_size_per_partition,
                    fuse_qkv=trn_config.fuse_qkv,
                )
            )

When to Use Each Variant:

The choice is typically determined by:

tp_size = get_tensor_model_parallel_size()
use_gqa_qkv = (num_key_value_heads < tp_size) or (num_key_value_heads % tp_size != 0)
use_fused_qkv = trn_config.fuse_qkv and (num_heads == num_key_value_heads) and not use_gqa_qkv

Step 3: Implement Main Model Classes

Base Model

class YourPreTrainedModel(PreTrainedModel, NeuronModelMixin):
    config_class = YourModelConfig
    base_model_prefix = "model"
    supports_gradient_checkpointing = True
    _no_split_modules = ["YourModelDecoderLayer"]
    _skip_keys_device_placement = "past_key_values"
    _supports_flash_attn_2 = True
    _supports_cache_class = True
    _supports_quantized_cache = True
    _supports_static_cache = True


class YourModel(NeuronModelMixin, YourPreTrainedModel):
    def __init__(self, config: YourModelConfig, trn_config: TrainingNeuronConfig):
        YourPreTrainedModel.__init__(self, config)
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.trn_config = trn_config
        
        self.embed_tokens = ParallelEmbedding(...)
        self.layers = nn.ModuleList([
            YourModelDecoderLayer(config, trn_config, layer_idx)
            for layer_idx in range(config.num_hidden_layers)
        ])
        self.norm = YourModelRMSNorm(...)
        
        self.post_init()

CausalLM Model

class YourModelForCausalLM(NeuronModelMixin, YourPreTrainedModel):
    _tied_weights_keys = ["lm_head.weight"]
    
    # Pipeline parallelism support
    SUPPORTS_PIPELINE_PARALLELISM = True
    PIPELINE_TRANSFORMER_LAYER_CLS = YourModelDecoderLayer
    PIPELINE_INPUT_NAMES = ["input_ids", "attention_mask"]
    
    def __init__(self, config, trn_config):
        super().__init__(config)
        self.trn_config = trn_config
        self.model = YourModel(config, trn_config)
        self.vocab_size = config.vocab_size
        
        self.lm_head = ColumnParallelLinear(
            config.hidden_size,
            config.vocab_size,
            bias=False,
            gather_output=False,
            dtype=config.torch_dtype,
        )
        
        self.post_init()

Step 4: Register Model

Update optimum/neuron/models/training/__init__.py:

from .your_model import YourModelForCausalLM, YourModel

__all__ = [..., "YourModelForCausalLM", "YourModel"]

Update optimum/neuron/models/training/auto_models.py:

from .your_model.modeling_your_model import YourModelForCausalLM, YourModel

# Register the base model (without head)
register_neuron_model_for_training("your_model", "model")(YourModel)

# Register the CausalLM model
register_neuron_model_for_training("your_model", "text-generation")(YourModelForCausalLM)

Here "your_model" is the corresponding to the model_type attribute of your model’s configuration class.

Best Practices

1. Parallel Layer Configuration

2. Weight Transformation Specs

3. Pipeline Parallelism

4. Flash Attention Support

Testing Your Implementation

The training tests in tests/training/ provide a comprehensive testing framework that validates numerical correctness, distributed training scenarios, and checkpoint compatibility. Most of the tests are not designed to be run on every custom modeling implementation, but rather to validate the core functionality of the Optimum Neuron training infrastructure. With that in mind, here’s what you need to implement for your custom modeling:

1. Custom Modeling Validation

The test_custom_modeling.py file validates that your custom implementation produces identical outputs to the original Transformers model:

Update tests/training/test_custom_modeling.py:

CUSTOM_MODELINGS_TO_TEST = [
    # ... existing models ...
    ("YourModelForCausalLM", "your-org/your-model-name"),
]

Important: For custom modeling validation tests, use small/tiny models to ensure CI efficiency. The test models should have:

Examples of good test models for custom modeling validation:

Key test your model must pass:

def test_custom_modeling_matches_original()  # Output matching

2. End-to-End Training Validation

The test_overfit.py file validates training convergence. To include your model in end-to-end training validation, you must add it to the parametrized test cases:

Update tests/training/test_overfit.py:

@pytest.mark.parametrize(
    "model_class_name,model_name_or_path,learning_rate,warmup_ratio,training_kwargs,use_flash_attention_2,max_expected_loss,max_length,num_steps",
    [
        # ... existing models ...
        [
            "YourModelForCausalLM",
            "your-org/your-model-name",
            1e-4,
            0.03,
            {},
            True,
            0.5,
            2048,
            50,
        ],
    ],
    ids=[
        # ... existing model IDs ...
        "your-org/your-model-name",
    ],
)

This validates:

Your model will be tested in:

def test_overfit_custom_modeling_causal_lm()       # Basic training (your model included)

3. Auto Model Loading

The test_modeling_auto.py file validates that your model can be loaded using the NeuronModel and NeuronModelForCausalLM auto classes. To include your model in these tests, you must add it to the test cases:

Update tests/training/test_modeling_auto.py:

@pytest.mark.parametrize("from_pretrained", [False, True], ids=["from_config", "from_pretrained"])
@distributed_test(world_size=1)
@is_trainium_test
def test_auto_model_with_supported_architecture(from_pretrained):
    trn_config = TrainingNeuronConfig()
    kwargs = {"torch_dtype": torch.bfloat16}
    for model_name_or_path in [
        "michaelbenayoun/llama-2-tiny-4kv-heads-4layers-random",
        "michaelbenayoun/granite-tiny-4kv-heads-4layers-random", 
        "michaelbenayoun/qwen3-tiny-4kv-heads-4layers-random",
        "your-org/your-model-name",  # Add your model here
    ]:
        # ... rest of test logic

@pytest.mark.parametrize("from_pretrained", [False, True], ids=["from_config", "from_pretrained"])
@distributed_test(world_size=1)
@is_trainium_test
def test_auto_model_for_causal_lm_with_supported_architecture(from_pretrained):
    trn_config = TrainingNeuronConfig()
    kwargs = {"torch_dtype": torch.bfloat16}
    for model_name_or_path in [
        "michaelbenayoun/llama-2-tiny-4kv-heads-4layers-random",
        "michaelbenayoun/granite-tiny-4kv-heads-4layers-random",
        "michaelbenayoun/qwen3-tiny-4kv-heads-4layers-random", 
        "your-org/your-model-name",  # Add your model here
    ]:
        # ... rest of test logic

This validates:

4. Running Tests

Tests require AWS Trainium instances. Run specific test categories:

# Run all custom modeling tests
pytest tests/training/test_custom_modeling.py -v

# Run specific model tests
pytest tests/training/test_custom_modeling.py -v -k "your_model"

# Run end-to-end training validation
pytest tests/training/test_overfit.py -v

5. Test Requirements

Your implementation must:

  1. Pass numerical correctness tests against original Transformers implementation
  2. Support parallelization strategies (at minimum DP and TP; PP support recommended)
  3. Handle various QKV implementations (regular, fused, GQA)
  4. Support checkpoint consolidation for distributed training
  5. Support LoRA training if applicable
  6. Demonstrate convergence through overfitting tests

The testing framework ensures your custom model maintains compatibility with the existing Optimum Neuron training infrastructure while delivering expected performance and correctness guarantees.

Common Issues

Additional Resources

This guide provides the foundation for implementing custom models. For complete examples and advanced patterns, reference these existing implementations:

Key files to study: