|
1 | | -from typing import Optional |
| 1 | +import os |
| 2 | +from typing import Any, Dict, Optional |
2 | 3 |
|
3 | | -from dulwich.client import Urllib3HttpGitClient |
| 4 | +from dulwich.client import HTTPUnauthorized, Urllib3HttpGitClient |
4 | 5 |
|
5 | 6 | from scmrepo.git.credentials import Credential, CredentialNotFoundError |
6 | 7 |
|
@@ -37,8 +38,51 @@ def __init__( |
37 | 38 | self.pool_manager.headers.update(basic_auth) |
38 | 39 | self._store_credentials = creds |
39 | 40 |
|
40 | | - def _http_request(self, *args, **kwargs): |
41 | | - result = super()._http_request(*args, **kwargs) |
| 41 | + def _http_request( |
| 42 | + self, |
| 43 | + url: str, |
| 44 | + headers: Optional[Dict[str, str]] = None, |
| 45 | + data: Any = None, |
| 46 | + ): |
| 47 | + try: |
| 48 | + result = super()._http_request(url, headers=headers, data=data) |
| 49 | + except HTTPUnauthorized: |
| 50 | + auth_header = self._get_auth() |
| 51 | + if not auth_header: |
| 52 | + raise |
| 53 | + if headers: |
| 54 | + headers.update(auth_header) |
| 55 | + else: |
| 56 | + headers = auth_header |
| 57 | + result = super()._http_request(url, headers=headers, data=data) |
42 | 58 | if self._store_credentials is not None: |
43 | 59 | self._store_credentials.approve() |
44 | 60 | return result |
| 61 | + |
| 62 | + def _get_auth(self) -> Dict[str, str]: |
| 63 | + from getpass import getpass |
| 64 | + |
| 65 | + from urllib3.util import make_headers |
| 66 | + |
| 67 | + try: |
| 68 | + creds = Credential(username=self._username, url=self._base_url).fill() |
| 69 | + self._store_credentials = creds |
| 70 | + return make_headers(basic_auth=f"{creds.username}:{creds.password}") |
| 71 | + except CredentialNotFoundError: |
| 72 | + pass |
| 73 | + |
| 74 | + if os.environ.get("GIT_TERMINAL_PROMPT") == "0": |
| 75 | + return {} |
| 76 | + |
| 77 | + try: |
| 78 | + if self._username: |
| 79 | + username = self._username |
| 80 | + else: |
| 81 | + username = input(f"Username for '{self._base_url}': ") |
| 82 | + if self._password: |
| 83 | + password = self._password |
| 84 | + else: |
| 85 | + password = getpass(f"Password for '{self._base_url}': ") |
| 86 | + return make_headers(basic_auth=f"{username}:{password}") |
| 87 | + except KeyboardInterrupt: |
| 88 | + return {} |
0 commit comments