Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mmm-18320] operator create registered model registered model version #136

52 changes: 52 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -1394,6 +1394,58 @@ Example of DAG config params:

For more [prediction-explanations](https://datarobot-public-api-client.readthedocs-hosted.com/en/latest-release/autodoc/api_reference.html?highlight=PredictionExplanationsInitialization#prediction-explanations), see the DataRobot documentation.

---
#### `CreateRegisteredModelVersionOperator`

Dynamically creates a registered model version using one of three methods:
- Leaderboard Model
- Custom Model
- External Model

Parameters:

| Parameter | Type | Description |
|-----------------------|------|--------------------------------|
| `model_type` | str | Type of model version to create (leaderboard, custom, or external).|
| `name` | str | Name of the registered model version.|
| `model_id` | str | (Required for leaderboard) The ID of the leaderboard model. |
| `registered_model_name` | str | (Required for leaderboard) Name of the registered model. |
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This parameter has nothing to do with leaderboard or not.

| `custom_model_version_id` | str | (Required for custom) The ID of the custom model version. |
| `registered_model_id` | str | (Required for external) The ID of the registered model. |
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic and text description regarding the registerd_model_name and registered_model_id are incorrect. Those parameters are mutually exclusive and do not depend on underlying model type.
When providing registered model name you are saying: please take this model, create a model package (registered model version) and assign it to new registered model with given name.
When providing registered_model_id you are assigning newly created model package to the existing registered model.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are right will correct this, as for other parameters that are optional will extend methods to pass them as kwargs, and update readme file

| `target` | str | (Required for external) The target for the external model. |

Example of DAG config params:

Leaderboard Model
```
{
"model_type": "leaderboard",
"name": "My Registered Model",
"model_id": "123456789",
"registered_model_name": "My Model Registry"
}
```

Custom Model
```
{
"model_type": "custom",
"name": "Custom Model Version",
"custom_model_version_id": "987654321"
}
```

External Model
```
{
"model_type": "external",
"name": "External Model Version",
"registered_model_id": "1234567891011",
"target": "prediction"
}
```
For more [Model Registry](https://datarobot-public-api-client.readthedocs-hosted.com/en/latest-release/reference/mlops/model_registry.html#create-registered-model-version), see the DataRobot documentation.

---

### [Sensors](https://github.com/datarobot/airflow-provider-datarobot/blob/main/datarobot_provider/sensors/datarobot.py)
Expand Down
97 changes: 97 additions & 0 deletions datarobot_provider/operators/model_registry.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this new file is missing a copyright - I had assumed each new file should have that.

CC @mjnitz02 maybe we should be sure (assuming we need copyrights) that we're running a copyright checker as part of PR checks?

Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from enum import Enum
from typing import Any
from typing import Dict
from typing import Sequence

import datarobot as dr
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.utils.context import Context

from datarobot_provider.hooks.datarobot import DataRobotHook


class ModelType(Enum):
LEADERBOARD = "leaderboard"
CUSTOM = "custom"
EXTERNAL = "external"


class CreateRegisteredModelVersionOperator(BaseOperator):
c-h-russell-walker marked this conversation as resolved.
Show resolved Hide resolved
"""
Dynamically creates a registered model version using one of three methods:
- Leaderboard Model
- Custom Model Version
- External Model

:param model_version_params: Dictionary with parameters for creating model version.
:param datarobot_conn_id: Airflow connection ID for DataRobot.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, probably last comment related to the new base class, but I think we can also remove this.

Suggested change
:param datarobot_conn_id: Airflow connection ID for DataRobot.

"""

template_fields: Sequence[str] = ["model_version_params"]
template_fields_renderers: dict[str, str] = {}
template_ext: Sequence[str] = ()
ui_color = "#f4a460"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should be able to remove this boiler plate code thanks to using the BaseDatarobotOperator 😎

Suggested change
template_fields_renderers: dict[str, str] = {}
template_ext: Sequence[str] = ()
ui_color = "#f4a460"


def __init__(
self,
*,
model_version_params: Dict[str, Any],
datarobot_conn_id: str = "datarobot_default",
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.model_version_params = model_version_params
self.datarobot_conn_id = datarobot_conn_id
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here with some boiler plate code - it could be useful to check against this PR's changes for examples:
https://github.com/datarobot/airflow-provider-datarobot/pull/135/files

Suggested change
datarobot_conn_id: str = "datarobot_default",
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.model_version_params = model_version_params
self.datarobot_conn_id = datarobot_conn_id
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.model_version_params = model_version_params


def execute(self, context: Context) -> str:
"""Executes the operator to create a registered model version."""
DataRobotHook(datarobot_conn_id=self.datarobot_conn_id).run()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also taken care of in the newer base class 👍

Suggested change
DataRobotHook(datarobot_conn_id=self.datarobot_conn_id).run()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This happens in the base operator


model_type = self.model_version_params.get("model_type")
model_name = self.model_version_params.get("name")

if not model_type:
raise ValueError("'model_type' must be specified in model_version_params.")

try:
model_type_enum = ModelType(model_type)
except ValueError:
raise AirflowException(f"Invalid model_type: {model_type}") from None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you move these lines into the validate() method now provided by the newer base class?

I think lines should work as is as long as you alter:
if not model_type:
to be:
if not self.model_version_params.get("model_type"):


model_creation_methods = {
ModelType.LEADERBOARD: self.create_for_leaderboard,
ModelType.CUSTOM: self.create_for_custom,
ModelType.EXTERNAL: self.create_for_external,
}

create_method = model_creation_methods.get(model_type_enum)
if not create_method:
raise AirflowException(f"Unsupported model_type: {model_type}")

version = create_method(model_name)
self.log.info(f"Successfully created model version: {version.id}")
return version.id

def create_for_leaderboard(self, model_name):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the above comments, this should probably be reworked with regards to registered_model_name/id

"""Creates a registered model version from a leaderboard model."""
return dr.RegisteredModelVersion.create_for_leaderboard_item(
model_id=self.model_version_params["model_id"],
name=model_name,
registered_model_name=self.model_version_params["registered_model_name"],
)

def create_for_custom(self, model_name):
"""Creates a registered model version for a custom model."""
return dr.RegisteredModelVersion.create_for_custom_model_version(
custom_model_version_id=self.model_version_params["custom_model_version_id"],
name=model_name,
)

def create_for_external(self, model_name):
"""Creates a registered model version for an external model."""
return dr.RegisteredModelVersion.create_for_external(
name=model_name,
target=self.model_version_params["target"],
registered_model_id=self.model_version_params["registered_model_id"],
)
88 changes: 88 additions & 0 deletions tests/unit/operators/test_model_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import datarobot as dr

from datarobot_provider.hooks.datarobot import DataRobotHook
from datarobot_provider.operators.model_registry import CreateRegisteredModelVersionOperator


def test_create_registered_model_version_leaderboard(mocker):
model_version_params = {
"model_type": "leaderboard",
"name": "Test Model",
"model_id": "123456",
"registered_model_name": "Test Registry",
}

mocker.patch.object(DataRobotHook, "run", return_value=None)
mock_version = mocker.Mock()
mock_version.id = "version-123"

create_mock = mocker.patch.object(
dr.RegisteredModelVersion, "create_for_leaderboard_item", return_value=mock_version
)

operator = CreateRegisteredModelVersionOperator(
task_id="test_create_model_version",
model_version_params=model_version_params,
)

result = operator.execute(context={})

create_mock.assert_called_with(
model_id="123456", name="Test Model", registered_model_name="Test Registry"
)
assert result == "version-123"


def test_create_registered_model_version_custom(mocker):
model_version_params = {
"model_type": "custom",
"name": "Custom Model",
"custom_model_version_id": "987654",
}

mocker.patch.object(DataRobotHook, "run", return_value=None)
mock_version = mocker.Mock()
mock_version.id = "version-123"

create_mock = mocker.patch.object(
dr.RegisteredModelVersion, "create_for_custom_model_version", return_value=mock_version
)

operator = CreateRegisteredModelVersionOperator(
task_id="test_create_model_version_custom",
model_version_params=model_version_params,
)

result = operator.execute(context={})

create_mock.assert_called_with(custom_model_version_id="987654", name="Custom Model")
assert result == "version-123"


def test_create_registered_model_version_external(mocker):
model_version_params = {
"model_type": "external",
"name": "External Model",
"registered_model_id": "123456",
"target": "classification",
}

mocker.patch.object(DataRobotHook, "run", return_value=None)
mock_version = mocker.Mock()
mock_version.id = "version-1234"

create_mock = mocker.patch.object(
dr.RegisteredModelVersion, "create_for_external", return_value=mock_version
)

operator = CreateRegisteredModelVersionOperator(
task_id="test_create_model_version_external",
model_version_params=model_version_params,
)

result = operator.execute(context={})

create_mock.assert_called_with(
name="External Model", target="classification", registered_model_id="123456"
)
assert result == "version-1234"