Active Inference agents using message passing on factor graphs, implemented in JAX. Compares multiple belief propagation variants for goal-directed planning across four environments.
agents/ # Agent implementations per environment
environments/ # Tensor generation & environment wrappers
inference/ # Planning algorithms (11 BP variants) & state inference
utils/ # Index conversion utilities
tests/ # Unit & integration tests
| Environment | Entry point | Config key | Description |
|---|---|---|---|
| MiniGrid DoorKey | run_minigrid.py |
experiment |
Partially observable gridworld with key-door puzzle |
| Frozen Lake | run_frozen_lake.py |
frozen_lake |
Slippery gridworld with holes; configurable layouts |
| Wumpus World | run_wumpus_world.py |
wumpus_world |
Gridworld with pits and wumpus; indirect observations |
| RockSample | run_rocksample.py |
rocksample |
Gridworld with rocks of unknown quality; distance-dependent observations |
Each environment also has run_*_diagnostics.py for single-episode inspection.
inference/planning.py— Standard BP: forward-backward on temporal factor graphinference/loopy_bp.py— Loopy BP with θ as variable nodeinference/region_extended_loopy_bp.py— Adds observation factors to planning graphinference/nuijten_mp.py— Region beliefs without kernelsinference/state_inference.py— Loopy BP for Bayesian state estimationinference/messages.py— Low-level message operations (log-space, EPSILON=1e-8, LOG_ZERO=-1e12)agents/flat_tensor_agent.py— MiniGrid agents (FlatTensorAgent + 10 variants)agents/frozen_lake_agent.py— Frozen Lake agentsagents/wumpus_agent.py— Wumpus World agentsagents/rocksample_agent.py— RockSample agentsenvironments/minigrid.py— MiniGrid transition & observation tensor generationenvironments/frozen_lake.py— Frozen Lake environmentenvironments/wumpus_world.py— Wumpus World environmentenvironments/rocksample.py— RockSample environmentutils/tensors.py— Index flattening, coordinate conversion, one-hot creation
Available via --planning-method:
| Method | Module | θ handling |
|---|---|---|
bp |
planning.py |
Marginalized once |
vbp |
vbp.py |
Value iteration (ε→0) |
loopy-vbp |
loopy_vbp.py |
Variable node (VBP) |
loopy |
loopy_bp.py |
Variable node |
region-extended |
region_extended_loopy_bp.py |
Variable node + obs factors |
reduced-region-extended |
reduced_region_extended.py |
Fixed + kernel reparam + obs factors |
dyn-channel |
dyn_channel_loopy_bp.py |
Variable node + dynamic channels |
reduced-dyn-channel |
reduced_dyn_channel.py |
Fixed + dynamic channels |
nuijten |
nuijten_mp.py |
Variable node, no kernels |
reduced-nuijten |
nuijten_mp.py |
Fixed, no kernels |
# MiniGrid experiment
uv run python run_minigrid.py --grid-size 3 --episodes 100 --planning-method region-extended
# Frozen Lake
uv run python run_frozen_lake.py --grid-size 5 --n-configs 10 --planning-method loopy
# Wumpus World
uv run python run_wumpus_world.py --grid-size 4 --n-configs 50 --planning-method dyn-channel
# RockSample
uv run python run_rocksample.py --grid-size 5 --n-rocks 3 --n-configs 8 --planning-method bp
# Single-episode diagnostics (any environment)
uv run python run_frozen_lake_diagnostics.py --planning-method region-extended
# Convergence analysis
uv run python run_minigrid_convergence.py --damping 0.25 --obs-alpha 0.01
# Tests
uv run python run_tests.pyUse --help on any script for all parameters. Per-method iteration counts and damping are configured in params.yaml.
Experiments are managed with DVC. Configuration lives in params.yaml, pipeline in dvc.yaml. Results go to data/results/, videos to data/videos/, convergence plots to data/convergence/.
uv run dvc repro # Run full pipeline
uv run dvc repro -s frozen_lake # Run single stageMiniGrid agent representations:
IndexedTensorAgent(default) — stores indices, not full tensors. Always use for production.FlatTensorAgent— full tensors, debugging only, grid_size ≤ 3.
JAX requirements:
- All inference functions are JIT-compiled (
@jax.jit) - First run is slow (compilation), subsequent runs are fast
- Horizon and iteration counts must be compile-time constants (static args)
- Use
jax.lax.fori_loopinside JIT functions - Computation is in log-space throughout to avoid underflow
State representation (MiniGrid):
- Dynamic state: (location, orientation, door_key_state) → flat index
- Static state: (key_position, door_position) → flat index
- See
flatten_state_index()inutils/tensors.py
Coordinates: Our (x,y) is y-flipped vs MiniGrid. Use conversions in tests/test_minigrid_groundtruth.py.
Run uv run python run_tests.py before committing. Tests cover tensor generation, inference algorithms, and agent integration for all four environments.
- Always run Python scripts with
uv run python(e.g.uv run python run_tests.py)