Multi-Axis Attention Sharding for PyTorch - Distribute attention across GPUs
Project description
Mosaic
Multi-Axis Attention Sharding for PyTorch
Distribute attention computation across GPUs when your sequence is too long to fit on one device.
The Problem
Standard attention has O(n²) memory complexity. A 150,000-token sequence needs ~84GB just for attention weights:
Memory = n² × 4 bytes = 150,000² × 4 = 84 GB
Mosaic splits the sequence across GPUs and coordinates communication so each GPU only holds a fraction.
How It Works
Attention Refresher
For queries Q, keys K, values V with sequence length n:
Attention(Q, K, V) = softmax(QKᵀ / √d) × V
The bottleneck is QKᵀ — an (n × n) matrix.
Sharding Strategies
Mosaic supports three backends:
1. Local (No Sharding)
Each GPU has the full sequence. Use for small dimensions.
GPU 0: [Q₀, K₀, V₀] → Attention → [Out₀]
2. Ring Attention (1D Sharding)
Split sequence across GPUs. Each GPU holds Q_local but needs all K, V.
Solution: Pass K, V chunks around in a ring while accumulating partial attention.
Step 0: GPU₀ has (Q₀, K₀, V₀) GPU₁ has (Q₁, K₁, V₁)
↓ compute Q₀K₀ᵀ ↓ compute Q₁K₁ᵀ
Step 1: GPU₀ receives K₁, V₁ GPU₁ receives K₀, V₀
↓ compute Q₀K₁ᵀ ↓ compute Q₁K₀ᵀ
Final: Each GPU has full attention output for its Q chunk
Memory per GPU: O(n²/p) where p = number of GPUs
Communication: Each GPU sends/receives (n/p × d) per step, p-1 steps total
3. Mesh2D Attention (2D Sharding)
For very large sequences, shard both Q and K:
K₀ K₁
┌──────┬──────┐
Q₀ │GPU 0 │GPU 1 │ ← Each GPU computes one tile of QKᵀ
├──────┼──────┤
Q₁ │GPU 2 │GPU 3 │
└──────┴──────┘
Memory per GPU: O(n²/p²)
Trade-off: More communication (all-gather K, V along columns)
Installation
pip install mosaic-attention
# Or from source with ring attention
git clone https://github.com/stprnvsh/mosaic.git
cd mosaic
pip install -e ".[ring]"
Quick Start
import mosaic
import torch.nn as nn
# Initialize: 4 GPUs for sequence parallelism
ctx = mosaic.init(sp_size=4)
# Attention over axis 1, sharded across GPUs
attn = mosaic.MultiAxisAttention(
embed_dim=128,
num_heads=8,
attention_axis=1, # Which axis to attend over
backend="ring" # Use ring attention
)
# Input: (batch, sequence, features) where sequence is sharded
# Each GPU has (batch, seq_local, features) where seq_local = seq_total / 4
x_local = torch.randn(2, 37500, 128).cuda() # 150k / 4 = 37.5k per GPU
out = attn(x_local) # Ring communication happens automatically
Launch:
torchrun --nproc_per_node=4 train.py
Multi-Axis Models
Models like nanoTabPFN have tensors with shape (batch, rows, features, embed) and need attention over multiple axes:
| Axis | Dimension | Strategy |
|---|---|---|
| Features (axis 2) | ~5 | Local (small, fits on GPU) |
| Rows (axis 1) | ~150,000 | Ring (too large, must shard) |
class TabularTransformer(nn.Module):
def __init__(self):
super().__init__()
# Small axis: local attention
self.feature_attn = mosaic.MultiAxisAttention(
embed_dim=96, num_heads=4,
attention_axis=2, # features
backend="local"
)
# Large axis: ring attention
self.row_attn = mosaic.MultiAxisAttention(
embed_dim=96, num_heads=4,
attention_axis=1, # rows
backend="ring"
)
def forward(self, x):
# x: (batch, rows_local, features, embed)
x = self.feature_attn(x) + x # No communication
x = self.row_attn(x) + x # Ring across GPUs
return x
API Reference
Core
mosaic.init(sp_size=1) → MosaicContext
Initialize distributed context.
sp_size: Number of GPUs for sequence parallelism
mosaic.MultiAxisAttention(embed_dim, num_heads, attention_axis, backend, mesh_shape=None)
Attention over any single tensor axis.
| Parameter | Description |
|---|---|
embed_dim |
Hidden dimension (must be divisible by num_heads) |
num_heads |
Number of attention heads |
attention_axis |
Which axis to attend over (supports negative indexing) |
backend |
"local", "ring", or "mesh2d" |
mesh_shape |
Required for mesh2d: (rows, cols) GPU grid |
Advanced
mosaic.MultiAxisAttention2D(embed_dim, num_heads, axis1, axis2, mesh_shape)
Attend over two axes simultaneously (flattens them internally).
# Attention over both rows and columns
attn = mosaic.MultiAxisAttention2D(
embed_dim=128, num_heads=8,
axis1=1, axis2=2, # rows × columns
mesh_shape=(2, 2) # 4 GPUs in 2×2 grid
)
# Input: (batch, rows, cols, embed)
# Internally: flatten to (batch, rows*cols, embed), run mesh2d attention
mosaic.ComposedAttention(mesh_shape, head_parallel, seq_parallel)
Combine head parallelism with sequence parallelism.
# 8 GPUs: 2-way head parallel × 4-way sequence parallel
composed = mosaic.ComposedAttention(
mesh_shape=(2, 4),
head_parallel=True, # Split heads across dim 0 (2 ways)
seq_parallel="ring" # Ring attention across dim 1 (4 ways)
)
Memory: Heads sharded 2×, sequence sharded 4× → 8× memory reduction
mosaic.HierarchicalAttention(intra_node_size, inter_node_strategy, intra_node_strategy)
Two-level parallelism for multi-node clusters.
# 4 nodes × 8 GPUs = 32 GPUs
hier = mosaic.HierarchicalAttention(
intra_node_size=8, # GPUs per node
intra_node_strategy="local", # Fast NVLink within node
inter_node_strategy="ring" # Slower network between nodes
)
mosaic.HierarchicalMetaAttention (NEW in v0.3)
Learned routing via chunk summaries. Instead of communicating with all chunks (O(p) in ring), meta-attention learns which chunks are relevant and only communicates with top-k.
attn = mosaic.HierarchicalMetaAttention(
embed_dim=256,
num_heads=8,
summary_dim=64, # Compress each chunk to this dim
top_k=2, # Only attend to top-k relevant chunks
use_checkpoint=True
)
out, meta = attn(x_local)
print(meta['routing_weights']) # Shows which chunks are relevant
# e.g., tensor([[0.05, 0.40, 0.15, 0.40]]) → chunks 1 and 3 are most relevant
How it works:
- Local attention within each GPU's chunk
- Compress chunks into summaries (content + attention pattern stats)
all_gathersummaries (small: onlysummary_dimper GPU)- Meta-attention computes routing weights
- Soft-weighted cross-attention using routing weights
- Gated combination of local + cross outputs
Benefits:
- Communication: O(1) gather + O(top_k) vs O(p) ring rounds
- Interpretability:
routing_weightstells you which regions interact - Sparse attention: Learns which long-range dependencies matter
Use cases: Genomics (gene regulation), long documents (topic routing), time series (seasonal patterns)
mosaic.MultiAxisHierarchicalAttention
Convenience wrapper with axis permutation (same API as MultiAxisAttention).
attn = mosaic.MultiAxisHierarchicalAttention(
embed_dim=256, num_heads=8,
attention_axis=1, # Attend over this axis
summary_dim=64,
top_k=2
)
out, meta = attn(x) # Works with any N-dimensional tensor
Performance
All backends use FlashAttention (F.scaled_dot_product_attention) for the local computation:
- Fused GEMM + softmax + GEMM
- O(n) memory instead of O(n²) for attention weights
- 2-4× faster than naive implementation
Communication uses NCCL collectives:
- Ring:
send/recvin ring topology - Mesh2D:
all_gatheralong grid dimensions
When to Use What
| Sequence Length | GPUs | Backend | Memory per GPU |
|---|---|---|---|
| < 10k | 1 | local |
O(n²) |
| 10k - 100k | 2-8 | ring |
O(n²/p) |
| 100k - 1M | 8-64 | ring or mesh2d |
O(n²/p) or O(n²/p²) |
| > 1M | 64+ | mesh2d + head_parallel |
O(n²/(p²·h)) |
Distributed Launch
Single Node
# 4 GPUs on one machine
torchrun --nproc_per_node=4 train.py
Multi-Node
# Node 0 (master) - replace IP with your master node
torchrun --nnodes=2 --nproc_per_node=8 --node_rank=0 \
--master_addr=192.168.1.100 --master_port=29500 train.py
# Node 1
torchrun --nnodes=2 --nproc_per_node=8 --node_rank=1 \
--master_addr=192.168.1.100 --master_port=29500 train.py
SLURM
#!/bin/bash
#SBATCH --nodes=4
#SBATCH --ntasks-per-node=8
#SBATCH --gpus-per-node=8
srun torchrun \
--nnodes=$SLURM_NNODES \
--nproc_per_node=8 \
--rdzv_id=$SLURM_JOB_ID \
--rdzv_backend=c10d \
--rdzv_endpoint=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n1):29500 \
train.py
Training Script Setup
import os
import torch
import torch.distributed as dist
import mosaic
def main():
# torchrun sets RANK, LOCAL_RANK, WORLD_SIZE automatically
dist.init_process_group(backend="nccl")
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
# Initialize Mosaic with full world for sequence parallelism
ctx = mosaic.init(sp_size=dist.get_world_size())
model = MyModel().to(ctx.device)
# Each GPU loads its shard of the sequence
total_seq = 150000
local_seq = total_seq // dist.get_world_size()
x_local = load_my_shard(local_seq) # Your data loading
# Forward pass - ring communication automatic
out = model(x_local)
dist.destroy_process_group()
if __name__ == "__main__":
main()
Multi-Node Mesh Configuration
# 2 nodes × 4 GPUs = 8 total
# Option 1: Head parallel across nodes (slow), seq parallel within (fast)
composed = mosaic.ComposedAttention(
mesh_shape=(2, 4), # (nodes, gpus_per_node)
head_parallel=True, # Across nodes
seq_parallel="ring" # Within node
)
# Option 2: Explicit hierarchical control
hier = mosaic.HierarchicalAttention(
intra_node_size=4, # GPUs per node
intra_node_strategy="local", # No comm within node
inter_node_strategy="ring" # Ring between node leaders
)
Architecture
┌─────────────────────────────────────────┐
│ User Model │
│ (defines attention_axis per layer) │
└─────────────────┬───────────────────────┘
│
┌─────────────────▼───────────────────────┐
│ Mosaic │
│ • Axis routing (permute to seq dim) │
│ • Backend selection │
│ • Tensor reshape for QKV projection │
└─────────────────┬───────────────────────┘
│
┌─────────────────▼───────────────────────┐
│ Backends │
│ • local: F.scaled_dot_product_attention│
│ • ring: ring_flash_attn_func │
│ • mesh2d: all_gather + SDPA │
└─────────────────┬───────────────────────┘
│
┌─────────────────▼───────────────────────┐
│ PyTorch / NCCL │
└─────────────────────────────────────────┘
License
MIT
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file mosaic_attention-0.3.0.tar.gz.
File metadata
- Download URL: mosaic_attention-0.3.0.tar.gz
- Upload date:
- Size: 32.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.0.1 CPython/3.10.14
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
1ac906fe5ec271fa4fd78375159660d27cd53134ea98b281e407079c095ac9cf
|
|
| MD5 |
817e44a7be31b812c9e80542dcf091c4
|
|
| BLAKE2b-256 |
45d8cbf21f804f2a589740e04efdf153eecd7af6238233557699e43416a9ad29
|
File details
Details for the file mosaic_attention-0.3.0-py3-none-any.whl.
File metadata
- Download URL: mosaic_attention-0.3.0-py3-none-any.whl
- Upload date:
- Size: 35.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.0.1 CPython/3.10.14
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
89c9616c9148e48ff1c255bebe19e91ed56a1cf468f2ef75560d624f84cf1e54
|
|
| MD5 |
5b289b46e132fad6d197e7adc58d8961
|
|
| BLAKE2b-256 |
2d087b11f71e843c1b17d82590270cc2bdd2c9e339025b5da9e7f789d2f500db
|