diff --git a/modelcontextprotocol/utils/search.py b/modelcontextprotocol/utils/search.py index b69377d..96b3b50 100644 --- a/modelcontextprotocol/utils/search.py +++ b/modelcontextprotocol/utils/search.py @@ -6,6 +6,22 @@ class SearchUtils: + OPERATOR_MAP = { + "eq": lambda attr, v, ci: attr.eq(v, case_insensitive=ci), + "neq": lambda attr, v, ci: attr.neq(v, case_insensitive=ci), + "startswith": lambda attr, v, ci: attr.startswith(v, case_insensitive=ci), + "contains": lambda attr, v, ci: attr.contains(v, case_insensitive=ci), + "lt": lambda attr, v: attr.lt(v), + "lte": lambda attr, v: attr.lte(v), + "gt": lambda attr, v: attr.gt(v), + "gte": lambda attr, v: attr.gte(v), + "match": lambda attr, v: attr.match(v), + "has_any_value": lambda attr: attr.has_any_value(), + } + + CASE_INSENSITIVE_OPS = {"lt", "lte", "gt", "gte", "match"} + NO_VALUE_OPS = {"has_any_value"} + @staticmethod def process_results(results: Any) -> Dict[str, Any]: """ @@ -70,27 +86,8 @@ def _apply_operator_condition( f"Applying operator '{operator}' with value '{value}' (case_insensitive={case_insensitive})" ) - if operator == "startswith": - return attr.startswith(value, case_insensitive=case_insensitive) - elif operator == "match": - return attr.match(value) - elif operator == "eq": - return attr.eq(value, case_insensitive=case_insensitive) - elif operator == "neq": - return attr.neq(value, case_insensitive=case_insensitive) - elif operator == "gte": - return attr.gte(value) - elif operator == "lte": - return attr.lte(value) - elif operator == "gt": - return attr.gt(value) - elif operator == "lt": - return attr.lt(value) - elif operator == "has_any_value": - return attr.has_any_value() - elif operator == "contains": - return attr.contains(value, case_insensitive=case_insensitive) - elif operator == "between": + # Special case for between - needs custom handling + if operator == "between": # Expecting value to be a list/tuple with [start, end] if isinstance(value, (list, tuple)) and len(value) == 2: return attr.between(value[0], value[1]) @@ -98,6 +95,17 @@ def _apply_operator_condition( raise ValueError( f"Invalid value format for 'between' operator: {value}, expected [start, end]" ) + + if operator in SearchUtils.OPERATOR_MAP: + op_func = SearchUtils.OPERATOR_MAP[operator] + # Handle operators that don't need value or case_insensitive + if operator in SearchUtils.NO_VALUE_OPS: + return op_func(attr) + elif operator in SearchUtils.CASE_INSENSITIVE_OPS: + return op_func(attr, value) + else: + return op_func(attr, value, case_insensitive) + else: # Try to get the operator method from the attribute op_method = getattr(attr, operator, None) @@ -168,5 +176,6 @@ def _process_condition( logger.debug( f"Applying {search_method_name} equality condition {attr_name}={condition}" ) + search = search_method(attr.eq(condition)) return search