Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace ZMQ #54

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,6 @@ venv*/

dask-worker-space/*
.pre-commit-config.yaml

.o
scaler/io/cpp/build/
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,4 @@ cloudpickle
graphlib-backport; python_version < '3.9'
psutil
pycapnp
pyzmq
tblib
27 changes: 14 additions & 13 deletions scaler/client/agent/client_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import threading
from typing import Optional

import zmq.asyncio

from scaler.client.agent.disconnect_manager import ClientDisconnectManager
from scaler.client.agent.future_manager import ClientFutureManager
from scaler.client.agent.heartbeat_manager import ClientHeartbeatManager
Expand All @@ -28,16 +26,16 @@
from scaler.protocol.python.mixins import Message
from scaler.utility.event_loop import create_async_loop_routine
from scaler.utility.exceptions import ClientCancelledException, ClientQuitException, ClientShutdownException
from scaler.utility.zmq_config import ZMQConfig

from scaler.io.model import Session, ConnectorType, TCPAddress

class ClientAgent(threading.Thread):
def __init__(
self,
identity: bytes,
client_agent_address: ZMQConfig,
scheduler_address: ZMQConfig,
context: zmq.Context,
client_agent_address: TCPAddress,
scheduler_address: TCPAddress,
session: Session,
future_manager: ClientFutureManager,
stop_event: threading.Event,
timeout_seconds: int,
Expand All @@ -54,25 +52,28 @@ def __init__(
self._identity = identity
self._client_agent_address = client_agent_address
self._scheduler_address = scheduler_address
self._context = context
self._session = session

self._future_manager = future_manager

self._connector_internal = AsyncConnector(
context=zmq.asyncio.Context.shadow(self._context),
session=self._session,
name="client_agent_internal",
socket_type=zmq.PAIR,
type_=ConnectorType.Pair,
bind_or_connect="bind",
address=self._client_agent_address,
callback=self.__on_receive_from_client,
identity=None,
)

print("client agent internal connected")

self._connector_external = AsyncConnector(
context=zmq.asyncio.Context.shadow(self._context),
session=self._session,
name="client_agent_external",
socket_type=zmq.DEALER,
address=self._scheduler_address,
type_=ConnectorType.Dealer,
bind_or_connect="connect",
address=self._scheduler_address,
callback=self.__on_receive_from_scheduler,
identity=self._identity,
)
Expand Down Expand Up @@ -191,7 +192,7 @@ async def __get_loops(self):
logging.info("ClientAgent: client quitting")
self._future_manager.set_all_futures_with_exception(exception)
elif isinstance(exception, TimeoutError):
logging.error(f"ClientAgent: client timeout when connecting to {self._scheduler_address.to_address()}")
logging.error(f"ClientAgent: client timeout when connecting to {self._scheduler_address}")
self._future_manager.set_all_futures_with_exception(exception)
else:
raise exception
34 changes: 17 additions & 17 deletions scaler/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@
from collections import Counter
from inspect import signature
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union

import zmq
import zmq.asyncio
import random

from scaler.client.agent.client_agent import ClientAgent
from scaler.client.agent.future_manager import ClientFutureManager
Expand All @@ -26,9 +24,10 @@
from scaler.utility.graph.topological_sorter import TopologicalSorter
from scaler.utility.metadata.profile_result import ProfileResult
from scaler.utility.metadata.task_flags import TaskFlags, retrieve_task_flags_from_task
from scaler.utility.zmq_config import ZMQConfig, ZMQType
from scaler.worker.agent.processor.processor import Processor

from scaler.io.model import Session, ConnectorType, TCPAddress, IntraProcessAddress


@dataclasses.dataclass
class _CallNode:
Expand Down Expand Up @@ -83,23 +82,20 @@ def __initialize__(
self._profiling = profiling
self._identity = f"{os.getpid()}|Client|{uuid.uuid4().bytes.hex()}".encode()

self._client_agent_address = ZMQConfig(ZMQType.inproc, host=f"scaler_client_{uuid.uuid4().hex}")
self._scheduler_address = ZMQConfig.from_string(address)
self._client_agent_address = TCPAddress.localhost(random.randint(10000, 20000)) #InprocAddr(f"scaler_client_{uuid.uuid4().hex}")
self._scheduler_address = TCPAddress.from_str(address)
self._timeout_seconds = timeout_seconds
self._heartbeat_interval_seconds = heartbeat_interval_seconds

self._stop_event = threading.Event()
self._context = zmq.Context()
self._connector = SyncConnector(
context=self._context, socket_type=zmq.PAIR, address=self._client_agent_address, identity=self._identity
)
self._session = Session(2)

self._future_manager = ClientFutureManager(self._serializer)
self._agent = ClientAgent(
identity=self._identity,
client_agent_address=self._client_agent_address,
scheduler_address=ZMQConfig.from_string(address),
context=self._context,
scheduler_address=self._scheduler_address,
session=self._session,
future_manager=self._future_manager,
stop_event=self._stop_event,
timeout_seconds=self._timeout_seconds,
Expand All @@ -108,7 +104,11 @@ def __initialize__(
)
self._agent.start()

logging.info(f"ScalerClient: connect to {self._scheduler_address.to_address()}")
self._connector = SyncConnector(
session=self._session, type_=ConnectorType.Pair, address=self._client_agent_address, identity=self._identity
)

logging.info(f"ScalerClient: connect to {self._scheduler_address}")

self._object_buffer = ObjectBuffer(self._identity, self._serializer, self._connector)
self._future_factory = functools.partial(ScalerFuture, connector=self._connector)
Expand Down Expand Up @@ -158,7 +158,7 @@ def fibonacci(client: Client, n: int):
"""

return {
"address": self._scheduler_address.to_address(),
"address": str(self._scheduler_address),
"profiling": self._profiling,
"timeout_seconds": self._timeout_seconds,
"heartbeat_interval_seconds": self._heartbeat_interval_seconds,
Expand Down Expand Up @@ -326,7 +326,7 @@ def disconnect(self):
self.__destroy()
return

logging.info(f"ScalerClient: disconnect from {self._scheduler_address.to_address()}")
logging.info(f"ScalerClient: disconnect from {self._scheduler_address}")

self._future_manager.cancel_all_futures()

Expand All @@ -353,7 +353,7 @@ def shutdown(self):
self.__destroy()
return

logging.info(f"ScalerClient: request shutdown for {self._scheduler_address.to_address()}")
logging.info(f"ScalerClient: request shutdown for {self._scheduler_address}")

self._future_manager.cancel_all_futures()

Expand Down Expand Up @@ -545,7 +545,7 @@ def __assert_client_not_stopped(self):

def __destroy(self):
self._agent.join()
self._context.destroy(linger=1)
# self._context.destroy(linger=1)

@staticmethod
def __get_parent_task_priority() -> Optional[int]:
Expand Down
7 changes: 4 additions & 3 deletions scaler/cluster/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
from typing import List, Optional, Tuple

from scaler.utility.logging.utility import setup_logger
from scaler.utility.zmq_config import ZMQConfig
from scaler.worker.worker import Worker

from scaler.io.model import TCPAddress


class Cluster(multiprocessing.get_context("spawn").Process): # type: ignore[misc]
def __init__(
self,
address: ZMQConfig,
address: TCPAddress,
worker_io_threads: int,
worker_names: List[str],
heartbeat_interval_seconds: int,
Expand Down Expand Up @@ -55,7 +56,7 @@ def __destroy(self, *args):
logging.info(f"{self.__get_prefix()} received signal, shutting down")
for worker in self._workers:
logging.info(f"{self.__get_prefix()} shutting down worker[{worker.pid}]")
os.kill(worker.pid, signal.SIGINT)
worker.terminate()

def __register_signal(self):
signal.signal(signal.SIGINT, self.__destroy)
Expand Down
6 changes: 3 additions & 3 deletions scaler/cluster/combo.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
DEFAULT_WORKER_DEATH_TIMEOUT,
DEFAULT_WORKER_TIMEOUT_SECONDS,
)
from scaler.utility.zmq_config import ZMQConfig

from scaler.io.model import TCPAddress

class SchedulerClusterCombo:
def __init__(
Expand Down Expand Up @@ -50,7 +50,7 @@ def __init__(
logging_config_file: Optional[str] = None,
):
self._cluster = Cluster(
address=ZMQConfig.from_string(address),
address=TCPAddress.from_str(address),
worker_io_threads=worker_io_threads,
worker_names=[f"{socket.gethostname().split('.')[0]}_{i}" for i in range(n_workers)],
heartbeat_interval_seconds=heartbeat_interval_seconds,
Expand All @@ -65,7 +65,7 @@ def __init__(
logging_config_file=logging_config_file,
)
self._scheduler = SchedulerProcess(
address=ZMQConfig.from_string(address),
address=TCPAddress.from_str(address),
io_threads=scheduler_io_threads,
max_number_of_tasks_waiting=max_number_of_tasks_waiting,
per_worker_queue_size=per_worker_queue_size,
Expand Down
5 changes: 3 additions & 2 deletions scaler/cluster/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
from scaler.scheduler.scheduler import Scheduler, scheduler_main
from scaler.utility.event_loop import register_event_loop
from scaler.utility.logging.utility import setup_logger
from scaler.utility.zmq_config import ZMQConfig

from scaler.io.model import TCPAddress


class SchedulerProcess(multiprocessing.get_context("spawn").Process): # type: ignore[misc]
def __init__(
self,
address: ZMQConfig,
address: TCPAddress,
io_threads: int,
max_number_of_tasks_waiting: int,
per_worker_queue_size: int,
Expand Down
8 changes: 4 additions & 4 deletions scaler/entry_points/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
DEFAULT_WORKER_DEATH_TIMEOUT,
)
from scaler.utility.event_loop import EventLoopType, register_event_loop
from scaler.utility.zmq_config import ZMQConfig
from scaler.io.model import TCPAddress


def get_args():
Expand Down Expand Up @@ -76,10 +76,10 @@ def get_args():
"When set, suspends worker processors using the SIGTSTP signal instead of a synchronization event, "
"fully halting computation on suspended tasks. Note that this may cause some tasks to fail if they "
"do not support being paused at the OS level (e.g. tasks requiring active network connections)."
),
)
)
parser.add_argument(
"--log-hub-address", "-la", default=None, type=ZMQConfig.from_string, help="address for Worker send logs"
"--log-hub-address", "-la", default=None, type=TCPAddress.from_str, help="address for Worker send logs"
)
parser.add_argument(
"--logging-paths",
Expand All @@ -105,7 +105,7 @@ def get_args():
help="use standard python the .conf file the specify python logging file configuration format, this will "
"bypass --logging-paths and --logging-level at the same time, and this will not work on per worker logging",
)
parser.add_argument("address", type=ZMQConfig.from_string, help="scheduler address to connect to")
parser.add_argument("address", type=TCPAddress.from_str, help="scheduler address to connect to")
return parser.parse_args()


Expand Down
7 changes: 4 additions & 3 deletions scaler/entry_points/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@
from scaler.scheduler.scheduler import scheduler_main
from scaler.utility.event_loop import EventLoopType, register_event_loop
from scaler.utility.logging.utility import setup_logger
from scaler.utility.zmq_config import ZMQConfig

from scaler.io.model import TCPAddress


def get_args():
parser = argparse.ArgumentParser("scaler scheduler", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--io-threads", type=int, default=DEFAULT_IO_THREADS, help="number of io threads for zmq")
parser.add_argument("--io-threads", type=int, default=DEFAULT_IO_THREADS, help="number of io threads")
parser.add_argument(
"--max-number-of-tasks-waiting",
"-mt",
Expand Down Expand Up @@ -102,7 +103,7 @@ def get_args():
help="use standard python the .conf file the specify python logging file configuration format, this will "
"bypass --logging-path",
)
parser.add_argument("address", type=ZMQConfig.from_string, help="scheduler address to connect to")
parser.add_argument("address", type=TCPAddress.from_str, help="scheduler address to connect to")
return parser.parse_args()


Expand Down
1 change: 0 additions & 1 deletion scaler/entry_points/top.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
format_percentage,
format_seconds,
)
from scaler.utility.zmq_config import ZMQConfig

SORT_BY_OPTIONS = {
ord("n"): "worker",
Expand Down
Loading