Find maximum batch size, documents, and timesteps for PyTorch models
Project description
🔍 Batch Finder
Find the maximum value for any dimension your PyTorch models can handle without stopping the code.
Batch Finder detects your model’s inputs (types and shapes), fixes the sizes you choose, and searches for the largest value your run can sustain without stopping the code.
✨ Features
- 🎯 One main function – You call
find_max_minibatchfor pretty much everything you need here. - 🔍 Tell it what goes in – You pass names, data types, and rough shapes (
input_shapes/forward_params). If your model has aconfig, that can help too. - 📐 Shapes your way – Use a simple tuple or list, or a list of tuples when you have multiple tensors (for example
[(-1, 128, 512), (-1, 128, 512)]). For several inputs, you can use a text string or a dict. Or skip that and useaxis_to_maximizeplusfixed_axis(handy for Hugging Face–style models). - 🚀 Runs forward or full training – Turn backward on or off depending on what you want to test.
- ⚙️ Knobs you can turn – Change how fast it steps up or down (
factor_up,factor_down), how many tries it gets (n_attempts), and where it starts (initial_value). - ⏱️ Time cap (optional) – Set
time_limit_secondsif you only want the search to run for so long. When time is up, you get the best size that worked so far (orNoneif nothing worked yet). Leave it unset to ignore the clock and rely onn_attemptsand the normal stop rules. - 🛡️ Fails without trashing your session – When things blow up, it cleans up. If even batch size
1fails, you getNone(honest “no”). - 📊 See what it’s doing –
tqdmshows a progress bar with useful status text.
📦 Installation
pip install batch-finder
Or from source:
git clone https://github.com/yourusername/batch-finder.git
cd batch-finder
pip install -e .
🚀 Quick Start
One tensor: tuple or list
Use negative integers -1 on each axis you want to maximize (the search tries a single trial size each step). In an all-integer tuple, every -1 position shares that same trial size. Use positive integers for fixed dimensions.
Other negative integers d < -1 (e.g. -2, -3) size that dimension as round(|d| × trial)—the same scaling role as negative floats below (e.g. -2 → 2× the trial size tied to -1).
If the tuple mixes integers and negative floats, you are in compact numeric mode: there must be at least one integer -1 (the searched axis). Any other dimension given as a negative float -x is sized as round(|x| × trial), where trial is the current value on the -1 axis—so |x| is the proportion you want between that axis and the searched axis (e.g. -1.5 keeps that dim about 1.5× the trial size).
from batch_finder import find_max_minibatch
def get_model():
return MyModel()
# Maximize axis 0; other dims fixed
max_val = find_max_minibatch(get_model, input_shapes=(-1, 64, 256))
# Maximize axis 2
max_val = find_max_minibatch(get_model, input_shapes=(4, 8, -1))
# One integer -1 (searched axis) + negative float: other dims scale by |float| vs. that trial
max_val = find_max_minibatch(get_model, input_shapes=(-1, 4, -1.5, 16))
Several tensors: list of shape tuples
Pass a list or tuple of per-tensor shapes (same order as forward / forward_params). Each entry follows the same rules as a single-tensor tuple (integer -1, integer multipliers -2, -3, …, or compact floats). One trial size is searched; every -1 and every d < -1 uses that trial value as above.
# Example: two tensors, same free shape pattern
max_val = find_max_minibatch(
get_model,
input_shapes=[(-1, 128, 512), (-1, 128, 512)],
forward_params=["x", "y"],
)
For symbolic names and constraints across tensors, use a string or dict below.
HuggingFace-style: axis_to_maximize + fixed_axis
from transformers import AutoModelForCausalLM
from batch_finder import find_max_minibatch
def get_model():
return AutoModelForCausalLM.from_pretrained("distilgpt2")
# ``forward_params`` defaults to GPT2-style causal LM names; pass your own list for other architectures.
max_batch = find_max_minibatch(
get_model,
axis_to_maximize="batch_size",
fixed_axis={"seq_len": 32},
)
print(f"Max batch size: {max_batch}")
Several tensors: one string
When forward takes multiple tensors, pass input_shapes as text: one (…) group per argument (same order as forward), short names for axes that must match, optional equations between names, and at least one name=-1 for the size you search.
Pattern: (dims…), (dims…), … then optional name=value rules.
- Dimensions: numbers, or names like
b,tthat repeat across tensors. - One name must be set to
-1(that is the searched size). - Rules like
t=1.5btie sizes together (non-integers are rounded for tensor shapes).
import torch
from batch_finder import find_max_minibatch
class MyModel(torch.nn.Module):
def forward(self, x, y):
# x: (23, b, t, 45), y: (b, t, 12)
...
def get_model():
return MyModel()
max_b = find_max_minibatch(
get_model,
input_shapes="(23, b, t, 45),(b, t, 12), t=1.5b, b=-1",
)
Do not combine this with tuple/list single-tensor mode or with axis_to_maximize. Pass input_shapes= as a keyword.
Several tensors: dict (named arguments)
Same idea as the string form, but keys match forward parameters. Put shared rules in "#constraints" (must include exactly one symbol=-1). Values are shape text, optionally with , int or , float for dtype.
def get_model():
return MyModel()
max_b = find_max_minibatch(
get_model,
input_shapes={
"input_ids": "(b, t), int",
"attention_mask": "(b, t), int",
"input_ids_encoder": "(d, b, t), int",
"attention_mask_encoder": "(d, b, t), int",
"labels": "(b, t)",
"#constraints": "t=2b, b=-1",
},
)
Use the literal "#constraints" for the rules entry, or import FINDER_CONSTRAINTS_KEY from batch_finder to avoid typos (CONSTRAINTS_KEY is the same string, kept for compatibility).
Custom search parameters
def get_model():
return MyModel()
max_val = find_max_minibatch(
get_model,
input_shapes=(-1, 128, 512),
initial_value=8,
n_attempts=30,
factor_down=3.0, # divide by 3 on failure
factor_up=2.0, # multiply by 2 on success
time_limit_seconds=120.0, # optional: return best trial so far before timeout
)
📖 API Reference
find_max_minibatch(...)
Find the largest workable size for the free axis without OOM.
Parameters:
| Parameter | Type | Default | Description |
|---|---|---|---|
get_model |
Callable[[], nn.Module] |
(required) | First argument. Fresh module per attempt. Parent calls get_model() once first, keeps only .config (if any) for vocab/shape hints, then drops weights. Picklable under spawn (module-level function, not lambda) |
forward_params |
Sequence[str] |
None |
When input_shapes is not a dict: ordered tensor names for forward. With axis_to_maximize only, defaults to DEFAULT_FORWARD_PARAMS_CAUSAL_LM (GPT2-style). Override or import that constant to tweak. |
input_shapes |
str | dict | tuple/list |
None |
String: several tensors, shared names and rules. Dict: named shapes + "#constraints". Flat tuple/list: first tensor only (ints with -1 or compact floats). Nested list/tuple of shapes: one tuple per tensor, e.g. [(-1, 128, 512), (-1, 128, 512)]. Not used together with axis_to_maximize |
axis_to_maximize |
str |
None |
Name of the axis when input_shapes is omitted, e.g. "batch_size" |
fixed_axis |
Dict[str, int] |
{} |
Fixed sizes, e.g. {"seq_len": 128} |
device |
torch.device |
auto | Device to run on |
delay |
float |
None→auto |
Auto: 3.0 s on CUDA, 0.75 s on CPU/MPS (pass a number to override) |
initial_value |
int |
None→auto |
Auto: 512 CUDA, 64 MPS, 32 CPU (pass a number to override) |
n_attempts |
int |
50 |
Maximum attempts |
time_limit_seconds |
float |
None |
Wall-clock limit from the start of the search loop (after the probe). When time is up, returns the largest successful trial so far (or None if none). Shortens sleeps between attempts; with subprocess workers, terminates a run that would exceed the remaining time. In-process forward/backward is not interrupted mid-step. |
inference_only |
bool |
False |
If True, no backward. If False, forward + backward. |
factor_down |
float |
2.0 |
After failure: next = value / factor_down |
factor_up |
float |
2.0 |
Used when memory_guided=False for success steps |
memory_guided |
bool |
True |
Use peak GPU memory (and optional CPU via psutil) to pick the next size; cap with max_growth_multiplier |
memory_target_fraction |
float |
0.88 |
Target peak VRAM as a fraction of total GPU memory when extrapolating |
max_growth_multiplier |
float |
6.0 |
Max single-step increase after success |
cuda_mem_devices |
int, list, or "all" |
None |
Which GPUs to read peak memory on; default = search device’s index. "all" or [0,1,…] for multi-GPU bottleneck |
use_subprocess |
bool |
None |
Default: subprocess on Linux/macOS for CUDA, CPU, and MPS (worker may be OOM-killed; parent continues). Set BATCH_FINDER_SUBPROCESS=0 for in-process on very tight hosts. Windows: in-process |
Returns: Shape tuple, tuple of shape tuples (multi-tensor list input_shapes), or int (string/dict input_shapes or axis_to_maximize), or None if nothing worked. With time_limit_seconds, you may get the best value found before the deadline even if n_attempts was not reached.
Modes:
- Tuple/list: first
forwardargument, or a list/tuple of shape tuples for several arguments;-1marks the free axis, or add negative floats to scale other axes off that trial size. - String: one shape group per tensor; names, one
=-1, optional equations between names. - Dict: like the string form, keys = parameter names,
"#constraints"for the rules. axis_to_maximize+fixed_axis: when you skipinput_shapes.
Example output (axis_to_maximize + fixed_axis):
--- Detected inputs (type, estimated shape) ---
input_ids: integer, (32, 64)
attention_mask: integer, (32, 64)
---
batch_size fixed={'seq_len': 32}: 100%|████████████████████| 22/50 [01:26<00:00, 3.9s/it, gpus=1, i=22/50, max_ok=1919, min_fail=1920, status=✅, value=1919
✅ Max value that passed: 1919
🔧 How It Works
- Inputs – Uses
input_shapesdict keys orforward_params(no live module in the parent during the search). - Types – Integers for
*ids,*mask,labels; floats otherwise. - Shapes – From optional
config, or fromget_model().config(one probe if you omitconfig) plus argument names. - Search – On success, grow (memory-guided or
factor_up); on failure without a bracket, shrink byfactor_down. Once a success and a failure bracket the limit, the next trial is always the midpoint(max_ok + min_fail) // 2until the bracket tightens to one step. Stops at failure at size1, whenn_attemptsis reached, or whentime_limit_secondselapses (then returns the best successful trial so far, if any). - Loss – Uses
output.lossif present, otherwise sums output tensors.
⚠️ Important Notes
- Memory: Lower
initial_valueon small GPUs. - Speed:
inference_only=Trueis faster. - Training:
inference_only=Falseruns backward too. - Size 1: If size
1fails, the function returnsNone. - OOM: CUDA OOM and similar errors are caught and the search continues. On Linux/macOS, attempts default to a subprocess so the parent keeps running if the host kills the worker (SIGKILL). Set
BATCH_FINDER_SUBPROCESS=0for in-process attempts if repeatedspawnreloads are worse on your host. Each worker run ends withdel model,gc.collect(), and CUDA/MPS cache clears where applicable. After failures, the same cleanup runs. - Defaults: Omitting
initial_valueanddelaypicks smaller starts and shorter pauses on CPU/MPS than on CUDA. - Login nodes: Prefer a real GPU job for meaningful limits. On CPU-only hosts, use a smaller
initial_valueand/orinference_only=Trueif RAM is tight. - Memory-guided steps: After each good run, peak GPU memory vs total guides the next step (capped by
max_growth_multiplier). Installpsutilfor a rough CPU hint. Setmemory_guided=Falseto use onlyfactor_up/factor_down. - Multi-GPU (one OS process): Use
cuda_mem_devices="all"or a list so every visible GPU is measured; memory-guided growth uses the tightest GPU (minimum headroom). - DDP / torchrun / Accelerate (several processes): Each rank runs the search on its assigned GPU (
cuda_mem_devicesdefault is that GPU only). WhenWORLD_SIZE > 1,find_max_minibatchtakes the minimum successful trial size across ranks (JSON sync) before returning. UseBATCH_FINDER_SYNC_DIRor the default$WORK/.cache/$HOME/.cache— do not pass a per-process randomoutput_dir(e.g. each worker re-importing a script with a new random experiment id); every rank must resolve the same sync directory. Layout:<base>/find_max_minibatch_sync/<job_tag>/. The first rank done waits (poll every 4s, status log about every 30s) until others finish searching. - Time limit: Pass
time_limit_seconds=…to cap wall-clock time for the search loop (after the one-time probe). You get the best batch found so far when the limit hits. Subprocess workers are stopped if they would run past the remaining budget; an in-process forward/backward still runs to completion once started.
🤝 Contributing
Contributions are welcome. Please open a Pull Request.
📝 License
MIT — see the LICENSE file.
Made with ❤️ for the PyTorch community
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file batch_finder-0.6.0.tar.gz.
File metadata
- Download URL: batch_finder-0.6.0.tar.gz
- Upload date:
- Size: 34.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.10.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
676eaf598d0f2b831d5b10cd5a29cae0821728ad2c6d5e342c5e7579e70463f6
|
|
| MD5 |
4bd4a004e4b895144d656dae214e5d1c
|
|
| BLAKE2b-256 |
939c7d68e50c43e2e6e158dd25ca6c42dbdb4af87e1cde6c43af94a911f6ab82
|
File details
Details for the file batch_finder-0.6.0-py3-none-any.whl.
File metadata
- Download URL: batch_finder-0.6.0-py3-none-any.whl
- Upload date:
- Size: 28.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.10.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
a91be27c1e7f5cad3f51425c99de5277ca2f1de608827084f043453590f9bf8c
|
|
| MD5 |
a32add20695fb3842c1cb6730e071719
|
|
| BLAKE2b-256 |
6ee60d98791070d884348df39eb7c894aadd996e30b8da76f45ec02bd1affc96
|