@@ -157,23 +157,29 @@ def reset(
157157 if seed is not None :
158158 self .jax_key = jax .random .key (seed )
159159
160- self .reset_masked (mask = jnp .ones ((self .sim .n_worlds ), dtype = bool , device = self .device ))
160+ self .reset_masked (
161+ mask = jnp .ones ((self .sim .n_worlds ), dtype = bool , device = self .device ), reset_params = options
162+ )
161163 self .prev_done = jnp .zeros ((self .sim .n_worlds ), dtype = bool , device = self .device )
162164 return self ._obs (), {}
163165
164166 def reset_masked (self , mask : Array , reset_params : dict | None = None ) -> None :
165- default_reset_params = {
166- "pos_min" : jnp .array ([- 1.0 , - 1.0 , 1.0 ]), # x,y,z
167- "pos_max" : jnp .array ([1.0 , 1.0 , 2.0 ]), # x,y,z
168- "vel_min" : - 1.0 ,
169- "vel_max" : 1.0 ,
167+ if reset_params is None :
168+ reset_params = {}
169+
170+ default_drone_reset_params = {
171+ "pos_min" : reset_params .pop ("pos_min" , jnp .array ([- 1.0 , - 1.0 , 1.0 ])), # x,y,z
172+ "pos_max" : reset_params .pop ("pos_max" , jnp .array ([1.0 , 1.0 , 2.0 ])), # x,y,z
173+ "vel_min" : reset_params .pop ("vel_min" , - 1.0 ),
174+ "vel_max" : reset_params .pop ("vel_max" , 1.0 ),
170175 }
171176
172- if reset_params is not None :
173- invalid_keys = set (reset_params .keys ()) - set (default_reset_params .keys ())
174- if invalid_keys :
175- raise ValueError (f"Invalid bounds keys: { invalid_keys } " )
176- default_reset_params .update (reset_params )
177+ # sanity check to see if all keys have been used
178+ if len (reset_params ) > 0 :
179+ warnings .warn (
180+ f"Unused reset parameters: { reset_params .keys ()} . "
181+ "These will be ignored in the reset function. In case this parameter has already been used, please make sure to pop it from the dictionary."
182+ )
177183
178184 self .sim .reset (mask = mask )
179185 mask3d = mask [:, None , None ]
@@ -183,8 +189,8 @@ def reset_masked(self, mask: Array, reset_params: dict | None = None) -> None:
183189 init_pos = jax .random .uniform (
184190 key = subkey ,
185191 shape = (self .sim .n_worlds , self .sim .n_drones , 3 ),
186- minval = default_reset_params ["pos_min" ],
187- maxval = default_reset_params ["pos_max" ],
192+ minval = default_drone_reset_params ["pos_min" ],
193+ maxval = default_drone_reset_params ["pos_max" ],
188194 )
189195 self .sim .data = self .sim .data .replace (
190196 states = self .sim .data .states .replace (
@@ -196,8 +202,8 @@ def reset_masked(self, mask: Array, reset_params: dict | None = None) -> None:
196202 init_vel = jax .random .uniform (
197203 key = subkey ,
198204 shape = (self .sim .n_worlds , self .sim .n_drones , 3 ),
199- minval = default_reset_params ["vel_min" ],
200- maxval = default_reset_params ["vel_max" ],
205+ minval = default_drone_reset_params ["vel_min" ],
206+ maxval = default_drone_reset_params ["vel_max" ],
201207 )
202208 self .sim .data = self .sim .data .replace (
203209 states = self .sim .data .states .replace (
@@ -242,7 +248,9 @@ def render(self):
242248 def _obs (self ) -> dict [str , Array ]:
243249 fields = self .obs_keys
244250 states = [getattr (self .sim .data .states , field ) for field in fields ]
245- return {k : v .squeeze () for k , v in zip (fields , states )}
251+ return {
252+ k : v [:, 0 , :] for k , v in zip (fields , states )
253+ } # drop n_drones dimension, as it is always 1 for now
246254
247255 def close (self ):
248256 self .sim .close ()
@@ -273,19 +281,22 @@ def _reward(prev_done: Array, terminated: Array, states: SimState, goal: Array)
273281 reward = jnp .where (prev_done .reshape (- 1 , 1 ), 0.0 , reward )
274282 return reward
275283
276- def reset_masked (self , mask : Array ) -> None :
277- super ().reset_masked (mask )
284+ def reset_masked (self , mask : Array , reset_params : dict | None = None ) -> None :
285+ if reset_params is None :
286+ reset_params = {}
278287
279288 # Generate new goals
280289 self .jax_key , subkey = jax .random .split (self .jax_key )
281290 new_goals = jax .random .uniform (
282291 key = subkey ,
283292 shape = (self .sim .n_worlds , 3 ),
284- minval = jnp .array ([- 1.0 , - 1.0 , 0.5 ]), # x,y,z
285- maxval = jnp .array ([1.0 , 1.0 , 1.5 ]), # x,y,z
293+ minval = reset_params . pop ( "goal_pos_min" , jnp .array ([- 1.0 , - 1.0 , 0.5 ]) ), # x,y,z
294+ maxval = reset_params . pop ( "goal_pos_max" , jnp .array ([1.0 , 1.0 , 1.5 ]) ), # x,y,z
286295 )
287296 self .goal = self .goal .at [mask ].set (new_goals [mask ])
288297
298+ super ().reset_masked (mask , reset_params )
299+
289300 def step (self , action : Array ) -> tuple [Array , Array , Array , Array , dict ]:
290301 if self .render_goal_marker :
291302 for i in range (self .sim .n_worlds ):
@@ -300,7 +311,9 @@ def step(self, action: Array) -> tuple[Array, Array, Array, Array, dict]:
300311
301312 def _obs (self ) -> dict [str , Array ]:
302313 obs = super ()._obs ()
303- obs ["difference_to_goal" ] = [self .goal - self .sim .data .states .pos ]
314+ obs ["difference_to_goal" ] = (
315+ self .goal - self .sim .data .states .pos [:, 0 , :]
316+ ) # drop n_drones dimension, as it is always 1 for now
304317 return obs
305318
306319
@@ -329,22 +342,27 @@ def _reward(prev_done: Array, terminated: Array, states: SimState, target_vel: A
329342 reward = jnp .where (prev_done .reshape (- 1 , 1 ), 0.0 , reward )
330343 return reward
331344
332- def reset_masked (self , mask : Array ) -> None :
333- super ().reset_masked (mask )
345+ def reset_masked (self , mask : Array , reset_params : dict | None = None ) -> None :
346+ if reset_params is None :
347+ reset_params = {}
334348
335349 # Generate new target_vels
336350 self .jax_key , subkey = jax .random .split (self .jax_key )
337351 new_target_vel = jax .random .uniform (
338352 key = subkey ,
339353 shape = (self .sim .n_worlds , 3 ),
340- minval = jnp .array ([- 1.0 , - 1.0 , - 1.0 ]), # x,y,z
341- maxval = jnp .array ([1.0 , 1.0 , 1.0 ]), # x,y,z
354+ minval = reset_params . pop ( "target_vel_min" , jnp .array ([- 1.0 , - 1.0 , - 1.0 ]) ), # x,y,z
355+ maxval = reset_params . pop ( "target_vel_max" , jnp .array ([1.0 , 1.0 , 1.0 ]) ), # x,y,z
342356 )
343357 self .target_vel = self .target_vel .at [mask ].set (new_target_vel [mask ])
344358
359+ super ().reset_masked (mask )
360+
345361 def _obs (self ) -> dict [str , Array ]:
346362 obs = super ()._obs ()
347- obs ["difference_to_target_vel" ] = [self .target_vel - self .sim .data .states .vel ]
363+ obs ["difference_to_target_vel" ] = (
364+ self .target_vel - self .sim .data .states .vel [:, 0 , :]
365+ ) # drop n_drones dimension, as it is always 1 for now
348366 return obs
349367
350368
@@ -375,9 +393,6 @@ def _reward(prev_done: Array, terminated: Array, states: SimState, goal: Array)
375393 reward = jnp .where (prev_done .reshape (- 1 , 1 ), 0.0 , reward )
376394 return reward
377395
378- def reset_masked (self , mask : Array ) -> None :
379- super ().reset_masked (mask )
380-
381396 def step (self , action : Array ) -> tuple [Array , Array , Array , Array , dict ]:
382397 if self .render_landing_target :
383398 for i in range (self .sim .n_worlds ):
@@ -392,7 +407,9 @@ def step(self, action: Array) -> tuple[Array, Array, Array, Array, dict]:
392407
393408 def _obs (self ) -> dict [str , Array ]:
394409 obs = super ()._obs ()
395- obs ["difference_to_goal" ] = [self .goal - self .sim .data .states .pos ]
410+ obs ["difference_to_goal" ] = (
411+ self .goal - self .sim .data .states .pos [:, 0 , :]
412+ ) # drop n_drones dimension, as it is always 1 for now
396413 return obs
397414
398415
@@ -478,14 +495,19 @@ def _reward(prev_done: Array, terminated: Array, states: SimState, goal: Array)
478495 reward = jnp .where (prev_done .reshape (- 1 , 1 ), 0.0 , reward )
479496 return reward
480497
481- def reset_masked (self , mask : Array ) -> None :
482- reset_params = {
483- "pos_min" : jnp .array ([- 0.1 , - 0.1 , 1.1 ]), # x,y,z
484- "pos_max" : jnp .array ([0.1 , 0.1 , 1.3 ]), # x,y,z
485- "vel_min" : - 0.5 ,
486- "vel_max" : 0.5 ,
498+ def reset_masked (self , mask : Array , reset_params : dict | None = None ) -> None :
499+ if reset_params is None :
500+ reset_params = {}
501+
502+ # Different initial conditions than CrazyflowBaseEnv
503+ default_drone_reset_params = {
504+ "pos_min" : reset_params .pop ("pos_min" , jnp .array ([- 0.1 , - 0.1 , 1.1 ])), # x,y,z
505+ "pos_max" : reset_params .pop ("pos_max" , jnp .array ([0.1 , 0.1 , 1.3 ])), # x,y,z
506+ "vel_min" : reset_params .pop ("vel_min" , - 0.5 ),
507+ "vel_max" : reset_params .pop ("vel_max" , 0.5 ),
487508 }
488- super ().reset_masked (mask , reset_params )
509+
510+ super ().reset_masked (mask , default_drone_reset_params )
489511
490512 def _obs (self ) -> dict [str , Array ]:
491513 obs = super ()._obs ()
0 commit comments