diff --git a/examples/child-server/multi_server_router.py b/examples/child-server/multi_server_router.py index a48743f..268bfe0 100644 --- a/examples/child-server/multi_server_router.py +++ b/examples/child-server/multi_server_router.py @@ -9,12 +9,16 @@ class FirstApp(TrameApp): - def __init__(self, server: trame_server.Server | str | None = None) -> None: + def __init__( + self, + server: trame_server.Server | str | None = None, + template_name: str = "main", + ) -> None: super().__init__(server) self.state.test = "first" - self._build_ui() + self._build_ui(template_name) @trigger("test") def test_trigger(self) -> None: @@ -24,21 +28,25 @@ def test_trigger(self) -> None: def test_controller(self) -> None: print(self.state.test) - def _build_ui(self) -> None: - with VAppLayout(self.server, full_height=True), v3.VContainer(), v3.VCard( - title="This is the first app" - ): + def _build_ui(self, template_name) -> None: + with VAppLayout( + self.server, template_name=template_name, full_height=True + ), v3.VContainer(), v3.VCard(title="This is the first app"): v3.VBtn("Test Trigger", click="console.log(test); trame.trigger('test');") v3.VBtn("Test Controller", click=self.ctrl.test_controller) class SecondApp(TrameApp): - def __init__(self, server: trame_server.Server | str | None = None) -> None: + def __init__( + self, + server: trame_server.Server | str | None = None, + template_name: str = "main", + ) -> None: super().__init__(server) self.state.test = "second" - self._build_ui() + self._build_ui(template_name) @trigger("test") def trigger_test(self) -> None: @@ -48,10 +56,10 @@ def trigger_test(self) -> None: def test_controller(self) -> None: print(self.state.test) - def _build_ui(self) -> None: - with VAppLayout(self.server, full_height=True), v3.VContainer(), v3.VCard( - title="This is the second app" - ): + def _build_ui(self, template_name) -> None: + with VAppLayout( + self.server, template_name=template_name, full_height=True + ), v3.VContainer(), v3.VCard(title="This is the second app"): v3.VBtn("Test Trigger", click="console.log(test); trame.trigger('test');") v3.VBtn("Test Controller", click=self.ctrl.test_controller) @@ -70,10 +78,17 @@ def test_trigger(self) -> None: def _build_ui(self) -> None: # Register routes - with RouterViewLayout(self.server, "/"): - FirstApp(self.server.create_child_server(prefix="first_route_")) - with RouterViewLayout(self.server, "/second"): - SecondApp(self.server.create_child_server(prefix="second_route_")) + first_layout = RouterViewLayout(self.server, "/") + second_layout = RouterViewLayout(self.server, "/second") + + FirstApp( + self.server.create_child_server(prefix="first_route_"), + template_name=first_layout.template_name, + ) + SecondApp( + self.server.create_child_server(prefix="second_route_"), + template_name=second_layout.template_name, + ) with SinglePageLayout(self.server, full_height=True) as layout: with layout.toolbar: diff --git a/trame_server/state.py b/trame_server/state.py index e16456c..405c95e 100644 --- a/trame_server/state.py +++ b/trame_server/state.py @@ -1,6 +1,7 @@ import inspect import logging import weakref +from contextlib import contextmanager from .utils import asynchronous, is_dunder, is_private, share from .utils.hot_reload import reload @@ -15,6 +16,32 @@ TRAME_NON_INIT_VALUE = "__trame__: non_init_value_that_is_not_None" +class StateStatus: + """ + Tracks status flags for a State. + """ + + def __init__(self, flushing: bool = False, ready: bool = False): + self.flushing = flushing + self.ready = ready + + def mark_ready(self): + self.ready = True + + @property + def skip_flushing(self) -> bool: + return self.flushing or not self.ready + + @contextmanager + def flushing_context(self): + """Context manager for flushing state safely.""" + self.flushing = True + try: + yield + finally: + self.flushing = False + + class StateChangeHandler: def __init__(self, listeners): self._all_listeners = listeners @@ -67,38 +94,30 @@ def __init__( self._state_listeners = share( internal, "_state_listeners", StateChangeHandler(self._change_callbacks) ) + self._status = share(internal, "_status", StateStatus(ready=ready)) self._parent_state = internal self._children_state = [] - self._ready_flag = ready if internal: internal._children_state.append(self) - def ready(self) -> None: - """Mark the state as ready for synchronization.""" - if self._ready_flag: - return - - self._ready_flag = True - self.flush() - - if self._parent_state: - self._parent_state.ready() - - for child in self._children_state: - child.ready() - @property def is_ready(self) -> bool: """Return True is the instance is ready for synchronization, False otherwise.""" - if self._parent_state: - return self._parent_state.is_ready - return self._ready_flag + return self._status.ready @property def translator(self) -> Translator: """Return the translator instance used to namespace the variable names.""" return self._translator + def ready(self) -> None: + """Mark the state as ready for synchronization.""" + if self.is_ready: + return + + self._status.mark_ready() + self.flush() + def __getitem__(self, key): key = self._translator.translate_key(key) return self._pending_update.get(key, self._pushed_state.get(key)) @@ -267,6 +286,43 @@ def modified_keys(self): # for child server we may need to run the translator on them return self._modified_keys + def _flush_pending_keys(self) -> set[str]: + _keys = set(self._pending_update.keys()) + + # update modified keys for current update batch + self._modified_keys.clear() + self._modified_keys |= _keys + + # Do the flush + if self._push_state_fn: + self._push_state_fn(self._pending_update) + self._pushed_state.update(self._pending_update) + self._pending_update.clear() + + # Execute state listeners + self._state_listeners.add_all(_keys) + for fn, translator in self._state_listeners: + if isinstance(fn, weakref.WeakMethod): + callback = fn() + if callback is None: + continue + else: + callback = fn + + if self._hot_reload: + if not inspect.iscoroutinefunction(callback): + callback = reload(callback) + + reverse_translated_state = translator.reverse_translate_dict( + self._pushed_state + ) + coroutine = callback(**reverse_translated_state) + if inspect.isawaitable(coroutine): + asynchronous.create_task(coroutine) + + self._state_listeners.clear() + return _keys + def flush(self): """ Force pushing modified state and execute any @state.change listener @@ -274,51 +330,13 @@ def flush(self): previous value or if `dirty` has been flagged on the variable and it has not been unflagged since. """ - if not self.is_ready: + if self._status.skip_flushing: return None keys = set() - if len(self._pending_update): - _keys = set(self._pending_update.keys()) - - while len(_keys): - keys |= _keys - - # update modified keys for current update batch - self._modified_keys.clear() - self._modified_keys |= _keys - - # Do the flush - if self._push_state_fn: - self._push_state_fn(self._pending_update) - self._pushed_state.update(self._pending_update) - self._pending_update.clear() - - # Execute state listeners - self._state_listeners.add_all(_keys) - for fn, translator in self._state_listeners: - if isinstance(fn, weakref.WeakMethod): - callback = fn() - if callback is None: - continue - else: - callback = fn - - if self._hot_reload: - if not inspect.iscoroutinefunction(callback): - callback = reload(callback) - - reverse_translated_state = translator.reverse_translate_dict( - self._pushed_state - ) - coroutine = callback(**reverse_translated_state) - if inspect.isawaitable(coroutine): - asynchronous.create_task(coroutine) - - self._state_listeners.clear() - - # Check if state change from state listeners - _keys = set(self._pending_update.keys()) + with self._status.flushing_context(): + while bool(self._pending_update): + keys |= self._flush_pending_keys() return keys