Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions python/ray/serve/_private/replica.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
Generator,
List,
Optional,
Set,
Tuple,
Union,
)
Expand All @@ -37,6 +38,7 @@
from ray._common.filters import CoreContextFilter
from ray._common.utils import get_or_create_event_loop
from ray.actor import ActorClass, ActorHandle
from ray.dag.py_obj_scanner import _PyObjScanner
from ray.remote_function import RemoteFunction
from ray.serve import metrics
from ray.serve._private.common import (
Expand Down Expand Up @@ -113,6 +115,7 @@
DeploymentUnavailableError,
RayServeException,
)
from ray.serve.handle import DeploymentHandle
from ray.serve.schema import EncodingType, LoggingConfig

logger = logging.getLogger(SERVE_LOGGER_NAME)
Expand Down Expand Up @@ -542,6 +545,9 @@ def __init__(
self._user_callable_initialized_lock = asyncio.Lock()
self._initialization_latency: Optional[float] = None

# Track deployment handles created dynamically via get_deployment_handle()
self._dynamically_created_handles: Set[DeploymentID] = set()

# Flipped to `True` when health checks pass and `False` when they fail. May be
# used by replica subclass implementations.
self._healthy = False
Expand Down Expand Up @@ -600,17 +606,26 @@ def get_metadata(self) -> ReplicaMetadata:
route_patterns,
)

def get_dynamically_created_handles(self) -> Set[DeploymentID]:
return self._dynamically_created_handles

def _set_internal_replica_context(
self, *, servable_object: Callable = None, rank: int = None
):
# Calculate world_size from deployment config instead of storing it
world_size = self._deployment_config.num_replicas

# Create callback for registering dynamically created handles
def register_handle_callback(deployment_id: DeploymentID) -> None:
self._dynamically_created_handles.add(deployment_id)

ray.serve.context._set_internal_replica_context(
replica_id=self._replica_id,
servable_object=servable_object,
_deployment_config=self._deployment_config,
rank=rank,
world_size=world_size,
handle_registration_callback=register_handle_callback,
)

