Skip to main content

Transformer Metacontroller

Project description

metacontroller

Implementation of the MetaController proposed in Emergent temporal abstractions in autoregressive models enable hierarchical reinforcement learning, from the Paradigms of Intelligence team at Google

Install

$ pip install metacontroller-pytorch

Appreciation

  • Pranoy for submitting a pull request for fixing the previous latent action not being included in the inputs to the switching unit

  • Diego Calanzone for proposing testing on BabyAI gridworld task, and submitting the pull request for behavior cloning and discovery phase training for it!

  • Andrew Song for ongoing implementation of the PinPad environment!

  • Diego Calanzone for his experimental acumen, bringing the project to an initial working state for the BabyAI environment!

  • Andrew Song for implementing linear probing and fixing an issue with the action space

  • Andrew Song for identifying a critical issue with past action embed handling and detaching gradients of target states

  • Diego Calanzone for identifying inconsistencies in the MetaController

  • Diego Calanzone for replicating interpretable temporal segmentation for the BabyAI gridworld task!

Usage

import torch
from metacontroller import Transformer, MetaController

# 1. initialize model

model = Transformer(
    dim = 512,
    action_embed_readout = dict(num_discrete = 4),
    state_embed_readout = dict(num_continuous = 384),
    lower_body = dict(depth = 2),
    upper_body = dict(depth = 2)
)

state = torch.randn(2, 128, 384)
actions = torch.randint(0, 4, (2, 128))

# 2. behavioral cloning (BC)

state_loss, action_loss = model(state, actions)
(state_loss + action_loss).backward()

# 3. discovery phase

meta_controller = MetaController(
    dim_model = 512,
    dim_meta_controller = 256,
    dim_latent = 128
)

state_pred_loss, action_recon_loss, kl_loss, aux_ratio_loss = model(
    state,
    actions,
    meta_controller = meta_controller,
    discovery_phase = True
)

# they did not use state pred loss in the paper (weight set to 0, but available)
# the ratio loss from h-net paper is also available, but optional (set ratio_loss_weight > 0)

(action_recon_loss + kl_loss * 0.1).backward()

# 4. internal rl phase (GRPO)

# ... collect trajectories ...

logits, cache = model(
    one_state,
    past_action_id,
    meta_controller = meta_controller,
    return_cache = True
)

meta_output = cache.prev_hiddens.meta_controller
old_log_probs = meta_controller.log_prob(meta_output.action_dist, meta_output.actions)

# ... calculate advantages ...

# for GRPO, the inputs to policy loss should be of shape (batch, seq, dim_latent)
# where dim_latent is the dimension of the latent action space

loss = meta_controller.policy_loss(
    group_states,
    group_old_log_probs,
    group_latent_actions,
    group_advantages,
    group_switch_betas
)

loss.backward()

Or using evolutionary strategies for the last portion

# 5. evolve (ES over GRPO)

model.meta_controller = meta_controller

def environment_callable(model):
    # return a fitness score
    return 1.0

model.evolve(
    num_generations = 10,
    environment = environment_callable
)

Contributing

To install the dependencies for testing, run

$ uv sync --extra test

To run the tests with pytest, run

$ uv run pytest

Citations

@misc{kobayashi2025emergenttemporalabstractionsautoregressive,
    title   = {Emergent temporal abstractions in autoregressive models enable hierarchical reinforcement learning},
    author  = {Seijin Kobayashi and Yanick Schimpf and Maximilian Schlegel and Angelika Steger and Maciej Wolczyk and Johannes von Oswald and Nino Scherrer and Kaitlin Maile and Guillaume Lajoie and Blake A. Richards and Rif A. Saurous and James Manyika and Blaise Agüera y Arcas and Alexander Meulemans and João Sacramento},
    year    = {2025},
    eprint  = {2512.20605},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG},
    url     = {https://arxiv.org/abs/2512.20605},
}
@article{Wagenmaker2025SteeringYD,
    title   = {Steering Your Diffusion Policy with Latent Space Reinforcement Learning},
    author  = {Andrew Wagenmaker and Mitsuhiko Nakamoto and Yunchu Zhang and Seohong Park and Waleed Yagoub and Anusha Nagabandi and Abhishek Gupta and Sergey Levine},
    journal = {ArXiv},
    year    = {2025},
    volume  = {abs/2506.15799},
    url     = {https://api.semanticscholar.org/CorpusID:279464702}
}
@misc{hwang2025dynamicchunkingendtoendhierarchical,
    title   = {Dynamic Chunking for End-to-End Hierarchical Sequence Modeling},
    author  = {Sukjun Hwang and Brandon Wang and Albert Gu},
    year    = {2025},
    eprint  = {2507.07955},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG},
    url     = {https://arxiv.org/abs/2507.07955},
}
@misc{fleuret2025freetransformer,
    title     = {The Free Transformer},
    author    = {François Fleuret},
    year      = {2025},
    eprint    = {2510.17558},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG},
    url       = {https://arxiv.org/abs/2510.17558},
}
@misc{hafner2025trainingagentsinsidescalable,
    title   = {Training Agents Inside of Scalable World Models},
    author  = {Danijar Hafner and Wilson Yan and Timothy Lillicrap},
    year    = {2025},
    eprint  = {2509.24527},
    archivePrefix = {arXiv},
    primaryClass = {cs.AI},
    url     = {https://arxiv.org/abs/2509.24527},
}
@article{Pagnoni2024ByteLT,
    title   = {Byte Latent Transformer: Patches Scale Better Than Tokens},
    author  = {Artidoro Pagnoni and Ram Pasunuru and Pedro Rodriguez and John Nguyen and Benjamin Muller and Margaret Li and Chunting Zhou and Lili Yu and Jason Weston and Luke S. Zettlemoyer and Gargi Ghosh and Mike Lewis and Ari Holtzman and Srinivasan Iyer},
    journal = {ArXiv},
    year    = {2024},
    volume  = {abs/2412.09871},
    url     = {https://api.semanticscholar.org/CorpusID:274762821}
}

Life can only be understood backwards; but it must be lived forwards - Søren Kierkegaard

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

metacontroller_pytorch-0.2.41.tar.gz (30.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

metacontroller_pytorch-0.2.41-py3-none-any.whl (34.1 kB view details)

Uploaded Python 3

File details

Details for the file metacontroller_pytorch-0.2.41.tar.gz.

File metadata

File hashes

Hashes for metacontroller_pytorch-0.2.41.tar.gz
Algorithm Hash digest
SHA256 a4a9105fe43c6ce85245c1aedebdb4ac3ced2dbe87cf6c71875d822e94890e73
MD5 1ab8479f1f64b39d8b5b4d5a6070d842
BLAKE2b-256 a69b7f589af0d60d87d41362faf6c9334a0646bc5dfec8892485ad6e649f9e2d

See more details on using hashes here.

File details

Details for the file metacontroller_pytorch-0.2.41-py3-none-any.whl.

File metadata

File hashes

Hashes for metacontroller_pytorch-0.2.41-py3-none-any.whl
Algorithm Hash digest
SHA256 fef42dddaf6be17427ec38858b72475dfe4fbdde46640edcdb8a369d8288239a
MD5 e725a3a6205ccf9a38358c2cd353d237
BLAKE2b-256 234c417de311a649418cd23e25ff7f052e7f632d0cd8b64f0e36d3447327ae18

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page