Skip to content

Commit 99b9ed6

Browse files
authored
Merge pull request #134 from stealthrocket/remote-functions
Remote endpoints
2 parents ada9942 + e90a64c commit 99b9ed6

File tree

4 files changed

+34
-25
lines changed

4 files changed

+34
-25
lines changed

src/dispatch/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import dispatch.integrations
66
from dispatch.coroutine import all, any, call, gather, race
7-
from dispatch.function import DEFAULT_API_URL, Client
7+
from dispatch.function import DEFAULT_API_URL, Client, Registry
88
from dispatch.id import DispatchID
99
from dispatch.proto import Call, Error, Input, Output
1010
from dispatch.status import Status
@@ -23,4 +23,5 @@
2323
"all",
2424
"any",
2525
"race",
26+
"Registry",
2627
]

src/dispatch/fastapi.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,7 @@ def __init__(
115115
"request verification is disabled because DISPATCH_VERIFICATION_KEY is not set"
116116
)
117117

118-
self.client = Client(api_key=api_key, api_url=api_url)
119-
super().__init__(endpoint, self.client)
118+
super().__init__(endpoint, api_key=api_key, api_url=api_url)
120119

121120
function_service = _new_app(self, verification_key)
122121
app.mount("/dispatch.sdk.v1.FunctionService", function_service)
@@ -225,7 +224,7 @@ async def execute(request: fastapi.Request):
225224
raise _ConnectError(400, "invalid_argument", "function is required")
226225

227226
try:
228-
func = function_registry._functions[req.function]
227+
func = function_registry.functions[req.function]
229228
except KeyError:
230229
logger.debug("function '%s' not found", req.function)
231230
raise _ConnectError(

src/dispatch/function.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,6 @@ def __init__(
100100
client: Client,
101101
name: str,
102102
primitive_func: PrimitiveFunctionType,
103-
func: Callable,
104103
):
105104
PrimitiveFunction.__init__(self, endpoint, client, name, primitive_func)
106105

@@ -158,21 +157,30 @@ def build_call(
158157

159158

160159
class Registry:
161-
"""Registry of local functions."""
160+
"""Registry of functions."""
162161

163-
__slots__ = ("_functions", "_endpoint", "_client")
162+
__slots__ = ("functions", "endpoint", "client")
164163

165-
def __init__(self, endpoint: str, client: Client):
166-
"""Initialize a local function registry.
164+
def __init__(
165+
self, endpoint: str, api_key: str | None = None, api_url: str | None = None
166+
):
167+
"""Initialize a function registry.
167168
168169
Args:
169170
endpoint: URL of the endpoint that the function is accessible from.
170-
client: Client for the Dispatch API. Used to dispatch calls to
171-
local functions.
171+
172+
api_key: Dispatch API key to use for authentication when
173+
dispatching calls to functions. Uses the value of the
174+
DISPATCH_API_KEY environment variable by default.
175+
176+
api_url: The URL of the Dispatch API to use when dispatching calls
177+
to functions. Uses the value of the DISPATCH_API_URL environment
178+
variable if set, otherwise defaults to the public Dispatch API
179+
(DEFAULT_API_URL).
172180
"""
173-
self._functions: Dict[str, PrimitiveFunction] = {}
174-
self._endpoint = endpoint
175-
self._client = client
181+
self.functions: Dict[str, PrimitiveFunction] = {}
182+
self.endpoint = endpoint
183+
self.client = Client(api_key=api_key, api_url=api_url)
176184

177185
@overload
178186
def function(self, func: Callable[P, Coroutine[Any, Any, T]]) -> Function[P, T]: ...
@@ -215,9 +223,7 @@ def primitive_func(input: Input) -> Output:
215223
primitive_func.__qualname__ = f"{name}_primitive"
216224
primitive_func = durable(primitive_func)
217225

218-
wrapped_func = Function[P, T](
219-
self._endpoint, self._client, name, primitive_func, func
220-
)
226+
wrapped_func = Function[P, T](self.endpoint, self.client, name, primitive_func)
221227
self._register(name, wrapped_func)
222228
return wrapped_func
223229

@@ -228,20 +234,20 @@ def primitive_function(
228234
name = primitive_func.__qualname__
229235
logger.info("registering primitive function: %s", name)
230236
wrapped_func = PrimitiveFunction(
231-
self._endpoint, self._client, name, primitive_func
237+
self.endpoint, self.client, name, primitive_func
232238
)
233239
self._register(name, wrapped_func)
234240
return wrapped_func
235241

236242
def _register(self, name: str, wrapped_func: PrimitiveFunction):
237-
if name in self._functions:
243+
if name in self.functions:
238244
raise ValueError(f"function already registered with name '{name}'")
239-
self._functions[name] = wrapped_func
245+
self.functions[name] = wrapped_func
240246

241247
def set_client(self, client: Client):
242-
"""Set the Client instance used to dispatch calls to local functions."""
243-
self._client = client
244-
for fn in self._functions.values():
248+
"""Set the Client instance used to dispatch calls to registered functions."""
249+
self.client = client
250+
for fn in self.functions.values():
245251
fn._client = client
246252

247253

tests/dispatch/test_function.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,11 @@
66

77
class TestFunction(unittest.TestCase):
88
def setUp(self):
9-
self.client = Client(api_url="http://dispatch.com", api_key="foobar")
10-
self.dispatch = Registry(endpoint="http://example.com", client=self.client)
9+
self.dispatch = Registry(
10+
endpoint="http://example.com",
11+
api_url="http://dispatch.com",
12+
api_key="foobar",
13+
)
1114

1215
def test_serializable(self):
1316
@self.dispatch.function

0 commit comments

Comments
 (0)