Skip to main content

JAX Scalify: end-to-end scaled arithmetic.

Project description

JAX Scalify: end-to-end scaled arithmetic

JAX Scalify is a library implementing end-to-end scale propation and scaled arithmetic, allowing easy training and inference of deep neural networks in low precision (BF16, FP16, FP8).

Loss scaling, tensor scaling and block scaling have been widely used in the deep learning literature to unlock training and inference at lower precision. Most of these works focus on ad-hoc approaches around scaling of matrix multiplications (and sometimes reduction operations). Scalify is adopting a more systematic approach with end-to-end scale propagation, i.e. transforming the full computational graph into a ScaledArray graph where every operation has ScaledArray inputs and returns ScaledArray:

@dataclass
class ScaledArray:
    # Main data component, in low precision.
    data: Array
    # Scale, usually scalar, in FP32 or E8M0.
    scale: Array

    def __array__(self) -> Array:
        # Tensor represented as a `ScaledArray`.
        return data * scale.astype(self.data.dtype)

The main benefits of the scalify approach are:

  • Agnostic to neural-net model definition;
  • Decoupling scaling from low-precision, reducing the computational overhead of dynamic rescaling;
  • FP8 matrix multiplications and reductions as simple as a cast;
  • Out-of-the-box support of FP16 (scaled) master weights and optimizer state;
  • Composable with JAX ecosystem: Flax, Optax, ...

Scalify training loop example

A typical JAX training loop just requires a couple of modifications to take advantage of scalify. More specifically:

  • Represent input and state as ScaledArray using the as_scaled_array method (or variations of it);
  • End-to-end scale propagation in update training method using scalify decorator;
  • (Optionally) add dynamic_rescale calls to improve low-precision accuracy and stability;

The following (simplified) example presents how to scalify can be incorporated into a JAX training loop.

import jax_scalify as jsa

# Scalify transform on FWD + BWD + optimizer.
# Propagating scale in the computational graph.
@jsa.scalify
def update(state, data, labels):
    # Forward and backward pass on the NN model.
    loss, grads =
        jax.grad(model)(state, data, labels)
    # Optimizer applied on scaled state.
    state = optimizer.apply(state, grads)
    return loss, state

# Model + optimizer state.
state = (model.init(...), optimizer.init(...))
# Transform state to scaled array(s)
sc_state = jsa.as_scaled_array(state)

for (data, labels) in dataset:
    # If necessary (e.g. images), scale input data.
    data = jsa.as_scaled_array(data)
    # State update, with full scale propagation.
    sc_state = update(sc_state, data, labels)
    # Optional dynamic rescaling of state.
    sc_state = jsa.ops.dynamic_rescale_l2(sc_state)

As presented in the code above, the model state is represented as a JAX PyTree of ScaledArray, propagated end-to-end through the model (forward and backward passes) as well as the optimizer.

A full collection of examples is available:

Installation

JAX Scalify can be directly installed from the github repository in Python virtual environment:

pip install git+https://github.com/graphcore-research/jax-scalify.git@main

Alternatively, for a local development setup:

git clone git@github.com:graphcore-research/jax-scalify.git
pip install -e ./

The major dependencies are numpy, jax and chex libraries.

Documentation

Development

Running pre-commit and pytest on the JAX Scalify repository:

pip install pre-commit
pre-commit run --all-files
pytest -v ./tests

Python wheel can be built with the usual command python -m build.

Graphcore IPU support

JAX Scalify v0.1 is compatible with experimental JAX on IPU, which can be installed in a Graphcore Poplar Python environnment:

pip install jax==0.3.16+ipu jaxlib==0.3.15+ipu.sdk320 -f https://graphcore-research.github.io/jax-experimental/wheels.html

Here are the common JAX libraries compatible with IPU:

pip install chex==0.1.6 flax==0.6.4 equinox==0.7.0 jaxtyping==0.2.8

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_scalify-0.1.tar.gz (73.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jax_scalify-0.1-py3-none-any.whl (39.7 kB view details)

Uploaded Python 3

File details

Details for the file jax_scalify-0.1.tar.gz.

File metadata

  • Download URL: jax_scalify-0.1.tar.gz
  • Upload date:
  • Size: 73.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.10.12

File hashes

Hashes for jax_scalify-0.1.tar.gz
Algorithm Hash digest
SHA256 09ee5427d8119b1472f59ec2e9a67bdf810b8543eb170e64c2c4edfc662e60cb
MD5 8fd54490d7f2b6496625ee04e2cf21b6
BLAKE2b-256 9f6fec21e5008c5c19257a3b5b386197f316fb527322e51d9fb6eb6117c7df91

See more details on using hashes here.

File details

Details for the file jax_scalify-0.1-py3-none-any.whl.

File metadata

  • Download URL: jax_scalify-0.1-py3-none-any.whl
  • Upload date:
  • Size: 39.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.10.12

File hashes

Hashes for jax_scalify-0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 f3439626c317384230b1ff094d498bd93f5cb86a2d74fed938f93e7324ff81a3
MD5 8f95033e38aaa096018c2087a5c9e19a
BLAKE2b-256 329b9ee1be7cf43e23a3990b6a02f5926da7b9cbbc8de8b38c02ea5933802c9a

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page