Skip to content

Commit 5ff2d29

Browse files
depppMikhail MaluykMichael Malyuknik
authored
feat: Basic Auth support, Azure openAI, Semver, and using LabelInterface (#437)
* Adding basic authentication, and updating how model_version is handled. Also fixing links to the docs. Tests to come * Adding azure openai, plus response model * Adding a bit more tests :> fixing typos * fixing response * Merge azure into llm_interactive * git sdk install * Fix errors * Updating readme * Change versions * Add pytest verbosity * Change functional test CI location --------- Co-authored-by: Mikhail Maluyk <[email protected]> Co-authored-by: Michael Malyuk <[email protected]> Co-authored-by: nik <[email protected]>
1 parent 3fcf6f6 commit 5ff2d29

34 files changed

+708
-251
lines changed

.github/workflows/tests.yml

+1-2
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,7 @@ jobs:
104104
105105
- name: Run general functional tests
106106
run: |
107-
cd label_studio_ml/
108-
pytest --ignore-glob='**/logs/*' --ignore-glob='**/data/*' --cov=. --cov-report=xml
107+
pytest -vvv --ignore-glob='**/logs/*' --ignore-glob='**/data/*' --cov=. --cov-report=xml
109108
110109
- name: Pull the logs
111110
if: always()

Makefile

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
SHELL := /bin/bash
2+
3+
install:
4+
pip install -r requirements.txt
5+
6+
test:
7+
pytest tests

label_studio_ml/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
package_name = 'label-studio-ml'
33

44
# Package version
5-
__version__ = '2.0.0.dev0'
5+
__version__ = '2.0.1dev0'

label_studio_ml/api.py

+57-17
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,30 @@
1+
import hmac
12
import logging
23

3-
from flask import Flask, request, jsonify
4+
from flask import Flask, request, jsonify, Response
45

6+
from .response import ModelResponse
57
from .model import LabelStudioMLBase
68
from .exceptions import exception_handler
79

8-
910
logger = logging.getLogger(__name__)
1011

1112
_server = Flask(__name__)
1213
MODEL_CLASS = LabelStudioMLBase
14+
BASIC_AUTH = None
1315

1416

15-
def init_app(model_class):
17+
def init_app(model_class, basic_auth_user=None, basic_auth_pass=None):
1618
global MODEL_CLASS
19+
global BASIC_AUTH
1720

1821
if not issubclass(model_class, LabelStudioMLBase):
1922
raise ValueError('Inference class should be the subclass of ' + LabelStudioMLBase.__class__.__name__)
2023

2124
MODEL_CLASS = model_class
25+
if basic_auth_user and basic_auth_pass:
26+
BASIC_AUTH = (basic_auth_user, basic_auth_pass)
27+
2228
return _server
2329

2430

@@ -46,20 +52,38 @@ def _predict():
4652
"""
4753
data = request.json
4854
tasks = data.get('tasks')
49-
params = data.get('params') or {}
50-
project = data.get('project')
51-
if project:
52-
project_id = data.get('project').split('.', 1)[0]
53-
else:
54-
project_id = None
5555
label_config = data.get('label_config')
56+
project = data.get('project')
57+
project_id = project.split('.', 1)[0] if project else None
58+
params = data.get('params', {})
5659
context = params.pop('context', {})
5760

58-
model = MODEL_CLASS(project_id)
59-
model.use_label_config(label_config)
61+
model = MODEL_CLASS(project_id=project_id,
62+
label_config=label_config)
63+
64+
# model.use_label_config(label_config)
65+
66+
response = model.predict(tasks, context=context, **params)
67+
68+
# if there is no model version we will take the default
69+
if isinstance(response, ModelResponse):
70+
if not response.has_model_version:
71+
mv = model.model_version
72+
if mv:
73+
response.set_version(mv)
74+
else:
75+
response.update_predictions_version()
76+
77+
response = response.serialize()
78+
79+
res = response
80+
if res is None:
81+
res = []
6082

61-
predictions = model.predict(tasks, context=context, **params)
62-
return jsonify({'results': predictions})
83+
if isinstance(res, dict):
84+
res = response.get("predictions", response)
85+
86+
return jsonify({'results': res})
6387

6488

6589
@_server.route('/setup', methods=['POST'])
@@ -68,8 +92,13 @@ def _setup():
6892
data = request.json
6993
project_id = data.get('project').split('.', 1)[0]
7094
label_config = data.get('schema')
71-
model = MODEL_CLASS(project_id)
72-
model.use_label_config(label_config)
95+
extra_params = data.get('extra_params')
96+
model = MODEL_CLASS(project_id=project_id,
97+
label_config=label_config)
98+
99+
if extra_params:
100+
model.set_extra_params(extra_params)
101+
73102
model_version = model.get('model_version')
74103
return jsonify({'model_version': model_version})
75104

@@ -90,8 +119,7 @@ def webhook():
90119
return jsonify({'status': 'Unknown event'}), 200
91120
project_id = str(data['project']['id'])
92121
label_config = data['project']['label_config']
93-
model = MODEL_CLASS(project_id)
94-
model.use_label_config(label_config)
122+
model = MODEL_CLASS(project_id, label_config=label_config)
95123
model.fit(event, data)
96124
return jsonify({}), 201
97125

@@ -130,6 +158,18 @@ def index_error(error):
130158
return str(error), 500
131159

132160

161+
def safe_str_cmp(a, b):
162+
return hmac.compare_digest(a, b)
163+
164+
165+
@_server.before_request
166+
def check_auth():
167+
if BASIC_AUTH is not None:
168+
auth = request.authorization
169+
if not auth or not (safe_str_cmp(auth.username, BASIC_AUTH[0]) and safe_str_cmp(auth.password, BASIC_AUTH[1])):
170+
return Response('Unauthorized', 401, {'WWW-Authenticate': 'Basic realm="Login required"'})
171+
172+
133173
@_server.before_request
134174
def log_request_info():
135175
logger.debug('Request headers: %s', request.headers)

label_studio_ml/default_configs/_wsgi.py.tmpl

+8-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,14 @@ if __name__ == "__main__":
6767
parser.add_argument(
6868
'--check', dest='check', action='store_true',
6969
help='Validate model instance before launching server')
70-
70+
parser.add_argument('--basic-auth-user',
71+
default=os.environ.get('ML_SERVER_BASIC_AUTH_USER', None),
72+
help='Basic auth user')
73+
74+
parser.add_argument('--basic-auth-pass',
75+
default=os.environ.get('ML_SERVER_BASIC_AUTH_PASS', None),
76+
help='Basic auth pass')
77+
7178
args = parser.parse_args()
7279

7380
# setup logging level

label_studio_ml/default_configs/model.py

+29-3
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,47 @@
33

44

55
class NewModel(LabelStudioMLBase):
6+
"""Custom ML Backend model
7+
"""
8+
9+
def setup(self):
10+
"""Configure any paramaters of your model here
11+
"""
12+
self.set("model_version", "0.0.1")
613

14+
715
def predict(self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs) -> List[Dict]:
816
""" Write your inference logic here
917
:param tasks: [Label Studio tasks in JSON format](https://labelstud.io/guide/task_format.html)
10-
:param context: [Label Studio context in JSON format](https://labelstud.io/guide/ml.html#Passing-data-to-ML-backend)
11-
:return predictions: [Predictions array in JSON format](https://labelstud.io/guide/export.html#Raw-JSON-format-of-completed-tasks)
18+
:param context: [Label Studio context in JSON format](https://labelstud.io/guide/ml_create#Implement-prediction-logic)
19+
:return predictions: [Predictions array in JSON format](https://labelstud.io/guide/export.html#Label-Studio-JSON-format-of-annotated-tasks)
1220
"""
1321
print(f'''\
1422
Run prediction on {tasks}
1523
Received context: {context}
1624
Project ID: {self.project_id}
1725
Label config: {self.label_config}
18-
Parsed JSON Label config: {self.parsed_label_config}''')
26+
Parsed JSON Label config: {self.parsed_label_config}
27+
Extra params: {self.extra_params}''')
28+
29+
# example for simple classification
30+
# return [{
31+
# "model_version": self.get("model_version"),
32+
# "score": 0.12,
33+
# "result": [{
34+
# "id": "vgzE336-a8",
35+
# "from_name": "sentiment",
36+
# "to_name": "text",
37+
# "type": "choices",
38+
# "value": {
39+
# "choices": [ "Negative" ]
40+
# }
41+
# }]
42+
# }]
43+
1944
return []
2045

46+
2147
def fit(self, event, data, **kwargs):
2248
"""
2349
This method is called each time an annotation is created or updated

label_studio_ml/examples/easyocr/easyocr_labeling.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010
import easyocr
1111

1212
from label_studio_ml.model import LabelStudioMLBase
13-
from label_studio_ml.utils import get_image_size, \
14-
get_single_tag_keys, DATA_UNDEFINED_NAME
13+
from label_studio_ml.utils import get_image_size, DATA_UNDEFINED_NAME
1514
from label_studio_tools.core.utils.io import get_data_dir
1615
from botocore.exceptions import ClientError
1716
from urllib.parse import urlparse

label_studio_ml/examples/llm_interactive/README.md

+14-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
## Interactive LLM labeling
22

3-
This example server connects Label Studio to OpenAI's API to interact with GPT chat models (gpt-3.5-turbo, gpt-4, etc.).
3+
This example server connects Label Studio to [OpenAI](https://platform.openai.com/) or [Azure](https://azure.microsoft.com/en-us/products/ai-services/openai-service) API to interact with GPT chat models (gpt-3.5-turbo, gpt-4, etc.).
44

55
The interactive flow allows you to perform the following scenarios:
66

@@ -231,7 +231,18 @@ When deploying the server, you can specify the following parameters as environme
231231
- `PROMPT_TEMPLATE` (default: `"Source Text: {text}\n\nTask Directive: {prompt}"`): The prompt template to use. If `USE_INTERNAL_PROMPT_TEMPLATE` is set to `1`, the server will use
232232
the default internal prompt template. If `USE_INTERNAL_PROMPT_TEMPLATE` is set to `0`, the server will use the prompt template provided
233233
in the input prompt (i.e. the user input from `<TextArea name="my-prompt" ...>`). In the later case, the user has to provide the placeholders that match input task fields. For example, if the user wants to use the `input_text` and `instruction` field from the input task `{"input_text": "user text", "instruction": "user instruction"}`, the user has to provide the prompt template like this: `"Source Text: {input_text}, Custom instruction : {instruction}"`.
234-
- `OPENAI_MODEL` (default: `gpt-3.5-turbo`) : The OpenAI model to use.
234+
- `OPENAI_MODEL` (default: `gpt-3.5-turbo`) : The OpenAI model to use.
235+
- `OPENAI_PROVIDER` (available options: `openai`, `azure`, default - `openai`) : The OpenAI provider to use.
235236
- `TEMPERATURE` (default: `0.7`): The temperature to use for the model.
236237
- `NUM_RESPONSES` (default: `1`): The number of responses to generate in `<TextArea>` output fields. Useful if you want to generate multiple responses and let the user rank the best one.
237-
- `OPENAI_API_KEY`: The OpenAI API key to use. Must be set before deploying the server.
238+
- `OPENAI_API_KEY`: The OpenAI or Azure API key to use. Must be set before deploying the server.
239+
240+
### Azure Configuration
241+
242+
If you are using Azure as your OpenAI provider (`OPENAI_PROVIDER=azure`), you need to specify the following environment variables:
243+
244+
- `AZURE_RESOURCE_ENDPOINT`: This is the endpoint for your Azure resource. It should be set to the appropriate value based on your Azure setup.
245+
246+
- `AZURE_DEPLOYMENT_NAME`: This is the name of your Azure deployment. It should match the name you've given to your deployment in Azure.
247+
248+
- `AZURE_API_VERSION`: This is the version of the Azure API you are using. The default value is `2023-05-15`.

label_studio_ml/examples/llm_interactive/docker-compose.yml

+20-2
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,30 @@ services:
77
build: .
88
environment:
99
- MODEL_DIR=/data/models
10+
# Specify openai model provider: "openai" or "azure"
11+
- OPENAI_PROVIDER=openai
12+
# Specify API key for openai or azure
1013
- OPENAI_API_KEY=
11-
- OPENAI_MODEL=gpt-4
12-
- PROMPT_PREFIX=
14+
# Specify model name for openai or azure (by default it uses "gpt-3.5-turbo-instruct")
15+
- OPENAI_MODEL=
16+
# Internal prompt template for the model is:
17+
# **Source Text**:\n\n"{text}"\n\n**Task Directive**:\n\n"{prompt}"
18+
# if you want to specify task data keys in the prompt (i.e. input <TextArea name="$PROMPT_PREFIX..."/>, set this to 0
19+
- USE_INTERNAL_PROMPT_TEMPLATE=1
20+
# Prompt prefix for the TextArea component in the frontend to be used for the user input
21+
- PROMPT_PREFIX=prompt
22+
# Log level for the server
1323
- LOG_LEVEL=DEBUG
24+
# Number of responses to generate for each request
1425
- NUM_RESPONSES=1
26+
# Temperature for the model
1527
- TEMPERATURE=0.7
28+
# Azure resourse endpoint (in case OPENAI_PROVIDER=azure)
29+
- AZURE_RESOURCE_ENDPOINT=
30+
# Azure deployment name (in case OPENAI_PROVIDER=azure)
31+
- AZURE_DEPLOYMENT_NAME=
32+
# Azure API version (in case OPENAI_PROVIDER=azure)
33+
- AZURE_API_VERSION=2023-05-15
1634
ports:
1735
- 9090:9090
1836
volumes:

0 commit comments

Comments
 (0)