-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mmm-18320] operator create registered model registered model version #136
Changes from 8 commits
13f72f6
70781fc
bcff644
53783d8
fbf1e3b
3354c8e
dfc9077
f17af95
6dc326d
261d63f
4469ebc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,92 @@ | ||||||||||||||||||||||
from enum import Enum | ||||||||||||||||||||||
from typing import Any | ||||||||||||||||||||||
from typing import Dict | ||||||||||||||||||||||
from typing import Sequence | ||||||||||||||||||||||
|
||||||||||||||||||||||
import datarobot as dr | ||||||||||||||||||||||
from airflow.exceptions import AirflowException | ||||||||||||||||||||||
from airflow.utils.context import Context | ||||||||||||||||||||||
|
||||||||||||||||||||||
from datarobot_provider.hooks.datarobot import DataRobotHook | ||||||||||||||||||||||
from datarobot_provider.operators.base_datarobot_operator import BaseDatarobotOperator | ||||||||||||||||||||||
|
||||||||||||||||||||||
|
||||||||||||||||||||||
class ModelType(Enum): | ||||||||||||||||||||||
LEADERBOARD = "leaderboard" | ||||||||||||||||||||||
CUSTOM = "custom" | ||||||||||||||||||||||
EXTERNAL = "external" | ||||||||||||||||||||||
|
||||||||||||||||||||||
|
||||||||||||||||||||||
class CreateRegisteredModelVersionOperator(BaseDatarobotOperator): | ||||||||||||||||||||||
""" | ||||||||||||||||||||||
Dynamically creates a registered model version using one of three methods: | ||||||||||||||||||||||
- Leaderboard Model | ||||||||||||||||||||||
- Custom Model Version | ||||||||||||||||||||||
- External Model | ||||||||||||||||||||||
|
||||||||||||||||||||||
:param model_version_params: Dictionary with parameters for creating model version. | ||||||||||||||||||||||
:param datarobot_conn_id: Airflow connection ID for DataRobot. | ||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, probably last comment related to the new base class, but I think we can also remove this.
Suggested change
|
||||||||||||||||||||||
""" | ||||||||||||||||||||||
|
||||||||||||||||||||||
template_fields: Sequence[str] = ["model_version_params"] | ||||||||||||||||||||||
template_fields_renderers: dict[str, str] = {} | ||||||||||||||||||||||
template_ext: Sequence[str] = () | ||||||||||||||||||||||
ui_color = "#f4a460" | ||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You should be able to remove this boiler plate code thanks to using the
Suggested change
|
||||||||||||||||||||||
|
||||||||||||||||||||||
def __init__( | ||||||||||||||||||||||
self, | ||||||||||||||||||||||
*, | ||||||||||||||||||||||
model_version_params: Dict[str, Any], | ||||||||||||||||||||||
datarobot_conn_id: str = "datarobot_default", | ||||||||||||||||||||||
**kwargs: Any, | ||||||||||||||||||||||
) -> None: | ||||||||||||||||||||||
super().__init__(**kwargs) | ||||||||||||||||||||||
self.model_version_params = model_version_params | ||||||||||||||||||||||
self.datarobot_conn_id = datarobot_conn_id | ||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here with some boiler plate code - it could be useful to check against this PR's changes for examples:
Suggested change
|
||||||||||||||||||||||
|
||||||||||||||||||||||
def execute(self, context: Context) -> str: | ||||||||||||||||||||||
"""Executes the operator to create a registered model version.""" | ||||||||||||||||||||||
DataRobotHook(datarobot_conn_id=self.datarobot_conn_id).run() | ||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also taken care of in the newer base class 👍
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This happens in the base operator |
||||||||||||||||||||||
model_type = self.model_version_params.get("model_type") | ||||||||||||||||||||||
|
||||||||||||||||||||||
if not model_type: | ||||||||||||||||||||||
raise ValueError("'model_type' must be specified in model_version_params.") | ||||||||||||||||||||||
|
||||||||||||||||||||||
try: | ||||||||||||||||||||||
model_type_enum = ModelType(model_type) | ||||||||||||||||||||||
except ValueError: | ||||||||||||||||||||||
raise AirflowException(f"Invalid model_type: {model_type}") from None | ||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you move these lines into the I think lines should work as is as long as you alter: |
||||||||||||||||||||||
|
||||||||||||||||||||||
model_creation_methods = { | ||||||||||||||||||||||
ModelType.LEADERBOARD: self.create_for_leaderboard, | ||||||||||||||||||||||
ModelType.CUSTOM: self.create_for_custom, | ||||||||||||||||||||||
ModelType.EXTERNAL: self.create_for_external, | ||||||||||||||||||||||
} | ||||||||||||||||||||||
|
||||||||||||||||||||||
create_method = model_creation_methods.get(model_type_enum) | ||||||||||||||||||||||
if not create_method: | ||||||||||||||||||||||
raise AirflowException(f"Unsupported model_type: {model_type}") | ||||||||||||||||||||||
|
||||||||||||||||||||||
extra_params = {k: v for k, v in self.model_version_params.items() if k != "model_type"} | ||||||||||||||||||||||
version = create_method(**extra_params) | ||||||||||||||||||||||
|
||||||||||||||||||||||
self.log.info(f"Successfully created model version: {version.id}") | ||||||||||||||||||||||
return version.id | ||||||||||||||||||||||
|
||||||||||||||||||||||
def create_for_leaderboard(self, **kwargs): | ||||||||||||||||||||||
"""Creates a registered model version from a leaderboard model.""" | ||||||||||||||||||||||
return dr.RegisteredModelVersion.create_for_leaderboard_item( | ||||||||||||||||||||||
model_id=kwargs.pop("model_id"), **kwargs | ||||||||||||||||||||||
) | ||||||||||||||||||||||
|
||||||||||||||||||||||
def create_for_custom(self, **kwargs): | ||||||||||||||||||||||
"""Creates a registered model version for a custom model.""" | ||||||||||||||||||||||
return dr.RegisteredModelVersion.create_for_custom_model_version( | ||||||||||||||||||||||
custom_model_version_id=kwargs.pop("custom_model_version_id"), **kwargs | ||||||||||||||||||||||
) | ||||||||||||||||||||||
|
||||||||||||||||||||||
def create_for_external(self, **kwargs): | ||||||||||||||||||||||
"""Creates a registered model version for an external model.""" | ||||||||||||||||||||||
return dr.RegisteredModelVersion.create_for_external( | ||||||||||||||||||||||
target=kwargs.pop("target"), name=kwargs.pop("name"), **kwargs | ||||||||||||||||||||||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
import datarobot as dr | ||
|
||
from datarobot_provider.hooks.datarobot import DataRobotHook | ||
from datarobot_provider.operators.model_registry import CreateRegisteredModelVersionOperator | ||
|
||
|
||
def test_create_registered_model_version_leaderboard(mocker): | ||
model_version_params = { | ||
"model_type": "leaderboard", | ||
"name": "Test Model", | ||
"model_id": "123456", | ||
"registered_model_name": "Test Registry", | ||
} | ||
|
||
mocker.patch.object(DataRobotHook, "run", return_value=None) | ||
mock_version = mocker.Mock() | ||
mock_version.id = "version-123" | ||
|
||
create_mock = mocker.patch.object( | ||
dr.RegisteredModelVersion, "create_for_leaderboard_item", return_value=mock_version | ||
) | ||
|
||
operator = CreateRegisteredModelVersionOperator( | ||
task_id="test_create_model_version", | ||
model_version_params=model_version_params, | ||
) | ||
|
||
result = operator.execute(context={}) | ||
|
||
create_mock.assert_called_with( | ||
model_id="123456", name="Test Model", registered_model_name="Test Registry" | ||
) | ||
assert result == "version-123" | ||
|
||
|
||
def test_create_registered_model_version_custom(mocker): | ||
model_version_params = { | ||
"model_type": "custom", | ||
"name": "Custom Model", | ||
"custom_model_version_id": "987654", | ||
} | ||
|
||
mocker.patch.object(DataRobotHook, "run", return_value=None) | ||
mock_version = mocker.Mock() | ||
mock_version.id = "version-123" | ||
|
||
create_mock = mocker.patch.object( | ||
dr.RegisteredModelVersion, "create_for_custom_model_version", return_value=mock_version | ||
) | ||
|
||
operator = CreateRegisteredModelVersionOperator( | ||
task_id="test_create_model_version_custom", | ||
model_version_params=model_version_params, | ||
) | ||
|
||
result = operator.execute(context={}) | ||
|
||
create_mock.assert_called_with(custom_model_version_id="987654", name="Custom Model") | ||
assert result == "version-123" | ||
|
||
|
||
def test_create_registered_model_version_external(mocker): | ||
model_version_params = { | ||
"model_type": "external", | ||
"name": "External Model", | ||
"registered_model_id": "123456", | ||
"target": "classification", | ||
} | ||
|
||
mocker.patch.object(DataRobotHook, "run", return_value=None) | ||
mock_version = mocker.Mock() | ||
mock_version.id = "version-1234" | ||
|
||
create_mock = mocker.patch.object( | ||
dr.RegisteredModelVersion, "create_for_external", return_value=mock_version | ||
) | ||
|
||
operator = CreateRegisteredModelVersionOperator( | ||
task_id="test_create_model_version_external", | ||
model_version_params=model_version_params, | ||
) | ||
|
||
result = operator.execute(context={}) | ||
|
||
create_mock.assert_called_with( | ||
name="External Model", target="classification", registered_model_id="123456" | ||
) | ||
assert result == "version-1234" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like this new file is missing a copyright - I had assumed each new file should have that.
CC @mjnitz02 maybe we should be sure (assuming we need copyrights) that we're running a copyright checker as part of PR checks?