From b5f28c73286aff3700ce89ee7930f85a40b2b834 Mon Sep 17 00:00:00 2001 From: zjowowen Date: Tue, 30 Jul 2024 14:26:00 +0800 Subject: [PATCH] Add torchrl tensordict dataset and replay buffer. --- .../benchmark/gmpg/gvp/halfcheetah_medium.py | 2 +- .../gmpg/gvp/halfcheetah_medium_expert.py | 2 +- .../gmpg/gvp/halfcheetah_medium_replay.py | 2 +- .../benchmark/gmpg/gvp/hopper_medium.py | 2 +- .../gmpg/gvp/hopper_medium_expert.py | 2 +- .../gmpg/gvp/hopper_medium_replay.py | 2 +- .../benchmark/gmpg/gvp/walker2d_medium.py | 2 +- .../gmpg/gvp/walker2d_medium_expert.py | 2 +- .../gmpg/gvp/walker2d_medium_replay.py | 2 +- .../benchmark/gmpg/icfm/halfcheetah_medium.py | 2 +- .../gmpg/icfm/halfcheetah_medium_expert.py | 2 +- .../gmpg/icfm/halfcheetah_medium_replay.py | 2 +- .../benchmark/gmpg/icfm/hopper_medium.py | 2 +- .../gmpg/icfm/hopper_medium_expert.py | 2 +- .../gmpg/icfm/hopper_medium_replay.py | 2 +- .../benchmark/gmpg/icfm/walker2d_medium.py | 2 +- .../gmpg/icfm/walker2d_medium_expert.py | 2 +- .../gmpg/icfm/walker2d_medium_replay.py | 2 +- .../gmpg/vpsde/antmaze_large_diverse.py | 206 ++++++++++++++++++ .../gmpg/vpsde/antmaze_large_play.py | 206 ++++++++++++++++++ .../gmpg/vpsde/antmaze_medium_diverse.py | 206 ++++++++++++++++++ .../gmpg/vpsde/antmaze_medium_play.py | 206 ++++++++++++++++++ .../benchmark/gmpg/vpsde/antmaze_umaze.py | 206 ++++++++++++++++++ .../gmpg/vpsde/antmaze_umaze_diverse.py | 206 ++++++++++++++++++ .../gmpg/vpsde/halfcheetah_medium.py | 2 +- .../gmpg/vpsde/halfcheetah_medium_expert.py | 2 +- .../gmpg/vpsde/halfcheetah_medium_replay.py | 2 +- .../benchmark/gmpg/vpsde/hopper_medium.py | 2 +- .../gmpg/vpsde/hopper_medium_expert.py | 2 +- .../gmpg/vpsde/hopper_medium_replay.py | 2 +- .../benchmark/gmpg/vpsde/walker2d_medium.py | 2 +- .../gmpg/vpsde/walker2d_medium_expert.py | 2 +- .../gmpg/vpsde/walker2d_medium_replay.py | 2 +- .../benchmark/gmpo/gvp/halfcheetah_medium.py | 2 +- .../gmpo/gvp/halfcheetah_medium_expert.py | 2 +- .../gmpo/gvp/halfcheetah_medium_replay.py | 2 +- .../benchmark/gmpo/gvp/hopper_medium.py | 2 +- .../gmpo/gvp/hopper_medium_expert.py | 2 +- .../gmpo/gvp/hopper_medium_replay.py | 2 +- .../benchmark/gmpo/gvp/walker2d_medium.py | 2 +- .../gmpo/gvp/walker2d_medium_expert.py | 2 +- .../gmpo/gvp/walker2d_medium_replay.py | 2 +- .../benchmark/gmpo/icfm/halfcheetah_medium.py | 2 +- .../gmpo/icfm/halfcheetah_medium_expert.py | 2 +- .../gmpo/icfm/halfcheetah_medium_replay.py | 2 +- .../benchmark/gmpo/icfm/hopper_medium.py | 2 +- .../gmpo/icfm/hopper_medium_expert.py | 2 +- .../gmpo/icfm/hopper_medium_replay.py | 2 +- .../benchmark/gmpo/icfm/walker2d_medium.py | 2 +- .../gmpo/icfm/walker2d_medium_expert.py | 2 +- .../gmpo/icfm/walker2d_medium_replay.py | 2 +- .../gmpo/vpsde/halfcheetah_medium.py | 2 +- .../gmpo/vpsde/halfcheetah_medium_expert.py | 2 +- .../gmpo/vpsde/halfcheetah_medium_replay.py | 2 +- .../benchmark/gmpo/vpsde/hopper_medium.py | 2 +- .../gmpo/vpsde/hopper_medium_expert.py | 2 +- .../gmpo/vpsde/hopper_medium_replay.py | 2 +- .../benchmark/gmpo/vpsde/walker2d_medium.py | 2 +- .../gmpo/vpsde/walker2d_medium_expert.py | 2 +- .../gmpo/vpsde/walker2d_medium_replay.py | 2 +- .../configurations/d4rl_halfcheetah_qgpo.py | 7 +- .../configurations/d4rl_halfcheetah_srpo.py | 2 +- .../configurations/d4rl_hopper_srpo.py | 2 +- .../configurations/d4rl_walker2d_qgpo.py | 6 +- .../configurations/d4rl_walker2d_srpo.py | 2 +- .../lunarlander_continuous_qgpo.py | 2 +- .../lunarlander_continuous_qgpo.py | 2 +- 67 files changed, 1303 insertions(+), 64 deletions(-) create mode 100755 grl_pipelines/benchmark/gmpg/vpsde/antmaze_large_diverse.py create mode 100755 grl_pipelines/benchmark/gmpg/vpsde/antmaze_large_play.py create mode 100755 grl_pipelines/benchmark/gmpg/vpsde/antmaze_medium_diverse.py create mode 100755 grl_pipelines/benchmark/gmpg/vpsde/antmaze_medium_play.py create mode 100755 grl_pipelines/benchmark/gmpg/vpsde/antmaze_umaze.py create mode 100755 grl_pipelines/benchmark/gmpg/vpsde/antmaze_umaze_diverse.py diff --git a/grl_pipelines/benchmark/gmpg/gvp/halfcheetah_medium.py b/grl_pipelines/benchmark/gmpg/gvp/halfcheetah_medium.py index a09ec21..1801025 100644 --- a/grl_pipelines/benchmark/gmpg/gvp/halfcheetah_medium.py +++ b/grl_pipelines/benchmark/gmpg/gvp/halfcheetah_medium.py @@ -62,7 +62,7 @@ ), ), dataset=dict( - type="GPD4RLDataset", + type="GPD4RLTensorDictDataset", args=dict( env_id=env_id, ), diff --git a/grl_pipelines/benchmark/gmpg/gvp/halfcheetah_medium_expert.py b/grl_pipelines/benchmark/gmpg/gvp/halfcheetah_medium_expert.py index 1a02c6e..0eee113 100644 --- a/grl_pipelines/benchmark/gmpg/gvp/halfcheetah_medium_expert.py +++ b/grl_pipelines/benchmark/gmpg/gvp/halfcheetah_medium_expert.py @@ -62,7 +62,7 @@ ), ), dataset=dict( - type="GPD4RLDataset", + type="GPD4RLTensorDictDataset", args=dict( env_id=env_id, ), diff --git a/grl_pipelines/benchmark/gmpg/gvp/halfcheetah_medium_replay.py b/grl_pipelines/benchmark/gmpg/gvp/halfcheetah_medium_replay.py index eff781f..408a80d 100644 --- a/grl_pipelines/benchmark/gmpg/gvp/halfcheetah_medium_replay.py +++ b/grl_pipelines/benchmark/gmpg/gvp/halfcheetah_medium_replay.py @@ -62,7 +62,7 @@ ), ), dataset=dict( - type="GPD4RLDataset", + type="GPD4RLTensorDictDataset", args=dict( env_id=env_id, ), diff --git a/grl_pipelines/benchmark/gmpg/gvp/hopper_medium.py b/grl_pipelines/benchmark/gmpg/gvp/hopper_medium.py index 80f8131..70ad45c 100644 --- a/grl_pipelines/benchmark/gmpg/gvp/hopper_medium.py +++ b/grl_pipelines/benchmark/gmpg/gvp/hopper_medium.py @@ -62,7 +62,7 @@ ), ), dataset=dict( - type="GPD4RLDataset", + type="GPD4RLTensorDictDataset", args=dict( env_id=env_id, ), diff --git a/grl_pipelines/benchmark/gmpg/gvp/hopper_medium_expert.py b/grl_pipelines/benchmark/gmpg/gvp/hopper_medium_expert.py index 2e72e3c..a4020d1 100644 --- a/grl_pipelines/benchmark/gmpg/gvp/hopper_medium_expert.py +++ b/grl_pipelines/benchmark/gmpg/gvp/hopper_medium_expert.py @@ -62,7 +62,7 @@ ), ), dataset=dict( - type="GPD4RLDataset", + type="GPD4RLTensorDictDataset", args=dict( env_id=env_id, ), diff --git a/grl_pipelines/benchmark/gmpg/gvp/hopper_medium_replay.py b/grl_pipelines/benchmark/gmpg/gvp/hopper_medium_replay.py index 745b57c..c089777 100644 --- a/grl_pipelines/benchmark/gmpg/gvp/hopper_medium_replay.py +++ b/grl_pipelines/benchmark/gmpg/gvp/hopper_medium_replay.py @@ -62,7 +62,7 @@ ), ), dataset=dict( - type="GPD4RLDataset", + type="GPD4RLTensorDictDataset", args=dict( env_id=env_id, ), diff --git a/grl_pipelines/benchmark/gmpg/gvp/walker2d_medium.py b/grl_pipelines/benchmark/gmpg/gvp/walker2d_medium.py index 335b76e..3a950a1 100644 --- a/grl_pipelines/benchmark/gmpg/gvp/walker2d_medium.py +++ b/grl_pipelines/benchmark/gmpg/gvp/walker2d_medium.py @@ -62,7 +62,7 @@ ), ), dataset=dict( - type="GPD4RLDataset", + type="GPD4RLTensorDictDataset", args=dict( env_id=env_id, ), diff --git a/grl_pipelines/benchmark/gmpg/gvp/walker2d_medium_expert.py b/grl_pipelines/benchmark/gmpg/gvp/walker2d_medium_expert.py index 2d5f35c..fbf15bd 100644 --- a/grl_pipelines/benchmark/gmpg/gvp/walker2d_medium_expert.py +++ b/grl_pipelines/benchmark/gmpg/gvp/walker2d_medium_expert.py @@ -62,7 +62,7 @@ ), ), dataset=dict( - type="GPD4RLDataset", + type="GPD4RLTensorDictDataset", args=dict( env_id=env_id, ), diff --git a/grl_pipelines/benchmark/gmpg/gvp/walker2d_medium_replay.py b/grl_pipelines/benchmark/gmpg/gvp/walker2d_medium_replay.py index 788b962..853bc58 100644 --- a/grl_pipelines/benchmark/gmpg/gvp/walker2d_medium_replay.py +++ b/grl_pipelines/benchmark/gmpg/gvp/walker2d_medium_replay.py @@ -62,7 +62,7 @@ ), ), dataset=dict( - type="GPD4RLDataset", + type="GPD4RLTensorDictDataset", args=dict( env_id=env_id, ), diff --git a/grl_pipelines/benchmark/gmpg/icfm/halfcheetah_medium.py b/grl_pipelines/benchmark/gmpg/icfm/halfcheetah_medium.py index 5694f89..83720ad 100644 --- a/grl_pipelines/benchmark/gmpg/icfm/halfcheetah_medium.py +++ b/grl_pipelines/benchmark/gmpg/icfm/halfcheetah_medium.py @@ -62,7 +62,7 @@ ), ), dataset=dict( - type="GPD4RLDataset", + type="GPD4RLTensorDictDataset", args=dict( env_id=env_id, ), diff --git a/grl_pipelines/benchmark/gmpg/icfm/halfcheetah_medium_expert.py b/grl_pipelines/benchmark/gmpg/icfm/halfcheetah_medium_expert.py index 92564fa..6cbcd7a 100644 --- a/grl_pipelines/benchmark/gmpg/icfm/halfcheetah_medium_expert.py +++ b/grl_pipelines/benchmark/gmpg/icfm/halfcheetah_medium_expert.py @@ -62,7 +62,7 @@ ), ), dataset=dict( - type="GPD4RLDataset", + type="GPD4RLTensorDictDataset", args=dict( env_id=env_id, ), diff --git a/grl_pipelines/benchmark/gmpg/icfm/halfcheetah_medium_replay.py b/grl_pipelines/benchmark/gmpg/icfm/halfcheetah_medium_replay.py index dd30364..29b13db 100644 --- a/grl_pipelines/benchmark/gmpg/icfm/halfcheetah_medium_replay.py +++ b/grl_pipelines/benchmark/gmpg/icfm/halfcheetah_medium_replay.py @@ -62,7 +62,7 @@ ), ), dataset=dict( - type="GPD4RLDataset", + type="GPD4RLTensorDictDataset", args=dict( env_id=env_id, ), diff --git a/grl_pipelines/benchmark/gmpg/icfm/hopper_medium.py b/grl_pipelines/benchmark/gmpg/icfm/hopper_medium.py index 6547a62..cb72ec6 100644 --- a/grl_pipelines/benchmark/gmpg/icfm/hopper_medium.py +++ b/grl_pipelines/benchmark/gmpg/icfm/hopper_medium.py @@ -62,7 +62,7 @@ ), ), dataset=dict( - type="GPD4RLDataset", + type="GPD4RLTensorDictDataset", args=dict( env_id=env_id, ), diff --git a/grl_pipelines/benchmark/gmpg/icfm/hopper_medium_expert.py b/grl_pipelines/benchmark/gmpg/icfm/hopper_medium_expert.py index a3f3ae4..536889f 100644 --- a/grl_pipelines/benchmark/gmpg/icfm/hopper_medium_expert.py +++ b/grl_pipelines/benchmark/gmpg/icfm/hopper_medium_expert.py @@ -62,7 +62,7 @@ ), ), dataset=dict( - type="GPD4RLDataset", + type="GPD4RLTensorDictDataset", args=dict( env_id=env_id, ), diff --git a/grl_pipelines/benchmark/gmpg/icfm/hopper_medium_replay.py b/grl_pipelines/benchmark/gmpg/icfm/hopper_medium_replay.py index 9183de2..6810d06 100644 --- a/grl_pipelines/benchmark/gmpg/icfm/hopper_medium_replay.py +++ b/grl_pipelines/benchmark/gmpg/icfm/hopper_medium_replay.py @@ -62,7 +62,7 @@ ), ), dataset=dict( - type="GPD4RLDataset", + type="GPD4RLTensorDictDataset", args=dict( env_id=env_id, ), diff --git a/grl_pipelines/benchmark/gmpg/icfm/walker2d_medium.py b/grl_pipelines/benchmark/gmpg/icfm/walker2d_medium.py index 02bffe6..9f40de7 100644 --- a/grl_pipelines/benchmark/gmpg/icfm/walker2d_medium.py +++ b/grl_pipelines/benchmark/gmpg/icfm/walker2d_medium.py @@ -62,7 +62,7 @@ ), ), dataset=dict( - type="GPD4RLDataset", + type="GPD4RLTensorDictDataset", args=dict( env_id=env_id, ), diff --git a/grl_pipelines/benchmark/gmpg/icfm/walker2d_medium_expert.py b/grl_pipelines/benchmark/gmpg/icfm/walker2d_medium_expert.py index a0035bf..d9a8a94 100644 --- a/grl_pipelines/benchmark/gmpg/icfm/walker2d_medium_expert.py +++ b/grl_pipelines/benchmark/gmpg/icfm/walker2d_medium_expert.py @@ -62,7 +62,7 @@ ), ), dataset=dict( - type="GPD4RLDataset", + type="GPD4RLTensorDictDataset", args=dict( env_id=env_id, ), diff --git a/grl_pipelines/benchmark/gmpg/icfm/walker2d_medium_replay.py b/grl_pipelines/benchmark/gmpg/icfm/walker2d_medium_replay.py index 864e827..7928cc8 100644 --- a/grl_pipelines/benchmark/gmpg/icfm/walker2d_medium_replay.py +++ b/grl_pipelines/benchmark/gmpg/icfm/walker2d_medium_replay.py @@ -62,7 +62,7 @@ ), ), dataset=dict( - type="GPD4RLDataset", + type="GPD4RLTensorDictDataset", args=dict( env_id=env_id, ), diff --git a/grl_pipelines/benchmark/gmpg/vpsde/antmaze_large_diverse.py b/grl_pipelines/benchmark/gmpg/vpsde/antmaze_large_diverse.py new file mode 100755 index 0000000..424d8ef --- /dev/null +++ b/grl_pipelines/benchmark/gmpg/vpsde/antmaze_large_diverse.py @@ -0,0 +1,206 @@ +import torch +from easydict import EasyDict + +env_id = "antmaze-large-diverse-v0" +action_size = 8 +state_size = 29 +algorithm_type = "GMPG" +solver_type = "ODESolver" +model_type = "DiffusionModel" +generative_model_type = "VPSDE" +path = dict( + type="linear_vp_sde", + beta_0=0.1, + beta_1=20.0, +) +model_loss_type = "flow_matching" +project_name = f"d4rl-{env_id}-{algorithm_type}-{generative_model_type}" +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") +t_embedding_dim = 32 +t_encoder = dict( + type="GaussianFourierProjectionTimeEncoder", + args=dict( + embed_dim=t_embedding_dim, + scale=30.0, + ), +) +model = dict( + device=device, + x_size=action_size, + solver=dict( + type="ODESolver", + args=dict( + library="torchdiffeq_adjoint", + ), + ), + path=path, + reverse_path=path, + model=dict( + type="velocity_function", + args=dict( + t_encoder=t_encoder, + backbone=dict( + type="TemporalSpatialResidualNet", + args=dict( + hidden_sizes=[512, 256, 128], + output_dim=action_size, + t_dim=t_embedding_dim, + condition_dim=state_size, + condition_hidden_dim=32, + t_condition_hidden_dim=128, + ), + ), + ), + ), +) + +config = EasyDict( + train=dict( + project=project_name, + device=device, + wandb=dict(project=f"IQL-{env_id}-{algorithm_type}-{generative_model_type}"), + simulator=dict( + type="GymEnvSimulator", + args=dict( + env_id=env_id, + ), + ), + dataset=dict( + type="GPOD4RLDataset", + args=dict( + env_id=env_id, + device=device, + ), + ), + model=dict( + GPPolicy=dict( + device=device, + model_type=model_type, + model_loss_type=model_loss_type, + model=model, + critic=dict( + device=device, + q_alpha=1.0, + DoubleQNetwork=dict( + backbone=dict( + type="ConcatenateMLP", + args=dict( + hidden_sizes=[action_size + state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + ), + VNetwork=dict( + backbone=dict( + type="MultiLayerPerceptron", + args=dict( + hidden_sizes=[state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + ), + ), + ), + GuidedPolicy=dict( + model_type=model_type, + model=model, + ), + ), + parameter=dict( + algorithm_type=algorithm_type, + behaviour_policy=dict( + batch_size=4096, + learning_rate=1e-4, + epochs=4000, + ), + t_span=32, + critic=dict( + batch_size=4096, + epochs=4000, + learning_rate=1e-4, + discount_factor=0.99, + update_momentum=0.005, + tau=0.9, + method="iql", + ), + guided_policy=dict( + batch_size=40960, + epochs=100, + learning_rate=5e-6, + copy_from_basemodel=True, + gradtime_step=1000, + eta=0.5, + ), + evaluation=dict( + eval=True, + repeat=10, + interval=5, + ), + checkpoint_path=f"./{project_name}/checkpoint", + checkpoint_freq=5, + ), + ), + deploy=dict( + device=device, + env=dict( + env_id=env_id, + seed=0, + ), + t_span=32, + ), +) + + +if __name__ == "__main__": + + import gym + import d4rl + import numpy as np + + from grl.algorithms.gmpg import GPAlgorithm + from grl.utils.log import log + + def gp_pipeline(config): + + gp = GPAlgorithm(config) + + # --------------------------------------- + # Customized train code ↓ + # --------------------------------------- + gp.train() + # --------------------------------------- + # Customized train code ↑ + # --------------------------------------- + + # --------------------------------------- + # Customized deploy code ↓ + # --------------------------------------- + + agent = gp.deploy() + env = gym.make(config.deploy.env.env_id) + total_reward_list = [] + for i in range(100): + observation = env.reset() + total_reward = 0 + while True: + # env.render() + observation, reward, done, _ = env.step(agent.act(observation)) + total_reward += reward + if done: + observation = env.reset() + print(f"Episode {i}, total_reward: {total_reward}") + total_reward_list.append(total_reward) + break + + print( + f"Average total reward: {np.mean(total_reward_list)}, std: {np.std(total_reward_list)}" + ) + + # --------------------------------------- + # Customized deploy code ↑ + # --------------------------------------- + + log.info("config: \n{}".format(config)) + gp_pipeline(config) diff --git a/grl_pipelines/benchmark/gmpg/vpsde/antmaze_large_play.py b/grl_pipelines/benchmark/gmpg/vpsde/antmaze_large_play.py new file mode 100755 index 0000000..43838fc --- /dev/null +++ b/grl_pipelines/benchmark/gmpg/vpsde/antmaze_large_play.py @@ -0,0 +1,206 @@ +import torch +from easydict import EasyDict + +env_id = "antmaze-large-play-v0" +action_size = 8 +state_size = 29 +algorithm_type = "GMPG" +solver_type = "ODESolver" +model_type = "DiffusionModel" +generative_model_type = "VPSDE" +path = dict( + type="linear_vp_sde", + beta_0=0.1, + beta_1=20.0, +) +model_loss_type = "flow_matching" +project_name = f"d4rl-{env_id}-{algorithm_type}-{generative_model_type}" +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") +t_embedding_dim = 32 +t_encoder = dict( + type="GaussianFourierProjectionTimeEncoder", + args=dict( + embed_dim=t_embedding_dim, + scale=30.0, + ), +) +model = dict( + device=device, + x_size=action_size, + solver=dict( + type="ODESolver", + args=dict( + library="torchdiffeq_adjoint", + ), + ), + path=path, + reverse_path=path, + model=dict( + type="velocity_function", + args=dict( + t_encoder=t_encoder, + backbone=dict( + type="TemporalSpatialResidualNet", + args=dict( + hidden_sizes=[512, 256, 128], + output_dim=action_size, + t_dim=t_embedding_dim, + condition_dim=state_size, + condition_hidden_dim=32, + t_condition_hidden_dim=128, + ), + ), + ), + ), +) + +config = EasyDict( + train=dict( + project=project_name, + device=device, + wandb=dict(project=f"IQL-{env_id}-{algorithm_type}-{generative_model_type}"), + simulator=dict( + type="GymEnvSimulator", + args=dict( + env_id=env_id, + ), + ), + dataset=dict( + type="GPOD4RLDataset", + args=dict( + env_id=env_id, + device=device, + ), + ), + model=dict( + GPPolicy=dict( + device=device, + model_type=model_type, + model_loss_type=model_loss_type, + model=model, + critic=dict( + device=device, + q_alpha=1.0, + DoubleQNetwork=dict( + backbone=dict( + type="ConcatenateMLP", + args=dict( + hidden_sizes=[action_size + state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + ), + VNetwork=dict( + backbone=dict( + type="MultiLayerPerceptron", + args=dict( + hidden_sizes=[state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + ), + ), + ), + GuidedPolicy=dict( + model_type=model_type, + model=model, + ), + ), + parameter=dict( + algorithm_type=algorithm_type, + behaviour_policy=dict( + batch_size=4096, + learning_rate=1e-4, + epochs=4000, + ), + t_span=32, + critic=dict( + batch_size=4096, + epochs=2000, + learning_rate=1e-4, + discount_factor=0.99, + update_momentum=0.005, + tau=0.9, + method="iql", + ), + guided_policy=dict( + batch_size=40960, + epochs=100, + learning_rate=5e-6, + copy_from_basemodel=True, + gradtime_step=1000, + eta=0.25, + ), + evaluation=dict( + eval=True, + repeat=10, + interval=5, + ), + checkpoint_path=f"./{project_name}/checkpoint", + checkpoint_freq=5, + ), + ), + deploy=dict( + device=device, + env=dict( + env_id=env_id, + seed=0, + ), + t_span=32, + ), +) + + +if __name__ == "__main__": + + import gym + import d4rl + import numpy as np + + from grl.algorithms.gmpg import GPAlgorithm + from grl.utils.log import log + + def gp_pipeline(config): + + gp = GPAlgorithm(config) + + # --------------------------------------- + # Customized train code ↓ + # --------------------------------------- + gp.train() + # --------------------------------------- + # Customized train code ↑ + # --------------------------------------- + + # --------------------------------------- + # Customized deploy code ↓ + # --------------------------------------- + + agent = gp.deploy() + env = gym.make(config.deploy.env.env_id) + total_reward_list = [] + for i in range(100): + observation = env.reset() + total_reward = 0 + while True: + # env.render() + observation, reward, done, _ = env.step(agent.act(observation)) + total_reward += reward + if done: + observation = env.reset() + print(f"Episode {i}, total_reward: {total_reward}") + total_reward_list.append(total_reward) + break + + print( + f"Average total reward: {np.mean(total_reward_list)}, std: {np.std(total_reward_list)}" + ) + + # --------------------------------------- + # Customized deploy code ↑ + # --------------------------------------- + + log.info("config: \n{}".format(config)) + gp_pipeline(config) diff --git a/grl_pipelines/benchmark/gmpg/vpsde/antmaze_medium_diverse.py b/grl_pipelines/benchmark/gmpg/vpsde/antmaze_medium_diverse.py new file mode 100755 index 0000000..6635ec5 --- /dev/null +++ b/grl_pipelines/benchmark/gmpg/vpsde/antmaze_medium_diverse.py @@ -0,0 +1,206 @@ +import torch +from easydict import EasyDict + +env_id = "antmaze-medium-diverse-v0" +action_size = 8 +state_size = 29 +algorithm_type = "GMPG" +solver_type = "ODESolver" +model_type = "DiffusionModel" +generative_model_type = "VPSDE" +path = dict( + type="linear_vp_sde", + beta_0=0.1, + beta_1=20.0, +) +model_loss_type = "flow_matching" +project_name = f"d4rl-{env_id}-{algorithm_type}-{generative_model_type}" +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") +t_embedding_dim = 32 +t_encoder = dict( + type="GaussianFourierProjectionTimeEncoder", + args=dict( + embed_dim=t_embedding_dim, + scale=30.0, + ), +) +model = dict( + device=device, + x_size=action_size, + solver=dict( + type="ODESolver", + args=dict( + library="torchdiffeq_adjoint", + ), + ), + path=path, + reverse_path=path, + model=dict( + type="velocity_function", + args=dict( + t_encoder=t_encoder, + backbone=dict( + type="TemporalSpatialResidualNet", + args=dict( + hidden_sizes=[512, 256, 128], + output_dim=action_size, + t_dim=t_embedding_dim, + condition_dim=state_size, + condition_hidden_dim=32, + t_condition_hidden_dim=128, + ), + ), + ), + ), +) + +config = EasyDict( + train=dict( + project=project_name, + device=device, + wandb=dict(project=f"IQL-{env_id}-{algorithm_type}-{generative_model_type}"), + simulator=dict( + type="GymEnvSimulator", + args=dict( + env_id=env_id, + ), + ), + dataset=dict( + type="GPOD4RLDataset", + args=dict( + env_id=env_id, + device=device, + ), + ), + model=dict( + GPPolicy=dict( + device=device, + model_type=model_type, + model_loss_type=model_loss_type, + model=model, + critic=dict( + device=device, + q_alpha=1.0, + DoubleQNetwork=dict( + backbone=dict( + type="ConcatenateMLP", + args=dict( + hidden_sizes=[action_size + state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + ), + VNetwork=dict( + backbone=dict( + type="MultiLayerPerceptron", + args=dict( + hidden_sizes=[state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + ), + ), + ), + GuidedPolicy=dict( + model_type=model_type, + model=model, + ), + ), + parameter=dict( + algorithm_type=algorithm_type, + behaviour_policy=dict( + batch_size=4096, + learning_rate=1e-4, + epochs=4000, + ), + t_span=32, + critic=dict( + batch_size=4096, + epochs=4000, + learning_rate=1e-4, + discount_factor=0.99, + update_momentum=0.005, + tau=0.9, + method="iql", + ), + guided_policy=dict( + batch_size=40960, + epochs=100, + learning_rate=5e-6, + copy_from_basemodel=True, + gradtime_step=1000, + eta=0.5, + ), + evaluation=dict( + eval=True, + repeat=10, + interval=5, + ), + checkpoint_path=f"./{project_name}/checkpoint", + checkpoint_freq=5, + ), + ), + deploy=dict( + device=device, + env=dict( + env_id=env_id, + seed=0, + ), + t_span=32, + ), +) + + +if __name__ == "__main__": + + import gym + import d4rl + import numpy as np + + from grl.algorithms.gmpg import GPAlgorithm + from grl.utils.log import log + + def gp_pipeline(config): + + gp = GPAlgorithm(config) + + # --------------------------------------- + # Customized train code ↓ + # --------------------------------------- + gp.train() + # --------------------------------------- + # Customized train code ↑ + # --------------------------------------- + + # --------------------------------------- + # Customized deploy code ↓ + # --------------------------------------- + + agent = gp.deploy() + env = gym.make(config.deploy.env.env_id) + total_reward_list = [] + for i in range(100): + observation = env.reset() + total_reward = 0 + while True: + # env.render() + observation, reward, done, _ = env.step(agent.act(observation)) + total_reward += reward + if done: + observation = env.reset() + print(f"Episode {i}, total_reward: {total_reward}") + total_reward_list.append(total_reward) + break + + print( + f"Average total reward: {np.mean(total_reward_list)}, std: {np.std(total_reward_list)}" + ) + + # --------------------------------------- + # Customized deploy code ↑ + # --------------------------------------- + + log.info("config: \n{}".format(config)) + gp_pipeline(config) diff --git a/grl_pipelines/benchmark/gmpg/vpsde/antmaze_medium_play.py b/grl_pipelines/benchmark/gmpg/vpsde/antmaze_medium_play.py new file mode 100755 index 0000000..5a92186 --- /dev/null +++ b/grl_pipelines/benchmark/gmpg/vpsde/antmaze_medium_play.py @@ -0,0 +1,206 @@ +import torch +from easydict import EasyDict + +env_id = "antmaze-medium-play-v0" +action_size = 8 +state_size = 29 +algorithm_type = "GMPG" +solver_type = "ODESolver" +model_type = "DiffusionModel" +generative_model_type = "VPSDE" +path = dict( + type="linear_vp_sde", + beta_0=0.1, + beta_1=20.0, +) +model_loss_type = "flow_matching" +project_name = f"d4rl-{env_id}-{algorithm_type}-{generative_model_type}" +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") +t_embedding_dim = 32 +t_encoder = dict( + type="GaussianFourierProjectionTimeEncoder", + args=dict( + embed_dim=t_embedding_dim, + scale=30.0, + ), +) +model = dict( + device=device, + x_size=action_size, + solver=dict( + type="ODESolver", + args=dict( + library="torchdiffeq_adjoint", + ), + ), + path=path, + reverse_path=path, + model=dict( + type="velocity_function", + args=dict( + t_encoder=t_encoder, + backbone=dict( + type="TemporalSpatialResidualNet", + args=dict( + hidden_sizes=[512, 256, 128], + output_dim=action_size, + t_dim=t_embedding_dim, + condition_dim=state_size, + condition_hidden_dim=32, + t_condition_hidden_dim=128, + ), + ), + ), + ), +) + +config = EasyDict( + train=dict( + project=project_name, + device=device, + wandb=dict(project=f"IQL-{env_id}-{algorithm_type}-{generative_model_type}"), + simulator=dict( + type="GymEnvSimulator", + args=dict( + env_id=env_id, + ), + ), + dataset=dict( + type="GPOD4RLDataset", + args=dict( + env_id=env_id, + device=device, + ), + ), + model=dict( + GPPolicy=dict( + device=device, + model_type=model_type, + model_loss_type=model_loss_type, + model=model, + critic=dict( + device=device, + q_alpha=1.0, + DoubleQNetwork=dict( + backbone=dict( + type="ConcatenateMLP", + args=dict( + hidden_sizes=[action_size + state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + ), + VNetwork=dict( + backbone=dict( + type="MultiLayerPerceptron", + args=dict( + hidden_sizes=[state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + ), + ), + ), + GuidedPolicy=dict( + model_type=model_type, + model=model, + ), + ), + parameter=dict( + algorithm_type=algorithm_type, + behaviour_policy=dict( + batch_size=4096, + learning_rate=1e-4, + epochs=4000, + ), + t_span=32, + critic=dict( + batch_size=4096, + epochs=4000, + learning_rate=1e-4, + discount_factor=0.99, + update_momentum=0.005, + tau=0.9, + method="iql", + ), + guided_policy=dict( + batch_size=40960, + epochs=100, + learning_rate=1e-6, + copy_from_basemodel=True, + gradtime_step=1000, + eta=0.25, + ), + evaluation=dict( + eval=True, + repeat=10, + interval=5, + ), + checkpoint_path=f"./{project_name}/checkpoint", + checkpoint_freq=5, + ), + ), + deploy=dict( + device=device, + env=dict( + env_id=env_id, + seed=0, + ), + t_span=32, + ), +) + + +if __name__ == "__main__": + + import gym + import d4rl + import numpy as np + + from grl.algorithms.gmpg import GPAlgorithm + from grl.utils.log import log + + def gp_pipeline(config): + + gp = GPAlgorithm(config) + + # --------------------------------------- + # Customized train code ↓ + # --------------------------------------- + gp.train() + # --------------------------------------- + # Customized train code ↑ + # --------------------------------------- + + # --------------------------------------- + # Customized deploy code ↓ + # --------------------------------------- + + agent = gp.deploy() + env = gym.make(config.deploy.env.env_id) + total_reward_list = [] + for i in range(100): + observation = env.reset() + total_reward = 0 + while True: + # env.render() + observation, reward, done, _ = env.step(agent.act(observation)) + total_reward += reward + if done: + observation = env.reset() + print(f"Episode {i}, total_reward: {total_reward}") + total_reward_list.append(total_reward) + break + + print( + f"Average total reward: {np.mean(total_reward_list)}, std: {np.std(total_reward_list)}" + ) + + # --------------------------------------- + # Customized deploy code ↑ + # --------------------------------------- + + log.info("config: \n{}".format(config)) + gp_pipeline(config) diff --git a/grl_pipelines/benchmark/gmpg/vpsde/antmaze_umaze.py b/grl_pipelines/benchmark/gmpg/vpsde/antmaze_umaze.py new file mode 100755 index 0000000..7149cf0 --- /dev/null +++ b/grl_pipelines/benchmark/gmpg/vpsde/antmaze_umaze.py @@ -0,0 +1,206 @@ +import torch +from easydict import EasyDict + +env_id = "antmaze-umaze-v0" +action_size = 8 +state_size = 29 +algorithm_type = "GMPG" +solver_type = "ODESolver" +model_type = "DiffusionModel" +generative_model_type = "VPSDE" +path = dict( + type="linear_vp_sde", + beta_0=0.1, + beta_1=20.0, +) +model_loss_type = "flow_matching" +project_name = f"d4rl-{env_id}-{algorithm_type}-{generative_model_type}" +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") +t_embedding_dim = 32 +t_encoder = dict( + type="GaussianFourierProjectionTimeEncoder", + args=dict( + embed_dim=t_embedding_dim, + scale=30.0, + ), +) +model = dict( + device=device, + x_size=action_size, + solver=dict( + type="ODESolver", + args=dict( + library="torchdiffeq_adjoint", + ), + ), + path=path, + reverse_path=path, + model=dict( + type="velocity_function", + args=dict( + t_encoder=t_encoder, + backbone=dict( + type="TemporalSpatialResidualNet", + args=dict( + hidden_sizes=[512, 256, 128], + output_dim=action_size, + t_dim=t_embedding_dim, + condition_dim=state_size, + condition_hidden_dim=32, + t_condition_hidden_dim=128, + ), + ), + ), + ), +) + +config = EasyDict( + train=dict( + project=project_name, + device=device, + wandb=dict(project=f"IQL-{env_id}-{algorithm_type}-{generative_model_type}"), + simulator=dict( + type="GymEnvSimulator", + args=dict( + env_id=env_id, + ), + ), + dataset=dict( + type="GPOD4RLDataset", + args=dict( + env_id=env_id, + device=device, + ), + ), + model=dict( + GPPolicy=dict( + device=device, + model_type=model_type, + model_loss_type=model_loss_type, + model=model, + critic=dict( + device=device, + q_alpha=1.0, + DoubleQNetwork=dict( + backbone=dict( + type="ConcatenateMLP", + args=dict( + hidden_sizes=[action_size + state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + ), + VNetwork=dict( + backbone=dict( + type="MultiLayerPerceptron", + args=dict( + hidden_sizes=[state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + ), + ), + ), + GuidedPolicy=dict( + model_type=model_type, + model=model, + ), + ), + parameter=dict( + algorithm_type=algorithm_type, + behaviour_policy=dict( + batch_size=4096, + learning_rate=1e-4, + epochs=4000, + ), + t_span=32, + critic=dict( + batch_size=4096, + epochs=4000, + learning_rate=1e-4, + discount_factor=0.99, + update_momentum=0.005, + tau=0.9, + method="iql", + ), + guided_policy=dict( + batch_size=40960, + epochs=100, + learning_rate=2e-6, + copy_from_basemodel=True, + gradtime_step=1000, + eta=2, + ), + evaluation=dict( + eval=True, + repeat=10, + interval=5, + ), + checkpoint_path=f"./{project_name}/checkpoint", + checkpoint_freq=5, + ), + ), + deploy=dict( + device=device, + env=dict( + env_id=env_id, + seed=0, + ), + t_span=32, + ), +) + + +if __name__ == "__main__": + + import gym + import d4rl + import numpy as np + + from grl.algorithms.gmpg import GPAlgorithm + from grl.utils.log import log + + def gp_pipeline(config): + + gp = GPAlgorithm(config) + + # --------------------------------------- + # Customized train code ↓ + # --------------------------------------- + gp.train() + # --------------------------------------- + # Customized train code ↑ + # --------------------------------------- + + # --------------------------------------- + # Customized deploy code ↓ + # --------------------------------------- + + agent = gp.deploy() + env = gym.make(config.deploy.env.env_id) + total_reward_list = [] + for i in range(100): + observation = env.reset() + total_reward = 0 + while True: + # env.render() + observation, reward, done, _ = env.step(agent.act(observation)) + total_reward += reward + if done: + observation = env.reset() + print(f"Episode {i}, total_reward: {total_reward}") + total_reward_list.append(total_reward) + break + + print( + f"Average total reward: {np.mean(total_reward_list)}, std: {np.std(total_reward_list)}" + ) + + # --------------------------------------- + # Customized deploy code ↑ + # --------------------------------------- + + log.info("config: \n{}".format(config)) + gp_pipeline(config) diff --git a/grl_pipelines/benchmark/gmpg/vpsde/antmaze_umaze_diverse.py b/grl_pipelines/benchmark/gmpg/vpsde/antmaze_umaze_diverse.py new file mode 100755 index 0000000..d6aba83 --- /dev/null +++ b/grl_pipelines/benchmark/gmpg/vpsde/antmaze_umaze_diverse.py @@ -0,0 +1,206 @@ +import torch +from easydict import EasyDict + +env_id = "antmaze-umaze-diverse-v0" +action_size = 8 +state_size = 29 +algorithm_type = "GMPG" +solver_type = "ODESolver" +model_type = "DiffusionModel" +generative_model_type = "VPSDE" +path = dict( + type="linear_vp_sde", + beta_0=0.1, + beta_1=20.0, +) +model_loss_type = "flow_matching" +project_name = f"d4rl-{env_id}-{algorithm_type}-{generative_model_type}" +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") +t_embedding_dim = 32 +t_encoder = dict( + type="GaussianFourierProjectionTimeEncoder", + args=dict( + embed_dim=t_embedding_dim, + scale=30.0, + ), +) +model = dict( + device=device, + x_size=action_size, + solver=dict( + type="ODESolver", + args=dict( + library="torchdiffeq_adjoint", + ), + ), + path=path, + reverse_path=path, + model=dict( + type="velocity_function", + args=dict( + t_encoder=t_encoder, + backbone=dict( + type="TemporalSpatialResidualNet", + args=dict( + hidden_sizes=[512, 256, 128], + output_dim=action_size, + t_dim=t_embedding_dim, + condition_dim=state_size, + condition_hidden_dim=32, + t_condition_hidden_dim=128, + ), + ), + ), + ), +) + +config = EasyDict( + train=dict( + project=project_name, + device=device, + wandb=dict(project=f"IQL-{env_id}-{algorithm_type}-{generative_model_type}"), + simulator=dict( + type="GymEnvSimulator", + args=dict( + env_id=env_id, + ), + ), + dataset=dict( + type="GPOD4RLDataset", + args=dict( + env_id=env_id, + device=device, + ), + ), + model=dict( + GPPolicy=dict( + device=device, + model_type=model_type, + model_loss_type=model_loss_type, + model=model, + critic=dict( + device=device, + q_alpha=1.0, + DoubleQNetwork=dict( + backbone=dict( + type="ConcatenateMLP", + args=dict( + hidden_sizes=[action_size + state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + ), + VNetwork=dict( + backbone=dict( + type="MultiLayerPerceptron", + args=dict( + hidden_sizes=[state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + ), + ), + ), + GuidedPolicy=dict( + model_type=model_type, + model=model, + ), + ), + parameter=dict( + algorithm_type=algorithm_type, + behaviour_policy=dict( + batch_size=4096, + learning_rate=1e-4, + epochs=4000, + ), + t_span=32, + critic=dict( + batch_size=4096, + epochs=4000, + learning_rate=1e-4, + discount_factor=0.99, + update_momentum=0.005, + tau=0.9, + method="iql", + ), + guided_policy=dict( + batch_size=40960, + epochs=100, + learning_rate=5e-6, + copy_from_basemodel=True, + gradtime_step=1000, + eta=2, + ), + evaluation=dict( + eval=True, + repeat=10, + interval=5, + ), + checkpoint_path=f"./{project_name}/checkpoint", + checkpoint_freq=5, + ), + ), + deploy=dict( + device=device, + env=dict( + env_id=env_id, + seed=0, + ), + t_span=32, + ), +) + + +if __name__ == "__main__": + + import gym + import d4rl + import numpy as np + + from grl.algorithms.gmpg import GPAlgorithm + from grl.utils.log import log + + def gp_pipeline(config): + + gp = GPAlgorithm(config) + + # --------------------------------------- + # Customized train code ↓ + # --------------------------------------- + gp.train() + # --------------------------------------- + # Customized train code ↑ + # --------------------------------------- + + # --------------------------------------- + # Customized deploy code ↓ + # --------------------------------------- + + agent = gp.deploy() + env = gym.make(config.deploy.env.env_id) + total_reward_list = [] + for i in range(100): + observation = env.reset() + total_reward = 0 + while True: + # env.render() + observation, reward, done, _ = env.step(agent.act(observation)) + total_reward += reward + if done: + observation = env.reset() + print(f"Episode {i}, total_reward: {total_reward}") + total_reward_list.append(total_reward) + break + + print( + f"Average total reward: {np.mean(total_reward_list)}, std: {np.std(total_reward_list)}" + ) + + # --------------------------------------- + # Customized deploy code ↑ + # --------------------------------------- + + log.info("config: \n{}".format(config)) + gp_pipeline(config) diff --git a/grl_pipelines/benchmark/gmpg/vpsde/halfcheetah_medium.py b/grl_pipelines/benchmark/gmpg/vpsde/halfcheetah_medium.py index edaf922..34647c1 100644 --- a/grl_pipelines/benchmark/gmpg/vpsde/halfcheetah_medium.py +++ b/grl_pipelines/benchmark/gmpg/vpsde/halfcheetah_medium.py @@ -66,7 +66,7 @@ ), ), dataset=dict( - type="GPD4RLDataset", + type="GPD4RLTensorDictDataset", args=dict( env_id=env_id, ), diff --git a/grl_pipelines/benchmark/gmpg/vpsde/halfcheetah_medium_expert.py b/grl_pipelines/benchmark/gmpg/vpsde/halfcheetah_medium_expert.py index 50edbad..49c1dfc 100644 --- a/grl_pipelines/benchmark/gmpg/vpsde/halfcheetah_medium_expert.py +++ b/grl_pipelines/benchmark/gmpg/vpsde/halfcheetah_medium_expert.py @@ -66,7 +66,7 @@ ), ), dataset=dict( - type="GPD4RLDataset", + type="GPD4RLTensorDictDataset", args=dict( env_id=env_id, ), diff --git a/grl_pipelines/benchmark/gmpg/vpsde/halfcheetah_medium_replay.py b/grl_pipelines/benchmark/gmpg/vpsde/halfcheetah_medium_replay.py index 9a366e2..e6946e4 100644 --- a/grl_pipelines/benchmark/gmpg/vpsde/halfcheetah_medium_replay.py +++ b/grl_pipelines/benchmark/gmpg/vpsde/halfcheetah_medium_replay.py @@ -66,7 +66,7 @@ ), ), dataset=dict( - type="GPD4RLDataset", + type="GPD4RLTensorDictDataset", args=dict( env_id=env_id, ), diff --git a/grl_pipelines/benchmark/gmpg/vpsde/hopper_medium.py b/grl_pipelines/benchmark/gmpg/vpsde/hopper_medium.py index e14765b..961c6af 100644 --- a/grl_pipelines/benchmark/gmpg/vpsde/hopper_medium.py +++ b/grl_pipelines/benchmark/gmpg/vpsde/hopper_medium.py @@ -66,7 +66,7 @@ ), ), dataset=dict( - type="GPD4RLDataset", + type="GPD4RLTensorDictDataset", args=dict( env_id=env_id, ), diff --git a/grl_pipelines/benchmark/gmpg/vpsde/hopper_medium_expert.py b/grl_pipelines/benchmark/gmpg/vpsde/hopper_medium_expert.py index 99f285a..7ebc8eb 100644 --- a/grl_pipelines/benchmark/gmpg/vpsde/hopper_medium_expert.py +++ b/grl_pipelines/benchmark/gmpg/vpsde/hopper_medium_expert.py @@ -66,7 +66,7 @@ ), ), dataset=dict( - type="GPD4RLDataset", + type="GPD4RLTensorDictDataset", args=dict( env_id=env_id, ), diff --git a/grl_pipelines/benchmark/gmpg/vpsde/hopper_medium_replay.py b/grl_pipelines/benchmark/gmpg/vpsde/hopper_medium_replay.py index 5a3dbc8..f4ab02f 100644 --- a/grl_pipelines/benchmark/gmpg/vpsde/hopper_medium_replay.py +++ b/grl_pipelines/benchmark/gmpg/vpsde/hopper_medium_replay.py @@ -66,7 +66,7 @@ ), ), dataset=dict( - type="GPD4RLDataset", + type="GPD4RLTensorDictDataset", args=dict( env_id=env_id, ), diff --git a/grl_pipelines/benchmark/gmpg/vpsde/walker2d_medium.py b/grl_pipelines/benchmark/gmpg/vpsde/walker2d_medium.py index e95b3dc..9eb9d33 100644 --- a/grl_pipelines/benchmark/gmpg/vpsde/walker2d_medium.py +++ b/grl_pipelines/benchmark/gmpg/vpsde/walker2d_medium.py @@ -66,7 +66,7 @@ ), ), dataset=dict( - type="GPD4RLDataset", + type="GPD4RLTensorDictDataset", args=dict( env_id=env_id, ), diff --git a/grl_pipelines/benchmark/gmpg/vpsde/walker2d_medium_expert.py b/grl_pipelines/benchmark/gmpg/vpsde/walker2d_medium_expert.py index f1a0f97..71d5873 100644 --- a/grl_pipelines/benchmark/gmpg/vpsde/walker2d_medium_expert.py +++ b/grl_pipelines/benchmark/gmpg/vpsde/walker2d_medium_expert.py @@ -66,7 +66,7 @@ ), ), dataset=dict( - type="GPD4RLDataset", + type="GPD4RLTensorDictDataset", args=dict( env_id=env_id, ), diff --git a/grl_pipelines/benchmark/gmpg/vpsde/walker2d_medium_replay.py b/grl_pipelines/benchmark/gmpg/vpsde/walker2d_medium_replay.py index ae9620b..4519afb 100644 --- a/grl_pipelines/benchmark/gmpg/vpsde/walker2d_medium_replay.py +++ b/grl_pipelines/benchmark/gmpg/vpsde/walker2d_medium_replay.py @@ -66,7 +66,7 @@ ), ), dataset=dict( - type="GPD4RLDataset", + type="GPD4RLTensorDictDataset", args=dict( env_id=env_id, ), diff --git a/grl_pipelines/benchmark/gmpo/gvp/halfcheetah_medium.py b/grl_pipelines/benchmark/gmpo/gvp/halfcheetah_medium.py index 5c2c9ac..631b127 100644 --- a/grl_pipelines/benchmark/gmpo/gvp/halfcheetah_medium.py +++ b/grl_pipelines/benchmark/gmpo/gvp/halfcheetah_medium.py @@ -62,7 +62,7 @@ ), ), dataset=dict( - type="GPD4RLDataset", + type="GPD4RLTensorDictDataset", args=dict( env_id=env_id, ), diff --git a/grl_pipelines/benchmark/gmpo/gvp/halfcheetah_medium_expert.py b/grl_pipelines/benchmark/gmpo/gvp/halfcheetah_medium_expert.py index 8298843..c4d17e2 100644 --- a/grl_pipelines/benchmark/gmpo/gvp/halfcheetah_medium_expert.py +++ b/grl_pipelines/benchmark/gmpo/gvp/halfcheetah_medium_expert.py @@ -62,7 +62,7 @@ ), ), dataset=dict( - type="GPD4RLDataset", + type="GPD4RLTensorDictDataset", args=dict( env_id=env_id, ), diff --git a/grl_pipelines/benchmark/gmpo/gvp/halfcheetah_medium_replay.py b/grl_pipelines/benchmark/gmpo/gvp/halfcheetah_medium_replay.py index 0a4ea39..c397bad 100644 --- a/grl_pipelines/benchmark/gmpo/gvp/halfcheetah_medium_replay.py +++ b/grl_pipelines/benchmark/gmpo/gvp/halfcheetah_medium_replay.py @@ -62,7 +62,7 @@ ), ), dataset=dict( - type="GPD4RLDataset", + type="GPD4RLTensorDictDataset", args=dict( env_id=env_id, ), diff --git a/grl_pipelines/benchmark/gmpo/gvp/hopper_medium.py b/grl_pipelines/benchmark/gmpo/gvp/hopper_medium.py index 2a3acc3..083dc06 100644 --- a/grl_pipelines/benchmark/gmpo/gvp/hopper_medium.py +++ b/grl_pipelines/benchmark/gmpo/gvp/hopper_medium.py @@ -62,7 +62,7 @@ ), ), dataset=dict( - type="GPD4RLDataset", + type="GPD4RLTensorDictDataset", args=dict( env_id=env_id, ), diff --git a/grl_pipelines/benchmark/gmpo/gvp/hopper_medium_expert.py b/grl_pipelines/benchmark/gmpo/gvp/hopper_medium_expert.py index e478d83..adcb978 100644 --- a/grl_pipelines/benchmark/gmpo/gvp/hopper_medium_expert.py +++ b/grl_pipelines/benchmark/gmpo/gvp/hopper_medium_expert.py @@ -62,7 +62,7 @@ ), ), dataset=dict( - type="GPD4RLDataset", + type="GPD4RLTensorDictDataset", args=dict( env_id=env_id, ), diff --git a/grl_pipelines/benchmark/gmpo/gvp/hopper_medium_replay.py b/grl_pipelines/benchmark/gmpo/gvp/hopper_medium_replay.py index 9e0ebad..d82fc50 100644 --- a/grl_pipelines/benchmark/gmpo/gvp/hopper_medium_replay.py +++ b/grl_pipelines/benchmark/gmpo/gvp/hopper_medium_replay.py @@ -62,7 +62,7 @@ ), ), dataset=dict( - type="GPD4RLDataset", + type="GPD4RLTensorDictDataset", args=dict( env_id=env_id, ), diff --git a/grl_pipelines/benchmark/gmpo/gvp/walker2d_medium.py b/grl_pipelines/benchmark/gmpo/gvp/walker2d_medium.py index a87757a..6548b6e 100644 --- a/grl_pipelines/benchmark/gmpo/gvp/walker2d_medium.py +++ b/grl_pipelines/benchmark/gmpo/gvp/walker2d_medium.py @@ -62,7 +62,7 @@ ), ), dataset=dict( - type="GPD4RLDataset", + type="GPD4RLTensorDictDataset", args=dict( env_id=env_id, ), diff --git a/grl_pipelines/benchmark/gmpo/gvp/walker2d_medium_expert.py b/grl_pipelines/benchmark/gmpo/gvp/walker2d_medium_expert.py index 3802db5..7c929b4 100644 --- a/grl_pipelines/benchmark/gmpo/gvp/walker2d_medium_expert.py +++ b/grl_pipelines/benchmark/gmpo/gvp/walker2d_medium_expert.py @@ -62,7 +62,7 @@ ), ), dataset=dict( - type="GPD4RLDataset", + type="GPD4RLTensorDictDataset", args=dict( env_id=env_id, ), diff --git a/grl_pipelines/benchmark/gmpo/gvp/walker2d_medium_replay.py b/grl_pipelines/benchmark/gmpo/gvp/walker2d_medium_replay.py index ac11fa8..8853f38 100644 --- a/grl_pipelines/benchmark/gmpo/gvp/walker2d_medium_replay.py +++ b/grl_pipelines/benchmark/gmpo/gvp/walker2d_medium_replay.py @@ -62,7 +62,7 @@ ), ), dataset=dict( - type="GPD4RLDataset", + type="GPD4RLTensorDictDataset", args=dict( env_id=env_id, ), diff --git a/grl_pipelines/benchmark/gmpo/icfm/halfcheetah_medium.py b/grl_pipelines/benchmark/gmpo/icfm/halfcheetah_medium.py index 1a7cc50..f28cc82 100644 --- a/grl_pipelines/benchmark/gmpo/icfm/halfcheetah_medium.py +++ b/grl_pipelines/benchmark/gmpo/icfm/halfcheetah_medium.py @@ -62,7 +62,7 @@ ), ), dataset=dict( - type="GPD4RLDataset", + type="GPD4RLTensorDictDataset", args=dict( env_id=env_id, ), diff --git a/grl_pipelines/benchmark/gmpo/icfm/halfcheetah_medium_expert.py b/grl_pipelines/benchmark/gmpo/icfm/halfcheetah_medium_expert.py index 52e42c6..357b1a6 100644 --- a/grl_pipelines/benchmark/gmpo/icfm/halfcheetah_medium_expert.py +++ b/grl_pipelines/benchmark/gmpo/icfm/halfcheetah_medium_expert.py @@ -62,7 +62,7 @@ ), ), dataset=dict( - type="GPD4RLDataset", + type="GPD4RLTensorDictDataset", args=dict( env_id=env_id, ), diff --git a/grl_pipelines/benchmark/gmpo/icfm/halfcheetah_medium_replay.py b/grl_pipelines/benchmark/gmpo/icfm/halfcheetah_medium_replay.py index 73f363c..eea19f6 100644 --- a/grl_pipelines/benchmark/gmpo/icfm/halfcheetah_medium_replay.py +++ b/grl_pipelines/benchmark/gmpo/icfm/halfcheetah_medium_replay.py @@ -62,7 +62,7 @@ ), ), dataset=dict( - type="GPD4RLDataset", + type="GPD4RLTensorDictDataset", args=dict( env_id=env_id, ), diff --git a/grl_pipelines/benchmark/gmpo/icfm/hopper_medium.py b/grl_pipelines/benchmark/gmpo/icfm/hopper_medium.py index cd1419c..e22821b 100644 --- a/grl_pipelines/benchmark/gmpo/icfm/hopper_medium.py +++ b/grl_pipelines/benchmark/gmpo/icfm/hopper_medium.py @@ -62,7 +62,7 @@ ), ), dataset=dict( - type="GPD4RLDataset", + type="GPD4RLTensorDictDataset", args=dict( env_id=env_id, ), diff --git a/grl_pipelines/benchmark/gmpo/icfm/hopper_medium_expert.py b/grl_pipelines/benchmark/gmpo/icfm/hopper_medium_expert.py index 75da341..81e4905 100644 --- a/grl_pipelines/benchmark/gmpo/icfm/hopper_medium_expert.py +++ b/grl_pipelines/benchmark/gmpo/icfm/hopper_medium_expert.py @@ -62,7 +62,7 @@ ), ), dataset=dict( - type="GPD4RLDataset", + type="GPD4RLTensorDictDataset", args=dict( env_id=env_id, ), diff --git a/grl_pipelines/benchmark/gmpo/icfm/hopper_medium_replay.py b/grl_pipelines/benchmark/gmpo/icfm/hopper_medium_replay.py index efe10cb..6b43605 100644 --- a/grl_pipelines/benchmark/gmpo/icfm/hopper_medium_replay.py +++ b/grl_pipelines/benchmark/gmpo/icfm/hopper_medium_replay.py @@ -62,7 +62,7 @@ ), ), dataset=dict( - type="GPD4RLDataset", + type="GPD4RLTensorDictDataset", args=dict( env_id=env_id, ), diff --git a/grl_pipelines/benchmark/gmpo/icfm/walker2d_medium.py b/grl_pipelines/benchmark/gmpo/icfm/walker2d_medium.py index 562f665..2d1e543 100644 --- a/grl_pipelines/benchmark/gmpo/icfm/walker2d_medium.py +++ b/grl_pipelines/benchmark/gmpo/icfm/walker2d_medium.py @@ -62,7 +62,7 @@ ), ), dataset=dict( - type="GPD4RLDataset", + type="GPD4RLTensorDictDataset", args=dict( env_id=env_id, ), diff --git a/grl_pipelines/benchmark/gmpo/icfm/walker2d_medium_expert.py b/grl_pipelines/benchmark/gmpo/icfm/walker2d_medium_expert.py index a807544..95015a1 100644 --- a/grl_pipelines/benchmark/gmpo/icfm/walker2d_medium_expert.py +++ b/grl_pipelines/benchmark/gmpo/icfm/walker2d_medium_expert.py @@ -62,7 +62,7 @@ ), ), dataset=dict( - type="GPD4RLDataset", + type="GPD4RLTensorDictDataset", args=dict( env_id=env_id, ), diff --git a/grl_pipelines/benchmark/gmpo/icfm/walker2d_medium_replay.py b/grl_pipelines/benchmark/gmpo/icfm/walker2d_medium_replay.py index c738d23..1113d09 100644 --- a/grl_pipelines/benchmark/gmpo/icfm/walker2d_medium_replay.py +++ b/grl_pipelines/benchmark/gmpo/icfm/walker2d_medium_replay.py @@ -62,7 +62,7 @@ ), ), dataset=dict( - type="GPD4RLDataset", + type="GPD4RLTensorDictDataset", args=dict( env_id=env_id, ), diff --git a/grl_pipelines/benchmark/gmpo/vpsde/halfcheetah_medium.py b/grl_pipelines/benchmark/gmpo/vpsde/halfcheetah_medium.py index 20a15b8..8b88614 100644 --- a/grl_pipelines/benchmark/gmpo/vpsde/halfcheetah_medium.py +++ b/grl_pipelines/benchmark/gmpo/vpsde/halfcheetah_medium.py @@ -66,7 +66,7 @@ ), ), dataset=dict( - type="GPD4RLDataset", + type="GPD4RLTensorDictDataset", args=dict( env_id=env_id, ), diff --git a/grl_pipelines/benchmark/gmpo/vpsde/halfcheetah_medium_expert.py b/grl_pipelines/benchmark/gmpo/vpsde/halfcheetah_medium_expert.py index aaf2b7b..fd5fe4f 100644 --- a/grl_pipelines/benchmark/gmpo/vpsde/halfcheetah_medium_expert.py +++ b/grl_pipelines/benchmark/gmpo/vpsde/halfcheetah_medium_expert.py @@ -66,7 +66,7 @@ ), ), dataset=dict( - type="GPD4RLDataset", + type="GPD4RLTensorDictDataset", args=dict( env_id=env_id, ), diff --git a/grl_pipelines/benchmark/gmpo/vpsde/halfcheetah_medium_replay.py b/grl_pipelines/benchmark/gmpo/vpsde/halfcheetah_medium_replay.py index 8c6d23d..19d03e1 100644 --- a/grl_pipelines/benchmark/gmpo/vpsde/halfcheetah_medium_replay.py +++ b/grl_pipelines/benchmark/gmpo/vpsde/halfcheetah_medium_replay.py @@ -66,7 +66,7 @@ ), ), dataset=dict( - type="GPD4RLDataset", + type="GPD4RLTensorDictDataset", args=dict( env_id=env_id, ), diff --git a/grl_pipelines/benchmark/gmpo/vpsde/hopper_medium.py b/grl_pipelines/benchmark/gmpo/vpsde/hopper_medium.py index 2efab09..c2b2f25 100644 --- a/grl_pipelines/benchmark/gmpo/vpsde/hopper_medium.py +++ b/grl_pipelines/benchmark/gmpo/vpsde/hopper_medium.py @@ -66,7 +66,7 @@ ), ), dataset=dict( - type="GPD4RLDataset", + type="GPD4RLTensorDictDataset", args=dict( env_id=env_id, ), diff --git a/grl_pipelines/benchmark/gmpo/vpsde/hopper_medium_expert.py b/grl_pipelines/benchmark/gmpo/vpsde/hopper_medium_expert.py index 99beb3e..c8aacfb 100644 --- a/grl_pipelines/benchmark/gmpo/vpsde/hopper_medium_expert.py +++ b/grl_pipelines/benchmark/gmpo/vpsde/hopper_medium_expert.py @@ -66,7 +66,7 @@ ), ), dataset=dict( - type="GPD4RLDataset", + type="GPD4RLTensorDictDataset", args=dict( env_id=env_id, ), diff --git a/grl_pipelines/benchmark/gmpo/vpsde/hopper_medium_replay.py b/grl_pipelines/benchmark/gmpo/vpsde/hopper_medium_replay.py index 1d19e9b..6319778 100644 --- a/grl_pipelines/benchmark/gmpo/vpsde/hopper_medium_replay.py +++ b/grl_pipelines/benchmark/gmpo/vpsde/hopper_medium_replay.py @@ -66,7 +66,7 @@ ), ), dataset=dict( - type="GPD4RLDataset", + type="GPD4RLTensorDictDataset", args=dict( env_id=env_id, ), diff --git a/grl_pipelines/benchmark/gmpo/vpsde/walker2d_medium.py b/grl_pipelines/benchmark/gmpo/vpsde/walker2d_medium.py index 55cc1e5..6d74626 100644 --- a/grl_pipelines/benchmark/gmpo/vpsde/walker2d_medium.py +++ b/grl_pipelines/benchmark/gmpo/vpsde/walker2d_medium.py @@ -66,7 +66,7 @@ ), ), dataset=dict( - type="GPD4RLDataset", + type="GPD4RLTensorDictDataset", args=dict( env_id=env_id, ), diff --git a/grl_pipelines/benchmark/gmpo/vpsde/walker2d_medium_expert.py b/grl_pipelines/benchmark/gmpo/vpsde/walker2d_medium_expert.py index c31fece..a045cf2 100644 --- a/grl_pipelines/benchmark/gmpo/vpsde/walker2d_medium_expert.py +++ b/grl_pipelines/benchmark/gmpo/vpsde/walker2d_medium_expert.py @@ -66,7 +66,7 @@ ), ), dataset=dict( - type="GPD4RLDataset", + type="GPD4RLTensorDictDataset", args=dict( env_id=env_id, ), diff --git a/grl_pipelines/benchmark/gmpo/vpsde/walker2d_medium_replay.py b/grl_pipelines/benchmark/gmpo/vpsde/walker2d_medium_replay.py index 2faca8f..c1c9258 100644 --- a/grl_pipelines/benchmark/gmpo/vpsde/walker2d_medium_replay.py +++ b/grl_pipelines/benchmark/gmpo/vpsde/walker2d_medium_replay.py @@ -66,7 +66,7 @@ ), ), dataset=dict( - type="GPD4RLDataset", + type="GPD4RLTensorDictDataset", args=dict( env_id=env_id, ), diff --git a/grl_pipelines/diffusion_model/configurations/d4rl_halfcheetah_qgpo.py b/grl_pipelines/diffusion_model/configurations/d4rl_halfcheetah_qgpo.py index 7818d10..153e742 100644 --- a/grl_pipelines/diffusion_model/configurations/d4rl_halfcheetah_qgpo.py +++ b/grl_pipelines/diffusion_model/configurations/d4rl_halfcheetah_qgpo.py @@ -13,7 +13,7 @@ ), ) solver_type = "DPMSolver" - +action_augment_num = 16 config = EasyDict( train=dict( project="d4rl-halfcheetah-v2-qgpo", @@ -24,9 +24,10 @@ ), ), dataset=dict( - type="QGPOD4RLDataset", + type="QGPOCustomizedTensorDictDataset", args=dict( env_id="halfcheetah-medium-expert-v2", + action_augment_num=action_augment_num, ), ), model=dict( @@ -127,7 +128,7 @@ learning_rate=1e-4, iterations=600000, ), - sample_per_state=16, + action_augment_num=action_augment_num, fake_data_t_span=None if solver_type == "DPMSolver" else 32, energy_guided_policy=dict( batch_size=256, diff --git a/grl_pipelines/diffusion_model/configurations/d4rl_halfcheetah_srpo.py b/grl_pipelines/diffusion_model/configurations/d4rl_halfcheetah_srpo.py index 4765d35..dc74cc3 100644 --- a/grl_pipelines/diffusion_model/configurations/d4rl_halfcheetah_srpo.py +++ b/grl_pipelines/diffusion_model/configurations/d4rl_halfcheetah_srpo.py @@ -99,7 +99,7 @@ learning_rate=3e-4, iterations=600000, ), - sample_per_state=16, + action_augment_num=16, critic=dict( batch_size=256, iterations=600000, diff --git a/grl_pipelines/diffusion_model/configurations/d4rl_hopper_srpo.py b/grl_pipelines/diffusion_model/configurations/d4rl_hopper_srpo.py index 828ea3a..ca61c52 100644 --- a/grl_pipelines/diffusion_model/configurations/d4rl_hopper_srpo.py +++ b/grl_pipelines/diffusion_model/configurations/d4rl_hopper_srpo.py @@ -99,7 +99,7 @@ learning_rate=3e-4, iterations=2000000, ), - sample_per_state=16, + action_augment_num=16, critic=dict( batch_size=256, iterations=2000000, diff --git a/grl_pipelines/diffusion_model/configurations/d4rl_walker2d_qgpo.py b/grl_pipelines/diffusion_model/configurations/d4rl_walker2d_qgpo.py index cd67211..35f43c5 100644 --- a/grl_pipelines/diffusion_model/configurations/d4rl_walker2d_qgpo.py +++ b/grl_pipelines/diffusion_model/configurations/d4rl_walker2d_qgpo.py @@ -13,6 +13,7 @@ ), ) solver_type = "DPMSolver" +action_augment_num = 16 config = EasyDict( train=dict( project="d4rl-walker2d-v2-qgpo", @@ -23,9 +24,10 @@ ), ), dataset=dict( - type="QGPOD4RLDataset", + type="QGPOCustomizedTensorDictDataset", args=dict( env_id="walker2d-medium-expert-v2", + action_augment_num=action_augment_num, ), ), model=dict( @@ -126,7 +128,7 @@ learning_rate=1e-4, iterations=600000, ), - sample_per_state=16, + action_augment_num=action_augment_num, fake_data_t_span=None if solver_type == "DPMSolver" else 32, energy_guided_policy=dict( batch_size=256, diff --git a/grl_pipelines/diffusion_model/configurations/d4rl_walker2d_srpo.py b/grl_pipelines/diffusion_model/configurations/d4rl_walker2d_srpo.py index b730ebb..045a557 100644 --- a/grl_pipelines/diffusion_model/configurations/d4rl_walker2d_srpo.py +++ b/grl_pipelines/diffusion_model/configurations/d4rl_walker2d_srpo.py @@ -99,7 +99,7 @@ learning_rate=3e-4, iterations=2000000, ), - sample_per_state=16, + action_augment_num=16, critic=dict( batch_size=256, iterations=2000000, diff --git a/grl_pipelines/diffusion_model/configurations/lunarlander_continuous_qgpo.py b/grl_pipelines/diffusion_model/configurations/lunarlander_continuous_qgpo.py index 9d26332..0053c85 100644 --- a/grl_pipelines/diffusion_model/configurations/lunarlander_continuous_qgpo.py +++ b/grl_pipelines/diffusion_model/configurations/lunarlander_continuous_qgpo.py @@ -129,7 +129,7 @@ learning_rate=1e-4, iterations=100000, ), - sample_per_state=16, + action_augment_num=16, fake_data_t_span=None if solver_type == "DPMSolver" else 32, energy_guided_policy=dict( batch_size=256, diff --git a/grl_pipelines/diffusion_model/lunarlander_continuous_qgpo.py b/grl_pipelines/diffusion_model/lunarlander_continuous_qgpo.py index 17309d0..0181c54 100644 --- a/grl_pipelines/diffusion_model/lunarlander_continuous_qgpo.py +++ b/grl_pipelines/diffusion_model/lunarlander_continuous_qgpo.py @@ -13,7 +13,7 @@ def qgpo_pipeline(config): qgpo = QGPOAlgorithm( config, dataset=QGPOCustomizedDataset( - numpy_data_path="./data.npz", + numpy_data_path="./data.npz" ), )