Skip to content

Commit 01b0516

Browse files
makseqmicaelakaplancaitlinwheeless
authored
feat: RND-116: YOLOv8 ML Backend (#607)
Co-authored-by: Micaela Kaplan <[email protected]> Co-authored-by: caitlinwheeless <[email protected]> Co-authored-by: micaelakaplan <[email protected]>
1 parent 49e8929 commit 01b0516

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+8928
-27
lines changed

README.md

+2
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ Check the **Required parameters** column to see if you need to set any additiona
6161
| [spacy](/label_studio_ml/examples/spacy) | NER by [SpaCy](https://spacy.io/) |||| None | Set [(see documentation)](https://spacy.io/usage/linguistic-features) |
6262
| [tesseract](/label_studio_ml/examples/tesseract) | Interactive OCR. [Details](https://github.com/tesseract-ocr/tesseract) |||| None | Set (characters) |
6363
| [watsonX](/label_studio_ml/exampels/watsonx)| LLM inference with [WatsonX](https://www.ibm.com/products/watsonx-ai) and integration with [WatsonX.data](watsonx.data)|||| None| Arbitrary|
64+
| [yolo](/label_studio_ml/examples/yolo) | Object detection with [YOLO](https://docs.ultralytics.com/tasks/) |||| None | Arbitrary |
65+
6466
# (Advanced usage) Develop your model
6567

6668
To start developing your own ML backend, follow the instructions below.

label_studio_ml/examples/mmdetection-3/mmdetection.py

+1
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def build_labels_from_labeling_config(self, schema):
100100
for ls_label, label_attrs in self.labels_attrs.items():
101101
predicted_values = label_attrs.get("predicted_values", "").split(",")
102102
for predicted_value in predicted_values:
103+
predicted_value = predicted_value.strip() # remove spaces at the beginning and at the end
103104
if predicted_value: # it shouldn't be empty (like '')
104105
if predicted_value not in mmdet_labels:
105106
print(
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
gunicorn==22.0.0
2-
label-studio-ml @ git+https://github.com/HumanSignal/label-studio-ml-backend.git
2+
label-studio-ml @ git+https://github.com/HumanSignal/label-studio-ml-backend.git@fix/rnd-117

label_studio_ml/examples/mmdetection-3/test_model.py

+1-17
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from mmdetection import MMDetection
44

5-
from pytest import approx
5+
from label_studio_ml.utils import compare_nested_structures
66

77
label_config = """
88
<View>
@@ -41,22 +41,6 @@
4141
]
4242

4343

44-
def compare_nested_structures(a, b, path=""):
45-
"""Compare two dicts or list with approx() for float values"""
46-
if isinstance(a, dict) and isinstance(b, dict):
47-
assert a.keys() == b.keys(), f"Keys mismatch at {path}"
48-
for key in a.keys():
49-
compare_nested_structures(a[key], b[key], path + "." + str(key))
50-
elif isinstance(a, list) and isinstance(b, list):
51-
assert len(a) == len(b), f"List size mismatch at {path}"
52-
for i, (act_item, exp_item) in enumerate(zip(a, b)):
53-
compare_nested_structures(act_item, exp_item, path + f"[{i}]")
54-
elif isinstance(a, float) and isinstance(b, float):
55-
assert a == approx(b), f"Mismatch at {path}"
56-
else:
57-
assert a == b, f"Mismatch at {path}"
58-
59-
6044
def test_mmdetection_model_predict():
6145
model = MMDetection(label_config=label_config)
6246
predictions = model.predict([task])
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Exclude everything
2+
**
3+
4+
# Include Dockerfile and docker-compose for reference (optional, decide based on your use case)
5+
!Dockerfile
6+
!docker-compose.yml
7+
8+
# Include Python application files
9+
!*.py
10+
!*.yaml
11+
!tests/*
12+
!control_models/*
13+
!models/*
14+
15+
# Include requirements files
16+
!requirements*.txt
17+
18+
# Include script
19+
!*.sh
20+
21+
# Exclude specific requirements if necessary
22+
# requirements-test.txt (Uncomment if you decide to exclude this)
+62
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
FROM pytorch/pytorch:2.1.2-cuda12.1-cudnn8-runtime
2+
ARG DEBIAN_FRONTEND=noninteractive
3+
ARG TEST_ENV
4+
5+
WORKDIR /app
6+
7+
RUN conda update conda -y
8+
9+
RUN --mount=type=cache,target="/var/cache/apt",sharing=locked \
10+
--mount=type=cache,target="/var/lib/apt/lists",sharing=locked \
11+
apt-get -y update \
12+
&& apt-get install -y git \
13+
&& apt-get install -y wget \
14+
&& apt-get install -y g++ freeglut3-dev build-essential libx11-dev \
15+
libxmu-dev libxi-dev libglu1-mesa libglu1-mesa-dev libfreeimage-dev \
16+
&& apt-get -y install ffmpeg libsm6 libxext6 libffi-dev python3-dev python3-pip gcc
17+
18+
ENV PYTHONUNBUFFERED=1 \
19+
PYTHONDONTWRITEBYTECODE=1 \
20+
PIP_CACHE_DIR=/.cache \
21+
PORT=9090 \
22+
WORKERS=2 \
23+
THREADS=4 \
24+
CUDA_HOME=/usr/local/cuda
25+
26+
RUN conda install -c "nvidia/label/cuda-12.1.1" cuda -y
27+
ENV CUDA_HOME=/opt/conda \
28+
TORCH_CUDA_ARCH_LIST="6.0;6.1;7.0;7.5;8.0;8.6+PTX;8.9;9.0"
29+
30+
# install base requirements
31+
COPY requirements-base.txt .
32+
RUN --mount=type=cache,target=${PIP_CACHE_DIR},sharing=locked \
33+
pip install -r requirements-base.txt
34+
35+
# install model requirements
36+
COPY requirements.txt .
37+
RUN --mount=type=cache,target=${PIP_CACHE_DIR},sharing=locked \
38+
pip3 install -r requirements.txt
39+
40+
# install test requirements if needed
41+
COPY requirements-test.txt .
42+
# build only when TEST_ENV="true"
43+
RUN --mount=type=cache,target=${PIP_CACHE_DIR},sharing=locked \
44+
if [ "$TEST_ENV" = "true" ]; then \
45+
pip3 install -r requirements-test.txt; \
46+
fi
47+
48+
WORKDIR /app
49+
50+
COPY . ./
51+
52+
WORKDIR /app/models
53+
54+
# Download the YOLO models
55+
RUN yolo predict model=yolov8m.pt source=/app/tests/car.jpg \
56+
&& yolo predict model=yolov8n.pt source=/app/tests/car.jpg \
57+
&& yolo predict model=yolov8n-cls.pt source=/app/tests/car.jpg \
58+
&& yolo predict model=yolov8n-seg.pt source=/app/tests/car.jpg
59+
60+
WORKDIR /app
61+
62+
CMD ["/app/start.sh"]

label_studio_ml/examples/yolo/README.md

+810
Large diffs are not rendered by default.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
```mermaid
2+
classDiagram
3+
class ControlModel {
4+
+str type
5+
+ControlTag control
6+
+str from_name
7+
+str to_name
8+
+str value
9+
+YOLO model
10+
+float model_score_threshold
11+
+Optional[Dict[str, str]] label_map
12+
+LabelStudioMLBase label_studio_ml_backend
13+
+get_cached_model(path: str) YOLO
14+
+create(cls, mlbackend: LabelStudioMLBase, control: ControlTag) ControlModel
15+
+predict_regions(path: str) List[Dict]
16+
+debug_plot(image)
17+
}
18+
19+
class RectangleLabelsModel {
20+
+predict_regions(path: str) List[Dict]
21+
+create_rectangles(results, path) List[Dict]
22+
}
23+
24+
class RectangleLabelsObbModel {
25+
+predict_regions(path: str) List[Dict]
26+
+create_rotated_rectangles(results, path) List[Dict]
27+
}
28+
29+
30+
class PolygonLabelsModel {
31+
+predict_regions(path: str) List[Dict]
32+
+create_polygons(results, path) List[Dict]
33+
}
34+
35+
class KeyPointLabelsModel {
36+
+predict_regions(path: str) List[Dict]
37+
+create_keypoints(results, path) List[Dict]
38+
}
39+
40+
class ChoicesModel {
41+
+predict_regions(path: str) List[Dict]
42+
+create_choices(results, path) List[Dict]
43+
}
44+
45+
class VideoRectangleModel {
46+
+predict_regions(path: str) List[Dict]
47+
+create_video_rectangles(results, path) List[Dict]
48+
+update_tracker_params(yaml_path: str, prefix: str) str | None
49+
}
50+
51+
ControlModel <|-- RectangleLabelsModel
52+
ControlModel <|-- RectangleLabelsObbModel
53+
ControlModel <|-- PolygonLabelsModel
54+
ControlModel <|-- ChoicesModel
55+
ControlModel <|-- KeyPointLabelsModel
56+
ControlModel <|-- VideoRectangleModel
57+
58+
```
59+
60+
### 1. **Architecture Overview**
61+
62+
The architecture of the project is modular and is primarily centered around integrating YOLO-based models with Label Studio to automate the labeling of images and videos. The system is organized into several Python modules that interact with each other to perform this task. The main components of the architecture include:
63+
64+
1. **Main YOLO Integration Module (`model.py`)**:
65+
- This is the central module that connects Label Studio with YOLO models. It handles the overall process of detecting control tags from Label Studio’s configuration, running predictions on tasks, and returning the predictions in the format that Label Studio expects.
66+
67+
2. **Control Models (`control_models/`)**:
68+
- The control models are specialized modules that correspond to different annotation types in Label Studio (e.g., RectangleLabels, PolygonLabels, Choices, VideoRectangle). Each control model is responsible for handling specific types of annotations by using the YOLO model to predict the necessary regions or labels.
69+
70+
3. **Base Control Model (`control_models/base.py`)**:
71+
- This is an abstract base class that provides common functionality for all control models. It handles tasks like loading the YOLO model, caching it for efficiency, and providing a template for the predict and create methods.
72+
73+
4. **Specific Control Models**:
74+
- **RectangleLabelsModel (`control_models/rectanglelabels.py`)**: Handles bounding boxes (both simple and oriented bounding boxes) for images.
75+
- **PolygonLabelsModel (`control_models/polygonlabels.py`)**: Deals with polygon annotations, typically used for segmentation tasks.
76+
- **ChoicesModel (`control_models/choices.py`)**: Manages classification tasks where the model predicts one or more labels for the entire image.
77+
- **KeyPointLabelsModel (`control_models/keypointlabels.py`)**: Supports keypoint annotations, where the model predicts the locations of keypoints on an image.
78+
- **VideoRectangleModel (`control_models/videorectangle.py`)**: Focuses on tracking objects across video frames, generating bounding boxes for each frame.
79+
80+
### 2. **Module Descriptions**
81+
82+
1. **`model.py` (Main YOLO Integration Module)**:
83+
- **Purpose**: This module serves as the entry point for integrating YOLO models with Label Studio. It is responsible for setting up the YOLO model, detecting which control models are needed based on the Label Studio configuration, running predictions on tasks, and returning the results in the required format.
84+
- **Key Functions**:
85+
- `setup()`: Initializes the YOLO model parameters.
86+
- `detect_control_models()`: Scans the Label Studio configuration to determine which control models to use.
87+
- `predict()`: Runs predictions on a batch of tasks and formats the results for Label Studio.
88+
- `fit()`: (Not implemented) Placeholder for updating the model based on new annotations.
89+
90+
2. **`control_models/base.py` (Base Control Model)**:
91+
- **Purpose**: Provides a common interface and shared functionality for all specific control models. It includes methods for loading and caching the YOLO model, plotting results for debugging, and abstract methods that need to be implemented by subclasses.
92+
- **Key Functions**:
93+
- `get_cached_model()`: Retrieves a YOLO model from cache or loads it if not cached.
94+
- `create()`: Factory method to instantiate a control model.
95+
- `predict_regions()`: Abstract method to be implemented by subclasses to perform predictions.
96+
97+
3. **`control_models/choices.py` (ChoicesModel)**:
98+
- **Purpose**: Handles classification tasks where the model predicts one or more labels for an image. It converts the YOLO model’s classification output into Label Studio’s choices format.
99+
- **Key Functions**:
100+
- `create_choices()`: Processes the YOLO model’s output and maps it to the Label Studio choices format.
101+
102+
4. **`control_models/rectanglelabels.py` (RectangleLabelsModel)**:
103+
- **Purpose**: Manages the creation of bounding box annotations, both simple (axis-aligned) and oriented (rotated), from the YOLO model’s output.
104+
- **Key Functions**:
105+
- `create_rectangles()`: Converts the YOLO model’s bounding box predictions into Label Studio’s rectangle labels format.
106+
- `create_rotated_rectangles()`: Handles oriented bounding boxes (OBB) by processing rotation angles and converting them to the required format.
107+
108+
5. **`control_models/polygonlabels.py` (PolygonLabelsModel)**:
109+
- **Purpose**: Converts segmentation masks generated by the YOLO model into polygon annotations for Label Studio. This is useful for tasks where precise boundaries around objects are required.
110+
- **Key Functions**:
111+
- `create_polygons()`: Transforms the YOLO model’s segmentation output into polygon annotations.
112+
113+
6. **`control_models/keypointlabels.py` (KeyPointLabelsModel)**:
114+
- **Purpose**: Supports keypoint annotations by predicting the locations of keypoints on an image using the pose YOLO model.
115+
- **Key Functions**:
116+
- `create_keypoints()`: Processes the YOLO model’s keypoint predictions and converts them into Label Studio’s keypoint labels format.
117+
118+
7. **`control_models/videorectangle.py` (VideoRectangleModel)**:
119+
- **Purpose**: Focuses on tracking objects across video frames, using YOLO’s tracking capabilities to generate bounding box annotations for each frame in a video sequence.
120+
- **Key Functions**:
121+
- `predict_regions()`: Runs YOLO’s tracking model on a video and converts the results into Label Studio’s video rectangle format.
122+
- `create_video_rectangles()`: Processes the output of the tracking model to create a sequence of bounding boxes across video frames.
123+
- `update_tracker_params()`: Customizes the tracking parameters based on settings in Label Studio’s configuration.
124+
125+
### **Module Interaction**
126+
127+
- **Workflow**: The main workflow begins with `model.py`, which reads tasks and the Label Studio configuration to detect and instantiate the appropriate control models. These control models are responsible for making predictions using the YOLO model and converting the results into a format that Label Studio can use for annotations.
128+
129+
- **Inter-Module Communication**: Each control model inherits from `ControlModel` in `base.py`, ensuring that they all share common methods for loading the YOLO model, handling predictions, and caching. The specific control models (e.g., RectangleLabelsModel, PolygonLabelsModel) implement the abstract methods defined in `ControlModel` to provide the specialized behavior needed for different types of annotations.
130+
131+
This modular structure allows for easy extension and modification, where new control models can be added to handle additional annotation types or new model architectures.

0 commit comments

Comments
 (0)