Skip to main content

Enhanced Triton with Metal backend for Apple Silicon GPUs

Project description

Documentation PyPI License
Documentation PyPI License

Triton-Metal

Triton-Metal is an enhanced version of Triton with optimized Metal backend support for Apple Silicon GPUs. This fork focuses on delivering high performance for ML workloads on M1, M2, and M3 chips, with special optimizations for the M3's advanced capabilities.

The foundations of this project are described in the following MAPL2019 publication: Triton: An Intermediate Language and Compiler for Tiled Neural Network Computations. Please consider citing this work if you use Triton!

The official documentation contains installation instructions and tutorials specific to the Metal backend.

Quick Installation

You can install the latest stable release of Triton-Metal from PyPI:

pip install triton-metal

For the full functionality with Metal backend, install with the metal extras:

pip install "triton-metal[metal]"

Using the Package

After installation, you can import the package as triton (not triton_metal):

import triton
import triton.language as tl

Binary wheels are available for:

  • macOS 13.5+ on Apple Silicon (M1/M2/M3)
  • CPython 3.9-3.13

Metal Backend Features

The Metal backend enables running Triton kernels on Apple Silicon GPUs with the following optimizations:

  • Full MLX integration for efficient Metal execution
  • M3-specific optimizations leveraging 64KB shared memory (vs 32KB on M1/M2)
  • 8-wide vectorization support for M3 chips
  • Tensor core utilization for matrix operations on M3
  • Enhanced SIMD operations (32-wide vs 16-wide on M1/M2)
  • Dynamic register caching
  • Automatic hardware detection and optimization

Requirements

  • macOS 13.5 or higher
  • Apple Silicon Mac (M1/M2/M3)
  • MLX 0.3.0 or higher (installed automatically with pip install "triton-metal[metal]")

Usage

To use the Metal backend:

import os
os.environ["TRITON_BACKEND"] = "metal"  # Set this before importing Triton

import triton
import triton.language as tl

# Your Triton code here

Enabling Blackwell Support

The main branch now features support for NVIDIA Blackwell GPUs using 5th generation tensor cores. To enable this, you will need two additional steps:

  1. Build a pre-release PyTorch from source with CUDA 12.8
  2. Build triton from the latest source

First, to build pytorch you need to have CUDA 12.8 installed locally. If not, follow the instructions for your platform

# Clone and checkout pytorch 2.6 release candidate
git clone https://github.com/pytorch/pytorch
cd pytorch
git checkout v2.6.0-rc9
git submodule sync
git submodule update --init --recursive -j 8

# Install build dependencies (assumes you already have a system compiler)
pip install -r requirements.txt
pip install mkl-static mkl-include wheel

# Build PyTorch (will take a long time)
export CUDA_HOME=/usr/local/cuda-12.8
export CUDA_PATH=$CUDA_HOME
export TORCH_CUDA_ARCH_LIST=Blackwell
python setup.py develop

# Optional, package build into a wheel to install on other machines.
python setup.py bdist_wheel
ls dist  # Wheel should be output in this directory

Note that if you use the domain libraries (torchvision, torchtext, torchaudio, etc.) these will need to be built from source as well, otherwise their custom PyTorch extensions will not work.

Finally, follow the instructions below to install triton from source.

Install from source

git clone https://github.com/triton-lang/triton.git
cd triton

pip install -r python/requirements.txt # build-time dependencies
pip install -e .

Or with a virtualenv:

git clone https://github.com/triton-lang/triton.git
cd triton

python -m venv .venv --prompt triton
source .venv/bin/activate

pip install -r python/requirements.txt # build-time dependencies
pip install -e .

Building with a custom LLVM

Triton uses LLVM to generate code for GPUs and CPUs. Normally, the Triton build downloads a prebuilt LLVM, but you can also build LLVM from source and use that.

