A high-level self-supervised learning library on top of JAX
Project description
Heron is a high-level self-supervised learning (SSL) library built on JAX, Flax, and Optax. It provides high-level abstractions to simplify the process of training and experimenting with SSL algorithms, allowing you to focus on the model and data, not the boilerplate.
Heron aims to be modular and extensible, with clear separation between models, loss functions, data augmentations, and training strategies.
Core Features
- High-Level Trainer API: A simple
Trainerclass that abstracts away the complexities of JAX's functional paradigm, includingjit,pmap, and state management. - Modular Design: Easily mix and match backbones, heads, loss functions, and augmentation pipelines.
- SSL Strategies: Pre-packaged implementations of popular SSL algorithms.
- Performance: Built on JAX and Flax to leverage hardware acceleration on GPUs and TPUs.
Feature Roadmap
- [x] Establish core abstractions (`Trainer`, `Backbone`, `ProjectionHead`).
- [x] Implement **SimCLR** as the first end-to-end strategy.
- [ ] Implement a robust data augmentation pipeline for contrastive learning.
- [ ] Initial PyPI release.
- [ ] Implement **BYOL** and **SimSiam** (non-contrastive methods).
- [ ] Add logic for momentum encoders (teacher-student models).
- [ ] Refine `TrainState` management and checkpointing.
- [ ] Implement **DINO** and **MoCo v3**.
- [ ] Add Vision Transformer (ViT) backbones.
- [ ] Implement teacher-student centering and sharpening.
- [ ] Masked Image Modeling strategies (e.g., **MAE**).
- [ ] Integration with Hugging Face models and datasets.
- [ ] Comprehensive documentation and tutorials.
Installation
pip install heron-ssl
Quick Start
Here is a conceptual example of how to use the Trainer API.
import heron_ssl as ssl
import tensorflow_datasets as tfds
import optax
# 1. Load a dataset
dataset = tfds.load('cifar10', split='train')
# 2. Define the model and SSL strategy
strategy = ssl.strategies.SimCLR(
backbone=ssl.models.ResNet50(),
projector=ssl.models.ProjectionHead(hidden_dims=[2048], output_dim=128),
)
# 3. Configure and run the trainer
trainer = ssl.Trainer(
strategy=strategy,
optimizer=optax.adam(1e-3),
)
# 4. Start training
trained_backbone_params = trainer.fit(dataset)
Contributing
Contributions are welcome! Please see CONTRIBUTING.md for details on how to get started.
License
Heron is licensed under the MIT License (LICENSE).
Acknowledgements
- Logo is from SVG Repo.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file heron_ssl-0.1.0a1.tar.gz.
File metadata
- Download URL: heron_ssl-0.1.0a1.tar.gz
- Upload date:
- Size: 5.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.7.19
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
a13e40fcee3adac4deb3c6ea9d0a165651139ce7aaf17dc0fb0dd659337534bb
|
|
| MD5 |
34d13993de7bcf10c0370597b7af497e
|
|
| BLAKE2b-256 |
3376e720e70cc26bd1c91d47480ac1579ec090488476e5052f3d323123724861
|
File details
Details for the file heron_ssl-0.1.0a1-py3-none-any.whl.
File metadata
- Download URL: heron_ssl-0.1.0a1-py3-none-any.whl
- Upload date:
- Size: 7.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.7.19
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
b540864bbcdabdae8a49db8528b2463f925061c7159ec935a816a73bbed122ae
|
|
| MD5 |
fc7e0833796e5aa3e4679c365652c760
|
|
| BLAKE2b-256 |
f535427bcf1f467a1fff8679c504b633d1c97516044c85ee31399bf578cfb037
|