Skip to main content

Multi-Axis Attention Sharding for PyTorch - Distribute attention across GPUs

Project description

Mosaic

Multi-Axis Attention Sharding for PyTorch

Distribute attention computation across GPUs when your sequence is too long to fit on one device.

The Problem

Standard attention has O(n²) memory complexity. A 150,000-token sequence needs ~84GB just for attention weights:

Memory = n² × 4 bytes = 150,000² × 4 = 84 GB

Mosaic splits the sequence across GPUs and coordinates communication so each GPU only holds a fraction.

How It Works

Attention Refresher

For queries Q, keys K, values V with sequence length n:

Attention(Q, K, V) = softmax(QKᵀ / √d) × V

The bottleneck is QKᵀ — an (n × n) matrix.

Sharding Strategies

Mosaic supports three backends:

1. Local (No Sharding)

Each GPU has the full sequence. Use for small dimensions.

GPU 0: [Q₀, K₀, V₀] → Attention → [Out₀]

2. Ring Attention (1D Sharding)

Split sequence across GPUs. Each GPU holds Q_local but needs all K, V.

Solution: Pass K, V chunks around in a ring while accumulating partial attention.

Step 0:  GPU₀ has (Q₀, K₀, V₀)    GPU₁ has (Q₁, K₁, V₁)
         ↓ compute Q₀K₀ᵀ          ↓ compute Q₁K₁ᵀ
         
Step 1:  GPU₀ receives K₁, V₁    GPU₁ receives K₀, V₀
         ↓ compute Q₀K₁ᵀ          ↓ compute Q₁K₀ᵀ
         
Final:   Each GPU has full attention output for its Q chunk

Memory per GPU: O(n²/p) where p = number of GPUs

Communication: Each GPU sends/receives (n/p × d) per step, p-1 steps total

3. Mesh2D Attention (2D Sharding)

For very large sequences, shard both Q and K:

         K₀      K₁
       ┌──────┬──────┐
    Q₀ │GPU 0 │GPU 1 │  ← Each GPU computes one tile of QKᵀ
       ├──────┼──────┤
    Q₁ │GPU 2 │GPU 3 │
       └──────┴──────┘

Memory per GPU: O(n²/p²)

Trade-off: More communication (all-gather K, V along columns)

Installation

pip install mosaic-attention

# Or from source with ring attention
git clone https://github.com/stprnvsh/mosaic.git
cd mosaic
pip install -e ".[ring]"

Quick Start

import mosaic
import torch.nn as nn

# Initialize: 4 GPUs for sequence parallelism
ctx = mosaic.init(sp_size=4)

# Attention over axis 1, sharded across GPUs
attn = mosaic.MultiAxisAttention(
    embed_dim=128,
    num_heads=8,
    attention_axis=1,    # Which axis to attend over
    backend="ring"       # Use ring attention
)

# Input: (batch, sequence, features) where sequence is sharded
# Each GPU has (batch, seq_local, features) where seq_local = seq_total / 4
x_local = torch.randn(2, 37500, 128).cuda()  # 150k / 4 = 37.5k per GPU

out = attn(x_local)  # Ring communication happens automatically

Launch:

torchrun --nproc_per_node=4 train.py

Multi-Axis Models

Models like nanoTabPFN have tensors with shape (batch, rows, features, embed) and need attention over multiple axes:

Axis Dimension Strategy
Features (axis 2) ~5 Local (small, fits on GPU)
Rows (axis 1) ~150,000 Ring (too large, must shard)
class TabularTransformer(nn.Module):
    def __init__(self):
        super().__init__()
        # Small axis: local attention
        self.feature_attn = mosaic.MultiAxisAttention(
            embed_dim=96, num_heads=4,
            attention_axis=2,  # features
            backend="local"
        )
        # Large axis: ring attention
        self.row_attn = mosaic.MultiAxisAttention(
            embed_dim=96, num_heads=4,
            attention_axis=1,  # rows
            backend="ring"
        )
    
    def forward(self, x):
        # x: (batch, rows_local, features, embed)
        x = self.feature_attn(x) + x  # No communication
        x = self.row_attn(x) + x      # Ring across GPUs
        return x

API Reference

Core

mosaic.init(sp_size=1) → MosaicContext

Initialize distributed context.

  • sp_size: Number of GPUs for sequence parallelism

