Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions python/ray/serve/_private/deployment_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,9 @@ def __init__(
self._last_record_routing_stats_time: float = 0.0
self._ingress: bool = False

# Outbound deployments polling state
self._outbound_deployments: Optional[List[DeploymentID]] = None

@property
def replica_id(self) -> str:
return self._replica_id
Expand Down Expand Up @@ -775,6 +778,7 @@ def check_ready(self) -> Tuple[ReplicaStartupStatus, Optional[str]]:
self._grpc_port,
self._rank,
self._route_patterns,
self._outbound_deployments,
) = ray.get(self._ready_obj_ref)
except RayTaskError as e:
logger.exception(
Expand Down Expand Up @@ -1047,6 +1051,9 @@ def force_stop(self, log_shutdown_message: bool = False):
except ValueError:
pass

def get_outbound_deployments(self) -> Optional[List[DeploymentID]]:
return self._outbound_deployments


class DeploymentReplica:
"""Manages state transitions for deployment replicas.
Expand Down Expand Up @@ -1327,6 +1334,9 @@ def resource_requirements(self) -> Tuple[str, str]:
# https://github.com/ray-project/ray/issues/26210 for the issue.
return json.dumps(required), json.dumps(available)

def get_outbound_deployments(self) -> Optional[List[DeploymentID]]:
return self._actor.get_outbound_deployments()


class ReplicaStateContainer:
"""Container for mapping ReplicaStates to lists of DeploymentReplicas."""
Expand Down Expand Up @@ -3093,6 +3103,27 @@ def _stop_one_running_replica_for_testing(self):
def is_ingress(self) -> bool:
return self._target_state.info.ingress

def get_outbound_deployments(self) -> Optional[List[DeploymentID]]:
"""Get the outbound deployments.

Returns:
Sorted list of deployment IDs that this deployment calls. None if
outbound deployments are not yet polled.
"""
result: Set[DeploymentID] = set()
has_outbound_deployments = False
for replica in self._replicas.get([ReplicaState.RUNNING]):
if replica.version != self._target_state.version:
# Only consider replicas of the target version
continue
outbound_deployments = replica.get_outbound_deployments()
if outbound_deployments is not None:
result.update(outbound_deployments)
has_outbound_deployments = True
if not has_outbound_deployments:
return None
return sorted(result, key=lambda d: (d.name))


class DeploymentStateManager:
"""Manages all state for deployments in the system.
Expand Down Expand Up @@ -3701,3 +3732,21 @@ def _get_replica_ranks_mapping(self, deployment_id: DeploymentID) -> Dict[str, i
return {}

return deployment_state._get_replica_ranks_mapping()

def get_deployment_outbound_deployments(
self, deployment_id: DeploymentID
) -> Optional[List[DeploymentID]]:
"""Get the cached outbound deployments for a specific deployment.

Args:
deployment_id: The deployment ID to get outbound deployments for.

Returns:
List of deployment IDs that this deployment calls, or None if
the deployment doesn't exist or hasn't been polled yet.
"""
deployment_state = self._deployment_states.get(deployment_id)
if deployment_state is None:
return None

return deployment_state.get_outbound_deployments()
83 changes: 44 additions & 39 deletions python/ray/serve/_private/replica.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@
int,
int, # rank
Optional[List[str]], # route_patterns
Optional[List[DeploymentID]], # outbound_deployments
]


Expand Down Expand Up @@ -604,11 +605,51 @@ def get_metadata(self) -> ReplicaMetadata:
self._grpc_port,
current_rank,
route_patterns,
self.list_outbound_deployments(),
)

def get_dynamically_created_handles(self) -> Set[DeploymentID]:
return self._dynamically_created_handles

def list_outbound_deployments(self) -> List[DeploymentID]:
"""List all outbound deployment IDs this replica calls into.

This includes:
- Handles created via get_deployment_handle()
- Handles passed as init args/kwargs to the deployment constructor

This is used to determine which deployments are reachable from this replica.
The list of DeploymentIDs can change over time as new handles can be created at runtime.
Also its not guaranteed that the list of DeploymentIDs are identical across replicas
because it depends on user code.

Returns:
A list of DeploymentIDs that this replica calls into.
"""
seen_deployment_ids: Set[DeploymentID] = set()

# First, collect dynamically created handles
for deployment_id in self.get_dynamically_created_handles():
seen_deployment_ids.add(deployment_id)

# Get the init args/kwargs
init_args = self._user_callable_wrapper._init_args
init_kwargs = self._user_callable_wrapper._init_kwargs

# Use _PyObjScanner to find all DeploymentHandle objects in:
# The init_args and init_kwargs (handles might be passed as init args)
scanner = _PyObjScanner(source_type=DeploymentHandle)
try:
handles = scanner.find_nodes((init_args, init_kwargs))

