diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 2a2cf4512af4..7557f295d968 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -643,6 +643,8 @@ title: ConvNeXTV2 - local: model_doc/cvt title: CvT + - local: model_doc/dab-detr + title: DAB-DETR - local: model_doc/deformable_detr title: Deformable DETR - local: model_doc/deit diff --git a/docs/source/en/index.md b/docs/source/en/index.md index 92a082f0a911..9c3c5c76954d 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -110,6 +110,7 @@ Flax), PyTorch, and/or TensorFlow. | [CPM-Ant](model_doc/cpmant) | ✅ | ❌ | ❌ | | [CTRL](model_doc/ctrl) | ✅ | ✅ | ❌ | | [CvT](model_doc/cvt) | ✅ | ✅ | ❌ | +| [DAB-DETR](model_doc/dab-detr) | ✅ | ❌ | ❌ | | [DAC](model_doc/dac) | ✅ | ❌ | ❌ | | [Data2VecAudio](model_doc/data2vec) | ✅ | ❌ | ❌ | | [Data2VecText](model_doc/data2vec) | ✅ | ❌ | ❌ | diff --git a/docs/source/en/model_doc/dab-detr.md b/docs/source/en/model_doc/dab-detr.md new file mode 100644 index 000000000000..6071ee6ca460 --- /dev/null +++ b/docs/source/en/model_doc/dab-detr.md @@ -0,0 +1,119 @@ + + +# DAB-DETR + +## Overview + +The DAB-DETR model was proposed in [DAB-DETR: Dynamic Anchor Boxes are Better Queries for DETR](https://arxiv.org/abs/2201.12329) by Shilong Liu, Feng Li, Hao Zhang, Xiao Yang, Xianbiao Qi, Hang Su, Jun Zhu, Lei Zhang. +DAB-DETR is an enhanced variant of Conditional DETR. It utilizes dynamically updated anchor boxes to provide both a reference query point (x, y) and a reference anchor size (w, h), improving cross-attention computation. This new approach achieves 45.7% AP when trained for 50 epochs with a single ResNet-50 model as the backbone. + + + +The abstract from the paper is the following: + +*We present in this paper a novel query formulation using dynamic anchor boxes +for DETR (DEtection TRansformer) and offer a deeper understanding of the role +of queries in DETR. This new formulation directly uses box coordinates as queries +in Transformer decoders and dynamically updates them layer-by-layer. Using box +coordinates not only helps using explicit positional priors to improve the query-to-feature similarity and eliminate the slow training convergence issue in DETR, +but also allows us to modulate the positional attention map using the box width +and height information. Such a design makes it clear that queries in DETR can be +implemented as performing soft ROI pooling layer-by-layer in a cascade manner. +As a result, it leads to the best performance on MS-COCO benchmark among +the DETR-like detection models under the same setting, e.g., AP 45.7% using +ResNet50-DC5 as backbone trained in 50 epochs. We also conducted extensive +experiments to confirm our analysis and verify the effectiveness of our methods.* + +This model was contributed by [davidhajdu](https://huggingface.co/davidhajdu). +The original code can be found [here](https://github.com/IDEA-Research/DAB-DETR). + +## How to Get Started with the Model + +Use the code below to get started with the model. + +```python +import torch +import requests + +from PIL import Image +from transformers import AutoModelForObjectDetection, AutoImageProcessor + +url = 'http://images.cocodataset.org/val2017/000000039769.jpg' +image = Image.open(requests.get(url, stream=True).raw) + +image_processor = AutoImageProcessor.from_pretrained("IDEA-Research/dab-detr-resnet-50") +model = AutoModelForObjectDetection.from_pretrained("IDEA-Research/dab-detr-resnet-50") + +inputs = image_processor(images=image, return_tensors="pt") + +with torch.no_grad(): + outputs = model(**inputs) + +results = image_processor.post_process_object_detection(outputs, target_sizes=torch.tensor([image.size[::-1]]), threshold=0.3) + +for result in results: + for score, label_id, box in zip(result["scores"], result["labels"], result["boxes"]): + score, label = score.item(), label_id.item() + box = [round(i, 2) for i in box.tolist()] + print(f"{model.config.id2label[label]}: {score:.2f} {box}") +``` +This should output +``` +cat: 0.87 [14.7, 49.39, 320.52, 469.28] +remote: 0.86 [41.08, 72.37, 173.39, 117.2] +cat: 0.86 [344.45, 19.43, 639.85, 367.86] +remote: 0.61 [334.27, 75.93, 367.92, 188.81] +couch: 0.59 [-0.04, 1.34, 639.9, 477.09] +``` + +There are three other ways to instantiate a DAB-DETR model (depending on what you prefer): + +Option 1: Instantiate DAB-DETR with pre-trained weights for entire model +```py +>>> from transformers import DabDetrForObjectDetection + +>>> model = DabDetrForObjectDetection.from_pretrained("IDEA-Research/dab-detr-resnet-50") +``` + +Option 2: Instantiate DAB-DETR with randomly initialized weights for Transformer, but pre-trained weights for backbone +```py +>>> from transformers import DabDetrConfig, DabDetrForObjectDetection + +>>> config = DabDetrConfig() +>>> model = DabDetrForObjectDetection(config) +``` +Option 3: Instantiate DAB-DETR with randomly initialized weights for backbone + Transformer +```py +>>> config = DabDetrConfig(use_pretrained_backbone=False) +>>> model = DabDetrForObjectDetection(config) +``` + + +## DabDetrConfig + +[[autodoc]] DabDetrConfig + +## DabDetrModel + +[[autodoc]] DabDetrModel + - forward + +## DabDetrForObjectDetection + +[[autodoc]] DabDetrForObjectDetection + - forward diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index ae92f21dcc2a..ea832eded471 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -328,6 +328,7 @@ "CTRLTokenizer", ], "models.cvt": ["CvtConfig"], + "models.dab_detr": ["DabDetrConfig"], "models.dac": ["DacConfig", "DacFeatureExtractor"], "models.data2vec": [ "Data2VecAudioConfig", @@ -1898,6 +1899,13 @@ "CvtPreTrainedModel", ] ) + _import_structure["models.dab_detr"].extend( + [ + "DabDetrForObjectDetection", + "DabDetrModel", + "DabDetrPreTrainedModel", + ] + ) _import_structure["models.dac"].extend( [ "DacModel", @@ -5387,6 +5395,9 @@ CTRLTokenizer, ) from .models.cvt import CvtConfig + from .models.dab_detr import ( + DabDetrConfig, + ) from .models.dac import ( DacConfig, DacFeatureExtractor, @@ -6926,6 +6937,11 @@ CvtModel, CvtPreTrainedModel, ) + from .models.dab_detr import ( + DabDetrForObjectDetection, + DabDetrModel, + DabDetrPreTrainedModel, + ) from .models.dac import ( DacModel, DacPreTrainedModel, diff --git a/src/transformers/activations.py b/src/transformers/activations.py index 2355fb5fed67..15f0397535e8 100644 --- a/src/transformers/activations.py +++ b/src/transformers/activations.py @@ -217,6 +217,7 @@ def __getitem__(self, key): "silu": nn.SiLU, "swish": nn.SiLU, "tanh": nn.Tanh, + "prelu": nn.PReLU, } ACT2FN = ClassInstantier(ACT2CLS) diff --git a/src/transformers/loss/loss_utils.py b/src/transformers/loss/loss_utils.py index 0f39fde40a7c..86f8634a45e0 100644 --- a/src/transformers/loss/loss_utils.py +++ b/src/transformers/loss/loss_utils.py @@ -128,6 +128,7 @@ def ForTokenClassification(logits, labels, config, **kwargs): "ForObjectDetection": ForObjectDetectionLoss, "DeformableDetrForObjectDetection": DeformableDetrForObjectDetectionLoss, "ConditionalDetrForObjectDetection": DeformableDetrForObjectDetectionLoss, + "DabDetrForObjectDetection": DeformableDetrForObjectDetectionLoss, "GroundingDinoForObjectDetection": DeformableDetrForObjectDetectionLoss, "ConditionalDetrForSegmentation": DeformableDetrForSegmentationLoss, "RTDetrForObjectDetection": RTDetrForObjectDetectionLoss, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index f62d5d71672b..1667edbf37ab 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -63,6 +63,7 @@ cpmant, ctrl, cvt, + dab_detr, dac, data2vec, dbrx, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 699e307ac1b6..bd6fcb4a9d82 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -79,6 +79,7 @@ ("cpmant", "CpmAntConfig"), ("ctrl", "CTRLConfig"), ("cvt", "CvtConfig"), + ("dab-detr", "DabDetrConfig"), ("dac", "DacConfig"), ("data2vec-audio", "Data2VecAudioConfig"), ("data2vec-text", "Data2VecTextConfig"), @@ -399,6 +400,7 @@ ("cpmant", "CPM-Ant"), ("ctrl", "CTRL"), ("cvt", "CvT"), + ("dab-detr", "DAB-DETR"), ("dac", "DAC"), ("data2vec-audio", "Data2VecAudio"), ("data2vec-text", "Data2VecText"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 3b023251e1d9..87e2dab68708 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -78,6 +78,7 @@ ("cpmant", "CpmAntModel"), ("ctrl", "CTRLModel"), ("cvt", "CvtModel"), + ("dab-detr", "DabDetrModel"), ("dac", "DacModel"), ("data2vec-audio", "Data2VecAudioModel"), ("data2vec-text", "Data2VecTextModel"), @@ -592,6 +593,7 @@ ("conditional_detr", "ConditionalDetrModel"), ("convnext", "ConvNextModel"), ("convnextv2", "ConvNextV2Model"), + ("dab-detr", "DabDetrModel"), ("data2vec-vision", "Data2VecVisionModel"), ("deformable_detr", "DeformableDetrModel"), ("deit", "DeiTModel"), @@ -890,6 +892,7 @@ [ # Model for Object Detection mapping ("conditional_detr", "ConditionalDetrForObjectDetection"), + ("dab-detr", "DabDetrForObjectDetection"), ("deformable_detr", "DeformableDetrForObjectDetection"), ("deta", "DetaForObjectDetection"), ("detr", "DetrForObjectDetection"), diff --git a/src/transformers/models/conditional_detr/configuration_conditional_detr.py b/src/transformers/models/conditional_detr/configuration_conditional_detr.py index 8dae72edff08..7eecc6eda090 100644 --- a/src/transformers/models/conditional_detr/configuration_conditional_detr.py +++ b/src/transformers/models/conditional_detr/configuration_conditional_detr.py @@ -52,7 +52,7 @@ class ConditionalDetrConfig(PretrainedConfig): Number of object queries, i.e. detection slots. This is the maximal number of objects [`ConditionalDetrModel`] can detect in a single image. For COCO, we recommend 100 queries. d_model (`int`, *optional*, defaults to 256): - Dimension of the layers. + This parameter is a general dimension parameter, defining dimensions for components such as the encoder layer and projection parameters in the decoder layer, among others. encoder_layers (`int`, *optional*, defaults to 6): Number of encoder layers. decoder_layers (`int`, *optional*, defaults to 6): diff --git a/src/transformers/models/conditional_detr/modeling_conditional_detr.py b/src/transformers/models/conditional_detr/modeling_conditional_detr.py index 0aa4b2afa6bf..d020b94cffde 100644 --- a/src/transformers/models/conditional_detr/modeling_conditional_detr.py +++ b/src/transformers/models/conditional_detr/modeling_conditional_detr.py @@ -74,6 +74,8 @@ class ConditionalDetrDecoderOutput(BaseModelOutputWithCrossAttentions): intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, hidden_size)`, *optional*, returned when `config.auxiliary_loss=True`): Intermediate decoder activations, i.e. the output of each decoder layer, each of them gone through a layernorm. + reference_points (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, 2 (anchor points))`): + Reference points (reference points of each layer of the decoder). """ intermediate_hidden_states: Optional[torch.FloatTensor] = None @@ -116,6 +118,8 @@ class ConditionalDetrModelOutput(Seq2SeqModelOutput): intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, sequence_length, hidden_size)`, *optional*, returned when `config.auxiliary_loss=True`): Intermediate decoder activations, i.e. the output of each decoder layer, each of them gone through a layernorm. + reference_points (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, 2 (anchor points))`): + Reference points (reference points of each layer of the decoder). """ intermediate_hidden_states: Optional[torch.FloatTensor] = None diff --git a/src/transformers/models/dab_detr/__init__.py b/src/transformers/models/dab_detr/__init__.py new file mode 100644 index 000000000000..bfa364bd2152 --- /dev/null +++ b/src/transformers/models/dab_detr/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_dab_detr import * + from .modeling_dab_detr import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/dab_detr/configuration_dab_detr.py b/src/transformers/models/dab_detr/configuration_dab_detr.py new file mode 100644 index 000000000000..398e6f26591f --- /dev/null +++ b/src/transformers/models/dab_detr/configuration_dab_detr.py @@ -0,0 +1,260 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""DAB-DETR model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ...utils.backbone_utils import verify_backbone_config_arguments +from ..auto import CONFIG_MAPPING + + +logger = logging.get_logger(__name__) + + +class DabDetrConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`DabDetrModel`]. It is used to instantiate + a DAB-DETR model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the DAB-DETR + [IDEA-Research/dab_detr-base](https://huggingface.co/IDEA-Research/dab_detr-base) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + use_timm_backbone (`bool`, *optional*, defaults to `True`): + Whether or not to use the `timm` library for the backbone. If set to `False`, will use the [`AutoBackbone`] + API. + backbone_config (`PretrainedConfig` or `dict`, *optional*): + The configuration of the backbone model. Only used in case `use_timm_backbone` is set to `False` in which + case it will default to `ResNetConfig()`. + backbone (`str`, *optional*, defaults to `"resnet50"`): + Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this + will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone` + is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights. + use_pretrained_backbone (`bool`, *optional*, defaults to `True`): + Whether to use pretrained weights for the backbone. + backbone_kwargs (`dict`, *optional*): + Keyword arguments to be passed to AutoBackbone when loading from a checkpoint + e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set. + num_queries (`int`, *optional*, defaults to 300): + Number of object queries, i.e. detection slots. This is the maximal number of objects + [`DabDetrModel`] can detect in a single image. For COCO, we recommend 100 queries. + encoder_layers (`int`, *optional*, defaults to 6): + Number of encoder layers. + encoder_ffn_dim (`int`, *optional*, defaults to 2048): + Dimension of the "intermediate" (often named feed-forward) layer in encoder. + encoder_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer encoder. + decoder_layers (`int`, *optional*, defaults to 6): + Number of decoder layers. + decoder_ffn_dim (`int`, *optional*, defaults to 2048): + Dimension of the "intermediate" (often named feed-forward) layer in decoder. + decoder_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer decoder. + is_encoder_decoder (`bool`, *optional*, defaults to `True`): + Indicates whether the transformer model architecture is an encoder-decoder or not. + activation_function (`str` or `function`, *optional*, defaults to `"prelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + hidden_size (`int`, *optional*, defaults to 256): + This parameter is a general dimension parameter, defining dimensions for components such as the encoder layer and projection parameters in the decoder layer, among others. + dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + activation_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for activations inside the fully connected layer. + init_std (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + init_xavier_std (`float`, *optional*, defaults to 1.0): + The scaling factor used for the Xavier initialization gain in the HM Attention map module. + auxiliary_loss (`bool`, *optional*, defaults to `False`): + Whether auxiliary decoding losses (loss at each decoder layer) are to be used. + dilation (`bool`, *optional*, defaults to `False`): + Whether to replace stride with dilation in the last convolutional block (DC5). Only supported when `use_timm_backbone` = `True`. + class_cost (`float`, *optional*, defaults to 2): + Relative weight of the classification error in the Hungarian matching cost. + bbox_cost (`float`, *optional*, defaults to 5): + Relative weight of the L1 error of the bounding box coordinates in the Hungarian matching cost. + giou_cost (`float`, *optional*, defaults to 2): + Relative weight of the generalized IoU loss of the bounding box in the Hungarian matching cost. + cls_loss_coefficient (`float`, *optional*, defaults to 2): + Relative weight of the classification loss in the object detection loss function. + bbox_loss_coefficient (`float`, *optional*, defaults to 5): + Relative weight of the L1 bounding box loss in the object detection loss. + giou_loss_coefficient (`float`, *optional*, defaults to 2): + Relative weight of the generalized IoU loss in the object detection loss. + focal_alpha (`float`, *optional*, defaults to 0.25): + Alpha parameter in the focal loss. + temperature_height (`int`, *optional*, defaults to 20): + Temperature parameter to tune the flatness of positional attention (HEIGHT) + temperature_width (`int`, *optional*, defaults to 20): + Temperature parameter to tune the flatness of positional attention (WIDTH) + query_dim (`int`, *optional*, defaults to 4): + Query dimension parameter represents the size of the output vector. + random_refpoints_xy (`bool`, *optional*, defaults to `False`): + Whether to fix the x and y coordinates of the anchor boxes with random initialization. + keep_query_pos (`bool`, *optional*, defaults to `False`): + Whether to concatenate the projected positional embedding from the object query into the original query (key) in every decoder layer. + num_patterns (`int`, *optional*, defaults to 0): + Number of pattern embeddings. + normalize_before (`bool`, *optional*, defaults to `False`): + Whether we use a normalization layer in the Encoder or not. + sine_position_embedding_scale (`float`, *optional*, defaults to 'None'): + Scaling factor applied to the normalized positional encodings. + initializer_bias_prior_prob (`float`, *optional*): + The prior probability used by the bias initializer to initialize biases for `enc_score_head` and `class_embed`. + If `None`, `prior_prob` computed as `prior_prob = 1 / (num_labels + 1)` while initializing model weights. + + + Examples: + + ```python + >>> from transformers import DabDetrConfig, DabDetrModel + + >>> # Initializing a DAB-DETR IDEA-Research/dab_detr-base style configuration + >>> configuration = DabDetrConfig() + + >>> # Initializing a model (with random weights) from the IDEA-Research/dab_detr-base style configuration + >>> model = DabDetrModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "dab-detr" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = { + "num_attention_heads": "encoder_attention_heads", + } + + def __init__( + self, + use_timm_backbone=True, + backbone_config=None, + backbone="resnet50", + use_pretrained_backbone=True, + backbone_kwargs=None, + num_queries=300, + encoder_layers=6, + encoder_ffn_dim=2048, + encoder_attention_heads=8, + decoder_layers=6, + decoder_ffn_dim=2048, + decoder_attention_heads=8, + is_encoder_decoder=True, + activation_function="prelu", + hidden_size=256, + dropout=0.1, + attention_dropout=0.0, + activation_dropout=0.0, + init_std=0.02, + init_xavier_std=1.0, + auxiliary_loss=False, + dilation=False, + class_cost=2, + bbox_cost=5, + giou_cost=2, + cls_loss_coefficient=2, + bbox_loss_coefficient=5, + giou_loss_coefficient=2, + focal_alpha=0.25, + temperature_height=20, + temperature_width=20, + query_dim=4, + random_refpoints_xy=False, + keep_query_pos=False, + num_patterns=0, + normalize_before=False, + sine_position_embedding_scale=None, + initializer_bias_prior_prob=None, + **kwargs, + ): + if query_dim != 4: + raise ValueError("The query dimensions has to be 4.") + + # We default to values which were previously hard-coded in the model. This enables configurability of the config + # while keeping the default behavior the same. + if use_timm_backbone and backbone_kwargs is None: + backbone_kwargs = {} + if dilation: + backbone_kwargs["output_stride"] = 16 + backbone_kwargs["out_indices"] = [1, 2, 3, 4] + backbone_kwargs["in_chans"] = 3 # num_channels + # Backwards compatibility + elif not use_timm_backbone and backbone in (None, "resnet50"): + if backbone_config is None: + logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.") + backbone_config = CONFIG_MAPPING["resnet"](out_features=["stage4"]) + elif isinstance(backbone_config, dict): + backbone_model_type = backbone_config.get("model_type") + config_class = CONFIG_MAPPING[backbone_model_type] + backbone_config = config_class.from_dict(backbone_config) + backbone = None + # set timm attributes to None + dilation = None + + verify_backbone_config_arguments( + use_timm_backbone=use_timm_backbone, + use_pretrained_backbone=use_pretrained_backbone, + backbone=backbone, + backbone_config=backbone_config, + backbone_kwargs=backbone_kwargs, + ) + + self.use_timm_backbone = use_timm_backbone + self.backbone_config = backbone_config + self.num_queries = num_queries + self.hidden_size = hidden_size + self.encoder_ffn_dim = encoder_ffn_dim + self.encoder_layers = encoder_layers + self.encoder_attention_heads = encoder_attention_heads + self.decoder_ffn_dim = decoder_ffn_dim + self.decoder_layers = decoder_layers + self.decoder_attention_heads = decoder_attention_heads + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.activation_function = activation_function + self.init_std = init_std + self.init_xavier_std = init_xavier_std + self.num_hidden_layers = encoder_layers + self.auxiliary_loss = auxiliary_loss + self.backbone = backbone + self.use_pretrained_backbone = use_pretrained_backbone + self.backbone_kwargs = backbone_kwargs + # Hungarian matcher + self.class_cost = class_cost + self.bbox_cost = bbox_cost + self.giou_cost = giou_cost + # Loss coefficients + self.cls_loss_coefficient = cls_loss_coefficient + self.bbox_loss_coefficient = bbox_loss_coefficient + self.giou_loss_coefficient = giou_loss_coefficient + self.focal_alpha = focal_alpha + self.query_dim = query_dim + self.random_refpoints_xy = random_refpoints_xy + self.keep_query_pos = keep_query_pos + self.num_patterns = num_patterns + self.normalize_before = normalize_before + self.temperature_width = temperature_width + self.temperature_height = temperature_height + self.sine_position_embedding_scale = sine_position_embedding_scale + self.initializer_bias_prior_prob = initializer_bias_prior_prob + super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs) + + +__all__ = ["DabDetrConfig"] diff --git a/src/transformers/models/dab_detr/convert_dab_detr_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/dab_detr/convert_dab_detr_original_pytorch_checkpoint_to_pytorch.py new file mode 100644 index 000000000000..a6e5081b484c --- /dev/null +++ b/src/transformers/models/dab_detr/convert_dab_detr_original_pytorch_checkpoint_to_pytorch.py @@ -0,0 +1,233 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert DAB-DETR checkpoints.""" + +import argparse +import gc +import json +import re +from pathlib import Path + +import torch +from huggingface_hub import hf_hub_download + +from transformers import ConditionalDetrImageProcessor, DabDetrConfig, DabDetrForObjectDetection +from transformers.utils import logging + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + +ORIGINAL_TO_CONVERTED_KEY_MAPPING = { + # convolutional projection + query embeddings + layernorm of decoder + class and bounding box heads + # for dab-DETR, also convert reference point head and query scale MLP + r"input_proj\.(bias|weight)": r"input_projection.\1", + r"refpoint_embed\.weight": r"query_refpoint_embeddings.weight", + r"class_embed\.(bias|weight)": r"class_embed.\1", + # negative lookbehind because of the overlap + r"(?DabDetr,Conditional DETR->DAB-DETR,2 (anchor points)->4 (anchor points) +class DabDetrDecoderOutput(BaseModelOutputWithCrossAttentions): + """ + Base class for outputs of the Conditional DETR decoder. This class adds one attribute to + BaseModelOutputWithCrossAttentions, namely an optional stack of intermediate decoder activations, i.e. the output + of each decoder layer, each of them gone through a layernorm. This is useful when training the model with auxiliary + decoding losses. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer + plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in + the self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax, + used to compute the weighted average in the cross-attention heads. + intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, hidden_size)`, *optional*, returned when `config.auxiliary_loss=True`): + Intermediate decoder activations, i.e. the output of each decoder layer, each of them gone through a + layernorm. + reference_points (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, 2 (anchor points))`): + Reference points (reference points of each layer of the decoder). + """ + + intermediate_hidden_states: Optional[torch.FloatTensor] = None + reference_points: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +# Copied from transformers.models.conditional_detr.modeling_conditional_detr.ConditionalDetrModelOutput with ConditionalDetr->DabDetr,Conditional DETR->DAB-DETR,2 (anchor points)->4 (anchor points) +class DabDetrModelOutput(Seq2SeqModelOutput): + """ + Base class for outputs of the Conditional DETR encoder-decoder model. This class adds one attribute to + Seq2SeqModelOutput, namely an optional stack of intermediate decoder activations, i.e. the output of each decoder + layer, each of them gone through a layernorm. This is useful when training the model with auxiliary decoding + losses. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the decoder of the model. + decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the decoder at the output of each + layer plus the initial embedding outputs. + decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights of the decoder, after the attention softmax, used to compute the + weighted average in the self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax, + used to compute the weighted average in the cross-attention heads. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the encoder at the output of each + layer plus the initial embedding outputs. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights of the encoder, after the attention softmax, used to compute the + weighted average in the self-attention heads. + intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, sequence_length, hidden_size)`, *optional*, returned when `config.auxiliary_loss=True`): + Intermediate decoder activations, i.e. the output of each decoder layer, each of them gone through a + layernorm. + reference_points (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, 2 (anchor points))`): + Reference points (reference points of each layer of the decoder). + """ + + intermediate_hidden_states: Optional[torch.FloatTensor] = None + reference_points: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +# Copied from transformers.models.detr.modeling_detr.DetrObjectDetectionOutput with Detr->DabDetr +class DabDetrObjectDetectionOutput(ModelOutput): + """ + Output type of [`DabDetrForObjectDetection`]. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)): + Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a + bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized + scale-invariant IoU loss. + loss_dict (`Dict`, *optional*): + A dictionary containing the individual losses. Useful for logging. + logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`): + Classification logits (including no-object) for all queries. + pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These + values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding + possible padding). You can use [`~DabDetrImageProcessor.post_process_object_detection`] to retrieve the + unnormalized bounding boxes. + auxiliary_outputs (`list[Dict]`, *optional*): + Optional, only returned when auxilary losses are activated (i.e. `config.auxiliary_loss` is set to `True`) + and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and + `pred_boxes`) for each decoder layer. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the decoder of the model. + decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the decoder at the output of each + layer plus the initial embedding outputs. + decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights of the decoder, after the attention softmax, used to compute the + weighted average in the self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax, + used to compute the weighted average in the cross-attention heads. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the encoder at the output of each + layer plus the initial embedding outputs. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights of the encoder, after the attention softmax, used to compute the + weighted average in the self-attention heads. + """ + + loss: Optional[torch.FloatTensor] = None + loss_dict: Optional[Dict] = None + logits: torch.FloatTensor = None + pred_boxes: torch.FloatTensor = None + auxiliary_outputs: Optional[List[Dict]] = None + last_hidden_state: Optional[torch.FloatTensor] = None + decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + + +# Copied from transformers.models.detr.modeling_detr.DetrFrozenBatchNorm2d with Detr->DabDetr +class DabDetrFrozenBatchNorm2d(nn.Module): + """ + BatchNorm2d where the batch statistics and the affine parameters are fixed. + + Copy-paste from torchvision.misc.ops with added eps before rqsrt, without which any other models than + torchvision.models.resnet[18,34,50,101] produce nans. + """ + + def __init__(self, n): + super().__init__() + self.register_buffer("weight", torch.ones(n)) + self.register_buffer("bias", torch.zeros(n)) + self.register_buffer("running_mean", torch.zeros(n)) + self.register_buffer("running_var", torch.ones(n)) + + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + num_batches_tracked_key = prefix + "num_batches_tracked" + if num_batches_tracked_key in state_dict: + del state_dict[num_batches_tracked_key] + + super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) + + def forward(self, x): + # move reshapes to the beginning + # to make it user-friendly + weight = self.weight.reshape(1, -1, 1, 1) + bias = self.bias.reshape(1, -1, 1, 1) + running_var = self.running_var.reshape(1, -1, 1, 1) + running_mean = self.running_mean.reshape(1, -1, 1, 1) + epsilon = 1e-5 + scale = weight * (running_var + epsilon).rsqrt() + bias = bias - running_mean * scale + return x * scale + bias + + +# Copied from transformers.models.detr.modeling_detr.replace_batch_norm with Detr->DabDetr +def replace_batch_norm(model): + r""" + Recursively replace all `torch.nn.BatchNorm2d` with `DabDetrFrozenBatchNorm2d`. + + Args: + model (torch.nn.Module): + input model + """ + for name, module in model.named_children(): + if isinstance(module, nn.BatchNorm2d): + new_module = DabDetrFrozenBatchNorm2d(module.num_features) + + if not module.weight.device == torch.device("meta"): + new_module.weight.data.copy_(module.weight) + new_module.bias.data.copy_(module.bias) + new_module.running_mean.data.copy_(module.running_mean) + new_module.running_var.data.copy_(module.running_var) + + model._modules[name] = new_module + + if len(list(module.children())) > 0: + replace_batch_norm(module) + + +# Modified from transformers.models.detr.modeling_detr.DetrConvEncoder with Detr->DabDetr +class DabDetrConvEncoder(nn.Module): + """ + Convolutional backbone, using either the AutoBackbone API or one from the timm library. + + nn.BatchNorm2d layers are replaced by DabDetrFrozenBatchNorm2d as defined above. + + """ + + def __init__(self, config: DabDetrConfig): + super().__init__() + + self.config = config + backbone = load_backbone(config) + + # replace batch norm by frozen batch norm + with torch.no_grad(): + replace_batch_norm(backbone) + self.model = backbone + self.intermediate_channel_sizes = self.model.channels + + def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor): + # send pixel_values through the model to get list of feature maps + features = self.model(pixel_values).feature_maps + + out = [] + for feature_map in features: + # downsample pixel_mask to match shape of corresponding feature_map + mask = nn.functional.interpolate(pixel_mask[None].float(), size=feature_map.shape[-2:]).to(torch.bool)[0] + out.append((feature_map, mask)) + return out + + +# Copied from transformers.models.detr.modeling_detr.DetrConvModel with Detr->DabDetr +class DabDetrConvModel(nn.Module): + """ + This module adds 2D position embeddings to all intermediate feature maps of the convolutional encoder. + """ + + def __init__(self, conv_encoder, position_embedding): + super().__init__() + self.conv_encoder = conv_encoder + self.position_embedding = position_embedding + + def forward(self, pixel_values, pixel_mask): + # send pixel_values and pixel_mask through backbone to get list of (feature_map, pixel_mask) tuples + out = self.conv_encoder(pixel_values, pixel_mask) + pos = [] + for feature_map, mask in out: + # position encoding + pos.append(self.position_embedding(feature_map, mask).to(feature_map.dtype)) + + return out, pos + + +# Modified from transformers.models.conditional_detr.modeling_conditional_detr.ConditionalDetrSinePositionEmbedding with ConditionalDetr->DabDetr +class DabDetrSinePositionEmbedding(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one used by the Attention is all you + need paper, generalized to work on images. + """ + + def __init__(self, config: DabDetrConfig): + super().__init__() + self.config = config + self.embedding_dim = config.hidden_size / 2 + self.temperature_height = config.temperature_height + self.temperature_width = config.temperature_width + scale = config.sine_position_embedding_scale + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, pixel_values, pixel_mask): + if pixel_mask is None: + raise ValueError("No pixel mask provided") + y_embed = pixel_mask.cumsum(1, dtype=torch.float32) + x_embed = pixel_mask.cumsum(2, dtype=torch.float32) + y_embed = y_embed / (y_embed[:, -1:, :] + 1e-6) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + 1e-6) * self.scale + + # We use float32 to ensure reproducibility of the original implementation + dim_tx = torch.arange(self.embedding_dim, dtype=torch.float32, device=pixel_values.device) + # Modifying dim_tx in place to avoid extra memory allocation -> dim_tx = self.temperature_width ** (2 * (dim_tx // 2) / self.embedding_dim) + dim_tx //= 2 + dim_tx.mul_(2 / self.embedding_dim) + dim_tx.copy_(self.temperature_width**dim_tx) + pos_x = x_embed[:, :, :, None] / dim_tx + + # We use float32 to ensure reproducibility of the original implementation + dim_ty = torch.arange(self.embedding_dim, dtype=torch.float32, device=pixel_values.device) + # Modifying dim_ty in place to avoid extra memory allocation -> dim_ty = self.temperature_height ** (2 * (dim_ty // 2) / self.embedding_dim) + dim_ty //= 2 + dim_ty.mul_(2 / self.embedding_dim) + dim_ty.copy_(self.temperature_height**dim_ty) + pos_y = y_embed[:, :, :, None] / dim_ty + + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + +# function to generate sine positional embedding for 4d coordinates +def gen_sine_position_embeddings(pos_tensor, hidden_size=256): + """ + This function computes position embeddings using sine and cosine functions from the input positional tensor, + which has a shape of (batch_size, num_queries, 4). + The last dimension of `pos_tensor` represents the following coordinates: + - 0: x-coord + - 1: y-coord + - 2: width + - 3: height + + The output shape is (batch_size, num_queries, 512), where final dim (hidden_size*2 = 512) is the total embedding dimension + achieved by concatenating the sine and cosine values for each coordinate. + """ + scale = 2 * math.pi + dim = hidden_size // 2 + dim_t = torch.arange(dim, dtype=torch.float32, device=pos_tensor.device) + dim_t = 10000 ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / dim) + x_embed = pos_tensor[:, :, 0] * scale + y_embed = pos_tensor[:, :, 1] * scale + pos_x = x_embed[:, :, None] / dim_t + pos_y = y_embed[:, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2) + pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2) + if pos_tensor.size(-1) == 4: + w_embed = pos_tensor[:, :, 2] * scale + pos_w = w_embed[:, :, None] / dim_t + pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2) + + h_embed = pos_tensor[:, :, 3] * scale + pos_h = h_embed[:, :, None] / dim_t + pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2) + + pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2) + else: + raise ValueError("Unknown pos_tensor shape(-1):{}".format(pos_tensor.size(-1))) + return pos + + +def inverse_sigmoid(x, eps=1e-5): + x = x.clamp(min=0, max=1) + x1 = x.clamp(min=eps) + x2 = (1 - x).clamp(min=eps) + return torch.log(x1 / x2) + + +# Modified from transformers.models.detr.modeling_detr.DetrAttention +class DetrAttention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. + + Here, we add position embeddings to the queries and keys (as explained in the DETR paper). + """ + + def __init__( + self, + config: DabDetrConfig, + bias: bool = True, + ): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.num_heads = config.encoder_attention_heads + self.attention_dropout = config.attention_dropout + self.head_dim = self.hidden_size // self.num_heads + if self.head_dim * self.num_heads != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:" + f" {self.num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=bias) + self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=bias) + self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=bias) + self.out_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=bias) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + object_queries: Optional[torch.Tensor] = None, + key_value_states: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + batch_size, q_len, embed_dim = hidden_states.size() + # add position embeddings to the hidden states before projecting to queries and keys + if object_queries is not None: + hidden_states_original = hidden_states + hidden_states = hidden_states + object_queries + + query_states = self.q_proj(hidden_states) * self.scaling + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states_original) + + query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) + + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(batch_size, q_len, embed_dim) + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights + + +# Modified from transformers.models.conditional_detr.modeling_conditional_detr.ConditionalDetrAttention with ConditionalDetr->DABDETR,Conditional DETR->DabDetr +class DabDetrAttention(nn.Module): + """ + Cross-Attention used in DAB-DETR 'DAB-DETR for Fast Training Convergence' paper. + + The key q_proj, k_proj, v_proj are defined outside the attention. This attention allows the dim of q, k to be + different to v. + """ + + def __init__(self, config: DabDetrConfig, bias: bool = True, is_cross: bool = False): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size * 2 if is_cross else config.hidden_size + self.output_dim = config.hidden_size + self.attention_heads = config.decoder_attention_heads + self.attention_dropout = config.attention_dropout + self.attention_head_dim = self.embed_dim // self.attention_heads + if self.attention_head_dim * self.attention_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `attention_heads`:" + f" {self.attention_heads})." + ) + # head dimension of values + self.values_head_dim = self.output_dim // self.attention_heads + if self.values_head_dim * self.attention_heads != self.output_dim: + raise ValueError( + f"output_dim must be divisible by attention_heads (got `output_dim`: {self.output_dim} and `attention_heads`: {self.attention_heads})." + ) + self.scaling = self.attention_head_dim**-0.5 + self.output_proj = nn.Linear(self.output_dim, self.output_dim, bias=bias) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + key_states: Optional[torch.Tensor] = None, + value_states: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + batch_size, q_len, _ = hidden_states.size() + + # scaling query and refactor key-, value states + query_states = hidden_states * self.scaling + query_states = query_states.view(batch_size, -1, self.attention_heads, self.attention_head_dim).transpose(1, 2) + key_states = key_states.view(batch_size, -1, self.attention_heads, self.attention_head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, -1, self.attention_heads, self.values_head_dim).transpose(1, 2) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) + + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_probs = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_probs, value_states) + + if attn_output.size() != (batch_size, self.attention_heads, q_len, self.values_head_dim): + raise ValueError( + f"`attn_output` should be of size {(batch_size, self.attention_heads, q_len, self.values_head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(batch_size, q_len, self.output_dim) + attn_output = self.output_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights + + +class DabDetrDecoderLayerSelfAttention(nn.Module): + def __init__(self, config: DabDetrConfig): + super().__init__() + self.dropout = config.dropout + self.self_attn_query_content_proj = nn.Linear(config.hidden_size, config.hidden_size) + self.self_attn_query_pos_proj = nn.Linear(config.hidden_size, config.hidden_size) + self.self_attn_key_content_proj = nn.Linear(config.hidden_size, config.hidden_size) + self.self_attn_key_pos_proj = nn.Linear(config.hidden_size, config.hidden_size) + self.self_attn_value_proj = nn.Linear(config.hidden_size, config.hidden_size) + self.self_attn = DabDetrAttention(config) + self.self_attn_layer_norm = nn.LayerNorm(config.hidden_size) + + def forward( + self, + hidden_states: torch.Tensor, + query_position_embeddings: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + ): + residual = hidden_states + query_content = self.self_attn_query_content_proj(hidden_states) + query_pos = self.self_attn_query_pos_proj(query_position_embeddings) + key_content = self.self_attn_key_content_proj(hidden_states) + key_pos = self.self_attn_key_pos_proj(query_position_embeddings) + value = self.self_attn_value_proj(hidden_states) + + query = query_content + query_pos + key = key_content + key_pos + + hidden_states, attn_weights = self.self_attn( + hidden_states=query, + attention_mask=attention_mask, + key_states=key, + value_states=value, + output_attentions=True, + ) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + return hidden_states, attn_weights + + +class DabDetrDecoderLayerCrossAttention(nn.Module): + def __init__(self, config: DabDetrConfig, is_first: bool = False): + super().__init__() + hidden_size = config.hidden_size + self.cross_attn_query_content_proj = nn.Linear(hidden_size, hidden_size) + self.cross_attn_query_pos_proj = nn.Linear(hidden_size, hidden_size) + self.cross_attn_key_content_proj = nn.Linear(hidden_size, hidden_size) + self.cross_attn_key_pos_proj = nn.Linear(hidden_size, hidden_size) + self.cross_attn_value_proj = nn.Linear(hidden_size, hidden_size) + self.cross_attn_query_pos_sine_proj = nn.Linear(hidden_size, hidden_size) + self.decoder_attention_heads = config.decoder_attention_heads + self.cross_attn_layer_norm = nn.LayerNorm(hidden_size) + self.cross_attn = DabDetrAttention(config, is_cross=True) + + self.keep_query_pos = config.keep_query_pos + + if not self.keep_query_pos and not is_first: + self.cross_attn_query_pos_proj = None + + self.is_first = is_first + self.dropout = config.dropout + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + query_position_embeddings: Optional[torch.Tensor] = None, + object_queries: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + query_sine_embed: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + ): + query_content = self.cross_attn_query_content_proj(hidden_states) + key_content = self.cross_attn_key_content_proj(encoder_hidden_states) + value = self.cross_attn_value_proj(encoder_hidden_states) + + batch_size, num_queries, n_model = query_content.shape + _, height_width, _ = key_content.shape + + key_pos = self.cross_attn_key_pos_proj(object_queries) + + # For the first decoder layer, we add the positional embedding predicted from + # the object query (the positional embedding) into the original query (key) in DETR. + if self.is_first or self.keep_query_pos: + query_pos = self.cross_attn_query_pos_proj(query_position_embeddings) + query = query_content + query_pos + key = key_content + key_pos + else: + query = query_content + key = key_content + + query = query.view( + batch_size, num_queries, self.decoder_attention_heads, n_model // self.decoder_attention_heads + ) + query_sine_embed = self.cross_attn_query_pos_sine_proj(query_sine_embed) + query_sine_embed = query_sine_embed.view( + batch_size, num_queries, self.decoder_attention_heads, n_model // self.decoder_attention_heads + ) + query = torch.cat([query, query_sine_embed], dim=3).view(batch_size, num_queries, n_model * 2) + key = key.view(batch_size, height_width, self.decoder_attention_heads, n_model // self.decoder_attention_heads) + key_pos = key_pos.view( + batch_size, height_width, self.decoder_attention_heads, n_model // self.decoder_attention_heads + ) + key = torch.cat([key, key_pos], dim=3).view(batch_size, height_width, n_model * 2) + + # Cross-Attention Block + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + + hidden_states, cross_attn_weights = self.cross_attn( + hidden_states=query, + attention_mask=encoder_attention_mask, + key_states=key, + value_states=value, + output_attentions=output_attentions, + ) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.cross_attn_layer_norm(hidden_states) + + return hidden_states, cross_attn_weights + + +class DabDetrDecoderLayerFFN(nn.Module): + def __init__(self, config: DabDetrConfig): + super().__init__() + hidden_size = config.hidden_size + self.final_layer_norm = nn.LayerNorm(hidden_size) + self.fc1 = nn.Linear(hidden_size, config.decoder_ffn_dim) + self.fc2 = nn.Linear(config.decoder_ffn_dim, hidden_size) + self.activation_fn = ACT2FN[config.activation_function] + self.dropout = config.dropout + self.activation_dropout = config.activation_dropout + self.keep_query_pos = config.keep_query_pos + + def forward(self, hidden_states: torch.Tensor): + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + return hidden_states + + +# Modified from transformers.models.detr.modeling_detr.DetrEncoderLayer with DetrEncoderLayer->DabDetrEncoderLayer,DetrConfig->DabDetrConfig +class DabDetrEncoderLayer(nn.Module): + def __init__(self, config: DabDetrConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = DetrAttention(config) + self.self_attn_layer_norm = nn.LayerNorm(self.hidden_size) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.fc1 = nn.Linear(self.hidden_size, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.hidden_size) + self.final_layer_norm = nn.LayerNorm(self.hidden_size) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + object_queries: torch.Tensor, + output_attentions: Optional[bool] = None, + ): + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, source_len)` where padding elements are indicated by very large negative + values. + object_queries (`torch.FloatTensor`, *optional*): + Object queries (also called content embeddings), to be added to the hidden states. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + object_queries=object_queries, + output_attentions=output_attentions, + ) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +# Modified from transformers.models.conditional_detr.modeling_conditional_detr.ConditionalDetrDecoderLayer with ConditionalDetr->DabDetr +class DabDetrDecoderLayer(nn.Module): + def __init__(self, config: DabDetrConfig, is_first: bool = False): + super().__init__() + self.self_attn = DabDetrDecoderLayerSelfAttention(config) + self.cross_attn = DabDetrDecoderLayerCrossAttention(config, is_first) + self.mlp = DabDetrDecoderLayerFFN(config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + object_queries: Optional[torch.Tensor] = None, + query_position_embeddings: Optional[torch.Tensor] = None, + query_sine_embed: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + ): + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative + values. + object_queries (`torch.FloatTensor`, *optional*): + object_queries that are added to the queries and keys + in the cross-attention layer. + query_position_embeddings (`torch.FloatTensor`, *optional*): + object_queries that are added to the queries and keys + in the self-attention layer. + encoder_hidden_states (`torch.FloatTensor`): + cross attention input to the layer of shape `(seq_len, batch, embed_dim)` + encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size + `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative + values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + + """ + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + query_position_embeddings=query_position_embeddings, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + + hidden_states, cross_attn_weights = self.cross_attn( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + query_position_embeddings=query_position_embeddings, + object_queries=object_queries, + encoder_attention_mask=encoder_attention_mask, + query_sine_embed=query_sine_embed, + output_attentions=output_attentions, + ) + + hidden_states = self.mlp(hidden_states=hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + return outputs + + +# Modified from transformers.models.detr.modeling_detr.DetrMLPPredictionHead with DetrMLPPredictionHead->DabDetrMLP +class DabDetrMLP(nn.Module): + """ + Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates, + height and width of a bounding box w.r.t. an image. + + Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py + + """ + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, input_tensor): + for i, layer in enumerate(self.layers): + input_tensor = nn.functional.relu(layer(input_tensor)) if i < self.num_layers - 1 else layer(input_tensor) + return input_tensor + + +# Modified from transformers.models.detr.modeling_detr.DetrPreTrainedModel with Detr->DabDetr +class DabDetrPreTrainedModel(PreTrainedModel): + config_class = DabDetrConfig + base_model_prefix = "model" + main_input_name = "pixel_values" + _no_split_modules = [r"DabDetrConvEncoder", r"DabDetrEncoderLayer", r"DabDetrDecoderLayer"] + + def _init_weights(self, module): + std = self.config.init_std + xavier_std = self.config.init_xavier_std + + if isinstance(module, DabDetrMHAttentionMap): + nn.init.zeros_(module.k_linear.bias) + nn.init.zeros_(module.q_linear.bias) + nn.init.xavier_uniform_(module.k_linear.weight, gain=xavier_std) + nn.init.xavier_uniform_(module.q_linear.weight, gain=xavier_std) + if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, DabDetrForObjectDetection): + nn.init.constant_(module.bbox_predictor.layers[-1].weight.data, 0) + nn.init.constant_(module.bbox_predictor.layers[-1].bias.data, 0) + + # init prior_prob setting for focal loss + prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1) + bias_value = -math.log((1 - prior_prob) / prior_prob) + module.class_embed.bias.data.fill_(bias_value) + + +DAB_DETR_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`DabDetrConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +DAB_DETR_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. + + Pixel values can be obtained using [`AutoImageProcessor`]. See [`DetrImageProcessor.__call__`] + for details. + + pixel_mask (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*): + Mask to avoid performing attention on padding pixel values. Mask values selected in `[0, 1]`: + + - 1 for pixels that are real (i.e. **not masked**), + - 0 for pixels that are padding (i.e. **masked**). + + [What are attention masks?](../glossary#attention-mask) + + decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*): + Not used by default. Can be used to mask object queries. + encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): + Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of + hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you + can choose to directly pass a flattened representation of an image. + decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*): + Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an + embedded representation. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +# Modified from transformers.models.detr.modeling_detr.DetrEncoder with Detr->DabDetr,DETR->ConditionalDETR +class DabDetrEncoder(DabDetrPreTrainedModel): + """ + Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a + [`DabDetrEncoderLayer`]. + + The encoder updates the flattened feature map through multiple self-attention layers. + + Small tweak for DAB-DETR: + + - object_queries are added to the forward pass. + + Args: + config: DabDetrConfig + """ + + def __init__(self, config: DabDetrConfig): + super().__init__(config) + + self.dropout = config.dropout + self.query_scale = DabDetrMLP(config.hidden_size, config.hidden_size, config.hidden_size, 2) + self.layers = nn.ModuleList([DabDetrEncoderLayer(config) for _ in range(config.encoder_layers)]) + self.norm = nn.LayerNorm(config.hidden_size) if config.normalize_before else None + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + inputs_embeds, + attention_mask, + object_queries, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(sequence_length, batch_size, hidden_size)`): + Flattened feature map (output of the backbone + projection layer) that is passed to the encoder. + + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding pixel features. Mask values selected in `[0, 1]`: + + - 1 for pixel features that are real (i.e. **not masked**), + - 0 for pixel features that are padding (i.e. **masked**). + + [What are attention masks?](../glossary#attention-mask) + + object_queries (`torch.FloatTensor` of shape `(sequence_length, batch_size, hidden_size)`): + Object queries that are added to the queries in each self-attention layer. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + hidden_states = inputs_embeds + + # expand attention_mask + if attention_mask is not None: + # [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + for encoder_layer in self.layers: + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + # pos scaler + pos_scales = self.query_scale(hidden_states) + # we add object_queries * pos_scaler as extra input to the encoder_layer + scaled_object_queries = object_queries * pos_scales + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + attention_mask, + scaled_object_queries, + output_attentions, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask=attention_mask, + object_queries=scaled_object_queries, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if self.norm: + hidden_states = self.norm(hidden_states) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +# Modified from transformers.models.conditional_detr.modeling_conditional_detr.ConditionalDetrDecoder with ConditionalDetr->DabDetr,Conditional DETR->DAB-DETR +class DabDetrDecoder(DabDetrPreTrainedModel): + """ + Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`DabDetrDecoderLayer`]. + + The decoder updates the query embeddings through multiple self-attention and cross-attention layers. + + Some small tweaks for DAB-DETR: + + - object_queries and query_position_embeddings are added to the forward pass. + - if self.config.auxiliary_loss is set to True, also returns a stack of activations from all decoding layers. + + Args: + config: DabDetrConfig + """ + + def __init__(self, config: DabDetrConfig): + super().__init__(config) + self.config = config + self.dropout = config.dropout + self.num_layers = config.decoder_layers + self.gradient_checkpointing = False + + self.layers = nn.ModuleList( + [DabDetrDecoderLayer(config, is_first=(layer_id == 0)) for layer_id in range(config.decoder_layers)] + ) + # in DAB-DETR, the decoder uses layernorm after the last decoder layer output + self.hidden_size = config.hidden_size + self.layernorm = nn.LayerNorm(self.hidden_size) + + # Default cond-elewise + self.query_scale = DabDetrMLP(self.hidden_size, self.hidden_size, self.hidden_size, 2) + + self.ref_point_head = DabDetrMLP( + config.query_dim // 2 * self.hidden_size, self.hidden_size, self.hidden_size, 2 + ) + + self.bbox_embed = None + + # Default decoder_modulate_hw_attn is True + self.ref_anchor_head = DabDetrMLP(self.hidden_size, self.hidden_size, 2, 2) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + inputs_embeds, + encoder_hidden_states, + memory_key_padding_mask, + object_queries, + query_position_embeddings, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(sequence_length, batch_size, hidden_size)`): + The query embeddings that are passed into the decoder. + encoder_hidden_states (`torch.FloatTensor` of shape `(encoder_sequence_length, batch_size, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + memory_key_padding_mask (`torch.Tensor.bool` of shape `(batch_size, sequence_length)`): + The memory_key_padding_mask indicates which positions in the memory (encoder outputs) should be ignored during the attention computation, + ensuring padding tokens do not influence the attention mechanism. + object_queries (`torch.FloatTensor` of shape `(sequence_length, batch_size, hidden_size)`, *optional*): + Position embeddings that are added to the queries and keys in each cross-attention layer. + query_position_embeddings (`torch.FloatTensor` of shape `(num_queries, batch_size, number_of_anchor_points)`): + Position embeddings that are added to the queries and keys in each self-attention layer. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if inputs_embeds is not None: + hidden_states = inputs_embeds + input_shape = inputs_embeds.size()[:-1] + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + + intermediate = [] + reference_points = query_position_embeddings.sigmoid() + ref_points = [reference_points] + + # expand encoder attention mask + if encoder_hidden_states is not None and memory_key_padding_mask is not None: + # [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len] + memory_key_padding_mask = _prepare_4d_attention_mask( + memory_key_padding_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + for layer_id, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + obj_center = reference_points[..., : self.config.query_dim] + query_sine_embed = gen_sine_position_embeddings(obj_center, self.hidden_size) + query_pos = self.ref_point_head(query_sine_embed) + + # For the first decoder layer, we do not apply transformation over p_s + pos_transformation = 1 if layer_id == 0 else self.query_scale(hidden_states) + + # apply transformation + query_sine_embed = query_sine_embed[..., : self.hidden_size] * pos_transformation + + # modulated Height Width attentions + reference_anchor_size = self.ref_anchor_head(hidden_states).sigmoid() # nq, bs, 2 + query_sine_embed[..., self.hidden_size // 2 :] *= ( + reference_anchor_size[..., 0] / obj_center[..., 2] + ).unsqueeze(-1) + query_sine_embed[..., : self.hidden_size // 2] *= ( + reference_anchor_size[..., 1] / obj_center[..., 3] + ).unsqueeze(-1) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + None, + object_queries, + query_pos, + query_sine_embed, + encoder_hidden_states, + memory_key_padding_mask, + output_attentions, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=None, + object_queries=object_queries, + query_position_embeddings=query_pos, + query_sine_embed=query_sine_embed, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=memory_key_padding_mask, + output_attentions=output_attentions, + ) + + # iter update + hidden_states = layer_outputs[0] + + if self.bbox_embed is not None: + new_reference_points = self.bbox_embed(hidden_states) + + new_reference_points[..., : self.config.query_dim] += inverse_sigmoid(reference_points) + new_reference_points = new_reference_points[..., : self.config.query_dim].sigmoid() + if layer_id != self.num_layers - 1: + ref_points.append(new_reference_points) + reference_points = new_reference_points.detach() + + intermediate.append(self.layernorm(hidden_states)) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + # Layer normalization on hidden states + hidden_states = self.layernorm(hidden_states) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + output_intermediate_hidden_states = torch.stack(intermediate) + output_reference_points = torch.stack(ref_points) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + all_hidden_states, + all_self_attns, + all_cross_attentions, + output_intermediate_hidden_states, + output_reference_points, + ] + if v is not None + ) + return DabDetrDecoderOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + intermediate_hidden_states=output_intermediate_hidden_states, + reference_points=output_reference_points, + ) + + +@add_start_docstrings( + """ + The bare DAB-DETR Model (consisting of a backbone and encoder-decoder Transformer) outputting raw + hidden-states, intermediate hidden states, reference points, output coordinates without any specific head on top. + """, + DAB_DETR_START_DOCSTRING, +) +class DabDetrModel(DabDetrPreTrainedModel): + def __init__(self, config: DabDetrConfig): + super().__init__(config) + + self.auxiliary_loss = config.auxiliary_loss + + # Create backbone + positional encoding + self.backbone = DabDetrConvEncoder(config) + object_queries = DabDetrSinePositionEmbedding(config) + + self.query_refpoint_embeddings = nn.Embedding(config.num_queries, config.query_dim) + self.random_refpoints_xy = config.random_refpoints_xy + if self.random_refpoints_xy: + self.query_refpoint_embeddings.weight.data[:, :2].uniform_(0, 1) + self.query_refpoint_embeddings.weight.data[:, :2] = inverse_sigmoid( + self.query_refpoint_embeddings.weight.data[:, :2] + ) + self.query_refpoint_embeddings.weight.data[:, :2].requires_grad = False + + # Create projection layer + self.input_projection = nn.Conv2d( + self.backbone.intermediate_channel_sizes[-1], config.hidden_size, kernel_size=1 + ) + self.backbone = DabDetrConvModel(self.backbone, object_queries) + + self.encoder = DabDetrEncoder(config) + self.decoder = DabDetrDecoder(config) + + # decoder related variables + self.hidden_size = config.hidden_size + self.num_queries = config.num_queries + + self.num_patterns = config.num_patterns + if not isinstance(self.num_patterns, int): + logger.warning("num_patterns should be int but {}".format(type(self.num_patterns))) + self.num_patterns = 0 + if self.num_patterns > 0: + self.patterns = nn.Embedding(self.num_patterns, self.hidden_size) + + self.aux_loss = config.auxiliary_loss + + # Initialize weights and apply final processing + self.post_init() + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + def freeze_backbone(self): + for name, param in self.backbone.conv_encoder.model.named_parameters(): + param.requires_grad_(False) + + def unfreeze_backbone(self): + for name, param in self.backbone.conv_encoder.model.named_parameters(): + param.requires_grad_(True) + + @add_start_docstrings_to_model_forward(DAB_DETR_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=DabDetrModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: torch.FloatTensor, + pixel_mask: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], DabDetrModelOutput]: + r""" + Returns: + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, AutoModel + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("IDEA-Research/dab_detr-base") + >>> model = AutoModel.from_pretrained("IDEA-Research/dab_detr-base") + + >>> # prepare image for the model + >>> inputs = image_processor(images=image, return_tensors="pt") + + >>> # forward pass + >>> outputs = model(**inputs) + + >>> # the last hidden states are the final query embeddings of the Transformer decoder + >>> # these are of shape (batch_size, num_queries, hidden_size) + >>> last_hidden_states = outputs.last_hidden_state + >>> list(last_hidden_states.shape) + [1, 300, 256] + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size, _, height, width = pixel_values.shape + device = pixel_values.device + + if pixel_mask is None: + pixel_mask = torch.ones(((batch_size, height, width)), device=device) + + # First, sent pixel_values + pixel_mask through Backbone to obtain the features + # pixel_values should be of shape (batch_size, num_channels, height, width) + # pixel_mask should be of shape (batch_size, height, width) + features, object_queries_list = self.backbone(pixel_values, pixel_mask) + + # get final feature map and downsampled mask + feature_map, mask = features[-1] + + if mask is None: + raise ValueError("Backbone does not return downsampled pixel mask") + + flattened_mask = mask.flatten(1) + + # Second, apply 1x1 convolution to reduce the channel dimension to hidden_size (256 by default) + projected_feature_map = self.input_projection(feature_map) + + # Third, flatten the feature map + object_queries of shape NxCxHxW to HWxNxC, and permute it to NxHWxC + # In other words, turn their shape into ( sequence_length, batch_size, hidden_size) + flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1) + object_queries = object_queries_list[-1].flatten(2).permute(0, 2, 1) + reference_position_embeddings = self.query_refpoint_embeddings.weight.unsqueeze(0).repeat(batch_size, 1, 1) + + # Fourth, sent flattened_features + flattened_mask + object_queries through encoder + # flattened_features is a Tensor of shape (heigth*width, batch_size, hidden_size) + # flattened_mask is a Tensor of shape (batch_size, heigth*width) + if encoder_outputs is None: + encoder_outputs = self.encoder( + inputs_embeds=flattened_features, + attention_mask=flattened_mask, + object_queries=object_queries, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + # Fifth, sent query embeddings + object_queries through the decoder (which is conditioned on the encoder output) + num_queries = reference_position_embeddings.shape[1] + if self.num_patterns == 0: + queries = torch.zeros(batch_size, num_queries, self.hidden_size, device=device) + else: + queries = ( + self.patterns.weight[:, None, None, :] + .repeat(1, self.num_queries, batch_size, 1) + .flatten(0, 1) + .permute(1, 0, 2) + ) # bs, n_q*n_pat, hidden_size + reference_position_embeddings = reference_position_embeddings.repeat( + 1, self.num_patterns, 1 + ) # bs, n_q*n_pat, hidden_size + + # decoder outputs consists of (dec_features, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + inputs_embeds=queries, + query_position_embeddings=reference_position_embeddings, + object_queries=object_queries, + encoder_hidden_states=encoder_outputs[0], + memory_key_padding_mask=flattened_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + # last_hidden_state + output = (decoder_outputs[0],) + reference_points = decoder_outputs[-1] + intermediate_hidden_states = decoder_outputs[-2] + + # it has to follow the order of DABDETRModelOutput that is based on ModelOutput + # If we only use one of the variables then the indexing will change. + # E.g: if we return everything then 'decoder_attentions' is decoder_outputs[2], if we only use output_attentions then its decoder_outputs[1] + if output_hidden_states and output_attentions: + output += ( + decoder_outputs[1], + decoder_outputs[2], + decoder_outputs[3], + encoder_outputs[0], + encoder_outputs[1], + encoder_outputs[2], + ) + elif output_hidden_states: + # decoder_hidden_states, encoder_last_hidden_state, encoder_hidden_states + output += ( + decoder_outputs[1], + encoder_outputs[0], + encoder_outputs[1], + ) + elif output_attentions: + # decoder_self_attention, decoder_cross_attention, encoder_attentions + output += ( + decoder_outputs[1], + decoder_outputs[2], + encoder_outputs[1], + ) + + output += (intermediate_hidden_states, reference_points) + + return output + + reference_points = decoder_outputs.reference_points + intermediate_hidden_states = decoder_outputs.intermediate_hidden_states + + return DabDetrModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + decoder_hidden_states=decoder_outputs.hidden_states if output_hidden_states else None, + decoder_attentions=decoder_outputs.attentions if output_attentions else None, + cross_attentions=decoder_outputs.cross_attentions if output_attentions else None, + encoder_last_hidden_state=encoder_outputs.last_hidden_state if output_hidden_states else None, + encoder_hidden_states=encoder_outputs.hidden_states if output_hidden_states else None, + encoder_attentions=encoder_outputs.attentions if output_attentions else None, + intermediate_hidden_states=intermediate_hidden_states, + reference_points=reference_points, + ) + + +# Copied from transformers.models.detr.modeling_detr.DetrMHAttentionMap with Detr->DabDetr +class DabDetrMHAttentionMap(nn.Module): + """This is a 2D attention module, which only returns the attention softmax (no multiplication by value)""" + + def __init__(self, query_dim, hidden_dim, num_heads, dropout=0.0, bias=True, std=None): + super().__init__() + self.num_heads = num_heads + self.hidden_dim = hidden_dim + self.dropout = nn.Dropout(dropout) + + self.q_linear = nn.Linear(query_dim, hidden_dim, bias=bias) + self.k_linear = nn.Linear(query_dim, hidden_dim, bias=bias) + + self.normalize_fact = float(hidden_dim / self.num_heads) ** -0.5 + + def forward(self, q, k, mask: Optional[Tensor] = None): + q = self.q_linear(q) + k = nn.functional.conv2d(k, self.k_linear.weight.unsqueeze(-1).unsqueeze(-1), self.k_linear.bias) + queries_per_head = q.view(q.shape[0], q.shape[1], self.num_heads, self.hidden_dim // self.num_heads) + keys_per_head = k.view(k.shape[0], self.num_heads, self.hidden_dim // self.num_heads, k.shape[-2], k.shape[-1]) + weights = torch.einsum("bqnc,bnchw->bqnhw", queries_per_head * self.normalize_fact, keys_per_head) + + if mask is not None: + weights.masked_fill_(mask.unsqueeze(1).unsqueeze(1), torch.finfo(weights.dtype).min) + weights = nn.functional.softmax(weights.flatten(2), dim=-1).view(weights.size()) + weights = self.dropout(weights) + return weights + + +@add_start_docstrings( + """ + DAB_DETR Model (consisting of a backbone and encoder-decoder Transformer) with object detection heads on + top, for tasks such as COCO detection. + """, + DAB_DETR_START_DOCSTRING, +) +class DabDetrForObjectDetection(DabDetrPreTrainedModel): + # When using clones, all layers > 0 will be clones, but layer 0 *is* required + _tied_weights_keys = [ + r"bbox_predictor\.layers\.\d+\.(weight|bias)", + r"model\.decoder\.bbox_embed\.layers\.\d+\.(weight|bias)", + ] + + def __init__(self, config: DabDetrConfig): + super().__init__(config) + + self.config = config + self.auxiliary_loss = config.auxiliary_loss + self.query_dim = config.query_dim + # DAB-DETR encoder-decoder model + self.model = DabDetrModel(config) + + _bbox_embed = DabDetrMLP(config.hidden_size, config.hidden_size, 4, 3) + # Object detection heads + self.class_embed = nn.Linear(config.hidden_size, config.num_labels) + + # Default bbox_embed_diff_each_layer is False + self.bbox_predictor = _bbox_embed + + # Default iter_update is True + self.model.decoder.bbox_embed = self.bbox_predictor + + # Initialize weights and apply final processing + self.post_init() + + # taken from https://github.com/Atten4Vis/conditionalDETR/blob/master/models/dab_detr.py + @torch.jit.unused + def _set_aux_loss(self, outputs_class, outputs_coord): + # this is a workaround to make torchscript happy, as torchscript + # doesn't support dictionary with non-homogeneous values, such + # as a dict having both a Tensor and a list. + return [{"logits": a, "pred_boxes": b} for a, b in zip(outputs_class[:-1], outputs_coord[:-1])] + + @add_start_docstrings_to_model_forward(DAB_DETR_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=DabDetrObjectDetectionOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: torch.FloatTensor, + pixel_mask: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[List[dict]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], DabDetrObjectDetectionOutput]: + r""" + labels (`List[Dict]` of len `(batch_size,)`, *optional*): + Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the + following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch + respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of bounding boxes + in the image,)` and the boxes a `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)`. + + Returns: + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, AutoModelForObjectDetection + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("IDEA-Research/dab-detr-resnet-50") + >>> model = AutoModelForObjectDetection.from_pretrained("IDEA-Research/dab-detr-resnet-50") + + >>> inputs = image_processor(images=image, return_tensors="pt") + + >>> with torch.no_grad(): + >>> outputs = model(**inputs) + + >>> # convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax) + >>> target_sizes = torch.tensor([(image.height, image.width)]) + >>> results = image_processor.post_process_object_detection(outputs, threshold=0.5, target_sizes=target_sizes)[0] + >>> for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): + ... box = [round(i, 2) for i in box.tolist()] + ... print( + ... f"Detected {model.config.id2label[label.item()]} with confidence " + ... f"{round(score.item(), 3)} at location {box}" + ... ) + Detected remote with confidence 0.833 at location [38.31, 72.1, 177.63, 118.45] + Detected cat with confidence 0.831 at location [9.2, 51.38, 321.13, 469.0] + Detected cat with confidence 0.804 at location [340.3, 16.85, 642.93, 370.95] + Detected remote with confidence 0.683 at location [334.48, 73.49, 366.37, 190.01] + Detected couch with confidence 0.535 at location [0.52, 1.19, 640.35, 475.1] + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # First, sent images through DAB_DETR base model to obtain encoder + decoder outputs + model_outputs = self.model( + pixel_values, + pixel_mask=pixel_mask, + decoder_attention_mask=decoder_attention_mask, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + reference_points = model_outputs.reference_points if return_dict else model_outputs[-1] + intermediate_hidden_states = model_outputs.intermediate_hidden_states if return_dict else model_outputs[-2] + + # class logits + predicted bounding boxes + logits = self.class_embed(intermediate_hidden_states[-1]) + + reference_before_sigmoid = inverse_sigmoid(reference_points) + bbox_with_refinement = self.bbox_predictor(intermediate_hidden_states) + bbox_with_refinement[..., : self.query_dim] += reference_before_sigmoid + outputs_coord = bbox_with_refinement.sigmoid() + + pred_boxes = outputs_coord[-1] + + loss, loss_dict, auxiliary_outputs = None, None, None + if labels is not None: + outputs_class = None + if self.config.auxiliary_loss: + outputs_class = self.class_embed(intermediate_hidden_states) + loss, loss_dict, auxiliary_outputs = self.loss_function( + logits, labels, self.device, pred_boxes, self.config, outputs_class, outputs_coord + ) + + if not return_dict: + if auxiliary_outputs is not None: + output = (logits, pred_boxes) + auxiliary_outputs + model_outputs + else: + output = (logits, pred_boxes) + model_outputs + # Since DabDetrObjectDetectionOutput doesn't have reference points + intermedieate_hidden_states we cut down. + return ((loss, loss_dict) + output) if loss is not None else output[:-2] + + return DabDetrObjectDetectionOutput( + loss=loss, + loss_dict=loss_dict, + logits=logits, + pred_boxes=pred_boxes, + auxiliary_outputs=auxiliary_outputs, + last_hidden_state=model_outputs.last_hidden_state, + decoder_hidden_states=model_outputs.decoder_hidden_states if output_hidden_states else None, + decoder_attentions=model_outputs.decoder_attentions if output_attentions else None, + cross_attentions=model_outputs.cross_attentions if output_attentions else None, + encoder_last_hidden_state=model_outputs.encoder_last_hidden_state if output_hidden_states else None, + encoder_hidden_states=model_outputs.encoder_hidden_states if output_hidden_states else None, + encoder_attentions=model_outputs.encoder_attentions if output_attentions else None, + ) + + +__all__ = [ + "DabDetrForObjectDetection", + "DabDetrModel", + "DabDetrPreTrainedModel", +] diff --git a/src/transformers/models/detr/configuration_detr.py b/src/transformers/models/detr/configuration_detr.py index 90cd3b1345e3..3dd37c36a4ae 100644 --- a/src/transformers/models/detr/configuration_detr.py +++ b/src/transformers/models/detr/configuration_detr.py @@ -52,7 +52,7 @@ class DetrConfig(PretrainedConfig): Number of object queries, i.e. detection slots. This is the maximal number of objects [`DetrModel`] can detect in a single image. For COCO, we recommend 100 queries. d_model (`int`, *optional*, defaults to 256): - Dimension of the layers. + This parameter is a general dimension parameter, defining dimensions for components such as the encoder layer and projection parameters in the decoder layer, among others. encoder_layers (`int`, *optional*, defaults to 6): Number of encoder layers. decoder_layers (`int`, *optional*, defaults to 6): diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index dbfb1ef53491..349ad988df40 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -2482,6 +2482,27 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class DabDetrForObjectDetection(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class DabDetrModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class DabDetrPreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class DacModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/models/dab_detr/__init__.py b/tests/models/dab_detr/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/dab_detr/test_modeling_dab_detr.py b/tests/models/dab_detr/test_modeling_dab_detr.py new file mode 100644 index 000000000000..d3d70d67d4c3 --- /dev/null +++ b/tests/models/dab_detr/test_modeling_dab_detr.py @@ -0,0 +1,839 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch DAB-DETR model.""" + +import inspect +import math +import unittest +from typing import Dict, List, Tuple + +from transformers import DabDetrConfig, ResNetConfig, is_torch_available, is_vision_available +from transformers.testing_utils import require_timm, require_torch, require_vision, slow, torch_device +from transformers.utils import cached_property + +from ...generation.test_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor +from ...test_pipeline_mixin import PipelineTesterMixin + + +if is_torch_available(): + import torch + import torch.nn.functional as F + + from transformers import ( + DabDetrForObjectDetection, + DabDetrModel, + ) + + +if is_vision_available(): + from PIL import Image + + from transformers import ConditionalDetrImageProcessor + + +class DabDetrModelTester: + def __init__( + self, + parent, + batch_size=8, + is_training=True, + use_labels=True, + hidden_size=32, + num_hidden_layers=2, + num_attention_heads=8, + intermediate_size=4, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + num_queries=12, + num_channels=3, + min_size=200, + max_size=200, + n_targets=8, + num_labels=91, + ): + self.parent = parent + self.batch_size = batch_size + self.is_training = is_training + self.use_labels = use_labels + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.num_queries = num_queries + self.num_channels = num_channels + self.min_size = min_size + self.max_size = max_size + self.n_targets = n_targets + self.num_labels = num_labels + + # we also set the expected seq length for both encoder and decoder + self.encoder_seq_length = math.ceil(self.min_size / 32) * math.ceil(self.max_size / 32) + self.decoder_seq_length = self.num_queries + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.min_size, self.max_size]) + + pixel_mask = torch.ones([self.batch_size, self.min_size, self.max_size], device=torch_device) + + labels = None + if self.use_labels: + # labels is a list of Dict (each Dict being the labels for a given example in the batch) + labels = [] + for i in range(self.batch_size): + target = {} + target["class_labels"] = torch.randint( + high=self.num_labels, size=(self.n_targets,), device=torch_device + ) + target["boxes"] = torch.rand(self.n_targets, 4, device=torch_device) + target["masks"] = torch.rand(self.n_targets, self.min_size, self.max_size, device=torch_device) + labels.append(target) + + config = self.get_config() + return config, pixel_values, pixel_mask, labels + + def get_config(self): + resnet_config = ResNetConfig( + num_channels=3, + embeddings_size=10, + hidden_sizes=[10, 20, 30, 40], + depths=[1, 1, 2, 1], + hidden_act="relu", + num_labels=3, + out_features=["stage2", "stage3", "stage4"], + out_indices=[2, 3, 4], + ) + return DabDetrConfig( + hidden_size=self.hidden_size, + encoder_layers=self.num_hidden_layers, + decoder_layers=self.num_hidden_layers, + encoder_attention_heads=self.num_attention_heads, + decoder_attention_heads=self.num_attention_heads, + encoder_ffn_dim=self.intermediate_size, + decoder_ffn_dim=self.intermediate_size, + dropout=self.hidden_dropout_prob, + attention_dropout=self.attention_probs_dropout_prob, + num_queries=self.num_queries, + num_labels=self.num_labels, + use_timm_backbone=False, + backbone_config=resnet_config, + backbone=None, + use_pretrained_backbone=False, + ) + + def prepare_config_and_inputs_for_common(self): + config, pixel_values, pixel_mask, labels = self.prepare_config_and_inputs() + inputs_dict = {"pixel_values": pixel_values, "pixel_mask": pixel_mask} + return config, inputs_dict + + def create_and_check_dab_detr_model(self, config, pixel_values, pixel_mask, labels): + model = DabDetrModel(config=config) + model.to(torch_device) + model.eval() + + result = model(pixel_values=pixel_values, pixel_mask=pixel_mask) + result = model(pixel_values) + + self.parent.assertEqual( + result.last_hidden_state.shape, (self.batch_size, self.decoder_seq_length, self.hidden_size) + ) + + def create_and_check_dab_detr_object_detection_head_model(self, config, pixel_values, pixel_mask, labels): + model = DabDetrForObjectDetection(config=config) + model.to(torch_device) + model.eval() + + result = model(pixel_values=pixel_values, pixel_mask=pixel_mask) + result = model(pixel_values) + + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels)) + self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4)) + + result = model(pixel_values=pixel_values, labels=labels) + + self.parent.assertEqual(result.loss.shape, ()) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels)) + self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4)) + + +@require_torch +class DabDetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): + all_model_classes = ( + ( + DabDetrModel, + DabDetrForObjectDetection, + ) + if is_torch_available() + else () + ) + pipeline_model_mapping = ( + { + "image-feature-extraction": DabDetrModel, + "object-detection": DabDetrForObjectDetection, + } + if is_torch_available() + else {} + ) + is_encoder_decoder = True + test_torchscript = False + test_pruning = False + test_head_masking = False + test_missing_keys = False + zero_init_hidden_state = True + + # special case for head models + def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): + inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) + + if return_labels: + if model_class.__name__ in ["DabDetrForObjectDetection"]: + labels = [] + for i in range(self.model_tester.batch_size): + target = {} + target["class_labels"] = torch.ones( + size=(self.model_tester.n_targets,), device=torch_device, dtype=torch.long + ) + target["boxes"] = torch.ones( + self.model_tester.n_targets, 4, device=torch_device, dtype=torch.float + ) + target["masks"] = torch.ones( + self.model_tester.n_targets, + self.model_tester.min_size, + self.model_tester.max_size, + device=torch_device, + dtype=torch.float, + ) + labels.append(target) + inputs_dict["labels"] = labels + + return inputs_dict + + def setUp(self): + self.model_tester = DabDetrModelTester(self) + self.config_tester = ConfigTester(self, config_class=DabDetrConfig, has_text_modality=False) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_dab_detr_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_dab_detr_model(*config_and_inputs) + + def test_dab_detr_object_detection_head_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_dab_detr_object_detection_head_model(*config_and_inputs) + + # TODO: check if this works again for PyTorch 2.x.y + @unittest.skip(reason="Got `CUDA error: misaligned address` with PyTorch 2.0.0.") + def test_multi_gpu_data_parallel_forward(self): + pass + + @unittest.skip(reason="DETR does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + @unittest.skip(reason="DETR does not use inputs_embeds") + def test_model_get_set_embeddings(self): + pass + + @unittest.skip(reason="DETR does not use inputs_embeds") + def test_inputs_embeds_matches_input_ids(self): + pass + + @unittest.skip(reason="DETR does not have a get_input_embeddings method") + def test_model_common_attributes(self): + pass + + @unittest.skip(reason="DETR is not a generative model") + def test_generate_without_input_ids(self): + pass + + @unittest.skip(reason="DETR does not use token embeddings") + def test_resize_tokens_embeddings(self): + pass + + @slow + def test_model_outputs_equivalence(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + def set_nan_tensor_to_zero(t): + print(t) + t[t != t] = 0 + return t + + def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}): + with torch.no_grad(): + tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs) + dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple() + + def recursive_check(tuple_object, dict_object): + if isinstance(tuple_object, (List, Tuple)): + for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif isinstance(tuple_object, Dict): + for tuple_iterable_value, dict_iterable_value in zip( + tuple_object.values(), dict_object.values() + ): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif tuple_object is None: + return + else: + torch.testing.assert_close( + set_nan_tensor_to_zero(tuple_object), + set_nan_tensor_to_zero(dict_object), + atol=1e-5, + rtol=1e-5, + msg=( + "Tuple and dict output are not equal. Difference:" + f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:" + f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has" + f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}." + ), + ) + + recursive_check(tuple_output, dict_output) + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + tuple_inputs = self._prepare_for_class(inputs_dict, model_class) + dict_inputs = self._prepare_for_class(inputs_dict, model_class) + check_equivalence(model, tuple_inputs, dict_inputs) + + tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + check_equivalence(model, tuple_inputs, dict_inputs) + + tuple_inputs = self._prepare_for_class(inputs_dict, model_class) + dict_inputs = self._prepare_for_class(inputs_dict, model_class) + check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True}) + + tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True}) + + if self.has_attentions: + tuple_inputs = self._prepare_for_class(inputs_dict, model_class) + dict_inputs = self._prepare_for_class(inputs_dict, model_class) + check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True}) + + tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True}) + + tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + check_equivalence( + model, tuple_inputs, dict_inputs, {"output_hidden_states": True, "output_attentions": True} + ) + + def test_hidden_states_output(self): + def check_hidden_states_output(inputs_dict, config, model_class): + model = model_class(config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states + + expected_num_layers = getattr( + self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1 + ) + + self.assertEqual(len(hidden_states), expected_num_layers) + + if hasattr(self.model_tester, "encoder_seq_length"): + seq_length = self.model_tester.encoder_seq_length + if hasattr(self.model_tester, "chunk_length") and self.model_tester.chunk_length > 1: + seq_length = seq_length * self.model_tester.chunk_length + else: + seq_length = self.model_tester.seq_length + + self.assertListEqual( + [hidden_states[0].shape[1], hidden_states[0].shape[2]], + [seq_length, self.model_tester.hidden_size], + ) + + if config.is_encoder_decoder: + hidden_states = outputs.decoder_hidden_states + + self.assertIsInstance(hidden_states, (list, tuple)) + + self.assertEqual(len(hidden_states), expected_num_layers) + seq_len = getattr(self.model_tester, "seq_length", None) + decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len) + + self.assertListEqual( + [hidden_states[0].shape[1], hidden_states[0].shape[2]], + [decoder_seq_length, self.model_tester.hidden_size], + ) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + inputs_dict["output_hidden_states"] = True + check_hidden_states_output(inputs_dict, config, model_class) + + # check that output_hidden_states also work using config + del inputs_dict["output_hidden_states"] + config.output_hidden_states = True + + check_hidden_states_output(inputs_dict, config, model_class) + + # Had to modify the threshold to 2 decimals instead of 3 because sometimes it threw an error + def test_batching_equivalence(self): + """ + Tests that the model supports batching and that the output is the nearly the same for the same input in + different batch sizes. + (Why "nearly the same" not "exactly the same"? Batching uses different matmul shapes, which often leads to + different results: https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535) + """ + + def get_tensor_equivalence_function(batched_input): + # models operating on continuous spaces have higher abs difference than LMs + # instead, we can rely on cos distance for image/speech models, similar to `diffusers` + if "input_ids" not in batched_input: + return lambda tensor1, tensor2: ( + 1.0 - F.cosine_similarity(tensor1.float().flatten(), tensor2.float().flatten(), dim=0, eps=1e-38) + ) + return lambda tensor1, tensor2: torch.max(torch.abs(tensor1 - tensor2)) + + def recursive_check(batched_object, single_row_object, model_name, key): + if isinstance(batched_object, (list, tuple)): + for batched_object_value, single_row_object_value in zip(batched_object, single_row_object): + recursive_check(batched_object_value, single_row_object_value, model_name, key) + elif isinstance(batched_object, dict): + for batched_object_value, single_row_object_value in zip( + batched_object.values(), single_row_object.values() + ): + recursive_check(batched_object_value, single_row_object_value, model_name, key) + # do not compare returned loss (0-dim tensor) / codebook ids (int) / caching objects + elif batched_object is None or not isinstance(batched_object, torch.Tensor): + return + elif batched_object.dim() == 0: + return + else: + # indexing the first element does not always work + # e.g. models that output similarity scores of size (N, M) would need to index [0, 0] + slice_ids = [slice(0, index) for index in single_row_object.shape] + batched_row = batched_object[slice_ids] + self.assertFalse( + torch.isnan(batched_row).any(), f"Batched output has `nan` in {model_name} for key={key}" + ) + self.assertFalse( + torch.isinf(batched_row).any(), f"Batched output has `inf` in {model_name} for key={key}" + ) + self.assertFalse( + torch.isnan(single_row_object).any(), f"Single row output has `nan` in {model_name} for key={key}" + ) + self.assertFalse( + torch.isinf(single_row_object).any(), f"Single row output has `inf` in {model_name} for key={key}" + ) + self.assertTrue( + (equivalence(batched_row, single_row_object)) <= 1e-02, + msg=( + f"Batched and Single row outputs are not equal in {model_name} for key={key}. " + f"Difference={equivalence(batched_row, single_row_object)}." + ), + ) + + config, batched_input = self.model_tester.prepare_config_and_inputs_for_common() + equivalence = get_tensor_equivalence_function(batched_input) + + for model_class in self.all_model_classes: + config.output_hidden_states = True + + model_name = model_class.__name__ + if hasattr(self.model_tester, "prepare_config_and_inputs_for_model_class"): + config, batched_input = self.model_tester.prepare_config_and_inputs_for_model_class(model_class) + batched_input_prepared = self._prepare_for_class(batched_input, model_class) + model = model_class(config).to(torch_device).eval() + + batch_size = self.model_tester.batch_size + single_row_input = {} + for key, value in batched_input_prepared.items(): + if isinstance(value, torch.Tensor) and value.shape[0] % batch_size == 0: + # e.g. musicgen has inputs of size (bs*codebooks). in most cases value.shape[0] == batch_size + single_batch_shape = value.shape[0] // batch_size + single_row_input[key] = value[:single_batch_shape] + else: + single_row_input[key] = value + + with torch.no_grad(): + model_batched_output = model(**batched_input_prepared) + model_row_output = model(**single_row_input) + + if isinstance(model_batched_output, torch.Tensor): + model_batched_output = {"model_output": model_batched_output} + model_row_output = {"model_output": model_row_output} + + for key in model_batched_output: + # DETR starts from zero-init queries to decoder, leading to cos_similarity = `nan` + if hasattr(self, "zero_init_hidden_state") and "decoder_hidden_states" in key: + model_batched_output[key] = model_batched_output[key][1:] + model_row_output[key] = model_row_output[key][1:] + recursive_check(model_batched_output[key], model_row_output[key], model_name, key) + + def test_attention_outputs(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + + decoder_seq_length = self.model_tester.decoder_seq_length + encoder_seq_length = self.model_tester.encoder_seq_length + decoder_key_length = self.model_tester.decoder_seq_length + encoder_key_length = self.model_tester.encoder_seq_length + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = False + config.return_dict = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + + # check that output_attentions also work using config + del inputs_dict["output_attentions"] + del inputs_dict["output_hidden_states"] + config.output_attentions = True + config.output_hidden_states = False + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + + self.assertListEqual( + list(attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], + ) + out_len = len(outputs) + if self.is_encoder_decoder: + correct_outlen = 6 + + # loss is at first position + if "labels" in inputs_dict: + correct_outlen += 1 # loss is added to beginning + if "past_key_values" in outputs: + correct_outlen += 1 # past_key_values have been returned + + self.assertEqual(out_len, correct_outlen) + + # decoder attentions + decoder_attentions = outputs.decoder_attentions + self.assertIsInstance(decoder_attentions, (list, tuple)) + self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers) + self.assertListEqual( + list(decoder_attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length], + ) + + # cross attentions + cross_attentions = outputs.cross_attentions + self.assertIsInstance(cross_attentions, (list, tuple)) + self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers) + self.assertListEqual( + list(cross_attentions[0].shape[-3:]), + [ + self.model_tester.num_attention_heads, + decoder_seq_length, + encoder_key_length, + ], + ) + + # Check attention is always last and order is fine + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + if hasattr(self.model_tester, "num_hidden_states_types"): + added_hidden_states = self.model_tester.num_hidden_states_types + elif self.is_encoder_decoder: + # decoder_hidden_states, encoder_last_hidden_state, encoder_hidden_states + added_hidden_states = 3 + else: + added_hidden_states = 1 + + self.assertEqual(out_len + added_hidden_states, len(outputs)) + + self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + + self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers) + self.assertListEqual( + list(self_attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], + ) + + def test_retain_grad_hidden_states_attentions(self): + # removed retain_grad and grad on decoder_hidden_states, as queries don't require grad + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + # no need to test all models as different heads yield the same functionality + model_class = self.all_model_classes[0] + model = model_class(config) + model.to(torch_device) + + inputs = self._prepare_for_class(inputs_dict, model_class) + + outputs = model(**inputs, output_attentions=True, output_hidden_states=True) + + # logits + output = outputs[0] + + encoder_hidden_states = outputs.encoder_hidden_states[0] + encoder_hidden_states.retain_grad() + + encoder_attentions = outputs.encoder_attentions[0] + encoder_attentions.retain_grad() + + decoder_attentions = outputs.decoder_attentions[0] + decoder_attentions.retain_grad() + + cross_attentions = outputs.cross_attentions[0] + cross_attentions.retain_grad() + + output.flatten()[0].backward(retain_graph=True) + + self.assertIsNotNone(encoder_hidden_states.grad) + self.assertIsNotNone(encoder_attentions.grad) + self.assertIsNotNone(decoder_attentions.grad) + self.assertIsNotNone(cross_attentions.grad) + + def test_forward_auxiliary_loss(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.auxiliary_loss = True + + # only test for object detection and segmentation model + for model_class in self.all_model_classes[1:]: + model = model_class(config) + model.to(torch_device) + + inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + + outputs = model(**inputs) + + self.assertIsNotNone(outputs.auxiliary_outputs) + self.assertEqual(len(outputs.auxiliary_outputs), self.model_tester.num_hidden_layers - 1) + + def test_training(self): + if not self.model_tester.is_training: + self.skipTest(reason="ModelTester is not configured to run training tests") + + # We only have loss with ObjectDetection + model_class = self.all_model_classes[-1] + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + + model = model_class(config) + model.to(torch_device) + model.train() + inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + loss = model(**inputs).loss + loss.backward() + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.forward) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + if model.config.is_encoder_decoder: + expected_arg_names = ["pixel_values", "pixel_mask"] + expected_arg_names.extend( + ["head_mask", "decoder_head_mask", "encoder_outputs"] + if "head_mask" and "decoder_head_mask" in arg_names + else [] + ) + self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names) + else: + expected_arg_names = ["pixel_values", "pixel_mask"] + self.assertListEqual(arg_names[:1], expected_arg_names) + + def test_different_timm_backbone(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + # let's pick a random timm backbone + config.backbone = "tf_mobilenetv3_small_075" + config.backbone_config = None + config.use_timm_backbone = True + config.backbone_kwargs = {"out_indices": [2, 3, 4]} + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + if model_class.__name__ == "DabDetrForObjectDetection": + expected_shape = ( + self.model_tester.batch_size, + self.model_tester.num_queries, + self.model_tester.num_labels, + ) + self.assertEqual(outputs.logits.shape, expected_shape) + # Confirm out_indices was propogated to backbone + self.assertEqual(len(model.model.backbone.conv_encoder.intermediate_channel_sizes), 3) + else: + # Confirm out_indices was propogated to backbone + self.assertEqual(len(model.backbone.conv_encoder.intermediate_channel_sizes), 3) + + self.assertTrue(outputs) + + def test_initialization(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + configs_no_init = _config_zero_init(config) + configs_no_init.init_xavier_std = 1e9 + # Copied from RT-DETR + configs_no_init.initializer_bias_prior_prob = 0.2 + bias_value = -1.3863 # log_e ((1 - 0.2) / 0.2) + + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + for name, param in model.named_parameters(): + if param.requires_grad: + if "bbox_attention" in name and "bias" not in name: + self.assertLess( + 100000, + abs(param.data.max().item()), + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + # Modifed from RT-DETR + elif "class_embed" in name and "bias" in name: + bias_tensor = torch.full_like(param.data, bias_value) + torch.testing.assert_close( + param.data, + bias_tensor, + atol=1e-4, + rtol=1e-4, + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + elif "activation_fn" in name and config.activation_function == "prelu": + self.assertTrue( + param.data.mean() == 0.25, + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + elif "backbone.conv_encoder.model" in name: + continue + elif "self_attn.in_proj_weight" in name: + self.assertIn( + ((param.data.mean() * 1e2).round() / 1e2).item(), + [0.0, 1.0], + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + else: + self.assertIn( + ((param.data.mean() * 1e9).round() / 1e9).item(), + [0.0, 1.0], + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + + +TOLERANCE = 1e-4 +CHECKPOINT = "IDEA-Research/dab-detr-resnet-50" + + +# We will verify our results on an image of cute cats +def prepare_img(): + image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") + return image + + +@require_timm +@require_vision +@slow +class DabDetrModelIntegrationTests(unittest.TestCase): + @cached_property + def default_image_processor(self): + return ConditionalDetrImageProcessor.from_pretrained(CHECKPOINT) if is_vision_available() else None + + def test_inference_no_head(self): + model = DabDetrModel.from_pretrained(CHECKPOINT).to(torch_device) + + image_processor = self.default_image_processor + image = prepare_img() + encoding = image_processor(images=image, return_tensors="pt").to(torch_device) + + with torch.no_grad(): + outputs = model(pixel_values=encoding.pixel_values) + + expected_shape = torch.Size((1, 300, 256)) + self.assertEqual(outputs.last_hidden_state.shape, expected_shape) + expected_slice = torch.tensor( + [[-0.4879, -0.2594, 0.4524], [-0.4997, -0.4258, 0.4329], [-0.8220, -0.4996, 0.0577]] + ).to(torch_device) + torch.testing.assert_close(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=2e-4, rtol=2e-4) + + def test_inference_object_detection_head(self): + model = DabDetrForObjectDetection.from_pretrained(CHECKPOINT).to(torch_device) + + image_processor = self.default_image_processor + image = prepare_img() + encoding = image_processor(images=image, return_tensors="pt").to(torch_device) + pixel_values = encoding["pixel_values"].to(torch_device) + + with torch.no_grad(): + outputs = model(pixel_values) + + # verify logits + box predictions + expected_shape_logits = torch.Size((1, model.config.num_queries, model.config.num_labels)) + self.assertEqual(outputs.logits.shape, expected_shape_logits) + expected_slice_logits = torch.tensor( + [[-10.1765, -5.5243, -8.9324], [-9.8138, -5.6721, -7.5161], [-10.3054, -5.6081, -8.5931]] + ).to(torch_device) + torch.testing.assert_close(outputs.logits[0, :3, :3], expected_slice_logits, atol=3e-4, rtol=3e-4) + + expected_shape_boxes = torch.Size((1, model.config.num_queries, 4)) + self.assertEqual(outputs.pred_boxes.shape, expected_shape_boxes) + expected_slice_boxes = torch.tensor( + [[0.3708, 0.3000, 0.2753], [0.5211, 0.6125, 0.9495], [0.2897, 0.6730, 0.5459]] + ).to(torch_device) + torch.testing.assert_close(outputs.pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4, rtol=1e-4) + + # verify postprocessing + results = image_processor.post_process_object_detection( + outputs, threshold=0.3, target_sizes=[image.size[::-1]] + )[0] + expected_scores = torch.tensor([0.8732, 0.8563, 0.8554, 0.6079, 0.5896]).to(torch_device) + expected_labels = [17, 75, 17, 75, 63] + expected_boxes = torch.tensor([14.6970, 49.3892, 320.5165, 469.2765]).to(torch_device) + + self.assertEqual(len(results["scores"]), 5) + torch.testing.assert_close(results["scores"], expected_scores, atol=1e-4, rtol=1e-4) + self.assertSequenceEqual(results["labels"].tolist(), expected_labels) + torch.testing.assert_close(results["boxes"][0, :], expected_boxes, atol=1e-4, rtol=1e-4) diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py index 5b5fd52e6fa1..cfc0de5dc0ca 100644 --- a/utils/check_config_attributes.py +++ b/utils/check_config_attributes.py @@ -161,6 +161,16 @@ "giou_loss_coefficient", "mask_loss_coefficient", ], + "DabDetrConfig": [ + "dilation", + "bbox_cost", + "bbox_loss_coefficient", + "class_cost", + "cls_loss_coefficient", + "focal_alpha", + "giou_cost", + "giou_loss_coefficient", + ], "DetrConfig": [ "bbox_cost", "bbox_loss_coefficient",