Skip to content

Commit

Permalink
Fixed SCIM search for large queries
Browse files Browse the repository at this point in the history
  • Loading branch information
rhysyngsun committed Feb 19, 2025
1 parent eb61853 commit 43f051d
Show file tree
Hide file tree
Showing 11 changed files with 410 additions and 56 deletions.
2 changes: 1 addition & 1 deletion main/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@
"SERVICE_PROVIDER_CONFIG_MODEL": "scim.config.LearnSCIMServiceProviderConfig",
"USER_ADAPTER": "scim.adapters.LearnSCIMUser",
"USER_MODEL_GETTER": "scim.adapters.get_user_model_for_scim",
"USER_FILTER_PARSER": "scim.filters.LearnUserFilterQuery",
"USER_FILTER_PARSER": "scim.filters.UserFilterQuery",
"GET_IS_AUTHENTICATED_PREDICATE": "scim.utils.is_authenticated_predicate",
}

Expand Down
19 changes: 15 additions & 4 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ llama-index-agent-openai = "^0.4.1"
langchain-experimental = "^0.3.4"
langchain-openai = "^0.3.2"
deepmerge = "^2.0"
pyparsing = "^3.2.1"


[tool.poetry.group.dev.dependencies]
Expand All @@ -117,6 +118,7 @@ freezegun = "^1.4.0"
pytest-xdist = { version = "^3.6.1", extras = ["psutil"] }
anys = "^0.3.0"
locust = "^2.31.2"
traceback-with-variables = "^2.1.1"


[build-system]
Expand Down
130 changes: 114 additions & 16 deletions scim/filters.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,126 @@
import operator
from collections.abc import Callable
from typing import Optional

from django_scim.filters import UserFilterQuery
from django.contrib.auth import get_user_model
from django.db.models import Model, Q
from pyparsing import ParseResults

from scim.parser.queries.sql import PatchedSQLQuery
from scim.parser.grammar import Filters, TermType


class LearnUserFilterQuery(UserFilterQuery):
class FilterQuery:
"""Filters for users"""

query_class = PatchedSQLQuery
model_cls: type[Model]

