PyTorch + CUDA vs. XLA + TPU: Two Execution Models for ML Systems
When people first move from the familiar PyTorch + CUDA world into XLA + TPU, the most confusing part is often not a single API. It is the execution model.
PyTorch + CUDA: usually eager and op-by-op
In the typical PyTorch + CUDA workflow:
- execution is eager by default
- operations are issued op-by-op at runtime
- dispatch is driven by the ATen dispatcher
- performance is often dominated by kernels and vendor libraries such as
cuBLAS,cuDNN, andNCCL - work is executed immediately on the GPU
A useful conceptual diagram is:
1 | |
This is one reason PyTorch + CUDA feels so direct and interactive: the programming model is close to the runtime behavior.
XLA + TPU: compilation over larger computation regions
The XLA + TPU style is different.
Instead of treating execution as a stream of isolated operations, it often treats execution as compilation over larger computation regions. Depending on the framework and configuration, this may be ahead-of-time or just-in-time.
A rough conceptual pipeline looks like this:
1 | |
HLO and PJRT: the two key boundaries
HLO is the compiler IR.
It is valuable because it is:
- relatively framework-neutral
- high-level enough for graph optimization
- low-level enough to target hardware
PJRT is the runtime.
It gives a standard way to provide:
- device discovery
- executable loading
- buffer management
- execution
Why PyTorch creates some friction here
PyTorch is naturally:
- eager
- dynamic
- mutable
- object- and state-oriented
The XLA + TPU style tends to prefer:
- graph capture
- stable shapes
- stable execution paths
So PyTorch is not the most natural philosophical fit for this model.
Why JAX is a more natural fit
JAX more naturally encourages:
- functional style
- immutability
- explicit compilation through
jit
So a short summary is:
- JAX is a philosophical fit for XLA + TPU
- PyTorch/XLA is a practical and commercially necessary compromise
That is a useful framing when trying to understand why XLA + TPU can feel elegant in JAX and more awkward in PyTorch/XLA.
Why AWS Neuron likely chose this design
A backend like AWS Neuron has strong reasons to adapt the XLA + TPU style:
- it fits an ecosystem already used by JAX and PyTorch/XLA
- it allows one compiler and runtime stack to support multiple frameworks
A compact mental model for Neuron is:
1 | |
The AWS Neuron runtime environment
On a Neuron-enabled system, the host environment typically contains a runtime/tooling stack under /opt/aws/neuron.
Typical binaries include:
neuron-ls— enumerate Neuron devices and topologyneuron-monitor— collect live runtime metricsneuron-top— top-like live monitor built onneuron-monitorneuron-profile— capture and analyze workload or device profilesneuron-explorer— inspect and view profiling outputneuron-bench— benchmarking and reportingneuron-dbg— low-level device debuggingnccom-test— communication and collectives testingneuron-dump— support-bundle or dump helper
Typical libraries include:
libnrt.so— Neuron Runtimelibnccom.so— Neuron collectives and communicationlibndbg.so— debug supportlibncfw.solibnds.alibnrtucode_extisa.so
Runtime API headers under include/nrt/ typically include:
nrt.hnrt_async.hnrt_profile.hnrt_version.hnrt_status.h
Driver and device headers under include/ndl/ typically include:
ndl.hneuron_driver_shared.hneuron_driver_shared_tensor_batch_op.h