-
Notifications
You must be signed in to change notification settings - Fork 0
Online perception training #51
Description
Online Perception Training via Auxiliary Self-Supervised Loss
Problem
VisualPerception (the foveated + peripheral CNN pipeline) currently runs under torch.no_grad() during act(). Its output features are stored in the trajectory buffer and used by PPO, ICM, and the critic during training updates. However, no loss ever backpropagates through the perception module itself -- its parameters receive zero gradients and remain at their random initialization.
The downstream modules (affector, critic, forward/inverse dynamics) learn to work with whatever features perception happens to produce, but perception never learns to produce better features.
Storing raw frames for replay and recomputing features during training would solve this, but conflicts with two design principles:
- Real-time architecture. The environment runs asynchronously at 20 ticks/second. The agent's inference hot path must be as fast as possible. Training passes happen when the buffer fills, and their time budget is limited.
- Biological plausibility. The human brain does not store raw sensory input for replay. It stores compressed representations, and the compression itself is learned online from the structure of the sensory stream -- not from reward signals.
Proposed Solution: Online Auxiliary Loss
Train the perception module during the forward pass, before features are stored. The self-supervised loss operates on the current observation as it arrives, using the temporal structure of consecutive features as a training signal. No raw frame replay is needed.
Modified act() flow
obs arrives
│
▼
VisualPerception(obs, roi_obs) ← WITH gradients
│
▼
features_t ← live tensor, in computation graph
│
├──► compute auxiliary loss against features_{t-1}
│ │
│ ▼
│ loss.backward() ← updates perception parameters
│
▼
features_t.detach() ← fixed tensor, graph discarded
│
├──► affector(features_t) ← under no_grad(), no added latency
├──► critic(features_t)
└──► store in trajectory
The key insight: perception trains online from the temporal stream, then the detached features flow into the RL modules exactly as they do today. The RL training path is unchanged.
Candidate Auxiliary Losses
1. Temporal Predictive Coding (recommended starting point)
Given features_t and action_t, predict features_{t+1}. When the next observation arrives and produces features_{t+1}, compare against the prediction.
predicted_next = predictor(features_{t-1}.detach(), action_{t-1})
loss = F.mse_loss(predicted_next, features_t)The predictor is a small MLP (not the existing ForwardDynamics model, which trains on stored features). The loss says: "good features should make the world predictable." If perception produces features from which you can't predict what happens next, it's discarding useful information.
This is conceptually similar to what the ICM forward dynamics model already does, but applied directly to train perception rather than a separate dynamics network.
2. Temporal Coherence with Action-Gated Surprise
Features from consecutive steps should be similar (the world doesn't change much in 50ms), unless the agent took a significant action.
feature_delta = F.mse_loss(features_t, features_{t-1}.detach(), reduction='none').mean(dim=-1)
action_magnitude = action_{t-1}.abs().sum()
loss = (feature_delta / (1 + action_magnitude)).mean()This pushes perception to be stable and not noisy, while allowing large feature changes when the agent does something that should change the scene (e.g., turning the camera).
3. Contrastive (Prevent Feature Collapse)
Keep a small rolling buffer of the last K feature vectors (cheap -- these are 64-dim vectors, not raw frames). Current features should be distinguishable from features 10+ steps ago but similar to features 1-2 steps ago.
This prevents the degenerate solution where perception maps everything to the same constant vector (which would trivially minimize the temporal coherence loss).
Can be combined with option 1 or 2 as a regularizer.
Implementation Considerations
Latency Impact
The auxiliary loss adds one backward pass through the perception CNN per tick. This roughly doubles the perception portion of the hot path. The affector, critic, and action selection remain under no_grad() and are unaffected.
Need to benchmark: if VisualPerception.forward() currently takes N ms, the auxiliary training step adds approximately N ms (backward is roughly equal to forward for CNNs). If the total per-tick budget is 50ms and perception currently uses a small fraction, this is feasible. If perception is already the bottleneck, this may need to be amortized (e.g., train perception every K steps instead of every step).
State to Maintain
The agent needs to keep:
prev_features(detached) from the previous step -- already effectively available viaself.prev_visual_featuresprev_actionfrom the previous step -- trivially stored- A small auxiliary predictor network (for the temporal predictive coding approach)
- Optionally, a separate optimizer for the perception module with its own learning rate
Interaction with Existing Modules
- PPO / Critic / ICM: Unchanged. They continue to operate on stored (detached) features.
- Focus/ROI head: Unchanged. Still trained via separate REINFORCE loss on stored trajectories.
- ForwardDynamics: The existing forward dynamics model trains on stored features to predict next stored features (for ICM intrinsic reward). The auxiliary predictor is a separate, smaller network that trains on live features to update perception. They serve different purposes and don't conflict.