Skip to main content

A high-level self-supervised learning library on top of JAX

Project description

Heron Logo

Heron

Tests Code Coverage Code Quality PyPI Version Python Version Documentation License


Heron is a high-level self-supervised learning (SSL) library built on JAX, Flax, and Optax. It provides high-level abstractions to simplify the process of training and experimenting with SSL algorithms, allowing you to focus on the model and data, not the boilerplate.

Heron aims to be modular and extensible, with clear separation between models, loss functions, data augmentations, and training strategies.

Core Features

  • High-Level Trainer API: A simple Trainer class that abstracts away the complexities of JAX's functional paradigm, including jit, pmap, and state management.
  • Modular Design: Easily mix and match backbones, heads, loss functions, and augmentation pipelines.
  • SSL Strategies: Pre-packaged implementations of popular SSL algorithms.
  • Performance: Built on JAX and Flax to leverage hardware acceleration on GPUs and TPUs.

Feature Roadmap

-   [x] Establish core abstractions (`Trainer`, `Backbone`, `ProjectionHead`).
-   [x] Implement **SimCLR** as the first end-to-end strategy.
-   [ ] Implement a robust data augmentation pipeline for contrastive learning.
-   [ ] Initial PyPI release.
-   [ ] Implement **BYOL** and **SimSiam** (non-contrastive methods).
-   [ ] Add logic for momentum encoders (teacher-student models).
-   [ ] Refine `TrainState` management and checkpointing.
-   [ ] Implement **DINO** and **MoCo v3**.
-   [ ] Add Vision Transformer (ViT) backbones.
-   [ ] Implement teacher-student centering and sharpening.
-   [ ] Masked Image Modeling strategies (e.g., **MAE**).
-   [ ] Integration with Hugging Face models and datasets.
-   [ ] Comprehensive documentation and tutorials.

Installation

pip install heron-ssl

Quick Start

Here is a conceptual example of how to use the Trainer API.

import heron_ssl as ssl
import tensorflow_datasets as tfds
import optax

# 1. Load a dataset
dataset = tfds.load('cifar10', split='train')

# 2. Define the model and SSL strategy
strategy = ssl.strategies.SimCLR(
    backbone=ssl.models.ResNet50(),
    projector=ssl.models.ProjectionHead(hidden_dims=[2048], output_dim=128),
)

# 3. Configure and run the trainer
trainer = ssl.Trainer(
    strategy=strategy,
    optimizer=optax.adam(1e-3),
)

# 4. Start training
trained_backbone_params = trainer.fit(dataset)

Contributing

Contributions are welcome! Please see CONTRIBUTING.md for details on how to get started.

License

Heron is licensed under the MIT License (LICENSE).

Acknowledgements

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

heron_ssl-0.1.0a1.tar.gz (5.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

heron_ssl-0.1.0a1-py3-none-any.whl (7.5 kB view details)

Uploaded Python 3

File details

Details for the file heron_ssl-0.1.0a1.tar.gz.

File metadata

  • Download URL: heron_ssl-0.1.0a1.tar.gz
  • Upload date:
  • Size: 5.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.7.19

File hashes

Hashes for heron_ssl-0.1.0a1.tar.gz
Algorithm Hash digest
SHA256 a13e40fcee3adac4deb3c6ea9d0a165651139ce7aaf17dc0fb0dd659337534bb
MD5 34d13993de7bcf10c0370597b7af497e
BLAKE2b-256 3376e720e70cc26bd1c91d47480ac1579ec090488476e5052f3d323123724861

See more details on using hashes here.

File details

Details for the file heron_ssl-0.1.0a1-py3-none-any.whl.

File metadata

File hashes

Hashes for heron_ssl-0.1.0a1-py3-none-any.whl
Algorithm Hash digest
SHA256 b540864bbcdabdae8a49db8528b2463f925061c7159ec935a816a73bbed122ae
MD5 fc7e0833796e5aa3e4679c365652c760
BLAKE2b-256 f535427bcf1f467a1fff8679c504b633d1c97516044c85ee31399bf578cfb037

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page