Skip to content

Commit

Permalink
Merge pull request #1069 from roboflow/from_transformers_improvements
Browse files Browse the repository at this point in the history
`from_transformers` improvements
  • Loading branch information
SkalskiP authored Mar 28, 2024
2 parents ccc7118 + 68e1bae commit 77e38ba
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 25 deletions.
56 changes: 47 additions & 9 deletions docs/how_to/detect_and_annotate.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ model.
from inference import get_model

model = get_model(model_id="yolov8n-640")
image = cv2.imread(<PATH TO IMAGE>)
image = cv2.imread(<SOURCE_IMAGE_APTH>)
results = model.infer(image)[0]
```

Expand All @@ -38,7 +38,7 @@ model.
from ultralytics import YOLO

model = YOLO("yolov8n.pt")
image = cv2.imread(<PATH TO IMAGE>)
image = cv2.imread(<SOURCE_IMAGE_APTH>)
results = model(image)[0]
```

Expand All @@ -52,7 +52,7 @@ model.
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")

image = Image.open(<PATH TO IMAGE>)
image = Image.open(<SOURCE_IMAGE_APTH>)
inputs = processor(images=image, return_tensors="pt")

with torch.no_grad():
Expand All @@ -78,7 +78,7 @@ Now that we have predictions from a model, we can load them into Supervision.
from inference import get_model

model = get_model(model_id="yolov8n-640")
image = cv2.imread(<PATH TO IMAGE>)
image = cv2.imread(<SOURCE_IMAGE_APTH>)
results = model.infer(image)[0]
detections = sv.Detections.from_inference(results)
```
Expand All @@ -93,7 +93,7 @@ Now that we have predictions from a model, we can load them into Supervision.
from ultralytics import YOLO

model = YOLO("yolov8n.pt")
image = cv2.imread(<PATH TO IMAGE>)
image = cv2.imread(<SOURCE_IMAGE_APTH>)
results = model(image)[0]
detections = sv.Detections.from_ultralytics(results)
```
Expand All @@ -111,7 +111,7 @@ Now that we have predictions from a model, we can load them into Supervision.
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")

image = Image.open(<PATH TO IMAGE>)
image = Image.open(<SOURCE_IMAGE_APTH>)
inputs = processor(images=image, return_tensors="pt")

with torch.no_grad():
Expand Down Expand Up @@ -144,7 +144,7 @@ Finally, we can annotate the image with the predictions. Since we are working wi
from inference import get_model

model = get_model(model_id="yolov8n-640")
image = cv2.imread(<PATH TO IMAGE>)
image = cv2.imread(<SOURCE_IMAGE_APTH>)
results = model.infer(image)[0]
detections = sv.Detections.from_inference(results)

Expand All @@ -165,7 +165,7 @@ Finally, we can annotate the image with the predictions. Since we are working wi
from ultralytics import YOLO

model = YOLO("yolov8n.pt")
image = cv2.imread(<PATH TO IMAGE>)
image = cv2.imread(<SOURCE_IMAGE_APTH>)
results = model(image)[0]
detections = sv.Detections.from_ultralytics(results)

Expand All @@ -189,7 +189,7 @@ Finally, we can annotate the image with the predictions. Since we are working wi
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")

image = Image.open(<PATH TO IMAGE>)
image = Image.open(<SOURCE_IMAGE_APTH>)
inputs = processor(images=image, return_tensors="pt")

with torch.no_grad():
Expand Down Expand Up @@ -362,4 +362,42 @@ that will allow you to draw masks instead of boxes.
scene=annotated_image, detections=detections)
```

=== "Transformers"

```python
import torch
import supervision as sv
from PIL import Image
from transformers import DetrImageProcessor, DetrForSegmentation

processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50-panoptic")
model = DetrForSegmentation.from_pretrained("facebook/detr-resnet-50-panoptic")

image = Image.open(<SOURCE_IMAGE_PATH>)
inputs = processor(images=image, return_tensors="pt")

with torch.no_grad():
outputs = model(**inputs)

width, height = image.size
target_size = torch.tensor([[height, width]])
results = processor.post_process_segmentation(
outputs=outputs, target_sizes=target_size)[0]
detections = sv.Detections.from_transformers(results)

mask_annotator = sv.MaskAnnotator()
label_annotator = sv.LabelAnnotator(text_position=sv.Position.CENTER)

labels = [
f"{model.config.id2label[class_id]} {confidence:.2f}"
for class_id, confidence
in zip(detections.class_id, detections.confidence)
]

annotated_image = mask_annotator.annotate(
scene=image, detections=detections)
annotated_image = label_annotator.annotate(
scene=annotated_image, detections=detections, labels=labels)
```

![segmentation-annotation](https://media.roboflow.com/supervision_detect_and_annotate_example_3.png)
55 changes: 39 additions & 16 deletions supervision/detection/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,25 +396,48 @@ def from_transformers(cls, transformers_results: dict) -> Detections:
Returns:
Detections: A new Detections object.
"""
boxes = transformers_results.get("boxes")
# If the boxes key is in the transformers_results then we know it's an
# object detection result. Else, we can assume it's a segmentation model
if boxes:
return cls(
xyxy=transformers_results["boxes"].cpu().numpy(),
confidence=transformers_results["scores"].cpu().numpy(),
class_id=transformers_results["labels"].cpu().numpy().astype(int),
)
else:
masks = transformers_results["masks"].cpu().numpy().astype(bool)
Example:
```python
import torch
import supervision as sv
from PIL import Image
from transformers import DetrImageProcessor, DetrForObjectDetection
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
image = Image.open(<SOURCE_IMAGE_PATH>)
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
width, height = image.size
target_size = torch.tensor([[height, width]])
results = processor.post_process_object_detection(
outputs=outputs, target_sizes=target_size)[0]
detections = sv.Detections.from_transformers(results)
```
""" # noqa: E501 // docs

if "boxes" in transformers_results:
return cls(
xyxy=mask_to_xyxy(masks),
mask=masks,
confidence=transformers_results["scores"].cpu().numpy(),
class_id=transformers_results["labels"].cpu().numpy().astype(int),
xyxy=transformers_results["boxes"].cpu().detach().numpy(),
confidence=transformers_results["scores"].cpu().detach().numpy(),
class_id=transformers_results["labels"]
.cpu()
.detach()
.numpy()
.astype(int),
)
masks = transformers_results["masks"].cpu().detach().numpy().astype(bool)
return cls(
xyxy=mask_to_xyxy(masks),
mask=masks,
confidence=transformers_results["scores"].cpu().detach().numpy(),
class_id=transformers_results["labels"].cpu().detach().numpy().astype(int),
)

@classmethod
def from_detectron2(cls, detectron2_results) -> Detections:
Expand Down

0 comments on commit 77e38ba

Please sign in to comment.