LLVM does not have a stable API, so the Triton build will not work at an arbitrary LLVM version.

  1. Find the version of LLVM that Triton builds against. Check cmake/llvm-hash.txt to see the current version. For example, if it says: 49af6502c6dcb4a7f7520178bd14df396f78240c

    This means that the version of Triton you have builds against LLVM 49af6502.

  2. git checkout LLVM at this revision. Optionally, make additional modifications to LLVM.

  3. Build LLVM. For example, you might run

    $ cd $HOME/llvm-project  # your clone of LLVM.
    $ mkdir build
    $ cd build
    $ cmake -G Ninja -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=ON ../llvm -DLLVM_ENABLE_PROJECTS="mlir;llvm;lld" -DLLVM_TARGETS_TO_BUILD="host;NVPTX;AMDGPU"
    $ ninja
    
  4. Grab a snack, this will take a while.

  5. Build Triton as above, but set the following environment variables.

    # Modify as appropriate to point to your LLVM build.
    $ export LLVM_BUILD_DIR=$HOME/llvm-project/build
    
    $ cd <triton install>
    $ LLVM_INCLUDE_DIRS=$LLVM_BUILD_DIR/include \
      LLVM_LIBRARY_DIR=$LLVM_BUILD_DIR/lib \
      LLVM_SYSPATH=$LLVM_BUILD_DIR \
      pip install -e .
    

Tips for building

  • Set TRITON_BUILD_WITH_CLANG_LLD=true as an environment variable to use clang and lld. lld in particular results in faster builds.

  • Set TRITON_BUILD_WITH_CCACHE=true to build with ccache.

  • Set TRITON_HOME=/some/path to change the location of the .triton directory where Triton's cache is located and downloads are stored during the build. By default, this is the user's home directory. It can be changed anytime.

  • If you're running out of memory when building Triton, specify the MAX_JOBS environment variable (to the pip install -e . command) to limit the number of jobs.

  • Pass --no-build-isolation to pip install to make nop builds faster. Without this, every invocation of pip install uses a different symlink to cmake, and this forces ninja to rebuild most of the .a files.

  • vscode intellisense has some difficulty figuring out how to build Triton's C++ (probably because, in our build, users don't invoke cmake directly, but instead use setup.py). Teach vscode how to compile Triton as follows.

    • Do a local build. Run command pip install -e .
    • Get the full path to the compile_commands.json file produced by the build: find ./build -name 'compile_commands.json' | xargs readlink -f. You might get a full path similar to /Users/{username}/triton/build/cmake.macosx-11.1-arm64-cpython-3.12/compile_commands.json
    • In vscode, install the C/C++ extension, then open the command palette (Shift + Command + P on Mac, or Shift + Ctrl + P on Windows/Linux) and open C/C++: Edit Configurations (UI).
    • Open "Advanced Settings" and paste the full path to compile_commands.json into the "Compile Commands" textbox.

Running tests

There currently isn't a turnkey way to run all the Triton tests, but you can follow the following recipe.

# One-time setup.  Note this will reinstall local Triton because torch
# overwrites it with the public version.
$ make dev-install

# To run all tests (requires a GPU)
$ make test

# Or, to run tests without a gpu
$ make test-nogpu

Tips for hacking

For detailed instructions on how to debug Triton's frontend, please refer to this tutorial. The following includes additional tips for hacking on Triton's backend.

Configuration knobs

