Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 49 additions & 1 deletion core/src/apps/debug/n4w1_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

from trezor import log, loop
from trezor.messages import DebugLinkN4W1Read, DebugLinkN4W1Response, DebugLinkN4W1Write
from trezor.ui import Layout

if TYPE_CHECKING:
from buffer_types import AnyBytes
from typing import Any
from typing import Any, Awaitable, Iterator

from trezor.wire.context import Context
from typing_extensions import Self
Expand All @@ -28,6 +29,11 @@ def __exit__(self, exc_type: Any, exc_val: Any, tb: Any) -> None:
log.debug(__name__, "N4W1 exchange done")
self.tx.put(None)

async def connect(self) -> None:
"""Wait for N4W1 connection notification."""
res = await self.rx
assert res.value is None

async def read(self, key: str) -> AnyBytes | None:
"""Read a specific entry from N4W1."""
log.debug(__name__, "N4W1 read: %s", key)
Expand All @@ -50,9 +56,51 @@ async def write(self, key: str, value: AnyBytes | None) -> AnyBytes | None:

async def handle(self, ctx: Context) -> None:
"""Called from `apps.debug.dispatch_DebugLinkConnected()`."""
self.rx.put(DebugLinkN4W1Response(value=None)) # notify `self.connect()`
while (req := await self.tx) is not None:
res = await ctx.call(req, DebugLinkN4W1Response)
self.rx.put(res)

def confirm_connect(
self, *, title: str, description: str, button: str, br_name: str | None
) -> Awaitable[None]:
"""Show a layout waiting for N4W1 connection, allowing cancellation."""

from trezor import TR
from trezor.ui.layouts.menu import Menu, confirm_with_menu
from trezorui_api import show_info

self_ctx: N4W1Context = self

class _Connect(Layout):

def create_tasks(self) -> Iterator[loop.Task[None]]:
from trezor.ui import Shutdown
from trezorui_api import CONFIRMED

async def _task() -> None:
await self_ctx.connect() # blocks until N4W1 is connected.
try:
# emitting a message raises Shutdown exception
self._emit_message(CONFIRMED)
except Shutdown:
pass

yield from super().create_tasks()
yield _task()

main = show_info(
title=title,
description=description,
button=(button, False),
external_menu=True,
)
return confirm_with_menu(
main,
Menu.root(cancel=TR.buttons__cancel),
br_name=br_name,
layout_type=_Connect,
)


ctx = N4W1Context()
148 changes: 2 additions & 146 deletions core/src/apps/management/recovery_device/homescreen.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,18 @@
from typing import TYPE_CHECKING, Protocol
from typing import TYPE_CHECKING

import storage.device as storage_device
import storage.recovery as storage_recovery
import storage.recovery_shares as storage_recovery_shares
from trezor import TR, utils, wire
from trezor.messages import Success
from trezor.wire import message_handler

from apps.common import backup_types
from apps.management.recovery_device.recover import RecoveryAborted

from . import layout, recover

if TYPE_CHECKING:
from trezor.enums import BackupMethod, BackupType, RecoveryType

from .layout import RemainingSharesInfo


async def recovery_homescreen() -> None:
from trezor import workflow
Expand Down Expand Up @@ -90,76 +86,10 @@ async def _continue_repeated_backup() -> None:
backup.deactivate_repeated_backup()


if TYPE_CHECKING:

class RecoveryHandler(Protocol):
@classmethod
async def load(cls, recovery_type: RecoveryType) -> "RecoveryHandler": ...

async def show_state(self, is_retry: bool) -> None: ...
async def request_mnemonic(self) -> str | None: ...


class _DisplayHandler:
def __init__(
self,
recovery_type: RecoveryType,
word_count: int,
backup_type: BackupType | None,
) -> None:
self.recovery_type = recovery_type
self.word_count = word_count
self.backup_type = backup_type

