From 43f051dda60aa2f6977f76188698719580ca068c Mon Sep 17 00:00:00 2001 From: Nathan Levesque Date: Tue, 11 Feb 2025 17:43:55 -0500 Subject: [PATCH] Fixed SCIM search for large queries --- main/settings.py | 2 +- poetry.lock | 19 ++- pyproject.toml | 2 + scim/filters.py | 130 ++++++++++++++++--- scim/parser/grammar.py | 189 ++++++++++++++++++++++++++++ scim/parser/grammar_test.py | 62 +++++++++ scim/parser/queries/__init__.py | 0 scim/parser/queries/sql.py | 15 --- scim/parser/transpilers/__init__.py | 0 scim/parser/transpilers/sql.py | 20 --- scim/views_test.py | 27 ++++ 11 files changed, 410 insertions(+), 56 deletions(-) create mode 100644 scim/parser/grammar.py create mode 100644 scim/parser/grammar_test.py delete mode 100644 scim/parser/queries/__init__.py delete mode 100644 scim/parser/queries/sql.py delete mode 100644 scim/parser/transpilers/__init__.py delete mode 100644 scim/parser/transpilers/sql.py diff --git a/main/settings.py b/main/settings.py index 47f712eaa5..96c79c83ec 100644 --- a/main/settings.py +++ b/main/settings.py @@ -145,7 +145,7 @@ "SERVICE_PROVIDER_CONFIG_MODEL": "scim.config.LearnSCIMServiceProviderConfig", "USER_ADAPTER": "scim.adapters.LearnSCIMUser", "USER_MODEL_GETTER": "scim.adapters.get_user_model_for_scim", - "USER_FILTER_PARSER": "scim.filters.LearnUserFilterQuery", + "USER_FILTER_PARSER": "scim.filters.UserFilterQuery", "GET_IS_AUTHENTICATED_PREDICATE": "scim.utils.is_authenticated_predicate", } diff --git a/poetry.lock b/poetry.lock index c053b4b5ec..87ac07fa96 100644 --- a/poetry.lock +++ b/poetry.lock @@ -5469,13 +5469,13 @@ tests = ["hypothesis (>=3.27.0)", "pytest (>=3.2.1,!=3.3.0)"] [[package]] name = "pyparsing" -version = "3.2.0" +version = "3.2.1" description = "pyparsing module - Classes and methods to define and execute parsing grammars" optional = false python-versions = ">=3.9" files = [ - {file = "pyparsing-3.2.0-py3-none-any.whl", hash = "sha256:93d9577b88da0bbea8cc8334ee8b918ed014968fd2ec383e868fb8afb1ccef84"}, - {file = "pyparsing-3.2.0.tar.gz", hash = "sha256:cbf74e27246d595d9a74b186b810f6fbb86726dbf3b9532efb343f6d7294fe9c"}, + {file = "pyparsing-3.2.1-py3-none-any.whl", hash = "sha256:506ff4f4386c4cec0590ec19e6302d3aedb992fdc02c761e90416f158dacf8e1"}, + {file = "pyparsing-3.2.1.tar.gz", hash = "sha256:61980854fd66de3a90028d679a954d5f2623e83144b5afe5ee86f43d762e5f0a"}, ] [package.extras] @@ -7305,6 +7305,17 @@ notebook = ["ipywidgets (>=6)"] slack = ["slack-sdk"] telegram = ["requests"] +[[package]] +name = "traceback-with-variables" +version = "2.1.1" +description = "Adds variables to python traceback. Simple, lightweight, controllable. Debug reasons of exceptions by logging or pretty printing colorful variable contexts for each frame in a stacktrace, showing every value. Dump locals environments after errors to console, files, and loggers. Works with Jupyter and IPython." +optional = false +python-versions = ">=3.6" +files = [ + {file = "traceback-with-variables-2.1.1.tar.gz", hash = "sha256:ca1ab9cd2871c3be3bbc57bb7b2dfe4b427763f81da6c632663d27231eaab132"}, + {file = "traceback_with_variables-2.1.1-py3-none-any.whl", hash = "sha256:a98566d3931d151f43b1307e228b13fff5022a5c67defc3f53dbde64e8128e0b"}, +] + [[package]] name = "traitlets" version = "5.14.3" @@ -7834,4 +7845,4 @@ testing = ["coverage[toml]", "zope.event", "zope.testing"] [metadata] lock-version = "2.0" python-versions = "3.12.6" -content-hash = "5ee980b1260ffb80a8dea5001f14d22669e094e1af2f5b195b8caa3a49579068" +content-hash = "43ceb10754931fd2856e37fe9c5d43ca4728f8990e658be769d36dd981383447" diff --git a/pyproject.toml b/pyproject.toml index a1febb47e5..78437705a1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -92,6 +92,7 @@ llama-index-agent-openai = "^0.4.1" langchain-experimental = "^0.3.4" langchain-openai = "^0.3.2" deepmerge = "^2.0" +pyparsing = "^3.2.1" [tool.poetry.group.dev.dependencies] @@ -117,6 +118,7 @@ freezegun = "^1.4.0" pytest-xdist = { version = "^3.6.1", extras = ["psutil"] } anys = "^0.3.0" locust = "^2.31.2" +traceback-with-variables = "^2.1.1" [build-system] diff --git a/scim/filters.py b/scim/filters.py index a971e7fe9f..cae44f5956 100644 --- a/scim/filters.py +++ b/scim/filters.py @@ -1,28 +1,126 @@ +import operator +from collections.abc import Callable from typing import Optional -from django_scim.filters import UserFilterQuery +from django.contrib.auth import get_user_model +from django.db.models import Model, Q +from pyparsing import ParseResults -from scim.parser.queries.sql import PatchedSQLQuery +from scim.parser.grammar import Filters, TermType -class LearnUserFilterQuery(UserFilterQuery): +class FilterQuery: """Filters for users""" - query_class = PatchedSQLQuery + model_cls: type[Model] - attr_map: dict[tuple[Optional[str], Optional[str], Optional[str]], str] = { - ("userName", None, None): "auth_user.username", - ("emails", "value", None): "auth_user.email", - ("active", None, None): "auth_user.is_active", - ("fullName", None, None): "profiles_profile.name", - ("name", "givenName", None): "auth_user.first_name", - ("name", "familyName", None): "auth_user.last_name", + attr_map: dict[tuple[str, Optional[str]], tuple[str, ...]] + + related_selects: list[str] = [] + + dj_op_mapping = { + "eq": "exact", + "ne": "exact", + "gt": "gt", + "ge": "gte", + "lt": "lt", + "le": "lte", + "pr": "isnull", + "co": "contains", + "sw": "startswith", + "ew": "endswith", } - joins: tuple[str, ...] = ( - "INNER JOIN profiles_profile ON profiles_profile.user_id = auth_user.id", - ) + dj_negated_ops = ("ne", "pr") + + @classmethod + def _filter_expr(cls, parsed: ParseResults) -> Q: + if parsed is None: + msg = "Expected a filter, got: None" + raise ValueError(msg) + + if parsed.term_type == TermType.attr_expr: + return cls._attr_expr(parsed) + + msg = f"Unsupported term type: {parsed.term_type}" + raise ValueError(msg) + + @classmethod + def _attr_expr(cls, parsed: ParseResults) -> Q: + dj_op = cls.dj_op_mapping[parsed.comparison_operator.lower()] + + scim_keys = (parsed.attr_name, parsed.sub_attr) + + path_parts = list( + filter( + lambda part: part is not None, + ( + *cls.attr_map.get(scim_keys, scim_keys), + dj_op, + ), + ) + ) + path = "__".join(path_parts) + + q = Q(**{path: parsed.value}) + + if parsed.comparison_operator in cls.dj_negated_ops: + q = ~q + + return q + + @classmethod + def _filters(cls, parsed: ParseResults) -> Q: + parsed_iter = iter(parsed) + q = cls._filter_expr(next(parsed_iter)) + + try: + while operator := cls._logical_op(next(parsed_iter)): + filter_q = cls._filter_expr(next(parsed_iter)) + + # combine the previous and next Q() objects using the bitwise operator + q = operator(q, filter_q) + except StopIteration: + pass + + return q + + @classmethod + def _logical_op(cls, parsed: ParseResults) -> Callable[[Q, Q], Q] | None: + """Convert a defined operator to the corresponding bitwise operator""" + if parsed is None: + return None + + if parsed.logical_operator.lower() == "and": + return operator.and_ + elif parsed.logical_operator.lower() == "or": + return operator.or_ + else: + msg = f"Unexpected operator: {parsed.operator}" + raise ValueError(msg) @classmethod - def search(cls, filter_query, request=None): - return super().search(filter_query, request=request) + def search(cls, filter_query, request=None): # noqa: ARG003 + """Create a search query""" + parsed = Filters.parse_string(filter_query, parse_all=True) + + return cls.model_cls.objects.select_related(*cls.related_selects).filter( + cls._filters(parsed) + ) + + +class UserFilterQuery(FilterQuery): + """FilterQuery for User""" + + attr_map: dict[tuple[str, Optional[str]], tuple[str, ...]] = { + ("userName", None): ("username",), + ("emails", "value"): ("email",), + ("active", None): ("is_active",), + ("fullName", None): ("profile", "name"), + ("name", "givenName"): ("first_name",), + ("name", "familyName"): ("last_name",), + } + + related_selects = ["profile"] + + model_cls = get_user_model() diff --git a/scim/parser/grammar.py b/scim/parser/grammar.py new file mode 100644 index 0000000000..228ebd181e --- /dev/null +++ b/scim/parser/grammar.py @@ -0,0 +1,189 @@ +""" +SCIM filter parsers + + _tag_term_type(TermType.attr_name) + +This module aims to compliantly parse SCIM filter queries per the spec: +https://datatracker.ietf.org/doc/html/rfc7644#section-3.4.2.2 + +Note that this implementation defines things slightly differently +because a naive implementation exactly matching the filter grammar will +result in hitting Python's recursion limit because the grammar defines +logical lists (AND/OR chains) as a recursive relationship. + +This implementation avoids that by defining separately FilterExpr and +Filter. As a result of this, some definitions are collapsed and removed +(e.g. valFilter => FilterExpr). +""" + +from enum import StrEnum, auto + +from pyparsing import ( + CaselessKeyword, + Char, + Combine, + DelimitedList, + FollowedBy, + Forward, + Group, + Literal, + Suppress, + Tag, + alphanums, + alphas, + common, + dbl_quoted_string, + nested_expr, + one_of, + remove_quotes, + ungroup, +) + + +class TagName(StrEnum): + """Tag names""" + + term_type = auto() + value_type = auto() + + +class TermType(StrEnum): + """Tag term type""" + + urn = auto() + attr_name = auto() + attr_path = auto() + attr_expr = auto() + value_path = auto() + presence = auto() + + logical_op = auto() + compare_op = auto() + negation_op = auto() + + filter_expr = auto() + filters = auto() + + +class ValueType(StrEnum): + """Tag value_type""" + + boolean = auto() + number = auto() + string = auto() + null = auto() + + +def _tag_term_type(term_type: TermType) -> Tag: + return Tag(TagName.term_type.name, term_type) + + +def _tag_value_type(value_type: ValueType) -> Tag: + return Tag(TagName.value_type.name, value_type) + + +NameChar = Char(alphanums + "_-") +AttrName = Combine( + Char(alphas) + + NameChar[...] + # ensure we're not somehow parsing an URN + + ~FollowedBy(":") +).set_results_name("attr_name") + _tag_term_type(TermType.attr_name) + +# Example URN-qualifed attr: +# urn:ietf:params:scim:schemas:core:2.0:User:userName +# |--------------- URN --------------------|:| attr | +UrnAttr = Combine( + Combine( + Literal("urn:") + + DelimitedList( + # characters ONLY if followed by colon + Char(alphanums + ".-_")[1, ...] + FollowedBy(":"), + # separator + Literal(":"), + # combine everything back into a singular token + combine=True, + )[1, ...] + ).set_results_name("urn") + # separator between URN and attribute name + + Literal(":") + + AttrName + + _tag_term_type(TermType.urn) +) + + +SubAttr = ungroup(Combine(Suppress(".") + AttrName)).set_results_name("sub_attr") ^ ( + Tag("sub_attr", None) +) + +AttrPath = ( + ( + # match on UrnAttr first + UrnAttr ^ AttrName + ) + + SubAttr + + _tag_term_type(TermType.attr_path) +) + +ComparisonOperator = one_of( + ["eq", "ne", "co", "sw", "ew", "gt", "lt", "ge", "le"], + caseless=True, + as_keyword=True, +).set_results_name("comparison_operator") + _tag_term_type(TermType.compare_op) + +LogicalOperator = Group( + one_of(["or", "and"], caseless=True).set_results_name("logical_operator") + + _tag_term_type(TermType.logical_op) +) + +NegationOperator = Group( + ( + CaselessKeyword("not") + + _tag_term_type(TermType.negation_op) + + Tag("negated", True) # noqa: FBT003 + )[..., 1] + ^ Tag("negated", False) # noqa: FBT003 +) + +ValueTrue = Literal("true").set_parse_action(lambda: True) + _tag_value_type( + ValueType.boolean +) +ValueFalse = Literal("false").set_parse_action(lambda: False) + _tag_value_type( + ValueType.boolean +) +ValueNull = Literal("null").set_parse_action(lambda: None) + _tag_value_type( + ValueType.null +) +ValueNumber = (common.integer | common.fnumber) + _tag_value_type(ValueType.number) +ValueString = dbl_quoted_string.set_parse_action(remove_quotes) + _tag_value_type( + ValueType.string +) + +ComparisonValue = ungroup( + ValueTrue | ValueFalse | ValueNull | ValueNumber | ValueString +).set_results_name("value") + +AttrPresence = Group( + AttrPath + Literal("pr").set_results_name("presence").set_parse_action(lambda: True) +) + _tag_term_type(TermType.presence) +AttrExpression = AttrPresence | Group( + AttrPath + ComparisonOperator + ComparisonValue + _tag_term_type(TermType.attr_expr) +) + +# these are forward references, so that we can have +# parsers circularly reference themselves +FilterExpr = Forward() +Filters = Forward() + +ValuePath = Group(AttrPath + nested_expr("[", "]", Filters)).set_results_name( + "value_path" +) + _tag_term_type(TermType.value_path) + +FilterExpr <<= ( + AttrExpression | ValuePath | (NegationOperator + nested_expr("(", ")", Filters)) +) + _tag_term_type(TermType.filter_expr) + +Filters <<= ( + # comment to force it to wrap the below for operator precedence + (FilterExpr + (LogicalOperator + FilterExpr)[...]) + + _tag_term_type(TermType.filters) +) diff --git a/scim/parser/grammar_test.py b/scim/parser/grammar_test.py new file mode 100644 index 0000000000..27ddff7940 --- /dev/null +++ b/scim/parser/grammar_test.py @@ -0,0 +1,62 @@ +import pytest +from faker import Faker + +from scim.parser.grammar import Filters + +faker = Faker() + + +def test_scim_filter_parser(): + """Runer the parser tests""" + success, results = Filters.run_tests("""\ + userName eq "bjensen" + + name.familyName co "O'Malley" + + userName sw "J" + + urn:ietf:params:scim:schemas:core:2.0:User:userName sw "J" + + title pr + + meta.lastModified gt "2011-05-13T04:42:34Z" + + meta.lastModified ge "2011-05-13T04:42:34Z" + + meta.lastModified lt "2011-05-13T04:42:34Z" + + meta.lastModified le "2011-05-13T04:42:34Z" + + title pr and userType eq "Employee" + + title pr or userType eq "Intern" + + schemas eq "urn:ietf:params:scim:schemas:extension:enterprise:2.0:User" + + userType eq "Employee" and (emails co "example.com" or emails.value co "example.org") + + userType ne "Employee" and not (emails co "example.com" or emails.value co "example.org") + + userType eq "Employee" and (emails.type eq "work") + + userType eq "Employee" and emails[type eq "work" and value co "@example.com"] + + emails[type eq "work" and value co "@example.com"] or ims[type eq "xmpp" and value co "@foo.com"] + """) + + # run_tests will output error messages + assert success + + +@pytest.mark.parametrize("count", [10, 100, 1000, 5000]) +def test_large_filter(count): + """Test that the parser can handle large filters""" + + filter_str = " OR ".join( + [f'email.value eq "{faker.email()}"' for _ in range(count)] + ) + + success, _ = Filters.run_tests(filter_str) + + # run_tests will output error messages + assert success diff --git a/scim/parser/queries/__init__.py b/scim/parser/queries/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/scim/parser/queries/sql.py b/scim/parser/queries/sql.py deleted file mode 100644 index d6d8fcde52..0000000000 --- a/scim/parser/queries/sql.py +++ /dev/null @@ -1,15 +0,0 @@ -from scim2_filter_parser.lexer import SCIMLexer -from scim2_filter_parser.parser import SCIMParser -from scim2_filter_parser.queries.sql import SQLQuery - -from scim.parser.transpilers.sql import PatchedTranspiler - - -class PatchedSQLQuery(SQLQuery): - """Patched SQLQuery to use the patch transpiler""" - - def build_where_sql(self): - self.token_stream = SCIMLexer().tokenize(self.filter) - self.ast = SCIMParser().parse(self.token_stream) - self.transpiler = PatchedTranspiler(self.attr_map) - self.where_sql, self.params_dict = self.transpiler.transpile(self.ast) diff --git a/scim/parser/transpilers/__init__.py b/scim/parser/transpilers/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/scim/parser/transpilers/sql.py b/scim/parser/transpilers/sql.py deleted file mode 100644 index e095cce2b3..0000000000 --- a/scim/parser/transpilers/sql.py +++ /dev/null @@ -1,20 +0,0 @@ -import string - -from scim2_filter_parser.transpilers.sql import Transpiler - - -class PatchedTranspiler(Transpiler): - """ - This is a fixed version of the upstream sql transpiler that converts SCIM queries - to SQL queries. - - Specifically it fixes the upper limit of 26 conditions for the search endpoint due - to the upstream library using the ascii alphabet for query parameters. - """ - - def get_next_id(self): - """Convert the current index to a base26 string""" - chars = string.ascii_lowercase - index = len(self.params) - - return (chars[-1] * int(index / len(chars))) + chars[index % len(chars)] diff --git a/scim/views_test.py b/scim/views_test.py index c1ee44f9db..a3b86a962a 100644 --- a/scim/views_test.py +++ b/scim/views_test.py @@ -413,3 +413,30 @@ def test_bulk_post(scim_client, bulk_test_data): assert actual_value is expected_value else: assert actual_value == expected_value + + +def test_user_search(scim_client): + """Test the user search endpoint""" + users = UserFactory.create_batch(1500) + emails = [user.email for user in users[:1000]] + + resp = scim_client.post( + f"{reverse('scim:users-search')}?count={len(emails)}", + content_type="application/scim+json", + data=json.dumps( + { + "schemas": [djs_constants.SchemaURI.SERACH_REQUEST], + "filter": " OR ".join([f'email EQ "{email}"' for email in emails]), + } + ), + ) + + assert resp.status_code == 200 + + data = resp.json() + + assert data["totalResults"] == len(emails) + assert len(data["Resources"]) == len(emails) + + for resource in data["Resources"]: + assert resource["emails"][0]["value"] in emails