From 58661585e9a8fa62867ec552f3118da29b71e69e Mon Sep 17 00:00:00 2001 From: root Date: Wed, 22 Apr 2026 22:23:51 +0000 Subject: [PATCH] added verbose for task ingestion + tag filtering before window copying Made-with: Cursor --- olmoearth_pretrain/evals/studio_ingest/cli.py | 6 + .../evals/studio_ingest/ingest.py | 201 ++++++++++++++++-- 2 files changed, 192 insertions(+), 15 deletions(-) diff --git a/olmoearth_pretrain/evals/studio_ingest/cli.py b/olmoearth_pretrain/evals/studio_ingest/cli.py index 9bda36441..98ddf2c31 100644 --- a/olmoearth_pretrain/evals/studio_ingest/cli.py +++ b/olmoearth_pretrain/evals/studio_ingest/cli.py @@ -252,6 +252,12 @@ def main() -> int: Returns: Exit code (0 for success, non-zero for failure) """ + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)-5s %(name)s: %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + parser = argparse.ArgumentParser( prog="studio_ingest", description="Ingest Studio datasets into OlmoEarth eval system", diff --git a/olmoearth_pretrain/evals/studio_ingest/ingest.py b/olmoearth_pretrain/evals/studio_ingest/ingest.py index dd2a1575d..c30c78b06 100644 --- a/olmoearth_pretrain/evals/studio_ingest/ingest.py +++ b/olmoearth_pretrain/evals/studio_ingest/ingest.py @@ -55,6 +55,7 @@ import yaml from rslearn.config import DatasetConfig from rslearn.dataset.dataset import Dataset as RslearnDataset +from tqdm import tqdm from upath import UPath from olmoearth_pretrain.evals.datasets.rslearn_builder import parse_model_config @@ -278,20 +279,31 @@ def _copy_from_gcs( source_path: str, dest_path: str, source_groups: list[str] | None = None, + source_tags: dict[str, str] | None = None, ) -> str: """Copy dataset from GCS using gsutil with parallel transfers. Uses gsutil -m for multi-threaded/multi-processing transfers. Streams output directly to console for progress visibility. + Note: *source_tags* filtering is not supported for GCS sources. + If tags are specified a ``NotImplementedError`` is raised — download + the dataset locally first or use a local source. + Args: source_path: GCS path (gs://bucket/path) dest_path: Local destination path source_groups: If specified, only copy these groups (subdirs under windows/) + source_tags: Not supported for GCS (raises NotImplementedError). Returns: Destination path """ + if source_tags: + raise NotImplementedError( + "Tag-filtered copy is not supported for GCS sources. " + "Download the dataset locally first, then ingest from a local path." + ) logger.info(" Copy method: gsutil (parallel GCS transfer)") # Create destination directory @@ -375,10 +387,128 @@ def _tar_copy_cmd(src: str, dst: str, use_pv: bool) -> str: return f"tar cf - -C {src} . | tar xf - -C {dst}" +def _window_matches_tags( + window_metadata_path: Path, + source_tags: dict[str, str], +) -> bool: + """Check whether a window's metadata.json matches all required tags. + + Args: + window_metadata_path: Path to the window's metadata.json + source_tags: Tags to match. Empty string value means "key exists". + + Returns: + True if all tags match. + """ + try: + with open(window_metadata_path) as f: + meta = json.load(f) + except (json.JSONDecodeError, OSError): + return False + + options = meta.get("options", {}) + for key, value in source_tags.items(): + if key not in options: + return False + if value and options[key] != value: + return False + return True + + +def _collect_matching_windows( + source_path: str, + source_groups: list[str] | None, + source_tags: dict[str, str], +) -> list[tuple[str, str]]: + """Scan source windows and return (group, window_name) pairs matching tags. + + Args: + source_path: Path to rslearn dataset + source_groups: If set, only scan these groups + source_tags: Tags each window must have + + Returns: + List of (group_name, window_name) tuples that match. + """ + windows_dir = Path(source_path) / "windows" + if not windows_dir.exists(): + return [] + + groups = source_groups or [d.name for d in windows_dir.iterdir() if d.is_dir()] + logger.info(" Scanning groups: %s", groups) + + all_window_dirs: list[tuple[str, Path]] = [] + for group in groups: + group_dir = windows_dir / group + if not group_dir.is_dir(): + continue + for window_dir in group_dir.iterdir(): + if window_dir.is_dir(): + all_window_dirs.append((group, window_dir)) + + matched: list[tuple[str, str]] = [] + pbar = tqdm(all_window_dirs, desc="Scanning windows for tags", unit="win") + for group, window_dir in pbar: + meta_path = window_dir / "metadata.json" + if meta_path.exists() and _window_matches_tags(meta_path, source_tags): + matched.append((group, window_dir.name)) + pbar.set_postfix(matched=len(matched)) + pbar.close() + + logger.info( + " Tag scan complete: %d/%d windows matched tags %s", + len(matched), + len(all_window_dirs), + source_tags, + ) + return matched + + +def _copy_filtered_windows( + source_path: str, + dest_path: str, + matched_windows: list[tuple[str, str]], +) -> None: + """Copy only the matched windows from source to destination. + + Uses shutil.copytree per window for simplicity and correctness on Weka. + + Args: + source_path: Source dataset path + dest_path: Destination dataset path + matched_windows: List of (group, window_name) to copy + """ + from concurrent.futures import ThreadPoolExecutor, as_completed + + num_workers = int(os.environ.get("OLMOEARTH_INGEST_WORKERS", "8")) + total = len(matched_windows) + logger.info(" Copying %d matched windows (workers=%d)...", total, num_workers) + + def _copy_one(group: str, wname: str) -> str: + src = Path(source_path) / "windows" / group / wname + dst = Path(dest_path) / "windows" / group / wname + dst.parent.mkdir(parents=True, exist_ok=True) + shutil.copytree(str(src), str(dst)) + return wname + + pbar = tqdm(total=total, desc="Copying windows", unit="win") + with ThreadPoolExecutor(max_workers=num_workers) as pool: + futures = [ + pool.submit(_copy_one, group, wname) for group, wname in matched_windows + ] + for future in as_completed(futures): + future.result() + pbar.update(1) + pbar.close() + + logger.info(" Finished copying %d windows", total) + + def _copy_local( source_path: str, dest_path: str, source_groups: list[str] | None = None, + source_tags: dict[str, str] | None = None, ) -> str: """Copy dataset locally using streaming tar pipe. @@ -388,10 +518,15 @@ def _copy_local( is preserved because tar archives relative paths from the source and recreates them at the destination. + When *source_tags* is provided the bulk tar copy is replaced by a + per-window copy that only transfers windows whose ``metadata.json`` + matches the requested tags. + Args: source_path: Local source path dest_path: Local destination path source_groups: If specified, only copy these groups (subdirs under windows/) + source_tags: If specified, only copy windows matching these tags. Returns: Destination path @@ -409,17 +544,23 @@ def _copy_local( # Create destination directory Path(dest_path).mkdir(parents=True, exist_ok=True) - # TODO: remove pv progress bar once copy performance is validated - has_pv = shutil.which("pv") is not None - - logger.info( - " Copy method: streaming tar pipe%s", " (with pv progress)" if has_pv else "" - ) - _try_copy_config_json(source_path, dest_path) - if source_groups: - # Copy only specified groups under windows/ + if source_tags: + logger.info(" Copy method: tag-filtered per-window copy") + matched = _collect_matching_windows(source_path, source_groups, source_tags) + if not matched: + raise ValueError( + f"No windows in {source_path} matched tags {source_tags}. " + "Check that the tag key/values are correct." + ) + _copy_filtered_windows(source_path, dest_path, matched) + elif source_groups: + has_pv = shutil.which("pv") is not None + logger.info( + " Copy method: streaming tar pipe%s", + " (with pv progress)" if has_pv else "", + ) logger.info(f" Copying only groups: {source_groups}") for group in source_groups: group_src = f"{source_path}/windows/{group}" @@ -431,7 +572,11 @@ def _copy_local( subprocess.run(cmd, shell=True, check=True) # nosec B602 logger.info(f" Copied group '{group}'") else: - # Copy entire directory using streaming tar + has_pv = shutil.which("pv") is not None + logger.info( + " Copy method: streaming tar pipe%s", + " (with pv progress)" if has_pv else "", + ) cmd = _tar_copy_cmd(source_path, dest_path, has_pv) logger.info(f" Running: {cmd}") subprocess.run(cmd, shell=True, check=True) # nosec B602 @@ -444,6 +589,7 @@ def _copy_generic( source_path: str, dest_path: str, source_groups: list[str] | None = None, + source_tags: dict[str, str] | None = None, ) -> str: """Fallback copy using UPath for unknown storage backends. @@ -453,6 +599,7 @@ def _copy_generic( source_path: Source path (any UPath-compatible) dest_path: Destination path source_groups: If specified, only copy these groups (subdirs under windows/) + source_tags: If specified, only copy windows matching these tags. Returns: Destination path @@ -466,6 +613,20 @@ def _copy_generic( _try_copy_config_json(source_path, dest_path) + # Tag-filtered copy: only works when source is local-like (metadata readable) + if source_tags: + logger.info(" Using tag-filtered copy (generic)") + matched = _collect_matching_windows(source_path, source_groups, source_tags) + if not matched: + raise ValueError(f"No windows in {source_path} matched tags {source_tags}.") + for group, wname in matched: + _copy_directory_recursive( + source / "windows" / group / wname, + dest / "windows" / group / wname, + ) + logger.info(" Copied %d matched windows", len(matched)) + return dest_path + # Copy windows directory (filtered by groups if specified) windows_src = source / "windows" windows_dst = dest / "windows" @@ -524,6 +685,7 @@ def copy_dataset( source_path: str, name: str, source_groups: list[str] | None = None, + source_tags: dict[str, str] | None = None, untar_source: bool = False, ) -> str: """Copy an rslearn dataset to our Weka location. @@ -534,11 +696,17 @@ def copy_dataset( - Local/Weka (/weka, /) -> find + xargs -P (parallel local copy) - Other -> UPath generic copy (fallback) + When *source_tags* is provided, the copy is filtered so that only + windows whose ``metadata.json`` contains the requested tag key/values + are transferred. This avoids copying entire large datasets when only a + subset is needed for evaluation. + Args: source_path: Path to source rslearn dataset name: Name for the copied dataset source_groups: If specified, only copy these groups (subdirs under windows/). If None, copies everything. + source_tags: If specified, only copy windows matching these tags. untar_source: If True, source_path is a .tar.gz archive on GCS that will be streamed and extracted directly to the destination. @@ -550,10 +718,12 @@ def copy_dataset( logger.info("=== Dataset Copy ===") logger.info(f" Source: {source_path}") logger.info(f" Destination: {dest_path}") + if source_tags: + logger.info(f" Filtering to tags: {source_tags}") if source_groups: logger.info(f" Filtering to groups: {source_groups}") - else: - logger.info(" Copying all groups") + if not source_groups and not source_tags: + logger.info(" Copying all groups (no tag/group filter)") # Check if destination already exists if Path(dest_path).exists(): @@ -565,11 +735,11 @@ def copy_dataset( if untar_source and source_path.startswith("gs://"): actual_path = _copy_from_gcs_tar(source_path, dest_path) elif source_path.startswith("gs://"): - actual_path = _copy_from_gcs(source_path, dest_path, source_groups) + actual_path = _copy_from_gcs(source_path, dest_path, source_groups, source_tags) elif source_path.startswith("/weka") or source_path.startswith("/"): - actual_path = _copy_local(source_path, dest_path, source_groups) + actual_path = _copy_local(source_path, dest_path, source_groups, source_tags) else: - actual_path = _copy_generic(source_path, dest_path, source_groups) + actual_path = _copy_generic(source_path, dest_path, source_groups, source_tags) logger.info(f" Dataset copy complete: {actual_path}") return actual_path @@ -892,6 +1062,7 @@ def ingest_dataset(config: IngestConfig) -> EvalDatasetEntry: config.source_path, config.name, config.source_groups, + config.source_tags, config.untar_source, ) logger.info(f"[Step 1/6] Copy complete: {weka_path}")