Skip to main content

Transformer Metacontroller

Project description

metacontroller

Implementation of the MetaController proposed in Emergent temporal abstractions in autoregressive models enable hierarchical reinforcement learning

Install

$ pip install metacontroller-pytorch

Appreciation

  • Pranoy for submitting a pull request for fixing the previous latent action not being included in the inputs to the switching unit

  • Diego Calanzone for proposing testing on BabyAI gridworld task, and submitting the pull request for behavior cloning and discovery phase training for it!

Usage

import torch
from metacontroller import Transformer, MetaController

# 1. initialize model

model = Transformer(
    dim = 512,
    action_embed_readout = dict(num_discrete = 4),
    state_embed_readout = dict(num_continuous = 384),
    lower_body = dict(depth = 2),
    upper_body = dict(depth = 2)
)

state = torch.randn(2, 128, 384)
actions = torch.randint(0, 4, (2, 128))

# 2. behavioral cloning (BC)

state_loss, action_loss = model(state, actions)
(state_loss + action_loss).backward()

# 3. discovery phase

meta_controller = MetaController(
    dim_model = 512,
    dim_meta_controller = 256,
    dim_latent = 128
)

action_recon_loss, kl_loss, switch_loss = model(
    state,
    actions,
    meta_controller = meta_controller,
    discovery_phase = True
)

(action_recon_loss + kl_loss + switch_loss).backward()

# 4. internal rl phase (GRPO)

# ... collect trajectories ...

logits, cache = model(
    one_state,
    past_action_id,
    meta_controller = meta_controller,
    return_cache = True
)

meta_output = cache.prev_hiddens.meta_controller
old_log_probs = meta_controller.log_prob(meta_output.action_dist, meta_output.actions)

# ... calculate advantages ...

loss = meta_controller.policy_loss(
    group_states,
    group_old_log_probs,
    group_latent_actions,
    group_advantages,
    group_switch_betas
)

loss.backward()

Or using evolutionary strategies for the last portion

# 5. evolve (ES over GRPO)

model.meta_controller = meta_controller

def environment_callable(model):
    # return a fitness score
    return 1.0

model.evolve(
    num_generations = 10,
    environment = environment_callable
)

Citations

@misc{kobayashi2025emergenttemporalabstractionsautoregressive,
    title   = {Emergent temporal abstractions in autoregressive models enable hierarchical reinforcement learning}, 
    author  = {Seijin Kobayashi and Yanick Schimpf and Maximilian Schlegel and Angelika Steger and Maciej Wolczyk and Johannes von Oswald and Nino Scherrer and Kaitlin Maile and Guillaume Lajoie and Blake A. Richards and Rif A. Saurous and James Manyika and Blaise Agüera y Arcas and Alexander Meulemans and João Sacramento},
    year    = {2025},
    eprint  = {2512.20605},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG},
    url     = {https://arxiv.org/abs/2512.20605}, 
}
@article{Wagenmaker2025SteeringYD,
    title   = {Steering Your Diffusion Policy with Latent Space Reinforcement Learning},
    author  = {Andrew Wagenmaker and Mitsuhiko Nakamoto and Yunchu Zhang and Seohong Park and Waleed Yagoub and Anusha Nagabandi and Abhishek Gupta and Sergey Levine},
    journal = {ArXiv},
    year    = {2025},
    volume  = {abs/2506.15799},
    url     = {https://api.semanticscholar.org/CorpusID:279464702}
}
@misc{hwang2025dynamicchunkingendtoendhierarchical,
    title   = {Dynamic Chunking for End-to-End Hierarchical Sequence Modeling},
    author  = {Sukjun Hwang and Brandon Wang and Albert Gu},
    year    = {2025},
    eprint  = {2507.07955},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG},
    url     = {https://arxiv.org/abs/2507.07955},
}
@misc{fleuret2025freetransformer,
    title     = {The Free Transformer}, 
    author    = {François Fleuret},
    year      = {2025},
    eprint    = {2510.17558},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG},
    url       = {https://arxiv.org/abs/2510.17558}, 
}

Life can only be understood backwards; but it must be lived forwards - Søren Kierkegaard

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

metacontroller_pytorch-0.0.48.tar.gz (341.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

metacontroller_pytorch-0.0.48-py3-none-any.whl (15.2 kB view details)

Uploaded Python 3

File details

Details for the file metacontroller_pytorch-0.0.48.tar.gz.

File metadata

  • Download URL: metacontroller_pytorch-0.0.48.tar.gz
  • Upload date:
  • Size: 341.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.25

File hashes

Hashes for metacontroller_pytorch-0.0.48.tar.gz
Algorithm Hash digest
SHA256 da84e03ac42e5a07d9ebb48dd82f0e61218bd658793a3b5ec05d9c64f1fba842
MD5 5ac14f3b62ac1747a9789ba26c09b8e0
BLAKE2b-256 484a9a6ad1b9d062d5cd961d9178699d19b5f326f44b2dbd5394cc822c504b01

See more details on using hashes here.

File details

Details for the file metacontroller_pytorch-0.0.48-py3-none-any.whl.

File metadata

File hashes

Hashes for metacontroller_pytorch-0.0.48-py3-none-any.whl
Algorithm Hash digest
SHA256 5b21a73023001038104573c73b3aadc64758abbdd963cc50c07f3b32d72b2308
MD5 7fa69f077ed2241165e0f3dc0c378761
BLAKE2b-256 e1c6beab7d659556e5796b3a49e8ba1f4b4f8f4995b87763f4415b67852a3d25

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page