Skip to content

Commit 57432b4

Browse files
committed
dulwich: support kbd-interactive auth in asyncssh vendor
1 parent 4605cd3 commit 57432b4

File tree

1 file changed

+35
-4
lines changed

1 file changed

+35
-4
lines changed

src/scmrepo/git/backend/dulwich/asyncssh_vendor.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""asyncssh SSH vendor for Dulwich."""
22
import asyncio
3+
import os
34
from typing import (
45
TYPE_CHECKING,
56
Callable,
@@ -10,6 +11,7 @@
1011
Sequence,
1112
)
1213

14+
from asyncssh import SSHClient
1315
from dulwich.client import SSHVendor
1416

1517
from scmrepo.asyn import BaseAsyncObject, sync_wrapper
@@ -18,8 +20,10 @@
1820
if TYPE_CHECKING:
1921
from pathlib import Path
2022

23+
from asyncssh.auth import KbdIntPrompts, KbdIntResponse
2124
from asyncssh.config import ConfigPaths, FilePath
2225
from asyncssh.connection import SSHClientConnection
26+
from asyncssh.misc import MaybeAwait
2327
from asyncssh.process import SSHClientProcess
2428
from asyncssh.stream import SSHReader
2529

@@ -131,6 +135,36 @@ def _process_public_key_ok_gh(self, _pkttype, _pktid, packet):
131135
return True
132136

133137

138+
class InteractiveSSHClient(SSHClient):
139+
def kbdint_auth_requested(self) -> "MaybeAwait[Optional[str]]":
140+
return ""
141+
142+
async def kbdint_challenge_received( # pylint: disable=invalid-overridden-method
143+
self,
144+
name: str,
145+
instructions: str,
146+
lang: str,
147+
prompts: "KbdIntPrompts",
148+
) -> Optional["KbdIntResponse"]:
149+
from getpass import getpass
150+
151+
if os.environ.get("GIT_TERMINAL_PROMPT") == "0":
152+
return None
153+
154+
def _getpass(prompt: str) -> str:
155+
return getpass(prompt=prompt).rstrip()
156+
157+
if instructions:
158+
print(instructions)
159+
loop = asyncio.get_running_loop()
160+
return [
161+
await loop.run_in_executor(
162+
None, _getpass, f"({name}) {prompt}" if name else prompt
163+
)
164+
for prompt, _ in prompts
165+
]
166+
167+
134168
class AsyncSSHVendor(BaseAsyncObject, SSHVendor):
135169
def __init__(self, **kwargs) -> None:
136170
super().__init__(**kwargs)
@@ -176,6 +210,7 @@ async def _run_command(
176210
ignore_encrypted=not key_filename,
177211
known_hosts=None,
178212
encoding=None,
213+
client_factory=InteractiveSSHClient,
179214
)
180215
proc = await conn.create_process(command, encoding=None)
181216
except asyncssh.misc.PermissionDenied as exc:
@@ -185,10 +220,6 @@ async def _run_command(
185220
run_command = sync_wrapper(_run_command)
186221

187222

188-
# class ValidatedSSHClientConfig(SSHClientConfig):
189-
# pass
190-
191-
192223
def get_unsupported_opts(config_paths: "ConfigPaths") -> Iterator[str]:
193224
from pathlib import Path, PurePath
194225

0 commit comments

Comments
 (0)