Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Saving intermediate policies during the training #550

Open
mazzamani opened this issue Nov 5, 2024 · 0 comments
Open

Saving intermediate policies during the training #550

mazzamani opened this issue Nov 5, 2024 · 0 comments

Comments

@mazzamani
Copy link

mazzamani commented Nov 5, 2024

My goal is to save intermediate policies. The policy at the end of training works fine for me:

        # ... main training code
        self.make_inference_fn, self.params, _ = self.train_fn(environment=env, progress_fn=self.progress_callback,
                                                               policy_params_fn=self.policy_params_callback)
        self.visualize_trajectory()

However, when I call a self.visualize_trajectory() from self.policy_params_callback

        def policy_params_callback(self, step, make_policy, params):
            self.make_inference_fn = make_policy
            self.params = params  
            self.visualize_trajectory()

I run into the following error:

  File "/brax/training_code.py", line 29, in policy_params_callback
    self.visualize_trajectory()
  File "/brax/training_code.py", line 74, in visualize_trajectory
    act, _ = jit_inference_fn(state.obs, act_rng)
  File "/miniconda3/envs/brax/lib/python3.10/site-packages/brax/training/agents/ppo/networks.py", line 44, in policy
    logits = policy_network.apply(*params, observations)
  File "/miniconda3/envs/brax/lib/python3.10/site-packages/brax/training/networks.py", line 104, in apply
    return policy_module.apply(policy_params, obs)
TypeError: argument of type 'PPONetworkParams' is not iterable

This is my visualization method which is adapted from the example training code:

    def visualize_trajectory(self):
        inference_fn = self.make_inference_fn(self.params)
        env = self.load_environment()
        jit_env_reset = jax.jit(env.reset)
        jit_env_step = jax.jit(env.step)
        jit_inference_fn = jax.jit(inference_fn)

        trajectory = []
        rng = jax.random.PRNGKey(seed=1)
        state = jit_env_reset(rng=rng)

        for _ in tqdm(range(1000)):
            trajectory.append(state.pipeline_state)
            act_rng, rng = jax.random.split(rng)
            act, _ = jit_inference_fn(state.obs, act_rng)
            state = jit_env_step(state, act)

        rendered_html = html.render(env.sys.tree_replace({'opt.timestep': env.dt}), trajectory)
        with open("trajectory_visualization.html", "w") as file:
            file.write(rendered_html)

        print("trajectory visualization prepared.")

Any idea why it is not working?

Edit: here is the whole script:

import functools

import jax
from brax import envs
from brax.io import html
from brax.training.agents.ppo import train as ppo

from tqdm import tqdm


class RLTrainer:
    def __init__(self):
        self.env_name = 'ant'
        self.backend = 'positional'
        self.params = None
        self.make_inference_fn = None
        self.train_fn = None

    def load_environment(self):
        env = envs.get_environment(env_name=self.env_name, backend=self.backend)
        return env

    def policy_params_callback(self, step, make_policy, params):
        self.make_inference_fn = make_policy
        self.params = params
        self.visualize_trajectory()

    def progress_callback(self, num_steps, metrics):
        print(f"Training progress: steps={num_steps}, reward={metrics['eval/episode_reward']:.2f}")
        
    def start_training(self, env_name, backend):
        """Begins the training in a separate thread."""
        print("Training started ...")
        self.env_name = env_name
        self.backend = backend
        self.stop_training = False
        env = self.load_environment()
        self.train_fn = {
            'ant': functools.partial(ppo.train, num_timesteps=50_000, num_evals=2, reward_scaling=10,
                                     episode_length=1000, normalize_observations=True, action_repeat=1,
                                     unroll_length=5, num_minibatches=32, num_updates_per_batch=4,
                                     discounting=0.97, learning_rate=3e-4, entropy_cost=1e-2, num_envs=4096,
                                     batch_size=2048, seed=1),
            # Add other environment setups as needed...
        }[self.env_name]

        self.make_inference_fn, self.params, _ = self.train_fn(environment=env, progress_fn=self.progress_callback,
                                                               policy_params_fn=self.policy_params_callback)

        print("training finished. Visualizing the policy...")
        self.visualize_trajectory()

    def visualize_trajectory(self):
        inference_fn = self.make_inference_fn(self.params)
        env = self.load_environment()
        jit_env_reset = jax.jit(env.reset)
        jit_env_step = jax.jit(env.step)
        jit_inference_fn = jax.jit(inference_fn)

        trajectory = []
        rng = jax.random.PRNGKey(seed=1)
        state = jit_env_reset(rng=rng)

        for _ in tqdm(range(1000)):
            trajectory.append(state.pipeline_state)
            act_rng, rng = jax.random.split(rng)
            act, _ = jit_inference_fn(state.obs, act_rng)
            state = jit_env_step(state, act)

        rendered_html = html.render(env.sys.tree_replace({'opt.timestep': env.dt}), trajectory)
        with open("trajectory_visualization.html", "w") as file:
            file.write(rendered_html)

        print("trajectory visualization prepared.")


# Run the visualizer
if __name__ == "__main__":
    rl_trainer = RLTrainer()
    rl_trainer.start_training(env_name='ant', backend='positional')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant