Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
271 changes: 271 additions & 0 deletions Cargo.lock

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion python/rain/capnp

This file was deleted.

210 changes: 108 additions & 102 deletions python/rain/client/client.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,30 @@
import capnp
import json

from . import rpc
from ..common import RainException, SessionException, TaskException
from ..common.attributes import ObjectInfo, TaskInfo
from ..common.data_instance import DataInstance
from ..common.ids import governor_id_from_capnp, id_from_capnp, id_to_capnp
from .rpc import WsCommunicator, ALL_TASKS_ID
from .data import DataObject
from .session import Session
from .task import Task
from ..common import RainException, SessionException, TaskException
from ..common.attributes import ObjectInfo, TaskInfo
from ..common.data_instance import DataInstance
from ..common.ids import ID

CLIENT_PROTOCOL_VERSION = 1
FETCH_SIZE = 8 << 20 # 8MB


def check_result(sessions, result):
if result.which() == "ok":
if result is None:
return

status = result if isinstance(result, str) else result["status"]

if status == "Ok":
return # Do nothing
elif result.which() == "error":
task_id = id_from_capnp(result.error.task)
elif isinstance(status, list) and status[0] == "Error":
data = status[1]

task_id = ID._from_json(data["task"])
message = []

if task_id.session_id == -1:
Expand All @@ -41,13 +47,13 @@ def check_result(sessions, result):

message.append("Task {} failed".format(task))

message.append("Message: " + result.error.message)
message.append("Message: " + data["message"])

if task:
message.append("Task created at:\n" + task._stack)

if result.error.debug:
message.append("Debug:\n" + result.error.debug)
if data["debug"]:
message.append("Debug:\n" + data["debug"])
message = "\n".join(message)
raise cls(message)
else:
Expand All @@ -61,12 +67,10 @@ class Client:
"""

def __init__(self, address, port):
self._rpc_client = capnp.TwoPartyClient("{}:{}".format(address, port))

bootstrap = self._rpc_client.bootstrap().cast_as(
rpc.server.ServerBootstrap)
registration = bootstrap.registerAsClient(CLIENT_PROTOCOL_VERSION)
self._service = registration.wait().service
self._rpc_client = WsCommunicator(address, port)
self._rpc_client.request("RegisterClient", {
"version": CLIENT_PROTOCOL_VERSION
})

def new_session(self, name="Unnamed Session", default=False):
"""
Expand All @@ -78,7 +82,9 @@ def new_session(self, name="Unnamed Session", default=False):
:class:`Session`: A new session
"""
spec = json.dumps({"name": str(name)})
session_id = self._service.newSession(spec).wait().sessionId
session_id = self._rpc_client.request("NewSession", {
"spec": spec
})["session_id"]
return Session(self, session_id, default)

def get_server_info(self):
Expand All @@ -88,31 +94,30 @@ def get_server_info(self):
Returns:
dict: A JSON-like dictionary.
"""
info = self._service.getServerInfo().wait()
info = self._rpc_client.request("GetServerInfo")["governors"]
return {
"governors": [{"governor_id": governor_id_from_capnp(w.governorId),
"tasks": [id_from_capnp(t) for t in w.tasks],
"objects": [id_from_capnp(o) for o in w.objects],
"objects_to_delete": [id_from_capnp(o) for o in w.objectsToDelete],
"resources": {"cpus": w.resources.nCpus}}
for w in info.governors]
"governors": [{
"governor_id": g["governor_id"],
"tasks": [ID._from_json(id) for id in g["tasks"]],
"objects": [ID._from_json(id) for id in g["objects"]],
"objects_to_delete": [ID._from_json(id) for id in g["objects_to_delete"]],
"resources": g["resources"],
} for g in info]
}

def _submit(self, tasks, dataobjs):
req = self._service.submit_request()

# Serialize tasks print(tasks, dataobjs)

req.init("tasks", len(tasks))
for i in range(len(tasks)):
req.tasks[i].spec = json.dumps(tasks[i].spec._to_json())
# Serialize tasks
tasks_data = [{
"spec": json.dumps(t.spec._to_json())
} for t in tasks]

# Serialize objects
req.init("objects", len(dataobjs))
for i in range(len(dataobjs)):
dataobjs[i]._to_capnp(req.objects[i])
objects_data = [d._to_json() for d in dataobjs]

req.send().wait()
return self._rpc_client.request("Submit", {
"tasks": tasks_data,
"objects": objects_data
})

def _fetch(self, dataobj):
"Fetch the object data and update its state."
Expand All @@ -124,130 +129,131 @@ def _fetch(self, dataobj):
raise RainException(
"Object {} is not submitted.".format(dataobj))

req = self._service.fetch_request()
id_to_capnp(dataobj.id, req.id)
req.offset = 0
req.size = FETCH_SIZE
req.includeInfo = True
result = req.send().wait()
check_result((dataobj._session,), result.status)
msg = {
"id": dataobj.id,
"include_info": True,
"offset": 0,
"size": FETCH_SIZE
}

dataobj._info = ObjectInfo._from_json(json.loads(result.info))
result = self._rpc_client.request("Fetch", msg)
check_result((dataobj._session,), result)

size = result.transportSize
offset = len(result.data)
data = [result.data]
dataobj._info = ObjectInfo._from_json(json.loads(result["info"]))

size = result["transport_size"]
offset = len(result["data"])
data = [bytearray(result["data"])]

while offset < size:
req = self._service.fetch_request()
id_to_capnp(dataobj.id, req.id)
req.offset = offset
req.size = FETCH_SIZE
req.includeInfo = False
r = req.send().wait()
check_result((dataobj._session,), r.status)
data.append(r.data)
offset += len(r.data)
msg = {
"id": dataobj.id,
"include_info": False,
"offset": offset,
"size": FETCH_SIZE
}

result = self._rpc_client.request("Fetch", msg)
check_result((dataobj._session,), result["status"])
data.append(result["data"])
offset += len(result["data"])
rawdata = b"".join(data)

return DataInstance(data=rawdata,
data_object=dataobj,
data_type=dataobj.spec.data_type)

def _wait(self, tasks, dataobjs):
req = self._service.wait_request()

req.init("taskIds", len(tasks))
sessions = []
for i in range(len(tasks)):
task = tasks[i]
if task.state is None:
raise RainException("Task {} is not submitted".format(task))
id_to_capnp(task.id, req.taskIds[i])
sessions.append(task._session)

req.init("objectIds", len(dataobjs))
msg = {
"task_ids": [t.id for t in tasks],
"object_ids": [d.id for d in dataobjs]
}

for i in range(len(dataobjs)):
id_to_capnp(dataobjs[i].id, req.objectIds[i])
sessions.append(dataobjs[i]._session)

result = req.send().wait()
result = self._rpc_client.request("Wait", msg)
check_result(sessions, result)

def _close_session(self, session):
self._service.closeSession(session.session_id).wait()
return self._rpc_client.request("CloseSession", {
"session_id": session.session_id
}, allow_failure=True)

def _wait_some(self, tasks, dataobjs):
req = self._service.waitSome_request()

tasks_dict = {}
req.init("taskIds", len(tasks))
for i in range(len(tasks)):
tasks_dict[tasks[i].id] = tasks[i]
id_to_capnp(tasks[i].id, req.taskIds[i])

dataobjs_dict = {}
req.init("objectIds", len(dataobjs))
for i in range(len(dataobjs)):
dataobjs_dict[dataobjs[i].id] = dataobjs[i]
id_to_capnp(dataobjs[i].id, req.objectIds[i])

finished = req.send().wait()
finished_tasks = [tasks_dict[f_task.id]
for f_task in finished.finishedTasks]
finished_dataobjs = [dataobjs_dict[f_dataobj.id]
for f_dataobj in finished.finishedObjects]
msg = {
"task_ids": [t.id for t in tasks],
"object_ids": [d.id for d in dataobjs]
}

