Skip to content

Commit 03cd247

Browse files
committed
Test decorated actions
Using a decorator on an action will result in a callable that doesn't have a signature, and instead has a __wrapped__ attribute. These tests check that doesn't cause problems with LabThings. test_locking_decorator.py is a more realistic test of decorators that anticipates their use to lock functions.
1 parent 46ff8e6 commit 03cd247

File tree

2 files changed

+195
-0
lines changed

2 files changed

+195
-0
lines changed

tests/test_actions.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from fastapi.testclient import TestClient
22
import pytest
3+
import functools
34
from .temp_client import poll_task, get_link
45
from labthings_fastapi.example_things import MyThing
56
import labthings_fastapi as lt
@@ -94,3 +95,61 @@ def test_openapi():
9495
with TestClient(server.app) as client:
9596
r = client.get("/openapi.json")
9697
r.raise_for_status()
98+
99+
100+
def example_decorator(func):
101+
"""Decorate a function using functools.wraps."""
102+
103+
@functools.wraps(func)
104+
def action_wrapper(*args, **kwargs):
105+
result = func(*args, **kwargs)
106+
return result
107+
108+
return action_wrapper
109+
110+
111+
def assert_input_models_equivalent(model_a, model_b):
112+
"""Check two basemodels are equivalent."""
113+
keys = list(model_a.model_fields.keys())
114+
assert list(model_b.model_fields.keys()) == keys
115+
116+
for k in keys:
117+
field_a = model_a.model_fields[k]
118+
field_b = model_b.model_fields[k]
119+
assert str(field_a.annotation) == str(field_b.annotation)
120+
assert field_a.default == field_b.default
121+
122+
123+
def test_wrapped_action():
124+
"""Check functools.wraps does not confuse schema generation"""
125+
126+
class Example(lt.Thing):
127+
@lt.thing_action
128+
def action(
129+
self,
130+
portal: lt.deps.BlockingPortal,
131+
param1: int = 0,
132+
param2: str = "string",
133+
) -> float | None:
134+
"""An example action with type annotations."""
135+
return 0.5
136+
137+
@lt.thing_action
138+
@example_decorator
139+
def decorated(
140+
self,
141+
portal: lt.deps.BlockingPortal,
142+
param1: int = 0,
143+
param2: str = "string",
144+
) -> float | None:
145+
"""An example decorated action with type annotations."""
146+
return 0.5
147+
148+
assert_input_models_equivalent(
149+
Example.action.input_model, Example.decorated.input_model
150+
)
151+
assert Example.action.output_model == Example.decorated.output_model
152+
153+
# Check we can make the thing and it has a valid TD
154+
example = Example()
155+
example.validate_thing_description()

tests/test_locking_decorator.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
import time
2+
from typing import Callable, TypeVar, ParamSpec
3+
import functools
4+
from threading import RLock, Event, Thread
5+
6+
from fastapi.testclient import TestClient
7+
import pytest
8+
9+
import labthings_fastapi as lt
10+
from .temp_client import poll_task
11+
12+
13+
Value = TypeVar("Value")
14+
Params = ParamSpec("Params")
15+
16+
17+
def requires_lock(func: Callable[Params, Value]) -> Callable[Params, Value]:
18+
"""Decorate an action to require a lock."""
19+
20+
@functools.wraps(func)
21+
def locked_func(*args, **kwargs):
22+
lock: RLock = args[0]._lock
23+
if not lock.acquire(blocking=False):
24+
raise TimeoutError("Could not lock.")
25+
try:
26+
return func(*args, **kwargs)
27+
finally:
28+
lock.release()
29+
30+
return locked_func
31+
32+
33+
class LockedExample(lt.Thing):
34+
"""A Thing where only one action may happen at a time."""
35+
36+
flag: bool = lt.property(default=False)
37+
38+
def __init__(self):
39+
"""Initialise the lock."""
40+
self._lock = RLock() # This lock is used by @requires_lock
41+
self._event = Event() # This is used to keep tests quick
42+
# by stopping waits as soon as they are no longer needed
43+
44+
@lt.thing_action
45+
@requires_lock
46+
def wait_wrapper(self, time: float = 1) -> None:
47+
"""Wait a specified time, calling wait_with_flag.
48+
49+
This lets us check the RLock correctly allows one locked
50+
function to call another.
51+
"""
52+
self.wait_with_flag(time)
53+
54+
@lt.thing_action
55+
@requires_lock
56+
def echo(self, message: str) -> str:
57+
"""Echo a message back to the sender."""
58+
return message
59+
60+
@lt.thing_action
61+
@requires_lock
62+
def wait_with_flag(self, time: float = 1) -> None:
63+
"""Wait a specified time with the flag True."""
64+
assert self.flag is False
65+
self.flag = True
66+
self._event.wait(time)
67+
self.flag = False
68+
69+
70+
@pytest.fixture
71+
def thing(mocker) -> LockedExample:
72+
"""Instantiate the LockedExample thing."""
73+
thing = LockedExample()
74+
thing._labthings_blocking_portal = mocker.Mock()
75+
return thing
76+
77+
78+
def test_echo(thing: LockedExample) -> None:
79+
"""Check the example function works.
80+
81+
Having this in a test function means if it raises an
82+
exception, we can be sure it's because of the lock, and
83+
not just a typo in the test.
84+
"""
85+
assert thing.echo("test") == "test"
86+
87+
88+
def wait_for_flag(thing: LockedExample) -> None:
89+
"""Wait until the flag is set, so we know the lock is acquired."""
90+
while not thing.flag:
91+
time.sleep(0.001)
92+
93+
94+
def test_locking(thing: LockedExample) -> None:
95+
"""Check the lock prevents concurrent access."""
96+
thread = Thread(target=thing.wait_wrapper)
97+
thread.start()
98+
wait_for_flag(thing)
99+
with pytest.raises(TimeoutError):
100+
# This should fail because the lock is acquired
101+
test_echo(thing)
102+
thing._event.set() # tell the thread to stop
103+
thread.join()
104+
# Check the lock is now released - other actions should work
105+
test_echo(thing)
106+
107+
108+
def echo_via_client(client):
109+
"""Use a POST request to run the echo action."""
110+
r = client.post("/thing/echo", json={"message": "test"})
111+
r.raise_for_status()
112+
return poll_task(client, r.json())
113+
114+
115+
def test_locking_in_server():
116+
"""Check the lock works within LabThings."""
117+
server = lt.ThingServer()
118+
thing = LockedExample()
119+
server.add_thing(thing, "/thing")
120+
with TestClient(server.app) as client:
121+
# Start a long task
122+
r1 = client.post("/thing/wait_wrapper", json={})
123+
# Wait for it to start
124+
while client.get("/thing/flag").json() is not True:
125+
time.sleep(0.01)
126+
# Try another action and check it fails
127+
inv2 = echo_via_client(client)
128+
assert inv2["status"] == "error"
129+
# Instruct the first task to stop
130+
thing._event.set() # stop the first action
131+
inv1 = poll_task(client, r1.json()) # wait for it to complete
132+
assert inv1["status"] == "completed" # check there's no error
133+
# This action should succeed now
134+
inv3 = echo_via_client(client)
135+
assert inv3["status"] == "completed"
136+
assert inv3["output"] == "test"

0 commit comments

Comments
 (0)