attr_map: dict[tuple[Optional[str], Optional[str], Optional[str]], str] = {
("userName", None, None): "auth_user.username",
("emails", "value", None): "auth_user.email",
("active", None, None): "auth_user.is_active",
("fullName", None, None): "profiles_profile.name",
("name", "givenName", None): "auth_user.first_name",
("name", "familyName", None): "auth_user.last_name",
attr_map: dict[tuple[str, Optional[str]], tuple[str, ...]]

related_selects: list[str] = []

dj_op_mapping = {
"eq": "exact",
"ne": "exact",
"gt": "gt",
"ge": "gte",
"lt": "lt",
"le": "lte",
"pr": "isnull",
"co": "contains",
"sw": "startswith",
"ew": "endswith",
}

joins: tuple[str, ...] = (
"INNER JOIN profiles_profile ON profiles_profile.user_id = auth_user.id",
)
dj_negated_ops = ("ne", "pr")

@classmethod
def _filter_expr(cls, parsed: ParseResults) -> Q:
if parsed is None:
msg = "Expected a filter, got: None"
raise ValueError(msg)

if parsed.term_type == TermType.attr_expr:
return cls._attr_expr(parsed)

msg = f"Unsupported term type: {parsed.term_type}"
raise ValueError(msg)

@classmethod
def _attr_expr(cls, parsed: ParseResults) -> Q:
dj_op = cls.dj_op_mapping[parsed.comparison_operator.lower()]

scim_keys = (parsed.attr_name, parsed.sub_attr)

path_parts = list(
filter(
lambda part: part is not None,
(
*cls.attr_map.get(scim_keys, scim_keys),
dj_op,
),
)
)
path = "__".join(path_parts)

q = Q(**{path: parsed.value})

if parsed.comparison_operator in cls.dj_negated_ops:
q = ~q

return q

@classmethod
def _filters(cls, parsed: ParseResults) -> Q:
parsed_iter = iter(parsed)
q = cls._filter_expr(next(parsed_iter))

try:
while operator := cls._logical_op(next(parsed_iter)):
filter_q = cls._filter_expr(next(parsed_iter))

# combine the previous and next Q() objects using the bitwise operator
q = operator(q, filter_q)
except StopIteration:
pass

return q

@classmethod
def _logical_op(cls, parsed: ParseResults) -> Callable[[Q, Q], Q] | None:
"""Convert a defined operator to the corresponding bitwise operator"""
if parsed is None:
return None

if parsed.logical_operator.lower() == "and":
return operator.and_
elif parsed.logical_operator.lower() == "or":
return operator.or_
else:
msg = f"Unexpected operator: {parsed.operator}"
raise ValueError(msg)

@classmethod
def search(cls, filter_query, request=None):
return super().search(filter_query, request=request)
def search(cls, filter_query, request=None): # noqa: ARG003
"""Create a search query"""
parsed = Filters.parse_string(filter_query, parse_all=True)

return cls.model_cls.objects.select_related(*cls.related_selects).filter(
cls._filters(parsed)
)


class UserFilterQuery(FilterQuery):
"""FilterQuery for User"""

attr_map: dict[tuple[str, Optional[str]], tuple[str, ...]] = {
("userName", None): ("username",),
("emails", "value"): ("email",),
("active", None): ("is_active",),
("fullName", None): ("profile", "name"),
("name", "givenName"): ("first_name",),
("name", "familyName"): ("last_name",),
}

related_selects = ["profile"]

model_cls = get_user_model()
189 changes: 189 additions & 0 deletions scim/parser/grammar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
"""
SCIM filter parsers
+ _tag_term_type(TermType.attr_name)
This module aims to compliantly parse SCIM filter queries per the spec:
https://datatracker.ietf.org/doc/html/rfc7644#section-3.4.2.2
Note that this implementation defines things slightly differently
because a naive implementation exactly matching the filter grammar will
result in hitting Python's recursion limit because the grammar defines
logical lists (AND/OR chains) as a recursive relationship.
This implementation avoids that by defining separately FilterExpr and
Filter. As a result of this, some definitions are collapsed and removed
(e.g. valFilter => FilterExpr).
"""

from enum import StrEnum, auto

from pyparsing import (
CaselessKeyword,
Char,
Combine,
DelimitedList,
FollowedBy,
Forward,
Group,
Literal,
Suppress,
Tag,
alphanums,
alphas,
common,
dbl_quoted_string,
nested_expr,
one_of,
remove_quotes,
ungroup,
)


class TagName(StrEnum):
"""Tag names"""

term_type = auto()
value_type = auto()


class TermType(StrEnum):
"""Tag term type"""

urn = auto()
attr_name = auto()
attr_path = auto()
attr_expr = auto()
value_path = auto()
presence = auto()

logical_op = auto()
compare_op = auto()
negation_op = auto()

filter_expr = auto()
filters = auto()


class ValueType(StrEnum):
"""Tag value_type"""

boolean = auto()
number = auto()
string = auto()
null = auto()


def _tag_term_type(term_type: TermType) -> Tag:
return Tag(TagName.term_type.name, term_type)


def _tag_value_type(value_type: ValueType) -> Tag:
return Tag(TagName.value_type.name, value_type)


NameChar = Char(alphanums + "_-")
AttrName = Combine(
Char(alphas)
+ NameChar[...]
# ensure we're not somehow parsing an URN
+ ~FollowedBy(":")
).set_results_name("attr_name") + _tag_term_type(TermType.attr_name)

# Example URN-qualifed attr:
# urn:ietf:params:scim:schemas:core:2.0:User:userName
# |--------------- URN --------------------|:| attr |
UrnAttr = Combine(
Combine(
Literal("urn:")
+ DelimitedList(
# characters ONLY if followed by colon
Char(alphanums + ".-_")[1, ...] + FollowedBy(":"),
# separator
Literal(":"),
# combine everything back into a singular token
combine=True,
)[1, ...]
).set_results_name("urn")
# separator between URN and attribute name
+ Literal(":")
+ AttrName
+ _tag_term_type(TermType.urn)
)


SubAttr = ungroup(Combine(Suppress(".") + AttrName)).set_results_name("sub_attr") ^ (
Tag("sub_attr", None)
)

AttrPath = (
(
# match on UrnAttr first
UrnAttr ^ AttrName
)
+ SubAttr
+ _tag_term_type(TermType.attr_path)
)

ComparisonOperator = one_of(
["eq", "ne", "co", "sw", "ew", "gt", "lt", "ge", "le"],
caseless=True,
as_keyword=True,
).set_results_name("comparison_operator") + _tag_term_type(TermType.compare_op)

LogicalOperator = Group(
one_of(["or", "and"], caseless=True).set_results_name("logical_operator")
+ _tag_term_type(TermType.logical_op)
)

NegationOperator = Group(
(
CaselessKeyword("not")
+ _tag_term_type(TermType.negation_op)
+ Tag("negated", True) # noqa: FBT003
)[..., 1]
^ Tag("negated", False) # noqa: FBT003
)

ValueTrue = Literal("true").set_parse_action(lambda: True) + _tag_value_type(
ValueType.boolean
)
ValueFalse = Literal("false").set_parse_action(lambda: False) + _tag_value_type(
ValueType.boolean
)
ValueNull = Literal("null").set_parse_action(lambda: None) + _tag_value_type(
ValueType.null
)
ValueNumber = (common.integer | common.fnumber) + _tag_value_type(ValueType.number)
ValueString = dbl_quoted_string.set_parse_action(remove_quotes) + _tag_value_type(
ValueType.string
)

ComparisonValue = ungroup(
ValueTrue | ValueFalse | ValueNull | ValueNumber | ValueString
).set_results_name("value")

AttrPresence = Group(
AttrPath + Literal("pr").set_results_name("presence").set_parse_action(lambda: True)
) + _tag_term_type(TermType.presence)
AttrExpression = AttrPresence | Group(
AttrPath + ComparisonOperator + ComparisonValue + _tag_term_type(TermType.attr_expr)
)

# these are forward references, so that we can have
# parsers circularly reference themselves
FilterExpr = Forward()
Filters = Forward()

ValuePath = Group(AttrPath + nested_expr("[", "]", Filters)).set_results_name(
"value_path"
) + _tag_term_type(TermType.value_path)

FilterExpr <<= (
AttrExpression | ValuePath | (NegationOperator + nested_expr("(", ")", Filters))
) + _tag_term_type(TermType.filter_expr)

Filters <<= (
# comment to force it to wrap the below for operator precedence
(FilterExpr + (LogicalOperator + FilterExpr)[...])
+ _tag_term_type(TermType.filters)
)
Loading

0 comments on commit 43f051d

Please sign in to comment.