Skip to content

Commit 3767d00

Browse files
committed
Merge branch 'main' of github.com:utiasDSL/crazyflow into main
2 parents 713f8f4 + 4e54f13 commit 3767d00

6 files changed

Lines changed: 107 additions & 44 deletions

File tree

benchmark/plot.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def plot_fps_data(data_folder: Path):
9292
print(f"Plot saved to {output_path}")
9393

9494

95-
def format_log_axes(ax, dfs, prefix):
95+
def format_log_axes(ax: plt.Axes, dfs: dict[str, pd.DataFrame], prefix: str):
9696
"""Format logarithmic axes with nice labels.
9797
9898
Args:

crazyflow/gymnasium_envs/crazyflow.py

Lines changed: 59 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -157,23 +157,29 @@ def reset(
157157
if seed is not None:
158158
self.jax_key = jax.random.key(seed)
159159

160-
self.reset_masked(mask=jnp.ones((self.sim.n_worlds), dtype=bool, device=self.device))
160+
self.reset_masked(
161+
mask=jnp.ones((self.sim.n_worlds), dtype=bool, device=self.device), reset_params=options
162+
)
161163
self.prev_done = jnp.zeros((self.sim.n_worlds), dtype=bool, device=self.device)
162164
return self._obs(), {}
163165

164166
def reset_masked(self, mask: Array, reset_params: dict | None = None) -> None:
165-
default_reset_params = {
166-
"pos_min": jnp.array([-1.0, -1.0, 1.0]), # x,y,z
167-
"pos_max": jnp.array([1.0, 1.0, 2.0]), # x,y,z
168-
"vel_min": -1.0,
169-
"vel_max": 1.0,
167+
if reset_params is None:
168+
reset_params = {}
169+
170+
default_drone_reset_params = {
171+
"pos_min": reset_params.pop("pos_min", jnp.array([-1.0, -1.0, 1.0])), # x,y,z
172+
"pos_max": reset_params.pop("pos_max", jnp.array([1.0, 1.0, 2.0])), # x,y,z
173+
"vel_min": reset_params.pop("vel_min", -1.0),
174+
"vel_max": reset_params.pop("vel_max", 1.0),
170175
}
171176

172-
if reset_params is not None:
173-
invalid_keys = set(reset_params.keys()) - set(default_reset_params.keys())
174-
if invalid_keys:
175-
raise ValueError(f"Invalid bounds keys: {invalid_keys}")
176-
default_reset_params.update(reset_params)
177+
# sanity check to see if all keys have been used
178+
if len(reset_params) > 0:
179+
warnings.warn(
180+
f"Unused reset parameters: {reset_params.keys()}. "
181+
"These will be ignored in the reset function. In case this parameter has already been used, please make sure to pop it from the dictionary."
182+
)
177183

178184
self.sim.reset(mask=mask)
179185
mask3d = mask[:, None, None]
@@ -183,8 +189,8 @@ def reset_masked(self, mask: Array, reset_params: dict | None = None) -> None:
183189
init_pos = jax.random.uniform(
184190
key=subkey,
185191
shape=(self.sim.n_worlds, self.sim.n_drones, 3),
186-
minval=default_reset_params["pos_min"],
187-
maxval=default_reset_params["pos_max"],
192+
minval=default_drone_reset_params["pos_min"],
193+
maxval=default_drone_reset_params["pos_max"],
188194
)
189195
self.sim.data = self.sim.data.replace(
190196
states=self.sim.data.states.replace(
@@ -196,8 +202,8 @@ def reset_masked(self, mask: Array, reset_params: dict | None = None) -> None:
196202
init_vel = jax.random.uniform(
197203
key=subkey,
198204
shape=(self.sim.n_worlds, self.sim.n_drones, 3),
199-
minval=default_reset_params["vel_min"],
200-
maxval=default_reset_params["vel_max"],
205+
minval=default_drone_reset_params["vel_min"],
206+
maxval=default_drone_reset_params["vel_max"],
201207
)
202208
self.sim.data = self.sim.data.replace(
203209
states=self.sim.data.states.replace(
@@ -242,7 +248,9 @@ def render(self):
242248
def _obs(self) -> dict[str, Array]:
243249
fields = self.obs_keys
244250
states = [getattr(self.sim.data.states, field) for field in fields]
245-
return {k: v.squeeze() for k, v in zip(fields, states)}
251+
return {
252+
k: v[:, 0, :] for k, v in zip(fields, states)
253+
} # drop n_drones dimension, as it is always 1 for now
246254

247255
def close(self):
248256
self.sim.close()
@@ -273,19 +281,22 @@ def _reward(prev_done: Array, terminated: Array, states: SimState, goal: Array)
273281
reward = jnp.where(prev_done.reshape(-1, 1), 0.0, reward)
274282
return reward
275283

276-
def reset_masked(self, mask: Array) -> None:
277-
super().reset_masked(mask)
284+
def reset_masked(self, mask: Array, reset_params: dict | None = None) -> None:
285+
if reset_params is None:
286+
reset_params = {}
278287

279288
# Generate new goals
280289
self.jax_key, subkey = jax.random.split(self.jax_key)
281290
new_goals = jax.random.uniform(
282291
key=subkey,
283292
shape=(self.sim.n_worlds, 3),
284-
minval=jnp.array([-1.0, -1.0, 0.5]), # x,y,z
285-
maxval=jnp.array([1.0, 1.0, 1.5]), # x,y,z
293+
minval=reset_params.pop("goal_pos_min", jnp.array([-1.0, -1.0, 0.5])), # x,y,z
294+
maxval=reset_params.pop("goal_pos_max", jnp.array([1.0, 1.0, 1.5])), # x,y,z
286295
)
287296
self.goal = self.goal.at[mask].set(new_goals[mask])
288297

298+
super().reset_masked(mask, reset_params)
299+
289300
def step(self, action: Array) -> tuple[Array, Array, Array, Array, dict]:
290301
if self.render_goal_marker:
291302
for i in range(self.sim.n_worlds):
@@ -300,7 +311,9 @@ def step(self, action: Array) -> tuple[Array, Array, Array, Array, dict]:
300311

301312
def _obs(self) -> dict[str, Array]:
302313
obs = super()._obs()
303-
obs["difference_to_goal"] = [self.goal - self.sim.data.states.pos]
314+
obs["difference_to_goal"] = (
315+
self.goal - self.sim.data.states.pos[:, 0, :]
316+
) # drop n_drones dimension, as it is always 1 for now
304317
return obs
305318

306319

@@ -329,22 +342,27 @@ def _reward(prev_done: Array, terminated: Array, states: SimState, target_vel: A
329342
reward = jnp.where(prev_done.reshape(-1, 1), 0.0, reward)
330343
return reward
331344

332-
def reset_masked(self, mask: Array) -> None:
333-
super().reset_masked(mask)
345+
def reset_masked(self, mask: Array, reset_params: dict | None = None) -> None:
346+
if reset_params is None:
347+
reset_params = {}
334348

335349
# Generate new target_vels
336350
self.jax_key, subkey = jax.random.split(self.jax_key)
337351
new_target_vel = jax.random.uniform(
338352
key=subkey,
339353
shape=(self.sim.n_worlds, 3),
340-
minval=jnp.array([-1.0, -1.0, -1.0]), # x,y,z
341-
maxval=jnp.array([1.0, 1.0, 1.0]), # x,y,z
354+
minval=reset_params.pop("target_vel_min", jnp.array([-1.0, -1.0, -1.0])), # x,y,z
355+
maxval=reset_params.pop("target_vel_max", jnp.array([1.0, 1.0, 1.0])), # x,y,z
342356
)
343357
self.target_vel = self.target_vel.at[mask].set(new_target_vel[mask])
344358

359+
super().reset_masked(mask)
360+
345361
def _obs(self) -> dict[str, Array]:
346362
obs = super()._obs()
347-
obs["difference_to_target_vel"] = [self.target_vel - self.sim.data.states.vel]
363+
obs["difference_to_target_vel"] = (
364+
self.target_vel - self.sim.data.states.vel[:, 0, :]
365+
) # drop n_drones dimension, as it is always 1 for now
348366
return obs
349367

350368

@@ -375,9 +393,6 @@ def _reward(prev_done: Array, terminated: Array, states: SimState, goal: Array)
375393
reward = jnp.where(prev_done.reshape(-1, 1), 0.0, reward)
376394
return reward
377395

378-
def reset_masked(self, mask: Array) -> None:
379-
super().reset_masked(mask)
380-
381396
def step(self, action: Array) -> tuple[Array, Array, Array, Array, dict]:
382397
if self.render_landing_target:
383398
for i in range(self.sim.n_worlds):
@@ -392,7 +407,9 @@ def step(self, action: Array) -> tuple[Array, Array, Array, Array, dict]:
392407

393408
def _obs(self) -> dict[str, Array]:
394409
obs = super()._obs()
395-
obs["difference_to_goal"] = [self.goal - self.sim.data.states.pos]
410+
obs["difference_to_goal"] = (
411+
self.goal - self.sim.data.states.pos[:, 0, :]
412+
) # drop n_drones dimension, as it is always 1 for now
396413
return obs
397414

398415

@@ -478,14 +495,19 @@ def _reward(prev_done: Array, terminated: Array, states: SimState, goal: Array)
478495
reward = jnp.where(prev_done.reshape(-1, 1), 0.0, reward)
479496
return reward
480497

481-
def reset_masked(self, mask: Array) -> None:
482-
reset_params = {
483-
"pos_min": jnp.array([-0.1, -0.1, 1.1]), # x,y,z
484-
"pos_max": jnp.array([0.1, 0.1, 1.3]), # x,y,z
485-
"vel_min": -0.5,
486-
"vel_max": 0.5,
498+
def reset_masked(self, mask: Array, reset_params: dict | None = None) -> None:
499+
if reset_params is None:
500+
reset_params = {}
501+
502+
# Different initial conditions than CrazyflowBaseEnv
503+
default_drone_reset_params = {
504+
"pos_min": reset_params.pop("pos_min", jnp.array([-0.1, -0.1, 1.1])), # x,y,z
505+
"pos_max": reset_params.pop("pos_max", jnp.array([0.1, 0.1, 1.3])), # x,y,z
506+
"vel_min": reset_params.pop("vel_min", -0.5),
507+
"vel_max": reset_params.pop("vel_max", 0.5),
487508
}
488-
super().reset_masked(mask, reset_params)
509+
510+
super().reset_masked(mask, default_drone_reset_params)
489511

490512
def _obs(self) -> dict[str, Array]:
491513
obs = super()._obs()

examples/gymnasium_env.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,27 @@
99
def main():
1010
enable_cache()
1111
SEED = 42
12-
envs = gymnasium.make_vec("DroneLanding-v0", num_envs=20, freq=50, time_horizon_in_seconds=2)
12+
envs = gymnasium.make_vec("DroneReachPos-v0", num_envs=20, freq=50, time_horizon_in_seconds=2)
1313

14-
# This wrapper makes it possible to interact with the environment using numpy arrays, if
15-
# desired. JaxToTorch is available as well.
14+
# This wrapper makes it possible to interact with the environment using numpy arrays, if desired. JaxToTorch is available as well.
1615
envs = JaxToNumpy(envs)
1716

18-
# dummy action for going up (in attitude control)
17+
# Dummy action for going up (in attitude control)
1918
action = np.zeros((20, 4), dtype=np.float32)
2019
action[..., 0] = 0.4
2120

22-
obs, info = envs.reset(seed=SEED)
21+
# Environments provide reset parameters that can be used to set the initial state of the environment.
22+
obs, info = envs.reset(
23+
seed=SEED,
24+
options={
25+
"pos_min": np.array([-1.0, 1.0, 1.0]),
26+
"pos_max": np.array([-1.0, 1.0, 1.0]),
27+
"vel_min": 0.0,
28+
"vel_max": 0.0,
29+
"goal_pos_min": np.array([-1.0, 1.0, 1.0]),
30+
"goal_pos_max": np.array([-1.0, 1.0, 1.0]),
31+
},
32+
)
2333

2434
# Step through the environment
2535
for _ in range(100):

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ classifiers = [
1818
]
1919

2020
dependencies = [
21-
"jax>=0.5.0,<0.5.3", # 0.5.3 onwards has a bug that causes resets to not work properly. TODO: Investigate why this is happening
21+
"jax>=0.5.0,!=0.5.3,!=0.6.0", # 0.5.3 and 0.6.0 have a bug that causes resets to not work properly. This is related to the fusing of SimData.mjx_model into the _reset function.
2222
"mujoco>=3.3.0",
2323
"mujoco-mjx>=3.3.0",
2424
"gymnasium",
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import gymnasium
2+
import numpy as np
3+
import pytest
4+
from gymnasium.wrappers.vector import JaxToNumpy
5+
6+
import crazyflow # noqa: F401, register gymnasium envs
7+
8+
9+
@pytest.mark.integration
10+
def test_gymnasium_reset():
11+
"""Test reset behavior of the DroneReachPos-v0 environment."""
12+
SEED = 42
13+
envs = gymnasium.make_vec("DroneReachPos-v0", num_envs=1, freq=50, time_horizon_in_seconds=2)
14+
15+
envs = JaxToNumpy(envs)
16+
obs, _ = envs.reset(
17+
seed=SEED,
18+
options={
19+
"pos_min": np.array([-1.0, 1.0, 1.0]),
20+
"pos_max": np.array([-1.0, 1.0, 1.0]),
21+
"vel_min": 0.0,
22+
"vel_max": 0.0,
23+
"goal_pos_min": np.array([-1.0, 1.0, 1.0]),
24+
"goal_pos_max": np.array([-1.0, 1.0, 1.0]),
25+
},
26+
)
27+
assert np.all(obs["pos"] == np.array([[-1.0, 1.0, 1.0]]))
28+
assert np.all(obs["difference_to_goal"] == np.array([[.0, .0, .0]]))
29+
assert np.all(obs["vel"] == np.array([[0.0, 0.0, 0.0]]))

tests/integration/test_reset.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import numpy as np
33
import pytest
44

5+
import crazyflow # noqa: F401, register gymnasium envs
56
from crazyflow.control import Control
67
from crazyflow.sim import Physics, Sim
78

@@ -62,3 +63,4 @@ def test_reset_multi_world(physics: Physics):
6263
sim.step(sim.freq // sim.control_freq)
6364
assert jnp.all(sim.data.states.pos == final_pos)
6465
assert jnp.all(sim.data.states.quat == final_quat)
66+

0 commit comments

Comments
 (0)