diff --git a/tests/test_controller.py b/tests/test_controller.py index 1bb2dd9..4a6a591 100644 --- a/tests/test_controller.py +++ b/tests/test_controller.py @@ -1,5 +1,6 @@ import asyncio import logging +import weakref import pytest @@ -76,6 +77,38 @@ def fn_3(): controller.trigger = fn +def test_weakrefs(controller): + class Obj: + method_call_count = 0 + destructor_call_count = 0 + + def __del__(self): + Obj.destructor_call_count += 1 + + def fn(self): + Obj.method_call_count += 1 + print("Obj.fn called") + return 1 + + o = Obj() + + controller.func.add(weakref.WeakMethod(o.fn)) + + @controller.add("func") + def fn_1(): + return 1.5 + + controller.func() + assert Obj.method_call_count == 1 + + del o + + assert Obj.destructor_call_count == 1 + + controller.func() + assert Obj.method_call_count == 1 + + @pytest.mark.asyncio async def test_tasks(controller): @controller.add("async_fn") diff --git a/tests/test_state.py b/tests/test_state.py index e2132e1..6220618 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -1,4 +1,5 @@ import asyncio +import weakref import pytest @@ -339,3 +340,38 @@ def trigger_side_effect(**_): print(result) assert expected == result + + +def test_weakref(): + server = FakeServer() + state = State(commit_fn=server._push_state, hot_reload=True) + state.ready() + + class Obj: + method_call_count = 0 + destructor_call_count = 0 + + def __del__(self): + Obj.destructor_call_count += 1 + + def fn(self, *_args, **_kwargs): + Obj.method_call_count += 1 + print("Obj.fn called") + return 1 + + o = Obj() + + state.a = 1 + + state.change("a")(weakref.WeakMethod(o.fn)) + + state.a = 2 + state.flush() + assert Obj.method_call_count == 1 + + del o + assert Obj.destructor_call_count == 1 + + state.a = 3 + state.flush() + assert Obj.method_call_count == 1 diff --git a/trame_server/controller.py b/trame_server/controller.py index ebc47a7..1b18602 100644 --- a/trame_server/controller.py +++ b/trame_server/controller.py @@ -1,4 +1,5 @@ import logging +import weakref from .utils import asynchronous, is_dunder, share from .utils.hot_reload import reload @@ -7,6 +8,14 @@ logger = logging.getLogger(__name__) +def _safe_call(f, *args, **kwargs): + return ( + f() and f()(*args, **kwargs) + if isinstance(f, weakref.WeakMethod) + else f(*args, **kwargs) + ) + + class TriggerCounter: def __init__(self, init=0): self._count = init @@ -333,17 +342,19 @@ def __call__(self, *args, **kwargs): else: f = self.func - result = f(*args, **kwargs) + result = _safe_call(f, *args, **kwargs) if self.hot_reload: copy_list = list(map(reload, copy_list)) # Exec added fn after - results = [f(*args, **kwargs) for f in copy_list] + results = [_safe_call(f, *args, **kwargs) for f in copy_list] # Schedule any task for task_fn in list(self.task_funcs): - results.append(asynchronous.create_task(task_fn(*args, **kwargs))) + results.append( + asynchronous.create_task(_safe_call(task_fn, *args, **kwargs)) + ) # Figure out return if self.func is None: diff --git a/trame_server/state.py b/trame_server/state.py index 67d1465..51c742f 100644 --- a/trame_server/state.py +++ b/trame_server/state.py @@ -1,5 +1,6 @@ import inspect import logging +import weakref from .utils import asynchronous, is_dunder, is_private, share from .utils.hot_reload import reload @@ -162,7 +163,10 @@ def has(self, key): return result def setdefault(self, key, value): - """Set an initial value if the key is not present yet""" + """ + Set an initial value if the key is not present yet + :returns the value in the state for the given key + """ key = self._translator.translate_key(key) if key in self._pushed_state: return self._pushed_state[key] @@ -279,7 +283,13 @@ def flush(self): # Execute state listeners self._state_listeners.add_all(_keys) for fn in self._state_listeners: - callback = fn + if isinstance(fn, weakref.WeakMethod): + callback = fn() + if callback is None: + continue + else: + callback = fn + if self._hot_reload: if not inspect.iscoroutinefunction(callback): callback = reload(callback)