def _configure_logger_and_profilers(
Expand Down Expand Up @@ -1204,6 +1219,45 @@ def get_num_ongoing_requests(self) -> int:
"""
return self._replica_impl.get_num_ongoing_requests()

def list_outbound_deployments(self) -> List[DeploymentID]:
"""List all outbound deployment IDs this replica calls into.

This includes:
- Handles created via get_deployment_handle()
- Handles passed as init args/kwargs to the deployment constructor

This is used to determine which deployments are reachable from this replica.
The list of DeploymentIDs can change over time as new handles can be created at runtime.
Also its not guaranteed that the list of DeploymentIDs are identical across replicas
because it depends on user code.

Returns:
A list of DeploymentIDs that this replica calls into.
"""
seen_deployment_ids: Set[DeploymentID] = set()

# First, collect dynamically created handles
for deployment_id in self._replica_impl.get_dynamically_created_handles():
seen_deployment_ids.add(deployment_id)

# Get the init args/kwargs
init_args = self._replica_impl._user_callable_wrapper._init_args
init_kwargs = self._replica_impl._user_callable_wrapper._init_kwargs

# Use _PyObjScanner to find all DeploymentHandle objects in:
# The init_args and init_kwargs (handles might be passed as init args)
scanner = _PyObjScanner(source_type=DeploymentHandle)
try:
handles = scanner.find_nodes((init_args, init_kwargs))

for handle in handles:
deployment_id = handle.deployment_id
seen_deployment_ids.add(deployment_id)
finally:
scanner.clear()
Comment on lines +1249 to +1257
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can be cached, but not super important because list_outbound_deployments will be called infrequently.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the frequency?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

exponential backoff starting from 1s then capped at 10mins


return list(seen_deployment_ids)

async def is_allocated(self) -> str:
"""poke the replica to check whether it's alive.

Expand Down
13 changes: 12 additions & 1 deletion python/ray/serve/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1065,4 +1065,15 @@ async def __call__(self, val: int) -> int:
if _record_telemetry:
ServeUsageTag.SERVE_GET_DEPLOYMENT_HANDLE_API_USED.record("1")

return client.get_handle(deployment_name, app_name, check_exists=_check_exists)
handle: DeploymentHandle = client.get_handle(
deployment_name, app_name, check_exists=_check_exists
)

# Track handle creation if called from within a replica
if (
internal_replica_context is not None
and internal_replica_context._handle_registration_callback is not None
):
internal_replica_context._handle_registration_callback(handle.deployment_id)

return handle
5 changes: 4 additions & 1 deletion python/ray/serve/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import ray
from ray.exceptions import RayActorError
from ray.serve._private.client import ServeControllerClient
from ray.serve._private.common import ReplicaID
from ray.serve._private.common import DeploymentID, ReplicaID
from ray.serve._private.config import DeploymentConfig
from ray.serve._private.constants import (
SERVE_CONTROLLER_NAME,
Expand Down Expand Up @@ -50,6 +50,7 @@ class ReplicaContext:
_deployment_config: DeploymentConfig
rank: int
world_size: int
_handle_registration_callback: Optional[Callable[[DeploymentID], None]] = None

@property
def app_name(self) -> str:
Expand Down Expand Up @@ -114,6 +115,7 @@ def _set_internal_replica_context(
_deployment_config: DeploymentConfig,
rank: int,
world_size: int,
handle_registration_callback: Optional[Callable[[str, str], None]] = None,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Type mismatch in callback for replica context

The handle_registration_callback parameter in _set_internal_replica_context has a type annotation mismatch. It's currently Callable[[str, str], None], but the ReplicaContext field and its actual invocation expect Callable[[DeploymentID], None]. This difference could lead to a runtime type error.

Fix in Cursor Fix in Web

):
global _INTERNAL_REPLICA_CONTEXT
_INTERNAL_REPLICA_CONTEXT = ReplicaContext(
Expand All @@ -122,6 +124,7 @@ def _set_internal_replica_context(
_deployment_config=_deployment_config,
rank=rank,
world_size=world_size,
_handle_registration_callback=handle_registration_callback,
)


Expand Down
1 change: 1 addition & 0 deletions python/ray/serve/tests/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ py_test_module_list(
"test_http_headers.py",
"test_http_routes.py",
"test_https_proxy.py",
"test_list_outbound_deployments.py",
"test_max_replicas_per_node.py",
"test_multiplex.py",
"test_proxy.py",
Expand Down
196 changes: 196 additions & 0 deletions python/ray/serve/tests/test_list_outbound_deployments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
import sys
from typing import List

import pytest

import ray
from ray import serve
from ray.serve._private.common import DeploymentID
from ray.serve._private.constants import SERVE_NAMESPACE
from ray.serve.handle import DeploymentHandle


@serve.deployment
class DownstreamA:
def __call__(self, x: int) -> int:
return x * 2


@serve.deployment
class DownstreamB:
def process(self, x: int) -> int:
return x + 10


@serve.deployment
class UpstreamWithStoredHandles:
def __init__(self, handle_a: DeploymentHandle, handle_b: DeploymentHandle):
self.handle_a = handle_a
self.handle_b = handle_b

async def __call__(self, x: int) -> int:
result_a = await self.handle_a.remote(x)
result_b = await self.handle_b.process.remote(x)
return result_a + result_b


@serve.deployment
class UpstreamWithNestedHandles:
def __init__(self, handles_dict: dict, handles_list: list):
self.handles = handles_dict # {"a": handle_a, "b": handle_b}
self.handle_list = handles_list # [handle_a, handle_b]

async def __call__(self, x: int) -> int:
result_a = await self.handles["a"].remote(x)
result_b = await self.handles["b"].process.remote(x)
return result_a + result_b


@serve.deployment
class DynamicDeployment:
async def __call__(self, x: int, app_name1: str, app_name2: str) -> int:
handle_a = serve.get_deployment_handle("DownstreamA", app_name=app_name1)
handle_b = serve.get_deployment_handle("DownstreamB", app_name=app_name2)
result_a = await handle_a.remote(x)
result_b = await handle_b.process.remote(x)
return result_a + result_b


def get_replica_actor_handle(deployment_name: str, app_name: str):
actors = ray.util.list_named_actors(all_namespaces=True)
replica_actor_name = None
for actor in actors:
# Match pattern: SERVE_REPLICA::{app_name}#{deployment_name}#
if actor["name"].startswith(f"SERVE_REPLICA::{app_name}#{deployment_name}#"):
replica_actor_name = actor["name"]
break

if replica_actor_name is None:
# Debug: print all actor names to help diagnose
all_actors = [a["name"] for a in actors if "SERVE" in a["name"]]
raise RuntimeError(
f"Could not find replica actor for {deployment_name} in app {app_name}. "
f"Available serve actors: {all_actors}"
)

return ray.get_actor(replica_actor_name, namespace=SERVE_NAMESPACE)


@pytest.mark.asyncio
class TestListOutboundDeployments:
"""Test suite for list_outbound_deployments() method."""

async def test_stored_handles_in_init(self, serve_instance):
"""Test listing handles that are passed to __init__ and stored as attributes."""
app_name = "test_stored_handles"

# Build and deploy the app
handle_a = DownstreamA.bind()
handle_b = DownstreamB.bind()
app = UpstreamWithStoredHandles.bind(handle_a, handle_b)

serve.run(app, name=app_name)

# Get the replica actor for the upstream deployment
replica_actor = get_replica_actor_handle("UpstreamWithStoredHandles", app_name)

# Call list_outbound_deployments
outbound_deployments: List[DeploymentID] = ray.get(
replica_actor.list_outbound_deployments.remote()
)

# Verify results
deployment_names = {dep_id.name for dep_id in outbound_deployments}
assert "DownstreamA" in deployment_names
assert "DownstreamB" in deployment_names
assert len(outbound_deployments) == 2

# Verify app names match
for dep_id in outbound_deployments:
assert dep_id.app_name == app_name

async def test_nested_handles_in_dict_and_list(self, serve_instance):
"""Test listing handles stored in nested data structures (dict, list)."""
app_name = "test_nested_handles"

# Build and deploy the app
handle_a = DownstreamA.bind()
handle_b = DownstreamB.bind()
handles_dict = {"a": handle_a, "b": handle_b}
handles_list = [handle_a, handle_b]
app = UpstreamWithNestedHandles.bind(handles_dict, handles_list)

serve.run(app, name=app_name)

# Get the replica actor
replica_actor = get_replica_actor_handle("UpstreamWithNestedHandles", app_name)

# Call list_outbound_deployments
outbound_deployments: List[DeploymentID] = ray.get(
replica_actor.list_outbound_deployments.remote()
)

# Verify results (should find handles despite being in nested structures)
deployment_names = {dep_id.name for dep_id in outbound_deployments}
assert "DownstreamA" in deployment_names
assert "DownstreamB" in deployment_names

# Verify no duplicates (handle_a and handle_b appear in both dict and list)
assert len(outbound_deployments) == 2

async def test_no_handles(self, serve_instance):
"""Test deployment with no outbound handles."""
app_name = "test_no_handles"

# Deploy a simple deployment with no handles
app = DownstreamA.bind()
serve.run(app, name=app_name)

# Get the replica actor
replica_actor = get_replica_actor_handle("DownstreamA", app_name)

# Call list_outbound_deployments
outbound_deployments: List[DeploymentID] = ray.get(
replica_actor.list_outbound_deployments.remote()
)

# Should be empty
assert len(outbound_deployments) == 0

async def test_dynamic_handles(self, serve_instance):
app1 = DownstreamA.bind()
app2 = DownstreamB.bind()
app3 = DynamicDeployment.bind()

serve.run(app1, name="app1", route_prefix="/app1")
serve.run(app2, name="app2", route_prefix="/app2")
handle = serve.run(app3, name="app3", route_prefix="/app3")

# Make requests to trigger dynamic handle creation
# x=1: DownstreamA returns 1*2=2, DownstreamB returns 1+10=11, total=2+11=13
results = [await handle.remote(1, "app1", "app2") for _ in range(10)]
for result in results:
assert result == 13

# Get the replica actor
replica_actor = get_replica_actor_handle("DynamicDeployment", "app3")

# Call list_outbound_deployments
outbound_deployments: List[DeploymentID] = ray.get(
replica_actor.list_outbound_deployments.remote()
)

# Verify results - should include dynamically created handles
deployment_names = {dep_id.name for dep_id in outbound_deployments}
assert "DownstreamA" in deployment_names
assert "DownstreamB" in deployment_names
assert len(outbound_deployments) == 2

# Verify the app names are correct
app_names = {dep_id.app_name for dep_id in outbound_deployments}
assert "app1" in app_names
assert "app2" in app_names


if __name__ == "__main__":
sys.exit(pytest.main(["-v", "-s", __file__]))