diff --git a/tests/test_client.py b/tests/test_client.py index 1cded1fae..061a5a676 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,5 +1,7 @@ import logging from pathlib import Path +from types import SimpleNamespace +from unittest.mock import MagicMock, patch import httpx import pytest @@ -8,7 +10,13 @@ from starlette.status import HTTP_400_BAD_REQUEST from tiled.adapters.mapping import MapAdapter -from tiled.client import Context, from_context, from_profile, record_history +from tiled.client import ( + Context, + from_context, + from_profile, + from_provider, + record_history, +) from tiled.profiles import load_profiles, paths from tiled.queries import Key from tiled.server.app import build_app @@ -164,3 +172,290 @@ def test_jump_down_tree(): with record_history() as h: client["e"]["d"]["c"]["b"]["a"] assert len(h.requests) == 5 + + +# --------------------------------------------------------------------------- +# from_provider() tests +# --------------------------------------------------------------------------- + +# Patch targets – patch where the names are looked up (in constructors.py), +# not where they are defined (in context.py). +_CONTEXT = "tiled.client.constructors.Context" +_PASSWORD_GRANT = "tiled.client.constructors.password_grant" +_FROM_CONTEXT = "tiled.client.constructors.from_context" + + +def _make_provider_spec(name, mode="internal", auth_endpoint="/auth/provider/endpoint"): + """Build a minimal provider spec object matching tiled's structure.""" + return SimpleNamespace( + provider=name, + mode=mode, + links={"auth_endpoint": auth_endpoint}, + ) + + +def _make_context(providers): + """Build a mock Context with the given provider specs.""" + ctx = MagicMock() + ctx.server_info.authentication.providers = providers + return ctx + + +class TestFromProvider: + """Tests for from_provider().""" + + def test_calls_context_from_any_uri(self): + """Should call Context.from_any_uri with the given URI.""" + providers = [_make_provider_spec("my_authenticator")] + mock_context = _make_context(providers) + mock_client = MagicMock() + + with ( + patch( + _CONTEXT + ".from_any_uri", + return_value=(mock_context, []), + ) as mock_from_uri, + patch(_PASSWORD_GRANT, return_value={"access_token": "tok"}), + patch(_FROM_CONTEXT, return_value=mock_client), + ): + from_provider("http://localhost:8020", "my_authenticator", "user", "pass") + + mock_from_uri.assert_called_once_with("http://localhost:8020") + + def test_resolves_correct_provider(self): + """Should find the named provider and use its auth_endpoint.""" + providers = [ + _make_provider_spec("local", auth_endpoint="/auth/local"), + _make_provider_spec("my_authenticator", auth_endpoint="/auth/aps-dm"), + ] + mock_context = _make_context(providers) + + with ( + patch( + _CONTEXT + ".from_any_uri", + return_value=(mock_context, []), + ), + patch(_PASSWORD_GRANT, return_value={"access_token": "tok"}) as mock_grant, + patch(_FROM_CONTEXT, return_value=MagicMock()), + ): + from_provider("http://localhost:8020", "my_authenticator", "user", "pass") + + mock_grant.assert_called_once_with( + mock_context.http_client, + "/auth/aps-dm", + "my_authenticator", + "user", + "pass", + ) + + def test_calls_configure_auth_with_tokens(self): + """Should call context.configure_auth() with the returned tokens.""" + providers = [_make_provider_spec("my_authenticator")] + mock_context = _make_context(providers) + tokens = {"access_token": "abc", "refresh_token": "def"} + + with ( + patch( + _CONTEXT + ".from_any_uri", + return_value=(mock_context, []), + ), + patch(_PASSWORD_GRANT, return_value=tokens), + patch(_FROM_CONTEXT, return_value=MagicMock()), + ): + from_provider("http://localhost:8020", "my_authenticator", "user", "pass") + + mock_context.configure_auth.assert_called_once_with(tokens) + + def test_sets_has_external_auth(self): + """Should set context.has_external_auth = True after authentication.""" + providers = [_make_provider_spec("my_authenticator")] + mock_context = _make_context(providers) + + with ( + patch( + _CONTEXT + ".from_any_uri", + return_value=(mock_context, []), + ), + patch(_PASSWORD_GRANT, return_value={"access_token": "tok"}), + patch(_FROM_CONTEXT, return_value=MagicMock()), + ): + from_provider("http://localhost:8020", "my_authenticator", "user", "pass") + + assert mock_context.has_external_auth is True + + def test_returns_from_context_result(self): + """Should return the client produced by from_context().""" + providers = [_make_provider_spec("my_authenticator")] + mock_context = _make_context(providers) + mock_client = MagicMock(name="tiled_client") + + with ( + patch( + _CONTEXT + ".from_any_uri", + return_value=(mock_context, []), + ), + patch(_PASSWORD_GRANT, return_value={"access_token": "tok"}), + patch(_FROM_CONTEXT, return_value=mock_client) as mock_fc, + ): + result = from_provider( + "http://localhost:8020", "my_authenticator", "user", "pass" + ) + + mock_fc.assert_called_once_with( + mock_context, + structure_clients="numpy", + node_path_parts=[], + include_data_sources=False, + ) + assert result is mock_client + + def test_forwards_structure_clients(self): + """Should forward structure_clients to from_context().""" + providers = [_make_provider_spec("my_authenticator")] + mock_context = _make_context(providers) + + with ( + patch( + _CONTEXT + ".from_any_uri", + return_value=(mock_context, []), + ), + patch(_PASSWORD_GRANT, return_value={"access_token": "tok"}), + patch(_FROM_CONTEXT, return_value=MagicMock()) as mock_fc, + ): + from_provider( + "http://localhost:8020", "my_authenticator", "user", "pass", "dask" + ) + + assert mock_fc.call_args.kwargs["structure_clients"] == "dask" + + def test_forwards_node_path_parts(self): + """Should forward node_path_parts from Context.from_any_uri to from_context().""" + providers = [_make_provider_spec("my_authenticator")] + mock_context = _make_context(providers) + + with ( + patch( + _CONTEXT + ".from_any_uri", + return_value=(mock_context, ["a", "b", "c"]), + ), + patch(_PASSWORD_GRANT, return_value={"access_token": "tok"}), + patch(_FROM_CONTEXT, return_value=MagicMock()) as mock_fc, + ): + from_provider("http://localhost:8020", "my_authenticator", "user", "pass") + + assert mock_fc.call_args.kwargs["node_path_parts"] == ["a", "b", "c"] + + def test_unknown_provider_raises_valueerror(self): + """Should raise ValueError listing available providers.""" + providers = [ + _make_provider_spec("local"), + _make_provider_spec("my_authenticator"), + ] + mock_context = _make_context(providers) + + with patch( + _CONTEXT + ".from_any_uri", + return_value=(mock_context, []), + ): + with pytest.raises(ValueError, match="no-such-provider") as exc_info: + from_provider( + "http://localhost:8020", "no-such-provider", "user", "pass" + ) + + # Error message should list available providers. + msg = str(exc_info.value) + assert "local" in msg + assert "my_authenticator" in msg + + def test_no_providers_raises_valueerror(self): + """Should raise ValueError when server has no providers.""" + mock_context = _make_context([]) + + with patch( + _CONTEXT + ".from_any_uri", + return_value=(mock_context, []), + ): + with pytest.raises(ValueError, match="not found"): + from_provider( + "http://localhost:8020", "my_authenticator", "user", "pass" + ) + + def test_external_provider_raises_valueerror(self): + """Should raise ValueError for external (non-password) providers.""" + providers = [_make_provider_spec("oidc_provider", mode="external")] + mock_context = _make_context(providers) + + with patch( + _CONTEXT + ".from_any_uri", + return_value=(mock_context, []), + ): + with pytest.raises(ValueError, match="does not support password-based"): + from_provider("http://localhost:8020", "oidc_provider", "user", "pass") + + def test_password_mode_accepted(self): + """Should accept providers with back-compat mode 'password'.""" + providers = [_make_provider_spec("legacy", mode="password")] + mock_context = _make_context(providers) + + with ( + patch( + _CONTEXT + ".from_any_uri", + return_value=(mock_context, []), + ), + patch(_PASSWORD_GRANT, return_value={"access_token": "tok"}), + patch(_FROM_CONTEXT, return_value=MagicMock()), + ): + # Should not raise. + from_provider("http://localhost:8020", "legacy", "user", "pass") + + def test_connection_error_propagates(self): + """Connection errors from Context.from_any_uri should propagate.""" + with patch( + _CONTEXT + ".from_any_uri", + side_effect=ConnectionError("refused"), + ): + with pytest.raises(ConnectionError, match="refused"): + from_provider( + "http://localhost:8020", "my_authenticator", "user", "pass" + ) + + def test_auth_error_propagates(self): + """Authentication errors from password_grant should propagate.""" + providers = [_make_provider_spec("my_authenticator")] + mock_context = _make_context(providers) + + with ( + patch( + _CONTEXT + ".from_any_uri", + return_value=(mock_context, []), + ), + patch( + _PASSWORD_GRANT, + side_effect=Exception("invalid credentials"), + ), + ): + with pytest.raises(Exception, match="invalid credentials"): + from_provider( + "http://localhost:8020", "my_authenticator", "user", "pass" + ) + + def test_first_matching_provider_is_used(self): + """When multiple providers match, the first one should be used.""" + providers = [ + _make_provider_spec("my_authenticator", auth_endpoint="/auth/first"), + _make_provider_spec("my_authenticator", auth_endpoint="/auth/second"), + ] + mock_context = _make_context(providers) + + with ( + patch( + _CONTEXT + ".from_any_uri", + return_value=(mock_context, []), + ), + patch(_PASSWORD_GRANT, return_value={"access_token": "tok"}) as mock_grant, + patch(_FROM_CONTEXT, return_value=MagicMock()), + ): + from_provider("http://localhost:8020", "my_authenticator", "user", "pass") + + # Should use the first matching provider's endpoint. + assert mock_grant.call_args[0][1] == "/auth/first" diff --git a/tiled/client/__init__.py b/tiled/client/__init__.py index eb5b991e4..d3eb18577 100644 --- a/tiled/client/__init__.py +++ b/tiled/client/__init__.py @@ -1,5 +1,12 @@ from ..utils import tree -from .constructors import SERVERS, from_context, from_profile, from_uri, simple +from .constructors import ( + SERVERS, + from_context, + from_profile, + from_provider, + from_uri, + simple, +) from .container import ASCENDING, DESCENDING from .context import Context from .logger import hide_logs, record_history, show_logs @@ -12,6 +19,7 @@ "DELETE_KEY", "from_context", "from_profile", + "from_provider", "from_uri", "hide_logs", "record_history", diff --git a/tiled/client/constructors.py b/tiled/client/constructors.py index ec354a652..d0eda22b6 100644 --- a/tiled/client/constructors.py +++ b/tiled/client/constructors.py @@ -1,5 +1,6 @@ import collections import collections.abc +import logging import pathlib import threading import warnings @@ -10,9 +11,11 @@ from ..utils import import_object, prepend_to_sys_path from .container import DEFAULT_STRUCTURE_CLIENT_DISPATCH, Container -from .context import DEFAULT_TIMEOUT_PARAMS, UNSET, Context +from .context import DEFAULT_TIMEOUT_PARAMS, UNSET, Context, password_grant from .utils import MSGPACK_MIME_TYPE, client_for_item, handle_error, retry_context +logger = logging.getLogger(__name__) + def from_uri( uri, @@ -269,6 +272,122 @@ def from_profile(name, structure_clients=None, **kwargs): return from_uri(**merged) +def from_provider( + uri: str, + provider: str, + username: str, + password: str, + structure_clients="numpy", + *, + include_data_sources=False, +): + """ + Connect to a tiled server and authenticate via a named provider. + + Uses ``Context.from_any_uri()`` to connect without triggering tiled's + built-in interactive login flow, then authenticates via an OAuth2 + password grant and returns a tiled client. + + This only supports providers with mode ``"internal"`` (password-based + authentication). External (e.g. OIDC) providers are not supported + by this function; use ``from_uri`` with the interactive flow instead. + + Example:: + + from tiled.client import from_provider + + client = from_provider( + uri="http://localhost:8000", + provider="MyAuthenticator", + username="joe_user", + password="secret", + ) + print(f"{list(client)=}") + + Parameters + ---------- + uri : str + Tiled server URI (e.g. ``"http://localhost:8000"``). + provider : str + Authentication provider name as configured on the server + (e.g. ``"MyAuthenticator"``). + username : str + Username for authentication. + password : str + Password for authentication. + structure_clients : str or dict, optional + Use "dask" for delayed data loading and "numpy" for immediate + in-memory structures (e.g. normal numpy arrays, pandas + DataFrames). For advanced use, provide dict mapping a + structure_family or a spec to a client object. + include_data_sources : bool, optional + Default False. If True, fetch information about underlying data sources. + + Returns + ------- + client + An authenticated tiled client (the return value of + ``tiled.client.constructors.from_context``). + + Raises + ------ + ValueError + If the named *provider* is not offered by the server, or if the + provider does not use password-based (internal) authentication. + """ + # Connect without triggering interactive login. + logger.debug("Connecting to %s ...", uri) + context, node_path_parts = Context.from_any_uri(uri) + + # Resolve the authentication provider. + providers = context.server_info.authentication.providers + provider_spec = None + for p in providers: + if p.provider == provider: + provider_spec = p + break + + if provider_spec is None: + available = [p.provider for p in providers] + raise ValueError( + f"Provider {provider!r} not found on server {uri}. " + f"Available providers: {available}" + ) + + # Only internal (password-based) providers support the password grant. + # "password" is a back-compat alias for "internal". + if provider_spec.mode not in ("internal", "password"): + raise ValueError( + f"Provider {provider!r} uses mode {provider_spec.mode!r}, " + f"which does not support password-based authentication. " + f"Use from_uri() with the interactive login flow instead." + ) + + auth_endpoint = provider_spec.links["auth_endpoint"] + + # Authenticate via OAuth2 password grant. + logger.debug("Authenticating %r via %s ...", username, provider) + tokens = password_grant( + context.http_client, + auth_endpoint, + provider, + username, + password, + ) + context.configure_auth(tokens) + + # Mark authentication as externally handled so that from_context() + # does not attempt to re-authenticate or overwrite the tokens. + context.has_external_auth = True + + return from_context( + context, + structure_clients=structure_clients, + node_path_parts=node_path_parts, + include_data_sources=include_data_sources, + ) + + def simple( directory: Optional[Union[str, pathlib.Path]] = None, api_key: Optional[str] = None,