Skip to content

Commit

Permalink
Add DAB-DETR for object detection (#30803)
Browse files Browse the repository at this point in the history
* initial commit

* encoder+decoder layer changes WIP

* architecture checks

* working version of detection + segmentation

* fix modeling outputs

* fix return dict + output att/hs

* found the position embedding masking bug

* pre-training version

* added iamge processors

* typo in init.py

* iterupdate set to false

* fixed num_labels in class_output linear layer bias init

* multihead attention shape fixes

* test improvements

* test update

* dab-detr model_doc update

* dab-detr model_doc update2

* test fix:test_retain_grad_hidden_states_attentions

* config file clean and renaming variables

* config file clean and renaming variables fix

* updated convert_to_hf file

* small fixes

* style and qulity checks

* return_dict fix

* Merge branch main into add_dab_detr

* small comment fix

* skip test_inputs_embeds test

* image processor updates + image processor test updates

* check copies test fix update

* updates for check_copies.py test

* updates for check_copies.py test2

* tied weights fix

* fixed image processing tests and fixed shared weights issues

* added numpy nd array option to get_Expected_values method in test_image_processing_dab_detr.py

* delete prints from test file

* SafeTensor modification to solve HF Trainer issue

* removing the safetensor modifications

* make fix copies and hf uplaod has been added.

* fixed index.md

* fixed repo consistency

* styel fix and dabdetrimageprocessor docstring update

* requested modifications after the first review

* Update src/transformers/models/dab_detr/image_processing_dab_detr.py

Co-authored-by: Pavel Iakubovskii <[email protected]>

* repo consistency has been fixed

* update copied NestedTensor function after main merge

* Update src/transformers/models/dab_detr/modeling_dab_detr.py

Co-authored-by: Pavel Iakubovskii <[email protected]>

* temp commit

* temp commit2

* temp commit 3

* unit tests are fixed

* fixed repo consistency

* updated expected_boxes varible values based on related notebook results in DABDETRIntegrationTests file.

* temporarialy config modifications and repo consistency fixes

* Put dilation parameter back to config

* pattern embeddings have been added to the rename_keys method

* add dilation comment to config + add as an exception in check_config_attributes SPECIAL CASES

* delete FeatureExtractor part from docs.md

* requested modifications in modeling_dab_detr.py

* [run_slow] dab_detr

* deleted last segmentation code part, updated conversion script and changed the hf path in test files

* temp commit of requested modifications

* temp commit of requested modifications 2

* updated config file, resolved codepaths and refactored conversion script

* updated decodelayer block types and refactored conversion script

* style and quality update

* small modifications based on the request

* attentions are refactored

* removed loss functions from modeling file, added loss function to lossutils, tried to move the MLP layer generation to config but it failed

* deleted imageprocessor

* fixed conversion script + quality and style

* fixed config_att

* [run_slow] dab_detr

* changing model path in conversion file and in test file

* fix Decoder variable naming

* testing the old loss function

* switched back to the new loss function and testing with the odl attention functions

* switched back to the new last good result modeling file

* moved back to the version when I asked the review

* missing new line at the end of the file

* old version test

* turn back to newest mdoel versino but change image processor

* style fix

* style fix after merge main

* [run_slow] dab_detr

* [run_slow] dab_detr

* added device and type for head bias data part

* [run_slow] dab_detr

* fixed model head bias data fill

* changed test_inference_object_detection_head assertTrues to torch test assert_close

* fixes part 1

* quality update

* self.bbox_embed in decoder has been restored

* changed Assert true torch closeall methods to torch testing assertclose

* modelcard markdown file has been updated

* deleted intemediate list from decoder module

---------

Co-authored-by: Pavel Iakubovskii <[email protected]>
  • Loading branch information
conditionedstimulus and qubvel authored Feb 4, 2025
1 parent fe52679 commit 8d73a38
Show file tree
Hide file tree
Showing 20 changed files with 3,259 additions and 2 deletions.
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,8 @@
title: ConvNeXTV2
- local: model_doc/cvt
title: CvT
- local: model_doc/dab-detr
title: DAB-DETR
- local: model_doc/deformable_detr
title: Deformable DETR
- local: model_doc/deit
Expand Down
1 change: 1 addition & 0 deletions docs/source/en/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ Flax), PyTorch, and/or TensorFlow.
| [CPM-Ant](model_doc/cpmant) ||||
| [CTRL](model_doc/ctrl) ||||
| [CvT](model_doc/cvt) ||||
| [DAB-DETR](model_doc/dab-detr) ||||
| [DAC](model_doc/dac) ||||
| [Data2VecAudio](model_doc/data2vec) ||||
| [Data2VecText](model_doc/data2vec) ||||
Expand Down
119 changes: 119 additions & 0 deletions docs/source/en/model_doc/dab-detr.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->

# DAB-DETR

## Overview

