diff --git a/.gitignore b/.gitignore index 08eddd1397..6880c132bb 100644 --- a/.gitignore +++ b/.gitignore @@ -43,3 +43,7 @@ perf_bench_data/ /data_juicer/ops/deduplicator/minhash.cpython-* /data_juicer/ops/deduplicator/tokenize.c /data_juicer/ops/deduplicator/tokenize.cpython-* + +# claude +.claude/ +CLAUDE.md diff --git a/data_juicer/config/__init__.py b/data_juicer/config/__init__.py index 8b62b0d832..02fc413268 100644 --- a/data_juicer/config/__init__.py +++ b/data_juicer/config/__init__.py @@ -6,7 +6,10 @@ merge_config, prepare_cfgs_for_export, prepare_side_configs, + resolve_job_directories, + resolve_job_id, update_op_attr, + validate_work_dir_config, ) __all__ = [ @@ -18,4 +21,7 @@ "get_default_cfg", "prepare_cfgs_for_export", "update_op_attr", + "validate_work_dir_config", + "resolve_job_id", + "resolve_job_directories", ] diff --git a/data_juicer/config/config.py b/data_juicer/config/config.py index 449f4e0bba..5f57de4d30 100644 --- a/data_juicer/config/config.py +++ b/data_juicer/config/config.py @@ -7,8 +7,10 @@ import sys import tempfile import time +import uuid from argparse import ArgumentError from contextlib import contextmanager +from datetime import datetime from typing import Dict, List, Optional, Union import yaml @@ -174,8 +176,8 @@ def init_configs(args: Optional[List[str]] = None, which_entry: object = None, l "--executor_type", type=str, default="default", - choices=["default", "ray"], - help='Type of executor, support "default" or "ray" for now.', + choices=["default", "ray", "ray_partitioned"], + help='Type of executor, support "default", "ray", or "ray_partitioned".', ) parser.add_argument( "--dataset_path", @@ -419,6 +421,72 @@ def init_configs(args: Optional[List[str]] = None, which_entry: object = None, l "checkpoint are changed, all ops will be rerun from the " "beginning.", ) + # Enhanced checkpoint configuration for PartitionedRayExecutor + parser.add_argument( + "--checkpoint.enabled", + type=bool, + default=True, + help="Enable enhanced checkpointing for PartitionedRayExecutor", + ) + parser.add_argument( + "--checkpoint.strategy", + type=str, + default="every_op", + choices=["every_op", "every_partition", "every_n_ops", "manual", "disabled"], + help="Checkpoint strategy: every_op, every_partition, every_n_ops, manual, disabled", + ) + parser.add_argument( + "--checkpoint.n_ops", + type=int, + default=1, + help="Number of operations between checkpoints for every_n_ops strategy", + ) + parser.add_argument( + "--checkpoint.op_names", + type=List[str], + default=[], + help="List of operation names to checkpoint for manual strategy", + ) + # Event logging configuration + parser.add_argument( + "--event_logging.enabled", + type=bool, + default=True, + help="Enable event logging for job tracking and resumption", + ) + # Logging configuration + parser.add_argument( + "--max_log_size_mb", + type=int, + default=100, + help="Maximum log file size in MB before rotation", + ) + parser.add_argument( + "--backup_count", + type=int, + default=5, + help="Number of backup log files to keep", + ) + # Storage configuration + parser.add_argument( + "--event_log_dir", + type=str, + default=None, + help="Separate directory for event logs (fast storage)", + ) + parser.add_argument( + "--checkpoint_dir", + type=str, + default=None, + help="Separate directory for checkpoints (large storage)", + ) + # Job management + parser.add_argument( + "--job_id", + type=str, + default=None, + help="Custom job ID for resumption and tracking. If not provided, a unique ID will be auto-generated.", + ) parser.add_argument( "--temp_dir", type=str, @@ -532,6 +600,123 @@ def init_configs(args: Optional[List[str]] = None, which_entry: object = None, l help="Whether to save all stats to only one file. Only used in " "Analysis.", ) parser.add_argument("--ray_address", type=str, default="auto", help="The address of the Ray cluster.") + + # Partitioning configuration for PartitionedRayExecutor + # Support both flat and nested partition configuration + parser.add_argument( + "--partition_size", + type=int, + default=10000, + help="Number of samples per partition for PartitionedRayExecutor (legacy flat config)", + ) + parser.add_argument( + "--max_partition_size_mb", + type=int, + default=128, + help="Maximum partition size in MB for PartitionedRayExecutor (legacy flat config)", + ) + + parser.add_argument( + "--preserve_intermediate_data", + type=bool, + default=False, + help="Preserve intermediate data for debugging (legacy flat config)", + ) + + # partition configuration + parser.add_argument( + "--partition.mode", + type=str, + default="auto", + choices=["manual", "auto"], + help="Partition mode: manual (specify num_of_partitions) or auto (use partition size optimizer)", + ) + parser.add_argument( + "--partition.num_of_partitions", + type=int, + default=4, + help="Number of partitions for manual mode (ignored in auto mode)", + ) + parser.add_argument( + "--partition.target_size_mb", + type=int, + default=256, + help="Target partition size in MB for auto mode (128, 256, 512, or 1024). " + "Controls how large each partition should be. Smaller = more checkpoints & better recovery, " + "larger = less overhead. Default 256MB balances memory safety and efficiency.", + ) + + # Resource optimization configuration + parser.add_argument( + "--resource_optimization.auto_configure", + type=bool, + default=False, + help="Enable automatic optimization of partition size, worker count, and other resource-dependent settings (nested resource_optimization config)", + ) + + # Intermediate storage configuration + parser.add_argument( + "--intermediate_storage.preserve_intermediate_data", + type=bool, + default=False, + help="Preserve intermediate data for debugging (nested intermediate_storage config)", + ) + parser.add_argument( + "--intermediate_storage.cleanup_temp_files", + type=bool, + default=True, + help="Clean up temporary files after processing (nested intermediate_storage config)", + ) + parser.add_argument( + "--intermediate_storage.cleanup_on_success", + type=bool, + default=False, + help="Clean up intermediate files even on successful completion (nested intermediate_storage config)", + ) + parser.add_argument( + "--intermediate_storage.retention_policy", + type=str, + default="keep_all", + choices=["keep_all", "keep_failed_only", "cleanup_all"], + help="File retention policy (nested intermediate_storage config)", + ) + parser.add_argument( + "--intermediate_storage.max_retention_days", + type=int, + default=7, + help="Maximum retention days for files (nested intermediate_storage config)", + ) + + # Intermediate storage format configuration + parser.add_argument( + "--intermediate_storage.format", + type=str, + default="parquet", + choices=["parquet", "arrow", "jsonl"], + help="Storage format for checkpoints and intermediate data (nested intermediate_storage config)", + ) + parser.add_argument( + "--intermediate_storage.compression", + type=str, + default="snappy", + choices=["snappy", "gzip", "none"], + help="Compression format for storage files (nested intermediate_storage config)", + ) + + parser.add_argument( + "--intermediate_storage.write_partitions", + type=bool, + default=True, + help="Whether to write intermediate partition files to disk (nested intermediate_storage config). Set to false for better performance when intermediate files aren't needed.", + ) + + parser.add_argument( + "--partition_dir", + type=str, + default=None, + help="Directory to store partition files. Supports {work_dir} placeholder. If not set, defaults to {work_dir}/partitions.", + ) + parser.add_argument( "--custom-operator-paths", nargs="+", help="Paths to custom operator scripts or directories." ) @@ -607,6 +792,16 @@ def init_configs(args: Optional[List[str]] = None, which_entry: object = None, l with timing_context("Updating operator process"): cfg = update_op_process(cfg, parser, used_ops) + # Validate config for resumption if job_id is provided + if not load_configs_only and hasattr(cfg, "job_id") and cfg.job_id: + # Check if this is a resumption attempt by looking for existing job directory + if cfg.work_dir and os.path.exists(cfg.work_dir): + logger.info(f"🔍 Checking for job resumption: {cfg.job_id}") + cfg._same_yaml_config = validate_config_for_resumption(cfg, cfg.work_dir, args) + else: + # New job, set flag to True + cfg._same_yaml_config = True + # copy the config file into the work directory if not load_configs_only: config_backup(cfg) @@ -619,7 +814,7 @@ def init_configs(args: Optional[List[str]] = None, which_entry: object = None, l global_cfg = cfg global_parser = parser - if cfg.debug: + if cfg.get("debug", False): logger.debug("In DEBUG mode.") return cfg @@ -647,7 +842,7 @@ def init_setup_from_cfg(cfg: Namespace, load_configs_only=False): """ Do some extra setup tasks after parsing config file or command line. - 1. create working directory and a log directory + 1. create working directory and logs directory 2. update cache directory 3. update checkpoint and `temp_dir` of tempfile @@ -670,6 +865,14 @@ def init_setup_from_cfg(cfg: Namespace, load_configs_only=False): if cfg.work_dir is None: cfg.work_dir = os.path.dirname(cfg.export_path) + cfg.export_path = os.path.abspath(cfg.export_path) + if cfg.work_dir is None: + cfg.work_dir = os.path.dirname(cfg.export_path) + + # Call resolve_job_directories to finalize all job-related paths + cfg = resolve_job_id(cfg) + cfg = resolve_job_directories(cfg) + timestamp = time.strftime("%Y%m%d%H%M%S", time.localtime(time.time())) if not load_configs_only: # For S3 paths, use a simplified export path for log filename @@ -679,12 +882,13 @@ def init_setup_from_cfg(cfg: Namespace, load_configs_only=False): export_rel_path = s3_path_parts[1] if len(s3_path_parts) > 1 else s3_path_parts[0] else: export_rel_path = os.path.relpath(cfg.export_path, start=cfg.work_dir) - log_dir = os.path.join(cfg.work_dir, "log") - if not os.path.exists(log_dir): - os.makedirs(log_dir, exist_ok=True) + + # Ensure event_log_dir (logs/) exists - this is where logs are actually saved + if not os.path.exists(cfg.event_log_dir): + os.makedirs(cfg.event_log_dir, exist_ok=True) logfile_name = f"export_{export_rel_path}_time_{timestamp}.txt" setup_logger( - save_dir=log_dir, + save_dir=cfg.event_log_dir, filename=logfile_name, level="DEBUG" if cfg.get("debug", False) else "INFO", redirect=cfg.get("executor_type", "default") == "default", @@ -1003,15 +1207,293 @@ def namespace_to_arg_list(namespace, prefix="", includes=None, excludes=None): return arg_list +def save_cli_arguments(cfg: Namespace): + """Save CLI arguments to cli.yaml in the work directory.""" + if not hasattr(cfg, "work_dir") or not cfg.work_dir: + return + + # Get the original CLI arguments if available + original_args = getattr(cfg, "_original_args", None) + if not original_args: + # Try to reconstruct from sys.argv if available + import sys + + original_args = sys.argv[1:] if len(sys.argv) > 1 else [] + + if not original_args: + logger.warning("No CLI arguments available to save") + return + + # Create cli.yaml in work directory + cli_path = os.path.join(cfg.work_dir, "cli.yaml") + + # Convert args to a simple format + cli_data = {"arguments": original_args} + + # Save as YAML + import yaml + + with open(cli_path, "w") as f: + yaml.dump(cli_data, f, default_flow_style=False, indent=2) + + logger.info(f"💾 Saved CLI arguments to: {cli_path}") + + +def validate_config_for_resumption(cfg: Namespace, work_dir: str, original_args: List[str] = None) -> bool: + """Validate that the current config matches the job's saved config for safe resumption. + + Does verbatim comparison between: + 1. Original config.yaml + cli.yaml (saved during job creation) + 2. Current config (from current command) + + Sets cfg._same_yaml_config = True/False for the executor to use. + """ + try: + from pathlib import Path + + # Find the original config file in the work directory + config_files = list(Path(work_dir).glob("*.yaml")) + list(Path(work_dir).glob("*.yml")) + if not config_files: + logger.warning(f"No config file found in work directory: {work_dir}") + cfg._same_yaml_config = False + return False + + # Find the original config.yaml (not cli.yaml) + original_config_file = None + for config_file in config_files: + if config_file.name != "cli.yaml": + original_config_file = config_file + break + + if not original_config_file: + logger.warning(f"No original config file found in work directory: {work_dir}") + cfg._same_yaml_config = False + return False + + # 1. Direct file comparison for config files + current_config_file = cfg.config[0] if hasattr(cfg, "config") and cfg.config else None + if not current_config_file: + logger.error("No current config file found") + cfg._same_yaml_config = False + return False + + with open(original_config_file, "r") as f: + original_config_content = f.read() + with open(current_config_file, "r") as f: + current_config_content = f.read() + + config_match = original_config_content.strip() == current_config_content.strip() + + # 2. Per-key comparison for CLI arguments + cli_file = Path(work_dir) / "cli.yaml" + cli_config = {} + if cli_file.exists(): + with open(cli_file, "r") as f: + cli_data = yaml.safe_load(f) + cli_config = _parse_cli_to_config(cli_data.get("arguments", [])) + + # Get current CLI arguments from the original args passed to init_configs + current_cli_args = original_args + if not current_cli_args: + # Fallback: try to get from sys.argv + import sys + + current_cli_args = sys.argv[1:] if len(sys.argv) > 1 else [] + + current_cli_config = _parse_cli_to_config(current_cli_args) + + # Compare CLI arguments per key + cli_differences = [] + all_cli_keys = set(cli_config.keys()) | set(current_cli_config.keys()) + excluded_keys = {"config", "_original_args", "backed_up_config_path", "_same_yaml_config", "job_id", "work_dir"} + + for key in all_cli_keys: + if key in excluded_keys: + continue + + original_value = cli_config.get(key) + current_value = current_cli_config.get(key) + + if original_value != current_value: + cli_differences.append({"key": key, "original": original_value, "current": current_value}) + + cli_match = len(cli_differences) == 0 + + if not config_match or not cli_match: + logger.error("❌ Config validation failed - configurations don't match:") + if not config_match: + logger.error(" [config] Config file content differs") + if not cli_match: + logger.error(" [cli] CLI arguments differ:") + for diff in cli_differences: + logger.error(f" {diff['key']}: {diff['original']} → {diff['current']}") + logger.error("💡 Use the same config file and CLI arguments for resumption") + cfg._same_yaml_config = False + return False + + logger.info("✅ Config validation passed - configurations match exactly") + cfg._same_yaml_config = True + return True + + except Exception as e: + logger.error(f"Error validating config for resumption: {e}") + cfg._same_yaml_config = False + return False + + +def _parse_cli_to_config(cli_args: list) -> dict: + """ + Parse CLI arguments into config dictionary format using the global parser. + + This ensures proper handling of: + - --key=value syntax + - Arguments with spaces + - Multiple values (nargs='+') + - Complex type conversions + + Args: + cli_args: List of CLI arguments to parse + + Returns: + Dictionary of parsed configuration values + """ + global global_parser + + if not cli_args: + return {} + + # If global_parser is available, use it for robust parsing + if global_parser: + try: + # For comparison purposes, we only care about override arguments, not the config file + # Filter out --config and --auto since they're handled separately + filtered_args = [] + i = 0 + while i < len(cli_args): + arg = cli_args[i] + if arg == "--config" or arg == "--auto": + # Skip --config/--auto and its value (if any) + if i + 1 < len(cli_args) and not cli_args[i + 1].startswith("--"): + i += 2 + else: + i += 1 + elif arg.startswith("--"): + # Keep other flags + filtered_args.append(arg) + i += 1 + elif filtered_args: + # Keep values that follow flags + filtered_args.append(arg) + i += 1 + else: + # Skip positional arguments (e.g., pytest test names) + i += 1 + + # If no override args, return empty dict + if not filtered_args: + return {} + + # Add --auto to satisfy the required argument (we'll filter it out later) + temp_cli_args = ["--auto"] + filtered_args + + # Use parse_known_args to handle unrecognized arguments gracefully + parsed_cfg, unknown = global_parser.parse_known_args(temp_cli_args) + # Convert to dict for comparison + config_dict = namespace_to_dict(parsed_cfg) + + # Remove arguments we don't want to compare + config_dict.pop("config", None) + config_dict.pop("auto", None) + + return config_dict + except (Exception, SystemExit) as e: + logger.debug(f"Failed to parse CLI args with global_parser: {e}. Falling back to manual parsing.") + + # Fallback to improved manual parsing if parser not available + config = {} + i = 0 + + while i < len(cli_args): + arg = cli_args[i] + + if arg.startswith("--"): + # Handle --key=value syntax + if "=" in arg: + key, value = arg[2:].split("=", 1) + config[key] = _parse_value(value) + i += 1 + else: + key = arg[2:] + + # Collect all values until next flag + values = [] + j = i + 1 + while j < len(cli_args) and not cli_args[j].startswith("--"): + values.append(cli_args[j]) + j += 1 + + if values: + # If multiple values, keep as list; otherwise, single value + if len(values) == 1: + config[key] = _parse_value(values[0]) + else: + config[key] = [_parse_value(v) for v in values] + i = j + else: + # Boolean flag (no value) + config[key] = True + i += 1 + else: + i += 1 + + return config + + +def _parse_value(value: str): + """Parse a string value to its appropriate type.""" + # Try to parse as different types + if value.lower() in ["true", "false"]: + return value.lower() == "true" + + try: + # Try int first + if "." not in value and "e" not in value.lower(): + return int(value) + except ValueError: + pass + + try: + # Try float + return float(value) + except ValueError: + pass + + # Return as string + return value + + def config_backup(cfg: Namespace): if not cfg.get("config", None): return cfg_path = os.path.abspath(cfg.config[0]) - work_dir = cfg.work_dir - target_path = os.path.join(work_dir, os.path.basename(cfg_path)) - logger.info(f"Back up the input config file [{cfg_path}] into the " f"work_dir [{work_dir}]") + + # Use the backed_up_config_path which should be set by resolve_job_directories + if hasattr(cfg, "backed_up_config_path"): + target_path = cfg.backed_up_config_path + else: + # Fallback: use work_dir with original filename + work_dir = cfg.work_dir + original_config_name = os.path.basename(cfg_path) + target_path = os.path.join(work_dir, original_config_name) + if not os.path.exists(target_path): + logger.info(f"Back up the input config file [{cfg_path}] to [{target_path}]") shutil.copyfile(cfg_path, target_path) + else: + logger.info(f"Config file [{cfg_path}] already exists at [{target_path}]") + + # Also save CLI arguments + save_cli_arguments(cfg) def display_config(cfg: Namespace): @@ -1173,6 +1655,24 @@ def get_init_configs(cfg: Union[Namespace, Dict], load_configs_only: bool = True temp_file = os.path.join(temp_dir, "job_dj_config.json") if isinstance(cfg, Namespace): cfg = namespace_to_dict(cfg) + + # Remove internal attributes that are not part of the configuration schema + # to avoid validation errors when re-initializing the config + if isinstance(cfg, dict): + cfg = cfg.copy() + # Remove internal attributes that are added during config processing + internal_attrs = [ + "_user_provided_job_id", + "_same_yaml_config", + "metadata_dir", + "results_dir", + "event_log_file", + "job_summary_file", + "backed_up_config_path", + ] + for attr in internal_attrs: + cfg.pop(attr, None) + # create a temp config file with open(temp_file, "w") as f: json.dump(prepare_cfgs_for_export(cfg), f) @@ -1215,3 +1715,116 @@ def prepare_cfgs_for_export(cfg): if op in cfg: _ = cfg.pop(op) return cfg + + +def resolve_job_id(cfg): + """Resolve or auto-generate job_id and set it on cfg.""" + job_id = getattr(cfg, "job_id", None) + + # Track whether job_id was user-provided + if job_id is not None: + # User explicitly provided a job_id + setattr(cfg, "_user_provided_job_id", True) + else: + # No job_id provided by user + setattr(cfg, "_user_provided_job_id", False) + timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S") + short_hash = uuid.uuid4().hex[:6] + job_id = f"{timestamp}_{short_hash}" + setattr(cfg, "job_id", job_id) + return cfg + + +def validate_work_dir_config(work_dir: str) -> None: + """ + Validate work_dir configuration to ensure {job_id} placement rules are followed. + + Args: + work_dir: The work_dir string to validate + + Raises: + ValueError: If {job_id} is not at the end of the path + """ + if "{job_id}" in work_dir: + # Check if {job_id} is at the end of the path + if not work_dir.rstrip("/").endswith("{job_id}"): + raise ValueError( + f"Invalid work_dir configuration: '{{job_id}}' must be the last part of the path. " + f"Current: '{work_dir}'. " + f"Expected format: 'path/to/directory/{{job_id}}'" + ) + + +def resolve_job_directories(cfg): + """ + Centralize directory resolution and placeholder substitution. Assumes job_id is already set. + + Job Directory Rules: + - If work_dir contains '{job_id}' placeholder, it MUST be the last part of the path + - Examples: + ✅ work_dir: "./outputs/my_project/{job_id}" # Valid + ✅ work_dir: "/data/experiments/{job_id}" # Valid + ❌ work_dir: "./outputs/{job_id}/results" # Invalid - {job_id} not at end + ❌ work_dir: "./{job_id}/outputs/data" # Invalid - {job_id} not at end + + - If work_dir does NOT contain '{job_id}', job_id will be appended automatically + - Examples: + work_dir: "./outputs/my_project" → work_dir: "./outputs/my_project/20250804_143022_abc123" + + After resolution, work_dir will always include job_id at the end. + """ + # 1. placeholder map + placeholder_map = {"work_dir": cfg.work_dir, "job_id": getattr(cfg, "job_id", "")} + + # 2. Validate {job_id} placement in work_dir before substitution + original_work_dir = cfg.work_dir + validate_work_dir_config(original_work_dir) + + # 3. substitute placeholders in all relevant paths (change-detection loop) + max_passes = 10 + for _ in range(max_passes): + changed = False + for key in ["work_dir", "event_log_dir", "checkpoint_dir", "export_path", "dataset_path", "partition_dir"]: + val = getattr(cfg, key, None) + if isinstance(val, str): + new_val = val.format(**placeholder_map) + if new_val != val: + setattr(cfg, key, new_val) + changed = True + # update placeholder_map in case work_dir or job_id changed + placeholder_map = {"work_dir": cfg.work_dir, "job_id": getattr(cfg, "job_id", "")} + if not changed: + break + else: + raise RuntimeError("Too many placeholder substitution passes (possible recursive placeholders?)") + + # 4. directory resolution + job_id = getattr(cfg, "job_id", None) + if not job_id: + raise ValueError("job_id must be set before resolving job directories.") + + # Ensure work_dir always includes job_id at the end + # If work_dir already ends with job_id (from placeholder substitution), keep it as-is + # Otherwise, append job_id automatically + if not (cfg.work_dir.endswith(job_id) or os.path.basename(cfg.work_dir) == job_id): + cfg.work_dir = os.path.join(cfg.work_dir, job_id) + + # All job-specific directories are under work_dir + if getattr(cfg, "event_log_dir", None) is None: + cfg.event_log_dir = os.path.join(cfg.work_dir, "logs") + if getattr(cfg, "checkpoint_dir", None) is None: + cfg.checkpoint_dir = os.path.join(cfg.work_dir, "checkpoints") + if getattr(cfg, "partition_dir", None) is None: + cfg.partition_dir = os.path.join(cfg.work_dir, "partitions") + cfg.metadata_dir = os.path.join(cfg.work_dir, "metadata") + cfg.results_dir = os.path.join(cfg.work_dir, "results") + cfg.event_log_file = os.path.join(cfg.work_dir, "events.jsonl") + cfg.job_summary_file = os.path.join(cfg.work_dir, "job_summary.json") + # Set backed_up_config_path using original config filename + if hasattr(cfg, "config") and cfg.config: + original_config_name = os.path.basename(cfg.config[0]) + cfg.backed_up_config_path = os.path.join(cfg.work_dir, original_config_name) + else: + cfg.backed_up_config_path = os.path.join(cfg.work_dir, "config.yaml") + + return cfg diff --git a/data_juicer/config/config_all.yaml b/data_juicer/config/config_all.yaml index a07e5bdecc..03fe725662 100644 --- a/data_juicer/config/config_all.yaml +++ b/data_juicer/config/config_all.yaml @@ -70,6 +70,12 @@ eoc_special_token: '<|__dj__eoc|>' # the special token executor_type: default # type of executor, support "default" or "ray" for now. ray_address: auto # the address of the Ray cluster. +# partition configuration (for ray_partitioned executor) +partition: + mode: auto # partition mode: "auto" (use optimizer) or "manual" (specify count) + num_of_partitions: 4 # number of partitions for manual mode + target_size_mb: 256 # target partition size in MB for auto mode (128, 256, 512, or 1024). 256MB balances memory safety and efficiency. + # only for data analysis percentiles: [0.25, 0.5, 0.75] # percentiles to analyze the dataset distribution export_original_dataset: false # whether to export the original dataset with stats. If you only need the stats of the dataset, setting it to false could speed up the exporting. diff --git a/data_juicer/core/__init__.py b/data_juicer/core/__init__.py index 7261b3419c..8d1207ec72 100644 --- a/data_juicer/core/__init__.py +++ b/data_juicer/core/__init__.py @@ -1,7 +1,13 @@ from .adapter import Adapter from .analyzer import Analyzer from .data import NestedDataset -from .executor import DefaultExecutor, ExecutorBase, ExecutorFactory +from .executor import ( + DefaultExecutor, + ExecutorBase, + ExecutorFactory, + PartitionedRayExecutor, + RayExecutor, +) from .exporter import Exporter from .monitor import Monitor from .ray_exporter import RayExporter @@ -14,6 +20,8 @@ "ExecutorBase", "ExecutorFactory", "DefaultExecutor", + "RayExecutor", + "PartitionedRayExecutor", "Exporter", "RayExporter", "Monitor", diff --git a/data_juicer/core/data/ray_dataset.py b/data_juicer/core/data/ray_dataset.py index 2e7cb55e24..2b3038ad6b 100644 --- a/data_juicer/core/data/ray_dataset.py +++ b/data_juicer/core/data/ray_dataset.py @@ -163,9 +163,29 @@ def process(self, operators, *, exporter=None, checkpointer=None, tracer=None) - if self._auto_proc: calculate_ray_np(operators) + # Check if dataset is empty - Ray returns None for columns() on empty datasets + # with unknown schema. If empty, skip processing as there's nothing to process. + try: + row_count = self.data.count() + except Exception: + row_count = 0 + + if row_count == 0: + from loguru import logger + + logger.warning("Dataset is empty (0 rows), skipping operator processing") + return self + # Cache columns once at start to avoid breaking pipeline with repeated columns() calls # Ray's columns() internally does limit(1) which forces execution and breaks streaming - cached_columns = set(self.data.columns()) + columns_result = self.data.columns() + # Handle empty dataset case where columns() returns None + if columns_result is None: + from loguru import logger + + logger.warning("Dataset has unknown schema (likely empty), skipping operator processing") + return self + cached_columns = set(columns_result) for op in operators: cached_columns = self._run_single_op(op, cached_columns) diff --git a/data_juicer/core/executor/__init__.py b/data_juicer/core/executor/__init__.py index 501d421834..5073c6760f 100644 --- a/data_juicer/core/executor/__init__.py +++ b/data_juicer/core/executor/__init__.py @@ -1,5 +1,7 @@ from .base import ExecutorBase from .default_executor import DefaultExecutor from .factory import ExecutorFactory +from .ray_executor import RayExecutor +from .ray_executor_partitioned import PartitionedRayExecutor -__all__ = ["ExecutorBase", "ExecutorFactory", "DefaultExecutor"] +__all__ = ["ExecutorBase", "ExecutorFactory", "DefaultExecutor", "RayExecutor", "PartitionedRayExecutor"] diff --git a/data_juicer/core/executor/dag_execution_mixin.py b/data_juicer/core/executor/dag_execution_mixin.py new file mode 100644 index 0000000000..b4a5b8a698 --- /dev/null +++ b/data_juicer/core/executor/dag_execution_mixin.py @@ -0,0 +1,863 @@ +""" +DAG Execution Mixin for Data-Juicer Executors + +This mixin provides DAG execution planning and monitoring that can be integrated +into existing executors to provide intelligent pipeline analysis and execution tracking. +""" + +import json +import os +import time +from pathlib import Path +from typing import Any, Dict, List, Optional + +from loguru import logger + +from data_juicer.core.executor.dag_execution_strategies import ( + DAGExecutionStrategy, + NonPartitionedDAGStrategy, + PartitionedDAGStrategy, + is_global_operation, +) +from data_juicer.core.executor.event_logging_mixin import EventType +from data_juicer.core.pipeline_dag import DAGNodeStatus, PipelineDAG + + +class DAGExecutionMixin: + """ + Mixin that provides DAG-based execution planning and monitoring. + + This mixin can be integrated into any executor to provide: + - DAG execution planning + - Execution monitoring tied to DAG nodes + - Event logging with DAG context + """ + + def __init__(self): + """Initialize the DAG execution mixin.""" + self.pipeline_dag: Optional[PipelineDAG] = None + self.dag_initialized = False + self.current_dag_node: Optional[str] = None + self.dag_execution_start_time: Optional[float] = None + self.dag_execution_strategy: Optional[DAGExecutionStrategy] = None + + def _initialize_dag_execution(self, cfg) -> None: + """Initialize DAG execution planning with appropriate strategy. + + Note: For standalone mode (default executor), DAG execution can be disabled + by setting cfg.use_dag = False. DAG execution is primarily useful for + distributed/partitioned executors where execution planning and monitoring + provide significant value. + """ + if self.dag_initialized: + return + + # Check if DAG execution is enabled (default: True for distributed executors, False for standalone) + use_dag = getattr(cfg, "use_dag", None) + if use_dag is None: + # Default: enable for partitioned executors, disable for standalone (default executor) + use_dag = self._is_partitioned_executor() or getattr(self, "executor_type", "default") != "default" + + if not use_dag: + logger.info("DAG execution disabled for standalone mode") + self.dag_initialized = True # Mark as initialized to skip future attempts + return + + logger.info("Initializing DAG execution planning...") + + # Determine execution strategy based on executor type + self.dag_execution_strategy = self._create_execution_strategy(cfg) + + # Generate DAG using strategy + self._generate_dag_with_strategy(cfg) + + self.dag_initialized = True + self.dag_execution_start_time = time.time() + + logger.info( + f"DAG execution planning initialized: {len(self.pipeline_dag.nodes)} nodes, {len(self.pipeline_dag.edges)} edges" + ) + + def _create_execution_strategy(self, cfg) -> DAGExecutionStrategy: + """Create the appropriate execution strategy based on executor type.""" + if self._is_partitioned_executor(): + return self._create_partitioned_strategy(cfg) + else: + return self._create_non_partitioned_strategy(cfg) + + def _is_partitioned_executor(self) -> bool: + """Determine if this is a partitioned executor.""" + return getattr(self, "executor_type", None) == "ray_partitioned" + + def _create_partitioned_strategy(self, cfg) -> DAGExecutionStrategy: + """Create partitioned execution strategy.""" + # Partition count should be determined by the executor, not the DAG mixin + # Get it from the executor's attribute if available, otherwise use a default + num_partitions = getattr(self, "num_partitions", None) + if num_partitions is None: + # Last resort: use a default (shouldn't happen in practice) + logger.error("Partition count not found in executor") + raise ValueError("Partition count not found in executor") + + return PartitionedDAGStrategy(num_partitions) + + def _create_non_partitioned_strategy(self, cfg) -> DAGExecutionStrategy: + """Create non-partitioned execution strategy.""" + return NonPartitionedDAGStrategy() + + def _generate_dag_with_strategy(self, cfg) -> None: + """Generate DAG using the selected strategy.""" + # Get operations directly from config + operations = self._get_operations_from_config(cfg) + + # Get strategy-specific parameters + strategy_kwargs = self._get_strategy_kwargs(cfg) + + # Generate nodes using strategy + nodes = self.dag_execution_strategy.generate_dag_nodes(operations, **strategy_kwargs) + + # Build dependencies using strategy + self.dag_execution_strategy.build_dependencies(nodes, operations, **strategy_kwargs) + + # Create PipelineDAG instance + self.pipeline_dag = PipelineDAG(cfg.work_dir) + self.pipeline_dag.nodes = nodes + + # Log DAG initialization + if log_method := getattr(self, "log_dag_build_start", None): + ast_info = { + "config_source": "process_config", + "build_start_time": time.time(), + "node_count": len(operations), + "depth": len(operations), # AST is linear, so depth equals number of operations + "operation_types": self._extract_operation_types_from_ops(operations), + } + log_method(ast_info) + + if log_method := getattr(self, "log_dag_build_complete", None): + dag_info = { + "node_count": len(self.pipeline_dag.nodes), + "edge_count": len(self.pipeline_dag.edges), + "parallel_groups_count": len(self.pipeline_dag.parallel_groups), + "execution_plan_length": len(self.pipeline_dag.execution_plan), + "build_duration": time.time() - (self.dag_execution_start_time or time.time()), + } + log_method(dag_info) + + # Save execution plan + if self.pipeline_dag: + plan_path = self.pipeline_dag.save_execution_plan() + if log_method := getattr(self, "log_dag_execution_plan_saved", None): + dag_info = { + "node_count": len(self.pipeline_dag.nodes), + "edge_count": len(self.pipeline_dag.edges), + "parallel_groups_count": len(self.pipeline_dag.parallel_groups), + } + log_method(plan_path, dag_info) + + def _get_operations_from_config(self, cfg) -> List: + """Get operations from configuration - can be overridden by executors.""" + # Default implementation - create operation instances + operations = [] + for op_config in cfg.process: + op_name = list(op_config.keys())[0] + op_args = op_config[op_name] or {} + + # Import and instantiate operation + from data_juicer.ops import OPERATORS + + try: + op_class = OPERATORS.modules[op_name] + operation = op_class(**op_args) + operations.append(operation) + except KeyError: + # If operation not found, create a mock operation for DAG planning + logger.warning(f"Operation {op_name} not found in OPERATORS registry, creating mock for DAG planning") + + class MockOperation: + def __init__(self, name, **kwargs): + self._name = name + self.config = kwargs + + operation = MockOperation(op_name, **op_args) + operations.append(operation) + + return operations + + def _get_strategy_kwargs(self, cfg) -> Dict[str, Any]: + """Get strategy-specific parameters - can be overridden by executors.""" + kwargs = {} + + if self._is_partitioned_executor(): + kwargs["convergence_points"] = self._detect_convergence_points(cfg) + + return kwargs + + def _detect_convergence_points(self, cfg) -> List[int]: + """Detect convergence points - can be overridden by executors.""" + operations = self._get_operations_from_config(cfg) + convergence_points = [] + + for op_idx, op in enumerate(operations): + # Detect global operations (deduplicators, etc.) + if is_global_operation(op): + convergence_points.append(op_idx) + + # Detect manual convergence points + if getattr(op, "converge_after", False): + convergence_points.append(op_idx) + + return convergence_points + + def _get_dag_node_for_operation(self, op_name: str, op_idx: int, **kwargs) -> Optional[str]: + """Get the DAG node ID for a given operation using strategy.""" + if not self.dag_execution_strategy: + return None + + return self.dag_execution_strategy.get_dag_node_id(op_name, op_idx, **kwargs) + + def _mark_dag_node_started(self, node_id: str) -> None: + """Mark a DAG node as started.""" + if not self.pipeline_dag or node_id not in self.pipeline_dag.nodes: + return + + node = self.pipeline_dag.nodes[node_id] + self.pipeline_dag.mark_node_started(node_id) + self.current_dag_node = node_id + + # Log DAG node start + if log_method := getattr(self, "log_dag_node_start", None): + node_info = { + "op_name": node.get("op_name") or node.get("operation_name", ""), + "op_type": node.get("op_type") or node.get("node_type", "operation"), + "execution_order": node.get("execution_order", 0), + } + log_method(node_id, node_info) + + def _mark_dag_node_completed(self, node_id: str, duration: float = None) -> None: + """Mark a DAG node as completed.""" + if not self.pipeline_dag or node_id not in self.pipeline_dag.nodes: + return + + node = self.pipeline_dag.nodes[node_id] + self.pipeline_dag.mark_node_completed(node_id, duration) + + # Log DAG node completion + if log_method := getattr(self, "log_dag_node_complete", None): + node_info = { + "op_name": node.get("op_name") or node.get("operation_name", ""), + "op_type": node.get("op_type") or node.get("node_type", "operation"), + "execution_order": node.get("execution_order", 0), + } + log_method(node_id, node_info, duration or 0) + + self.current_dag_node = None + + def _mark_dag_node_failed(self, node_id: str, error_message: str, duration: float = 0) -> None: + """Mark a DAG node as failed.""" + if not self.pipeline_dag or node_id not in self.pipeline_dag.nodes: + return + + node = self.pipeline_dag.nodes[node_id] + self.pipeline_dag.mark_node_failed(node_id, error_message) + + # Log DAG node failure + if log_method := getattr(self, "log_dag_node_failed", None): + node_info = { + "op_name": node.get("op_name") or node.get("operation_name", ""), + "op_type": node.get("op_type") or node.get("node_type", "operation"), + "execution_order": node.get("execution_order", 0), + } + log_method(node_id, node_info, error_message, duration) + + self.current_dag_node = None + + def _log_operation_with_dag_context( + self, op_name: str, op_idx: int, event_type: str, partition_id: int = 0, **kwargs + ) -> None: + """Log an operation event with DAG context. + + Args: + op_name: Operation name + op_idx: Operation index + event_type: Type of event ("op_start", "op_complete", "op_failed") + partition_id: Partition ID for partitioned executors (default: 0) + **kwargs: Additional arguments for logging + """ + # Get the corresponding DAG node + node_id = self._get_dag_node_for_operation(op_name, op_idx, partition_id=partition_id) + + # Add DAG node ID to metadata if found + if "metadata" not in kwargs: + kwargs["metadata"] = {} + + if node_id: + kwargs["metadata"]["dag_node_id"] = node_id + else: + # Log warning if DAG node not found + logger.warning(f"DAG node not found for operation {op_name} (idx {op_idx})") + + # Call the original logging method with correct parameters + if event_type == "op_start" and (log_method := getattr(self, "log_op_start", None)): + log_method(partition_id, op_name, op_idx, kwargs.get("metadata", {})) + elif event_type == "op_complete" and (log_method := getattr(self, "log_op_complete", None)): + log_method( + partition_id, + op_name, + op_idx, + kwargs.get("duration", 0), + kwargs.get("checkpoint_path"), + kwargs.get("input_rows", 0), + kwargs.get("output_rows", 0), + ) + elif event_type == "op_failed" and (log_method := getattr(self, "log_op_failed", None)): + log_method( + partition_id, op_name, op_idx, kwargs.get("error", "Unknown error"), kwargs.get("retry_count", 0) + ) + + def _pre_execute_operations_with_dag_monitoring(self, ops: List, partition_id: int = 0) -> None: + """Log operation start events with DAG monitoring before execution. + + This method should be called before dataset.process() to log operation start events. + Each executor can then call dataset.process() with its own specific parameters. + + Args: + ops: List of operations that will be executed + partition_id: Partition ID for partitioned executors (default: 0) + """ + if not self.pipeline_dag: + return + + # Log operation start events for all operations + for op_idx, op in enumerate(ops): + op_name = op._name + node_id = self._get_dag_node_for_operation(op_name, op_idx, partition_id=partition_id) + + if node_id: + # Mark DAG node as started + self._mark_dag_node_started(node_id) + + # Log operation start with DAG context + self._log_operation_with_dag_context(op_name, op_idx, "op_start", partition_id=partition_id) + else: + # Log operation start without DAG context + logger.warning(f"DAG node not found for operation {op_name}, logging without DAG context") + if log_method := getattr(self, "log_op_start", None): + log_method(partition_id, op_name, op_idx, {}) + + def _post_execute_operations_with_dag_monitoring( + self, ops: List, partition_id: int = 0, metrics: dict = None + ) -> None: + """Log operation completion events with DAG monitoring after execution. + + This method should be called after dataset.process() to log operation completion events. + + Args: + ops: List of operations that were executed + partition_id: Partition ID for partitioned executors (default: 0) + metrics: Optional dict with real execution metrics: + { + 'duration': float, + 'input_rows': int, + 'output_rows': int, + 'per_op_metrics': List[dict] # Optional per-op breakdown + } + """ + if not self.pipeline_dag: + return + + # Default metrics if not provided + if metrics is None: + metrics = {"duration": 0.0, "input_rows": 0, "output_rows": 0} + + # Check if we have per-op metrics + per_op_metrics = metrics.get("per_op_metrics", []) + + # Log operation completion events for all operations + for op_idx, op in enumerate(ops): + op_name = op._name + node_id = self._get_dag_node_for_operation(op_name, op_idx, partition_id=partition_id) + + # Get metrics for this specific op if available + if per_op_metrics and op_idx < len(per_op_metrics): + op_metrics = per_op_metrics[op_idx] + else: + # We materialize per group, not per op, so we can't measure intermediate row counts + # Only show what we actually know: + # - First op: input to group + # - Last op: output from group + # - Middle ops: no row counts (unknown) + num_ops = len(ops) + op_metrics = { + "duration": metrics["duration"] / num_ops if num_ops > 0 else 0.0, + } + + # Only show input rows for first op in group + if op_idx == 0 and metrics.get("input_rows"): + op_metrics["input_rows"] = metrics["input_rows"] + + # Only show output rows for last op in group + if op_idx == len(ops) - 1 and metrics.get("output_rows"): + op_metrics["output_rows"] = metrics["output_rows"] + + if node_id: + # Mark DAG node as completed with real duration + self._mark_dag_node_completed(node_id, op_metrics["duration"]) + + # Log operation completion with DAG context + self._log_operation_with_dag_context( + op_name, + op_idx, + "op_complete", + partition_id=partition_id, + duration=op_metrics["duration"], + input_rows=op_metrics.get("input_rows"), + output_rows=op_metrics.get("output_rows"), + ) + else: + # Log operation completion without DAG context + if log_method := getattr(self, "log_op_complete", None): + log_method( + partition_id, + op_name, + op_idx, + op_metrics["duration"], + None, + op_metrics.get("input_rows"), + op_metrics.get("output_rows"), + ) + + def _extract_operation_types_from_ops(self, operations: List) -> List[str]: + """Extract operation types from operations list.""" + types = set() + for op in operations: + # Determine op type from operation name or class + op_name = getattr(op, "_name", "") + if op_name.endswith("_filter"): + types.add("filter") + elif op_name.endswith("_mapper"): + types.add("mapper") + elif op_name.endswith("_deduplicator"): + types.add("deduplicator") + elif op_name.endswith("_selector"): + types.add("selector") + elif op_name.endswith("_grouper"): + types.add("grouper") + elif op_name.endswith("_aggregator"): + types.add("aggregator") + else: + # Try to infer from class hierarchy + from data_juicer.ops.base_op import Filter, Mapper + + if isinstance(op, Filter): + types.add("filter") + elif isinstance(op, Mapper): + types.add("mapper") + return list(types) + + def get_dag_execution_status(self) -> Dict[str, Any]: + """Get DAG execution status.""" + if not self.pipeline_dag: + return {"status": "not_initialized"} + + summary = self.pipeline_dag.get_execution_summary() + + return { + "status": "running" if summary["pending_nodes"] > 0 else "completed", + "summary": summary, + "execution_plan_length": len(self.pipeline_dag.execution_plan), + "parallel_groups_count": len(self.pipeline_dag.parallel_groups), + "dag_execution_start_time": self.dag_execution_start_time, + } + + def visualize_dag_execution_plan(self) -> str: + """Get visualization of the DAG execution plan.""" + if not self.pipeline_dag: + return "Pipeline DAG not initialized" + + return self.pipeline_dag.visualize() + + def get_dag_execution_plan_path(self) -> str: + """Get the path to the saved DAG execution plan.""" + if not self.pipeline_dag: + # If pipeline_dag is not initialized, try to construct the path from work_dir + work_dir = getattr(getattr(self, "cfg", None), "work_dir", None) + if work_dir: + return str(Path(work_dir) / "dag_execution_plan.json") + return "" + + # DAG execution plan is now saved directly in the work directory + return str(self.pipeline_dag.dag_dir / "dag_execution_plan.json") + + def reconstruct_dag_state_from_events(self, job_id: str) -> Optional[Dict[str, Any]]: + """Reconstruct DAG execution state from event logs. + + This method has been decomposed into smaller, focused methods for better + maintainability and testability. + + Args: + job_id: The job ID to analyze + + Returns: + Dictionary containing reconstructed DAG state and resumption information + """ + # Step 1: Validate event logger availability + if not getattr(self, "event_logger", None): + logger.warning("Event logger not available for DAG state reconstruction") + return None + + # Step 2: Load DAG events and execution plan + dag_events = self._load_dag_events() + dag_plan = self._load_dag_execution_plan() + if not dag_plan: + return None + + # Step 3: Reconstruct node states from plan and events + node_states = self._initialize_node_states_from_plan(dag_plan) + self._update_node_states_from_events(node_states, dag_events) + + # Step 4: Calculate statistics + statistics = self._calculate_dag_statistics(node_states) + + # Step 5: Determine ready nodes + ready_nodes = self._find_ready_nodes(node_states) + + # Step 6: Determine resumption strategy + resumption_info = self._determine_resumption_strategy(node_states, ready_nodes, statistics) + + return { + "job_id": job_id, + "dag_plan_path": self.get_dag_execution_plan_path(), + "node_states": node_states, + "statistics": statistics, + "resumption": resumption_info, + "execution_plan": dag_plan.get("execution_plan", []), + "parallel_groups": dag_plan.get("parallel_groups", []), + } + + def _load_dag_events(self) -> List[Any]: + """Load DAG-related events from the event logger. + + Returns: + List of DAG-related events + """ + return self.event_logger.get_events( + event_type=[ + EventType.DAG_BUILD_START, + EventType.DAG_BUILD_COMPLETE, + EventType.DAG_NODE_START, + EventType.DAG_NODE_COMPLETE, + EventType.DAG_NODE_FAILED, + EventType.DAG_EXECUTION_PLAN_SAVED, + EventType.OP_START, + EventType.OP_COMPLETE, + EventType.OP_FAILED, + ] + ) + + def _load_dag_execution_plan(self) -> Optional[Dict[str, Any]]: + """Load the saved DAG execution plan. + + Returns: + DAG execution plan dictionary, or None if loading fails + """ + dag_plan_path = self.get_dag_execution_plan_path() + if not os.path.exists(dag_plan_path): + logger.warning(f"DAG execution plan not found: {dag_plan_path}") + return None + + try: + with open(dag_plan_path, "r") as f: + return json.load(f) + except Exception as e: + logger.error(f"Failed to load DAG execution plan: {e}") + return None + + def _initialize_node_states_from_plan(self, dag_plan: Dict[str, Any]) -> Dict[str, Dict[str, Any]]: + """Initialize node states from the DAG execution plan. + + Args: + dag_plan: The loaded DAG execution plan + + Returns: + Dictionary mapping node_id to initial node state + """ + node_states = {} + for node_id, node_data in dag_plan.get("nodes", {}).items(): + node_states[node_id] = { + "node_id": node_id, + "op_name": node_data.get("op_name"), + "op_type": node_data.get("op_type"), + "status": DAGNodeStatus.PENDING.value, + "execution_order": node_data.get("execution_order", -1), + "dependencies": node_data.get("dependencies", []), + "dependents": node_data.get("dependents", []), + "start_time": None, + "end_time": None, + "actual_duration": 0.0, + "error_message": None, + } + return node_states + + def _update_node_states_from_events(self, node_states: Dict[str, Dict[str, Any]], dag_events: List[Any]) -> None: + """Update node states based on events. + + Args: + node_states: Dictionary of node states to update (modified in-place) + dag_events: List of DAG-related events + """ + for event in dag_events: + event_data = getattr(event, "__dict__", event) + + # Handle DAG node events + if event_data.get("event_type") == EventType.DAG_NODE_START.value: + self._handle_dag_node_start_event(event_data, node_states) + elif event_data.get("event_type") == EventType.DAG_NODE_COMPLETE.value: + self._handle_dag_node_complete_event(event_data, node_states) + elif event_data.get("event_type") == EventType.DAG_NODE_FAILED.value: + self._handle_dag_node_failed_event(event_data, node_states) + # Handle operation events with DAG context + elif event_data.get("event_type") in [ + EventType.OP_START.value, + EventType.OP_COMPLETE.value, + EventType.OP_FAILED.value, + ]: + self._handle_operation_event(event_data, node_states) + + def _handle_dag_node_start_event(self, event_data: Dict[str, Any], node_states: Dict[str, Dict[str, Any]]) -> None: + """Handle DAG_NODE_START event.""" + node_id = event_data.get("metadata", {}).get("dag_node_id") + if node_id and node_id in node_states: + node_states[node_id]["status"] = DAGNodeStatus.RUNNING.value + node_states[node_id]["start_time"] = event_data.get("timestamp") + + def _handle_dag_node_complete_event( + self, event_data: Dict[str, Any], node_states: Dict[str, Dict[str, Any]] + ) -> None: + """Handle DAG_NODE_COMPLETE event.""" + node_id = event_data.get("metadata", {}).get("dag_node_id") + if node_id and node_id in node_states: + node_states[node_id]["status"] = DAGNodeStatus.COMPLETED.value + node_states[node_id]["end_time"] = event_data.get("timestamp") + node_states[node_id]["actual_duration"] = event_data.get("duration", 0.0) + + def _handle_dag_node_failed_event(self, event_data: Dict[str, Any], node_states: Dict[str, Dict[str, Any]]) -> None: + """Handle DAG_NODE_FAILED event.""" + node_id = event_data.get("metadata", {}).get("dag_node_id") + if node_id and node_id in node_states: + node_states[node_id]["status"] = DAGNodeStatus.FAILED.value + node_states[node_id]["end_time"] = event_data.get("timestamp") + node_states[node_id]["actual_duration"] = event_data.get("duration", 0.0) + node_states[node_id]["error_message"] = event_data.get("error_message") + + def _handle_operation_event(self, event_data: Dict[str, Any], node_states: Dict[str, Dict[str, Any]]) -> None: + """Handle operation events (OP_START, OP_COMPLETE, OP_FAILED) with DAG context.""" + dag_context = event_data.get("metadata", {}).get("dag_context", {}) + node_id = dag_context.get("dag_node_id") + if not node_id or node_id not in node_states: + return + + event_type = event_data.get("event_type") + if event_type == EventType.OP_START.value: + node_states[node_id]["status"] = DAGNodeStatus.RUNNING.value + node_states[node_id]["start_time"] = event_data.get("timestamp") + elif event_type == EventType.OP_COMPLETE.value: + node_states[node_id]["status"] = DAGNodeStatus.COMPLETED.value + node_states[node_id]["end_time"] = event_data.get("timestamp") + node_states[node_id]["actual_duration"] = event_data.get("duration", 0.0) + elif event_type == EventType.OP_FAILED.value: + node_states[node_id]["status"] = DAGNodeStatus.FAILED.value + node_states[node_id]["end_time"] = event_data.get("timestamp") + node_states[node_id]["actual_duration"] = event_data.get("duration", 0.0) + node_states[node_id]["error_message"] = event_data.get("error_message") + + def _calculate_dag_statistics(self, node_states: Dict[str, Dict[str, Any]]) -> Dict[str, Any]: + """Calculate DAG execution statistics. + + Args: + node_states: Dictionary of node states + + Returns: + Dictionary with statistics + """ + total_nodes = len(node_states) + completed_nodes = sum(1 for node in node_states.values() if node["status"] == DAGNodeStatus.COMPLETED.value) + failed_nodes = sum(1 for node in node_states.values() if node["status"] == DAGNodeStatus.FAILED.value) + running_nodes = sum(1 for node in node_states.values() if node["status"] == DAGNodeStatus.RUNNING.value) + pending_nodes = sum(1 for node in node_states.values() if node["status"] == DAGNodeStatus.PENDING.value) + + return { + "total_nodes": total_nodes, + "completed_nodes": completed_nodes, + "failed_nodes": failed_nodes, + "running_nodes": running_nodes, + "pending_nodes": pending_nodes, + "ready_nodes": 0, # Will be set by caller + "completion_percentage": (completed_nodes / total_nodes * 100) if total_nodes > 0 else 0, + } + + def _find_ready_nodes(self, node_states: Dict[str, Dict[str, Any]]) -> List[str]: + """Find nodes that are ready to execute (all dependencies completed). + + Args: + node_states: Dictionary of node states + + Returns: + List of node IDs that are ready to execute + """ + ready_nodes = [] + for node_id, node_state in node_states.items(): + if node_state["status"] == DAGNodeStatus.PENDING.value: + # Check if all dependencies are completed + all_deps_completed = all( + node_states[dep_id]["status"] == DAGNodeStatus.COMPLETED.value + for dep_id in node_state["dependencies"] + if dep_id in node_states + ) + if all_deps_completed: + ready_nodes.append(node_id) + return ready_nodes + + def _determine_resumption_strategy( + self, node_states: Dict[str, Dict[str, Any]], ready_nodes: List[str], statistics: Dict[str, Any] + ) -> Dict[str, Any]: + """Determine the resumption strategy based on current DAG state. + + Args: + node_states: Dictionary of node states + ready_nodes: List of ready node IDs + statistics: DAG statistics + + Returns: + Dictionary with resumption information + """ + can_resume = True + resume_from_node = None + + # Priority 1: Resume from failed nodes + if statistics["failed_nodes"] > 0: + failed_node_ids = [ + node_id for node_id, state in node_states.items() if state["status"] == DAGNodeStatus.FAILED.value + ] + if failed_node_ids: + failed_node_ids.sort(key=lambda x: node_states[x]["execution_order"]) + resume_from_node = failed_node_ids[0] + + # Priority 2: Resume from running nodes + elif statistics["running_nodes"] > 0: + running_node_ids = [ + node_id for node_id, state in node_states.items() if state["status"] == DAGNodeStatus.RUNNING.value + ] + if running_node_ids: + running_node_ids.sort(key=lambda x: node_states[x]["execution_order"]) + resume_from_node = running_node_ids[0] + + # Priority 3: Start from ready nodes + elif ready_nodes: + ready_nodes_sorted = sorted(ready_nodes, key=lambda x: node_states[x]["execution_order"]) + resume_from_node = ready_nodes_sorted[0] + + # All nodes completed - cannot resume + elif statistics["completed_nodes"] == statistics["total_nodes"]: + can_resume = False + + return { + "can_resume": can_resume, + "resume_from_node": resume_from_node, + "ready_nodes": ready_nodes, + "failed_nodes": [ + node_id for node_id, state in node_states.items() if state["status"] == DAGNodeStatus.FAILED.value + ], + "running_nodes": [ + node_id for node_id, state in node_states.items() if state["status"] == DAGNodeStatus.RUNNING.value + ], + } + + def resume_dag_execution(self, job_id: str, dataset, ops: List) -> bool: + """ + Resume DAG execution from the last known state. + + Args: + job_id: The job ID to resume + dataset: The dataset to process + ops: List of operations to execute + + Returns: + True if resumption was successful, False otherwise + """ + # Reconstruct DAG state from events + dag_state = self.reconstruct_dag_state_from_events(job_id) + if not dag_state: + logger.error("Failed to reconstruct DAG state for resumption") + return False + + if not dag_state["resumption"]["can_resume"]: + logger.info("No resumption needed - all nodes completed") + return True + + # Load the DAG execution plan + if not self.pipeline_dag: + logger.error("Pipeline DAG not initialized") + return False + + dag_plan_path = dag_state["dag_plan_path"] + if not self.pipeline_dag.load_execution_plan(dag_plan_path): + logger.error("Failed to load DAG execution plan for resumption") + return False + + # Restore node states + for node_id, node_state in dag_state["node_states"].items(): + if node_id in self.pipeline_dag.nodes: + node = self.pipeline_dag.nodes[node_id] + node.status = DAGNodeStatus(node_state["status"]) + node.start_time = node_state["start_time"] + node.end_time = node_state["end_time"] + node.actual_duration = node_state["actual_duration"] + node.error_message = node_state["error_message"] + + logger.info(f"Resuming DAG execution from node: {dag_state['resumption']['resume_from_node']}") + logger.info(f"Statistics: {dag_state['statistics']}") + + # Execute remaining operations + resume_from_node = dag_state["resumption"]["resume_from_node"] + if resume_from_node: + # Find the operation index for this node + node_state = dag_state["node_states"][resume_from_node] + execution_order = node_state["execution_order"] + + # Execute operations starting from the resume point + for op_idx, op in enumerate(ops): + if op_idx >= execution_order: + op_name = op._name + node_id = self._get_dag_node_for_operation(op_name, op_idx) + + if node_id: + # Check if this node was already completed + if node_id in dag_state["node_states"]: + node_status = dag_state["node_states"][node_id]["status"] + if node_status == DAGNodeStatus.COMPLETED.value: + logger.info(f"Skipping completed node: {node_id}") + continue + + # Execute the operation with DAG monitoring + self._mark_dag_node_started(node_id) + self._log_operation_with_dag_context(op_name, op_idx, "op_start") + + start_time = time.time() + try: + dataset.process([op]) + duration = time.time() - start_time + self._mark_dag_node_completed(node_id, duration) + self._log_operation_with_dag_context( + op_name, op_idx, "op_complete", duration=duration, input_rows=0, output_rows=0 + ) + except Exception as e: + duration = time.time() - start_time + error_message = str(e) + self._mark_dag_node_failed(node_id, error_message, duration) + self._log_operation_with_dag_context( + op_name, op_idx, "op_failed", error=error_message, duration=duration + ) + raise + + return True diff --git a/data_juicer/core/executor/dag_execution_strategies.py b/data_juicer/core/executor/dag_execution_strategies.py new file mode 100644 index 0000000000..42798dce97 --- /dev/null +++ b/data_juicer/core/executor/dag_execution_strategies.py @@ -0,0 +1,337 @@ +import re +from abc import ABC, abstractmethod +from dataclasses import dataclass +from enum import Enum +from typing import Any, Dict, List, Optional + + +class DAGNodeType(Enum): + """Types of DAG nodes.""" + + OPERATION = "operation" + PARTITION_OPERATION = "partition_operation" + SCATTER_GATHER = "scatter_gather" + + +@dataclass +class ScatterGatherNode: + """Represents a scatter-gather operation in partitioned execution. + + Encapsulates the complete scatter-gather pattern: + 1. Convergence: All partitions complete their work and converge + 2. Global Operation: A single operation runs on the gathered data + 3. Redistribution: Results are redistributed back to partitions + """ + + operation_index: int + operation_name: str + input_partitions: List[int] + output_partitions: List[int] + + @property + def node_id(self) -> str: + """Generate unique node ID for scatter-gather operation.""" + return f"sg_{self.operation_index:03d}_{self.operation_name}" + + +class NodeID: + """Utility for creating and parsing standardized node IDs. + + Node ID formats: + - Operation: "op_{idx:03d}_{name}" + - Partition Operation: "op_{idx:03d}_{name}_partition_{pid}" + - Scatter-Gather: "sg_{idx:03d}_{name}" + """ + + @staticmethod + def for_operation(op_idx: int, op_name: str) -> str: + """Create node ID for global operation. + + Args: + op_idx: Operation index (0-based) + op_name: Operation name + + Returns: + Standardized node ID + """ + return f"op_{op_idx+1:03d}_{op_name}" + + @staticmethod + def for_partition_operation(partition_id: int, op_idx: int, op_name: str) -> str: + """Create node ID for partition operation. + + Args: + partition_id: Partition ID + op_idx: Operation index (0-based) + op_name: Operation name + + Returns: + Standardized node ID + """ + return f"op_{op_idx+1:03d}_{op_name}_partition_{partition_id}" + + @staticmethod + def for_scatter_gather(op_idx: int, op_name: str) -> str: + """Create node ID for scatter-gather operation. + + Args: + op_idx: Operation index (0-based) + op_name: Operation name + + Returns: + Standardized node ID + """ + return f"sg_{op_idx:03d}_{op_name}" + + @staticmethod + def parse(node_id: str) -> Optional[Dict[str, Any]]: + """Parse node ID into components. + + Args: + node_id: The node ID to parse + + Returns: + Dictionary with node type and components, or None if invalid format + + Example: + >>> NodeID.parse("op_001_mapper_partition_0") + {'type': DAGNodeType.PARTITION_OPERATION, 'partition_id': 0, + 'operation_index': 0, 'operation_name': 'mapper'} + + >>> NodeID.parse("sg_002_deduplicator") + {'type': DAGNodeType.SCATTER_GATHER, 'operation_index': 2, + 'operation_name': 'deduplicator'} + """ + # Partition operation: op_001_mapper_name_partition_0 + match = re.match(r"op_(\d+)_(.+)_partition_(\d+)", node_id) + if match: + return { + "type": DAGNodeType.PARTITION_OPERATION, + "operation_index": int(match.group(1)) - 1, # Convert back to 0-based + "operation_name": match.group(2), + "partition_id": int(match.group(3)), + } + + # Scatter-gather: sg_002_mapper_name + match = re.match(r"sg_(\d+)_(.+)", node_id) + if match: + return { + "type": DAGNodeType.SCATTER_GATHER, + "operation_index": int(match.group(1)), + "operation_name": match.group(2), + } + + # Regular operation: op_001_mapper_name + match = re.match(r"op_(\d+)_(.+)", node_id) + if match: + return { + "type": DAGNodeType.OPERATION, + "operation_index": int(match.group(1)) - 1, + "operation_name": match.group(2), + } + + return None + + +class DAGExecutionStrategy(ABC): + """Abstract base class for different DAG execution strategies.""" + + @abstractmethod + def generate_dag_nodes(self, operations: List, **kwargs) -> Dict[str, Any]: + """Generate DAG nodes based on execution strategy.""" + pass + + @abstractmethod + def get_dag_node_id(self, op_name: str, op_idx: int, **kwargs) -> str: + """Get DAG node ID for operation based on strategy.""" + pass + + @abstractmethod + def build_dependencies(self, nodes: Dict[str, Any], operations: List, **kwargs) -> None: + """Build dependencies between nodes based on strategy.""" + pass + + @abstractmethod + def can_execute_node(self, node_id: str, nodes: Dict[str, Any], completed_nodes: set) -> bool: + """Check if a node can be executed based on strategy.""" + pass + + +class NonPartitionedDAGStrategy(DAGExecutionStrategy): + """Strategy for non-partitioned executors (default, ray).""" + + def generate_dag_nodes(self, operations: List, **kwargs) -> Dict[str, Any]: + """Generate DAG nodes for non-partitioned execution.""" + nodes = {} + for op_idx, op in enumerate(operations): + node_id = self.get_dag_node_id(op._name, op_idx) + nodes[node_id] = { + "node_id": node_id, + "operation_name": op._name, + "execution_order": op_idx + 1, + "node_type": DAGNodeType.OPERATION.value, + "partition_id": None, + "dependencies": [], + "status": "pending", + "start_time": None, + "end_time": None, + "actual_duration": None, + "error_message": None, + } + return nodes + + def get_dag_node_id(self, op_name: str, op_idx: int, **kwargs) -> str: + """Get DAG node ID for non-partitioned operation.""" + return f"op_{op_idx+1:03d}_{op_name}" + + def build_dependencies(self, nodes: Dict[str, Any], operations: List, **kwargs) -> None: + """Build sequential dependencies for non-partitioned execution.""" + # Simple sequential dependencies + for i in range(1, len(operations)): + current_node = self.get_dag_node_id(operations[i]._name, i) + prev_node = self.get_dag_node_id(operations[i - 1]._name, i - 1) + if current_node in nodes and prev_node in nodes: + nodes[current_node]["dependencies"].append(prev_node) + + def can_execute_node(self, node_id: str, nodes: Dict[str, Any], completed_nodes: set) -> bool: + """Check if a node can be executed (all dependencies completed).""" + if node_id not in nodes: + return False + node = nodes[node_id] + return all(dep in completed_nodes for dep in node["dependencies"]) + + +class PartitionedDAGStrategy(DAGExecutionStrategy): + """Strategy for partitioned executors (ray_partitioned).""" + + def __init__(self, num_partitions: int): + self.num_partitions = num_partitions + + def generate_dag_nodes(self, operations: List, **kwargs) -> Dict[str, Any]: + """Generate DAG nodes for partitioned execution using scatter-gather pattern.""" + nodes = {} + convergence_points = kwargs.get("convergence_points", []) + + # Generate partition-specific nodes + for partition_id in range(self.num_partitions): + for op_idx, op in enumerate(operations): + node_id = self.get_dag_node_id(op._name, op_idx, partition_id=partition_id) + nodes[node_id] = { + "node_id": node_id, + "operation_name": op._name, + "execution_order": op_idx + 1, + "node_type": DAGNodeType.PARTITION_OPERATION.value, + "partition_id": partition_id, + "dependencies": [], + "status": "pending", + "start_time": None, + "end_time": None, + "actual_duration": None, + "error_message": None, + } + + # Generate scatter-gather nodes for global operations + for conv_idx, conv_point in enumerate(convergence_points): + if conv_point < len(operations): + op = operations[conv_point] + sg_node = ScatterGatherNode( + operation_index=conv_point, + operation_name=op._name, + input_partitions=list(range(self.num_partitions)), + output_partitions=list(range(self.num_partitions)), + ) + + nodes[sg_node.node_id] = { + "node_id": sg_node.node_id, + "operation_name": op._name, + "execution_order": conv_point + 1, + "node_type": DAGNodeType.SCATTER_GATHER.value, + "operation_index": conv_point, + "input_partitions": sg_node.input_partitions, + "output_partitions": sg_node.output_partitions, + "dependencies": [], + "status": "pending", + "start_time": None, + "end_time": None, + "actual_duration": None, + "error_message": None, + "scatter_gather_node": sg_node, + } + + return nodes + + def get_dag_node_id(self, op_name: str, op_idx: int, partition_id: int = None, **kwargs) -> str: + """Get DAG node ID for partitioned operation.""" + if partition_id is not None: + return f"op_{op_idx+1:03d}_{op_name}_partition_{partition_id}" + else: + return f"op_{op_idx+1:03d}_{op_name}" + + def build_dependencies(self, nodes: Dict[str, Any], operations: List, **kwargs) -> None: + """Build dependencies for partitioned execution using scatter-gather pattern. + + - Partition operations depend on previous operation in same partition + - Scatter-gather nodes depend on ALL partitions from previous op + - Post-scatter-gather partition ops depend on the scatter-gather node + """ + convergence_points = kwargs.get("convergence_points", []) + + # Find all scatter-gather nodes + sg_nodes = { + node_id: node + for node_id, node in nodes.items() + if node.get("node_type") == DAGNodeType.SCATTER_GATHER.value + } + + # Build partition-specific dependencies + for partition_id in range(self.num_partitions): + prev_node_id = None + for op_idx, op in enumerate(operations): + # Skip operations that are scatter-gather points + if op_idx in convergence_points: + # Find the scatter-gather node for this operation + sg_node_id = None + for nid, node in sg_nodes.items(): + if node.get("operation_index") == op_idx: + sg_node_id = nid + break + + if sg_node_id: + # Scatter-gather node depends on all partitions from previous op + if prev_node_id: + for pid in range(self.num_partitions): + dep_node = self.get_dag_node_id(operations[op_idx]._name, op_idx, partition_id=pid) + if dep_node in nodes: + nodes[sg_node_id]["dependencies"].append(dep_node) + + # Update prev_node for next iteration + prev_node_id = sg_node_id + continue + + # Regular partition operation + node_id = self.get_dag_node_id(op._name, op_idx, partition_id=partition_id) + if node_id in nodes: + # Depends on previous node in this partition (could be partition op or scatter-gather) + if prev_node_id: + nodes[node_id]["dependencies"].append(prev_node_id) + prev_node_id = node_id + + def can_execute_node(self, node_id: str, nodes: Dict[str, Any], completed_nodes: set) -> bool: + """Check if a node can be executed (all dependencies completed).""" + if node_id not in nodes: + return False + node = nodes[node_id] + return all(dep in completed_nodes for dep in node["dependencies"]) + + +def is_global_operation(operation) -> bool: + """Check if an operation is a global operation that requires convergence.""" + # Deduplicators are typically global operations + if "deduplicator" in getattr(operation, "_name", ""): + return True + + # Check for explicit global operation flag + if getattr(operation, "is_global_operation", False): + return True + + return False diff --git a/data_juicer/core/executor/default_executor.py b/data_juicer/core/executor/default_executor.py index cefe914c0c..dd0f84edef 100644 --- a/data_juicer/core/executor/default_executor.py +++ b/data_juicer/core/executor/default_executor.py @@ -11,6 +11,8 @@ from data_juicer.core.data import NestedDataset from data_juicer.core.data.dataset_builder import DatasetBuilder from data_juicer.core.executor import ExecutorBase +from data_juicer.core.executor.dag_execution_mixin import DAGExecutionMixin +from data_juicer.core.executor.event_logging_mixin import EventLoggingMixin from data_juicer.core.exporter import Exporter from data_juicer.core.tracer import Tracer from data_juicer.ops import load_ops @@ -24,7 +26,7 @@ from data_juicer.utils.sample import random_sample -class DefaultExecutor(ExecutorBase): +class DefaultExecutor(ExecutorBase, DAGExecutionMixin, EventLoggingMixin): """ This Executor class is used to process a specific dataset. @@ -39,10 +41,17 @@ def __init__(self, cfg: Optional[Namespace] = None): :param cfg: optional jsonargparse Namespace. """ super().__init__(cfg) - self.executor_type = "default" + # If work_dir contains job_id, all outputs go under it self.work_dir = self.cfg.work_dir - self.tracer = None + # Initialize EventLoggingMixin for job management and event logging + EventLoggingMixin.__init__(self, cfg) + + # Initialize DAGExecutionMixin for AST/DAG functionality + DAGExecutionMixin.__init__(self) + # Set executor type for strategy selection + self.executor_type = "default" + self.ckpt_manager = None self.adapter = Adapter(self.cfg) @@ -150,6 +159,27 @@ def run( logger.info("Preparing process operators...") ops = load_ops(self.cfg.process) + # Initialize DAG execution planning + self._initialize_dag_execution(self.cfg) + + # Log job start with DAG context + # Handle both dataset_path (string) and dataset (dict) configurations + dataset_info = {} + if hasattr(self.cfg, "dataset_path") and self.cfg.dataset_path: + dataset_info["dataset_path"] = self.cfg.dataset_path + if hasattr(self.cfg, "dataset") and self.cfg.dataset: + dataset_info["dataset"] = self.cfg.dataset + + job_config = { + **dataset_info, + "work_dir": self.work_dir, + "executor_type": self.executor_type, + "dag_node_count": len(self.pipeline_dag.nodes) if self.pipeline_dag else 0, + "dag_edge_count": len(self.pipeline_dag.edges) if self.pipeline_dag else 0, + "parallel_groups_count": len(self.pipeline_dag.parallel_groups) if self.pipeline_dag else 0, + } + self.log_job_start(job_config, len(ops)) + # OP fusion if self.cfg.op_fusion: probe_res = None @@ -171,20 +201,31 @@ def run( if op.is_batched_op(): op.batch_size = bs_per_op[i] - # 3. data process + # 3. data process with DAG monitoring # - If tracer is open, trace each op after it's processed # - If checkpoint is open, clean the cache files after each process - logger.info("Processing data...") + logger.info("Processing data with DAG monitoring...") tstart = time() + + # Pre-execute DAG monitoring (log operation start events) + if self.pipeline_dag: + self._pre_execute_operations_with_dag_monitoring(ops) + + # Execute operations with executor-specific parameters dataset = dataset.process( ops, work_dir=self.work_dir, exporter=self.exporter, checkpointer=self.ckpt_manager, - tracer=self.tracer, + tracer=self.tracer if self.cfg.open_tracer else None, adapter=self.adapter, open_monitor=self.cfg.open_monitor, ) + + # Post-execute DAG monitoring (log operation completion events) + if self.pipeline_dag: + self._post_execute_operations_with_dag_monitoring(ops) + tend = time() logger.info(f"All OPs are done in {tend - tstart:.3f}s.") @@ -198,6 +239,10 @@ def run( compress(dataset) + # Log job completion with DAG context + job_duration = time() - tstart + self.log_job_complete(job_duration, self.cfg.export_path) + if not skip_return: return dataset diff --git a/data_juicer/core/executor/event_logging_mixin.py b/data_juicer/core/executor/event_logging_mixin.py new file mode 100644 index 0000000000..c994b455ad --- /dev/null +++ b/data_juicer/core/executor/event_logging_mixin.py @@ -0,0 +1,1237 @@ +#!/usr/bin/env python3 +""" +Event Logging Mixin for Data-Juicer Executors + +This module provides comprehensive event logging capabilities that can be used +by any executor (default, ray, partitioned, etc.) to track operations, +performance, and errors in real-time. + +Features: +1. Real-time event logging with configurable levels +2. Event filtering and querying +3. Performance metrics tracking +4. Error tracking with stack traces +5. Status reporting and monitoring +6. Log rotation and cleanup +""" + +import json +import os +import re +import threading +import time +from collections import defaultdict, deque +from dataclasses import dataclass +from datetime import datetime +from enum import Enum +from pathlib import Path +from typing import Any, Dict, Generator, List, Optional +from uuid import uuid4 + +from loguru import logger + + +class EventType(Enum): + """Types of events that can be logged.""" + + JOB_START = "job_start" + JOB_COMPLETE = "job_complete" + JOB_FAILED = "job_failed" + JOB_RESTART = "job_restart" # New: Job restart event + PARTITION_START = "partition_start" + PARTITION_COMPLETE = "partition_complete" + PARTITION_FAILED = "partition_failed" + PARTITION_RESUME = "partition_resume" # New: Partition resume event + OP_START = "op_start" + OP_COMPLETE = "op_complete" + OP_FAILED = "op_failed" + CHECKPOINT_SAVE = "checkpoint_save" + CHECKPOINT_LOAD = "checkpoint_load" + PROCESSING_START = "processing_start" + PROCESSING_COMPLETE = "processing_complete" + PROCESSING_ERROR = "processing_error" + # DAG-specific events + DAG_BUILD_START = "dag_build_start" + DAG_BUILD_COMPLETE = "dag_build_complete" + DAG_NODE_READY = "dag_node_ready" + DAG_NODE_START = "dag_node_start" + DAG_NODE_COMPLETE = "dag_node_complete" + DAG_NODE_FAILED = "dag_node_failed" + DAG_PARALLEL_GROUP_START = "dag_parallel_group_start" + DAG_PARALLEL_GROUP_COMPLETE = "dag_parallel_group_complete" + DAG_EXECUTION_PLAN_SAVED = "dag_execution_plan_saved" + DAG_EXECUTION_PLAN_LOADED = "dag_execution_plan_loaded" + + +@dataclass +class Event: + """Event data structure.""" + + event_type: EventType + timestamp: float + message: str + event_id: Optional[str] = None + job_id: Optional[str] = None + partition_id: Optional[int] = None + operation_name: Optional[str] = None + operation_idx: Optional[int] = None + status: Optional[str] = None + duration: Optional[float] = None + error_message: Optional[str] = None + stack_trace: Optional[str] = None + retry_count: Optional[int] = None + checkpoint_path: Optional[str] = None + op_args: Optional[Dict[str, Any]] = None + input_rows: Optional[int] = None + output_rows: Optional[int] = None + output_path: Optional[str] = None + partition_meta: Optional[Dict[str, Any]] = None + config: Optional[Dict[str, Any]] = None + metadata: Optional[Dict[str, Any]] = None + total_partitions: Optional[int] = None + successful_partitions: Optional[int] = None + failed_partitions: Optional[int] = None + job_duration: Optional[float] = None + completion_time: Optional[float] = None + failure_time: Optional[float] = None + error_type: Optional[str] = None + # Process and thread tracking + process_id: Optional[int] = None + thread_id: Optional[int] = None + + +class EventLogger: + """Event logging system with real-time capabilities and JSONL event log for resumability.""" + + def __init__(self, log_dir: str, job_id: Optional[str] = None, work_dir: Optional[str] = None): + self.log_dir = Path(log_dir) + self.log_dir.mkdir(parents=True, exist_ok=True) + # Use provided job_id or generate a simple timestamp-based one + self.job_id = job_id or f"{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}-{uuid4().hex[:6]}" + self.events: deque = deque(maxlen=10000) + self.event_lock = threading.Lock() + + # Use work_dir for JSONL file if provided, otherwise use log_dir + self.jsonl_dir = Path(work_dir) if work_dir else self.log_dir + self.jsonl_dir.mkdir(parents=True, exist_ok=True) + + # Create timestamped events file + timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S") + self.jsonl_file = self.jsonl_dir / f"events_{timestamp}.jsonl" + + def log_event(self, event: Event): + """Log an event (to memory, loguru, and JSONL for resumability).""" + with self.event_lock: + event.job_id = self.job_id + self.events.append(event) + # Log to file (loguru) + log_message = self._format_event_for_logging(event) + logger.info(log_message) + # Write to JSONL for resumability + with open(self.jsonl_file, "a") as f: + f.write( + json.dumps( + {k: (v.value if isinstance(v, Enum) else v) for k, v in event.__dict__.items() if v is not None} + ) + + "\n" + ) + + def find_latest_events_file(self, work_dir: str) -> Optional[Path]: + """Find the latest events file in the work directory.""" + events_dir = Path(work_dir) + if not events_dir.exists(): + return None + + # Find all events files with timestamp pattern + events_files = list(events_dir.glob("events_*.jsonl")) + if not events_files: + return None + + # Sort by modification time and return the latest + latest_file = max(events_files, key=lambda f: f.stat().st_mtime) + return latest_file + + def check_job_completion(self, events_file: Path) -> bool: + """Check if job is already completed by looking for job_complete event.""" + if not events_file.exists(): + return False + + try: + with open(events_file, "r") as f: + for line in f: + if line.strip(): + event = json.loads(line.strip()) + if event.get("event_type") == "job_complete": + return True + except (json.JSONDecodeError, IOError) as e: + logger.warning(f"Error reading events file {events_file}: {e}") + + return False + + def _format_event_for_logging(self, event: Event) -> str: + """Format event for logging with enhanced details.""" + parts = [f"EVENT[{event.event_type.value}]", f"TIME[{datetime.fromtimestamp(event.timestamp).isoformat()}]"] + + if event.partition_id is not None: + parts.append(f"PARTITION[{event.partition_id}]") + + if event.operation_name: + parts.append(f"OP[{event.operation_name}]") + if event.operation_idx is not None: + parts.append(f"OP_IDX[{event.operation_idx}]") + + if event.duration is not None: + # Handle case where duration might be a string (due to parameter order issues) + try: + if isinstance(event.duration, (int, float)): + parts.append(f"DURATION[{event.duration:.3f}s]") + else: + parts.append(f"DURATION[{event.duration}]") + except (ValueError, TypeError): + parts.append(f"DURATION[{event.duration}]") + + parts.append(f"MSG[{event.message}]") + + if event.error_message: + parts.append(f"ERROR[{event.error_message}]") + + if event.checkpoint_path: + parts.append(f"CHECKPOINT[{os.path.basename(event.checkpoint_path)}]") + + if event.output_path: + parts.append(f"OUTPUT[{os.path.basename(event.output_path)}]") + + if event.metadata: + # Include key metadata in the log message + key_metadata = {} + for key in ["status", "retry_count", "error_type", "operation_class"]: + if key in event.metadata: + key_metadata[key] = event.metadata[key] + if key_metadata: + parts.append(f"META[{json.dumps(key_metadata)}]") + + return " | ".join(parts) + + def get_events( + self, + event_type: Optional[EventType] = None, + partition_id: Optional[int] = None, + operation_name: Optional[str] = None, + start_time: Optional[float] = None, + end_time: Optional[float] = None, + limit: Optional[int] = None, + ) -> List[Event]: + """Get events with optional filtering.""" + with self.event_lock: + filtered_events = [] + + for event in self.events: + # Apply filters + if event_type and event.event_type != event_type: + continue + if partition_id is not None and event.partition_id != partition_id: + continue + if operation_name and event.operation_name != operation_name: + continue + if start_time and event.timestamp < start_time: + continue + if end_time and event.timestamp > end_time: + continue + + filtered_events.append(event) + + # Apply limit + if limit: + filtered_events = filtered_events[-limit:] + + return filtered_events + + def generate_status_report(self) -> str: + """Generate a comprehensive status report.""" + with self.event_lock: + total_events = len(self.events) + if total_events == 0: + return "No events logged yet." + + # Count event types + event_counts = defaultdict(int) + error_count = 0 + warning_count = 0 + + for event in self.events: + event_counts[event.event_type.value] += 1 + + # Generate report + report_lines = [ + "=== EVENT LOGGING STATUS REPORT ===", + f"Total Events: {total_events}", + f"Errors: {error_count}", + f"Warnings: {warning_count}", + "", + "Event Type Distribution:", + ] + + for event_type, count in sorted(event_counts.items()): + percentage = (count / total_events) * 100 + report_lines.append(f" {event_type}: {count} ({percentage:.1f}%)") + + return "\n".join(report_lines) + + def monitor_events(self, event_type: Optional[EventType] = None) -> Generator[Event, None, None]: + """Monitor events in real-time.""" + last_event_count = len(self.events) + + while True: + with self.event_lock: + current_events = list(self.events) + + # Yield new events + for event in current_events[last_event_count:]: + if event_type is None or event.event_type == event_type: + yield event + + last_event_count = len(current_events) + time.sleep(0.1) # Check every 100ms + + @classmethod + def list_available_jobs(cls, work_dir: str) -> List[Dict[str, Any]]: + """List available jobs for resumption from a work directory.""" + available_jobs = [] + + if not os.path.exists(work_dir): + return available_jobs + + # Look for job directories (each job has its own directory) + for item in os.listdir(work_dir): + job_work_dir = os.path.join(work_dir, item) + if os.path.isdir(job_work_dir): + summary_file = os.path.join(job_work_dir, "job_summary.json") + if os.path.exists(summary_file): + try: + with open(summary_file, "r") as f: + job_summary = json.load(f) + job_summary["work_dir"] = job_work_dir + available_jobs.append(job_summary) + except Exception as e: + logger.warning(f"Failed to load job summary from {summary_file}: {e}") + + return available_jobs + + +class EventLoggingMixin: + """Mixin to add event logging capabilities to any executor.""" + + def __init__(self, *args, **kwargs): + """Initialize the mixin.""" + # Initialize event logging if not already done + if not hasattr(self, "event_logger"): + self._setup_event_logging() + + def _setup_event_logging(self): + """Setup event logging for the executor.""" + # Get event logging configuration + event_config = getattr(self.cfg, "event_logging", {}) + enabled = event_config.get("enabled", True) + + if not enabled: + self.event_logger = None + return + + # job_id and work_dir should already be resolved by resolve_job_directories() in config.py + job_id = getattr(self.cfg, "job_id", None) + if not job_id: + raise ValueError( + "job_id must be set before setting up event logging. " + "This should have been done by resolve_job_id() in config.py" + ) + + # work_dir already includes job_id after resolve_job_directories + # Create work directory and subdirectories + os.makedirs(self.work_dir, exist_ok=True) + + # Use logs directory instead of event_logs + logs_dir = os.path.join(self.work_dir, "logs") + os.makedirs(logs_dir, exist_ok=True) + + self.event_logger = EventLogger(logs_dir, job_id=job_id, work_dir=self.work_dir) + + logger.info(f"Event logging initialized for {self.executor_type} executor") + + def _update_job_summary(self, status: str, end_time: Optional[float] = None, error_message: Optional[str] = None): + """Update job summary with completion status.""" + # work_dir already includes job_id after resolve_job_directories + summary_file = os.path.join(self.work_dir, "job_summary.json") + + if not os.path.exists(summary_file): + return + + with open(summary_file, "r") as f: + job_summary = json.load(f) + + job_summary.update( + { + "status": status, + "end_time": end_time or time.time(), + "duration": (end_time or time.time()) - job_summary.get("start_time", time.time()), + "error_message": error_message, + } + ) + + with open(summary_file, "w") as f: + json.dump(job_summary, f, indent=2, default=str) + + # Display completion info + if status == "completed": + logger.info("=" * 60) + logger.info("DataJuicer Job Completed Successfully") + logger.info(f"Duration: {job_summary['duration']:.2f} seconds") + logger.info("=" * 60) + elif status == "failed": + logger.error("=" * 60) + logger.error("DataJuicer Job Failed") + logger.error(f"Error: {error_message}") + logger.error(f"Duration: {job_summary['duration']:.2f} seconds") + logger.error("=" * 60) + logger.error("To resume this job, use:") + logger.error(f" {job_summary['resumption_command']}") + logger.error("=" * 60) + + def _load_job_summary(self) -> Optional[Dict[str, Any]]: + """Load job summary if it exists.""" + # work_dir already includes job_id after resolve_job_directories + summary_file = os.path.join(self.work_dir, "job_summary.json") + + if os.path.exists(summary_file): + with open(summary_file, "r") as f: + return json.load(f) + return None + + def _get_config_name(self) -> str: + """Extract a meaningful name from config file or project name.""" + # Try to get config file name first + config_file = getattr(self.cfg, "config", None) + if config_file: + # Extract filename without extension and path + config_name = os.path.splitext(os.path.basename(config_file))[0] + # Clean up the name (remove special chars, limit length) + config_name = re.sub(r"[^a-zA-Z0-9_-]", "_", config_name) + config_name = config_name[:20] # Limit length + if config_name: + return config_name + + # Fall back to project name + project_name = getattr(self.cfg, "project_name", "dj") + # Clean up project name + project_name = re.sub(r"[^a-zA-Z0-9_-]", "_", project_name) + project_name = project_name[:15] # Limit length + + return project_name + + def _add_dag_context_to_metadata( + self, metadata: Dict[str, Any], operation_name: str, operation_idx: int, partition_id: int + ): + """Add DAG context to metadata if DAGExecutionMixin is available.""" + # Check if DAGExecutionMixin is available and has the method to get DAG node + if hasattr(self, "_get_dag_node_for_operation"): + try: + node_id = self._get_dag_node_for_operation(operation_name, operation_idx, partition_id=partition_id) + if node_id: + metadata["dag_node_id"] = node_id + else: + logger.debug(f"DAG node not found for operation {operation_name} (idx {operation_idx})") + except Exception as e: + logger.debug(f"Error getting DAG node for operation {operation_name}: {e}") + + def _log_event(self, event_type: EventType, message: str, **kwargs): + """Log an event if event logging is enabled.""" + if self.event_logger is None: + logger.warning(f"Event logger is None, cannot log event: {event_type.value}") + return + + # Automatically capture process and thread IDs + process_id = os.getpid() + thread_id = threading.get_ident() + + # Generate event ID if not provided + event_id = kwargs.pop("event_id", None) + if event_id is None: + timestamp = int(time.time()) + event_id = f"{event_type.value}_{timestamp}_{uuid4().hex[:8]}" + + logger.debug(f"Creating event: {event_type.value} - {message}") + event = Event( + event_type=event_type, + timestamp=time.time(), + message=message, + event_id=event_id, + process_id=process_id, + thread_id=thread_id, + **kwargs, + ) + logger.debug(f"Logging event to event logger: {event_type.value}") + self.event_logger.log_event(event) + logger.debug(f"Successfully logged event: {event_type.value}") + + # Add new logging methods for job, partition, and op events + def log_job_start(self, config, total_partitions): + """Log job start with detailed configuration.""" + # Handle both dataset_path (string) and dataset (dict) configurations + dataset_info = {} + if "dataset_path" in config: + dataset_info["dataset_path"] = config.get("dataset_path") + if "dataset" in config: + dataset_info["dataset"] = config.get("dataset") + + metadata = { + "total_partitions": total_partitions, + "config_summary": { + **dataset_info, + "executor_type": config.get("executor_type"), + "partition_size": config.get("partition_size"), + "checkpoint_strategy": config.get("checkpoint_strategy"), + "storage_format": config.get("storage_format"), + "compression": config.get("compression"), + }, + } + event_id = f"job_start_{int(time.time())}" + self._log_event( + EventType.JOB_START, + "Job started", + event_id=event_id, + config=config, + metadata=metadata, + total_partitions=total_partitions, + ) + + def log_job_complete(self, duration, output_path=None): + """Log job completion with performance metrics.""" + metadata = {"status": "completed", "duration_seconds": duration, "completion_time": time.time()} + if output_path: + metadata["output_path"] = output_path + + event_id = f"job_complete_{int(time.time())}" + self._log_event( + EventType.JOB_COMPLETE, + f"Job completed successfully in {duration:.2f}s", + event_id=event_id, + status="completed", + duration=duration, + metadata=metadata, + ) + self._update_job_summary("completed", error_message=None) + + def log_job_failed(self, error_message, duration): + """Log job failure with error details.""" + metadata = { + "status": "failed", + "duration_seconds": duration, + "failure_time": time.time(), + "error_type": type(error_message).__name__ if error_message else "Unknown", + } + event_id = f"job_failed_{int(time.time())}" + self._log_event( + EventType.JOB_FAILED, + f"Job failed: {error_message}", + event_id=event_id, + status="failed", + error_message=error_message, + duration=duration, + metadata=metadata, + ) + self._update_job_summary("failed", error_message=error_message) + + def log_partition_start(self, partition_id, partition_meta): + """Log partition start with detailed metadata.""" + metadata = { + "partition_path": partition_meta.get("partition_path"), + "start_time": partition_meta.get("start_time"), + "partition_size_bytes": partition_meta.get("file_size_bytes"), + "sample_count": partition_meta.get("sample_count"), + } + event_id = f"partition_start_{partition_id}_{int(time.time())}" + self._log_event( + EventType.PARTITION_START, + f"Partition {partition_id} started processing", + event_id=event_id, + partition_id=partition_id, + partition_meta=partition_meta, + metadata=metadata, + ) + + def log_partition_complete(self, partition_id, duration, output_path, success=True, error=None): + """Log partition completion with performance metrics.""" + metadata = { + "output_path": output_path, + "duration_seconds": duration, + "completion_time": time.time(), + "success": success, + "throughput_samples_per_second": None, # Will be calculated if sample_count is available + } + + if not success and error: + metadata["error"] = error + message = f"Partition {partition_id} completed with failure after {duration:.2f}s: {error}" + else: + message = f"Partition {partition_id} completed successfully after {duration:.2f}s" + + # Add debug logging to help diagnose issues + logger.debug(f"Creating partition_complete event for partition {partition_id}") + logger.debug(f" Duration: {duration:.2f}s") + logger.debug(f" Success: {success}") + logger.debug(f" Output path: {output_path}") + if error: + logger.debug(f" Error: {error}") + + # Use the _log_event method to ensure proper logging + event_id = f"partition_complete_{partition_id}_{int(time.time())}" + self._log_event( + EventType.PARTITION_COMPLETE, message, event_id=event_id, partition_id=partition_id, metadata=metadata + ) + + def log_partition_failed(self, partition_id, error_message, retry_count): + """Log partition failure with retry information.""" + metadata = { + "retry_count": retry_count, + "failure_time": time.time(), + "error_type": type(error_message).__name__ if error_message else "Unknown", + } + event_id = f"partition_failed_{partition_id}_{int(time.time())}" + self._log_event( + EventType.PARTITION_FAILED, + f"Partition {partition_id} failed after {retry_count} retries: {error_message}", + event_id=event_id, + partition_id=partition_id, + error_message=error_message, + retry_count=retry_count, + status="failed", + metadata=metadata, + ) + + def log_op_start(self, partition_id, operation_name, operation_idx, op_args, **kwargs): + """Log operation start with detailed arguments.""" + metadata = { + "operation_idx": operation_idx, + "operation_args": op_args, + "start_time": time.time(), + "operation_class": operation_name, + } + # Merge any additional metadata from kwargs + if "metadata" in kwargs: + metadata.update(kwargs["metadata"]) + + # Automatically add DAG context if DAGExecutionMixin is available + self._add_dag_context_to_metadata(metadata, operation_name, operation_idx, partition_id) + + event_id = f"op_start_{partition_id}_{operation_idx}_{int(time.time())}" + self._log_event( + EventType.OP_START, + f"Operation {operation_name} (idx {operation_idx}) started on partition {partition_id}", + event_id=event_id, + partition_id=partition_id, + operation_name=operation_name, + operation_idx=operation_idx, + op_args=op_args, + metadata=metadata, + ) + + def log_op_complete( + self, partition_id, operation_name, operation_idx, duration, checkpoint_path, input_rows, output_rows, **kwargs + ): + """Log operation completion with detailed performance metrics.""" + # Build metadata with only meaningful metrics + metadata = { + "duration_seconds": duration, + "checkpoint_path": checkpoint_path, + "completion_time": time.time(), + "operation_class": operation_name, + } + + # Only include row counts and derived metrics if they're meaningful (non-zero or explicitly set) + if input_rows is not None and input_rows > 0: + metadata["input_rows"] = input_rows + if output_rows is not None and output_rows > 0: + metadata["output_rows"] = output_rows + + # Calculate derived metrics only if we have valid row counts + if input_rows and output_rows is not None: + if duration > 0: + metadata["throughput_rows_per_second"] = input_rows / duration + if input_rows > 0: + metadata["reduction_ratio"] = (input_rows - output_rows) / input_rows + + # Merge any additional metadata from kwargs + if "metadata" in kwargs: + metadata.update(kwargs["metadata"]) + + # Automatically add DAG context if DAGExecutionMixin is available + self._add_dag_context_to_metadata(metadata, operation_name, operation_idx, partition_id) + + # Build message without row counts (they're in metadata if meaningful) + event_id = f"op_complete_{partition_id}_{operation_idx}_{int(time.time())}" + self._log_event( + EventType.OP_COMPLETE, + f"Operation {operation_name} (idx {operation_idx}) completed on partition {partition_id} in {duration:.3f}s", + event_id=event_id, + partition_id=partition_id, + operation_name=operation_name, + operation_idx=operation_idx, + status="success", + metadata=metadata, + ) + + def log_op_failed(self, partition_id, operation_name, operation_idx, error_message, retry_count, **kwargs): + """Log operation failure with error details.""" + metadata = { + "retry_count": retry_count, + "failure_time": time.time(), + "error_type": type(error_message).__name__ if error_message else "Unknown", + "operation_class": operation_name, + } + # Merge any additional metadata from kwargs + if "metadata" in kwargs: + metadata.update(kwargs["metadata"]) + + # Automatically add DAG context if DAGExecutionMixin is available + self._add_dag_context_to_metadata(metadata, operation_name, operation_idx, partition_id) + + event_id = f"op_failed_{partition_id}_{operation_idx}_{int(time.time())}" + self._log_event( + EventType.OP_FAILED, + f"Operation {operation_name} (idx {operation_idx}) failed on partition {partition_id}: {error_message}", + event_id=event_id, + partition_id=partition_id, + operation_name=operation_name, + operation_idx=operation_idx, + error_message=error_message, + retry_count=retry_count, + status="failed", + metadata=metadata, + ) + + def log_checkpoint_save(self, partition_id, operation_name, operation_idx, checkpoint_path): + """Log checkpoint save with file information.""" + metadata = { + "checkpoint_path": checkpoint_path, + "operation_idx": operation_idx, + "operation_class": operation_name, + "save_time": time.time(), + } + event_id = f"checkpoint_save_{partition_id}_{operation_idx}_{int(time.time())}" + self._log_event( + EventType.CHECKPOINT_SAVE, + f"Checkpoint saved for operation {operation_name} (idx {operation_idx}) on partition {partition_id}", + event_id=event_id, + partition_id=partition_id, + operation_name=operation_name, + operation_idx=operation_idx, + checkpoint_path=checkpoint_path, + metadata=metadata, + ) + + def log_checkpoint_load(self, partition_id, operation_name, operation_idx, checkpoint_path): + """Log checkpoint load with file information.""" + metadata = { + "checkpoint_path": checkpoint_path, + "operation_idx": operation_idx, + "operation_class": operation_name, + "load_time": time.time(), + } + event_id = f"checkpoint_load_{partition_id}_{operation_idx}_{int(time.time())}" + self._log_event( + EventType.CHECKPOINT_LOAD, + f"Checkpoint loaded for operation {operation_name} (idx {operation_idx}) on partition {partition_id}", + event_id=event_id, + partition_id=partition_id, + operation_name=operation_name, + operation_idx=operation_idx, + checkpoint_path=checkpoint_path, + metadata=metadata, + ) + + # DAG-specific event logging methods + def log_dag_build_start(self, ast_info: Dict[str, Any]): + """Log DAG build start with AST information.""" + metadata = { + "ast_node_count": ast_info.get("node_count", 0), + "ast_depth": ast_info.get("depth", 0), + "ast_operation_types": ast_info.get("operation_types", []), + "build_start_time": time.time(), + } + event_id = f"dag_build_start_{int(time.time())}" + self._log_event( + EventType.DAG_BUILD_START, + "DAG build started from pipeline AST", + event_id=event_id, + metadata=metadata, + ) + + def log_dag_build_complete(self, dag_info: Dict[str, Any]): + """Log DAG build completion with execution plan information.""" + metadata = { + "dag_node_count": dag_info.get("node_count", 0), + "dag_edge_count": dag_info.get("edge_count", 0), + "parallel_groups_count": dag_info.get("parallel_groups_count", 0), + "execution_plan_length": dag_info.get("execution_plan_length", 0), + "build_duration": dag_info.get("build_duration", 0), + "build_complete_time": time.time(), + } + event_id = f"dag_build_complete_{int(time.time())}" + self._log_event( + EventType.DAG_BUILD_COMPLETE, + f"DAG build completed: {dag_info.get('node_count', 0)} nodes, {dag_info.get('edge_count', 0)} edges", + event_id=event_id, + metadata=metadata, + ) + + def log_dag_node_ready(self, node_id: str, node_info: Dict[str, Any]): + """Log when a DAG node becomes ready for execution.""" + metadata = { + "node_id": node_id, + "op_name": node_info.get("op_name"), + "op_type": node_info.get("op_type"), + "dependencies_count": node_info.get("dependencies_count", 0), + "dependents_count": node_info.get("dependents_count", 0), + "execution_order": node_info.get("execution_order", -1), + "ready_time": time.time(), + } + event_id = f"dag_node_ready_{node_id}_{int(time.time())}" + self._log_event( + EventType.DAG_NODE_READY, + f"DAG node {node_id} ({node_info.get('op_name')}) ready for execution", + event_id=event_id, + metadata=metadata, + ) + + def log_dag_node_start(self, node_id: str, node_info: Dict[str, Any]): + """Log when a DAG node starts execution.""" + metadata = { + "node_id": node_id, + "op_name": node_info.get("op_name"), + "op_type": node_info.get("op_type"), + "execution_order": node_info.get("execution_order", -1), + "start_time": time.time(), + } + event_id = f"dag_node_start_{node_id}_{int(time.time())}" + self._log_event( + EventType.DAG_NODE_START, + f"DAG node {node_id} ({node_info.get('op_name')}) started execution", + event_id=event_id, + metadata=metadata, + ) + + def log_dag_node_complete(self, node_id: str, node_info: Dict[str, Any], duration: float): + """Log when a DAG node completes execution.""" + metadata = { + "node_id": node_id, + "op_name": node_info.get("op_name"), + "op_type": node_info.get("op_type"), + "execution_order": node_info.get("execution_order", -1), + "duration_seconds": duration, + "completion_time": time.time(), + } + event_id = f"dag_node_complete_{node_id}_{int(time.time())}" + self._log_event( + EventType.DAG_NODE_COMPLETE, + f"DAG node {node_id} ({node_info.get('op_name')}) completed in {duration:.3f}s", + event_id=event_id, + metadata=metadata, + ) + + def log_dag_node_failed(self, node_id: str, node_info: Dict[str, Any], error_message: str, duration: float = 0): + """Log when a DAG node fails execution.""" + metadata = { + "node_id": node_id, + "op_name": node_info.get("op_name"), + "op_type": node_info.get("op_type"), + "execution_order": node_info.get("execution_order", -1), + "duration_seconds": duration, + "error_message": error_message, + "failure_time": time.time(), + } + event_id = f"dag_node_failed_{node_id}_{int(time.time())}" + self._log_event( + EventType.DAG_NODE_FAILED, + f"DAG node {node_id} ({node_info.get('op_name')}) failed: {error_message}", + event_id=event_id, + metadata=metadata, + ) + + def log_dag_parallel_group_start(self, group_id: str, group_info: Dict[str, Any]): + """Log when a parallel group starts execution.""" + metadata = { + "group_id": group_id, + "node_count": group_info.get("node_count", 0), + "node_ids": group_info.get("node_ids", []), + "op_types": group_info.get("op_types", []), + "start_time": time.time(), + } + event_id = f"dag_parallel_group_start_{group_id}_{int(time.time())}" + self._log_event( + EventType.DAG_PARALLEL_GROUP_START, + f"Parallel group {group_id} started with {group_info.get('node_count', 0)} nodes", + event_id=event_id, + metadata=metadata, + ) + + def log_dag_parallel_group_complete(self, group_id: str, group_info: Dict[str, Any], duration: float): + """Log when a parallel group completes execution.""" + metadata = { + "group_id": group_id, + "node_count": group_info.get("node_count", 0), + "completed_nodes": group_info.get("completed_nodes", 0), + "failed_nodes": group_info.get("failed_nodes", 0), + "duration_seconds": duration, + "completion_time": time.time(), + } + event_id = f"dag_parallel_group_complete_{group_id}_{int(time.time())}" + self._log_event( + EventType.DAG_PARALLEL_GROUP_COMPLETE, + f"Parallel group {group_id} completed: {group_info.get('completed_nodes', 0)}/{group_info.get('node_count', 0)} nodes in {duration:.3f}s", + event_id=event_id, + metadata=metadata, + ) + + def log_dag_execution_plan_saved(self, plan_path: str, plan_info: Dict[str, Any]): + """Log when DAG execution plan is saved.""" + metadata = { + "plan_path": plan_path, + "node_count": plan_info.get("node_count", 0), + "edge_count": plan_info.get("edge_count", 0), + "parallel_groups_count": plan_info.get("parallel_groups_count", 0), + "save_time": time.time(), + } + event_id = f"dag_execution_plan_saved_{int(time.time())}" + self._log_event( + EventType.DAG_EXECUTION_PLAN_SAVED, + f"DAG execution plan saved to {plan_path}", + event_id=event_id, + metadata=metadata, + ) + + def log_dag_execution_plan_loaded(self, plan_path: str, plan_info: Dict[str, Any]): + """Log when DAG execution plan is loaded.""" + metadata = { + "plan_path": plan_path, + "node_count": plan_info.get("node_count", 0), + "edge_count": plan_info.get("edge_count", 0), + "parallel_groups_count": plan_info.get("parallel_groups_count", 0), + "load_time": time.time(), + } + event_id = f"dag_execution_plan_loaded_{int(time.time())}" + self._log_event( + EventType.DAG_EXECUTION_PLAN_LOADED, + f"DAG execution plan loaded from {plan_path}", + event_id=event_id, + metadata=metadata, + ) + + def log_job_restart( + self, + restart_reason: str, + original_start_time: float, + resume_partitions: List[int], + resume_from_operation: int, + checkpoint_paths: List[str], + ): + """Log when a job is restarted after interruption.""" + metadata = { + "restart_reason": restart_reason, + "original_start_time": original_start_time, + "restart_time": time.time(), + "resume_partitions": resume_partitions, + "resume_from_operation": resume_from_operation, + "checkpoint_paths": checkpoint_paths, + } + event_id = f"job_restart_{int(time.time())}" + self._log_event( + EventType.JOB_RESTART, + f"Job restarted after {restart_reason} interruption", + event_id=event_id, + metadata=metadata, + ) + + def log_partition_resume(self, partition_id: int, resume_operation: int, checkpoint_path: str, resume_reason: str): + """Log when a partition is resumed from a checkpoint.""" + metadata = { + "resume_operation": resume_operation, + "checkpoint_path": checkpoint_path, + "resume_reason": resume_reason, + } + event_id = f"partition_resume_{partition_id}_{int(time.time())}" + self._log_event( + EventType.PARTITION_RESUME, + f"Partition {partition_id} resumed from operation {resume_operation} checkpoint", + event_id=event_id, + partition_id=partition_id, + metadata=metadata, + ) + + def get_events(self, **kwargs) -> List[Event]: + """Get events with optional filtering.""" + if self.event_logger is None: + return [] + return self.event_logger.get_events(**kwargs) + + def generate_status_report(self) -> str: + """Generate status report.""" + if self.event_logger is None: + return "Event logging is disabled." + return self.event_logger.generate_status_report() + + def monitor_events(self, event_type: Optional[EventType] = None) -> Generator[Event, None, None]: + """Monitor events in real-time.""" + if self.event_logger is None: + return + yield from self.event_logger.monitor_events(event_type) + + def analyze_resumption_state(self, job_id: str) -> Dict[str, Any]: + """ + Analyze event history to determine resumption state and generate resumption plan. + + Args: + job_id: The job ID to analyze + + Returns: + Dictionary containing resumption analysis and plan + """ + if not self.event_logger: + return {"error": "Event logger not available"} + + events_file = self.event_logger.jsonl_file + if not os.path.exists(events_file): + return {"error": f"Events file not found: {events_file}"} + + # Parse all events + events = [] + with open(events_file, "r") as f: + for line in f: + try: + event = json.loads(line.strip()) + events.append(event) + except json.JSONDecodeError: + continue + + # Analyze events by type + partition_starts = [e for e in events if e.get("event_type") == "partition_start"] + partition_completes = [e for e in events if e.get("event_type") == "partition_complete"] + partition_failures = [e for e in events if e.get("event_type") == "partition_failed"] + op_starts = [e for e in events if e.get("event_type") == "op_start"] + op_completes = [e for e in events if e.get("event_type") == "op_complete"] + checkpoints = [e for e in events if e.get("event_type") == "checkpoint_saved"] + + # Determine job status + job_status = self._determine_job_status(events, partition_completes, partition_failures) + + # Analyze partition states + partition_states = self._analyze_partition_states( + partition_starts, partition_completes, partition_failures, op_starts, op_completes + ) + + # Generate resumption plan + resumption_plan = self._generate_resumption_plan(partition_states, checkpoints, job_status) + + # Calculate progress metrics + progress_metrics = self._calculate_progress_metrics(partition_states, events) + + return { + "job_id": job_id, + "job_status": job_status, + "total_events": len(events), + "partition_states": partition_states, + "resumption_plan": resumption_plan, + "progress_metrics": progress_metrics, + "analysis_timestamp": time.time(), + "can_resume": resumption_plan["can_resume"], + "resume_from_checkpoint": resumption_plan.get("resume_from_checkpoint"), + "partitions_to_retry": resumption_plan.get("partitions_to_retry", []), + "partitions_to_skip": resumption_plan.get("partitions_to_skip", []), + } + + def _determine_job_status( + self, events: List[Dict], partition_completes: List[Dict], partition_failures: List[Dict] + ) -> str: + """Determine the current job status based on events.""" + # Check if job has any completion events + job_completes = [e for e in events if e.get("event_type") == "job_complete"] + job_failures = [e for e in events if e.get("event_type") == "job_failed"] + + if job_completes: + return "completed" + elif job_failures: + return "failed" + elif partition_completes: + # Check if all partitions are completed (success or failure) + all_partitions_completed = all( + pc.get("metadata", {}).get("success", False) or pc.get("metadata", {}).get("error") is not None + for pc in partition_completes + ) + if all_partitions_completed: + return "completed_with_failures" + else: + return "running" + else: + return "not_started" + + def _analyze_partition_states( + self, + partition_starts: List[Dict], + partition_completes: List[Dict], + partition_failures: List[Dict], + op_starts: List[Dict], + op_completes: List[Dict], + ) -> Dict[int, Dict]: + """Analyze the state of each partition based on events.""" + partition_states = {} + + # Group events by partition ID + for start_event in partition_starts: + partition_id = start_event.get("partition_id") + if partition_id is None: + continue + + # Find the latest start event for this partition + partition_starts_for_id = [e for e in partition_starts if e.get("partition_id") == partition_id] + latest_start = max(partition_starts_for_id, key=lambda x: x.get("timestamp", 0)) + + # Find completion events for this partition + partition_completes_for_id = [e for e in partition_completes if e.get("partition_id") == partition_id] + partition_failures_for_id = [e for e in partition_failures if e.get("partition_id") == partition_id] + + # Find operation events for this partition + ops_for_partition = [e for e in op_starts if e.get("partition_id") == partition_id] + op_completes_for_partition = [e for e in op_completes if e.get("partition_id") == partition_id] + + # Determine partition state + state = self._determine_partition_state( + partition_id, + latest_start, + partition_completes_for_id, + partition_failures_for_id, + ops_for_partition, + op_completes_for_partition, + ) + + partition_states[partition_id] = state + + return partition_states + + def _determine_partition_state( + self, + partition_id: int, + start_event: Dict, + completes: List[Dict], + failures: List[Dict], + op_starts: List[Dict], + op_completes: List[Dict], + ) -> Dict: + """Determine the detailed state of a specific partition.""" + # Find the latest completion event + latest_complete = max(completes, key=lambda x: x.get("timestamp", 0)) if completes else None + + # Determine if partition is completed successfully + is_completed = latest_complete and latest_complete.get("metadata", {}).get("success", False) + is_failed = latest_complete and not latest_complete.get("metadata", {}).get("success", False) + + # Find the last operation that was started + last_op_start = max(op_starts, key=lambda x: x.get("timestamp", 0)) if op_starts else None + last_op_complete = max(op_completes, key=lambda x: x.get("timestamp", 0)) if op_completes else None + + # Determine current operation + current_operation = None + if last_op_start: + current_operation = { + "name": last_op_start.get("operation_name"), + "idx": last_op_start.get("operation_idx"), + "started_at": last_op_start.get("timestamp"), + "completed": last_op_complete is not None + and last_op_complete.get("timestamp", 0) > last_op_start.get("timestamp", 0), + } + + return { + "partition_id": partition_id, + "status": "completed" if is_completed else "failed" if is_failed else "running", + "start_time": start_event.get("timestamp"), + "completion_time": latest_complete.get("timestamp") if latest_complete else None, + "duration": latest_complete.get("metadata", {}).get("duration_seconds") if latest_complete else None, + "success": is_completed, + "error": latest_complete.get("metadata", {}).get("error") if latest_complete and not is_completed else None, + "current_operation": current_operation, + "retry_count": len([f for f in failures if f.get("partition_id") == partition_id]), + "output_path": latest_complete.get("metadata", {}).get("output_path") if latest_complete else None, + } + + def _generate_resumption_plan( + self, partition_states: Dict[int, Dict], checkpoints: List[Dict], job_status: str + ) -> Dict: + """Generate a resumption plan based on partition states and checkpoints.""" + # Find partitions that need to be retried + partitions_to_retry = [] + partitions_to_skip = [] + + for partition_id, state in partition_states.items(): + if state["status"] == "failed": + partitions_to_retry.append(partition_id) + elif state["status"] == "completed": + partitions_to_skip.append(partition_id) + + # Find the latest checkpoint + latest_checkpoint = max(checkpoints, key=lambda x: x.get("timestamp", 0)) if checkpoints else None + + # Determine if we can resume based on job status and partition states + if job_status == "completed": + can_resume = False + reason = "Job already completed successfully" + elif job_status == "failed": + can_resume = True + reason = "Job failed, can resume from checkpoint or retry failed partitions" + elif len(partitions_to_retry) > 0: + can_resume = True + reason = f"Found {len(partitions_to_retry)} failed partitions to retry" + elif latest_checkpoint is not None: + can_resume = True + reason = "Found checkpoint to resume from" + else: + can_resume = False + reason = "No failed partitions or checkpoints found" + + return { + "can_resume": can_resume, + "reason": reason, + "resume_from_checkpoint": ( + latest_checkpoint.get("metadata", {}).get("checkpoint_path") if latest_checkpoint else None + ), + "partitions_to_retry": partitions_to_retry, + "partitions_to_skip": partitions_to_skip, + "total_partitions_to_process": len(partitions_to_retry), + "estimated_remaining_work": len(partitions_to_retry) / len(partition_states) if partition_states else 0, + } + + def _calculate_progress_metrics(self, partition_states: Dict[int, Dict], events: List[Dict]) -> Dict: + """Calculate progress metrics based on partition states.""" + total_partitions = len(partition_states) + completed_partitions = len([s for s in partition_states.values() if s["status"] == "completed"]) + failed_partitions = len([s for s in partition_states.values() if s["status"] == "failed"]) + running_partitions = len([s for s in partition_states.values() if s["status"] == "running"]) + + # Calculate overall progress + if total_partitions == 0: + progress_percentage = 0 + else: + progress_percentage = (completed_partitions / total_partitions) * 100 + + # Calculate timing metrics + job_start_events = [e for e in events if e.get("event_type") == "job_start"] + start_time = job_start_events[0].get("timestamp") if job_start_events else None + current_time = time.time() + elapsed_time = current_time - start_time if start_time else 0 + + return { + "total_partitions": total_partitions, + "completed_partitions": completed_partitions, + "failed_partitions": failed_partitions, + "running_partitions": running_partitions, + "progress_percentage": progress_percentage, + "elapsed_time_seconds": elapsed_time, + "start_time": start_time, + "current_time": current_time, + } diff --git a/data_juicer/core/executor/factory.py b/data_juicer/core/executor/factory.py index 0f89a19723..d507b0efb1 100644 --- a/data_juicer/core/executor/factory.py +++ b/data_juicer/core/executor/factory.py @@ -1,19 +1,25 @@ +from .base import ExecutorBase +from .default_executor import DefaultExecutor + + class ExecutorFactory: @staticmethod - def create_executor(executor_type: str): + def create_executor(executor_type: str) -> ExecutorBase: if executor_type in ("local", "default"): - from .default_executor import DefaultExecutor - return DefaultExecutor elif executor_type == "ray": from .ray_executor import RayExecutor return RayExecutor + elif executor_type == "ray_partitioned": + from .ray_executor_partitioned import PartitionedRayExecutor + + return PartitionedRayExecutor # TODO: add nemo support # elif executor_type == "nemo": - # return NemoExecutor() + # return NemoExecutor # TODO: add dask support # elif executor_type == "dask": - # return DaskExecutor() + # return DaskExecutor else: raise ValueError("Unsupported executor type") diff --git a/data_juicer/core/executor/partition_size_optimizer.py b/data_juicer/core/executor/partition_size_optimizer.py new file mode 100644 index 0000000000..756701e69e --- /dev/null +++ b/data_juicer/core/executor/partition_size_optimizer.py @@ -0,0 +1,806 @@ +""" +Partition Size Optimizer for DataJuicer + +This module automatically configures optimal partition sizes based on: +1. Data modality (text, image, audio, video, multimodal) +2. Dataset characteristics (file sizes, complexity) +3. Available system resources (CPU, memory, GPU) +4. Processing pipeline complexity +5. Ray cluster configuration +""" + +from dataclasses import dataclass +from enum import Enum +from typing import Dict, List, Optional, Tuple + +import psutil +import ray +from loguru import logger + + +class ModalityType(Enum): + """Supported data modalities.""" + + TEXT = "text" + IMAGE = "image" + AUDIO = "audio" + VIDEO = "video" + MULTIMODAL = "multimodal" + + +@dataclass +class LocalResources: + """Local system resources.""" + + cpu_cores: int + available_memory_gb: float + total_memory_gb: float + gpu_count: int + gpu_memory_gb: Optional[float] = None + disk_space_gb: Optional[float] = None + + +@dataclass +class ClusterResources: + """Ray cluster resources.""" + + num_nodes: int + total_cpu_cores: int + total_memory_gb: float + available_cpu_cores: int + available_memory_gb: float + gpu_resources: Dict[str, int] + + +@dataclass +class DataCharacteristics: + """Data characteristics from sampling.""" + + primary_modality: ModalityType + modality_distribution: Dict[ModalityType, int] + avg_text_length: float + avg_images_per_sample: float + avg_audio_per_sample: float + avg_video_per_sample: float + total_samples: int + sample_size_analyzed: int + memory_per_sample_mb: float + processing_complexity_score: float + data_skew_factor: float # 0-1, higher means more variance + + +@dataclass +class ModalityConfig: + """Configuration for a specific modality.""" + + modality: ModalityType + default_partition_size: int + max_partition_size: int + max_partition_size_mb: int + memory_multiplier: float # Memory usage multiplier compared to text + complexity_multiplier: float # Processing complexity multiplier + description: str + + +class ResourceDetector: + """Detect available system and cluster resources.""" + + @staticmethod + def detect_local_resources() -> LocalResources: + """Detect local system resources.""" + # CPU + cpu_cores = psutil.cpu_count(logical=True) + + # Memory + memory = psutil.virtual_memory() + available_memory_gb = memory.available / (1024**3) + total_memory_gb = memory.total / (1024**3) + + # GPU (basic detection) + gpu_count = 0 + gpu_memory_gb = None + try: + import torch + + if torch.cuda.is_available(): + gpu_count = torch.cuda.device_count() + if gpu_count > 0: + gpu_memory_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3) + except ImportError: + pass + + # Disk space + disk_space_gb = None + try: + disk_usage = psutil.disk_usage("/") + disk_space_gb = disk_usage.free / (1024**3) + except Exception as e: + logger.warning(f"Could not detect disk space: {e}") + pass + + return LocalResources( + cpu_cores=cpu_cores, + available_memory_gb=available_memory_gb, + total_memory_gb=total_memory_gb, + gpu_count=gpu_count, + gpu_memory_gb=gpu_memory_gb, + disk_space_gb=disk_space_gb, + ) + + @staticmethod + def detect_ray_cluster() -> Optional[ClusterResources]: + """Detect Ray cluster resources.""" + try: + if not ray.is_initialized(): + return None + + # Get cluster resources + cluster_resources = ray.cluster_resources() + available_resources = ray.available_resources() + + # Parse resources + total_cpu = cluster_resources.get("CPU", 0) + total_memory = cluster_resources.get("memory", 0) / (1024**3) # Convert to GB + available_cpu = available_resources.get("CPU", 0) + available_memory = available_resources.get("memory", 0) / (1024**3) + + # Count nodes (approximate) + num_nodes = max(1, int(total_cpu / 8)) # Assume 8 cores per node + + # GPU resources + gpu_resources = {} + for key, value in cluster_resources.items(): + if key.startswith("GPU"): + gpu_resources[key] = value + + return ClusterResources( + num_nodes=num_nodes, + total_cpu_cores=int(total_cpu), + total_memory_gb=total_memory, + available_cpu_cores=int(available_cpu), + available_memory_gb=available_memory, + gpu_resources=gpu_resources, + ) + except Exception as e: + logger.warning(f"Could not detect Ray cluster resources: {e}") + return None + + @staticmethod + def calculate_optimal_worker_count( + local_resources: LocalResources, + cluster_resources: Optional[ClusterResources] = None, + partition_size: int = None, + total_samples: int = None, + ) -> int: + """ + Calculate optimal number of Ray workers based on available resources. + + Args: + local_resources: Local system resources + cluster_resources: Ray cluster resources (optional) + partition_size: Size of each partition (for workload estimation) + total_samples: Total number of samples (for workload estimation) + + Returns: + Optimal number of workers + """ + # Determine available CPU cores + if cluster_resources: + available_cores = min(local_resources.cpu_cores, cluster_resources.available_cpu_cores) + else: + available_cores = local_resources.cpu_cores + + # Base calculation: use 75% of available cores to leave room for system processes + base_workers = max(1, int(available_cores * 0.75)) + + # Adjust based on workload characteristics + if partition_size and total_samples: + estimated_partitions = total_samples / partition_size + + # We want enough workers to process partitions efficiently + # But not so many that we have too much overhead + if estimated_partitions < base_workers: + # Few partitions - reduce workers to avoid overhead + optimal_workers = max(1, int(estimated_partitions * 0.8)) + elif estimated_partitions > base_workers * 2: + # Many partitions - can use more workers + optimal_workers = min(available_cores, int(base_workers * 1.2)) + else: + # Balanced workload - use base calculation + optimal_workers = base_workers + else: + # No workload info - use base calculation + optimal_workers = base_workers + + # Ensure we don't exceed available cores + optimal_workers = min(optimal_workers, available_cores) + + # Minimum of 1 worker, maximum reasonable limit + optimal_workers = max(1, min(optimal_workers, 32)) # Cap at 32 workers + + logger.info(f"Worker count calculation:") + logger.info(f" Available CPU cores: {available_cores}") + logger.info(f" Base workers (75% of cores): {base_workers}") + if partition_size and total_samples: + logger.info(f" Estimated partitions: {total_samples / partition_size:.1f}") + logger.info(f" Optimal workers: {optimal_workers}") + + return optimal_workers + + +class PartitionSizeOptimizer: + """Automatically optimizes partition sizes based on data characteristics and available resources.""" + + def calculate_target_partition_mb(self, available_memory_gb: float) -> int: + """Calculate target partition size in MB based on available memory and config. + + Uses config.partition.target_size_mb if available, otherwise falls back to + dynamic sizing based on available memory (32MB - 256MB). + """ + # Use configured target if available + if hasattr(self.cfg, "partition") and hasattr(self.cfg.partition, "target_size_mb"): + configured_size = self.cfg.partition.target_size_mb + logger.info(f"Using configured target partition size: {configured_size} MB") + return configured_size + + # Fall back to dynamic calculation based on available memory + if available_memory_gb < 16: + return 32 + elif available_memory_gb < 64: + return 64 + elif available_memory_gb < 256: + return 128 + else: + return 256 + + # Default configurations for different modalities + MODALITY_CONFIGS = { + ModalityType.TEXT: ModalityConfig( + modality=ModalityType.TEXT, + default_partition_size=10000, # Increased for 256MB target + max_partition_size=50000, # Increased for larger partitions + max_partition_size_mb=256, # Default 256MB per partition (configurable) + memory_multiplier=1.0, + complexity_multiplier=1.0, + description="Text data - efficient processing, low memory usage, target 256MB partitions (configurable)", + ), + ModalityType.IMAGE: ModalityConfig( + modality=ModalityType.IMAGE, + default_partition_size=2000, # Increased for 256MB target + max_partition_size=10000, # Increased for larger partitions + max_partition_size_mb=256, # Default 256MB per partition (configurable) + memory_multiplier=5.0, + complexity_multiplier=3.0, + description="Image data - moderate memory usage, target 256MB partitions (configurable)", + ), + ModalityType.AUDIO: ModalityConfig( + modality=ModalityType.AUDIO, + default_partition_size=1000, # Increased for 256MB target + max_partition_size=4000, # Increased for larger partitions + max_partition_size_mb=256, # Default 256MB per partition (configurable) + memory_multiplier=8.0, + complexity_multiplier=5.0, + description="Audio data - high memory usage, target 256MB partitions (configurable)", + ), + ModalityType.VIDEO: ModalityConfig( + modality=ModalityType.VIDEO, + default_partition_size=400, # Increased for 256MB target + max_partition_size=2000, # Increased for larger partitions + max_partition_size_mb=256, # Default 256MB per partition (configurable) + memory_multiplier=20.0, + complexity_multiplier=15.0, + description="Video data - very high memory usage, target 256MB partitions (configurable)", + ), + ModalityType.MULTIMODAL: ModalityConfig( + modality=ModalityType.MULTIMODAL, + default_partition_size=1600, # Increased for 256MB target + max_partition_size=6000, # Increased for larger partitions + max_partition_size_mb=256, # Default 256MB per partition (configurable) + memory_multiplier=10.0, + complexity_multiplier=8.0, + description="Multimodal data - combination of multiple modalities, target 256MB partitions (configurable)", + ), + } + + def __init__(self, cfg): + """Initialize the optimizer with configuration.""" + self.cfg = cfg + self.text_key = getattr(cfg, "text_key", "text") + self.image_key = getattr(cfg, "image_key", "images") + self.audio_key = getattr(cfg, "audio_key", "audios") + self.video_key = getattr(cfg, "video_key", "videos") + self.resource_detector = ResourceDetector() + + def detect_modality(self, sample: Dict) -> ModalityType: + """Detect the primary modality of a sample.""" + modalities = [] + + # Check for text + if self.text_key in sample and sample[self.text_key]: + modalities.append(ModalityType.TEXT) + + # Check for images + if sample.get(self.image_key): + modalities.append(ModalityType.IMAGE) + + # Check for audio + if sample.get(self.audio_key): + modalities.append(ModalityType.AUDIO) + + # Check for video + if sample.get(self.video_key): + modalities.append(ModalityType.VIDEO) + + # Determine primary modality + if len(modalities) > 1: + return ModalityType.MULTIMODAL + elif len(modalities) == 1: + return modalities[0] + else: + # Default to text if no modality detected + return ModalityType.TEXT + + def analyze_dataset_characteristics(self, dataset) -> DataCharacteristics: + """Analyze dataset characteristics to inform partition sizing.""" + logger.info("Analyzing dataset characteristics for partition optimization...") + + # Get dataset size + try: + if hasattr(dataset, "count"): + total_samples = dataset.count() + elif hasattr(dataset, "__len__"): + total_samples = len(dataset) + else: + total_samples = 1000 + logger.warning("Could not determine dataset size, using estimate of 1000 samples") + except Exception as e: + logger.warning(f"Could not determine dataset size: {e}, using estimate of 1000 samples") + total_samples = 1000 + + # Adaptive sampling: minimum 0.1% for large datasets + if total_samples < 1000: + sample_size = total_samples + elif total_samples < 100000: + sample_size = min(1000, total_samples // 100) # 1% + else: + sample_size = min(10000, total_samples // 1000) # 0.1%, cap at 10k + + try: + # Sample dataset for analysis + if hasattr(dataset, "get"): + # RayDataset with get() method + samples = dataset.get(sample_size) + logger.info(f"Successfully sampled {len(samples)} samples using get()") + elif hasattr(dataset, "take"): + # Datasets with take() method + samples = list(dataset.take(sample_size)) + logger.info(f"Successfully sampled {len(samples)} samples using take()") + elif hasattr(dataset, "__getitem__"): + # Handle list-like datasets + samples = list(dataset[:sample_size]) + logger.info(f"Successfully sampled {len(samples)} samples from list-like dataset") + else: + # Fallback: try to iterate + samples = [] + for i, sample in enumerate(dataset): + if i >= sample_size: + break + samples.append(sample) + logger.info(f"Successfully sampled {len(samples)} samples by iteration") + except Exception as e: + logger.warning(f"Could not sample dataset: {e}, using default analysis") + import traceback + + logger.debug(f"Sampling error traceback: {traceback.format_exc()}") + return DataCharacteristics( + primary_modality=ModalityType.TEXT, + modality_distribution={ModalityType.TEXT: 1}, + avg_text_length=500, + avg_images_per_sample=0, + avg_audio_per_sample=0, + avg_video_per_sample=0, + total_samples=total_samples, + sample_size_analyzed=0, + memory_per_sample_mb=0.002, + processing_complexity_score=1.0, + data_skew_factor=0.5, + ) + + # Analyze samples + modality_counts = {modality: 0 for modality in ModalityType} + text_lengths = [] + image_counts = [] + audio_counts = [] + video_counts = [] + sample_sizes = [] + + for sample in samples: + # Detect modality + modality = self.detect_modality(sample) + modality_counts[modality] += 1 + + # Analyze text + text_length = 0 + if self.text_key in sample and sample[self.text_key]: + if isinstance(sample[self.text_key], str): + text_length = len(sample[self.text_key]) + elif isinstance(sample[self.text_key], list): + text_length = sum(len(t) for t in sample[self.text_key]) + text_lengths.append(text_length) + + # Count media files + image_count = len(sample.get(self.image_key, [])) + audio_count = len(sample.get(self.audio_key, [])) + video_count = len(sample.get(self.video_key, [])) + + image_counts.append(image_count) + audio_counts.append(audio_count) + video_counts.append(video_count) + + # Estimate sample size in MB + sample_size_mb = self.estimate_sample_size_mb(sample) + sample_sizes.append(sample_size_mb) + + # Calculate statistics + avg_text_length = sum(text_lengths) / len(text_lengths) if text_lengths else 0 + avg_images_per_sample = sum(image_counts) / len(image_counts) if image_counts else 0 + avg_audio_per_sample = sum(audio_counts) / len(audio_counts) if audio_counts else 0 + avg_video_per_sample = sum(video_counts) / len(video_counts) if video_counts else 0 + + # Calculate percentile-based memory estimates (p90 is more robust than mean) + if sample_sizes and len(sample_sizes) > 1: + sorted_sizes = sorted(sample_sizes) + p90_idx = int(len(sorted_sizes) * 0.9) + p90_memory = sorted_sizes[p90_idx] + mean_size = sum(sample_sizes) / len(sample_sizes) + variance = sum((x - mean_size) ** 2 for x in sample_sizes) / (len(sample_sizes) - 1) + std_dev = variance**0.5 + data_skew_factor = min(1.0, std_dev / mean_size if mean_size > 0 else 0) + # Use p90 for conservative sizing + avg_memory_per_sample_mb = p90_memory + else: + avg_memory_per_sample_mb = sample_sizes[0] if sample_sizes else 0.002 + data_skew_factor = 0.5 + + # Determine primary modality + primary_modality = max(modality_counts.items(), key=lambda x: x[1])[0] + + characteristics = DataCharacteristics( + primary_modality=primary_modality, + modality_distribution=modality_counts, + avg_text_length=avg_text_length, + avg_images_per_sample=avg_images_per_sample, + avg_audio_per_sample=avg_audio_per_sample, + avg_video_per_sample=avg_video_per_sample, + total_samples=total_samples, + sample_size_analyzed=len(samples), + memory_per_sample_mb=avg_memory_per_sample_mb, + processing_complexity_score=1.0, # Will be calculated later + data_skew_factor=data_skew_factor, + ) + + logger.info(f"Dataset analysis complete:") + logger.info(f" Primary modality: {primary_modality.value}") + logger.info(f" Modality distribution: {modality_counts}") + logger.info(f" Avg text length: {avg_text_length:.0f} chars") + logger.info(f" Avg images per sample: {avg_images_per_sample:.1f}") + logger.info(f" Avg audio per sample: {avg_audio_per_sample:.1f}") + logger.info(f" Avg video per sample: {avg_video_per_sample:.1f}") + logger.info(f" Avg memory per sample: {avg_memory_per_sample_mb:.3f} MB") + logger.info(f" Data skew factor: {data_skew_factor:.2f}") + + return characteristics + + def estimate_sample_size_mb(self, sample: Dict) -> float: + """Measure actual memory size of a sample in MB using sys.getsizeof.""" + import sys + + return sys.getsizeof(sample) / (1024 * 1024) + + def analyze_processing_complexity(self, process_pipeline: List) -> float: + """Analyze the complexity of the processing pipeline using linear scoring.""" + COMPLEXITY_WEIGHTS = { + "high": 0.3, # embedding, model, neural + "medium": 0.2, # filter, deduplicator + "low": 0.1, # text cleaning + } + + # Count operations by complexity level + high_ops = medium_ops = low_ops = 0 + for op in process_pipeline: + if isinstance(op, dict): + op_name = list(op.keys())[0].lower() + if any(kw in op_name for kw in ["embedding", "similarity", "model", "neural", "vision", "audio"]): + high_ops += 1 + elif any(kw in op_name for kw in ["filter", "deduplicator", "mapper"]): + medium_ops += 1 + else: + low_ops += 1 + + # Linear complexity scoring + complexity_score = 1.0 + ( + high_ops * COMPLEXITY_WEIGHTS["high"] + + medium_ops * COMPLEXITY_WEIGHTS["medium"] + + low_ops * COMPLEXITY_WEIGHTS["low"] + ) + + logger.info(f"Processing complexity: {high_ops} high, {medium_ops} med, {low_ops} low = {complexity_score:.2f}") + return complexity_score + + def get_optimal_partition_size(self, dataset, process_pipeline: List) -> Tuple[int, int]: + """Get optimal partition size and max size based on data characteristics and available resources.""" + + # Analyze dataset + characteristics = self.analyze_dataset_characteristics(dataset) + + # Analyze processing complexity + complexity_multiplier = self.analyze_processing_complexity(process_pipeline) + characteristics.processing_complexity_score = complexity_multiplier + + # Detect available resources + local_resources = self.resource_detector.detect_local_resources() + cluster_resources = self.resource_detector.detect_ray_cluster() + + logger.info(f"Resource analysis:") + logger.info(f" Local CPU cores: {local_resources.cpu_cores}") + logger.info(f" Local available memory: {local_resources.available_memory_gb:.1f} GB") + if cluster_resources: + logger.info(f" Cluster CPU cores: {cluster_resources.total_cpu_cores}") + logger.info(f" Cluster available memory: {cluster_resources.available_memory_gb:.1f} GB") + + # Calculate optimal partition size + optimal_size = self.calculate_resource_aware_partition_size( + characteristics, local_resources, cluster_resources, complexity_multiplier + ) + + # Calculate optimal max size in MB + optimal_max_size_mb = self.calculate_optimal_max_size_mb( + characteristics, local_resources, cluster_resources, complexity_multiplier + ) + + logger.info(f"Optimal partition configuration:") + logger.info(f" Size: {optimal_size} samples") + logger.info(f" Max size: {optimal_max_size_mb} MB") + logger.info(f" Based on: {characteristics.primary_modality.value} modality") + logger.info(f" Complexity multiplier: {complexity_multiplier:.2f}") + logger.info(f" Data skew factor: {characteristics.data_skew_factor:.2f}") + + return optimal_size, optimal_max_size_mb + + def calculate_resource_aware_partition_size( + self, + characteristics: DataCharacteristics, + local_resources: LocalResources, + cluster_resources: Optional[ClusterResources], + complexity_multiplier: float, + ) -> int: + """ + Calculate partition size based on data characteristics and available resources. + + Primary goal: Target 64MB per partition for optimal memory usage. + Secondary goals: Ensure sufficient parallelism and respect resource constraints. + """ + + # Get base configuration for the modality + base_config = self.MODALITY_CONFIGS[characteristics.primary_modality] + + # Step 1: Calculate dynamic target based on available memory + available_memory_gb = self._get_available_memory(local_resources, cluster_resources) + target_memory_mb = self.calculate_target_partition_mb(available_memory_gb) + + if characteristics.primary_modality == ModalityType.TEXT: + target_size = self.calculate_text_partition_size_simple( + characteristics.avg_text_length, complexity_multiplier, target_memory_mb + ) + else: + # For media, use memory-per-sample to calculate target + if characteristics.memory_per_sample_mb > 0: + target_size = int(target_memory_mb / (characteristics.memory_per_sample_mb * complexity_multiplier)) + else: + target_size = base_config.default_partition_size + target_size = max(10, min(target_size, base_config.max_partition_size)) + + # Step 2: Check if this fits in available memory + max_partition_memory_mb = (available_memory_gb * 1024 * 0.8) / 4 # Allow 4 concurrent partitions + + if target_size * characteristics.memory_per_sample_mb * 2 > max_partition_memory_mb: + # Doesn't fit - scale down + safe_size = int(max_partition_memory_mb / (characteristics.memory_per_sample_mb * 2)) + logger.warning(f"Memory constraint: reducing partition size from {target_size} to {safe_size} samples") + target_size = max(10, safe_size) + + # Step 3: Ensure sufficient parallelism for large datasets + min_partitions_needed = self._calculate_min_partitions( + characteristics.total_samples, local_resources, cluster_resources + ) + + if characteristics.total_samples / target_size < min_partitions_needed: + # Too few partitions - reduce size for better parallelism + parallelism_size = int(characteristics.total_samples / min_partitions_needed) + logger.info( + f"Parallelism optimization: reducing partition size from {target_size} to {parallelism_size} " + f"to create {min_partitions_needed} partitions" + ) + target_size = max(10, parallelism_size) + + # Step 4: Adjust for data skew + if characteristics.data_skew_factor > 0.7: + # High variance - use smaller partitions for better load balancing + skew_adjusted_size = int(target_size * 0.8) + logger.info(f"Data skew adjustment: reducing partition size from {target_size} to {skew_adjusted_size}") + target_size = skew_adjusted_size + + # Step 5: Apply final bounds + final_size = max(10, min(target_size, base_config.max_partition_size)) + + logger.info(f"Final partition size: {final_size} samples") + logger.info(f" Estimated memory per partition: {final_size * characteristics.memory_per_sample_mb:.1f} MB") + logger.info(f" Estimated total partitions: {characteristics.total_samples / final_size:.0f}") + + return final_size + + def _get_available_memory( + self, local_resources: LocalResources, cluster_resources: Optional[ClusterResources] + ) -> float: + """Get available memory in GB.""" + if cluster_resources: + return min(local_resources.available_memory_gb, cluster_resources.available_memory_gb) + return local_resources.available_memory_gb + + def _calculate_min_partitions( + self, + total_samples: int, + local_resources: LocalResources, + cluster_resources: Optional[ClusterResources], + ) -> int: + """Calculate minimum number of partitions needed for good parallelism.""" + # Only enforce minimum partitions for large datasets (>10k samples) + if total_samples <= 10000: + return 1 # Small datasets - prioritize 64MB target over parallelism + + # For large datasets, aim for at least 1.5x CPU cores in partitions + available_cores = local_resources.cpu_cores + if cluster_resources: + available_cores = min(available_cores, cluster_resources.available_cpu_cores) + + return max(1, int(available_cores * 1.5)) + + def calculate_text_partition_size_simple( + self, avg_text_length: float, complexity_score: float, target_memory_mb: float + ) -> int: + """Calculate text partition size targeting specified memory size.""" + # Estimate bytes per sample (conservative: 2 bytes per char + overhead) + bytes_per_sample = avg_text_length * 2.0 + mb_per_sample = bytes_per_sample / (1024 * 1024) + + # Calculate samples for target, adjusted for complexity + if mb_per_sample > 0: + target_samples = int(target_memory_mb / (mb_per_sample * complexity_score)) + else: + target_samples = 5000 + + # Apply reasonable bounds + target_samples = max(1000, min(target_samples, 20000)) + + logger.info(f"Text partition calculation:") + logger.info(f" Target: {target_memory_mb}MB, Avg text: {avg_text_length:.0f} chars") + logger.info(f" Estimated: {mb_per_sample:.3f} MB/sample") + logger.info(f" Result: {target_samples} samples (~{target_samples * mb_per_sample:.1f} MB)") + + return target_samples + + def calculate_optimal_max_size_mb( + self, + characteristics: DataCharacteristics, + local_resources: LocalResources, + cluster_resources: Optional[ClusterResources], + complexity_multiplier: float, + ) -> int: + """Calculate optimal max partition size in MB based on available memory.""" + # Calculate dynamic target based on available memory + available_memory_gb = local_resources.available_memory_gb + if cluster_resources: + available_memory_gb = min(available_memory_gb, cluster_resources.available_memory_gb) + + target_max_size_mb = self.calculate_target_partition_mb(available_memory_gb) + + # Adjust for processing complexity + complexity_adjusted_size = int(target_max_size_mb / complexity_multiplier) + + # Don't exceed 25% of available memory per partition + max_size_by_memory = int(available_memory_gb * 1024 * 0.25) + + # Apply bounds + optimal_max_size_mb = min(complexity_adjusted_size, max_size_by_memory) + optimal_max_size_mb = max(32, optimal_max_size_mb) + optimal_max_size_mb = min(512, optimal_max_size_mb) # Increased max from 128MB + + logger.info(f"Max partition size calculation:") + logger.info(f" Target size: {target_max_size_mb} MB (dynamic based on {available_memory_gb:.1f} GB)") + logger.info(f" Complexity adjusted: {complexity_adjusted_size} MB") + logger.info(f" Max by memory (25%): {max_size_by_memory} MB") + logger.info(f" Optimal max size: {optimal_max_size_mb} MB") + + return optimal_max_size_mb + + def get_partition_recommendations(self, dataset, process_pipeline: List) -> Dict: + """Get comprehensive partition recommendations.""" + optimal_size, optimal_max_size_mb = self.get_optimal_partition_size(dataset, process_pipeline) + characteristics = self.analyze_dataset_characteristics(dataset) + + # Detect resources + local_resources = self.resource_detector.detect_local_resources() + cluster_resources = self.resource_detector.detect_ray_cluster() + + # Calculate optimal worker count + optimal_workers = self.resource_detector.calculate_optimal_worker_count( + local_resources, cluster_resources, optimal_size, characteristics.total_samples + ) + + recommendations = { + "recommended_partition_size": optimal_size, + "recommended_max_size_mb": optimal_max_size_mb, + "recommended_worker_count": optimal_workers, + "primary_modality": characteristics.primary_modality.value, + "data_characteristics": { + "avg_text_length": characteristics.avg_text_length, + "avg_images_per_sample": characteristics.avg_images_per_sample, + "avg_audio_per_sample": characteristics.avg_audio_per_sample, + "avg_video_per_sample": characteristics.avg_video_per_sample, + "memory_per_sample_mb": characteristics.memory_per_sample_mb, + "data_skew_factor": characteristics.data_skew_factor, + "total_samples": characteristics.total_samples, + }, + "resource_analysis": { + "local_cpu_cores": local_resources.cpu_cores, + "local_available_memory_gb": local_resources.available_memory_gb, + "cluster_available_cpu_cores": cluster_resources.available_cpu_cores if cluster_resources else None, + "cluster_available_memory_gb": cluster_resources.available_memory_gb if cluster_resources else None, + }, + "reasoning": { + "modality": f"Based on {characteristics.primary_modality.value} modality", + "complexity": f"Processing complexity factor: {characteristics.processing_complexity_score:.2f}", + "dataset_size": f"Dataset size: {characteristics.total_samples} samples", + "text_length": f"Average text length: {characteristics.avg_text_length:.0f} characters", + "data_skew": f"Data skew factor: {characteristics.data_skew_factor:.2f}", + "memory_constraints": f"Memory per sample: {characteristics.memory_per_sample_mb:.3f} MB", + "worker_count": f"Optimal workers: {optimal_workers} (based on {local_resources.cpu_cores} available cores)", + }, + "modality_configs": { + modality.value: { + "default_size": config.default_partition_size, + "max_size": config.max_partition_size, + "max_size_mb": config.max_partition_size_mb, + "description": config.description, + } + for modality, config in self.MODALITY_CONFIGS.items() + }, + } + + return recommendations + + +def auto_configure_resources(cfg, dataset, process_pipeline: List) -> Dict: + """ + Analyze dataset and return resource configuration recommendations. + + Does NOT mutate cfg - caller should apply recommendations as needed. + + Args: + cfg: Configuration object (read-only) + dataset: Dataset to analyze + process_pipeline: List of processing operations + + Returns: + Dict with recommended resource configuration + """ + logger.info("Starting resource optimization...") + optimizer = PartitionSizeOptimizer(cfg) + recommendations = optimizer.get_partition_recommendations(dataset, process_pipeline) + + logger.info("Resource optimization completed:") + logger.info(f" Recommended partition.size: {recommendations['recommended_partition_size']}") + logger.info(f" Recommended partition.max_size_mb: {recommendations['recommended_max_size_mb']}") + logger.info(f" Recommended worker count: {recommendations['recommended_worker_count']}") + + return recommendations diff --git a/data_juicer/core/executor/ray_executor.py b/data_juicer/core/executor/ray_executor.py index 2313b18aef..9bc4112211 100644 --- a/data_juicer/core/executor/ray_executor.py +++ b/data_juicer/core/executor/ray_executor.py @@ -9,6 +9,8 @@ from data_juicer.core.data.dataset_builder import DatasetBuilder from data_juicer.core.executor import ExecutorBase +from data_juicer.core.executor.dag_execution_mixin import DAGExecutionMixin +from data_juicer.core.executor.event_logging_mixin import EventLoggingMixin from data_juicer.core.ray_exporter import RayExporter from data_juicer.ops import load_ops from data_juicer.ops.op_fusion import fuse_operators @@ -31,7 +33,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): shutil.rmtree(self.tmp_dir) -class RayExecutor(ExecutorBase): +class RayExecutor(ExecutorBase, DAGExecutionMixin, EventLoggingMixin): """ Executor based on Ray. @@ -50,10 +52,15 @@ def __init__(self, cfg: Optional[Namespace] = None): :param cfg: optional config dict. """ super().__init__(cfg) + self.executor_type = "ray" self.work_dir = self.cfg.work_dir - # TODO: support ray - # self.adapter = Adapter(self.cfg) + + # Initialize EventLoggingMixin for job management and event logging + EventLoggingMixin.__init__(self, cfg) + + # Initialize DAGExecutionMixin for AST/DAG functionality + DAGExecutionMixin.__init__(self) # init ray logger.info("Initializing Ray ...") @@ -120,15 +127,59 @@ def run(self, load_data_np: Optional[PositiveInt] = None, skip_export: bool = Fa logger.info("Preparing process operators...") ops = load_ops(self.cfg.process) + # Initialize DAG execution planning + self._initialize_dag_execution(self.cfg) + + # Log job start with DAG context + # Handle both dataset_path (string) and dataset (dict) configurations + dataset_info = {} + if hasattr(self.cfg, "dataset_path") and self.cfg.dataset_path: + dataset_info["dataset_path"] = self.cfg.dataset_path + if hasattr(self.cfg, "dataset") and self.cfg.dataset: + dataset_info["dataset"] = self.cfg.dataset + + job_config = { + **dataset_info, + "work_dir": self.work_dir, + "executor_type": self.executor_type, + "dag_node_count": len(self.pipeline_dag.nodes) if self.pipeline_dag else 0, + "dag_edge_count": len(self.pipeline_dag.edges) if self.pipeline_dag else 0, + "parallel_groups_count": len(self.pipeline_dag.parallel_groups) if self.pipeline_dag else 0, + } + self.log_job_start(job_config, len(ops)) + if self.cfg.op_fusion: logger.info(f"Start OP fusion and reordering with strategy " f"[{self.cfg.fusion_strategy}]...") ops = fuse_operators(ops) with TempDirManager(self.tmp_dir): - # 3. data process - logger.info("Processing data...") + # 3. data process with DAG monitoring + logger.info("Processing data with DAG monitoring...") tstart = time.time() - dataset.process(ops) + + # Get input row count before processing + input_rows = dataset.data.count() + start_time = time.time() + + # Pre-execute DAG monitoring (log operation start events) + if self.pipeline_dag: + self._pre_execute_operations_with_dag_monitoring(ops) + + # Execute operations (Ray executor uses simple dataset.process) + dataset = dataset.process(ops) + + # Force materialization to get real execution + logger.info("Materializing dataset to collect real metrics...") + dataset.data = dataset.data.materialize() + + # Get metrics after execution + duration = time.time() - start_time + output_rows = dataset.data.count() + + # Post-execute DAG monitoring (log operation completion events with real metrics) + if self.pipeline_dag: + metrics = {"duration": duration, "input_rows": input_rows, "output_rows": output_rows} + self._post_execute_operations_with_dag_monitoring(ops, metrics=metrics) # 4. data export if not skip_export: @@ -137,5 +188,9 @@ def run(self, load_data_np: Optional[PositiveInt] = None, skip_export: bool = Fa tend = time.time() logger.info(f"All Ops are done in {tend - tstart:.3f}s.") + # Log job completion with DAG context + job_duration = time.time() - tstart + self.log_job_complete(job_duration, self.cfg.export_path) + if not skip_return: return dataset diff --git a/data_juicer/core/executor/ray_executor_partitioned.py b/data_juicer/core/executor/ray_executor_partitioned.py new file mode 100644 index 0000000000..0f21e7c2ed --- /dev/null +++ b/data_juicer/core/executor/ray_executor_partitioned.py @@ -0,0 +1,748 @@ +""" +Simplified Partitioned Ray Executor for Large Dataset Processing + +This module implements a streamlined partitioned execution strategy for Ray mode that: +2. Splits the dataset into manageable partitions using Ray's .split() method +3. Processes each partition independently with Ray tasks +4. Merges results back into a single dataset for export +5. Supports convergence points for global operations (like deduplicators) +""" + +import os +import shutil +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Any, List, Optional + +from jsonargparse import Namespace +from loguru import logger +from pydantic import PositiveInt + +from data_juicer.core.data.dataset_builder import DatasetBuilder +from data_juicer.core.data.ray_dataset import RayDataset +from data_juicer.core.executor import ExecutorBase +from data_juicer.core.executor.dag_execution_mixin import DAGExecutionMixin +from data_juicer.core.executor.event_logging_mixin import EventLoggingMixin, EventType +from data_juicer.core.ray_exporter import RayExporter +from data_juicer.ops import load_ops +from data_juicer.ops.op_fusion import fuse_operators +from data_juicer.utils.ckpt_utils import CheckpointStrategy, RayCheckpointManager +from data_juicer.utils.config_utils import ConfigAccessor +from data_juicer.utils.lazy_loader import LazyLoader + +ray = LazyLoader("ray") + + +class TempDirManager: + """Context manager for temporary directory cleanup.""" + + def __init__(self, tmp_dir): + self.tmp_dir = tmp_dir + + def __enter__(self): + os.makedirs(self.tmp_dir, exist_ok=True) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if os.path.exists(self.tmp_dir): + logger.info(f"Removing tmp dir {self.tmp_dir} ...") + shutil.rmtree(self.tmp_dir) + + +# Note: Using Ray Data's built-in map_batches for parallel processing instead of custom remote functions + + +# Simplified classes for basic functionality +@dataclass +class PartitionResult: + """Simple result container for partition processing.""" + + partition_id: int + dataset: Optional[Any] = None + success: bool = False + error: Optional[str] = None + + +class PartitionedRayExecutor(ExecutorBase, DAGExecutionMixin, EventLoggingMixin): + """ + Simplified Ray executor with dataset partitioning using .split(). + + Features: + - Single DatasetBuilder loads the full dataset + - Uses Ray's .split() method for partitioning + - Processes partitions in parallel with Ray tasks + - Supports convergence points for global operations + - Merges results back into a single dataset + """ + + def __init__(self, cfg: Optional[Namespace] = None): + """Initialize the partitioned Ray executor.""" + super().__init__(cfg) + + self.executor_type = "ray_partitioned" + self.work_dir = self.cfg.work_dir + self.job_id = self.cfg.get("job_id", None) + + # Initialize temporary directory for Ray operations + self.tmp_dir = os.path.join(self.work_dir, ".tmp", ray.get_runtime_context().get_job_id()) + + # Initialize EventLoggingMixin for job management and event logging + EventLoggingMixin.__init__(self, cfg) + + # Initialize DAGExecutionMixin for AST/DAG functionality + DAGExecutionMixin.__init__(self) + + # Override strategy methods for partitioned execution + self._override_strategy_methods() + + self.datasetbuilder = DatasetBuilder(self.cfg, executor_type="ray") + + # Partition configuration + self._configure_partitioning() + + # Checkpoint configuration and manager initialization + checkpoint_cfg = getattr(self.cfg, "checkpoint", None) + checkpoint_dir = getattr(self.cfg, "checkpoint_dir", os.path.join(self.work_dir, "checkpoints")) + + if checkpoint_cfg: + # Use ConfigAccessor to handle both dict and object configurations + checkpoint_enabled = ConfigAccessor.get(checkpoint_cfg, "enabled", True) + strategy_str = ConfigAccessor.get(checkpoint_cfg, "strategy", "every_op") + checkpoint_n_ops = ConfigAccessor.get(checkpoint_cfg, "n_ops", 1) + checkpoint_op_names = ConfigAccessor.get(checkpoint_cfg, "op_names", []) + + # Parse checkpoint strategy with validation + try: + checkpoint_strategy = CheckpointStrategy(strategy_str) + except ValueError: + logger.warning(f"Unknown checkpoint strategy: {strategy_str}, defaulting to EVERY_OP") + checkpoint_strategy = CheckpointStrategy.EVERY_OP + else: + checkpoint_enabled = False + checkpoint_strategy = CheckpointStrategy.DISABLED + checkpoint_n_ops = 1 + checkpoint_op_names = [] + + # Initialize Ray checkpoint manager + self.ckpt_manager = RayCheckpointManager( + ckpt_dir=checkpoint_dir, + checkpoint_enabled=checkpoint_enabled, + checkpoint_strategy=checkpoint_strategy, + checkpoint_n_ops=checkpoint_n_ops, + checkpoint_op_names=checkpoint_op_names, + event_logger=self, + ) + + logger.info(f"Checkpointing: {'enabled' if self.ckpt_manager.checkpoint_enabled else 'disabled'}") + if self.ckpt_manager.checkpoint_enabled: + logger.info(f"Checkpoint strategy: {self.ckpt_manager.checkpoint_strategy.value}") + logger.info(f"Checkpoint directory: {self.ckpt_manager.ckpt_dir}") + + # Initialize RayExporter for final output + logger.info("Preparing exporter...") + # Prepare export extra args, including S3 credentials if export_path is S3 + export_extra_args = dict(self.cfg.export_extra_args) if hasattr(self.cfg, "export_extra_args") else {} + + # If export_path is S3, extract AWS credentials with priority: + # 1. export_aws_credentials (export-specific) + # 2. dataset config (for backward compatibility) + # 3. environment variables (handled by exporter) + if self.cfg.export_path.startswith("s3://"): + # Pass export-specific credentials if provided. + # The RayExporter will handle falling back to environment variables or other credential mechanisms. + if hasattr(self.cfg, "export_aws_credentials") and self.cfg.export_aws_credentials: + export_aws_creds = self.cfg.export_aws_credentials + if hasattr(export_aws_creds, "aws_access_key_id"): + export_extra_args["aws_access_key_id"] = export_aws_creds.aws_access_key_id + if hasattr(export_aws_creds, "aws_secret_access_key"): + export_extra_args["aws_secret_access_key"] = export_aws_creds.aws_secret_access_key + if hasattr(export_aws_creds, "aws_session_token"): + export_extra_args["aws_session_token"] = export_aws_creds.aws_session_token + if hasattr(export_aws_creds, "aws_region"): + export_extra_args["aws_region"] = export_aws_creds.aws_region + if hasattr(export_aws_creds, "endpoint_url"): + export_extra_args["endpoint_url"] = export_aws_creds.endpoint_url + + self.exporter = RayExporter( + self.cfg.export_path, + getattr(self.cfg, "export_type", None), + getattr(self.cfg, "export_shard_size", 0), + keep_stats_in_res_ds=getattr(self.cfg, "keep_stats_in_res_ds", True), + keep_hashes_in_res_ds=getattr(self.cfg, "keep_hashes_in_res_ds", False), + **export_extra_args, + ) + + def _configure_partitioning(self): + """Configure partitioning based on manual or auto mode.""" + # Get partition configuration + partition_cfg = getattr(self.cfg, "partition", {}) + + # Use ConfigAccessor to handle both dict and object configurations + mode = ConfigAccessor.get(partition_cfg, "mode", "auto") + num_of_partitions = ConfigAccessor.get(partition_cfg, "num_of_partitions", 4) + partition_size = ConfigAccessor.get(partition_cfg, "size", 5000) + max_size_mb = ConfigAccessor.get(partition_cfg, "max_size_mb", 64) + + # Fallback to legacy configuration if partition config is not available + # or if legacy num_partitions is explicitly set + if ( + not partition_cfg + or hasattr(self.cfg, "num_partitions") + and getattr(self.cfg, "num_partitions", None) is not None + ): + mode = "manual" + num_of_partitions = getattr(self.cfg, "num_partitions", 4) + if not partition_cfg: + logger.warning("No partition configuration found, using legacy num_partitions") + else: + logger.warning("Legacy num_partitions detected, overriding partition configuration") + + self.partition_mode = mode + self.num_partitions = num_of_partitions + self.partition_size = partition_size + self.max_size_mb = max_size_mb + + if mode == "manual": + logger.info(f"Manual partition mode: using {self.num_partitions} partitions") + else: # auto mode + logger.info(f"Auto partition mode: will determine optimal partitioning based on data characteristics") + logger.info(f"Fallback partition size: {self.partition_size} samples, max {self.max_size_mb} MB") + + def _configure_auto_partitioning(self, dataset, ops): + """Configure partitioning using the partition size optimizer for auto mode.""" + try: + from data_juicer.core.executor.partition_size_optimizer import ( + auto_configure_resources, + ) + + logger.info("🔧 Auto-configuring partition settings based on data characteristics...") + + # Use the partition size optimizer to determine optimal settings + recommendations = auto_configure_resources(self.cfg, dataset, ops) + + # Update partition configuration based on recommendations + recommended_size = ConfigAccessor.get(recommendations, "recommended_partition_size", self.partition_size) + recommended_max_size_mb = ConfigAccessor.get(recommendations, "recommended_max_size_mb", self.max_size_mb) + recommended_workers = ConfigAccessor.get( + recommendations, "recommended_worker_count", getattr(self.cfg, "np", 4) + ) + + # Calculate optimal number of partitions based on dataset size and recommended partition size + try: + if hasattr(dataset, "count"): + total_samples = dataset.count() + elif hasattr(dataset, "__len__"): + total_samples = len(dataset) + else: + total_samples = 10000 # Fallback estimate + + # Calculate number of partitions needed + self.num_partitions = max(1, int(total_samples / recommended_size)) + + # Ensure we don't create too many partitions (max 32 for efficiency) + self.num_partitions = min(self.num_partitions, 32) + + logger.info(f"📊 Dataset analysis complete:") + logger.info(f" Total samples: {total_samples}") + logger.info(f" Recommended partition size: {recommended_size} samples") + logger.info(f" Calculated partitions: {self.num_partitions}") + logger.info(f" Recommended max size: {recommended_max_size_mb} MB") + logger.info(f" Recommended workers: {recommended_workers}") + + # Update worker count if not already set + if not hasattr(self.cfg, "np") or self.cfg.np is None: + self.cfg.np = recommended_workers + logger.info(f" Updated worker count to: {recommended_workers}") + + except Exception as e: + logger.warning(f"Could not determine dataset size for partition calculation: {e}") + logger.info(f"Using fallback partition count: {self.num_partitions}") + + except ImportError as e: + logger.warning(f"Could not import partition size optimizer: {e}") + logger.info("Falling back to manual partition configuration") + except Exception as e: + logger.warning(f"Auto partition configuration failed: {e}") + logger.info("Falling back to manual partition configuration") + + def run(self, load_data_np: Optional[PositiveInt] = None, skip_return=False): + """ + Run the simplified partitioned dataset processing pipeline. + + Args: + load_data_np: Number of workers for loading dataset + skip_return: Whether to skip returning the dataset + job_id: Optional job ID to resume from checkpoints + + Returns: + Processed dataset + """ + # Use TempDirManager to ensure cleanup of temporary files + with TempDirManager(self.tmp_dir): + return self._run_impl(load_data_np, skip_return) + + def _run_impl(self, load_data_np: Optional[PositiveInt] = None, skip_return=False): + """ + Internal implementation of the run method. + """ + job_start_time = time.time() + + # Check if user provided a job_id (indicating resumption attempt) + user_provided_job_id = getattr(self.cfg, "_user_provided_job_id", False) + + if user_provided_job_id and self.job_id: + logger.info(f"🔄 User provided job_id: {self.job_id} - attempting to resume job") + resume_result = self._resume_job(self.job_id) + if resume_result == "completed": + logger.info("✅ Job is already completed - nothing to do") + return None # Exit gracefully + elif resume_result == "resuming": + logger.info("✅ Job resumption successful - will use existing checkpoints") + is_resuming = True + else: # resume_result == "failed" + logger.info("❌ Job resumption failed - starting fresh") + is_resuming = False + else: + if self.job_id: + logger.info(f"🚀 Starting new job with auto-generated job_id: {self.job_id}") + else: + logger.info("🚀 Starting new job") + is_resuming = False + + if not is_resuming: + logger.info("🚀 Starting simplified partitioned processing...") + else: + logger.info("🔄 Resuming partitioned processing from checkpoints...") + + # Log job start event + self._log_event( + event_type=EventType.JOB_START, + message=( + "Starting partitioned dataset processing" + if not is_resuming + else "Resuming partitioned dataset processing" + ), + metadata={ + "num_partitions": self.num_partitions, + "checkpoint_enabled": self.ckpt_manager.checkpoint_enabled, + "is_resuming": is_resuming, + "job_id": self.job_id, + "user_provided_job_id": user_provided_job_id, + }, + ) + + # Note: Config validation is handled in _resume_job() if resuming + + # Load the full dataset using a single DatasetBuilder + logger.info("Loading dataset with single DatasetBuilder...") + + dataset = self.datasetbuilder.load_dataset(num_proc=load_data_np) + columns = dataset.schema().columns + + # Prepare operations + logger.info("Preparing operations...") + ops = self._prepare_operators() + + # Handle auto partition mode BEFORE initializing DAG + # (DAG needs final partition count) + if self.partition_mode == "auto": + self._configure_auto_partitioning(dataset, ops) + + # Initialize DAG execution planning with final partition count + self._initialize_dag_execution(self.cfg) + + # Log job start with DAG context + # Handle both dataset_path (string) and dataset (dict) configurations + dataset_info = {} + if hasattr(self.cfg, "dataset_path") and self.cfg.dataset_path: + dataset_info["dataset_path"] = self.cfg.dataset_path + if hasattr(self.cfg, "dataset") and self.cfg.dataset: + dataset_info["dataset"] = self.cfg.dataset + + job_config = { + **dataset_info, + "work_dir": self.work_dir, + "executor_type": self.executor_type, + "dag_node_count": len(self.pipeline_dag.nodes) if self.pipeline_dag else 0, + "dag_edge_count": len(self.pipeline_dag.edges) if self.pipeline_dag else 0, + "parallel_groups_count": len(self.pipeline_dag.parallel_groups) if self.pipeline_dag else 0, + } + self.log_job_start(job_config, len(ops)) + + # Detect convergence points for global operations + convergence_points = self._detect_convergence_points(self.cfg) + + if convergence_points: + logger.info(f"Found convergence points at operations: {convergence_points}") + final_dataset = self._process_with_convergence(dataset, ops, convergence_points) + else: + logger.info("No convergence points found, processing with simple partitioning") + final_dataset = self._process_with_simple_partitioning(dataset, ops) + + # Export final dataset + logger.info("Exporting final dataset...") + self.exporter.export(final_dataset.data, columns=columns) + + job_duration = time.time() - job_start_time + logger.info(f"✅ Job completed successfully in {job_duration:.2f}s") + logger.info(f"📁 Output saved to: {self.cfg.export_path}") + + # Log job completion with DAG context + self.log_job_complete(job_duration, self.cfg.export_path) + + if skip_return: + return None + + return final_dataset + + def cleanup_temp_files(self): + """Manually clean up temporary files from previous runs.""" + tmp_base_dir = os.path.join(self.work_dir, ".tmp") + if os.path.exists(tmp_base_dir): + logger.info(f"Cleaning up temporary files in {tmp_base_dir}") + shutil.rmtree(tmp_base_dir) + logger.info("✅ Temporary files cleaned up successfully") + else: + logger.info("No temporary files found to clean up") + + def _process_with_simple_partitioning(self, dataset: RayDataset, ops: List): + """ + Process dataset with real partitioning using Ray Data's split and union. + """ + logger.info("Processing with real partitioning using Ray Data's split and union...") + + # Split the dataset into partitions + logger.info(f"Splitting dataset into {self.num_partitions} partitions...") + partitions = dataset.data.split(self.num_partitions) + logger.info(f"Created {len(partitions)} partitions") + + # Process each partition separately with checkpointing + logger.info("Processing partitions with checkpointing support...") + processed_partitions = [] + + for i, partition in enumerate(partitions): + logger.info(f"Processing partition {i+1}/{len(partitions)}") + + # Log partition start event + self._log_event( + event_type=EventType.PARTITION_START, + message=f"Starting processing of partition {i+1}/{len(partitions)}", + partition_id=i, + ) + + # Create a RayDataset wrapper for this partition + partition_dataset = RayDataset(partition, cfg=self.cfg) + + # Apply operations with checkpointing support and DAG monitoring + processed_partition = self._process_with_checkpointing(partition_dataset, i, ops) + + # Store the processed partition's data + processed_partitions.append(processed_partition.data) + + # Log partition completion event + self._log_event( + event_type=EventType.PARTITION_COMPLETE, + message=f"Completed processing of partition {i+1}/{len(partitions)}", + partition_id=i, + ) + + # Merge all processed partitions back into a single dataset + logger.info("Merging processed partitions...") + if len(processed_partitions) == 1: + merged_dataset = processed_partitions[0] + else: + # Union all partitions + merged_dataset = processed_partitions[0] + for partition in processed_partitions[1:]: + merged_dataset = merged_dataset.union(partition) + + # Return as RayDataset wrapper + return RayDataset(merged_dataset, cfg=self.cfg) + + def _process_with_convergence(self, dataset: RayDataset, ops: List, convergence_points: List[int]): + """ + Process dataset with convergence support for global operations. + """ + logger.info("Processing with convergence support for global operations...") + + # Find the first convergence point + first_convergence = min(convergence_points) + logger.info(f"First convergence point at operation {first_convergence}") + + # Split operations into pre-convergence and post-convergence + pre_convergence_ops = ops[:first_convergence] + post_convergence_ops = ops[first_convergence:] + + logger.info(f"Pre-convergence operations: {len(pre_convergence_ops)}") + logger.info(f"Post-convergence operations: {len(post_convergence_ops)}") + + # Process partitions up to convergence point + if pre_convergence_ops: + logger.info("Processing partitions up to convergence point...") + processed_dataset = self._process_with_simple_partitioning(dataset, pre_convergence_ops) + else: + logger.info("No pre-convergence operations, using original dataset...") + processed_dataset = dataset + + # Merge partitions for global operations + logger.info("Merging partitions for global operations...") + merged_dataset = processed_dataset.data + + # Process merged dataset with post-convergence operations + if post_convergence_ops: + logger.info("Processing merged dataset with global operations...") + merged_ray_dataset = RayDataset(merged_dataset, cfg=self.cfg) + + # Pre-execute DAG monitoring (log operation start events) + if self.pipeline_dag: + self._pre_execute_operations_with_dag_monitoring(post_convergence_ops, partition_id=0) + + # Execute operations + final_dataset = merged_ray_dataset.process(post_convergence_ops) + + # Post-execute DAG monitoring (log operation completion events) + if self.pipeline_dag: + self._post_execute_operations_with_dag_monitoring(post_convergence_ops, partition_id=0) + + logger.info("Global operations completed. Final dataset ready for export") + return final_dataset + else: + # No post-convergence operations, just return the merged result + return RayDataset(merged_dataset, cfg=self.cfg) + + def _process_with_checkpointing(self, dataset: RayDataset, partition_id: int, ops: List) -> RayDataset: + """ + Process dataset with checkpointing support. + Groups operations and checkpoints between groups based on strategy. + """ + logger.info(f"Processing partition {partition_id} with checkpointing support...") + + if not self.ckpt_manager.checkpoint_enabled: + logger.info(f"Checkpointing disabled, processing all operations at once for partition {partition_id}") + + # Get input row count before processing + input_rows = dataset.data.count() + start_time = time.time() + + # Pre-execute DAG monitoring (log operation start events) + if self.pipeline_dag: + self._pre_execute_operations_with_dag_monitoring(ops, partition_id=partition_id) + + # Execute operations (lazy) + processed_dataset = dataset.process(ops) + + # Force materialization to get real execution (required for union anyway) + processed_dataset.data = processed_dataset.data.materialize() + + # Get metrics after execution + duration = time.time() - start_time + output_rows = processed_dataset.data.count() + + logger.info(f"Partition {partition_id}: Processed {input_rows}→{output_rows} rows in {duration:.2f}s") + + # Post-execute DAG monitoring with real metrics + if self.pipeline_dag: + metrics = {"duration": duration, "input_rows": input_rows, "output_rows": output_rows} + self._post_execute_operations_with_dag_monitoring(ops, partition_id=partition_id, metrics=metrics) + + return processed_dataset + + # check the latest checkpoint for the partition + latest_checkpoint = self.ckpt_manager.find_latest_checkpoint(partition_id) + + # Group operations based on checkpoint strategy + op_groups = self.ckpt_manager.group_operations_for_checkpointing(ops) + logger.info(f"Grouped {len(ops)} operations into {len(op_groups)} groups for checkpointing") + logger.info(f"Detailed op gruops: {op_groups}") + + current_dataset = dataset + + for group_idx, (start_idx, end_idx, group_ops) in enumerate(op_groups): + logger.info( + f"Processing partition {partition_id}, group {group_idx + 1}/{len(op_groups)}: operations {start_idx}-{end_idx-1}" + ) + + if latest_checkpoint and latest_checkpoint[0] >= end_idx: + logger.info( + f"Partition {partition_id}: All operations in group {group_idx + 1} already processed (checkpoint at op {latest_checkpoint[0]}, group ends at {end_idx-1}), skipping" + ) + continue + + if latest_checkpoint and latest_checkpoint[0] >= start_idx: + logger.info(f"Partition {partition_id}: Resuming from checkpoint at operation {latest_checkpoint[0]}") + current_dataset = self.ckpt_manager.load_checkpoint( + latest_checkpoint[0], latest_checkpoint[1], partition_id, cfg=self.cfg + ) + if current_dataset is None: + logger.warning(f"Partition {partition_id}: Failed to load checkpoint, starting from beginning") + current_dataset = dataset + group_ops = ops[start_idx:end_idx] # Start from beginning of group + logger.info( + f"Partition {partition_id}: Will process {len(group_ops)} operations from beginning of group" + ) + else: + logger.info( + f"Partition {partition_id}: Successfully loaded checkpoint, resuming from operation {latest_checkpoint[0] + 1}" + ) + group_ops = ops[latest_checkpoint[0] + 1 : end_idx] # Resume from checkpoint + if not group_ops: + logger.info( + f"Partition {partition_id}: All operations in this group already processed, skipping" + ) + continue + else: + logger.info( + f"Partition {partition_id}: Will process {len(group_ops)} remaining operations from checkpoint" + ) + + # Process the group of operations + if group_ops: + logger.info( + f"Partition {partition_id}: Processing {len(group_ops)} operations in group {group_idx + 1}" + ) + + # Get input row count before processing + input_rows = current_dataset.data.count() + start_time = time.time() + + # Pre-execute DAG monitoring (log operation start events) + if self.pipeline_dag: + self._pre_execute_operations_with_dag_monitoring(group_ops, partition_id=partition_id) + + # Execute operations (lazy) + current_dataset = current_dataset.process(group_ops) + + # Force materialization (required for checkpointing anyway) + current_dataset.data = current_dataset.data.materialize() + + # Get metrics after execution + duration = time.time() - start_time + output_rows = current_dataset.data.count() + + logger.info( + f"Partition {partition_id}, group {group_idx + 1}: Processed {input_rows}→{output_rows} rows in {duration:.2f}s" + ) + + # Post-execute DAG monitoring with real metrics + if self.pipeline_dag: + metrics = {"duration": duration, "input_rows": input_rows, "output_rows": output_rows} + self._post_execute_operations_with_dag_monitoring( + group_ops, partition_id=partition_id, metrics=metrics + ) + + # Checkpoint after the last operation in the group + if group_ops: + last_op_idx = end_idx - 1 + last_op_name = ops[last_op_idx]._name + if self.ckpt_manager.should_checkpoint(last_op_idx, last_op_name): + logger.info( + f"Partition {partition_id}: Creating checkpoint after operation {last_op_idx}: {last_op_name}" + ) + # Data already materialized above, safe to checkpoint + self.ckpt_manager.save_checkpoint( + current_dataset, last_op_idx, last_op_name, partition_id, cfg=self.cfg + ) + + return current_dataset + + def _find_work_directory(self, job_id: str) -> Optional[str]: + """Find the work directory based on job_id.""" + # Check if the current work_dir already contains the job_id + current_work_dir = Path(self.work_dir) + logger.info(f"Checking if current work_dir contains job_id: {current_work_dir}") + + if job_id in str(current_work_dir): + # Current work_dir already contains job_id, check if it's a valid work directory + logger.info(f"Current work_dir contains job_id '{job_id}', checking if it's a valid work directory") + + # Check if this directory has events files (indicating it's a work directory) + latest_events_file = self.event_logger.find_latest_events_file(str(current_work_dir)) + if latest_events_file: + logger.info(f"Found events file in current work_dir: {latest_events_file}") + return str(current_work_dir) + + logger.warning(f"No events file found in current work_dir: {current_work_dir}") + + logger.warning(f"No directory found containing job_id '{job_id}' with events files") + return None + + def _check_job_completion(self, work_dir: str, job_id: str) -> bool: + """Check if the job is already completed.""" + latest_events_file = self.event_logger.find_latest_events_file(work_dir) + if not latest_events_file: + logger.info(f"No events file found in work directory: {work_dir}") + return False + + is_completed = self.event_logger.check_job_completion(latest_events_file) + if is_completed: + logger.info(f"Job {job_id} is already completed - no need to resume") + else: + logger.info(f"Job {job_id} is not completed - resumption possible") + + return is_completed + + def _resume_job(self, job_id: str) -> str: + """Resume a job from checkpoints. + + Returns: + "completed": Job is already completed + "resuming": Job can be resumed + "failed": Job resumption failed + """ + logger.info(f"Attempting to resume job: {job_id}") + + # Find work directory + work_dir = self._find_work_directory(job_id) + if not work_dir: + logger.error(f"Work directory not found for job_id: {job_id}") + return "failed" + + logger.info(f"Found work directory: {work_dir}") + + # Check if config validation passed (done during config initialization) + if not getattr(self.cfg, "_same_yaml_config", False): + logger.error("Config validation failed - configurations don't match") + return "failed" + + # Check if job is already completed + if self._check_job_completion(work_dir, job_id): + return "completed" # Job already completed + + # Update checkpoint directory to use the work directory's checkpoint directory + work_checkpoint_dir = os.path.join(work_dir, "checkpoints") + if os.path.exists(work_checkpoint_dir): + self.ckpt_manager.ckpt_dir = work_checkpoint_dir + logger.info(f"Using checkpoint directory from work directory: {self.ckpt_manager.ckpt_dir}") + else: + logger.warning(f"No checkpoint directory found in work directory: {work_checkpoint_dir}") + + return "resuming" + + def _prepare_operators(self): + """Prepare process operators.""" + ops = load_ops(self.cfg.process) + + # Check for op_fusion configuration with safe attribute access + if hasattr(self.cfg, "op_fusion") and self.cfg.op_fusion: + logger.info(f"Start OP fusion and reordering with strategy [{self.cfg.fusion_strategy}]...") + ops = fuse_operators(ops) + + return ops + + def _override_strategy_methods(self): + """Override strategy methods for partitioned execution.""" + # Override DAG-related methods for partitioned execution + # Note: Partition count is determined by the executor (self.num_partitions), + # not by the DAG mixin, so we don't override _determine_partition_count here + # Note: _detect_convergence_points is reused from DAGExecutionMixin (no override needed) + self._get_dag_node_for_operation = self._get_dag_node_for_operation_partitioned + + def _get_dag_node_for_operation_partitioned( + self, op_name: str, op_idx: int, partition_id: int = 0, **kwargs + ) -> Optional[str]: + """Get DAG node ID for partitioned operation.""" + if not self.dag_execution_strategy: + return None + + return self.dag_execution_strategy.get_dag_node_id(op_name, op_idx, partition_id=partition_id, **kwargs) diff --git a/data_juicer/core/pipeline_dag.py b/data_juicer/core/pipeline_dag.py new file mode 100644 index 0000000000..38c21ece9c --- /dev/null +++ b/data_juicer/core/pipeline_dag.py @@ -0,0 +1,453 @@ +""" +Pipeline DAG Representation for Data-Juicer Pipelines + +This module provides Pipeline DAG (Directed Acyclic Graph) representation and planning +capabilities for tracking execution state, dependencies, and monitoring. +""" + +import json +import time +from dataclasses import dataclass, field +from enum import Enum +from pathlib import Path +from typing import Any, Dict, List, Optional, Set + +from loguru import logger + + +class DAGNodeStatus(Enum): + """Status of a DAG node during execution.""" + + PENDING = "pending" + READY = "ready" + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + SKIPPED = "skipped" + + +class DAGEdgeType(Enum): + """Types of edges in the DAG.""" + + SEQUENTIAL = "sequential" # Standard sequential dependency + PARALLEL = "parallel" # Can run in parallel + CONDITIONAL = "conditional" # Conditional dependency + + +@dataclass +class DAGNode: + """Node in the execution DAG. + + Note: This is kept for backward compatibility, but strategies typically use dict nodes. + """ + + node_id: str + op_name: str + node_type: str # Changed from op_type: OpType to node_type: str + config: Dict[str, Any] + status: DAGNodeStatus = DAGNodeStatus.PENDING + dependencies: Set[str] = field(default_factory=set) + dependents: Set[str] = field(default_factory=set) + execution_order: int = -1 + estimated_duration: float = 0.0 + actual_duration: float = 0.0 + start_time: Optional[float] = None + end_time: Optional[float] = None + error_message: Optional[str] = None + metadata: Dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for serialization.""" + return { + "node_id": self.node_id, + "op_name": self.op_name, + "node_type": self.node_type, + "config": self.config, + "status": self.status.value, + "dependencies": list(self.dependencies), + "dependents": list(self.dependents), + "execution_order": self.execution_order, + "estimated_duration": self.estimated_duration, + "actual_duration": self.actual_duration, + "start_time": self.start_time, + "end_time": self.end_time, + "error_message": self.error_message, + "metadata": self.metadata, + } + + +@dataclass +class DAGEdge: + """Edge in the execution DAG.""" + + source_id: str + target_id: str + edge_type: DAGEdgeType = DAGEdgeType.SEQUENTIAL + condition: Optional[str] = None + metadata: Dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for serialization.""" + return { + "source_id": self.source_id, + "target_id": self.target_id, + "edge_type": self.edge_type.value, + "condition": self.condition, + "metadata": self.metadata, + } + + +class PipelineDAG: + """Pipeline DAG representation and execution planner.""" + + def __init__(self, work_dir: str): + """Initialize the Pipeline DAG. + + Args: + work_dir: Working directory for storing DAG execution plans and logs + """ + self.work_dir = Path(work_dir) + # Remove the separate dag_execution subdirectory - save directly in work_dir + # self.dag_dir = self.work_dir / "dag_execution" + # self.dag_dir.mkdir(parents=True, exist_ok=True) + self.dag_dir = self.work_dir # Use work_dir directly + + # DAG structure - support both DAGNode objects and dict nodes from strategies + self.nodes: Dict[str, Any] = {} + self.edges: List[DAGEdge] = [] # Not currently populated by strategies + self.execution_plan: List[str] = [] # Not currently populated by strategies + self.parallel_groups: List[List[str]] = [] # Not currently populated by strategies + + def save_execution_plan(self, filename: str = "dag_execution_plan.json") -> str: + """Save the execution plan to file. + + Args: + filename: Name of the file to save the plan + + Returns: + Path to the saved file + """ + # Save only static DAG structure, not execution state + static_nodes = {} + for node_id, node in self.nodes.items(): + # Handle both DAGNode objects and dict nodes from strategies + if hasattr(node, "to_dict"): + # DAGNode object + static_node_data = { + "node_id": node.node_id, + "op_name": node.op_name, + "op_type": node.op_type.value, + "config": node.config, + "dependencies": list(node.dependencies), + "dependents": list(node.dependents), + "execution_order": node.execution_order, + "estimated_duration": node.estimated_duration, + "metadata": node.metadata, + } + else: + # Dict node from strategy + static_node_data = { + "node_id": node["node_id"], + "op_name": node.get("operation_name", ""), + "op_type": node.get("node_type", "operation"), + "config": node.get("config", {}), + "dependencies": node.get("dependencies", []), + "dependents": node.get("dependents", []), + "execution_order": node.get("execution_order", 0), + "estimated_duration": node.get("estimated_duration", 0.0), + "metadata": node.get("metadata", {}), + } + static_nodes[node_id] = static_node_data + + plan_data = { + "nodes": static_nodes, + "edges": [edge.to_dict() for edge in self.edges], + "execution_plan": self.execution_plan, + "parallel_groups": self.parallel_groups, + "metadata": { + "created_at": time.time(), + "total_nodes": len(self.nodes), + "total_edges": len(self.edges), + "parallel_groups_count": len(self.parallel_groups), + }, + } + + plan_path = self.dag_dir / filename + with open(plan_path, "w") as f: + json.dump(plan_data, f, indent=2, default=str) + + logger.info(f"Execution plan saved to: {plan_path}") + return str(plan_path) + + def load_execution_plan(self, filename: str = "dag_execution_plan.json") -> bool: + """Load execution plan from file. + + Args: + filename: Name of the file to load the plan from + + Returns: + True if loaded successfully, False otherwise + """ + plan_path = self.dag_dir / filename + if not plan_path.exists(): + logger.warning(f"Execution plan file not found: {plan_path}") + return False + + try: + with open(plan_path, "r") as f: + plan_data = json.load(f) + + # Reconstruct nodes (static structure only) + self.nodes.clear() + for node_id, node_data in plan_data["nodes"].items(): + # Keep as dict to match strategy format + self.nodes[node_id] = { + "node_id": node_data["node_id"], + "operation_name": node_data.get("op_name", node_data.get("operation_name", "")), + "node_type": node_data.get("node_type", node_data.get("op_type", "operation")), + "config": node_data.get("config", {}), + "status": "pending", # Always start with pending status + "dependencies": node_data.get("dependencies", []), + "dependents": node_data.get("dependents", []), + "execution_order": node_data.get("execution_order", 0), + "estimated_duration": node_data.get("estimated_duration", 0.0), + "actual_duration": 0.0, # Reset execution state + "start_time": None, # Reset execution state + "end_time": None, # Reset execution state + "error_message": None, # Reset execution state + "metadata": node_data.get("metadata", {}), + } + + # Reconstruct edges + self.edges.clear() + for edge_data in plan_data["edges"]: + edge = DAGEdge( + source_id=edge_data["source_id"], + target_id=edge_data["target_id"], + edge_type=DAGEdgeType(edge_data["edge_type"]), + condition=edge_data["condition"], + metadata=edge_data["metadata"], + ) + self.edges.append(edge) + + # Load execution plan and parallel groups + self.execution_plan = plan_data["execution_plan"] + self.parallel_groups = plan_data["parallel_groups"] + + logger.info(f"Execution plan loaded from: {plan_path}") + return True + + except Exception as e: + logger.error(f"Failed to load execution plan: {e}") + return False + + def visualize(self) -> str: + """Generate a string representation of the DAG for visualization.""" + if not self.nodes: + return "Empty DAG" + + lines = ["DAG Execution Plan:"] + lines.append("=" * 50) + + # Show execution order (if available, otherwise show all nodes) + if self.execution_plan: + lines.append("Execution Order:") + for i, node_id in enumerate(self.execution_plan): + node = self.nodes[node_id] + status = DAGNodeStatus(node.get("status", "pending")) if isinstance(node, dict) else node.status + op_name = node.get("operation_name", "unknown") if isinstance(node, dict) else node.op_name + op_type = ( + node.get("node_type", "operation") + if isinstance(node, dict) + else getattr(node, "node_type", "operation") + ) + + status_icon = { + DAGNodeStatus.PENDING: "⏳", + DAGNodeStatus.READY: "✅", + DAGNodeStatus.RUNNING: "🔄", + DAGNodeStatus.COMPLETED: "✅", + DAGNodeStatus.FAILED: "❌", + DAGNodeStatus.SKIPPED: "⏭️", + }.get(status, "❓") + + lines.append(f" {i+1:2d}. {status_icon} {op_name} ({op_type})") + else: + # No execution plan, show all nodes + lines.append("Nodes:") + for i, (node_id, node) in enumerate(self.nodes.items()): + status = DAGNodeStatus(node.get("status", "pending")) if isinstance(node, dict) else node.status + op_name = node.get("operation_name", "unknown") if isinstance(node, dict) else node.op_name + op_type = ( + node.get("node_type", "operation") + if isinstance(node, dict) + else getattr(node, "node_type", "operation") + ) + + status_icon = { + DAGNodeStatus.PENDING: "⏳", + DAGNodeStatus.READY: "✅", + DAGNodeStatus.RUNNING: "🔄", + DAGNodeStatus.COMPLETED: "✅", + DAGNodeStatus.FAILED: "❌", + DAGNodeStatus.SKIPPED: "⏭️", + }.get(status, "❓") + + lines.append(f" {i+1:2d}. {status_icon} {op_name} ({op_type})") + + # Show parallel groups + if self.parallel_groups: + lines.append("\nParallel Groups:") + for i, group in enumerate(self.parallel_groups): + group_names = [] + for node_id in group: + node = self.nodes[node_id] + if hasattr(node, "op_name"): + group_names.append(node.op_name) + else: + group_names.append(node.get("operation_name", "unknown")) + lines.append(f" Group {i+1}: {', '.join(group_names)}") + + # Show dependencies + lines.append("\nDependencies:") + for node_id, node in self.nodes.items(): + dependencies = node.get("dependencies", []) if isinstance(node, dict) else getattr(node, "dependencies", []) + op_name = ( + node.get("operation_name", "unknown") if isinstance(node, dict) else getattr(node, "op_name", "unknown") + ) + + if dependencies: + dep_names = [] + for dep_id in dependencies: + dep_node = self.nodes.get(dep_id, {}) + dep_name = ( + dep_node.get("operation_name", "unknown") + if isinstance(dep_node, dict) + else getattr(dep_node, "op_name", "unknown") + ) + dep_names.append(dep_name) + lines.append(f" {op_name} depends on: {', '.join(dep_names)}") + + return "\n".join(lines) + + def get_ready_nodes(self) -> List[str]: + """Get list of nodes that are ready to execute (all dependencies completed).""" + ready_nodes = [] + for node_id, node in self.nodes.items(): + # Handle both DAGNode objects and dict nodes + if hasattr(node, "status"): + status = node.status + dependencies = node.dependencies + else: + status = DAGNodeStatus(node.get("status", "pending")) + dependencies = node.get("dependencies", []) + + if status == DAGNodeStatus.PENDING: + # Check if all dependencies are completed + all_deps_completed = all( + self._get_node_status(dep_id) == DAGNodeStatus.COMPLETED for dep_id in dependencies + ) + if all_deps_completed: + ready_nodes.append(node_id) + return ready_nodes + + def _get_node_status(self, node: Any) -> DAGNodeStatus: + """Get status of a node, handling both DAGNode objects and dict nodes. + + Args: + node: Can be a node_id (str), DAGNode object, or dict representation of a node. + + Returns: + DAGNodeStatus of the node. + """ + if isinstance(node, str): + # Argument `node` is node_id + node = self.nodes[node] + if hasattr(node, "status"): + return node.status + elif isinstance(node, dict): + return DAGNodeStatus(node.get("status", "pending")) + else: + return DAGNodeStatus.PENDING + + def mark_node_started(self, node_id: str) -> None: + """Mark a node as started.""" + if node_id in self.nodes: + node = self.nodes[node_id] + current_time = time.time() + if hasattr(node, "status"): + node.status = DAGNodeStatus.RUNNING + node.start_time = current_time + elif isinstance(node, dict): + node["status"] = DAGNodeStatus.RUNNING.value + node["start_time"] = current_time + + def mark_node_completed(self, node_id: str, duration: float = None) -> None: + """Mark a node as completed.""" + if node_id in self.nodes: + node = self.nodes[node_id] + current_time = time.time() + if hasattr(node, "status"): + node.status = DAGNodeStatus.COMPLETED + node.end_time = current_time + if duration is not None: + node.actual_duration = duration + else: + node.actual_duration = current_time - (node.start_time or current_time) + elif isinstance(node, dict): + node["status"] = DAGNodeStatus.COMPLETED.value + node["end_time"] = current_time + if duration is not None: + node["actual_duration"] = duration + else: + node["actual_duration"] = current_time - (node.get("start_time", current_time)) + + def mark_node_failed(self, node_id: str, error_message: str) -> None: + """Mark a node as failed.""" + if node_id in self.nodes: + node = self.nodes[node_id] + current_time = time.time() + if hasattr(node, "status"): + node.status = DAGNodeStatus.FAILED + node.end_time = current_time + node.error_message = error_message + node.actual_duration = current_time - (node.start_time or current_time) + elif isinstance(node, dict): + node["status"] = DAGNodeStatus.FAILED.value + node["end_time"] = current_time + node["error_message"] = error_message + node["actual_duration"] = current_time - (node.get("start_time", current_time)) + + def get_execution_summary(self) -> Dict[str, Any]: + """Get execution summary statistics.""" + total_nodes = len(self.nodes) + + def get_node_duration(node): + if hasattr(node, "actual_duration"): + duration = node.actual_duration + return duration if duration is not None else 0 + elif isinstance(node, dict): + duration = node.get("actual_duration") + return duration if duration is not None else 0 + else: + return 0 + + completed_nodes = sum( + 1 for node in self.nodes.values() if self._get_node_status(node) == DAGNodeStatus.COMPLETED + ) + failed_nodes = sum(1 for node in self.nodes.values() if self._get_node_status(node) == DAGNodeStatus.FAILED) + running_nodes = sum(1 for node in self.nodes.values() if self._get_node_status(node) == DAGNodeStatus.RUNNING) + pending_nodes = sum(1 for node in self.nodes.values() if self._get_node_status(node) == DAGNodeStatus.PENDING) + + total_duration = sum(get_node_duration(node) for node in self.nodes.values()) + + return { + "total_nodes": total_nodes, + "completed_nodes": completed_nodes, + "failed_nodes": failed_nodes, + "running_nodes": running_nodes, + "pending_nodes": pending_nodes, + "completion_percentage": (completed_nodes / total_nodes * 100) if total_nodes > 0 else 0, + "total_duration": total_duration, + "parallel_groups_count": len(self.parallel_groups), + } diff --git a/data_juicer/core/ray_exporter.py b/data_juicer/core/ray_exporter.py index f0a231b0e8..a77b432183 100644 --- a/data_juicer/core/ray_exporter.py +++ b/data_juicer/core/ray_exporter.py @@ -131,7 +131,23 @@ def _export_impl(self, dataset, export_path, columns=None): :param columns: the columns to export. :return: """ - feature_fields = dataset.columns() if not columns else columns + # Handle empty dataset case - Ray returns None for columns() on empty datasets + # Check if dataset is empty by calling columns() regardless of columns parameter + cols = dataset.columns() + if cols is None: + # Empty dataset with unknown schema - create an empty file + from loguru import logger + + logger.warning(f"Dataset is empty, creating empty export file at {export_path}") + import os + + os.makedirs(os.path.dirname(export_path) or ".", exist_ok=True) + with open(export_path, "w"): + pass # Create empty file + return + + # Use provided columns or infer from dataset + feature_fields = columns if columns else cols removed_fields = [] if not self.keep_stats_in_res_ds: extra_fields = {Fields.stats, Fields.meta} diff --git a/data_juicer/ops/mapper/image_sam_3d_body_mapper.py b/data_juicer/ops/mapper/image_sam_3d_body_mapper.py index a3b3ebcb24..aeda1e5091 100644 --- a/data_juicer/ops/mapper/image_sam_3d_body_mapper.py +++ b/data_juicer/ops/mapper/image_sam_3d_body_mapper.py @@ -215,7 +215,18 @@ def process_single(self, sample=None, rank=None): os.makedirs(self.visualization_dir, exist_ok=True) vis_path = os.path.join(self.visualization_dir, os.path.splitext(img_name)[0] + "_vis.jpg") img = cv2.imread(image_path) - rend_img = vis_utils.visualize_sample_together(img, output, estimator.faces) + try: + rend_img = vis_utils.visualize_sample_together(img, output, estimator.faces) + except (ImportError, OSError) as e: + if "EGL" in str(e): + raise RuntimeError( + "Visualization requires EGL for offscreen rendering, but EGL " + "library was not found. To fix this:\n" + " - On Ubuntu/Debian: apt-get install libegl1-mesa libegl1-mesa-dev\n" + " - On headless servers: also install libgl1-mesa-dri\n" + " - Or disable visualization by not setting visualization_dir" + ) from e + raise cv2.imwrite( vis_path, rend_img.astype(np.uint8), diff --git a/data_juicer/ops/mapper/s3_download_file_mapper.py b/data_juicer/ops/mapper/s3_download_file_mapper.py index 994af5a165..8c62957698 100644 --- a/data_juicer/ops/mapper/s3_download_file_mapper.py +++ b/data_juicer/ops/mapper/s3_download_file_mapper.py @@ -4,13 +4,15 @@ import os.path as osp from typing import List, Union -import boto3 -from botocore.exceptions import ClientError from loguru import logger from data_juicer.ops.base_op import OPERATORS, Mapper +from data_juicer.utils.lazy_loader import LazyLoader from data_juicer.utils.s3_utils import get_aws_credentials +boto3 = LazyLoader("boto3") +botocore_exceptions = LazyLoader("botocore.exceptions") + OP_NAME = "s3_download_file_mapper" @@ -193,7 +195,7 @@ def _download_from_s3(self, s3_url: str, save_path: str = None, return_content: else: return "success", None, None, None - except ClientError as e: + except botocore_exceptions.ClientError as e: error_msg = f"S3 download failed: {e}" logger.error(error_msg) return "failed", error_msg, None, None diff --git a/data_juicer/ops/mapper/s3_upload_file_mapper.py b/data_juicer/ops/mapper/s3_upload_file_mapper.py index 7942b6e78d..a887416352 100644 --- a/data_juicer/ops/mapper/s3_upload_file_mapper.py +++ b/data_juicer/ops/mapper/s3_upload_file_mapper.py @@ -2,13 +2,15 @@ import os from typing import List, Union -import boto3 -from botocore.exceptions import ClientError from loguru import logger from data_juicer.ops.base_op import OPERATORS, Mapper +from data_juicer.utils.lazy_loader import LazyLoader from data_juicer.utils.s3_utils import get_aws_credentials +boto3 = LazyLoader("boto3") +botocore_exceptions = LazyLoader("botocore.exceptions") + OP_NAME = "s3_upload_file_mapper" @@ -137,7 +139,7 @@ def _check_s3_exists(self, s3_key: str) -> bool: try: self.s3_client.head_object(Bucket=self.s3_bucket, Key=s3_key) return True - except ClientError: + except botocore_exceptions.ClientError: return False def _upload_to_s3(self, local_path: str) -> tuple: @@ -191,7 +193,7 @@ def _upload_to_s3(self, local_path: str) -> tuple: return "success", s3_url, None - except ClientError as e: + except botocore_exceptions.ClientError as e: error_msg = f"S3 upload failed: {e}" logger.error(error_msg) return "failed", local_path, error_msg diff --git a/data_juicer/utils/ckpt_utils.py b/data_juicer/utils/ckpt_utils.py index f779a58eec..9ae609d512 100644 --- a/data_juicer/utils/ckpt_utils.py +++ b/data_juicer/utils/ckpt_utils.py @@ -1,10 +1,62 @@ import json import os +from abc import ABC, abstractmethod +from enum import Enum +from typing import Any, List, Optional, Tuple from loguru import logger -class CheckpointManager: +class CheckpointManagerBase(ABC): + """ + Base class for checkpoint managers. + + Provides common functionality for managing checkpoint directories and + defines the interface that checkpoint managers should implement. + """ + + def __init__(self, ckpt_dir: str): + """ + Initialize base checkpoint manager. + + :param ckpt_dir: Directory to save and load checkpoints + """ + self.ckpt_dir = ckpt_dir + # Ensure checkpoint directory exists + os.makedirs(self.ckpt_dir, exist_ok=True) + + @abstractmethod + def save_checkpoint(self, dataset: Any, **kwargs) -> str: + """ + Save a dataset checkpoint. + + :param dataset: Dataset to save + :param kwargs: Additional arguments specific to the implementation + :return: Path to saved checkpoint + """ + pass + + @abstractmethod + def load_checkpoint(self, **kwargs) -> Optional[Any]: + """ + Load a dataset checkpoint. + + :param kwargs: Arguments specific to the implementation (e.g., op_idx, partition_id) + :return: Loaded dataset or None if checkpoint doesn't exist + """ + pass + + def checkpoint_exists(self, checkpoint_path: str) -> bool: + """ + Check if a checkpoint file/directory exists. + + :param checkpoint_path: Path to checkpoint + :return: True if checkpoint exists, False otherwise + """ + return os.path.exists(checkpoint_path) + + +class CheckpointManager(CheckpointManagerBase): """ This class is used to save the latest version of dataset to checkpoint directory or load it from checkpoint directory, a bit like cache management @@ -22,7 +74,7 @@ def __init__(self, ckpt_dir, original_process_list, num_proc=1): :param original_process_list: process list in config :param num_proc: number of process workers when saving dataset """ - self.ckpt_dir = ckpt_dir + super().__init__(ckpt_dir) self.ckpt_ds_dir = os.path.join(self.ckpt_dir, "latest") self.ckpt_op_record = os.path.join(self.ckpt_dir, "ckpt_op.json") self.process_list = original_process_list @@ -123,8 +175,19 @@ def check_ops_to_skip(self): def save_ckpt(self, ds): """ Save dataset to checkpoint directory and dump processed ops list. + Alias for save_checkpoint for backward compatibility. + + :param ds: input dataset to save + """ + return self.save_checkpoint(ds) + + def save_checkpoint(self, ds, **kwargs): + """ + Save dataset to checkpoint directory and dump processed ops list. :param ds: input dataset to save + :param kwargs: Additional arguments (not used, kept for interface compatibility) + :return: Path to checkpoint directory """ left_sample_num = len(ds) ds.save_to_disk(self.ckpt_ds_dir, num_proc=min(self.num_proc, left_sample_num)) @@ -132,13 +195,251 @@ def save_ckpt(self, ds): with open(self.ckpt_op_record, "w") as fout: json.dump(self.op_record, fout) + return self.ckpt_ds_dir + def load_ckpt(self): """ Load dataset from a checkpoint file. + Alias for load_checkpoint for backward compatibility. + + :return: a dataset stored in checkpoint file. + """ + return self.load_checkpoint() + + def load_checkpoint(self, **kwargs): + """ + Load dataset from a checkpoint file. + :param kwargs: Additional arguments (not used, kept for interface compatibility) :return: a dataset stored in checkpoint file. """ from data_juicer.core.data import NestedDataset ds = NestedDataset.load_from_disk(self.ckpt_ds_dir) return ds + + +class CheckpointStrategy(Enum): + """Checkpoint strategies for controlling when to create checkpoints.""" + + EVERY_OP = "every_op" # Checkpoint after every operation + EVERY_N_OPS = "every_n_ops" # Checkpoint after every N operations + MANUAL = "manual" # Checkpoint only after specified operations + DISABLED = "disabled" # Disable checkpointing entirely + + +class RayCheckpointManager(CheckpointManagerBase): + """ + Checkpoint manager for Ray Data with per-partition checkpointing support. + + This class manages checkpoints for Ray Data datasets using Parquet format, + supporting per-partition checkpointing and various checkpoint strategies. + """ + + def __init__( + self, + ckpt_dir: str, + checkpoint_enabled: bool = True, + checkpoint_strategy: CheckpointStrategy = CheckpointStrategy.EVERY_OP, + checkpoint_n_ops: int = 1, + checkpoint_op_names: Optional[List[str]] = None, + event_logger=None, + ): + """ + Initialize Ray checkpoint manager. + + :param ckpt_dir: Directory to save and load checkpoints + :param checkpoint_enabled: Whether checkpointing is enabled + :param checkpoint_strategy: Strategy for when to create checkpoints + :param checkpoint_n_ops: Number of operations between checkpoints (for EVERY_N_OPS strategy) + :param checkpoint_op_names: List of operation names to checkpoint (for MANUAL strategy) + :param event_logger: Optional event logger for checkpoint events + """ + super().__init__(ckpt_dir) + self.checkpoint_enabled = checkpoint_enabled + self.checkpoint_strategy = checkpoint_strategy + self.checkpoint_n_ops = checkpoint_n_ops + self.checkpoint_op_names = set(checkpoint_op_names or []) + self.event_logger = event_logger + + # If strategy is DISABLED, disable checkpointing regardless of enabled flag + if self.checkpoint_strategy == CheckpointStrategy.DISABLED: + self.checkpoint_enabled = False + + def resolve_checkpoint_filename(self, op_idx: int, partition_id: int) -> str: + """Resolve checkpoint filename using consistent format.""" + return f"checkpoint_op_{op_idx:04d}_partition_{partition_id:04d}.parquet" + + def should_checkpoint(self, op_idx: int, op_name: str) -> bool: + """Determine if checkpoint should be created based on configuration strategy.""" + if not self.checkpoint_enabled: + return False + + if self.checkpoint_strategy == CheckpointStrategy.EVERY_OP: + return True + elif self.checkpoint_strategy == CheckpointStrategy.EVERY_N_OPS: + return (op_idx + 1) % self.checkpoint_n_ops == 0 + elif self.checkpoint_strategy == CheckpointStrategy.MANUAL: + return op_name in self.checkpoint_op_names + elif self.checkpoint_strategy == CheckpointStrategy.DISABLED: + return False + else: + logger.warning(f"Unknown checkpoint strategy: {self.checkpoint_strategy}, defaulting to every_op") + return True + + def save_checkpoint( + self, + dataset: Any, # RayDataset or ray.data.Dataset + op_idx: int, + op_name: Optional[str] = None, + partition_id: int = 0, + cfg: Optional[Any] = None, + ) -> str: + """ + Save dataset checkpoint to parquet format. + + :param dataset: RayDataset or ray.data.Dataset to save + :param op_idx: Operation index + :param op_name: Operation name (optional) + :param partition_id: Partition ID + :param cfg: Optional config for RayDataset wrapper + :return: Path to saved checkpoint + """ + checkpoint_filename = self.resolve_checkpoint_filename(op_idx, partition_id) + checkpoint_path = os.path.join(self.ckpt_dir, checkpoint_filename) + + # Ensure directory exists + os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True) + + # Extract ray.data.Dataset if it's wrapped in RayDataset + ray_data = dataset.data if hasattr(dataset, "data") else dataset + + # Save as parquet + ray_data.write_parquet(checkpoint_path) + + # Log checkpoint save event if event logger is available + if self.event_logger and hasattr(self.event_logger, "_log_event"): + from data_juicer.core.executor.event_logging_mixin import EventType + + self.event_logger._log_event( + event_type=EventType.CHECKPOINT_SAVE, + message=f"Saved checkpoint after operation {op_idx}: {op_name}", + partition_id=partition_id, + operation_name=op_name, + operation_idx=op_idx, + metadata={"checkpoint_path": checkpoint_path}, + ) + + logger.info(f"Saved checkpoint: {checkpoint_path}") + return checkpoint_path + + def load_checkpoint( + self, + op_idx: int, + op_name: Optional[str] = None, + partition_id: int = 0, + cfg: Optional[Any] = None, + ) -> Optional[Any]: # Returns RayDataset or None + """ + Load dataset checkpoint from parquet format. + + :param op_idx: Operation index + :param op_name: Operation name (optional) + :param partition_id: Partition ID + :param cfg: Optional config for RayDataset wrapper + :return: RayDataset or None if checkpoint doesn't exist + """ + checkpoint_filename = self.resolve_checkpoint_filename(op_idx, partition_id) + checkpoint_path = os.path.join(self.ckpt_dir, checkpoint_filename) + + if not os.path.exists(checkpoint_path): + return None + + try: + # Lazy import ray to avoid dependency if not using Ray + from data_juicer.utils.lazy_loader import LazyLoader + + ray = LazyLoader("ray") + + # Load from parquet + ray_dataset = ray.data.read_parquet(checkpoint_path) + + # Log checkpoint load event if event logger is available + if self.event_logger and hasattr(self.event_logger, "_log_event"): + from data_juicer.core.executor.event_logging_mixin import EventType + + self.event_logger._log_event( + event_type=EventType.CHECKPOINT_LOAD, + message=f"Loaded checkpoint from operation {op_idx}", + partition_id=partition_id, + operation_name=op_name or f"op_{op_idx:04d}", + operation_idx=op_idx, + metadata={"checkpoint_path": checkpoint_path}, + ) + + # Wrap in RayDataset if cfg is provided + if cfg is not None: + from data_juicer.core.data.ray_dataset import RayDataset + + return RayDataset(ray_dataset, cfg=cfg) + else: + return ray_dataset + + except Exception as e: + logger.warning(f"Failed to load checkpoint {checkpoint_path}: {e}") + return None + + def find_latest_checkpoint(self, partition_id: int = 0) -> Optional[Tuple[int, str, str]]: + """ + Find the latest checkpoint for a partition. + + :param partition_id: Partition ID + :return: Tuple of (op_idx, op_name, checkpoint_path) or None if no checkpoint found + """ + checkpoint_files = [] + + if not os.path.exists(self.ckpt_dir): + return None + + for filename in os.listdir(self.ckpt_dir): + if filename.startswith("checkpoint_op_") and filename.endswith(f"_partition_{partition_id:04d}.parquet"): + try: + # Parse filename: checkpoint_op_XXXX_partition_YYYY.parquet + parts = filename.replace(".parquet", "").split("_") + if len(parts) >= 4: + op_idx = int(parts[2]) + # For backward compatibility, we'll use a generic op_name + op_name = f"op_{op_idx:04d}" + checkpoint_files.append((op_idx, op_name, os.path.join(self.ckpt_dir, filename))) + except (ValueError, IndexError): + continue + + if not checkpoint_files: + return None + + # Return the latest checkpoint (highest op_idx) + latest = max(checkpoint_files, key=lambda x: x[0]) + return latest + + def group_operations_for_checkpointing(self, ops: List[Any]) -> List[Tuple[int, int, List[Any]]]: + """ + Group operations based on checkpoint strategy. + + :param ops: List of operations + :return: List of (start_idx, end_idx, group_ops) tuples + """ + groups = [] + current_start = 0 + + for i, op in enumerate(ops): + op_name = getattr(op, "_name", f"op_{i}") + if self.should_checkpoint(i, op_name): + # This operation should trigger a checkpoint + groups.append((current_start, i + 1, ops[current_start : i + 1])) + current_start = i + 1 + + # Add remaining operations as the last group + if current_start < len(ops): + groups.append((current_start, len(ops), ops[current_start:])) + + return groups diff --git a/data_juicer/utils/config_utils.py b/data_juicer/utils/config_utils.py new file mode 100644 index 0000000000..f1e727dc72 --- /dev/null +++ b/data_juicer/utils/config_utils.py @@ -0,0 +1,51 @@ +""" +Configuration utilities for handling both dict and object-style configs. +""" + +from typing import Any + + +class ConfigAccessor: + """Utility for accessing configuration values that may be dicts or objects.""" + + @staticmethod + def get(config: Any, key: str, default: Any = None) -> Any: + """ + Get a configuration value from either a dict or object. + + Args: + config: Configuration object (dict or object with attributes) + key: Key/attribute name to retrieve + default: Default value if key not found + + Returns: + Configuration value or default + """ + if config is None: + return default + if isinstance(config, dict): + return config.get(key, default) + return getattr(config, key, default) + + @staticmethod + def get_nested(config: Any, *keys: str, default: Any = None) -> Any: + """ + Get a nested configuration value. + + Example: + get_nested(cfg, 'partition', 'mode', default='auto') + + Args: + config: Configuration object + keys: Series of keys to traverse + default: Default value if path not found + + Returns: + Configuration value or default + """ + current = config + for key in keys: + if current is None: + return default + current = ConfigAccessor.get(current, key) + return current if current is not None else default diff --git a/data_juicer/utils/job/__init__.py b/data_juicer/utils/job/__init__.py new file mode 100644 index 0000000000..9809b5d2d5 --- /dev/null +++ b/data_juicer/utils/job/__init__.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +""" +Job utilities for DataJuicer. + +This module provides utilities for job management, monitoring, and analysis. +""" + +from .common import JobUtils, list_running_jobs +from .snapshot import ( + JobSnapshot, + OperationStatus, + PartitionStatus, + ProcessingSnapshotAnalyzer, + ProcessingStatus, + create_snapshot, +) + +__all__ = [ + "JobUtils", + "list_running_jobs", + "ProcessingSnapshotAnalyzer", + "create_snapshot", + "JobSnapshot", + "ProcessingStatus", + "OperationStatus", + "PartitionStatus", +] diff --git a/data_juicer/utils/job/common.py b/data_juicer/utils/job/common.py new file mode 100644 index 0000000000..f023e1cb4b --- /dev/null +++ b/data_juicer/utils/job/common.py @@ -0,0 +1,379 @@ +#!/usr/bin/env python3 +""" +DataJuicer Job Utilities - Common Functions + +Shared utilities for job stopping and monitoring operations. +""" + +import json +import os +import sys +import threading +import time +from pathlib import Path +from typing import Any, Dict, List, Optional, Set + +import psutil +from loguru import logger + + +class JobUtils: + """Common utilities for DataJuicer job operations.""" + + def __init__(self, job_id: str, work_dir: str = None, base_dir: str = None): + """ + Initialize job utilities. + + Args: + job_id: The job ID to work with + work_dir: Work directory that already includes job_id (preferred) + base_dir: Base directory containing job outputs (deprecated, use work_dir instead) + """ + self.job_id = job_id + if work_dir: + # work_dir already includes job_id + self.work_dir = Path(work_dir) + elif base_dir: + # Legacy: construct work_dir from base_dir + job_id + self.work_dir = Path(base_dir) / job_id + else: + # Default fallback + self.work_dir = Path("outputs/partition-checkpoint-eventlog") / job_id + + # Set up logging + logger.remove() + logger.add(sys.stderr, level="INFO", format="{time:HH:mm:ss} | {level} | {name}:{function}:{line} - {message}") + + if not self.work_dir.exists(): + raise FileNotFoundError(f"Job directory not found: {self.work_dir}") + + def load_job_summary(self) -> Optional[Dict[str, Any]]: + """Load job summary from the work directory.""" + job_summary_file = self.work_dir / "job_summary.json" + if not job_summary_file.exists(): + logger.error(f"Job summary not found: {job_summary_file}") + return None + + try: + with open(job_summary_file, "r") as f: + return json.load(f) + except Exception as e: + logger.error(f"Failed to load job summary: {e}") + return None + + def load_dataset_mapping(self) -> Dict[str, Any]: + """Load dataset mapping information.""" + mapping_file = self.work_dir / "metadata" / "dataset_mapping.json" + if mapping_file.exists(): + try: + with open(mapping_file, "r") as f: + return json.load(f) + except Exception as e: + logger.warning(f"Failed to load dataset mapping: {e}") + return {} + + def _find_latest_events_file(self) -> Optional[Path]: + """Find the latest events file in the work directory.""" + # Look for events files with timestamp pattern (events_*.jsonl) + events_files = list(self.work_dir.glob("events_*.jsonl")) + if events_files: + # Sort by modification time and return the latest + return max(events_files, key=lambda f: f.stat().st_mtime) + + # Fallback to old naming convention for backward compatibility + fallback_file = self.work_dir / "events.jsonl" + return fallback_file if fallback_file.exists() else None + + def load_event_logs(self) -> List[Dict[str, Any]]: + """Load and parse event logs.""" + events_file = self._find_latest_events_file() + events = [] + + if events_file and events_file.exists(): + try: + with open(events_file, "r") as f: + for line in f: + try: + events.append(json.loads(line.strip())) + except json.JSONDecodeError: + continue + except Exception as e: + logger.error(f"Failed to read events file: {e}") + else: + logger.warning(f"Events file not found: {events_file}") + + return events + + def extract_process_thread_ids(self) -> Dict[str, Set[int]]: + """ + Extract process and thread IDs from event logs. + Returns a dict with 'process_ids' and 'thread_ids' sets. + """ + events = self.load_event_logs() + process_ids = set() + thread_ids = set() + + for event in events: + # Extract process ID + if "process_id" in event and event["process_id"] is not None: + process_ids.add(event["process_id"]) + + # Extract thread ID + if "thread_id" in event and event["thread_id"] is not None: + thread_ids.add(event["thread_id"]) + + logger.info(f"Found {len(process_ids)} unique process IDs and {len(thread_ids)} unique thread IDs") + return {"process_ids": process_ids, "thread_ids": thread_ids} + + def find_processes_by_ids(self, process_ids: Set[int]) -> List[psutil.Process]: + """Find running processes by their PIDs.""" + processes = [] + current_pid = os.getpid() + + for pid in process_ids: + if pid == current_pid: + logger.debug(f"Skipping current process PID {pid}") + continue + + try: + proc = psutil.Process(pid) + if proc.is_running(): + processes.append(proc) + logger.debug(f"Found running process PID {pid}") + else: + logger.debug(f"Process PID {pid} is not running") + except psutil.NoSuchProcess: + logger.debug(f"Process PID {pid} no longer exists") + except psutil.AccessDenied: + logger.warning(f"Access denied to process PID {pid}") + except Exception as e: + logger.warning(f"Error checking process PID {pid}: {e}") + + return processes + + def find_threads_by_ids(self, thread_ids: Set[int]) -> List[threading.Thread]: + """Find running threads by their IDs (if possible).""" + # Note: Python doesn't provide a direct way to enumerate all threads + # This is more of a placeholder for future implementation + logger.info(f"Thread termination not implemented yet. Found {len(thread_ids)} thread IDs") + return [] + + def get_partition_status(self) -> Dict[int, Dict[str, Any]]: + """Get current status of all partitions.""" + dataset_mapping = self.load_dataset_mapping() + events = self.load_event_logs() + + partition_status = {} + + # Initialize from dataset mapping + if "partitions" in dataset_mapping: + for partition_info in dataset_mapping["partitions"]: + partition_id = partition_info["partition_id"] + partition_status[partition_id] = { + "status": partition_info.get("processing_status", "unknown"), + "sample_count": partition_info.get("sample_count", 0), + "start_time": partition_info.get("processing_start_time"), + "end_time": partition_info.get("processing_end_time"), + "error_message": partition_info.get("error_message"), + "current_op": None, + "completed_ops": [], + "checkpoints": [], + } + + # Update from event logs + for event in events: + if "partition_id" in event: + partition_id = event["partition_id"] + if partition_id not in partition_status: + partition_status[partition_id] = { + "status": "unknown", + "sample_count": 0, + "start_time": None, + "end_time": None, + "error_message": None, + "current_op": None, + "completed_ops": [], + "checkpoints": [], + } + + # Track partition start/complete + if event["event_type"] == "partition_start": + partition_status[partition_id]["start_time"] = event["timestamp"] + partition_status[partition_id]["status"] = "processing" + + elif event["event_type"] == "partition_complete": + partition_status[partition_id]["end_time"] = event["timestamp"] + partition_status[partition_id]["status"] = "completed" + + # Track operations + elif event["event_type"] == "op_start": + partition_status[partition_id]["current_op"] = { + "name": event.get("operation_name", "Unknown"), + "idx": event.get("operation_idx", 0), + "start_time": event["timestamp"], + } + + elif event["event_type"] == "op_complete": + op_info = { + "name": event.get("operation_name", "Unknown"), + "idx": event.get("operation_idx", 0), + "duration": event.get("duration", 0), + "input_rows": event.get("input_rows", 0), + "output_rows": event.get("output_rows", 0), + "throughput": event.get("performance_metrics", {}).get("throughput", 0), + "reduction_ratio": event.get("performance_metrics", {}).get("reduction_ratio", 0), + } + partition_status[partition_id]["completed_ops"].append(op_info) + partition_status[partition_id]["current_op"] = None + + # Track checkpoints + elif event["event_type"] == "checkpoint_save": + checkpoint_info = { + "operation_name": event.get("operation_name", "Unknown"), + "operation_idx": event.get("operation_idx", 0), + "checkpoint_path": event.get("checkpoint_path", ""), + "timestamp": event["timestamp"], + } + partition_status[partition_id]["checkpoints"].append(checkpoint_info) + + return partition_status + + def calculate_overall_progress(self) -> Dict[str, Any]: + """Calculate overall job progress.""" + partition_status = self.get_partition_status() + job_summary = self.load_job_summary() + + total_partitions = len(partition_status) + completed_partitions = sum(1 for p in partition_status.values() if p["status"] == "completed") + processing_partitions = sum(1 for p in partition_status.values() if p["status"] == "processing") + failed_partitions = sum(1 for p in partition_status.values() if p["status"] == "failed") + + # Calculate total samples + total_samples = sum(p.get("sample_count", 0) for p in partition_status.values()) + processed_samples = sum( + p.get("sample_count", 0) for p in partition_status.values() if p["status"] == "completed" + ) + + # Calculate progress percentage + progress_percentage = (completed_partitions / total_partitions * 100) if total_partitions > 0 else 0 + + # Calculate estimated time remaining + estimated_remaining = None + if job_summary and "start_time" in job_summary and completed_partitions > 0: + elapsed_time = time.time() - job_summary["start_time"] + if completed_partitions > 0: + avg_time_per_partition = elapsed_time / completed_partitions + remaining_partitions = total_partitions - completed_partitions + estimated_remaining = avg_time_per_partition * remaining_partitions + + return { + "total_partitions": total_partitions, + "completed_partitions": completed_partitions, + "processing_partitions": processing_partitions, + "failed_partitions": failed_partitions, + "progress_percentage": progress_percentage, + "total_samples": total_samples, + "processed_samples": processed_samples, + "estimated_remaining_seconds": estimated_remaining, + "job_status": job_summary.get("status", "unknown") if job_summary else "unknown", + } + + def get_operation_pipeline(self) -> List[Dict[str, Any]]: + """Get the operation pipeline from config.""" + config_file = self.work_dir / "partition-checkpoint-eventlog.yaml" + if not config_file.exists(): + return [] + + # Try to find process section in config + try: + with open(config_file, "r") as f: + content = f.read() + + # Simple parsing for process section + operations = [] + lines = content.split("\n") + in_process = False + + for line in lines: + if line.strip().startswith("process:"): + in_process = True + continue + elif in_process and line.strip().startswith("-"): + # Extract operation name + op_line = line.strip() + if ":" in op_line: + op_name = op_line.split(":")[0].replace("- ", "").strip() + operations.append({"name": op_name, "config": {}}) + + return operations + except Exception as e: + logger.warning(f"Failed to parse operation pipeline: {e}") + return [] + + +def _find_latest_events_file_in_dir(job_dir: Path) -> Optional[Path]: + """Helper function to find the latest events file in a directory.""" + # Look for events files with timestamp pattern (events_*.jsonl) + events_files = list(job_dir.glob("events_*.jsonl")) + if events_files: + # Sort by modification time and return the latest + return max(events_files, key=lambda f: f.stat().st_mtime) + + # Fallback to old naming convention for backward compatibility + fallback_file = job_dir / "events.jsonl" + return fallback_file if fallback_file.exists() else None + + +def list_running_jobs(base_dir: str = "outputs/partition-checkpoint-eventlog") -> List[Dict[str, Any]]: + """List all DataJuicer jobs and their status.""" + base_path = Path(base_dir) + if not base_path.exists(): + return [] + + jobs = [] + for job_dir in base_path.iterdir(): + if job_dir.is_dir(): + job_summary_file = job_dir / "job_summary.json" + if job_summary_file.exists(): + try: + with open(job_summary_file, "r") as f: + job_summary = json.load(f) + + # Check if processes are still running + events_file = _find_latest_events_file_in_dir(job_dir) + process_ids = set() + if events_file and events_file.exists(): + try: + with open(events_file, "r") as f: + for line in f: + try: + event_data = json.loads(line.strip()) + if "process_id" in event_data and event_data["process_id"] is not None: + process_ids.add(event_data["process_id"]) + except json.JSONDecodeError: + continue + except Exception: + pass + + # Count running processes + running_processes = 0 + for pid in process_ids: + try: + if psutil.Process(pid).is_running(): + running_processes += 1 + except (psutil.NoSuchProcess, psutil.AccessDenied): + pass + + jobs.append( + { + "job_id": job_dir.name, + "status": job_summary.get("status", "unknown"), + "start_time": job_summary.get("start_time"), + "processes": running_processes, + "work_dir": str(job_dir), + } + ) + except Exception as e: + logger.warning(f"Failed to read job summary for {job_dir.name}: {e}") + + return sorted(jobs, key=lambda x: x.get("start_time", 0) or 0, reverse=True) diff --git a/data_juicer/utils/job/monitor.py b/data_juicer/utils/job/monitor.py new file mode 100644 index 0000000000..10032e2bb2 --- /dev/null +++ b/data_juicer/utils/job/monitor.py @@ -0,0 +1,212 @@ +#!/usr/bin/env python3 +""" +DataJuicer Job Progress Monitor + +A utility to monitor and display progress information for DataJuicer jobs. +Shows partition status, operation progress, checkpoints, and overall job metrics. +""" + +import os +import sys +import time +from datetime import datetime +from typing import Any, Dict + +from data_juicer.utils.job.common import JobUtils + + +class JobProgressMonitor: + """Monitor and display progress for DataJuicer jobs.""" + + def __init__(self, job_id: str, base_dir: str = "outputs/partition-checkpoint-eventlog"): + """ + Initialize the job progress monitor. + + Args: + job_id: The job ID to monitor + base_dir: Base directory containing job outputs + """ + self.job_utils = JobUtils(job_id, base_dir=base_dir) + self.job_id = job_id + self.work_dir = self.job_utils.work_dir + + def display_progress(self, detailed: bool = False): + """Display job progress information.""" + print(f"\n{'='*80}") + print(f"DataJuicer Job Progress Monitor") + print(f"Job ID: {self.job_id}") + print(f"{'='*80}") + + # Load data + job_summary = self.job_utils.load_job_summary() + dataset_mapping = self.job_utils.load_dataset_mapping() + partition_status = self.job_utils.get_partition_status() + overall_progress = self.job_utils.calculate_overall_progress() + + # Job overview + print(f"\n📊 JOB OVERVIEW") + print(f" Status: {overall_progress['job_status'].upper()}") + print(f" Dataset: {dataset_mapping.get('original_dataset_path', 'Unknown')}") + print(f" Total Samples: {dataset_mapping.get('original_dataset_size', 0):,}") + print(f" Partition Size: {dataset_mapping.get('partition_size', 0):,} samples") + + if job_summary and job_summary.get("start_time"): + start_time = datetime.fromtimestamp(job_summary["start_time"]) + print(f" Start Time: {start_time.strftime('%Y-%m-%d %H:%M:%S')}") + + if job_summary and job_summary.get("duration"): + print(f" Duration: {job_summary['duration']:.1f} seconds") + + # Overall progress + print(f"\n🎯 OVERALL PROGRESS") + print( + f" Progress: {overall_progress['progress_percentage']:.1f}% " + f"({overall_progress['completed_partitions']}/{overall_progress['total_partitions']} partitions)" + ) + print( + f" Status: {overall_progress['completed_partitions']} completed, " + f"{overall_progress['processing_partitions']} processing, " + f"{overall_progress['failed_partitions']} failed" + ) + print(f" Samples: {overall_progress['processed_samples']:,}/{overall_progress['total_samples']:,}") + + if overall_progress["estimated_remaining_seconds"]: + remaining_minutes = overall_progress["estimated_remaining_seconds"] / 60 + print(f" Estimated Time Remaining: {remaining_minutes:.1f} minutes") + + # Partition status + print(f"\n📦 PARTITION STATUS") + for partition_id in sorted(partition_status.keys()): + partition = partition_status[partition_id] + status_icon = {"completed": "✅", "processing": "🔄", "failed": "❌", "unknown": "❓"}.get( + partition["status"], "❓" + ) + + print(f" Partition {partition_id:2d}: {status_icon} {partition['status'].upper()}") + print(f" Samples: {partition['sample_count']:,}") + + if partition["current_op"]: + print(f" Current: {partition['current_op']['name']} (op {partition['current_op']['idx']})") + + if partition["completed_ops"]: + print(f" Completed: {len(partition['completed_ops'])} operations") + + if partition["checkpoints"]: + print(f" Checkpoints: {len(partition['checkpoints'])} saved") + + if detailed: + # Detailed operation information + print(f"\n🔧 OPERATION DETAILS") + for partition_id in sorted(partition_status.keys()): + partition = partition_status[partition_id] + if partition["completed_ops"]: + print(f"\n Partition {partition_id}:") + for op in partition["completed_ops"]: + reduction = op.get("reduction_ratio", 0) * 100 + print( + f" {op['name']:25s} | " + f"Duration: {op['duration']:6.1f}s | " + f"Throughput: {op['throughput']:6.0f} rows/s | " + f"Reduction: {reduction:5.2f}%" + ) + + # Checkpoint information + print(f"\n💾 CHECKPOINT SUMMARY") + total_checkpoints = sum(len(p["checkpoints"]) for p in partition_status.values()) + print(f" Total Checkpoints: {total_checkpoints}") + + if detailed: + for partition_id in sorted(partition_status.keys()): + partition = partition_status[partition_id] + if partition["checkpoints"]: + print(f"\n Partition {partition_id} checkpoints:") + for checkpoint in partition["checkpoints"]: + checkpoint_time = datetime.fromtimestamp(checkpoint["timestamp"]) + print( + f" {checkpoint['operation_name']} (op {checkpoint['operation_idx']}) - " + f"{checkpoint_time.strftime('%H:%M:%S')}" + ) + + # Add helpful hint for stopping the job + print(f"\n💡 To stop this job: from data_juicer.utils.job_stopper import stop_job; stop_job('{self.job_id}')") + print(f"{'='*80}") + + def get_progress_data(self) -> Dict[str, Any]: + """Get progress data as a dictionary for programmatic use.""" + job_summary = self.job_utils.load_job_summary() + dataset_mapping = self.job_utils.load_dataset_mapping() + partition_status = self.job_utils.get_partition_status() + overall_progress = self.job_utils.calculate_overall_progress() + + return { + "job_id": self.job_id, + "job_summary": job_summary, + "dataset_mapping": dataset_mapping, + "partition_status": partition_status, + "overall_progress": overall_progress, + } + + +def show_job_progress( + job_id: str, base_dir: str = "outputs/partition-checkpoint-eventlog", detailed: bool = False +) -> Dict[str, Any]: + """ + Utility function to show job progress. + + Args: + job_id: The job ID to monitor + base_dir: Base directory containing job outputs + detailed: Whether to show detailed operation information + + Returns: + Dictionary containing all progress data + + Example: + >>> show_job_progress("20250728_233517_510abf") + >>> show_job_progress("20250728_233517_510abf", detailed=True) + """ + monitor = JobProgressMonitor(job_id, base_dir) + monitor.display_progress(detailed) + return monitor.get_progress_data() + + +def main(): + """Main entry point for the job progress monitor.""" + import argparse + + parser = argparse.ArgumentParser(description="Monitor DataJuicer job progress") + parser.add_argument("job_id", help="Job ID to monitor") + parser.add_argument( + "--base-dir", default="outputs/partition-checkpoint-eventlog", help="Base directory containing job outputs" + ) + parser.add_argument("--detailed", action="store_true", help="Show detailed operation information") + parser.add_argument("--watch", action="store_true", help="Watch mode - continuously update progress") + parser.add_argument("--interval", type=int, default=10, help="Update interval in seconds for watch mode") + + args = parser.parse_args() + + try: + monitor = JobProgressMonitor(args.job_id, args.base_dir) + + if args.watch: + print(f"Watching job {args.job_id} (press Ctrl+C to stop)...") + try: + while True: + os.system("clear" if os.name == "posix" else "cls") + monitor.display_progress(args.detailed) + time.sleep(args.interval) + except KeyboardInterrupt: + print("\nStopped watching.") + else: + monitor.display_progress(args.detailed) + + except FileNotFoundError as e: + print(f"Error: {e}") + sys.exit(1) + except Exception as e: + print(f"Unexpected error: {e}") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/data_juicer/utils/job/snapshot.py b/data_juicer/utils/job/snapshot.py new file mode 100644 index 0000000000..dbd5e12a41 --- /dev/null +++ b/data_juicer/utils/job/snapshot.py @@ -0,0 +1,734 @@ +""" +Processing Snapshot Utility for DataJuicer + +This module analyzes the current state of processing based on events.jsonl and DAG structure +to provide a comprehensive snapshot of what's done, what's not, and checkpointing status. +""" + +import json +import os +from dataclasses import dataclass +from datetime import datetime +from enum import Enum +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +from loguru import logger + + +class ProcessingStatus(Enum): + """Processing status enumeration.""" + + NOT_STARTED = "not_started" + IN_PROGRESS = "in_progress" + COMPLETED = "completed" + FAILED = "failed" + CHECKPOINTED = "checkpointed" + + +@dataclass +class OperationStatus: + """Status of a single operation.""" + + operation_name: str + operation_idx: int + status: ProcessingStatus + start_time: Optional[float] = None + end_time: Optional[float] = None + duration: Optional[float] = None + input_rows: Optional[int] = None + output_rows: Optional[int] = None + checkpoint_time: Optional[float] = None + error_message: Optional[str] = None + + +@dataclass +class PartitionStatus: + """Status of a single partition.""" + + partition_id: int + status: ProcessingStatus + sample_count: Optional[int] = None + creation_start_time: Optional[float] = None + creation_end_time: Optional[float] = None + processing_start_time: Optional[float] = None + processing_end_time: Optional[float] = None + current_operation: Optional[str] = None + completed_operations: List[str] = None + failed_operations: List[str] = None + checkpointed_operations: List[str] = None + error_message: Optional[str] = None + + def __post_init__(self): + """Initialize mutable fields after dataclass creation.""" + if self.completed_operations is None: + self.completed_operations = [] + if self.failed_operations is None: + self.failed_operations = [] + if self.checkpointed_operations is None: + self.checkpointed_operations = [] + + +@dataclass +class JobSnapshot: + """Complete snapshot of job processing status.""" + + job_id: str + job_start_time: Optional[float] = None + job_end_time: Optional[float] = None + total_duration: Optional[float] = None + total_partitions: int = 0 + completed_partitions: int = 0 + failed_partitions: int = 0 + in_progress_partitions: int = 0 + total_operations: int = 0 + completed_operations: int = 0 + failed_operations: int = 0 + checkpointed_operations: int = 0 + partition_statuses: Dict[int, PartitionStatus] = None + operation_statuses: Dict[str, OperationStatus] = None + dag_structure: Dict = None + checkpoint_strategy: Optional[str] = None + checkpoint_frequency: Optional[str] = None + last_checkpoint_time: Optional[float] = None + resumable: bool = False + overall_status: ProcessingStatus = ProcessingStatus.NOT_STARTED + + +class ProcessingSnapshotAnalyzer: + """Analyzer for processing snapshots.""" + + def __init__(self, work_dir: str): + """Initialize the analyzer with work directory.""" + self.work_dir = Path(work_dir) + self.events_file = self._find_latest_events_file() + self.dag_file = self.work_dir / "dag_execution_plan.json" + self.job_summary_file = self.work_dir / "job_summary.json" + + def _find_latest_events_file(self) -> Path: + """Find the latest events file in the work directory.""" + # Look for events files with timestamp pattern (events_*.jsonl) + events_files = list(self.work_dir.glob("events_*.jsonl")) + if events_files: + # Sort by modification time and return the latest + return max(events_files, key=lambda f: f.stat().st_mtime) + + # Fallback to old naming convention for backward compatibility + return self.work_dir / "events.jsonl" + + def load_events(self) -> List[Dict]: + """Load events from events.jsonl file.""" + events = [] + if self.events_file.exists(): + try: + with open(self.events_file, "r") as f: + for line in f: + events.append(json.loads(line.strip())) + logger.info(f"Loaded {len(events)} events from {self.events_file}") + except Exception as e: + logger.error(f"Failed to load events: {e}") + else: + logger.warning(f"Events file not found: {self.events_file}") + return events + + def load_dag_plan(self) -> Dict: + """Load DAG execution plan.""" + dag_plan = {} + if self.dag_file.exists(): + try: + with open(self.dag_file, "r") as f: + dag_plan = json.load(f) + logger.info(f"Loaded DAG plan from {self.dag_file}") + except Exception as e: + logger.error(f"Failed to load DAG plan: {e}") + else: + logger.warning(f"DAG file not found: {self.dag_file}") + return dag_plan + + def load_job_summary(self) -> Dict: + """Load job summary if available.""" + summary = {} + if self.job_summary_file.exists(): + try: + with open(self.job_summary_file, "r") as f: + summary = json.load(f) + logger.info(f"Loaded job summary from {self.job_summary_file}") + except Exception as e: + logger.error(f"Failed to load job summary: {e}") + return summary + + def analyze_events(self, events: List[Dict]) -> Tuple[Dict[int, PartitionStatus], Dict[str, OperationStatus]]: + """Analyze events to determine processing status.""" + partition_statuses = {} + operation_statuses = {} + + # Track job-level events + for event in events: + event_type = event.get("event_type") + timestamp = event.get("timestamp") + + if event_type == "job_start": + # Extract checkpoint strategy from metadata + metadata = event.get("metadata", {}) + # Note: checkpoint_strategy is extracted but not used in this method + # It's used in generate_snapshot method + pass + + elif event_type == "job_complete": + # Note: job_end_time is extracted but not used in this method + # It's used in generate_snapshot method + pass + + elif event_type == "partition_creation_start": + partition_id = event.get("partition_id") + if partition_id not in partition_statuses: + partition_statuses[partition_id] = PartitionStatus( + partition_id=partition_id, status=ProcessingStatus.NOT_STARTED + ) + partition_statuses[partition_id].creation_start_time = timestamp + + elif event_type == "partition_creation_complete": + partition_id = event.get("partition_id") + if partition_id in partition_statuses: + partition_statuses[partition_id].creation_end_time = timestamp + metadata = event.get("metadata", {}) + partition_statuses[partition_id].sample_count = metadata.get("sample_count") + + elif event_type == "partition_start": + partition_id = event.get("partition_id") + if partition_id in partition_statuses: + partition_statuses[partition_id].processing_start_time = timestamp + partition_statuses[partition_id].status = ProcessingStatus.IN_PROGRESS + + elif event_type == "partition_complete": + partition_id = event.get("partition_id") + if partition_id in partition_statuses: + partition_statuses[partition_id].processing_end_time = timestamp + partition_statuses[partition_id].status = ProcessingStatus.COMPLETED + + elif event_type == "partition_failed": + partition_id = event.get("partition_id") + if partition_id in partition_statuses: + partition_statuses[partition_id].status = ProcessingStatus.FAILED + partition_statuses[partition_id].error_message = event.get("error_message") + + elif event_type == "op_start": + partition_id = event.get("partition_id") + op_idx = event.get("operation_idx") + op_name = event.get("operation_name") + key = f"p{partition_id}_op{op_idx}_{op_name}" + + operation_statuses[key] = OperationStatus( + operation_name=op_name, + operation_idx=op_idx, + status=ProcessingStatus.IN_PROGRESS, + start_time=timestamp, + ) + + # Update partition status + if partition_id in partition_statuses: + partition_statuses[partition_id].current_operation = op_name + + elif event_type == "op_complete": + partition_id = event.get("partition_id") + op_idx = event.get("operation_idx") + op_name = event.get("operation_name") + key = f"p{partition_id}_op{op_idx}_{op_name}" + + if key in operation_statuses: + operation_statuses[key].end_time = timestamp + operation_statuses[key].status = ProcessingStatus.COMPLETED + if operation_statuses[key].start_time: + operation_statuses[key].duration = timestamp - operation_statuses[key].start_time + + metadata = event.get("metadata", {}) + operation_statuses[key].input_rows = metadata.get("input_rows") + operation_statuses[key].output_rows = metadata.get("output_rows") + + # Update partition status + if partition_id in partition_statuses: + partition_statuses[partition_id].completed_operations.append(op_name) + + elif event_type == "op_failed": + partition_id = event.get("partition_id") + op_idx = event.get("operation_idx") + op_name = event.get("operation_name") + key = f"p{partition_id}_op{op_idx}_{op_name}" + + if key in operation_statuses: + operation_statuses[key].status = ProcessingStatus.FAILED + operation_statuses[key].error_message = event.get("error_message") + + # Update partition status + if partition_id in partition_statuses: + partition_statuses[partition_id].failed_operations.append(op_name) + partition_statuses[partition_id].status = ProcessingStatus.FAILED + + elif event_type == "checkpoint_save": + partition_id = event.get("partition_id") + op_idx = event.get("operation_idx") + op_name = event.get("operation_name") + key = f"p{partition_id}_op{op_idx}_{op_name}" + + if key in operation_statuses: + operation_statuses[key].checkpoint_time = timestamp + operation_statuses[key].status = ProcessingStatus.CHECKPOINTED + + # Update partition status + if partition_id in partition_statuses: + partition_statuses[partition_id].checkpointed_operations.append(op_name) + + return partition_statuses, operation_statuses + + def determine_overall_status( + self, partition_statuses: Dict[int, PartitionStatus], operation_statuses: Dict[str, OperationStatus] + ) -> ProcessingStatus: + """Determine overall job status.""" + if not partition_statuses: + return ProcessingStatus.NOT_STARTED + + completed = sum(1 for p in partition_statuses.values() if p.status == ProcessingStatus.COMPLETED) + failed = sum(1 for p in partition_statuses.values() if p.status == ProcessingStatus.FAILED) + in_progress = sum(1 for p in partition_statuses.values() if p.status == ProcessingStatus.IN_PROGRESS) + + if failed > 0 and completed == 0: + return ProcessingStatus.FAILED + elif completed == len(partition_statuses): + return ProcessingStatus.COMPLETED + elif in_progress > 0 or completed > 0: + return ProcessingStatus.IN_PROGRESS + else: + return ProcessingStatus.NOT_STARTED + + def calculate_statistics( + self, partition_statuses: Dict[int, PartitionStatus], operation_statuses: Dict[str, OperationStatus] + ) -> Dict: + """Calculate processing statistics.""" + total_partitions = len(partition_statuses) + completed_partitions = sum(1 for p in partition_statuses.values() if p.status == ProcessingStatus.COMPLETED) + failed_partitions = sum(1 for p in partition_statuses.values() if p.status == ProcessingStatus.FAILED) + in_progress_partitions = sum(1 for p in partition_statuses.values() if p.status == ProcessingStatus.IN_PROGRESS) + + total_operations = len(operation_statuses) + completed_operations = sum(1 for op in operation_statuses.values() if op.status == ProcessingStatus.COMPLETED) + failed_operations = sum(1 for op in operation_statuses.values() if op.status == ProcessingStatus.FAILED) + checkpointed_operations = sum( + 1 for op in operation_statuses.values() if op.status == ProcessingStatus.CHECKPOINTED + ) + + return { + "total_partitions": total_partitions, + "completed_partitions": completed_partitions, + "failed_partitions": failed_partitions, + "in_progress_partitions": in_progress_partitions, + "total_operations": total_operations, + "completed_operations": completed_operations, + "failed_operations": failed_operations, + "checkpointed_operations": checkpointed_operations, + } + + def generate_snapshot(self) -> JobSnapshot: + """Generate a complete processing snapshot.""" + logger.info(f"Generating processing snapshot for work directory: {self.work_dir}") + + # Load data + events = self.load_events() + dag_plan = self.load_dag_plan() + job_summary = self.load_job_summary() + + # Extract job ID from directory name + job_id = self.work_dir.name + + # Analyze events + partition_statuses, operation_statuses = self.analyze_events(events) + + # Calculate statistics + stats = self.calculate_statistics(partition_statuses, operation_statuses) + + # Determine overall status + overall_status = self.determine_overall_status(partition_statuses, operation_statuses) + + # Extract timing information from job summary first, then fall back to events + job_start_time = None + job_end_time = None + total_duration = None + + if job_summary: + # Use job summary timing if available (more accurate) + job_start_time = job_summary.get("start_time") + job_end_time = job_summary.get("end_time") + total_duration = job_summary.get("duration") + else: + # Fall back to event-based timing + for event in events: + if event.get("event_type") == "job_start": + job_start_time = event.get("timestamp") + elif event.get("event_type") == "job_complete": + job_end_time = event.get("timestamp") + + if job_start_time and job_end_time: + total_duration = job_end_time - job_start_time + + # Determine resumability + resumable = any(op.status == ProcessingStatus.CHECKPOINTED for op in operation_statuses.values()) + + # Extract checkpoint information + checkpoint_strategy = None + last_checkpoint_time = None + for event in events: + if event.get("event_type") == "job_start": + metadata = event.get("metadata", {}) + checkpoint_strategy = metadata.get("checkpoint_strategy") + elif event.get("event_type") == "checkpoint_save": + last_checkpoint_time = event.get("timestamp") + + return JobSnapshot( + job_id=job_id, + job_start_time=job_start_time, + job_end_time=job_end_time, + total_duration=total_duration, + partition_statuses=partition_statuses, + operation_statuses=operation_statuses, + dag_structure=dag_plan, + checkpoint_strategy=checkpoint_strategy, + last_checkpoint_time=last_checkpoint_time, + resumable=resumable, + overall_status=overall_status, + **stats, + ) + + def to_json_dict(self, snapshot: JobSnapshot) -> Dict: + """Convert snapshot to JSON-serializable dictionary with comprehensive progress tracking.""" + # Load job summary for additional metadata + job_summary = self.load_job_summary() + + # Convert partition statuses to JSON format + partition_progress = {} + for partition_id, partition in snapshot.partition_statuses.items(): + partition_progress[str(partition_id)] = { + "status": partition.status.value, + "sample_count": partition.sample_count, + "creation_start_time": partition.creation_start_time, + "creation_end_time": partition.creation_end_time, + "processing_start_time": partition.processing_start_time, + "processing_end_time": partition.processing_end_time, + "current_operation": partition.current_operation, + "completed_operations": partition.completed_operations, + "failed_operations": partition.failed_operations, + "checkpointed_operations": partition.checkpointed_operations, + "error_message": partition.error_message, + "progress_percentage": self._calculate_partition_progress(partition), + } + + # Convert operation statuses to JSON format + operation_progress = {} + for op_key, operation in snapshot.operation_statuses.items(): + operation_progress[op_key] = { + "operation_name": operation.operation_name, + "operation_idx": operation.operation_idx, + "status": operation.status.value, + "start_time": operation.start_time, + "end_time": operation.end_time, + "duration": operation.duration, + "input_rows": operation.input_rows, + "output_rows": operation.output_rows, + "checkpoint_time": operation.checkpoint_time, + "error_message": operation.error_message, + "progress_percentage": self._calculate_operation_progress(operation), + } + + # Extract DAG structure information + dag_info = {} + if snapshot.dag_structure: + dag_info = { + "total_nodes": len(snapshot.dag_structure.get("nodes", [])), + "total_edges": len(snapshot.dag_structure.get("edges", [])), + "parallel_groups": len(snapshot.dag_structure.get("parallel_groups", [])), + "execution_plan": snapshot.dag_structure.get("execution_plan", []), + "metadata": snapshot.dag_structure.get("metadata", {}), + } + + # Calculate overall progress percentages + overall_progress = self._calculate_overall_progress(snapshot) + + # Build job information from job summary + job_info = { + "job_id": snapshot.job_id, + "executor_type": job_summary.get("executor_type") if job_summary else None, + "status": job_summary.get("status") if job_summary else snapshot.overall_status.value, + "config_file": job_summary.get("config_file") if job_summary else None, + "work_dir": job_summary.get("work_dir") if job_summary else None, + "resumption_command": job_summary.get("resumption_command") if job_summary else None, + "error_message": job_summary.get("error_message") if job_summary else None, + } + + return { + "job_info": job_info, + "overall_status": snapshot.overall_status.value, + "overall_progress": overall_progress, + "job_start_time": snapshot.job_start_time, + "job_end_time": snapshot.job_end_time, + "total_duration": snapshot.total_duration, + "timing": { + "start_time": snapshot.job_start_time, + "end_time": snapshot.job_end_time, + "duration_seconds": snapshot.total_duration, + "duration_formatted": ( + self._format_duration(snapshot.total_duration) if snapshot.total_duration else None + ), + "job_summary_duration": job_summary.get("duration") if job_summary else None, + "timing_source": "job_summary" if job_summary else "events", + }, + "progress_summary": { + "total_partitions": snapshot.total_partitions, + "completed_partitions": snapshot.completed_partitions, + "failed_partitions": snapshot.failed_partitions, + "in_progress_partitions": snapshot.in_progress_partitions, + "partition_progress_percentage": self._calculate_partition_progress_percentage(snapshot), + "total_operations": snapshot.total_operations, + "completed_operations": snapshot.completed_operations, + "failed_operations": snapshot.failed_operations, + "checkpointed_operations": snapshot.checkpointed_operations, + "operation_progress_percentage": self._calculate_operation_progress_percentage(snapshot), + }, + "checkpointing": { + "strategy": snapshot.checkpoint_strategy, + "last_checkpoint_time": snapshot.last_checkpoint_time, + "checkpointed_operations_count": snapshot.checkpointed_operations, + "resumable": snapshot.resumable, + "checkpoint_progress": self._calculate_checkpoint_progress(snapshot), + "checkpoint_dir": job_summary.get("checkpoint_dir") if job_summary else None, + }, + "partition_progress": partition_progress, + "operation_progress": operation_progress, + "dag_structure": dag_info, + "file_paths": { + "event_log_file": job_summary.get("event_log_file") if job_summary else None, + "event_log_dir": job_summary.get("event_log_dir") if job_summary else None, + "checkpoint_dir": job_summary.get("checkpoint_dir") if job_summary else None, + "metadata_dir": job_summary.get("metadata_dir") if job_summary else None, + "backed_up_config_path": job_summary.get("backed_up_config_path") if job_summary else None, + }, + "metadata": { + "snapshot_generated_at": datetime.now().isoformat(), + "events_analyzed": len(self.load_events()), + "dag_plan_loaded": bool(snapshot.dag_structure), + "job_summary_loaded": bool(job_summary), + "job_summary_used": bool(job_summary), + }, + } + + def _calculate_partition_progress(self, partition: PartitionStatus) -> float: + """Calculate progress percentage for a partition.""" + if partition.status == ProcessingStatus.COMPLETED: + return 100.0 + elif partition.status == ProcessingStatus.FAILED: + return 0.0 + elif partition.status == ProcessingStatus.IN_PROGRESS: + # Estimate progress based on completed operations + total_ops = ( + len(partition.completed_operations) + + len(partition.failed_operations) + + len(partition.checkpointed_operations) + ) + if total_ops > 0: + return min(90.0, (total_ops / 8) * 100) # Assume 8 operations per partition + else: + return 10.0 # Just started + else: + return 0.0 + + def _calculate_operation_progress(self, operation: OperationStatus) -> float: + """Calculate progress percentage for an operation.""" + if operation.status == ProcessingStatus.COMPLETED: + return 100.0 + elif operation.status == ProcessingStatus.FAILED: + return 0.0 + elif operation.status == ProcessingStatus.CHECKPOINTED: + return 100.0 # Checkpointed operations are considered complete + elif operation.status == ProcessingStatus.IN_PROGRESS: + if operation.start_time: + # Estimate progress based on time elapsed + current_time = datetime.now().timestamp() + elapsed = current_time - operation.start_time + # Assume average operation takes 1 second + estimated_duration = 1.0 + progress = min(90.0, (elapsed / estimated_duration) * 100) + return max(10.0, progress) + else: + return 10.0 + else: + return 0.0 + + def _calculate_overall_progress(self, snapshot: JobSnapshot) -> Dict[str, float]: + """Calculate overall progress percentages.""" + total_partitions = snapshot.total_partitions or 1 + total_operations = snapshot.total_operations or 1 + + partition_progress = (snapshot.completed_partitions / total_partitions) * 100 + operation_progress = (snapshot.completed_operations / total_operations) * 100 + + # Weighted overall progress (partitions and operations equally weighted) + overall_progress = (partition_progress + operation_progress) / 2 + + return { + "overall_percentage": overall_progress, + "partition_percentage": partition_progress, + "operation_percentage": operation_progress, + } + + def _calculate_partition_progress_percentage(self, snapshot: JobSnapshot) -> float: + """Calculate partition progress percentage.""" + if snapshot.total_partitions == 0: + return 100.0 + return (snapshot.completed_partitions / snapshot.total_partitions) * 100 + + def _calculate_operation_progress_percentage(self, snapshot: JobSnapshot) -> float: + """Calculate operation progress percentage.""" + if snapshot.total_operations == 0: + return 100.0 + return (snapshot.completed_operations / snapshot.total_operations) * 100 + + def _calculate_checkpoint_progress(self, snapshot: JobSnapshot) -> Dict[str, any]: + """Calculate checkpoint progress information.""" + if snapshot.total_operations == 0: + return {"percentage": 0.0, "checkpointed_operations": [], "checkpoint_coverage": 0.0} + + checkpoint_percentage = (snapshot.checkpointed_operations / snapshot.total_operations) * 100 + + # Get list of checkpointed operations + checkpointed_ops = [] + for op_key, operation in snapshot.operation_statuses.items(): + if operation.status == ProcessingStatus.CHECKPOINTED: + checkpointed_ops.append( + { + "operation_key": op_key, + "operation_name": operation.operation_name, + "checkpoint_time": operation.checkpoint_time, + } + ) + + return { + "percentage": checkpoint_percentage, + "checkpointed_operations": checkpointed_ops, + "checkpoint_coverage": checkpoint_percentage / 100.0, + } + + def _format_duration(self, duration_seconds: float) -> str: + """Format duration in human-readable format.""" + if duration_seconds is None: + return None + + hours = int(duration_seconds // 3600) + minutes = int((duration_seconds % 3600) // 60) + seconds = int(duration_seconds % 60) + + if hours > 0: + return f"{hours}h {minutes}m {seconds}s" + elif minutes > 0: + return f"{minutes}m {seconds}s" + else: + return f"{seconds}s" + + +def create_snapshot(work_dir: str, detailed: bool = False) -> JobSnapshot: + """Create and display a processing snapshot for a work directory.""" + analyzer = ProcessingSnapshotAnalyzer(work_dir) + snapshot = analyzer.generate_snapshot() + return snapshot + + +def main(): + """Main function for command-line usage.""" + import argparse + + parser = argparse.ArgumentParser( + description="Generate DataJuicer processing snapshot", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python -m data_juicer.utils.job.snapshot outputs/partition-checkpoint-eventlog/20250808_230030_501c9d + python -m data_juicer.utils.job.snapshot /path/to/job/directory --human-readable + """, + ) + parser.add_argument("work_dir", help="Path to the DataJuicer work directory") + parser.add_argument("--human-readable", action="store_true", help="Output in human-readable format instead of JSON") + + args = parser.parse_args() + + if not os.path.exists(args.work_dir): + print(f"Error: Work directory '{args.work_dir}' does not exist") + return 1 + + try: + snapshot = create_snapshot(args.work_dir) + analyzer = ProcessingSnapshotAnalyzer(args.work_dir) + + if args.human_readable: + # Human-readable output + print("\n" + "=" * 80) + print(f"DataJuicer Processing Snapshot - Job: {snapshot.job_id}") + print("=" * 80) + + # Overall status + status_emoji = { + ProcessingStatus.NOT_STARTED: "⏳", + ProcessingStatus.IN_PROGRESS: "🔄", + ProcessingStatus.COMPLETED: "✅", + ProcessingStatus.FAILED: "❌", + ProcessingStatus.CHECKPOINTED: "💾", + } + + print( + f"\n📊 Overall Status: {status_emoji[snapshot.overall_status]} {snapshot.overall_status.value.upper()}" + ) + + # Timing information + if snapshot.job_start_time: + start_time = datetime.fromtimestamp(snapshot.job_start_time) + print(f"🕐 Started: {start_time.strftime('%Y-%m-%d %H:%M:%S')}") + + if snapshot.total_duration: + print(f"⏱️ Duration: {snapshot.total_duration:.2f} seconds") + + # Progress summary + print(f"\n📈 Progress Summary:") + print(f" Partitions: {snapshot.completed_partitions}/{snapshot.total_partitions} completed") + print(f" Operations: {snapshot.completed_operations}/{snapshot.total_operations} completed") + + if snapshot.failed_partitions > 0: + print(f" ❌ Failed partitions: {snapshot.failed_partitions}") + if snapshot.failed_operations > 0: + print(f" ❌ Failed operations: {snapshot.failed_operations}") + if snapshot.checkpointed_operations > 0: + print(f" 💾 Checkpointed operations: {snapshot.checkpointed_operations}") + + # Checkpointing information + if snapshot.checkpoint_strategy: + print(f"\n💾 Checkpointing:") + print(f" Strategy: {snapshot.checkpoint_strategy}") + if snapshot.last_checkpoint_time: + checkpoint_time = datetime.fromtimestamp(snapshot.last_checkpoint_time) + print(f" Last checkpoint: {checkpoint_time.strftime('%Y-%m-%d %H:%M:%S')}") + print(f" Resumable: {'Yes' if snapshot.resumable else 'No'}") + + print("\n" + "=" * 80) + else: + # JSON output (default) + json_dict = analyzer.to_json_dict(snapshot) + print(json.dumps(json_dict, indent=2)) + + return 0 + + except Exception as e: + print(f"Error generating snapshot: {e}") + import traceback + + traceback.print_exc() + return 1 + + +if __name__ == "__main__": + import sys + + sys.exit(main()) diff --git a/data_juicer/utils/job/stopper.py b/data_juicer/utils/job/stopper.py new file mode 100644 index 0000000000..685cf77c8e --- /dev/null +++ b/data_juicer/utils/job/stopper.py @@ -0,0 +1,200 @@ +#!/usr/bin/env python3 +""" +DataJuicer Job Stopper + +A utility to stop DataJuicer jobs by reading event logs to find process and thread IDs, +then terminating those specific processes and threads. +""" + +import json +import sys +import time +from typing import Any, Dict + +import psutil +from loguru import logger + +from data_juicer.utils.job.common import JobUtils, list_running_jobs + + +class JobStopper: + """Stop DataJuicer jobs using event log-based process discovery.""" + + def __init__(self, job_id: str, base_dir: str = "outputs/partition-checkpoint-eventlog"): + self.job_utils = JobUtils(job_id, base_dir=base_dir) + self.job_id = job_id + self.work_dir = self.job_utils.work_dir + + def terminate_process_gracefully(self, proc, timeout: int = 10) -> bool: + """Terminate a process gracefully with timeout.""" + try: + logger.info(f"Terminating process {proc.pid} gracefully...") + proc.terminate() + + # Wait for the process to terminate + try: + proc.wait(timeout=timeout) + logger.info(f"Process {proc.pid} terminated gracefully") + return True + except psutil.TimeoutExpired: + logger.warning(f"Process {proc.pid} did not terminate within {timeout}s, force killing...") + proc.kill() + proc.wait() + logger.info(f"Process {proc.pid} force killed") + return True + + except psutil.NoSuchProcess: + logger.info(f"Process {proc.pid} already terminated") + return True + except psutil.AccessDenied: + logger.error(f"Access denied when terminating process {proc.pid}") + return False + except Exception as e: + logger.error(f"Error terminating process {proc.pid}: {e}") + return False + + def cleanup_job_resources(self) -> None: + """Clean up job resources and update job summary.""" + job_summary = self.job_utils.load_job_summary() + if job_summary: + job_summary["status"] = "stopped" + job_summary["stop_time"] = time.time() + job_summary["stop_reason"] = "manual_stop" + + try: + with open(self.work_dir / "job_summary.json", "w") as f: + json.dump(job_summary, f, indent=2, default=str) + logger.info(f"Updated job summary: {self.work_dir / 'job_summary.json'}") + except Exception as e: + logger.error(f"Failed to update job summary: {e}") + + def stop_job(self, force: bool = False, timeout: int = 30) -> Dict[str, Any]: + """Stop the DataJuicer job using event log-based process discovery.""" + results = { + "job_id": self.job_id, + "success": False, + "processes_found": 0, + "processes_terminated": 0, + "threads_found": 0, + "threads_terminated": 0, + "errors": [], + } + + logger.info(f"🛑 Stopping DataJuicer job: {self.job_id}") + logger.info(f"Work directory: {self.work_dir}") + + # Load job summary + job_summary = self.job_utils.load_job_summary() + if job_summary: + logger.info(f"Job status: {job_summary.get('status', 'unknown')}") + logger.info(f"Job started: {job_summary.get('start_time', 'unknown')}") + + # Extract process and thread IDs from event logs + logger.info("🔍 Extracting process and thread IDs from event logs...") + ids = self.job_utils.extract_process_thread_ids() + + results["processes_found"] = len(ids["process_ids"]) + results["threads_found"] = len(ids["thread_ids"]) + + if not ids["process_ids"] and not ids["thread_ids"]: + logger.warning("No process or thread IDs found in event logs") + results["errors"].append("No process or thread IDs found in event logs") + self.cleanup_job_resources() + return results + + # Find and terminate processes + logger.info(f"🔍 Finding {len(ids['process_ids'])} processes...") + processes = self.job_utils.find_processes_by_ids(ids["process_ids"]) + + if processes: + logger.info(f"Found {len(processes)} running processes to terminate") + for proc in processes: + if self.terminate_process_gracefully(proc, timeout): + results["processes_terminated"] += 1 + else: + results["errors"].append(f"Failed to terminate process {proc.pid}") + else: + logger.info("No running processes found") + + # Find and terminate threads (placeholder for future implementation) + logger.info(f"🔍 Finding {len(ids['thread_ids'])} threads...") + threads = self.job_utils.find_threads_by_ids(ids["thread_ids"]) + results["threads_terminated"] = len(threads) + + # Clean up job resources + logger.info("🧹 Cleaning up job resources...") + self.cleanup_job_resources() + + # Determine success + results["success"] = results["processes_terminated"] > 0 or results["threads_terminated"] > 0 + + if results["success"]: + logger.info(f"✅ Job {self.job_id} stopped successfully") + logger.info(f" Terminated {results['processes_terminated']} processes") + logger.info(f" Terminated {results['threads_terminated']} threads") + else: + logger.warning(f"⚠️ Job {self.job_id} may not have been fully stopped") + if results["errors"]: + logger.error(f" Errors: {results['errors']}") + + return results + + +def stop_job( + job_id: str, base_dir: str = "outputs/partition-checkpoint-eventlog", force: bool = False, timeout: int = 30 +) -> Dict[str, Any]: + """Stop a DataJuicer job using event log-based process discovery.""" + stopper = JobStopper(job_id, base_dir) + return stopper.stop_job(force=force, timeout=timeout) + + +def main(): + """Main function for command-line usage.""" + import argparse + + parser = argparse.ArgumentParser(description="Stop DataJuicer jobs using event log-based process discovery") + parser.add_argument("job_id", nargs="?", help="Job ID to stop") + parser.add_argument( + "--base-dir", default="outputs/partition-checkpoint-eventlog", help="Base directory for job outputs" + ) + parser.add_argument("--force", action="store_true", help="Force termination") + parser.add_argument("--timeout", type=int, default=30, help="Termination timeout in seconds") + parser.add_argument("--list", action="store_true", help="List all jobs") + parser.add_argument("--verbose", action="store_true", help="Verbose output") + + args = parser.parse_args() + + if args.verbose: + logger.remove() + logger.add(sys.stderr, level="DEBUG") + + if args.list: + jobs = list_running_jobs(args.base_dir) + if jobs: + print("📋 DataJuicer Jobs:") + print("=" * 80) + for job in jobs: + status_icon = "🟢" if job["status"] == "completed" else "🟡" if job["status"] == "running" else "🔴" + print(f"{status_icon} {job['job_id']} | Status: {job['status']} | Processes: {job['processes']}") + else: + print("No DataJuicer jobs found") + return + + if not args.job_id: + parser.error("Job ID is required unless using --list") + + result = stop_job(args.job_id, args.base_dir, force=args.force, timeout=args.timeout) + + if result["success"]: + print(f"✅ Job {args.job_id} stopped successfully") + print(f" Terminated {result['processes_terminated']} processes") + print(f" Terminated {result['threads_terminated']} threads") + else: + print(f"⚠️ Job {args.job_id} may not have been fully stopped") + if result["errors"]: + print(f" Errors: {result['errors']}") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/data_juicer/utils/logger_utils.py b/data_juicer/utils/logger_utils.py index 1f33785210..d89c6204ef 100644 --- a/data_juicer/utils/logger_utils.py +++ b/data_juicer/utils/logger_utils.py @@ -167,7 +167,13 @@ def setup_logger(save_dir, distributed_rank=0, filename="log.txt", mode="o", lev level=level, enqueue=not is_notebook(), ) - logger.add(save_file) + logger.add( + save_file, + format=loguru_format, + level=level, + compression="gz", + enqueue=True, + ) # for interest of levels: debug, error, warning logger.add( @@ -175,6 +181,7 @@ def setup_logger(save_dir, distributed_rank=0, filename="log.txt", mode="o", lev level="DEBUG", filter=lambda x: "DEBUG" == x["level"].name, format=loguru_format, + compression="gz", enqueue=True, serialize=True, ) @@ -183,6 +190,7 @@ def setup_logger(save_dir, distributed_rank=0, filename="log.txt", mode="o", lev level="ERROR", filter=lambda x: "ERROR" == x["level"].name, format=loguru_format, + compression="gz", enqueue=True, serialize=True, ) @@ -191,6 +199,7 @@ def setup_logger(save_dir, distributed_rank=0, filename="log.txt", mode="o", lev level="WARNING", filter=lambda x: "WARNING" == x["level"].name, format=loguru_format, + compression="gz", enqueue=True, serialize=True, ) diff --git a/demos/README.md b/demos/README.md index 000f782469..eaac2c9fa4 100644 --- a/demos/README.md +++ b/demos/README.md @@ -48,3 +48,6 @@ streamlit run app.py - Data mixture (`data_mixture`) - This demo selects and mixes samples from multiple datasets and exports them into a new dataset. + +- Partition and checkpoint (`partition_and_checkpoint`) + - This demo showcases distributed processing with partitioning, checkpointing, and event logging. It demonstrates the new job management features including resource-aware partitioning, comprehensive event logging, and the processing snapshot utility for monitoring job progress. diff --git a/demos/README_ZH.md b/demos/README_ZH.md index 218fe1e649..e783cbadfe 100644 --- a/demos/README_ZH.md +++ b/demos/README_ZH.md @@ -48,3 +48,6 @@ streamlit run app.py - 数据混合 (`data_mixture`) - 该示例从多份数据集中进行采样并混合为一个新的数据集。 + +- 分区和检查点 (`partition_and_checkpoint`) + - 该演示展示了带分区、检查点和事件日志的分布式处理。它演示了新的作业管理功能,包括资源感知分区、全面的事件日志记录和处理快照工具,用于监控作业进度。 diff --git a/demos/partition_and_checkpoint/configs/partition-checkpoint-eventlog-control.yaml b/demos/partition_and_checkpoint/configs/partition-checkpoint-eventlog-control.yaml new file mode 100644 index 0000000000..809a983f71 --- /dev/null +++ b/demos/partition_and_checkpoint/configs/partition-checkpoint-eventlog-control.yaml @@ -0,0 +1,89 @@ +# ============================================================================= +# CONTROL CONFIG FOR partition-checkpoint-eventlog.yaml +# ============================================================================= +# This is a control configuration file for partition-checkpoint-eventlog.yaml +# that uses the non-partitioned Ray executor (executor_type: "ray") instead of +# the partitioned executor (executor_type: "ray_partitioned"). +# +# This config is useful for: +# 1. Comparing performance between partitioned and non-partitioned executors +# 2. Testing DAG execution without partitioning +# 3. Simpler execution flow without partition management +# +# Key differences from partition-checkpoint-eventlog.yaml: +# - executor_type: "ray" (instead of "ray_partitioned") +# - No partition configuration needed +# - Simpler execution model (no partition splitting/merging) +# ============================================================================= + +dataset_path: './demos/data/demo-dataset.jsonl' + +work_dir: "./outputs/partition-checkpoint-eventlog/{job_id}" +export_path: '{work_dir}/processed.jsonl' +np: 8 + +executor_type: "ray" # Non-partitioned Ray executor (control config) +ray_address: "auto" + +# Process pipeline with real DataJuicer operations +process: + # Text cleaning operations + - clean_links_mapper: + text_key: "text" + min_links: 0 + max_links: 10 + + - clean_email_mapper: + text_key: "text" + min_emails: 0 + max_emails: 5 + + - whitespace_normalization_mapper: + text_key: "text" + + - fix_unicode_mapper: + text_key: "text" + + # Text filtering operations + - text_length_filter: + text_key: "text" + min_len: 5 + max_len: 10000 + + - alphanumeric_filter: + text_key: "text" + min_ratio: 0.1 + + # Quality filtering + - character_repetition_filter: + text_key: "text" + min_ratio: 0.0 + max_ratio: 0.5 + + - word_repetition_filter: + text_key: "text" + min_ratio: 0.0 + max_ratio: 0.5 + + - ray_bts_minhash_deduplicator: + tokenization: 'character' + lowercase: true + union_find_parallel_num: 2 + +# Export configuration +export_in_parallel: true +keep_stats_in_res_ds: true +keep_hashes_in_res_ds: true + +# ============================================================================= +# USAGE: +# ============================================================================= +# This control config uses the non-partitioned Ray executor for comparison. +# To use this config: +# +# dj-process --config configs/demo/partition-checkpoint-eventlog-control.yaml +# +# For the partitioned executor version, use: +# dj-process --config configs/demo/partition-checkpoint-eventlog.yaml +# +# ============================================================================= diff --git a/demos/partition_and_checkpoint/configs/partition-checkpoint-eventlog.yaml b/demos/partition_and_checkpoint/configs/partition-checkpoint-eventlog.yaml new file mode 100644 index 0000000000..9158f2f1b3 --- /dev/null +++ b/demos/partition_and_checkpoint/configs/partition-checkpoint-eventlog.yaml @@ -0,0 +1,153 @@ +# ============================================================================= +# COMPREHENSIVE DATAJUICER DEMO: Checkpointing, Event Logging & Job Management +# ============================================================================= +# This demo showcases: +# 1. Configurable checkpointing strategies +# 2. Event logging with job-specific directories +# 3. Flexible storage architecture +# 4. Job resumption capabilities +# 5. Real DataJuicer operations +# ============================================================================= + +# Data location configuration (Mandatory) +dataset_path: './demos/data/demo-dataset.jsonl' + +# Work directory configuration +# IMPORTANT: If using {job_id} placeholder, it MUST be the last part of the path +# Examples: +# ✅ work_dir: "./outputs/my_project/{job_id}" # Valid +# ✅ work_dir: "/data/experiments/{job_id}" # Valid +# ❌ work_dir: "./outputs/{job_id}/results" # Invalid - {job_id} not at end +# ❌ work_dir: "./{job_id}/outputs/data" # Invalid - {job_id} not at end +# +# If no {job_id} is specified, job_id will be automatically appended: +# work_dir: "./outputs/my_project" → job_dir: "./outputs/my_project/20250804_143022_abc123" +work_dir: "./outputs/partition-checkpoint-eventlog/{job_id}" +export_path: '{work_dir}/processed.jsonl' + +# Executor configuration +executor_type: "ray_partitioned" # Use our enhanced partitioned executor +ray_address: "auto" +# np will be auto-configured based on available cluster resources when partition.auto_configure: true +# np: 2 # Number of Ray workers (auto-configured when partition.auto_configure: true) + +# Separate storage configuration +# Partition directory (Optional) is used to store the partitions of the dataset if using ray_partitioned executor +partition_dir: "{work_dir}/partitions" + +# Event logs: Fast storage (SSD, local disk) - small files, frequent writes (Optional) +event_log_dir: "{work_dir}/event_logs" # Optional: separate fast storage for event logs + +# Checkpoints: Large storage (HDD, network storage) - large files, infrequent writes (Optional) +checkpoint_dir: "{work_dir}/checkpoints" # Optional: separate large storage for checkpoints + + +# Partition configuration +partition: + mode: "manual" # Partition mode: "auto" (use optimizer) or "manual" (specify count) + num_of_partitions: 4 # Number of partitions (for manual mode) + target_size_mb: 256 # Target partition size in MB (for auto mode) + # Options: 128 (memory-constrained), 256 (default, balanced), + # 512 (high-memory systems), 1024 (very large files) + # Smaller = more checkpoints & better recovery, larger = less overhead + +# Checkpoint configuration +checkpoint: + enabled: false + strategy: "every_n_ops" + n_ops: 3 + # strategy: "every_op" # every_op, every_partition, every_n_ops, manual, disabled + # n_ops: 1 # Number of operations between checkpoints (for every_n_ops strategy) + # op_names: [] # Specific operation names to checkpoint after (for manual strategy) + +# Intermediate storage configuration (includes file lifecycle management) +intermediate_storage: + format: "parquet" # parquet, arrow, jsonl; defaults to parquet + write_partitions: false + +# Event logging configuration +event_logging: + enabled: true + +# Process pipeline with real DataJuicer operations +process: + # Text cleaning operations + - clean_links_mapper: + text_key: "text" + min_links: 0 + max_links: 10 + + - clean_email_mapper: + text_key: "text" + min_emails: 0 + max_emails: 5 + + - whitespace_normalization_mapper: + text_key: "text" + + - fix_unicode_mapper: + text_key: "text" + + # Text filtering operations + - text_length_filter: + text_key: "text" + min_len: 5 + max_len: 10000 + + - alphanumeric_filter: + text_key: "text" + min_ratio: 0.1 + + # Quality filtering + - character_repetition_filter: + text_key: "text" + min_ratio: 0.0 + max_ratio: 0.5 + + - word_repetition_filter: + text_key: "text" + min_ratio: 0.0 + max_ratio: 0.5 + +# Export configuration +export_in_parallel: true +keep_stats_in_res_ds: true +keep_hashes_in_res_ds: true + + +# ============================================================================= +# COMPLETE USER EXPERIENCE: +# ============================================================================= +# 1. Start job: +# dj-process --config configs/demo/partition-checkpoint-eventlog.yaml +# # Output shows: Job ID (timestamp_configname_suffix), job directory, resumption command +# # Example: 20241201_143022_partition-checkpoint-eventlog_abc123 +# +# 2. If job fails, resume with: +# dj-process --config configs/demo/partition-checkpoint-eventlog.yaml --job_id +# # System validates job_id and shows previous status +# +# 3. Directory structure (flexible storage): +# outputs/partition-checkpoint-eventlog/{job_id}/ +# ├── partitions/ # Dataset partitions (large files) +# ├── checkpoints/ # Operation checkpoints (large files) +# ├── event_logs/ # Event logs (small files, frequent writes) +# ├── metadata/ # Job metadata and mapping +# ├── results/ # Final processed dataset +# └── processed.jsonl # Final output file +# +# 4. Resource Optimization: +# - partition.mode: "auto" automatically optimizes: +# * Partition size based on data characteristics and available memory +# * Number of partitions based on dataset size and optimal partition size +# * Worker count (np) based on available CPU cores +# * Processing efficiency based on data modality (text, image, audio, video) +# - No manual tuning required - system adapts to your hardware and data +# +# 5. Monitoring and Debugging: +# - Real-time event logs in event_logs/ directory +# - Processing summary with statistics and timing +# - Checkpoint recovery for fault tolerance +# - Detailed resource utilization analysis +# +# ============================================================================= diff --git a/demos/partition_and_checkpoint/example_event_log.jsonl b/demos/partition_and_checkpoint/example_event_log.jsonl new file mode 100644 index 0000000000..7652fde892 --- /dev/null +++ b/demos/partition_and_checkpoint/example_event_log.jsonl @@ -0,0 +1,26 @@ +{"event_id": "evt_001", "event_type": "processing_start", "timestamp": 1703123456.789, "partition_id": null, "operation_name": null, "operation_idx": null, "message": "Starting partitioned processing pipeline", "metadata": {"executor_type": "ray_partitioned", "dataset_path": "data/large-dataset.jsonl", "total_samples": 50000, "partition_size": 10000}, "error_details": null} +{"event_id": "evt_002", "event_type": "partition_start", "timestamp": 1703123457.123, "partition_id": 0, "operation_name": null, "operation_idx": null, "message": "Starting processing of partition 0", "metadata": {"partition_path": "work_dir/partitions/partition_000000.parquet", "sample_count": 10000, "file_size_bytes": 2048576}, "error_details": null} +{"event_id": "evt_003", "event_type": "operation_start", "timestamp": 1703123457.456, "partition_id": 0, "operation_name": "whitespace_normalization_mapper", "operation_idx": 0, "message": "Starting whitespace normalization on partition 0", "metadata": {"operation_config": {"text_key": "text"}}, "error_details": null} +{"event_id": "evt_004", "event_type": "operation_complete", "timestamp": 1703123458.789, "partition_id": 0, "operation_name": "whitespace_normalization_mapper", "operation_idx": 0, "message": "Completed whitespace normalization on partition 0", "metadata": {"duration_seconds": 1.333, "samples_processed": 10000, "samples_filtered": 0}, "error_details": null} +{"event_id": "evt_005", "event_type": "operation_checkpoint", "timestamp": 1703123458.890, "partition_id": 0, "operation_name": "whitespace_normalization_mapper", "operation_idx": 0, "message": "Saved checkpoint after whitespace normalization", "metadata": {"checkpoint_path": "work_dir/checkpoints/partition_000000/op_000_whitespace_normalization_mapper.parquet", "checkpoint_size_bytes": 1536000}, "error_details": null} +{"event_id": "evt_006", "event_type": "operation_start", "timestamp": 1703123459.123, "partition_id": 0, "operation_name": "text_length_filter", "operation_idx": 1, "message": "Starting text length filtering on partition 0", "metadata": {"operation_config": {"min_len": 50, "max_len": 2000, "text_key": "text"}}, "error_details": null} +{"event_id": "evt_007", "event_type": "operation_complete", "timestamp": 1703123460.456, "partition_id": 0, "operation_name": "text_length_filter", "operation_idx": 1, "message": "Completed text length filtering on partition 0", "metadata": {"duration_seconds": 1.333, "samples_processed": 10000, "samples_filtered": 1250}, "error_details": null} +{"event_id": "evt_008", "event_type": "operation_checkpoint", "timestamp": 1703123460.567, "partition_id": 0, "operation_name": "text_length_filter", "operation_idx": 1, "message": "Saved checkpoint after text length filtering", "metadata": {"checkpoint_path": "work_dir/checkpoints/partition_000000/op_001_text_length_filter.parquet", "checkpoint_size_bytes": 1280000}, "error_details": null} +{"event_id": "evt_009", "event_type": "operation_start", "timestamp": 1703123461.123, "partition_id": 0, "operation_name": "language_id_score_filter", "operation_idx": 2, "message": "Starting language filtering on partition 0", "metadata": {"operation_config": {"lang": "en", "min_score": 0.8, "text_key": "text"}}, "error_details": null} +{"event_id": "evt_010", "event_type": "operation_complete", "timestamp": 1703123462.789, "partition_id": 0, "operation_name": "language_id_score_filter", "operation_idx": 2, "message": "Completed language filtering on partition 0", "metadata": {"duration_seconds": 1.666, "samples_processed": 8750, "samples_filtered": 875}, "error_details": null} +{"event_id": "evt_011", "event_type": "partition_complete", "timestamp": 1703123462.890, "partition_id": 0, "operation_name": null, "operation_idx": null, "message": "Completed processing of partition 0", "metadata": {"total_duration_seconds": 5.767, "final_sample_count": 7875, "operations_completed": 3, "checkpoints_created": 3}, "error_details": null} +{"event_id": "evt_012", "event_type": "partition_start", "timestamp": 1703123463.123, "partition_id": 1, "operation_name": null, "operation_idx": null, "message": "Starting processing of partition 1", "metadata": {"partition_path": "work_dir/partitions/partition_000001.parquet", "sample_count": 10000, "file_size_bytes": 2150400}, "error_details": null} +{"event_id": "evt_013", "event_type": "operation_start", "timestamp": 1703123463.456, "partition_id": 1, "operation_name": "whitespace_normalization_mapper", "operation_idx": 0, "message": "Starting whitespace normalization on partition 1", "metadata": {"operation_config": {"text_key": "text"}}, "error_details": null} +{"event_id": "evt_014", "event_type": "operation_error", "timestamp": 1703123464.123, "partition_id": 1, "operation_name": "whitespace_normalization_mapper", "operation_idx": 0, "message": "Error during whitespace normalization on partition 1", "metadata": {"duration_seconds": 0.667, "samples_processed": 2500}, "error_details": "ValueError: Invalid text format in sample 2501: expected string, got None"} +{"event_id": "evt_015", "event_type": "partition_failed", "timestamp": 1703123464.234, "partition_id": 1, "operation_name": null, "operation_idx": null, "message": "Failed processing of partition 1 due to operation error", "metadata": {"total_duration_seconds": 1.111, "operations_completed": 0, "retry_count": 0}, "error_details": "ValueError: Invalid text format in sample 2501: expected string, got None"} +{"event_id": "evt_016", "event_type": "partition_start", "timestamp": 1703123465.123, "partition_id": 2, "operation_name": null, "operation_idx": null, "message": "Starting processing of partition 2", "metadata": {"partition_path": "work_dir/partitions/partition_000002.parquet", "sample_count": 10000, "file_size_bytes": 1984512}, "error_details": null} +{"event_id": "evt_017", "event_type": "operation_start", "timestamp": 1703123465.456, "partition_id": 2, "operation_name": "whitespace_normalization_mapper", "operation_idx": 0, "message": "Starting whitespace normalization on partition 2", "metadata": {"operation_config": {"text_key": "text"}}, "error_details": null} +{"event_id": "evt_018", "event_type": "operation_complete", "timestamp": 1703123466.789, "partition_id": 2, "operation_name": "whitespace_normalization_mapper", "operation_idx": 0, "message": "Completed whitespace normalization on partition 2", "metadata": {"duration_seconds": 1.333, "samples_processed": 10000, "samples_filtered": 0}, "error_details": null} +{"event_id": "evt_019", "event_type": "operation_checkpoint", "timestamp": 1703123466.890, "partition_id": 2, "operation_name": "whitespace_normalization_mapper", "operation_idx": 0, "message": "Saved checkpoint after whitespace normalization", "metadata": {"checkpoint_path": "work_dir/checkpoints/partition_000002/op_000_whitespace_normalization_mapper.parquet", "checkpoint_size_bytes": 1472000}, "error_details": null} +{"event_id": "evt_020", "event_type": "operation_start", "timestamp": 1703123467.123, "partition_id": 2, "operation_name": "text_length_filter", "operation_idx": 1, "message": "Starting text length filtering on partition 2", "metadata": {"operation_config": {"min_len": 50, "max_len": 2000, "text_key": "text"}}, "error_details": null} +{"event_id": "evt_021", "event_type": "operation_complete", "timestamp": 1703123468.456, "partition_id": 2, "operation_name": "text_length_filter", "operation_idx": 1, "message": "Completed text length filtering on partition 2", "metadata": {"duration_seconds": 1.333, "samples_processed": 10000, "samples_filtered": 1100}, "error_details": null} +{"event_id": "evt_022", "event_type": "operation_checkpoint", "timestamp": 1703123468.567, "partition_id": 2, "operation_name": "text_length_filter", "operation_idx": 1, "message": "Saved checkpoint after text length filtering", "metadata": {"checkpoint_path": "work_dir/checkpoints/partition_000002/op_001_text_length_filter.parquet", "checkpoint_size_bytes": 1216000}, "error_details": null} +{"event_id": "evt_023", "event_type": "operation_start", "timestamp": 1703123469.123, "partition_id": 2, "operation_name": "language_id_score_filter", "operation_idx": 2, "message": "Starting language filtering on partition 2", "metadata": {"operation_config": {"lang": "en", "min_score": 0.8, "text_key": "text"}}, "error_details": null} +{"event_id": "evt_024", "event_type": "operation_complete", "timestamp": 1703123470.789, "partition_id": 2, "operation_name": "language_id_score_filter", "operation_idx": 2, "message": "Completed language filtering on partition 2", "metadata": {"duration_seconds": 1.666, "samples_processed": 8900, "samples_filtered": 890}, "error_details": null} +{"event_id": "evt_025", "event_type": "partition_complete", "timestamp": 1703123470.890, "partition_id": 2, "operation_name": null, "operation_idx": null, "message": "Completed processing of partition 2", "metadata": {"total_duration_seconds": 5.767, "final_sample_count": 8010, "operations_completed": 3, "checkpoints_created": 3}, "error_details": null} +{"event_id": "evt_026", "event_type": "processing_complete", "timestamp": 1703123471.123, "partition_id": null, "operation_name": null, "operation_idx": null, "message": "Completed partitioned processing pipeline", "metadata": {"total_duration_seconds": 14.334, "total_partitions": 3, "completed_partitions": 2, "failed_partitions": 1, "total_samples_processed": 30000, "total_samples_output": 15885, "success_rate": 0.667, "checkpoints_created": 6}, "error_details": null} \ No newline at end of file diff --git a/demos/partition_and_checkpoint/example_processing_summary.json b/demos/partition_and_checkpoint/example_processing_summary.json new file mode 100644 index 0000000000..3b511f1820 --- /dev/null +++ b/demos/partition_and_checkpoint/example_processing_summary.json @@ -0,0 +1,102 @@ +{ + "start_time": 1703123456.789, + "end_time": 1703123471.123, + "total_processing_time": 14.334, + "total_partitions": 3, + "completed_partitions": 2, + "failed_partitions": 1, + "total_operations": 9, + "completed_operations": 8, + "failed_operations": 1, + "checkpoints_created": 6, + "total_samples_processed": 30000, + "total_samples_output": 15885, + "success_rate": 0.667, + "errors": [ + { + "timestamp": 1703123464.123, + "message": "Error during whitespace normalization on partition 1", + "partition_id": 1, + "operation_name": "whitespace_normalization_mapper", + "error_details": "ValueError: Invalid text format in sample 2501: expected string, got None" + } + ], + "partition_details": [ + { + "partition_id": 0, + "status": "completed", + "start_time": 1703123457.123, + "end_time": 1703123462.890, + "processing_time": 5.767, + "operations_completed": 3, + "checkpoints_created": 3, + "initial_sample_count": 10000, + "final_sample_count": 7875, + "samples_filtered": 2125 + }, + { + "partition_id": 1, + "status": "failed", + "start_time": 1703123463.123, + "end_time": 1703123464.234, + "processing_time": 1.111, + "operations_completed": 0, + "checkpoints_created": 0, + "initial_sample_count": 10000, + "final_sample_count": 0, + "samples_filtered": 0, + "error_message": "ValueError: Invalid text format in sample 2501: expected string, got None" + }, + { + "partition_id": 2, + "status": "completed", + "start_time": 1703123465.123, + "end_time": 1703123470.890, + "processing_time": 5.767, + "operations_completed": 3, + "checkpoints_created": 3, + "initial_sample_count": 10000, + "final_sample_count": 8010, + "samples_filtered": 1990 + } + ], + "operation_performance": { + "whitespace_normalization_mapper": { + "total_executions": 3, + "successful_executions": 2, + "failed_executions": 1, + "average_duration": 1.333, + "total_samples_processed": 22500, + "total_samples_filtered": 0 + }, + "text_length_filter": { + "total_executions": 2, + "successful_executions": 2, + "failed_executions": 0, + "average_duration": 1.333, + "total_samples_processed": 18900, + "total_samples_filtered": 2350 + }, + "language_id_score_filter": { + "total_executions": 2, + "successful_executions": 2, + "failed_executions": 0, + "average_duration": 1.666, + "total_samples_processed": 17650, + "total_samples_filtered": 1765 + } + }, + "resource_usage": { + "peak_memory_mb": 2048, + "average_cpu_percent": 75.5, + "total_disk_io_mb": 15.2, + "checkpoint_storage_mb": 8.5 + }, + "configuration": { + "executor_type": "ray_partitioned", + "partition_size": 10000, + "max_partition_size_mb": 128, + "storage_format": "parquet", + "preserve_intermediate_data": true + } +} \ No newline at end of file diff --git a/demos/partition_and_checkpoint/run_demo.py b/demos/partition_and_checkpoint/run_demo.py new file mode 100755 index 0000000000..c23edff409 --- /dev/null +++ b/demos/partition_and_checkpoint/run_demo.py @@ -0,0 +1,373 @@ +#!/usr/bin/env python3 +""" +Comprehensive Demo for DataJuicer Job Management & Monitoring + +This script demonstrates all the implemented job management features: +1. Processing Snapshot Utility - Comprehensive job status analysis with JSON output +2. Job Management Tools - Monitor and manage DataJuicer processing jobs +3. Resource-Aware Partitioning - Automatic resource optimization for distributed processing +4. Job-specific directory isolation +5. Flexible storage paths for event logs and checkpoints +6. Configurable checkpointing strategies +7. Event logging with JSONL format (events_{timestamp}.jsonl) +8. Job resumption capabilities +9. Comprehensive job management + +Important Notes: +- Event logs (events_{timestamp}.jsonl) are created immediately when a job starts +- Job summary (job_summary.json) is only created when a job completes successfully +- For running/incomplete jobs, use event logs and the monitor tool to track progress + +Usage: + # IMPORTANT: This script must be run from the Data-Juicer root directory + cd /path/to/data-juicer + python demos/partition_and_checkpoint/run_demo.py +""" + +import os +import subprocess +import time +import json +from pathlib import Path +import re + + +def run_data_juicer_command(config_file, job_id=None, extra_args=None): + """Run a DataJuicer command and return the result.""" + cmd = ["dj-process", "--config", config_file] + if job_id: + cmd.extend(["--job_id", job_id]) + if extra_args: + cmd.extend(extra_args) + + print(f"Running: {' '.join(cmd)}") + print("-" * 80) + + start_time = time.time() + result = subprocess.run(cmd, capture_output=True, text=True) + end_time = time.time() + + print(f"Exit code: {result.returncode}") + print(f"Duration: {end_time - start_time:.2f} seconds") + print("-" * 80) + + if result.stdout: + print("STDOUT:") + print(result.stdout) + + if result.stderr: + print("STDERR:") + print(result.stderr) + + return result + + +def run_snapshot_analysis(job_id, work_dir="./outputs/partition-checkpoint-eventlog"): + """Run the processing snapshot utility to analyze job status.""" + print(f"\n📊 Processing Snapshot Analysis for {job_id}:") + print("=" * 60) + + # Check if job directory exists and has events + job_dir = os.path.join(work_dir, job_id) + from pathlib import Path + job_path = Path(job_dir) + + if not job_path.exists(): + print(f"❌ Job directory not found: {job_dir}") + print("=" * 60) + return + + event_files = list(job_path.glob("events_*.jsonl")) + if not event_files and not (job_path / "events.jsonl").exists(): + print(f"ℹ️ No event logs found for this job yet.") + print(f" The job may still be initializing.") + print("=" * 60) + return + + # Run the snapshot utility + cmd = ["python", "-m", "data_juicer.utils.job.snapshot", job_dir] + + try: + result = subprocess.run(cmd, capture_output=True, text=True, timeout=30) + if result.returncode == 0: + snapshot_data = json.loads(result.stdout) + print("✅ Snapshot Analysis Results:") + print(f" Job Status: {snapshot_data.get('overall_status', 'unknown')}") + print(f" Progress: {snapshot_data.get('overall_progress', {}).get('overall_percentage', 0):.1f}%") + print(f" Duration: {snapshot_data.get('timing', {}).get('duration_formatted', 'unknown')}") + print(f" Partitions: {snapshot_data.get('progress_summary', {}).get('completed_partitions', 0)}/{snapshot_data.get('progress_summary', {}).get('total_partitions', 0)}") + print(f" Operations: {snapshot_data.get('progress_summary', {}).get('completed_operations', 0)}/{snapshot_data.get('progress_summary', {}).get('total_operations', 0)}") + print(f" Resumable: {snapshot_data.get('checkpointing', {}).get('resumable', False)}") + else: + print(f"⚠️ Snapshot analysis completed with warnings:") + if result.stderr: + # Only show first few lines of error + error_lines = result.stderr.strip().split('\n')[:3] + for line in error_lines: + if line.strip(): + print(f" {line}") + print(f" Tip: This is normal for jobs that haven't completed yet.") + except subprocess.TimeoutExpired: + print(f"⚠️ Snapshot analysis timed out (job may be too large)") + except json.JSONDecodeError: + print(f"⚠️ Could not parse snapshot output (job may be incomplete)") + except Exception as e: + print(f"⚠️ Error running snapshot analysis: {e}") + + print("=" * 60) + + +def check_directory_structure(job_id, work_dir="./outputs/partition-checkpoint-eventlog"): + """Check and display the job-specific directory structure.""" + job_dir = os.path.join(work_dir, job_id) + + print(f"\n📁 Job Directory Structure for {job_id}:") + print("=" * 60) + + if os.path.exists(job_dir): + for root, dirs, files in os.walk(job_dir): + level = root.replace(job_dir, '').count(os.sep) + indent = ' ' * 2 * level + print(f"{indent}{os.path.basename(root)}/") + subindent = ' ' * 2 * (level + 1) + for file in files: + print(f"{subindent}{file}") + else: + print(f"Job directory {job_dir} does not exist") + + print("=" * 60) + + +def check_flexible_storage(job_id): + """Check job storage directories.""" + print(f"\n💾 Job Storage for {job_id}:") + print("=" * 60) + + # Check event logs in job directory (find latest events file with timestamp) + from pathlib import Path + job_dir = Path(f"./outputs/partition-checkpoint-eventlog/{job_id}") + event_files = list(job_dir.glob("events_*.jsonl")) + + if event_files: + # Find the latest events file + event_log_file = max(event_files, key=lambda f: f.stat().st_mtime) + size = os.path.getsize(event_log_file) + print(f"✅ Event Logs: {event_log_file} ({size} bytes)") + else: + # Try old naming convention for backward compatibility + event_log_file = job_dir / "events.jsonl" + if event_log_file.exists(): + size = os.path.getsize(event_log_file) + print(f"✅ Event Logs: {event_log_file} ({size} bytes)") + else: + print(f"❌ Event Logs: No events files found in {job_dir}") + + # Check logs directory + logs_dir = f"./outputs/partition-checkpoint-eventlog/{job_id}/logs" + if os.path.exists(logs_dir): + print(f"✅ Logs Directory: {logs_dir}") + for file in os.listdir(logs_dir): + file_path = os.path.join(logs_dir, file) + size = os.path.getsize(file_path) + print(f" 📄 {file} ({size} bytes)") + else: + print(f"❌ Logs Directory: {logs_dir} not found") + + # Check checkpoints in job directory + checkpoint_dir = f"./outputs/partition-checkpoint-eventlog/{job_id}/checkpoints" + if os.path.exists(checkpoint_dir): + print(f"✅ Checkpoints: {checkpoint_dir}") + for file in os.listdir(checkpoint_dir): + file_path = os.path.join(checkpoint_dir, file) + if os.path.isfile(file_path): + size = os.path.getsize(file_path) + print(f" 💾 {file} ({size} bytes)") + else: + print(f" 📁 {file}/") + else: + print(f"❌ Checkpoints: {checkpoint_dir} not found") + + print("=" * 60) + + +def check_job_summary(job_id, work_dir="./outputs/partition-checkpoint-eventlog"): + """Check and display job summary.""" + job_dir = os.path.join(work_dir, job_id) + summary_file = os.path.join(job_dir, "job_summary.json") + + print(f"\n📋 Job Summary for {job_id}:") + print("=" * 60) + + if os.path.exists(summary_file): + with open(summary_file, 'r') as f: + summary = json.load(f) + + print(f"✅ Job Summary Available (job completed)") + print(f" Job ID: {summary.get('job_id')}") + print(f" Status: {summary.get('status')}") + print(f" Start Time: {summary.get('start_time')}") + print(f" Job Directory: {summary.get('job_dir')}") + print(f" Event Log File: {summary.get('event_log_file')}") + print(f" Checkpoint Directory: {summary.get('checkpoint_dir')}") + print(f" Resumption Command: {summary.get('resumption_command')}") + else: + print(f"ℹ️ Job summary not yet available") + print(f" Note: job_summary.json is created when the job completes.") + print(f" For running jobs, use the snapshot analysis or monitor tools instead.") + + # Try to get basic info from event logs + from pathlib import Path + job_path = Path(job_dir) + event_files = list(job_path.glob("events_*.jsonl")) + if event_files: + latest_event_file = max(event_files, key=lambda f: f.stat().st_mtime) + print(f" Event logs available: {latest_event_file.name}") + print(f" Use: python -m data_juicer.utils.job.monitor {job_id}") + + print("=" * 60) + + +def check_resource_optimization(): + """Check resource-aware partitioning configuration.""" + print(f"\n⚙️ Resource-Aware Partitioning Check:") + print("=" * 60) + + # Check if resource optimization is enabled in config + config_file = "configs/demo/partition-checkpoint-eventlog.yaml" + if os.path.exists(config_file): + with open(config_file, 'r') as f: + config_content = f.read() + + if "resource_optimization:" in config_content and "auto_configure: true" in config_content: + print("✅ Resource optimization is enabled") + print(" - Automatic partition size optimization") + print(" - Worker count optimization") + print(" - 64MB partition targeting") + else: + print("ℹ️ Resource optimization not enabled (using manual configuration)") + else: + print(f"❌ Config file {config_file} not found") + + print("=" * 60) + + +def get_latest_job_id(work_dir): + """Get the most recently created job_id directory in work_dir.""" + if not os.path.exists(work_dir): + return None + job_dirs = [d for d in os.listdir(work_dir) if os.path.isdir(os.path.join(work_dir, d))] + if not job_dirs: + return None + # Sort by creation time (descending) + job_dirs = sorted(job_dirs, key=lambda d: os.path.getctime(os.path.join(work_dir, d)), reverse=True) + return job_dirs[0] + + +def main(): + """Run the comprehensive demo.""" + print("🚀 DataJuicer Job Management & Monitoring Demo") + print("=" * 80) + + # IMPORTANT: This script must be run from the Data-Juicer root directory + # Check if we're in the root directory by looking for key files/directories + if not os.path.exists("configs") or not os.path.exists("data_juicer"): + print("❌ Error: This script must be run from the Data-Juicer root directory!") + print(" Current directory:", os.getcwd()) + print(" Expected to find: configs/ and data_juicer/ directories") + print("\n Please run:") + print(" cd /path/to/data-juicer") + print(" python demos/partition_and_checkpoint/run_demo.py") + return + + config_file = "configs/demo/partition-checkpoint-eventlog.yaml" + work_dir = "./outputs/partition-checkpoint-eventlog" + + # Ensure the config file exists + if not os.path.exists(config_file): + print(f"❌ Config file {config_file} not found!") + print("Please run this script from the DataJuicer root directory.") + return + + # Check resource optimization configuration + check_resource_optimization() + + # Demo 1: First run with new job (auto-generated job_id) + print("\n🎯 Demo 1: First Run (New Job, Auto-generated job_id)") + print("=" * 80) + result1 = run_data_juicer_command(config_file) + job_id_1 = get_latest_job_id(work_dir) + if result1.returncode == 0 and job_id_1: + print(f"✅ First run completed successfully! (job_id: {job_id_1})") + check_directory_structure(job_id_1, work_dir) + check_flexible_storage(job_id_1) + check_job_summary(job_id_1, work_dir) + run_snapshot_analysis(job_id_1, work_dir) + else: + print("❌ First run failed!") + return + + # Demo 2: Resume the same job + print("\n🎯 Demo 2: Resume Job") + print("=" * 80) + result2 = run_data_juicer_command(config_file, job_id_1) + if result2.returncode == 0: + print("✅ Job resumption completed successfully!") + print("Note: This should be much faster than the first run due to checkpoint resumption.") + check_job_summary(job_id_1, work_dir) + run_snapshot_analysis(job_id_1, work_dir) + else: + print("❌ Job resumption failed!") + + # Demo 3: New job with different checkpoint strategy (auto-generated job_id) + print("\n🎯 Demo 3: Different Checkpoint Strategy") + print("=" * 80) + extra_args = ["--checkpoint.strategy", "every_partition"] + result3 = run_data_juicer_command(config_file, None, extra_args) + job_id_2 = get_latest_job_id(work_dir) + if result3.returncode == 0 and job_id_2: + print(f"✅ Different checkpoint strategy completed successfully! (job_id: {job_id_2})") + check_directory_structure(job_id_2, work_dir) + check_flexible_storage(job_id_2) + check_job_summary(job_id_2, work_dir) + run_snapshot_analysis(job_id_2, work_dir) + else: + print("❌ Different checkpoint strategy failed!") + + # Demo 4: List available jobs + print("\n🎯 Demo 4: List Available Jobs") + print("=" * 80) + if os.path.exists(work_dir): + print("Available job directories:") + from pathlib import Path + for item in os.listdir(work_dir): + item_path = os.path.join(work_dir, item) + if os.path.isdir(item_path): + # Check for event logs or job summary to confirm it's a job directory + job_path = Path(item_path) + has_events = list(job_path.glob("events_*.jsonl")) or (job_path / "events.jsonl").exists() + has_summary = (job_path / "job_summary.json").exists() + + if has_events or has_summary: + status_indicator = "✅" if has_summary else "🔄" + status_text = "Completed" if has_summary else "Running/Incomplete" + print(f" {status_indicator} {item} ({status_text})") + else: + print(f"Work directory {work_dir} not found") + + print("\n🎉 Demo completed!") + print("=" * 80) + print("Key Features Demonstrated:") + print("✅ Processing Snapshot Utility - Comprehensive job status analysis with JSON output") + print("✅ Job Management Tools - Monitor and manage DataJuicer processing jobs") + print("✅ Resource-Aware Partitioning - Automatic resource optimization for distributed processing") + print("✅ Job-specific directory isolation") + print("✅ Event logging with JSONL format") + print("✅ Human-readable logs with multiple levels") + print("✅ Configurable checkpointing strategies") + print("✅ Job resumption capabilities") + print("✅ Comprehensive job management with job_summary.json") + print("✅ Fast resumption from checkpoints") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/docs/JobManagement.md b/docs/JobManagement.md new file mode 100644 index 0000000000..e04091b059 --- /dev/null +++ b/docs/JobManagement.md @@ -0,0 +1,417 @@ +# Job Management & Monitoring + +Data-Juicer provides comprehensive job management and monitoring capabilities to help you track, analyze, and optimize your data processing workflows. + +## Overview + +The job management system includes: + +- **Processing Snapshot Utility**: Detailed analysis of job status and progress +- **Resource-Aware Partitioning**: Automatic optimization of distributed processing +- **Enhanced Logging**: Centralized logging with rotation and retention +- **Job Monitoring Tools**: Real-time tracking of processing jobs + +## Processing Snapshot Utility + +The Processing Snapshot Utility provides comprehensive analysis of Data-Juicer job processing status based on `events_{timestamp}.jsonl` (timestamped event logs) and DAG structure. + +### Features + +- **JSON Output**: Machine-readable format for automation and integration +- **Progress Tracking**: Detailed partition and operation progress +- **Checkpointing Analysis**: Checkpoint status and resumability information +- **Timing Analysis**: Precise timing from job summary or events +- **Resource Utilization**: Partition and operation-level statistics + +### Usage + +#### Basic Snapshot +```bash +python -m data_juicer.utils.job.snapshot outputs/partition-checkpoint-eventlog/20250809_040053_a001de +``` + +#### Human-Readable Output +```bash +python -m data_juicer.utils.job.snapshot outputs/partition-checkpoint-eventlog/20250809_040053_a001de --human-readable +``` + +### JSON Output Structure + +```json +{ + "job_info": { + "job_id": "20250809_040053_a001de", + "executor_type": "ray_partitioned", + "status": "completed", + "config_file": ["configs/demo/partition-checkpoint-eventlog.yaml"], + "work_dir": "./outputs/partition-checkpoint-eventlog/20250809_040053_a001de", + "resumption_command": "dj-process --config [Path_fr(...)] --job_id 20250809_040053_a001de", + "error_message": null + }, + "overall_status": "completed", + "overall_progress": { + "overall_percentage": 100.0, + "partition_percentage": 100.0, + "operation_percentage": 100.0 + }, + "timing": { + "start_time": 1754712053.496651, + "end_time": 1754712325.323669, + "duration_seconds": 271.82701802253723, + "duration_formatted": "4m 31s", + "job_summary_duration": 271.82701802253723, + "timing_source": "job_summary" + }, + "progress_summary": { + "total_partitions": 18, + "completed_partitions": 18, + "failed_partitions": 0, + "in_progress_partitions": 0, + "partition_progress_percentage": 100.0, + "total_operations": 144, + "completed_operations": 144, + "failed_operations": 0, + "checkpointed_operations": 0, + "operation_progress_percentage": 100.0 + }, + "checkpointing": { + "strategy": "every_op", + "last_checkpoint_time": 1754712320.123456, + "checkpointed_operations_count": 72, + "resumable": true, + "checkpoint_progress": { + "percentage": 50.0, + "checkpointed_operations": [...], + "checkpoint_coverage": 0.5 + }, + "checkpoint_dir": "./outputs/partition-checkpoint-eventlog/20250809_040053_a001de/checkpoints" + }, + "partition_progress": { + "0": { + "status": "completed", + "sample_count": 20000, + "creation_start_time": 1754712074.356004, + "creation_end_time": 1754712074.366004, + "processing_start_time": 1754712074.366004, + "processing_end_time": 1754712074.456004, + "current_operation": null, + "completed_operations": ["clean_links_mapper", "clean_email_mapper", ...], + "failed_operations": [], + "checkpointed_operations": [], + "error_message": null, + "progress_percentage": 100.0 + } + }, + "operation_progress": { + "p0_op0_clean_links_mapper": { + "operation_name": "clean_links_mapper", + "operation_idx": 0, + "status": "completed", + "start_time": 1754712074.356004, + "end_time": 1754712074.366004, + "duration": 0.01, + "input_rows": 20000, + "output_rows": 19363, + "checkpoint_time": null, + "error_message": null, + "progress_percentage": 100.0 + } + }, + "file_paths": { + "event_log_file": "./outputs/partition-checkpoint-eventlog/20250809_040053_a001de/events.jsonl", + "event_log_dir": "./outputs/partition-checkpoint-eventlog/20250809_040053_a001de/logs", + "checkpoint_dir": "./outputs/partition-checkpoint-eventlog/20250809_040053_a001de/checkpoints", + "metadata_dir": "./outputs/partition-checkpoint-eventlog/20250809_040053_a001de/metadata", + "backed_up_config_path": "./outputs/partition-checkpoint-eventlog/20250809_040053_a001de/partition-checkpoint-eventlog.yaml" + }, + "metadata": { + "snapshot_generated_at": "2025-08-09T13:33:54.770298", + "events_analyzed": 367, + "dag_plan_loaded": true, + "job_summary_loaded": true, + "job_summary_used": true + } +} +``` + +## Resource-Aware Partitioning + +The Resource-Aware Partitioning system automatically optimizes partition sizes and worker counts based on available cluster resources and data characteristics. + +### Features + +- **Automatic Resource Detection**: Analyzes local and cluster resources +- **Data-Driven Optimization**: Samples data to determine optimal partition sizes +- **Modality-Aware**: Different optimization strategies for text, image, audio, video, and multimodal data +- **64MB Target**: Optimizes partitions to target 64MB per partition +- **Worker Count Optimization**: Automatically determines optimal number of Ray workers + +### Configuration + +Enable resource optimization in your config: + +```yaml +# Resource optimization configuration +resource_optimization: + auto_configure: true # Enable automatic optimization + +# Manual configuration (used when auto_configure: false) +# partition: +# size: 10000 # Number of samples per partition +# max_size_mb: 128 # Maximum partition size in MB +# np: 2 # Number of Ray workers +``` + +### Optimization Process + +1. **Resource Detection**: Analyzes CPU, memory, GPU, and cluster resources +2. **Data Sampling**: Samples dataset to understand data characteristics +3. **Modality Analysis**: Determines data modality and applies appropriate optimizations +4. **Partition Calculation**: Calculates optimal partition size targeting 64MB +5. **Worker Optimization**: Determines optimal number of Ray workers +6. **Application**: Applies optimizations to the processing pipeline + +## Enhanced Logging System + +The enhanced logging system provides centralized logging with rotation and retention policies. + +### Features + +- **Centralized Logging**: All logs managed through `logger_utils.py` +- **Log Rotation**: Automatic rotation based on file size +- **Retention Policies**: Configurable retention and cleanup +- **Compression**: Automatic compression of rotated logs +- **Multiple Levels**: Separate log files for different log levels + +### Configuration + +```python +from data_juicer.utils.logger_utils import setup_logger + +# Setup logger with rotation and retention +setup_logger( + save_dir="./outputs", + filename="log.txt", + max_log_size_mb=100, # Rotate at 100MB + backup_count=5 # Keep 5 backup files +) +``` + +### Log Structure + +``` +outputs/ +├── job_20250809_040053_a001de/ +│ ├── events_{timestamp}.jsonl # Event log (JSONL format with timestamp) +│ ├── logs/ # Log directory +│ │ ├── events.log # Event log (human-readable) +│ │ ├── log.txt # Main log file +│ │ ├── log_DEBUG.txt # Debug level logs +│ │ ├── log_ERROR.txt # Error level logs +│ │ └── log_WARNING.txt # Warning level logs +│ ├── checkpoints/ # Checkpoint directory +│ ├── partitions/ # Partition directory +│ └── job_summary.json # Job summary (created on job completion) +``` + +## Job Management Tools + +### Job Utilities + +```python +from data_juicer.utils.job import JobUtils, create_snapshot + +# Create job utilities +job_utils = JobUtils("./outputs") + +# List running jobs +running_jobs = job_utils.list_running_jobs() + +# Load event logs +events = job_utils.load_event_logs() + +# Create processing snapshot +snapshot = create_snapshot("./outputs/job_20250809_040053_a001de") +``` + +### Event Analysis + +The system tracks various event types: + +- **Job Events**: `job_start`, `job_complete` +- **Partition Events**: `partition_creation_start`, `partition_creation_complete`, `partition_start`, `partition_complete`, `partition_failed` +- **Operation Events**: `op_start`, `op_complete`, `op_failed` +- **Checkpoint Events**: `checkpoint_save` +- **DAG Events**: `dag_build_start`, `dag_build_complete`, `dag_execution_plan_saved` + +## Best Practices + +### 1. Enable Resource Optimization + +Always enable resource optimization for production workloads: + +```yaml +resource_optimization: + auto_configure: true +``` + +### 2. Monitor Job Progress + +Use the snapshot utility to monitor long-running jobs: + +```bash +# Check job status +python -m data_juicer.utils.job.snapshot /path/to/job/directory + +# Get detailed analysis +python -m data_juicer.utils.job.snapshot /path/to/job/directory --human-readable +``` + +### 3. Configure Logging + +Set appropriate log rotation and retention: + +```python +setup_logger( + save_dir="./outputs", + max_log_size_mb=100, + backup_count=5 +) +``` + +### 4. Use Checkpointing + +Enable checkpointing for long-running jobs: + +```yaml +checkpoint: + enabled: true + strategy: "every_op" # or "every_partition", "every_n_ops" +``` + +### 5. Monitor Resource Usage + +The snapshot utility provides detailed resource utilization information: + +- Partition-level progress and timing +- Operation-level performance metrics +- Checkpoint coverage and resumability +- Overall job efficiency statistics + +## Integration Examples + +### Automation Script + +```python +import json +import subprocess +from pathlib import Path + +def monitor_job(job_dir: str): + """Monitor a Data-Juicer job and return status.""" + result = subprocess.run([ + "python", "-m", "data_juicer.utils.job.snapshot", job_dir + ], capture_output=True, text=True) + + if result.returncode == 0: + snapshot = json.loads(result.stdout) + return { + "status": snapshot["overall_status"], + "progress": snapshot["overall_progress"]["overall_percentage"], + "duration": snapshot["timing"]["duration_formatted"], + "resumable": snapshot["checkpointing"]["resumable"] + } + else: + return {"error": result.stderr} + +# Usage +status = monitor_job("./outputs/job_20250809_040053_a001de") +print(f"Job Status: {status['status']}, Progress: {status['progress']:.1f}%") +``` + +### Dashboard Integration + +The JSON output format makes it easy to integrate with monitoring dashboards: + +```python +def get_job_metrics(job_dir: str): + """Extract key metrics for dashboard display.""" + snapshot = create_snapshot(job_dir) + + return { + "job_id": snapshot.job_id, + "status": snapshot.overall_status.value, + "progress": { + "partitions": f"{snapshot.completed_partitions}/{snapshot.total_partitions}", + "operations": f"{snapshot.completed_operations}/{snapshot.total_operations}" + }, + "timing": { + "duration": snapshot.total_duration, + "start_time": snapshot.job_start_time + }, + "checkpointing": { + "resumable": snapshot.resumable, + "strategy": snapshot.checkpoint_strategy + } + } +``` + +## Troubleshooting + +### Common Issues + +1. **Job Not Starting**: Check resource availability and configuration +2. **Slow Performance**: Enable resource optimization and check partition sizes +3. **Memory Issues**: Reduce partition size or enable checkpointing +4. **Log File Growth**: Configure log rotation and retention policies + +### Debug Commands + +```bash +# Check job status +python -m data_juicer.utils.job.snapshot /path/to/job + +# Analyze events (finds the latest timestamped event file) +python -c "from pathlib import Path; import json; job_dir = Path('/path/to/job'); events_file = max(job_dir.glob('events_*.jsonl'), key=lambda f: f.stat().st_mtime); events = [json.loads(line) for line in open(events_file)]; print(f'Total events: {len(events)}')" + +# Check resource usage +python -c "from data_juicer.core.executor.partition_size_optimizer import ResourceDetector; print(ResourceDetector.detect_local_resources())" +``` + +## API Reference + +### ProcessingSnapshotAnalyzer + +```python +from data_juicer.utils.job.snapshot import ProcessingSnapshotAnalyzer + +analyzer = ProcessingSnapshotAnalyzer(job_dir) +snapshot = analyzer.generate_snapshot() +json_data = analyzer.to_json_dict(snapshot) +``` + +### ResourceDetector + +```python +from data_juicer.core.executor.partition_size_optimizer import ResourceDetector + +# Detect local resources +local_resources = ResourceDetector.detect_local_resources() + +# Detect Ray cluster +cluster_resources = ResourceDetector.detect_ray_cluster() + +# Calculate optimal worker count +optimal_workers = ResourceDetector.calculate_optimal_worker_count() +``` + +### PartitionSizeOptimizer + +```python +from data_juicer.core.executor.partition_size_optimizer import PartitionSizeOptimizer + +optimizer = PartitionSizeOptimizer() +recommendations = optimizer.get_partition_recommendations(dataset, modality) +``` + +This comprehensive job management system provides the tools you need to monitor, optimize, and troubleshoot Data-Juicer processing jobs effectively. diff --git a/docs/JobManagement_ZH.md b/docs/JobManagement_ZH.md new file mode 100644 index 0000000000..3dea89b80d --- /dev/null +++ b/docs/JobManagement_ZH.md @@ -0,0 +1,417 @@ +# 作业管理与监控 + +Data-Juicer 提供全面的作业管理和监控功能,帮助您跟踪、分析和优化数据处理工作流。 + +## 概述 + +作业管理系统包括: + +- **处理快照工具**:详细的作业状态和进度分析 +- **资源感知分区**:分布式处理的自动优化 +- **增强日志系统**:集中化日志管理,支持轮转和保留 +- **作业监控工具**:处理作业的实时跟踪 + +## 处理快照工具 + +处理快照工具基于 `events_{timestamp}.jsonl`(带时间戳的事件日志)和 DAG 结构提供 Data-Juicer 作业处理状态的全面分析。 + +### 功能特性 + +- **JSON 输出**:机器可读格式,便于自动化和集成 +- **进度跟踪**:详细的分区和操作进度 +- **检查点分析**:检查点状态和可恢复性信息 +- **时间分析**:从作业摘要或事件中获取精确时间 +- **资源利用**:分区和操作级别的统计信息 + +### 使用方法 + +#### 基本快照 +```bash +python -m data_juicer.utils.job.snapshot outputs/partition-checkpoint-eventlog/20250809_040053_a001de +``` + +#### 人类可读输出 +```bash +python -m data_juicer.utils.job.snapshot outputs/partition-checkpoint-eventlog/20250809_040053_a001de --human-readable +``` + +### JSON 输出结构 + +```json +{ + "job_info": { + "job_id": "20250809_040053_a001de", + "executor_type": "ray_partitioned", + "status": "completed", + "config_file": ["configs/demo/partition-checkpoint-eventlog.yaml"], + "work_dir": "./outputs/partition-checkpoint-eventlog/20250809_040053_a001de", + "resumption_command": "dj-process --config [Path_fr(...)] --job_id 20250809_040053_a001de", + "error_message": null + }, + "overall_status": "completed", + "overall_progress": { + "overall_percentage": 100.0, + "partition_percentage": 100.0, + "operation_percentage": 100.0 + }, + "timing": { + "start_time": 1754712053.496651, + "end_time": 1754712325.323669, + "duration_seconds": 271.82701802253723, + "duration_formatted": "4m 31s", + "job_summary_duration": 271.82701802253723, + "timing_source": "job_summary" + }, + "progress_summary": { + "total_partitions": 18, + "completed_partitions": 18, + "failed_partitions": 0, + "in_progress_partitions": 0, + "partition_progress_percentage": 100.0, + "total_operations": 144, + "completed_operations": 144, + "failed_operations": 0, + "checkpointed_operations": 0, + "operation_progress_percentage": 100.0 + }, + "checkpointing": { + "strategy": "every_op", + "last_checkpoint_time": 1754712320.123456, + "checkpointed_operations_count": 72, + "resumable": true, + "checkpoint_progress": { + "percentage": 50.0, + "checkpointed_operations": [...], + "checkpoint_coverage": 0.5 + }, + "checkpoint_dir": "./outputs/partition-checkpoint-eventlog/20250809_040053_a001de/checkpoints" + }, + "partition_progress": { + "0": { + "status": "completed", + "sample_count": 20000, + "creation_start_time": 1754712074.356004, + "creation_end_time": 1754712074.366004, + "processing_start_time": 1754712074.366004, + "processing_end_time": 1754712074.456004, + "current_operation": null, + "completed_operations": ["clean_links_mapper", "clean_email_mapper", ...], + "failed_operations": [], + "checkpointed_operations": [], + "error_message": null, + "progress_percentage": 100.0 + } + }, + "operation_progress": { + "p0_op0_clean_links_mapper": { + "operation_name": "clean_links_mapper", + "operation_idx": 0, + "status": "completed", + "start_time": 1754712074.356004, + "end_time": 1754712074.366004, + "duration": 0.01, + "input_rows": 20000, + "output_rows": 19363, + "checkpoint_time": null, + "error_message": null, + "progress_percentage": 100.0 + } + }, + "file_paths": { + "event_log_file": "./outputs/partition-checkpoint-eventlog/20250809_040053_a001de/events.jsonl", + "event_log_dir": "./outputs/partition-checkpoint-eventlog/20250809_040053_a001de/logs", + "checkpoint_dir": "./outputs/partition-checkpoint-eventlog/20250809_040053_a001de/checkpoints", + "metadata_dir": "./outputs/partition-checkpoint-eventlog/20250809_040053_a001de/metadata", + "backed_up_config_path": "./outputs/partition-checkpoint-eventlog/20250809_040053_a001de/partition-checkpoint-eventlog.yaml" + }, + "metadata": { + "snapshot_generated_at": "2025-08-09T13:33:54.770298", + "events_analyzed": 367, + "dag_plan_loaded": true, + "job_summary_loaded": true, + "job_summary_used": true + } +} +``` + +## 资源感知分区 + +资源感知分区系统根据可用的集群资源和数据特征自动优化分区大小和工作节点数量。 + +### 功能特性 + +- **自动资源检测**:分析本地和集群资源 +- **数据驱动优化**:采样数据以确定最佳分区大小 +- **模态感知**:针对文本、图像、音频、视频和多模态数据的不同优化策略 +- **64MB 目标**:优化分区以目标 64MB 每个分区 +- **工作节点数量优化**:自动确定最佳 Ray 工作节点数量 + +### 配置 + +在配置中启用资源优化: + +```yaml +# 资源优化配置 +resource_optimization: + auto_configure: true # 启用自动优化 + +# 手动配置(当 auto_configure: false 时使用) +# partition: +# size: 10000 # 每个分区的样本数 +# max_size_mb: 128 # 最大分区大小(MB) +# np: 2 # Ray 工作节点数量 +``` + +### 优化过程 + +1. **资源检测**:分析 CPU、内存、GPU 和集群资源 +2. **数据采样**:采样数据集以了解数据特征 +3. **模态分析**:确定数据模态并应用适当的优化 +4. **分区计算**:计算最佳分区大小,目标 64MB +5. **工作节点优化**:确定最佳 Ray 工作节点数量 +6. **应用**:将优化应用到处理管道 + +## 增强日志系统 + +增强日志系统提供集中化日志管理,支持轮转和保留策略。 + +### 功能特性 + +- **集中化日志**:所有日志通过 `logger_utils.py` 管理 +- **日志轮转**:基于文件大小的自动轮转 +- **保留策略**:可配置的保留和清理 +- **压缩**:轮转日志的自动压缩 +- **多级别**:不同日志级别的单独日志文件 + +### 配置 + +```python +from data_juicer.utils.logger_utils import setup_logger + +# 设置带轮转和保留的日志记录器 +setup_logger( + save_dir="./outputs", + filename="log.txt", + max_log_size_mb=100, # 100MB 时轮转 + backup_count=5 # 保留 5 个备份文件 +) +``` + +### 日志结构 + +``` +outputs/ +├── job_20250809_040053_a001de/ +│ ├── events_{timestamp}.jsonl # 事件日志(带时间戳的 JSONL 格式) +│ ├── logs/ # 日志目录 +│ │ ├── events.log # 事件日志(人类可读) +│ │ ├── log.txt # 主日志文件 +│ │ ├── log_DEBUG.txt # 调试级别日志 +│ │ ├── log_ERROR.txt # 错误级别日志 +│ │ └── log_WARNING.txt # 警告级别日志 +│ ├── checkpoints/ # 检查点目录 +│ ├── partitions/ # 分区目录 +│ └── job_summary.json # 作业摘要(作业完成时创建) +``` + +## 作业管理工具 + +### 作业工具 + +```python +from data_juicer.utils.job import JobUtils, create_snapshot + +# 创建作业工具 +job_utils = JobUtils("./outputs") + +# 列出运行中的作业 +running_jobs = job_utils.list_running_jobs() + +# 加载事件日志 +events = job_utils.load_event_logs() + +# 创建处理快照 +snapshot = create_snapshot("./outputs/job_20250809_040053_a001de") +``` + +### 事件分析 + +系统跟踪各种事件类型: + +- **作业事件**:`job_start`、`job_complete` +- **分区事件**:`partition_creation_start`、`partition_creation_complete`、`partition_start`、`partition_complete`、`partition_failed` +- **操作事件**:`op_start`、`op_complete`、`op_failed` +- **检查点事件**:`checkpoint_save` +- **DAG 事件**:`dag_build_start`、`dag_build_complete`、`dag_execution_plan_saved` + +## 最佳实践 + +### 1. 启用资源优化 + +对于生产工作负载,始终启用资源优化: + +```yaml +resource_optimization: + auto_configure: true +``` + +### 2. 监控作业进度 + +使用快照工具监控长时间运行的作业: + +```bash +# 检查作业状态 +python -m data_juicer.utils.job.snapshot /path/to/job/directory + +# 获取详细分析 +python -m data_juicer.utils.job.snapshot /path/to/job/directory --human-readable +``` + +### 3. 配置日志 + +设置适当的日志轮转和保留: + +```python +setup_logger( + save_dir="./outputs", + max_log_size_mb=100, + backup_count=5 +) +``` + +### 4. 使用检查点 + +为长时间运行的作业启用检查点: + +```yaml +checkpoint: + enabled: true + strategy: "every_op" # 或 "every_partition"、"every_n_ops" +``` + +### 5. 监控资源使用 + +快照工具提供详细的资源利用信息: + +- 分区级别的进度和时间 +- 操作级别的性能指标 +- 检查点覆盖率和可恢复性 +- 整体作业效率统计 + +## 集成示例 + +### 自动化脚本 + +```python +import json +import subprocess +from pathlib import Path + +def monitor_job(job_dir: str): + """监控 Data-Juicer 作业并返回状态。""" + result = subprocess.run([ + "python", "-m", "data_juicer.utils.job.snapshot", job_dir + ], capture_output=True, text=True) + + if result.returncode == 0: + snapshot = json.loads(result.stdout) + return { + "status": snapshot["overall_status"], + "progress": snapshot["overall_progress"]["overall_percentage"], + "duration": snapshot["timing"]["duration_formatted"], + "resumable": snapshot["checkpointing"]["resumable"] + } + else: + return {"error": result.stderr} + +# 使用 +status = monitor_job("./outputs/job_20250809_040053_a001de") +print(f"作业状态: {status['status']}, 进度: {status['progress']:.1f}%") +``` + +### 仪表板集成 + +JSON 输出格式便于与监控仪表板集成: + +```python +def get_job_metrics(job_dir: str): + """提取仪表板显示的关键指标。""" + snapshot = create_snapshot(job_dir) + + return { + "job_id": snapshot.job_id, + "status": snapshot.overall_status.value, + "progress": { + "partitions": f"{snapshot.completed_partitions}/{snapshot.total_partitions}", + "operations": f"{snapshot.completed_operations}/{snapshot.total_operations}" + }, + "timing": { + "duration": snapshot.total_duration, + "start_time": snapshot.job_start_time + }, + "checkpointing": { + "resumable": snapshot.resumable, + "strategy": snapshot.checkpoint_strategy + } + } +``` + +## 故障排除 + +### 常见问题 + +1. **作业无法启动**:检查资源可用性和配置 +2. **性能缓慢**:启用资源优化并检查分区大小 +3. **内存问题**:减少分区大小或启用检查点 +4. **日志文件增长**:配置日志轮转和保留策略 + +### 调试命令 + +```bash +# 检查作业状态 +python -m data_juicer.utils.job.snapshot /path/to/job + +# 分析事件(查找最新的带时间戳的事件文件) +python -c "from pathlib import Path; import json; job_dir = Path('/path/to/job'); events_file = max(job_dir.glob('events_*.jsonl'), key=lambda f: f.stat().st_mtime); events = [json.loads(line) for line in open(events_file)]; print(f'总事件数: {len(events)}')" + +# 检查资源使用 +python -c "from data_juicer.core.executor.partition_size_optimizer import ResourceDetector; print(ResourceDetector.detect_local_resources())" +``` + +## API 参考 + +### ProcessingSnapshotAnalyzer + +```python +from data_juicer.utils.job.snapshot import ProcessingSnapshotAnalyzer + +analyzer = ProcessingSnapshotAnalyzer(job_dir) +snapshot = analyzer.generate_snapshot() +json_data = analyzer.to_json_dict(snapshot) +``` + +### ResourceDetector + +```python +from data_juicer.core.executor.partition_size_optimizer import ResourceDetector + +# 检测本地资源 +local_resources = ResourceDetector.detect_local_resources() + +# 检测 Ray 集群 +cluster_resources = ResourceDetector.detect_ray_cluster() + +# 计算最佳工作节点数量 +optimal_workers = ResourceDetector.calculate_optimal_worker_count() +``` + +### PartitionSizeOptimizer + +```python +from data_juicer.core.executor.partition_size_optimizer import PartitionSizeOptimizer + +optimizer = PartitionSizeOptimizer() +recommendations = optimizer.get_partition_recommendations(dataset, modality) +``` + +这个全面的作业管理系统提供了您有效监控、优化和故障排除 Data-Juicer 处理作业所需的工具。 diff --git a/docs/PartitionAndCheckpoint.md b/docs/PartitionAndCheckpoint.md new file mode 100644 index 0000000000..a53d829362 --- /dev/null +++ b/docs/PartitionAndCheckpoint.md @@ -0,0 +1,864 @@ +# DataJuicer Fault-Tolerant Processing with Checkpointing and Event Logging + +This directory contains the implementation of fault-tolerant, resumable DataJuicer processing with comprehensive checkpointing, partitioning, and event logging capabilities. + +## 🚀 Features Implemented + +### ✅ Core Features +- **Job-Specific Directory Isolation**: Each job gets its own dedicated directory structure +- **Configurable Checkpointing Strategies**: Multiple checkpointing frequencies and strategies +- **Spark-Style Event Logging**: Comprehensive event tracking in JSONL format for resumability +- **Job Resumption Capabilities**: Resume failed or interrupted jobs from the last checkpoint +- **Comprehensive Job Management**: Job summaries, metadata tracking, and resumption commands + +### ✅ Checkpointing Strategies +- `EVERY_OP`: Checkpoint after every operation (most resilient, slower) +- `EVERY_N_OPS`: Checkpoint after every N operations (configurable) +- `MANUAL`: Checkpoint only after specified operations +- `DISABLED`: Disable checkpointing entirely + +### ✅ Event Logging +- **Human-readable logs**: Loguru-based logging for debugging and monitoring +- **Machine-readable logs**: JSONL format for programmatic analysis and resumption +- **Comprehensive event types**: Job start/complete/failed, partition events, operation events, checkpoint events +- **Real-time monitoring**: Live event streaming and status reporting + +### ✅ Job Management +- **Meaningful Job IDs**: Format: `{YYYYMMDD}_{HHMMSS}_{config_name}_{unique_suffix}` +- **Job Summary Files**: Comprehensive metadata for each job run +- **Resumption Commands**: Automatic generation of exact commands to resume jobs +- **Job Validation**: Validation of job resumption parameters and existing state + +## 📁 Directory Structure + +``` +{work_dir}/ +├── {job_id}/ # Job-specific directory +│ ├── job_summary.json # Job metadata and resumption info (created on job completion) +│ ├── events_{timestamp}.jsonl # Machine-readable events (JSONL format with timestamp) +│ ├── dag_execution_plan.json # DAG execution plan +│ ├── partition-checkpoint-eventlog.yaml # Backed up config file +│ ├── metadata/ # Job metadata files +│ │ ├── dataset_mapping.json +│ │ └── final_mapping_report.json +│ ├── logs/ # Human-readable logs +│ │ ├── export_processed.jsonl_time_*.txt # Main log file +│ │ ├── export_processed.jsonl_time_*_DEBUG.txt # Debug level logs +│ │ ├── export_processed.jsonl_time_*_WARNING.txt # Warning level logs +│ │ └── export_processed.jsonl_time_*_ERROR.txt # Error level logs +│ ├── checkpoints/ # Checkpoint data +│ │ ├── checkpoint_*.json # Checkpoint metadata +│ │ └── partition_*/ # Partition checkpoint data +│ ├── partitions/ # Input data partitions +│ ├── processed.jsonl/ # Intermediate processing results +│ └── results/ # Final processing results +``` + +## 🛠️ Configuration + +### Configuration Structure + +The configuration uses a **logical nested structure** that groups related settings by concern: + +#### New Logical Structure (Recommended) +```yaml +# Partitioning configuration +partition: + size: 1000 # Number of samples per partition + max_size_mb: 64 # Maximum partition size in MB + + + +# Intermediate storage configuration for partition and checkpoint data (format, compression, and lifecycle management) +intermediate_storage: + # File format and compression + format: "parquet" # parquet, arrow, jsonl + compression: "snappy" # snappy, gzip, none + use_arrow_batches: true + arrow_batch_size: 500 + arrow_memory_mapping: false + + # File lifecycle management + preserve_intermediate_data: true # Keep temporary files for debugging/resumption + cleanup_temp_files: true + cleanup_on_success: false + retention_policy: "keep_all" # keep_all, keep_failed_only, cleanup_all + max_retention_days: 7 +``` + +#### Legacy Flat Structure (Still Supported) +```yaml +# Legacy flat configuration (still works) +partition_size: 1000 +max_partition_size_mb: 64 +preserve_intermediate_data: true +storage_format: "parquet" +use_arrow_batches: true +arrow_batch_size: 500 +arrow_memory_mapping: false +``` + +**Note**: The system reads from the new nested sections first, then falls back to the legacy flat configuration if not found. + +### Configuration Sections Explained + +#### `partition` - Partitioning and Resilience +Controls how the dataset is split and how failures are handled: + +**Two Partition Modes:** + +1. **Auto Mode** (Recommended - `mode: "auto"`): + - Automatically analyzes your data characteristics and system resources + - Calculates optimal partition size targeting ~64MB per partition + - Determines optimal number of partitions based on dataset size + - Configures optimal worker count based on available CPU cores + - No manual tuning required - adapts to your hardware and data + - Configuration: + - `mode`: `"auto"` + - `size`: Fallback partition size (samples) - used if auto-analysis fails + - `max_size_mb`: Fallback max partition size (MB) - used if auto-analysis fails + +2. **Manual Mode** (`mode: "manual"`): + - You specify the exact number of partitions to create + - Useful when you know your optimal partitioning strategy + - Configuration: + - `mode`: `"manual"` + - `num_of_partitions`: Exact number of partitions to create + - `size` and `max_size_mb` are ignored in manual mode + + +#### `intermediate_storage` - Intermediate Data Management +Controls file formats, compression, and lifecycle management for intermediate data: +- **File Format & Compression**: + - `format`: Storage format (`parquet`, `arrow`, `jsonl`) + - `compression`: Compression algorithm (`snappy`, `gzip`, `none`) + - `use_arrow_batches`: Use Arrow batch processing + - `arrow_batch_size`: Arrow batch size + - `arrow_memory_mapping`: Enable memory mapping +- **File Lifecycle Management**: + - `preserve_intermediate_data`: Keep temporary files for debugging + - `cleanup_temp_files`: Enable automatic cleanup + - `cleanup_on_success`: Clean up even on successful completion + - `retention_policy`: File retention strategy (`keep_all`, `keep_failed_only`, `cleanup_all`) + - `max_retention_days`: Auto-cleanup after X days + +### Basic Configuration +```yaml +# Enable fault-tolerant processing +executor_type: ray_partitioned + +# Job management +job_id: my_experiment_001 # Optional: auto-generated if not provided + +# Checkpointing configuration +checkpoint: + enabled: true + strategy: every_op # every_op, every_n_ops, manual, disabled + n_ops: 2 # For every_n_ops strategy + op_names: # For manual strategy + - clean_links_mapper + - whitespace_normalization_mapper + +# Event logging configuration +event_logging: + enabled: true + max_log_size_mb: 100 + backup_count: 5 + +# Partitioning configuration +partition: + mode: "auto" # Auto mode - optimal partitioning based on data analysis + size: 5000 # Fallback partition size (samples) - used if auto-analysis fails + max_size_mb: 64 # Fallback max partition size (MB) - used if auto-analysis fails + # Note: num_of_partitions is calculated automatically in auto mode + +# Alternative: Manual partition mode +# partition: +# mode: "manual" # Manual mode - specify exact number of partitions +# num_of_partitions: 8 # Split dataset into exactly 8 partitions +# # Note: size and max_size_mb are ignored in manual mode + + + +# Intermediate storage configuration for partition and checkpoint data (format, compression, and lifecycle management) +intermediate_storage: + # File format and compression + format: "parquet" # parquet, arrow, jsonl + compression: "snappy" # snappy, gzip, none + use_arrow_batches: true + arrow_batch_size: 500 + arrow_memory_mapping: false + + # File lifecycle management + preserve_intermediate_data: true # Keep temporary files for debugging/resumption + cleanup_temp_files: true + cleanup_on_success: false + retention_policy: "keep_all" # keep_all, keep_failed_only, cleanup_all + max_retention_days: 7 +``` + +## 📊 Partition Modes Explained + +### Auto Mode (Recommended) +**When to use:** Most use cases, especially when you want optimal performance without manual tuning. + +**Benefits:** +- ✅ Automatically adapts to your data characteristics (text length, modality, etc.) +- ✅ Optimizes for your system resources (CPU, memory, GPU) +- ✅ Targets ~64MB per partition for optimal memory usage +- ✅ Calculates optimal number of partitions based on dataset size +- ✅ No manual tuning required + +**Example output:** +``` +🔧 Auto-configuring partition settings based on data characteristics... +📊 Dataset analysis complete: + Total samples: 10000 + Recommended partition size: 5000 samples + Calculated partitions: 2 + Recommended max size: 64 MB + Recommended workers: 4 +``` + +### Manual Mode +**When to use:** When you have specific requirements or know your optimal partitioning strategy. + +**Benefits:** +- ✅ Full control over partition count +- ✅ Predictable resource usage +- ✅ Useful for debugging or specific workflows +- ✅ Can be more efficient for known dataset patterns + +**Example:** +```yaml +partition: + mode: "manual" + num_of_partitions: 8 # Always creates exactly 8 partitions +``` + +## 🚀 Quick Start + +### 1. Basic Usage + +#### Auto Partition Mode (Recommended) +```bash +# Run with auto-generated job ID and auto partition optimization +dj-process --config configs/demo/partition-checkpoint-eventlog.yaml --partition.mode auto + +# Run with custom job ID +dj-process --config configs/demo/partition-checkpoint-eventlog.yaml --partition.mode auto --job_id my_experiment_001 +``` + +#### Manual Partition Mode +```bash +# Run with manual partition configuration (4 partitions) +dj-process --config configs/demo/partition-checkpoint-eventlog.yaml --partition.mode manual --partition.num_of_partitions 4 + +# Run with custom job ID +dj-process --config configs/demo/partition-checkpoint-eventlog.yaml --partition.mode manual --partition.num_of_partitions 4 --job_id my_experiment_001 +``` + +### 2. Resume a Job +```bash +# Resume using the job ID +dj-process --config configs/demo/partition-checkpoint-eventlog.yaml --job_id my_experiment_001 +``` + +### 3. Different Checkpoint Strategies +```bash +# Checkpoint every operation (most resilient) +dj-process --config configs/demo/partition-checkpoint-eventlog.yaml --job_id every_op_test --checkpoint.strategy every_op + +# Checkpoint every 3 operations +dj-process --config configs/demo/partition-checkpoint-eventlog.yaml --job_id n_ops_test --checkpoint.strategy every_n_ops --checkpoint.n_ops 3 + +# Manual checkpointing +dj-process --config configs/demo/partition-checkpoint-eventlog.yaml --job_id manual_test --checkpoint.strategy manual --checkpoint.op_names clean_links_mapper,whitespace_normalization_mapper +``` + +### 4. Run Comprehensive Demo +```bash +# Run the full demo showcasing all features +python demos/partition_and_checkpoint/run_demo.py +``` + +## 📊 Monitoring and Debugging + +### View Job Information +```bash +# Check job summary (created on job completion) +cat ./outputs/partition-checkpoint-eventlog/{job_id}/job_summary.json + +# View event logs (use the latest events file with timestamp) +cat ./outputs/partition-checkpoint-eventlog/{job_id}/events_*.jsonl + +# View human-readable logs +cat ./outputs/partition-checkpoint-eventlog/{job_id}/logs/export_processed.jsonl_time_*.txt + +# View DAG execution plan +cat ./outputs/partition-checkpoint-eventlog/{job_id}/dag_execution_plan.json +``` + +### List Available Jobs +```bash +# List all job directories +ls -la ./outputs/partition-checkpoint-eventlog/ +``` + +### Check Job Structure +```bash +# Check job directory structure +ls -la ./outputs/partition-checkpoint-eventlog/{job_id}/ + +# Check logs directory +ls -la ./outputs/partition-checkpoint-eventlog/{job_id}/logs/ + +# Check checkpoints directory +ls -la ./outputs/partition-checkpoint-eventlog/{job_id}/checkpoints/ +``` + +## 📈 Job Management Utilities + +DataJuicer provides comprehensive job management utilities for monitoring progress and stopping running jobs. These utilities are located in `data_juicer/utils/job/` and provide both command-line and programmatic interfaces. + +### 📊 Job Progress Monitor + +A comprehensive utility to monitor and display progress information for DataJuicer jobs. Shows partition status, operation progress, checkpoints, and overall job metrics. + +#### Features + +- **Real-time Progress Tracking**: Monitor job progress with partition-level details +- **Operation Performance**: View detailed operation metrics including throughput and data reduction +- **Checkpoint Monitoring**: Track checkpoint saves and recovery points +- **Watch Mode**: Continuously monitor jobs with automatic updates +- **Programmatic Access**: Use as a Python function for integration into other tools + +#### Command Line Usage + +##### Basic Usage +```bash +# Show basic progress for a job +python -m data_juicer.utils.job.monitor 20250728_233517_510abf + +# Show detailed progress with operation metrics +python -m data_juicer.utils.job.monitor 20250728_233517_510abf --detailed + +# Watch mode - continuously update progress every 10 seconds +python -m data_juicer.utils.job.monitor 20250728_233517_510abf --watch + +# Watch mode with custom update interval (30 seconds) +python -m data_juicer.utils.job.monitor 20250728_233517_510abf --watch --interval 30 + +# Use custom base directory +python -m data_juicer.utils.job.monitor 20250728_233517_510abf --base-dir /custom/path +``` + +##### Command Line Options +- `job_id`: The job ID to monitor (required) +- `--base-dir`: Base directory containing job outputs (default: `outputs/partition-checkpoint-eventlog`) +- `--detailed`: Show detailed operation information +- `--watch`: Watch mode - continuously update progress +- `--interval`: Update interval in seconds for watch mode (default: 10) + +#### Python API + +##### Basic Function Usage +```python +from data_juicer.utils.job.monitor import show_job_progress + +# Show progress and get data +data = show_job_progress("20250728_233517_510abf") + +# Show detailed progress +data = show_job_progress("20250728_233517_510abf", detailed=True) + +# Use custom base directory +data = show_job_progress("20250728_233517_510abf", base_dir="/custom/path") +``` + +##### Class-based Usage +```python +from data_juicer.utils.job.monitor import JobProgressMonitor + +# Create monitor instance +monitor = JobProgressMonitor("20250728_233517_510abf") + +# Display progress +monitor.display_progress(detailed=True) + +# Get progress data as dictionary +data = monitor.get_progress_data() + +# Access specific information +job_status = data['overall_progress']['job_status'] +progress_percentage = data['overall_progress']['progress_percentage'] +partition_status = data['partition_status'] +``` + +### 🛑 Job Stopper + +A utility to stop running DataJuicer jobs by reading event logs to find process and thread IDs, then terminating those specific processes and threads. + +#### Features + +- **Precise Process Termination**: Uses event logs to identify exact processes and threads to terminate +- **Graceful Shutdown**: Sends SIGTERM first for graceful shutdown, then SIGKILL if needed +- **Safety Checks**: Validates job existence and running status before stopping +- **Comprehensive Logging**: Detailed logging of termination process +- **Programmatic Access**: Can be used as a Python function or command-line tool + +#### Command Line Usage + +##### Basic Usage +```bash +# Stop a job gracefully (SIGTERM) +python -m data_juicer.utils.job.stopper 20250728_233517_510abf + +# Force stop a job (SIGKILL) +python -m data_juicer.utils.job.stopper 20250728_233517_510abf --force + +# Stop with custom timeout (60 seconds) +python -m data_juicer.utils.job.stopper 20250728_233517_510abf --timeout 60 + +# Use custom base directory +python -m data_juicer.utils.job.stopper 20250728_233517_510abf --base-dir /custom/path + +# List all running jobs +python -m data_juicer.utils.job.stopper --list +``` + +##### Command Line Options +- `job_id`: The job ID to stop (required, unless using --list) +- `--base-dir`: Base directory containing job outputs (default: `outputs/partition-checkpoint-eventlog`) +- `--force`: Force kill with SIGKILL instead of graceful SIGTERM +- `--timeout`: Timeout in seconds for graceful shutdown (default: 30) +- `--list`: List all running jobs instead of stopping one + +#### Python API + +##### Basic Function Usage +```python +from data_juicer.utils.job.stopper import stop_job + +# Stop a job gracefully +result = stop_job("20250728_233517_510abf") + +# Force stop a job +result = stop_job("20250728_233517_510abf", force=True) + +# Stop with custom timeout +result = stop_job("20250728_233517_510abf", timeout=60) + +# Use custom base directory +result = stop_job("20250728_233517_510abf", base_dir="/custom/path") +``` + +##### Class-based Usage +```python +from data_juicer.utils.job.stopper import JobStopper + +# Create stopper instance +stopper = JobStopper("20250728_233517_510abf") + +# Stop the job +result = stopper.stop_job(force=False, timeout=30) + +# Check if job is running +is_running = stopper.is_job_running() + +# Get job summary +summary = stopper.get_job_summary() +``` + +### 🔧 Common Utilities + +Both the monitor and stopper utilities share common functionality through `data_juicer.utils.job.common`: + +```python +from data_juicer.utils.job.common import JobUtils, list_running_jobs + +# List all running jobs +running_jobs = list_running_jobs() + +# Create job utilities instance +job_utils = JobUtils("20250728_233517_510abf") + +# Load job summary +summary = job_utils.load_job_summary() + +# Load event logs +events = job_utils.load_event_logs() + +# Get partition status +partition_status = job_utils.get_partition_status() +``` + +### Output Information + +#### Job Overview +- Job status (completed, processing, failed, etc.) +- Dataset path and size +- Partition configuration +- Start time and duration + +#### Overall Progress +- Progress percentage +- Partition completion status +- Sample processing counts +- Estimated time remaining (for running jobs) + +#### Partition Status +- Individual partition status with visual indicators +- Sample counts per partition +- Current operation (if processing) +- Number of completed operations +- Number of saved checkpoints + +#### Operation Details (with --detailed flag) +- Per-partition operation performance +- Duration, throughput, and data reduction metrics +- Operation completion order + +#### Checkpoint Summary +- Total number of checkpoints saved +- Checkpoint details by partition and operation +- Timestamp information + +### Example Output + +``` +================================================================================ +DataJuicer Job Progress Monitor +Job ID: 20250728_233517_510abf +================================================================================ + +📊 JOB OVERVIEW + Status: COMPLETED + Dataset: /Users/yilei.z/Downloads/c4-train.00000-of-01024.jsonl + Total Samples: 356,317 + Partition Size: 50,000 samples + Start Time: 2025-07-28 16:35:18 + Duration: 441.1 seconds + +🎯 OVERALL PROGRESS + Progress: 100.0% (8/8 partitions) + Status: 8 completed, 0 processing, 0 failed + Samples: 356,317/356,317 + +📦 PARTITION STATUS + Partition 0: ✅ COMPLETED + Samples: 44,539 + Completed: 8 operations + Checkpoints: 2 saved + Partition 1: ✅ COMPLETED + Samples: 44,540 + Completed: 8 operations + Checkpoints: 2 saved + ... + +💾 CHECKPOINT SUMMARY + Total Checkpoints: 16 +``` + +### Integration Examples + +#### Monitoring Multiple Jobs +```python +from data_juicer.utils.job.monitor import show_job_progress + +job_ids = ["job1", "job2", "job3"] +for job_id in job_ids: + try: + data = show_job_progress(job_id) + print(f"Job {job_id}: {data['overall_progress']['progress_percentage']:.1f}%") + except FileNotFoundError: + print(f"Job {job_id}: Not found") +``` + +#### Custom Monitoring Script +```python +from data_juicer.utils.job.monitor import JobProgressMonitor +import time + +def monitor_job_until_completion(job_id, check_interval=30): + monitor = JobProgressMonitor(job_id) + + while True: + data = monitor.get_progress_data() + status = data['overall_progress']['job_status'] + + if status == 'completed': + print(f"Job {job_id} completed!") + break + elif status == 'failed': + print(f"Job {job_id} failed!") + break + + print(f"Job {job_id} still running... {data['overall_progress']['progress_percentage']:.1f}%") + time.sleep(check_interval) +``` + +#### Job Management Workflow +```python +from data_juicer.utils.job.monitor import show_job_progress +from data_juicer.utils.job.stopper import stop_job +from data_juicer.utils.job.common import list_running_jobs + +# List all running jobs +running_jobs = list_running_jobs() +print(f"Found {len(running_jobs)} running jobs") + +# Monitor and potentially stop jobs +for job_info in running_jobs: + job_id = job_info['job_id'] + + # Check progress + try: + data = show_job_progress(job_id) + progress = data['overall_progress']['progress_percentage'] + + # Stop jobs that are stuck (less than 10% progress after 1 hour) + if progress < 10 and data['overall_progress']['elapsed_time_seconds'] > 3600: + print(f"Stopping stuck job {job_id} (progress: {progress:.1f}%)") + stop_job(job_id, force=True) + else: + print(f"Job {job_id}: {progress:.1f}% complete") + + except Exception as e: + print(f"Error monitoring job {job_id}: {e}") +``` + +## 🤖 Auto-Configuration System + +### **Smart Partition Sizing by Modality** + +DataJuicer now includes an intelligent auto-configuration system that automatically determines optimal partition sizes based on your data characteristics: + +#### **How It Works** + +1. **Modality Detection**: Analyzes your dataset to detect the primary modality (text, image, audio, video, multimodal) +2. **Dataset Analysis**: Examines sample characteristics (text length, media counts, file sizes) +3. **Pipeline Complexity**: Considers the complexity of your processing operations +4. **Resource Optimization**: Adjusts partition sizes for optimal memory usage and fault tolerance + +#### **Modality-Specific Optimizations** + +| Modality | Default Size | Max Size | Memory Multiplier | Use Case | +|----------|--------------|----------|-------------------|----------| +| **Text** | 5000 samples | 20000 | 1.0x | Efficient processing, low memory, target 64MB partitions | +| **Image** | 1000 samples | 5000 | 5.0x | Moderate memory, image processing, target 64MB partitions | +| **Audio** | 500 samples | 2000 | 8.0x | High memory, audio processing, target 64MB partitions | +| **Video** | 200 samples | 1000 | 20.0x | Very high memory, complex processing, target 64MB partitions | +| **Multimodal** | 800 samples | 3000 | 10.0x | Multiple modalities, moderate complexity, target 64MB partitions | + +#### **Enable Auto-Configuration** + +```yaml +partition: + mode: "auto" # Enable automatic optimization + # Fallback values used if auto-analysis fails + size: 5000 + max_size_mb: 64 +``` + +#### **Manual Override** + +```yaml +partition: + mode: "manual" # Use manual partition configuration + num_of_partitions: 8 # Specify exact number of partitions + # size and max_size_mb are ignored in manual mode +``` + +## 📊 Partition Sizing Guidelines + +### **Why Smaller Partitions Are Better** + +**Fault Tolerance**: Smaller partitions mean smaller units of failure. If a partition fails, you lose less work. + +**Recovery Speed**: Failed partitions can be retried faster, reducing overall job time. + +**Progress Visibility**: More granular progress tracking and faster feedback. + +**Memory Efficiency**: Lower memory usage per partition, better for resource-constrained environments. + +**Debugging**: Easier to isolate and debug issues in smaller chunks. + +### **Partition Size Recommendations** + +| Use Case | Partition Size | When to Use | +|----------|---------------|-------------| +| **Debugging** | 50-100 samples | Quick iterations, testing, small datasets | +| **Production** ⭐ | 100-300 samples | Most use cases, good balance | +| **Large Datasets** | 300-500 samples | Stable processing, large datasets | +| **Very Large** | 500+ samples | Only when failure risk is minimal | + +### **Factors to Consider** + +- **Dataset Size**: Larger datasets can use larger partitions +- **Processing Complexity**: Complex operations benefit from smaller partitions +- **Failure Rate**: Higher failure rates need smaller partitions +- **Memory Constraints**: Limited memory requires smaller partitions +- **Time Sensitivity**: Faster feedback needs smaller partitions + +## 🔧 Implementation Details + +### Core Components + +1. **`EventLoggingMixin`** (`data_juicer/core/executor/event_logging_mixin.py`) + - Provides event logging capabilities to executors + - Manages job-specific directories and flexible storage + - Handles job summary creation and validation + - Implements Spark-style event logging schema + +2. **`PartitionedRayExecutor`** (`data_juicer/core/executor/ray_executor_partitioned.py`) + - Extends Ray executor with partitioning and fault tolerance + - Implements configurable checkpointing strategies + - Integrates with EventLoggingMixin for comprehensive logging + - Handles job resumption from checkpoints + +3. **Configuration Integration** (`data_juicer/config/config.py`) + - Added command-line arguments for job management + - Added checkpointing configuration options + - Added flexible storage path configuration + +### Event Types +- `JOB_START`, `JOB_COMPLETE`, `JOB_FAILED` +- `PARTITION_START`, `PARTITION_COMPLETE`, `PARTITION_FAILED` +- `OP_START`, `OP_COMPLETE`, `OP_FAILED` +- `CHECKPOINT_SAVE`, `CHECKPOINT_LOAD` +- `PROCESSING_START`, `PROCESSING_COMPLETE`, `PROCESSING_ERROR` +- `RESOURCE_USAGE`, `PERFORMANCE_METRIC` +- `WARNING`, `INFO`, `DEBUG` + +## 🎯 Use Cases + +### 1. Large Dataset Processing +- Process datasets that are too large for memory +- Automatic partitioning with fault tolerance +- Resume processing after failures + +### 2. Experimental Workflows +- Track different experiments with meaningful job IDs +- Compare results across different configurations +- Maintain experiment history and reproducibility + +### 3. Production Pipelines +- Robust error handling and recovery +- Comprehensive monitoring and logging +- Flexible storage for different performance requirements + +### 4. Research and Development +- Iterative development with checkpoint resumption +- Detailed event logging for analysis +- Configurable checkpointing for different scenarios + +## 🔍 Troubleshooting + +### Common Issues + +1. **Job resumption fails** + - Check if job summary exists: `ls -la ./outputs/{work_dir}/{job_id}/job_summary.json` + - Verify checkpoint files exist: `ls -la /tmp/large_checkpoints/{job_id}/` + +2. **Event logs not found** + - Check flexible storage paths: `ls -la /tmp/fast_event_logs/{job_id}/` + - Verify event logging is enabled in config + +3. **Checkpointing not working** + - Verify checkpoint strategy in config + - Check if checkpoint directory is writable + - Ensure checkpoint.enabled is true + +4. **Performance issues** + - Adjust partition size based on available memory + - Consider different checkpoint strategies + - Use appropriate storage formats (parquet for large datasets) + +### Debug Commands +```bash +# Check Ray cluster status +ray status + +# View Ray dashboard +open http://localhost:8265 + +# Check DataJuicer logs +tail -f /tmp/fast_event_logs/{job_id}/event_logs/events.log +``` + +## 📊 Understanding Intermediate Data + +### What is Intermediate Data? + +Intermediate data refers to temporary results generated during the processing pipeline that exist between operations and before the final output. In DataJuicer's partitioned processing, this includes: + +1. **Partition-level intermediate data**: Results after each operation within a partition +2. **Operation-level intermediate data**: Data that exists between operations (e.g., after `clean_links_mapper` but before `whitespace_normalization_mapper`) +3. **Checkpoint intermediate data**: Temporary files created during checkpointing + +### When to Preserve Intermediate Data + +**Enable `preserve_intermediate_data: true` when you need:** +- **Debugging**: Inspect what the data looks like after each operation +- **Resumption**: If a job fails, see exactly where it failed and what the data looked like +- **Analysis**: Understand how each operation transforms the data +- **Development**: Iterate on processing pipelines with detailed inspection + +**Disable `preserve_intermediate_data: false` when you want:** +- **Performance**: Faster processing with less disk I/O +- **Storage efficiency**: Reduced disk space usage +- **Production**: Clean processing without temporary file accumulation + +### Example Directory Structure with Intermediate Data + +``` +{job_dir}/intermediate/ +├── partition_000000/ +│ ├── op_000_clean_links_mapper.parquet # After clean_links_mapper +│ ├── op_001_clean_email_mapper.parquet # After clean_email_mapper +│ ├── op_002_whitespace_normalization_mapper.parquet +│ └── op_003_fix_unicode_mapper.parquet # After fix_unicode_mapper +└── partition_000001/ + ├── op_000_clean_links_mapper.parquet + └── ... +``` + +## 📈 Performance Considerations + +### Checkpointing Overhead +- `EVERY_OP`: Highest overhead, maximum resilience +- `EVERY_N_OPS`: Configurable overhead (balance between resilience and performance) +- `MANUAL`: Minimal overhead, requires careful planning +- `DISABLED`: No overhead, no resilience + +### Storage Recommendations +- **Event logs**: Use fast storage (SSD) for real-time monitoring +- **Checkpoints**: Use large capacity storage (HDD/network storage) for cost efficiency +- **Partitions**: Use local storage for processing speed + +### Memory Management +- Adjust `partition_size` based on available memory +- Use `max_partition_size_mb` to limit partition size +- Consider `preserve_intermediate_data` for debugging vs. performance + +## 🎉 Success Metrics + +The implementation successfully demonstrates: +- ✅ **Fault Tolerance**: Jobs can resume after failures +- ✅ **Scalability**: Handles large datasets through partitioning +- ✅ **Observability**: Comprehensive logging and monitoring +- ✅ **Flexibility**: Configurable checkpointing strategies +- ✅ **Usability**: Simple command-line interface with meaningful job IDs +- ✅ **Performance**: Fast resumption from checkpoints +- ✅ **Reliability**: Robust error handling and validation + +## 🔮 Future Enhancements + +Potential areas for future development: +- **Distributed checkpointing**: Multi-node checkpoint coordination +- **Incremental checkpointing**: Only save changed data +- **Checkpoint compression**: Reduce storage requirements +- **Advanced monitoring**: Web-based dashboard for job monitoring +- **Checkpoint versioning**: Support for multiple checkpoint versions +- **Integration with external systems**: Cloud storage, monitoring systems \ No newline at end of file diff --git a/docs/PartitionAndCheckpoint_ZH.md b/docs/PartitionAndCheckpoint_ZH.md new file mode 100644 index 0000000000..8c42953d3a --- /dev/null +++ b/docs/PartitionAndCheckpoint_ZH.md @@ -0,0 +1,801 @@ +# DataJuicer 容错处理与检查点和事件日志记录 + +本目录包含具有全面检查点、分区和事件日志记录功能的容错、可恢复 DataJuicer 处理的实现。 + +## 🚀 已实现功能 + +### ✅ 核心功能 +- **作业特定目录隔离**: 每个作业都有自己专用的目录结构 +- **可配置检查点策略**: 多种检查点频率和策略 +- **Spark 风格事件日志记录**: 用于可恢复性的 JSONL 格式全面事件跟踪 +- **作业恢复功能**: 从最后一个检查点恢复失败或中断的作业 +- **全面作业管理**: 作业摘要、元数据跟踪和恢复命令 + +### ✅ 检查点策略 +- `EVERY_OP`: 每个操作后检查点(最容错,较慢) +- `EVERY_PARTITION`: 仅在分区完成时检查点(平衡) +- `EVERY_N_OPS`: 每 N 个操作后检查点(可配置) +- `MANUAL`: 仅在指定操作后检查点 +- `DISABLED`: 完全禁用检查点 + +### ✅ 事件日志记录 +- **人类可读日志**: 基于 Loguru 的日志记录,用于调试和监控 +- **机器可读日志**: JSONL 格式,用于程序化分析和恢复 +- **全面事件类型**: 作业开始/完成/失败、分区事件、操作事件、检查点事件 +- **实时监控**: 实时事件流和状态报告 + +### ✅ 作业管理 +- **有意义的作业 ID**: 格式:`{YYYYMMDD}_{HHMMSS}_{config_name}_{unique_suffix}` +- **作业摘要文件**: 每个作业运行的全面元数据 +- **恢复命令**: 自动生成恢复作业的确切命令 +- **作业验证**: 验证作业恢复参数和现有状态 + +## 📁 目录结构 + +``` +{work_dir}/ +├── {job_id}/ # 作业特定目录 +│ ├── job_summary.json # 作业元数据和恢复信息(作业完成时创建) +│ ├── events_{timestamp}.jsonl # 机器可读事件(带时间戳的 JSONL 格式) +│ ├── dag_execution_plan.json # DAG 执行计划 +│ ├── partition-checkpoint-eventlog.yaml # 备份的配置文件 +│ ├── metadata/ # 作业元数据文件 +│ │ ├── dataset_mapping.json +│ │ └── final_mapping_report.json +│ ├── logs/ # 人类可读日志 +│ │ ├── export_processed.jsonl_time_*.txt # 主日志文件 +│ │ ├── export_processed.jsonl_time_*_DEBUG.txt # 调试级别日志 +│ │ ├── export_processed.jsonl_time_*_WARNING.txt # 警告级别日志 +│ │ └── export_processed.jsonl_time_*_ERROR.txt # 错误级别日志 +│ ├── checkpoints/ # 检查点数据 +│ │ ├── checkpoint_*.json # 检查点元数据 +│ │ └── partition_*/ # 分区检查点数据 +│ ├── partitions/ # 输入数据分区 +│ ├── processed.jsonl/ # 中间处理结果 +│ └── results/ # 最终处理结果 +``` + +## 🛠️ 配置 + +### 配置结构 + +配置使用**逻辑嵌套结构**,按关注点分组相关设置: + +#### 新的逻辑结构(推荐) +```yaml +# 分区配置 +partition: + size: 1000 # 每个分区的样本数 + max_size_mb: 64 # 分区最大大小(MB) + + + +# 中间存储配置(格式、压缩和生命周期管理) +intermediate_storage: + # 文件格式和压缩 + format: "parquet" # parquet, arrow, jsonl + compression: "snappy" # snappy, gzip, none + use_arrow_batches: true + arrow_batch_size: 500 + arrow_memory_mapping: false + + # 文件生命周期管理 + preserve_intermediate_data: true # 保留临时文件用于调试/恢复 + cleanup_temp_files: true + cleanup_on_success: false + retention_policy: "keep_all" # keep_all, keep_failed_only, cleanup_all + max_retention_days: 7 +``` + +#### 传统扁平结构(仍支持) +```yaml +# 传统扁平配置(仍有效) +partition_size: 1000 +max_partition_size_mb: 64 +preserve_intermediate_data: true +storage_format: "parquet" +use_arrow_batches: true +arrow_batch_size: 500 +arrow_memory_mapping: false +``` + +**注意**: 系统首先从新的嵌套部分读取,如果未找到则回退到传统扁平配置。 + +### 配置部分说明 + +#### `partition` - 分区和容错 +控制数据集如何分割以及如何处理故障: +- **自动配置**(推荐): + - `auto_configure`: 根据数据模态启用自动分区大小优化 +- **手动分区**(当 `auto_configure: false` 时): + - `size`: 每个分区的样本数 + - **50-100**: 调试、快速迭代、小数据集 + - **100-300**: 生产、容错和效率的良好平衡 ⭐ + - **300-500**: 具有稳定处理的大数据集 + - **500+**: 仅适用于故障风险最小的大数据集 + - `max_size_mb`: 分区最大大小(MB) + + +#### `intermediate_storage` - 中间数据管理 +控制中间数据的文件格式、压缩和生命周期管理: +- **文件格式和压缩**: + - `format`: 存储格式(`parquet`、`arrow`、`jsonl`) + - `compression`: 压缩算法(`snappy`、`gzip`、`none`) + - `use_arrow_batches`: 使用 Arrow 批处理 + - `arrow_batch_size`: Arrow 批大小 + - `arrow_memory_mapping`: 启用内存映射 +- **文件生命周期管理**: + - `preserve_intermediate_data`: 保留临时文件用于调试 + - `cleanup_temp_files`: 启用自动清理 + - `cleanup_on_success`: 即使成功完成也清理 + - `retention_policy`: 文件保留策略(`keep_all`、`keep_failed_only`、`cleanup_all`) + - `max_retention_days`: X 天后自动清理 + +### 基本配置 +```yaml +# 启用容错处理 +executor_type: ray_partitioned + +# 作业管理 +job_id: my_experiment_001 # 可选:如果未提供则自动生成 + +# 检查点配置 +checkpoint: + enabled: true + strategy: every_op # every_op, every_partition, every_n_ops, manual, disabled + n_ops: 2 # 用于 every_n_ops 策略 + op_names: # 用于 manual 策略 + - clean_links_mapper + - whitespace_normalization_mapper + +# 事件日志记录配置 +event_logging: + enabled: true + max_log_size_mb: 100 + backup_count: 5 + +# 分区配置 +partition: + # 基本分区设置 + # 推荐分区大小: + # - 50-100: 用于调试、快速迭代、小数据集 + # - 100-300: 用于生产、容错和效率的良好平衡 + # - 300-500: 用于具有稳定处理的大数据集 + # - 500+: 仅适用于故障风险最小的大数据集 + size: 200 # 每个分区的样本数(较小以获得更好的容错性) + max_size_mb: 32 # 分区最大大小(MB)(减少以加快处理速度) + + + +# 中间存储配置(格式、压缩和生命周期管理) +intermediate_storage: + # 文件格式和压缩 + format: "parquet" # parquet, arrow, jsonl + compression: "snappy" # snappy, gzip, none + use_arrow_batches: true + arrow_batch_size: 500 + arrow_memory_mapping: false + + # 文件生命周期管理 + preserve_intermediate_data: true # 保留临时文件用于调试/恢复 + cleanup_temp_files: true + cleanup_on_success: false + retention_policy: "keep_all" # keep_all, keep_failed_only, cleanup_all + max_retention_days: 7 +``` + +## 🚀 快速开始 + +### 1. 基本用法 +```bash +# 使用自动生成的作业 ID 运行 +dj-process --config configs/demo/checkpoint_config_example.yaml + +# 使用自定义作业 ID 运行 +dj-process --config configs/demo/checkpoint_config_example.yaml --job_id my_experiment_001 +``` + +### 2. 恢复作业 +```bash +# 使用作业 ID 恢复 +dj-process --config configs/demo/checkpoint_config_example.yaml --job_id my_experiment_001 +``` + +### 3. 不同的检查点策略 +```bash +# 每个分区检查点 +dj-process --config configs/demo/checkpoint_config_example.yaml --job_id partition_test --checkpoint.strategy every_partition + +# 每 3 个操作检查点 +dj-process --config configs/demo/checkpoint_config_example.yaml --job_id n_ops_test --checkpoint.strategy every_n_ops --checkpoint.n_ops 3 + +# 手动检查点 +dj-process --config configs/demo/checkpoint_config_example.yaml --job_id manual_test --checkpoint.strategy manual --checkpoint.op_names clean_links_mapper,whitespace_normalization_mapper +``` + +### 4. 运行综合演示 +```bash +# 运行展示所有功能的完整演示 +python demos/partition_and_checkpoint/run_demo.py +``` + +## 📊 监控和调试 + +### 查看作业信息 +```bash +# 检查作业摘要(作业完成时创建) +cat ./outputs/partition-checkpoint-eventlog/{job_id}/job_summary.json + +# 查看事件日志(使用带时间戳的最新事件文件) +cat ./outputs/partition-checkpoint-eventlog/{job_id}/events_*.jsonl + +# 查看人类可读日志 +cat ./outputs/partition-checkpoint-eventlog/{job_id}/logs/export_processed.jsonl_time_*.txt + +# 查看 DAG 执行计划 +cat ./outputs/partition-checkpoint-eventlog/{job_id}/dag_execution_plan.json +``` + +### 列出可用作业 +```bash +# 列出所有作业目录 +ls -la ./outputs/partition-checkpoint-eventlog/ +``` + +### 检查作业结构 +```bash +# 检查作业目录结构 +ls -la ./outputs/partition-checkpoint-eventlog/{job_id}/ + +# 检查日志目录 +ls -la ./outputs/partition-checkpoint-eventlog/{job_id}/logs/ + +# 检查检查点目录 +ls -la ./outputs/partition-checkpoint-eventlog/{job_id}/checkpoints/ +``` + +## 📈 作业管理工具 + +DataJuicer 提供全面的作业管理工具,用于监控进度和停止正在运行的作业。这些工具位于 `data_juicer/utils/job/` 中,提供命令行和程序化接口。 + +### 📊 作业进度监控器 + +一个全面的工具,用于监控和显示 DataJuicer 作业的进度信息。显示分区状态、操作进度、检查点和整体作业指标。 + +#### 功能特性 + +- **实时进度跟踪**: 监控具有分区级详细信息的作业进度 +- **操作性能**: 查看详细的操作指标,包括吞吐量和数据减少 +- **检查点监控**: 跟踪检查点保存和恢复点 +- **监视模式**: 连续监控作业,自动更新 +- **程序化访问**: 作为 Python 函数使用,集成到其他工具中 + +#### 命令行用法 + +##### 基本用法 +```bash +# 显示作业的基本进度 +python -m data_juicer.utils.job.monitor 20250728_233517_510abf + +# 显示详细进度和操作指标 +python -m data_juicer.utils.job.monitor 20250728_233517_510abf --detailed + +# 监视模式 - 每 10 秒连续更新进度 +python -m data_juicer.utils.job.monitor 20250728_233517_510abf --watch + +# 监视模式,自定义更新间隔(30 秒) +python -m data_juicer.utils.job.monitor 20250728_233517_510abf --watch --interval 30 + +# 使用自定义基础目录 +python -m data_juicer.utils.job.monitor 20250728_233517_510abf --base-dir /custom/path +``` + +##### 命令行选项 +- `job_id`: 要监控的作业 ID(必需) +- `--base-dir`: 包含作业输出的基础目录(默认:`outputs/partition-checkpoint-eventlog`) +- `--detailed`: 显示详细的操作信息 +- `--watch`: 监视模式 - 连续更新进度 +- `--interval`: 监视模式的更新间隔(秒)(默认:10) + +#### Python API + +##### 基本函数用法 +```python +from data_juicer.utils.job.monitor import show_job_progress + +# 显示进度并获取数据 +data = show_job_progress("20250728_233517_510abf") + +# 显示详细进度 +data = show_job_progress("20250728_233517_510abf", detailed=True) + +# 使用自定义基础目录 +data = show_job_progress("20250728_233517_510abf", base_dir="/custom/path") +``` + +##### 基于类的用法 +```python +from data_juicer.utils.job.monitor import JobProgressMonitor + +# 创建监控器实例 +monitor = JobProgressMonitor("20250728_233517_510abf") + +# 显示进度 +monitor.display_progress(detailed=True) + +# 获取进度数据作为字典 +data = monitor.get_progress_data() + +# 访问特定信息 +job_status = data['overall_progress']['job_status'] +progress_percentage = data['overall_progress']['progress_percentage'] +partition_status = data['partition_status'] +``` + +### 🛑 作业停止器 + +一个工具,通过读取事件日志来查找进程和线程 ID,然后终止这些特定的进程和线程来停止正在运行的 DataJuicer 作业。 + +#### 功能特性 + +- **精确进程终止**: 使用事件日志识别要终止的确切进程和线程 +- **优雅关闭**: 首先发送 SIGTERM 进行优雅关闭,然后在需要时发送 SIGKILL +- **安全检查**: 在停止前验证作业存在性和运行状态 +- **全面日志记录**: 终止过程的详细日志记录 +- **程序化访问**: 可以作为 Python 函数或命令行工具使用 + +#### 命令行用法 + +##### 基本用法 +```bash +# 优雅地停止作业(SIGTERM) +python -m data_juicer.utils.job.stopper 20250728_233517_510abf + +# 强制停止作业(SIGKILL) +python -m data_juicer.utils.job.stopper 20250728_233517_510abf --force + +# 使用自定义超时停止(60 秒) +python -m data_juicer.utils.job.stopper 20250728_233517_510abf --timeout 60 + +# 使用自定义基础目录 +python -m data_juicer.utils.job.stopper 20250728_233517_510abf --base-dir /custom/path + +# 列出所有正在运行的作业 +python -m data_juicer.utils.job.stopper --list +``` + +##### 命令行选项 +- `job_id`: 要停止的作业 ID(必需,除非使用 --list) +- `--base-dir`: 包含作业输出的基础目录(默认:`outputs/partition-checkpoint-eventlog`) +- `--force`: 使用 SIGKILL 强制杀死而不是优雅的 SIGTERM +- `--timeout`: 优雅关闭的超时时间(秒)(默认:30) +- `--list`: 列出所有正在运行的作业而不是停止一个 + +#### Python API + +##### 基本函数用法 +```python +from data_juicer.utils.job.stopper import stop_job + +# 优雅地停止作业 +result = stop_job("20250728_233517_510abf") + +# 强制停止作业 +result = stop_job("20250728_233517_510abf", force=True) + +# 使用自定义超时停止 +result = stop_job("20250728_233517_510abf", timeout=60) + +# 使用自定义基础目录 +result = stop_job("20250728_233517_510abf", base_dir="/custom/path") +``` + +##### 基于类的用法 +```python +from data_juicer.utils.job.stopper import JobStopper + +# 创建停止器实例 +stopper = JobStopper("20250728_233517_510abf") + +# 停止作业 +result = stopper.stop_job(force=False, timeout=30) + +# 检查作业是否正在运行 +is_running = stopper.is_job_running() + +# 获取作业摘要 +summary = stopper.get_job_summary() +``` + +### 🔧 通用工具 + +监控器和停止器工具都通过 `data_juicer.utils.job.common` 共享通用功能: + +```python +from data_juicer.utils.job.common import JobUtils, list_running_jobs + +# 列出所有正在运行的作业 +running_jobs = list_running_jobs() + +# 创建作业工具实例 +job_utils = JobUtils("20250728_233517_510abf") + +# 加载作业摘要 +summary = job_utils.load_job_summary() + +# 加载事件日志 +events = job_utils.load_event_logs() + +# 获取分区状态 +partition_status = job_utils.get_partition_status() +``` + +### 输出信息 + +#### 作业概览 +- 作业状态(已完成、处理中、失败等) +- 数据集路径和大小 +- 分区配置 +- 开始时间和持续时间 + +#### 整体进度 +- 进度百分比 +- 分区完成状态 +- 样本处理计数 +- 估计剩余时间(对于运行中的作业) + +#### 分区状态 +- 带有视觉指示器的单个分区状态 +- 每个分区的样本计数 +- 当前操作(如果正在处理) +- 已完成操作的数量 +- 已保存检查点的数量 + +#### 操作详情(使用 --detailed 标志) +- 每个分区的操作性能 +- 持续时间、吞吐量和数据减少指标 +- 操作完成顺序 + +#### 检查点摘要 +- 已保存检查点的总数 +- 按分区和操作的检查点详情 +- 时间戳信息 + +### 示例输出 + +``` +================================================================================ +DataJuicer 作业进度监控器 +作业 ID: 20250728_233517_510abf +================================================================================ + +📊 作业概览 + 状态: 已完成 + 数据集: /Users/yilei.z/Downloads/c4-train.00000-of-01024.jsonl + 总样本数: 356,317 + 分区大小: 50,000 样本 + 开始时间: 2025-07-28 16:35:18 + 持续时间: 441.1 秒 + +🎯 整体进度 + 进度: 100.0% (8/8 分区) + 状态: 8 已完成, 0 处理中, 0 失败 + 样本: 356,317/356,317 + +📦 分区状态 + 分区 0: ✅ 已完成 + 样本: 44,539 + 已完成: 8 个操作 + 检查点: 2 个已保存 + 分区 1: ✅ 已完成 + 样本: 44,540 + 已完成: 8 个操作 + 检查点: 2 个已保存 + ... + +💾 检查点摘要 + 总检查点: 16 +``` + +### 集成示例 + +#### 监控多个作业 +```python +from data_juicer.utils.job.monitor import show_job_progress + +job_ids = ["job1", "job2", "job3"] +for job_id in job_ids: + try: + data = show_job_progress(job_id) + print(f"作业 {job_id}: {data['overall_progress']['progress_percentage']:.1f}%") + except FileNotFoundError: + print(f"作业 {job_id}: 未找到") +``` + +#### 自定义监控脚本 +```python +from data_juicer.utils.job.monitor import JobProgressMonitor +import time + +def monitor_job_until_completion(job_id, check_interval=30): + monitor = JobProgressMonitor(job_id) + + while True: + data = monitor.get_progress_data() + status = data['overall_progress']['job_status'] + + if status == 'completed': + print(f"作业 {job_id} 已完成!") + break + elif status == 'failed': + print(f"作业 {job_id} 失败!") + break + + print(f"作业 {job_id} 仍在运行... {data['overall_progress']['progress_percentage']:.1f}%") + time.sleep(check_interval) +``` + +#### 作业管理工作流 +```python +from data_juicer.utils.job.monitor import show_job_progress +from data_juicer.utils.job.stopper import stop_job +from data_juicer.utils.job.common import list_running_jobs + +# 列出所有正在运行的作业 +running_jobs = list_running_jobs() +print(f"发现 {len(running_jobs)} 个正在运行的作业") + +# 监控并可能停止作业 +for job_info in running_jobs: + job_id = job_info['job_id'] + + # 检查进度 + try: + data = show_job_progress(job_id) + progress = data['overall_progress']['progress_percentage'] + + # 停止卡住的作业(1小时后进度仍少于10%) + if progress < 10 and data['overall_progress']['elapsed_time_seconds'] > 3600: + print(f"停止卡住的作业 {job_id}(进度: {progress:.1f}%)") + stop_job(job_id, force=True) + else: + print(f"作业 {job_id}: {progress:.1f}% 完成") + + except Exception as e: + print(f"监控作业 {job_id} 时出错: {e}") +``` + +## 🤖 自动配置系统 + +### **按模态智能分区大小调整** + +DataJuicer 现在包含一个智能自动配置系统,可以根据您的数据特征自动确定最佳分区大小: + +#### **工作原理** + +1. **模态检测**: 分析您的数据集以检测主要模态(文本、图像、音频、视频、多模态) +2. **数据集分析**: 检查样本特征(文本长度、媒体数量、文件大小) +3. **管道复杂性**: 考虑处理操作的复杂性 +4. **资源优化**: 调整分区大小以获得最佳内存使用和容错性 + +#### **模态特定优化** + +| 模态 | 默认大小 | 最大大小 | 内存倍数 | 使用场景 | +|------|----------|----------|----------|----------| +| **文本** | 200 样本 | 1000 | 1.0x | 高效处理,低内存 | +| **图像** | 50 样本 | 200 | 5.0x | 中等内存,图像处理 | +| **音频** | 30 样本 | 100 | 8.0x | 高内存,音频处理 | +| **视频** | 10 样本 | 50 | 20.0x | 极高内存,复杂处理 | +| **多模态** | 20 样本 | 100 | 10.0x | 多种模态,中等复杂性 | + +#### **启用自动配置** + +```yaml +partition: + auto_configure: true # 启用自动优化 + # 当 auto_configure 为 true 时忽略手动设置 + size: 200 + max_size_mb: 32 +``` + +#### **手动覆盖** + +```yaml +partition: + auto_configure: false # 禁用自动配置 + size: 100 # 使用您自己的分区大小 + max_size_mb: 64 +``` + +## 📊 分区大小指南 + +### **为什么较小的分区更好** + +**容错性**: 较小的分区意味着较小的故障单元。如果分区失败,您损失的工作更少。 + +**恢复速度**: 失败的分区可以更快地重试,减少总体作业时间。 + +**进度可见性**: 更细粒度的进度跟踪和更快的反馈。 + +**内存效率**: 每个分区更低的内存使用,更适合资源受限的环境。 + +**调试**: 更容易隔离和调试较小块中的问题。 + +### **分区大小建议** + +| 使用场景 | 分区大小 | 何时使用 | +|----------|----------|----------| +| **调试** | 50-100 样本 | 快速迭代、测试、小数据集 | +| **生产** ⭐ | 100-300 样本 | 大多数用例,良好平衡 | +| **大数据集** | 300-500 样本 | 稳定处理,大数据集 | +| **超大** | 500+ 样本 | 仅在故障风险最小时 | + +### **需要考虑的因素** + +- **数据集大小**: 较大的数据集可以使用较大的分区 +- **处理复杂性**: 复杂操作受益于较小的分区 +- **故障率**: 较高的故障率需要较小的分区 +- **内存约束**: 有限的内存需要较小的分区 +- **时间敏感性**: 更快的反馈需要较小的分区 + +## 🔧 实现细节 + +### 核心组件 + +1. **`EventLoggingMixin`** (`data_juicer/core/executor/event_logging_mixin.py`) + - 为执行器提供事件日志记录功能 + - 管理作业特定目录和灵活存储 + - 处理作业摘要创建和验证 + - 实现 Spark 风格事件日志记录模式 + +2. **`PartitionedRayExecutor`** (`data_juicer/core/executor/ray_executor_partitioned.py`) + - 使用分区和容错扩展 Ray 执行器 + - 实现可配置检查点策略 + - 与 EventLoggingMixin 集成以进行全面日志记录 + - 处理从检查点恢复作业 + +3. **配置集成** (`data_juicer/config/config.py`) + - 添加了作业管理的命令行参数 + - 添加了检查点配置选项 + - 添加了灵活存储路径配置 + +### 事件类型 +- `JOB_START`, `JOB_COMPLETE`, `JOB_FAILED` +- `PARTITION_START`, `PARTITION_COMPLETE`, `PARTITION_FAILED` +- `OP_START`, `OP_COMPLETE`, `OP_FAILED` +- `CHECKPOINT_SAVE`, `CHECKPOINT_LOAD` +- `PROCESSING_START`, `PROCESSING_COMPLETE`, `PROCESSING_ERROR` +- `RESOURCE_USAGE`, `PERFORMANCE_METRIC` +- `WARNING`, `INFO`, `DEBUG` + +## 🎯 使用场景 + +### 1. 大数据集处理 +- 处理对于内存来说太大的数据集 +- 具有容错的自动分区 +- 故障后恢复处理 + +### 2. 实验工作流 +- 使用有意义的作业 ID 跟踪不同实验 +- 比较不同配置的结果 +- 维护实验历史和可重现性 + +### 3. 生产管道 +- 强大的错误处理和恢复 +- 全面监控和日志记录 +- 不同性能要求的灵活存储 + +### 4. 研究和开发 +- 具有检查点恢复的迭代开发 +- 用于分析的详细事件日志记录 +- 不同场景的可配置检查点 + +## 🔍 故障排除 + +### 常见问题 + +1. **作业恢复失败** + - 检查作业摘要是否存在:`ls -la ./outputs/{work_dir}/{job_id}/job_summary.json` + - 验证检查点文件是否存在:`ls -la /tmp/large_checkpoints/{job_id}/` + +2. **找不到事件日志** + - 检查灵活存储路径:`ls -la /tmp/fast_event_logs/{job_id}/` + - 验证配置中是否启用了事件日志记录 + +3. **检查点不工作** + - 验证配置中的检查点策略 + - 检查检查点目录是否可写 + - 确保 checkpoint.enabled 为 true + +4. **性能问题** + - 根据可用内存调整分区大小 + - 考虑不同的检查点策略 + - 使用适当的存储格式(大数据集使用 parquet) + +### 调试命令 +```bash +# 检查 Ray 集群状态 +ray status + +# 查看 Ray 仪表板 +open http://localhost:8265 + +# 检查 DataJuicer 日志 +tail -f /tmp/fast_event_logs/{job_id}/event_logs/events.log +``` + +## 📊 理解中间数据 + +### 什么是中间数据? + +中间数据是指在处理管道期间生成的临时结果,存在于操作之间和最终输出之前。在 DataJuicer 的分区处理中,这包括: + +1. **分区级中间数据**: 分区内每个操作后的结果 +2. **操作级中间数据**: 操作之间存在的数据(例如,在 `clean_links_mapper` 之后但在 `whitespace_normalization_mapper` 之前) +3. **检查点中间数据**: 检查点期间创建的临时文件 + +### 何时保留中间数据 + +**当您需要以下功能时启用 `preserve_intermediate_data: true`:** +- **调试**: 检查每个操作后数据的样貌 +- **恢复**: 如果作业失败,查看确切失败位置和数据样貌 +- **分析**: 了解每个操作如何转换数据 +- **开发**: 通过详细检查迭代处理管道 + +**当您想要以下功能时禁用 `preserve_intermediate_data: false`:** +- **性能**: 更快的处理,更少的磁盘 I/O +- **存储效率**: 减少磁盘空间使用 +- **生产**: 无临时文件累积的清洁处理 + +### 带有中间数据的目录结构示例 + +``` +{job_dir}/intermediate/ +├── partition_000000/ +│ ├── op_000_clean_links_mapper.parquet # clean_links_mapper 之后 +│ ├── op_001_clean_email_mapper.parquet # clean_email_mapper 之后 +│ ├── op_002_whitespace_normalization_mapper.parquet +│ └── op_003_fix_unicode_mapper.parquet # fix_unicode_mapper 之后 +└── partition_000001/ + ├── op_000_clean_links_mapper.parquet + └── ... +``` + +## 📈 性能考虑 + +### 检查点开销 +- `EVERY_OP`: 最高开销,最大容错性 +- `EVERY_PARTITION`: 平衡的开销和容错性 +- `EVERY_N_OPS`: 可配置开销 +- `MANUAL`: 最小开销,需要仔细规划 + +### 存储建议 +- **事件日志**: 使用快速存储(SSD)进行实时监控 +- **检查点**: 使用大容量存储(HDD/网络存储)以提高成本效率 +- **分区**: 使用本地存储以提高处理速度 + +### 内存管理 +- 根据可用内存调整 `partition_size` +- 使用 `max_partition_size_mb` 限制分区大小 +- 考虑 `preserve_intermediate_data` 用于调试与性能 + +## 🎉 成功指标 + +实现成功展示了: +- ✅ **容错性**: 作业可以在故障后恢复 +- ✅ **可扩展性**: 通过分区处理大数据集 +- ✅ **可观察性**: 全面日志记录和监控 +- ✅ **灵活性**: 可配置检查点策略 +- ✅ **可用性**: 具有有意义的作业 ID 的简单命令行界面 +- ✅ **性能**: 从检查点快速恢复 +- ✅ **可靠性**: 强大的错误处理和验证 + +## 🔮 未来增强 + +未来开发的潜在领域: +- **分布式检查点**: 多节点检查点协调 +- **增量检查点**: 仅保存更改的数据 +- **检查点压缩**: 减少存储要求 +- **高级监控**: 用于作业监控的基于 Web 的仪表板 +- **检查点版本控制**: 支持多个检查点版本 +- **与外部系统集成**: 云存储、监控系统 \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index fcb6ede4e0..eb61ae3c45 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -122,6 +122,7 @@ distributed = [ "uvloop==0.21.0", # avoid async error before it's fixed in uvloop "pyspark==3.5.5", # distributed data processing "s3fs", # S3 filesystem support for cloud storage + "boto3", # AWS SDK for S3 operations "bitarray", # efficient arrays of booleans ] @@ -209,6 +210,7 @@ extend-ignore = [ "E203", # whitespace before ':' (black handles this) "E501", # line too long (black handles this) "BLK100", # black would make changes (black handles this) + "F541", # f-string is missing placeholders ] [tool.black] @@ -217,3 +219,14 @@ target-version = ['py310'] [tool.isort] profile = "black" + +[tool.pytest.ini_options] +testpaths = ["tests"] +python_files = ["test_*.py"] +python_classes = ["*Test"] +python_functions = ["test_*"] +addopts = "-v" +filterwarnings = [ + "ignore::DeprecationWarning", + "ignore::PendingDeprecationWarning", +] diff --git a/tests/config/test_config.py b/tests/config/test_config.py index 425e256c2b..ff8036b261 100644 --- a/tests/config/test_config.py +++ b/tests/config/test_config.py @@ -2,12 +2,14 @@ import sys import copy import unittest +import tempfile +import yaml from contextlib import redirect_stdout, redirect_stderr from io import StringIO from jsonargparse import Namespace, namespace_to_dict -from data_juicer.config import init_configs, get_default_cfg, update_op_attr, export_config, merge_config, prepare_side_configs +from data_juicer.config import init_configs, get_default_cfg, validate_work_dir_config, resolve_job_id, resolve_job_directories, update_op_attr, export_config, merge_config, prepare_side_configs from data_juicer.ops import load_ops from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, TEST_TAG from data_juicer.utils.constant import RAY_JOB_ENV_VAR @@ -63,6 +65,9 @@ def test_yaml_cfg_file(self): cfg = init_configs(args=f'--config {test_yaml_path}'.split()) self.assertIsInstance(cfg, Namespace) self.assertEqual(cfg.project_name, 'test_demo') + + # work_dir now includes job_id suffix due to resolve_job_directories + expected_work_dir = cfg.work_dir self.assertDictEqual( cfg.process[0], { 'whitespace_normalization_mapper': { @@ -85,7 +90,7 @@ def test_yaml_cfg_file(self): 'turbo': False, 'index_key': None, 'skip_op_error': True, - 'work_dir': WORKDIR, + 'work_dir': expected_work_dir, 'cpu_required': None, 'gpu_required': None, 'mem_required': None, @@ -123,7 +128,7 @@ def test_yaml_cfg_file(self): 'num_gpus': None, 'index_key': None, 'skip_op_error': True, - 'work_dir': WORKDIR, + 'work_dir': expected_work_dir, 'cpu_required': None, 'gpu_required': None, 'mem_required': None, @@ -182,6 +187,8 @@ def test_mixture_cfg(self): '--language_id_score_filter.lang=en ' '--language_id_score_filter.min_score=0.5'.split()) print(f'ori_cfg.process[1] = {ori_cfg.process[1]}') + # work_dir now includes job_id suffix due to resolve_job_directories + expected_work_dir = ori_cfg.work_dir self.assertDictEqual( ori_cfg.process[1], { 'language_id_score_filter': { @@ -210,7 +217,7 @@ def test_mixture_cfg(self): 'turbo': False, 'index_key': None, 'skip_op_error': True, - 'work_dir': WORKDIR, + 'work_dir': expected_work_dir, 'cpu_required': None, 'gpu_required': None, 'mem_required': None, @@ -220,6 +227,8 @@ def test_mixture_cfg(self): 'auto_op_parallelism': True } }) + # work_dir now includes job_id suffix due to resolve_job_directories + expected_work_dir_1 = mixed_cfg_1.work_dir self.assertDictEqual( mixed_cfg_1.process[1], { 'language_id_score_filter': { @@ -248,7 +257,7 @@ def test_mixture_cfg(self): 'num_gpus': None, 'index_key': None, 'skip_op_error': True, - 'work_dir': WORKDIR, + 'work_dir': expected_work_dir_1, 'cpu_required': None, 'gpu_required': None, 'mem_required': None, @@ -258,6 +267,8 @@ def test_mixture_cfg(self): 'auto_op_parallelism': True } }) + # work_dir now includes job_id suffix due to resolve_job_directories + expected_work_dir_2 = mixed_cfg_2.work_dir self.assertDictEqual( mixed_cfg_2.process[1], { 'language_id_score_filter': { @@ -286,7 +297,7 @@ def test_mixture_cfg(self): 'num_gpus': None, 'index_key': None, 'skip_op_error': True, - 'work_dir': WORKDIR, + 'work_dir': expected_work_dir_2, 'cpu_required': None, 'gpu_required': None, 'mem_required': None, @@ -296,6 +307,8 @@ def test_mixture_cfg(self): 'auto_op_parallelism': True } }) + # work_dir now includes job_id suffix due to resolve_job_directories + expected_work_dir_3 = mixed_cfg_3.work_dir self.assertDictEqual( mixed_cfg_3.process[1], { 'language_id_score_filter': { @@ -324,7 +337,7 @@ def test_mixture_cfg(self): 'num_gpus': None, 'index_key': None, 'skip_op_error': True, - 'work_dir': WORKDIR, + 'work_dir': expected_work_dir_3, 'cpu_required': None, 'gpu_required': None, 'mem_required': None, @@ -334,6 +347,8 @@ def test_mixture_cfg(self): 'auto_op_parallelism': True } }) + # work_dir now includes job_id suffix due to resolve_job_directories + expected_work_dir_4 = mixed_cfg_4.work_dir self.assertDictEqual( mixed_cfg_4.process[1], { 'language_id_score_filter': { @@ -362,7 +377,7 @@ def test_mixture_cfg(self): 'num_gpus': None, 'index_key': None, 'skip_op_error': True, - 'work_dir': WORKDIR, + 'work_dir': expected_work_dir_4, 'cpu_required': None, 'gpu_required': None, 'mem_required': None, @@ -746,6 +761,317 @@ def process_single(self, data): os.environ[RAY_JOB_ENV_VAR] = "0" + def test_validate_work_dir_config_valid_cases(self): + """Test validate_work_dir_config with valid configurations.""" + valid_configs = [ + './outputs/my_project/{job_id}', + '/data/experiments/{job_id}', + 'outputs/{job_id}', + './{job_id}', + 'C:/data/projects/{job_id}', + '/home/user/data/{job_id}', + 'relative/path/to/{job_id}', + '{job_id}', # Just job_id alone + ] + + for work_dir in valid_configs: + with self.subTest(work_dir=work_dir): + # Should not raise any exception + validate_work_dir_config(work_dir) + + def test_validate_work_dir_config_invalid_cases(self): + """Test validate_work_dir_config with invalid configurations.""" + invalid_configs = [ + './outputs/{job_id}/results', + './{job_id}/outputs/data', + 'outputs/{job_id}/intermediate/stuff', + 'data/{job_id}/processed/results', + '/home/user/{job_id}/data/outputs', + 'C:/data/{job_id}/projects/results', + 'relative/{job_id}/path/to/data', + 'outputs/data/{job_id}/processed', + ] + + for work_dir in invalid_configs: + with self.subTest(work_dir=work_dir): + with self.assertRaises(ValueError) as cm: + validate_work_dir_config(work_dir) + + # Check that the error message is helpful + error_msg = str(cm.exception) + self.assertIn('{job_id}', error_msg) + self.assertIn('must be the last part', error_msg) + self.assertIn('Expected format', error_msg) + + def test_validate_work_dir_config_no_job_id(self): + """Test validate_work_dir_config with configurations that don't contain {job_id}.""" + no_job_id_configs = [ + './outputs/my_project', + '/data/experiments', + 'outputs', + './', + 'C:/data/projects', + '/home/user/data', + 'relative/path/to', + '', # Empty string + ] + + for work_dir in no_job_id_configs: + with self.subTest(work_dir=work_dir): + # Should not raise any exception + validate_work_dir_config(work_dir) + + def test_resolve_job_id_with_placeholder(self): + """Test resolve_job_id when {job_id} placeholder is present.""" + cfg = Namespace() + cfg.work_dir = './outputs/my_project/{job_id}' + cfg.export_path = './outputs/{job_id}/results.jsonl' + + # Should auto-generate job_id + cfg = resolve_job_id(cfg) + + self.assertIsNotNone(cfg.job_id) + self.assertFalse(cfg._user_provided_job_id) + self.assertIsInstance(cfg.job_id, str) + # Job ID should be in format: YYYYMMDD_HHMMSS_xxxxxx + self.assertRegex(cfg.job_id, r'^\d{8}_\d{6}_[a-f0-9]{6}$') + + def test_resolve_job_id_without_placeholder(self): + """Test resolve_job_id when no {job_id} placeholder is present.""" + cfg = Namespace() + cfg.work_dir = './outputs/my_project' + cfg.export_path = './outputs/results.jsonl' + + # Should still auto-generate job_id (fallback behavior) + cfg = resolve_job_id(cfg) + + self.assertIsNotNone(cfg.job_id) + self.assertFalse(cfg._user_provided_job_id) + self.assertIsInstance(cfg.job_id, str) + self.assertRegex(cfg.job_id, r'^\d{8}_\d{6}_[a-f0-9]{6}$') + + def test_resolve_job_id_user_provided(self): + """Test resolve_job_id when user provides job_id.""" + cfg = Namespace() + cfg.job_id = 'my_custom_job_123' + cfg.work_dir = './outputs/my_project/{job_id}' + + cfg = resolve_job_id(cfg) + + self.assertEqual(cfg.job_id, 'my_custom_job_123') + self.assertTrue(cfg._user_provided_job_id) + + def test_resolve_job_directories_with_job_id_at_end(self): + """Test resolve_job_directories when {job_id} is at the end of work_dir.""" + cfg = Namespace() + cfg.work_dir = './outputs/my_project/{job_id}' + cfg.job_id = '20250804_143022_abc123' + + cfg = resolve_job_directories(cfg) + + # work_dir should be substituted + self.assertEqual(cfg.work_dir, './outputs/my_project/20250804_143022_abc123') + # Other directories should be under job_dir + self.assertEqual(cfg.event_log_dir, './outputs/my_project/20250804_143022_abc123/logs') + self.assertEqual(cfg.checkpoint_dir, './outputs/my_project/20250804_143022_abc123/checkpoints') + self.assertEqual(cfg.partition_dir, './outputs/my_project/20250804_143022_abc123/partitions') + self.assertEqual(cfg.metadata_dir, './outputs/my_project/20250804_143022_abc123/metadata') + self.assertEqual(cfg.results_dir, './outputs/my_project/20250804_143022_abc123/results') + self.assertEqual(cfg.event_log_file, './outputs/my_project/20250804_143022_abc123/events.jsonl') + + def test_resolve_job_directories_without_job_id_placeholder(self): + """Test resolve_job_directories when work_dir doesn't contain {job_id}.""" + cfg = Namespace() + cfg.job_id = '20250804_143022_abc123' + cfg.work_dir = './outputs/my_project' + cfg = resolve_job_directories(cfg) + + self.assertEqual(cfg.work_dir, './outputs/my_project/20250804_143022_abc123') + self.assertEqual(cfg.event_log_dir, './outputs/my_project/20250804_143022_abc123/logs') + self.assertEqual(cfg.checkpoint_dir, './outputs/my_project/20250804_143022_abc123/checkpoints') + + def test_resolve_job_directories_placeholder_substitution(self): + """Test that placeholders are properly substituted in all relevant paths.""" + cfg = Namespace() + cfg.work_dir = './outputs/{job_id}' + cfg.export_path = '{work_dir}/results.jsonl' + cfg.event_log_dir = '{work_dir}/logs' + cfg.checkpoint_dir = '{work_dir}/checkpoints' + cfg.partition_dir = '{work_dir}/partitions' + cfg.job_id = '20250804_143022_abc123' + + cfg = resolve_job_directories(cfg) + + # All placeholders should be substituted + self.assertEqual(cfg.work_dir, './outputs/20250804_143022_abc123') + self.assertEqual(cfg.export_path, './outputs/20250804_143022_abc123/results.jsonl') + # Note: event_log_dir is overridden by the system to use standard 'logs' directory + self.assertEqual(cfg.event_log_dir, './outputs/20250804_143022_abc123/logs') + self.assertEqual(cfg.checkpoint_dir, './outputs/20250804_143022_abc123/checkpoints') + self.assertEqual(cfg.partition_dir, './outputs/20250804_143022_abc123/partitions') + self.assertEqual(cfg.metadata_dir, './outputs/20250804_143022_abc123/metadata') + self.assertEqual(cfg.results_dir, './outputs/20250804_143022_abc123/results') + self.assertEqual(cfg.event_log_file, './outputs/20250804_143022_abc123/events.jsonl') + + def test_resolve_job_directories_missing_job_id(self): + """Test resolve_job_directories when job_id is not set.""" + cfg = Namespace() + cfg.work_dir = './outputs/my_project' + + with self.assertRaises(ValueError) as cm: + resolve_job_directories(cfg) + + self.assertIn('job_id must be set', str(cm.exception)) + + def test_resolve_job_directories_invalid_work_dir(self): + """Test resolve_job_directories with invalid work_dir containing {job_id} in middle.""" + cfg = Namespace() + cfg.work_dir = './outputs/{job_id}/results' + cfg.job_id = '20250804_143022_abc123' + + with self.assertRaises(ValueError) as cm: + resolve_job_directories(cfg) + + error_msg = str(cm.exception) + self.assertIn('{job_id}', error_msg) + self.assertIn('must be the last part', error_msg) + + def test_full_config_loading_with_job_id_placeholder(self): + """Test full config loading with {job_id} placeholder in work_dir.""" + # Create a temporary config file + config_data = { + 'dataset_path': './demos/data/demo-dataset.jsonl', + 'work_dir': './outputs/test_project/{job_id}', + 'export_path': '{work_dir}/results.jsonl', + 'process': [ + {'whitespace_normalization_mapper': {'text_key': 'text'}} + ] + } + + with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: + yaml.dump(config_data, f) + temp_config_path = f.name + + try: + out = StringIO() + with redirect_stdout(out): + cfg = init_configs(args=['--config', temp_config_path]) + + # Verify job_id was auto-generated + self.assertIsNotNone(cfg.job_id) + self.assertRegex(cfg.job_id, r'^\d{8}_\d{6}_[a-f0-9]{6}$') + + # Verify work_dir was substituted + self.assertIn(cfg.job_id, cfg.work_dir) + self.assertNotIn('{job_id}', cfg.work_dir) + + # Verify export_path was substituted + self.assertIn(cfg.job_id, cfg.export_path) + self.assertNotIn('{work_dir}', cfg.export_path) + + finally: + os.unlink(temp_config_path) + + def test_full_config_loading_without_job_id_placeholder(self): + """Test full config loading without {job_id} placeholder in work_dir.""" + # Create a temporary config file + config_data = { + 'dataset_path': './demos/data/demo-dataset.jsonl', + 'work_dir': './outputs/test_project', + 'export_path': '{work_dir}/results.jsonl', + 'process': [ + {'whitespace_normalization_mapper': {'text_key': 'text'}} + ] + } + + with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: + yaml.dump(config_data, f) + temp_config_path = f.name + + try: + out = StringIO() + with redirect_stdout(out): + cfg = init_configs(args=['--config', temp_config_path]) + + # Verify job_id was auto-generated + self.assertIsNotNone(cfg.job_id) + self.assertRegex(cfg.job_id, r'^\d{8}_\d{6}_[a-f0-9]{6}$') + + # Verify work_dir + self.assertEqual(cfg.work_dir, f'./outputs/test_project/{cfg.job_id}') + + # Note: When there's no {job_id} placeholder, {work_dir} in export_path is still substituted + # The system substitutes {work_dir} with the actual work_dir value + self.assertNotIn('{work_dir}', cfg.export_path) + self.assertIn('./outputs/test_project', cfg.export_path) + self.assertNotIn(cfg.job_id, cfg.export_path) + + finally: + os.unlink(temp_config_path) + + def test_full_config_loading_invalid_work_dir(self): + """Test full config loading with invalid work_dir containing {job_id} in middle.""" + # Create a temporary config file with invalid work_dir + config_data = { + 'dataset_path': './demos/data/demo-dataset.jsonl', + 'work_dir': './outputs/{job_id}/results', # Invalid: {job_id} not at end + 'export_path': '{work_dir}/results.jsonl', + 'process': [ + {'whitespace_normalization_mapper': {'text_key': 'text'}} + ] + } + + with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: + yaml.dump(config_data, f) + temp_config_path = f.name + + try: + out = StringIO() + with redirect_stdout(out), redirect_stderr(out): + with self.assertRaises(ValueError) as cm: + init_configs(args=['--config', temp_config_path]) + + error_msg = str(cm.exception) + self.assertIn('{job_id}', error_msg) + self.assertIn('must be the last part', error_msg) + + finally: + os.unlink(temp_config_path) + + def test_user_provided_job_id(self): + """Test config loading with user-provided job_id.""" + # Create a temporary config file + config_data = { + 'dataset_path': './demos/data/demo-dataset.jsonl', + 'work_dir': './outputs/test_project/{job_id}', + 'export_path': '{work_dir}/results.jsonl', + 'process': [ + {'whitespace_normalization_mapper': {'text_key': 'text'}} + ] + } + + with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: + yaml.dump(config_data, f) + temp_config_path = f.name + + try: + out = StringIO() + with redirect_stdout(out): + # Test with user-provided job_id + cfg = init_configs(args=[ + '--config', temp_config_path, + '--job_id', 'my_custom_job_123' + ]) + + # Verify user-provided job_id was used + self.assertEqual(cfg.job_id, 'my_custom_job_123') + self.assertTrue(cfg._user_provided_job_id) + + # Verify work_dir was substituted + self.assertEqual(cfg.work_dir, './outputs/test_project/my_custom_job_123') + + finally: + os.unlink(temp_config_path) if __name__ == '__main__': unittest.main() diff --git a/tests/core/executor/test_dag.py b/tests/core/executor/test_dag.py new file mode 100644 index 0000000000..a326031158 --- /dev/null +++ b/tests/core/executor/test_dag.py @@ -0,0 +1,193 @@ +#!/usr/bin/env python3 +""" +Tests for DAG Execution functionality. + +This module tests the strategy-based DAG execution planning +capabilities of the Data-Juicer system. +""" + +import os +import tempfile +import unittest + +from data_juicer.core.pipeline_dag import PipelineDAG, DAGNodeStatus +from data_juicer.core.executor.dag_execution_strategies import ( + NonPartitionedDAGStrategy, + PartitionedDAGStrategy, + is_global_operation +) +from data_juicer.ops import load_ops + + +# Note: PipelineAST tests removed - AST functionality was removed in favor of strategy-based DAG building + + +class TestPipelineDAG(unittest.TestCase): + """Test DAG execution planning functionality.""" + + def setUp(self): + """Set up test fixtures.""" + self.temp_dir = tempfile.mkdtemp() + self.dag = PipelineDAG(self.temp_dir) + self.sample_config = { + "process": [ + {"text_length_filter": {"min_len": 10, "max_len": 1000}}, + {"character_repetition_filter": {"rep_len": 3}}, + {"words_num_filter": {"min_num": 5, "max_num": 1000}}, + {"language_id_score_filter": {"lang": "en", "min_score": 0.8}}, + {"document_deduplicator": {}}, + {"clean_email_mapper": {}}, + {"clean_links_mapper": {}}, + ] + } + + def tearDown(self): + """Clean up test fixtures.""" + import shutil + shutil.rmtree(self.temp_dir, ignore_errors=True) + + def _build_dag_from_config(self): + """Helper method to build DAG from config using strategy-based approach.""" + # Load operations from config + operations = load_ops(self.sample_config["process"]) + + # Create strategy and build DAG + strategy = NonPartitionedDAGStrategy() + nodes = strategy.generate_dag_nodes(operations) + strategy.build_dependencies(nodes, operations) + + # Assign nodes to DAG + self.dag.nodes = nodes + + def test_dag_build_from_strategy(self): + """Test building DAG using strategy-based approach.""" + self._build_dag_from_config() + + self.assertGreater(len(self.dag.nodes), 0) + # Note: execution_plan is not populated by strategies currently + # self.assertGreater(len(self.dag.execution_plan), 0) + + def test_dag_execution_plan_save_load(self): + """Test saving and loading execution plans.""" + self._build_dag_from_config() + + # Save execution plan + plan_path = self.dag.save_execution_plan() + self.assertTrue(os.path.exists(plan_path)) + + # Load execution plan + new_dag = PipelineDAG(self.temp_dir) + success = new_dag.load_execution_plan() + self.assertTrue(success) + self.assertEqual(len(new_dag.nodes), len(self.dag.nodes)) + + def test_dag_visualization(self): + """Test DAG visualization.""" + self._build_dag_from_config() + + viz = self.dag.visualize() + self.assertIsInstance(viz, str) + self.assertIn("DAG Execution Plan", viz) + + def test_dag_node_status_management(self): + """Test DAG node status management.""" + self._build_dag_from_config() + + # Get first node + first_node_id = list(self.dag.nodes.keys())[0] + + # Test status transitions + self.dag.mark_node_started(first_node_id) + # Check status for dict nodes + node = self.dag.nodes[first_node_id] + if isinstance(node, dict): + self.assertEqual(node["status"], DAGNodeStatus.RUNNING.value) + else: + self.assertEqual(node.status, DAGNodeStatus.RUNNING) + + self.dag.mark_node_completed(first_node_id, 1.5) + # Check status for dict nodes + node = self.dag.nodes[first_node_id] + if isinstance(node, dict): + self.assertEqual(node["status"], DAGNodeStatus.COMPLETED.value) + self.assertEqual(node["actual_duration"], 1.5) + else: + self.assertEqual(node.status, DAGNodeStatus.COMPLETED) + self.assertEqual(node.actual_duration, 1.5) + + def test_dag_execution_summary(self): + """Test DAG execution summary generation.""" + self._build_dag_from_config() + + summary = self.dag.get_execution_summary() + + self.assertIn("total_nodes", summary) + self.assertIn("completed_nodes", summary) + self.assertIn("pending_nodes", summary) + self.assertIn("parallel_groups_count", summary) + + +class TestDAGExecutionStrategies(unittest.TestCase): + """Test DAG execution strategies.""" + + def setUp(self): + """Set up test fixtures.""" + # Create mock operations + class MockOperation: + def __init__(self, name): + self._name = name + + self.operations = [ + MockOperation("text_length_filter"), + MockOperation("character_repetition_filter"), + MockOperation("document_deduplicator"), + MockOperation("text_cleaning_mapper"), + ] + + def test_non_partitioned_strategy(self): + """Test non-partitioned execution strategy.""" + strategy = NonPartitionedDAGStrategy() + + # Generate nodes + nodes = strategy.generate_dag_nodes(self.operations) + self.assertEqual(len(nodes), 4) + + # Test node ID generation + node_id = strategy.get_dag_node_id("text_length_filter", 0) + self.assertEqual(node_id, "op_001_text_length_filter") + + # Test dependency building + strategy.build_dependencies(nodes, self.operations) + self.assertGreater(len(nodes["op_002_character_repetition_filter"]["dependencies"]), 0) + + def test_partitioned_strategy(self): + """Test partitioned execution strategy.""" + strategy = PartitionedDAGStrategy(num_partitions=2) + + # Generate nodes + nodes = strategy.generate_dag_nodes(self.operations) + self.assertGreater(len(nodes), 4) # Should have partition-specific nodes + + # Test node ID generation + node_id = strategy.get_dag_node_id("text_length_filter", 0, partition_id=1) + self.assertEqual(node_id, "op_001_text_length_filter_partition_1") + + def test_global_operation_detection(self): + """Test global operation detection.""" + class MockDeduplicator: + def __init__(self): + self._name = "document_deduplicator" + + class MockFilter: + def __init__(self): + self._name = "text_length_filter" + + deduplicator = MockDeduplicator() + filter_op = MockFilter() + + self.assertTrue(is_global_operation(deduplicator)) + self.assertFalse(is_global_operation(filter_op)) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/tests/core/executor/test_ray_executor_partitioned.py b/tests/core/executor/test_ray_executor_partitioned.py new file mode 100644 index 0000000000..a3216843f9 --- /dev/null +++ b/tests/core/executor/test_ray_executor_partitioned.py @@ -0,0 +1,365 @@ +import os +import tempfile +import unittest +from data_juicer.core.executor.ray_executor_partitioned import PartitionedRayExecutor +from data_juicer.config import init_configs +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, TEST_TAG + + +class PartitionedRayExecutorTest(DataJuicerTestCaseBase): + root_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', '..', '..') + + def setUp(self) -> None: + super().setUp() + # Create temporary directory + self.tmp_dir = tempfile.mkdtemp(prefix='test_ray_executor_partitioned_') + + def tearDown(self) -> None: + super().tearDown() + # Clean up temporary directory + import shutil + if os.path.exists(self.tmp_dir): + shutil.rmtree(self.tmp_dir) + + @TEST_TAG('ray') + def test_end2end_execution_manual_partitioning(self): + """Test end-to-end execution with manual partitioning mode.""" + cfg = init_configs([ + '--config', os.path.join(self.root_path, 'demos/process_on_ray/configs/demo-new-config.yaml'), + '--partition.mode', 'manual', + '--partition.num_of_partitions', '2' + ]) + cfg.export_path = os.path.join(self.tmp_dir, 'test_end2end_execution_manual', 'res.jsonl') + cfg.work_dir = os.path.join(self.tmp_dir, 'test_end2end_execution_manual') + executor = PartitionedRayExecutor(cfg) + executor.run() + + # check result files + self.assertTrue(os.path.exists(cfg.export_path)) + + @TEST_TAG('ray') + def test_end2end_execution_with_checkpointing(self): + """Test end-to-end execution with checkpointing enabled.""" + cfg = init_configs([ + '--config', os.path.join(self.root_path, 'demos/process_on_ray/configs/demo-new-config.yaml'), + '--partition.mode', 'manual', + '--partition.num_of_partitions', '2', + '--checkpoint.enabled', 'true', + '--checkpoint.strategy', 'every_op' + ]) + cfg.export_path = os.path.join(self.tmp_dir, 'test_end2end_execution_checkpointing', 'res.jsonl') + cfg.work_dir = os.path.join(self.tmp_dir, 'test_end2end_execution_checkpointing') + executor = PartitionedRayExecutor(cfg) + executor.run() + + # check result files + self.assertTrue(os.path.exists(cfg.export_path)) + + # check checkpoint directory exists + checkpoint_dir = cfg.checkpoint_dir + self.assertTrue(os.path.exists(checkpoint_dir)) + + # check that checkpoint files were created + checkpoint_files = [f for f in os.listdir(checkpoint_dir) if f.endswith('.parquet')] + self.assertGreater(len(checkpoint_files), 0, "No checkpoint files were created") + + # verify checkpoint file naming convention + for checkpoint_file in checkpoint_files: + self.assertTrue(checkpoint_file.startswith('checkpoint_op_'), + f"Checkpoint file {checkpoint_file} doesn't follow naming convention") + self.assertTrue('_partition_' in checkpoint_file, + f"Checkpoint file {checkpoint_file} doesn't contain partition info") + self.assertTrue(checkpoint_file.endswith('.parquet'), + f"Checkpoint file {checkpoint_file} doesn't have .parquet extension") + + # test checkpoint loading functionality + executor2 = PartitionedRayExecutor(cfg) + + # test find_latest_checkpoint method (on checkpoint manager) + for partition_id in range(2): + latest_checkpoint = executor2.ckpt_manager.find_latest_checkpoint(partition_id) + if latest_checkpoint: + op_idx, _, checkpoint_path = latest_checkpoint + self.assertIsInstance(op_idx, int) + self.assertTrue(os.path.exists(checkpoint_path)) + self.assertTrue(checkpoint_path.endswith('.parquet')) + + # test resolve_checkpoint_filename method (on checkpoint manager) + test_filename = executor2.ckpt_manager.resolve_checkpoint_filename(0, 1) + expected_pattern = 'checkpoint_op_0000_partition_0001.parquet' + self.assertEqual(test_filename, expected_pattern) + + + @TEST_TAG('ray') + def test_dag_execution_initialization(self): + """Test DAG execution initialization and strategy selection.""" + cfg = init_configs([ + '--config', os.path.join(self.root_path, 'demos/process_on_ray/configs/demo-new-config.yaml'), + '--partition.mode', 'manual', + '--partition.num_of_partitions', '4' + ]) + cfg.export_path = os.path.join(self.tmp_dir, 'test_dag_initialization', 'res.jsonl') + cfg.work_dir = os.path.join(self.tmp_dir, 'test_dag_initialization') + + executor = PartitionedRayExecutor(cfg) + + # Test DAG initialization + executor._initialize_dag_execution(cfg) + + # Verify DAG is initialized + self.assertIsNotNone(executor.pipeline_dag) + self.assertIsNotNone(executor.dag_execution_strategy) + + # Verify partitioned strategy is used + from data_juicer.core.executor.dag_execution_strategies import PartitionedDAGStrategy + self.assertIsInstance(executor.dag_execution_strategy, PartitionedDAGStrategy) + + # Verify DAG nodes are created + self.assertGreater(len(executor.pipeline_dag.nodes), 0) + + @TEST_TAG('ray') + def test_convergence_point_detection(self): + """Test convergence point detection for global operations.""" + # Create a simple config without loading from file + from jsonargparse import Namespace + cfg = Namespace() + cfg.process = [ + {'text_length_filter': {'min_len': 10}}, + {'text_length_filter': {'max_len': 1000}} + ] + cfg.job_id = 'test_convergence_123' # Required for event logging + cfg.work_dir = os.path.join(self.tmp_dir, 'test_convergence') + cfg.event_logging = {'enabled': False} # Disable event logging for this test + + # Create executor without running full initialization + executor = PartitionedRayExecutor.__new__(PartitionedRayExecutor) + executor.cfg = cfg + executor.executor_type = 'ray_partitioned' + executor.work_dir = cfg.work_dir + executor.num_partitions = 2 + + # Initialize only the necessary components + from data_juicer.core.executor.event_logging_mixin import EventLoggingMixin + from data_juicer.core.executor.dag_execution_mixin import DAGExecutionMixin + EventLoggingMixin.__init__(executor, cfg) + DAGExecutionMixin.__init__(executor) + executor._override_strategy_methods() + + convergence_points = executor._detect_convergence_points(cfg) + + # Should not detect any convergence points for non-global operations + self.assertEqual(len(convergence_points), 0) + + @TEST_TAG('ray') + def test_partition_configuration_manual_mode(self): + """Test manual partition configuration.""" + cfg = init_configs([ + '--config', os.path.join(self.root_path, 'demos/process_on_ray/configs/demo-new-config.yaml'), + '--partition.mode', 'manual', + '--partition.num_of_partitions', '6' + ]) + cfg.export_path = os.path.join(self.tmp_dir, 'test_manual_config', 'res.jsonl') + cfg.work_dir = os.path.join(self.tmp_dir, 'test_manual_config') + + executor = PartitionedRayExecutor(cfg) + + # Verify manual mode configuration + self.assertEqual(executor.partition_mode, 'manual') + self.assertEqual(executor.num_partitions, 6) + + @TEST_TAG('ray') + def test_partition_configuration_auto_mode(self): + """Test auto partition configuration.""" + cfg = init_configs([ + '--config', os.path.join(self.root_path, 'demos/process_on_ray/configs/demo-new-config.yaml'), + '--partition.mode', 'auto' + ]) + cfg.export_path = os.path.join(self.tmp_dir, 'test_auto_config', 'res.jsonl') + cfg.work_dir = os.path.join(self.tmp_dir, 'test_auto_config') + + executor = PartitionedRayExecutor(cfg) + + # Verify auto mode configuration + self.assertEqual(executor.partition_mode, 'auto') + # num_partitions should be set to a default value initially + self.assertIsNotNone(executor.num_partitions) + + @TEST_TAG('ray') + def test_checkpoint_strategies(self): + """Test different checkpoint strategies.""" + cfg = init_configs([ + '--config', os.path.join(self.root_path, 'demos/process_on_ray/configs/demo-new-config.yaml'), + '--partition.mode', 'manual', + '--partition.num_of_partitions', '2', + '--checkpoint.enabled', 'true' + ]) + + # Test EVERY_OP strategy + cfg.checkpoint = {'strategy': 'every_op'} + executor = PartitionedRayExecutor(cfg) + self.assertEqual(executor.ckpt_manager.checkpoint_strategy.value, 'every_op') + + # Test EVERY_N_OPS strategy + cfg.checkpoint = {'strategy': 'every_n_ops', 'n_ops': 2} + executor = PartitionedRayExecutor(cfg) + self.assertEqual(executor.ckpt_manager.checkpoint_strategy.value, 'every_n_ops') + self.assertEqual(executor.ckpt_manager.checkpoint_n_ops, 2) + + # Test MANUAL strategy + cfg.checkpoint = {'strategy': 'manual', 'op_names': ['text_length_filter']} + executor = PartitionedRayExecutor(cfg) + self.assertEqual(executor.ckpt_manager.checkpoint_strategy.value, 'manual') + self.assertIn('text_length_filter', executor.ckpt_manager.checkpoint_op_names) + + # Test DISABLED strategy + cfg.checkpoint = {'strategy': 'disabled'} + executor = PartitionedRayExecutor(cfg) + self.assertEqual(executor.ckpt_manager.checkpoint_strategy.value, 'disabled') + self.assertFalse(executor.ckpt_manager.checkpoint_enabled) + + @TEST_TAG('ray') + def test_dag_node_generation(self): + """Test DAG node generation for partitioned execution.""" + cfg = init_configs([ + '--config', os.path.join(self.root_path, 'demos/process_on_ray/configs/demo-new-config.yaml'), + '--partition.mode', 'manual', + '--partition.num_of_partitions', '3' + ]) + cfg.export_path = os.path.join(self.tmp_dir, 'test_dag_nodes', 'res.jsonl') + cfg.work_dir = os.path.join(self.tmp_dir, 'test_dag_nodes') + + executor = PartitionedRayExecutor(cfg) + executor._initialize_dag_execution(cfg) + + # Test DAG node ID generation for different partitions + node_id_0 = executor._get_dag_node_for_operation_partitioned('text_length_filter', 0, partition_id=0) + node_id_1 = executor._get_dag_node_for_operation_partitioned('text_length_filter', 0, partition_id=1) + node_id_2 = executor._get_dag_node_for_operation_partitioned('text_length_filter', 0, partition_id=2) + + # All should be different for different partitions + self.assertNotEqual(node_id_0, node_id_1) + self.assertNotEqual(node_id_1, node_id_2) + self.assertNotEqual(node_id_0, node_id_2) + + # All should contain the partition ID + self.assertIn('_partition_0', node_id_0) + self.assertIn('_partition_1', node_id_1) + self.assertIn('_partition_2', node_id_2) + + @TEST_TAG('ray') + def test_global_operation_detection(self): + """Test detection of global operations that require convergence.""" + from data_juicer.core.executor.dag_execution_strategies import is_global_operation + + cfg = init_configs([ + '--config', os.path.join(self.root_path, 'demos/process_on_ray/configs/demo-new-config.yaml'), + '--partition.mode', 'manual', + '--partition.num_of_partitions', '2' + ]) + + executor = PartitionedRayExecutor(cfg) + + # Test deduplicator detection + from data_juicer.ops.deduplicator.ray_bts_minhash_deduplicator import RayBTSMinhashDeduplicator + deduplicator = RayBTSMinhashDeduplicator(hash_func='sha1', threshold=0.7) + self.assertTrue(is_global_operation(deduplicator)) + + # Test non-global operation + from data_juicer.ops.filter.text_length_filter import TextLengthFilter + text_filter = TextLengthFilter(min_len=10) + self.assertFalse(is_global_operation(text_filter)) + + @TEST_TAG('ray') + def test_executor_initialization_with_legacy_config(self): + """Test executor initialization with legacy num_partitions config.""" + cfg = init_configs([ + '--config', os.path.join(self.root_path, 'demos/process_on_ray/configs/demo-new-config.yaml') + ]) + # Use legacy num_partitions instead of partition config + cfg.num_partitions = 5 + cfg.export_path = os.path.join(self.tmp_dir, 'test_legacy_config', 'res.jsonl') + cfg.work_dir = os.path.join(self.tmp_dir, 'test_legacy_config') + + executor = PartitionedRayExecutor(cfg) + + # Should fall back to manual mode with legacy config + self.assertEqual(executor.partition_mode, 'manual') + self.assertEqual(executor.num_partitions, 5) + + @TEST_TAG('ray') + def test_job_resumption_workflow(self): + """Test job resumption workflow with user-provided job_id.""" + from unittest.mock import Mock, patch, MagicMock + import json + + # Create a simple config without loading from file + from jsonargparse import Namespace + cfg = Namespace() + cfg.process = [{'text_length_filter': {'min_len': 10}}] + cfg.dataset_path = 'test.jsonl' + cfg.work_dir = os.path.join(self.tmp_dir, 'test_job_resumption') + cfg.export_path = os.path.join(self.tmp_dir, 'test_job_resumption', 'res.jsonl') + cfg.partition = {'mode': 'manual', 'num_of_partitions': 2} + cfg.checkpoint = {'enabled': True, 'strategy': 'every_op'} + cfg._user_provided_job_id = False + cfg.job_id = 'test_job_resumption_123' # Required for event logging + cfg.event_logging = {'enabled': True} # Enable event logging for this test + + # Create work_dir first + os.makedirs(cfg.work_dir, exist_ok=True) + + # Create executor without running full initialization + executor = PartitionedRayExecutor.__new__(PartitionedRayExecutor) + executor.cfg = cfg + executor.executor_type = 'ray_partitioned' + executor.work_dir = cfg.work_dir + executor.num_partitions = 2 + + # Initialize only the necessary components + from data_juicer.core.executor.event_logging_mixin import EventLoggingMixin + from data_juicer.core.executor.dag_execution_mixin import DAGExecutionMixin + EventLoggingMixin.__init__(executor, cfg) + DAGExecutionMixin.__init__(executor) + executor._override_strategy_methods() + + # Test 1: Check job resumption when no job exists + cfg._user_provided_job_id = False + result = executor._resume_job('nonexistent_job') + self.assertEqual(result, "failed") + + # Test 2: Test job completion check with mock job directory + job_id = 'test_job_123' + job_dir = os.path.join(cfg.work_dir, f'20250101_120000_{job_id}') + os.makedirs(job_dir, exist_ok=True) + + # Create events file directly in job directory (required for job completion check) + events_file = os.path.join(job_dir, 'events_20250101_120000.jsonl') + with open(events_file, 'w') as f: + f.write('{"timestamp": "2025-01-01T12:00:00", "event_type": "job_start", "message": "Job started"}\n') + f.write('{"timestamp": "2025-01-01T12:01:00", "event_type": "job_complete", "message": "Job completed"}\n') + + # Test job completion check directly + is_completed = executor._check_job_completion(job_dir, job_id) + self.assertTrue(is_completed) + + # Test 3: Test job completion check with incomplete job + with open(events_file, 'w') as f: + f.write('{"timestamp": "2025-01-01T12:00:00", "event_type": "job_start", "message": "Job started"}\n') + f.write('{"timestamp": "2025-01-01T12:01:00", "event_type": "op_start", "message": "Operation started"}\n') + + is_completed = executor._check_job_completion(job_dir, job_id) + self.assertFalse(is_completed) + + # Test 4: Test job resumption with proper job directory (mock the directory finding) + cfg._user_provided_job_id = True + cfg.job_id = job_id + + # Mock the work directory finding to return our test directory + with patch.object(executor, '_find_work_directory', return_value=job_dir): + result = executor._resume_job(job_id) + # Should return "failed" due to config validation failure (we didn't save the config) + self.assertEqual(result, "failed") + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/mapper/test_image_sam_3d_body_mapper.py b/tests/ops/mapper/test_image_sam_3d_body_mapper.py index 1c0c522856..be20bb0317 100644 --- a/tests/ops/mapper/test_image_sam_3d_body_mapper.py +++ b/tests/ops/mapper/test_image_sam_3d_body_mapper.py @@ -11,6 +11,19 @@ from data_juicer.utils.unittest_utils import TEST_TAG, DataJuicerTestCaseBase +def _is_egl_available(): + """Check if EGL is available for offscreen rendering.""" + try: + from OpenGL.platform import ctypesloader + ctypesloader.loadLibrary(None, 'EGL') + return True + except (ImportError, OSError, TypeError): + return False + + +EGL_AVAILABLE = _is_egl_available() + + class ImageSAM3DBodyMapperTest(DataJuicerTestCaseBase): data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'data') @@ -106,10 +119,12 @@ def test_multi_process(self): """ self._run_test(num_proc=2) + @unittest.skipUnless(EGL_AVAILABLE, 'EGL not available for visualization') def test_vis(self): self._run_test(visualization_dir=self.tmp_dir, num_proc=1) @TEST_TAG('ray') + @unittest.skipUnless(EGL_AVAILABLE, 'EGL not available for visualization') def test_ray(self): self._run_test(visualization_dir=self.tmp_dir, ray_mode=True, num_proc=2) diff --git a/tests/ops/mapper/test_text_tagging_by_prompt_mapper.py b/tests/ops/mapper/test_text_tagging_by_prompt_mapper.py index d0123212c6..79371dae3d 100644 --- a/tests/ops/mapper/test_text_tagging_by_prompt_mapper.py +++ b/tests/ops/mapper/test_text_tagging_by_prompt_mapper.py @@ -1,6 +1,7 @@ import unittest from data_juicer.ops.mapper.text_tagging_by_prompt_mapper import TextTaggingByPromptMapper, DEFAULT_CLASSIFICATION_PROMPT, DEFAULT_CLASSIFICATION_LIST from data_juicer.utils.constant import Fields +from data_juicer.utils.resource_utils import is_cuda_available from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase def check_string_in_list(string_list, output): @@ -40,6 +41,7 @@ def test_tagging(self): }] self._run_tagging(samples) + @unittest.skipUnless(is_cuda_available(), 'vLLM requires CUDA') def test_tagging_vllm(self): samples = [ { diff --git a/tools/count_rows.py b/tools/count_rows.py new file mode 100644 index 0000000000..30bc128ec3 --- /dev/null +++ b/tools/count_rows.py @@ -0,0 +1,192 @@ +#!/usr/bin/env python3 +""" +Different ways to count rows in a parquet file +""" + +import argparse +from pathlib import Path + +import pandas as pd +import pyarrow as pa +import pyarrow.parquet as pq + + +def get_parquet_info(file_path): + """Get detailed information about the parquet file""" + print(f"\nParquet file information for: {file_path}") + print("-" * 50) + + parquet_file = pq.ParquetFile(file_path) + metadata = parquet_file.metadata + + print(f"Total rows: {metadata.num_rows:,}") + print(f"Total columns: {metadata.num_columns}") + print(f"Number of row groups: {metadata.num_row_groups}") + print(f"File size: {metadata.serialized_size / 1024 / 1024:.2f} MB") + + # Show column information + print("\nColumns:") + for i in range(metadata.num_columns): + col_meta = metadata.row_group(0).column(i) + print(f" {col_meta.path_in_schema}: {col_meta.physical_type}") + + +def count_rows_auto(file_path): + """Automatically choose the best method based on file extension and count rows""" + file_path = Path(file_path) + extension = file_path.suffix.lower() + + if extension == ".parquet": + # Use pyarrow metadata for parquet - fastest and most efficient + parquet_file = pq.ParquetFile(file_path) + row_count = parquet_file.metadata.num_rows + method_used = "pyarrow metadata" + elif extension in [".csv", ".tsv"]: + # For CSV files, use pandas + df = pd.read_csv(file_path) + row_count = len(df) + method_used = "pandas read_csv" + elif extension in [".json", ".jsonl"]: + # For JSON files, try to detect if it's JSONL content + try: + # First try to read as regular JSON + df = pd.read_json(file_path) + row_count = len(df) + method_used = "pandas read_json" + except Exception as e: + # If that fails, try reading as JSONL (one JSON object per line) + if "Trailing data" in str(e) or "Extra data" in str(e): + df = pd.read_json(file_path, lines=True) + row_count = len(df) + method_used = "pandas read_json (lines=True) - detected JSONL content" + else: + # Re-raise the original error if it's not a trailing data issue + raise e + elif extension in [".arrow", ".feather"]: + # For Arrow files, use pyarrow + table = pa.ipc.open_file(file_path).read_all() + row_count = table.num_rows + method_used = "pyarrow arrow" + else: + # Default to pandas for unknown extensions + try: + df = pd.read_csv(file_path) + row_count = len(df) + method_used = "pandas read_csv (default)" + except Exception as e: + print(f"Error: Could not read file with extension {extension}: {e}") + return None, None + + return row_count, method_used + + +def get_supported_extensions(): + """Return list of supported file extensions""" + return [".parquet", ".csv", ".tsv", ".json", ".jsonl", ".arrow", ".feather"] + + +def count_directory(directory_path, show_info=False): + """Count rows for all supported files in a directory""" + directory_path = Path(directory_path) + supported_extensions = get_supported_extensions() + + # Find all supported files in directory (recursive) + files = [] + for ext in supported_extensions: + files.extend(directory_path.rglob(f"*{ext}")) + + if not files: + print(f"No supported files found in directory: {directory_path}") + return + + # Sort files for consistent output + files = sorted(files) + + print(f"Found {len(files)} supported files in: {directory_path}") + print("=" * 80) + + total_rows = 0 + file_counts = [] + + for file_path in files: + try: + row_count, method_used = count_rows_auto(file_path) + if row_count is not None: + file_counts.append( + { + "file": file_path, + "rows": row_count, + "method": method_used, + "size_mb": file_path.stat().st_size / 1024 / 1024, + } + ) + total_rows += row_count + print(f"{file_path.name:<50} {row_count:>10,} rows ({method_used})") + else: + print(f"{file_path.name:<50} {'ERROR':>10}") + except Exception as e: + print(f"{file_path.name:<50} {'ERROR':>10} - {e}") + + # Print summary + print("=" * 80) + print(f"Total files: {len(file_counts)}") + print(f"Total rows: {total_rows:,}") + print(f"Average rows per file: {total_rows // len(file_counts):,}") + + # Show detailed info for parquet files if requested + if show_info: + parquet_files = [f for f in file_counts if f["file"].suffix.lower() == ".parquet"] + if parquet_files: + print("\n" + "=" * 80) + print("DETAILED PARQUET FILE INFORMATION") + print("=" * 80) + for file_info in parquet_files: + get_parquet_info(file_info["file"]) + print() + + return file_counts, total_rows + + +def main(): + parser = argparse.ArgumentParser(description="Count rows in data files using the most appropriate method") + parser.add_argument("path", help="Path to a data file or directory containing data files") + parser.add_argument("--info", "-i", action="store_true", help="Show detailed file information (for parquet files)") + + args = parser.parse_args() + + path = Path(args.path) + + if not path.exists(): + print(f"Error: Path not found: {args.path}") + return 1 + + if path.is_file(): + # Single file mode + print(f"Counting rows in: {args.path}") + print("=" * 60) + + row_count, method_used = count_rows_auto(args.path) + + if row_count is not None: + print(f"Row count: {row_count:,}") + print(f"Method used: {method_used}") + else: + return 1 + + # Show detailed info for parquet files if requested + if args.info and path.suffix.lower() == ".parquet": + get_parquet_info(args.path) + + elif path.is_dir(): + # Directory mode + count_directory(args.path, show_info=args.info) + + else: + print(f"Error: Path is neither a file nor a directory: {args.path}") + return 1 + + return 0 + + +if __name__ == "__main__": + exit(main()) diff --git a/tools/process_data.py b/tools/process_data.py index 075f3aeb62..3e959618fb 100644 --- a/tools/process_data.py +++ b/tools/process_data.py @@ -27,6 +27,14 @@ def main(): from data_juicer.core.executor.ray_executor import RayExecutor executor = RayExecutor(cfg) + elif cfg.executor_type == "ray_partitioned": + from data_juicer.core.executor.ray_executor_partitioned import ( + PartitionedRayExecutor, + ) + + executor = PartitionedRayExecutor(cfg) + else: + raise ValueError(f"Unsupported executor type: {cfg.executor_type}") with timing_context("Running executor"): executor.run()