Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PPO train code refactor for checkpointing and curriculum compatibility #211

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
225 changes: 200 additions & 25 deletions brax/training/agents/ppo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from brax.training.types import Params
from brax.training.types import PRNGKey
import flax
from gym import make
import jax
import jax.numpy as jnp
import optax
Expand All @@ -57,8 +58,9 @@ class TrainingState:
def _unpmap(v):
return jax.tree_map(lambda x: x[0], v)


def train(environment: envs.Env,
from types import SimpleNamespace

def make_train_space(environment: envs.Env,
num_timesteps: int,
episode_length: int,
action_repeat: int = 1,
Expand All @@ -82,9 +84,11 @@ def train(environment: envs.Env,
network_factory: types.NetworkFactory[
ppo_networks.PPONetworks] = ppo_networks.make_ppo_networks,
progress_fn: Callable[[int, Metrics], None] = lambda *args: None):
"""PPO training."""
"""Creates a PPO training (name) space

This contains the functions used to train the agent, they are tracked in a namespace,
so as to avoid recompilation when using checkpointing, and or environment variation"""
assert batch_size * num_minibatches % num_envs == 0
xt = time.time()

process_count = jax.process_count()
process_id = jax.process_index()
Expand Down Expand Up @@ -231,29 +235,71 @@ def training_epoch(training_state: TrainingState, state: envs.State,
def training_epoch_with_timing(
training_state: TrainingState, env_state: envs.State,
key: PRNGKey) -> Tuple[TrainingState, envs.State, Metrics]:
nonlocal training_walltime
t = time.time()
(training_state, env_state,
metrics) = training_epoch(training_state, env_state, key)
metrics = jax.tree_map(jnp.mean, metrics)
jax.tree_map(lambda x: x.block_until_ready(), metrics)

epoch_training_time = time.time() - t
training_walltime += epoch_training_time
sps = (num_training_steps_per_epoch *
env_step_per_training_step) / epoch_training_time
metrics = {
'training/sps': sps,
'training/walltime': training_walltime,
**{f'training/{name}': value for name, value in metrics.items()}
}
return training_state, env_state, metrics

train_space = SimpleNamespace()
train_space.env = env
train_space.ppo_network = ppo_network
train_space.optimizer = optimizer
train_space.num_envs = num_envs
train_space.make_policy = make_policy
train_space.reset_fn = reset_fn
train_space.progress_fn = progress_fn
train_space.deterministic_eval = deterministic_eval
train_space.num_eval_envs = num_eval_envs
train_space.episode_length = episode_length
train_space.action_repeat = action_repeat
train_space.num_evals = num_evals
train_space.num_evals_after_init = num_evals_after_init
train_space.training_epoch_with_timing = training_epoch_with_timing
train_space.num_timesteps = num_timesteps
train_space.seed = seed
train_space.max_devices_per_host = max_devices_per_host

key = jax.random.PRNGKey(seed)
global_key, local_key = jax.random.split(key)
_, local_key = jax.random.split(key)
del key
local_key = jax.random.fold_in(local_key, process_id)
local_key, key_env, eval_key = jax.random.split(local_key, 3)
_, _, eval_key = jax.random.split(local_key, 3)

evaluator = acting.Evaluator(
env,
functools.partial(make_policy, deterministic=deterministic_eval),
num_eval_envs=num_eval_envs,
episode_length=episode_length,
action_repeat=action_repeat,
key=eval_key)
train_space.evaluator = evaluator

return train_space


def init_training_state(train_space):
"""initializes the training state, based on the parameters and functions in train space"""
env = train_space.env
ppo_network = train_space.ppo_network
optimizer = train_space.optimizer
seed = train_space.seed

if train_space.max_devices_per_host:
local_devices_to_use = min(local_devices_to_use, train_space.max_devices_per_host)

key = jax.random.PRNGKey(seed)
global_key, local_key = jax.random.split(key)
del key
# key_networks should be global, so that networks are initialized the same
# way for different processes.
key_policy, key_value = jax.random.split(global_key)
Expand All @@ -268,33 +314,91 @@ def training_epoch_with_timing(
normalizer_params=running_statistics.init_state(
specs.Array((env.observation_size,), jnp.float32)),
env_steps=0)
training_state = jax.device_put_replicated(
training_state,
jax.local_devices()[:local_devices_to_use])

return training_state


def init_env_state(train_space):
num_envs = train_space.num_envs
reset_fn = train_space.reset_fn
seed = train_space.seed

process_count = jax.process_count()
process_id = jax.process_index()
local_device_count = jax.local_device_count()
local_devices_to_use = local_device_count

if train_space.max_devices_per_host:
local_devices_to_use = min(local_devices_to_use, train_space.max_devices_per_host)


key = jax.random.PRNGKey(seed)
global_key, local_key = jax.random.split(key)
del key
local_key = jax.random.fold_in(local_key, process_id)
local_key, key_env, eval_key = jax.random.split(local_key, 3)
# key_networks should be global, so that networks are initialized the same
# way for different processes.
del global_key
key_envs = jax.random.split(key_env, num_envs // process_count)
key_envs = jnp.reshape(key_envs,
(local_devices_to_use, -1) + key_envs.shape[1:])
env_state = reset_fn(key_envs)
(local_devices_to_use, -1) + key_envs.shape[1:])
return reset_fn(key_envs)


def train_run(train_space, training_state, env_state):
"""Train a PPO agent, with initial training state training_state using the train_space.

train_space provides the functions and parameters used during training, while train_state parameterizes
both the agent, its policy, value function and the optimizer state.

This partioning is useful if one wants to use curiculum generation or checkpointing"""
make_policy = train_space.make_policy
progress_fn = train_space.progress_fn
num_evals = train_space.num_evals
num_evals_after_init = train_space.num_evals_after_init
training_epoch_with_timing = train_space.training_epoch_with_timing
num_timesteps = train_space.num_timesteps
seed = train_space.seed
evaluator = train_space.evaluator

evaluator = acting.Evaluator(
env,
functools.partial(make_policy, deterministic=deterministic_eval),
num_eval_envs=num_eval_envs,
episode_length=episode_length,
action_repeat=action_repeat,
key=eval_key)
xt = time.time()

process_count = jax.process_count()
process_id = jax.process_index()
local_device_count = jax.local_device_count()
local_devices_to_use = local_device_count

if train_space.max_devices_per_host:
local_devices_to_use = min(local_devices_to_use, train_space.max_devices_per_host)
logging.info(
'Device count: %d, process count: %d (id %d), local device count: %d, '
'devices to be used count: %d', jax.device_count(), process_count,
process_id, local_device_count, local_devices_to_use)

key = jax.random.PRNGKey(seed)
global_key, local_key = jax.random.split(key)
del key
local_key = jax.random.fold_in(local_key, process_id)
local_key, key_env, eval_key = jax.random.split(local_key, 3)
# key_networks should be global, so that networks are initialized the same
# way for different processes.
del global_key

training_state = jax.device_put_replicated(
training_state,
jax.local_devices()[:local_devices_to_use])

# Run initial eval
if process_id == 0 and num_evals > 1:
current_step = int(_unpmap(training_state.env_steps))
metrics = evaluator.run_evaluation(
_unpmap(
(training_state.normalizer_params, training_state.params.policy)),
training_metrics={})
logging.info(metrics)
progress_fn(0, metrics)
progress_fn(current_step, metrics)

training_walltime = 0
current_step = 0
for it in range(num_evals_after_init):
logging.info('starting iteration %s %s', it, time.time() - xt)
Expand All @@ -303,7 +407,7 @@ def training_epoch_with_timing(
epoch_key, local_key = jax.random.split(local_key)
epoch_keys = jax.random.split(epoch_key, local_devices_to_use)
(training_state, env_state,
training_metrics) = training_epoch_with_timing(training_state, env_state,
training_metrics) = training_epoch_with_timing(training_state, env_state,
epoch_keys)
current_step = int(_unpmap(training_state.env_steps))

Expand All @@ -326,4 +430,75 @@ def training_epoch_with_timing(
(training_state.normalizer_params, training_state.params.policy))
logging.info('total steps: %s', total_steps)
pmap.synchronize_hosts()
return (make_policy, params, metrics)
return (make_policy, params, metrics, _unpmap(training_state), env_state)


def train(*args, **kwargs):
"""PPO training."""
train_space = make_train_space(*args, **kwargs)
training_state = init_training_state(train_space=train_space)
env_state = init_env_state(train_space=train_space)
return train_run(train_space=train_space,
training_state=training_state,
env_state=env_state)[:3]

import pickle, os
def checkpoint_train(environment,
num_timesteps,
break_steps,
checkpoint_dir,
checkpoint_time,
**kwargs):
"""Performs checkpointed training.

This loads the most recent checkpoint when run,
and continues to run until at least num_timesteps of training have occured.

The checkpointing happens when checkpoint_time has passed,
but this is only checked every break_step training steps."""

train_space = make_train_space(environment=environment, num_timesteps=break_steps, **kwargs)
checkpoint_dir = os.path.join(checkpoint_dir, str(jax.process_index())) # different processors may have different env_state
if not os.path.isdir(checkpoint_dir):
if os.path.exists(checkpoint_dir):
raise Exception("the checkpoint directory is not a directory!")
else:
os.makedirs(checkpoint_dir)

## loads the most recent checkpoint, or initializes if none is found
checkpoint_steps = []
for filename in os.listdir(checkpoint_dir):
if (filename[-4:] == '.pkl') and (filename[:4] == 'ppo_'):
checkpoint_steps.append(int(filename[4:-4]))
if checkpoint_steps:
with open(os.path.join(checkpoint_dir,
f'ppo_{max(checkpoint_steps)}.pkl'),
'rb') as file:
training_state, env_state = pickle.load(file)
print(f"loaded ppo_{max(checkpoint_steps)}.pkl")
else:
print("no checkpoint found, initializing instead")
training_state = init_training_state(train_space)
env_state = init_env_state(train_space)
with open(os.path.join(checkpoint_dir, f'ppo_{training_state.env_steps}.pkl'), 'wb') as file:
pickle.dump((training_state, env_state), file=file)

## save checkpoints
t = time.time()
while True:
ans = train_run(train_space=train_space,
training_state=training_state,
env_state=env_state)
training_state, env_state = ans[-2:]
if (time.time() - t > checkpoint_time):
t = time.time()
with open(os.path.join(checkpoint_dir, f'ppo_{training_state.env_steps}.pkl'), 'wb') as file:
pickle.dump((training_state, env_state), file=file)
print(f"check point ppo_{training_state.env_steps}.pkl saved")
if training_state.env_steps > num_timesteps:
## ensures the final state is always saved
if not os.path.exists(os.path.join(checkpoint_dir, f'ppo_{training_state.env_steps}.pkl')):
with open(os.path.join(checkpoint_dir, f'ppo_{training_state.env_steps}.pkl'), 'wb') as file:
pickle.dump((training_state, env_state), file=file)
break
return ans[:3]