mosaic.MultiAxisAttention(embed_dim, num_heads, attention_axis, backend, mesh_shape=None)

Attention over any single tensor axis.

Parameter Description
embed_dim Hidden dimension (must be divisible by num_heads)
num_heads Number of attention heads
attention_axis Which axis to attend over (supports negative indexing)
backend "local", "ring", or "mesh2d"
mesh_shape Required for mesh2d: (rows, cols) GPU grid

Advanced

mosaic.MultiAxisAttention2D(embed_dim, num_heads, axis1, axis2, mesh_shape)

Attend over two axes simultaneously (flattens them internally).

# Attention over both rows and columns
attn = mosaic.MultiAxisAttention2D(
    embed_dim=128, num_heads=8,
    axis1=1, axis2=2,      # rows × columns
    mesh_shape=(2, 2)      # 4 GPUs in 2×2 grid
)
# Input: (batch, rows, cols, embed)
# Internally: flatten to (batch, rows*cols, embed), run mesh2d attention

mosaic.ComposedAttention(mesh_shape, head_parallel, seq_parallel)

Combine head parallelism with sequence parallelism.

# 8 GPUs: 2-way head parallel × 4-way sequence parallel
composed = mosaic.ComposedAttention(
    mesh_shape=(2, 4),
    head_parallel=True,    # Split heads across dim 0 (2 ways)
    seq_parallel="ring"    # Ring attention across dim 1 (4 ways)
)

Memory: Heads sharded 2×, sequence sharded 4× → 8× memory reduction

mosaic.HierarchicalAttention(intra_node_size, inter_node_strategy, intra_node_strategy)

Two-level parallelism for multi-node clusters.

# 4 nodes × 8 GPUs = 32 GPUs
hier = mosaic.HierarchicalAttention(
    intra_node_size=8,           # GPUs per node
    intra_node_strategy="local", # Fast NVLink within node
    inter_node_strategy="ring"   # Slower network between nodes
)

mosaic.HierarchicalMetaAttention (NEW in v0.3)

Learned routing via chunk summaries. Instead of communicating with all chunks (O(p) in ring), meta-attention learns which chunks are relevant and only communicates with top-k.

attn = mosaic.HierarchicalMetaAttention(
    embed_dim=256,
    num_heads=8,
    summary_dim=64,   # Compress each chunk to this dim
    top_k=2,          # Only attend to top-k relevant chunks
    use_checkpoint=True
)

out, meta = attn(x_local)
print(meta['routing_weights'])  # Shows which chunks are relevant
# e.g., tensor([[0.05, 0.40, 0.15, 0.40]]) → chunks 1 and 3 are most relevant

How it works:

  1. Local attention within each GPU's chunk
  2. Compress chunks into summaries (content + attention pattern stats)
  3. all_gather summaries (small: only summary_dim per GPU)
  4. Meta-attention computes routing weights
  5. Soft-weighted cross-attention using routing weights
  6. Gated combination of local + cross outputs

Benefits:

  • Communication: O(1) gather + O(top_k) vs O(p) ring rounds
  • Interpretability: routing_weights tells you which regions interact
  • Sparse attention: Learns which long-range dependencies matter

Use cases: Genomics (gene regulation), long documents (topic routing), time series (seasonal patterns)

mosaic.MultiAxisHierarchicalAttention

Convenience wrapper with axis permutation (same API as MultiAxisAttention).

attn = mosaic.MultiAxisHierarchicalAttention(
    embed_dim=256, num_heads=8,
    attention_axis=1,  # Attend over this axis
    summary_dim=64,
    top_k=2
)
out, meta = attn(x)  # Works with any N-dimensional tensor

Performance

All backends use FlashAttention (F.scaled_dot_product_attention) for the local computation:

  • Fused GEMM + softmax + GEMM
  • O(n) memory instead of O(n²) for attention weights
  • 2-4× faster than naive implementation

Communication uses NCCL collectives:

  • Ring: send/recv in ring topology
  • Mesh2D: all_gather along grid dimensions

When to Use What

Sequence Length GPUs Backend Memory per GPU
< 10k 1 local O(n²)
10k - 100k 2-8 ring O(n²/p)
100k - 1M 8-64 ring or mesh2d O(n²/p) or O(n²/p²)
> 1M 64+ mesh2d + head_parallel O(n²/(p²·h))

Distributed Launch

Single Node

