diff --git a/inference/Qwen2.5-VL-3B-instruct.py b/inference/Qwen2.5-VL-3B-instruct.py new file mode 100644 index 000000000..5999ddaec --- /dev/null +++ b/inference/Qwen2.5-VL-3B-instruct.py @@ -0,0 +1,61 @@ +from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor +from qwen_vl_utils import process_vision_info +from datasets import load_dataset + +# default: Load the model on the available device(s) +model = Qwen2_5_VLForConditionalGeneration.from_pretrained( + "smolagents/Qwen2.5-VL-3B-Instruct-Agentic", torch_dtype="auto", device_map="auto" +) + +# We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios. +# model = Qwen2_5_VLForConditionalGeneration.from_pretrained( +# "Qwen/Qwen2.5-VL-3B-Instruct", +# torch_dtype=torch.bfloat16, +# attn_implementation="flash_attention_2", +# device_map="auto", +# ) + +# default processer +processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct") + +# The default range for the number of visual tokens per image in the model is 4-16384. +# You can set min_pixels and max_pixels according to your needs, such as a token range of 256-1280, to balance performance and cost. +# min_pixels = 256*28*28 +# max_pixels = 1280*28*28 +# processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels) + +dataset = load_dataset("smolagents/aguvis-stage-2", "mind2web", split="train") + +for example in dataset: + messages = [ + {"role": "system", "content": example["system"]}, + {"role": "user", "content": [ + {"type": "image", "image": example["image"]}, + {"type": "text", "text": example["user"]} + ]}, + ] + break + +# Preparation for inference +text = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True +) +image_inputs, video_inputs = process_vision_info(messages) +inputs = processor( + text=[text], + images=image_inputs, + videos=video_inputs, + padding=True, + return_tensors="pt", +) +inputs = inputs.to("cuda") + +# Inference: Generation of the output +generated_ids = model.generate(**inputs, max_new_tokens=4096) +generated_ids_trimmed = [ + out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) +] +output_text = processor.batch_decode( + generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False +) +print(output_text) \ No newline at end of file diff --git a/recipes/Qwen2.5-VL-3B-Instruct/sft/config_gui.yaml b/recipes/Qwen2.5-VL-3B-Instruct/sft/config_gui.yaml new file mode 100644 index 000000000..81653343e --- /dev/null +++ b/recipes/Qwen2.5-VL-3B-Instruct/sft/config_gui.yaml @@ -0,0 +1,142 @@ +# Model arguments +# You can download the model and manually change the rope to 300k/500k and max_position_embeddings to 32768 +model_name_or_path: Qwen/Qwen2.5-VL-3B-Instruct +vision_model: true +model_revision: main +torch_dtype: bfloat16 +attn_implementation: sdpa + +# Data training arguments +dataset_name: smolagents/aguvis-stage-2 +dataset_num_proc: 48 + +#SFT hyperparam +max_length: 4096 +optim: adamw_torch +lr_scheduler_type: cosine_with_min_lr +lr_scheduler_kwargs: + min_lr_rate: 0.1 +max_grad_norm: 0.2 +warmup_ratio: 0.03 +learning_rate: 2.0e-05 +gradient_accumulation_steps: 16 +per_device_eval_batch_size: 4 +per_device_train_batch_size: 4 # Change this depending on the context length of the model to keep a 500M GBS. + +# Image resize arguments +image_resize: + factor: 28 + min_pixels: 200704 + max_pixels: 1003520 + +# SFT trainer config +max_steps: -1 +num_train_epochs: 1 +bf16: true +do_eval: true +eval_strategy: 'steps' +eval_steps: 100 +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false +hub_model_id: A-Mahla/Qwen2.5-VL-3B-Instruct-Agentic-GUI +hub_strategy: end +push_to_hub: true +log_level: info +logging_steps: 5 +logging_strategy: steps +output_dir: /fsx/amir_mahla/smolagents-Qwen2.5-VL-3B-Instruct-Agentic +overwrite_output_dir: true +report_to: +- wandb +wandb_project: smolagents +save_strategy: "epoch" +save_steps: 1 +save_total_limit: 1 +seed: 42 + +dataset_mixture: + datasets: # List of datasets to include in the mixture + - id: smolagents/aguvis-stage-2 # Hub dataset ID + config: mind2web # Name of the dataset config + split: train # Split to use from the dataset + columns: # Columns to keep + - system + - user + - assistant + - image + weight: 1. + - id: smolagents/aguvis-stage-2 + config: guiact-web-single + split: train + columns: + - system + - user + - assistant + - image + weight: 1. + - id: smolagents/aguvis-stage-2 + config: guiact-web-multi + split: train + columns: + - system + - user + - assistant + - image + weight: 1. + - id: smolagents/aguvis-stage-2 + config: miniwob + split: train + columns: + - system + - user + - assistant + - image + weight: 1. + - id: smolagents/aguvis-stage-2 + config: coat + split: train + columns: + - system + - user + - assistant + - image + weight: 1. + - id: smolagents/aguvis-stage-2 + config: android_control + split: train + columns: + - system + - user + - assistant + - image + weight: 1. + - id: smolagents/aguvis-stage-2 + config: gui-odyssey + split: train + columns: + - system + - user + - assistant + - image + weight: 1. + - id: smolagents/aguvis-stage-2 + config: amex + split: train + columns: + - system + - user + - assistant + - image + weight: 1. + - id: smolagents/aguvis-stage-2 + config: aitw + split: train + columns: + - system + - user + - assistant + - image + weight: 1. + seed: 42 # Seed for shuffling the combined dataset + test_split_size: 0.01 \ No newline at end of file diff --git a/recipes/SmolVLM2-2.2B-Instruct/sft/config_gui_phase_1_1152.yaml b/recipes/SmolVLM2-2.2B-Instruct/sft/config_gui_phase_1_1152.yaml new file mode 100644 index 000000000..c6af88c36 --- /dev/null +++ b/recipes/SmolVLM2-2.2B-Instruct/sft/config_gui_phase_1_1152.yaml @@ -0,0 +1,116 @@ +# Model arguments +# You can download the model and manually change the rope to 300k/500k and max_position_embeddings to 32768 +model_name_or_path: HuggingFaceTB/SmolVLM2-2.2B-Instruct +vision_model: true +model_revision: main +torch_dtype: bfloat16 +attn_implementation: sdpa + +# Data training arguments +dataset_name: smolagents/aguvis-stage-2 +dataset_num_proc: 48 + +#SFT hyperparam +max_length: 4096 +optim: adamw_torch +lr_scheduler_type: cosine_with_min_lr +lr_scheduler_kwargs: + min_lr_rate: 0.1 +max_grad_norm: 0.2 +warmup_ratio: 0.03 +learning_rate: 2.0e-05 +gradient_accumulation_steps: 32 +per_device_eval_batch_size: 2 +per_device_train_batch_size: 2 # Change this depending on the context length of the model to keep a 500M GBS. + +image_resize: + resolution_max_side: 1152 + to_pixel_coordinates: true + +# SFT trainer config +max_steps: -1 +num_train_epochs: 1 +bf16: true +do_eval: false +eval_strategy: 'steps' +eval_steps: 100 +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false +hub_model_id: smolagents/SmolVLM2-2.2B-Instruct-Agentic-GUI +hub_model_revision: main +hub_strategy: end +push_to_hub: false +log_level: info +logging_steps: 5 +logging_strategy: steps +output_dir: /fsx/amir_mahla/smolagents-SmolVLM2-2.2B-Instruct-Agentic-GUI-phase-1-max-size-1152-pixel-coordinates +overwrite_output_dir: true +report_to: +- wandb +wandb_project: smolagents +save_strategy: steps +save_steps: 800 +save_total_limit: 1 +seed: 42 + +dataset_mixture: + datasets: # List of datasets to include in the mixture + - id: smolagents/aguvis-stage-1 # Hub dataset ID + config: guienv # Name of the dataset config + split: train # Split to use from the dataset + columns: # Columns to keep + - images + - texts + weight: 1. + - id: smolagents/aguvis-stage-1 + config: omniact + split: train + columns: + - images + - texts + weight: 1. + - id: smolagents/aguvis-stage-1 + config: ricoig16k + split: train + columns: + - images + - texts + weight: 1. + - id: smolagents/aguvis-stage-1 + config: ricosca + split: train + columns: + - images + - texts + weight: 1. + - id: smolagents/aguvis-stage-1 + config: seeclick + split: train + columns: + - images + - texts + weight: 1. + - id: smolagents/aguvis-stage-1 + config: ui_refexp + split: train + columns: + - images + - texts + weight: 1. + - id: smolagents/aguvis-stage-1 + config: webui350k + split: train + columns: + - images + - texts + weight: 1. + - id: smolagents/aguvis-stage-1 + config: widget_captioning + split: train + columns: + - images + - texts + weight: 1. + seed: 42 # Seed for shuffling the combined dataset + test_split_size: 0.007 \ No newline at end of file diff --git a/recipes/SmolVLM2-2.2B-Instruct/sft/config_gui_phase_1_384.yaml b/recipes/SmolVLM2-2.2B-Instruct/sft/config_gui_phase_1_384.yaml new file mode 100644 index 000000000..a102e738e --- /dev/null +++ b/recipes/SmolVLM2-2.2B-Instruct/sft/config_gui_phase_1_384.yaml @@ -0,0 +1,116 @@ +# Model arguments +# You can download the model and manually change the rope to 300k/500k and max_position_embeddings to 32768 +model_name_or_path: HuggingFaceTB/SmolVLM2-2.2B-Instruct +vision_model: true +model_revision: main +torch_dtype: bfloat16 +attn_implementation: sdpa + +# Data training arguments +dataset_name: smolagents/aguvis-stage-2 +dataset_num_proc: 48 + +#SFT hyperparam +max_length: 4096 +optim: adamw_torch +lr_scheduler_type: cosine_with_min_lr +lr_scheduler_kwargs: + min_lr_rate: 0.1 +max_grad_norm: 0.2 +warmup_ratio: 0.03 +learning_rate: 2.0e-05 +gradient_accumulation_steps: 32 +per_device_eval_batch_size: 2 +per_device_train_batch_size: 2 # Change this depending on the context length of the model to keep a 500M GBS. + +image_resize: + resolution_max_side: 384 + to_pixel_coordinates: true + +# SFT trainer config +max_steps: -1 +num_train_epochs: 1 +bf16: true +do_eval: false +eval_strategy: 'steps' +eval_steps: 100 +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false +hub_model_id: smolagents/SmolVLM2-2.2B-Instruct-Agentic-GUI +hub_model_revision: main +hub_strategy: end +push_to_hub: false +log_level: info +logging_steps: 5 +logging_strategy: steps +output_dir: /fsx/amir_mahla/smolagents-SmolVLM2-2.2B-Instruct-Agentic-GUI-phase-1-max-size-384-pixel-coordinates +overwrite_output_dir: true +report_to: +- wandb +wandb_project: smolagents +save_strategy: steps +save_steps: 800 +save_total_limit: 1 +seed: 42 + +dataset_mixture: + datasets: # List of datasets to include in the mixture + - id: smolagents/aguvis-stage-1 # Hub dataset ID + config: guienv # Name of the dataset config + split: train # Split to use from the dataset + columns: # Columns to keep + - images + - texts + weight: 1. + - id: smolagents/aguvis-stage-1 + config: omniact + split: train + columns: + - images + - texts + weight: 1. + - id: smolagents/aguvis-stage-1 + config: ricoig16k + split: train + columns: + - images + - texts + weight: 1. + - id: smolagents/aguvis-stage-1 + config: ricosca + split: train + columns: + - images + - texts + weight: 1. + - id: smolagents/aguvis-stage-1 + config: seeclick + split: train + columns: + - images + - texts + weight: 1. + - id: smolagents/aguvis-stage-1 + config: ui_refexp + split: train + columns: + - images + - texts + weight: 1. + - id: smolagents/aguvis-stage-1 + config: webui350k + split: train + columns: + - images + - texts + weight: 1. + - id: smolagents/aguvis-stage-1 + config: widget_captioning + split: train + columns: + - images + - texts + weight: 1. + seed: 42 # Seed for shuffling the combined dataset + test_split_size: 0.007 \ No newline at end of file diff --git a/recipes/SmolVLM2-2.2B-Instruct/sft/config_gui_phase_1_764.yaml b/recipes/SmolVLM2-2.2B-Instruct/sft/config_gui_phase_1_764.yaml new file mode 100644 index 000000000..2b78d8f6b --- /dev/null +++ b/recipes/SmolVLM2-2.2B-Instruct/sft/config_gui_phase_1_764.yaml @@ -0,0 +1,116 @@ +# Model arguments +# You can download the model and manually change the rope to 300k/500k and max_position_embeddings to 32768 +model_name_or_path: HuggingFaceTB/SmolVLM2-2.2B-Instruct +vision_model: true +model_revision: main +torch_dtype: bfloat16 +attn_implementation: sdpa + +# Data training arguments +dataset_name: smolagents/aguvis-stage-2 +dataset_num_proc: 48 + +#SFT hyperparam +max_length: 4096 +optim: adamw_torch +lr_scheduler_type: cosine_with_min_lr +lr_scheduler_kwargs: + min_lr_rate: 0.1 +max_grad_norm: 0.2 +warmup_ratio: 0.03 +learning_rate: 2.0e-05 +gradient_accumulation_steps: 32 +per_device_eval_batch_size: 2 +per_device_train_batch_size: 2 # Change this depending on the context length of the model to keep a 500M GBS. + +image_resize: + resolution_max_side: 764 + to_pixel_coordinates: true + +# SFT trainer config +max_steps: -1 +num_train_epochs: 1 +bf16: true +do_eval: false +eval_strategy: 'steps' +eval_steps: 100 +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false +hub_model_id: smolagents/SmolVLM2-2.2B-Instruct-Agentic-GUI +hub_model_revision: main +hub_strategy: end +push_to_hub: false +log_level: info +logging_steps: 5 +logging_strategy: steps +output_dir: /fsx/amir_mahla/smolagents-SmolVLM2-2.2B-Instruct-Agentic-GUI-phase-1-max-size-764-pixel-coordinates +overwrite_output_dir: true +report_to: +- wandb +wandb_project: smolagents +save_strategy: steps +save_steps: 800 +save_total_limit: 1 +seed: 42 + +dataset_mixture: + datasets: # List of datasets to include in the mixture + - id: smolagents/aguvis-stage-1 # Hub dataset ID + config: guienv # Name of the dataset config + split: train # Split to use from the dataset + columns: # Columns to keep + - images + - texts + weight: 1. + - id: smolagents/aguvis-stage-1 + config: omniact + split: train + columns: + - images + - texts + weight: 1. + - id: smolagents/aguvis-stage-1 + config: ricoig16k + split: train + columns: + - images + - texts + weight: 1. + - id: smolagents/aguvis-stage-1 + config: ricosca + split: train + columns: + - images + - texts + weight: 1. + - id: smolagents/aguvis-stage-1 + config: seeclick + split: train + columns: + - images + - texts + weight: 1. + - id: smolagents/aguvis-stage-1 + config: ui_refexp + split: train + columns: + - images + - texts + weight: 1. + - id: smolagents/aguvis-stage-1 + config: webui350k + split: train + columns: + - images + - texts + weight: 1. + - id: smolagents/aguvis-stage-1 + config: widget_captioning + split: train + columns: + - images + - texts + weight: 1. + seed: 42 # Seed for shuffling the combined dataset + test_split_size: 0.007 \ No newline at end of file diff --git a/recipes/SmolVLM2-2.2B-Instruct/sft/config_gui_phase_2.yaml b/recipes/SmolVLM2-2.2B-Instruct/sft/config_gui_phase_2.yaml new file mode 100644 index 000000000..6c396a0b8 --- /dev/null +++ b/recipes/SmolVLM2-2.2B-Instruct/sft/config_gui_phase_2.yaml @@ -0,0 +1,137 @@ +# Model arguments +# You can download the model and manually change the rope to 300k/500k and max_position_embeddings to 32768 +model_name_or_path: HuggingFaceTB/SmolVLM2-2.2B-Instruct +vision_model: true +model_revision: main +torch_dtype: bfloat16 +attn_implementation: sdpa + +# Data training arguments +dataset_name: smolagents/aguvis-stage-2 +dataset_num_proc: 48 + +#SFT hyperparam +max_length: 4096 +optim: adamw_torch +lr_scheduler_type: cosine_with_min_lr +lr_scheduler_kwargs: + min_lr_rate: 0.1 +max_grad_norm: 0.2 +warmup_ratio: 0.03 +learning_rate: 2.0e-05 +gradient_accumulation_steps: 16 +per_device_eval_batch_size: 4 +per_device_train_batch_size: 4 # Change this depending on the context length of the model to keep a 500M GBS. + +# SFT trainer config +max_steps: -1 +num_train_epochs: 1 +bf16: true +do_eval: true +eval_strategy: 'steps' +eval_steps: 100 +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false +hub_model_id: smolagents/SmolVLM2-2.2B-Instruct-Agentic-GUI +hub_model_revision: max-resolution-1152-without-system +hub_strategy: end +push_to_hub: true +log_level: info +logging_steps: 5 +logging_strategy: steps +output_dir: /fsx/amir_mahla/smolagents-SmolVLM2-2.2B-Instruct-Agentic-GUI-max-resolution-1152-without-system +overwrite_output_dir: true +report_to: +- wandb +wandb_project: smolagents +save_strategy: "epoch" +save_steps: 1 +save_total_limit: 1 +seed: 42 + +dataset_mixture: + datasets: # List of datasets to include in the mixture + - id: smolagents/aguvis-stage-2 # Hub dataset ID + config: mind2web # Name of the dataset config + split: train # Split to use from the dataset + columns: # Columns to keep + - system + - user + - assistant + - image + weight: 1. + - id: smolagents/aguvis-stage-2 + config: guiact-web-single + split: train + columns: + - system + - user + - assistant + - image + weight: 1. + - id: smolagents/aguvis-stage-2 + config: guiact-web-multi + split: train + columns: + - system + - user + - assistant + - image + weight: 1. + - id: smolagents/aguvis-stage-2 + config: miniwob + split: train + columns: + - system + - user + - assistant + - image + weight: 1. + - id: smolagents/aguvis-stage-2 + config: coat + split: train + columns: + - system + - user + - assistant + - image + weight: 1. + - id: smolagents/aguvis-stage-2 + config: android_control + split: train + columns: + - system + - user + - assistant + - image + weight: 1. + - id: smolagents/aguvis-stage-2 + config: gui-odyssey + split: train + columns: + - system + - user + - assistant + - image + weight: 1. + - id: smolagents/aguvis-stage-2 + config: amex + split: train + columns: + - system + - user + - assistant + - image + weight: 1. + - id: smolagents/aguvis-stage-2 + config: aitw + split: train + columns: + - system + - user + - assistant + - image + weight: 1. + seed: 42 # Seed for shuffling the combined dataset + test_split_size: 0.01 \ No newline at end of file diff --git a/scripts/agents/action_conversion.py b/scripts/agents/action_conversion.py new file mode 100644 index 000000000..d25fcd6d6 --- /dev/null +++ b/scripts/agents/action_conversion.py @@ -0,0 +1,194 @@ +from function_parser import FunctionCall +from copy import deepcopy + +# from aguvis aguvis action space to custom action space: + +# mobile.home() -> navigate_home() +# mobile.open_app(app_name='drupe') -> open_app(app_name: str) -> str: +# mobile.swipe(from_coord=[0.581, 0.898], to_coord=[0.601, 0.518]) -> swipe(from_coord=[0.581, 0.898], to_coord=[0.601, 0.518]) +# mobile.back() -> navigate_back() +# mobile.long_press(x=0.799, y=0.911) -> long_press(x, y) +# mobile.terminate(status='success') -> final_answer(answer: str) + +# answer('text') -> final_answer('text') OK +# mobile.wait(seconds=3) -> wait(seconds=3) OK +# pyautogui.hscroll(page=-0.1) +# ? +# pyautogui.scroll(page=-0.1) or pyautogui.scroll(0.13) OK +# -> negative: scroll(direction: Literal["up", "down"] = "up", amount: int = abs(page * 10)) +# -> positive: scroll(direction: Literal["up", "down"] = "down", amount: int = abs(page * 10)) +# pyautogui.click(x=0.8102, y=0.9463) -> click(x: int, y: int) OK +# pyautogui.doubleClick() -> double_click() OK +# pyautogui.hotkey(keys=['ctrl', 'c']) -> press(keys: str | list) OK +# pyautogui.press(keys='enter') or pyautogui.press(keys=['enter']) -> press(keys: str | list) OK +# pyautogui.moveTo(x=0.04, y=0.405) -> move_mouse(x: int, y: int) OK +# pyautogui.write(message='bread buns') -> type(text: str) OK +# pyautogui.dragTo(x=0.8102, y=0.9463) -> drag(x1, y1, x2, y2) OK but to recheck formatage in official dataset + + +def convert_to_pixel_coordinates(action: FunctionCall, resolution: tuple[int, int]) -> None: + if "arg_0" in action.parameters: + if isinstance(action.parameters["arg_0"], (list, tuple)): + action.parameters["from_coord"] = (int(action.parameters["arg_0"][0] * resolution[0]), int(action.parameters["arg_0"][1] * resolution[1])) + else: + action.parameters["x"] = int(action.parameters["arg_0"] * resolution[0]) + del action.parameters["arg_0"] + if "arg_1" in action.parameters: + if isinstance(action.parameters["arg_1"], (list, tuple)): + action.parameters["to_coord"] = (int(action.parameters["arg_1"][0] * resolution[0]), int(action.parameters["arg_1"][1] * resolution[1])) + else: + action.parameters["y"] = int(action.parameters["arg_1"] * resolution[1]) + del action.parameters["arg_1"] + +def change_argument_name(action: FunctionCall) -> None: + if "arg_0" in action.parameters: + if isinstance(action.parameters["arg_0"], (list, tuple)): + action.parameters["from_coord"] = (float(action.parameters["arg_0"][0]), float(action.parameters["arg_0"][1])) + else: + action.parameters["x"] = float(action.parameters["arg_0"]) + del action.parameters["arg_0"] + if "arg_1" in action.parameters: + if isinstance(action.parameters["arg_1"], (list, tuple)): + action.parameters["to_coord"] = (float(action.parameters["arg_1"][0]), float(action.parameters["arg_1"][1])) + else: + action.parameters["y"] = float(action.parameters["arg_1"]) + del action.parameters["arg_1"] + + +def rename_parameters(action: FunctionCall) -> None: + """ + Reorder FunctionCall parameters to use arg_0, arg_1, arg_2, etc. as keys. + Preserves the order of the original parameters. + + Args: + action: FunctionCall object to reorder parameters for + + """ + if not action.parameters: + return + + for i, (key, value) in enumerate(deepcopy(action.parameters).items()): + tmp = value + del action.parameters[key] + action.parameters[f"arg_{i}"] = tmp + + + +def action_conversion( + actions: list[FunctionCall], resolution: tuple[int, int] +) -> list[FunctionCall]: + for i, action in enumerate(actions): + rename_parameters(action) + # MOBILE ACTIONS + if action.function_name == "mobile.home": + actions[i].function_name = "navigate_home" + + elif action.function_name == "mobile.open_app": + actions[i].function_name = "open_app" + + elif action.function_name == "mobile.swipe": + actions[i].function_name = "swipe" + change_argument_name(actions[i]) + + elif action.function_name == "mobile.back": + actions[i].function_name = "navigate_back" + + elif action.function_name == "mobile.long_press": + actions[i].function_name = "long_press" + change_argument_name(actions[i]) + + elif action.function_name in ["mobile.terminate", "answer"]: + actions[i].function_name = "final_answer" + + elif action.function_name == "mobile.wait": + actions[i].function_name = "wait" + if "arg_0" in actions[i].parameters: + actions[i].parameters["seconds"] = int(actions[i].parameters["arg_0"]) + del actions[i].parameters["arg_0"] + + # OS ACTION + elif action.function_name == "pyautogui.click": + actions[i].function_name = "click" + change_argument_name(actions[i]) + + elif action.function_name == "pyautogui.doubleClick": + actions[i].function_name = "double_click" + change_argument_name(actions[i]) + + elif action.function_name == "pyautogui.rightClick": + actions[i].function_name = "right_click" + change_argument_name(actions[i]) + + elif action.function_name in ["pyautogui.hotkey", "pyautogui.press"]: + actions[i].function_name = "press" + if "arg_0" in actions[i].parameters: + actions[i].parameters["keys"] = actions[i].parameters["arg_0"] + del actions[i].parameters["arg_0"] + + elif action.function_name == "pyautogui.moveTo": + actions[i].function_name = "move_mouse" + change_argument_name(actions[i]) + + elif action.function_name == "pyautogui.write": + actions[i].function_name = "type" + + elif action.function_name in ["pyautogui.scroll", "pyautogui.hscroll"]: + arg_value = actions[i].parameters["arg_0"] + if arg_value < 0: + if action.function_name == "pyautogui.hscroll": + actions[i].parameters["direction"] = "left" + else: + actions[i].parameters["direction"] = "up" + else: + if action.function_name == "pyautogui.hscroll": + actions[i].parameters["direction"] = "right" + else: + actions[i].parameters["direction"] = "down" + del actions[i].parameters["arg_0"] + actions[i].function_name = "scroll" + actions[i].parameters["amount"] = int(abs(arg_value * 100)) + + elif action.function_name == "pyautogui.dragTo": + actions[i].function_name = "drag" + change_argument_name(actions[i]) + + else: + ValueError("Error FonctionCall Formatting") + + actions[i].original_string = actions[i].to_string() + + return actions + +if __name__ == "__main__": + from function_parser import FunctionCall + + # Example actions for all function types + actions = [ + # MOBILE ACTIONS + FunctionCall("mobile.home", {}, "mobile.home()"), + FunctionCall("mobile.open_app", {"app_name": "drupe"}, "mobile.open_app(app_name='drupe')"), + FunctionCall("mobile.swipe", {"from_coord": [0.581, 0.898], "to_coord": [0.601, 0.518]}, "mobile.swipe(from_coord=[0.581,0.898],to_coord=[0.601,0.518])"), + FunctionCall("mobile.back", {}, "mobile.back()"), + FunctionCall("mobile.long_press", {"x": 0.799, "y": 0.911}, "mobile.long_press(x=0.799, y=0.911)"), + FunctionCall("mobile.terminate", {"status": "success"}, "mobile.terminate(status='success')"), + FunctionCall("answer", {"arg_0": "text"}, "answer('text')"), + FunctionCall("mobile.wait", {"seconds": 3}, "mobile.wait(seconds=3)"), + # OS ACTIONS + FunctionCall("pyautogui.hscroll", {"page": -0.1}, "pyautogui.hscroll(page=-0.1)"), + FunctionCall("pyautogui.scroll", {"page": 0.13}, "pyautogui.scroll(page=0.13)"), + FunctionCall("pyautogui.click", {"x": 0.8102, "y": 0.9463}, "pyautogui.click(x=0.8102, y=0.9463)"), + FunctionCall("pyautogui.doubleClick", {}, "pyautogui.doubleClick()"), + FunctionCall("pyautogui.hotkey", {"keys": ["ctrl", "c"]}, "pyautogui.hotkey(keys=['ctrl','c'])"), + FunctionCall("pyautogui.press", {"keys": "enter"}, "pyautogui.press(keys='enter')"), + FunctionCall("pyautogui.moveTo", {"x": 0.04, "y": 0.405}, "pyautogui.moveTo(x=0.04, y=0.405)"), + FunctionCall("pyautogui.write", {"message": "bread buns"}, "pyautogui.write(message='bread buns')"), + FunctionCall("pyautogui.dragTo", {"from_coord": [0.87, 0.423], "to_coord": [0.8102, 0.9463]}, "pyautogui.dragTo(from_coord=[0.87, 0.423], to_coord=[0.8102, 0.9463])"), + ] + resolution = (1080, 1920) + print("Before conversion:") + for action in actions: + print(action) + print("\nAfter conversion:") + converted = action_conversion(actions, resolution) + for action in converted: + print(action) diff --git a/scripts/agents/config.py b/scripts/agents/config.py new file mode 100644 index 000000000..b83aa64dc --- /dev/null +++ b/scripts/agents/config.py @@ -0,0 +1,231 @@ + +# aguvis json file with mobile action space +MOBILE_FILE = [ + "android_control.json", + "gui-odyssey-l1.json", + "aitw-l3.json", + "coat.jsonamex-l2.json", + "amex-l1.json", + "amex-l3.json", + "gui-odyssey-l3.json", + "aitw-l1.json", + "aitw-l2.json", + "gui-odyssey-l2.json", +] + +# Processing: guienv +# Max conversations by image: 5 conversations +# Duplicates: 0 +# Duplicate images in guienv.json. difference: 257578 +# len(images_path): 327972 +# len(images_set_path): 70394 +# user/assistant by image: 3.6590902633747193 +# +# Processing: omniact +# Max conversations by image: 0 conversations +# Duplicates: 0 +# No duplicate images in omniact.json +# len(images_path): 6720 +# len(images_set_path): 6720 +# user/assistant by image: 0.0 +# +# Processing: ricoig16k +# Max conversations by image: 0 conversations +# Duplicates: 0 +# No duplicate images in ricoig16k.json +# len(images_path): 16133 +# len(images_set_path): 16133 +# user/assistant by image: 0.0 +# +# Processing: ricosca +# Max conversations by image: 20 conversations +# Duplicates: 0 +# Duplicate images in ricosca.json. difference: 155066 +# len(images_path): 173212 +# len(images_set_path): 18146 +# user/assistant by image: 8.54546456519343 +# +# Processing: seeclick +# Max conversations by image: 0 conversations +# Duplicates: 0 +# No duplicate images in seeclick.json +# len(images_path): 271121 +# len(images_set_path): 271121 +# user/assistant by image: 0.0 +# +# Processing: webui350k +# Max conversations by image: 0 conversations +# Duplicates: 0 +# No duplicate images in webui350k.json +# len(images_path): 57389 +# len(images_set_path): 57389 +# user/assistant by image: 0.0 +# +# Processing: ui_refexp +# Max conversations by image: 15 conversations +# Duplicates: 32 +# Duplicate images in ui_refexp.json. difference: 10978 +# len(images_path): 15624 +# len(images_set_path): 4646 +# user/assistant by image: 2.3628928110202323 +# +# Processing: widget_captioning +# Max conversations by image: 161 conversations +# Duplicates: 4877 +# Duplicate images in widget_captioning.json. difference: 87017 +# len(images_path): 101426 +# len(images_set_path): 14409 +# user/assistant by image: 6.039072801721146 +# +# total_samples = 458958 + +config_dict_stage_1 = [ + { + "json_path": "guienv.json", + "images_folder": "guienvs/images/", + }, + { + "json_path": "omniact.json", + "images_folder": "omniact/images/", + }, + { + "json_path": "ricoig16k.json", + "images_folder": "ricoig16k/images/", + }, + { + "json_path": "ricosca.json", + "images_folder": "ricosca/images/", + }, + { + "json_path": "seeclick.json", + "images_folder": "seeclick/seeclick_web_imgs/", + }, + { + "json_path": "webui350k.json", + "images_folder": "webui350k/images/", + }, + { + "json_path": "ui_refexp.json", + "images_folder": "ui_refexp/images/", + }, + { + "json_path": "widget_captioning.json", + "images_folder": "widget_captioning/images/", + }, + +] + + +# Processing: mind2web-l3 +# Max conversations by image: 0 conversations +# Duplicates: 0 +# No duplicate images in mind2web-l3.json +# len(images_path): 7591 +# len(images_set_path): 7591 +# user/assistant by image: 0.0 +# +# Processing: guiact-web-single +# Max conversations by image: 12 conversations +# Duplicates: 0 +# Duplicate images in guiact-web-single.json. difference: 54134 +# len(images_path): 67396 +# len(images_set_path): 13262 +# user/assistant by image: 4.081888101342181 +# +# Processing: guiact-web-multi-l3 +# Max conversations by image: 2 conversations +# Duplicates: 0 +# Duplicate images in guiact-web-multi-l3.json. difference: 24 +# len(images_path): 16704 +# len(images_set_path): 16680 +# user/assistant by image: 0.0014388489208633094 +# +# Processing: miniwob-l3 +# Max conversations by image: 6 conversations +# Duplicates: 0 +# Duplicate images in miniwob-l3.json. difference: 161 +# len(images_path): 9826 +# len(images_set_path): 9665 +# user/assistant by image: 0.016658044490429385 +# +# Processing: coat +# Max conversations by image: 0 conversations +# Duplicates: 0 +# No duplicate images in coat.json +# len(images_path): 11921 +# len(images_set_path): 11921 +# user/assistant by image: 0.0 +# +# Processing: android_control +# Max conversations by image: 0 conversations +# Duplicates: 0 +# No duplicate images in android_control.json +# len(images_path): 74714 +# len(images_set_path): 74714 +# user/assistant by image: 0.0 +# +# Processing: gui-odyssey-l3 +# Max conversations by image: 2 conversations +# Duplicates: 0 +# Duplicate images in gui-odyssey-l3.json. difference: 24 +# len(images_path): 118282 +# len(images_set_path): 118258 +# user/assistant by image: 0.0002029461008980365 +# +# Processing: amex-l3 +# Max conversations by image: 0 conversations +# Duplicates: 0 +# No duplicate images in amex-l3.json +# len(images_path): 38469 +# len(images_set_path): 38469 +# user/assistant by image: 0.0 +# +# Processing: aitw-l3 +# Max conversations by image: 0 conversations +# Duplicates: 0 +# No duplicate images in aitw-l3.json +# len(images_path): 18992 +# len(images_set_path): 18992 +# user/assistant by image: 0.0 +# +# Total samples: 309552 + + +config_dict_stage_2 = [ + { + "json_path": "mind2web-l3.json", + "images_folder": "mind2web/", + }, + { + "json_path": "guiact-web-single.json", + "images_folder": "guiact-web-single/images/", + }, + { + "json_path": "guiact-web-multi-l3.json", + "images_folder": "guiact-web-multi-v2/images", + }, + { + "json_path": "miniwob-l3.json", + "images_folder": "images", + }, + { + "json_path": "coat.json", + "images_folder": "coat/images/", + }, + { + "json_path": "android_control.json", + "images_folder": "android_control/images/", + }, + { + "json_path": "gui-odyssey-l3.json", + "images_folder": "gui-odyssey/images/", + }, + { + "json_path": "amex-l3.json", + "images_folder": "amex/images/", + }, + { + "json_path": "aitw-l3.json", + "images_folder": "aitw-v1/images/", + }, +] diff --git a/scripts/agents/function_parser.py b/scripts/agents/function_parser.py new file mode 100644 index 000000000..9b409a627 --- /dev/null +++ b/scripts/agents/function_parser.py @@ -0,0 +1,547 @@ +#!/usr/bin/env python3 +""" +Function parser for extracting function names, parameter names, and values from string function calls. +Supports both mobile and pyautogui function patterns. +""" + +import re +from typing import Dict, List, Tuple, Any, Union +from dataclasses import dataclass +from collections import OrderedDict + +@dataclass +class FunctionCall: + """Represents a parsed function call with its parameters.""" + function_name: str + parameters: Dict[str, Any] + original_string: str + description: str = "" + + def to_string(self) -> str: + """ + Reconstruct the function call string from the parsed data. + + Returns: + String representation of the function call + + Examples: + >>> call = FunctionCall("mobile.wait", {"seconds": 3}, "mobile.wait(seconds=3)") + >>> call.to_string() + "mobile.wait(seconds=3)" + + >>> call = FunctionCall("function", {"arg_0": 1, "arg_1": 2, "x": 0.5}, "function(1, 2, x=0.5)") + >>> call.to_string() + "function(1, 2, x=0.5)" + """ + if not self.parameters: + return f"{self.function_name}()" + + # Separate positional and named arguments + positional_args = [] + named_args = [] + + for name, value in self.parameters.items(): + if name.startswith("arg_"): + # Positional argument + positional_args.append((int(name.split("_")[1]), value)) + else: + # kwargs + named_args.append((name, value)) + + # Sort positional arguments by index + positional_args.sort(key=lambda x: x[0]) + + # Build parameter string + param_parts = [] + + # Add positional arguments + for _, value in positional_args: + param_parts.append(self._value_to_string(value)) + + # Add named arguments + for name, value in named_args: + param_parts.append(f"{name}={self._value_to_string(value)}") + + return f"{self.function_name}({', '.join(param_parts)})" + + def _value_to_string(self, value: Any) -> str: + """ + Convert a value to its string representation for function calls. + + Args: + value: The value to convert + + Returns: + String representation of the value + """ + if isinstance(value, str): + # Quote strings + return f"'{value}'" + elif isinstance(value, (list, tuple)): + # Convert lists/tuples to string representation + items = [self._value_to_string(item) for item in value] + return f"[{', '.join(items)}]" + elif isinstance(value, dict): + # Convert dictionaries to string representation + items = [f"'{k}': {self._value_to_string(v)}" for k, v in value.items()] + return f"{{{', '.join(items)}}}" + elif isinstance(value, bool): + # Convert booleans to lowercase + return str(value).lower() + elif value is None: + return "None" + else: + # Numbers and other types + return str(value) + + +def parse_function_call(function_string: str, pattern_to_match: list[str] = []) -> List[FunctionCall]: + """ + Parse a function call string and extract all function calls found. + + Args: + function_string: String representation of function calls + + Returns: + List of FunctionCall objects with parsed information + + Examples: + >>> parse_function_call("mobile.wait(seconds=3)") + [FunctionCall(function_name='wait', parameters={'seconds': 3}, ...)] + + >>> parse_function_call("mobile. wait(seconds=3)") + [FunctionCall(function_name='wait', parameters={'seconds': 3}, ...)] + + >>> parse_function_call("mobile.wait(seconds=3) mobile.home()") + [FunctionCall(function_name='wait', parameters={'seconds': 3}, ...), FunctionCall(function_name='home', parameters={}, ...)] + """ + # Remove any leading/trailing whitespace + function_string = function_string.strip() + + # Pattern to match function calls with parameters + # Matches: function_name(param1=value1, param2=value2, ...) + # Can have any characters before the function call, extracts just the function name + pattern = r'.*?([a-zA-Z_][a-zA-Z0-9_.]*)\(([^)]*)\)' + + matches = re.findall(pattern, function_string) + if not matches: + # No valid function calls found in: {function_string} + return [] + + results = [] + for match in matches: + function_name = match[0] + params_string = match[1] + + if pattern_to_match and all(pattern not in function_name for pattern in pattern_to_match): + continue + + # Parse parameters + parameters = parse_parameters(params_string) + + # Create the original string for this specific function call + original_string = f"{function_name}({params_string})" + + results.append(FunctionCall( + function_name=function_name, + parameters=parameters, + original_string=original_string + )) + + return results + + +def parse_parameters(params_string: str) -> Dict[str, Any]: + """ + Parse parameter string and extract parameter names and values. + + Args: + params_string: String containing parameters (e.g., "x=0.5, y=0.6, text='hello'") + + Returns: + Dictionary mapping parameter names to their values + + Examples: + >>> parse_parameters("x=0.5, y=0.6") + {'x': 0.5, 'y': 0.6} + + >>> parse_parameters("app_name='drupe'") + {'app_name': 'drupe'} + + >>> parse_parameters("'text'") + {'arg_0': 'text'} + + >>> parse_parameters("1, 3, 4") + {'arg_0': 1, 'arg_1': 3, 'arg_2': 4} + + >>> parse_parameters("arg1, arg2, x=0.5") + {'arg_0': 'arg1', 'arg_1': 'arg2', 'x': 0.5} + """ + if not params_string.strip(): + return {} + + parameters = OrderedDict() + + # Split by commas, but be careful with commas inside quotes or brackets + param_parts = split_parameters(params_string) + + positional_index = 0 + + for part in param_parts: + part = part.strip() + if not part: + continue + + # Parse individual parameter + name, value = parse_single_parameter(part) + + # For positional arguments, use index-based naming + if name.startswith("arg_"): + name = f"arg_{positional_index}" + positional_index += 1 + + parameters[name] = value + + return parameters + + +def split_parameters(params_string: str) -> List[str]: + """ + Split parameter string by commas, respecting quotes and brackets. + + Args: + params_string: String containing parameters + + Returns: + List of individual parameter strings + """ + parts = [] + current_part = "" + paren_count = 0 + bracket_count = 0 + brace_count = 0 + in_quotes = False + quote_char = None + + for char in params_string: + if char in ['"', "'"] and (not in_quotes or char == quote_char): + if not in_quotes: + in_quotes = True + quote_char = char + else: + in_quotes = False + quote_char = None + elif not in_quotes: + if char == '(': + paren_count += 1 + elif char == ')': + paren_count -= 1 + elif char == '[': + bracket_count += 1 + elif char == ']': + bracket_count -= 1 + elif char == '{': + brace_count += 1 + elif char == '}': + brace_count -= 1 + elif char == ',' and paren_count == 0 and bracket_count == 0 and brace_count == 0: + parts.append(current_part.strip()) + current_part = "" + continue + + current_part += char + + if current_part.strip(): + parts.append(current_part.strip()) + + return parts + + +def parse_single_parameter(param_string: str) -> Tuple[str, Any]: + """ + Parse a single parameter string into name and value. + + Args: + param_string: String like "x=0.5" or "app_name='drupe'" or just "value" + + Returns: + Tuple of (parameter_name, parameter_value) + + Examples: + >>> parse_single_parameter("x=0.5") + ('x', 0.5) + + >>> parse_single_parameter("app_name='drupe'") + ('app_name', 'drupe') + + >>> parse_single_parameter("'text'") + ('arg_0', 'text') + + >>> parse_single_parameter("123") + ('arg_0', 123) + + >>> parse_single_parameter("3") + ('arg_0', 3) + """ + # Pattern to match parameter name and value + pattern = r'^([a-zA-Z_][a-zA-Z0-9_]*)\s*=\s*(.+)$' + + match = re.match(pattern, param_string) + if match: + # Named parameter + param_name = match.group(1) + param_value_str = match.group(2).strip() + param_value = parse_value(param_value_str) + return param_name, param_value + else: + # Positional parameter - treat as unnamed argument + param_value = parse_value(param_string) + return "arg_0", param_value + + +def parse_value(value_string: str) -> Any: + """ + Parse a value string into appropriate Python type. + + Args: + value_string: String representation of a value + + Returns: + Parsed value (int, float, str, list, etc.) + + Examples: + >>> parse_value("3") + 3 + + >>> parse_value("3.14") + 3.14 + + >>> parse_value("'hello'") + 'hello' + + >>> parse_value("[0.581, 0.898]") + [0.581, 0.898] + """ + value_string = value_string.strip() + + # String values (quoted) + if (value_string.startswith("'") and value_string.endswith("'")) or \ + (value_string.startswith('"') and value_string.endswith('"')): + return value_string[1:-1] + + # List values + if value_string.startswith('[') and value_string.endswith(']'): + return parse_list(value_string) + + # Dictionary values + if value_string.startswith('{') and value_string.endswith('}'): + return parse_dict(value_string) + + # Boolean values + if value_string.lower() in ['true', 'false']: + return value_string.lower() == 'true' + + # None value + if value_string.lower() == 'none': + return None + + # Numeric values + try: + # Try integer first + if '.' not in value_string: + return int(value_string) + else: + return float(value_string) + except ValueError: + # If it's not a number, return as string (remove quotes if present) + if value_string.startswith("'") and value_string.endswith("'"): + return value_string[1:-1] + elif value_string.startswith('"') and value_string.endswith('"'): + return value_string[1:-1] + else: + return value_string + + +def parse_list(list_string: str) -> List[Any]: + """ + Parse a list string into a Python list. + + Args: + list_string: String like "[0.581, 0.898]" + + Returns: + List of parsed values + + Examples: + >>> parse_list("[0.581, 0.898]") + [0.581, 0.898] + """ + # Remove outer brackets + content = list_string[1:-1].strip() + if not content: + return [] + + # Split by commas, respecting nested structures + parts = split_parameters(content) + + return [parse_value(part.strip()) for part in parts] + + +def parse_dict(dict_string: str) -> Dict[str, Any]: + """ + Parse a dictionary string into a Python dict. + + Args: + dict_string: String like "{'key': 'value'}" + + Returns: + Dictionary of parsed key-value pairs + """ + # Remove outer braces + content = dict_string[1:-1].strip() + if not content: + return {} + + # Split by commas, respecting nested structures + parts = split_parameters(content) + + result = {} + for part in parts: + part = part.strip() + if ':' in part: + key_str, value_str = part.split(':', 1) + key = parse_value(key_str.strip()) + value = parse_value(value_str.strip()) + result[key] = value + + return result + + +def parse_multiple_functions(function_strings: List[str]) -> List[FunctionCall]: + """ + Parse multiple function call strings. + + Args: + function_strings: List of function call strings + + Returns: + List of FunctionCall objects + """ + results = [] + for func_str in function_strings: + try: + result_list = parse_function_call(func_str) + results.extend(result_list) + except Exception as e: + print(f"Warning: Could not parse function call '{func_str}': {e}") + continue + + return results + + +def extract_function_calls_from_text(text: str) -> List[FunctionCall]: + """ + Extract and parse function calls from a text block. + + Args: + text: Text containing function calls + + Returns: + List of FunctionCall objects + """ + # Pattern to find function calls in text + # Matches: function_name(param1=value1, param2=value2) + pattern = r'[a-zA-Z_][a-zA-Z0-9_.]*\([^)]*\)' + + matches = re.findall(pattern, text) + return parse_multiple_functions(matches) + + +# Example usage and testing +if __name__ == "__main__": + test_cases = [ + "mobile.home()", + "mobile.open_app(app_name='drupe')", + "mobile.swipe(from_coord=[0.581, 0.898], to_coord=[0.601, 0.518])", + "mobile.back()", + "mobile.long_press(x=0.799, y=0.911)", + "mobile.terminate(status='success')", + "answer('text')", + "pyautogui.hscroll(page=-0.1)", + "pyautogui.scroll(page=-0.1)", + "pyautogui.scroll(0.13)", + "pyautogui.click(x=0.8102, y=0.9463)", + "pyautogui.hotkey(keys=['ctrl', 'c'])", + "pyautogui.doubleClick()", + "pyautogui.press(keys='enter')", + "pyautogui.press(keys=['enter'])", + "pyautogui.moveTo(x=0.04, y=0.405)", + "pyautogui.write(message='bread buns')", + "pyautogui.dragTo(x=0.8102, y=0.9463)", + "mobile.wait(seconds=3)\nmobile.swipe(from_coord=[0.581, 0.898], to_coord=[0.601, 0.518])", + # Additional test cases for multiple positional arguments + "function(arg1, arg2, arg3)", + "function('hello', 123, x=0.5)", + "function(arg1, arg2, named_param='value')", + "function(1, 2, 3, 4, 5)", + "function('a', 'b', 'c', x=1, y=2)", + ] + + print("Testing function parser:") + print("=" * 50) + + for test_case in test_cases: + try: + results = parse_function_call(test_case) + print(f"✓ {test_case}") + for result in results: + print(f" Function: {result.function_name}") + print(f" Parameters: {result.parameters}") + print() + except Exception as e: + print(f"✗ {test_case}") + print(f" Error: {e}") + print() + + # Test extracting from text + print("Testing text extraction:") + print("=" * 50) + + sample_text = """ + mobile.wait(seconds=3) + mobile.open_app(app_name='drupe') + pyautogui.click(x=0.8102, y=0.9463) + pyautogui.write(message='bread buns') + """ + + extracted = extract_function_calls_from_text(sample_text) + for func_call in extracted: + print(f"Found: {func_call.function_name} with params: {func_call.parameters}") + + # Test reconstruction + print("\nTesting function call reconstruction:") + print("=" * 50) + + reconstruction_tests = [ + "mobile.wait(seconds=3)", + "mobile.home()", + "mobile.open_app(app_name='drupe')", + "mobile.swipe(from_coord=[0.581, 0.898], to_coord=[0.601, 0.518])", + "answer('text')", + "pyautogui.scroll(0.13)", + "pyautogui.click(x=0.8102, y=0.9463)", + "pyautogui.hotkey(keys=['ctrl', 'c'])", + "function(1, 2, 3)", + "function('hello', 123, x=0.5, y=0.8)", + "function([1, 3], 'arg2', named_param='value')", + ] + + for test_case in reconstruction_tests: + parsed_list = parse_function_call(test_case) + for parsed in parsed_list: + reconstructed = parsed.to_string() + print(f"Original: {test_case}") + print(f"Reconstructed: {reconstructed}") + print(f"Match: {test_case == reconstructed}") + assert test_case == reconstructed + print() \ No newline at end of file diff --git a/scripts/agents/get_aguvis_data.py b/scripts/agents/get_aguvis_data.py new file mode 100644 index 000000000..25452501b --- /dev/null +++ b/scripts/agents/get_aguvis_data.py @@ -0,0 +1,519 @@ +#!/usr/bin/env python3 +""" +Script to download, process, and upload the aguvis-stage2 dataset. +Downloads from huggingface.co/datasets/xlangai/aguvis-stage2 and uploads to smolagents/aguvis-stage-2 +""" + +import re +import gc +import sys +import json +import os +import shutil +import zipfile +from copy import deepcopy +from pathlib import Path +from typing import Any, Dict, List, Generator, Callable, Literal +from tqdm import tqdm +from datasets import Dataset, load_dataset, concatenate_datasets +from dotenv import load_dotenv +from huggingface_hub import HfApi, login, snapshot_download +from collections import defaultdict +from PIL import Image +import tarfile +from itertools import islice +import multiprocessing as mp +from multiprocessing import Pool, Manager +from prompts import OS_SYSTEM_PROMPT, MOBILE_SYSTEM_PROMPT +from models import ConversationDataList, ConversationData, ChatMessage, DataRow +from function_parser import parse_function_call +from action_conversion import action_conversion +from pydantic import BaseModel +from config import config_dict_stage_1, config_dict_stage_2, MOBILE_FILE + + +api = HfApi() + + +def authenticate_huggingface(): + """Authenticate with HuggingFace Hub using token.""" + hf_token = os.getenv("HF_TOKEN") + if hf_token: + print("Authenticating with HuggingFace Hub using token...") + login(token=hf_token) + else: + raise ValueError("HF_TOKEN environment variable not set.") + + +def discover_dataset_config(dataset_path: str, config_dict: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Discover dataset configuration by scanning the data directory.""" + dataset_dir = Path(dataset_path) + train_dir = dataset_dir + + if not train_dir.exists(): + raise FileNotFoundError(f"Train directory not found: {train_dir}") + + configs = [] + processed_splits = set() + + # Find all JSON files in the train directory + for config in config_dict: + subset_name = ( + config["json_path"] + .replace(".json", "") + .replace("-l1", "") + .replace("-l2", "") + .replace("-l3", "") + ) + + # Skip if we already processed this split + if subset_name in processed_splits: + continue + + config["subset_name"] = subset_name + configs.append(config) + processed_splits.add(subset_name) + print( + f"Discovered config: {config['subset_name']} -> {config['images_folder']}" + ) + + return configs + + +def download_dataset( + repo_id: str = "xlangai/aguvis-stage2", local_dir: str = "./aguvis_raw" +) -> str: + """Download the dataset using snapshot_download.""" + print(f"Downloading dataset from {repo_id}...") + local_path = snapshot_download( + repo_id=repo_id, local_dir=local_dir, repo_type="dataset" + ) + print(f"Dataset downloaded to: {local_path}") + return local_path + + +def extract_zip_files(dataset_path: str): + """Extract all zip files found in the dataset directory, but only if not already extracted.""" + print("Extracting zip files...") + dataset_dir = Path(dataset_path) + + for zip_file in dataset_dir.rglob("*.zip"): + extract_dir = zip_file.parent / zip_file.stem + if extract_dir.exists() and any(extract_dir.iterdir()): + print( + f"Skipping extraction for {zip_file} (already extracted at {extract_dir})" + ) + continue + + print(f"Extracting: {zip_file}") + with zipfile.ZipFile(zip_file, "r") as zip_ref: + zip_ref.extractall(extract_dir) + print(f"Extracted to: {extract_dir}") + + +def extract_tar_parts_grouped(dataset_path: str): + """ + Finds all .tar.gz.part_* groups, merges them, and extracts them into directories + named after their common prefix. + """ + dataset_dir = Path(dataset_path) + part_files = list(dataset_dir.glob("*.tar.gz.part_*")) + + if not part_files: + print("No split .tar.gz.part_* files found.") + return + + # Group part files by prefix + groups = defaultdict(list) + for part in part_files: + prefix = part.name.split(".tar.gz.part_")[0] + groups[prefix].append(part) + + for prefix, parts in groups.items(): + parts = sorted(parts) # Ensure correct order + merged_tar_path = dataset_dir / f"{prefix}.tar.gz" + extract_dir = dataset_dir / prefix + + if extract_dir.exists() and any(extract_dir.iterdir()): + print( + f"Skipping extraction for '{prefix}' (already extracted at {extract_dir})" + ) + continue + + # Merge parts + CHUNK_SIZE = 1024 * 1024 + print(f"Merging parts for '{prefix}'...") + with open(merged_tar_path, "wb") as outfile: + for part in parts: + print(f" Adding: {part.name}") + with open(part, "rb") as infile: + while chunk := infile.read(CHUNK_SIZE): + outfile.write(chunk) + + print(f"Merged to: {merged_tar_path}") + + # Extract + print(f"Extracting to: {extract_dir}") + with tarfile.open(merged_tar_path, "r:gz") as tar: + tar.extractall(path=extract_dir) + print(f"Done extracting '{prefix}'\n") + + +def check_subset_exists(repo_id: str, subset_name: str) -> bool: + """Check if a subset already exists in the remote dataset.""" + try: + # Try to get dataset info with specific subset + from datasets import get_dataset_config_names + + config_names = get_dataset_config_names(repo_id) + return subset_name in config_names + except Exception as e: + print(f"Could not check if subset exists: {e}") + return False + + +def load_image_from_folder(images_folder: Path, img_path: str) -> Image.Image: + """Load images from the specified folder.""" + full_path = images_folder / img_path + img = Image.open(full_path) + new_img = img.copy() + img.close() + return new_img + + +def convert_to_code_agent_format(messages: list[ChatMessage], json_path: str, reasoning: bool): + for i, message in enumerate(messages): + content = message.content + + if message.role == "system": + if json_path in MOBILE_FILE: + content = MOBILE_SYSTEM_PROMPT + else: + content = OS_SYSTEM_PROMPT + + if message.role == "user": + content = content.replace("\n", "").replace("", "") + + elif message.role == "assistant": + content = ( + content.replace("Action: ", "") + .replace("Observation: ", "") + .replace("Thought: ", "") + ) + if reasoning and i == len(messages) - 1: + content = ( + "\n" + content.strip() + "\n" + ) + elif reasoning: + # TODO: Check if there is always only 2 assistants + content = ( + "\n" + + content.strip() + + "\n\n" + ) + else: + content = content.strip() + + messages[i].content = content + + # Fuse subsequent messages have the same role, merge it + if i > 0 and messages[i].role == messages[i - 1].role: + # Need to fuse both messages + if reasoning: + messages[i - 1].content += messages[i].content + else: + messages[i - 1].content += "\n" + messages[i].content + messages.pop(i) + + return messages + + +def convert_to_chat_format( + data: ConversationData, json_path: str, reasoning: bool +) -> list[ChatMessage]: + """Convert data item to chat template format.""" + # This is a placeholder - you'll need to adapt this based on the actual data structure + # The exact conversion depends on how the original data is structured + chat_messages = data.to_chat_messages() + # mobile = json_path in open("mobile_files.txt", "r").read() + # os = json_path in open("os_files.txt", "r").read() + # if not mobile and not os: + # for message in chat_messages: + # if mobile and os: + # break + # if message.role == "assistant": + # if not mobile and "mobile" in message.content: + # with open("mobile_files.txt", "a") as mobile_files: + # mobile_files.write(json_path + "\n") + # mobile = True + # if not os and "pyautogui" in message.content: + # with open("os_files.txt", "a") as os_files: + # os_files.write(json_path + "\n") + # os = True + # Aguvis stage 1 + chat_messages = convert_to_code_agent_format(chat_messages, json_path, reasoning) + return chat_messages + + +def convert_to_new_action_space( + messages: list[ChatMessage], resolution: tuple[int, int], code_format: bool = True +) -> list[ChatMessage]: + regex_match: re.Match | str | None = None + index = -1 + regex = r"\n(.*?)\n" + assistant_msg = [(i, message) for i, message in enumerate(messages) if message.role == "assistant"] + if assistant_msg: + for index, msg in assistant_msg: + + if code_format: + regex_match = re.search(regex, msg.content, re.DOTALL) + else: + regex_match = msg.content + + if regex_match is not None: + function_calls = parse_function_call( + regex_match.group(1) if isinstance(regex_match, re.Match) else regex_match, + pattern_to_match=["pyautogui", "mobile", "terminate", "answer"], + ) + + + if len(function_calls) > 0: + + for i, function_call in enumerate(deepcopy(function_calls)): + + if function_call.function_name == "pyautogui.dragTo" and not isinstance(list(function_calls[i].parameters.values())[0], (list, tuple)): + x1, y1 = islice(function_calls[i-1].parameters.values(), 2) + x2, y2 = islice(function_calls[i].parameters.values(), 2) + function_calls[i].parameters = {"from_coord": (x1, y1), "to_coord": (x2, y2)} + function_calls[i].original_string = function_calls[i].to_string() + function_calls.pop(i-1) + + function_calls = action_conversion(function_calls, resolution=resolution) + + new_action_string = "\n".join( + [function_call.to_string() for function_call in function_calls] + ) + messages[index].content = messages[index].content.replace( + regex_match.group(1) if isinstance(regex_match, re.Match) else regex_match, new_action_string + ) + + + return messages + + +def process_subset( + config: Dict[str, Any], + dataset_path: str, +) -> tuple[ConversationDataList, Path]: + """Process a single dataset subset.""" + subset_name = config["subset_name"] + + print(f"Processing split: {subset_name}") + + dataset_dir = Path(dataset_path) + images_folder = dataset_dir / config["subset_name"] / config["images_folder"] + + if not images_folder.exists(): + print(f"Images folder not found: {images_folder}") + else: + print(f"Images folder: {images_folder}") + + json_config_path = dataset_dir / config["json_path"] + with open(json_config_path, "r") as f: + data = ConversationDataList.model_validate_json(f.read()) + # data = f.read() + print(f"Added '{json_config_path}'") + + return data, images_folder + + +def row_generator( + data: ConversationDataList, images_folder: Path, json_path: str, reasoning: bool +) -> Generator[Dict[str, Any], None, None]: + conversations: list[ConversationData] = data.root + for item in tqdm(conversations): + # Extract image paths from the data item + try: + # Load images + image = load_image_from_folder(images_folder, item.image) + chat_message = convert_to_chat_format(item, json_path, reasoning) + chat_message = convert_to_new_action_space(chat_message, image.size, code_format=reasoning) + if len(chat_message) == 0: + continue + + row = DataRow.from_chat_messages(chat_message, image, source=json_path.split("/")[-1].split(".")[0]) + yield row.model_dump(exclude_none=True) + del image + except Exception as e: + import traceback + traceback.print_exc() + print(f"Error processing item: {e}", item) + continue + + +class DatasetConfig(BaseModel): + huggingface_repo_id: str + local_path: str + config_dict: List[Dict[str, Any]] + smolagents_repo_id: str + reasoning: bool + + +def process_single_config(config: Dict[str, Any], dataset_path: str, smolagents_repo_id: str, reasoning: bool) -> bool: + """Process a single config in a separate process.""" + try: + # Authenticate in this process + authenticate_huggingface() + + print(f"\n{'=' * 50}") + print(f"Processing config: {config}") + + # Check if the subset already exists in the remote dataset + subset_name = config["subset_name"] + # if check_subset_exists(smolagents_repo_id, subset_name): + # print( + # f"Subset '{subset_name}' already exists in {smolagents_repo_id}, skipping processing." + # ) + # return True + + json_path = config["json_path"] + data, image_folder = process_subset(config, dataset_path) + + # Collect all rows first + rows = [] + datasets = [] + for row in row_generator(data, image_folder, json_path, reasoning): + rows.append(row) + if len(rows) > 20000: + print("Creating batch dataset") + dataset = Dataset.from_list(rows) + datasets.append(dataset) + rows = [] + gc.collect() + + if len(rows) > 0: + # Create dataset from collected data + dataset = Dataset.from_list(rows) + datasets.append(dataset) + rows = [] + + dataset_to_push = concatenate_datasets(datasets) + + # Push to hub + dataset_to_push.push_to_hub( + smolagents_repo_id, + # config_name=subset_name, # This sets the subset name + split="train", # This should be "train" not the subset name + ) + + print(f"Processed and uploaded subset: {config['subset_name']}") + + # Force garbage collection to manage memory + gc.collect() + + return True + + except Exception as e: + print(f"Error processing config {config.get('subset_name', 'unknown')}: {e}") + import traceback + traceback.print_exc() + return False + + +def make_dataset_from_original_data(dataset_config: DatasetConfig, max_processes: int | None = None): + """Main function to orchestrate the entire process.""" + load_dotenv(override=True) + + print(f"Starting {dataset_config.smolagents_repo_id} dataset processing...") + + # Step 0: Authenticate with HuggingFace Hub + authenticate_huggingface() + + dataset_path = download_dataset( + dataset_config.huggingface_repo_id, dataset_config.local_path + ) + + # extract_zip_files(dataset_path) + # extract_tar_parts_grouped(dataset_path) + + dataset_configs = discover_dataset_config(dataset_path, dataset_config.config_dict) + converted_repo_id = dataset_config.smolagents_repo_id + reasoning = dataset_config.reasoning + + # Use multiprocessing to process configs in parallel + available_cpus = mp.cpu_count() + if max_processes is None: + max_processes = available_cpus + num_processes = min(max_processes, len(dataset_configs)) + print(f"Using {num_processes} processes (out of {available_cpus} available CPUs) to process {len(dataset_configs)} configs") + + # Prepare arguments for multiprocessing + process_args = [ + (config, dataset_path, converted_repo_id, reasoning) + for config in dataset_configs if config["subset_name"] if config["subset_name"] in ["guiact-web-single"] + ] + + # Process configs in parallel with progress tracking + print(f"Starting parallel processing of {len(dataset_configs)} configs...") + try: + with Pool(processes=num_processes) as pool: + results = [] + for i, result in enumerate(pool.starmap(process_single_config, process_args)): + results.append(result) + print(f"Completed {i+1}/{len(dataset_configs)} configs") + except Exception as e: + print(f"Multiprocessing failed: {e}") + print("Falling back to sequential processing...") + results = [] + for i, args in enumerate(process_args): + result = process_single_config(*args) + results.append(result) + print(f"Completed {i+1}/{len(dataset_configs)} configs (sequential)") + + # Check results + successful = sum(results) + total = len(dataset_configs) + print(f"\nProcessing complete: {successful}/{total} configs processed successfully") + + if successful < total: + failed_count = total - successful + print(f"Warning: {failed_count} configs failed to process. Check the logs above for details.") + else: + print("All configs processed successfully!") + +# # Cleanup +# print("\nCleaning up temporary files...") +# # shutil.rmtree(dataset_path, ignore_errors=True) +# +# # api.upload_large_folder(folder_path=converted_folder, repo_id="smolagents/aguvis-stage-2", repo_type="dataset") +# +# shutil.rmtree(converted_folder, ignore_errors=True) +# +# print("All done!") + + +if __name__ == "__main__": + # dataset_config_1 = DatasetConfig( + # huggingface_repo_id="xlangai/aguvis-stage1", + # local_path="/fsx/amir_mahla/aguvis_raw_stage_1", + # config_dict=config_dict_stage_1, + # smolagents_repo_id="smolagents/aguvis-stage-1", + # reasoning=False, + # ) + # dataset_config_2 = DatasetConfig( + # huggingface_repo_id="xlangai/aguvis-stage2", + # local_path="/fsx/amir_mahla/aguvis_raw_stage_2", + # config_dict=config_dict_stage_2, + # smolagents_repo_id="smolagents/aguvis-stage-2", + # reasoning=True, + # ) + dataset_config_3 = DatasetConfig( + huggingface_repo_id="xlangai/aguvis-stage2", + local_path="/fsx/amir_mahla/aguvis_raw_stage_2", + config_dict=config_dict_stage_2, + smolagents_repo_id="smolagents/guiact-web-single", + reasoning=True, + ) + # You can specify max_processes to limit the number of parallel processes + # make_dataset_from_original_data(dataset_config_1, max_processes=4) + make_dataset_from_original_data(dataset_config_3, 1) diff --git a/scripts/agents/models.py b/scripts/agents/models.py new file mode 100644 index 000000000..a92366f40 --- /dev/null +++ b/scripts/agents/models.py @@ -0,0 +1,154 @@ +from typing import List, Optional, Literal +from pydantic import BaseModel, Field, RootModel, field_validator, model_validator +from copy import deepcopy +from PIL import Image +from collections import OrderedDict + +class ChatMessage(BaseModel): + role: Literal["user", "assistant", "system"] + content: str + + @staticmethod + def from_conversation_list(data: list[dict[str, str]]) -> list["ChatMessage"]: + messages = [] + system_added = False + for item in data: + if item["from"] == "system": + if not system_added: + role: Literal["user", "assistant", "system"] = "system" + messages.append(ChatMessage(role=role, content=item["value"])) + system_added = True + elif item["from"] == "human": + role = "user" + messages.append(ChatMessage(role=role, content=item["value"])) + else: + role = "assistant" + messages.append(ChatMessage(role=role, content=item["value"])) + + return messages + + +class ConversationEntry(BaseModel): + from_: Literal["system", "human", "gpt"] = Field(alias="from") + value: str + recipient: Optional[str] = None + end_turn: Optional[bool] = None + + def to_chat_message(self) -> ChatMessage: + if self.from_ == "system": + role: Literal["user", "assistant", "system"] = "system" + elif self.from_ == "human": + role = "user" + else: + role = "assistant" + return ChatMessage(role=role, content=self.value) + +class ConversationData(BaseModel): + image: str + conversations: List[ConversationEntry] + recipient: Optional[str] = None + end_turn: Optional[bool] = None + + @field_validator("image", mode="before") + def validate_image(cls, v): + if isinstance(v, list): + if len(v) == 1: + return v[0] + elif len(v) == 2: + return v[1] + else: + raise ValueError("Expected 1 or 2 images, got multiple") + return v + + + def to_chat_messages(self) -> list[ChatMessage]: + return [conversation.to_chat_message() for conversation in self.conversations] + +class ConversationDataList(RootModel[List[ConversationData]]): + + @model_validator(mode="after") + def validate_conversation(self): + new_conversations: dict[str, List[ConversationData]] = {} + + # merge image duplicates + for conversation in self.root: + if conversation.image not in new_conversations: + new_conversations[conversation.image] = [conversation] + else: + new_conversations[conversation.image].append(conversation) + + # delete text duplicates + duplicates = 0 + for data in new_conversations.values(): + if isinstance(data, list): + index_to_pop = set() + for i in range(len(data) - 1): + for j in range(i + 1, len(data)): + if [c1.model_dump() for c1 in data[i].conversations] == [c2.model_dump() for c2 in data[j].conversations]: + if j not in index_to_pop: + duplicates += 1 + index_to_pop.add(j) + for index in sorted(index_to_pop, reverse=True): + data.pop(index) + + # delete text duplicates + new_data = [] + for data in new_conversations.values(): + for i in range(len(data)): + if i == 0: + new_data.append(data[i]) + else: + new_data[-1].conversations.extend(data[i].conversations) + + + self.root = new_data + + return self + +class DataRow(BaseModel): + images: list[Image.Image] + texts: list[OrderedDict[str, str]] + source: str + + class Config: + arbitrary_types_allowed = True + + @classmethod + def from_chat_messages(cls, messages: list[ChatMessage], image: Image.Image, source: str) -> "DataRow": + + system, user, assistant = None, None, None + have_system = any(message.role == "system" for message in messages) + texts: list[OrderedDict[str, str]] = [] + images = [image] + chat_messages: OrderedDict[str, str] = OrderedDict() + for message in messages: + if message.role == "system": + system = message.content + elif message.role == "user": + user = message.content + elif message.role == "assistant": + assistant = message.content + + if have_system and user is not None and assistant is not None and system is not None: + chat_messages["system"] = system + chat_messages["user"] = user + chat_messages["assistant"] = assistant + texts.append(chat_messages) + chat_messages = OrderedDict() + user, assistant = None, None + + elif not have_system and user is not None and assistant is not None: + chat_messages["user"] = user + chat_messages["assistant"] = assistant + texts.append(chat_messages) + chat_messages = OrderedDict() + user, assistant = None, None + + return cls(images=images, texts=texts, source=source) + + def to_model_dump(self) -> dict: + return { + "images": self.images, + "texts": self.texts, + "source": self.source, + } \ No newline at end of file diff --git a/scripts/agents/prompts.py b/scripts/agents/prompts.py new file mode 100644 index 000000000..0593162e2 --- /dev/null +++ b/scripts/agents/prompts.py @@ -0,0 +1,145 @@ +from typing import Literal + +OS_ACTIONS = """ +def final_answer(answer: any) -> any: + \"\"\" + Provides a final answer to the given problem. + Args: + answer: The final answer to the problem + \"\"\" + +def move_mouse(self, x: float, y: float) -> str: + \"\"\" + Moves the mouse cursor to the specified coordinates + Args: + x: The x coordinate (horizontal position) + y: The y coordinate (vertical position) + \"\"\" + +def click(x: Optional[float] = None, y: Optional[float] = None) -> str: + \"\"\" + Performs a left-click at the specified normalized coordinates + Args: + x: The x coordinate (horizontal position) + y: The y coordinate (vertical position) + \"\"\" + +def double_click(x: Optional[float] = None, y: Optional[float] = None) -> str: + \"\"\" + Performs a double-click at the specified normalized coordinates + Args: + x: The x coordinate (horizontal position) + y: The y coordinate (vertical position) + \"\"\" + +def type(text: str) -> str: + \"\"\" + Types the specified text at the current cursor position. + Args: + text: The text to type + \"\"\" + +def press(keys: str | list[str]) -> str: + \"\"\" + Presses a keyboard key + Args: + keys: The key or list of keys to press (e.g. "enter", "space", "backspace", "ctrl", etc.). + \"\"\" + +def navigate_back() -> str: + \"\"\" + Goes back to the previous page in the browser. If using this tool doesn't work, just click the button directly. + \"\"\" + +def drag(from_coord: list[float], to_coord: list[float]) -> str: + \"\"\" + Clicks [x1, y1], drags mouse to [x2, y2], then release click. + Args: + x1: origin x coordinate + y1: origin y coordinate + x2: end x coordinate + y2: end y coordinate + \"\"\" + +def scroll(direction: Literal["up", "down"] = "down", amount: int = 1) -> str: + \"\"\" + Moves the mouse to selected coordinates, then uses the scroll button: this could scroll the page or zoom, depending on the app. DO NOT use scroll to move through linux desktop menus. + Args: + x: The x coordinate (horizontal position) of the element to scroll/zoom, defaults to None to not focus on specific coordinates + y: The y coordinate (vertical position) of the element to scroll/zoom, defaults to None to not focus on specific coordinates + direction: The direction to scroll ("up" or "down"), defaults to "down". For zoom, "up" zooms in, "down" zooms out. + amount: The amount to scroll. A good amount is 1 or 2. + \"\"\" + +def wait(seconds: float) -> str: + \"\"\" + Waits for the specified number of seconds. Very useful in case the prior order is still executing (for example starting very heavy applications like browsers or office apps) + Args: + seconds: Number of seconds to wait, generally 2 is enough. + \"\"\" +""" + +MOBILE_ACTIONS = """ +def navigate_back() -> str: + \"\"\" + Return to home page + \"\"\" + +def open_app(app_name: str) -> str: + \"\"\" + Launches the specified application. + Args: + app_name: the name of the application to launch + \"\"\" + +def swipe(from_coord: list[str], to_coord: list[str]) -> str: + \"\"\" + swipe from 'from_coord' to 'to_coord' + Args: + from_coord: origin coordinates + to_coord: end coordinates + \"\"\" + +def long_press(x: int, y: int) -> str: + \"\"\" + Performs a long-press at the specified coordinates + Args: + x: The x coordinate (horizontal position) + y: The y coordinate (vertical position) + \"\"\" +""" + +OS_SYSTEM_PROMPT = f"""You are a helpful GUI agent. You’ll be given a task and a screenshot of the screen. Complete the task using Python function calls. + +For each step: + • First, to express the thought process guiding your next action and the reasoning behind it. + • Then, use to perform the action. it will be executed in a stateful environment. + +The following functions are exposed to the Python interpreter: + +{OS_ACTIONS} + + +The state persists between code executions: so if in one step you've created variables or imported modules, these will all persist. +""" + +MOBILE_SYSTEM_PROMPT = f"""You are a helpful GUI agent. You’ll be given a task and a screenshot of the screen. Complete the task using Python function calls. + +For each step: + • First, to express the thought process guiding your next action and the reasoning behind it. + • Then, use to perform the action. it will be executed in a stateful environment. + +The following functions are exposed to the Python interpreter: + + +# OS ACTIONS + +{OS_ACTIONS} + +# MOBILE ACTIONS + +{MOBILE_ACTIONS} + + +The state persists between code executions: so if in one step you've created variables or imported modules, these will all persist. +""" \ No newline at end of file diff --git a/scripts/agents/qwenvl_collator.py b/scripts/agents/qwenvl_collator.py new file mode 100644 index 000000000..e28134afa --- /dev/null +++ b/scripts/agents/qwenvl_collator.py @@ -0,0 +1,153 @@ +from PIL import Image +from scripts.agents.function_parser import parse_function_call + +from qwen_vl_utils import smart_resize + +def resize_images_in_messages(batch_messages, script_args) -> list[Image.Image]: + + min_pixels = script_args.image_resize["min_pixels"] + max_pixels = script_args.image_resize["max_pixels"] + factor = script_args.image_resize["factor"] + + all_image_inputs = [] + for messages in batch_messages: + + old_image = messages[1]["content"][0]["image"] + resized_height, resized_width = smart_resize( + old_image.height, + old_image.width, + factor=factor, + min_pixels=min_pixels, + max_pixels=max_pixels, + ) + new_image = old_image.resize((resized_width, resized_height)) + messages[1]["content"][0]["image"] = new_image + + function_calls = parse_function_call(messages[2]["content"]) + old_function_call_strings = [ + function_call.to_string() for function_call in function_calls + ] + for function_call, old_function_call_string in zip(function_calls, old_function_call_strings): + if function_call.function_name in [ + "click", + "long_press", + "double_click", + "move_mouse", + ]: + function_call.parameters["arg_0"] = ( + int(function_call.parameters["arg_0"] + / old_image.width + * new_image.width) + ) + function_call.parameters["arg_1"] = ( + int(function_call.parameters["arg_1"] + / old_image.height + * new_image.height) + ) + elif function_call.function_name in ["swipe", "drag"]: + function_call.parameters["arg_0"] = ( + int(function_call.parameters["arg_0"][0] + / old_image.width + * new_image.width), + int(function_call.parameters["arg_0"][1] + / old_image.height + * new_image.height) + ) + function_call.parameters["arg_1"] = ( + int(function_call.parameters["arg_1"][0] + / old_image.width + * new_image.width), + int(function_call.parameters["arg_1"][1] + / old_image.height + * new_image.height) + ) + messages[2]["content"] = messages[2]["content"].replace(old_function_call_string, function_call.to_string()) + + all_image_inputs.append([new_image]) + return all_image_inputs + +def create_vlm_collate_fn(processor, script_args): + """Optimized collate function for VLM training that masks system prompt tokens.""" + + def collate_fn(examples: list[dict[str, str | Image.Image]]): + batch_messages = [] + system_prompts = [] + user_prompts = [] + for example in examples: + system = example["system"] + user = example["user"] + assistant = example["assistant"] + image = example["image"] + + system_prompts.append(system) + user_prompts.append(user) + batch_messages.append( + [ + {"role": "system", "content": system}, + { + "role": "user", + "content": [ + {"type": "image", "image": image}, + {"type": "text", "text": user}, + ], + }, + {"role": "assistant", "content": assistant}, + ] + ) + + all_image_inputs = [] + if script_args.image_resize is not None: + all_image_inputs = resize_images_in_messages(batch_messages, script_args) + + + texts = [ + processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=False + ) + for messages in batch_messages + ] + + batch = processor( + text=texts, + images=all_image_inputs if all_image_inputs else None, + padding=True, + return_tensors="pt", + max_length=4096, + ) + + input_ids = batch["input_ids"] + labels = input_ids.clone() + labels[labels == processor.tokenizer.pad_token_id] = -100 + + if hasattr(processor, "image_token"): + image_token_id = processor.tokenizer.convert_tokens_to_ids( + processor.image_token + ) + if image_token_id is not None: + labels[labels == image_token_id] = -100 + else: + raise ValueError("Processor does not have image_token") + + system_encodings = processor.tokenizer( + system_prompts, add_special_tokens=False, padding=False + )["input_ids"] + + user_encodings = processor.tokenizer( + user_prompts, add_special_tokens=False, padding=False + )["input_ids"] + + for encodings in [system_encodings, user_encodings]: + for i, system_ids in enumerate(encodings): + if input_ids[i, : len(system_ids)].tolist() == system_ids: + labels[i, : len(system_ids)] = -100 + else: + seq = input_ids[i].tolist() + for j in range(len(seq) - len(system_ids) + 1): + if seq[j : j + len(system_ids)] == system_ids: + labels[i, j : j + len(system_ids)] = -100 + break # early exit + + batch["labels"] = labels + return batch + + return collate_fn diff --git a/scripts/agents/smolvlm2_collator.py b/scripts/agents/smolvlm2_collator.py new file mode 100644 index 000000000..41ed1e439 --- /dev/null +++ b/scripts/agents/smolvlm2_collator.py @@ -0,0 +1,224 @@ +from PIL import Image +from scripts.agents.function_parser import parse_function_call +import numpy as np +from transformers.models.smolvlm.image_processing_smolvlm import ( + get_resize_output_image_size, +) +from transformers.image_utils import ChannelDimension + + +def transform_messages( + batch_messages, + image_resize: dict[str, int | bool], +) -> list[list[Image.Image]]: + + resolution_max_side = image_resize["resolution_max_side"] if "resolution_max_side" in image_resize else None + to_pixel_coordinates = image_resize["to_pixel_coordinates"] if "to_pixel_coordinates" in image_resize else False + + if not to_pixel_coordinates and resolution_max_side is None: + return batch_messages + + all_image_inputs: list[list[Image.Image]] = [] + for messages in batch_messages: + new_image = None + for i in range(len(messages)): + if "image" in messages[i]["content"][0]: + old_image = messages[i]["content"][0]["image"] + + if resolution_max_side is not None: + resized_height, resized_width = get_resize_output_image_size( + np.array(old_image), + resolution_max_side=resolution_max_side, + input_data_format=ChannelDimension.LAST, + ) + new_image = old_image.resize((resized_width, resized_height)) + else: + resized_height, resized_width = old_image.height, old_image.width + new_image = old_image + + messages[i]["content"][0]["image"] = new_image + all_image_inputs.append([new_image]) + + if messages[i]["role"] == "assistant" and to_pixel_coordinates: + assert new_image is not None, "new_image is None" + + function_calls = parse_function_call(messages[i]["content"][0]["text"]) + old_function_call_strings = [ + function_call.to_string() for function_call in function_calls + ] + for function_call, old_function_call_string in zip( + function_calls, old_function_call_strings + ): + if function_call.function_name in [ + "click", + "long_press", + "double_click", + "move_mouse", + ]: + function_call.parameters["x"] = int( + function_call.parameters["x"] * new_image.width + ) + function_call.parameters["y"] = int( + function_call.parameters["y"] * new_image.height + ) + elif function_call.function_name in ["swipe", "drag"]: + function_call.parameters["from_coord"] = ( + int(function_call.parameters["from_coord"][0] * new_image.width), + int( + function_call.parameters["from_coord"][1] * new_image.height + ), + ) + function_call.parameters["to_coord"] = ( + int(function_call.parameters["to_coord"][0] * new_image.width), + int( + function_call.parameters["to_coord"][1] * new_image.height + ), + ) + messages[i]["content"][0]["text"] = messages[i]["content"][0][ + "text" + ].replace(old_function_call_string, function_call.to_string()) + + return all_image_inputs + + +def create_vlm_collate_fn(processor, training_args, script_args): + """Optimized collate function for VLM training that masks system prompt tokens.""" + + def collate_fn(examples: list[dict[str, list | str | Image.Image]]): + batch_messages: list[list[dict[str, list | str | Image.Image]]] = [] + assistant_messages: list[list[str]] = [] + all_image_inputs: list[list[Image.Image]] = [] + for example in examples: + images: list[Image.Image] = example["images"] + is_first_user = True + sample: list[dict[str, list | str | Image.Image]] = [] + assistant: list[str] = [] + for text in example["texts"]: + if "system" in text.keys(): + sample.append( + { + "role": "system", + "content": [{"type": "text", "text": text["system"]}], + } + ) + + if is_first_user: + sample.append( + { + "role": "user", + "content": [ + {"type": "image", "image": images[0]}, + {"type": "text", "text": text["user"]}, + ], + } + ) + is_first_user = False + else: + sample.append( + { + "role": "user", + "content": [ + {"type": "text", "text": text["user"]}, + ], + } + ) + + sample.append( + { + "role": "assistant", + "content": [{"type": "text", "text": "\n" + text["assistant"]}], + } + ) + assistant.append(text["assistant"] + "") + + batch_messages.append(sample) + assistant_messages.append(assistant) + all_image_inputs.append(images) + + if script_args.image_resize is not None and "to_pixel_coordinates" in script_args.image_resize and script_args.image_resize["to_pixel_coordinates"]: + all_image_inputs = transform_messages( + batch_messages, + image_resize=script_args.image_resize, + ) + + + texts = [ + processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=False + ) + for messages in batch_messages + ] + + batch = processor( + text=texts, + images=all_image_inputs if all_image_inputs else None, + max_length=training_args.max_length, + truncation=True, + padding=True, + return_tensors="pt", + ) + + input_ids = batch["input_ids"] + labels = input_ids.clone() + + assistant_encodings = [ + processor.tokenizer( + assistant_message, add_special_tokens=False, padding=False + )["input_ids"] + for assistant_message in assistant_messages + ] + + # Mask out all except the assistant messages + for i, assistant_ids_list in enumerate(assistant_encodings): + seq = input_ids[i].tolist() + assistant_positions: list[int] = [] + for ids in assistant_ids_list: + start_pos = 0 + while start_pos < len(seq) - len(ids) + 1: + found = False + for j in range(start_pos, len(seq) - len(ids) + 1): + if seq[j : j + len(ids)] == ids: + assistant_positions.extend(range(j, j + len(ids))) + start_pos = j + len(ids) + found = True + break + if not found: + break + + for pos in range(len(seq)): + if pos not in assistant_positions: + labels[i, pos] = -100 + + + batch["labels"] = labels + return batch + + return collate_fn + + +if __name__ == "__main__": + from transformers import AutoProcessor + from datasets import load_dataset + + class ScriptArguments: + image_resize = None + + processor = AutoProcessor.from_pretrained( + "HuggingFaceTB/SmolVLM2-2.2B-Instruct" + ) + processor.image_processor.size = {"longest_edge": 384} + collate_fn = create_vlm_collate_fn(processor, script_args=ScriptArguments) + max_length = [] + for dataset_name in ['ricosca']: + dataset_max_length = 0 + data = load_dataset("smolagents/aguvis-stage-1", dataset_name, split="train") + print("processing", dataset_name) + for example in data: + batch = collate_fn([example]) + dataset_max_length = max(dataset_max_length, batch["input_ids"].shape[1]) + print("dataset_max_length", dataset_name, dataset_max_length) + max_length.append(dataset_max_length) + + print(max_length) + print("max_length", max(max_length)) + open("max_length_384_phase_1.txt", "a").write(str(max(max_length))) diff --git a/scripts/agents/smolvlm_inference.py b/scripts/agents/smolvlm_inference.py new file mode 100644 index 000000000..f9d5710b4 --- /dev/null +++ b/scripts/agents/smolvlm_inference.py @@ -0,0 +1,54 @@ +import torch +from transformers import AutoModelForImageTextToText, AutoProcessor +from transformers.models.smolvlm.image_processing_smolvlm import SmolVLMImageProcessor + + +class TransformersModel: + def __init__(self, model_id: str, to_device: str = "cuda"): + self.model_id = model_id + self.processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Instruct") + self.processor.image_processor.size = {"longest_edge": 3 * 384} + self.model = AutoModelForImageTextToText.from_pretrained(model_id, torch_dtype=torch.bfloat16).to(to_device) + + def generate(self, messages: list[dict], **kwargs): + inputs = self.processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + ).to(self.model.device, dtype=torch.bfloat16) + generated_ids = self.model.generate(**inputs, **kwargs) + return self.processor.batch_decode( + generated_ids[:, len(inputs["input_ids"][0]) :], skip_special_tokens=True + )[0] + + +if __name__ == "__main__": + from PIL import Image + + model = TransformersModel( + model_id="/fsx/amir_mahla/smolagents-SmolVLM2-2.2B-Instruct-Agentic-GUI-phase-1-max-size-1152/checkpoint-800", + to_device="cuda:0", + ) + + image = Image.open("/admin/home/amir_mahla/screensuite/examples/sample_image.png") + + messages = [ + { + "role": "user", + "content": [ + { + "type": "image", + "image": image, + }, + { + "type": "text", + "text": "Given the screenshot, and the instruction, output a click that completes the instruction or targets the given element (always target the center of the element).\n\nOutput the click position as follows:\n\n(thought process)click(x, y)\nWith x the number of pixels from the left edge and y the number of pixels from the top edge.\n\nNow write the click needed to complete the instruction:\nInstruction: view more information about bomber\n", + }, + ], + }, + ] + + + print(model.generate(messages, max_new_tokens=128)) \ No newline at end of file diff --git a/setup.py b/setup.py index a88508b94..0229fa2ff 100644 --- a/setup.py +++ b/setup.py @@ -74,6 +74,8 @@ "async-lru>=2.0.5", "aiofiles>=24.1.0", "pandas>=2.2.3", + "qwen-vl-utils>=0.1.0", + "setuptools>=80.9.0", ] # this is a lookup table with items like: diff --git a/slurm/train.slurm b/slurm/train.slurm index 15a70d62c..b222870eb 100644 --- a/slurm/train.slurm +++ b/slurm/train.slurm @@ -33,8 +33,8 @@ START_TIME=$(date +%s) echo "START TIME: $(date)" # Refresh Weka on h4 cache -echo "Refreshing Weka filesystem..." -find -L /fsx/h4/ -type f | xargs -d '\n' -r -n512 -P64 weka fs tier fetch +# echo "Refreshing Weka filesystem..." +# find -L /fsx/${USER}/ -type f | xargs -d '\n' -r -n512 -P64 weka fs tier fetch # Default values MODEL="" diff --git a/src/open_r1/configs.py b/src/open_r1/configs.py index ddb6e53b0..0d6988ccf 100644 --- a/src/open_r1/configs.py +++ b/src/open_r1/configs.py @@ -68,19 +68,51 @@ class ScriptArguments(trl.ScriptArguments): # Override the dataset_name to make it optional dataset_name: Optional[str] = field( - default=None, metadata={"help": "Dataset name. Can be omitted if using dataset_mixture."} + default=None, + metadata={"help": "Dataset name. Can be omitted if using dataset_mixture."}, ) dataset_mixture: Optional[dict[str, Any]] = field( default=None, - metadata={"help": "Configuration for creating dataset mixtures with advanced options like shuffling."}, + metadata={ + "help": "Configuration for creating dataset mixtures with advanced options like shuffling." + }, + ) + single_gpu: bool = field( + default=False, + metadata={ + "help": "Force training on single GPU only, disabling distributed training." + }, + ) + + image_resize: Optional[dict[str, int]] = field( + default=None, + metadata={"help": "Resize the image to the given minimum and maximum pixels."}, ) def __post_init__(self): if self.dataset_name is None and self.dataset_mixture is None: - raise ValueError("Either `dataset_name` or `dataset_mixture` must be provided") + raise ValueError( + "Either `dataset_name` or `dataset_mixture` must be provided" + ) + + if self.image_resize is not None: + if ( + not isinstance(self.image_resize, dict) + # or "min_pixels" not in self.image_resize + # or "max_pixels" not in self.image_resize + # or "factor" not in self.image_resize + or "to_pixel_coordinates" not in self.image_resize + or "resolution_max_side" not in self.image_resize + ): + raise ValueError( + f"image_resize must be a dictionary with a 'min_pixels', 'max_pixels' and 'factor' key. {self.image_resize}" + ) if self.dataset_mixture is not None: - if not isinstance(self.dataset_mixture, dict) or "datasets" not in self.dataset_mixture: + if ( + not isinstance(self.dataset_mixture, dict) + or "datasets" not in self.dataset_mixture + ): raise ValueError( "dataset_mixture must be a dictionary with a 'datasets' key. " "Expected format: {'datasets': [...], 'seed': int}" @@ -110,7 +142,11 @@ def __post_init__(self): ) # Check that column names are consistent across all dataset configs - columns_sets = [set(dataset.columns) for dataset in datasets_list if dataset.columns is not None] + columns_sets = [ + set(dataset.columns) + for dataset in datasets_list + if dataset.columns is not None + ] if columns_sets: first_columns = columns_sets[0] if not all(columns == first_columns for columns in columns_sets): @@ -135,13 +171,21 @@ class GRPOConfig(trl.GRPOConfig): default_factory=lambda: [], metadata={"help": "The callbacks to run during training."}, ) - chat_template: Optional[str] = field(default=None, metadata={"help": "The chat template to use."}) + chat_template: Optional[str] = field( + default=None, metadata={"help": "The chat template to use."} + ) hub_model_revision: Optional[str] = field( default="main", metadata={"help": "The Hub model branch to push the model to."} ) - num_completions_to_print: int = field(default=0, metadata={"help": "Number of completions to print."}) - overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."}) - push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."}) + num_completions_to_print: int = field( + default=0, metadata={"help": "Number of completions to print."} + ) + overwrite_hub_revision: bool = field( + default=False, metadata={"help": "Whether to overwrite the Hub revision."} + ) + push_to_hub_revision: bool = field( + default=False, metadata={"help": "Whether to push to a Hub revision/branch."} + ) system_prompt: Optional[str] = field( default=None, metadata={"help": "The optional system prompt to use."}, @@ -149,7 +193,9 @@ class GRPOConfig(trl.GRPOConfig): wandb_log_unique_prompts: bool = field( default=True, metadata={ - "help": ("Whether to log the unique prompts to wandb. This will create a new run for each unique prompt.") + "help": ( + "Whether to log the unique prompts to wandb. This will create a new run for each unique prompt." + ) }, ) wandb_entity: Optional[str] = field( @@ -180,17 +226,27 @@ class SFTConfig(trl.SFTConfig): default_factory=lambda: [], metadata={"help": "The callbacks to run during training."}, ) - chat_template: Optional[str] = field(default=None, metadata={"help": "The chat template to use."}) + chat_template: Optional[str] = field( + default=None, metadata={"help": "The chat template to use."} + ) system_prompt: Optional[str] = field( default=None, metadata={"help": "The optional system prompt to use for benchmarking."}, ) + vision_model: bool = field( + default=False, + metadata={"help": "Whether this is a vision-language model training."}, + ) hub_model_revision: Optional[str] = field( default="main", metadata={"help": "The Hub model branch to push the model to."}, ) - overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."}) - push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."}) + overwrite_hub_revision: bool = field( + default=False, metadata={"help": "Whether to overwrite the Hub revision."} + ) + push_to_hub_revision: bool = field( + default=False, metadata={"help": "Whether to push to a Hub revision/branch."} + ) wandb_entity: Optional[str] = field( default=None, metadata={"help": ("The entity to store runs under.")}, @@ -263,7 +319,9 @@ class GRPOScriptArguments(ScriptArguments): ) repetition_max_penalty: float = field( default=-1.0, - metadata={"help": "Maximum (negative) penalty for for repetition penalty reward"}, + metadata={ + "help": "Maximum (negative) penalty for for repetition penalty reward" + }, ) code_language: str = field( default="python", @@ -281,7 +339,9 @@ class GRPOScriptArguments(ScriptArguments): ) code_eval_scoring_mode: Literal["pass_fail", "partial", "weighted_sum"] = field( default="weighted_sum", - metadata={"help": "use fraction of passed test cases as reward. If false, use 0/1 scoring."}, + metadata={ + "help": "use fraction of passed test cases as reward. If false, use 0/1 scoring." + }, ) parallel_code_exec_per_proc: int = field( default=2, diff --git a/src/open_r1/sft.py b/src/open_r1/sft.py index c11c023ca..c08d09858 100644 --- a/src/open_r1/sft.py +++ b/src/open_r1/sft.py @@ -13,18 +13,18 @@ # limitations under the License. """ -Supervised fine-tuning script for decoder language models. +Supervised fine-tuning script for decoder language models and vision-language models. Usage: # One 1 node of 8 x H100s accelerate launch --config_file=recipes/accelerate_configs/zero3.yaml src/open_r1/sft.py \ - --model_name_or_path open-r1/Qwen2.5-Math-7B-RoPE-300k \ - --dataset_name open-r1/Mixture-of-Thoughts \ + --model_name_or_path Qwen/Qwen2.5-1.5B-Instruct \ + --dataset_name smolagents/gaia-traces \ + --num_train_epochs 1 \ --dataset_config all \ --eos_token '<|im_end|>' \ --learning_rate 4.0e-5 \ - --num_train_epochs 5 \ --max_seq_length 32768 \ --per_device_train_batch_size 2 \ --gradient_checkpointing \ @@ -39,20 +39,41 @@ import datasets import transformers -from transformers import set_seed +from transformers import ( + set_seed, + AutoModelForVision2Seq, + AutoProcessor, + LlavaForConditionalGeneration, +) from transformers.trainer_utils import get_last_checkpoint +from trl import ModelConfig, SFTTrainer, TrlParser, get_peft_config, setup_chat_format from open_r1.configs import ScriptArguments, SFTConfig -from open_r1.utils import get_dataset, get_model, get_tokenizer +from open_r1.utils import get_dataset, get_model, get_tokenizer, get_processor from open_r1.utils.callbacks import get_callbacks from open_r1.utils.wandb_logging import init_wandb_training -from trl import ModelConfig, SFTTrainer, TrlParser, get_peft_config, setup_chat_format - +from PIL import Image +from transformers import Qwen2VLProcessor +from typing import Any +from scripts.agents.function_parser import parse_function_call +from scripts.agents.smolvlm2_collator import create_vlm_collate_fn logger = logging.getLogger(__name__) +from dotenv import load_dotenv + +load_dotenv() + def main(script_args, training_args, model_args): + # Force single GPU mode if requested + # if hasattr(script_args, 'single_gpu') and script_args.single_gpu: + # logger.info("Single GPU mode requested - setting CUDA_VISIBLE_DEVICES=0") + # # Disable distributed training + # os.environ["CUDA_VISIBLE_DEVICES"] = "0" + # training_args.local_rank = -1 + # training_args.ddp_backend = None + set_seed(training_args.seed) ############### @@ -85,15 +106,42 @@ def main(script_args, training_args, model_args): init_wandb_training(training_args) ###################################### - # Load dataset, tokenizer, and model # + # Load dataset, processor/tokenizer, and model # ###################################### dataset = get_dataset(script_args) - tokenizer = get_tokenizer(model_args, training_args) - model = get_model(model_args, training_args) - if tokenizer.chat_template is None: - logger.info("No chat template provided, defaulting to ChatML.") - model, tokenizer = setup_chat_format(model, tokenizer, format="chatml") + if training_args.vision_model: + logger.info("Setting up vision-language model training") + + # Set VLM-specific training arguments (following TRL reference) + training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False) + training_args.remove_unused_columns = False + training_args.dataset_kwargs = {"skip_prepare_dataset": True} + training_args.ddp_find_unused_parameters = True + + # Load processor and model for VLM + processor = get_processor(model_args, training_args, script_args) + model = get_model( + model_args, training_args + ) # This should return AutoModelForVision2Seq + data_collator = create_vlm_collate_fn(processor, training_args, script_args) + processing_class = processor.tokenizer + model_tags = ["open-r1", "vision-language", "vlm"] + + else: + logger.info("Setting up text-only model training") + + # Load tokenizer and model for text-only + tokenizer = get_tokenizer(model_args, training_args) + model = get_model(model_args, training_args) + + if tokenizer.chat_template is None: + logger.info("No chat template provided, defaulting to ChatML.") + model, tokenizer = setup_chat_format(model, tokenizer, format="chatml") + + data_collator = None # Use default + processing_class = tokenizer + model_tags = ["open-r1"] ############################ # Initialize the SFT Trainer @@ -101,9 +149,14 @@ def main(script_args, training_args, model_args): trainer = SFTTrainer( model=model, args=training_args, + data_collator=data_collator, train_dataset=dataset[script_args.dataset_train_split], - eval_dataset=(dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None), - processing_class=tokenizer, + eval_dataset=( + dataset[script_args.dataset_test_split] + if training_args.eval_strategy != "no" + else None + ), + processing_class=processing_class, peft_config=get_peft_config(model_args), callbacks=get_callbacks(training_args, model_args), ) @@ -128,16 +181,17 @@ def main(script_args, training_args, model_args): # Save model and create model card ################################## logger.info("*** Save model ***") - # Align the model's generation config with the tokenizer's eos token - # to avoid unbounded generation in the transformers `pipeline()` function - trainer.model.generation_config.eos_token_id = tokenizer.eos_token_id trainer.save_model(training_args.output_dir) logger.info(f"Model saved to {training_args.output_dir}") + try: + processor.save_pretrained(training_args.output_dir) + except Exception as e: + logger.error(f"Error saving processor: {e}") # Save everything else on main process kwargs = { "dataset_name": script_args.dataset_name, - "tags": ["open-r1"], + "tags": model_tags, } if trainer.accelerator.is_main_process: trainer.create_model_card(**kwargs) @@ -160,7 +214,10 @@ def main(script_args, training_args, model_args): ############# if training_args.push_to_hub: logger.info("Pushing to hub...") - trainer.push_to_hub(**kwargs) + trainer.push_to_hub(**kwargs, token=os.getenv("HF_TOKEN")) + # Also push processor for VLM models + if training_args.vision_model and trainer.accelerator.is_main_process: + processor.push_to_hub(training_args.hub_model_id) if __name__ == "__main__": diff --git a/src/open_r1/utils/__init__.py b/src/open_r1/utils/__init__.py index d3b84a99d..5c247bd02 100644 --- a/src/open_r1/utils/__init__.py +++ b/src/open_r1/utils/__init__.py @@ -1,6 +1,6 @@ from .data import get_dataset from .import_utils import is_e2b_available, is_morph_available -from .model_utils import get_model, get_tokenizer +from .model_utils import get_model, get_tokenizer, get_processor -__all__ = ["get_tokenizer", "is_e2b_available", "is_morph_available", "get_model", "get_dataset"] +__all__ = ["get_tokenizer", "get_processor", "is_e2b_available", "is_morph_available", "get_model", "get_dataset"] diff --git a/src/open_r1/utils/model_utils.py b/src/open_r1/utils/model_utils.py index 8191c17ea..4dd103df5 100644 --- a/src/open_r1/utils/model_utils.py +++ b/src/open_r1/utils/model_utils.py @@ -1,12 +1,20 @@ import torch -from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + PreTrainedTokenizer, + AutoProcessor, + AutoModelForImageTextToText, +) from trl import ModelConfig, get_kbit_device_map, get_quantization_config -from ..configs import GRPOConfig, SFTConfig +from ..configs import GRPOConfig, SFTConfig, ScriptArguments -def get_tokenizer(model_args: ModelConfig, training_args: SFTConfig | GRPOConfig) -> PreTrainedTokenizer: +def get_tokenizer( + model_args: ModelConfig, training_args: SFTConfig | GRPOConfig +) -> PreTrainedTokenizer: """Get the tokenizer for the model.""" tokenizer = AutoTokenizer.from_pretrained( model_args.model_name_or_path, @@ -20,10 +28,42 @@ def get_tokenizer(model_args: ModelConfig, training_args: SFTConfig | GRPOConfig return tokenizer -def get_model(model_args: ModelConfig, training_args: SFTConfig | GRPOConfig) -> AutoModelForCausalLM: - """Get the model""" +def get_processor( + model_args: ModelConfig, + training_args: SFTConfig | GRPOConfig, + script_args: ScriptArguments, +) -> AutoProcessor: + """Get the processor for VLM models.""" + + processor = AutoProcessor.from_pretrained( + model_args.model_name_or_path, + revision=model_args.model_revision, + trust_remote_code=model_args.trust_remote_code, + ) + + # Set the image processor resize size + if script_args.image_resize is not None and "resolution_max_side" in script_args.image_resize: + processor.image_processor.size = { + "longest_edge": script_args.image_resize["resolution_max_side"] + } + if hasattr(processor, "tokenizer"): + processor.tokenizer.truncation_side = "right" + processor.tokenizer.padding_side = "right" + + if training_args.chat_template is not None: + processor.chat_template = training_args.chat_template + + return processor + + +def get_model( + model_args: ModelConfig, training_args: SFTConfig | GRPOConfig +) -> AutoModelForCausalLM | AutoModelForImageTextToText: + """Get the model - supports both text-only and vision-language models""" torch_dtype = ( - model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) + model_args.torch_dtype + if model_args.torch_dtype in ["auto", None] + else getattr(torch, model_args.torch_dtype) ) quantization_config = get_quantization_config(model_args) model_kwargs = dict( @@ -35,8 +75,19 @@ def get_model(model_args: ModelConfig, training_args: SFTConfig | GRPOConfig) -> device_map=get_kbit_device_map() if quantization_config is not None else None, quantization_config=quantization_config, ) - model = AutoModelForCausalLM.from_pretrained( - model_args.model_name_or_path, - **model_kwargs, - ) + + # Check if this is a VLM model using the explicit flag + if hasattr(training_args, "vision_model") and training_args.vision_model: + # Load as vision-language model + model = AutoModelForImageTextToText.from_pretrained( + model_args.model_name_or_path, + **model_kwargs, + ) + else: + # Load as text-only model + model = AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, + **model_kwargs, + ) + return model diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/slow/test_code_reward.py b/tests/slow/test_code_reward.py deleted file mode 100644 index 8718eb35a..000000000 --- a/tests/slow/test_code_reward.py +++ /dev/null @@ -1,219 +0,0 @@ -# Copyright 2025 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import unittest - -from datasets import load_dataset - -from e2b_code_interpreter.models import Execution, ExecutionError -from open_r1.rewards import code_reward, ioi_code_reward -from open_r1.utils.routed_morph import RoutedMorphSandbox -from open_r1.utils.routed_sandbox import RoutedSandbox - - -class TestCodeRewards(unittest.TestCase): - def test_python_code_reward(self): - # requires E2B, see the README.md file - code_dataset = load_dataset("open-r1/verifiable-coding-problems-python_decontaminated-tested-shuffled") - NUM_SAMPLES = 20 - samples = code_dataset["train"].select(range(NUM_SAMPLES)) - test_completions = [[{"content": sample["gold_standard_solution"]}] for sample in samples] - reward_kwargs = {"verification_info": [sample["verification_info"] for sample in samples]} - rewards = code_reward(test_completions, **reward_kwargs) - print(rewards) - assert rewards == [1.0] * NUM_SAMPLES - - def test_e2b_router(self): - # run router locally: python scripts/e2b_router.py - code_dataset = load_dataset("open-r1/verifiable-coding-problems-python_decontaminated-tested-shuffled") - NUM_SAMPLES = 128 - samples = code_dataset["train"].select(range(NUM_SAMPLES)) - test_completions = [[{"content": sample["gold_standard_solution"]}] for sample in samples] - reward_kwargs = {"verification_info": [sample["verification_info"] for sample in samples]} - rewards = code_reward(test_completions, e2b_router_url="0.0.0.0:8000", **reward_kwargs) - print(rewards) - assert rewards == [1.0] * NUM_SAMPLES - - def test_e2b_router_parallel(self): - # run router locally: python scripts/e2b_router.py - code_dataset = load_dataset("open-r1/verifiable-coding-problems-python_decontaminated-tested-shuffled") - - BATCH_SIZE = 32 - NUM_SAMPLES = 256 - - def batch_code_reward(examples): - test_completions = [[{"content": solution}] for solution in examples["gold_standard_solution"]] - reward_kwargs = { - "verification_info": [verification_info for verification_info in examples["verification_info"]] - } - rewards = code_reward(test_completions, e2b_router_url="0.0.0.0:8000", **reward_kwargs) - assert rewards == [1.0] * BATCH_SIZE - return examples - - code_dataset = code_dataset["train"].select(range(NUM_SAMPLES)) - code_dataset = code_dataset.map( - batch_code_reward, - batched=True, - batch_size=BATCH_SIZE, - num_proc=4, - load_from_cache_file=False, - ) - - def test_ioi_code_reward(self): - # This slow test case requires spinning up a bunch (I tested with ~64) of piston workers, see docs here - # slurm/piston/README.md - code_dataset = load_dataset("open-r1/ioi-reward-test-dataset") - NUM_SAMPLES = 16 - samples = code_dataset["train"].select(range(NUM_SAMPLES)) - test_completions = [[{"content": f"```cpp\n{sample['sample_solution']}```"}] for sample in samples] - keys = [key for key in samples[0] if key not in ["prompt", "completion"]] - reward_kwargs = {key: [example[key] for example in samples] for key in keys} - rewards = ioi_code_reward(test_completions, **reward_kwargs) - print(rewards) - assert rewards == [1.0] * NUM_SAMPLES - - def test_e2b_router_run_code_success(self): - # run router locally: python scripts/e2b_router.py - routed_sandbox = RoutedSandbox(router_url="localhost:8000") - scripts = [ - "print('hello from integration test')", - "result = 2 + 2\nprint(result)", - ] - - results = routed_sandbox.run_code(scripts) - - assert len(results) == 2 - - for result in results: - assert isinstance(result, Execution) - # assert result.exit_code == 0 - assert result.error is None - assert "hello" in result.logs["stdout"][0] or "4" in result.logs["stdout"][0] - - def test_e2b_router_run_code_with_error(self): - # run router locally: python scripts/e2b_router.py - - routed_sandbox = RoutedSandbox(router_url="localhost:8000") - scripts = ["print('this is fine')", "print('unterminated string"] - - results = routed_sandbox.run_code(scripts) - - assert len(results) == 2 - - # First one should be okay - # assert results[0].exit_code == 0 # Execution object has no attribute 'exit_code' - assert results[0].error is None - assert "this is fine" in results[0].logs["stdout"][0] - - # Second one should have a syntax error - - # assert results[1].exit_code != 0 # Execution object has no attribute 'exit_code' - assert results[1].error is not None - assert isinstance(results[1].error, ExecutionError) - assert "SyntaxError" in results[1].error.name - - def test_python_code_reward_morph(self): - # requires MorphCloud, see the README.md file - code_dataset = load_dataset("open-r1/verifiable-coding-problems-python_decontaminated-tested-shuffled") - NUM_SAMPLES = 20 - samples = code_dataset["train"].select(range(NUM_SAMPLES)) - test_completions = [[{"content": sample["gold_standard_solution"]}] for sample in samples] - reward_kwargs = { - "verification_info": [sample["verification_info"] for sample in samples], - "provider_type": "morph", - } - rewards = code_reward(test_completions, **reward_kwargs) - print(rewards) - assert rewards == [1.0] * NUM_SAMPLES - - def test_morph_router(self): - # run router locally: python scripts/morph_router.py --port 8001 --max_num_sandboxes 20 - code_dataset = load_dataset("open-r1/verifiable-coding-problems-python_decontaminated-tested-shuffled") - NUM_SAMPLES = 32 - samples = code_dataset["train"].select(range(NUM_SAMPLES)) - test_completions = [[{"content": sample["gold_standard_solution"]}] for sample in samples] - reward_kwargs = { - "verification_info": [sample["verification_info"] for sample in samples], - "provider_type": "morph", - "morph_router_url": "0.0.0.0:8001", - } - rewards = code_reward(test_completions, **reward_kwargs) - print(rewards) - assert rewards == [1.0] * NUM_SAMPLES - - def test_morph_router_parallel(self): - # run router locally: python scripts/morph_router.py --port 8001 --max_num_sandboxes 20 - code_dataset = load_dataset("open-r1/verifiable-coding-problems-python_decontaminated-tested-shuffled") - - BATCH_SIZE = 32 - NUM_SAMPLES = 256 - - def batch_code_reward(examples): - test_completions = [[{"content": solution}] for solution in examples["gold_standard_solution"]] - reward_kwargs = { - "verification_info": [verification_info for verification_info in examples["verification_info"]], - "provider_type": "morph", - "morph_router_url": "0.0.0.0:8001", - } - rewards = code_reward(test_completions, **reward_kwargs) - assert rewards == [1.0] * BATCH_SIZE - return examples - - code_dataset = code_dataset["train"].select(range(NUM_SAMPLES)) - code_dataset = code_dataset.map( - batch_code_reward, - batched=True, - batch_size=BATCH_SIZE, - num_proc=4, - load_from_cache_file=False, - ) - - def test_morph_router_run_code_success(self): - # run router locally: python scripts/morph_router.py --port 8001 --max_num_sandboxes 20 - - routed_sandbox = RoutedMorphSandbox(router_url="localhost:8001") - scripts = [ - "print('hello from morph integration test')", - "result = 2 + 2\nprint(result)", - ] - - results = routed_sandbox.run_code(scripts) - - assert len(results) == 2 - - for result in results: - assert result.exception_str is None - assert "hello" in result.text or "4" in result.text - - def test_morph_router_run_code_with_error(self): - # run router locally: python scripts/morph_router.py --port 8001 --max_num_sandboxes 20 - - routed_sandbox = RoutedMorphSandbox(router_url="localhost:8001") - scripts = ["print('this is fine with morph')", "print('unterminated string"] - - results = routed_sandbox.run_code(scripts) - - assert len(results) == 2 - - # First one should be okay - assert results[0].exception_str is None - assert "this is fine with morph" in results[0].text - - # Second one should have a syntax error - assert "SyntaxError" in results[1].text - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_rewards.py b/tests/test_rewards.py deleted file mode 100644 index 03ac517c9..000000000 --- a/tests/test_rewards.py +++ /dev/null @@ -1,568 +0,0 @@ -# Copyright 2025 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import unittest - -from dotenv import load_dotenv -from open_r1.configs import GRPOScriptArguments -from open_r1.rewards import ( - accuracy_reward, - format_reward, - get_code_format_reward, - get_cosine_scaled_reward, - get_repetition_penalty_reward, - get_reward_funcs, - get_soft_overlong_punishment, - len_reward, - reasoning_steps_reward, - tag_count_reward, -) - - -load_dotenv() - - -class TestGetRewardFuncs(unittest.TestCase): - def test_get_reward_funcs(self): - """Test get_reward_funcs with various reward functions.""" - reward_names = [ - "accuracy", - "format", - "reasoning_steps", - "cosine", - "repetition_penalty", - "length", - "tag_count", - "code", - "ioi_code", - "code_format", - "binary_code", - ] - reward_func_names = [ - "accuracy_reward", - "format_reward", - "reasoning_steps_reward", - "cosine_scaled_reward", - "repetition_penalty_reward", - "len_reward", - "tag_count_reward", - "code_reward", - "ioi_code_reward", - "code_format_reward", - "binary_code_reward", - ] - - args = GRPOScriptArguments( - dataset_name="dummy", - reward_funcs=reward_names, - ) - - reward_funcs = get_reward_funcs(args) - self.assertEqual(len(reward_funcs), 11) - for func_name, func in zip(reward_func_names, reward_funcs): - self.assertEqual(func_name, func.__name__) - - -class TestRewards(unittest.TestCase): - def test_accuracy_reward_correct_answer(self): - """Test accuracy_reward with a correct answer.""" - completion = [[{"content": r"\boxed{\frac{63}{400}}"}]] - solution = [r"\frac{63}{400}"] - rewards = accuracy_reward(completion, solution) - self.assertEqual(rewards[0], 1.0) - - def test_accuracy_reward_wrong_answer(self): - """Test accuracy_reward with an incorrect answer.""" - completion = [[{"content": r"\boxed{\frac{64}{400}}"}]] - solution = [r"\frac{63}{400}"] - rewards = accuracy_reward(completion, solution) - self.assertEqual(rewards[0], 0.0) - - def test_accuracy_reward_wrong_answer_no_latex(self): - """Test accuracy_reward with an incorrect answer and gold solution with no latex.""" - completion = [[{"content": r"\boxed{3}"}]] - solution = ["6"] - rewards = accuracy_reward(completion, solution) - self.assertEqual(rewards[0], 0.0) - - def test_format_reward_correct(self): - """Test format_reward with correct format.""" - completion = [[{"content": "\nSome reasoning\n\n\nThe answer\n"}]] - rewards = format_reward(completion) - self.assertEqual(rewards[0], 1.0) - - def test_format_reward_incorrect(self): - """Test format_reward with incorrect format.""" - incorrect_formats = [ - "Only thinking", - "Only answer", - "No tags at all", - "Missing closingMissing closing", - "Wrong orderWrong order", - ] - - for fmt in incorrect_formats: - completion = [[{"content": fmt}]] - rewards = format_reward(completion) - self.assertEqual(rewards[0], 0.0) - - def test_reasoning_steps_reward(self): - """Test reasoning_steps_reward with various formats.""" - test_cases = [ - # Full credit cases (3 or more steps) - ("Step 1: First step\nStep 2: Second step\nStep 3: Third step", 1.0), - ("First, we do this.\nSecond, we do that.\nFinally, we conclude.", 1.0), - # Partial credit cases (less than 3 steps) - ("Step 1: Only step", 1 / 3), - ("First, we do this.\nFinally, we conclude.", 2 / 3), - # No credit case - ("Just plain text without any clear steps", 0.0), - ] - - for content, expected_reward in test_cases: - completion = [[{"content": content}]] - rewards = reasoning_steps_reward(completion) - self.assertAlmostEqual(rewards[0], expected_reward) - - def test_multiple_completions(self): - """Test handling multiple completions at once.""" - completions = [ - [{"content": r"\boxed{\frac{63}{400}}"}], - [{"content": r"\boxed{\frac{64}{400}}"}], - ] - solutions = [r"\frac{63}{400}", r"\frac{63}{400}"] - - rewards = accuracy_reward(completions, solutions) - self.assertEqual(len(rewards), 2) - self.assertEqual(rewards[0], 1.0) - self.assertEqual(rewards[1], 0.0) - - def test_cosine_scaled_reward(self): - """Test cosine_scaled_reward with various cases.""" - # Test parameters - test_params = { - "min_value_wrong": -1.0, - "max_value_wrong": -0.5, - "min_value_correct": 0.5, - "max_value_correct": 1.0, - "max_len": 100, - } - - test_cases = [ - # Correct answers with different lengths - ( - r"\boxed{\frac{63}{400}}", - r"\frac{63}{400}", - 20, - 0.943, - ), # Short correct answer - ( - r"\boxed{\frac{63}{400}}", - r"\frac{63}{400}", - 80, - 0.547, - ), # Long correct answer - # Wrong answers with different lengths - ( - r"\boxed{\frac{64}{400}}", - r"\frac{63}{400}", - 20, - -0.942, - ), # Short wrong answer - ( - r"\boxed{\frac{64}{400}}", - r"\frac{63}{400}", - 80, - -0.547, - ), # Long wrong answer - ] - - for content, solution, content_len, expected_reward in test_cases: - # Pad content to desired length - padded_content = content + " " * (content_len - len(content)) - completion = [[{"content": padded_content}]] - - rewards = get_cosine_scaled_reward(**test_params)(completion, [solution]) - self.assertAlmostEqual(rewards[0], expected_reward, places=2) - - def test_format_reward_specific_multiline(self): - """Test format_reward with a specific multiline input.""" - inputs = "\nI will count each distinct object in the image:\n1. Purple scooter\n2. Red bicycle\n3. Green motorcycle\n4. Gray sedan\n5. Yellow school bus\n6. Small green double-decker bus\n7. Small red car\n8. Small purple car\n9. Small gray dirt bike\n\nThere are 9 distinct objects in total.\n\n\n9\n" - completion = [[{"content": inputs}]] - rewards = format_reward(completion) - self.assertEqual(rewards[0], 1.0) - - def test_same_length_responses(self): - """Test len_reward when all responses have the same length.""" - completions = [ - [{"content": r"\boxed{\frac{63}{400}}"}], - [{"content": r"\boxed{\frac{64}{400}}"}], - ] - solutions = [r"\frac{63}{400}", r"\frac{63}{400}"] - - rewards = len_reward(completions, solutions) - self.assertEqual(rewards, [0.0, 0.0]) - - def test_different_lengths_correct_answers(self): - """Test len_reward with different length correct answers.""" - completions = [ - [{"content": r"\boxed{\frac{63}{400}}"}], # shorter - [{"content": r"\boxed{\frac{63}{400}} " + "x" * 10}], # longer - ] - solutions = [r"\frac{63}{400}", r"\frac{63}{400}"] - - rewards = len_reward(completions, solutions) - self.assertGreater(rewards[0], rewards[1]) # shorter answer should get higher reward - self.assertAlmostEqual(rewards[0], 0.5) # shortest correct answer gets maximum reward - - def test_different_lengths_incorrect_answers(self): - """Test len_reward with different length incorrect answers.""" - completions = [ - [{"content": r"\boxed{\frac{64}{400}}"}], # shorter - [{"content": r"\boxed{\frac{64}{400}} " + "x" * 10}], # longer - ] - solutions = [r"\frac{63}{400}", r"\frac{63}{400}"] - - rewards = len_reward(completions, solutions) - self.assertLessEqual(rewards[0], 0.0) # incorrect answers should get non-positive rewards - self.assertLessEqual(rewards[1], 0.0) - self.assertGreater(rewards[0], rewards[1]) # shorter answer should still be penalized less - - def test_mixed_correctness(self): - """Test len_reward with mix of correct and incorrect answers of different lengths.""" - completions = [ - [{"content": r"\boxed{\frac{63}{400}}"}], # correct, shorter - [{"content": r"\boxed{\frac{63}{400}} " + "x" * 10}], # correct, longer - [{"content": r"\boxed{\frac{64}{400}}"}], # incorrect, shorter - [{"content": r"\boxed{\frac{64}{400}} " + "x" * 10}], # incorrect, longer - ] - solutions = [r"\frac{63}{400}"] * 4 - - rewards = len_reward(completions, solutions) - - # Shortest correct answer should get positive reward - self.assertGreater(rewards[0], 0.0) - - # Longer correct answer might get negative reward: - self.assertGreater(rewards[2], rewards[1]) - self.assertGreaterEqual(rewards[1], rewards[3]) - - # Incorrect answers should get non-positive rewards - self.assertLessEqual(rewards[2], 0.0) - self.assertLessEqual(rewards[3], 0.0) - - # Shorter answers should get better rewards within their correctness category - self.assertGreater(rewards[0], rewards[1]) # correct answers - self.assertGreater(rewards[2], rewards[3]) # incorrect answers - - def test_unparseable_solution(self): - """Test len_reward with unparseable solution.""" - completions = [ - [{"content": r"\boxed{answer}"}], - [{"content": r"\boxed{answer} " + "x" * 10}], - ] - solutions = ["unparseable_latex", "unparseable_latex"] - - rewards = len_reward(completions, solutions) - self.assertGreater(rewards[0], rewards[1]) # shorter answer should still get better reward - self.assertAlmostEqual(rewards[0], 0.5) # treated as correct, shortest gets maximum reward - - -class TestRepetitionPenaltyReward(unittest.TestCase): - def test_positive_max_penalty_raises_value_error(self): - with self.assertRaises(ValueError): - get_repetition_penalty_reward(ngram_size=2, max_penalty=1.0) - with self.assertRaisesRegex(ValueError, "max_penalty 1.5 should not be positive"): - get_repetition_penalty_reward(ngram_size=2, max_penalty=1.5) - - def test_no_repetition(self): - reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0) - completions = [[{"content": "this is a test sentence"}]] - rewards = reward_fn(completions) - self.assertEqual(rewards, [0.0]) - - def test_full_repetition(self): - reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0) - completions = [[{"content": "this this this this this"}]] - - rewards = reward_fn(completions) - # (1 - 1/4) * -1 = -0.75 - self.assertEqual(rewards, [-0.75]) - - def test_partial_repetition(self): - reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0) - completions = [[{"content": "this is a this is a test"}]] - - rewards = reward_fn(completions) - # Unique 2-grams: (this, is), (is, a), (a, this), (a, test). 4 unique out of 6 total - # (1 - 4/6) * -1 = -1/3 = -0.3333... - self.assertAlmostEqual(rewards[0], -1 / 3) - - def test_multiple_completions(self): - reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-0.5) - completions = [ - [{"content": "this is a test"}], - [{"content": "test test test test"}], - ] - - rewards = reward_fn(completions) - # Completion 1: (this, is, a), (is, a, test) -> 2 unique / 2 total -> (1 - 2/2) * -0.5 = 0 - # Completion 2: (test, test, test) -> 1 unique / 2 total -> (1 - 1/2) * -0.5 = -0.25 - self.assertAlmostEqual(rewards[0], 0.0) - self.assertAlmostEqual(rewards[1], -0.25) - - def test_empty_completion(self): - reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0) - completions = [[{"content": ""}]] - rewards = reward_fn(completions) - self.assertEqual(rewards, [0.0]) - - def test_different_ngram_size(self): - reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-2.0) - completions = [[{"content": "this is a this is a test"}]] - - rewards = reward_fn(completions) - self.assertAlmostEqual(rewards[0], -0.4) - - def test_mixed_case(self): - reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0) - completions = [ - [{"content": "This is A Test"}], - [{"content": "this IS a test"}], - ] - - rewards = reward_fn(completions) - # both completions should produce the same reward, because the text gets lowercased - self.assertAlmostEqual(rewards[0], rewards[1]) - - def test_one_word_completion(self): - reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0) - completions = [[{"content": "word"}]] - - rewards = reward_fn(completions) - self.assertEqual(rewards, [0.0]) - - def test_two_word_completion(self): - reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0) - completions = [[{"content": "two words"}]] - - rewards = reward_fn(completions) - self.assertEqual(rewards, [0.0]) - - def test_three_word_completion(self): - reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0) - completions = [[{"content": "three different words"}]] - - rewards = reward_fn(completions) - self.assertEqual(rewards, [0.0]) - - def test_three_word_repetition_completion(self): - reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0) - completions = [[{"content": "word word word word"}]] - - rewards = reward_fn(completions) - self.assertEqual(rewards, [-0.5]) - - def test_four_word_completion_with_repetition(self): - reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0) - completions = [[{"content": "one two one two"}]] - - rewards = reward_fn(completions) - # ngrams are (one two one) (two one two). unique is 2 and count is 2, therefore (1-1) * -1. - self.assertEqual(rewards, [0.0]) - - def test_five_word_completion_with_repetition(self): - reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-0.5) - completions = [[{"content": "A B C A B"}]] - - rewards = reward_fn(completions) - # (A B C) (B C A) (C A B). unique is 3. count is 3 (1-1) * -.5 = 0 - self.assertEqual(rewards, [0.0]) - - def test_six_word_completion_with_repetition(self): - reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0) - completions = [[{"content": "A B C A B C"}]] - - rewards = reward_fn(completions) - self.assertEqual(rewards, [-0.25]) - - def test_long_completion_with_repetition(self): - reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0) - completions = [[{"content": "A B C A B C E F G A B C A B C"}]] - rewards = reward_fn(completions) - self.assertAlmostEqual(rewards[0], -0.3846, places=4) - - def test_long_completion_without_repetition(self): - reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0) - completions = [[{"content": "A B C D E F G H I J K L"}]] - - rewards = reward_fn(completions) - self.assertEqual(rewards, [0.0]) - - def test_tag_count_rewards_all_correct(self): - """Test tag_count_reward with correct tags.""" - completion = [[{"content": "\nSome reasoning\n\n\nThe answer\n"}]] - rewards = tag_count_reward(completion) - self.assertEqual(rewards[0], 1.0) - - def test_tag_count_rewards_missing_think_begin(self): - """Test tag_count_reward with missing tag.""" - completion = [[{"content": "Some reasoning\n\n\nThe answer\n"}]] - rewards = tag_count_reward(completion) - self.assertEqual(rewards[0], 0.75) - - def test_tag_count_rewards_missing_think_end(self): - """Test tag_count_reward with missing tag.""" - completion = [[{"content": "\nSome reasoning\n\nThe answer\n"}]] - rewards = tag_count_reward(completion) - self.assertEqual(rewards[0], 0.75) - - def test_tag_count_rewards_missing_answer_begin(self): - """Test tag_count_reward with missing tag.""" - completion = [[{"content": "\nSome reasoning\n\nThe answer\n"}]] - rewards = tag_count_reward(completion) - self.assertEqual(rewards[0], 0.75) - - def test_tag_count_rewards_missing_answer_end(self): - """Test tag_count_reward with missing tag.""" - completion = [[{"content": "\nSome reasoning\n\n\nThe answer"}]] - rewards = tag_count_reward(completion) - self.assertEqual(rewards[0], 0.75) - - def test_tag_count_rewards_missing_all_tags(self): - """Test tag_count_reward with missing all tags.""" - completion = [[{"content": "Some reasoning\nThe answer"}]] - rewards = tag_count_reward(completion) - self.assertEqual(rewards[0], 0.0) - - def test_full_repetition_with_language(self): - reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0, language="en") - completions = [[{"content": "that that that that that"}]] - rewards = reward_fn(completions) - self.assertEqual(rewards, [-0.75]) - # begin test for zh language - reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0, language="zh") - completions = [[{"content": "这个这个这个这个这个"}]] - rewards = reward_fn(completions) - self.assertEqual(rewards, [-0.75]) - - def test_soft_overlong_punishment_short_completion(self): - """Test soft overlong punishment reward function with a short completion.""" - # length 50, with max=100 and soft cache=20, reward should be 0. - reward_fn = get_soft_overlong_punishment(max_completion_len=100, soft_punish_cache=20) - completion_ids = [[1] * 50] # 50 <= 80 - rewards = reward_fn(completion_ids=completion_ids) - self.assertEqual(rewards, [0]) - - def test_soft_overlong_punishment_long_completion(self): - """Test soft overlong punishment reward function with a longer than max completion.""" - # 110 > 100, reward should be -1. - reward_fn = get_soft_overlong_punishment(max_completion_len=100, soft_punish_cache=20) - completion_ids = [[1] * 110] - rewards = reward_fn(completion_ids) - self.assertEqual(rewards, [-1]) - - def test_soft_overlong_punishment_intermediate_completion(self): - """Test soft overlong punishment reward function for intermediate length completion.""" - reward_fn = get_soft_overlong_punishment(max_completion_len=100, soft_punish_cache=20) - completion_ids = [[1] * 90] # 90 is between 80 and 100 - rewards = reward_fn(completion_ids) - self.assertAlmostEqual(rewards[0], -0.5, places=4) - - -class TestCodeFormat(unittest.TestCase): - def test_correct_python_format(self): - """Test code format reward with correct Python format.""" - completion = [ - [ - { - "content": "\nLet's solve this\nStep 1: First step\n\n\n```python\ndef hello():\n print('world')\n```\n" - } - ] - ] - reward_fn = get_code_format_reward(language="python") - rewards = reward_fn(completion) - self.assertEqual(rewards[0], 1.0) - - def test_incorrect_formats(self): - """Test code format reward with various incorrect formats.""" - incorrect_formats = [ - # Missing think/answer tags - "```python\ndef hello():\n print('world')\n```", - # Missing code block - "Some thinkingJust plain text", - # Wrong language - "Analysis```javascript\nconsole.log('hello');\n```", - # Missing language identifier - "Analysis```\ndef hello(): pass\n```", - # Wrong order of tags - "```python\ndef hello(): pass\n```Analysis", - ] - - reward_fn = get_code_format_reward(language="python") - for fmt in incorrect_formats: - completion = [[{"content": fmt}]] - rewards = reward_fn(completion) - self.assertEqual(rewards[0], 0.0) - - def test_multiple_code_blocks(self): - """Test format reward with multiple code blocks in think and answer sections.""" - completion = [ - [ - { - "content": "\nHere's an example:\n```python\nx = 1\n```\nNow the solution:\n\n\n```python\ndef solution():\n return 42\n```\n" - } - ] - ] - reward_fn = get_code_format_reward(language="python") - rewards = reward_fn(completion) - self.assertEqual(rewards[0], 1.0) - - def test_different_languages(self): - """Test code format reward with different programming languages.""" - completion = [ - [ - { - "content": "\nAnalysis\n\n\n```javascript\nconsole.log('hello');\n```\n" - } - ] - ] - - # Test with JavaScript - js_reward_fn = get_code_format_reward(language="javascript") - rewards = js_reward_fn(completion) - self.assertEqual(rewards[0], 1.0) - - # Same completion should fail for Python - py_reward_fn = get_code_format_reward(language="python") - rewards = py_reward_fn(completion) - self.assertEqual(rewards[0], 0.0) - - def test_multiline_code(self): - """Test format reward with complex multiline code blocks.""" - completion = [ - [ - { - "content": "\nHere's the analysis\n\n\n```python\nclass Solution:\n def __init__(self):\n self.value = 42\n \n def get_value(self):\n return self.value\n```\n" - } - ] - ] - reward_fn = get_code_format_reward(language="python") - rewards = reward_fn(completion) - self.assertEqual(rewards[0], 1.0) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/utils/test_data.py b/tests/utils/test_data.py deleted file mode 100644 index 669057e78..000000000 --- a/tests/utils/test_data.py +++ /dev/null @@ -1,129 +0,0 @@ -# Copyright 2025 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import unittest -from dataclasses import asdict - -from datasets import DatasetDict, load_dataset - -from open_r1.configs import DatasetConfig, DatasetMixtureConfig, ScriptArguments -from open_r1.utils.data import get_dataset - - -class TestGetDataset(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.dataset_name = "trl-internal-testing/zen" - cls.dataset_config = "conversational_preference" - cls.ref_dataset = load_dataset(cls.dataset_name, cls.dataset_config) - - def test_dataset_and_config_name(self): - args = ScriptArguments(dataset_name=self.dataset_name, dataset_config=self.dataset_config) - dataset = get_dataset(args) - self.assertIsInstance(dataset, DatasetDict) - self.assertIn("train", dataset) - self.assertEqual(len(dataset["train"]), len(self.ref_dataset["train"])) - - def test_unweighted_mixture(self): - """Mix train and test splits of the same dataset.""" - dataset_configs = [ - DatasetConfig(id=self.dataset_name, config=self.dataset_config, split="train", columns=None, weight=None), - DatasetConfig(id=self.dataset_name, config=self.dataset_config, split="test", columns=None, weight=None), - ] - dataset_mixture = DatasetMixtureConfig( - datasets=dataset_configs, - ) - args = ScriptArguments(dataset_mixture=asdict(dataset_mixture)) - dataset = get_dataset(args) - self.assertIsInstance(dataset, DatasetDict) - self.assertIn("train", dataset) - self.assertEqual(len(dataset["train"]), len(self.ref_dataset["train"]) + len(self.ref_dataset["test"])) - - def test_weighted_mixture(self): - """Test loading a dataset mixture with weights.""" - dataset_configs = [ - DatasetConfig(id=self.dataset_name, config=self.dataset_config, split="train", columns=None, weight=0.25), - DatasetConfig(id=self.dataset_name, config=self.dataset_config, split="test", columns=None, weight=0.5), - ] - dataset_mixture = DatasetMixtureConfig( - datasets=dataset_configs, - ) - args = ScriptArguments(dataset_mixture=asdict(dataset_mixture)) - dataset = get_dataset(args) - self.assertIsInstance(dataset, DatasetDict) - self.assertIn("train", dataset) - self.assertEqual( - len(dataset["train"]), len(self.ref_dataset["train"]) // 4 + len(self.ref_dataset["test"]) // 2 - ) - - def test_mixture_and_test_split(self): - """Test loading a dataset mixture with test split.""" - dataset_configs = [ - DatasetConfig( - id=self.dataset_name, config=self.dataset_config, split="train[:10]", columns=None, weight=None - ), - ] - dataset_mixture = DatasetMixtureConfig(datasets=dataset_configs, test_split_size=0.2) - args = ScriptArguments(dataset_name=None, dataset_mixture=asdict(dataset_mixture)) - dataset = get_dataset(args) - self.assertIsInstance(dataset, DatasetDict) - self.assertIn("train", dataset) - self.assertIn("test", dataset) - self.assertEqual(len(dataset["train"]), 8) - self.assertEqual(len(dataset["test"]), 2) - - def test_mixture_column_selection(self): - """Test loading a dataset mixture with column selection.""" - dataset_configs = [ - DatasetConfig( - id=self.dataset_name, - config=self.dataset_config, - split="train", - columns=["prompt", "chosen"], - weight=None, - ), - ] - dataset_mixture = DatasetMixtureConfig( - datasets=dataset_configs, - ) - args = ScriptArguments(dataset_mixture=asdict(dataset_mixture)) - dataset = get_dataset(args) - self.assertIsInstance(dataset, DatasetDict) - self.assertIn("train", dataset) - self.assertIn("prompt", dataset["train"].column_names) - self.assertIn("chosen", dataset["train"].column_names) - - def test_mixture_with_mismatched_columns(self): - dataset_configs = [ - DatasetConfig( - id=self.dataset_name, config=self.dataset_config, split="train", columns=["prompt"], weight=None - ), - DatasetConfig( - id=self.dataset_name, config=self.dataset_config, split="train", columns=["chosen"], weight=None - ), - ] - dataset_mixture = DatasetMixtureConfig( - datasets=dataset_configs, - ) - with self.assertRaises(ValueError) as context: - _ = ScriptArguments(dataset_mixture=asdict(dataset_mixture)) - self.assertIn("Column names must be consistent", str(context.exception)) - - def test_no_dataset_name_or_mixture(self): - with self.assertRaises(ValueError) as context: - _ = ScriptArguments(dataset_name=None, dataset_mixture=None) - self.assertIn("Either `dataset_name` or `dataset_mixture` must be provided", str(context.exception)) - - -if __name__ == "__main__": - unittest.main() diff --git a/token_stat b/token_stat new file mode 100644 index 000000000..8ed5e8249 --- /dev/null +++ b/token_stat @@ -0,0 +1,5 @@ +max length: 3678 +total user token: 52272359 +total system token: 326634875 +total answer token: 32224010 +total token: 843473151