for handle in handles:
deployment_id = handle.deployment_id
seen_deployment_ids.add(deployment_id)
finally:
scanner.clear()

return list(seen_deployment_ids)

def _set_internal_replica_context(
self, *, servable_object: Callable = None, rank: int = None
):
Expand Down Expand Up @@ -1219,45 +1260,6 @@ def get_num_ongoing_requests(self) -> int:
"""
return self._replica_impl.get_num_ongoing_requests()

def list_outbound_deployments(self) -> List[DeploymentID]:
"""List all outbound deployment IDs this replica calls into.

This includes:
- Handles created via get_deployment_handle()
- Handles passed as init args/kwargs to the deployment constructor

This is used to determine which deployments are reachable from this replica.
The list of DeploymentIDs can change over time as new handles can be created at runtime.
Also its not guaranteed that the list of DeploymentIDs are identical across replicas
because it depends on user code.

Returns:
A list of DeploymentIDs that this replica calls into.
"""
seen_deployment_ids: Set[DeploymentID] = set()

# First, collect dynamically created handles
for deployment_id in self._replica_impl.get_dynamically_created_handles():
seen_deployment_ids.add(deployment_id)

# Get the init args/kwargs
init_args = self._replica_impl._user_callable_wrapper._init_args
init_kwargs = self._replica_impl._user_callable_wrapper._init_kwargs

# Use _PyObjScanner to find all DeploymentHandle objects in:
# The init_args and init_kwargs (handles might be passed as init args)
scanner = _PyObjScanner(source_type=DeploymentHandle)
try:
handles = scanner.find_nodes((init_args, init_kwargs))

for handle in handles:
deployment_id = handle.deployment_id
seen_deployment_ids.add(deployment_id)
finally:
scanner.clear()

return list(seen_deployment_ids)

async def is_allocated(self) -> str:
"""poke the replica to check whether it's alive.

Expand All @@ -1281,6 +1283,9 @@ async def is_allocated(self) -> str:
get_component_logger_file_path(),
)

def list_outbound_deployments(self) -> Optional[List[DeploymentID]]:
return self._replica_impl.list_outbound_deployments()

