Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions tests/test_controller.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import logging
import weakref

import pytest

Expand Down Expand Up @@ -76,6 +77,38 @@ def fn_3():
controller.trigger = fn


def test_weakrefs(controller):
class Obj:
method_call_count = 0
destructor_call_count = 0

def __del__(self):
Obj.destructor_call_count += 1

def fn(self):
Obj.method_call_count += 1
print("Obj.fn called")
return 1

o = Obj()

controller.func.add(weakref.WeakMethod(o.fn))

@controller.add("func")
def fn_1():
return 1.5

controller.func()
assert Obj.method_call_count == 1

del o

assert Obj.destructor_call_count == 1

controller.func()
assert Obj.method_call_count == 1


@pytest.mark.asyncio
async def test_tasks(controller):
@controller.add("async_fn")
Expand Down
36 changes: 36 additions & 0 deletions tests/test_state.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import weakref

import pytest

Expand Down Expand Up @@ -339,3 +340,38 @@ def trigger_side_effect(**_):
print(result)

assert expected == result


def test_weakref():
server = FakeServer()
state = State(commit_fn=server._push_state, hot_reload=True)
state.ready()

class Obj:
method_call_count = 0
destructor_call_count = 0

def __del__(self):
Obj.destructor_call_count += 1

def fn(self, *_args, **_kwargs):
Obj.method_call_count += 1
print("Obj.fn called")
return 1

o = Obj()

state.a = 1

state.change("a")(weakref.WeakMethod(o.fn))

state.a = 2
state.flush()
assert Obj.method_call_count == 1

del o
assert Obj.destructor_call_count == 1

state.a = 3
state.flush()
assert Obj.method_call_count == 1
17 changes: 14 additions & 3 deletions trame_server/controller.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import weakref

from .utils import asynchronous, is_dunder, share
from .utils.hot_reload import reload
Expand All @@ -7,6 +8,14 @@
logger = logging.getLogger(__name__)


def _safe_call(f, *args, **kwargs):
return (
f() and f()(*args, **kwargs)
if isinstance(f, weakref.WeakMethod)
else f(*args, **kwargs)
)


class TriggerCounter:
def __init__(self, init=0):
self._count = init
Expand Down Expand Up @@ -333,17 +342,19 @@ def __call__(self, *args, **kwargs):
else:
f = self.func

result = f(*args, **kwargs)
result = _safe_call(f, *args, **kwargs)

if self.hot_reload:
copy_list = list(map(reload, copy_list))

# Exec added fn after
results = [f(*args, **kwargs) for f in copy_list]
results = [_safe_call(f, *args, **kwargs) for f in copy_list]

# Schedule any task
for task_fn in list(self.task_funcs):
results.append(asynchronous.create_task(task_fn(*args, **kwargs)))
results.append(
asynchronous.create_task(_safe_call(task_fn, *args, **kwargs))
)

# Figure out return
if self.func is None:
Expand Down
14 changes: 12 additions & 2 deletions trame_server/state.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import inspect
import logging
import weakref

from .utils import asynchronous, is_dunder, is_private, share
from .utils.hot_reload import reload
Expand Down Expand Up @@ -162,7 +163,10 @@ def has(self, key):
return result

def setdefault(self, key, value):
"""Set an initial value if the key is not present yet"""
"""
Set an initial value if the key is not present yet
:returns the value in the state for the given key
"""
key = self._translator.translate_key(key)
if key in self._pushed_state:
return self._pushed_state[key]
Expand Down Expand Up @@ -279,7 +283,13 @@ def flush(self):
# Execute state listeners
self._state_listeners.add_all(_keys)
for fn in self._state_listeners:
callback = fn
if isinstance(fn, weakref.WeakMethod):
callback = fn()
if callback is None:
continue
else:
callback = fn

if self._hot_reload:
if not inspect.iscoroutinefunction(callback):
callback = reload(callback)
Expand Down