diff --git a/label_studio_ml/examples/yolov8/Dockerfile b/label_studio_ml/examples/yolov8/Dockerfile new file mode 100644 index 000000000..1e52aed48 --- /dev/null +++ b/label_studio_ml/examples/yolov8/Dockerfile @@ -0,0 +1,31 @@ +FROM python:3.11-slim + +RUN apt-get update && \ + apt-get upgrade -y && \ + apt-get install -y git wget && \ + apt-get -y install ffmpeg libsm6 libxext6 libffi-dev python3-dev gcc + +ENV PYTHONUNBUFFERED=True \ + PORT=9090 + +WORKDIR /tmp +COPY requirements.txt . + +RUN pip install --no-cache-dir -r requirements.txt + +RUN wget https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n-oiv7.pt +RUN wget https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n.pt + +COPY uwsgi.ini /etc/uwsgi/ +COPY supervisord.conf /etc/supervisor/conf.d/ + +WORKDIR /app + +RUN mkdir -p datasets/temp/images +RUN mkdir -p datasets/temp/labels + +COPY * /app/ + +EXPOSE 9090 + +CMD ["/usr/local/bin/supervisord", "-c", "/etc/supervisor/conf.d/supervisord.conf"] \ No newline at end of file diff --git a/label_studio_ml/examples/yolov8/README.md b/label_studio_ml/examples/yolov8/README.md new file mode 100644 index 000000000..67d3fefce --- /dev/null +++ b/label_studio_ml/examples/yolov8/README.md @@ -0,0 +1,71 @@ +This project integrates the YOLOv8 model with Label Studio. + + + +https://github.com/HumanSignal/label-studio-ml-backend/assets/106922533/82f539f1-dbee-47bf-b129-f7b5df83af43 + + + +## How The Project Works + +This project helps you detect objects in Label Studio by doing two things. + +1 - Uses a pretrained YOLOv8 model on Google's Open Images V7 (OIV7) to provide a pretrained model on 600 classes! + +2 - Use a custom model for classes in cases that don't fit under the 600 classes in the OIV7 dataset + +While annotating in label studio, you predefine which one of your labels overlap with the first pretrained model and custom labels that don't fit under the 600 classes are automatically used in the second custom model for predictions that is trained as you submit annotations in Label Studio. + +Predictions are then gathered using the OIV7 pretrained model and the custom model in Label Studio in milliseconds, where you can adjust annotations and fine tune your custom model for even more precise predictions. + + +## Setup + +1. Defining Classes for Pretrained and Custom Models + +Edit your labeling config to something like the following + +```xml + + + + + +``` + +In the `class_matching.yml` edit the `labels_to_coco` dictionary to where the keys are the exact names of your rectangular labels in label studio and the values are the exact names of the same classes in [open-images-v7.yaml](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/datasets/open-images-v7.yaml). + +Any classes in your labeling config that you do not add to the `labels_to_coco` dictionary in `class_matching.yml` will be trained using the second, custom model. + +In the `all_classes` dictionary add all of the classes in your Label Studio labeling config that are under the rectangular labels. + +Note: if you leave the `labels_to_coco` dictionary empty with no keys and values, only the custom model will be trained and then used for predictions. In such a case, the model trained on 600 classes will not be used at all. + +2. Editing `docker-compose.yml` + +Set `LABEL_STUDIO_HOST` to your private IP address (which starts with 192 so ex. 192.168.1.1) with the port that label studio is running on. For example, your docker compose may look like `LABEL_STUDIO_HOST=192.168.1.1:8080` + +Set `LABEL_STUDIO_ACCESS_TOKEN` by going to your Label Studio Accounts & Settings, and then copying the Access Token. Paste it into the docker file. Ex. `LABEL_STUDIO_ACCESS_TOKEN=cjneskn2keoqpejleed8d8frje9992jdjdasvbfnwe2jsx` + +3. Running the backend + +Run `docker compose up` to start the backend. Under the `Machine Learning` settings in your project in Label Studio enter the following URL while adding the model: `http://{your_private_ip}:9090`. Note: if you changed the port before running the backend, you will have to change it here as well. + +## Training With ML Backend + +In the machine learning tab for label studio, make sure the first toggle for training the model when annotations are submitted is turned on. This will allow training the custom model for custom classes that you defined in the previous steps when you submit annotations. + +If you would like to train multiple images at once, which is preferred, run label studio from docker using the [`feature/batch-train`](https://github.com/HumanSignal/label-studio/tree/feature/batch-train) branch. Under the app and inside the environment variables in the `docker-compose.yml` add `EXPERIMENTAL_FEATURES=True`. Then, run the instance. + +In the task menu, select all the tasks you would like to train your ML backend custom model on and under the toggle menu in the top left hand corner, select `Batch Train` and select `Ok` in the next popup menu. + + +## Notes + +If you would like to save your model inside of your docker container or move it into your local machine, you will need to access the terminal of your docker container. See how to do this [here](https://stackoverflow.com/a/30173220). + +If you want to train a new custom model, move the `yolov8n(custom).pt` out of your container's directory. It will automatically realize there is no custom model, and will create a new one from scratch to use when training custom models. diff --git a/label_studio_ml/examples/yolov8/_wsgi.py b/label_studio_ml/examples/yolov8/_wsgi.py new file mode 100644 index 000000000..f7cbc094d --- /dev/null +++ b/label_studio_ml/examples/yolov8/_wsgi.py @@ -0,0 +1,114 @@ +import os +import argparse +import json +import logging +import logging.config + +logging.config.dictConfig({ + "version": 1, + "formatters": { + "standard": { + "format": "[%(asctime)s] [%(levelname)s] [%(name)s::%(funcName)s::%(lineno)d] %(message)s" + } + }, + "handlers": { + "console": { + "class": "logging.StreamHandler", + "level": os.getenv('LOG_LEVEL'), + "stream": "ext://sys.stdout", + "formatter": "standard" + } + }, + "root": { + "level": os.getenv('LOG_LEVEL'), + "handlers": [ + "console" + ], + "propagate": True + } +}) + +from label_studio_ml.api import init_app +from model import YOLO_LS + + +_DEFAULT_CONFIG_PATH = os.path.join(os.path.dirname(__file__), 'config.json') + + +def get_kwargs_from_config(config_path=_DEFAULT_CONFIG_PATH): + if not os.path.exists(config_path): + return dict() + with open(config_path) as f: + config = json.load(f) + assert isinstance(config, dict) + return config + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Label studio') + parser.add_argument( + '-p', '--port', dest='port', type=int, default=9090, + help='Server port') + parser.add_argument( + '--host', dest='host', type=str, default='0.0.0.0', + help='Server host') + parser.add_argument( + '--kwargs', '--with', dest='kwargs', metavar='KEY=VAL', nargs='+', type=lambda kv: kv.split('='), + help='Additional LabelStudioMLBase model initialization kwargs') + parser.add_argument( + '-d', '--debug', dest='debug', action='store_true', + help='Switch debug mode') + parser.add_argument( + '--log-level', dest='log_level', choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'], default=None, + help='Logging level') + parser.add_argument( + '--model-dir', dest='model_dir', default=os.path.dirname(__file__), + help='Directory where models are stored (relative to the project directory)') + parser.add_argument( + '--check', dest='check', action='store_true', + help='Validate model instance before launching server') + + args = parser.parse_args() + + # setup logging level + if args.log_level: + logging.root.setLevel(args.log_level) + + def isfloat(value): + try: + float(value) + return True + except ValueError: + return False + + def parse_kwargs(): + param = dict() + for k, v in args.kwargs: + if v.isdigit(): + param[k] = int(v) + elif v == 'True' or v == 'true': + param[k] = True + elif v == 'False' or v == 'false': + param[k] = False + elif isfloat(v): + param[k] = float(v) + else: + param[k] = v + return param + + kwargs = get_kwargs_from_config() + + if args.kwargs: + kwargs.update(parse_kwargs()) + + if args.check: + print('Check "' + YOLO_LS.__name__ + '" instance creation..') + model = YOLO_LS(**kwargs) + + app = init_app(model_class=YOLO_LS) + + app.run(host=args.host, port=args.port, debug=args.debug) + +else: + # for uWSGI use + app = init_app(model_class=YOLO_LS) diff --git a/label_studio_ml/examples/yolov8/class_matching.yml b/label_studio_ml/examples/yolov8/class_matching.yml new file mode 100644 index 000000000..1ffb4d67d --- /dev/null +++ b/label_studio_ml/examples/yolov8/class_matching.yml @@ -0,0 +1,13 @@ +# Note: check here to match class names +# https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/datasets/open-images-v7.yaml +# the key should be what you've named your label, the value is what class it matches from the link above +labels_to_coco: + cats: Cat + lights: Traffic light + cars: Car +all_classes: + - "cats" + - "cars" + - "taxi" + - "lights" + - "others" \ No newline at end of file diff --git a/label_studio_ml/examples/yolov8/custom_config.yml b/label_studio_ml/examples/yolov8/custom_config.yml new file mode 100644 index 000000000..48fbf9a53 --- /dev/null +++ b/label_studio_ml/examples/yolov8/custom_config.yml @@ -0,0 +1,7 @@ +names: + 0: taxi + 1: other +path: ./temp +test: null +train: images +val: images diff --git a/label_studio_ml/examples/yolov8/docker-compose.yml b/label_studio_ml/examples/yolov8/docker-compose.yml new file mode 100644 index 000000000..a55cf4a55 --- /dev/null +++ b/label_studio_ml/examples/yolov8/docker-compose.yml @@ -0,0 +1,32 @@ +version: "3.8" + +services: + redis: + image: redis:alpine + container_name: redis + hostname: redis + volumes: + - "./data/redis:/data" + expose: + - 6379 + server: + container_name: ml-backend + build: . + environment: + - MODEL_DIR=/data/models + - LOG_LEVEL=DEBUG + - LABEL_STUDIO_HOST= + - LABEL_STUDIO_ACCESS_TOKEN= + - RQ_QUEUE_NAME=default + - REDIS_HOST=redis + - REDIS_PORT=6379 + - LABEL_STUDIO_USE_REDIS=true + ports: + - "9090:9090" + depends_on: + - redis + links: + - redis + volumes: + - "./data/server:/data" + - "./logs:/tmp" diff --git a/label_studio_ml/examples/yolov8/model.py b/label_studio_ml/examples/yolov8/model.py new file mode 100644 index 000000000..37cce93eb --- /dev/null +++ b/label_studio_ml/examples/yolov8/model.py @@ -0,0 +1,307 @@ +from typing import List, Dict, Optional +from label_studio_ml.model import LabelStudioMLBase +from label_studio_ml.utils import get_image_local_path +from label_studio_tools.core.utils.io import get_local_path + +import os +from PIL import Image +from uuid import uuid4 +from ultralytics import YOLO +import torch +import os +import yaml +import shutil + + +LABEL_STUDIO_ACCESS_TOKEN = os.environ.get("LABEL_STUDIO_ACCESS_TOKEN") +LABEL_STUDIO_HOST = os.environ.get("LABEL_STUDIO_HOST") + + +with open("class_matching.yml", "r") as file: + ls_config = yaml.safe_load(file) + +label_to_COCO = ls_config["labels_to_coco"] +all_classes = ls_config["all_classes"] + +JUST_CUSTOM = True if len(label_to_COCO) == 0 else False + +# checks if you have already built a custom model +# if you want to do it for a new task, move this model out of the directory +NEW_START = os.path.isfile('yolov8n(custom).pt') + + +class YOLO_LS(LabelStudioMLBase): + + def __init__(self, project_id, **kwargs): + super(YOLO_LS, self).__init__(**kwargs) + + self.device = "cuda" if torch.cuda.is_available else "cpu" # can to mps + + if not JUST_CUSTOM: + self.pretrained_model = YOLO('yolov8n-oiv7.pt') + + if not NEW_START: + shutil.copyfile('./yolov8n.pt', 'yolov8n(custom).pt') + self.custom_model = YOLO('yolov8n(custom).pt') + FIRST_USE = True + else: + self.custom_model = YOLO('yolov8n(custom).pt') + FIRST_USE = False + + self.first_use = FIRST_USE + + self.from_name = "label" + self.to_name = "image" + + classes = all_classes + + self.NEW_START = NEW_START + self.JUST_CUSTOM = JUST_CUSTOM + + self.COCO_to_label = {v:k for k, v in label_to_COCO.items()} + + first_label_classes = list(label_to_COCO.keys()) # raw labels from labelling config + second_label_classes = [x for x in classes if x not in set(first_label_classes)] # raw labels from labelling config + + input_file = "custom_config.yml" + with open(input_file, "r") as file: + data = yaml.safe_load(file) + + if self.NEW_START: + + self.custom_num_to_name = {i:v for i,v in enumerate(second_label_classes)} + + data["names"] = self.custom_num_to_name + + with open(input_file, "w") as file: + yaml.dump(data, file, default_flow_style=False) + else: + self.custom_num_to_name = data["names"] + + self.custom_name_to_num = {v:k for k, v in self.custom_num_to_name.items()} + + + + def predict(self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs) -> List[Dict]: + """ Inference logic for YOLO model """ + + print("..... we here 3.9") + + imgs = [] + lengths = [] + + # loading all images into lists + for task in tasks: + + raw_img_path = task['data']['image'] + + try: + img_path = get_local_path( + url=raw_img_path, + hostname=LABEL_STUDIO_HOST, + access_token=LABEL_STUDIO_ACCESS_TOKEN + ) + + except: + img_path = raw_img_path + + img = Image.open(img_path) + + + imgs.append(img) + + W, H = img.size + lengths.append((H, W)) + + # predicting from PIL loaded images + if not self.JUST_CUSTOM: + try: + results_1 = self.pretrained_model.predict(imgs) # define model earlier + except Exception as e: + print(f"the error was {e}") + else: + results_1 = None + + # we don't want the predictions from the pretrained version of the custom model + # because it hasn't reshaped to the new classes yet + if not self.first_use: + results_2 = self.custom_model.predict(source=imgs) + else: + results_2 = None + + # each item will be the predictions for a task + predictions = [] + + # basically, running this loop for each task + for res_num, results in enumerate([results_1, results_2]): + if results == None: + continue + + for (result, len) in zip(results, lengths): + boxes = result.boxes.cpu().numpy() + pretrained = True if res_num == 0 else False + # results names + predictions.append(self.get_results(boxes.xywh, boxes.cls, len, boxes.conf, result.names, pretrained=pretrained)) + + return predictions + + def get_results(self, boxes, classes, length, confidences, num_to_names_dict, pretrained=True): + """This method returns annotation results that will be packaged and sent to Label Studio frontend""" + + results = [] + + for box, class_num, conf in zip(boxes, classes, confidences): + + label_id = str(uuid4())[:9] + + x, y, w, h = box + + height, width = length + + if pretrained: + name = num_to_names_dict[int(class_num)] + label = self.COCO_to_label.get(name) + else: # then, we are using the custom model + label = num_to_names_dict[int(class_num)] + + if label==None: + continue + + results.append({ + 'id': label_id, + 'from_name': self.from_name, + 'to_name': self.to_name, + 'original_width': int(width), + 'original_height': int(height), + 'image_rotation': 0, + 'value': { + 'rotation': 0, + 'rectanglelabels': [label], + 'width': w / width * 100, + 'height': h / height * 100, + 'x': (x - 0.5*w) / width * 100, + 'y': (y-0.5*h) / height * 100 + }, + 'score': conf.item(), + 'type': 'rectanglelabels', + 'readonly': False + }) + + return { + 'result': results + } + + def fit(self, event, data, **kwargs): + """ + This method is called each time an annotation is created or updated + You can run your logic here to update the model + """ + + all_new_paths = [] + + try: + total_results = data['annotations'] + except: # then this is a submission of just one image from the annotation image page + total_results = [data] + + multiple_results = True if len(total_results) > 1 else False + + for task in total_results: + + if not multiple_results: + raw_img_path = task["task"]["data"]["image"] + else: + raw_img_path = task["data"]["image"] + + + try: + img_path = get_image_local_path( + raw_img_path, + label_studio_access_token=LABEL_STUDIO_ACCESS_TOKEN, + label_studio_host=LABEL_STUDIO_HOST + ) + except: + img_path = raw_img_path + + img = Image.open(img_path) + + sample_img_path = img_path + + img = Image.open(sample_img_path) + + image_name = sample_img_path.split("/")[-1] + + img1 = img.save(f"./datasets/temp/images/{image_name}") + + if not multiple_results: + img2 = img.save(f"./datasets/temp/images/(2){image_name}") + + all_new_paths.append(f"./datasets/temp/images/{image_name}") + + if not multiple_results: + all_new_paths.append(f"./datasets/temp/images/(2){image_name}") + + # now saving text file labels + txt_name = image_name.rsplit('.', 1)[0] + + with open(f'./datasets/temp/labels/{txt_name}.txt', 'w') as f: + f.write("") + + + if not multiple_results: + with open(f'./datasets/temp/labels/(2){txt_name}.txt', 'w') as f: + f.write("") + + all_new_paths.append(f'./datasets/temp/labels/{txt_name}.txt') + + if not multiple_results: + all_new_paths.append(f'./datasets/temp/labels/(2){txt_name}.txt') + + if not multiple_results: + results = task["annotation"]["result"] + else: + results = task["annotations"][0]["result"] + + + for result in results: + + value = result['value'] + label = value['rectanglelabels'][0] + + if label in self.custom_name_to_num: + + # these are out of 100, so you need to convert them back + x = value['x'] + y = value['y'] + width = value['width'] + height = value['height'] + + orig_width = result['original_width'] + orig_height = result['original_height'] + + w = width / 100 + h = height / 100 + trans_x = (x / 100) + (0.5 * w) + trans_y = (y / 100) + (0.5 * h) + + # now getting the class label + label_num = self.custom_name_to_num.get(label) + + with open(f'./datasets/temp/labels/{txt_name}.txt', 'a') as f: + f.write(f"{label_num} {trans_x} {trans_y} {w} {h}\n") + + if not multiple_results: + with open(f'./datasets/temp/labels/(2){txt_name}.txt', 'a') as f: + f.write(f"{label_num} {trans_x} {trans_y} {w} {h}\n") + + results = self.custom_model.train(data='custom_config.yml', epochs = 1, imgsz=640) + + self.first_use = False + + # remove all these files so train starts from nothing next time + self.remove_train_files(all_new_paths) + + def remove_train_files(self, file_paths): + """This cleans the dataset directory""" + for path in file_paths: + os.remove(path) \ No newline at end of file diff --git a/label_studio_ml/examples/yolov8/requirements.txt b/label_studio_ml/examples/yolov8/requirements.txt new file mode 100644 index 000000000..c751da30d --- /dev/null +++ b/label_studio_ml/examples/yolov8/requirements.txt @@ -0,0 +1,55 @@ +gunicorn==20.1.0 +label-studio-ml @ git+https://github.com/HumanSignal/label-studio-ml-backend.git + +appdirs==1.4.4 +blinker==1.7.0 +certifi==2023.7.22 +charset-normalizer==3.3.1 +click==8.1.7 +colorama==0.4.6 +contourpy==1.1.1 +cycler==0.12.1 +filelock==3.12.4 +# Flask==1.1.2 +fonttools==4.43.1 +fsspec==2023.10.0 +idna==3.4 +# itsdangerous==2.0.1 +# Jinja2==3.0.3 +kiwisolver==1.4.5 +label-studio-tools==0.0.3 +lxml==4.9.3 +MarkupSafe==2.1.3 +matplotlib==3.8.0 +mpmath==1.3.0 +networkx==3.2 +numpy==1.26.1 +opencv-python==4.8.1.78 +packaging==23.2 +pandas==2.1.1 +Pillow==10.1.0 +psutil==5.9.6 +py-cpuinfo==9.0.0 +pyparsing==3.1.1 +python-dateutil==2.8.2 +pytz==2023.3.post1 +PyYAML==6.0.1 +requests==2.31.0 +scipy==1.11.3 +seaborn==0.13.0 +six==1.16.0 +sympy==1.12 +thop==0.1.1.post2209072238 + +# pinned torch and torchvision to previous versions due to signal 11 killing worker errors +torch==2.0.1 +torchvision==0.15.2 + +tqdm==4.66.1 +typing_extensions==4.8.0 +tzdata==2023.3 +ultralytics==8.0.200 +urllib3==2.0.7 +supervisor==4.2.2 +uwsgi==2.0.21 +rq diff --git a/label_studio_ml/examples/yolov8/supervisord.conf b/label_studio_ml/examples/yolov8/supervisord.conf new file mode 100644 index 000000000..4079c2132 --- /dev/null +++ b/label_studio_ml/examples/yolov8/supervisord.conf @@ -0,0 +1,40 @@ +[supervisord] +nodaemon = true +loglevel = info +logfile = supervisord.log + +[inet_http_server] +port=127.0.0.1:9001 + +[supervisorctl] +serverurl=http://127.0.0.1:9001 + +[rpcinterface:supervisor] +supervisor.rpcinterface_factory = supervisor.rpcinterface:make_main_rpcinterface + +[program:rq] +process_name=%(program_name)s_%(process_num)02d +command = rq worker --url redis://%(ENV_REDIS_HOST)s:6379/0 %(ENV_RQ_QUEUE_NAME)s +stopsignal = TERM +autostart = true +autorestart = true +killasgroup = true +stopasgroup = true +numprocs = 1 +stderr_logfile = /dev/stderr +stderr_logfile_maxbytes = 0 +stdout_logfile = /dev/stdout +stdout_logfile_maxbytes = 0 + +[program:wsgi] +environment = + RQ_QUEUE_NAME="%(ENV_RQ_QUEUE_NAME)s", + REDIS_HOST="%(ENV_REDIS_HOST)s" +command = uwsgi --ini /etc/uwsgi/uwsgi.ini +autostart = true +autorestart = true +stopsignal = QUIT +stderr_logfile = /dev/stderr +stderr_logfile_maxbytes = 0 +stdout_logfile = /dev/stdout +stdout_logfile_maxbytes = 0 \ No newline at end of file diff --git a/label_studio_ml/examples/yolov8/uwsgi.ini b/label_studio_ml/examples/yolov8/uwsgi.ini new file mode 100644 index 000000000..9f4667019 --- /dev/null +++ b/label_studio_ml/examples/yolov8/uwsgi.ini @@ -0,0 +1,10 @@ +[uwsgi] +protocol = http +socket = 0.0.0.0:9090 +module = _wsgi:app +master = true +processes = 1 +vacuum = true +die-on-term = true +plugins = python37 +pidfile = /tmp/%n.pid \ No newline at end of file