Skip to content

Commit 1a197d9

Browse files
committed
refactor(core): move ButtonRequest-related code into ButtonRequestHandler
Also, minimize context-related access in `ui.Layout`. [no changelog]
1 parent 195382b commit 1a197d9

2 files changed

Lines changed: 105 additions & 98 deletions

File tree

core/src/trezor/ui/__init__.py

Lines changed: 44 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from trezor import io, log, loop, utils, wire, workflow
88
from trezor.messages import ButtonRequest
99
from trezor.wire import context
10-
from trezor.wire.protocol_common import Context
10+
from trezor.wire.protocol_common import ButtonRequestHandler, Context
1111
from trezorui_api import (
1212
AttachType,
1313
BacklightLevels,
@@ -146,24 +146,28 @@ def callback(*args: str) -> None:
146146
def __str__(self) -> str:
147147
return f"{repr(self)}({self._trace(self.layout)[:150]})"
148148

149+
def is_layout_attached(self) -> bool:
150+
return self._is_attached
151+
149152
def __init__(self, layout: LayoutObj[T]) -> None:
150153
"""Set up a layout."""
151154
self.layout = layout
152155
self.tasks: set[loop.Task[None]] = set()
153156
self.timers: dict[int, loop.Task[None]] = {}
154157
self.result_box: loop.mailbox[Any] = loop.mailbox()
155-
self.button_request_ack_pending: bool = False
156-
self.button_request_box: loop.mailbox[ButtonRequest | None] = loop.mailbox()
158+
self.button_request_handler: ButtonRequestHandler | None = None
157159
self.button_request_task: loop.Task[None] | None = None
158160
self.transition_out: AttachType | None = None
159161
self.backlight_level = BacklightLevels.NORMAL
160162
self.context: Context | None = None
161-
self.state: LayoutState = LayoutState.INITIAL
162163

163164
# Indicates whether we should use Resume attach style when launching.
164165
# Homescreen layouts can override this.
165166
self.should_resume = False
166167

168+
if __debug__:
169+
self._is_attached: bool = False
170+
167171
def is_ready(self) -> bool:
168172
"""True if the layout is in READY state."""
169173
return CURRENT_LAYOUT is not self and self.result_box.is_empty()
@@ -176,9 +180,6 @@ def is_finished(self) -> bool:
176180
"""True if the layout is in FINISHED state."""
177181
return CURRENT_LAYOUT is not self and not self.result_box.is_empty()
178182

179-
def is_layout_attached(self) -> bool:
180-
return self.state is LayoutState.ATTACHED
181-
182183
def start(self) -> None:
183184
"""Start the layout, stopping any other RUNNING layout.
184185
@@ -206,8 +207,7 @@ def start(self) -> None:
206207
set_current_layout(self)
207208

208209
try:
209-
# save context (if exists)
210-
self.context = context.get_context()
210+
self.button_request_handler = context.get_context().button_request_handler
211211
except context.NoWireContext:
212212
pass
213213

@@ -251,65 +251,44 @@ def stop(self, _close_all: bool = True) -> None:
251251
# shut down anyone who is waiting for the result
252252
if _close_all:
253253
self.result_box.maybe_close()
254+
if __debug__ and self.button_request_task is not None:
255+
# Don't raise in production to avoid THP desync
256+
raise wire.FirmwareError("button request ack pending")
254257

255258
if CURRENT_LAYOUT is self:
256259
# fade to black -- backlight is off while no layout is running
257260
backlight_fade(BacklightLevels.NONE)
258261

259262
set_current_layout(None)
260263
if __debug__:
261-
if self.button_request_ack_pending:
262-
msg = "button request ack pending"
263-
if utils.USE_THP:
264-
# Don't raise to avoid THP desync
265-
log.error(__name__, msg)
266-
else:
267-
raise wire.FirmwareError(msg)
268264
notify_layout_change(None)
269265

270266
async def get_result(self) -> T:
271267
"""Wait for, and return, the result of this UI layout."""
272268
if self.is_ready():
273269
self.start()
274270
# else we are (a) still running or (b) already finished
275-
is_done = None
276271
try:
277-
if (ctx := self.context) is not None and self.result_box.is_empty():
278-
is_done = loop.mailbox() # (see below)
279-
280-
def _button_request_task() -> Generator[Any, Any, None]:
281-
try:
282-
yield from ctx.button_request_handler.handle(
283-
button_requests=self.button_request_box,
284-
ack_callback=self._button_request_acked,
285-
)
286-
finally:
287-
is_done.put(None)
288-
289-
self.button_request_task = _button_request_task()
290-
self._start_task(self.button_request_task)
291-
elif __debug__ and not self.button_request_box.is_empty():
292-
log.debug(
293-
__name__,
294-
"ButtonRequest task not started, %s ignored",
295-
self.button_request_box.value,
296-
)
272+
br_handler = self.button_request_handler
273+
if br_handler is not None:
274+
# Keep a reference to ButtonRequest handling task (to avoid prematurely closing it).
275+
br_task = br_handler.br_task(self._button_request_acked)
276+
self.button_request_task = br_task
277+
self._start_task(br_task)
297278

298279
result = await self.result_box
299280
assert CURRENT_LAYOUT is None # the screen is blank now
300281

301-
if is_done is not None:
282+
if br_handler is not None:
302283
# Make sure ButtonRequest is ACKed, before the result is returned.
303284
# Otherwise, THP channel may become desynced (due to two consecutive writes).
304-
self.put_button_request(None)
305-
task = loop.spawn(_waiting_screen())
306-
try:
307-
await is_done
308-
finally:
309-
task.close()
285+
await br_handler.join(_waiting_screen())
310286

311287
return result
312288
finally:
289+
# No more ButtonRequests will be sent
290+
self.button_request_handler = None
291+
self.button_request_task = None
313292
# Close all tasks (including ButtonRequest handler)
314293
self.stop()
315294

@@ -340,50 +319,32 @@ def _event(self, event_call: Callable[..., LayoutState | None], *args: Any) -> N
340319

341320
if state is LayoutState.DONE:
342321
self._emit_message(self.layout.return_value())
343-
322+
# Shutdown is raised after emitting the return value.
344323
elif state is LayoutState.ATTACHED:
345324
first_paint = True
346-
self.button_request_ack_pending = self._button_request()
347-
if self.button_request_ack_pending:
348-
state = LayoutState.TRANSITIONING
349-
elif __debug__:
350-
notify_layout_change(self)
325+
# Process a button request coming out of the Rust layout.
326+
has_br = self.put_button_request(self.layout.button_request())
327+
if __debug__:
328+
self._is_attached = not has_br
329+
if self._is_attached:
330+
notify_layout_change(self)
351331

352-
if state is not None:
353-
self.state = state
332+
elif __debug__ and state is not None:
333+
self._is_attached = False
354334

355335
if first_paint:
356336
self._first_paint()
357337
else:
358338
self._paint()
359339

360-
def _button_request(self) -> bool:
361-
"""Process a button request coming out of the Rust layout."""
362-
res = self.layout.button_request()
363-
if res is None:
364-
return False
365-
366-
if self.context is None:
367-
if __debug__:
368-
log.debug(__name__, "ButtonRequest ignored: %s", res)
340+
def put_button_request(self, msg: ButtonRequestMsg | None) -> bool:
341+
if self.button_request_handler is None or msg is None:
369342
return False
370343

371-
if __debug__ and not self.button_request_box.is_empty():
372-
raise wire.FirmwareError(
373-
"button request already pending -- "
374-
"don't forget to yield your input flow from time to time ^_^"
375-
)
376-
377-
self.put_button_request(res)
344+
br = ButtonRequest(code=msg[0], name=msg[1], pages=self.layout.page_count())
345+
self.button_request_handler.put(br)
378346
return True
379347

380-
def put_button_request(self, msg: ButtonRequestMsg | None) -> None:
381-
br = msg and ButtonRequest(
382-
code=msg[0], name=msg[1], pages=self.layout.page_count()
383-
)
384-
# in production, we don't want this to fail, hence replace=True
385-
self.button_request_box.put(br, replace=True)
386-
387348
def _paint(self) -> None:
388349
"""Paint the layout and ensure that homescreen cache is properly invalidated."""
389350
import storage.cache as storage_cache
@@ -489,11 +450,9 @@ def _handle_touch_events(self) -> Generator[Any, tuple[int, int, int], None]:
489450
touch.close()
490451

491452
def _button_request_acked(self) -> None:
492-
if self.button_request_ack_pending and self.state is LayoutState.TRANSITIONING:
493-
self.button_request_ack_pending = False
494-
self.state = LayoutState.ATTACHED
495-
if __debug__:
496-
notify_layout_change(self)
453+
if __debug__:
454+
self._is_attached = True
455+
notify_layout_change(self)
497456

498457
if utils.USE_BLE:
499458

@@ -579,15 +538,17 @@ class ProgressLayout:
579538
is currently displayed, who needs to redraw and when.
580539
"""
581540

541+
if __debug__:
542+
543+
def is_layout_attached(self) -> bool:
544+
return True
545+
582546
def __init__(self, layout: LayoutObj[UiResult]) -> None:
583547
self.layout = layout
584548
self.transition_out = None
585549
self.value = 0
586550
self.progress_step = 20
587551

588-
def is_layout_attached(self) -> bool:
589-
return True
590-
591552
def report(self, value: int, description: str | None = None) -> None:
592553
"""Report a progress step.
593554

core/src/trezor/wire/protocol_common.py

Lines changed: 61 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
Awaitable,
1414
Callable,
1515
Container,
16+
Generator,
1617
Literal,
1718
NoReturn,
1819
TypeVar,
@@ -22,6 +23,8 @@
2223
from storage.cache_common import DataCache
2324
from trezor.messages import ButtonRequest
2425

26+
AckCallback = Callable[[], None]
27+
2528
LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType)
2629
T = TypeVar("T")
2730

@@ -132,31 +135,78 @@ class ButtonRequestHandler:
132135
"""Handle button requests and unexpected messages from host."""
133136

134137
def __init__(self, ctx: Context) -> None:
135-
self.ctx = ctx
138+
self.ctx = ctx # used for communication with the host.
136139

137-
async def handle(
138-
self,
139-
button_requests: loop.mailbox[ButtonRequest | None],
140-
ack_callback: Callable[[], None] | None,
141-
) -> None:
140+
# Receives ButtonRequest notifications from the active layout,
141+
# or `None` when the layout is closed.
142+
self.box: loop.mailbox[ButtonRequest | None] = loop.mailbox()
143+
144+
# Allows the layout to block until ButtonRequest handling is over,
145+
# using `join()` method.
146+
self.is_done: loop.mailbox[None] = loop.mailbox()
147+
148+
if __debug__:
149+
# Is there a pending ButtonRequest (still waiting for an ButtonAck)?
150+
# Used for detecting missing ButtonAck in debug builds.
151+
self.pending = False
152+
153+
def put(self, br: ButtonRequest) -> None:
154+
if __debug__:
155+
if self.pending:
156+
from . import FirmwareError
157+
158+
raise FirmwareError(
159+
"button request already pending -- "
160+
"don't forget to yield your input flow from time to time ^_^"
161+
)
162+
self.pending = True
163+
164+
# in production, we don't want this to fail, hence replace=True
165+
self.box.put(br, replace=True)
166+
167+
def br_task(self, ack_callback: AckCallback) -> Generator[Any, Any, None]:
168+
assert self.is_done.is_empty()
169+
try:
170+
yield from self._handle(ack_callback)
171+
finally:
172+
# no pending I/O - mark as done, to unblock `join()`.
173+
self.is_done.put(None)
174+
175+
async def join(self, wait_task: loop.Task[None]) -> None:
176+
# `br_task()` must be scheduled before joining.
177+
178+
# notify the handler that no more button requests are expected
179+
# in production, we don't want this to fail, hence replace=True
180+
self.box.put(None, replace=True)
181+
182+
task = loop.spawn(wait_task)
183+
try:
184+
await self.is_done
185+
finally:
186+
assert self.is_done.is_empty()
187+
task.close()
188+
189+
async def _handle(self, ack_callback: AckCallback) -> None:
142190
from trezor.messages import ButtonAck
143191

144192
while True:
145193
# The following task will raise on any incoming message.
146194
unexpected_read = self.ctx.read(None)
147-
br = await loop.race(unexpected_read, button_requests)
195+
br = await loop.race(unexpected_read, self.box)
148196

149197
# Exit the loop when the layout is done.
150198
if br is None:
199+
if __debug__:
200+
self.pending = False
151201
return
152202

153203
if __debug__:
154204
log.info(__name__, "ButtonRequest sent: %s", br.name)
155205
await self.ctx.call(br, ButtonAck)
156206
if __debug__:
207+
self.pending = False
157208
log.info(__name__, "ButtonRequest acked: %s", br.name)
158-
if ack_callback is not None:
159-
ack_callback()
209+
ack_callback()
160210

161211

162212
class ContinueOnErrors(ButtonRequestHandler):
@@ -167,18 +217,14 @@ def __init__(self, ctx: Context, msg: str) -> None:
167217
self._prev_handler: ButtonRequestHandler | None = None
168218
self.msg = msg
169219

170-
async def handle(
171-
self,
172-
button_requests: loop.mailbox[ButtonRequest | None],
173-
ack_callback: Callable[[], None] | None,
174-
) -> None:
220+
async def _handle(self, ack_callback: AckCallback) -> None:
175221
"""Unexpected messages will not cause the handler to fail."""
176222
from .context import UnexpectedMessageException
177223

178224
while True:
179225
try:
180226
# Exit the loop when the layout is done.
181-
return await super().handle(button_requests, ack_callback)
227+
return await super()._handle(ack_callback)
182228
except UnexpectedMessageException as exc:
183229
# in case of THP channel preemption, `msg` is not set.
184230
# TRANSPORT_BUSY error has been already sent by `InterfaceContext.handle_packet()`.

0 commit comments

Comments
 (0)