@@ -164,14 +164,14 @@ def _rel_merge_helper(
164
164
165
165
_UNARY_OPERATORS = (_SPECIAL_OPERATOR_ISNULL , _SPECIAL_OPERATOR_ISNOTNULL )
166
166
167
- _REGEX_INSESITIVE = _SPECIAL_OPERATOR_INSENSITIVE + "{}"
167
+ _REGEX_INSENSITIVE = _SPECIAL_OPERATOR_INSENSITIVE + "{}"
168
168
_REGEX_CONTAINS = ".*{}.*"
169
169
_REGEX_STARTSWITH = "{}.*"
170
170
_REGEX_ENDSWITH = ".*{}"
171
171
172
172
# regex operations that require escaping
173
173
_STRING_REGEX_OPERATOR_TABLE = {
174
- "iexact" : _REGEX_INSESITIVE ,
174
+ "iexact" : _REGEX_INSENSITIVE ,
175
175
"contains" : _REGEX_CONTAINS ,
176
176
"icontains" : _SPECIAL_OPERATOR_INSENSITIVE + _REGEX_CONTAINS ,
177
177
"startswith" : _REGEX_STARTSWITH ,
@@ -181,7 +181,7 @@ def _rel_merge_helper(
181
181
}
182
182
# regex operations that do not require escaping
183
183
_REGEX_OPERATOR_TABLE = {
184
- "iregex" : _REGEX_INSESITIVE ,
184
+ "iregex" : _REGEX_INSENSITIVE ,
185
185
}
186
186
# list all regex operations, these will require formatting of the value
187
187
_REGEX_OPERATOR_TABLE .update (_STRING_REGEX_OPERATOR_TABLE )
@@ -409,6 +409,7 @@ def __init__(
409
409
match : TOptional [list [str ]] = None ,
410
410
optional_match : TOptional [list [str ]] = None ,
411
411
where : TOptional [list [str ]] = None ,
412
+ optional_where : TOptional [list [str ]] = None ,
412
413
with_clause : TOptional [str ] = None ,
413
414
return_clause : TOptional [str ] = None ,
414
415
order_by : TOptional [list [str ]] = None ,
@@ -422,6 +423,7 @@ def __init__(
422
423
self .match = match if match else []
423
424
self .optional_match = optional_match if optional_match else []
424
425
self .where = where if where else []
426
+ self .optional_where = optional_where if optional_where else []
425
427
self .with_clause = with_clause
426
428
self .return_clause = return_clause
427
429
self .order_by = order_by
@@ -482,6 +484,8 @@ async def build_source(
482
484
if hasattr (source , "order_by_elements" ):
483
485
self .build_order_by (ident , source )
484
486
487
+ # source.filters seems to be used only by Traversal objects
488
+ # source.q_filters is used by NodeSet objects
485
489
if source .filters or source .q_filters :
486
490
self .build_where_stmt (
487
491
ident = ident ,
@@ -676,11 +680,19 @@ def build_label(self, ident: str, cls: type[AsyncStructuredNode]) -> str:
676
680
"""
677
681
ident_w_label = ident + ":" + cls .__label__
678
682
679
- if not self ._ast .return_clause and (
680
- not self ._ast .additional_return or ident not in self ._ast .additional_return
681
- ):
683
+ if not self ._ast .return_clause :
684
+ if (
685
+ not self ._ast .additional_return
686
+ or ident not in self ._ast .additional_return
687
+ ):
688
+ self ._ast .match .append (f"({ ident_w_label } )" )
689
+ self ._ast .return_clause = ident
690
+ self ._ast .result_class = cls
691
+ elif not self ._ast .match :
692
+ # If we get here, it means return_clause was filled because of an
693
+ # optional match, so we add a regular match for root node.
694
+ # Not very elegant, this part would deserve a refactoring...
682
695
self ._ast .match .append (f"({ ident_w_label } )" )
683
- self ._ast .return_clause = ident
684
696
self ._ast .result_class = cls
685
697
return ident
686
698
@@ -693,16 +705,16 @@ def build_additional_match(self, ident: str, node_set: "AsyncNodeSet") -> None:
693
705
for _ , value in node_set .must_match .items ():
694
706
if isinstance (value , dict ):
695
707
label = ":" + value ["node_class" ].__label__
696
- stmt = _rel_helper (lhs = source_ident , rhs = label , ident = "" , ** value )
708
+ stmt = f"EXISTS ( { _rel_helper (lhs = source_ident , rhs = label , ident = '' , ** value )} )"
697
709
self ._ast .where .append (stmt )
698
710
else :
699
711
raise ValueError ("Expecting dict got: " + repr (value ))
700
712
701
713
for _ , val in node_set .dont_match .items ():
702
714
if isinstance (val , dict ):
703
715
label = ":" + val ["node_class" ].__label__
704
- stmt = _rel_helper (lhs = source_ident , rhs = label , ident = "" , ** val )
705
- self ._ast .where .append ("NOT " + stmt )
716
+ stmt = f"NOT EXISTS ( { _rel_helper (lhs = source_ident , rhs = label , ident = '' , ** val )} )"
717
+ self ._ast .where .append (stmt )
706
718
else :
707
719
raise ValueError ("Expecting dict got: " + repr (val ))
708
720
@@ -718,13 +730,14 @@ def _register_place_holder(self, key: str) -> str:
718
730
719
731
def _parse_path (
720
732
self , source_class : type [AsyncStructuredNode ], prop : str
721
- ) -> Tuple [str , str , str , Any ]:
733
+ ) -> Tuple [str , str , str , Any , bool ]:
722
734
is_rel_filter = "|" in prop
723
735
if is_rel_filter :
724
736
path , prop = prop .rsplit ("|" , 1 )
725
737
else :
726
738
path , prop = prop .rsplit ("__" , 1 )
727
739
result = self .lookup_query_variable (path , return_relation = is_rel_filter )
740
+ is_optional_relation = False
728
741
if not result :
729
742
ident , target_class = self .build_traversal_from_path (
730
743
{
@@ -735,8 +748,8 @@ def _parse_path(
735
748
source_class ,
736
749
)
737
750
else :
738
- ident , target_class = result
739
- return ident , path , prop , target_class
751
+ ident , target_class , is_optional_relation = result
752
+ return ident , path , prop , target_class , is_optional_relation
740
753
741
754
def _finalize_filter_statement (
742
755
self , operator : str , ident : str , prop : str , val : Any
@@ -762,41 +775,66 @@ def _build_filter_statements(
762
775
self ,
763
776
ident : str ,
764
777
filters : dict [str , tuple ],
765
- target : list [str ],
778
+ target : list [tuple [ str , bool ] ],
766
779
source_class : type [AsyncStructuredNode ],
767
780
) -> None :
768
781
for prop , op_and_val in filters .items ():
769
782
path = None
770
783
is_rel_filter = "|" in prop
771
784
target_class = source_class
785
+ is_optional_relation = False
772
786
if "__" in prop or is_rel_filter :
773
- ident , path , prop , target_class = self ._parse_path (source_class , prop )
787
+ (
788
+ ident ,
789
+ path ,
790
+ prop ,
791
+ target_class ,
792
+ is_optional_relation ,
793
+ ) = self ._parse_path (source_class , prop )
774
794
operator , val = op_and_val
775
795
if not is_rel_filter :
776
796
prop = target_class .defined_properties (rels = False )[
777
797
prop
778
798
].get_db_property_name (prop )
779
799
statement = self ._finalize_filter_statement (operator , ident , prop , val )
780
- target .append (statement )
800
+ target .append (( statement , is_optional_relation ) )
781
801
782
802
def _parse_q_filters (
783
803
self , ident : str , q : Union [QBase , Any ], source_class : type [AsyncStructuredNode ]
784
- ) -> str :
785
- target = []
804
+ ) -> tuple [str , str ]:
805
+ target : list [tuple [str , bool ]] = []
806
+
807
+ def add_to_target (statement : str , connector : Q , optional : bool ) -> None :
808
+ if not statement :
809
+ return
810
+ if connector == Q .OR :
811
+ statement = f"({ statement } )"
812
+ target .append ((statement , optional ))
813
+
786
814
for child in q .children :
787
815
if isinstance (child , QBase ):
788
- q_childs = self ._parse_q_filters (ident , child , source_class )
789
- if child .connector == Q .OR :
790
- q_childs = "(" + q_childs + ")"
791
- target .append (q_childs )
816
+ q_childs , q_opt_childs = self ._parse_q_filters (
817
+ ident , child , source_class
818
+ )
819
+ add_to_target (q_childs , child .connector , False )
820
+ add_to_target (q_opt_childs , child .connector , True )
792
821
else :
793
822
kwargs = {child [0 ]: child [1 ]}
794
823
filters = process_filter_args (source_class , kwargs )
795
824
self ._build_filter_statements (ident , filters , target , source_class )
796
- ret = f" { q .connector } " .join (target )
797
- if q .negated :
825
+ match_filters = [filter [0 ] for filter in target if not filter [1 ]]
826
+ opt_match_filters = [filter [0 ] for filter in target if filter [1 ]]
827
+ if q .connector == Q .OR and match_filters and opt_match_filters :
828
+ raise ValueError (
829
+ "Cannot filter using OR operator on variables coming from both MATCH and OPTIONAL MATCH statements"
830
+ )
831
+ ret = f" { q .connector } " .join (match_filters )
832
+ if ret and q .negated :
798
833
ret = f"NOT ({ ret } )"
799
- return ret
834
+ opt_ret = f" { q .connector } " .join (opt_match_filters )
835
+ if opt_ret and q .negated :
836
+ opt_ret = f"NOT ({ opt_ret } )"
837
+ return ret , opt_ret
800
838
801
839
def build_where_stmt (
802
840
self ,
@@ -806,12 +844,18 @@ def build_where_stmt(
806
844
q_filters : Union [QBase , Any , None ] = None ,
807
845
) -> None :
808
846
"""
809
- construct a where statement from some filters
847
+ Construct a where statement from some filters.
848
+
849
+ We make a difference between filters applied to variables coming from MATCH and
850
+ OPTIONAL MATCH statements.
851
+
810
852
"""
811
853
if q_filters is not None :
812
- stmt = self ._parse_q_filters (ident , q_filters , source_class )
854
+ stmt , opt_stmt = self ._parse_q_filters (ident , q_filters , source_class )
813
855
if stmt :
814
856
self ._ast .where .append (stmt )
857
+ if opt_stmt :
858
+ self ._ast .optional_where .append (opt_stmt )
815
859
else :
816
860
stmts = []
817
861
for row in filters :
@@ -839,7 +883,7 @@ def build_where_stmt(
839
883
840
884
def lookup_query_variable (
841
885
self , path : str , return_relation : bool = False
842
- ) -> TOptional [Tuple [str , Any ]]:
886
+ ) -> TOptional [Tuple [str , Any , bool ]]:
843
887
"""Retrieve the variable name generated internally for the given traversal path."""
844
888
subgraph = self ._ast .subgraph
845
889
if not subgraph :
@@ -849,10 +893,19 @@ def lookup_query_variable(
849
893
raise ValueError ("Can only lookup traversal variables" )
850
894
if traversals [0 ] not in subgraph :
851
895
return None
896
+
897
+ # Check if relation is coming from an optional MATCH
898
+ # (declared using fetch|traverse_relations)
899
+ is_optional_relation = False
900
+ for relation in self .node_set .relations_to_fetch :
901
+ if relation ["path" ] == path :
902
+ is_optional_relation = relation .get ("optional" , False )
903
+ break
904
+
852
905
subgraph = subgraph [traversals [0 ]]
853
906
if len (traversals ) == 1 :
854
907
variable_to_return = f"{ subgraph ['rel_variable_name' if return_relation else 'variable_name' ]} "
855
- return variable_to_return , subgraph ["target" ]
908
+ return variable_to_return , subgraph ["target" ], is_optional_relation
856
909
variable_to_return = ""
857
910
last_property = traversals [- 1 ]
858
911
for part in traversals [1 :]:
@@ -864,7 +917,7 @@ def lookup_query_variable(
864
917
# if last part of prop is the last traversal
865
918
# we are safe to lookup the variable from the query
866
919
variable_to_return = f"{ subgraph ['rel_variable_name' if return_relation else 'variable_name' ]} "
867
- return variable_to_return , subgraph ["target" ]
920
+ return variable_to_return , subgraph ["target" ], is_optional_relation
868
921
869
922
def build_query (self ) -> str :
870
923
query : str = ""
@@ -881,16 +934,19 @@ def build_query(self) -> str:
881
934
query += " MATCH "
882
935
query += " MATCH " .join (i for i in self ._ast .match )
883
936
937
+ if self ._ast .where :
938
+ query += " WHERE "
939
+ query += " AND " .join (self ._ast .where )
940
+
884
941
if self ._ast .optional_match :
885
942
query += " OPTIONAL MATCH "
886
943
query += " OPTIONAL MATCH " .join (i for i in self ._ast .optional_match )
887
944
888
- if self ._ast .where :
889
- if self ._ast .optional_match :
890
- # Make sure filtering works as expected with optional match, even if it's not performant...
891
- query += " WITH *"
945
+ if self ._ast .optional_where :
946
+ # Make sure filtering works as expected with optional match, even if it's not performant...
947
+ query += " WITH *"
892
948
query += " WHERE "
893
- query += " AND " .join (self ._ast .where )
949
+ query += " AND " .join (self ._ast .optional_where )
894
950
895
951
if self ._ast .with_clause :
896
952
query += " WITH "
0 commit comments