Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mmm-18320] operator create registered model registered model version #136

6 changes: 4 additions & 2 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
## Unreleased

### New features
- Introduce `CreateWranglingRecipeOperator <datarobot_provider.operators.ai_catalog.CreateWranglingRecipeOperator>`
and `CreateDatasetFromRecipeOperator <datarobot_provider.operators.ai_catalog.CreateDatasetFromRecipeOperator>`
- Introduce `CreateRegisteredModelVersionOperator <datarobot_provider.operators.CreateRegisteredModelVersionOperator>`
to create registered models that are generic containers that group multiple versions of models which can be deployed
- Introduce `CreateWranglingRecipeOperator <datarobot_provider.operators.ai_catalog.CreateWranglingRecipeOperator>`
and `CreateDatasetFromRecipeOperator <datarobot_provider.operators.ai_catalog.CreateDatasetFromRecipeOperator>`
to create a wrangling recipe and publish it as a dataset into an existing use case.
- Introduce `CreateDatasetFromProjectOperator <datarobot_provider.operators.ai_catalog.CreateDatasetFromProjectOperator>`
to create datasets from project data.
Expand Down
90 changes: 90 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -1394,6 +1394,96 @@ Example of DAG config params:

For more [prediction-explanations](https://datarobot-public-api-client.readthedocs-hosted.com/en/latest-release/autodoc/api_reference.html?highlight=PredictionExplanationsInitialization#prediction-explanations), see the DataRobot documentation.

---
#### `CreateRegisteredModelVersionOperator`

Dynamically creates a registered model version using one of three methods:
- Leaderboard Model
- Custom Model
- External Model

Parameters:
- **Leaderboard Model**

| Parameter | Type | Description |
|------------------------------------|---------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `model_type` | str | Type of model version to create (leaderboard, custom, or external). |
| `leaderboard_model_id` | str | The ID of the leaderboard model. |
| `name` | Optional[str] | Name of the version (model package). |
| `prediction_threshold` | Optional[float] | Threshold used for binary classification in predictions. |
| `distribution_prediction_model_id` | Optional[str] | ID of the DataRobot distribution prediction model trained on predictions from the DataRobot model. |
| `description` | Optional[str] | Description of the version (model package). |
| `compute_all_ts_intervals` | Optional[bool] | Whether to compute all time series prediction intervals (1-100 percentiles). |
| `registered_model_name` | Optional[str] | Name of the new registered model that will be created from this model package (version).The model package (version) will be created as version 1 of the created registered model. |
| `registered_model_id` | Optional[str] | Creates a model package (version) as a new version for the provided registered model ID.Mutually exclusive with registeredModelName. |
| `tags` | Optional[List[Tag]] | Tags for the registered model version. |
| `registered_model_tags` | Optional[List[Tag]] | Tags for the registered model. |
| `registered_model_description` | Optional[str] | Description for the registered model. |

- **Custom Model**

| Parameter | Type | Description |
|------------------------------------|---------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `model_type` | str | Type of model version to create (leaderboard, custom, or external). |
| `custom_model_id` | str | ID of the custom model version. |
| `name` | Optional[str] | Name of the registered model version. |
| `description` | Optional[str] | Description of the version (model package). |
| `registered_model_name` | Optional[str] | Name of the new registered model that will be created from this model package (version).The model package (version) will be created as version 1 of the created registered model. |
| `registered_model_id` | Optional[str] | Creates a model package (version) as a new version for the provided registered model ID.Mutually exclusive with registeredModelName. |
| `tags` | Optional[List[Tag]] | Tags for the registered model version. |
| `registered_model_tags` | Optional[List[Tag]] | Tags for the registered model. |
| `registered_model_description` | Optional[str] | Description for the registered model. |

- **External Model**

| Parameter | Type | Description |
|-----------------------------------|----------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `model_type` | str | Type of model version to create (leaderboard, custom, or external). |
| `name` | str | Name of the registered model version. |
| `target` | ExternalTarget | Target information for the registered model version. |
| `model_id` | Optional[str] | Model ID of the registered model version. |
| `model_description` | Optional[ModelDescription] | Information about the model. |
| `datasets` | Optional[ExternalDatasets] | Dataset information for the registered model version. |
| `timeseries` | Optional[Timeseries] | Timeseries properties for the registered model version. |
| `registered_model_name` | Optional[str] | Name of the new registered model that will be created from this model package (version).The model package (version) will be created as version 1 of the created registered model. |
| `registered_model_id` | Optional[str] | Creates a model package (version) as a new version for the provided registered model ID.Mutually exclusive with registeredModelName. |
| `tags` | Optional[List[Tag]] | Tags for the registered model version. |
| `registered_model_tags` | Optional[List[Tag]] | Tags for the registered model. |
| `registered_model_description` | Optional[str] | Description for the registered model. |


Example of DAG config params:

Leaderboard Model
```
{
"model_type": "leaderboard",
"leaderboard_model_id": "123456789",
"name": "My Registered Model",
"registered_model_name": "My Model Registry"
}
```

Custom Model
```
{
"model_type": "custom",
"custom_model_id": "987654321",
"name": "Custom Model Version"
}
```

External Model
```
{
"model_type": "external",
"name": "External Model Version",
"target": {"name": "Target", "type": "Regression"},
"registered_model_id": "667c694a772c6e79dd61e6e1",
}
```
For more [Model Registry](https://datarobot-public-api-client.readthedocs-hosted.com/en/latest-release/reference/mlops/model_registry.html#create-registered-model-version), see the DataRobot documentation.

---

### [Sensors](https://github.com/datarobot/airflow-provider-datarobot/blob/main/datarobot_provider/sensors/datarobot.py)
Expand Down
92 changes: 92 additions & 0 deletions datarobot_provider/operators/model_registry.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this new file is missing a copyright - I had assumed each new file should have that.

CC @mjnitz02 maybe we should be sure (assuming we need copyrights) that we're running a copyright checker as part of PR checks?

Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from enum import Enum
from typing import Any
from typing import Dict
from typing import Sequence

import datarobot as dr
from airflow.exceptions import AirflowException
from airflow.utils.context import Context

from datarobot_provider.hooks.datarobot import DataRobotHook
from datarobot_provider.operators.base_datarobot_operator import BaseDatarobotOperator


class ModelType(Enum):
LEADERBOARD = "leaderboard"
CUSTOM = "custom"
EXTERNAL = "external"


class CreateRegisteredModelVersionOperator(BaseDatarobotOperator):
"""
Dynamically creates a registered model version using one of three methods:
- Leaderboard Model
- Custom Model Version
- External Model

:param model_version_params: Dictionary with parameters for creating model version.
:param datarobot_conn_id: Airflow connection ID for DataRobot.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, probably last comment related to the new base class, but I think we can also remove this.

Suggested change
:param datarobot_conn_id: Airflow connection ID for DataRobot.

"""

template_fields: Sequence[str] = ["model_version_params"]
template_fields_renderers: dict[str, str] = {}
template_ext: Sequence[str] = ()
ui_color = "#f4a460"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should be able to remove this boiler plate code thanks to using the BaseDatarobotOperator 😎

Suggested change
template_fields_renderers: dict[str, str] = {}
template_ext: Sequence[str] = ()
ui_color = "#f4a460"


def __init__(
self,
*,
model_version_params: Dict[str, Any],
datarobot_conn_id: str = "datarobot_default",
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.model_version_params = model_version_params
self.datarobot_conn_id = datarobot_conn_id
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here with some boiler plate code - it could be useful to check against this PR's changes for examples:
https://github.com/datarobot/airflow-provider-datarobot/pull/135/files

Suggested change
datarobot_conn_id: str = "datarobot_default",
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.model_version_params = model_version_params
self.datarobot_conn_id = datarobot_conn_id
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.model_version_params = model_version_params


def execute(self, context: Context) -> str:
"""Executes the operator to create a registered model version."""
DataRobotHook(datarobot_conn_id=self.datarobot_conn_id).run()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also taken care of in the newer base class 👍

Suggested change
DataRobotHook(datarobot_conn_id=self.datarobot_conn_id).run()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This happens in the base operator

model_type = self.model_version_params.get("model_type")

if not model_type:
raise ValueError("'model_type' must be specified in model_version_params.")

try:
model_type_enum = ModelType(model_type)
except ValueError:
raise AirflowException(f"Invalid model_type: {model_type}") from None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you move these lines into the validate() method now provided by the newer base class?

I think lines should work as is as long as you alter:
if not model_type:
to be:
if not self.model_version_params.get("model_type"):


model_creation_methods = {
ModelType.LEADERBOARD: self.create_for_leaderboard,
ModelType.CUSTOM: self.create_for_custom,
ModelType.EXTERNAL: self.create_for_external,
}

create_method = model_creation_methods.get(model_type_enum)
if not create_method:
raise AirflowException(f"Unsupported model_type: {model_type}")

extra_params = {k: v for k, v in self.model_version_params.items() if k != "model_type"}
version = create_method(**extra_params)

self.log.info(f"Successfully created model version: {version.id}")
return version.id

def create_for_leaderboard(self, **kwargs):
"""Creates a registered model version from a leaderboard model."""
return dr.RegisteredModelVersion.create_for_leaderboard_item(
model_id=kwargs.pop("model_id"), **kwargs
)

def create_for_custom(self, **kwargs):
"""Creates a registered model version for a custom model."""
return dr.RegisteredModelVersion.create_for_custom_model_version(
custom_model_version_id=kwargs.pop("custom_model_version_id"), **kwargs
)

def create_for_external(self, **kwargs):
"""Creates a registered model version for an external model."""
return dr.RegisteredModelVersion.create_for_external(
target=kwargs.pop("target"), name=kwargs.pop("name"), **kwargs
)
88 changes: 88 additions & 0 deletions tests/unit/operators/test_model_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import datarobot as dr

from datarobot_provider.hooks.datarobot import DataRobotHook
from datarobot_provider.operators.model_registry import CreateRegisteredModelVersionOperator


def test_create_registered_model_version_leaderboard(mocker):
model_version_params = {
"model_type": "leaderboard",
"name": "Test Model",
"model_id": "123456",
"registered_model_name": "Test Registry",
}

mocker.patch.object(DataRobotHook, "run", return_value=None)
mock_version = mocker.Mock()
mock_version.id = "version-123"

create_mock = mocker.patch.object(
dr.RegisteredModelVersion, "create_for_leaderboard_item", return_value=mock_version
)

operator = CreateRegisteredModelVersionOperator(
task_id="test_create_model_version",
model_version_params=model_version_params,
)

result = operator.execute(context={})

create_mock.assert_called_with(
model_id="123456", name="Test Model", registered_model_name="Test Registry"
)
assert result == "version-123"


def test_create_registered_model_version_custom(mocker):
model_version_params = {
"model_type": "custom",
"name": "Custom Model",
"custom_model_version_id": "987654",
}

mocker.patch.object(DataRobotHook, "run", return_value=None)
mock_version = mocker.Mock()
mock_version.id = "version-123"

create_mock = mocker.patch.object(
dr.RegisteredModelVersion, "create_for_custom_model_version", return_value=mock_version
)

operator = CreateRegisteredModelVersionOperator(
task_id="test_create_model_version_custom",
model_version_params=model_version_params,
)

result = operator.execute(context={})

create_mock.assert_called_with(custom_model_version_id="987654", name="Custom Model")
assert result == "version-123"


def test_create_registered_model_version_external(mocker):
model_version_params = {
"model_type": "external",
"name": "External Model",
"registered_model_id": "123456",
"target": "classification",
}

mocker.patch.object(DataRobotHook, "run", return_value=None)
mock_version = mocker.Mock()
mock_version.id = "version-1234"

create_mock = mocker.patch.object(
dr.RegisteredModelVersion, "create_for_external", return_value=mock_version
)

operator = CreateRegisteredModelVersionOperator(
task_id="test_create_model_version_external",
model_version_params=model_version_params,
)

result = operator.execute(context={})

create_mock.assert_called_with(
name="External Model", target="classification", registered_model_id="123456"
)
assert result == "version-1234"