@@ -633,5 +633,95 @@ def pass_reference(v):
633633 self .assertTrue (np .array_equal (rvecf , np .array ([1.0 , 4.0 ])))
634634
635635
636+ class NumbaDeclareInferred (unittest .TestCase ):
637+ """
638+ Test decorator created with a reconstructed list of arguments using RDF column types,
639+ and a return type inferred from the numba jitted function.
640+ """
641+
642+ def test_fund_types (self ):
643+ """
644+ Test fundamental types
645+ """
646+ df = ROOT .RDataFrame (4 ).Define ("x" , "rdfentry_" )
647+
648+ with self .subTest ("function" ):
649+ def is_even (x ):
650+ return x % 2 == 0
651+ df = df .Define ("is_even_x_1" , is_even , ["x" ])
652+ results = df .Take ["bool" ]("is_even_x_1" ).GetValue ()[0 ]
653+ self .assertEqual (results , True )
654+
655+ with self .subTest ("lambda" ):
656+ df = df .Define ("is_even_x_2" , lambda x : x % 2 == 0 , ["x" ])
657+ results = df .Take ["bool" ]("is_even_x_2" ).GetValue ()[0 ]
658+ self .assertEqual (results , True )
659+
660+ def test_rvec (self ):
661+ """
662+ Test RVec
663+ """
664+ df = ROOT .RDataFrame (4 ).Define ("x" , "ROOT::VecOps::RVec<int>({1, 2, 3})" )
665+
666+ with self .subTest ("function" ):
667+ def square_rvec (v ):
668+ return v * v
669+ df = df .Define ("square_rvec_1" , square_rvec , ["x" ])
670+ results = df .Take ["RVec<int>" ]("square_rvec_1" ).GetValue ()[0 ]
671+ self .assertTrue (np .array_equal (results , np .array ([1 , 4 , 9 ])))
672+
673+ with self .subTest ("lambda" ):
674+ df = df .Define ("square_rvec_2" , lambda v : v * v , ["x" ])
675+ results = df .Take ["RVec<int>" ]("square_rvec_2" ).GetValue ()[0 ]
676+ self .assertTrue (np .array_equal (results , np .array ([1 , 4 , 9 ])))
677+
678+ def test_std_vec (self ):
679+ """
680+ Test std::vector
681+ """
682+ df = ROOT .RDataFrame (4 ).Define ("x" , "std::vector<int>({1, 2, 3})" )
683+
684+ with self .subTest ("function" ):
685+ def square_std_vec (v ):
686+ return v * v
687+ df = df .Define ("square_std_vec_1" , square_std_vec , ["x" ])
688+ results = df .Take ["RVec<int>" ]("square_std_vec_1" ).GetValue ()[0 ]
689+ self .assertTrue (np .array_equal (results , np .array ([1 , 4 , 9 ])))
690+
691+ with self .subTest ("lambda" ):
692+ df = df .Define ("square_std_vec_2" , lambda v : v * v , ["x" ])
693+ results = df .Take ["RVec<int>" ]("square_std_vec_2" ).GetValue ()[0 ]
694+ self .assertTrue (np .array_equal (results , np .array ([1 , 4 , 9 ])))
695+
696+ def test_std_array (self ):
697+ """
698+ Test std::array
699+ """
700+ df = ROOT .RDataFrame (4 ).Define ("x" , "std::array<int, 3>({1, 2, 3})" )
701+
702+ with self .subTest ("function" ):
703+ def square_std_arr (v ):
704+ return v * v
705+ df = df .Define ("square_std_arr_1" , square_std_arr , ["x" ])
706+ results = df .Take ["RVec<int>" ]("square_std_arr_1" ).GetValue ()[0 ]
707+ self .assertTrue (np .array_equal (results , np .array ([1 , 4 , 9 ])))
708+
709+ with self .subTest ("lambda" ):
710+ df = df .Define ("square_std_arr_2" , lambda v : v * v , ["x" ])
711+ results = df .Take ["RVec<int>" ]("square_std_arr_2" ).GetValue ()[0 ]
712+ self .assertTrue (np .array_equal (results , np .array ([1 , 4 , 9 ])))
713+
714+ def test_missing_signature_raises (self ):
715+ """
716+ Ensure an Exception is raised when return type cannot be inferred
717+ and no explicit signature is provided in the decorator.
718+ """
719+ def f (x ):
720+ return x .M ()
721+
722+ with self .assertRaises (Exception ):
723+ ROOT .RDataFrame (4 ).Define ("v" , "ROOT::Math::PtEtaPhiMVector(1, 2, 3, 4)" ).Define ("m" , f , ["v" ])
724+
725+
636726if __name__ == "__main__" :
637727 unittest .main ()
0 commit comments