Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions olmoearth_pretrain/evals/studio_ingest/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,12 @@ def main() -> int:
Returns:
Exit code (0 for success, non-zero for failure)
"""
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)-5s %(name)s: %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)

parser = argparse.ArgumentParser(
prog="studio_ingest",
description="Ingest Studio datasets into OlmoEarth eval system",
Expand Down
201 changes: 186 additions & 15 deletions olmoearth_pretrain/evals/studio_ingest/ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
import yaml
from rslearn.config import DatasetConfig
from rslearn.dataset.dataset import Dataset as RslearnDataset
from tqdm import tqdm
from upath import UPath

from olmoearth_pretrain.evals.datasets.rslearn_builder import parse_model_config
Expand Down Expand Up @@ -278,20 +279,31 @@ def _copy_from_gcs(
source_path: str,
dest_path: str,
source_groups: list[str] | None = None,
source_tags: dict[str, str] | None = None,
) -> str:
"""Copy dataset from GCS using gsutil with parallel transfers.

Uses gsutil -m for multi-threaded/multi-processing transfers.
Streams output directly to console for progress visibility.

Note: *source_tags* filtering is not supported for GCS sources.
If tags are specified a ``NotImplementedError`` is raised — download
the dataset locally first or use a local source.

Args:
source_path: GCS path (gs://bucket/path)
dest_path: Local destination path
source_groups: If specified, only copy these groups (subdirs under windows/)
source_tags: Not supported for GCS (raises NotImplementedError).

Returns:
Destination path
"""
if source_tags:
raise NotImplementedError(
"Tag-filtered copy is not supported for GCS sources. "
"Download the dataset locally first, then ingest from a local path."
)
logger.info(" Copy method: gsutil (parallel GCS transfer)")

# Create destination directory
Expand Down Expand Up @@ -375,10 +387,128 @@ def _tar_copy_cmd(src: str, dst: str, use_pv: bool) -> str:
return f"tar cf - -C {src} . | tar xf - -C {dst}"


def _window_matches_tags(
window_metadata_path: Path,
source_tags: dict[str, str],
) -> bool:
"""Check whether a window's metadata.json matches all required tags.

Args:
window_metadata_path: Path to the window's metadata.json
source_tags: Tags to match. Empty string value means "key exists".

Returns:
True if all tags match.
"""
try:
with open(window_metadata_path) as f:
meta = json.load(f)
except (json.JSONDecodeError, OSError):
return False

options = meta.get("options", {})
for key, value in source_tags.items():
if key not in options:
return False
if value and options[key] != value:
return False
return True


def _collect_matching_windows(
source_path: str,
source_groups: list[str] | None,
source_tags: dict[str, str],
) -> list[tuple[str, str]]:
"""Scan source windows and return (group, window_name) pairs matching tags.

Args:
source_path: Path to rslearn dataset
source_groups: If set, only scan these groups
source_tags: Tags each window must have

Returns:
List of (group_name, window_name) tuples that match.
"""
windows_dir = Path(source_path) / "windows"
if not windows_dir.exists():
return []

groups = source_groups or [d.name for d in windows_dir.iterdir() if d.is_dir()]
logger.info(" Scanning groups: %s", groups)

all_window_dirs: list[tuple[str, Path]] = []
for group in groups:
group_dir = windows_dir / group
if not group_dir.is_dir():
continue
for window_dir in group_dir.iterdir():
if window_dir.is_dir():
all_window_dirs.append((group, window_dir))

matched: list[tuple[str, str]] = []
pbar = tqdm(all_window_dirs, desc="Scanning windows for tags", unit="win")
for group, window_dir in pbar:
meta_path = window_dir / "metadata.json"
if meta_path.exists() and _window_matches_tags(meta_path, source_tags):
matched.append((group, window_dir.name))
pbar.set_postfix(matched=len(matched))
pbar.close()

logger.info(
" Tag scan complete: %d/%d windows matched tags %s",
len(matched),
len(all_window_dirs),
source_tags,
)
return matched


def _copy_filtered_windows(
source_path: str,
dest_path: str,
matched_windows: list[tuple[str, str]],
) -> None:
"""Copy only the matched windows from source to destination.

Uses shutil.copytree per window for simplicity and correctness on Weka.

Args:
source_path: Source dataset path
dest_path: Destination dataset path
matched_windows: List of (group, window_name) to copy
"""
from concurrent.futures import ThreadPoolExecutor, as_completed

num_workers = int(os.environ.get("OLMOEARTH_INGEST_WORKERS", "8"))
total = len(matched_windows)
logger.info(" Copying %d matched windows (workers=%d)...", total, num_workers)

def _copy_one(group: str, wname: str) -> str:
src = Path(source_path) / "windows" / group / wname
dst = Path(dest_path) / "windows" / group / wname
dst.parent.mkdir(parents=True, exist_ok=True)
shutil.copytree(str(src), str(dst))
return wname

pbar = tqdm(total=total, desc="Copying windows", unit="win")
with ThreadPoolExecutor(max_workers=num_workers) as pool:
futures = [
pool.submit(_copy_one, group, wname) for group, wname in matched_windows
]
for future in as_completed(futures):
future.result()
pbar.update(1)
pbar.close()

logger.info(" Finished copying %d windows", total)


def _copy_local(
source_path: str,
dest_path: str,
source_groups: list[str] | None = None,
source_tags: dict[str, str] | None = None,
) -> str:
"""Copy dataset locally using streaming tar pipe.

Expand All @@ -388,10 +518,15 @@ def _copy_local(
is preserved because tar archives relative paths from the source and
recreates them at the destination.

When *source_tags* is provided the bulk tar copy is replaced by a
per-window copy that only transfers windows whose ``metadata.json``
matches the requested tags.

Args:
source_path: Local source path
dest_path: Local destination path
source_groups: If specified, only copy these groups (subdirs under windows/)
source_tags: If specified, only copy windows matching these tags.

Returns:
Destination path
Expand All @@ -409,17 +544,23 @@ def _copy_local(
# Create destination directory
Path(dest_path).mkdir(parents=True, exist_ok=True)

# TODO: remove pv progress bar once copy performance is validated
has_pv = shutil.which("pv") is not None

logger.info(
" Copy method: streaming tar pipe%s", " (with pv progress)" if has_pv else ""
)

_try_copy_config_json(source_path, dest_path)

if source_groups:
# Copy only specified groups under windows/
if source_tags:
logger.info(" Copy method: tag-filtered per-window copy")
matched = _collect_matching_windows(source_path, source_groups, source_tags)
if not matched:
raise ValueError(
f"No windows in {source_path} matched tags {source_tags}. "
"Check that the tag key/values are correct."
)
_copy_filtered_windows(source_path, dest_path, matched)
elif source_groups:
has_pv = shutil.which("pv") is not None
logger.info(
" Copy method: streaming tar pipe%s",
" (with pv progress)" if has_pv else "",
)
logger.info(f" Copying only groups: {source_groups}")
for group in source_groups:
group_src = f"{source_path}/windows/{group}"
Expand All @@ -431,7 +572,11 @@ def _copy_local(
subprocess.run(cmd, shell=True, check=True) # nosec B602
logger.info(f" Copied group '{group}'")
else:
# Copy entire directory using streaming tar
has_pv = shutil.which("pv") is not None
logger.info(
" Copy method: streaming tar pipe%s",
" (with pv progress)" if has_pv else "",
)
cmd = _tar_copy_cmd(source_path, dest_path, has_pv)
logger.info(f" Running: {cmd}")
subprocess.run(cmd, shell=True, check=True) # nosec B602
Expand All @@ -444,6 +589,7 @@ def _copy_generic(
source_path: str,
dest_path: str,
source_groups: list[str] | None = None,
source_tags: dict[str, str] | None = None,
) -> str:
"""Fallback copy using UPath for unknown storage backends.

Expand All @@ -453,6 +599,7 @@ def _copy_generic(
source_path: Source path (any UPath-compatible)
dest_path: Destination path
source_groups: If specified, only copy these groups (subdirs under windows/)
source_tags: If specified, only copy windows matching these tags.

Returns:
Destination path
Expand All @@ -466,6 +613,20 @@ def _copy_generic(

_try_copy_config_json(source_path, dest_path)

# Tag-filtered copy: only works when source is local-like (metadata readable)
if source_tags:
logger.info(" Using tag-filtered copy (generic)")
matched = _collect_matching_windows(source_path, source_groups, source_tags)
if not matched:
raise ValueError(f"No windows in {source_path} matched tags {source_tags}.")
for group, wname in matched:
_copy_directory_recursive(
source / "windows" / group / wname,
dest / "windows" / group / wname,
)
logger.info(" Copied %d matched windows", len(matched))
return dest_path

# Copy windows directory (filtered by groups if specified)
windows_src = source / "windows"
windows_dst = dest / "windows"
Expand Down Expand Up @@ -524,6 +685,7 @@ def copy_dataset(
source_path: str,
name: str,
source_groups: list[str] | None = None,
source_tags: dict[str, str] | None = None,
untar_source: bool = False,
) -> str:
"""Copy an rslearn dataset to our Weka location.
Expand All @@ -534,11 +696,17 @@ def copy_dataset(
- Local/Weka (/weka, /) -> find + xargs -P (parallel local copy)
- Other -> UPath generic copy (fallback)

When *source_tags* is provided, the copy is filtered so that only
windows whose ``metadata.json`` contains the requested tag key/values
are transferred. This avoids copying entire large datasets when only a
subset is needed for evaluation.

Args:
source_path: Path to source rslearn dataset
name: Name for the copied dataset
source_groups: If specified, only copy these groups (subdirs under windows/).
If None, copies everything.
source_tags: If specified, only copy windows matching these tags.
untar_source: If True, source_path is a .tar.gz archive on GCS that
will be streamed and extracted directly to the destination.

Expand All @@ -550,10 +718,12 @@ def copy_dataset(
logger.info("=== Dataset Copy ===")
logger.info(f" Source: {source_path}")
logger.info(f" Destination: {dest_path}")
if source_tags:
logger.info(f" Filtering to tags: {source_tags}")
if source_groups:
logger.info(f" Filtering to groups: {source_groups}")
else:
logger.info(" Copying all groups")
if not source_groups and not source_tags:
logger.info(" Copying all groups (no tag/group filter)")

# Check if destination already exists
if Path(dest_path).exists():
Expand All @@ -565,11 +735,11 @@ def copy_dataset(
if untar_source and source_path.startswith("gs://"):
actual_path = _copy_from_gcs_tar(source_path, dest_path)
elif source_path.startswith("gs://"):
actual_path = _copy_from_gcs(source_path, dest_path, source_groups)
actual_path = _copy_from_gcs(source_path, dest_path, source_groups, source_tags)
elif source_path.startswith("/weka") or source_path.startswith("/"):
actual_path = _copy_local(source_path, dest_path, source_groups)
actual_path = _copy_local(source_path, dest_path, source_groups, source_tags)
else:
actual_path = _copy_generic(source_path, dest_path, source_groups)
actual_path = _copy_generic(source_path, dest_path, source_groups, source_tags)

logger.info(f" Dataset copy complete: {actual_path}")
return actual_path
Expand Down Expand Up @@ -892,6 +1062,7 @@ def ingest_dataset(config: IngestConfig) -> EvalDatasetEntry:
config.source_path,
config.name,
config.source_groups,
config.source_tags,
config.untar_source,
)
logger.info(f"[Step 1/6] Copy complete: {weka_path}")
Expand Down
Loading