GIST: Greedy Independent Set Thresholding for Max-Min Diversification with Submodular Utility
Project description
gist-select
Greedy Independent Set Thresholding for Max-Min Diversification with Submodular Utility
A production-grade Python implementation of the GIST algorithm from Fahrbach et al. (NeurIPS 2025). Select subsets that are both high-quality and diverse — with provable approximation guarantees.
The Problem
You have a large pool of items (data points, images, documents, candidates) and need to select k of them. You want items that are individually valuable and collectively diverse — no redundancy, maximum coverage.
GIST solves this by maximizing:
f(S) = g(S) + λ · div(S)
| Term | Meaning |
|---|---|
g(S) |
Monotone submodular utility — how valuable the selected set is |
div(S) |
Max-min diversity — the minimum pairwise distance in the set |
λ |
Trade-off knob between utility and diversity |
Approximation guarantees:
- General submodular utility: 1/2 - ε
- Linear utility: 2/3 - ε (tight — matches the NP-hardness lower bound)
Features
- Provably good — constant-factor approximation guarantees from the paper
- Fast at scale — tested up to 2M points with high-dimensional embeddings
- CELF acceleration — lazy greedy evaluation reduces oracle calls by orders of magnitude
- Optimised numerics — BLAS-backed distance computation, precomputed norms, no large temporaries
- Flexible metrics — Euclidean, cosine, or bring your own distance function
- Flexible utilities — linear weights, set coverage, or bring your own submodular function
- Parallel threshold sweep — optional multi-threaded execution via joblib
- Deterministic — seed parameter for full reproducibility
Installation
pip install gist-select
With optional parallel support:
pip install "gist-select[parallel]"
Requirements: Python ≥ 3.10, NumPy ≥ 1.24, SciPy ≥ 1.10
Quick Start
import numpy as np
from gist import gist, LinearUtility, EuclideanDistance
# 10,000 points in 64 dimensions
rng = np.random.default_rng(42)
points = rng.standard_normal((10_000, 64)).astype(np.float32)
weights = rng.random(10_000)
# Select the 50 best-and-diverse points
result = gist(
points=points,
utility=LinearUtility(weights),
distance=EuclideanDistance(),
k=50,
lam=1.0,
seed=42,
)
print(f"Selected {len(result.indices)} points")
print(f"Objective: {result.objective_value:.4f}")
print(f"Utility: {result.utility_value:.4f}")
print(f"Diversity: {result.diversity:.4f}")
API Reference
gist()
gist(
points, # np.ndarray (n, d) — your data
utility, # SubmodularFunction — how to score subsets
distance, # DistanceMetric — how to measure spread
k, # int — how many points to select
lam=1.0, # float — diversity weight (λ ≥ 0)
eps=0.05, # float — approximation granularity (ε > 0)
n_jobs=1, # int — threads for threshold sweep
seed=None, # int — random seed for reproducibility
diameter=None, # tuple — precomputed (d_max, idx_u, idx_v)
) -> GISTResult
Returns a GISTResult with:
| Field | Type | Description |
|---|---|---|
indices |
np.ndarray |
Indices of selected points |
objective_value |
float |
g(S) + λ · div(S) |
utility_value |
float |
g(S) |
diversity |
float |
div(S) — minimum pairwise distance |
Parameters in detail:
lam— Controls the utility/diversity trade-off. Higher values favour more spread-out selections. Set to0for pure utility maximisation (standard greedy).eps— Controls the number of distance thresholds swept (~76 foreps=0.05, ~38 foreps=0.1). Smaller is more thorough but slower.n_jobs— Number of threads for the threshold sweep. Requiresjoblib. Uses threading backend to share memory.diameter— Skip the automatic diameter estimation by providing(d_max, idx_u, idx_v). Useful when callinggist()repeatedly on the same point set.
Distance Metrics
from gist import EuclideanDistance, CosineDistance, CallableDistance
| Class | Description | Hot Path |
|---|---|---|
EuclideanDistance() |
L2 distance with precomputed norms | Single BLAS GEMV |
CosineDistance() |
1 - cos(a, b), auto-normalises |
Single BLAS GEMV |
CallableDistance(fn) |
User-provided fn(vec, matrix) → dists |
Your function |
Custom distance example:
from gist import CallableDistance
def manhattan(source_vec, target_matrix):
"""L1 distance — must be vectorised."""
return np.abs(target_matrix - source_vec).sum(axis=1)
distance = CallableDistance(manhattan)
Note: The callable signature is
fn(source: ndarray shape (d,), targets: ndarray shape (m, d)) → ndarray shape (m,). It must be vectorised — a scalardist(a, b)function will not work at scale.
Submodular Utilities
from gist import LinearUtility, CoverageFunction, SubmodularFunction
LinearUtility(weights)
Additive utility: g(S) = Σ weights[i] for i ∈ S.
This is the most common case and the fastest — marginal gains are just individual weights. GIST achieves the tight 2/3-approximation for linear utilities.
weights = np.array([0.9, 0.1, 0.8, 0.3, 0.7])
utility = LinearUtility(weights)
CoverageFunction(coverage_matrix, element_weights=None)
Set-coverage utility: g(S) = |⋃_{i ∈ S} cover(i)|.
Each point covers a set of elements. The utility is the total number (or weighted sum) of distinct elements covered by the selected set. Classic diminishing returns.
from scipy import sparse
# 1000 points, each covering some of 500 elements
coverage_matrix = sparse.random(1000, 500, density=0.05, format="csr")
coverage_matrix.data[:] = 1 # binary
utility = CoverageFunction(coverage_matrix)
Custom Submodular Functions
Subclass SubmodularFunction and implement two methods:
from gist import SubmodularFunction
class MyUtility(SubmodularFunction):
def marginal_gains(self, selected: list[int], candidates: np.ndarray) -> np.ndarray:
"""Return g(v | S) for each v in candidates."""
gains = np.empty(len(candidates))
for i, v in enumerate(candidates):
gains[i] = self._compute_marginal(v, selected)
return gains
def value(self, selected: list[int]) -> float:
"""Return g(S)."""
return self._compute_value(selected)
Performance tip: GIST uses CELF (lazy greedy) internally, so
marginal_gainsis called infrequently on small batches after the initial pass. But the initial call evaluates all points, so make sure it handles largecandidatesarrays efficiently.
Examples
Data Sampling for Model Training
Select a representative training subset that balances uncertainty and diversity — inspired by the paper's ImageNet experiment:
import numpy as np
from gist import gist, LinearUtility, CosineDistance
# embeddings: (n, 2048) from a pretrained model
# uncertainty: (n,) margin-based uncertainty scores
embeddings = np.load("embeddings.npy")
uncertainty = np.load("uncertainty_scores.npy")
# Select 50K diverse, uncertain examples
result = gist(
points=embeddings,
utility=LinearUtility(uncertainty),
distance=CosineDistance(),
k=50_000,
lam=0.5, # balance uncertainty with diversity
eps=0.05,
n_jobs=4, # parallel threshold sweep
seed=0,
)
train_indices = result.indices
print(f"Selected {len(train_indices)} training examples")
print(f"Min pairwise cosine distance: {result.diversity:.4f}")
Feature Selection
Select a diverse subset of features that individually have high relevance:
import numpy as np
from gist import gist, LinearUtility, EuclideanDistance
# features: (n_features, n_samples) — each row is a feature vector
features = np.load("feature_matrix.npy")
relevance_scores = np.load("relevance.npy")
result = gist(
points=features,
utility=LinearUtility(relevance_scores),
distance=EuclideanDistance(),
k=20,
lam=2.0, # strongly penalise redundant features
seed=42,
)
selected_features = result.indices
Tuning the Diversity Trade-off
Sweep over lam to find the right balance for your task:
import numpy as np
from gist import gist, LinearUtility, EuclideanDistance
rng = np.random.default_rng(0)
points = rng.standard_normal((5000, 32))
weights = rng.random(5000)
for lam in [0.0, 0.5, 1.0, 2.0, 5.0]:
result = gist(
points, LinearUtility(weights), EuclideanDistance(),
k=50, lam=lam, eps=0.1, seed=0,
)
print(f"λ={lam:<4} utility={result.utility_value:.2f} "
f"diversity={result.diversity:.3f} "
f"objective={result.objective_value:.2f}")
λ=0.0 utility=49.47 diversity=3.041 objective=49.47
λ=0.5 utility=48.91 diversity=3.255 objective=50.54
λ=1.0 utility=48.31 diversity=3.442 objective=51.75
λ=2.0 utility=47.06 diversity=3.819 objective=54.70
λ=5.0 utility=43.98 diversity=4.607 objective=67.02
Pure Utility Maximisation
Set lam=0 to recover standard greedy submodular maximisation (no diversity term):
result = gist(points, utility, distance, k=100, lam=0.0)
# Equivalent to the classic (1 - 1/e)-approximation greedy
Performance
Benchmarked on Apple M-series, single-threaded, eps=0.1:
| Points | Dimensions | k | Time |
|---|---|---|---|
| 10K | 64 | 50 | 0.3s |
| 100K | 128 | 100 | 6s |
| 500K | 128 | 100 | ~30s |
| 2M | 128 | 100 | ~2 min |
Scaling tips:
- Use
float32points — 2x memory savings and faster BLAS - Increase
eps(e.g.,0.1→0.2) to halve the number of thresholds - Use
n_jobs=-1with joblib for parallel threshold sweep - Precompute
diameterwhen callinggist()repeatedly on the same data
How It Works
GIST sweeps over a geometric sequence of distance thresholds. For each threshold d, it runs a greedy algorithm that builds a maximal independent set of the intersection graph (points within distance d are "neighbours") while maximising the submodular utility.
GIST(V, g, k, ε):
1. S ← GreedyIndependentSet(V, g, d=0, k) # pure utility baseline
2. T ← diametrical pair with max distance # pure diversity baseline
3. For each threshold d in geometric sequence:
T ← GreedyIndependentSet(V, g, d, k) # utility + diversity
Keep best f(T) = g(T) + λ · div(T)
4. Return the best solution found
The GreedyIndependentSet subroutine uses CELF (lazy greedy) to minimise submodular oracle calls. Points within distance d of a selected point are eliminated, ensuring all selected points are pairwise at distance ≥ d.
For the full details, see the paper: arXiv:2405.18754
Citation
If you use this package in your research, please cite the original paper:
@inproceedings{fahrbach2025gist,
title={GIST: Greedy Independent Set Thresholding for Max-Min Diversification with Submodular Utility},
author={Fahrbach, Matthew and Ramalingam, Srikumar and Zadimoghaddam, Morteza and Ahmadian, Sara and Citovsky, Gui and DeSalvo, Giulia},
booktitle={Advances in Neural Information Processing Systems (NeurIPS)},
year={2025}
}
License
MIT
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file gist_select-0.1.0.tar.gz.
File metadata
- Download URL: gist_select-0.1.0.tar.gz
- Upload date:
- Size: 20.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
20442ce741b0f98575eb49927a44b96b71f71b3e22ba84c39d3b1964679eacfa
|
|
| MD5 |
5b62aa8b7cbb0ff85ece0b3825a5f26b
|
|
| BLAKE2b-256 |
986e15109590efed1fce757bd8fadf8cb44298ff5263eb0c277203bf96e310db
|
File details
Details for the file gist_select-0.1.0-py3-none-any.whl.
File metadata
- Download URL: gist_select-0.1.0-py3-none-any.whl
- Upload date:
- Size: 13.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
12267b8cc544239d06ecb555fd6a5ad1f3ee7de03b5a49e3e9c49cdcfccfd6be
|
|
| MD5 |
5cbc71a48424c88174b5471761317681
|
|
| BLAKE2b-256 |
8675ea399e91cbb8c41a0815412bf5c21897805f7f2507ffce8686b7a357ac96
|