Skip to main content

Bayesian Deep Ensembles with MILE (JAX), scikit-learn style.

Project description

Bayesian Deep Ensembles for scikit-learn

Docs Status Tests Lifecycle: stable License

👉 Start Here: Complete Online Documentation

Introduction

bde is a user-friendly implementation of Bayesian Deep Ensembles compatible with scikit-learn with a particular focus on tabular data. It exposes estimators that plug into scikit-learn pipelines while leveraging JAX for accelerator-backed training, sampling, and uncertainty quantification.

In particular, bde implements Microcanonical Langevin Ensembles (MILE) as introduced in Microcanonical Langevin Ensembles: Advancing the Sampling of Bayesian Neural Networks (ICLR 2025). A conceptual overview of MILE is shown below:

MILE Overview

Installation

To install the latest release from PyPI, run:

pip install sklearn-contrib-bde

To install the latest development version from GitHub, run:

pip install git+https://github.com/scikit-learn-contrib/bde.git

Developer environment

We recommend using pixi to create a deterministic development environment:

pixi install

# Then you can directly run examples like so:
pixi run python -m examples.example

Pixi ensures the correct JAX, CUDA (when needed), and scikit-learn versions are selected automatically. See pixi.lock for channel and platform details.

Example Usage

Minimal runnable scripts live in examples/, and the snippets below highlight the most common regression and classification workflows. When running outside those scripts, remember to set the XLA device count so JAX allocates enough host devices ( this needs to be done before importing JAX):

export XLA_FLAGS="--xla_force_host_platform_device_count=<n_decive>"

Adjust the value to match the number of CPU (or GPU) devices you plan to use.

Regression Example

import os

os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"

import jax.numpy as jnp
from sklearn.datasets import fetch_openml
from sklearn.metrics import root_mean_squared_error
from sklearn.model_selection import train_test_split

from bde import BdeRegressor
from bde.loss import GaussianNLL

data = fetch_openml(name="airfoil_self_noise", as_frame=True)

X = data.data.values
y = data.target.values.reshape(-1, 1)

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)

Xmu, Xstd = jnp.mean(X_train, 0), jnp.std(X_train, 0) + 1e-8
Ymu, Ystd = jnp.mean(y_train, 0), jnp.std(y_train, 0) + 1e-8

Xtr = (X_train - Xmu) / Xstd
Xte = (X_test - Xmu) / Xstd
ytr = (y_train - Ymu) / Ystd
yte = (y_test - Ymu) / Ystd

# Build the regressor
regressor = BdeRegressor(
    hidden_layers=[16, 16],
    n_members=8,
    seed=0,
    loss=GaussianNLL(),
    epochs=200,
    validation_split=0.15,
    lr=1e-3,
    weight_decay=1e-4,
    warmup_steps=5000,
    n_samples=2000,
    n_thinning=2,
    patience=10,
)

# Fit the regressor
regressor.fit(x=Xtr, y=ytr)

# Get results from regressor
means, sigmas = regressor.predict(Xte, mean_and_std=True)
mean, intervals = regressor.predict(Xte, credible_intervals=[0.1, 0.9])
raw = regressor.predict(Xte, raw=True) # (ensemble members, n_samples/n_thinning, n_test_data, (mu,sigma))

Classification Example

import os

os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

from bde import BdeClassifier
from bde.loss import CategoricalCrossEntropy

iris = load_iris()
X = iris.data.astype("float32")
y = iris.target.astype("int32").ravel()

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)

# Build the classifier
classifier = BdeClassifier(
    n_members=4,
    hidden_layers=[16, 16],
    seed=0,
    loss=CategoricalCrossEntropy(),
    activation="relu",
    epochs=1000,
    validation_split=0.15,
    lr=1e-3,
    warmup_steps=400,  # very few steps required for this simple dataset
    n_samples=100,
    n_thinning=1,
    patience=10,
)

# Fit the classifier
classifier.fit(x=X_train, y=y_train)

# Get results from classifier
preds = classifier.predict(X_test)
probs = classifier.predict_proba(X_test)
score = classifier.score(X_train, y_train)
raw = classifier.predict(X_test, raw=True) # (ensemble members, n_samples/n_thinning, n_test_data, n_classes)

Workflow

