From 20b13d8413a584e312c70dbc7de078bf4c91b116 Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 28 Jan 2025 13:10:23 +0100 Subject: [PATCH 1/4] update --- .../models/aria/image_processing_aria.py | 2 + .../models/aria/processing_aria.py | 4 + .../models/emu3/image_processing_emu3.py | 2 +- .../models/emu3/processing_emu3.py | 3 +- .../idefics2/image_processing_idefics2.py | 2 +- .../idefics3/image_processing_idefics3.py | 2 +- .../models/idefics3/processing_idefics3.py | 3 +- .../llava_next/image_processing_llava_next.py | 2 +- .../models/mllama/processing_mllama.py | 10 +- src/transformers/processing_utils.py | 75 +++--- tests/models/aria/test_processor_aria.py | 49 ++++ tests/models/emu3/test_processor_emu3.py | 5 + tests/models/llava/test_processor_llava.py | 33 +-- .../llava_next/test_processor_llava_next.py | 6 +- .../test_processor_llava_onevision.py | 6 +- tests/models/mllama/test_processor_mllama.py | 3 + .../qwen2_5_vl/test_processor_qwen2_5_vl.py | 145 ++++++++++- .../qwen2_vl/test_processor_qwen2_vl.py | 145 ++++++++++- tests/test_modeling_common.py | 6 + tests/test_processing_common.py | 225 +++++++++++++++++- 20 files changed, 660 insertions(+), 68 deletions(-) diff --git a/src/transformers/models/aria/image_processing_aria.py b/src/transformers/models/aria/image_processing_aria.py index 7b00665aa285..10ffde2f9952 100644 --- a/src/transformers/models/aria/image_processing_aria.py +++ b/src/transformers/models/aria/image_processing_aria.py @@ -133,6 +133,8 @@ class AriaImageProcessor(BaseImageProcessor): The resampling filter to use if resizing the image. """ + model_input_names = ["pixel_values", "pixel_mask", "num_crops"] + def __init__( self, image_mean: List[float] = None, diff --git a/src/transformers/models/aria/processing_aria.py b/src/transformers/models/aria/processing_aria.py index 2cfbd72a0020..4b7163db8fd3 100644 --- a/src/transformers/models/aria/processing_aria.py +++ b/src/transformers/models/aria/processing_aria.py @@ -158,6 +158,10 @@ def decode(self, *args, **kwargs): def model_input_names(self): tokenizer_input_names = self.tokenizer.model_input_names image_processor_input_names = self.image_processor.model_input_names + + # Remove `num_crops`, it is popped and used only when processing. Make a copy of list when remocing + # otherwise `self.image_processor.model_input_names` is also modified + image_processor_input_names = [name for name in image_processor_input_names if name != "num_crops"] return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) diff --git a/src/transformers/models/emu3/image_processing_emu3.py b/src/transformers/models/emu3/image_processing_emu3.py index f28bc501ba16..1cc02f58ddce 100644 --- a/src/transformers/models/emu3/image_processing_emu3.py +++ b/src/transformers/models/emu3/image_processing_emu3.py @@ -132,7 +132,7 @@ class Emu3ImageProcessor(BaseImageProcessor): The spatial downsample factor the image will be downsampled in feature extracting phase """ - model_input_names = ["pixel_values"] + model_input_names = ["pixel_values", "image_sizes"] def __init__( self, diff --git a/src/transformers/models/emu3/processing_emu3.py b/src/transformers/models/emu3/processing_emu3.py index 2c536f5f2463..01966e470bdf 100644 --- a/src/transformers/models/emu3/processing_emu3.py +++ b/src/transformers/models/emu3/processing_emu3.py @@ -63,6 +63,7 @@ class Emu3Processor(ProcessorMixin): """ attributes = ["image_processor", "tokenizer"] + valid_kwargs = ["chat_template"] tokenizer_class = ("GPT2Tokenizer", "GPT2TokenizerFast") image_processor_class = "Emu3ImageProcessor" @@ -179,7 +180,7 @@ def __call__( data = self.tokenizer(text, **output_kwargs["text_kwargs"]) data.update(**image_features) - return BatchFeature(data=data, tensor_type=output_kwargs["common_kwargs"]["return_tensors"]) + return BatchFeature(data=data, tensor_type=output_kwargs["common_kwargs"].pop("return_tensors", None)) def calculate_generate_size(self, ratio, image_area, spatial_factor): width, height = map(int, ratio.split(":")) diff --git a/src/transformers/models/idefics2/image_processing_idefics2.py b/src/transformers/models/idefics2/image_processing_idefics2.py index 65d5a8285416..daf31174366a 100644 --- a/src/transformers/models/idefics2/image_processing_idefics2.py +++ b/src/transformers/models/idefics2/image_processing_idefics2.py @@ -217,7 +217,7 @@ class Idefics2ImageProcessor(BaseImageProcessor): strategy was first introduced in https://arxiv.org/abs/2311.06607. """ - model_input_names = ["pixel_values"] + model_input_names = ["pixel_values", "pixel_attention_mask"] def __init__( self, diff --git a/src/transformers/models/idefics3/image_processing_idefics3.py b/src/transformers/models/idefics3/image_processing_idefics3.py index df71a8bf0e85..452b5ebfee2a 100644 --- a/src/transformers/models/idefics3/image_processing_idefics3.py +++ b/src/transformers/models/idefics3/image_processing_idefics3.py @@ -323,7 +323,7 @@ class Idefics3ImageProcessor(BaseImageProcessor): sample in the batch, such that the returned tensor is of shape (batch_size, max_num_images, num_channels, max_height, max_width). """ - model_input_names = ["pixel_values"] + model_input_names = ["pixel_values", "pixel_attention_mask"] def __init__( self, diff --git a/src/transformers/models/idefics3/processing_idefics3.py b/src/transformers/models/idefics3/processing_idefics3.py index 40c8829fe76e..2cf605cf65fc 100644 --- a/src/transformers/models/idefics3/processing_idefics3.py +++ b/src/transformers/models/idefics3/processing_idefics3.py @@ -129,6 +129,7 @@ class Idefics3Processor(ProcessorMixin): """ attributes = ["image_processor", "tokenizer"] + valid_kwargs = ["image_seq_len", "chat_template"] image_processor_class = "Idefics3ImageProcessor" tokenizer_class = "AutoTokenizer" @@ -354,7 +355,7 @@ def decode(self, *args, **kwargs): def model_input_names(self): tokenizer_input_names = self.tokenizer.model_input_names image_processor_input_names = self.image_processor.model_input_names - return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) + return list(dict.fromkeys(image_processor_input_names + tokenizer_input_names)) __all__ = ["Idefics3Processor"] diff --git a/src/transformers/models/llava_next/image_processing_llava_next.py b/src/transformers/models/llava_next/image_processing_llava_next.py index 8e2a4f4644fc..b4ee3bf6dc3f 100644 --- a/src/transformers/models/llava_next/image_processing_llava_next.py +++ b/src/transformers/models/llava_next/image_processing_llava_next.py @@ -186,7 +186,7 @@ class LlavaNextImageProcessor(BaseImageProcessor): Whether to convert the image to RGB. """ - model_input_names = ["pixel_values"] + model_input_names = ["pixel_values", "image_sizes"] def __init__( self, diff --git a/src/transformers/models/mllama/processing_mllama.py b/src/transformers/models/mllama/processing_mllama.py index 5905f3313f78..03c95a085156 100644 --- a/src/transformers/models/mllama/processing_mllama.py +++ b/src/transformers/models/mllama/processing_mllama.py @@ -209,10 +209,11 @@ class MllamaProcessor(ProcessorMixin): """ attributes = ["image_processor", "tokenizer"] + valid_kwargs = ["chat_template"] image_processor_class = "MllamaImageProcessor" tokenizer_class = "PreTrainedTokenizerFast" - def __init__(self, image_processor, tokenizer): + def __init__(self, image_processor, tokenizer, chat_template=None): if not hasattr(tokenizer, "image_token"): self.image_token = "<|image|>" self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token) @@ -223,8 +224,7 @@ def __init__(self, image_processor, tokenizer): self.python_token = "<|python_tag|>" self.python_token_id = tokenizer.convert_tokens_to_ids(self.python_token) self.bos_token = tokenizer.bos_token - self.chat_template = tokenizer.chat_template - super().__init__(image_processor, tokenizer) + super().__init__(image_processor, tokenizer, chat_template=chat_template) def __call__( self, @@ -367,6 +367,10 @@ def post_process_image_text_to_text(self, generated_outputs): def model_input_names(self): tokenizer_input_names = self.tokenizer.model_input_names image_processor_input_names = self.image_processor.model_input_names + + # Remove `num_tiles`, it is popped and used only when processing. Make a copy of list when remocing + # otherwise `self.image_processor.model_input_names` is also modified + image_processor_input_names = [name for name in image_processor_input_names if name != "num_tiles"] return list(tokenizer_input_names + image_processor_input_names + ["cross_attention_mask"]) diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py index b94230c7d4a1..ca55e9d63cdb 100644 --- a/src/transformers/processing_utils.py +++ b/src/transformers/processing_utils.py @@ -762,7 +762,11 @@ def get_processor_dict( # (`cached_file` called using `_raise_exceptions_for_missing_entries=False` to avoid exception) # However, for models added in the future, we won't get the expected error if this file is missing. if resolved_processor_file is None: - return {}, kwargs + # In any case we need to pass `chat_template` if it is available + processor_dict = {} + if "chat_template" in kwargs: + processor_dict = {"chat_template": kwargs.pop("chat_template")} + return processor_dict, kwargs try: # Load processor dict @@ -786,6 +790,9 @@ def get_processor_dict( "in the processor's config. Make sure to move your template to its own file." ) + if "chat_template" in kwargs: + processor_dict["chat_template"] = kwargs.pop("chat_template") + if not is_local: if "auto_map" in processor_dict: processor_dict["auto_map"] = add_model_info_to_auto_map( @@ -817,7 +824,6 @@ def from_args_and_dict(cls, args, processor_dict: Dict[str, Any], **kwargs): """ processor_dict = processor_dict.copy() return_unused_kwargs = kwargs.pop("return_unused_kwargs", False) - chat_template = kwargs.pop("chat_template", None) # We have to pop up some unused (but specific) kwargs and then validate that it doesn't contain unused kwargs # If we don't pop, some specific kwargs will raise a warning @@ -829,8 +835,6 @@ def from_args_and_dict(cls, args, processor_dict: Dict[str, Any], **kwargs): unused_kwargs = cls.validate_init_kwargs(processor_config=processor_dict, valid_kwargs=cls.valid_kwargs) processor = cls(*args, **processor_dict) - if chat_template is not None: - setattr(processor, "chat_template", chat_template) # Update processor with kwargs if needed for key in set(kwargs.keys()): @@ -1199,12 +1203,6 @@ def apply_chat_template( "https://huggingface.co/docs/transformers/main/en/chat_templating for more information." ) - text_kwargs = {} - for key in TextKwargs.__annotations__.keys(): - value = kwargs.pop(key, None) - if value is not None: - text_kwargs[key] = value - chat_template_kwargs = {} for key in ChatTemplateKwargs.__annotations__.keys(): value = kwargs.pop(key, getattr(ChatTemplateKwargs, key)) @@ -1221,31 +1219,52 @@ def apply_chat_template( chat_template=chat_template, tokenize=False, return_dict=False, - **text_kwargs, **chat_template_kwargs, ) - # we will have to return all processed inputs in a dict + if isinstance(conversation, (list, tuple)) and ( + isinstance(conversation[0], (list, tuple)) or hasattr(conversation[0], "content") + ): + conversations = conversation + is_batched = True + else: + conversations = [conversation] + is_batched = False + + # We will have to return all processed inputs in a dict + # Currently all processors can accept flat list of visuals, but not all can accept nested list of batches + # So we'll make a simple list of images in the order they appear if tokenize: - images, videos = [], [] - for message in conversation: - visuals = [content for content in message["content"] if content["type"] in ["image", "video"]] - for vision_info in visuals: - if vision_info["type"] == "image": - for key in ["image", "url", "path", "base64"]: - if key in vision_info: - images.append(load_image(vision_info[key])) - elif vision_info["type"] == "video": - for key in ["video", "url", "path"]: - if key in vision_info: - videos.append( - load_video(vision_info[key], num_frames=num_frames, backend=video_load_backend) - ) + batch_images, batch_videos = [], [] + for conversation in conversations: + for message in conversation: + visuals = [content for content in message["content"] if content["type"] in ["image", "video"]] + for vision_info in visuals: + if vision_info["type"] == "image": + for key in ["image", "url", "path", "base64"]: + if key in vision_info: + batch_images.append(load_image(vision_info[key])) + elif vision_info["type"] == "video": + for key in ["video", "url", "path"]: + if key in vision_info: + batch_videos.append( + load_video(vision_info[key], num_frames=num_frames, backend=video_load_backend) + ) + + # Tokenizer's `apply_chat_template` never adds special tokens when tokenizing + # But processor's `apply_chat_template` didn't have an option to tokenize, so users had to format the prompt + # and pass it to the processor. Users thus never worried about special tokens relying on processor hadnling + # everything internally. The below line is to keep BC for that and be able to work with model that have + # special tokens in the template (consistent with tokenizers). We dont want to raise warning, it will flood command line + # without actionable solution for users + single_prompt = prompt[0] if is_batched else prompt + if self.tokenizer.bos_token is not None and single_prompt.startswith(self.tokenizer.bos_token): + kwargs["add_special_tokens"] = False out = self( text=prompt, - images=images if images else None, - videos=videos if videos else None, + images=batch_images if batch_images else None, + videos=batch_videos if batch_videos else None, **kwargs, ) if return_dict: diff --git a/tests/models/aria/test_processor_aria.py b/tests/models/aria/test_processor_aria.py index 7e23d861c775..623153b6798a 100644 --- a/tests/models/aria/test_processor_aria.py +++ b/tests/models/aria/test_processor_aria.py @@ -237,6 +237,55 @@ def test_apply_chat_template(self): """ self.assertEqual(rendered, expected_rendered) + # Override as AriaImageProcessor doesn't accept `do_rescale` + def test_chat_template_accepts_processing_kwargs(self): + processor = self.get_processor() + if processor.chat_template is None: + self.skipTest("Processor has no chat template") + + messages = [ + [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What is shown in this image?"}, + ], + }, + ] + ] + + formatted_prompt_tokenized = processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + padding="max_length", + max_length=50, + ) + self.assertEqual(len(formatted_prompt_tokenized[0]), 50) + + formatted_prompt_tokenized = processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + truncation=True, + max_length=5, + ) + self.assertEqual(len(formatted_prompt_tokenized[0]), 5) + + # Now test the ability to return dict + messages[0][0]["content"].append( + {"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"} + ) + out_dict = processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + max_image_size=980, + return_tensors="np", + ) + self.assertListEqual(list(out_dict[self.images_input_name].shape), [1, 3, 980, 980]) + # Override as AriaProcessor needs image tokens in prompts def prepare_text_inputs(self, batch_size: Optional[int] = None): if batch_size is None: diff --git a/tests/models/emu3/test_processor_emu3.py b/tests/models/emu3/test_processor_emu3.py index 7bc77075b1a6..c7b792c2a391 100644 --- a/tests/models/emu3/test_processor_emu3.py +++ b/tests/models/emu3/test_processor_emu3.py @@ -52,6 +52,11 @@ def setUp(self): ) processor.save_pretrained(self.tmpdirname) + def prepare_processor_dict(self): + return { + "chat_template": "{% for message in messages %}{% if message['role'] != 'system' %}{{ message['role'].upper() + ': '}}{% endif %}{# Render all images first #}{% for content in message['content'] | selectattr('type', 'equalto', 'image') %}{{ '' }}{% endfor %}{# Render all text next #}{% if message['role'] != 'assistant' %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{{ content['text'] + ' '}}{% endfor %}{% else %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{% generation %}{{ content['text'] + ' '}}{% endgeneration %}{% endfor %}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'ASSISTANT:' }}{% endif %}", + } # fmt: skip + def test_processor_for_generation(self): processor_components = self.prepare_components() processor = self.processor_class(**processor_components) diff --git a/tests/models/llava/test_processor_llava.py b/tests/models/llava/test_processor_llava.py index fd0ba8cacc18..a411430a1e8a 100644 --- a/tests/models/llava/test_processor_llava.py +++ b/tests/models/llava/test_processor_llava.py @@ -17,7 +17,7 @@ import unittest from transformers import AutoProcessor, AutoTokenizer, LlamaTokenizerFast, LlavaProcessor -from transformers.testing_utils import require_torch, require_vision +from transformers.testing_utils import require_vision from transformers.utils import is_torch_available, is_vision_available from ...test_processing_common import ProcessorTesterMixin @@ -27,7 +27,7 @@ from transformers import CLIPImageProcessor if is_torch_available: - import torch + pass @require_vision @@ -53,7 +53,11 @@ def tearDown(self): shutil.rmtree(self.tmpdirname) def prepare_processor_dict(self): - return {"chat_template": "dummy_template", "patch_size": 3, "vision_feature_select_strategy": "default"} + return { + "chat_template": "{% for message in messages %}{% if message['role'] != 'system' %}{{ message['role'].upper() + ': '}}{% endif %}{# Render all images first #}{% for content in message['content'] | selectattr('type', 'equalto', 'image') %}{{ '\n' }}{% endfor %}{# Render all text next #}{% if message['role'] != 'assistant' %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{{ content['text'] + ' '}}{% endfor %}{% else %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{% generation %}{{ content['text'] + ' '}}{% endgeneration %}{% endfor %}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'ASSISTANT:' }}{% endif %}", + "patch_size": 3, + "vision_feature_select_strategy": "default" + } # fmt: skip @unittest.skip( "Skip because the model has no processor kwargs except for chat template and" @@ -123,29 +127,6 @@ def test_chat_template_dict(self): ) self.assertListEqual(list(out_dict_with_image.keys()), ["input_ids", "attention_mask", "pixel_values"]) - @require_torch - def test_chat_template_dict_torch(self): - processor = LlavaProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf") - messages = [ - { - "role": "user", - "content": [ - {"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"}, - {"type": "text", "text": "What is shown in this image?"}, - ], - }, - ] - - out_dict_tensors = processor.apply_chat_template( - messages, - add_generation_prompt=True, - tokenize=True, - return_dict=True, - return_tensors="pt", - ) - self.assertListEqual(list(out_dict_tensors.keys()), ["input_ids", "attention_mask", "pixel_values"]) - self.assertTrue(isinstance(out_dict_tensors["input_ids"], torch.Tensor)) - def test_chat_template_with_continue_final_message(self): processor = LlavaProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf") expected_prompt = "USER: \nDescribe this image. ASSISTANT: There is a dog and" diff --git a/tests/models/llava_next/test_processor_llava_next.py b/tests/models/llava_next/test_processor_llava_next.py index 234e47911000..af1457841163 100644 --- a/tests/models/llava_next/test_processor_llava_next.py +++ b/tests/models/llava_next/test_processor_llava_next.py @@ -50,7 +50,11 @@ def get_image_processor(self, **kwargs): return LlavaNextProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor def prepare_processor_dict(self): - return {"chat_template": "dummy_template", "patch_size": 3, "vision_feature_select_strategy": "default"} + return { + "chat_template": "{% for message in messages %}{% if message['role'] != 'system' %}{{ message['role'].upper() + ': '}}{% endif %}{# Render all images first #}{% for content in message['content'] | selectattr('type', 'equalto', 'image') %}{{ '\n' }}{% endfor %}{# Render all text next #}{% if message['role'] != 'assistant' %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{{ content['text'] + ' '}}{% endfor %}{% else %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{% generation %}{{ content['text'] + ' '}}{% endgeneration %}{% endfor %}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'ASSISTANT:' }}{% endif %}", + "patch_size": 3, + "vision_feature_select_strategy": "default" + } # fmt: skip @unittest.skip( "Skip because the model has no processor kwargs except for chat template and" diff --git a/tests/models/llava_onevision/test_processor_llava_onevision.py b/tests/models/llava_onevision/test_processor_llava_onevision.py index 04aafa11a8b0..01e5e1d8384d 100644 --- a/tests/models/llava_onevision/test_processor_llava_onevision.py +++ b/tests/models/llava_onevision/test_processor_llava_onevision.py @@ -61,7 +61,11 @@ def get_video_processor(self, **kwargs): return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).video_processor def prepare_processor_dict(self): - return {"chat_template": "dummy_template", "num_image_tokens": 6, "vision_feature_select_strategy": "default"} + return { + "chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + ' '}}{# Render all images first #}{% for content in message['content'] | selectattr('type', 'equalto', 'image') %}{{ '' }}{% endfor %}{# Render all video then #}{% for content in message['content'] | selectattr('type', 'equalto', 'video') %}{{ '