# Kernels

## Docs

- [Kernels](https://huggingface.co/docs/kernels/main/index.md)
- [Locking kernel/layer versions](https://huggingface.co/docs/kernels/main/locking.md)
- [FAQ](https://huggingface.co/docs/kernels/main/faq.md)
- [Layers](https://huggingface.co/docs/kernels/main/layers.md)
- [Environment variables](https://huggingface.co/docs/kernels/main/env.md)
- [Installation](https://huggingface.co/docs/kernels/main/installation.md)
- [Basic Usage](https://huggingface.co/docs/kernels/main/basic-usage.md)
- [Kernels CLI Reference](https://huggingface.co/docs/kernels/main/cli.md)
- [Kernel requirements](https://huggingface.co/docs/kernels/main/kernel-requirements.md)
- [Layers API Reference](https://huggingface.co/docs/kernels/main/api/layers.md)
- [Kernels API Reference](https://huggingface.co/docs/kernels/main/api/kernels.md)

### Kernels
https://huggingface.co/docs/kernels/main/index.md

# Kernels

<div align="center">
<img src="https://github.com/user-attachments/assets/64a652f3-0cd3-4829-b3c1-df13f7933569" width="450" height="450" alt="kernel-builder logo">
</div>

The Kernel Hub allows Python libraries and applications to load compute
kernels directly from the [Hub](https://hf.co/). To support this kind
of dynamic loading, Hub kernels differ from traditional Python kernel
packages in that they are made to be:

- **Portable**: a kernel can be loaded from paths outside `PYTHONPATH`.
- **Unique**: multiple versions of the same kernel can be loaded in the
  same Python process.
- **Compatible**: kernels must support all recent versions of Python and
  the different PyTorch build configurations (various CUDA versions
  and C++ ABIs). Furthermore, older C library versions must be supported.

You can [search for kernels](https://huggingface.co/models?other=kernel) on
the Hub.


<EditOnGithub source="https://github.com/huggingface/kernels/blob/main/docs/source/index.md" />

### Locking kernel/layer versions
https://huggingface.co/docs/kernels/main/locking.md

# Locking kernel/layer versions

Projects that use `setuptools` can lock the kernel versions that should be
used. First specify the accepted versions in `pyproject.toml` and make
sure that `kernels` is a build dependency:

```toml
[build-system]
requires = ["kernels", "setuptools"]
build-backend = "setuptools.build_meta"

[tool.kernels.dependencies]
"kernels-community/activation" = ">=0.0.1"
```

Then run `kernels lock .` in the project directory. This generates a `kernels.lock` file with
the locked revisions. The locked revision will be used when loading a kernel with
`get_locked_kernel`:

```python
from kernels import get_locked_kernel

activation = get_locked_kernel("kernels-community/activation")
```

**Note:** the lock file is included in the package metadata, so it will only be visible
to `kernels` after doing an (editable or regular) installation of your project.

## Locked kernel layers

Locking is also supported for kernel layers. To use locked layers, register them
with the `LockedLayerRepository` class:

```python
kernel_layer_mapping = {
    "SiluAndMul": {
        "cuda": LockedLayerRepository(
            repo_id="kernels-community/activation",
            layer_name="SiluAndMul",
        )
    }
}

register_kernel_mapping(kernel_layer_mapping)
```

## Pre-downloading locked kernels

Locked kernels can be pre-downloaded by running `kernels download .` in your
project directory. This will download the kernels to your local Hugging Face
Hub cache.

The pre-downloaded kernels are used by the `get_locked_kernel` function.
`get_locked_kernel` will download a kernel when it is not pre-downloaded. If you
want kernel loading to error when a kernel is not pre-downloaded, you can use
the `load_kernel` function instead:

```python
from kernels import load_kernel

activation = load_kernel("kernels-community/activation")
```


<EditOnGithub source="https://github.com/huggingface/kernels/blob/main/docs/source/locking.md" />

### FAQ
https://huggingface.co/docs/kernels/main/faq.md

# FAQ

## Kernel layers

### Why is the kernelization step needed as a separate step?

In earlier versions of `kernels`, a layer's `forward` method was replaced
by `use_kernel_forward_from_hub` and `replace_kernel_forward_from_hub`.
The new `forward` would dispatch to a kernel based on the device type,
whether a model was training, etc. However, this approach was
fundamentally incompatible with `torch.compile` since it relied
on data-dependent branching.

To avoid branching, we have to make dispatch decisions ahead of time,
which is what the `kernelize` function does.

### Why does kernelization only replace `forward` methods?

There are some other possible approaches. The first is to completely
replace existing layers by kernel layers. However, since this would
permit free-form layer classes, it would be much harder to validate
that layers are fully compatible with the layers that they are
replacing. For instance, they could have completely different member
variables. Besides that, we would also need to hold on to the original
layers, in case we need to revert to the base layers when the model
is `kernelize`d again with different options.

A second approach would be to make an auxiliary layer that wraps the
original layer and the kernel layer and dispatches to the kernel layer.
This wouldn't have the issues of the first approach, because kernel layers
could be similarly strict as they are now, and we would still have access
to the original layers when `kernelize`-ing the model again. However,
this would change the graph structure of the model and would break use
cases where programs access the model internals (e.g.
`model.layers[0].attention.query_weight`) or rely on the graph structure
in other ways.

The approach of `forward`-replacement is the least invasive, because
it preserves the original model graph. It is also reversible, since
even though the `forward` of a layer _instance_ might be replaced,
the corresponding class still has the original `forward`.

## Misc

### How can I disable kernel reporting in the user-agent?

By default, we collect telemetry when a call to `get_kernel()` is made.
This only includes the `kernels` version, `torch` version, and the build
information for the kernel being requested.

You can disable this by setting `export DISABLE_TELEMETRY=yes`.


<EditOnGithub source="https://github.com/huggingface/kernels/blob/main/docs/source/faq.md" />

### Layers
https://huggingface.co/docs/kernels/main/layers.md

# Layers

A kernel can provide layers in addition to kernel functions. A layer from
the Hub can replace the `forward` method of an existing layer for a certain
device type. This makes it possible to provide more performant kernels for
existing layers.

See [Kernel requirements](kernel-requirements.md) for more information on the
requirements of Hub layers.

## Making a layer extensible with kernels from the hub

### Using a decorator

A layer can be made extensible with the `use_kernel_forward_from_hub`
decorator. For example:

```python
@use_kernel_forward_from_hub("SiluAndMul")
class SiluAndMul(nn.Module):
    def forward(self, input: torch.Tensor) -> torch.Tensor:
        d = input.shape[-1] // 2
        return F.silu(input[..., :d]) * input[..., d:]
```

The decorator does not change the behavior of the class -- it annotates
the class with the given name (here `SiluAndMul`). The `kernelize` function
described below uses this name to look up kernels for the layer.

### External layers

An existing layer that does not (yet) have the `use_kernel_forward_from_hub`
decorator can be made extensible using the `replace_kernel_forward_from_hub`
function:

```python
from somelibrary import SiluAndMul

replace_kernel_forward_from_hub(SiluAndMul, "SiluAndMul")
```

**Warning:** we strongly recommend using layers with a decorator, since
it signifies that the maintainer intends to keep the `forward` signature
compatible with layers from the hub.

## Kernelizing a model

A model will not use Hub kernels by default, even if it contains extensible
layers. To enable the use of Hub kernels in the model, it needs to be
'kernelized' using the `kernelize` function. This function traverses the
model graph and replaces the `forward` methods of extensible layers for which
Hub kernels are registered. `kernelize` can be used as follows:

```python
model = MyModel(...)
model = kernelize(model, mode=Mode.INFERENCE)
```

The `kernelize` function modifies the model in-place, the model itself is
returned as a convenience. The `mode` specifies that the model will be used
in inference. Similarly, you can ask `kernelize` to prepare the model for
training:

```python
model = MyModel(...)
model = kernelize(model, mode=Mode.TRAINING)
```

A model that is kernelized for training can also be used for inference, but
not the other way around. If you want to change the mode of the kernelized
model, you can just run `kernelize` on the model again with the new mode.

If you want to compile a model with `torch.compile`, this should be indicated
in the mode as well. You can do this by combining `Mode.INFERENCE` or
`Mode.TRAINING` with `Mode.TORCH_COMPILE` using the set union (`|`) operator:

```python
model = MyModel(...)

# Inference
model = kernelize(model, mode=Mode.INFERENCE | Mode.TORCH_COMPILE)

# Training
model = kernelize(model, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
```

### Kernel device

Kernels can be registered per device type. For instance, separate `cuda` and
`metal` kernels could be registered for the name `SiluAndMul`. By default,
`kernelize` will try to infer the device type from the model's parameters.
You can pass the device type to `kernelize` if the device type cannot be
inferred (e.g. because the model has no parameters):

```python
model = MyModel(...)
model = kernelize(model, device="cuda", mode=Mode.INFERENCE)
```

### Fallback `forward`

If the `TRAINING` and/or `TORCH_COMPILE` modes are used, but a registered
kernel does not support backward passes or `torch.compile` respectively,
`kernelize` will fall back to the original, non-kernelized, layer. You
can let `kernelize` raise an exception instead by using `use_fallback=False`:

```python
model = MyModel(...)
model = kernelize(model, mode=Mode.INFERENCE | Mode.TORCH_COMPILE, use_fallback=False)
```

This can be useful if you want to guarantee that Hub kernels are used.

### Inspecting which kernels are used

The kernels that are used are logged at the `INFO` level by `kernelize`.
See the [Python logging](https://docs.python.org/3/library/logging.html)
documentation for information on how to configure logging.

## Registering a hub kernel for a layer

`kernelize` relies on kernel mappings to find Hub kernels for layers.
Kernel mappings map a kernel name such as `SiluAndMul` to a kernel on
the Hub. For example:

```python
kernel_layer_mapping = {
    "SiluAndMul": {
        "cuda": LayerRepository(
            repo_id="kernels-community/activation",
            layer_name="SiluAndMul",
        ),
        "rocm": LayerRepository(
            repo_id="kernels-community/activation",
            layer_name="SiluAndMul",
        )
    }
}
```

You can register such a mapping using `register_kernel_mapping`:

```python
register_kernel_mapping(kernel_layer_mapping)
```

This will register the kernel mapping in the current context, which is
normally global. It is recommended to scope the mapping to where it is
used with the `use_kernel_mapping` context manager:

```python
with use_kernel_mapping(kernel_layer_mapping):
    # Use the layer for which the mapping is applied.
    model = kernelize(model, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
```

This ensures that the mapping is not active anymore outside the
`with`-scope.

### Using version bounds

Kernels are versioned using tags of the form `v<major>.<minor>.<patch>`.
You can specify which version of the kernel to download using Python version
specifiers:

```python
kernel_layer_mapping = {
    "SiluAndMul": {
        "cuda": LayerRepository(
            repo_id="kernels-community/activation",
            layer_name="SiluAndMul",
            version=">=0.0.4,<0.1.0",
        ),
        "rocm": LayerRepository(
            repo_id="kernels-community/activation",
            layer_name="SiluAndMul",
            version=">=0.0.4,<0.1.0",
        )
    }
}
```

This will get the layer from latest kernel tagged `v0.0.z` where `z` is at
least 4. It is strongly recommended to specify a version bound, since a
kernel author might push incompatible changes to the `main` branch.

### Registering kernels for specific modes

You might want to register two different kernels for a particular layer,
where one kernel is optimized for a specific mode. You can do so by
registering layer repositories for specific modes. For example:

```python
kernel_layer_mapping = {
    "SiluAndMul": {
        "cuda": {
          Mode.INFERENCE: LayerRepository(
              repo_id="kernels-community/activation-inference-optimized",
              layer_name="SiluAndMul",
          ),
          Mode.TRAINING | Mode.TORCH_COMPILE: LayerRepository(
              repo_id="kernels-community/activation-training-optimized",
              layer_name="SiluAndMul",
          ),
      }
    }
}
```

The `kernelize` function will attempt to use the following registered
kernels for a given mode:

- `INFERENCE`: `INFERENCE` → `INFERENCE | TORCH_COMPILE` → `TRAINING` →
  `TRAINING | TORCH_COMPILE` → `FALLBACK`
- `INFERENCE | TORCH_COMPILE`: `INFERENCE | TORCH_COMPILE` →
  `TRAINING | TORCH_COMPILE` → `FALLBACK`
- `TRAINING`: `TRAINING` → `TRAINING | TORCH_COMPILE` → `FALLBACK`
- `TRAINING | TORCH_COMPILE`: `TRAINING | TORCH_COMPILE` → `FALLBACK`

`Mode.FALLBACK` is a special mode that is used when no other mode matches. It
is also used when a kernel is registered without a mode, as described in the
previous section.

```python
kernel_layer_mapping = {
    "SiluAndMul": {
        "cuda": {
            Mode.FALLBACK: LayerRepository(
                repo_id="kernels-community/activation",
                layer_name="SiluAndMul",
            ),
            Mode.INFERENCE: LayerRepository(
                repo_id="kernels-community/activation-inference-optimized",
                layer_name="SiluAndMul",
            ),
            Mode.TRAINING: LayerRepository(
                repo_id="kernels-community/activation-training-optimized",
                layer_name="SiluAndMul",
            ),
        }
    }
}
```

In this case, both `Mode.INFERENCE | Mode.TORCH_COMPILE` and
`Mode.TRAINING | Mode.TORCH_COMPILE` will use the `Mode.FALLBACK` kernel,
since the other kernels do not support `torch.compile`.

### Registering kernels for specific CUDA capabilities

Some kernels only work with newer CUDA architectures. For instance, some
kernels require capability 9.0 for the TMA unit on Hopper GPUs. `kernels`
supports registering layers for a range of CUDA capabilities. To do so,
you need to register the layer for a `Device` with type `cuda` and
set the supported range of CUDA capabilities with using `CUDAProperties`:

```python
kernel_layer_mapping = {
    "SiluAndMul": {
        Device(
            type="cuda",
            properties=CUDAProperties(
                min_capability=75, max_capability=89
            ),
        ): LayerRepository(
            repo_id="kernels-community/activation",
            layer_name="SiluAndMul",
        ),
        Device(
            type="cuda",
            properties=CUDAProperties(
                min_capability=90, max_capability=sys.maxsize
            ),
        ): LayerRepository(
            repo_id="kernels-community/activation-hopper",
            layer_name="SiluAndMul",
        ),
    }
}
```

Capabilities behave as follows:

- The minimum and maximum capabilities are inclusive.
- When a new kernel is registered with the same min/max capabilities as
  an existing kernel, the new kernel will replace the old kernel.
- When there are multiple kernels that support a capability, the kernel
  with the smaller capability interval will be used. E.g. given:
  - `KernelA` with `min_capability=80` and `max_capability=89`;
  - `KernelB` with `min_capability=75` and `max_capability=89`;
  - `kernelize` runs on a system with capability 8.6.

  Then `KernelA` will be used because the interval 80..89 is smaller
  than 75..89. The motivation is that kernels with smaller ranges
  tend to be more optimized for a specific set of GPUs. **This behavior
  might still change in the future.**

### Registering kernels for specific ROCm capabilities

Registering kernels for the ROCm architecture follows the exact same
pattern as CUDA kernels, using `min_capability` and `max_capability` to restrict
a kernel to a range of ROCm capabilities.

### Loading from a local repository for testing

The `LocalLayerRepository` class is provided to load a repository from
a local directory. For example:

```python
with use_kernel_mapping(
    {
        "SiluAndMul": {
            "cuda": LocalLayerRepository(
                repo_path="/home/daniel/kernels/activation",
                package_name="activation",
                layer_name="SiluAndMul",
            )
        }
    },
    inherit_mapping=False,
):
    kernelize(linear, mode=Mode.INFERENCE)
```


<EditOnGithub source="https://github.com/huggingface/kernels/blob/main/docs/source/layers.md" />

### Environment variables
https://huggingface.co/docs/kernels/main/env.md

# Environment variables

## `KERNELS_CACHE`

The directory to use as the local kernel cache. If not set, the cache
of the `huggingface_hub` package is used.

## `DISABLE_KERNEL_MAPPING`

Disables kernel mappings for [`layers`](layers.md).


<EditOnGithub source="https://github.com/huggingface/kernels/blob/main/docs/source/env.md" />

### Installation
https://huggingface.co/docs/kernels/main/installation.md

# Installation

Install the `kernels` package with `pip` (requires `torch>=2.5` and CUDA):

```bash
pip install kernels
```

# Using kernels in a Docker container

Build and run the reference `examples/basic.py` in a Docker container with the following commands:

```bash
docker build --platform linux/amd64 -t kernels-reference -f docker/Dockerfile.reference .
docker run --gpus all -it --rm -e HF_TOKEN=$HF_TOKEN kernels-reference
```


<EditOnGithub source="https://github.com/huggingface/kernels/blob/main/docs/source/installation.md" />

### Basic Usage
https://huggingface.co/docs/kernels/main/basic-usage.md

# Basic Usage

## Loading Kernels

Here is how you would use the [activation](https://huggingface.co/kernels-community/activation) kernels from the Hugging Face Hub:

```python
import torch
from kernels import get_kernel

# Download optimized kernels from the Hugging Face hub
activation = get_kernel("kernels-community/activation")

# Create a random tensor
x = torch.randn((10, 10), dtype=torch.float16, device="cuda")

# Run the kernel
y = torch.empty_like(x)
activation.gelu_fast(y, x)

print(y)
```

### Using version bounds

Kernels are versioned using tags of the form `v<major>.<minor>.<patch>`.
You can specify which version to download using Python version specifiers:

```python
import torch
from kernels import get_kernel

activation = get_kernel("kernels-community/activation", version=">=0.0.4,<0.1.0")
```

This will get the latest kernel tagged `v0.0.z` where `z` is at least 4. It
is strongly recommended to specify a version bound, since a kernel author
might push incompatible changes to the `main` branch.

## Checking Kernel Availability

You can check if a specific kernel is available for your environment:

```python
from kernels import has_kernel

# Check if kernel is available for current environment
is_available = has_kernel("kernels-community/activation")
print(f"Kernel available: {is_available}")
```


<EditOnGithub source="https://github.com/huggingface/kernels/blob/main/docs/source/basic-usage.md" />

### Kernels CLI Reference
https://huggingface.co/docs/kernels/main/cli.md

# Kernels CLI Reference

## Main Functions

### kernels check

You can use `kernels check` to test compliance of a kernel on the Hub.
This currently checks that the kernel:

- Supports the currently-required Python ABI version.
- Works on supported operating system versions.

For example:

```bash
$ kernels check kernels-community/flash-attn3
Checking variant: torch28-cxx11-cu128-aarch64-linux
  🐍 Python ABI 3.9 compatible
  🐧 manylinux_2_28 compatible
[...]
```

### kernels to-wheel

We strongly recommend downloading kernels from the Hub using the `kernels`
package, since this comes with large [benefits](index.md) over using Python
wheels. That said, some projects may require deployment of kernels as
wheels. The `kernels` utility provides a simple solution to this. You can
convert any Hub kernel into a set of wheels with the `to-wheel` command:

```bash
$ kernels to-wheel drbh/img2grey 1.1.2
☸ img2grey-1.1.2+torch27cu128cxx11-cp39-abi3-manylinux_2_28_x86_64.whl
☸ img2grey-1.1.2+torch26cu124cxx11-cp39-abi3-manylinux_2_28_x86_64.whl
☸ img2grey-1.1.2+torch26cu126cxx11-cp39-abi3-manylinux_2_28_x86_64.whl
☸ img2grey-1.1.2+torch27cu126cxx11-cp39-abi3-manylinux_2_28_x86_64.whl
☸ img2grey-1.1.2+torch26cu126cxx98-cp39-abi3-manylinux_2_28_x86_64.whl
☸ img2grey-1.1.2+torch27cu128cxx11-cp39-abi3-manylinux_2_28_aarch64.whl
☸ img2grey-1.1.2+torch26cu126cxx98-cp39-abi3-manylinux_2_28_aarch64.whl
☸ img2grey-1.1.2+torch27cu126cxx11-cp39-abi3-manylinux_2_28_aarch64.whl
☸ img2grey-1.1.2+torch26cu126cxx11-cp39-abi3-manylinux_2_28_aarch64.whl
☸ img2grey-1.1.2+torch26cu118cxx98-cp39-abi3-manylinux_2_28_x86_64.whl
☸ img2grey-1.1.2+torch26cu124cxx98-cp39-abi3-manylinux_2_28_x86_64.whl
☸ img2grey-1.1.2+torch26cu118cxx11-cp39-abi3-manylinux_2_28_x86_64.whl
☸ img2grey-1.1.2+torch27cu118cxx11-cp39-abi3-manylinux_2_28_x86_64.whl
```

### kernels upload

Use `kernels upload <dir_containing_build> --repo_id="hub-username/kernel"` to upload
your kernel builds to the Hub. To know the supported arguments run: `kernels upload -h`.

**Notes**:

- This will take care of creating a repository on the Hub with the `repo_id` provided.
- If a repo with the `repo_id` already exists and if it contains a `build` with the build variant
  being uploaded, it will attempt to delete the files existing under it.
- Make sure to be authenticated (run `hf auth login` if not) to be able to perform uploads to the Hub.


<EditOnGithub source="https://github.com/huggingface/kernels/blob/main/docs/source/cli.md" />

### Kernel requirements
https://huggingface.co/docs/kernels/main/kernel-requirements.md

# Kernel requirements

Kernels on the Hub must fulfill the requirements outlined on this page. By
ensuring kernels are compliant, they can be used on a wide range of Linux
systems and Torch builds.

You can use [kernel-builder](https://github.com/huggingface/kernel-builder/)
to build compliant kernels.

## Directory layout

A kernel repository on the Hub must contain a `build` directory. This
directory contains build variants of a kernel in the form of directories
following the template
`<framework><version>-cxx<abiver>-<cu><cudaver>-<arch>-<os>`.
For example `build/torch26-cxx98-cu118-x86_64-linux`.

Each variant directory must contain a single directory with the same name
as the repository (replacing `-` by `_`). For instance, kernels in the
`kernels-community/activation` repository have a directories like
`build/<variant>/activation`. This directory
must be a Python package with an `__init__.py` file.

## Build variants

A kernel can be compliant for a specific compute framework (e.g. CUDA) or
architecture (e.g. x86_64). For compliance with a compute framework and
architecture combination, all the variants from the [build variant list](https://github.com/huggingface/kernel-builder/blob/main/docs/build-variants.md)
must be available for that combination.

## Versioning

Kernels are versioned on the Hub using Git tags. Version tags must be of
the form `v<major>.<minor>.<patch>`. Versions are used by [locking](./locking.md)
to resolve the version constraints.

We recommend using [semver](https://semver.org/) to version kernels.

## Native Python module

Kernels will typically contain a native Python module with precompiled
compute kernels and bindings. This module must fulfill the requirements
outlined in this section. For all operating systems, a kernel must not
have dynamic library dependencies outside:

- Torch;
- CUDA/ROCm libraries installed as dependencies of Torch.

## Compatibility with torch.compile

The Kernel Hub also encourages to write the kernels in a `torch.compile`
compliant way. This helps to ensure that the kernels are compatible with
`torch.compile` without introducing any graph breaks and triggering 
recompilation which can limit the benefits of compilation.

[Here](https://github.com/huggingface/kernel-builder/blob/d1ee9bf9301ac8c5199099d90ee1c9d5c789d5ba/examples/relu-backprop-compile/tests/test_relu.py#L162) is a simple test example which checks for graph breaks and 
recompilation triggers during `torch.compile`.

### Linux

- Use [ABI3/Limited API](https://docs.python.org/3/c-api/stable.html#stable-application-binary-interface)
  for compatibility with Python 3.9 and later.
- Compatible with [`manylinux_2_28`](https://github.com/pypa/manylinux?tab=readme-ov-file#manylinux_2_28-almalinux-8-based).
  This means that the extension **must not** use symbols versions higher than:
  - GLIBC 2.28
  - GLIBCXX 3.4.24
  - CXXABI 1.3.11
  - GCC 7.0.0

These requirements can be checked with the ABI checker (see below).

### macOS

- Use [ABI3/Limited API](https://docs.python.org/3/c-api/stable.html#stable-application-binary-interface)
  for compatibility with Python 3.9 and later.
- macOS deployment target 15.0.
- Metal 3.0 (`-std=metal3.0`).

The ABI3 requirement can be checked with the ABI checker (see below).

### ABI checker

The manylinux_2_28 and Python ABI 3.9 version requirements can be checked with
[`kernel-abi-check`](https://crates.io/crates/kernel-abi-check):

```bash

$ cargo install kernel-abi-check
$ kernel-abi-check result/relu/_relu_e87e0ca_dirty.abi3.so
🐍 Checking for compatibility with manylinux_2_28 and Python ABI version 3.9
✅ No compatibility issues found
```

## Torch extension

Torch native extension functions must be [registered](https://pytorch.org/tutorials/advanced/cpp_custom_ops.html#cpp-custom-ops-tutorial)
in `torch.ops.<namespace>`. Since we allow loading of multiple versions of
a module in the same Python process, `namespace` must be unique for each
version of a kernel. Failing to do so will create clashes when different
versions of the same kernel are loaded. Two suggested ways of doing this
are:

- Appending a truncated SHA-1 hash of the git commit that the kernel was
  built from to the name of the extension.
- Appending random material to the name of the extension.

**Note:** we recommend against appending a version number or git tag.
Version numbers are typically not bumped on each commit, so users
might use two different commits that happen to have the same version
number. Git tags are not stable, so they do not provide a good way
of guaranteeing uniqueness of the namespace.

## Layers

A kernel can provide layers in addition to kernel functions. A layer from
the Hub can replace the `forward` method of an existing layer for a certain
device type. This makes it possible to provide more performant kernels for
existing layers. See the [layers documentation](layers.md) for more information
on how to use layers.

### Writing layers

To make the extension of layers safe, the layers must fulfill the following
requirements:

- The layers are subclasses of `torch.nn.Module`.
- The layers are pure, meaning that they do not have their own state. This
  means that:
  - The layer must not define its own constructor.
  - The layer must not use class variables.
- No other methods must be defined than `forward`.
- The `forward` method has a signature that is compatible with the
  `forward` method that it is extending.

There are two exceptions to the _no class variables rule_:

1. The `has_backward` variable can be used to indicate whether the layer has
   a backward pass implemented (`True` when absent).
2. The `can_torch_compile` variable can be used to indicate whether the layer
   supports `torch.compile` (`False` when absent).

This is an example of a pure layer:

```python
class SiluAndMul(nn.Module):
    # This layer does not implement backward.
    has_backward: bool = False

    def forward(self, x: torch.Tensor):
        d = x.shape[-1] // 2
        output_shape = x.shape[:-1] + (d,)
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
        ops.silu_and_mul(out, x)
        return out
```

For some layers, the `forward` method has to use state from the adopting class.
In these cases, we recommend to use type annotations to indicate what member
variables are expected. For instance:

```python
class LlamaRMSNorm(nn.Module):
    weight: torch.Tensor
    variance_epsilon: float

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        return rms_norm_fn(
            hidden_states,
            self.weight,
            bias=None,
            residual=None,
            eps=self.variance_epsilon,
            dropout_p=0.0,
            prenorm=False,
            residual_in_fp32=False,
        )
```

This layer expects the adopting layer to have `weight` and `variance_epsilon`
member variables and uses them in the `forward` method.

### Exporting layers

To accommodate portable loading, `layers` must be defined in the main
`__init__.py` file. For example:

```python
from . import layers

__all__ = [
  # ...
  "layers"
  # ...
]
```

## Python requirements

- Python code must be compatible with Python 3.9 and later.
- All Python code imports from the kernel itself must be relative. So,
  for instance if in the example kernel `example`,
  `module_b` needs a function from `module_a`, import as:

  ```python
  from .module_a import foo
  ```

  **Never use:**

  ```python
  # DO NOT DO THIS!

  from example.module_a import foo
  ```

  The latter would import from the module `example` that is in Python's
  global module dict. However, since we allow loading multiple versions
  of a module, we uniquely name the module.

- Only modules from the Python standard library, Torch, or the kernel itself
  can be imported.


<EditOnGithub source="https://github.com/huggingface/kernels/blob/main/docs/source/kernel-requirements.md" />

### Layers API Reference
https://huggingface.co/docs/kernels/main/api/layers.md

# Layers API Reference

## Making layers kernel-aware

### use_kernel_forward_from_hub[[kernels.use_kernel_forward_from_hub]]

<div class="docstring border-l-2 border-t-2 pl-4 pt-3.5 border-gray-100 rounded-tl-xl mb-6 mt-8">


<docstring><name>kernels.use_kernel_forward_from_hub</name><anchor>kernels.use_kernel_forward_from_hub</anchor><source>https://github.com/huggingface/kernels/blob/main/src/kernels/layer.py#L988</source><parameters>[{"name": "layer_name", "val": ": str"}]</parameters><paramsdesc>- **layer_name** (`str`) --
  The name of the layer to use for kernel lookup in registered mappings.</paramsdesc><paramgroups>0</paramgroups><rettype>`Callable`</rettype><retdesc>A decorator function that can be applied to layer classes.</retdesc></docstring>

Decorator factory that makes a layer extensible using the specified layer name.

This is a decorator factory that returns a decorator which prepares a layer class to use kernels from the
Hugging Face Hub.







<ExampleCodeBlock anchor="kernels.use_kernel_forward_from_hub.example">

Example:
```python
import torch
import torch.nn as nn

from kernels import use_kernel_forward_from_hub
from kernels import Mode, kernelize

@use_kernel_forward_from_hub("MyCustomLayer")
class MyCustomLayer(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size

    def forward(self, x: torch.Tensor):
        # original implementation
        return x

model = MyCustomLayer(768)

# The layer can now be kernelized:
# model = kernelize(model, mode=Mode.TRAINING | Mode.TORCH_COMPILE, device="cuda")
```

</ExampleCodeBlock>


</div>

### replace_kernel_forward_from_hub[[kernels.replace_kernel_forward_from_hub]]

<div class="docstring border-l-2 border-t-2 pl-4 pt-3.5 border-gray-100 rounded-tl-xl mb-6 mt-8">


<docstring><name>kernels.replace_kernel_forward_from_hub</name><anchor>kernels.replace_kernel_forward_from_hub</anchor><source>https://github.com/huggingface/kernels/blob/main/src/kernels/layer.py#L753</source><parameters>[{"name": "layer_name", "val": ": str"}]</parameters></docstring>

Function that prepares a layer class to use kernels from the Hugging Face Hub.

It is recommended to use [use_kernel_forward_from_hub()](/docs/kernels/main/en/api/layers#kernels.use_kernel_forward_from_hub) decorator instead.
This function should only be used as a last resort to extend third-party layers,
it is inherently fragile since the member variables and `forward` signature
of usch a layer can change.

<ExampleCodeBlock anchor="kernels.replace_kernel_forward_from_hub.example">

Example:
```python
from kernels import replace_kernel_forward_from_hub
import torch.nn as nn

replace_kernel_forward_from_hub(nn.LayerNorm, "LayerNorm")
```

</ExampleCodeBlock>


</div>

## Registering kernel mappings

### use_kernel_mapping[[kernels.use_kernel_mapping]]

<div class="docstring border-l-2 border-t-2 pl-4 pt-3.5 border-gray-100 rounded-tl-xl mb-6 mt-8">


<docstring><name>kernels.use_kernel_mapping</name><anchor>kernels.use_kernel_mapping</anchor><source>https://github.com/huggingface/kernels/blob/main/src/kernels/layer.py#L596</source><parameters>[{"name": "mapping", "val": ": Dict[str, Dict[Union[Device, str], Union[LayerRepositoryProtocol, Dict[Mode, LayerRepositoryProtocol]]]]"}, {"name": "inherit_mapping", "val": ": bool = True"}]</parameters><paramsdesc>- **mapping** (`Dict[str, Dict[Union[Device, str], Union[LayerRepositoryProtocol, Dict[Mode, LayerRepositoryProtocol]]]]`) --
  The kernel mapping to apply. Maps layer names to device-specific kernel configurations.
- **inherit_mapping** (`bool`, *optional*, defaults to `True`) --
  When `True`, the current mapping will be extended by `mapping` inside the context. When `False`,
  only `mapping` is used inside the context.</paramsdesc><paramgroups>0</paramgroups><retdesc>Context manager that handles the temporary kernel mapping.</retdesc></docstring>

Context manager that sets a kernel mapping for the duration of the context.

This function allows temporary kernel mappings to be applied within a specific context, enabling different
kernel configurations for different parts of your code.





<ExampleCodeBlock anchor="kernels.use_kernel_mapping.example">

Example:
```python
import torch
import torch.nn as nn
from torch.nn import functional as F

from kernels import use_kernel_forward_from_hub
from kernels import use_kernel_mapping, LayerRepository, Device
from kernels import Mode, kernelize

# Define a mapping
mapping = {
    "SiluAndMul": {
        "cuda": LayerRepository(
            repo_id="kernels-community/activation",
            layer_name="SiluAndMul",
        )
    }
}

@use_kernel_forward_from_hub("SiluAndMul")
class SiluAndMul(nn.Module):
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        return F.silu(x[..., :d]) * x[..., d:]

model = SiluAndMul()

# Use the mapping for the duration of the context.
with use_kernel_mapping(mapping):
    # kernelize uses the temporary mapping
    model = kernelize(model, mode=Mode.TRAINING | Mode.TORCH_COMPILE, device="cuda")

# Outside the context, original mappings are restored
```

</ExampleCodeBlock>


</div>

### register_kernel_mapping[[kernels.register_kernel_mapping]]

<div class="docstring border-l-2 border-t-2 pl-4 pt-3.5 border-gray-100 rounded-tl-xl mb-6 mt-8">


<docstring><name>kernels.register_kernel_mapping</name><anchor>kernels.register_kernel_mapping</anchor><source>https://github.com/huggingface/kernels/blob/main/src/kernels/layer.py#L675</source><parameters>[{"name": "mapping", "val": ": Dict[str, Dict[Union[Device, str], Union[LayerRepositoryProtocol, Dict[Mode, LayerRepositoryProtocol]]]]"}, {"name": "inherit_mapping", "val": ": bool = True"}]</parameters><paramsdesc>- **mapping** (`Dict[str, Dict[Union[Device, str], Union[LayerRepositoryProtocol, Dict[Mode, LayerRepositoryProtocol]]]]`) --
  The kernel mapping to register globally. Maps layer names to device-specific kernels.
  The mapping can specify different kernels for different modes (training, inference, etc.).
- **inherit_mapping** (`bool`, *optional*, defaults to `True`) --
  When `True`, the current mapping will be extended by `mapping`. When `False`, the existing mappings
  are erased before adding `mapping`.</paramsdesc><paramgroups>0</paramgroups></docstring>

Register a global mapping between layer names and their corresponding kernel implementations.

This function allows you to register a mapping between a layer name and the corresponding kernel(s) to use,
depending on the device and mode. This should be used in conjunction with [kernelize()](/docs/kernels/main/en/api/layers#kernels.kernelize).



<ExampleCodeBlock anchor="kernels.register_kernel_mapping.example">

Example:
```python
from kernels import LayerRepository, register_kernel_mapping, Mode

# Simple mapping for a single kernel per device
kernel_layer_mapping = {
    "LlamaRMSNorm": {
        "cuda": LayerRepository(
            repo_id="kernels-community/activation",
            layer_name="RmsNorm",
            revision="layers",
        ),
    },
}
register_kernel_mapping(kernel_layer_mapping)

# Advanced mapping with mode-specific kernels
advanced_mapping = {
    "MultiHeadAttention": {
        "cuda": {
            Mode.TRAINING: LayerRepository(
                repo_id="username/training-kernels",
                layer_name="TrainingAttention"
            ),
            Mode.INFERENCE: LayerRepository(
                repo_id="username/inference-kernels",
                layer_name="FastAttention"
            ),
        }
    }
}
register_kernel_mapping(advanced_mapping)
```

</ExampleCodeBlock>


</div>

## Kernelizing a model

### kernelize[[kernels.kernelize]]

<div class="docstring border-l-2 border-t-2 pl-4 pt-3.5 border-gray-100 rounded-tl-xl mb-6 mt-8">


<docstring><name>kernels.kernelize</name><anchor>kernels.kernelize</anchor><source>https://github.com/huggingface/kernels/blob/main/src/kernels/layer.py#L822</source><parameters>[{"name": "model", "val": ": 'nn.Module'"}, {"name": "mode", "val": ": Mode"}, {"name": "device", "val": ": Optional[Union[str, 'torch.device']] = None"}, {"name": "use_fallback", "val": ": bool = True"}]</parameters><paramsdesc>- **model** (`nn.Module`) --
  The PyTorch model to kernelize.
- **mode** ([Mode](/docs/kernels/main/en/api/layers#kernels.Mode)) -- The mode that the kernel is going to be used in. For example,
  `Mode.TRAINING | Mode.TORCH_COMPILE` kernelizes the model for training with
  `torch.compile`.
- **device** (`Union[str, torch.device]`, *optional*) --
  The device type to load kernels for. Supported device types are: "cuda", "mps", "npu", "rocm", "xpu".
  The device type will be inferred from the model parameters when not provided.
- **use_fallback** (`bool`, *optional*, defaults to `True`) --
  Whether to use the original forward method of modules when no compatible kernel could be found.
  If set to `False`, an exception will be raised in such cases.</paramsdesc><paramgroups>0</paramgroups><rettype>`nn.Module`</rettype><retdesc>The kernelized model with optimized kernel implementations.</retdesc></docstring>

Replace layer forward methods with optimized kernel implementations.

This function iterates over all modules in the model and replaces the `forward` method of extensible layers
for which kernels are registered using [register_kernel_mapping()](/docs/kernels/main/en/api/layers#kernels.register_kernel_mapping) or [use_kernel_mapping()](/docs/kernels/main/en/api/layers#kernels.use_kernel_mapping).







<ExampleCodeBlock anchor="kernels.kernelize.example">

Example:
```python
import torch
import torch.nn as nn

from kernels import kernelize, Mode, register_kernel_mapping, LayerRepository
from kernels import use_kernel_forward_from_hub

@use_kernel_forward_from_hub("SiluAndMul")
class SiluAndMul(nn.Module):
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        return F.silu(x[..., :d]) * x[..., d:]

mapping = {
    "SiluAndMul": {
        "cuda": LayerRepository(
            repo_id="kernels-community/activation",
            layer_name="SiluAndMul",
        )
    }
}
register_kernel_mapping(mapping)

# Create and kernelize a model
model = nn.Sequential(
    nn.Linear(1024, 2048, device="cuda"),
    SiluAndMul(),
)

# Kernelize for inference
kernelized_model = kernelize(model, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
```

</ExampleCodeBlock>


</div>

## Classes

### Device[[kernels.Device]]

<div class="docstring border-l-2 border-t-2 pl-4 pt-3.5 border-gray-100 rounded-tl-xl mb-6 mt-8">


<docstring><name>class kernels.Device</name><anchor>kernels.Device</anchor><source>https://github.com/huggingface/kernels/blob/main/src/kernels/layer.py#L81</source><parameters>[{"name": "type", "val": ": str"}, {"name": "properties", "val": ": Optional[CUDAProperties] = None"}]</parameters><paramsdesc>- **type** (`str`) --
  The device type (e.g., "cuda", "mps", "npu", "rocm", "xpu").
- **properties** (`CUDAProperties`, *optional*) --
  Device-specific properties. Currently only `CUDAProperties` is supported for CUDA devices.</paramsdesc><paramgroups>0</paramgroups></docstring>

Represents a compute device with optional properties.

This class encapsulates device information including device type and optional device-specific properties
like CUDA capabilities.



<ExampleCodeBlock anchor="kernels.Device.example">

Example:
```python
from kernels import Device, CUDAProperties

# Basic CUDA device
cuda_device = Device(type="cuda")

# CUDA device with specific capability requirements
cuda_device_with_props = Device(
    type="cuda",
    properties=CUDAProperties(min_capability=75, max_capability=90)
)

# MPS device for Apple Silicon
mps_device = Device(type="mps")

# XPU device (e.g., Intel(R) Data Center GPU Max 1550)
xpu_device = Device(type="xpu")

# NPU device (Huawei Ascend)
npu_device = Device(type="npu")
```

</ExampleCodeBlock>



<div class="docstring border-l-2 border-t-2 pl-4 pt-3.5 border-gray-100 rounded-tl-xl mb-6 mt-8">


<docstring><name>create_repo</name><anchor>kernels.Device.create_repo</anchor><source>https://github.com/huggingface/kernels/blob/main/src/kernels/layer.py#L126</source><parameters>[]</parameters></docstring>
Create an appropriate repository set for this device type.

</div></div>

### Mode[[kernels.Mode]]

<div class="docstring border-l-2 border-t-2 pl-4 pt-3.5 border-gray-100 rounded-tl-xl mb-6 mt-8">


<docstring><name>class kernels.Mode</name><anchor>kernels.Mode</anchor><source>https://github.com/huggingface/kernels/blob/main/src/kernels/layer.py#L43</source><parameters>[{"name": "value", "val": ""}, {"name": "names", "val": " = None"}, {"name": "module", "val": " = None"}, {"name": "qualname", "val": " = None"}, {"name": "type", "val": " = None"}, {"name": "start", "val": " = 1"}]</parameters><paramsdesc>- **INFERENCE** -- The kernel is used for inference.
- **TRAINING** -- The kernel is used for training.
- **TORCH_COMPILE** -- The kernel is used with `torch.compile`.
- **FALLBACK** -- In a kernel mapping, this kernel is used when no other mode matches.</paramsdesc><paramgroups>0</paramgroups></docstring>

Kernelize mode

The `Mode` flag is used by [kernelize()](/docs/kernels/main/en/api/layers#kernels.kernelize) to select kernels for the given mode. Mappings can be registered for
specific modes.



Note:
Different modes can be combined. For instance, `INFERENCE | TORCH_COMPILE` should be used for layers that
are used for inference *with* `torch.compile`.



</div>

### LayerRepository[[kernels.LayerRepository]]

<div class="docstring border-l-2 border-t-2 pl-4 pt-3.5 border-gray-100 rounded-tl-xl mb-6 mt-8">


<docstring><name>class kernels.LayerRepository</name><anchor>kernels.LayerRepository</anchor><source>https://github.com/huggingface/kernels/blob/main/src/kernels/layer.py#L247</source><parameters>[{"name": "repo_id", "val": ": str"}, {"name": "layer_name", "val": ": str"}, {"name": "revision", "val": ": Optional[str] = None"}, {"name": "version", "val": ": Optional[str] = None"}]</parameters><paramsdesc>- **repo_id** (`str`) --
  The Hub repository containing the layer.
- **layer_name** (`str`) --
  The name of the layer within the kernel repository.
- **revision** (`str`, *optional*, defaults to `"main"`) --
  The specific revision (branch, tag, or commit) to download. Cannot be used together with `version`.
- **version** (`str`, *optional*) --
  The kernel version to download. This can be a Python version specifier, such as `">=1.0.0,<2.0.0"`.
  Cannot be used together with `revision`.</paramsdesc><paramgroups>0</paramgroups></docstring>

Repository and name of a layer for kernel mapping.



<ExampleCodeBlock anchor="kernels.LayerRepository.example">

Example:
```python
from kernels import LayerRepository

# Reference a specific layer by revision
layer_repo = LayerRepository(
    repo_id="kernels-community/activation",
    layer_name="SiluAndMul",
)

# Reference a layer by version constraint
layer_repo_versioned = LayerRepository(
    repo_id="kernels-community/activation",
    layer_name="SiluAndMul",
    version=">=0.0.3,<0.1"
)
```

</ExampleCodeBlock>


</div>

### LocalLayerRepository[[kernels.LocalLayerRepository]]

<div class="docstring border-l-2 border-t-2 pl-4 pt-3.5 border-gray-100 rounded-tl-xl mb-6 mt-8">


<docstring><name>class kernels.LocalLayerRepository</name><anchor>kernels.LocalLayerRepository</anchor><source>https://github.com/huggingface/kernels/blob/main/src/kernels/layer.py#L327</source><parameters>[{"name": "repo_path", "val": ": Path"}, {"name": "package_name", "val": ": str"}, {"name": "layer_name", "val": ": str"}]</parameters><paramsdesc>- **repo_path** (`Path`) --
  The local repository containing the layer.
- **package_name** (`str`) --
  Package name of the kernel.
- **layer_name** (`str`) --
  The name of the layer within the kernel repository.</paramsdesc><paramgroups>0</paramgroups></docstring>

Repository from a local directory for kernel mapping.



<ExampleCodeBlock anchor="kernels.LocalLayerRepository.example">

Example:
```python
from pathlib import Path

from kernels import LocalLayerRepository

# Reference a specific layer by revision
layer_repo = LocalLayerRepository(
    repo_path=Path("/home/daniel/kernels/activation"),
    package_name="activation",
    layer_name="SiluAndMul",
)
```

</ExampleCodeBlock>


</div>

### LockedLayerRepository[[kernels.LockedLayerRepository]]

<div class="docstring border-l-2 border-t-2 pl-4 pt-3.5 border-gray-100 rounded-tl-xl mb-6 mt-8">


<docstring><name>class kernels.LockedLayerRepository</name><anchor>kernels.LockedLayerRepository</anchor><source>https://github.com/huggingface/kernels/blob/main/src/kernels/layer.py#L383</source><parameters>[{"name": "repo_id", "val": ": str"}, {"name": "lockfile", "val": ": Optional[Path] = None"}, {"name": "layer_name", "val": ": str"}]</parameters></docstring>

Repository and name of a layer.

In contrast to `LayerRepository`, this class uses repositories that
are locked inside a project.


</div>

<EditOnGithub source="https://github.com/huggingface/kernels/blob/main/docs/source/api/layers.md" />

### Kernels API Reference
https://huggingface.co/docs/kernels/main/api/kernels.md

# Kernels API Reference

## Main Functions

### get_kernel[[kernels.get_kernel]]

<div class="docstring border-l-2 border-t-2 pl-4 pt-3.5 border-gray-100 rounded-tl-xl mb-6 mt-8">


<docstring><name>kernels.get_kernel</name><anchor>kernels.get_kernel</anchor><source>https://github.com/huggingface/kernels/blob/main/src/kernels/utils.py#L222</source><parameters>[{"name": "repo_id", "val": ": str"}, {"name": "revision", "val": ": typing.Optional[str] = None"}, {"name": "version", "val": ": typing.Optional[str] = None"}, {"name": "user_agent", "val": ": typing.Union[str, dict, NoneType] = None"}]</parameters><paramsdesc>- **repo_id** (`str`) --
  The Hub repository containing the kernel.
- **revision** (`str`, *optional*, defaults to `"main"`) --
  The specific revision (branch, tag, or commit) to download. Cannot be used together with `version`.
- **version** (`str`, *optional*) --
  The kernel version to download. This can be a Python version specifier, such as `">=1.0.0,<2.0.0"`.
  Cannot be used together with `revision`.
- **user_agent** (`Union[str, dict]`, *optional*) --
  The `user_agent` info to pass to `snapshot_download()` for internal telemetry.</paramsdesc><paramgroups>0</paramgroups><rettype>`ModuleType`</rettype><retdesc>The imported kernel module.</retdesc></docstring>

Load a kernel from the kernel hub.

This function downloads a kernel to the local Hugging Face Hub cache directory (if it was not downloaded before)
and then loads the kernel.







<ExampleCodeBlock anchor="kernels.get_kernel.example">

Example:
```python
import torch
from kernels import get_kernel

activation = get_kernel("kernels-community/activation")
x = torch.randn(10, 20, device="cuda")
out = torch.empty_like(x)
result = activation.silu_and_mul(out, x)
```

</ExampleCodeBlock>


</div>

### get_local_kernel[[kernels.get_local_kernel]]

<div class="docstring border-l-2 border-t-2 pl-4 pt-3.5 border-gray-100 rounded-tl-xl mb-6 mt-8">


<docstring><name>kernels.get_local_kernel</name><anchor>kernels.get_local_kernel</anchor><source>https://github.com/huggingface/kernels/blob/main/src/kernels/utils.py#L266</source><parameters>[{"name": "repo_path", "val": ": Path"}, {"name": "package_name", "val": ": str"}]</parameters><paramsdesc>- **repo_path** (`Path`) --
  The local path to the kernel repository.
- **package_name** (`str`) --
  The name of the package to import from the repository.</paramsdesc><paramgroups>0</paramgroups><rettype>`ModuleType`</rettype><retdesc>The imported kernel module.</retdesc></docstring>

Import a kernel from a local kernel repository path.








</div>

### has_kernel[[kernels.has_kernel]]

<div class="docstring border-l-2 border-t-2 pl-4 pt-3.5 border-gray-100 rounded-tl-xl mb-6 mt-8">


<docstring><name>kernels.has_kernel</name><anchor>kernels.has_kernel</anchor><source>https://github.com/huggingface/kernels/blob/main/src/kernels/utils.py#L299</source><parameters>[{"name": "repo_id", "val": ": str"}, {"name": "revision", "val": ": typing.Optional[str] = None"}, {"name": "version", "val": ": typing.Optional[str] = None"}]</parameters><paramsdesc>- **repo_id** (`str`) --
  The Hub repository containing the kernel.
- **revision** (`str`, *optional*, defaults to `"main"`) --
  The specific revision (branch, tag, or commit) to download. Cannot be used together with `version`.
- **version** (`str`, *optional*) --
  The kernel version to download. This can be a Python version specifier, such as `">=1.0.0,<2.0.0"`.
  Cannot be used together with `revision`.</paramsdesc><paramgroups>0</paramgroups><rettype>`bool`</rettype><retdesc>`True` if a kernel is available for the current environment.</retdesc></docstring>

Check whether a kernel build exists for the current environment (Torch version and compute framework).








</div>

## Loading locked kernels

### load_kernel[[kernels.load_kernel]]

<div class="docstring border-l-2 border-t-2 pl-4 pt-3.5 border-gray-100 rounded-tl-xl mb-6 mt-8">


<docstring><name>kernels.load_kernel</name><anchor>kernels.load_kernel</anchor><source>https://github.com/huggingface/kernels/blob/main/src/kernels/utils.py#L337</source><parameters>[{"name": "repo_id", "val": ": str"}, {"name": "lockfile", "val": ": typing.Optional[pathlib.Path] = None"}]</parameters><paramsdesc>- **repo_id** (`str`) --
  The Hub repository containing the kernel.
- **lockfile** (`Path`, *optional*) --
  Path to the lockfile. If not provided, the lockfile will be loaded from the caller's package metadata.</paramsdesc><paramgroups>0</paramgroups><rettype>`ModuleType`</rettype><retdesc>The imported kernel module.</retdesc></docstring>

Get a pre-downloaded, locked kernel.

If `lockfile` is not specified, the lockfile will be loaded from the caller's package metadata.








</div>

### get_locked_kernel[[kernels.get_locked_kernel]]

<div class="docstring border-l-2 border-t-2 pl-4 pt-3.5 border-gray-100 rounded-tl-xl mb-6 mt-8">


<docstring><name>kernels.get_locked_kernel</name><anchor>kernels.get_locked_kernel</anchor><source>https://github.com/huggingface/kernels/blob/main/src/kernels/utils.py#L394</source><parameters>[{"name": "repo_id", "val": ": str"}, {"name": "local_files_only", "val": ": bool = False"}]</parameters><paramsdesc>- **repo_id** (`str`) --
  The Hub repository containing the kernel.
- **local_files_only** (`bool`, *optional*, defaults to `False`) --
  Whether to only use local files and not download from the Hub.</paramsdesc><paramgroups>0</paramgroups><rettype>`ModuleType`</rettype><retdesc>The imported kernel module.</retdesc></docstring>

Get a kernel using a lock file.








</div>

<EditOnGithub source="https://github.com/huggingface/kernels/blob/main/docs/source/api/kernels.md" />