The high-level estimators follow this flow during fit and evaluation:

  • BdeRegressor / BdeClassifier (bde/bde.py) delegate to the shared Bde base class.
  • Bde.fit validates data, resolves defaults, and calls _build_bde() to instantiate BdeBuilder.
  • BdeBuilder.fit_members (bde/bde_builder.py) trains each network, handles device padding, and applies early stopping.
  • _build_log_post constructs the ensemble log-posterior, then warmup_bde (bde/sampler/warmup.py) adapts step sizes before sampling.
  • Sampler utilities (bde/sampler/*) draw posterior samples and cache them for downstream prediction.
  • User-facing predict / predict_proba call the private _evaluate / _make_predictor (bde/bde_evaluator.py) to aggregate samples into means, intervals, probabilities, or raw outputs.
flowchart TD
    subgraph User
        FitCall["Call BdeRegressor/BdeClassifier.fit(X, y)"]
        PredCall["Call predict(...)/predict_proba(...)"]
    end

    subgraph Bde
        Validate["validate_fit_data / _prepare_targets"]
        Build["_build_bde()"]
        Builder["BdeBuilder"]
        Train["fit_members(X, y, optimizer, loss)"]
        LogPost["_build_log_post(X, y)"]
        WarmSampler["_warmup_sampler(logpost)"]
        Keys["_generate_rng_keys + _normalize_tuned_parameters"]
        Draw["_draw_samples(...) via MileWrapper.sample_batched"]
        Cache["positions_eT_ stored in estimator"]
        Eval["_evaluate(... flags ...)"]
        MakePred["_make_predictor(Xte)"]
    end

    subgraph Warmup
        Warm["warmup_bde()"]
        Adapter["custom_mclmc_warmup adapter"]
        Adapt["per-member adaptation (pmap/vmap)"]
        Results["AdaptationResults: states_e, tuned params"]
    end

    subgraph Sampling
        Wrapper["MileWrapper"]
        Batch["sample_batched(...)"]
        Posterior["Posterior samples (E x T x ...)"]
    end

    subgraph Evaluation
        Predictor["BdePredictor"]
        Outputs["Predictions (mean, std, intervals, probs, raw)"]
    end

    FitCall --> Validate --> Build --> Builder
    Builder --> Train --> LogPost --> WarmSampler --> Keys --> Draw --> Cache
    WarmSampler --> Warm --> Adapter --> Adapt --> Results
    Draw --> Wrapper --> Batch --> Posterior
    Cache --> PredCall --> Eval --> MakePred --> Predictor --> Outputs
    Posterior --> Predictor

Datasets included in the package for testing purposes

Dataset Source Task
Airfoil UCI Machine Learning Repository (Dua & Graff, 2017) Regression
Concrete UCI Machine Learning Repository (Yeh, 2006) Regression
Iris Fisher (1936); canonical modern version distributed via scikit-learn Multiclass classification (setosa, versicolor, virginica)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

sklearn_contrib_bde-1.0.1.tar.gz (424.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

sklearn_contrib_bde-1.0.1-py3-none-any.whl (75.9 kB view details)

Uploaded Python 3

File details

Details for the file sklearn_contrib_bde-1.0.1.tar.gz.

File metadata

  • Download URL: sklearn_contrib_bde-1.0.1.tar.gz
  • Upload date:
  • Size: 424.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for sklearn_contrib_bde-1.0.1.tar.gz
Algorithm Hash digest
SHA256 27e248d6472dad4352f206b2719ef5bc2b9e3e9a4c9e614d658d2c96cd160661
MD5 f929ce574d75e3a9c1eba5b995caf1f1
BLAKE2b-256 2261cd28042b999f3324148578b332d590042d87f448a8c235cc25b6eb3a8b57

See more details on using hashes here.

Provenance

The following attestation bundles were made for sklearn_contrib_bde-1.0.1.tar.gz:

Publisher: publish.yml on scikit-learn-contrib/bde

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file sklearn_contrib_bde-1.0.1-py3-none-any.whl.

File metadata

File hashes

Hashes for sklearn_contrib_bde-1.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 e26df31467dc596bfe358a036fb22e1ee3270901cf34dfeb75da0cd3fe69b7d4
MD5 4e598f0b1f942835050a1930af9a2730
BLAKE2b-256 c4278348f3baf0a2d5e293f183f681445ab9e74626e204018de9c6878eddd681

See more details on using hashes here.

Provenance

The following attestation bundles were made for sklearn_contrib_bde-1.0.1-py3-none-any.whl:

Publisher: publish.yml on scikit-learn-contrib/bde

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page