Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 685868876
Change-Id: Ied7d20aa50889b6b611fc355a56f084ab212b125
  • Loading branch information
Brax Team authored and btaba committed Oct 14, 2024
1 parent 865d974 commit 6a62109
Show file tree
Hide file tree
Showing 6 changed files with 150 additions and 133 deletions.
53 changes: 28 additions & 25 deletions brax/training/agents/apg/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def train(
environment: Union[envs_v1.Env, envs.Env],
episode_length: int,
policy_updates: int,
wrap_env: bool = True,
horizon_length: int = 32,
num_envs: int = 1,
num_evals: int = 1,
Expand Down Expand Up @@ -102,29 +103,30 @@ def train(
updates_per_epoch = jnp.round(num_updates / (num_evals_after_init))

assert num_envs % device_count == 0
env = environment
if isinstance(env, envs.Env):
wrap_for_training = envs.training.wrap
else:
wrap_for_training = envs_v1.wrappers.wrap_for_training

key = jax.random.PRNGKey(seed)
global_key, local_key = jax.random.split(key)
rng, global_key = jax.random.split(global_key, 2)
local_key = jax.random.fold_in(local_key, process_id)
local_key, eval_key = jax.random.split(local_key)

v_randomiation_fn = None
if randomization_fn is not None:
v_randomiation_fn = functools.partial(
randomization_fn, rng=jax.random.split(rng, num_envs // process_count)
env = environment
if wrap_env:
if isinstance(env, envs.Env):
wrap_for_training = envs.training.wrap
else:
wrap_for_training = envs_v1.wrappers.wrap_for_training

v_randomization_fn = None
if randomization_fn is not None:
v_randomization_fn = functools.partial(
randomization_fn, rng=jax.random.split(rng, num_envs // process_count)
)
env = wrap_for_training(
env,
episode_length=episode_length,
action_repeat=action_repeat,
randomization_fn=v_randomization_fn,
)
env = wrap_for_training(
env,
episode_length=episode_length,
action_repeat=action_repeat,
randomization_fn=v_randomiation_fn,
)

reset_fn = jax.jit(jax.vmap(env.reset))
step_fn = jax.jit(jax.vmap(env.step))
Expand Down Expand Up @@ -298,16 +300,17 @@ def training_epoch_with_timing(

if not eval_env:
eval_env = environment
if randomization_fn is not None:
v_randomiation_fn = functools.partial(
randomization_fn, rng=jax.random.split(eval_key, num_eval_envs)
if wrap_env:
if randomization_fn is not None:
v_randomization_fn = functools.partial(
randomization_fn, rng=jax.random.split(eval_key, num_eval_envs)
)
eval_env = wrap_for_training(
eval_env,
episode_length=episode_length,
action_repeat=action_repeat,
randomization_fn=v_randomization_fn,
)
eval_env = wrap_for_training(
eval_env,
episode_length=episode_length,
action_repeat=action_repeat,
randomization_fn=v_randomiation_fn,
)

evaluator = acting.Evaluator(
eval_env,
Expand Down
54 changes: 28 additions & 26 deletions brax/training/agents/ars/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class TrainingState:
# TODO: Pass the network as argument.
def train(
environment: Union[envs_v1.Env, envs.Env],
wrap_env: bool = True,
num_timesteps: int = 100,
episode_length: int = 1000,
action_repeat: int = 1,
Expand Down Expand Up @@ -98,23 +99,24 @@ def train(

assert num_envs % local_devices_to_use == 0
env = environment
if isinstance(env, envs.Env):
wrap_for_training = envs.training.wrap
else:
wrap_for_training = envs_v1.wrappers.wrap_for_training

v_randomization_fn = None
if randomization_fn is not None:
v_randomization_fn = functools.partial(
randomization_fn,
rng=jax.random.split(rng_key, num_envs // local_devices_to_use),
if wrap_env:
if isinstance(env, envs.Env):
wrap_for_training = envs.training.wrap
else:
wrap_for_training = envs_v1.wrappers.wrap_for_training

v_randomization_fn = None
if randomization_fn is not None:
v_randomization_fn = functools.partial(
randomization_fn,
rng=jax.random.split(rng_key, num_envs // local_devices_to_use),
)
env = wrap_for_training(
env,
episode_length=episode_length,
action_repeat=action_repeat,
randomization_fn=v_randomization_fn,
)
env = wrap_for_training(
env,
episode_length=episode_length,
action_repeat=action_repeat,
randomization_fn=v_randomization_fn,
)

obs_size = env.observation_size

Expand Down Expand Up @@ -273,17 +275,17 @@ def training_epoch_with_timing(training_state: TrainingState,

if not eval_env:
eval_env = environment

if randomization_fn is not None:
v_randomization_fn = functools.partial(
randomization_fn, rng=jax.random.split(eval_key, num_eval_envs)
if wrap_env:
if randomization_fn is not None:
v_randomization_fn = functools.partial(
randomization_fn, rng=jax.random.split(eval_key, num_eval_envs)
)
eval_env = wrap_for_training(
eval_env,
episode_length=episode_length,
action_repeat=action_repeat,
randomization_fn=v_randomization_fn,
)
eval_env = wrap_for_training(
eval_env,
episode_length=episode_length,
action_repeat=action_repeat,
randomization_fn=v_randomization_fn,
)

# Evaluator function
evaluator = acting.Evaluator(
Expand Down
53 changes: 28 additions & 25 deletions brax/training/agents/es/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class FitnessShaping(enum.Enum):
# TODO: Pass the network as argument.
def train(
environment: Union[envs_v1.Env, envs.Env],
wrap_env: bool = True,
num_timesteps: int = 100,
episode_length: int = 1000,
action_repeat: int = 1,
Expand Down Expand Up @@ -125,23 +126,24 @@ def train(

assert num_envs % local_devices_to_use == 0
env = environment
if isinstance(env, envs.Env):
wrap_for_training = envs.training.wrap
else:
wrap_for_training = envs_v1.wrappers.wrap_for_training

v_randomization_fn = None
if randomization_fn is not None:
v_randomization_fn = functools.partial(
randomization_fn,
rng=jax.random.split(rng_key, num_envs // local_devices_to_use),
if wrap_env:
if isinstance(env, envs.Env):
wrap_for_training = envs.training.wrap
else:
wrap_for_training = envs_v1.wrappers.wrap_for_training

v_randomization_fn = None
if randomization_fn is not None:
v_randomization_fn = functools.partial(
randomization_fn,
rng=jax.random.split(rng_key, num_envs // local_devices_to_use),
)
env = wrap_for_training(
env,
episode_length=episode_length,
action_repeat=action_repeat,
randomization_fn=v_randomization_fn,
)
env = wrap_for_training(
env,
episode_length=episode_length,
action_repeat=action_repeat,
randomization_fn=v_randomization_fn,
)

obs_size = env.observation_size

Expand Down Expand Up @@ -325,16 +327,17 @@ def training_epoch_with_timing(training_state: TrainingState,

if not eval_env:
eval_env = environment
if randomization_fn is not None:
v_randomization_fn = functools.partial(
randomization_fn, rng=jax.random.split(eval_key, num_eval_envs)
if wrap_env:
if randomization_fn is not None:
v_randomization_fn = functools.partial(
randomization_fn, rng=jax.random.split(eval_key, num_eval_envs)
)
eval_env = wrap_for_training(
eval_env,
episode_length=episode_length,
action_repeat=action_repeat,
randomization_fn=v_randomization_fn,
)
eval_env = wrap_for_training(
eval_env,
episode_length=episode_length,
action_repeat=action_repeat,
randomization_fn=v_randomization_fn,
)

# Evaluator function
evaluator = acting.Evaluator(
Expand Down
60 changes: 32 additions & 28 deletions brax/training/agents/ppo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def train(
environment: Union[envs_v1.Env, envs.Env],
num_timesteps: int,
episode_length: int,
wrap_env: bool = True,
action_repeat: int = 1,
num_envs: int = 1,
max_devices_per_host: Optional[int] = None,
Expand Down Expand Up @@ -113,6 +114,8 @@ def train(
environment: the environment to train
num_timesteps: the total number of environment steps to use during training
episode_length: the length of an environment episode
wrap_env: If True, wrap the environment for training. Otherwise use the
environment as is.
action_repeat: the number of timesteps to repeat an action
num_envs: the number of parallel environments to use for rollouts
NOTE: `num_envs` must be divisible by the total number of chips since each
Expand Down Expand Up @@ -202,27 +205,27 @@ def train(

assert num_envs % device_count == 0

v_randomization_fn = None
if randomization_fn is not None:
randomization_batch_size = num_envs // local_device_count
# all devices gets the same randomization rng
randomization_rng = jax.random.split(key_env, randomization_batch_size)
v_randomization_fn = functools.partial(
randomization_fn, rng=randomization_rng
env = environment
if wrap_env:
v_randomization_fn = None
if randomization_fn is not None:
randomization_batch_size = num_envs // local_device_count
# all devices gets the same randomization rng
randomization_rng = jax.random.split(key_env, randomization_batch_size)
v_randomization_fn = functools.partial(
randomization_fn, rng=randomization_rng
)
if isinstance(environment, envs.Env):
wrap_for_training = envs.training.wrap
else:
wrap_for_training = envs_v1.wrappers.wrap_for_training
env = wrap_for_training(
environment,
episode_length=episode_length,
action_repeat=action_repeat,
randomization_fn=v_randomization_fn,
)

if isinstance(environment, envs.Env):
wrap_for_training = envs.training.wrap
else:
wrap_for_training = envs_v1.wrappers.wrap_for_training

env = wrap_for_training(
environment,
episode_length=episode_length,
action_repeat=action_repeat,
randomization_fn=v_randomization_fn,
)

reset_fn = jax.jit(jax.vmap(env.reset))
key_envs = jax.random.split(key_env, num_envs // process_count)
key_envs = jnp.reshape(key_envs,
Expand Down Expand Up @@ -409,16 +412,17 @@ def training_epoch_with_timing(

if not eval_env:
eval_env = environment
if randomization_fn is not None:
v_randomization_fn = functools.partial(
randomization_fn, rng=jax.random.split(eval_key, num_eval_envs)
if wrap_env:
if randomization_fn is not None:
v_randomization_fn = functools.partial(
randomization_fn, rng=jax.random.split(eval_key, num_eval_envs)
)
eval_env = wrap_for_training(
eval_env,
episode_length=episode_length,
action_repeat=action_repeat,
randomization_fn=v_randomization_fn,
)
eval_env = wrap_for_training(
eval_env,
episode_length=episode_length,
action_repeat=action_repeat,
randomization_fn=v_randomization_fn,
)

evaluator = acting.Evaluator(
eval_env,
Expand Down
Loading

0 comments on commit 6a62109

Please sign in to comment.