-
Notifications
You must be signed in to change notification settings - Fork 27.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add DAB-DETR for object detection (#30803)
* initial commit * encoder+decoder layer changes WIP * architecture checks * working version of detection + segmentation * fix modeling outputs * fix return dict + output att/hs * found the position embedding masking bug * pre-training version * added iamge processors * typo in init.py * iterupdate set to false * fixed num_labels in class_output linear layer bias init * multihead attention shape fixes * test improvements * test update * dab-detr model_doc update * dab-detr model_doc update2 * test fix:test_retain_grad_hidden_states_attentions * config file clean and renaming variables * config file clean and renaming variables fix * updated convert_to_hf file * small fixes * style and qulity checks * return_dict fix * Merge branch main into add_dab_detr * small comment fix * skip test_inputs_embeds test * image processor updates + image processor test updates * check copies test fix update * updates for check_copies.py test * updates for check_copies.py test2 * tied weights fix * fixed image processing tests and fixed shared weights issues * added numpy nd array option to get_Expected_values method in test_image_processing_dab_detr.py * delete prints from test file * SafeTensor modification to solve HF Trainer issue * removing the safetensor modifications * make fix copies and hf uplaod has been added. * fixed index.md * fixed repo consistency * styel fix and dabdetrimageprocessor docstring update * requested modifications after the first review * Update src/transformers/models/dab_detr/image_processing_dab_detr.py Co-authored-by: Pavel Iakubovskii <[email protected]> * repo consistency has been fixed * update copied NestedTensor function after main merge * Update src/transformers/models/dab_detr/modeling_dab_detr.py Co-authored-by: Pavel Iakubovskii <[email protected]> * temp commit * temp commit2 * temp commit 3 * unit tests are fixed * fixed repo consistency * updated expected_boxes varible values based on related notebook results in DABDETRIntegrationTests file. * temporarialy config modifications and repo consistency fixes * Put dilation parameter back to config * pattern embeddings have been added to the rename_keys method * add dilation comment to config + add as an exception in check_config_attributes SPECIAL CASES * delete FeatureExtractor part from docs.md * requested modifications in modeling_dab_detr.py * [run_slow] dab_detr * deleted last segmentation code part, updated conversion script and changed the hf path in test files * temp commit of requested modifications * temp commit of requested modifications 2 * updated config file, resolved codepaths and refactored conversion script * updated decodelayer block types and refactored conversion script * style and quality update * small modifications based on the request * attentions are refactored * removed loss functions from modeling file, added loss function to lossutils, tried to move the MLP layer generation to config but it failed * deleted imageprocessor * fixed conversion script + quality and style * fixed config_att * [run_slow] dab_detr * changing model path in conversion file and in test file * fix Decoder variable naming * testing the old loss function * switched back to the new loss function and testing with the odl attention functions * switched back to the new last good result modeling file * moved back to the version when I asked the review * missing new line at the end of the file * old version test * turn back to newest mdoel versino but change image processor * style fix * style fix after merge main * [run_slow] dab_detr * [run_slow] dab_detr * added device and type for head bias data part * [run_slow] dab_detr * fixed model head bias data fill * changed test_inference_object_detection_head assertTrues to torch test assert_close * fixes part 1 * quality update * self.bbox_embed in decoder has been restored * changed Assert true torch closeall methods to torch testing assertclose * modelcard markdown file has been updated * deleted intemediate list from decoder module --------- Co-authored-by: Pavel Iakubovskii <[email protected]>
- Loading branch information
1 parent
fe52679
commit 8d73a38
Showing
20 changed files
with
3,259 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
<!--Copyright 2024 The HuggingFace Team. All rights reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | ||
the License. You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | ||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
specific language governing permissions and limitations under the License. | ||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be | ||
rendered properly in your Markdown viewer. | ||
--> | ||
|
||
# DAB-DETR | ||
|
||
## Overview | ||
|
||
The DAB-DETR model was proposed in [DAB-DETR: Dynamic Anchor Boxes are Better Queries for DETR](https://arxiv.org/abs/2201.12329) by Shilong Liu, Feng Li, Hao Zhang, Xiao Yang, Xianbiao Qi, Hang Su, Jun Zhu, Lei Zhang. | ||
DAB-DETR is an enhanced variant of Conditional DETR. It utilizes dynamically updated anchor boxes to provide both a reference query point (x, y) and a reference anchor size (w, h), improving cross-attention computation. This new approach achieves 45.7% AP when trained for 50 epochs with a single ResNet-50 model as the backbone. | ||
|
||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/dab_detr_convergence_plot.png" | ||
alt="drawing" width="600"/> | ||
|
||
The abstract from the paper is the following: | ||
|
||
*We present in this paper a novel query formulation using dynamic anchor boxes | ||
for DETR (DEtection TRansformer) and offer a deeper understanding of the role | ||
of queries in DETR. This new formulation directly uses box coordinates as queries | ||
in Transformer decoders and dynamically updates them layer-by-layer. Using box | ||
coordinates not only helps using explicit positional priors to improve the query-to-feature similarity and eliminate the slow training convergence issue in DETR, | ||
but also allows us to modulate the positional attention map using the box width | ||
and height information. Such a design makes it clear that queries in DETR can be | ||
implemented as performing soft ROI pooling layer-by-layer in a cascade manner. | ||
As a result, it leads to the best performance on MS-COCO benchmark among | ||
the DETR-like detection models under the same setting, e.g., AP 45.7% using | ||
ResNet50-DC5 as backbone trained in 50 epochs. We also conducted extensive | ||
experiments to confirm our analysis and verify the effectiveness of our methods.* | ||
|
||
This model was contributed by [davidhajdu](https://huggingface.co/davidhajdu). | ||
The original code can be found [here](https://github.com/IDEA-Research/DAB-DETR). | ||
|
||
## How to Get Started with the Model | ||
|
||
Use the code below to get started with the model. | ||
|
||
```python | ||
import torch | ||
import requests | ||
|
||
from PIL import Image | ||
from transformers import AutoModelForObjectDetection, AutoImageProcessor | ||
|
||
url = 'http://images.cocodataset.org/val2017/000000039769.jpg' | ||
image = Image.open(requests.get(url, stream=True).raw) | ||
|
||
image_processor = AutoImageProcessor.from_pretrained("IDEA-Research/dab-detr-resnet-50") | ||
model = AutoModelForObjectDetection.from_pretrained("IDEA-Research/dab-detr-resnet-50") | ||
|
||
inputs = image_processor(images=image, return_tensors="pt") | ||
|
||
with torch.no_grad(): | ||
outputs = model(**inputs) | ||
|
||
results = image_processor.post_process_object_detection(outputs, target_sizes=torch.tensor([image.size[::-1]]), threshold=0.3) | ||
|
||
for result in results: | ||
for score, label_id, box in zip(result["scores"], result["labels"], result["boxes"]): | ||
score, label = score.item(), label_id.item() | ||
box = [round(i, 2) for i in box.tolist()] | ||
print(f"{model.config.id2label[label]}: {score:.2f} {box}") | ||
``` | ||
This should output | ||
``` | ||
cat: 0.87 [14.7, 49.39, 320.52, 469.28] | ||
remote: 0.86 [41.08, 72.37, 173.39, 117.2] | ||
cat: 0.86 [344.45, 19.43, 639.85, 367.86] | ||
remote: 0.61 [334.27, 75.93, 367.92, 188.81] | ||
couch: 0.59 [-0.04, 1.34, 639.9, 477.09] | ||
``` | ||
|
||
There are three other ways to instantiate a DAB-DETR model (depending on what you prefer): | ||
|
||
Option 1: Instantiate DAB-DETR with pre-trained weights for entire model | ||
```py | ||
>>> from transformers import DabDetrForObjectDetection | ||
|
||
>>> model = DabDetrForObjectDetection.from_pretrained("IDEA-Research/dab-detr-resnet-50") | ||
``` | ||
|
||
Option 2: Instantiate DAB-DETR with randomly initialized weights for Transformer, but pre-trained weights for backbone | ||
```py | ||
>>> from transformers import DabDetrConfig, DabDetrForObjectDetection | ||
|
||
>>> config = DabDetrConfig() | ||
>>> model = DabDetrForObjectDetection(config) | ||
``` | ||
Option 3: Instantiate DAB-DETR with randomly initialized weights for backbone + Transformer | ||
```py | ||
>>> config = DabDetrConfig(use_pretrained_backbone=False) | ||
>>> model = DabDetrForObjectDetection(config) | ||
``` | ||
|
||
|
||
## DabDetrConfig | ||
|
||
[[autodoc]] DabDetrConfig | ||
|
||
## DabDetrModel | ||
|
||
[[autodoc]] DabDetrModel | ||
- forward | ||
|
||
## DabDetrForObjectDetection | ||
|
||
[[autodoc]] DabDetrForObjectDetection | ||
- forward |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -63,6 +63,7 @@ | |
cpmant, | ||
ctrl, | ||
cvt, | ||
dab_detr, | ||
dac, | ||
data2vec, | ||
dbrx, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
# Copyright 2024 The HuggingFace Team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import TYPE_CHECKING | ||
|
||
from ...utils import _LazyModule | ||
from ...utils.import_utils import define_import_structure | ||
|
||
|
||
if TYPE_CHECKING: | ||
from .configuration_dab_detr import * | ||
from .modeling_dab_detr import * | ||
else: | ||
import sys | ||
|
||
_file = globals()["__file__"] | ||
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) |
Oops, something went wrong.