Skip to content

Commit

Permalink
[Ray] Ray execution state (#3002)
Browse files Browse the repository at this point in the history
* Ray execution state

* Fix stop pool

* Not to fetch chunk meta when tiling HeadOptimizedDataSource

* Fix

* Fix

* Use named actor for Ray task state

* Improve coverage

* Fix lint

Co-authored-by: 刘宝 <[email protected]>
  • Loading branch information
fyrestone and 刘宝 authored May 7, 2022
1 parent c2e334f commit 03ed810
Show file tree
Hide file tree
Showing 11 changed files with 254 additions and 71 deletions.
4 changes: 1 addition & 3 deletions mars/dataframe/datasource/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,7 @@ def _tile_head(cls, op: "HeadOptimizedDataSource"):
# execute first chunk
yield chunks[:1]

ctx = get_context()
chunk_shape = ctx.get_chunks_meta([chunks[0].key], fields=["shape"])[0]["shape"]

chunk_shape = chunks[0].shape
if chunk_shape[0] == op.nrows:
# the first chunk has enough data
tileds[0]._nsplits = tuple((s,) for s in chunk_shape)
Expand Down
44 changes: 25 additions & 19 deletions mars/deploy/oscar/tests/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ async def test_web_session(create_cluster, config):
)


@pytest.mark.parametrize("config", [{"backend": "mars", "incremental_index": True}])
@pytest.mark.parametrize("config", [{"backend": "mars"}])
def test_sync_execute(config):
session = new_session(
backend=config["backend"], n_cpu=2, web=False, use_uvloop=False
Expand Down Expand Up @@ -518,25 +518,31 @@ def test_sync_execute(config):
assert d is c
assert abs(session.fetch(d) - raw.sum()) < 0.001

# TODO(fyrestone): Remove this when the Ray backend support incremental index.
if config["incremental_index"]:
with tempfile.TemporaryDirectory() as tempdir:
file_path = os.path.join(tempdir, "test.csv")
pdf = pd.DataFrame(
np.random.RandomState(0).rand(100, 10),
columns=[f"col{i}" for i in range(10)],
)
pdf.to_csv(file_path, index=False)

df = md.read_csv(file_path, chunk_bytes=os.stat(file_path).st_size / 5)
result = df.sum(axis=1).execute().fetch()
expected = pd.read_csv(file_path).sum(axis=1)
pd.testing.assert_series_equal(result, expected)
with tempfile.TemporaryDirectory() as tempdir:
file_path = os.path.join(tempdir, "test.csv")
pdf = pd.DataFrame(
np.random.RandomState(0).rand(100, 10),
columns=[f"col{i}" for i in range(10)],
)
pdf.to_csv(file_path, index=False)

df = md.read_csv(file_path, chunk_bytes=os.stat(file_path).st_size / 5)
result = df.head(10).execute().fetch()
expected = pd.read_csv(file_path).head(10)
pd.testing.assert_frame_equal(result, expected)
df = md.read_csv(
file_path,
chunk_bytes=os.stat(file_path).st_size / 5,
incremental_index=True,
)
result = df.sum(axis=1).execute().fetch()
expected = pd.read_csv(file_path).sum(axis=1)
pd.testing.assert_series_equal(result, expected)

df = md.read_csv(
file_path,
chunk_bytes=os.stat(file_path).st_size / 5,
incremental_index=True,
)
result = df.head(10).execute().fetch()
expected = pd.read_csv(file_path).head(10)
pd.testing.assert_frame_equal(result, expected)

for worker_pool in session._session.client._cluster._worker_pools:
_assert_storage_cleaned(
Expand Down
3 changes: 1 addition & 2 deletions mars/deploy/oscar/tests/test_ray_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,7 @@ async def test_iterative_tiling(ray_start_regular_shared2, create_cluster):
await test_local.test_iterative_tiling(create_cluster)


# TODO(fyrestone): Support incremental index in ray backend.
@require_ray
@pytest.mark.parametrize("config", [{"backend": "ray", "incremental_index": False}])
@pytest.mark.parametrize("config", [{"backend": "ray"}])
def test_sync_execute(config):
test_local.test_sync_execute(config)
4 changes: 3 additions & 1 deletion mars/oscar/backends/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,9 @@ async def join(self, timeout: float = None):
async def stop(self):
try:
# clean global router
Router.get_instance().remove_router(self._router)
router = Router.get_instance()
if router is not None:
router.remove_router(self._router)
stop_tasks = []
# stop all servers
stop_tasks.extend([server.stop() for server in self._servers])
Expand Down
12 changes: 12 additions & 0 deletions mars/services/task/execution/mars/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from ..... import oscar as mo
from .....core import ChunkGraph, TileContext
from .....core.context import set_context
from .....core.operand import (
Fetch,
MapReduceOperand,
Expand All @@ -33,6 +34,7 @@
from .....resource import Resource
from .....typing import TileableType, BandType
from .....utils import Timer
from ....context import ThreadedServiceContext
from ....cluster.api import ClusterAPI
from ....lifecycle.api import LifecycleAPI
from ....meta.api import MetaAPI
Expand Down Expand Up @@ -121,6 +123,7 @@ async def create(
task_id=task.task_id,
cluster_api=cluster_api,
)
await cls._init_context(session_id, address)
return cls(
config,
task,
Expand All @@ -142,6 +145,15 @@ async def _get_apis(cls, session_id: str, address: str):
MetaAPI.create(session_id, address),
)

@classmethod
async def _init_context(cls, session_id: str, address: str):
loop = asyncio.get_running_loop()
context = ThreadedServiceContext(
session_id, address, address, address, loop=loop
)
await context.init()
set_context(context)

async def __aenter__(self):
profiling = ProfilingData[self._task.task_id, "general"]
# incref fetch tileables to ensure fetch data not deleted
Expand Down
103 changes: 101 additions & 2 deletions mars/services/task/execution/ray/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,108 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
import inspect
from typing import Union

from .....core.context import Context
from .....utils import implements, lazy_import
from ....context import ThreadedServiceContext

ray = lazy_import("ray")


class RayRemoteObjectManager:
"""The remote object manager in task state actor."""

def __init__(self):
self._named_remote_objects = {}

def create_remote_object(self, name: str, object_cls, *args, **kwargs):
remote_object = object_cls(*args, **kwargs)
self._named_remote_objects[name] = remote_object

def destroy_remote_object(self, name: str):
self._named_remote_objects.pop(name, None)

async def call_remote_object(self, name: str, attr: str, *args, **kwargs):
remote_object = self._named_remote_objects[name]
meth = getattr(remote_object, attr)
async_meth = self._sync_to_async(meth)
return await async_meth(*args, **kwargs)

@staticmethod
@functools.lru_cache(100)
def _sync_to_async(func):
if inspect.iscoroutinefunction(func):
return func
else:

async def async_wrapper(*args, **kwargs):
return func(*args, **kwargs)

return async_wrapper


class _RayRemoteObjectWrapper:
def __init__(self, task_state_actor: "ray.actor.ActorHandle", name: str):
self._task_state_actor = task_state_actor
self._name = name

def __getattr__(self, attr):
def wrap(*args, **kwargs):
r = self._task_state_actor.call_remote_object.remote(
self._name, attr, *args, **kwargs
)
return ray.get(r)

return wrap


class _RayRemoteObjectContext:
def __init__(
self, actor_name_or_handle: Union[str, "ray.actor.ActorHandle"], *args, **kwargs
):
super().__init__(*args, **kwargs)
self._actor_name_or_handle = actor_name_or_handle
self._task_state_actor = None

def _get_task_state_actor(self) -> "ray.actor.ActorHandle":
if self._task_state_actor is None:
if isinstance(self._actor_name_or_handle, ray.actor.ActorHandle):
self._task_state_actor = self._actor_name_or_handle
else:
self._task_state_actor = ray.get_actor(self._actor_name_or_handle)
return self._task_state_actor

@implements(Context.create_remote_object)
def create_remote_object(self, name: str, object_cls, *args, **kwargs):
task_state_actor = self._get_task_state_actor()
task_state_actor.create_remote_object.remote(name, object_cls, *args, **kwargs)
return _RayRemoteObjectWrapper(task_state_actor, name)

@implements(Context.get_remote_object)
def get_remote_object(self, name: str):
task_state_actor = self._get_task_state_actor()
return _RayRemoteObjectWrapper(task_state_actor, name)

@implements(Context.destroy_remote_object)
def destroy_remote_object(self, name: str):
task_state_actor = self._get_task_state_actor()
task_state_actor.destroy_remote_object.remote(name)


# TODO(fyrestone): Implement more APIs for Ray.
class RayExecutionContext(_RayRemoteObjectContext, ThreadedServiceContext):
"""The context for tiling."""

pass


# TODO(fyrestone): Implement more APIs for Ray.
class RayExecutionWorkerContext(_RayRemoteObjectContext, dict):
"""The context for executing operands."""

# TODO(fyrestone): Should implement the mars.core.context.Context.
class RayExecutionContext(dict):
@staticmethod
def new_custom_log_dir():
return None
Loading

0 comments on commit 03ed810

Please sign in to comment.