@@ -921,6 +921,7 @@ def test_run_mpierr():
921921 pyuvsim .run_uvdata_uvsim (UVData (), ['beamlist' ])
922922
923923
924+ @pytest .mark .parallel (2 )
924925@pytest .mark .parametrize ("order" , [("bda" ,), ("baseline" , "time" ), ("ant2" , "time" )])
925926def test_ordering (uvdata_two_redundant_bls_triangle_sources , order ):
926927 pytest .importorskip ('mpi4py' )
@@ -934,41 +935,62 @@ def test_ordering(uvdata_two_redundant_bls_triangle_sources, order):
934935 beam_dict = beam_dict ,
935936 catalog = sky_model ,
936937 )
937- assert out_uv .blt_order == order
938- assert out_uv .blt_order == uvdata_linear .blt_order
938+ rank = pyuvsim .mpi .get_rank ()
939+ if rank == 0 :
940+ print (rank , out_uv )
941+ assert out_uv .blt_order == order
942+ assert out_uv .blt_order == uvdata_linear .blt_order
939943
940- uvdata_linear .data_array = out_uv .data_array
944+ uvdata_linear .data_array = out_uv .data_array
941945
942- uvdata_linear .reorder_blts (order = "time" , minor_order = "baseline" , conj_convention = "ant1<ant2" )
946+ uvdata_linear .reorder_blts (
947+ order = "time" , minor_order = "baseline" , conj_convention = "ant1<ant2"
948+ )
943949
944- assert np .allclose (
945- uvdata_linear .get_data ((0 , 1 )), uvdata_linear .get_data ((1 , 2 ))
946- )
947- assert not np .allclose (
948- uvdata_linear .get_data ((0 , 1 )), uvdata_linear .get_data ((0 , 2 ))
949- )
950+ assert np .allclose (
951+ uvdata_linear .get_data ((0 , 1 )), uvdata_linear .get_data ((1 , 2 ))
952+ )
953+ assert not np .allclose (
954+ uvdata_linear .get_data ((0 , 1 )), uvdata_linear .get_data ((0 , 2 ))
955+ )
950956
951957
958+ @pytest .mark .parallel (2 )
952959@pytest .mark .parametrize ("order" , [("bda" ,), ("baseline" , "time" ), ("ant2" , "time" )])
953960def test_order_warning (uvdata_two_redundant_bls_triangle_sources , order ):
954961 pytest .importorskip ('mpi4py' )
962+ # need to get the mpi initialized
963+ # now that simulations require at least 2 PUs
964+ pyuvsim .mpi .start_mpi ()
965+ rank = pyuvsim .mpi .get_rank ()
955966 uvdata_linear , beam_list , beam_dict , sky_model = uvdata_two_redundant_bls_triangle_sources
956967
957968 uvdata_linear .reorder_blts (* order )
958969 # delete the order like we forgot to set it
959970 uvdata_linear .blt_order = None
960- with uvtest .check_warnings (
961- UserWarning , match = "The parameter `blt_order` could not be identified."
962- ):
971+ if rank == 0 :
972+ with uvtest .check_warnings (
973+ UserWarning , match = "The parameter `blt_order` could not be identified."
974+ ):
975+
976+ out_uv = pyuvsim .uvsim .run_uvdata_uvsim (
977+ input_uv = uvdata_linear .copy (),
978+ beam_list = beam_list ,
979+ beam_dict = beam_dict ,
980+ catalog = sky_model ,
981+ )
982+
983+ assert out_uv .blt_order == ("time" , "baseline" )
984+ else :
963985 out_uv = pyuvsim .uvsim .run_uvdata_uvsim (
964986 input_uv = uvdata_linear .copy (),
965987 beam_list = beam_list ,
966988 beam_dict = beam_dict ,
967989 catalog = sky_model ,
968990 )
969- assert out_uv .blt_order == ("time" , "baseline" )
970991
971992
993+ @pytest .mark .parallel (2 )
972994def test_nblts_not_square (uvdata_two_redundant_bls_triangle_sources ):
973995 pytest .importorskip ('mpi4py' )
974996 uvdata_linear , beam_list , beam_dict , sky_model = uvdata_two_redundant_bls_triangle_sources
@@ -980,7 +1002,7 @@ def test_nblts_not_square(uvdata_two_redundant_bls_triangle_sources):
9801002 indices = np .nonzero (
9811003 uvdata_linear .baseline_array == uvdata_linear .antnums_to_baseline (0 , 2 )
9821004 )[0 ]
983- print ( uvdata_linear . antnums_to_baseline ( 0 , 2 ), indices )
1005+
9841006 # discard half of them
9851007 indices = indices [::2 ]
9861008 blt_inds = np .delete (np .arange (uvdata_linear .Nblts ), indices )
@@ -994,12 +1016,13 @@ def test_nblts_not_square(uvdata_two_redundant_bls_triangle_sources):
9941016 beam_dict = beam_dict ,
9951017 catalog = sky_model ,
9961018 )
997-
998- assert np .allclose (
999- out_uv .get_data ((0 , 1 )), out_uv .get_data ((1 , 2 ))
1000- )
1001- # make sure (0, 2) has fewer times
1002- assert out_uv .get_data ((0 , 2 )).shape == (out_uv .Ntimes // 2 , out_uv .Nfreqs , out_uv .Npols )
1019+ rank = pyuvsim .mpi .get_rank ()
1020+ if rank == 0 :
1021+ assert np .allclose (
1022+ out_uv .get_data ((0 , 1 )), out_uv .get_data ((1 , 2 ))
1023+ )
1024+ # make sure (0, 2) has fewer times
1025+ assert out_uv .get_data ((0 , 2 )).shape == (out_uv .Ntimes // 2 , out_uv .Nfreqs , out_uv .Npols )
10031026
10041027
10051028def test_tqdm_import_error ():
0 commit comments