diff --git a/python/ray/serve/_private/replica.py b/python/ray/serve/_private/replica.py index fcb882fdb19..c468fb55f1f 100644 --- a/python/ray/serve/_private/replica.py +++ b/python/ray/serve/_private/replica.py @@ -22,6 +22,7 @@ Generator, List, Optional, + Set, Tuple, Union, ) @@ -37,6 +38,7 @@ from ray._common.filters import CoreContextFilter from ray._common.utils import get_or_create_event_loop from ray.actor import ActorClass, ActorHandle +from ray.dag.py_obj_scanner import _PyObjScanner from ray.remote_function import RemoteFunction from ray.serve import metrics from ray.serve._private.common import ( @@ -113,6 +115,7 @@ DeploymentUnavailableError, RayServeException, ) +from ray.serve.handle import DeploymentHandle from ray.serve.schema import EncodingType, LoggingConfig logger = logging.getLogger(SERVE_LOGGER_NAME) @@ -542,6 +545,9 @@ def __init__( self._user_callable_initialized_lock = asyncio.Lock() self._initialization_latency: Optional[float] = None + # Track deployment handles created dynamically via get_deployment_handle() + self._dynamically_created_handles: Set[DeploymentID] = set() + # Flipped to `True` when health checks pass and `False` when they fail. May be # used by replica subclass implementations. self._healthy = False @@ -600,17 +606,26 @@ def get_metadata(self) -> ReplicaMetadata: route_patterns, ) + def get_dynamically_created_handles(self) -> Set[DeploymentID]: + return self._dynamically_created_handles + def _set_internal_replica_context( self, *, servable_object: Callable = None, rank: int = None ): # Calculate world_size from deployment config instead of storing it world_size = self._deployment_config.num_replicas + + # Create callback for registering dynamically created handles + def register_handle_callback(deployment_id: DeploymentID) -> None: + self._dynamically_created_handles.add(deployment_id) + ray.serve.context._set_internal_replica_context( replica_id=self._replica_id, servable_object=servable_object, _deployment_config=self._deployment_config, rank=rank, world_size=world_size, + handle_registration_callback=register_handle_callback, ) def _configure_logger_and_profilers( @@ -1204,6 +1219,45 @@ def get_num_ongoing_requests(self) -> int: """ return self._replica_impl.get_num_ongoing_requests() + def list_outbound_deployments(self) -> List[DeploymentID]: + """List all outbound deployment IDs this replica calls into. + + This includes: + - Handles created via get_deployment_handle() + - Handles passed as init args/kwargs to the deployment constructor + + This is used to determine which deployments are reachable from this replica. + The list of DeploymentIDs can change over time as new handles can be created at runtime. + Also its not guaranteed that the list of DeploymentIDs are identical across replicas + because it depends on user code. + + Returns: + A list of DeploymentIDs that this replica calls into. + """ + seen_deployment_ids: Set[DeploymentID] = set() + + # First, collect dynamically created handles + for deployment_id in self._replica_impl.get_dynamically_created_handles(): + seen_deployment_ids.add(deployment_id) + + # Get the init args/kwargs + init_args = self._replica_impl._user_callable_wrapper._init_args + init_kwargs = self._replica_impl._user_callable_wrapper._init_kwargs + + # Use _PyObjScanner to find all DeploymentHandle objects in: + # The init_args and init_kwargs (handles might be passed as init args) + scanner = _PyObjScanner(source_type=DeploymentHandle) + try: + handles = scanner.find_nodes((init_args, init_kwargs)) + + for handle in handles: + deployment_id = handle.deployment_id + seen_deployment_ids.add(deployment_id) + finally: + scanner.clear() + + return list(seen_deployment_ids) + async def is_allocated(self) -> str: """poke the replica to check whether it's alive. diff --git a/python/ray/serve/api.py b/python/ray/serve/api.py index 30560f99502..d18bbd3d0e6 100644 --- a/python/ray/serve/api.py +++ b/python/ray/serve/api.py @@ -1065,4 +1065,15 @@ async def __call__(self, val: int) -> int: if _record_telemetry: ServeUsageTag.SERVE_GET_DEPLOYMENT_HANDLE_API_USED.record("1") - return client.get_handle(deployment_name, app_name, check_exists=_check_exists) + handle: DeploymentHandle = client.get_handle( + deployment_name, app_name, check_exists=_check_exists + ) + + # Track handle creation if called from within a replica + if ( + internal_replica_context is not None + and internal_replica_context._handle_registration_callback is not None + ): + internal_replica_context._handle_registration_callback(handle.deployment_id) + + return handle diff --git a/python/ray/serve/context.py b/python/ray/serve/context.py index b6ad7bb2a68..3986430d4d1 100644 --- a/python/ray/serve/context.py +++ b/python/ray/serve/context.py @@ -13,7 +13,7 @@ import ray from ray.exceptions import RayActorError from ray.serve._private.client import ServeControllerClient -from ray.serve._private.common import ReplicaID +from ray.serve._private.common import DeploymentID, ReplicaID from ray.serve._private.config import DeploymentConfig from ray.serve._private.constants import ( SERVE_CONTROLLER_NAME, @@ -50,6 +50,7 @@ class ReplicaContext: _deployment_config: DeploymentConfig rank: int world_size: int + _handle_registration_callback: Optional[Callable[[DeploymentID], None]] = None @property def app_name(self) -> str: @@ -114,6 +115,7 @@ def _set_internal_replica_context( _deployment_config: DeploymentConfig, rank: int, world_size: int, + handle_registration_callback: Optional[Callable[[str, str], None]] = None, ): global _INTERNAL_REPLICA_CONTEXT _INTERNAL_REPLICA_CONTEXT = ReplicaContext( @@ -122,6 +124,7 @@ def _set_internal_replica_context( _deployment_config=_deployment_config, rank=rank, world_size=world_size, + _handle_registration_callback=handle_registration_callback, ) diff --git a/python/ray/serve/tests/BUILD.bazel b/python/ray/serve/tests/BUILD.bazel index 06b4005ad62..1f1a2909a18 100644 --- a/python/ray/serve/tests/BUILD.bazel +++ b/python/ray/serve/tests/BUILD.bazel @@ -119,6 +119,7 @@ py_test_module_list( "test_http_headers.py", "test_http_routes.py", "test_https_proxy.py", + "test_list_outbound_deployments.py", "test_max_replicas_per_node.py", "test_multiplex.py", "test_proxy.py", diff --git a/python/ray/serve/tests/test_list_outbound_deployments.py b/python/ray/serve/tests/test_list_outbound_deployments.py new file mode 100644 index 00000000000..26f44485092 --- /dev/null +++ b/python/ray/serve/tests/test_list_outbound_deployments.py @@ -0,0 +1,196 @@ +import sys +from typing import List + +import pytest + +import ray +from ray import serve +from ray.serve._private.common import DeploymentID +from ray.serve._private.constants import SERVE_NAMESPACE +from ray.serve.handle import DeploymentHandle + + +@serve.deployment +class DownstreamA: + def __call__(self, x: int) -> int: + return x * 2 + + +@serve.deployment +class DownstreamB: + def process(self, x: int) -> int: + return x + 10 + + +@serve.deployment +class UpstreamWithStoredHandles: + def __init__(self, handle_a: DeploymentHandle, handle_b: DeploymentHandle): + self.handle_a = handle_a + self.handle_b = handle_b + + async def __call__(self, x: int) -> int: + result_a = await self.handle_a.remote(x) + result_b = await self.handle_b.process.remote(x) + return result_a + result_b + + +@serve.deployment +class UpstreamWithNestedHandles: + def __init__(self, handles_dict: dict, handles_list: list): + self.handles = handles_dict # {"a": handle_a, "b": handle_b} + self.handle_list = handles_list # [handle_a, handle_b] + + async def __call__(self, x: int) -> int: + result_a = await self.handles["a"].remote(x) + result_b = await self.handles["b"].process.remote(x) + return result_a + result_b + + +@serve.deployment +class DynamicDeployment: + async def __call__(self, x: int, app_name1: str, app_name2: str) -> int: + handle_a = serve.get_deployment_handle("DownstreamA", app_name=app_name1) + handle_b = serve.get_deployment_handle("DownstreamB", app_name=app_name2) + result_a = await handle_a.remote(x) + result_b = await handle_b.process.remote(x) + return result_a + result_b + + +def get_replica_actor_handle(deployment_name: str, app_name: str): + actors = ray.util.list_named_actors(all_namespaces=True) + replica_actor_name = None + for actor in actors: + # Match pattern: SERVE_REPLICA::{app_name}#{deployment_name}# + if actor["name"].startswith(f"SERVE_REPLICA::{app_name}#{deployment_name}#"): + replica_actor_name = actor["name"] + break + + if replica_actor_name is None: + # Debug: print all actor names to help diagnose + all_actors = [a["name"] for a in actors if "SERVE" in a["name"]] + raise RuntimeError( + f"Could not find replica actor for {deployment_name} in app {app_name}. " + f"Available serve actors: {all_actors}" + ) + + return ray.get_actor(replica_actor_name, namespace=SERVE_NAMESPACE) + + +@pytest.mark.asyncio +class TestListOutboundDeployments: + """Test suite for list_outbound_deployments() method.""" + + async def test_stored_handles_in_init(self, serve_instance): + """Test listing handles that are passed to __init__ and stored as attributes.""" + app_name = "test_stored_handles" + + # Build and deploy the app + handle_a = DownstreamA.bind() + handle_b = DownstreamB.bind() + app = UpstreamWithStoredHandles.bind(handle_a, handle_b) + + serve.run(app, name=app_name) + + # Get the replica actor for the upstream deployment + replica_actor = get_replica_actor_handle("UpstreamWithStoredHandles", app_name) + + # Call list_outbound_deployments + outbound_deployments: List[DeploymentID] = ray.get( + replica_actor.list_outbound_deployments.remote() + ) + + # Verify results + deployment_names = {dep_id.name for dep_id in outbound_deployments} + assert "DownstreamA" in deployment_names + assert "DownstreamB" in deployment_names + assert len(outbound_deployments) == 2 + + # Verify app names match + for dep_id in outbound_deployments: + assert dep_id.app_name == app_name + + async def test_nested_handles_in_dict_and_list(self, serve_instance): + """Test listing handles stored in nested data structures (dict, list).""" + app_name = "test_nested_handles" + + # Build and deploy the app + handle_a = DownstreamA.bind() + handle_b = DownstreamB.bind() + handles_dict = {"a": handle_a, "b": handle_b} + handles_list = [handle_a, handle_b] + app = UpstreamWithNestedHandles.bind(handles_dict, handles_list) + + serve.run(app, name=app_name) + + # Get the replica actor + replica_actor = get_replica_actor_handle("UpstreamWithNestedHandles", app_name) + + # Call list_outbound_deployments + outbound_deployments: List[DeploymentID] = ray.get( + replica_actor.list_outbound_deployments.remote() + ) + + # Verify results (should find handles despite being in nested structures) + deployment_names = {dep_id.name for dep_id in outbound_deployments} + assert "DownstreamA" in deployment_names + assert "DownstreamB" in deployment_names + + # Verify no duplicates (handle_a and handle_b appear in both dict and list) + assert len(outbound_deployments) == 2 + + async def test_no_handles(self, serve_instance): + """Test deployment with no outbound handles.""" + app_name = "test_no_handles" + + # Deploy a simple deployment with no handles + app = DownstreamA.bind() + serve.run(app, name=app_name) + + # Get the replica actor + replica_actor = get_replica_actor_handle("DownstreamA", app_name) + + # Call list_outbound_deployments + outbound_deployments: List[DeploymentID] = ray.get( + replica_actor.list_outbound_deployments.remote() + ) + + # Should be empty + assert len(outbound_deployments) == 0 + + async def test_dynamic_handles(self, serve_instance): + app1 = DownstreamA.bind() + app2 = DownstreamB.bind() + app3 = DynamicDeployment.bind() + + serve.run(app1, name="app1", route_prefix="/app1") + serve.run(app2, name="app2", route_prefix="/app2") + handle = serve.run(app3, name="app3", route_prefix="/app3") + + # Make requests to trigger dynamic handle creation + # x=1: DownstreamA returns 1*2=2, DownstreamB returns 1+10=11, total=2+11=13 + results = [await handle.remote(1, "app1", "app2") for _ in range(10)] + for result in results: + assert result == 13 + + # Get the replica actor + replica_actor = get_replica_actor_handle("DynamicDeployment", "app3") + + # Call list_outbound_deployments + outbound_deployments: List[DeploymentID] = ray.get( + replica_actor.list_outbound_deployments.remote() + ) + + # Verify results - should include dynamically created handles + deployment_names = {dep_id.name for dep_id in outbound_deployments} + assert "DownstreamA" in deployment_names + assert "DownstreamB" in deployment_names + assert len(outbound_deployments) == 2 + + # Verify the app names are correct + app_names = {dep_id.app_name for dep_id in outbound_deployments} + assert "app1" in app_names + assert "app2" in app_names + + +if __name__ == "__main__": + sys.exit(pytest.main(["-v", "-s", __file__]))