async def initialize_and_get_metadata(
self, deployment_config: DeploymentConfig = None, _after: Optional[Any] = None
) -> ReplicaMetadata:
Expand Down
4 changes: 2 additions & 2 deletions python/ray/serve/tests/test_controller_recovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __call__(self, *args):
replica_version_hash = None
for replica in deployment_dict[id]:
ref = replica.get_actor_handle().initialize_and_get_metadata.remote()
_, version, _, _, _, _, _, _, _ = ray.get(ref)
_, version, _, _, _, _, _, _, _, _ = ray.get(ref)
if replica_version_hash is None:
replica_version_hash = hash(version)
assert replica_version_hash == hash(version), (
Expand Down Expand Up @@ -118,7 +118,7 @@ def __call__(self, *args):
for replica_name in recovered_replica_names:
actor_handle = ray.get_actor(replica_name, namespace=SERVE_NAMESPACE)
ref = actor_handle.initialize_and_get_metadata.remote()
_, version, _, _, _, _, _, _, _ = ray.get(ref)
_, version, _, _, _, _, _, _, _, _ = ray.get(ref)
assert replica_version_hash == hash(
version
), "Replica version hash should be the same after recover from actor names"
Expand Down
141 changes: 141 additions & 0 deletions python/ray/serve/tests/unit/test_deployment_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,9 @@ def check_health(self):
def get_routing_stats(self) -> Dict[str, Any]:
return {}

def get_outbound_deployments(self) -> Optional[List[DeploymentID]]:
return getattr(self, "_outbound_deployments", None)

@property
def route_patterns(self) -> Optional[List[str]]:
return None
Expand Down Expand Up @@ -5600,5 +5603,143 @@ def test_rank_assignment_with_replica_failures(self, mock_deployment_state_manag
}, f"Expected ranks [0, 1, 2], got {ranks_mapping.values()}"


class TestGetOutboundDeployments:
def test_basic_outbound_deployments(self, mock_deployment_state_manager):
"""Test that outbound deployments are returned."""
create_dsm, _, _, _ = mock_deployment_state_manager
dsm: DeploymentStateManager = create_dsm()

deployment_id = DeploymentID(name="test_deployment", app_name="test_app")
b_info_1, _ = deployment_info(num_replicas=1)
dsm.deploy(deployment_id, b_info_1)

# Create a RUNNING replica
ds = dsm._deployment_states[deployment_id]
dsm.update() # Transitions to STARTING
for replica in ds._replicas.get([ReplicaState.STARTING]):
replica._actor.set_ready()
dsm.update() # Transitions to RUNNING

# Set outbound deployments on the mock replica
running_replicas = ds._replicas.get([ReplicaState.RUNNING])
assert len(running_replicas) == 1
d1 = DeploymentID(name="dep1", app_name="test_app")
d2 = DeploymentID(name="dep2", app_name="test_app")
running_replicas[0]._actor._outbound_deployments = [d1, d2]

outbound_deployments = ds.get_outbound_deployments()
assert outbound_deployments == [d1, d2]

# Verify it's accessible through DeploymentStateManager
assert dsm.get_deployment_outbound_deployments(deployment_id) == [
d1,
d2,
]

def test_deployment_state_manager_returns_none_for_nonexistent_deployment(
self, mock_deployment_state_manager
):
"""Test that DeploymentStateManager returns None for nonexistent deployments."""
(
create_dsm,
timer,
cluster_node_info_cache,
autoscaling_state_manager,
) = mock_deployment_state_manager
dsm = create_dsm()

deployment_id = DeploymentID(name="nonexistent", app_name="test_app")
assert dsm.get_deployment_outbound_deployments(deployment_id) is None

def test_returns_none_if_replicas_are_not_running(
self, mock_deployment_state_manager
):
"""Test that DeploymentStateManager returns None if replicas are not running."""
create_dsm, _, _, _ = mock_deployment_state_manager
dsm: DeploymentStateManager = create_dsm()

deployment_id = DeploymentID(name="test_deployment", app_name="test_app")
b_info_1, _ = deployment_info(num_replicas=2)
dsm.deploy(deployment_id, b_info_1)
ds = dsm._deployment_states[deployment_id]
dsm.update()
replicas = ds._replicas.get([ReplicaState.STARTING])
assert len(replicas) == 2
d1 = DeploymentID(name="dep1", app_name="test_app")
d2 = DeploymentID(name="dep2", app_name="test_app")
d3 = DeploymentID(name="dep3", app_name="test_app")
d4 = DeploymentID(name="dep4", app_name="test_app")
replicas[0]._actor._outbound_deployments = [d1, d2]
replicas[1]._actor._outbound_deployments = [d3, d4]
dsm.update()

outbound_deployments = ds.get_outbound_deployments()
assert outbound_deployments is None

# Set replicas ready
replicas[0]._actor.set_ready()
dsm.update()
outbound_deployments = ds.get_outbound_deployments()
assert outbound_deployments == [d1, d2]

def test_only_considers_replicas_matching_target_version(
self, mock_deployment_state_manager
):
"""Test that only replicas with target version are considered.

When a new version is deployed, old version replicas that are still
running should not be included in the outbound deployments result.
"""
create_dsm, _, _, _ = mock_deployment_state_manager
dsm: DeploymentStateManager = create_dsm()

# Deploy version 1
b_info_1, v1 = deployment_info(version="1")
dsm.deploy(TEST_DEPLOYMENT_ID, b_info_1)
ds = dsm._deployment_states[TEST_DEPLOYMENT_ID]
dsm.update()

# Get v1 replica to RUNNING state
ds._replicas.get()[0]._actor.set_ready()
dsm.update()

# Set outbound deployments for v1 replica
d1 = DeploymentID(name="dep1", app_name="test_app")
d2 = DeploymentID(name="dep2", app_name="test_app")
ds._replicas.get()[0]._actor._outbound_deployments = [d1, d2]

# Verify v1 outbound deployments are returned
assert ds.get_outbound_deployments() == [d1, d2]

# Deploy version 2 - this triggers rolling update
b_info_2, v2 = deployment_info(version="2")
dsm.deploy(TEST_DEPLOYMENT_ID, b_info_2)
dsm.update()

# Now we have v1 stopping and v2 starting
check_counts(
ds,
total=2,
by_state=[(ReplicaState.STOPPING, 1, v1), (ReplicaState.STARTING, 1, v2)],
)

# Key test: Even though v1 replica exists (stopping), it should not be
# included because target version is v2. Since v2 is not RUNNING yet,
# should return None.
assert ds.get_outbound_deployments() is None

# Set outbound deployments for v2 replica and mark it ready
d3 = DeploymentID(name="dep3", app_name="test_app")
ds._replicas.get(states=[ReplicaState.STARTING])[
0
]._actor._outbound_deployments = [d3]
ds._replicas.get(states=[ReplicaState.STARTING])[0]._actor.set_ready()
dsm.update()

# Now v2 is running. Should only return v2's outbound deployments (d3),
# not v1's outbound deployments (d1, d2).
assert ds.get_outbound_deployments() == [d3]


if __name__ == "__main__":
sys.exit(pytest.main(["-v", "-s", __file__]))