Skip to content

Adds serialization to observation and action managers #2234

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions source/isaaclab/isaaclab/managers/action_manager.py
Original file line number Diff line number Diff line change
@@ -358,6 +358,14 @@ def get_term(self, name: str) -> ActionTerm:
"""
return self._terms[name]

def serialize(self) -> dict:
"""Serialize the action manager configuration.

Returns:
A dictionary of serialized action term configurations.
"""
return {term_name: term.serialize() for term_name, term in self._terms.items()}

"""
Helper functions.
"""
11 changes: 10 additions & 1 deletion source/isaaclab/isaaclab/managers/manager_base.py
Original file line number Diff line number Diff line change
@@ -16,7 +16,7 @@
import omni.timeline

import isaaclab.utils.string as string_utils
from isaaclab.utils import string_to_callable
from isaaclab.utils import class_to_dict, string_to_callable

from .manager_term_cfg import ManagerTermBaseCfg
from .scene_entity_cfg import SceneEntityCfg
@@ -79,6 +79,11 @@ def device(self) -> str:
"""Device on which to perform computations."""
return self._env.device

@property
def __name__(self) -> str:
"""Return the name of the class or subclass."""
return self.__class__.__name__

"""
Operations.
"""
@@ -92,6 +97,10 @@ def reset(self, env_ids: Sequence[int] | None = None) -> None:
"""
pass

def serialize(self) -> dict:
"""General serialization call. Includes the configuration dict."""
return {"cfg": class_to_dict(self.cfg)}

def __call__(self, *args) -> Any:
"""Returns the value of the term required by the manager.

25 changes: 24 additions & 1 deletion source/isaaclab/isaaclab/managers/observation_manager.py
Original file line number Diff line number Diff line change
@@ -14,7 +14,7 @@
from prettytable import PrettyTable
from typing import TYPE_CHECKING

from isaaclab.utils import modifiers
from isaaclab.utils import class_to_dict, modifiers
from isaaclab.utils.buffers import CircularBuffer

from .manager_base import ManagerBase, ManagerTermBase
@@ -334,6 +334,29 @@ def compute_group(self, group_name: str) -> torch.Tensor | dict[str, torch.Tenso
else:
return group_obs

def serialize(self) -> dict:
"""Serialize the observation term configurations for all active groups.

Returns:
A dictionary where each group name maps to its serialized observation term configurations.
"""
output = {
group_name: {
term_name: (
term_cfg.func.serialize()
if isinstance(term_cfg.func, ManagerTermBase)
else {"cfg": class_to_dict(term_cfg)}
)
for term_name, term_cfg in zip(
self._group_obs_term_names[group_name],
self._group_obs_term_cfgs[group_name],
)
}
for group_name in self.active_terms.keys()
}

return output

"""
Helper functions.
"""
48 changes: 47 additions & 1 deletion source/isaaclab/test/managers/test_observation_manager.py
Original file line number Diff line number Diff line change
@@ -18,11 +18,21 @@
import torch
import unittest
from collections import namedtuple
from typing import TYPE_CHECKING

import isaaclab.sim as sim_utils
from isaaclab.managers import ManagerTermBase, ObservationGroupCfg, ObservationManager, ObservationTermCfg
from isaaclab.managers import (
ManagerTermBase,
ObservationGroupCfg,
ObservationManager,
ObservationTermCfg,
RewardTermCfg,
)
from isaaclab.utils import configclass, modifiers

if TYPE_CHECKING:
from isaaclab.envs import ManagerBasedEnv


def grilled_chicken(env):
return torch.ones(env.num_envs, 4, device=env.device)
@@ -662,6 +672,42 @@ class PolicyCfg(ObservationGroupCfg):
with self.assertRaises(ValueError):
self.obs_man = ObservationManager(cfg, self.env)

def test_serialize(self):
"""Test serialize call for ManagerTermBase terms."""

serialize_data = {"test": 0}

class test_serialize_term(ManagerTermBase):

def __init__(self, cfg: RewardTermCfg, env: ManagerBasedEnv):
super().__init__(cfg, env)

def __call__(self, env: ManagerBasedEnv) -> torch.Tensor:
return grilled_chicken(env)

def serialize(self) -> dict:
return serialize_data

@configclass
class MyObservationManagerCfg:
"""Test config class for observation manager."""

@configclass
class PolicyCfg(ObservationGroupCfg):
"""Test config class for policy observation group."""

concatenate_terms = False
term_1 = ObservationTermCfg(func=test_serialize_term)

policy: ObservationGroupCfg = PolicyCfg()

# create observation manager
cfg = MyObservationManagerCfg()
self.obs_man = ObservationManager(cfg, self.env)

# check expected output
self.assertEqual(self.obs_man.serialize(), {"policy": {"term_1": serialize_data}})


if __name__ == "__main__":
run_tests()