Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add get_additional_upstream_nodes to FlytekitPlugin #2708

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion flytekit/configuration/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,20 @@
```
"""

from typing import Optional, Protocol, runtime_checkable
from contextvars import Context
from typing import TYPE_CHECKING, List, Optional, Protocol, runtime_checkable

from click import Group
from importlib_metadata import entry_points

from flytekit.configuration import Config, get_config_file
from flytekit.core.node import Node
from flytekit.loggers import logger
from flytekit.remote import FlyteRemote

if TYPE_CHECKING:
from flytekit.core.promise import SupportsNodeCreation


@runtime_checkable
class FlytekitPluginProtocol(Protocol):
Expand Down Expand Up @@ -90,6 +95,10 @@ def get_auth_success_html(endpoint: str) -> Optional[str]:
"""Get default success html. Return None to use flytekit's default success html."""
return None

@staticmethod
def get_additional_upstream_nodes(ctx: Context, entity: "SupportsNodeCreation") -> List[Node]:
return []


def _get_plugin_from_entrypoint():
"""Get plugin from entrypoint."""
Expand Down
5 changes: 5 additions & 0 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -1160,6 +1160,11 @@ def create_and_link_node(
# These will be our core Nodes until we can amend the Promise to use NodeOutputs that reference our Nodes
upstream_nodes = list(set([n for n in nodes if n.id != _common_constants.GLOBAL_INPUT_NODE_ID]))

from flytekit.configuration.plugin import get_plugin

additional_nodes = get_plugin().get_additional_upstream_nodes(ctx, entity)
upstream_nodes.extend(additional_nodes)

flytekit_node = Node(
# TODO: Better naming, probably a derivative of the function name.
id=f"{ctx.compilation_state.prefix}n{len(ctx.compilation_state.nodes)}",
Expand Down
Loading