diff --git a/nsi_auth.py b/nsi_auth.py index 423b910..664cc18 100644 --- a/nsi_auth.py +++ b/nsi_auth.py @@ -51,6 +51,8 @@ class Settings(BaseSettings): allowed_client_subject_dn_path: FilePath = FilePath("/config/allowed_client_dn.txt") ssl_client_subject_dn_header: str = "ssl-client-subject-dn" + pem_header: str = "X-Forwarded-Tls-Client-Cert" + traefik_cert_info_header: str = "X-Forwarded-Tls-Client-Cert-Info" use_watchdog: bool = False log_level: str = "INFO" @@ -107,6 +109,27 @@ def _escape_dn_value(value: str) -> str: return value +def _first_der_cert(data: bytes) -> bytes: + """Return just the first DER certificate from potentially concatenated DER bytes. + + Certificates are ASN.1 SEQUENCE structures (tag 0x30). Reads the length from the + tag+length header to slice off exactly the first cert and discard the rest. + """ + if not data or data[0] != 0x30: + raise ValueError("not an ASN.1 SEQUENCE") + idx = 1 + b = data[idx] + if b & 0x80 == 0: + length = b + idx += 1 + else: + n = b & 0x7F + idx += 1 + length = int.from_bytes(data[idx : idx + n], "big") + idx += n + return data[: idx + length] + + def extract_dn_from_pem_header(header_value: str) -> str | None: """Extract DN from Traefik's X-Forwarded-Tls-Client-Cert header (URL-encoded PEM). @@ -115,13 +138,17 @@ def extract_dn_from_pem_header(header_value: str) -> str | None: Returns a normalized DN string in DER field order, or None on parse failure. """ try: - # Traefik strips newlines from the PEM before URL-encoding (to prevent header injection), - # so load_pem_x509_certificate would fail on the re-assembled string. Instead, extract - # the base64 between the PEM markers and load as DER. - # Use unquote (not unquote_plus) to preserve '+' characters valid in base64. + # Use unquote (not unquote_plus) to preserve literal '+' which is valid in base64. + # Traefik sends raw base64 DER without PEM markers or URL-encoding, but unquote + # handles any %XX sequences if ever present. pem_str = unquote(header_value) - b64 = re.sub(r"-----[^-]+-----", "", pem_str).replace(" ", "") - cert = x509.load_der_x509_certificate(base64.b64decode(b64)) + # Support both raw base64 (Traefik) and PEM-wrapped base64. + match = re.search(r"-----BEGIN CERTIFICATE-----([^-]*)-----END CERTIFICATE-----", pem_str) + b64 = match.group(1) if match else pem_str + # Strip whitespace, then extract only the first cert — Traefik may send the full chain + # as concatenated DER bytes with no markers or separators between certs. + raw = base64.b64decode(b64.replace(" ", "").replace("\n", "").replace("\r", "")) + cert = x509.load_der_x509_certificate(_first_der_cert(raw)) except Exception as e: app.logger.warning(f"failed to parse PEM from X-Forwarded-Tls-Client-Cert: {e!s}") return None @@ -158,21 +185,24 @@ def get_client_dn() -> tuple[str | None, str]: Returns: Tuple of (dn, source) where source indicates which header was used. """ - pem_header = request.headers.get("X-Forwarded-Tls-Client-Cert") + pem_header = request.headers.get(settings.pem_header) if pem_header: dn = extract_dn_from_pem_header(pem_header) if dn: - return dn, "traefik-pem" + app.logger.debug(f"extracted DN from {settings.pem_header} (PEM): {dn}") + return dn, settings.pem_header - traefik_header = request.headers.get("X-Forwarded-Tls-Client-Cert-Info") - if traefik_header: - dn = extract_dn_from_traefik_header(traefik_header) + traefik_cert_info_header = request.headers.get(settings.traefik_cert_info_header) + if traefik_cert_info_header: + dn = extract_dn_from_traefik_header(traefik_cert_info_header) if dn: - return dn, "traefik" + app.logger.debug(f"extracted DN from {settings.traefik_cert_info_header}: {dn}") + return dn, settings.traefik_cert_info_header nginx_header = request.headers.get(settings.ssl_client_subject_dn_header) if nginx_header: - return nginx_header, "nginx" + app.logger.debug(f"extracted DN from {settings.ssl_client_subject_dn_header}: {nginx_header}") + return nginx_header, settings.ssl_client_subject_dn_header return None, "none" @@ -184,8 +214,8 @@ def validate() -> tuple[str, int]: if not dn: app.logger.warning( - f"no client DN found in headers (tried X-Forwarded-Tls-Client-Cert, " - f"X-Forwarded-Tls-Client-Cert-Info, {settings.ssl_client_subject_dn_header})" + f"no client DN found in headers (tried {settings.pem_header}, " + f"{settings.traefik_cert_info_header}, {settings.ssl_client_subject_dn_header})" ) return "Forbidden", 403 diff --git a/tests/conftest.py b/tests/conftest.py index c23ce90..ff13762 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,13 +10,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import datetime from collections.abc import Generator from pathlib import Path +from cryptography import x509 +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.x509.oid import NameOID from flask import Flask from flask.testing import FlaskClient from pytest import MonkeyPatch, fixture +_OID_ORGANIZATION_IDENTIFIER = x509.ObjectIdentifier("2.5.4.97") + @fixture def allowed_client_dn(tmp_path: Path) -> Path: @@ -47,3 +54,43 @@ def application(allowed_client_dn: Path, monkeypatch: MonkeyPatch) -> Generator[ def client(application: Flask) -> FlaskClient: """A test client for the application instance.""" return application.test_client() + + +@fixture(scope="session") +def test_cert() -> x509.Certificate: + """Self-signed certificate with extended subject fields for testing.""" + key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + subject = x509.Name([ + x509.NameAttribute(NameOID.COUNTRY_NAME, "US"), + x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, "Michigan"), + x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Test Organization"), + x509.NameAttribute(_OID_ORGANIZATION_IDENTIFIER, "NTRUS+MI-123456"), + x509.NameAttribute(NameOID.EMAIL_ADDRESS, "test@example.com"), + x509.NameAttribute(NameOID.COMMON_NAME, "Test Client"), + ]) + return ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(subject) + .public_key(key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.datetime.now(datetime.UTC)) + .not_valid_after(datetime.datetime.now(datetime.UTC) + datetime.timedelta(days=1)) + .sign(key, hashes.SHA256()) + ) + + +@fixture(scope="session") +def test_cert_dn() -> str: + """Expected DN string for test_cert (DER field order, RFC 4514 escaped).""" + return r"C=US,ST=Michigan,O=Test Organization,organizationIdentifier=NTRUS\+MI-123456,emailAddress=test@example.com,CN=Test Client" + + +@fixture(scope="session") +def pem_header_value(test_cert: x509.Certificate) -> str: + """Traefik X-Forwarded-Tls-Client-Cert header value for test_cert. + + Traefik sends raw base64 DER with no PEM markers and no URL-encoding. + """ + import base64 + return base64.b64encode(test_cert.public_bytes(serialization.Encoding.DER)).decode("ascii") diff --git a/tests/functional/test_application.py b/tests/functional/test_application.py index b7aaa7c..6059ad1 100644 --- a/tests/functional/test_application.py +++ b/tests/functional/test_application.py @@ -10,6 +10,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from urllib.parse import quote_plus + import pytest from flask.testing import FlaskClient @@ -71,3 +73,95 @@ def test_validate_rejects_non_get_methods(client: FlaskClient, method: str) -> N """Verify that the /validate endpoint only accepts GET requests.""" response = getattr(client, method)("/validate") assert response.status_code == 405 + + +# --------------------------------------------------------------------------- +# PEM header (X-Forwarded-Tls-Client-Cert) +# --------------------------------------------------------------------------- + + +def test_validate_pem_header_allowed(client: FlaskClient, pem_header_value: str, test_cert_dn: str) -> None: + """PEM header with DN in allow-list returns 200.""" + from nsi_auth import state + + state.allowed_client_subject_dn = [test_cert_dn] + response = client.get("/validate", headers={"X-Forwarded-Tls-Client-Cert": pem_header_value}) + assert response.status_code == 200 + assert response.data == b"OK" + + +def test_validate_pem_header_not_in_allowlist(client: FlaskClient, pem_header_value: str) -> None: + """PEM header with DN not in allow-list returns 403.""" + from nsi_auth import state + + state.allowed_client_subject_dn = ["CN=SomeoneElse,C=NL"] + response = client.get("/validate", headers={"X-Forwarded-Tls-Client-Cert": pem_header_value}) + assert response.status_code == 403 + + +# --------------------------------------------------------------------------- +# Traefik Info header (X-Forwarded-Tls-Client-Cert-Info) +# --------------------------------------------------------------------------- + + +def test_validate_traefik_info_header_allowed(client: FlaskClient) -> None: + """URL-encoded Subject= info header with DN in allow-list returns 200.""" + from nsi_auth import state + + dn = "CN=Test,O=Org,C=US" + state.allowed_client_subject_dn = [dn] + encoded = quote_plus(f'Subject="{dn}"') + response = client.get("/validate", headers={"X-Forwarded-Tls-Client-Cert-Info": encoded}) + assert response.status_code == 200 + assert response.data == b"OK" + + +def test_validate_traefik_info_header_not_in_allowlist(client: FlaskClient) -> None: + """URL-encoded Subject= info header with DN not in allow-list returns 403.""" + from nsi_auth import state + + state.allowed_client_subject_dn = ["CN=SomeoneElse,C=NL"] + encoded = quote_plus('Subject="CN=Test,O=Org,C=US"') + response = client.get("/validate", headers={"X-Forwarded-Tls-Client-Cert-Info": encoded}) + assert response.status_code == 403 + + +# --------------------------------------------------------------------------- +# Priority and fallback behaviour +# --------------------------------------------------------------------------- + + +def test_validate_pem_takes_priority_over_info( + client: FlaskClient, pem_header_value: str, test_cert_dn: str +) -> None: + """When both headers are present, PEM DN is used (not Info DN).""" + from nsi_auth import state + + # Only the PEM cert's DN is allowed; Info header carries a different DN + state.allowed_client_subject_dn = [test_cert_dn] + info_encoded = quote_plus('Subject="CN=Different,C=NL"') + response = client.get( + "/validate", + headers={ + "X-Forwarded-Tls-Client-Cert": pem_header_value, + "X-Forwarded-Tls-Client-Cert-Info": info_encoded, + }, + ) + assert response.status_code == 200 + + +def test_validate_pem_parse_failure_falls_back_to_info(client: FlaskClient) -> None: + """Garbage PEM header falls back to Info header for DN extraction.""" + from nsi_auth import state + + dn = "CN=Test,O=Org,C=US" + state.allowed_client_subject_dn = [dn] + info_encoded = quote_plus(f'Subject="{dn}"') + response = client.get( + "/validate", + headers={ + "X-Forwarded-Tls-Client-Cert": "not-a-valid-pem", + "X-Forwarded-Tls-Client-Cert-Info": info_encoded, + }, + ) + assert response.status_code == 200 diff --git a/tests/unit/test_dn_extraction.py b/tests/unit/test_dn_extraction.py new file mode 100644 index 0000000..8c1a133 --- /dev/null +++ b/tests/unit/test_dn_extraction.py @@ -0,0 +1,186 @@ +# Copyright 2026 SURF. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import base64 +from urllib.parse import quote_plus + +import pytest +from cryptography import x509 +from cryptography.hazmat.primitives import serialization + + +# --------------------------------------------------------------------------- +# _escape_dn_value +# --------------------------------------------------------------------------- + + +def test_escape_plus() -> None: + from nsi_auth import _escape_dn_value + + assert _escape_dn_value("NTRUS+MI-123456") == r"NTRUS\+MI-123456" + + +def test_escape_comma() -> None: + from nsi_auth import _escape_dn_value + + assert _escape_dn_value("a,b") == r"a\,b" + + +def test_escape_backslash() -> None: + from nsi_auth import _escape_dn_value + + assert _escape_dn_value("a\\b") == r"a\\b" + + +def test_escape_double_quote() -> None: + from nsi_auth import _escape_dn_value + + assert _escape_dn_value('say "hi"') == r'say \"hi\"' + + +def test_escape_hash_at_start() -> None: + from nsi_auth import _escape_dn_value + + assert _escape_dn_value("#value") == r"\#value" + + +def test_escape_hash_not_at_start() -> None: + from nsi_auth import _escape_dn_value + + assert _escape_dn_value("val#ue") == "val#ue" + + +def test_escape_leading_space() -> None: + from nsi_auth import _escape_dn_value + + assert _escape_dn_value(" value") == r"\ value" + + +def test_escape_trailing_space() -> None: + from nsi_auth import _escape_dn_value + + assert _escape_dn_value("value ") == r"value\ " + + +def test_escape_plain_string_unchanged() -> None: + from nsi_auth import _escape_dn_value + + assert _escape_dn_value("University Corporation") == "University Corporation" + + +# --------------------------------------------------------------------------- +# extract_dn_from_traefik_header +# --------------------------------------------------------------------------- + + +def test_traefik_header_valid(application: object) -> None: # noqa: ARG001 + """Plain (unencoded) Subject= wrapper is handled.""" + from nsi_auth import extract_dn_from_traefik_header + + raw = 'Subject="CN=Test,O=Org,C=US"' + assert extract_dn_from_traefik_header(raw) == "CN=Test,O=Org,C=US" + + +def test_traefik_header_url_encoded(application: object) -> None: # noqa: ARG001 + """URL-encoded form (as Traefik sends it) is decoded correctly.""" + from nsi_auth import extract_dn_from_traefik_header + + encoded = quote_plus('Subject="CN=Test Client,O=Org,C=US"') + assert extract_dn_from_traefik_header(encoded) == "CN=Test Client,O=Org,C=US" + + +def test_traefik_header_percent_encoded_chars(application: object) -> None: # noqa: ARG001 + """Percent-encoded delimiters (%3D %22 %2C) are decoded correctly.""" + from nsi_auth import extract_dn_from_traefik_header + + # Subject%3D%22CN%3DTest%2CO%3DOrg%22 → Subject="CN=Test,O=Org" + encoded = "Subject%3D%22CN%3DTest%2CO%3DOrg%22" + assert extract_dn_from_traefik_header(encoded) == "CN=Test,O=Org" + + +def test_traefik_header_no_subject_wrapper(application: object) -> None: # noqa: ARG001 + """Header without Subject= wrapper returns None.""" + from nsi_auth import extract_dn_from_traefik_header + + assert extract_dn_from_traefik_header("CN=Test,O=Org,C=US") is None + + +def test_traefik_header_empty(application: object) -> None: # noqa: ARG001 + from nsi_auth import extract_dn_from_traefik_header + + assert extract_dn_from_traefik_header("") is None + + +# --------------------------------------------------------------------------- +# extract_dn_from_pem_header +# --------------------------------------------------------------------------- + + +def test_pem_header_valid_dn(application: object, pem_header_value: str, test_cert_dn: str) -> None: # noqa: ARG001 + """Valid PEM header returns correct full DN.""" + from nsi_auth import extract_dn_from_pem_header + + assert extract_dn_from_pem_header(pem_header_value) == test_cert_dn + + +def test_pem_header_extended_fields_present(application: object, pem_header_value: str) -> None: # noqa: ARG001 + """organizationIdentifier and emailAddress appear in the extracted DN.""" + from nsi_auth import extract_dn_from_pem_header + + dn = extract_dn_from_pem_header(pem_header_value) + assert dn is not None + assert "organizationIdentifier=" in dn + assert "emailAddress=" in dn + + +def test_pem_header_plus_escaped(application: object, pem_header_value: str) -> None: # noqa: ARG001 + """'+' in organizationIdentifier value is escaped as '\\+'.""" + from nsi_auth import extract_dn_from_pem_header + + dn = extract_dn_from_pem_header(pem_header_value) + assert dn is not None + assert r"NTRUS\+MI-123456" in dn + + +def test_pem_header_cert_chain_uses_first_cert( + application: object, test_cert: x509.Certificate # noqa: ARG001 +) -> None: + """When header contains a chain, only the first cert's DN is returned.""" + from nsi_auth import extract_dn_from_pem_header + + # Traefik sends concatenated raw DER bytes, base64-encoded, no markers + der = test_cert.public_bytes(serialization.Encoding.DER) + chain_b64 = base64.b64encode(der + der).decode("ascii") + + dn = extract_dn_from_pem_header(chain_b64) + assert dn is not None + assert "CN=Test Client" in dn + + +def test_pem_header_garbage_returns_none(application: object) -> None: # noqa: ARG001 + from nsi_auth import extract_dn_from_pem_header + + assert extract_dn_from_pem_header("not-a-cert") is None + + +def test_pem_header_valid_base64_but_not_cert(application: object) -> None: # noqa: ARG001 + """Valid base64 that is not a DER certificate returns None.""" + from nsi_auth import extract_dn_from_pem_header + + assert extract_dn_from_pem_header(base64.b64encode(b"not a cert").decode()) is None + + +@pytest.mark.parametrize("value", ["", " ", "%ZZ"]) +def test_pem_header_malformed_returns_none(application: object, value: str) -> None: # noqa: ARG001 + from nsi_auth import extract_dn_from_pem_header + + assert extract_dn_from_pem_header(value) is None