See python/triton/knobs.py for the full list of configuration knobs. You can set those knobs directly in python or use environment variables to control them. Below are some of the environment variables you can specify (see knobs.py for the full list):

  • MLIR_ENABLE_DUMP=1 dumps the IR before every MLIR pass Triton runs, for all kernels. Use MLIR_ENABLE_DUMP=kernelName to dump for a specific kernel only.

    • Triton cache can interfere with the dump. In cases where MLIR_ENABLE_DUMP=1 does not work, try cleaning your triton cache: rm -r ~/.triton/cache/*
  • MLIR_DUMP_PATH specifies where MLIR_ENABLE_DUMP will dump to. If unset will dump to stderr.

  • LLVM_IR_ENABLE_DUMP=1 dumps the IR before every pass run over the LLVM IR.

  • TRITON_REPRODUCER_PATH=<reproducer_path> will generate an MLIR reproducer file at <reproducer_path> before each MLIR compiler stage. If any of the stages fail, <reproducer_path> will be a local MLIR reproducer captured right before the failing pass.

  • TRITON_INTERPRET=1 uses the Triton interpreter instead of running on the GPU. You can insert Python breakpoints in your kernel code!

  • TRITON_ENABLE_LLVM_DEBUG=1 passes -debug to LLVM, printing a lot of debugging information to stdout. If this is too noisy, run with just TRITON_LLVM_DEBUG_ONLY instead to limit the output.

    An alternative way to reduce output noisiness is running with LLVM_IR_ENABLE_DUMP=1, extract the IR before the LLVM pass of interest, and then run LLVM's opt standalone, perhaps passing -debug-only=foo on the command line.

  • TRITON_LLVM_DEBUG_ONLY=<comma-separated> is the equivalent of LLVM's -debug-only command-line option. This limits the LLVM debug output to specific pass or component names (which are specified using #define DEBUG_TYPE throughout LLVM and Triton) in order to allow the debug output to be less noisy. TRITON_LLVM_DEBUG_ONLY allows for one or more comma separated values to be specified (eg TRITON_LLVM_DEBUG_ONLY="tritongpu-remove-layout-conversions" or TRITON_LLVM_DEBUG_ONLY="tritongpu-remove-layout-conversions,regalloc").

  • TRITON_ENABLE_ASAN=1 invokes the LLVM address sanitizer for memory leak and out of bounds access detection. Currently only supported on the AMD backend. This must be run using the ASAN libraries documented here.

    When enabling the address sanitizer it is recommended to disable various memory caching strategies both within the ROCm stack and PyTorch. This will give the address sanitizer the best chance at finding the memory fault where it originates. See this test for more details.

  • USE_IR_LOC={ttir,ttgir} reparses the IR such that the location information will be the line number of the IR file with that particular extension, instead of line number of the python file. This can provide a direct mapping from the IR to llir/ptx. When used with performance tools, it can provide a breakdown on IR instructions.

  • TRITON_PRINT_AUTOTUNING=1 prints out the best autotuning config and total time spent for each kernel after autotuning is complete.

  • DISABLE_LLVM_OPT will disable llvm optimizations for make_llir and make_ptx if its value is true when parsing as Bool. Otherwise, it will be parsed as a list of flags to disable llvm optimizations. One usage case is DISABLE_LLVM_OPT="disable-lsr" Loop strength reduction is known to cause up to 10% performance changes for certain kernels with register pressure.

  • TRITON_ALWAYS_COMPILE=1 forces to compile kernels regardless of cache hit.

  • MLIR_ENABLE_TIMING dumps the timing information for each MLIR pass.

  • LLVM_ENABLE_TIMING dumps the timing information for each LLVM pass.

  • TRITON_DEFAULT_FP_FUSION overrides the default behavior of allowing fp fusion (mul+add->fma).

  • MLIR_ENABLE_DIAGNOSTICS=<comma-separated> controls diagnostic emission in MLIR. Options are: warnings, remarks, stacktraces, operations. Use comma-separated values to customize output. For example, MLIR_ENABLE_DIAGNOSTICS=remarks,operations enables remarks and IR operations, while MLIR_ENABLE_DIAGNOSTICS=warnings,stacktraces enables warnings with stacktraces. By default, only errors are shown. Setting warnings includes errors and warnings; remarks

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

triton_metal-3.3.0rc2.tar.gz (81.8 MB view details)

Uploaded Source

File details

Details for the file triton_metal-3.3.0rc2.tar.gz.

File metadata

  • Download URL: triton_metal-3.3.0rc2.tar.gz
  • Upload date:
  • Size: 81.8 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.11.0

File hashes

Hashes for triton_metal-3.3.0rc2.tar.gz
Algorithm Hash digest
SHA256 5c28c1150f798f0c6762a53e3fd9f8f5ad618025591eb145bc97356bd279ce50
MD5 aaabc8ebd1f210f3a7a401e7035d900e
BLAKE2b-256 fa77aa8d18600df8f14283fd994bfbe0773b4f43728fb0b8a1d2397a72e2a31e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page