Skip to content

Commit ebbce44

Browse files
author
tuzuolin
committed
fusion get data
1 parent 614a837 commit ebbce44

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

neorl2/envs/base.py

+7
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,13 @@ def get_dataset(self, traj_num=None):
9292
val_traj_num = int(traj_num/4)
9393
val_dataset, val_samples = get_dataset_traj_num(val_dataset, val_traj_num)
9494

95+
if self.spec.id == "Fusion" and traj_num is None:
96+
traj_num = 20
97+
if traj_num != None:
98+
train_dataset, train_samples = get_dataset_traj_num(train_dataset, traj_num)
99+
val_traj_num = int(traj_num/4)
100+
val_dataset, val_samples = get_dataset_traj_num(val_dataset, val_traj_num)
101+
95102
return train_dataset, val_dataset
96103

97104
def set_reward_func(self, reward_func):

0 commit comments

Comments
 (0)