# 4 GPUs on one machine
torchrun --nproc_per_node=4 train.py

Multi-Node

# Node 0 (master) - replace IP with your master node
torchrun --nnodes=2 --nproc_per_node=8 --node_rank=0 \
    --master_addr=192.168.1.100 --master_port=29500 train.py

# Node 1
torchrun --nnodes=2 --nproc_per_node=8 --node_rank=1 \
    --master_addr=192.168.1.100 --master_port=29500 train.py

SLURM

#!/bin/bash
#SBATCH --nodes=4
#SBATCH --ntasks-per-node=8
#SBATCH --gpus-per-node=8

srun torchrun \
    --nnodes=$SLURM_NNODES \
    --nproc_per_node=8 \
    --rdzv_id=$SLURM_JOB_ID \
    --rdzv_backend=c10d \
    --rdzv_endpoint=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n1):29500 \
    train.py

Training Script Setup

import os
import torch
import torch.distributed as dist
import mosaic

def main():
    # torchrun sets RANK, LOCAL_RANK, WORLD_SIZE automatically
    dist.init_process_group(backend="nccl")
    local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)
    
    # Initialize Mosaic with full world for sequence parallelism
    ctx = mosaic.init(sp_size=dist.get_world_size())
    
    model = MyModel().to(ctx.device)
    
    # Each GPU loads its shard of the sequence
    total_seq = 150000
    local_seq = total_seq // dist.get_world_size()
    x_local = load_my_shard(local_seq)  # Your data loading
    
    # Forward pass - ring communication automatic
    out = model(x_local)
    
    dist.destroy_process_group()

if __name__ == "__main__":
    main()

Multi-Node Mesh Configuration

# 2 nodes × 4 GPUs = 8 total
# Option 1: Head parallel across nodes (slow), seq parallel within (fast)
composed = mosaic.ComposedAttention(
    mesh_shape=(2, 4),      # (nodes, gpus_per_node)
    head_parallel=True,     # Across nodes
    seq_parallel="ring"     # Within node
)

# Option 2: Explicit hierarchical control
hier = mosaic.HierarchicalAttention(
    intra_node_size=4,           # GPUs per node
    intra_node_strategy="local", # No comm within node
    inter_node_strategy="ring"   # Ring between node leaders
)

Architecture

┌─────────────────────────────────────────┐
│              User Model                 │
│  (defines attention_axis per layer)     │
└─────────────────┬───────────────────────┘
                  │
┌─────────────────▼───────────────────────┐
│               Mosaic                    │
│  • Axis routing (permute to seq dim)    │
│  • Backend selection                    │
│  • Tensor reshape for QKV projection    │
└─────────────────┬───────────────────────┘
                  │
┌─────────────────▼───────────────────────┐
│             Backends                    │
│  • local: F.scaled_dot_product_attention│
│  • ring: ring_flash_attn_func           │
│  • mesh2d: all_gather + SDPA            │
└─────────────────┬───────────────────────┘
                  │
┌─────────────────▼───────────────────────┐
│          PyTorch / NCCL                 │
└─────────────────────────────────────────┘

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mosaic_attention-0.3.0.tar.gz (32.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

mosaic_attention-0.3.0-py3-none-any.whl (35.7 kB view details)

Uploaded Python 3

File details

Details for the file mosaic_attention-0.3.0.tar.gz.

File metadata

  • Download URL: mosaic_attention-0.3.0.tar.gz
  • Upload date:
  • Size: 32.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.10.14

File hashes

Hashes for mosaic_attention-0.3.0.tar.gz
Algorithm Hash digest
SHA256 1ac906fe5ec271fa4fd78375159660d27cd53134ea98b281e407079c095ac9cf
MD5 817e44a7be31b812c9e80542dcf091c4
BLAKE2b-256 45d8cbf21f804f2a589740e04efdf153eecd7af6238233557699e43416a9ad29

See more details on using hashes here.

File details

Details for the file mosaic_attention-0.3.0-py3-none-any.whl.

File metadata

File hashes

Hashes for mosaic_attention-0.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 89c9616c9148e48ff1c255bebe19e91ed56a1cf468f2ef75560d624f84cf1e54
MD5 5b289b46e132fad6d197e7adc58d8961
BLAKE2b-256 2d087b11f71e843c1b17d82590270cc2bdd2c9e339025b5da9e7f789d2f500db

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page