-
Notifications
You must be signed in to change notification settings - Fork 255
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PPO train code refactor for checkpointing and curriculum compatibility #211
base: main
Are you sure you want to change the base?
Conversation
The train function was partitioned into several parts. make_train_space: a function that creates the training functions, e.g. training_epoch_with_timing, evaluator, etc... init_training_state: a function that initializes the training state init_env_state: a function that initializes the environment state and run_train: a function that performs a training run, when passed a training state, environment state, and train space The benefit of creating a train space is that the training functions can be run multiple times with jit compilation only occurring once. Combined with partitioning the training loop from training state initialization this increases compatibility for curriculum generation and environment variation, as the environment parameters can be changed, without repeated jit compilation. A simple check pointing version of the train function was created that makes use of the above partition. Finally, for backwards compatibility, the train function still exists and continues to function as before, by calling the above detailed functions.
Ensured that max_devices_per_host is used for all train functions.
Hi btnorman, I'm very glad to see you're showing interest in Brax! The idea of the agents directory is to show example implementations of popular algorithms with Brax. We know it doesn't cover all usages, and so it's expected that people will fork those examples to get the algorithm to do what they want. We expect different users will have different opinions regarding Check-pointing for example - as a result we prefer to leave it unimplemented. There is also that we want the interface for all algorithms to be as close as possible, so I don't believe introducing the abstractions you propose only for PPO can go through. For those reasons, I think it makes sense this PR stays out of the main branch. Have fun with Brax! |
Hi! Thanks for looking over it. I want to make sure I have not misrepresented the intention of the proposed contribution! The contribution comes in two separate parts.
If the aim of making it easier to build on top of existing Brax training code seems valuable, but the proposed implementation lacking, perhaps there is something else we can do, and I would be happy to help! If you decide the change does seem valuable then I would be happy to modify the other algorithms so that they are all consistent. Fyi, to allay a potential concern, in using the checkpointing code with these abstractions there is very low overhead. |
I refactored the PPO/train code to increase its compatibility with check-pointing and environment variation, while keeping backwards compatibility
This splits the train function into multiple functions:
This enables the training functions to be run multiple times with jit compilation only occurring once.
This also adds:
The general aim is that run_train should enable people to easily make their own check pointing and environment variation / curriculum generation on top of the Brax PPO training code, without having to modify the Brax PPO code internally.
This is my first pull request, so let me know if I have made any rookie mistakes! And thanks for the great physics engine!
Fyi, I have not been able to test on in an environment with multiple processors, e.g. a TPU slice