-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
eb61853
commit 43f051d
Showing
11 changed files
with
410 additions
and
56 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,28 +1,126 @@ | ||
import operator | ||
from collections.abc import Callable | ||
from typing import Optional | ||
|
||
from django_scim.filters import UserFilterQuery | ||
from django.contrib.auth import get_user_model | ||
from django.db.models import Model, Q | ||
from pyparsing import ParseResults | ||
|
||
from scim.parser.queries.sql import PatchedSQLQuery | ||
from scim.parser.grammar import Filters, TermType | ||
|
||
|
||
class LearnUserFilterQuery(UserFilterQuery): | ||
class FilterQuery: | ||
"""Filters for users""" | ||
|
||
query_class = PatchedSQLQuery | ||
model_cls: type[Model] | ||
|
||
attr_map: dict[tuple[Optional[str], Optional[str], Optional[str]], str] = { | ||
("userName", None, None): "auth_user.username", | ||
("emails", "value", None): "auth_user.email", | ||
("active", None, None): "auth_user.is_active", | ||
("fullName", None, None): "profiles_profile.name", | ||
("name", "givenName", None): "auth_user.first_name", | ||
("name", "familyName", None): "auth_user.last_name", | ||
attr_map: dict[tuple[str, Optional[str]], tuple[str, ...]] | ||
|
||
related_selects: list[str] = [] | ||
|
||
dj_op_mapping = { | ||
"eq": "exact", | ||
"ne": "exact", | ||
"gt": "gt", | ||
"ge": "gte", | ||
"lt": "lt", | ||
"le": "lte", | ||
"pr": "isnull", | ||
"co": "contains", | ||
"sw": "startswith", | ||
"ew": "endswith", | ||
} | ||
|
||
joins: tuple[str, ...] = ( | ||
"INNER JOIN profiles_profile ON profiles_profile.user_id = auth_user.id", | ||
) | ||
dj_negated_ops = ("ne", "pr") | ||
|
||
@classmethod | ||
def _filter_expr(cls, parsed: ParseResults) -> Q: | ||
if parsed is None: | ||
msg = "Expected a filter, got: None" | ||
raise ValueError(msg) | ||
|
||
if parsed.term_type == TermType.attr_expr: | ||
return cls._attr_expr(parsed) | ||
|
||
msg = f"Unsupported term type: {parsed.term_type}" | ||
raise ValueError(msg) | ||
|
||
@classmethod | ||
def _attr_expr(cls, parsed: ParseResults) -> Q: | ||
dj_op = cls.dj_op_mapping[parsed.comparison_operator.lower()] | ||
|
||
scim_keys = (parsed.attr_name, parsed.sub_attr) | ||
|
||
path_parts = list( | ||
filter( | ||
lambda part: part is not None, | ||
( | ||
*cls.attr_map.get(scim_keys, scim_keys), | ||
dj_op, | ||
), | ||
) | ||
) | ||
path = "__".join(path_parts) | ||
|
||
q = Q(**{path: parsed.value}) | ||
|
||
if parsed.comparison_operator in cls.dj_negated_ops: | ||
q = ~q | ||
|
||
return q | ||
|
||
@classmethod | ||
def _filters(cls, parsed: ParseResults) -> Q: | ||
parsed_iter = iter(parsed) | ||
q = cls._filter_expr(next(parsed_iter)) | ||
|
||
try: | ||
while operator := cls._logical_op(next(parsed_iter)): | ||
filter_q = cls._filter_expr(next(parsed_iter)) | ||
|
||
# combine the previous and next Q() objects using the bitwise operator | ||
q = operator(q, filter_q) | ||
except StopIteration: | ||
pass | ||
|
||
return q | ||
|
||
@classmethod | ||
def _logical_op(cls, parsed: ParseResults) -> Callable[[Q, Q], Q] | None: | ||
"""Convert a defined operator to the corresponding bitwise operator""" | ||
if parsed is None: | ||
return None | ||
|
||
if parsed.logical_operator.lower() == "and": | ||
return operator.and_ | ||
elif parsed.logical_operator.lower() == "or": | ||
return operator.or_ | ||
else: | ||
msg = f"Unexpected operator: {parsed.operator}" | ||
raise ValueError(msg) | ||
|
||
@classmethod | ||
def search(cls, filter_query, request=None): | ||
return super().search(filter_query, request=request) | ||
def search(cls, filter_query, request=None): # noqa: ARG003 | ||
"""Create a search query""" | ||
parsed = Filters.parse_string(filter_query, parse_all=True) | ||
|
||
return cls.model_cls.objects.select_related(*cls.related_selects).filter( | ||
cls._filters(parsed) | ||
) | ||
|
||
|
||
class UserFilterQuery(FilterQuery): | ||
"""FilterQuery for User""" | ||
|
||
attr_map: dict[tuple[str, Optional[str]], tuple[str, ...]] = { | ||
("userName", None): ("username",), | ||
("emails", "value"): ("email",), | ||
("active", None): ("is_active",), | ||
("fullName", None): ("profile", "name"), | ||
("name", "givenName"): ("first_name",), | ||
("name", "familyName"): ("last_name",), | ||
} | ||
|
||
related_selects = ["profile"] | ||
|
||
model_cls = get_user_model() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,189 @@ | ||
""" | ||
SCIM filter parsers | ||
+ _tag_term_type(TermType.attr_name) | ||
This module aims to compliantly parse SCIM filter queries per the spec: | ||
https://datatracker.ietf.org/doc/html/rfc7644#section-3.4.2.2 | ||
Note that this implementation defines things slightly differently | ||
because a naive implementation exactly matching the filter grammar will | ||
result in hitting Python's recursion limit because the grammar defines | ||
logical lists (AND/OR chains) as a recursive relationship. | ||
This implementation avoids that by defining separately FilterExpr and | ||
Filter. As a result of this, some definitions are collapsed and removed | ||
(e.g. valFilter => FilterExpr). | ||
""" | ||
|
||
from enum import StrEnum, auto | ||
|
||
from pyparsing import ( | ||
CaselessKeyword, | ||
Char, | ||
Combine, | ||
DelimitedList, | ||
FollowedBy, | ||
Forward, | ||
Group, | ||
Literal, | ||
Suppress, | ||
Tag, | ||
alphanums, | ||
alphas, | ||
common, | ||
dbl_quoted_string, | ||
nested_expr, | ||
one_of, | ||
remove_quotes, | ||
ungroup, | ||
) | ||
|
||
|
||
class TagName(StrEnum): | ||
"""Tag names""" | ||
|
||
term_type = auto() | ||
value_type = auto() | ||
|
||
|
||
class TermType(StrEnum): | ||
"""Tag term type""" | ||
|
||
urn = auto() | ||
attr_name = auto() | ||
attr_path = auto() | ||
attr_expr = auto() | ||
value_path = auto() | ||
presence = auto() | ||
|
||
logical_op = auto() | ||
compare_op = auto() | ||
negation_op = auto() | ||
|
||
filter_expr = auto() | ||
filters = auto() | ||
|
||
|
||
class ValueType(StrEnum): | ||
"""Tag value_type""" | ||
|
||
boolean = auto() | ||
number = auto() | ||
string = auto() | ||
null = auto() | ||
|
||
|
||
def _tag_term_type(term_type: TermType) -> Tag: | ||
return Tag(TagName.term_type.name, term_type) | ||
|
||
|
||
def _tag_value_type(value_type: ValueType) -> Tag: | ||
return Tag(TagName.value_type.name, value_type) | ||
|
||
|
||
NameChar = Char(alphanums + "_-") | ||
AttrName = Combine( | ||
Char(alphas) | ||
+ NameChar[...] | ||
# ensure we're not somehow parsing an URN | ||
+ ~FollowedBy(":") | ||
).set_results_name("attr_name") + _tag_term_type(TermType.attr_name) | ||
|
||
# Example URN-qualifed attr: | ||
# urn:ietf:params:scim:schemas:core:2.0:User:userName | ||
# |--------------- URN --------------------|:| attr | | ||
UrnAttr = Combine( | ||
Combine( | ||
Literal("urn:") | ||
+ DelimitedList( | ||
# characters ONLY if followed by colon | ||
Char(alphanums + ".-_")[1, ...] + FollowedBy(":"), | ||
# separator | ||
Literal(":"), | ||
# combine everything back into a singular token | ||
combine=True, | ||
)[1, ...] | ||
).set_results_name("urn") | ||
# separator between URN and attribute name | ||
+ Literal(":") | ||
+ AttrName | ||
+ _tag_term_type(TermType.urn) | ||
) | ||
|
||
|
||
SubAttr = ungroup(Combine(Suppress(".") + AttrName)).set_results_name("sub_attr") ^ ( | ||
Tag("sub_attr", None) | ||
) | ||
|
||
AttrPath = ( | ||
( | ||
# match on UrnAttr first | ||
UrnAttr ^ AttrName | ||
) | ||
+ SubAttr | ||
+ _tag_term_type(TermType.attr_path) | ||
) | ||
|
||
ComparisonOperator = one_of( | ||
["eq", "ne", "co", "sw", "ew", "gt", "lt", "ge", "le"], | ||
caseless=True, | ||
as_keyword=True, | ||
).set_results_name("comparison_operator") + _tag_term_type(TermType.compare_op) | ||
|
||
LogicalOperator = Group( | ||
one_of(["or", "and"], caseless=True).set_results_name("logical_operator") | ||
+ _tag_term_type(TermType.logical_op) | ||
) | ||
|
||
NegationOperator = Group( | ||
( | ||
CaselessKeyword("not") | ||
+ _tag_term_type(TermType.negation_op) | ||
+ Tag("negated", True) # noqa: FBT003 | ||
)[..., 1] | ||
^ Tag("negated", False) # noqa: FBT003 | ||
) | ||
|
||
ValueTrue = Literal("true").set_parse_action(lambda: True) + _tag_value_type( | ||
ValueType.boolean | ||
) | ||
ValueFalse = Literal("false").set_parse_action(lambda: False) + _tag_value_type( | ||
ValueType.boolean | ||
) | ||
ValueNull = Literal("null").set_parse_action(lambda: None) + _tag_value_type( | ||
ValueType.null | ||
) | ||
ValueNumber = (common.integer | common.fnumber) + _tag_value_type(ValueType.number) | ||
ValueString = dbl_quoted_string.set_parse_action(remove_quotes) + _tag_value_type( | ||
ValueType.string | ||
) | ||
|
||
ComparisonValue = ungroup( | ||
ValueTrue | ValueFalse | ValueNull | ValueNumber | ValueString | ||
).set_results_name("value") | ||
|
||
AttrPresence = Group( | ||
AttrPath + Literal("pr").set_results_name("presence").set_parse_action(lambda: True) | ||
) + _tag_term_type(TermType.presence) | ||
AttrExpression = AttrPresence | Group( | ||
AttrPath + ComparisonOperator + ComparisonValue + _tag_term_type(TermType.attr_expr) | ||
) | ||
|
||
# these are forward references, so that we can have | ||
# parsers circularly reference themselves | ||
FilterExpr = Forward() | ||
Filters = Forward() | ||
|
||
ValuePath = Group(AttrPath + nested_expr("[", "]", Filters)).set_results_name( | ||
"value_path" | ||
) + _tag_term_type(TermType.value_path) | ||
|
||
FilterExpr <<= ( | ||
AttrExpression | ValuePath | (NegationOperator + nested_expr("(", ")", Filters)) | ||
) + _tag_term_type(TermType.filter_expr) | ||
|
||
Filters <<= ( | ||
# comment to force it to wrap the below for operator precedence | ||
(FilterExpr + (LogicalOperator + FilterExpr)[...]) | ||
+ _tag_term_type(TermType.filters) | ||
) |
Oops, something went wrong.