The DAB-DETR model was proposed in [DAB-DETR: Dynamic Anchor Boxes are Better Queries for DETR](https://arxiv.org/abs/2201.12329) by Shilong Liu, Feng Li, Hao Zhang, Xiao Yang, Xianbiao Qi, Hang Su, Jun Zhu, Lei Zhang.
DAB-DETR is an enhanced variant of Conditional DETR. It utilizes dynamically updated anchor boxes to provide both a reference query point (x, y) and a reference anchor size (w, h), improving cross-attention computation. This new approach achieves 45.7% AP when trained for 50 epochs with a single ResNet-50 model as the backbone.

<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/dab_detr_convergence_plot.png"
alt="drawing" width="600"/>

The abstract from the paper is the following:

*We present in this paper a novel query formulation using dynamic anchor boxes
for DETR (DEtection TRansformer) and offer a deeper understanding of the role
of queries in DETR. This new formulation directly uses box coordinates as queries
in Transformer decoders and dynamically updates them layer-by-layer. Using box
coordinates not only helps using explicit positional priors to improve the query-to-feature similarity and eliminate the slow training convergence issue in DETR,
but also allows us to modulate the positional attention map using the box width
and height information. Such a design makes it clear that queries in DETR can be
implemented as performing soft ROI pooling layer-by-layer in a cascade manner.
As a result, it leads to the best performance on MS-COCO benchmark among
the DETR-like detection models under the same setting, e.g., AP 45.7% using
ResNet50-DC5 as backbone trained in 50 epochs. We also conducted extensive
experiments to confirm our analysis and verify the effectiveness of our methods.*

This model was contributed by [davidhajdu](https://huggingface.co/davidhajdu).
The original code can be found [here](https://github.com/IDEA-Research/DAB-DETR).

## How to Get Started with the Model

Use the code below to get started with the model.

```python
import torch
import requests

from PIL import Image
from transformers import AutoModelForObjectDetection, AutoImageProcessor

url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)

image_processor = AutoImageProcessor.from_pretrained("IDEA-Research/dab-detr-resnet-50")
model = AutoModelForObjectDetection.from_pretrained("IDEA-Research/dab-detr-resnet-50")

inputs = image_processor(images=image, return_tensors="pt")

with torch.no_grad():
outputs = model(**inputs)

results = image_processor.post_process_object_detection(outputs, target_sizes=torch.tensor([image.size[::-1]]), threshold=0.3)

for result in results:
for score, label_id, box in zip(result["scores"], result["labels"], result["boxes"]):
score, label = score.item(), label_id.item()
box = [round(i, 2) for i in box.tolist()]
print(f"{model.config.id2label[label]}: {score:.2f} {box}")
```
This should output
```
cat: 0.87 [14.7, 49.39, 320.52, 469.28]
remote: 0.86 [41.08, 72.37, 173.39, 117.2]
cat: 0.86 [344.45, 19.43, 639.85, 367.86]
remote: 0.61 [334.27, 75.93, 367.92, 188.81]
couch: 0.59 [-0.04, 1.34, 639.9, 477.09]
```

There are three other ways to instantiate a DAB-DETR model (depending on what you prefer):

Option 1: Instantiate DAB-DETR with pre-trained weights for entire model
```py
>>> from transformers import DabDetrForObjectDetection

>>> model = DabDetrForObjectDetection.from_pretrained("IDEA-Research/dab-detr-resnet-50")
```

Option 2: Instantiate DAB-DETR with randomly initialized weights for Transformer, but pre-trained weights for backbone
```py
>>> from transformers import DabDetrConfig, DabDetrForObjectDetection

>>> config = DabDetrConfig()
>>> model = DabDetrForObjectDetection(config)
```
Option 3: Instantiate DAB-DETR with randomly initialized weights for backbone + Transformer
```py
>>> config = DabDetrConfig(use_pretrained_backbone=False)
>>> model = DabDetrForObjectDetection(config)
```


## DabDetrConfig

[[autodoc]] DabDetrConfig

## DabDetrModel

[[autodoc]] DabDetrModel
- forward

## DabDetrForObjectDetection

[[autodoc]] DabDetrForObjectDetection
- forward
16 changes: 16 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,7 @@
"CTRLTokenizer",
],
"models.cvt": ["CvtConfig"],
"models.dab_detr": ["DabDetrConfig"],
"models.dac": ["DacConfig", "DacFeatureExtractor"],
"models.data2vec": [
"Data2VecAudioConfig",
Expand Down Expand Up @@ -1898,6 +1899,13 @@
"CvtPreTrainedModel",
]
)
_import_structure["models.dab_detr"].extend(
[
"DabDetrForObjectDetection",
"DabDetrModel",
"DabDetrPreTrainedModel",
]
)
_import_structure["models.dac"].extend(
[
"DacModel",
Expand Down Expand Up @@ -5387,6 +5395,9 @@
CTRLTokenizer,
)
from .models.cvt import CvtConfig
from .models.dab_detr import (
DabDetrConfig,
)
from .models.dac import (
DacConfig,
DacFeatureExtractor,
Expand Down Expand Up @@ -6926,6 +6937,11 @@
CvtModel,
CvtPreTrainedModel,
)
from .models.dab_detr import (
DabDetrForObjectDetection,
DabDetrModel,
DabDetrPreTrainedModel,
)
from .models.dac import (
DacModel,
DacPreTrainedModel,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def __getitem__(self, key):
"silu": nn.SiLU,
"swish": nn.SiLU,
"tanh": nn.Tanh,
"prelu": nn.PReLU,
}
ACT2FN = ClassInstantier(ACT2CLS)

Expand Down
1 change: 1 addition & 0 deletions src/transformers/loss/loss_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def ForTokenClassification(logits, labels, config, **kwargs):
"ForObjectDetection": ForObjectDetectionLoss,
"DeformableDetrForObjectDetection": DeformableDetrForObjectDetectionLoss,
"ConditionalDetrForObjectDetection": DeformableDetrForObjectDetectionLoss,
"DabDetrForObjectDetection": DeformableDetrForObjectDetectionLoss,
"GroundingDinoForObjectDetection": DeformableDetrForObjectDetectionLoss,
"ConditionalDetrForSegmentation": DeformableDetrForSegmentationLoss,
"RTDetrForObjectDetection": RTDetrForObjectDetectionLoss,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
cpmant,
ctrl,
cvt,
dab_detr,
dac,
data2vec,
dbrx,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
("cpmant", "CpmAntConfig"),
("ctrl", "CTRLConfig"),
("cvt", "CvtConfig"),
("dab-detr", "DabDetrConfig"),
("dac", "DacConfig"),
("data2vec-audio", "Data2VecAudioConfig"),
("data2vec-text", "Data2VecTextConfig"),
Expand Down Expand Up @@ -399,6 +400,7 @@
("cpmant", "CPM-Ant"),
("ctrl", "CTRL"),
("cvt", "CvT"),
("dab-detr", "DAB-DETR"),
("dac", "DAC"),
("data2vec-audio", "Data2VecAudio"),
("data2vec-text", "Data2VecText"),
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
("cpmant", "CpmAntModel"),
("ctrl", "CTRLModel"),
("cvt", "CvtModel"),
("dab-detr", "DabDetrModel"),
("dac", "DacModel"),
("data2vec-audio", "Data2VecAudioModel"),
("data2vec-text", "Data2VecTextModel"),
Expand Down Expand Up @@ -592,6 +593,7 @@
("conditional_detr", "ConditionalDetrModel"),
("convnext", "ConvNextModel"),
("convnextv2", "ConvNextV2Model"),
("dab-detr", "DabDetrModel"),
("data2vec-vision", "Data2VecVisionModel"),
("deformable_detr", "DeformableDetrModel"),
("deit", "DeiTModel"),
Expand Down Expand Up @@ -890,6 +892,7 @@
[
# Model for Object Detection mapping
("conditional_detr", "ConditionalDetrForObjectDetection"),
("dab-detr", "DabDetrForObjectDetection"),
("deformable_detr", "DeformableDetrForObjectDetection"),
("deta", "DetaForObjectDetection"),
("detr", "DetrForObjectDetection"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class ConditionalDetrConfig(PretrainedConfig):
Number of object queries, i.e. detection slots. This is the maximal number of objects
[`ConditionalDetrModel`] can detect in a single image. For COCO, we recommend 100 queries.
d_model (`int`, *optional*, defaults to 256):
Dimension of the layers.
This parameter is a general dimension parameter, defining dimensions for components such as the encoder layer and projection parameters in the decoder layer, among others.
encoder_layers (`int`, *optional*, defaults to 6):
Number of encoder layers.
decoder_layers (`int`, *optional*, defaults to 6):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ class ConditionalDetrDecoderOutput(BaseModelOutputWithCrossAttentions):
intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, hidden_size)`, *optional*, returned when `config.auxiliary_loss=True`):
Intermediate decoder activations, i.e. the output of each decoder layer, each of them gone through a
layernorm.
reference_points (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, 2 (anchor points))`):
Reference points (reference points of each layer of the decoder).
"""

intermediate_hidden_states: Optional[torch.FloatTensor] = None
Expand Down Expand Up @@ -116,6 +118,8 @@ class ConditionalDetrModelOutput(Seq2SeqModelOutput):
intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, sequence_length, hidden_size)`, *optional*, returned when `config.auxiliary_loss=True`):
Intermediate decoder activations, i.e. the output of each decoder layer, each of them gone through a
layernorm.
reference_points (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, 2 (anchor points))`):
Reference points (reference points of each layer of the decoder).
"""

intermediate_hidden_states: Optional[torch.FloatTensor] = None
Expand Down
28 changes: 28 additions & 0 deletions src/transformers/models/dab_detr/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING

from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure


if TYPE_CHECKING:
from .configuration_dab_detr import *
from .modeling_dab_detr import *
else:
import sys

_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
Loading

0 comments on commit 8d73a38

Please sign in to comment.