Skip to content

Commit d469782

Browse files
authored
make new droid idle filter default (#625)
2 parents 255fe9b + 245048e commit d469782

File tree

5 files changed

+44
-32
lines changed

5 files changed

+44
-32
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ This is an experiment: $\pi_0$ was developed for our own robots, which differ fr
1212

1313
## Updates
1414

15+
- [Sept 2025]: We have added an [improved idle filter](examples/droid/README_train.md#data-filtering) for DROID training.
1516
- [Jun 2025]: We have added [instructions](examples/droid/README_train.md) for using `openpi` to train VLAs on the full [DROID dataset](https://droid-dataset.github.io/). This is an approximate open-source implementation of the training pipeline used to train pi0-FAST-DROID.
1617

1718

examples/droid/README_train.md

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,20 +30,14 @@ First, change the `rlds_data_dir` path in your `TrainConfig` to the directory th
3030

3131
Then, compute normalization statistics (this will take ~10 minutes):
3232
```bash
33-
uv run --group rlds scripts/compute_norm_stats.py --config-name pi0_fast_droid_finetune --max-frames 10_000_000
33+
uv run --group rlds scripts/compute_norm_stats.py --config-name pi0_fast_droid_finetune
3434
```
3535

3636
Run training:
3737
```bash
38-
uv run --group rlds scripts/train.py pi0_fast_droid_finetune --exp-name=my_experiment --overwrite
38+
XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 uv run --group rlds scripts/train.py pi0_fast_droid_finetune --exp-name=my_experiment --overwrite
3939
```
4040

41-
By default, training uses no filtering. Alternatively, you can use a custom filtering scheme by providing a json that maps from episode keys to a list of time step ranges (denoted as a tuple of start and end time step indicies) in that episode you wish to keep. The episode key is a unique ID defined as `f"{recording_folderpath}--{file_path}"`. We choose this convention because both paths are easily accessible in the DROID RLDS episodes' metadata.
42-
43-
We provide an example of such a filtering scheme in [filtering/compute_droid_nonidle_ranges.py](examples/droid/filtering/compute_droid_nonidle_ranges.py), which is significantly more aggressive than the default (and thus leads to policies that take significantly fewer idle actions). We recommend using the filter produced by this script, and have also provided a copy of the filter [here](https://huggingface.co/KarlP/droid#filtering-data) specifically for `droid/1.0.1`.
44-
45-
The filter json you wish to use can be specified by modifying the line `filter_dict_path="<path_to_filter_dict>"` in [src/openpi/training/config.py](src/openpi/training/config.py).
46-
4741
**Note**: The original pi0-FAST-DROID model was trained with joint velocity actions.
4842
Joint velocity actions are not compatible with simulated evaluation environments (much harder to simulate).
4943
Thus, we do not recommend training with joint velocity actions and instead use joint position actions here.
@@ -57,6 +51,14 @@ If you start from PaliGemma instead of pi0 initialization, plan with ~5 days on
5751
We have experimented with LoRA for cheaper finetuning, but haven't found the policies to perform well so far.
5852

5953

54+
## Data Filtering
55+
56+
Like any diverse real-robot dataset, the DROID dataset isn't perfectly "clean" and we have found data filtering to significantly improve policy performance. Concretely, the DROID dataset contains many *idle* timesteps in which the robot does not move (in part due to the VR teleoperation interface that was used during data collection, we will not go into too much detail here). Appropriate filtering of these idle transitions can improve policy performance.
57+
58+
By default, our openpi training recipe implements the same idle filter used to train all pi-DROID models. We implement it by pre-computing which dataset indices to sample during training. You can check [compute_droid_nonidle_ranges.py](examples/droid/compute_droid_nonidle_ranges.py) for how we compute these indices. Roughly speaking, we filter any time steps for which the next chunk of actions would be largely idle. During training, our code automatically pulls our pre-computed list of indices from cloud storage and applies them. If you want to modify the idle filter / create your custom sampling logic, you can modify our script to generate a new index list and provide it via the `filter_dict_path="<path_to_filter_dict>"` argument in [src/openpi/training/config.py](src/openpi/training/config.py).
59+
60+
**Note**: our list of filtering indices is only valid for the `droid/1.0.1` dataset mentioned in the download section above, and will not provide valid filtering for any other version of the DROID dataset, so make sure you download the dataset above! If you have a custom DROID version, you can rerun the [compute_droid_nonidle_ranges.py](examples/droid/compute_droid_nonidle_ranges.py) script to generate a new list of sampling indices.
61+
6062
## RoboArena
6163

6264
Consider submitting your DROID policies to the [RoboArena benchmark](https://robo-arena.github.io/), which allows you to evaluate your policies on diverse tasks & scenes, **in the real world**! :)

examples/droid/compute_droid_nonidle_ranges.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
"""
2-
Iterates through the DROID dataset and a json mapping from episode unique IDs to ranges of time steps
3-
that should not be filtered out (all others are).
2+
Iterates through the DROID dataset and creates a json mapping from episode unique IDs to ranges of time steps
3+
that should be sampled during training (all others are filtered out).
44
5-
Specifically, we look for ranges of consecutive steps that contain at most min_idle_len consecutive idle frames
5+
Filtering logic:
6+
We look for ranges of consecutive steps that contain at most min_idle_len consecutive idle frames
67
(default to 7 -- as most DROID action-chunking policies run the first 8 actions generated in each chunk, filtering
78
this way means the policy will not get stuck outputting stationary actions). Additionally, we also only keep non-idle
89
ranges of length at least min_non_idle_len (default to 16 frames = ~1 second), while also removing the last

src/openpi/training/config.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ class DataConfig:
9393
rlds_data_dir: str | None = None
9494
# Action space for DROID dataset.
9595
action_space: droid_rlds_dataset.DroidActionSpace | None = None
96-
# Path to the filter dictionary file for DROID dataset
96+
# Path to the data filter file for DROID dataset
9797
filter_dict_path: str | None = None
9898

9999

@@ -350,7 +350,7 @@ class RLDSDroidDataConfig(DataConfigFactory):
350350
# to tuples denoting ranges of time steps to keep (start, end). Episodes are uniquely identified with
351351
# f"{recording_folderpath}--{file_path}", both of which are present in the RLDS episode metadata.
352352
# Path to the filter dictionary file.
353-
filter_dict_path: str | None = None
353+
filter_dict_path: str | None = "gs://openpi-assets/droid/droid_sample_ranges_v1_0_1.json"
354354

355355
@override
356356
def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig:
@@ -693,8 +693,6 @@ def __post_init__(self) -> None:
693693
# Set this to the path to your DROID RLDS dataset (the parent directory of the `droid` directory).
694694
rlds_data_dir="<path_to_droid_rlds_dataset>",
695695
action_space=droid_rlds_dataset.DroidActionSpace.JOINT_POSITION,
696-
# Set this to the path for whatever filtering json you wish to use (or None)
697-
filter_dict_path="<path_to_filtering_json_or_None>",
698696
),
699697
weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_fast_base/params"),
700698
lr_schedule=_optimizer.CosineDecaySchedule(

src/openpi/training/droid_rlds_dataset.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,14 @@
77

88
from enum import Enum
99
from enum import auto
10+
import json
11+
import logging
1012
from pathlib import Path
1113

14+
import tqdm
15+
16+
import openpi.shared.download as download
17+
1218

1319
class DroidActionSpace(Enum):
1420
"""Action space for DROID dataset."""
@@ -32,7 +38,7 @@ def __init__(
3238
shuffle_buffer_size: int = 250_000,
3339
num_parallel_reads: int = -1, # -1 == tf.data.AUTOTUNE -- hack to not import tf at top level
3440
num_parallel_calls: int = -1, # -1 == tf.data.AUTOTUNE -- hack to not import tf at top level
35-
filter_dict_path=None,
41+
filter_dict_path=None, # Path to json file with indices to sample during training
3642
):
3743
# Import tensorflow here to not make it mandatory in case RLDS data loader is not used.
3844
import dlimp as dl
@@ -52,32 +58,27 @@ def __init__(
5258
)
5359
)
5460

55-
# Repeat dataset so we never run out of data.
56-
dataset = dataset.repeat()
61+
# # Repeat dataset so we never run out of data.
62+
# dataset = dataset.repeat()
5763

5864
# Load the filter dictionary if provided.
59-
# The filter dictionary is a JSON file that maps episode keys to ranges of frames to keep
65+
# The filter dictionary is a JSON file that maps episode keys to ranges of frames to sample
6066
# (e.g.,
6167
# {
62-
# "keep_ranges": {
63-
# "<episode key>": [[0, 100], [200, 300]]
64-
# }
65-
# }
66-
# means keep frames 0-89 and 200-289).
68+
# "<episode key>": [[0, 100], [200, 300]]
69+
# }
70+
# means keep frames 0-99 and 200-299).
6771
if filter_dict_path is not None:
68-
import json
69-
70-
from tqdm import tqdm
71-
72-
with Path(filter_dict_path).open("r") as f:
72+
cached_filter_dict_path = download.maybe_download(filter_dict_path)
73+
with Path(cached_filter_dict_path).open("r") as f:
7374
filter_dict = json.load(f)
7475

75-
print(f"Using filter dictionary with {len(filter_dict['keep_ranges'])} episodes")
76+
logging.info(f"Using filter dictionary with {len(filter_dict)} episodes")
7677

7778
keys_tensor = []
7879
values_tensor = []
7980

80-
for episode_key, ranges in tqdm(filter_dict.items()):
81+
for episode_key, ranges in tqdm.tqdm(filter_dict.items(), desc="Creating idle filter hash table..."):
8182
for start, end in ranges:
8283
for t in range(start, end):
8384
frame_key = f"{episode_key}--{t}"
@@ -86,7 +87,7 @@ def __init__(
8687
self.filter_table = tf.lookup.StaticHashTable(
8788
tf.lookup.KeyValueTensorInitializer(keys_tensor, values_tensor), default_value=False
8889
)
89-
print("Filter hash table initialized")
90+
logging.info("Filter hash table initialized")
9091
else:
9192
self.filter_table = tf.lookup.StaticHashTable(
9293
tf.lookup.KeyValueTensorInitializer([""], [True]), default_value=True
@@ -122,6 +123,7 @@ def restructure(traj):
122123
traj_len = tf.shape(traj["action"])[0]
123124
indices = tf.as_string(tf.range(traj_len))
124125

126+
# Data filtering:
125127
# Compute a uniquely-identifying step ID by concatenating the recording folderpath, file path,
126128
# and each step's time step index. This will index into the filter hash table, and if it returns true,
127129
# then the frame passes the filter.
@@ -175,11 +177,19 @@ def chunk_actions(traj):
175177
# Flatten: map from trajectory dataset to dataset of individual action chunks
176178
dataset = dataset.flatten(num_parallel_calls=num_parallel_calls)
177179

180+
# Filter data that doesn't pass the filter
178181
def filter_from_dict(frame):
179182
return frame["passes_filter"]
180183

181184
dataset = dataset.filter(filter_from_dict)
182185

186+
# Remove "passes_filter" key from output
187+
def remove_passes_filter(frame):
188+
frame.pop("passes_filter")
189+
return frame
190+
191+
dataset = dataset.map(remove_passes_filter)
192+
183193
# Decode images: RLDS saves encoded images, only decode now for efficiency
184194
def decode_images(traj):
185195
traj["observation"]["image"] = tf.io.decode_image(

0 commit comments

Comments
 (0)