From 4ad78becb674c9c0d82b0961a64aa9d3fefeb455 Mon Sep 17 00:00:00 2001 From: SkalskiP Date: Thu, 28 Mar 2024 16:08:05 +0100 Subject: [PATCH 1/3] initial comment --- docs/how_to/detect_and_annotate.md | 38 ++++++++++++++++++++++ supervision/detection/core.py | 51 ++++++++++++++++++++---------- 2 files changed, 73 insertions(+), 16 deletions(-) diff --git a/docs/how_to/detect_and_annotate.md b/docs/how_to/detect_and_annotate.md index 45d0487a9..6aaabff30 100644 --- a/docs/how_to/detect_and_annotate.md +++ b/docs/how_to/detect_and_annotate.md @@ -362,4 +362,42 @@ that will allow you to draw masks instead of boxes. scene=annotated_image, detections=detections) ``` +=== "Transformers" + + ```python + import torch + import supervision as sv + from PIL import Image + from transformers import DetrImageProcessor, DetrForSegmentation + + processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50") + model = DetrForSegmentation.from_pretrained("facebook/detr-resnet-50") + + image = Image.open() + inputs = processor(images=image, return_tensors="pt") + + with torch.no_grad(): + outputs = model(**inputs) + + width, height = image.size + target_size = torch.tensor([[height, width]]) + results = processor.post_process_object_detection( + outputs=outputs, target_sizes=target_size)[0] + detections = sv.Detections.from_transformers(results) + + mask_annotator = sv.MaskAnnotator() + label_annotator = sv.LabelAnnotator(text_position=sv.Position.CENTER) + + labels = [ + f"{model.config.id2label[class_id]} {confidence:.2f}" + for class_id, confidence + in zip(detections.class_id, detections.confidence) + ] + + annotated_image = mask_annotator.annotate( + scene=image, detections=detections) + annotated_image = label_annotator.annotate( + scene=annotated_image, detections=detections) + ``` + ![segmentation-annotation](https://media.roboflow.com/supervision_detect_and_annotate_example_3.png) diff --git a/supervision/detection/core.py b/supervision/detection/core.py index 60ff07f33..51616335b 100644 --- a/supervision/detection/core.py +++ b/supervision/detection/core.py @@ -396,25 +396,44 @@ def from_transformers(cls, transformers_results: dict) -> Detections: Returns: Detections: A new Detections object. - """ - boxes = transformers_results.get("boxes") - # If the boxes key is in the transformers_results then we know it's an - # object detection result. Else, we can assume it's a segmentation model - if boxes: - return cls( - xyxy=transformers_results["boxes"].cpu().numpy(), - confidence=transformers_results["scores"].cpu().numpy(), - class_id=transformers_results["labels"].cpu().numpy().astype(int), - ) - else: - masks = transformers_results["masks"].cpu().numpy().astype(bool) + Example: + ```python + import torch + import supervision as sv + from PIL import Image + from transformers import DetrImageProcessor, DetrForObjectDetection + + processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50") + model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50") + + image = Image.open() + inputs = processor(images=image, return_tensors="pt") + + with torch.no_grad(): + outputs = model(**inputs) + + width, height = image.size + target_size = torch.tensor([[height, width]]) + results = processor.post_process_object_detection( + outputs=outputs, target_sizes=target_size)[0] + detections = sv.Detections.from_transformers(results) + ``` + """ # noqa: E501 // docs + + if "boxes" in transformers_results: return cls( - xyxy=mask_to_xyxy(masks), - mask=masks, - confidence=transformers_results["scores"].cpu().numpy(), - class_id=transformers_results["labels"].cpu().numpy().astype(int), + xyxy=transformers_results["boxes"].cpu().detach().numpy(), + confidence=transformers_results["scores"].cpu().detach().numpy(), + class_id=transformers_results["labels"].cpu().detach().numpy().astype(int), ) + masks = transformers_results["masks"].cpu().detach().numpy().astype(bool) + return cls( + xyxy=mask_to_xyxy(masks), + mask=masks, + confidence=transformers_results["scores"].cpu().detach().numpy(), + class_id=transformers_results["labels"].cpu().detach().numpy().astype(int), + ) @classmethod def from_detectron2(cls, detectron2_results) -> Detections: From 0c7c746f15c1ed39f691b606d2cdb6c7ca0015f4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 28 Mar 2024 15:10:28 +0000 Subject: [PATCH 2/3] =?UTF-8?q?fix(pre=5Fcommit):=20=F0=9F=8E=A8=20auto=20?= =?UTF-8?q?format=20pre-commit=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- supervision/detection/core.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/supervision/detection/core.py b/supervision/detection/core.py index 51616335b..8c29cfc48 100644 --- a/supervision/detection/core.py +++ b/supervision/detection/core.py @@ -403,16 +403,16 @@ def from_transformers(cls, transformers_results: dict) -> Detections: import supervision as sv from PIL import Image from transformers import DetrImageProcessor, DetrForObjectDetection - + processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50") model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50") - + image = Image.open() inputs = processor(images=image, return_tensors="pt") - + with torch.no_grad(): outputs = model(**inputs) - + width, height = image.size target_size = torch.tensor([[height, width]]) results = processor.post_process_object_detection( @@ -425,7 +425,11 @@ def from_transformers(cls, transformers_results: dict) -> Detections: return cls( xyxy=transformers_results["boxes"].cpu().detach().numpy(), confidence=transformers_results["scores"].cpu().detach().numpy(), - class_id=transformers_results["labels"].cpu().detach().numpy().astype(int), + class_id=transformers_results["labels"] + .cpu() + .detach() + .numpy() + .astype(int), ) masks = transformers_results["masks"].cpu().detach().numpy().astype(bool) return cls( From afc31e42a5b8b2dab2fd544233587231ce1bbcf8 Mon Sep 17 00:00:00 2001 From: SkalskiP Date: Thu, 28 Mar 2024 16:34:05 +0100 Subject: [PATCH 3/3] few more improvements --- docs/how_to/detect_and_annotate.md | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/docs/how_to/detect_and_annotate.md b/docs/how_to/detect_and_annotate.md index 6aaabff30..7bc2f9708 100644 --- a/docs/how_to/detect_and_annotate.md +++ b/docs/how_to/detect_and_annotate.md @@ -27,7 +27,7 @@ model. from inference import get_model model = get_model(model_id="yolov8n-640") - image = cv2.imread() + image = cv2.imread() results = model.infer(image)[0] ``` @@ -38,7 +38,7 @@ model. from ultralytics import YOLO model = YOLO("yolov8n.pt") - image = cv2.imread() + image = cv2.imread() results = model(image)[0] ``` @@ -52,7 +52,7 @@ model. processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50") model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50") - image = Image.open() + image = Image.open() inputs = processor(images=image, return_tensors="pt") with torch.no_grad(): @@ -78,7 +78,7 @@ Now that we have predictions from a model, we can load them into Supervision. from inference import get_model model = get_model(model_id="yolov8n-640") - image = cv2.imread() + image = cv2.imread() results = model.infer(image)[0] detections = sv.Detections.from_inference(results) ``` @@ -93,7 +93,7 @@ Now that we have predictions from a model, we can load them into Supervision. from ultralytics import YOLO model = YOLO("yolov8n.pt") - image = cv2.imread() + image = cv2.imread() results = model(image)[0] detections = sv.Detections.from_ultralytics(results) ``` @@ -111,7 +111,7 @@ Now that we have predictions from a model, we can load them into Supervision. processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50") model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50") - image = Image.open() + image = Image.open() inputs = processor(images=image, return_tensors="pt") with torch.no_grad(): @@ -144,7 +144,7 @@ Finally, we can annotate the image with the predictions. Since we are working wi from inference import get_model model = get_model(model_id="yolov8n-640") - image = cv2.imread() + image = cv2.imread() results = model.infer(image)[0] detections = sv.Detections.from_inference(results) @@ -165,7 +165,7 @@ Finally, we can annotate the image with the predictions. Since we are working wi from ultralytics import YOLO model = YOLO("yolov8n.pt") - image = cv2.imread() + image = cv2.imread() results = model(image)[0] detections = sv.Detections.from_ultralytics(results) @@ -189,7 +189,7 @@ Finally, we can annotate the image with the predictions. Since we are working wi processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50") model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50") - image = Image.open() + image = Image.open() inputs = processor(images=image, return_tensors="pt") with torch.no_grad(): @@ -370,8 +370,8 @@ that will allow you to draw masks instead of boxes. from PIL import Image from transformers import DetrImageProcessor, DetrForSegmentation - processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50") - model = DetrForSegmentation.from_pretrained("facebook/detr-resnet-50") + processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50-panoptic") + model = DetrForSegmentation.from_pretrained("facebook/detr-resnet-50-panoptic") image = Image.open() inputs = processor(images=image, return_tensors="pt") @@ -381,7 +381,7 @@ that will allow you to draw masks instead of boxes. width, height = image.size target_size = torch.tensor([[height, width]]) - results = processor.post_process_object_detection( + results = processor.post_process_segmentation( outputs=outputs, target_sizes=target_size)[0] detections = sv.Detections.from_transformers(results) @@ -397,7 +397,7 @@ that will allow you to draw masks instead of boxes. annotated_image = mask_annotator.annotate( scene=image, detections=detections) annotated_image = label_annotator.annotate( - scene=annotated_image, detections=detections) + scene=annotated_image, detections=detections, labels=labels) ``` ![segmentation-annotation](https://media.roboflow.com/supervision_detect_and_annotate_example_3.png)