AWS Neuron Kernel Interface (NKI) lets you write custom compute kernels for AWS Trainium and Inferentia. This post documents a minimal end-to-end pipeline: writing an NKI kernel in Python, extracting its HLO, compiling the HLO to a Neuron Executable File Format (NEFF) binary with neuronx-cc, and loading the result on-device with spike.
1. Write a kernel: matmul.py
The example below implements a tiled matrix multiplication in NKI. It reads tiles from HBM into SBUF, accumulates partial results in PSUM via nc_matmul, and writes the final tile back to HBM.
import nki import nki.language as nl import nki.isa as nisa import numpy as np import torch from torch_xla.core import xla_model as xm from nki._torch_xla import PyTorchXLAKernel
if __name__ == '__main__': kernel = PyTorchXLAKernel(func=matmul) device = xm.xla_device()
a = torch.randn((2048, 2048)).to(device=device) b = torch.randn((2048, 2048)).to(device=device) c = kernel(a, b)
# Force XLA graph compilation xm.mark_step()
The function accepts a transposed left-hand-side (lhsT) — the kernel internally computes lhsT.T @ rhs, i.e., A.T @ B for the caller. The if __name__ == '__main__' block forces XLA graph capture; this is the entry point used by the extraction step below.
2. Extract HLO from the kernel
Neuron’s XLA-based compilation path lowers NKI kernels through PyTorch/XLA into HLO modules. The helper function below sets two environment variables so the runtime dumps those modules to disk without requiring device execution:
NEURON_EXTRACT_GRAPHS_ONLY=1 — skip execution; only extract graphs
NEURON_COMPILE_CACHE_URL=<dir> — write HLO artifacts directly into <dir>
The model.hlo_module.pb is a serialized HLO module containing the lowered kernel graph. Each run may produce multiple modules (runtime glue, parameter loading, etc.); the kernel’s module is typically the largest.
The spike library provides a lightweight runtime for loading pre-compiled NEFF binaries and running them on Neuron devices. It does not invoke the compiler — it loads the NEFF directly.
import numpy as np import torch from spike import SpikeModel, SpikeTensor
# 1. Load the kernel from a NEFF file (use .neff, NOT .hlo_module.pb) model = SpikeModel.load_from_neff( './hlo_artifacts/neuronxcc-2.22.12471.0+b4a00d10/' 'MODULE_8749100665346044836+e30acd3a/model.hlo_module.neff', 'matmul' )
# 3. Create inputs on host (torch → numpy → SpikeTensor) A = torch.rand(2048, 2048) B = torch.rand(2048, 2048) device_A = SpikeTensor.from_numpy(A.numpy(), name='input1') device_B = SpikeTensor.from_numpy(B.numpy(), name='input0')
# 4. Let the model auto-allocate outputs (pass outputs=None) outputs = model(inputs={'input1': device_A, 'input0': device_B})
# 5. Read result back (SpikeTensor → numpy → torch) result = torch.from_numpy(outputs['output0'].numpy())
# 6. Verify against PyTorch reference # This kernel computes input1.T @ input0 (A.T @ B) assert torch.allclose(A.T @ B, result, atol=1e-3)
A few practical notes:
Input and output tensors are identified by name (input1, input0, output0, etc.). The input_tensors_info / output_tensors_info dictionaries tell you the expected names, shapes, and dtypes.
SpikeTensor names must match the kernel’s parameter names.
The NEFF is an opaque binary — no compilation step occurs at load time, which makes this path useful when you want to ship a pre-compiled kernel without bundling the full compiler toolchain.