Skip to content

How to use Data-Juicer with Ray (Python API, Checkpointing, and Pipeline Improvements) #846

@atulydvv

Description

@atulydvv

Before Asking 在提问之前

  • I have read the README carefully. 我已经仔细阅读了 README 上的操作指引。

  • I have pulled the latest code of main branch to run again and the problem still existed. 我已经拉取了主分支上最新的代码,重新运行之后,问题仍不能解决。

Search before asking 先搜索,再提问

  • I have searched the Data-Juicer issues and found no similar questions. 我已经在 issue列表 中搜索但是没有发现类似的问题。

Question

Hi everyone,
I’m new to Data-Juicer and Ray, and I’m building a pipeline for LLM data preprocessing. I’ve managed to get a basic workflow running, but I’m still unsure how to fully leverage Data-Juicer using the Python API (I prefer not to use the YAML config approach).
I haven't been able to find many examples of Python-based usage, so I wanted to share my main.py below and ask a few questions:

How can I properly use Data-Juicer’s checkpointing mechanism via the Python API, also how does it work exactly?
The documentation mostly shows YAML examples, so I’m not sure how to translate that into Python-based pipeline code.
For URL deduplication, I currently need to call materialize() afterward.
Is there a way to avoid manually materializing at that step?
Can Data-Juicer handle this more elegantly within the Python pipeline?
If you notice anything that can improve or simplify my pipeline, I'd really appreciate any suggestions. Also in general i wanted to understand, how can i better use data-juicer capabilities.
Here is my current main.py:


import os
import math
import boto3
import s3fs
import asyncio
import fasttext
import multiprocessing
import ray.data as rd
from pyarrow.fs import S3FileSystem
from ray.data import DatasetContext
from huggingface_hub import hf_hub_download
import ray
from ray import serve
from data_juicer.core.data.ray_dataset import RayDataset

num_cpus = multiprocessing.cpu_count()
print(f"Number of CPUs available: {num_cpus}")

# 2. Initialize Ray
# 'auto' connects to an existing cluster or starts a local one.
ray.init(num_cpus=num_cpus//4, ignore_reinit_error=True)
# ray.init(address="localhost:8265", ignore_reinit_error=True)

from dj_normalizer import ArabicNormalizerMapper
from dj_lid import ArabicLIDFilter, GlotLIDModel
from dj_url_dedup_bloom import DJUrlBloomDeduplicator
from dj_url_dedup_exact import DJUrlExactDeduplicator
from dj_url_filter import DjUrlFilter
from dj_quality_heuristics import QHMapper


os.environ["ARROW_S3_CLIENT_TIMEOUT"] = "600"   # 10 min
os.environ["ARROW_S3_READ_TIMEOUT"] = "600"
os.environ["ARROW_S3_CONNECT_TIMEOUT"] = "300"
os.environ["ARROW_S3_RETRY_LIMIT"] = "10"
os.environ["AWS_MAX_ATTEMPTS"] = "10"

ctx = DatasetContext.get_current()
max_input_block_size = 100 * 1024 * 1024  # 100MB
ctx.target_max_block_size = max_input_block_size

bucket = "bucket-name"
input_paths = ["test/raw/5_10pct*.parquet", "test/raw/1.parquet"]
file_paths = []
for input_path in input_paths:
    s3_input_path = f"s3://{bucket}/{input_path}"
    file_paths.extend(s3fs.S3FileSystem().glob(s3_input_path))

file_infos = [ fs.info(p) for p in file_paths ]
print("File Infos: ", file_infos)

s3_file_list = [f"s3://{fp}" for fp in file_paths]
print(f"S3 File List: {s3_file_list}")

output_key = "test/processed/"
s3_output_path = f"s3://{bucket}/{output_key}"
print(f"Output Path: {s3_output_path}")


#TODO: Need to improve this chunking logic based on size
def chunk_list(lst, n):
    for i in range(0, len(lst), n):
        yield lst[i:i + n]
file_batches = list(chunk_list(s3_file_list, 5))

# 1. Load the model ONCE in the driver
print("Loading model in driver...")
# We load the model object here
model_path = hf_hub_download(repo_id="cis-lmu/glotlid", filename="model.bin")

print("Deploying Model Service...")
serve.run(GlotLIDModel.bind(model_path), name="glotlid")

fs = S3FileSystem(
        region="us-west-2", 
        request_timeout=600,
        connect_timeout=300
    )

for i, batch_files in enumerate(file_batches):
    print(f"Processing batch {i+1}/{len(file_batches)} with {len(batch_files)} files...")
    
    # 3. Read Input Data from S3 into Ray Dataset
    # TODO: Here check checksum and s3 file location in redis
    # if already processed or not. Also at the end of the pipeline
    # Store this in redis as well.
    ray_ds = rd.read_parquet_bulk(
        s3_file_list,
        filesystem=fs
        )

    input_files = ray_ds.input_files()
    print(f"Input files: {input_files}")

    input_size = ray_ds.size_bytes()
    print(f"Input Dataset Size (bytes): {input_size}")
    output_target_number_of_blocks = max(1, input_size // max_input_block_size)
    print(f"Output Target Number of Blocks: {output_target_number_of_blocks}")

    ds = RayDataset(
        dataset=ray_ds
    )

    # 5. Define your Pipeline (List of Operators)
    # You can instantiate these classes directly with python arguments.
    dedup_ops = [
        DJUrlBloomDeduplicator(),
    ]

    # 6. Execute the Pipeline
    # The .process() method runs the operators in sequence using Ray.
    print("Starting processing with Ray...")
    dedup_ds = ds.process(dedup_ops)
    dedup_ds.data = dedup_ds.data.materialize()

    print(f"Deduplication complete., count after dedup: {dedup_ds.data.count()}, dropped by dedup: {ray_ds.count() - dedup_ds.data.count()}, size after dedup: {dedup_ds.data.size_bytes()} bytes, Stats: {dedup_ds.data.stats()}")

    url_filtering_ops = [
        DjUrlFilter(), # 10k rows with bloom filter, 17.4k
        ArabicLIDFilter(),
        ArabicNormalizerMapper(),
        QHMapper()
    ]

    normalized_ds = dedup_ds.process(url_filtering_ops)

    # print(f"Filtering complete., count after url filter: {normalized_ds.data.count()}, dropped by filter: {dedup_ds.data.count() - normalized_ds.data.count()}, size after filter: {normalized_ds.data.size_bytes()} bytes, Stats: {normalized_ds.data.stats()}")

    final_ds = normalized_ds.data.repartition(num_blocks=output_target_number_of_blocks, shuffle=False)

    # 7. Export the Results to S3
    print("Exporting data to S3...")
    final_ds.write_parquet(
        path=s3_output_path,
        filesystem=fs
    )    
    print("Job finished successfully.")

Thanks in advance for the help! I really appreciate any guidance or examples from others who have used Data-Juicer with Ray via Python.

Additional 额外信息

No response

Metadata

Metadata

Labels

questionFurther information is requested

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions