-
Notifications
You must be signed in to change notification settings - Fork 255
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Subsume part of System inside State; EDIT: Or add Options to reset #446
Comments
Wanted to add an example of another workaround: https://github.com/automl/CARL. In this library for meta-RL, instead of batching environments on the GPU which Brax should support, the CARL-brax environments create VectorizedWrappers from Gymnasium in order to run multiple |
Hi @joeryjoery , I believe we considered passing around sys as part of the env state, but IIRC we managed to squeeze out better performance using the current implementation. brax/brax/envs/wrappers/training.py Line 199 in a893224
Feel free to implement a version of the base env class and wrapper which passes the sys in a functional way (e.g. as part of the |
Hey thanks for the reply. A big obstacle right now in trying to implement something like this is that the So I'm trying to work around this by doing dependency injection for @btaba Could the In principle, if these are none then the performance stays the same, and if I want to provide it with options then I can wrap the What do you think? |
Hi @joeryjoery , I'm not quite following why you want to add extra args to |
Hey, yes this works. But it's not the problem. The issue is that I have no easy way to propagate def reset(self, rng: jax.Array) -> State:
"""Resets the environment to an initial state."""
rng, rng1, rng2 = jax.random.split(rng, 3)
...
pipeline_state = self.pipeline_init(q, qd)
obs = self._get_obs(pipeline_state)
... Now suppose I want to wrap A way to solve this is to allow options, for example, def reset(self, rng: jax.Array, *, options: dict | None = None) -> State:
"""Resets the environment to an initial state."""
rng, rng1, rng2 = jax.random.split(rng, 3)
...
pipeline_state = self.pipeline_init(q, qd, options=options) # Pass along here
obs = self._get_obs(pipeline_state)
... In this way, I can wrap my_env.pipeline_init = my_wrapped_init
def my_wrapped_init(self, q, qd, *, options: dict | None = None):
sys = self.sys
if options is not None:
variations = some_sampling_function(options) # returns dict
sys = self.sys.replace(**variations)
return jax.vmap(self._pipeline.init, in_axes=(0, None, None, None))(sys, q, qd, self._debug)
return self._pipeline.init(self.sys, q, qd, self._debug) |
Comments and questions on the proposed changes: [1] Subsume part of [2] Add Options to reset: Strong preference here to add your logic to a wrapper, and to split out the vmap case from the non-vmap case into distinct wrappers. It looks like your proposal is similar to the |
Hey thanks a lot for continuing the discussion. TLDR; I was overthinking this, and the easy solution is indeed a slight modification of
In my implementation I also do not include This is what I propose: class DomainRandomization(brax.envs.Wrapper):
"""Wrapper for Procedural Domain Randomization."""
def __init__(
self,
env: Env,
randomization_fn: Callable[[System, jax.Array], System]
):
super().__init__(env)
self.randomization_fn = randomization_fn
def env_fn(self, sys: System) -> Env:
env = self.env
env.unwrapped.sys = sys
return env
def reset(self, rng: jax.Array) -> State:
key_reset, key_var = jax.random.split(rng)
sys = self.env.unwrapped.sys
variations = self.randomization_fn(sys, key_var)
new_sys = sys.replace(**variations)
new_env = self.env_fn(new_sys)
state = new_env.reset(key_reset)
state = state.replace(info=state.info | {'sys_var': variations})
return state
def step(self, state: State, action: jax.Array) -> State:
variations = state.info['sys_var']
sys = self.env.unwrapped.sys
new_sys = sys.replace(**variations)
new_env = self.env_fn(new_sys)
state = new_env.step(state, action)
state = state.replace(info=state.info | {'sys_var': variations})
return state example usage, def viscosity_randomizer(system: System, key: jax.Array) -> dict[str, Any]:
return {'viscosity': jax.random.uniform(key, system.viscosity.shape)}
env = envs.create(
env_name='ant',
episode_length=1000,
action_repeat=1,
auto_reset=True,
batch_size=None,
)
wrap = DomainRandomization(env, viscosity_randomizer)
s0 = jax.jit(wrap.reset)(jax.random.key(0))
s1 = jax.jit(wrap.reset)(jax.random.key(321))
print(s0.info['sys_var'], s1.info['sys_var'])
>> {'viscosity': Array(0.10536897, dtype=float32)} {'viscosity': Array(0.3906865, dtype=float32)}
print(w.unwrapped.sys.viscosity)
>> Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
print(w.default_sys.viscosity)
>> 0.0 Or composing with the sbatch = jax.jit(brax.envs.wrappers.training.VmapWrapper(wrap).reset)(
jax.random.split(jax.random.key(0), 5)
)
print(sbatch.info['sys_var'])
>> {'viscosity': Array([0.6306313 , 0.5778805 , 0.64515114, 0.95315635, 0.24741197], dtype=float32)} |
It's not really easy to show that this implementation works here, but if you visualize the results using the code shown in the Colab, you can see that it indeed randomizes the I also haven't tested performance for RL training. But it's guaranteed faster than using the current |
Hi @joeryjoery I think we tried a version of this implementation. A few comments: [1] Can you update your impl to make it work for nested fields in FWIW, the impl at HEAD, despite creating a static batch of brax/brax/training/agents/ppo/train.py Lines 418 to 431 in e91772b
|
Hey! For 1) I was working on something like this, but didn't quite finish today, will update it later. What do you mean with For 2), I don't think there is a way around this, we are passing around more data. If the variations are small (like just the viscosity or gravity), then I'd imagine that this is negligible really, but this can grow yes for something like Humanoid and I'm not suggesting that the other However, for me, I'm specifically looking at fulfiling my research assumptions as well as I can. This assumes random environments at every sampled trajectory, which makes learning a good policy also severely more difficult. Also, In my experiments the data-collection is rarely the bottleneck and moreso the learner I've found (at least for my very specific use-case; meaning PPO with a recurrent network architecture that also does internal matrix inversions). If I find the time I'll try run the default agent with the current domain-randomization and the one I posted. |
Hi @joeryjoery , Line 114 in f9a4d73
Thanks for the context on [2], I recommend using your own wrapper (for ensuring sampling a new system for every trajectory), looks like you're pretty close to a more general version with the implementation above! Let us know if you have any trouble and please feel free to share any findings (or open a PR) |
For domain randomization it is not particularly easy to
vmap
over differentSystem
values. For example thegravity
values, or theelasticity
. Preferably you should be able to do this inenv.reset
but right now this is not possible asself.sys
is a global variable in the Env namespace.Right now my hacky workaround is to Mock the Brax environment with my custom PyTree-like dataclass so I can modify the
env.sys
values in a functionally pure way inside thereset
function.It would be nice if brax could expose part of the
sys
dict/ namespace as a pure argument toenv.reset
andenv.step
(e.g., as part of the state).The text was updated successfully, but these errors were encountered: