diff --git a/docs/source/index.rst b/docs/source/index.rst index 0d721408db7..98a6f5f51d4 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -116,6 +116,7 @@ Intermediate tutorials/torchrl_envs tutorials/pretrained_models tutorials/dqn_with_rnn + tutorials/recurrent_sequence_training tutorials/mujoco_cube_bowl_macros tutorials/collector_trajectory_assembly tutorials/evaluator diff --git a/tutorials/sphinx-tutorials/recurrent_sequence_training.py b/tutorials/sphinx-tutorials/recurrent_sequence_training.py new file mode 100644 index 00000000000..a7727985434 --- /dev/null +++ b/tutorials/sphinx-tutorials/recurrent_sequence_training.py @@ -0,0 +1,415 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Recurrent training on sequence batches +====================================== + +**Author**: `Achintya Paningapalli `_ + +.. _recurrent_sequence_tuto: + +.. grid:: 2 + + .. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn + + * How TorchRL auto-wires :class:`~torchrl.envs.InitTracker` and the + recurrent-state primer when a recurrent policy is detected + * How to sample multi-step slices from a replay buffer with + :class:`~torchrl.data.replay_buffers.SliceSampler` + * How to train a recurrent policy in *full-sequence* mode using + :class:`~torchrl.modules.set_recurrent_mode` + * Why ``is_init`` at slice starts is what keeps hidden state from + leaking across episode boundaries inside a replay batch + + .. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites + + * PyTorch v2.0.0 + * gymnasium[classic_control] + * Familiarity with :ref:`Recurrent DQN ` is helpful but not + required — this tutorial is its multi-step complement +""" + +######################################################################### +# Overview +# -------- +# +# There are two ways to train a recurrent policy in TorchRL, and choosing +# between them is the first design decision in any recurrent project: +# +# 1. **Sequential mode** (one step per forward call). The policy reads the +# previous step's hidden state from the TensorDict, runs the LSTM for one +# step, and writes the new hidden state under the ``("next", ...)`` keys. +# This is the natural mode during collection and is what +# :ref:`Recurrent DQN ` covers in depth. +# 2. **Recurrent mode** (full ``[B, T]`` sequence per forward call). The LSTM +# processes a whole time dimension at once. This is used inside loss / +# advantage code over replayed trajectory slices, where you want to +# backprop through several timesteps in a single batched call. +# +# This tutorial focuses on (2): how to collect data, sample multi-step +# slices, and run a recurrent policy in full-sequence mode without leaking +# hidden state across episode boundaries. +# +# The key building blocks — most of which the collector now auto-wires for +# you — are: +# +# - :class:`~torchrl.modules.LSTMModule` (the recurrent policy core), +# - :class:`~torchrl.envs.InitTracker` (writes ``is_init=True`` at the start +# of every trajectory), +# - A :class:`~torchrl.envs.transforms.TensorDictPrimer` that seeds the +# initial recurrent state on reset, +# - :class:`~torchrl.data.replay_buffers.SliceSampler` for trajectory-aware +# replay, +# - :class:`~torchrl.modules.set_recurrent_mode` for switching the LSTM +# between single-step and full-sequence execution. +# +# See :ref:`ref_recurrent_state_lifecycle` for the full reference on how +# hidden state flows through the pipeline. + +###################################################################### +# If you are running this in Google Colab, install the dependencies first: +# +# .. code-block:: bash +# +# !pip3 install torchrl +# !pip3 install gymnasium + +# sphinx_gallery_start_ignore +import warnings + +warnings.filterwarnings("ignore") +from torch import multiprocessing + +if multiprocessing.get_start_method(allow_none=True) is None: + multiprocessing.set_start_method("fork") +# sphinx_gallery_end_ignore + +import torch +from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq +from torch import nn + +from torchrl.collectors import Collector +from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer +from torchrl.data.replay_buffers.samplers import SliceSampler +from torchrl.envs import GymEnv +from torchrl.modules import LSTMModule, set_recurrent_mode + +torch.manual_seed(0) +device = torch.device("cpu") + +###################################################################### +# Environment and policy +# ---------------------- +# +# We use ``CartPole-v1`` for a small, fast, fully-observed env. A recurrent +# policy is overkill for CartPole, but that is precisely the point: it lets +# us focus on the sequence-batching machinery without the env being the +# bottleneck. The same pattern scales unchanged to partially-observed +# environments where memory actually matters. +# +# The policy is a tiny :class:`~torchrl.modules.LSTMModule` followed by a +# linear head that maps the LSTM output to the action logits. We pick +# ``hidden_size=16`` so the tutorial finishes quickly on CPU. + +OBS_DIM = 4 # CartPole observation: cart pos, cart vel, pole angle, pole vel +N_ACTIONS = 2 # left or right +HIDDEN = 16 + + +def make_policy() -> Seq: + """Construct the recurrent policy: LSTMModule + linear logits head.""" + lstm = LSTMModule( + input_size=OBS_DIM, + hidden_size=HIDDEN, + in_keys=["observation", "rs_h", "rs_c"], + out_keys=["features", ("next", "rs_h"), ("next", "rs_c")], + python_based=True, # avoids cuDNN for vmap / torch.compile compatibility + ) + head = Mod( + nn.Linear(HIDDEN, N_ACTIONS), + in_keys=["features"], + out_keys=["logits"], + ) + # Deterministic argmax action selection — enough to demonstrate the + # collection / replay / sequence-training plumbing. + chooser = Mod( + lambda logits: logits.argmax(-1, keepdim=False), + in_keys=["logits"], + out_keys=["action"], + ) + return Seq(lstm, head, chooser) + + +policy = make_policy() +policy.eval() + +###################################################################### +# Auto-wiring recurrent env transforms +# ------------------------------------ +# +# A recurrent policy needs two env-side transforms to behave correctly: +# +# 1. :class:`~torchrl.envs.InitTracker`, which writes ``is_init=True`` at +# the first step after every reset. +# 2. A :class:`~torchrl.envs.transforms.TensorDictPrimer` that zero-fills +# the initial ``("rs_h", "rs_c")`` recurrent state slots on reset. +# +# Forgetting either of these is the most common source of silent +# hidden-state bugs. The collector now detects recurrent submodules in the +# policy and auto-appends both transforms when ``auto_register_policy_transforms=True`` +# is passed. As of v0.15 this becomes the default; for earlier versions you +# need to opt in explicitly. + +env = GymEnv("CartPole-v1", device=device) + +collector = Collector( + env, + policy, + frames_per_batch=64, + total_frames=512, + device=device, + storing_device=device, + auto_register_policy_transforms=True, + reset_at_each_iter=False, +) + +###################################################################### +# Inspecting a single batch +# ------------------------- +# +# Run one rollout and look at what comes back. Three things to notice: +# +# - The batch shape is ``[T]`` where ``T == frames_per_batch`` (single env). +# - ``is_init`` is ``True`` at the very first step and immediately after +# each ``done``. These mark trajectory boundaries. +# - ``("next", "rs_h")`` / ``("next", "rs_c")`` carry the LSTM's next-step +# hidden / cell state across timesteps. + +data = next(iter(collector)) + +print("Batch shape:", data.shape) +print("Available keys:", sorted(k for k in data.keys())) +print("is_init shape:", data["is_init"].shape) +print("# trajectory boundaries in batch:", int(data["is_init"].sum().item())) +print( + "Next-step hidden shape:", + data["next", "rs_h"].shape, + "(batch, num_layers, hidden_size)", +) + +###################################################################### +# Replay with SliceSampler +# ------------------------ +# +# We store the rollout in a replay buffer and sample fixed-length slices of +# trajectories from it. :class:`~torchrl.data.replay_buffers.SliceSampler` +# is trajectory-aware: it uses the boundary information already in the +# batch (``("next", "done")`` and the collector-written ``("collector", +# "traj_ids")``) to draw whole sub-trajectories. +# +# Two design choices in this section: +# +# 1. **Slices remain *ragged* by default**. ``SliceSampler`` returns a flat +# batch of concatenated variable-length slices, *not* a rectangular +# ``[B, T]`` tensor. The trajectory boundaries inside that flat batch +# are marked by ``is_init=True``. This avoids padding waste and matches +# how TorchRL's recurrent modules already consume input. +# 2. **``pad_output=True`` is available but discouraged**. It pads short +# slices to ``slice_len`` and writes a ``("collector", "mask")`` key for +# consumers (e.g. mask-aware losses), but the docstring on SliceSampler +# explicitly steers users toward the ragged path when possible. + +SLICE_LEN = 8 +NUM_SLICES = 4 + +rb = TensorDictReplayBuffer( + storage=LazyTensorStorage(max_size=1024, device=device), + sampler=SliceSampler(slice_len=SLICE_LEN, strict_length=False), + batch_size=SLICE_LEN * NUM_SLICES, +) +rb.extend(data) + +sample = rb.sample() +print("Sampled batch shape:", sample.shape) +print( + "is_init True positions (slice starts):", + sample["is_init"].nonzero().squeeze(-1).tolist(), +) + +###################################################################### +# What ``is_init`` marks +# ~~~~~~~~~~~~~~~~~~~~~~ +# +# Each ``True`` value in the sampled batch's ``is_init`` flags the first +# timestep of a slice (or a trajectory boundary that fell inside a slice +# because the underlying episode ended mid-slice). Downstream code — the +# recurrent module's full-sequence forward, GAE / advantage computation, +# mask-aware loss reduction — all read these markers to know where to +# reset hidden state and where to mask. + +###################################################################### +# Multi-step (full-sequence) recurrent forward +# -------------------------------------------- +# +# Now the payoff. By default, ``LSTMModule.forward`` runs in *sequential* +# mode — one step at a time, using the hidden state in the TensorDict. +# Inside :class:`~torchrl.modules.set_recurrent_mode(True)`, the same +# module runs in *recurrent* mode: it batches the time dimension into a +# single ``nn.LSTM`` call. +# +# The key invariant: if any ``is_init`` is True in the time positions +# ``1..T-1`` (i.e. mid-batch), the module's recurrent path *splits the +# batch along those boundaries*, runs the LSTM independently on each +# resulting chunk, and stitches the outputs back together. Hidden state +# from one trajectory never bleeds into the next. + +# Reshape the sampled ragged batch into rectangular [num_slices, slice_len]. +# When ``strict_length=False`` and there is no padding, slices are exactly +# ``slice_len`` long, so this reshape is safe. (With ``pad_output=True``, +# the same reshape works and a mask key is provided alongside.) +sample_seq = sample.reshape(NUM_SLICES, SLICE_LEN) + +with set_recurrent_mode(True): + out = policy(sample_seq.clone()) + +print("Sequence-mode output shape:", out["features"].shape) +print( + "Sequence-mode logits shape:", + out["logits"].shape, + "(num_slices, slice_len, n_actions)", +) + +###################################################################### +# Why the boundary handling matters: a controlled check +# ----------------------------------------------------- +# +# To prove that the recurrent path does not leak hidden state across +# trajectory boundaries inside a single batch, we build a small +# two-trajectory packed batch by hand, seed the *first* trajectory with +# non-zero noise in its incoming hidden, and check that the second +# trajectory's outputs match a standalone forward over just that second +# trajectory. +# +# If hidden state leaked across the boundary, the noisy first half would +# pollute the second half's outputs and the comparison would fail. + +# Two adjacent slices of length 4 each, packed end-to-end. is_init=True at +# index 0 (first traj start) and index 4 (second traj start). +T_A = T_B = 4 +T = T_A + T_B +is_init_packed = torch.zeros(1, T, dtype=torch.bool) +is_init_packed[0, 0] = True +is_init_packed[0, T_A] = True + +obs = torch.randn(1, T, OBS_DIM) +# Seed the incoming hidden with noise. The recurrent forward should +# *override* this at every is_init=True position. +noisy_h = torch.randn(1, T, 1, HIDDEN) +noisy_c = torch.randn(1, T, 1, HIDDEN) + +from tensordict import TensorDict + +packed = TensorDict( + { + "observation": obs, + "rs_h": noisy_h, + "rs_c": noisy_c, + "is_init": is_init_packed, + }, + batch_size=[1, T], +) + +# Isolated forward over only the second trajectory, with is_init=True at +# step 0 (its real start). If the packed run handles boundaries correctly, +# the packed batch's second-half output must match this isolated output. +is_init_b = torch.zeros(1, T_B, dtype=torch.bool) +is_init_b[0, 0] = True +b_only = TensorDict( + { + "observation": obs[:, T_A:].clone(), + "rs_h": noisy_h[:, T_A:].clone(), # same noise — must be overridden + "rs_c": noisy_c[:, T_A:].clone(), + "is_init": is_init_b, + }, + batch_size=[1, T_B], +) + +with set_recurrent_mode(True): + packed_out = policy(packed) + b_out = policy(b_only) + +# Trajectory B's features inside the packed batch == trajectory B alone. +torch.testing.assert_close( + packed_out["features"][:, T_A:], b_out["features"], rtol=1e-5, atol=1e-6 +) +print("Hidden-state isolation across is_init boundary: verified.") + +###################################################################### +# A tiny training loop +# -------------------- +# +# We close with a minimal supervised training loop on top of the same +# infrastructure. We use the simplest possible objective — match the +# action logits to a constant target — purely to exercise the +# ``set_recurrent_mode(True)`` + replay-buffer + LSTM gradient path. In a +# real recurrent training job, this is where you would plug in your +# advantage estimator and policy loss. + +trainable_policy = make_policy() +optimizer = torch.optim.Adam(trainable_policy.parameters(), lr=3e-4) +target_logits = torch.zeros(N_ACTIONS) +target_logits[0] = 1.0 # arbitrary supervised target + +losses = [] +for _step in range(4): + sample = rb.sample() + sample_seq = sample.reshape(NUM_SLICES, SLICE_LEN) + with set_recurrent_mode(True): + out = trainable_policy(sample_seq.clone()) + loss = (out["logits"] - target_logits.expand_as(out["logits"])).pow(2).mean() + optimizer.zero_grad() + loss.backward() + optimizer.step() + losses.append(loss.item()) + +collector.shutdown() + +print("Training loss trajectory:", [round(v, 4) for v in losses]) + +###################################################################### +# Conclusion +# ---------- +# +# You have built a recurrent training pipeline that: +# +# - Lets the collector auto-wire :class:`~torchrl.envs.InitTracker` and the +# recurrent-state primer for you, removing two manual steps that used to +# be a frequent source of silent bugs. +# - Samples multi-step trajectory slices from a replay buffer with +# :class:`~torchrl.data.replay_buffers.SliceSampler`. +# - Runs the LSTM in full-sequence mode under +# :class:`~torchrl.modules.set_recurrent_mode(True)` and verified by +# construction that hidden state does not leak across trajectory +# boundaries inside a sampled batch. +# +# This is the canonical TorchRL pattern for recurrent / sequence-based RL. +# It composes cleanly with every loss module, every advantage estimator, +# and every replay-buffer extension in the library. + +###################################################################### +# Further reading +# --------------- +# +# - :ref:`Recurrent DQN ` — the single-step / collection-time +# complement to this tutorial. +# - :ref:`ref_recurrent_state_lifecycle` — full reference on how hidden +# state flows from collection through replay to the loss, and what +# ``is_init`` controls along the way. +# - :ref:`ref_collectors_internals` — the per-step rollout flow, +# ``_carrier`` semantics, and how the device-cast flags interact with +# recurrent state. +# - :ref:`ref_glossary` — short definitions of ``is_init``, +# ``TensorDictPrimer``, ``recurrent mode``, ``set_keys``, and other +# shorthand that appears throughout the recurrent code paths.