Skip to content

Commit

Permalink
Add torchrl tensordict dataset and replay buffer.
Browse files Browse the repository at this point in the history
  • Loading branch information
zjowowen committed Jul 30, 2024
1 parent b9d9118 commit b5f28c7
Show file tree
Hide file tree
Showing 67 changed files with 1,303 additions and 64 deletions.
2 changes: 1 addition & 1 deletion grl_pipelines/benchmark/gmpg/gvp/halfcheetah_medium.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
),
),
dataset=dict(
type="GPD4RLDataset",
type="GPD4RLTensorDictDataset",
args=dict(
env_id=env_id,
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
),
),
dataset=dict(
type="GPD4RLDataset",
type="GPD4RLTensorDictDataset",
args=dict(
env_id=env_id,
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
),
),
dataset=dict(
type="GPD4RLDataset",
type="GPD4RLTensorDictDataset",
args=dict(
env_id=env_id,
),
Expand Down
2 changes: 1 addition & 1 deletion grl_pipelines/benchmark/gmpg/gvp/hopper_medium.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
),
),
dataset=dict(
type="GPD4RLDataset",
type="GPD4RLTensorDictDataset",
args=dict(
env_id=env_id,
),
Expand Down
2 changes: 1 addition & 1 deletion grl_pipelines/benchmark/gmpg/gvp/hopper_medium_expert.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
),
),
dataset=dict(
type="GPD4RLDataset",
type="GPD4RLTensorDictDataset",
args=dict(
env_id=env_id,
),
Expand Down
2 changes: 1 addition & 1 deletion grl_pipelines/benchmark/gmpg/gvp/hopper_medium_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
),
),
dataset=dict(
type="GPD4RLDataset",
type="GPD4RLTensorDictDataset",
args=dict(
env_id=env_id,
),
Expand Down
2 changes: 1 addition & 1 deletion grl_pipelines/benchmark/gmpg/gvp/walker2d_medium.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
),
),
dataset=dict(
type="GPD4RLDataset",
type="GPD4RLTensorDictDataset",
args=dict(
env_id=env_id,
),
Expand Down
2 changes: 1 addition & 1 deletion grl_pipelines/benchmark/gmpg/gvp/walker2d_medium_expert.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
),
),
dataset=dict(
type="GPD4RLDataset",
type="GPD4RLTensorDictDataset",
args=dict(
env_id=env_id,
),
Expand Down
2 changes: 1 addition & 1 deletion grl_pipelines/benchmark/gmpg/gvp/walker2d_medium_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
),
),
dataset=dict(
type="GPD4RLDataset",
type="GPD4RLTensorDictDataset",
args=dict(
env_id=env_id,
),
Expand Down
2 changes: 1 addition & 1 deletion grl_pipelines/benchmark/gmpg/icfm/halfcheetah_medium.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
),
),
dataset=dict(
type="GPD4RLDataset",
type="GPD4RLTensorDictDataset",
args=dict(
env_id=env_id,
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
),
),
dataset=dict(
type="GPD4RLDataset",
type="GPD4RLTensorDictDataset",
args=dict(
env_id=env_id,
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
),
),
dataset=dict(
type="GPD4RLDataset",
type="GPD4RLTensorDictDataset",
args=dict(
env_id=env_id,
),
Expand Down
2 changes: 1 addition & 1 deletion grl_pipelines/benchmark/gmpg/icfm/hopper_medium.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
),
),
dataset=dict(
type="GPD4RLDataset",
type="GPD4RLTensorDictDataset",
args=dict(
env_id=env_id,
),
Expand Down
2 changes: 1 addition & 1 deletion grl_pipelines/benchmark/gmpg/icfm/hopper_medium_expert.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
),
),
dataset=dict(
type="GPD4RLDataset",
type="GPD4RLTensorDictDataset",
args=dict(
env_id=env_id,
),
Expand Down
2 changes: 1 addition & 1 deletion grl_pipelines/benchmark/gmpg/icfm/hopper_medium_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
),
),
dataset=dict(
type="GPD4RLDataset",
type="GPD4RLTensorDictDataset",
args=dict(
env_id=env_id,
),
Expand Down
2 changes: 1 addition & 1 deletion grl_pipelines/benchmark/gmpg/icfm/walker2d_medium.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
),
),
dataset=dict(
type="GPD4RLDataset",
type="GPD4RLTensorDictDataset",
args=dict(
env_id=env_id,
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
),
),
dataset=dict(
type="GPD4RLDataset",
type="GPD4RLTensorDictDataset",
args=dict(
env_id=env_id,
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
),
),
dataset=dict(
type="GPD4RLDataset",
type="GPD4RLTensorDictDataset",
args=dict(
env_id=env_id,
),
Expand Down
206 changes: 206 additions & 0 deletions grl_pipelines/benchmark/gmpg/vpsde/antmaze_large_diverse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
import torch
from easydict import EasyDict

env_id = "antmaze-large-diverse-v0"
action_size = 8
state_size = 29
algorithm_type = "GMPG"
solver_type = "ODESolver"
model_type = "DiffusionModel"
generative_model_type = "VPSDE"
path = dict(
type="linear_vp_sde",
beta_0=0.1,
beta_1=20.0,
)
model_loss_type = "flow_matching"
project_name = f"d4rl-{env_id}-{algorithm_type}-{generative_model_type}"
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
t_embedding_dim = 32
t_encoder = dict(
type="GaussianFourierProjectionTimeEncoder",
args=dict(
embed_dim=t_embedding_dim,
scale=30.0,
),
)
model = dict(
device=device,
x_size=action_size,
solver=dict(
type="ODESolver",
args=dict(
library="torchdiffeq_adjoint",
),
),
path=path,
reverse_path=path,
model=dict(
type="velocity_function",
args=dict(
t_encoder=t_encoder,
backbone=dict(
type="TemporalSpatialResidualNet",
args=dict(
hidden_sizes=[512, 256, 128],
output_dim=action_size,
t_dim=t_embedding_dim,
condition_dim=state_size,
condition_hidden_dim=32,
t_condition_hidden_dim=128,
),
),
),
),
)

config = EasyDict(
train=dict(
project=project_name,
device=device,
wandb=dict(project=f"IQL-{env_id}-{algorithm_type}-{generative_model_type}"),
simulator=dict(
type="GymEnvSimulator",
args=dict(
env_id=env_id,
),
),
dataset=dict(
type="GPOD4RLDataset",
args=dict(
env_id=env_id,
device=device,
),
),
model=dict(
GPPolicy=dict(
device=device,
model_type=model_type,
model_loss_type=model_loss_type,
model=model,
critic=dict(
device=device,
q_alpha=1.0,
DoubleQNetwork=dict(
backbone=dict(
type="ConcatenateMLP",
args=dict(
hidden_sizes=[action_size + state_size, 256, 256],
output_size=1,
activation="relu",
),
),
),
VNetwork=dict(
backbone=dict(
type="MultiLayerPerceptron",
args=dict(
hidden_sizes=[state_size, 256, 256],
output_size=1,
activation="relu",
),
),
),
),
),
GuidedPolicy=dict(
model_type=model_type,
model=model,
),
),
parameter=dict(
algorithm_type=algorithm_type,
behaviour_policy=dict(
batch_size=4096,
learning_rate=1e-4,
epochs=4000,
),
t_span=32,
critic=dict(
batch_size=4096,
epochs=4000,
learning_rate=1e-4,
discount_factor=0.99,
update_momentum=0.005,
tau=0.9,
method="iql",
),
guided_policy=dict(
batch_size=40960,
epochs=100,
learning_rate=5e-6,
copy_from_basemodel=True,
gradtime_step=1000,
eta=0.5,
),
evaluation=dict(
eval=True,
repeat=10,
interval=5,
),
checkpoint_path=f"./{project_name}/checkpoint",
checkpoint_freq=5,
),
),
deploy=dict(
device=device,
env=dict(
env_id=env_id,
seed=0,
),
t_span=32,
),
)


if __name__ == "__main__":

import gym
import d4rl
import numpy as np

from grl.algorithms.gmpg import GPAlgorithm
from grl.utils.log import log

def gp_pipeline(config):

gp = GPAlgorithm(config)

# ---------------------------------------
# Customized train code ↓
# ---------------------------------------
gp.train()
# ---------------------------------------
# Customized train code ↑
# ---------------------------------------

# ---------------------------------------
# Customized deploy code ↓
# ---------------------------------------

agent = gp.deploy()
env = gym.make(config.deploy.env.env_id)
total_reward_list = []
for i in range(100):
observation = env.reset()
total_reward = 0
while True:
# env.render()
observation, reward, done, _ = env.step(agent.act(observation))
total_reward += reward
if done:
observation = env.reset()
print(f"Episode {i}, total_reward: {total_reward}")
total_reward_list.append(total_reward)
break

print(
f"Average total reward: {np.mean(total_reward_list)}, std: {np.std(total_reward_list)}"
)

# ---------------------------------------
# Customized deploy code ↑
# ---------------------------------------

log.info("config: \n{}".format(config))
gp_pipeline(config)
Loading

0 comments on commit b5f28c7

Please sign in to comment.