finished = self._rpc_client.request("WaitSome", msg)
finished_tasks = [tasks_dict[ID._from_json(f_task["id"])]
for f_task in finished["finished_tasks"]]
finished_dataobjs = [dataobjs_dict[ID._from_json(f_dataobj["id"])]
for f_dataobj in finished["finished_objects"]]

return finished_tasks, finished_dataobjs

def _wait_all(self, session):
req = self._service.wait_request()
req.init("taskIds", 1)
req.taskIds[0].id = rpc.common.allTasksId
req.taskIds[0].sessionId = session.session_id
result = req.send().wait()
msg = {
"task_ids": [ID(session_id=session.session_id, id=ALL_TASKS_ID)],
"object_ids": []
}

result = self._rpc_client.request("Wait", msg)
check_result((session,), result)

def _unkeep(self, dataobjs):
req = self._service.unkeep_request()

req.init("objectIds", len(dataobjs))
for i in range(len(dataobjs)):
id_to_capnp(dataobjs[i].id, req.objectIds[i])

result = req.send().wait()
result = self._rpc_client.request("Unkeep", {
"object_ids": [d.id for d in dataobjs]
}, allow_failure=True)
check_result([o._session for o in dataobjs], result)

def update(self, items):
tasks, dataobjects = split_items(items)
self._get_state(tasks, dataobjects)

def _get_state(self, tasks, dataobjs):
req = self._service.getState_request()
sessions = []
req.init("taskIds", len(tasks))
for i in range(len(tasks)):
id_to_capnp(tasks[i].id, req.taskIds[i])
sessions.append(tasks[i]._session)

dataobjs_dict = {}
req.init("objectIds", len(dataobjs))
for i in range(len(dataobjs)):
dataobjs_dict[dataobjs[i].id.id] = dataobjs[i]
id_to_capnp(dataobjs[i].id, req.objectIds[i])
sessions.append(dataobjs[i]._session)

results = req.send().wait()
check_result(sessions, results.state)
msg = {
"task_ids": [t.id for t in tasks],
"object_ids": [d.id for d in dataobjs]
}

results = self._rpc_client.request("GetState", msg)["update"]
check_result(sessions, results)

for task_update, task in zip(results.tasks, tasks):
task._state = task_update.state
task._info = TaskInfo._from_json(json.loads(task_update.info))
for task_update, task in zip(results["tasks"], tasks):
task._state = task_update["state"]
task._info = TaskInfo._from_json(json.loads(task_update["info"]))

for object_update in results.objects:
dataobj = dataobjs_dict[object_update.id.id]
dataobj._state = object_update.state
dataobj._info = ObjectInfo._from_json(json.loads(object_update.info))
for object_update in results["objects"]:
dataobj = dataobjs_dict[ID._from_json(object_update["id"]).id]
dataobj._state = object_update["state"]
dataobj._info = ObjectInfo._from_json(json.loads(object_update["info"]))


def split_items(items):
Expand Down
25 changes: 8 additions & 17 deletions python/rain/client/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
import json
import tarfile

import capnp

from ..common import ID, DataType, RainException
from ..common.attributes import ObjectSpec
from ..common.content_type import (check_content_type, encode_value,
Expand Down Expand Up @@ -83,15 +81,13 @@ def is_kept(self):
"""Returns the value of self._keep"""
return self._keep

def _to_capnp(self, out):
out.spec = json.dumps(self._spec._to_json())
out.keep = self._keep

if self._data is not None:
out.data = self._data
out.hasData = True
else:
out.hasData = False
def _to_json(self):
return {
"spec": json.dumps(self._spec._to_json()),
"keep": self._keep,
"has_data": self._data is not None,
"data": b'' if not self._data else self._data
}

def wait(self):
self._session.wait((self,))
Expand All @@ -115,12 +111,7 @@ def update(self):

def __del__(self):
if self.state is not None and self._keep:
try:
self._session.client._unkeep((self,))
except capnp.lib.capnp.KjException:
# Ignore capnp exception, since this constructor may be
# called when connection is closed
pass
self._session.client._unkeep((self,))

def __reduce__(self):
"""Speciaization to replace with executor.unpickle_input_object
Expand Down
Loading