@classmethod
async def load(cls, recovery_type: RecoveryType) -> "RecoveryHandler":
# `slip39_state is None` indicates that we are (re)starting the first recovery step,
# which includes word count selection.
if (slip39_state := recover.load_slip39_state()) is None:
# If we are starting recovery, ask for word count first...
try:
word_count = await layout.request_word_count(recovery_type)
except wire.ActionCancelled:
raise RecoveryAborted
# ...and only then show the starting screen with word count.
# Backup type will be deduced from the first share.
backup_type = None
else:
# SLIP-39 recovery is ongoing (at least one share was entered).
word_count, backup_type = slip39_state

return cls(recovery_type, word_count, backup_type)

async def show_state(self, is_retry: bool) -> None:
if is_retry and self.backup_type is not None:
# skip showing recovery state on retries (if first share was entered)
return
await _request_share_first_screen(self.word_count, self.recovery_type)

async def request_mnemonic(self) -> str | None:
"""Return the mnemonic or `None` on cancellation/validation error."""
from .word_validity import WordValidityResult

try:
# returns `None` on cancellation
return await layout.request_mnemonic(self.word_count, self.backup_type)
except WordValidityResult as exc:
# if they were invalid or some checks failed we continue and request them again
await exc.show_error()
return None


async def _recover_secret(
recovery_type: RecoveryType, method: BackupMethod | None
) -> tuple[bytes, BackupType]:
from trezor.enums import BackupMethod

if method not in (None, BackupMethod.Display):
from trezor import log

log.warning(__name__, "Unsupported backup method: %s", method)

handler_type = _DisplayHandler
handler_type = await layout.choose_handler(method)

# Show recovery state in the beginning, on some failures, and after a successful share entry.
is_retry = False
Expand Down Expand Up @@ -337,77 +267,3 @@ async def _process_words(words: str) -> tuple[bytes, BackupType] | None:
return None # more shares are needed

return secret, backup_type


async def _request_share_first_screen(
word_count: int, recovery_type: RecoveryType
) -> None:
from trezor.enums import RecoveryType

if backup_types.is_slip39_word_count(word_count):
remaining = storage_recovery.fetch_slip39_remaining_shares()
if remaining:
group_count = storage_recovery.get_slip39_group_count()
if group_count > 1:
await layout.enter_share(
remaining_shares_info=_get_remaining_groups_and_shares()
)
else:
entered = len(storage_recovery_shares.fetch_group(0))
await layout.enter_share(entered_remaining=(entered, remaining[0]))
else:
if recovery_type == RecoveryType.UnlockRepeatedBackup:
text = TR.recovery__enter_backup
button_label = TR.buttons__continue
else:
text = TR.recovery__enter_any_share
button_label = TR.buttons__enter_share
await layout.homescreen_dialog(
button_label,
text,
TR.recovery__word_count_template.format(word_count),
show_instructions=True,
)
else: # BIP-39
await layout.homescreen_dialog(
TR.buttons__continue,
TR.recovery__enter_backup,
TR.recovery__word_count_template.format(word_count),
show_instructions=True,
)


def _get_remaining_groups_and_shares() -> "RemainingSharesInfo":
"""
Prepare data for Slip39 Advanced - what shares are to be entered.
"""
from trezor.crypto import slip39

shares_remaining = storage_recovery.fetch_slip39_remaining_shares()
assert shares_remaining # should be stored at this point

groups = set()
first_entered_index = -1
for i, group_count in enumerate(shares_remaining):
if group_count < slip39.MAX_SHARE_COUNT:
first_entered_index = i
break

share = None
for index, remaining in enumerate(shares_remaining):
if 0 <= remaining < slip39.MAX_SHARE_COUNT:
m = storage_recovery_shares.fetch_group(index)[0]
if not share:
share = slip39.decode_mnemonic(m)
identifier = tuple(m.split(" ")[0:3])
groups.add(identifier)
elif remaining == slip39.MAX_SHARE_COUNT: # no shares yet
identifier = tuple(
storage_recovery_shares.fetch_group(first_entered_index)[0].split(" ")[
0:2
]
)
groups.add(identifier)

assert share # share needs to be set
return groups, shares_remaining, share.group_threshold
Loading
Loading