Skip to content

Commit f750d93

Browse files
committed
feat(weakref): support weakref.WeakMethod in state.change and ctrl
1 parent 0866ed1 commit f750d93

File tree

4 files changed

+91
-5
lines changed

4 files changed

+91
-5
lines changed

tests/test_controller.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import logging
3+
import weakref
34

45
import pytest
56

@@ -76,6 +77,38 @@ def fn_3():
7677
controller.trigger = fn
7778

7879

80+
def test_weakrefs(controller):
81+
82+
class Obj:
83+
method_call_count = 0
84+
destructor_call_count = 0
85+
86+
def __del__(self):
87+
Obj.destructor_call_count += 1
88+
89+
def fn(self):
90+
Obj.method_call_count += 1
91+
print('Obj.fn called')
92+
return 1
93+
o = Obj()
94+
95+
controller.func.add(weakref.WeakMethod(o.fn))
96+
97+
@controller.add("func")
98+
def fn_1():
99+
return 1.5
100+
101+
controller.func()
102+
assert Obj.method_call_count == 1
103+
104+
del o
105+
106+
assert Obj.destructor_call_count == 1
107+
108+
controller.func()
109+
assert Obj.method_call_count == 1
110+
111+
79112
@pytest.mark.asyncio
80113
async def test_tasks(controller):
81114
@controller.add("async_fn")

tests/test_state.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import weakref
23

34
import pytest
45

@@ -339,3 +340,37 @@ def trigger_side_effect(**_):
339340
print(result)
340341

341342
assert expected == result
343+
344+
345+
def test_weakref():
346+
server = FakeServer()
347+
state = State(commit_fn=server._push_state, hot_reload=True)
348+
state.ready()
349+
350+
class Obj:
351+
method_call_count = 0
352+
destructor_call_count = 0
353+
354+
def __del__(self):
355+
Obj.destructor_call_count += 1
356+
357+
def fn(self, *_args, **_kwargs):
358+
Obj.method_call_count += 1
359+
print('Obj.fn called')
360+
return 1
361+
o = Obj()
362+
363+
state.a = 1
364+
365+
state.change("a")(weakref.WeakMethod(o.fn))
366+
367+
state.a = 2
368+
state.flush()
369+
assert Obj.method_call_count == 1
370+
371+
del o
372+
assert Obj.destructor_call_count == 1
373+
374+
state.a = 3
375+
state.flush()
376+
assert Obj.method_call_count == 1

trame_server/controller.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
import weakref
23

34
from .utils import asynchronous, is_dunder, share
45
from .utils.hot_reload import reload
@@ -7,6 +8,12 @@
78
logger = logging.getLogger(__name__)
89

910

11+
def _safe_call(f, *args, **kwargs):
12+
return (f() and f()(*args, **kwargs)
13+
if isinstance(f, weakref.WeakMethod)
14+
else f(*args, **kwargs))
15+
16+
1017
class TriggerCounter:
1118
def __init__(self, init=0):
1219
self._count = init
@@ -333,17 +340,18 @@ def __call__(self, *args, **kwargs):
333340
else:
334341
f = self.func
335342

336-
result = f(*args, **kwargs)
343+
result = _safe_call(f, *args, **kwargs)
337344

338345
if self.hot_reload:
339346
copy_list = list(map(reload, copy_list))
340347

341348
# Exec added fn after
342-
results = [f(*args, **kwargs) for f in copy_list]
349+
results = [_safe_call(f, *args, **kwargs) for f in copy_list]
343350

344351
# Schedule any task
345352
for task_fn in list(self.task_funcs):
346-
results.append(asynchronous.create_task(task_fn(*args, **kwargs)))
353+
results.append(asynchronous.create_task(
354+
_safe_call(task_fn, *args, **kwargs)))
347355

348356
# Figure out return
349357
if self.func is None:

trame_server/state.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import inspect
22
import logging
3+
import weakref
34

45
from .utils import asynchronous, is_dunder, is_private, share
56
from .utils.hot_reload import reload
@@ -162,7 +163,10 @@ def has(self, key):
162163
return result
163164

164165
def setdefault(self, key, value):
165-
"""Set an initial value if the key is not present yet"""
166+
"""
167+
Set an initial value if the key is not present yet
168+
:returns the value in the state for the given key
169+
"""
166170
key = self._translator.translate_key(key)
167171
if key in self._pushed_state:
168172
return self._pushed_state[key]
@@ -279,7 +283,13 @@ def flush(self):
279283
# Execute state listeners
280284
self._state_listeners.add_all(_keys)
281285
for fn in self._state_listeners:
282-
callback = fn
286+
if isinstance(fn, weakref.WeakMethod):
287+
callback = fn()
288+
if callback is None:
289+
continue
290+
else:
291+
callback = fn
292+
283293
if self._hot_reload:
284294
if not inspect.iscoroutinefunction(callback):
285295
callback = reload(callback)

0 commit comments

Comments
 (0)