From e9052cf1d034c98f4f81a84960940cdc3d287beb Mon Sep 17 00:00:00 2001 From: Bryna Hazelton Date: Tue, 15 Jul 2025 15:43:51 -0700 Subject: [PATCH 01/16] add tests that breaks due to the bug --- tests/uvdata/test_uvdata.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/uvdata/test_uvdata.py b/tests/uvdata/test_uvdata.py index abf7fe291..1a1bebfa6 100644 --- a/tests/uvdata/test_uvdata.py +++ b/tests/uvdata/test_uvdata.py @@ -3474,6 +3474,7 @@ def test_add_times(casa_uvfits): @pytest.mark.filterwarnings("ignore:The uvw_array does not match the expected values") def test_add_bls(casa_uvfits): uv_full = casa_uvfits + uv_full.scan_number_array = np.arange(uv_full.Nblts) ant_list = list(range(15)) # Roughly half the antennas in the data # All blts where ant_1 is in list @@ -4170,6 +4171,7 @@ def test_fast_concat_times(casa_uvfits): @pytest.mark.parametrize("in_order", [True, False]) def test_fast_concat_bls(casa_uvfits, in_order): uv_full = casa_uvfits + uv_full.scan_number_array = np.arange(uv_full.Nblts) if in_order: # divide in half to keep in order From bdb3b40ff4202db7b6b139fc47eb6e2411ccd32a Mon Sep 17 00:00:00 2001 From: Bryna Hazelton Date: Tue, 15 Jul 2025 16:37:56 -0700 Subject: [PATCH 02/16] start work on making add and fast_concat use forms --- src/pyuvdata/uvbase.py | 41 ++++++++ src/pyuvdata/uvdata/uvdata.py | 181 ++++++++++------------------------ tests/uvdata/test_uvdata.py | 6 ++ 3 files changed, 101 insertions(+), 127 deletions(-) diff --git a/src/pyuvdata/uvbase.py b/src/pyuvdata/uvbase.py index 7bb3d9a81..2684dbd98 100644 --- a/src/pyuvdata/uvbase.py +++ b/src/pyuvdata/uvbase.py @@ -754,6 +754,47 @@ def copy(self): """ return copy.deepcopy(self) + def _get_param_axis(self, axis_name: str, single_named_axis: bool = False): + """ + Get a mapping of parameters that have a given axis to the axis number. + + Parameters + ---------- + axis_name : str + A named parameter within the object (e.g., "Nblts", "Ntimes", "Nants"). + single_named_axis : bool + Option to only include parameters with a single named axis. + + Returns + ------- + dict + The keys are UVParameter names that have an axis with axis_name + (axis_name appears in their form). The values are a list of the axis + indices where axis_name appears in their form. + """ + ret_dict = {} + for param in self: + # For each attribute, if the value is None, then bail, otherwise + # attempt to figure out along which axis ind_arr will apply. + + attr = getattr(self, param) + if ( + attr.value is not None + and isinstance(attr.form, tuple) + and axis_name in attr.form + ): + if ( + single_named_axis + and sum([isinstance(entry, str) for entry in attr.form]) > 1 + ): + continue + + # Only look at where form is a tuple, since that's the only case we + # can have a dynamically defined shape. Note that index doesn't work + # here in the case of a repeated param_name in the form. + ret_dict[attr.name] = np.nonzero(np.asarray(attr.form) == axis_name)[0] + return ret_dict + def _select_along_param_axis(self, param_dict: dict): """ Downselect values along a parameterized axis. diff --git a/src/pyuvdata/uvdata/uvdata.py b/src/pyuvdata/uvdata/uvdata.py index b55c1e1ce..741909338 100644 --- a/src/pyuvdata/uvdata/uvdata.py +++ b/src/pyuvdata/uvdata/uvdata.py @@ -40,6 +40,43 @@ ) +def _axis_add_helper(this, other, axis_name: str, other_inds, final_order=None): + update_params = this._get_param_axis(axis_name, single_named_axis=True) + other_form_dict = {axis_name: other_inds} + for param, axis_list in update_params.items(): + axis = axis_list[0] + new_array = np.concatenate( + [ + getattr(this, param), + getattr(other, "_" + param).get_from_form(other_form_dict), + ], + axis=axis, + ) + if param == "scan_number_array": + print() + print(new_array) + if final_order is not None: + new_array = np.take(new_array, final_order, axis=axis) + if param == "scan_number_array": + print(new_array) + + setattr(this, param, new_array) + + +def _axis_fast_concat_helper(this, other, axis_name: str): + update_params = this._get_param_axis(axis_name) + for param, axis_list in update_params.items(): + axis = axis_list[0] + setattr( + this, + param, + np.concatenate( + [getattr(this, param)] + [getattr(obj, param) for obj in other], + axis=axis, + ), + ) + + class UVData(UVBase): """ A class for defining a radio interferometer dataset. @@ -5568,6 +5605,12 @@ def __add__( "_uvw_array", ] compatibility_params.extend(extra_params) + # TODO: make this list programmatically if possible? + if ( + this.scan_number_array is not None + or other.scan_number_array is not None + ): + compatibility_params.append("_scan_number_array") # find the freq indices in "other" but not in "this" if (this.flex_spw_polarization_array is None) != ( @@ -5645,6 +5688,7 @@ def __add__( "_phase_center_app_dec", "_phase_center_frame_pa", "_phase_center_id_array", + "_scan_number_array", ] for cp in compatibility_params: if cp in blt_inds_params: @@ -5731,6 +5775,9 @@ def __add__( if len(bnew_inds) > 0: this_blts = np.concatenate((this_blts, new_blts)) blt_order = np.argsort(this_blts) + + _axis_add_helper(this, other, "Nblts", bnew_inds, blt_order) + if not self.metadata_only: zero_pad = np.zeros((len(bnew_inds), this.Nfreqs, this.Npols)) this.data_array = np.concatenate([this.data_array, zero_pad], axis=0) @@ -5740,53 +5787,11 @@ def __add__( this.flag_array = np.concatenate( [this.flag_array, 1 - zero_pad], axis=0 ).astype(np.bool_) - this.uvw_array = np.concatenate( - [this.uvw_array, other.uvw_array[bnew_inds, :]], axis=0 - )[blt_order, :] - this.time_array = np.concatenate( - [this.time_array, other.time_array[bnew_inds]] - )[blt_order] - this.integration_time = np.concatenate( - [this.integration_time, other.integration_time[bnew_inds]] - )[blt_order] - this.lst_array = np.concatenate( - [this.lst_array, other.lst_array[bnew_inds]] - )[blt_order] - this.ant_1_array = np.concatenate( - [this.ant_1_array, other.ant_1_array[bnew_inds]] - )[blt_order] - this.ant_2_array = np.concatenate( - [this.ant_2_array, other.ant_2_array[bnew_inds]] - )[blt_order] - this.baseline_array = np.concatenate( - [this.baseline_array, other.baseline_array[bnew_inds]] - )[blt_order] - this.phase_center_app_ra = np.concatenate( - [this.phase_center_app_ra, other.phase_center_app_ra[bnew_inds]] - )[blt_order] - this.phase_center_app_dec = np.concatenate( - [this.phase_center_app_dec, other.phase_center_app_dec[bnew_inds]] - )[blt_order] - this.phase_center_frame_pa = np.concatenate( - [this.phase_center_frame_pa, other.phase_center_frame_pa[bnew_inds]] - )[blt_order] - this.phase_center_id_array = np.concatenate( - [this.phase_center_id_array, other.phase_center_id_array[bnew_inds]] - )[blt_order] f_order = None if len(fnew_inds) > 0: - this.freq_array = np.concatenate( - [this.freq_array, other.freq_array[fnew_inds]] - ) - this.channel_width = np.concatenate( - [this.channel_width, other.channel_width[fnew_inds]] - ) + _axis_add_helper(this, other, "Nfreqs", fnew_inds) - this.flex_spw_id_array = np.concatenate( - [this.flex_spw_id_array, other.flex_spw_id_array[fnew_inds]] - ) - this.spw_array = np.concatenate([this.spw_array, other.spw_array]) # We want to preserve per-spw information based on first appearance # in the concatenated array. unique_index = np.sort( @@ -5833,9 +5838,8 @@ def __add__( p_order = None if len(pnew_inds) > 0: - this.polarization_array = np.concatenate( - [this.polarization_array, other.polarization_array[pnew_inds]] - ) + _axis_add_helper(this, other, "Npols", pnew_inds) + p_order = np.argsort(np.abs(this.polarization_array)) if not self.metadata_only: zero_pad = np.zeros( @@ -6217,18 +6221,8 @@ def fast_concat( if axis == "freq": this.Nfreqs = sum([this.Nfreqs] + [obj.Nfreqs for obj in other]) - this.freq_array = np.concatenate( - [this.freq_array] + [obj.freq_array for obj in other] - ) - this.channel_width = np.concatenate( - [this.channel_width] + [obj.channel_width for obj in other] - ) - this.flex_spw_id_array = np.concatenate( - [this.flex_spw_id_array] + [obj.flex_spw_id_array for obj in other] - ) - this.spw_array = np.concatenate( - [this.spw_array] + [obj.spw_array for obj in other] - ) + _axis_fast_concat_helper(this, other, "Nfreqs") + _axis_fast_concat_helper(this, other, "Nspws") # We want to preserve per-spw information based on first appearance # in the concatenated array. unique_index = np.sort( @@ -6238,83 +6232,16 @@ def fast_concat( this.Nspws = len(this.spw_array) - if not self.metadata_only: - this.data_array = np.concatenate( - [this.data_array] + [obj.data_array for obj in other], axis=1 - ) - this.nsample_array = np.concatenate( - [this.nsample_array] + [obj.nsample_array for obj in other], axis=1 - ) - this.flag_array = np.concatenate( - [this.flag_array] + [obj.flag_array for obj in other], axis=1 - ) elif axis == "polarization": - this.polarization_array = np.concatenate( - [this.polarization_array] + [obj.polarization_array for obj in other] - ) + _axis_fast_concat_helper(this, other, "Npols") this.Npols = sum([this.Npols] + [obj.Npols for obj in other]) - if not self.metadata_only: - this.data_array = np.concatenate( - [this.data_array] + [obj.data_array for obj in other], axis=2 - ) - this.nsample_array = np.concatenate( - [this.nsample_array] + [obj.nsample_array for obj in other], axis=2 - ) - this.flag_array = np.concatenate( - [this.flag_array] + [obj.flag_array for obj in other], axis=2 - ) elif axis == "blt": this.Nblts = sum([this.Nblts] + [obj.Nblts for obj in other]) - this.ant_1_array = np.concatenate( - [this.ant_1_array] + [obj.ant_1_array for obj in other] - ) - this.ant_2_array = np.concatenate( - [this.ant_2_array] + [obj.ant_2_array for obj in other] - ) + _axis_fast_concat_helper(this, other, "Nblts") this.Nants_data = this._calc_nants_data() - this.uvw_array = np.concatenate( - [this.uvw_array] + [obj.uvw_array for obj in other], axis=0 - ) - this.time_array = np.concatenate( - [this.time_array] + [obj.time_array for obj in other] - ) this.Ntimes = len(np.unique(this.time_array)) - this.lst_array = np.concatenate( - [this.lst_array] + [obj.lst_array for obj in other] - ) - this.baseline_array = np.concatenate( - [this.baseline_array] + [obj.baseline_array for obj in other] - ) this.Nbls = len(np.unique(this.baseline_array)) - this.integration_time = np.concatenate( - [this.integration_time] + [obj.integration_time for obj in other] - ) - this.phase_center_app_ra = np.concatenate( - [this.phase_center_app_ra] + [obj.phase_center_app_ra for obj in other] - ) - this.phase_center_app_dec = np.concatenate( - [this.phase_center_app_dec] - + [obj.phase_center_app_dec for obj in other] - ) - this.phase_center_frame_pa = np.concatenate( - [this.phase_center_frame_pa] - + [obj.phase_center_frame_pa for obj in other] - ) - this.phase_center_id_array = np.concatenate( - [this.phase_center_id_array] - + [obj.phase_center_id_array for obj in other] - ) - if not self.metadata_only: - this.data_array = np.concatenate( - [this.data_array] + [obj.data_array for obj in other], axis=0 - ) - this.nsample_array = np.concatenate( - [this.nsample_array] + [obj.nsample_array for obj in other], axis=0 - ) - this.flag_array = np.concatenate( - [this.flag_array] + [obj.flag_array for obj in other], axis=0 - ) # update filename attribute for obj in other: diff --git a/tests/uvdata/test_uvdata.py b/tests/uvdata/test_uvdata.py index 1a1bebfa6..eeb12971d 100644 --- a/tests/uvdata/test_uvdata.py +++ b/tests/uvdata/test_uvdata.py @@ -3385,6 +3385,7 @@ def test_sum_vis_errors(hera_uvh5, attr_to_get, attr_to_set, arg_dict, msg): @pytest.mark.filterwarnings("ignore:The uvw_array does not match the expected values") def test_add_freq(casa_uvfits): uv_full = casa_uvfits + uv_full.scan_number_array = np.arange(uv_full.Nblts) uv1 = uv_full.select(freq_chans=np.arange(0, 32), inplace=False) uv2 = uv_full.select(freq_chans=np.arange(32, 64), inplace=False) @@ -3415,6 +3416,7 @@ def test_add_freq(casa_uvfits): @pytest.mark.filterwarnings("ignore:The uvw_array does not match the expected values") def test_add_pols(casa_uvfits): uv_full = casa_uvfits + uv_full.scan_number_array = np.arange(uv_full.Nblts) uv1 = uv_full.select(polarizations=uv_full.polarization_array[0:2], inplace=False) uv2 = uv_full.select(polarizations=uv_full.polarization_array[2:4], inplace=False) @@ -3453,6 +3455,7 @@ def test_add_pols(casa_uvfits): @pytest.mark.filterwarnings("ignore:The uvw_array does not match the expected values") def test_add_times(casa_uvfits): uv_full = casa_uvfits + uv_full.scan_number_array = np.arange(uv_full.Nblts) times = np.unique(uv_full.time_array) uv1 = uv_full.select(times=times[0 : len(times) // 2], inplace=False) @@ -3474,6 +3477,7 @@ def test_add_times(casa_uvfits): @pytest.mark.filterwarnings("ignore:The uvw_array does not match the expected values") def test_add_bls(casa_uvfits): uv_full = casa_uvfits + uv_full.reorder_blts() uv_full.scan_number_array = np.arange(uv_full.Nblts) ant_list = list(range(15)) # Roughly half the antennas in the data @@ -3515,6 +3519,7 @@ def test_add_bls(casa_uvfits): uv3.ant_1_array = uv3.ant_1_array[-1::-1] uv3.ant_2_array = uv3.ant_2_array[-1::-1] uv3.baseline_array = uv3.baseline_array[-1::-1] + uv3.scan_number_array = uv3.scan_number_array[-1::-1] uv1 += uv3 uv1 += uv2 assert utils.history._check_histories( @@ -3533,6 +3538,7 @@ def test_add_bls(casa_uvfits): @pytest.mark.filterwarnings("ignore:The uvw_array does not match the expected values") def test_add_multi_axis(casa_uvfits): uv_full = casa_uvfits + uv_full.scan_number_array = np.arange(uv_full.Nblts) uv_ref = uv_full.copy() times = np.unique(uv_full.time_array) From 1f9213121716d36ae57594bbb8864ab29438c7fb Mon Sep 17 00:00:00 2001 From: Bryna Hazelton Date: Mon, 18 Aug 2025 14:50:26 -0700 Subject: [PATCH 03/16] finish up add and fast concat work using forms --- src/pyuvdata/uvdata/uvdata.py | 615 ++++++++++++++++------------------ tests/uvdata/test_uvdata.py | 8 +- 2 files changed, 293 insertions(+), 330 deletions(-) diff --git a/src/pyuvdata/uvdata/uvdata.py b/src/pyuvdata/uvdata/uvdata.py index 741909338..0afed4a18 100644 --- a/src/pyuvdata/uvdata/uvdata.py +++ b/src/pyuvdata/uvdata/uvdata.py @@ -52,13 +52,8 @@ def _axis_add_helper(this, other, axis_name: str, other_inds, final_order=None): ], axis=axis, ) - if param == "scan_number_array": - print() - print(new_array) if final_order is not None: new_array = np.take(new_array, final_order, axis=axis) - if param == "scan_number_array": - print(new_array) setattr(this, param, new_array) @@ -5487,75 +5482,124 @@ def __add__( # Define parameters that must be the same to add objects compatibility_params = ["_vis_units"] - # Build up history string - history_update_string = " Combined data along " - n_axes = 0 - - # Create blt arrays for convenience - prec_t = -2 * np.floor(np.log10(this._time_array.tols[-1])).astype(int) - prec_b = 8 - this_blts = np.array( - [ - "_".join( - ["{1:.{0}f}".format(prec_t, blt[0]), str(blt[1]).zfill(prec_b)] + # identify params that are not explicitly included in overlap calc per axis + axes = ["Nblts", "Nfreqs", "Npols"] + axis_params_check = {} + axis_overlap_params = { + "Nblts": ["time_array", "baseline_array"], + "Nfreqs": ["freq_array"], + "Npols": ["polarization_array"], + } + axis_dict = {} + for axis in axes: + axis_dict[axis] = this._get_param_axis(axis) + axis_params_check[axis] = [] + for param in axis_dict[axis]: + if ( + param not in this._data_params + and param not in axis_overlap_params[axis] + ): + axis_params_check[axis].append("_" + param) + + # build this/other arrays for checking for overlap. More complicated for + # the blt axis because we need a combo of time and baseline. + axis_vals = {} + for axis, overlap_params in axis_overlap_params.items(): + if axis == "Nblts": + # Create combined arrays for convenience + prec_t = -2 * np.floor(np.log10(this._time_array.tols[-1])).astype(int) + prec_b = 8 + this_blts = np.array( + [ + "_".join( + [ + "{1:.{0}f}".format(prec_t, blt[0]), + str(blt[1]).zfill(prec_b), + ] + ) + for blt in zip( + this.time_array, this.baseline_array, strict=True + ) + ] ) - for blt in zip(this.time_array, this.baseline_array, strict=True) - ] - ) - other_blts = np.array( - [ - "_".join( - ["{1:.{0}f}".format(prec_t, blt[0]), str(blt[1]).zfill(prec_b)] + other_blts = np.array( + [ + "_".join( + [ + "{1:.{0}f}".format(prec_t, blt[0]), + str(blt[1]).zfill(prec_b), + ] + ) + for blt in zip( + other.time_array, other.baseline_array, strict=True + ) + ] ) - for blt in zip(other.time_array, other.baseline_array, strict=True) - ] - ) + axis_vals[axis] = {"this": this_blts, "other": other_blts} + else: + axis_vals[axis] = { + "this": getattr(this, overlap_params[0]), + "other": getattr(other, overlap_params[0]), + } # Check we don't have overlapping data - both_pol, this_pol_ind, other_pol_ind = np.intersect1d( - this.polarization_array, other.polarization_array, return_indices=True - ) - - # If we have a flexible spectral window, the handling here becomes a bit funky, - # because we are allowed to have channels with the same frequency *if* they - # belong to different spectral windows (one real-life example: you might want - # to preserve guard bands in the correlator, which can have overlaping RF - # frequency channels) - this_freq_ind = np.array([], dtype=np.int64) - other_freq_ind = np.array([], dtype=np.int64) - both_freq = np.array([], dtype=float) - both_spw = np.intersect1d(this.spw_array, other.spw_array) - for idx in both_spw: - this_mask = np.where(this.flex_spw_id_array == idx)[0] - other_mask = np.where(other.flex_spw_id_array == idx)[0] - both_spw_freq, this_spw_ind, other_spw_ind = np.intersect1d( - this.freq_array[this_mask], - other.freq_array[other_mask], - return_indices=True, - ) - this_freq_ind = np.append(this_freq_ind, this_mask[this_spw_ind]) - other_freq_ind = np.append(other_freq_ind, other_mask[other_spw_ind]) - both_freq = np.append(both_freq, both_spw_freq) + axis_inds = {} + for axis, val_arr in axis_vals.items(): + if axis == "Nfreqs": + # This is more complicated because we are allowed to have channels + # with the same frequency *if* they belong to different spectral + # windows (one real-life example: you might want to preserve + # guard bands in the correlator, which can have overlaping RF + # frequency channels) + this_freq_ind = np.array([], dtype=np.int64) + other_freq_ind = np.array([], dtype=np.int64) + both_freq_ind = np.array([], dtype=float) + both_spw = np.intersect1d(this.spw_array, other.spw_array) + for idx in both_spw: + this_mask = np.where(this.flex_spw_id_array == idx)[0] + other_mask = np.where(other.flex_spw_id_array == idx)[0] + both_spw_freq, this_spw_ind, other_spw_ind = np.intersect1d( + this.freq_array[this_mask], + other.freq_array[other_mask], + return_indices=True, + ) + this_freq_ind = np.append(this_freq_ind, this_mask[this_spw_ind]) + other_freq_ind = np.append( + other_freq_ind, other_mask[other_spw_ind] + ) + both_freq_ind = np.append(both_freq_ind, both_spw_freq) + axis_inds["Nfreqs"] = { + "this": this_freq_ind, + "other": other_freq_ind, + "both": both_freq_ind, + } + else: + both_inds, this_inds, other_inds = np.intersect1d( + val_arr["this"], val_arr["other"], return_indices=True + ) + axis_inds[axis] = { + "this": this_inds, + "other": other_inds, + "both": both_inds, + } - both_blts, this_blts_ind, other_blts_ind = np.intersect1d( - this_blts, other_blts, return_indices=True - ) - if not self.metadata_only and ( - len(both_pol) > 0 and len(both_freq) > 0 and len(both_blts) > 0 + history_update_string = "" + if not self.metadata_only and np.all( + [len(axis_inds[axis]["both"]) > 0 for axis in axis_inds] ): # check that overlapping data is not valid this_inds = np.ravel_multi_index( ( - this_blts_ind[:, np.newaxis, np.newaxis], - this_freq_ind[np.newaxis, :, np.newaxis], - this_pol_ind[np.newaxis, np.newaxis, :], + axis_inds["Nblts"]["this"][:, np.newaxis, np.newaxis], + axis_inds["Nfreqs"]["this"][np.newaxis, :, np.newaxis], + axis_inds["Npols"]["this"][np.newaxis, np.newaxis, :], ), this.data_array.shape, ).flatten() other_inds = np.ravel_multi_index( ( - other_blts_ind[:, np.newaxis, np.newaxis], - other_freq_ind[np.newaxis, :, np.newaxis], - other_pol_ind[np.newaxis, np.newaxis, :], + axis_inds["Nblts"]["other"][:, np.newaxis, np.newaxis], + axis_inds["Nfreqs"]["other"][np.newaxis, :, np.newaxis], + axis_inds["Npols"]["other"][np.newaxis, np.newaxis, :], ), other.data_array.shape, ).flatten() @@ -5567,7 +5611,6 @@ def __add__( if this_all_zero and this_all_flag: # we're fine to overwrite; update history accordingly history_update_string = " Overwrote invalid data using pyuvdata." - this.history += history_update_string elif other_all_zero and other_all_flag: raise ValueError( "To combine these data, please run the add operation again, " @@ -5579,145 +5622,99 @@ def __add__( "These objects have overlapping data and cannot be combined." ) - # find the blt indices in "other" but not in "this" - temp = np.nonzero(~np.isin(other_blts, this_blts))[0] - if len(temp) > 0: - bnew_inds = temp - new_blts = other_blts[temp] - history_update_string += "baseline-time" - n_axes += 1 - else: - bnew_inds, new_blts = ([], []) - - # if there's any overlap in blts, check extra params - temp = np.nonzero(np.isin(other_blts, this_blts))[0] - if len(temp) > 0: - # add metadata to be checked to compatibility params - extra_params = [ - "_integration_time", - "_lst_array", - "_phase_center_catalog", - "_phase_center_id_array", - "_phase_center_app_ra", - "_phase_center_app_dec", - "_phase_center_frame_pa", - "_Nphase", - "_uvw_array", - ] - compatibility_params.extend(extra_params) - # TODO: make this list programmatically if possible? - if ( - this.scan_number_array is not None - or other.scan_number_array is not None - ): - compatibility_params.append("_scan_number_array") - - # find the freq indices in "other" but not in "this" - if (this.flex_spw_polarization_array is None) != ( - other.flex_spw_polarization_array is None - ): - raise ValueError( - "Cannot add a flex-pol and non-flex-pol UVData objects. Use " - "the `remove_flex_pol` method to convert the objects to " - "have a regular polarization axis." - ) - elif this.flex_spw_polarization_array is not None: - this_flexpol_dict = dict( - zip(this.spw_array, this.flex_spw_polarization_array, strict=True) - ) - other_flexpol_dict = dict( - zip(other.spw_array, other.flex_spw_polarization_array, strict=True) - ) - for key in other_flexpol_dict: - try: - if this_flexpol_dict[key] != other_flexpol_dict[key]: - raise ValueError( - "Cannot add a flex-pol UVData objects where the same " - "spectral window contains different polarizations. Use " - "the `remove_flex_pol` method to convert the objects " - "to have a regular polarization axis." - ) - except KeyError: - this_flexpol_dict[key] = other_flexpol_dict[key] - - other_mask = np.ones_like(other.flex_spw_id_array, dtype=bool) - for idx in np.intersect1d(this.spw_array, other.spw_array): - other_mask[other.flex_spw_id_array == idx] = np.isin( - other.freq_array[other.flex_spw_id_array == idx], - this.freq_array[this.flex_spw_id_array == idx], - invert=True, - ) - temp = np.where(other_mask)[0] - if len(temp) > 0: - fnew_inds = temp - if n_axes > 0: - history_update_string += ", frequency" + new_inds = {} + additions = [] + axis_descriptions = { + "Nblts": "baseline-time", + "Nfreqs": "frequency", + "Npols": "polarization", + } + # find the indices in "other" but not in "this" + for axis in axes: + if axis != "Nfreqs": + temp = np.nonzero( + ~np.isin(axis_vals[axis]["other"], axis_vals[axis]["this"]) + )[0] else: - history_update_string += "frequency" - n_axes += 1 - else: - fnew_inds = [] - - # if channel width is an array and there's any overlap in freqs, - # check extra params - temp = np.nonzero(np.isin(other.freq_array, this.freq_array))[0] - if len(temp) > 0: - # add metadata to be checked to compatibility params - extra_params = ["_channel_width"] - compatibility_params.extend(extra_params) - - # find the pol indices in "other" but not in "this" - temp = np.nonzero(~np.isin(other.polarization_array, this.polarization_array))[ - 0 - ] - if len(temp) > 0: - pnew_inds = temp - if n_axes > 0: - history_update_string += ", polarization" + # more complicated because of spws + if (this.flex_spw_polarization_array is None) != ( + other.flex_spw_polarization_array is None + ): + raise ValueError( + "Cannot add a flex-pol and non-flex-pol UVData objects. Use " + "the `remove_flex_pol` method to convert the objects to " + "have a regular polarization axis." + ) + elif this.flex_spw_polarization_array is not None: + this_flexpol_dict = dict( + zip( + this.spw_array, + this.flex_spw_polarization_array, + strict=True, + ) + ) + other_flexpol_dict = dict( + zip( + other.spw_array, + other.flex_spw_polarization_array, + strict=True, + ) + ) + for key in other_flexpol_dict: + try: + if this_flexpol_dict[key] != other_flexpol_dict[key]: + raise ValueError( + "Cannot add a flex-pol UVData objects where " + "the same spectral window contains different " + "polarizations. Use the `remove_flex_pol` " + "method to convert the objects to have a " + "regular polarization axis." + ) + except KeyError: + this_flexpol_dict[key] = other_flexpol_dict[key] + + other_mask = np.ones_like(other.flex_spw_id_array, dtype=bool) + for idx in np.intersect1d(this.spw_array, other.spw_array): + other_mask[other.flex_spw_id_array == idx] = np.isin( + other.freq_array[other.flex_spw_id_array == idx], + this.freq_array[this.flex_spw_id_array == idx], + invert=True, + ) + temp = np.where(other_mask)[0] + if len(temp) > 0: + new_inds[axis] = temp + # add params associated with the other axes to compatibility_params + for axis2 in axes: + if axis2 != axis: + compatibility_params.extend(axis_params_check[axis2]) + if axis == "Nblts": + new_blts = other_blts[temp] + additions.append(axis_descriptions[axis]) else: - history_update_string += "polarization" - n_axes += 1 - else: - pnew_inds = [] + new_inds[axis] = [] + if axis == "Nblts": + new_blts = ([], []) # Actually check compatibility parameters - blt_inds_params = [ - "_integration_time", - "_lst_array", - "_phase_center_app_ra", - "_phase_center_app_dec", - "_phase_center_frame_pa", - "_phase_center_id_array", - "_scan_number_array", - ] for cp in compatibility_params: - if cp in blt_inds_params: - # only check that overlapping blt indices match - this_param = getattr(this, cp) - other_param = getattr(other, cp) - params_match = np.allclose( - this_param.value[this_blts_ind], - other_param.value[other_blts_ind], - rtol=this_param.tols[0], - atol=this_param.tols[1], - ) - elif cp == "_uvw_array": - # only check that overlapping blt indices match - params_match = np.allclose( - this.uvw_array[this_blts_ind, :], - other.uvw_array[other_blts_ind, :], - rtol=this._uvw_array.tols[0], - atol=this._uvw_array.tols[1], - ) - elif cp == "_channel_width": - # only check that overlapping freq indices match - params_match = np.allclose( - this.channel_width[this_freq_ind], - other.channel_width[other_freq_ind], - rtol=this._channel_width.tols[0], - atol=this._channel_width.tols[1], - ) - else: + params_match = None + for axis, check_list in axis_params_check.items(): + if cp in check_list: + # only check that overlapping blt indices match + this_param = getattr(this, cp) + this_param_overlap = this_param.get_from_form( + {axis: axis_inds[axis]["this"]} + ) + other_param_overlap = getattr(other, cp).get_from_form( + {axis: axis_inds[axis]["other"]} + ) + params_match = np.allclose( + this_param_overlap, + other_param_overlap, + rtol=this_param.tols[0], + atol=this_param.tols[1], + ) + if params_match is None: params_match = getattr(this, cp) == getattr(other, cp) if not params_match: msg = ( @@ -5736,62 +5733,53 @@ def __add__( # Next, we want to make sure that the ordering of the _overlapping_ data is # the same, so that things can get plugged together in a sensible way. - if len(this_blts_ind) != 0: - this_argsort = np.argsort(this_blts_ind) - other_argsort = np.argsort(other_blts_ind) - - if np.any(this_argsort != other_argsort): - temp_ind = np.arange(this.Nblts) - temp_ind[this_blts_ind[this_argsort]] = temp_ind[ - this_blts_ind[other_argsort] - ] - - this.reorder_blts(order=temp_ind) - - if len(this_freq_ind) != 0: - this_argsort = np.argsort(this_freq_ind) - other_argsort = np.argsort(other_freq_ind) - if np.any(this_argsort != other_argsort): - temp_ind = np.arange(this.Nfreqs) - temp_ind[this_freq_ind[this_argsort]] = temp_ind[ - this_freq_ind[other_argsort] - ] - - this.reorder_freqs(channel_order=temp_ind) - - if len(this_pol_ind) != 0: - this_argsort = np.argsort(this_pol_ind) - other_argsort = np.argsort(other_pol_ind) - if np.any(this_argsort != other_argsort): - temp_ind = np.arange(this.Npols) - temp_ind[this_pol_ind[this_argsort]] = temp_ind[ - this_pol_ind[other_argsort] - ] + reorder_method = { + "Nblts": {"method": "reorder_blts", "parameter": "order"}, + "Nfreqs": {"method": "reorder_freqs", "parameter": "channel_order"}, + "Npols": {"method": "reorder_pols", "parameter": "order"}, + } + order_dict = {"Nblts": None, "Nfreqs": None, "Npols": None} + for axis, ind_dict in axis_inds.items(): + if len(ind_dict["this"]) != 0: + this_argsort = np.argsort(ind_dict["this"]) + other_argsort = np.argsort(ind_dict["other"]) + + if np.any(this_argsort != other_argsort): + temp_ind = np.arange(getattr(this, axis)) + temp_ind[ind_dict["this"][this_argsort]] = temp_ind[ + ind_dict["this"][other_argsort] + ] + kwargs = {reorder_method[axis]["parameter"]: temp_ind} - this.reorder_pols(temp_ind) + getattr(this, reorder_method[axis]["method"])(**kwargs) # Pad out self to accommodate new data - blt_order = None - if len(bnew_inds) > 0: - this_blts = np.concatenate((this_blts, new_blts)) - blt_order = np.argsort(this_blts) + for axis_ind, axis in enumerate(axes): + if len(new_inds[axis]) > 0: + if axis == "Nblts": + this_blts = np.concatenate((this_blts, new_blts)) + order_dict["Nblts"] = np.argsort(this_blts) + order_use = order_dict["Nblts"] + else: + order_use = None - _axis_add_helper(this, other, "Nblts", bnew_inds, blt_order) + _axis_add_helper(this, other, axis, new_inds[axis], order_use) - if not self.metadata_only: - zero_pad = np.zeros((len(bnew_inds), this.Nfreqs, this.Npols)) - this.data_array = np.concatenate([this.data_array, zero_pad], axis=0) - this.nsample_array = np.concatenate( - [this.nsample_array, zero_pad], axis=0 - ) - this.flag_array = np.concatenate( - [this.flag_array, 1 - zero_pad], axis=0 - ).astype(np.bool_) - - f_order = None - if len(fnew_inds) > 0: - _axis_add_helper(this, other, "Nfreqs", fnew_inds) + if not self.metadata_only: + pad_shape = list(this.data_array.shape) + pad_shape[axis_ind] = len(new_inds[axis]) + zero_pad = np.zeros(tuple(pad_shape)) + this.data_array = np.concatenate( + [this.data_array, zero_pad], axis=axis_ind + ) + this.nsample_array = np.concatenate( + [this.nsample_array, zero_pad], axis=axis_ind + ) + this.flag_array = np.concatenate( + [this.flag_array, 1 - zero_pad], axis=axis_ind + ).astype(np.bool_) + if len(new_inds["Nfreqs"]) > 0: # We want to preserve per-spw information based on first appearance # in the concatenated array. unique_index = np.sort( @@ -5805,7 +5793,7 @@ def __add__( [this_flexpol_dict[key] for key in this.spw_array] ) # Need to sort out the order of the individual windows first. - f_order = np.concatenate( + order_dict["Nfreqs"] = np.concatenate( [ np.where(this.flex_spw_id_array == idx)[0] for idx in sorted(this.spw_array) @@ -5816,42 +5804,18 @@ def __add__( # windows need sorting. If they are ordered in ascending or descending # fashion, leave them be. If not, sort in ascending order for idx in this.spw_array: - select_mask = this.flex_spw_id_array[f_order] == idx - check_freqs = this.freq_array[f_order[select_mask]] + select_mask = this.flex_spw_id_array[order_dict["Nfreqs"]] == idx + check_freqs = this.freq_array[order_dict["Nfreqs"][select_mask]] if (not np.all(check_freqs[1:] > check_freqs[:-1])) and ( not np.all(check_freqs[1:] < check_freqs[:-1]) ): - subsort_order = f_order[select_mask] - f_order[select_mask] = subsort_order[np.argsort(check_freqs)] - - if not self.metadata_only: - zero_pad = np.zeros( - (this.data_array.shape[0], len(fnew_inds), this.Npols) - ) - this.data_array = np.concatenate([this.data_array, zero_pad], axis=1) - this.nsample_array = np.concatenate( - [this.nsample_array, zero_pad], axis=1 - ) - this.flag_array = np.concatenate( - [this.flag_array, 1 - zero_pad], axis=1 - ).astype(np.bool_) - - p_order = None - if len(pnew_inds) > 0: - _axis_add_helper(this, other, "Npols", pnew_inds) + subsort_order = order_dict["Nfreqs"][select_mask] + order_dict["Nfreqs"][select_mask] = subsort_order[ + np.argsort(check_freqs) + ] - p_order = np.argsort(np.abs(this.polarization_array)) - if not self.metadata_only: - zero_pad = np.zeros( - (this.data_array.shape[0], this.data_array.shape[1], len(pnew_inds)) - ) - this.data_array = np.concatenate([this.data_array, zero_pad], axis=2) - this.nsample_array = np.concatenate( - [this.nsample_array, zero_pad], axis=2 - ) - this.flag_array = np.concatenate( - [this.flag_array, 1 - zero_pad], axis=2 - ).astype(np.bool_) + if len(new_inds["Npols"]) > 0: + order_dict["Npols"] = np.argsort(np.abs(this.polarization_array)) # Now populate the data pol_t2o = np.nonzero( @@ -5875,29 +5839,28 @@ def __add__( this.flag_array[np.ix_(blt_t2o, freq_t2o, pol_t2o)] = other.flag_array # Fix ordering - axis_dict = { - 0: {"inds": bnew_inds, "order": blt_order}, - 1: {"inds": fnew_inds, "order": f_order}, - 2: {"inds": pnew_inds, "order": p_order}, - } - for axis, subdict in axis_dict.items(): + for axis_ind, axis in enumerate(axes): for name, param in zip( this._data_params, this.data_like_parameters, strict=True ): - if len(subdict["inds"]) > 0: - unique_order_diffs = np.unique(np.diff(subdict["order"])) + if len(new_inds[axis]) > 0: + unique_order_diffs = np.unique(np.diff(order_dict[axis])) if np.array_equal(unique_order_diffs, np.array([1])): # everything is already in order continue - setattr(this, name, np.take(param, subdict["order"], axis=axis)) - - if len(fnew_inds) > 0: - this.freq_array = this.freq_array[f_order] - this.channel_width = this.channel_width[f_order] - this.flex_spw_id_array = this.flex_spw_id_array[f_order] + setattr( + this, name, np.take(param, order_dict[axis], axis=axis_ind) + ) - if len(pnew_inds) > 0: - this.polarization_array = this.polarization_array[p_order] + # reorder freq, pol axes but not blt axis because that was already done. + for axis in axes[1:]: + params_to_update = axis_params_check[axis] + [ + "_" + param for param in axis_overlap_params[axis] + ] + if len(new_inds[axis]) > 0: + for param in params_to_update: + this_param = getattr(this, param) + this_param.value = this_param.value[order_dict[axis]] # Update N parameters (e.g. Npols) this.Ntimes = len(np.unique(this.time_array)) @@ -5912,9 +5875,14 @@ def __add__( if this.filename is not None: this._filename.form = (len(this.filename),) - if n_axes > 0: - history_update_string += " axis using pyuvdata." + if len(additions) > 0: + # Build up history string + history_update_string += ( + " Combined data along " + ", ".join(additions) + " axis using pyuvdata." + ) + if len(history_update_string) > 0: + # this can be true even if len(additions)=0 b/c of filling in invalid data. histories_match = utils.history._check_histories( this.history, other.history ) @@ -6148,39 +6116,28 @@ def fast_concat( history_update_string = " Combined data along " - if axis == "freq": - history_update_string += "frequency" - compatibility_params += [ - "_polarization_array", - "_ant_1_array", - "_ant_2_array", - "_integration_time", - "_uvw_array", - "_lst_array", - "_phase_center_id_array", - ] - elif axis == "polarization": - history_update_string += "polarization" - compatibility_params += [ - "_freq_array", - "_channel_width", - "_flex_spw_id_array", - "_ant_1_array", - "_ant_2_array", - "_integration_time", - "_uvw_array", - "_lst_array", - "_phase_center_id_array", - ] - elif axis == "blt": - history_update_string += "baseline-time" - compatibility_params += [ - "_freq_array", - "_polarization_array", - "_flex_spw_id_array", - ] + # identify params that are not explicitly included in overlap calc per axis + axis_shape = {"blt": "Nblts", "freq": "Nfreqs", "polarization": "Npols"} + axis_params_check = {} + axis_dict = {} + for axis2, ax_shape in axis_shape.items(): + axis_dict[axis2] = this._get_param_axis(ax_shape) + axis_params_check[axis2] = [] + for param in axis_dict[axis2]: + if param not in this._data_params: + axis_params_check[axis2].append("_" + param) + + for axis2 in axis_shape: + if axis2 != axis: + compatibility_params.extend(axis_params_check[axis2]) + + axis_descriptions = { + "blt": "baseline-time", + "freq": "frequency", + "polarization": "polarization", + } - history_update_string += " axis using pyuvdata." + history_update_string += f" {axis_descriptions[axis]} axis using pyuvdata." histories_match = [] for obj in other: diff --git a/tests/uvdata/test_uvdata.py b/tests/uvdata/test_uvdata.py index eeb12971d..56fd7b4c5 100644 --- a/tests/uvdata/test_uvdata.py +++ b/tests/uvdata/test_uvdata.py @@ -4012,7 +4012,7 @@ def test_flex_spw_add_concat(sma_mir, add_method, screen1, screen2): [ [], [["unproject_phase", {}], ["select", {"freq_chans": np.arange(32, 64)}]], - "UVParameter phase_center_catalog does not match. Cannot combine objects.", + "UVParameter phase_center_app_dec does not match. Cannot combine objects.", ], [ [["vis_units", "Jy"]], @@ -4024,6 +4024,11 @@ def test_flex_spw_add_concat(sma_mir, add_method, screen1, screen2): [["select", {"freq_chans": np.arange(32, 64)}]], "UVParameter integration_time does not match.", ], + [ + [["scan_number_array", np.arange(1360)]], + [["select", {"freq_chans": np.arange(32, 64)}]], + "UVParameter scan_number_array does not match.", + ], ], ) def test_break_add(casa_uvfits, attr_to_set, attr_to_get, msg): @@ -4033,6 +4038,7 @@ def test_break_add(casa_uvfits, attr_to_set, attr_to_get, msg): """ # Test failure modes of add function uv1 = casa_uvfits + uv1._set_scan_numbers() uv2 = uv1.copy() uv1.select(freq_chans=np.arange(0, 32)) From 82a380f0b0bbb1e17285492f36450deb17274705 Mon Sep 17 00:00:00 2001 From: Bryna Hazelton Date: Tue, 19 Aug 2025 10:36:55 -0700 Subject: [PATCH 04/16] refactor with more convenience methods --- src/pyuvdata/uvdata/uvdata.py | 82 +++++++++++++---------------------- 1 file changed, 29 insertions(+), 53 deletions(-) diff --git a/src/pyuvdata/uvdata/uvdata.py b/src/pyuvdata/uvdata/uvdata.py index 0afed4a18..f2b2c423d 100644 --- a/src/pyuvdata/uvdata/uvdata.py +++ b/src/pyuvdata/uvdata/uvdata.py @@ -5395,6 +5395,23 @@ def fix_phase(self, *, use_ant_pos=True): use_ant_pos=False, ) + def blt_str_arr(self): + """Create a string array with baseline and time info for matching purposes.""" + prec_t = -2 * np.floor(np.log10(self._time_array.tols[-1])).astype(int) + prec_b = 8 + return np.array( + [ + "_".join( + ["{1:.{0}f}".format(prec_t, blt[0]), str(blt[1]).zfill(prec_b)] + ) + for blt in zip(self.time_array, self.baseline_array, strict=True) + ] + ) + + def flexpol_dict(self): + """Create a dict with flexpol information for comparison.""" + return dict(zip(self.spw_array, self.flex_spw_polarization_array, strict=True)) + def __add__( self, other, @@ -5507,34 +5524,8 @@ def __add__( for axis, overlap_params in axis_overlap_params.items(): if axis == "Nblts": # Create combined arrays for convenience - prec_t = -2 * np.floor(np.log10(this._time_array.tols[-1])).astype(int) - prec_b = 8 - this_blts = np.array( - [ - "_".join( - [ - "{1:.{0}f}".format(prec_t, blt[0]), - str(blt[1]).zfill(prec_b), - ] - ) - for blt in zip( - this.time_array, this.baseline_array, strict=True - ) - ] - ) - other_blts = np.array( - [ - "_".join( - [ - "{1:.{0}f}".format(prec_t, blt[0]), - str(blt[1]).zfill(prec_b), - ] - ) - for blt in zip( - other.time_array, other.baseline_array, strict=True - ) - ] - ) + this_blts = this.blt_str_arr() + other_blts = other.blt_str_arr() axis_vals[axis] = {"this": this_blts, "other": other_blts} else: axis_vals[axis] = { @@ -5646,20 +5637,8 @@ def __add__( "have a regular polarization axis." ) elif this.flex_spw_polarization_array is not None: - this_flexpol_dict = dict( - zip( - this.spw_array, - this.flex_spw_polarization_array, - strict=True, - ) - ) - other_flexpol_dict = dict( - zip( - other.spw_array, - other.flex_spw_polarization_array, - strict=True, - ) - ) + this_flexpol_dict = this.flexpol_dict() + other_flexpol_dict = other.flexpol_dict() for key in other_flexpol_dict: try: if this_flexpol_dict[key] != other_flexpol_dict[key]: @@ -6176,26 +6155,23 @@ def fast_concat( this.telescope = tel_obj + # update the relevant shape parameter + _axis_fast_concat_helper(this, other, axis_shape[axis]) + new_shape = sum( + [getattr(this, axis_shape[axis])] + + [getattr(obj, axis_shape[axis]) for obj in other] + ) + setattr(this, axis_shape[axis], new_shape) + if axis == "freq": - this.Nfreqs = sum([this.Nfreqs] + [obj.Nfreqs for obj in other]) - _axis_fast_concat_helper(this, other, "Nfreqs") - _axis_fast_concat_helper(this, other, "Nspws") # We want to preserve per-spw information based on first appearance # in the concatenated array. unique_index = np.sort( np.unique(this.flex_spw_id_array, return_index=True)[1] ) this.spw_array = this.flex_spw_id_array[unique_index] - this.Nspws = len(this.spw_array) - - elif axis == "polarization": - _axis_fast_concat_helper(this, other, "Npols") - this.Npols = sum([this.Npols] + [obj.Npols for obj in other]) - elif axis == "blt": - this.Nblts = sum([this.Nblts] + [obj.Nblts for obj in other]) - _axis_fast_concat_helper(this, other, "Nblts") this.Nants_data = this._calc_nants_data() this.Ntimes = len(np.unique(this.time_array)) this.Nbls = len(np.unique(this.baseline_array)) From c4cd2f16ed042a0294c9fd73fb079f058b5768fa Mon Sep 17 00:00:00 2001 From: Bryna Hazelton Date: Tue, 19 Aug 2025 14:35:37 -0700 Subject: [PATCH 05/16] handle spw, freq in the same way as time, bl --- src/pyuvdata/uvdata/uvdata.py | 207 +++++++++++++--------------------- tests/uvdata/test_mir.py | 4 +- 2 files changed, 81 insertions(+), 130 deletions(-) diff --git a/src/pyuvdata/uvdata/uvdata.py b/src/pyuvdata/uvdata/uvdata.py index f2b2c423d..ad047d8fd 100644 --- a/src/pyuvdata/uvdata/uvdata.py +++ b/src/pyuvdata/uvdata/uvdata.py @@ -40,6 +40,19 @@ ) +def flt_ind_str_arr(*, fltarr, intarr, flt_tols, flt_first=True): + """Create a string array built from float and integer arrays for matching.""" + prec_flt = -2 * np.floor(np.log10(flt_tols[-1])).astype(int) + prec_int = 8 + flt_str_list = ["{1:.{0}f}".format(prec_flt, flt) for flt in fltarr] + int_str_list = [str(intv).zfill(prec_int) for intv in intarr] + if flt_first: + zipped = zip(flt_str_list, int_str_list, strict=True) + else: + zipped = zip(int_str_list, flt_str_list, strict=True) + return np.array(["_".join(zpval) for zpval in zipped]) + + def _axis_add_helper(this, other, axis_name: str, other_inds, final_order=None): update_params = this._get_param_axis(axis_name, single_named_axis=True) other_form_dict = {axis_name: other_inds} @@ -5397,15 +5410,20 @@ def fix_phase(self, *, use_ant_pos=True): def blt_str_arr(self): """Create a string array with baseline and time info for matching purposes.""" - prec_t = -2 * np.floor(np.log10(self._time_array.tols[-1])).astype(int) - prec_b = 8 - return np.array( - [ - "_".join( - ["{1:.{0}f}".format(prec_t, blt[0]), str(blt[1]).zfill(prec_b)] - ) - for blt in zip(self.time_array, self.baseline_array, strict=True) - ] + return flt_ind_str_arr( + fltarr=self.time_array, + intarr=self.baseline_array, + flt_tols=self._time_array.tols, + flt_first=True, + ) + + def spw_freq_str_arr(self): + """Create a string array with spw and freq info for matching purposes.""" + return flt_ind_str_arr( + fltarr=self.freq_array, + intarr=self.flex_spw_id_array, + flt_tols=self._freq_array.tols, + flt_first=False, ) def flexpol_dict(self): @@ -5504,9 +5522,10 @@ def __add__( axis_params_check = {} axis_overlap_params = { "Nblts": ["time_array", "baseline_array"], - "Nfreqs": ["freq_array"], + "Nfreqs": ["freq_array", "flex_spw_id_array"], "Npols": ["polarization_array"], } + axis_combined_func = {"Nblts": "blt_str_arr", "Nfreqs": "spw_freq_str_arr"} axis_dict = {} for axis in axes: axis_dict[axis] = this._get_param_axis(axis) @@ -5522,56 +5541,28 @@ def __add__( # the blt axis because we need a combo of time and baseline. axis_vals = {} for axis, overlap_params in axis_overlap_params.items(): - if axis == "Nblts": - # Create combined arrays for convenience - this_blts = this.blt_str_arr() - other_blts = other.blt_str_arr() - axis_vals[axis] = {"this": this_blts, "other": other_blts} + if len(overlap_params) > 1: + axis_vals[axis] = { + "this": getattr(this, axis_combined_func[axis])(), + "other": getattr(other, axis_combined_func[axis])(), + } else: axis_vals[axis] = { "this": getattr(this, overlap_params[0]), "other": getattr(other, overlap_params[0]), } + # Check we don't have overlapping data axis_inds = {} for axis, val_arr in axis_vals.items(): - if axis == "Nfreqs": - # This is more complicated because we are allowed to have channels - # with the same frequency *if* they belong to different spectral - # windows (one real-life example: you might want to preserve - # guard bands in the correlator, which can have overlaping RF - # frequency channels) - this_freq_ind = np.array([], dtype=np.int64) - other_freq_ind = np.array([], dtype=np.int64) - both_freq_ind = np.array([], dtype=float) - both_spw = np.intersect1d(this.spw_array, other.spw_array) - for idx in both_spw: - this_mask = np.where(this.flex_spw_id_array == idx)[0] - other_mask = np.where(other.flex_spw_id_array == idx)[0] - both_spw_freq, this_spw_ind, other_spw_ind = np.intersect1d( - this.freq_array[this_mask], - other.freq_array[other_mask], - return_indices=True, - ) - this_freq_ind = np.append(this_freq_ind, this_mask[this_spw_ind]) - other_freq_ind = np.append( - other_freq_ind, other_mask[other_spw_ind] - ) - both_freq_ind = np.append(both_freq_ind, both_spw_freq) - axis_inds["Nfreqs"] = { - "this": this_freq_ind, - "other": other_freq_ind, - "both": both_freq_ind, - } - else: - both_inds, this_inds, other_inds = np.intersect1d( - val_arr["this"], val_arr["other"], return_indices=True - ) - axis_inds[axis] = { - "this": this_inds, - "other": other_inds, - "both": both_inds, - } + both_inds, this_inds, other_inds = np.intersect1d( + val_arr["this"], val_arr["other"], return_indices=True + ) + axis_inds[axis] = { + "this": this_inds, + "other": other_inds, + "both": both_inds, + } history_update_string = "" if not self.metadata_only and np.all( @@ -5622,12 +5613,11 @@ def __add__( } # find the indices in "other" but not in "this" for axis in axes: - if axis != "Nfreqs": - temp = np.nonzero( - ~np.isin(axis_vals[axis]["other"], axis_vals[axis]["this"]) - )[0] - else: - # more complicated because of spws + if axis == "Nfreqs" and ( + this.flex_spw_polarization_array is not None + or other.flex_spw_polarization_array is not None + ): + # special checking for flex_spw if (this.flex_spw_polarization_array is None) != ( other.flex_spw_polarization_array is None ): @@ -5652,27 +5642,18 @@ def __add__( except KeyError: this_flexpol_dict[key] = other_flexpol_dict[key] - other_mask = np.ones_like(other.flex_spw_id_array, dtype=bool) - for idx in np.intersect1d(this.spw_array, other.spw_array): - other_mask[other.flex_spw_id_array == idx] = np.isin( - other.freq_array[other.flex_spw_id_array == idx], - this.freq_array[this.flex_spw_id_array == idx], - invert=True, - ) - temp = np.where(other_mask)[0] + temp = np.nonzero( + ~np.isin(axis_vals[axis]["other"], axis_vals[axis]["this"]) + )[0] if len(temp) > 0: new_inds[axis] = temp # add params associated with the other axes to compatibility_params for axis2 in axes: if axis2 != axis: compatibility_params.extend(axis_params_check[axis2]) - if axis == "Nblts": - new_blts = other_blts[temp] additions.append(axis_descriptions[axis]) else: new_inds[axis] = [] - if axis == "Nblts": - new_blts = ([], []) # Actually check compatibility parameters for cp in compatibility_params: @@ -5720,6 +5701,7 @@ def __add__( order_dict = {"Nblts": None, "Nfreqs": None, "Npols": None} for axis, ind_dict in axis_inds.items(): if len(ind_dict["this"]) != 0: + # there is some overlap, so sorting matters this_argsort = np.argsort(ind_dict["this"]) other_argsort = np.argsort(ind_dict["other"]) @@ -5733,16 +5715,18 @@ def __add__( getattr(this, reorder_method[axis]["method"])(**kwargs) # Pad out self to accommodate new data + new_axis_inds = {} for axis_ind, axis in enumerate(axes): if len(new_inds[axis]) > 0: - if axis == "Nblts": - this_blts = np.concatenate((this_blts, new_blts)) - order_dict["Nblts"] = np.argsort(this_blts) - order_use = order_dict["Nblts"] + new_axis_inds[axis] = np.concatenate( + (axis_vals[axis]["this"], axis_vals[axis]["other"][new_inds[axis]]) + ) + if axis == "Npols": + order_dict[axis] = np.argsort(np.abs(new_axis_inds[axis])) else: - order_use = None + order_dict[axis] = np.argsort(new_axis_inds[axis]) - _axis_add_helper(this, other, axis, new_inds[axis], order_use) + _axis_add_helper(this, other, axis, new_inds[axis], order_dict[axis]) if not self.metadata_only: pad_shape = list(this.data_array.shape) @@ -5757,6 +5741,8 @@ def __add__( this.flag_array = np.concatenate( [this.flag_array, 1 - zero_pad], axis=axis_ind ).astype(np.bool_) + else: + new_axis_inds[axis] = axis_vals[axis]["this"] if len(new_inds["Nfreqs"]) > 0: # We want to preserve per-spw information based on first appearance @@ -5771,51 +5757,24 @@ def __add__( this.flex_spw_polarization_array = np.array( [this_flexpol_dict[key] for key in this.spw_array] ) - # Need to sort out the order of the individual windows first. - order_dict["Nfreqs"] = np.concatenate( - [ - np.where(this.flex_spw_id_array == idx)[0] - for idx in sorted(this.spw_array) - ] - ) - - # With spectral windows sorted, check and see if channels within - # windows need sorting. If they are ordered in ascending or descending - # fashion, leave them be. If not, sort in ascending order - for idx in this.spw_array: - select_mask = this.flex_spw_id_array[order_dict["Nfreqs"]] == idx - check_freqs = this.freq_array[order_dict["Nfreqs"][select_mask]] - if (not np.all(check_freqs[1:] > check_freqs[:-1])) and ( - not np.all(check_freqs[1:] < check_freqs[:-1]) - ): - subsort_order = order_dict["Nfreqs"][select_mask] - order_dict["Nfreqs"][select_mask] = subsort_order[ - np.argsort(check_freqs) - ] - - if len(new_inds["Npols"]) > 0: - order_dict["Npols"] = np.argsort(np.abs(this.polarization_array)) # Now populate the data - pol_t2o = np.nonzero( - np.isin(this.polarization_array, other.polarization_array) - )[0] - this_freqs = this.freq_array - other_freqs = other.freq_array - - freq_t2o = np.zeros(this_freqs.shape, dtype=bool) - for spw_id in set(this.spw_array).intersection(other.spw_array): - mask = this.flex_spw_id_array == spw_id - freq_t2o[mask] |= np.isin( - this_freqs[mask], other_freqs[other.flex_spw_id_array == spw_id] - ) - freq_t2o = np.nonzero(freq_t2o)[0] - blt_t2o = np.nonzero(np.isin(this_blts, other_blts))[0] + t2o_dict = {} + for axis, inds_dict in axis_vals.items(): + t2o_dict[axis] = np.nonzero( + np.isin(new_axis_inds[axis], inds_dict["other"]) + )[0] if not self.metadata_only: - this.data_array[np.ix_(blt_t2o, freq_t2o, pol_t2o)] = other.data_array - this.nsample_array[np.ix_(blt_t2o, freq_t2o, pol_t2o)] = other.nsample_array - this.flag_array[np.ix_(blt_t2o, freq_t2o, pol_t2o)] = other.flag_array + this.data_array[ + np.ix_(t2o_dict["Nblts"], t2o_dict["Nfreqs"], t2o_dict["Npols"]) + ] = other.data_array + this.nsample_array[ + np.ix_(t2o_dict["Nblts"], t2o_dict["Nfreqs"], t2o_dict["Npols"]) + ] = other.nsample_array + this.flag_array[ + np.ix_(t2o_dict["Nblts"], t2o_dict["Nfreqs"], t2o_dict["Npols"]) + ] = other.flag_array # Fix ordering for axis_ind, axis in enumerate(axes): @@ -5831,16 +5790,6 @@ def __add__( this, name, np.take(param, order_dict[axis], axis=axis_ind) ) - # reorder freq, pol axes but not blt axis because that was already done. - for axis in axes[1:]: - params_to_update = axis_params_check[axis] + [ - "_" + param for param in axis_overlap_params[axis] - ] - if len(new_inds[axis]) > 0: - for param in params_to_update: - this_param = getattr(this, param) - this_param.value = this_param.value[order_dict[axis]] - # Update N parameters (e.g. Npols) this.Ntimes = len(np.unique(this.time_array)) this.Nbls = len(np.unique(this.baseline_array)) @@ -5881,8 +5830,8 @@ def __add__( ) # Reset blt_order if blt axis was added to and it is set - if len(blt_t2o) > 0: - this.blt_order = None + if len(t2o_dict["Nblts"]) > 0: + this.blt_order = ("time", "baseline") this.set_rectangularity(force=True) diff --git a/tests/uvdata/test_mir.py b/tests/uvdata/test_mir.py index 30ce53d1d..db1e3f6b8 100644 --- a/tests/uvdata/test_mir.py +++ b/tests/uvdata/test_mir.py @@ -592,9 +592,11 @@ def test_flex_pol_add(sma_mir_filt): sma_yy_copy._make_flex_pol() # Add the two back together here, and make sure we can the same value out, - # modulo the history. + # modulo the history and sorting. sma_check = sma_yy_copy + sma_xx_copy + sma_mir_filt.reorder_freqs(channel_order="freq") + assert sma_check.history != sma_mir_filt.history sma_check.history = sma_mir_filt.history = None From 34b2cefe2c02fd94e6a52d61ae352d60fe2933b6 Mon Sep 17 00:00:00 2001 From: Bryna Hazelton Date: Tue, 19 Aug 2025 14:47:19 -0700 Subject: [PATCH 06/16] update code comments --- src/pyuvdata/uvdata/uvdata.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/pyuvdata/uvdata/uvdata.py b/src/pyuvdata/uvdata/uvdata.py index ad047d8fd..f7b92c3f6 100644 --- a/src/pyuvdata/uvdata/uvdata.py +++ b/src/pyuvdata/uvdata/uvdata.py @@ -5537,8 +5537,9 @@ def __add__( ): axis_params_check[axis].append("_" + param) - # build this/other arrays for checking for overlap. More complicated for - # the blt axis because we need a combo of time and baseline. + # build this/other arrays for checking for overlap. + # Use a combined string if there are multiple arrays defining overlap + # (e.g. baseline-time, spw-freq) axis_vals = {} for axis, overlap_params in axis_overlap_params.items(): if len(overlap_params) > 1: @@ -5660,7 +5661,7 @@ def __add__( params_match = None for axis, check_list in axis_params_check.items(): if cp in check_list: - # only check that overlapping blt indices match + # only check that overlapping indices match this_param = getattr(this, cp) this_param_overlap = this_param.get_from_form( {axis: axis_inds[axis]["this"]} @@ -5829,7 +5830,7 @@ def __add__( + extra_history ) - # Reset blt_order if blt axis was added to and it is set + # Reset blt_order if blt axis was added to if len(t2o_dict["Nblts"]) > 0: this.blt_order = ("time", "baseline") From aab95b8890c6cb3619cd29a5fa467b8a8073f82b Mon Sep 17 00:00:00 2001 From: Bryna Hazelton Date: Tue, 19 Aug 2025 15:55:03 -0700 Subject: [PATCH 07/16] fix handling for spws with descending freqs --- src/pyuvdata/uvdata/uvdata.py | 44 ++++++++++++++++++++++++++++++++--- tests/uvdata/test_mir.py | 4 +--- 2 files changed, 42 insertions(+), 6 deletions(-) diff --git a/src/pyuvdata/uvdata/uvdata.py b/src/pyuvdata/uvdata/uvdata.py index f7b92c3f6..c92d30f75 100644 --- a/src/pyuvdata/uvdata/uvdata.py +++ b/src/pyuvdata/uvdata/uvdata.py @@ -46,13 +46,34 @@ def flt_ind_str_arr(*, fltarr, intarr, flt_tols, flt_first=True): prec_int = 8 flt_str_list = ["{1:.{0}f}".format(prec_flt, flt) for flt in fltarr] int_str_list = [str(intv).zfill(prec_int) for intv in intarr] + list_of_lists = [] if flt_first: - zipped = zip(flt_str_list, int_str_list, strict=True) + list_of_lists = [flt_str_list, int_str_list] else: - zipped = zip(int_str_list, flt_str_list, strict=True) + list_of_lists = [int_str_list, flt_str_list] + zipped = zip(*list_of_lists, strict=True) return np.array(["_".join(zpval) for zpval in zipped]) +def _get_freq_order(spw_id, freq_arr): + spws = np.unique(spw_id) + f_order = np.concatenate([np.where(spw_id == spw)[0] for spw in np.unique(spw_id)]) + + # With spectral windows sorted, check and see if channels within + # windows need sorting. If they are ordered in ascending or descending + # fashion, leave them be. If not, sort in ascending order + for spw in spws: + select_mask = spw_id[f_order] == spw + check_freqs = freq_arr[f_order[select_mask]] + if not np.all(np.diff(check_freqs) > 0) and not np.all( + np.diff(check_freqs) < 0 + ): + subsort_order = f_order[select_mask] + f_order[select_mask] = subsort_order[np.argsort(check_freqs)] + + return f_order + + def _axis_add_helper(this, other, axis_name: str, other_inds, final_order=None): update_params = this._get_param_axis(axis_name, single_named_axis=True) other_form_dict = {axis_name: other_inds} @@ -5699,7 +5720,6 @@ def __add__( "Nfreqs": {"method": "reorder_freqs", "parameter": "channel_order"}, "Npols": {"method": "reorder_pols", "parameter": "order"}, } - order_dict = {"Nblts": None, "Nfreqs": None, "Npols": None} for axis, ind_dict in axis_inds.items(): if len(ind_dict["this"]) != 0: # there is some overlap, so sorting matters @@ -5717,6 +5737,7 @@ def __add__( # Pad out self to accommodate new data new_axis_inds = {} + order_dict = {"Nblts": None, "Nfreqs": None, "Npols": None} for axis_ind, axis in enumerate(axes): if len(new_inds[axis]) > 0: new_axis_inds[axis] = np.concatenate( @@ -5724,6 +5745,23 @@ def __add__( ) if axis == "Npols": order_dict[axis] = np.argsort(np.abs(new_axis_inds[axis])) + elif axis == "Nfreqs" and ( + np.any(np.diff(this.freq_array) < 0) + or np.any(np.diff(other.freq_array) < 0) + ): + # deal with the possibility of spws with channels in + # descending order. + order_dict[axis] = _get_freq_order( + np.concatenate( + ( + this.flex_spw_id_array, + other.flex_spw_id_array[new_inds[axis]], + ) + ), + np.concatenate( + (this.freq_array, other.freq_array[new_inds[axis]]) + ), + ) else: order_dict[axis] = np.argsort(new_axis_inds[axis]) diff --git a/tests/uvdata/test_mir.py b/tests/uvdata/test_mir.py index db1e3f6b8..30ce53d1d 100644 --- a/tests/uvdata/test_mir.py +++ b/tests/uvdata/test_mir.py @@ -592,11 +592,9 @@ def test_flex_pol_add(sma_mir_filt): sma_yy_copy._make_flex_pol() # Add the two back together here, and make sure we can the same value out, - # modulo the history and sorting. + # modulo the history. sma_check = sma_yy_copy + sma_xx_copy - sma_mir_filt.reorder_freqs(channel_order="freq") - assert sma_check.history != sma_mir_filt.history sma_check.history = sma_mir_filt.history = None From 0aab0a27589d7d87dce42634026bdfe8408ab593 Mon Sep 17 00:00:00 2001 From: Bryna Hazelton Date: Tue, 23 Sep 2025 15:18:05 -0700 Subject: [PATCH 08/16] docstrings and annotations for convenience methods --- src/pyuvdata/uvdata/uvdata.py | 96 +++++++++++++++++++++++++++++++---- 1 file changed, 87 insertions(+), 9 deletions(-) diff --git a/src/pyuvdata/uvdata/uvdata.py b/src/pyuvdata/uvdata/uvdata.py index c92d30f75..8d6701a5b 100644 --- a/src/pyuvdata/uvdata/uvdata.py +++ b/src/pyuvdata/uvdata/uvdata.py @@ -27,6 +27,7 @@ from ..utils import phasing as phs_utils from ..utils.io import hdf5 as hdf5_utils from ..utils.phasing import _get_focus_xyz, _get_nearfield_delay +from ..utils.types import FloatArray, IntArray, StrArray from ..uvbase import UVBase from .initializers import new_uvdata @@ -40,8 +41,33 @@ ) -def flt_ind_str_arr(*, fltarr, intarr, flt_tols, flt_first=True): - """Create a string array built from float and integer arrays for matching.""" +def flt_ind_str_arr( + *, + fltarr: FloatArray, + intarr: IntArray, + flt_tols: tuple[float, float], + flt_first: bool = True, +) -> StrArray: + """ + Create a string array built from float and integer arrays for matching. + + Parameters + ---------- + fltarr : np.ndarray of float + float array to be used in output string array + intarr : np.ndarray of int + integer array to be used in output string array + flt_tols : 2-tuple of float + Tolerances (relative, absolute) to use in formatting the floats as strings. + flt_first : bool + Whether to put the float first in the out put string or not (if False + the int comes first.) + + Returns + ------- + np.ndarray of str + String array that combines the float and integer values, useful for matching. + """ prec_flt = -2 * np.floor(np.log10(flt_tols[-1])).astype(int) prec_int = 8 flt_str_list = ["{1:.{0}f}".format(prec_flt, flt) for flt in fltarr] @@ -55,7 +81,25 @@ def flt_ind_str_arr(*, fltarr, intarr, flt_tols, flt_first=True): return np.array(["_".join(zpval) for zpval in zipped]) -def _get_freq_order(spw_id, freq_arr): +def _add_freq_order(spw_id: IntArray, freq_arr: FloatArray) -> IntArray: + """ + Get the sorting order for the frequency axis after an add. + + Sort first by spw then by channel, but don't reorder channels if they are + changing monotonically (all ascending or descending). + + Parameters + ---------- + spw_id : np.ndarray of int + SPW id array of combined data to be sorted. + freq_arr : np.ndarray of float + Frequency array of combined data to be sorted. + + Returns + ------- + f_order : np.ndarray of int + index array giving the sort order. + """ spws = np.unique(spw_id) f_order = np.concatenate([np.where(spw_id == spw)[0] for spw in np.unique(spw_id)]) @@ -74,7 +118,29 @@ def _get_freq_order(spw_id, freq_arr): return f_order -def _axis_add_helper(this, other, axis_name: str, other_inds, final_order=None): +def _axis_add_helper( + this: UVData, + other: UVData, + axis_name: str, + other_inds: IntArray, + final_order: IntArray | None = None, +): + """ + Combine UVData objects along an axis. + + Parameters + ---------- + this : UVData + The left UVData object in the add. + other : UVData + The right UVData object in the add. + axis_name : str + The axis name (e.g. "Nblts", "Npols"). + other_inds : np.ndarray of int + Indices into the other object along this axis to include. + final_order : np.ndarray of int + Final ordering array giving the sort order after concatenation. + """ update_params = this._get_param_axis(axis_name, single_named_axis=True) other_form_dict = {axis_name: other_inds} for param, axis_list in update_params.items(): @@ -92,7 +158,19 @@ def _axis_add_helper(this, other, axis_name: str, other_inds, final_order=None): setattr(this, param, new_array) -def _axis_fast_concat_helper(this, other, axis_name: str): +def _axis_fast_concat_helper(this: UVData, other: UVData, axis_name: str): + """ + Concatenate UVData objects along an axis assuming no overlap. + + Parameters + ---------- + this : UVData + The left UVData object in the add. + other : UVData + The right UVData object in the add. + axis_name : str + The axis name (e.g. "Nblts", "Npols"). + """ update_params = this._get_param_axis(axis_name) for param, axis_list in update_params.items(): axis = axis_list[0] @@ -5429,7 +5507,7 @@ def fix_phase(self, *, use_ant_pos=True): use_ant_pos=False, ) - def blt_str_arr(self): + def blt_str_arr(self) -> StrArray: """Create a string array with baseline and time info for matching purposes.""" return flt_ind_str_arr( fltarr=self.time_array, @@ -5438,7 +5516,7 @@ def blt_str_arr(self): flt_first=True, ) - def spw_freq_str_arr(self): + def spw_freq_str_arr(self) -> StrArray: """Create a string array with spw and freq info for matching purposes.""" return flt_ind_str_arr( fltarr=self.freq_array, @@ -5447,7 +5525,7 @@ def spw_freq_str_arr(self): flt_first=False, ) - def flexpol_dict(self): + def flexpol_dict(self) -> dict: """Create a dict with flexpol information for comparison.""" return dict(zip(self.spw_array, self.flex_spw_polarization_array, strict=True)) @@ -5751,7 +5829,7 @@ def __add__( ): # deal with the possibility of spws with channels in # descending order. - order_dict[axis] = _get_freq_order( + order_dict[axis] = _add_freq_order( np.concatenate( ( this.flex_spw_id_array, From 038768c65b1511d7a3ca5ac952d6c2b79b8f33e1 Mon Sep 17 00:00:00 2001 From: Bryna Hazelton Date: Tue, 19 Aug 2025 18:13:19 -0700 Subject: [PATCH 09/16] handle multidimensional arrays programmatically --- src/pyuvdata/uvbase.py | 17 ++++ src/pyuvdata/uvdata/uvdata.py | 151 +++++++++++++++++++++++----------- 2 files changed, 120 insertions(+), 48 deletions(-) diff --git a/src/pyuvdata/uvbase.py b/src/pyuvdata/uvbase.py index 2684dbd98..5f87d8453 100644 --- a/src/pyuvdata/uvbase.py +++ b/src/pyuvdata/uvbase.py @@ -771,6 +771,7 @@ def _get_param_axis(self, axis_name: str, single_named_axis: bool = False): The keys are UVParameter names that have an axis with axis_name (axis_name appears in their form). The values are a list of the axis indices where axis_name appears in their form. + """ ret_dict = {} for param in self: @@ -795,6 +796,22 @@ def _get_param_axis(self, axis_name: str, single_named_axis: bool = False): ret_dict[attr.name] = np.nonzero(np.asarray(attr.form) == axis_name)[0] return ret_dict + def _get_multi_axis_params(self) -> list[str]: + """Get a list of all multidimensional parameters.""" + ret_list = [] + for param in self: + # For each attribute, if the value is None, then bail, otherwise + # attempt to figure out along which axis ind_arr will apply. + + attr = getattr(self, param) + if ( + attr.value is not None + and isinstance(attr.form, tuple) + and sum([isinstance(entry, str) for entry in attr.form]) > 1 + ): + ret_list.append(attr.name) + return ret_list + def _select_along_param_axis(self, param_dict: dict): """ Downselect values along a parameterized axis. diff --git a/src/pyuvdata/uvdata/uvdata.py b/src/pyuvdata/uvdata/uvdata.py index 8d6701a5b..39a1c3826 100644 --- a/src/pyuvdata/uvdata/uvdata.py +++ b/src/pyuvdata/uvdata/uvdata.py @@ -67,6 +67,7 @@ def flt_ind_str_arr( ------- np.ndarray of str String array that combines the float and integer values, useful for matching. + """ prec_flt = -2 * np.floor(np.log10(flt_tols[-1])).astype(int) prec_int = 8 @@ -99,6 +100,7 @@ def _add_freq_order(spw_id: IntArray, freq_arr: FloatArray) -> IntArray: ------- f_order : np.ndarray of int index array giving the sort order. + """ spws = np.unique(spw_id) f_order = np.concatenate([np.where(spw_id == spw)[0] for spw in np.unique(spw_id)]) @@ -126,7 +128,7 @@ def _axis_add_helper( final_order: IntArray | None = None, ): """ - Combine UVData objects along an axis. + Combine UVParameter objects with a single axis along an axis. Parameters ---------- @@ -140,6 +142,7 @@ def _axis_add_helper( Indices into the other object along this axis to include. final_order : np.ndarray of int Final ordering array giving the sort order after concatenation. + """ update_params = this._get_param_axis(axis_name, single_named_axis=True) other_form_dict = {axis_name: other_inds} @@ -158,9 +161,90 @@ def _axis_add_helper( setattr(this, param, new_array) +def _axis_pad_helper(this: UVData, axis_name: str, add_len: int): + """ + Pad out UVParameter objects with multiple dimensions along an axis. + + Parameters + ---------- + this : UVData + The left UVData object in the add. + axis_name : str + The axis name (e.g. "Nblts", "Npols"). + add_len : int + The extra length to be padded on for this axis. + + """ + update_params = this._get_param_axis(axis_name) + multi_axis_params = this._get_multi_axis_params() + for param, axis_list in update_params.items(): + if param not in multi_axis_params: + continue + this_param_shape = getattr(this, param).shape + this_param_type = getattr(this, "_" + param).expected_type + bool_type = this_param_type is bool or bool in this_param_type + pad_shape = list(this_param_shape) + for ax in axis_list: + pad_shape[ax] = add_len + if bool_type: + pad_array = np.ones(tuple(pad_shape), dtype=bool) + else: + pad_array = np.zeros(tuple(pad_shape)) + new_array = np.concatenate([getattr(this, param), pad_array], axis=ax) + if bool_type: + new_array = new_array.astype(np.bool_) + setattr(this, param, new_array) + + +def _fill_multi_helper( + this: UVData, other: UVData, t2o_dict: dict, sort_axes: list[str], order_dict: dict +): + """ + Fill UVParameter objects with multiple dimensions from the right side object. + + Parameters + ---------- + this : UVData + The left UVData object in the add. + other : UVData + The right UVData object in the add. + t2o_dict : dict + dict giving the indices in the left object to be filled from the right + object for each axis (keys are axes, values are index arrays). + sort_axes : list of str + The axes that need to be sorted along. + order_dict : dict + dict giving the final sort indices for each axis (keys are axes, values + are index arrays for sorting). + + """ + multi_axis_params = this._get_multi_axis_params() + for param in multi_axis_params: + form = getattr(this, "_" + param).form + index_list = [] + for axis in form: + index_list.append(t2o_dict[axis]) + new_arr = getattr(this, param) + new_arr[np.ix_(*index_list)] = getattr(other, param) + setattr(this, param, new_arr) + + # Fix ordering + for axis_ind, axis in enumerate(form): + if axis in sort_axes: + unique_order_diffs = np.unique(np.diff(order_dict[axis])) + if np.array_equal(unique_order_diffs, np.array([1])): + # everything is already in order + continue + setattr( + this, + param, + np.take(getattr(this, param), order_dict[axis], axis=axis_ind), + ) + + def _axis_fast_concat_helper(this: UVData, other: UVData, axis_name: str): """ - Concatenate UVData objects along an axis assuming no overlap. + Concatenate UVParameter objects along an axis assuming no overlap. Parameters ---------- @@ -5816,7 +5900,7 @@ def __add__( # Pad out self to accommodate new data new_axis_inds = {} order_dict = {"Nblts": None, "Nfreqs": None, "Npols": None} - for axis_ind, axis in enumerate(axes): + for axis in axes: if len(new_inds[axis]) > 0: new_axis_inds[axis] = np.concatenate( (axis_vals[axis]["this"], axis_vals[axis]["other"][new_inds[axis]]) @@ -5843,24 +5927,27 @@ def __add__( else: order_dict[axis] = np.argsort(new_axis_inds[axis]) + # first handle parameters with a single axis _axis_add_helper(this, other, axis, new_inds[axis], order_dict[axis]) - if not self.metadata_only: - pad_shape = list(this.data_array.shape) - pad_shape[axis_ind] = len(new_inds[axis]) - zero_pad = np.zeros(tuple(pad_shape)) - this.data_array = np.concatenate( - [this.data_array, zero_pad], axis=axis_ind - ) - this.nsample_array = np.concatenate( - [this.nsample_array, zero_pad], axis=axis_ind - ) - this.flag_array = np.concatenate( - [this.flag_array, 1 - zero_pad], axis=axis_ind - ).astype(np.bool_) + # then pad out parameters with multiple axes + _axis_pad_helper(this, axis, len(new_inds[axis])) else: new_axis_inds[axis] = axis_vals[axis]["this"] + # Now fill in multidimensional arrays + t2o_dict = {} + for axis, inds_dict in axis_vals.items(): + t2o_dict[axis] = np.nonzero( + np.isin(new_axis_inds[axis], inds_dict["other"]) + )[0] + + sort_axes = [] + for axis in axes: + if len(new_inds[axis]) > 0: + sort_axes.append(axis) + _fill_multi_helper(this, other, t2o_dict, sort_axes, order_dict) + if len(new_inds["Nfreqs"]) > 0: # We want to preserve per-spw information based on first appearance # in the concatenated array. @@ -5875,38 +5962,6 @@ def __add__( [this_flexpol_dict[key] for key in this.spw_array] ) - # Now populate the data - t2o_dict = {} - for axis, inds_dict in axis_vals.items(): - t2o_dict[axis] = np.nonzero( - np.isin(new_axis_inds[axis], inds_dict["other"]) - )[0] - - if not self.metadata_only: - this.data_array[ - np.ix_(t2o_dict["Nblts"], t2o_dict["Nfreqs"], t2o_dict["Npols"]) - ] = other.data_array - this.nsample_array[ - np.ix_(t2o_dict["Nblts"], t2o_dict["Nfreqs"], t2o_dict["Npols"]) - ] = other.nsample_array - this.flag_array[ - np.ix_(t2o_dict["Nblts"], t2o_dict["Nfreqs"], t2o_dict["Npols"]) - ] = other.flag_array - - # Fix ordering - for axis_ind, axis in enumerate(axes): - for name, param in zip( - this._data_params, this.data_like_parameters, strict=True - ): - if len(new_inds[axis]) > 0: - unique_order_diffs = np.unique(np.diff(order_dict[axis])) - if np.array_equal(unique_order_diffs, np.array([1])): - # everything is already in order - continue - setattr( - this, name, np.take(param, order_dict[axis], axis=axis_ind) - ) - # Update N parameters (e.g. Npols) this.Ntimes = len(np.unique(this.time_array)) this.Nbls = len(np.unique(this.baseline_array)) From 0101a86aebead622a34ec9cbf2d9ca64ec34bcd2 Mon Sep 17 00:00:00 2001 From: Bryna Hazelton Date: Tue, 19 Aug 2025 18:31:15 -0700 Subject: [PATCH 10/16] minor fixes --- src/pyuvdata/uvbase.py | 3 +- src/pyuvdata/uvdata/uvdata.py | 80 +++++++++++++++++------------------ 2 files changed, 41 insertions(+), 42 deletions(-) diff --git a/src/pyuvdata/uvbase.py b/src/pyuvdata/uvbase.py index 5f87d8453..437035b44 100644 --- a/src/pyuvdata/uvbase.py +++ b/src/pyuvdata/uvbase.py @@ -801,8 +801,7 @@ def _get_multi_axis_params(self) -> list[str]: ret_list = [] for param in self: # For each attribute, if the value is None, then bail, otherwise - # attempt to figure out along which axis ind_arr will apply. - + # check if it's multidimensional attr = getattr(self, param) if ( attr.value is not None diff --git a/src/pyuvdata/uvdata/uvdata.py b/src/pyuvdata/uvdata/uvdata.py index 39a1c3826..cd66e77cc 100644 --- a/src/pyuvdata/uvdata/uvdata.py +++ b/src/pyuvdata/uvdata/uvdata.py @@ -78,8 +78,7 @@ def flt_ind_str_arr( list_of_lists = [flt_str_list, int_str_list] else: list_of_lists = [int_str_list, flt_str_list] - zipped = zip(*list_of_lists, strict=True) - return np.array(["_".join(zpval) for zpval in zipped]) + return np.array(["_".join(zpval) for zpval in zip(*list_of_lists, strict=True)]) def _add_freq_order(spw_id: IntArray, freq_arr: FloatArray) -> IntArray: @@ -87,7 +86,7 @@ def _add_freq_order(spw_id: IntArray, freq_arr: FloatArray) -> IntArray: Get the sorting order for the frequency axis after an add. Sort first by spw then by channel, but don't reorder channels if they are - changing monotonically (all ascending or descending). + changing monotonically (all ascending or descending) within the spw. Parameters ---------- @@ -5697,6 +5696,35 @@ def __add__( strict_uvw_antpos_check=strict_uvw_antpos_check, ) + if ( + this.flex_spw_polarization_array is not None + or other.flex_spw_polarization_array is not None + ): + # special checking for flex_spw + if (this.flex_spw_polarization_array is None) != ( + other.flex_spw_polarization_array is None + ): + raise ValueError( + "Cannot add a flex-pol and non-flex-pol UVData objects. Use " + "the `remove_flex_pol` method to convert the objects to " + "have a regular polarization axis." + ) + elif this.flex_spw_polarization_array is not None: + this_flexpol_dict = this.flexpol_dict() + other_flexpol_dict = other.flexpol_dict() + for key in other_flexpol_dict: + try: + if this_flexpol_dict[key] != other_flexpol_dict[key]: + raise ValueError( + "Cannot add a flex-pol UVData objects where " + "the same spectral window contains different " + "polarizations. Use the `remove_flex_pol` " + "method to convert the objects to have a " + "regular polarization axis." + ) + except KeyError: + this_flexpol_dict[key] = other_flexpol_dict[key] + # Define parameters that must be the same to add objects compatibility_params = ["_vis_units"] @@ -5736,7 +5764,7 @@ def __add__( "other": getattr(other, overlap_params[0]), } - # Check we don't have overlapping data + # Check if we have overlapping data axis_inds = {} for axis, val_arr in axis_vals.items(): both_inds, this_inds, other_inds = np.intersect1d( @@ -5749,10 +5777,11 @@ def __add__( } history_update_string = "" + # TODO do this programmatically for multidimensional parameters if not self.metadata_only and np.all( [len(axis_inds[axis]["both"]) > 0 for axis in axis_inds] ): - # check that overlapping data is not valid + # We have overlaps, check that overlapping data is not valid this_inds = np.ravel_multi_index( ( axis_inds["Nblts"]["this"][:, np.newaxis, np.newaxis], @@ -5797,35 +5826,6 @@ def __add__( } # find the indices in "other" but not in "this" for axis in axes: - if axis == "Nfreqs" and ( - this.flex_spw_polarization_array is not None - or other.flex_spw_polarization_array is not None - ): - # special checking for flex_spw - if (this.flex_spw_polarization_array is None) != ( - other.flex_spw_polarization_array is None - ): - raise ValueError( - "Cannot add a flex-pol and non-flex-pol UVData objects. Use " - "the `remove_flex_pol` method to convert the objects to " - "have a regular polarization axis." - ) - elif this.flex_spw_polarization_array is not None: - this_flexpol_dict = this.flexpol_dict() - other_flexpol_dict = other.flexpol_dict() - for key in other_flexpol_dict: - try: - if this_flexpol_dict[key] != other_flexpol_dict[key]: - raise ValueError( - "Cannot add a flex-pol UVData objects where " - "the same spectral window contains different " - "polarizations. Use the `remove_flex_pol` " - "method to convert the objects to have a " - "regular polarization axis." - ) - except KeyError: - this_flexpol_dict[key] = other_flexpol_dict[key] - temp = np.nonzero( ~np.isin(axis_vals[axis]["other"], axis_vals[axis]["this"]) )[0] @@ -5884,7 +5884,7 @@ def __add__( } for axis, ind_dict in axis_inds.items(): if len(ind_dict["this"]) != 0: - # there is some overlap, so sorting matters + # there is some overlap, so check sorting this_argsort = np.argsort(ind_dict["this"]) other_argsort = np.argsort(ind_dict["other"]) @@ -5897,7 +5897,7 @@ def __add__( getattr(this, reorder_method[axis]["method"])(**kwargs) - # Pad out self to accommodate new data + # start updating parameters new_axis_inds = {} order_dict = {"Nblts": None, "Nfreqs": None, "Npols": None} for axis in axes: @@ -5935,7 +5935,7 @@ def __add__( else: new_axis_inds[axis] = axis_vals[axis]["this"] - # Now fill in multidimensional arrays + # Now fill in multidimensional parameters t2o_dict = {} for axis, inds_dict in axis_vals.items(): t2o_dict[axis] = np.nonzero( @@ -5965,9 +5965,9 @@ def __add__( # Update N parameters (e.g. Npols) this.Ntimes = len(np.unique(this.time_array)) this.Nbls = len(np.unique(this.baseline_array)) - this.Nblts = this.uvw_array.shape[0] + this.Nblts = this.baseline_array.size this.Nfreqs = this.freq_array.size - this.Npols = this.polarization_array.shape[0] + this.Npols = this.polarization_array.size this.Nants_data = this._calc_nants_data() # Update filename parameter @@ -6002,7 +6002,7 @@ def __add__( ) # Reset blt_order if blt axis was added to - if len(t2o_dict["Nblts"]) > 0: + if "Nblts" in sort_axes: this.blt_order = ("time", "baseline") this.set_rectangularity(force=True) From caea2164d5944cf9dca426a64e4219de494aac94 Mon Sep 17 00:00:00 2001 From: Bryna Hazelton Date: Wed, 20 Aug 2025 12:17:30 -0700 Subject: [PATCH 11/16] check that overlap is not valid data programmatically --- src/pyuvdata/uvdata/uvdata.py | 67 +++++++++++++++++++++-------------- tests/uvdata/test_uvdata.py | 1 + 2 files changed, 41 insertions(+), 27 deletions(-) diff --git a/src/pyuvdata/uvdata/uvdata.py b/src/pyuvdata/uvdata/uvdata.py index cd66e77cc..f6ca6b28f 100644 --- a/src/pyuvdata/uvdata/uvdata.py +++ b/src/pyuvdata/uvdata/uvdata.py @@ -5777,36 +5777,49 @@ def __add__( } history_update_string = "" - # TODO do this programmatically for multidimensional parameters - if not self.metadata_only and np.all( - [len(axis_inds[axis]["both"]) > 0 for axis in axis_inds] - ): + + if np.all([len(axis_inds[axis]["both"]) > 0 for axis in axis_inds]): # We have overlaps, check that overlapping data is not valid - this_inds = np.ravel_multi_index( - ( - axis_inds["Nblts"]["this"][:, np.newaxis, np.newaxis], - axis_inds["Nfreqs"]["this"][np.newaxis, :, np.newaxis], - axis_inds["Npols"]["this"][np.newaxis, np.newaxis, :], - ), - this.data_array.shape, - ).flatten() - other_inds = np.ravel_multi_index( - ( - axis_inds["Nblts"]["other"][:, np.newaxis, np.newaxis], - axis_inds["Nfreqs"]["other"][np.newaxis, :, np.newaxis], - axis_inds["Npols"]["other"][np.newaxis, np.newaxis, :], - ), - other.data_array.shape, - ).flatten() - this_all_zero = np.all(this.data_array.flatten()[this_inds] == 0) - this_all_flag = np.all(this.flag_array.flatten()[this_inds]) - other_all_zero = np.all(other.data_array.flatten()[other_inds] == 0) - other_all_flag = np.all(other.flag_array.flatten()[other_inds]) - - if this_all_zero and this_all_flag: + multi_axis_params = this._get_multi_axis_params() + this_test = [] + other_test = [] + for param in multi_axis_params: + form = getattr(this, "_" + param).form + this_shape = getattr(this, param).shape + other_shape = getattr(other, param).shape + this_param_type = getattr(this, "_" + param).expected_type + bool_type = this_param_type is bool or bool in this_param_type + + this_index_list = [] + other_index_list = [] + for ax_ind, axis in enumerate(form): + expand_axes = [ax for ax in range(len(form)) if ax != ax_ind] + this_index_list.append( + np.expand_dims(axis_inds[axis]["this"], axis=expand_axes) + ) + other_index_list.append( + np.expand_dims(axis_inds[axis]["other"], axis=expand_axes) + ) + this_inds = np.ravel_multi_index(this_index_list, this_shape).flatten() + + other_inds = np.ravel_multi_index( + other_index_list, other_shape + ).flatten() + + this_arr = getattr(this, param).flatten()[this_inds] + other_arr = getattr(other, param).flatten()[other_inds] + + if bool_type: + this_test.append(np.all(this_arr)) + other_test.append(np.all(other_arr)) + else: + this_test.append(np.all(this_arr == 0)) + other_test.append(np.all(other_arr == 0)) + + if np.all(this_test): # we're fine to overwrite; update history accordingly history_update_string = " Overwrote invalid data using pyuvdata." - elif other_all_zero and other_all_flag: + elif np.all(other_test): raise ValueError( "To combine these data, please run the add operation again, " "but with the object whose data is to be overwritten as the " diff --git a/tests/uvdata/test_uvdata.py b/tests/uvdata/test_uvdata.py index 56fd7b4c5..6ad1c62de 100644 --- a/tests/uvdata/test_uvdata.py +++ b/tests/uvdata/test_uvdata.py @@ -3983,6 +3983,7 @@ def test_flex_spw_add_concat(sma_mir, add_method, screen1, screen2): if np.any(np.logical_and(screen1, screen2)): flag_screen = screen2[screen1] uv1.data_array[:, flag_screen] = 0.0 + uv1.nsample_array[:, flag_screen] = 0.0 uv1.flag_array[:, flag_screen] = True uv_recomb = getattr(uv1, add_method[0])(uv2, **add_method[1]) From 881b2dd8fe12ee1d1828e2d48bc01a7670e7469b Mon Sep 17 00:00:00 2001 From: Bryna Hazelton Date: Fri, 22 Aug 2025 14:15:01 -0700 Subject: [PATCH 12/16] Better comments and variable names --- src/pyuvdata/uvbase.py | 6 +- src/pyuvdata/uvdata/uvdata.py | 170 +++++++++++++++++++--------------- 2 files changed, 99 insertions(+), 77 deletions(-) diff --git a/src/pyuvdata/uvbase.py b/src/pyuvdata/uvbase.py index 437035b44..c3c82cade 100644 --- a/src/pyuvdata/uvbase.py +++ b/src/pyuvdata/uvbase.py @@ -776,7 +776,7 @@ def _get_param_axis(self, axis_name: str, single_named_axis: bool = False): ret_dict = {} for param in self: # For each attribute, if the value is None, then bail, otherwise - # attempt to figure out along which axis ind_arr will apply. + # find the axis number(s) with the named shape. attr = getattr(self, param) if ( @@ -791,8 +791,8 @@ def _get_param_axis(self, axis_name: str, single_named_axis: bool = False): continue # Only look at where form is a tuple, since that's the only case we - # can have a dynamically defined shape. Note that index doesn't work - # here in the case of a repeated param_name in the form. + # can have a dynamically defined shape. Handle a repeated + # param_name in the form. ret_dict[attr.name] = np.nonzero(np.asarray(attr.form) == axis_name)[0] return ret_dict diff --git a/src/pyuvdata/uvdata/uvdata.py b/src/pyuvdata/uvdata/uvdata.py index f6ca6b28f..1505dbde3 100644 --- a/src/pyuvdata/uvdata/uvdata.py +++ b/src/pyuvdata/uvdata/uvdata.py @@ -195,9 +195,7 @@ def _axis_pad_helper(this: UVData, axis_name: str, add_len: int): setattr(this, param, new_array) -def _fill_multi_helper( - this: UVData, other: UVData, t2o_dict: dict, sort_axes: list[str], order_dict: dict -): +def _fill_multi_helper(this: UVData, other: UVData, t2o_dict: dict, order_dict: dict): """ Fill UVParameter objects with multiple dimensions from the right side object. @@ -210,8 +208,6 @@ def _fill_multi_helper( t2o_dict : dict dict giving the indices in the left object to be filled from the right object for each axis (keys are axes, values are index arrays). - sort_axes : list of str - The axes that need to be sorted along. order_dict : dict dict giving the final sort indices for each axis (keys are axes, values are index arrays for sorting). @@ -229,7 +225,7 @@ def _fill_multi_helper( # Fix ordering for axis_ind, axis in enumerate(form): - if axis in sort_axes: + if order_dict[axis] is not None: unique_order_diffs = np.unique(np.diff(order_dict[axis])) if np.array_equal(unique_order_diffs, np.array([1])): # everything is already in order @@ -5728,49 +5724,61 @@ def __add__( # Define parameters that must be the same to add objects compatibility_params = ["_vis_units"] - # identify params that are not explicitly included in overlap calc per axis axes = ["Nblts", "Nfreqs", "Npols"] - axis_params_check = {} - axis_overlap_params = { + # axis_key_params defines which parameters to use as the defining + # parameters along each axis. These are used to identify overlapping data. + axis_key_params = { "Nblts": ["time_array", "baseline_array"], "Nfreqs": ["freq_array", "flex_spw_id_array"], "Npols": ["polarization_array"], } - axis_combined_func = {"Nblts": "blt_str_arr", "Nfreqs": "spw_freq_str_arr"} - axis_dict = {} - for axis in axes: - axis_dict[axis] = this._get_param_axis(axis) - axis_params_check[axis] = [] - for param in axis_dict[axis]: + # specify a function to form a combined string if there are multiple + # key arrays (e.g. baseline-time, spw-freq) + axis_key_func = {"Nblts": "blt_str_arr", "Nfreqs": "spw_freq_str_arr"} + multi_axis_params = this._get_multi_axis_params() + # axis_parameters gives parameters whose form contains each axis + axis_parameters = {} + # axis_check_params gives parameters that should be checked if adding + # along other axes + axis_check_params = {} + # axis_key_arrays gives the arrays to use for checking for overlap per axis + axis_key_arrays = {} + # axis_overlap_inds has the outcomes of np.intersect1d on the + # axis_key_arrays per axis. So it has the both/this inds/other inds + # for any overlaps. + axis_overlap_inds = {} + for axis, overlap_params in axis_key_params.items(): + axis_parameters[axis] = this._get_param_axis(axis) + axis_check_params[axis] = [] + for param in axis_parameters[axis]: + # get parameters for compatibility checking. Exclude parameters + # that define overlap and multidimensional parameters which are + # handled separately later. if ( - param not in this._data_params - and param not in axis_overlap_params[axis] + param not in multi_axis_params + and param not in axis_key_params[axis] ): - axis_params_check[axis].append("_" + param) + axis_check_params[axis].append("_" + param) - # build this/other arrays for checking for overlap. - # Use a combined string if there are multiple arrays defining overlap - # (e.g. baseline-time, spw-freq) - axis_vals = {} - for axis, overlap_params in axis_overlap_params.items(): + # build this/other arrays for checking for overlap. if len(overlap_params) > 1: - axis_vals[axis] = { - "this": getattr(this, axis_combined_func[axis])(), - "other": getattr(other, axis_combined_func[axis])(), + axis_key_arrays[axis] = { + "this": getattr(this, axis_key_func[axis])(), + "other": getattr(other, axis_key_func[axis])(), } else: - axis_vals[axis] = { + axis_key_arrays[axis] = { "this": getattr(this, overlap_params[0]), "other": getattr(other, overlap_params[0]), } - # Check if we have overlapping data - axis_inds = {} - for axis, val_arr in axis_vals.items(): + # Check if we have overlapping data both_inds, this_inds, other_inds = np.intersect1d( - val_arr["this"], val_arr["other"], return_indices=True + axis_key_arrays[axis]["this"], + axis_key_arrays[axis]["other"], + return_indices=True, ) - axis_inds[axis] = { + axis_overlap_inds[axis] = { "this": this_inds, "other": other_inds, "both": both_inds, @@ -5778,9 +5786,10 @@ def __add__( history_update_string = "" - if np.all([len(axis_inds[axis]["both"]) > 0 for axis in axis_inds]): + if np.all( + [len(axis_overlap_inds[axis]["both"]) > 0 for axis in axis_overlap_inds] + ): # We have overlaps, check that overlapping data is not valid - multi_axis_params = this._get_multi_axis_params() this_test = [] other_test = [] for param in multi_axis_params: @@ -5795,10 +5804,14 @@ def __add__( for ax_ind, axis in enumerate(form): expand_axes = [ax for ax in range(len(form)) if ax != ax_ind] this_index_list.append( - np.expand_dims(axis_inds[axis]["this"], axis=expand_axes) + np.expand_dims( + axis_overlap_inds[axis]["this"], axis=expand_axes + ) ) other_index_list.append( - np.expand_dims(axis_inds[axis]["other"], axis=expand_axes) + np.expand_dims( + axis_overlap_inds[axis]["other"], axis=expand_axes + ) ) this_inds = np.ravel_multi_index(this_index_list, this_shape).flatten() @@ -5830,7 +5843,9 @@ def __add__( "These objects have overlapping data and cannot be combined." ) - new_inds = {} + # Now actually find which axes are going to be added along + # other_inds_use have the indices in other that will be added to this + other_inds_use = {} additions = [] axis_descriptions = { "Nblts": "baseline-time", @@ -5840,30 +5855,30 @@ def __add__( # find the indices in "other" but not in "this" for axis in axes: temp = np.nonzero( - ~np.isin(axis_vals[axis]["other"], axis_vals[axis]["this"]) + ~np.isin(axis_key_arrays[axis]["other"], axis_key_arrays[axis]["this"]) )[0] if len(temp) > 0: - new_inds[axis] = temp + other_inds_use[axis] = temp # add params associated with the other axes to compatibility_params for axis2 in axes: if axis2 != axis: - compatibility_params.extend(axis_params_check[axis2]) + compatibility_params.extend(axis_check_params[axis2]) additions.append(axis_descriptions[axis]) else: - new_inds[axis] = [] + other_inds_use[axis] = [] # Actually check compatibility parameters for cp in compatibility_params: params_match = None - for axis, check_list in axis_params_check.items(): + for axis, check_list in axis_check_params.items(): if cp in check_list: # only check that overlapping indices match this_param = getattr(this, cp) this_param_overlap = this_param.get_from_form( - {axis: axis_inds[axis]["this"]} + {axis: axis_overlap_inds[axis]["this"]} ) other_param_overlap = getattr(other, cp).get_from_form( - {axis: axis_inds[axis]["other"]} + {axis: axis_overlap_inds[axis]["other"]} ) params_match = np.allclose( this_param_overlap, @@ -5895,7 +5910,7 @@ def __add__( "Nfreqs": {"method": "reorder_freqs", "parameter": "channel_order"}, "Npols": {"method": "reorder_pols", "parameter": "order"}, } - for axis, ind_dict in axis_inds.items(): + for axis, ind_dict in axis_overlap_inds.items(): if len(ind_dict["this"]) != 0: # there is some overlap, so check sorting this_argsort = np.argsort(ind_dict["this"]) @@ -5910,16 +5925,22 @@ def __add__( getattr(this, reorder_method[axis]["method"])(**kwargs) - # start updating parameters - new_axis_inds = {} + # checks are all done, start updating parameters + # combined_key_arrays has the final key arrays after adding. + combined_key_arrays = {} + # order_dict has info about how to sort each axis. Initialize to None + # for axes that are not added along (so do not need sorting) order_dict = {"Nblts": None, "Nfreqs": None, "Npols": None} for axis in axes: - if len(new_inds[axis]) > 0: - new_axis_inds[axis] = np.concatenate( - (axis_vals[axis]["this"], axis_vals[axis]["other"][new_inds[axis]]) + if len(other_inds_use[axis]) > 0: + combined_key_arrays[axis] = np.concatenate( + ( + axis_key_arrays[axis]["this"], + axis_key_arrays[axis]["other"][other_inds_use[axis]], + ) ) if axis == "Npols": - order_dict[axis] = np.argsort(np.abs(new_axis_inds[axis])) + order_dict[axis] = np.argsort(np.abs(combined_key_arrays[axis])) elif axis == "Nfreqs" and ( np.any(np.diff(this.freq_array) < 0) or np.any(np.diff(other.freq_array) < 0) @@ -5930,38 +5951,39 @@ def __add__( np.concatenate( ( this.flex_spw_id_array, - other.flex_spw_id_array[new_inds[axis]], + other.flex_spw_id_array[other_inds_use[axis]], ) ), np.concatenate( - (this.freq_array, other.freq_array[new_inds[axis]]) + (this.freq_array, other.freq_array[other_inds_use[axis]]) ), ) else: - order_dict[axis] = np.argsort(new_axis_inds[axis]) + order_dict[axis] = np.argsort(combined_key_arrays[axis]) - # first handle parameters with a single axis - _axis_add_helper(this, other, axis, new_inds[axis], order_dict[axis]) + # first handle parameters with a single named axis + _axis_add_helper( + this, other, axis, other_inds_use[axis], order_dict[axis] + ) # then pad out parameters with multiple axes - _axis_pad_helper(this, axis, len(new_inds[axis])) + _axis_pad_helper(this, axis, len(other_inds_use[axis])) else: - new_axis_inds[axis] = axis_vals[axis]["this"] + # no add along this axis, so it's the same as what's already on this + combined_key_arrays[axis] = axis_key_arrays[axis]["this"] # Now fill in multidimensional parameters + # t2o_dict has the mapping of where arrays on other get mapped into + # this after padding t2o_dict = {} - for axis, inds_dict in axis_vals.items(): + for axis, inds_dict in axis_key_arrays.items(): t2o_dict[axis] = np.nonzero( - np.isin(new_axis_inds[axis], inds_dict["other"]) + np.isin(combined_key_arrays[axis], inds_dict["other"]) )[0] - sort_axes = [] - for axis in axes: - if len(new_inds[axis]) > 0: - sort_axes.append(axis) - _fill_multi_helper(this, other, t2o_dict, sort_axes, order_dict) + _fill_multi_helper(this, other, t2o_dict, order_dict) - if len(new_inds["Nfreqs"]) > 0: + if len(other_inds_use["Nfreqs"]) > 0: # We want to preserve per-spw information based on first appearance # in the concatenated array. unique_index = np.sort( @@ -6015,7 +6037,7 @@ def __add__( ) # Reset blt_order if blt axis was added to - if "Nblts" in sort_axes: + if order_dict["Nblts"] is not None: this.blt_order = ("time", "baseline") this.set_rectangularity(force=True) @@ -6231,18 +6253,18 @@ def fast_concat( # identify params that are not explicitly included in overlap calc per axis axis_shape = {"blt": "Nblts", "freq": "Nfreqs", "polarization": "Npols"} - axis_params_check = {} - axis_dict = {} + axis_check_params = {} + axis_parameters = {} for axis2, ax_shape in axis_shape.items(): - axis_dict[axis2] = this._get_param_axis(ax_shape) - axis_params_check[axis2] = [] - for param in axis_dict[axis2]: + axis_parameters[axis2] = this._get_param_axis(ax_shape) + axis_check_params[axis2] = [] + for param in axis_parameters[axis2]: if param not in this._data_params: - axis_params_check[axis2].append("_" + param) + axis_check_params[axis2].append("_" + param) for axis2 in axis_shape: if axis2 != axis: - compatibility_params.extend(axis_params_check[axis2]) + compatibility_params.extend(axis_check_params[axis2]) axis_descriptions = { "blt": "baseline-time", From 5a58143b1605dff850fac3635a86fcd90f643be7 Mon Sep 17 00:00:00 2001 From: Bryna Hazelton Date: Mon, 25 Aug 2025 15:53:58 -0700 Subject: [PATCH 13/16] Move functions out of uvdata for better reusability --- src/pyuvdata/utils/frequency.py | 39 ++++++ src/pyuvdata/utils/tools.py | 42 ++++++ src/pyuvdata/uvbase.py | 133 ++++++++++++++++++ src/pyuvdata/uvdata/uvdata.py | 240 ++------------------------------ 4 files changed, 223 insertions(+), 231 deletions(-) diff --git a/src/pyuvdata/utils/frequency.py b/src/pyuvdata/utils/frequency.py index c42d741f2..dabedf4d8 100644 --- a/src/pyuvdata/utils/frequency.py +++ b/src/pyuvdata/utils/frequency.py @@ -8,6 +8,7 @@ from . import tools from .pol import jstr2num, polstr2num +from .types import FloatArray, IntArray def _check_flex_spw_contiguous(*, spw_array, flex_spw_id_array, strict=True): @@ -549,3 +550,41 @@ def _select_freq_helper( freq_inds = freq_inds.tolist() return freq_inds, spw_inds, selections + + +def _add_freq_order(spw_id: IntArray, freq_arr: FloatArray) -> IntArray: + """ + Get the sorting order for the frequency axis after an add. + + Sort first by spw then by channel, but don't reorder channels if they are + changing monotonically (all ascending or descending) within the spw. + + Parameters + ---------- + spw_id : np.ndarray of int + SPW id array of combined data to be sorted. + freq_arr : np.ndarray of float + Frequency array of combined data to be sorted. + + Returns + ------- + f_order : np.ndarray of int + index array giving the sort order. + + """ + spws = np.unique(spw_id) + f_order = np.concatenate([np.where(spw_id == spw)[0] for spw in np.unique(spw_id)]) + + # With spectral windows sorted, check and see if channels within + # windows need sorting. If they are ordered in ascending or descending + # fashion, leave them be. If not, sort in ascending order + for spw in spws: + select_mask = spw_id[f_order] == spw + check_freqs = freq_arr[f_order[select_mask]] + if not np.all(np.diff(check_freqs) > 0) and not np.all( + np.diff(check_freqs) < 0 + ): + subsort_order = f_order[select_mask] + f_order[select_mask] = subsort_order[np.argsort(check_freqs)] + + return f_order diff --git a/src/pyuvdata/utils/tools.py b/src/pyuvdata/utils/tools.py index 039c90af6..9a7247703 100644 --- a/src/pyuvdata/utils/tools.py +++ b/src/pyuvdata/utils/tools.py @@ -9,6 +9,8 @@ import numpy as np +from .types import FloatArray, IntArray, StrArray + def _get_iterable(x): """Return iterable version of input.""" @@ -687,3 +689,43 @@ def _ntimes_to_nblts(uvd): inds.append(np.where(unique_t == i)[0][0]) return np.asarray(inds) + + +def flt_ind_str_arr( + *, + fltarr: FloatArray, + intarr: IntArray, + flt_tols: tuple[float, float], + flt_first: bool = True, +) -> StrArray: + """ + Create a string array built from float and integer arrays for matching. + + Parameters + ---------- + fltarr : np.ndarray of float + float array to be used in output string array + intarr : np.ndarray of int + integer array to be used in output string array + flt_tols : 2-tuple of float + Tolerances (relative, absolute) to use in formatting the floats as strings. + flt_first : bool + Whether to put the float first in the out put string or not (if False + the int comes first.) + + Returns + ------- + np.ndarray of str + String array that combines the float and integer values, useful for matching. + + """ + prec_flt = -2 * np.floor(np.log10(flt_tols[-1])).astype(int) + prec_int = 8 + flt_str_list = ["{1:.{0}f}".format(prec_flt, flt) for flt in fltarr] + int_str_list = [str(intv).zfill(prec_int) for intv in intarr] + list_of_lists = [] + if flt_first: + list_of_lists = [flt_str_list, int_str_list] + else: + list_of_lists = [int_str_list, flt_str_list] + return np.array(["_".join(zpval) for zpval in zip(*list_of_lists, strict=True)]) diff --git a/src/pyuvdata/uvbase.py b/src/pyuvdata/uvbase.py index c3c82cade..c073e3523 100644 --- a/src/pyuvdata/uvbase.py +++ b/src/pyuvdata/uvbase.py @@ -16,6 +16,7 @@ from . import __version__, parameter as uvp from .utils.tools import _get_iterable, slicify +from .utils.types import IntArray __all__ = ["UVBase"] @@ -865,3 +866,135 @@ def _select_along_param_axis(self, param_dict: dict): # here in the case of a repeated param_name in the form. attr.value = attr.get_from_form(slice_dict) attr.setter(self) + + def _axis_add_helper( + self, + other, + axis_name: str, + other_inds: IntArray, + final_order: IntArray | None = None, + ): + """ + Combine UVParameter objects with a single axis along an axis. + + Parameters + ---------- + other : UVBase + The UVBase object to be added. + axis_name : str + The axis name (e.g. "Nblts", "Npols"). + other_inds : np.ndarray of int + Indices into the other object along this axis to include. + final_order : np.ndarray of int + Final ordering array giving the sort order after concatenation. + + """ + update_params = self._get_param_axis(axis_name, single_named_axis=True) + other_form_dict = {axis_name: other_inds} + for param, axis_list in update_params.items(): + axis = axis_list[0] + new_array = np.concatenate( + [ + getattr(self, param), + getattr(other, "_" + param).get_from_form(other_form_dict), + ], + axis=axis, + ) + if final_order is not None: + new_array = np.take(new_array, final_order, axis=axis) + + setattr(self, param, new_array) + + def _axis_pad_helper(self, axis_name: str, add_len: int): + """ + Pad out UVParameter objects with multiple dimensions along an axis. + + Parameters + ---------- + axis_name : str + The axis name (e.g. "Nblts", "Npols"). + add_len : int + The extra length to be padded on for this axis. + + """ + update_params = self._get_param_axis(axis_name) + multi_axis_params = self._get_multi_axis_params() + for param, axis_list in update_params.items(): + if param not in multi_axis_params: + continue + this_param_shape = getattr(self, param).shape + this_param_type = getattr(self, "_" + param).expected_type + bool_type = this_param_type is bool or bool in this_param_type + pad_shape = list(this_param_shape) + for ax in axis_list: + pad_shape[ax] = add_len + if bool_type: + pad_array = np.ones(tuple(pad_shape), dtype=bool) + else: + pad_array = np.zeros(tuple(pad_shape)) + new_array = np.concatenate([getattr(self, param), pad_array], axis=ax) + if bool_type: + new_array = new_array.astype(np.bool_) + setattr(self, param, new_array) + + def _fill_multi_helper(self, other, t2o_dict: dict, order_dict: dict): + """ + Fill UVParameter objects with multiple dimensions from the right side object. + + Parameters + ---------- + other : UVBase + The UVBase object to be added. + t2o_dict : dict + dict giving the indices in the left object to be filled from the right + object for each axis (keys are axes, values are index arrays). + order_dict : dict + dict giving the final sort indices for each axis (keys are axes, values + are index arrays for sorting). + + """ + multi_axis_params = self._get_multi_axis_params() + for param in multi_axis_params: + form = getattr(self, "_" + param).form + index_list = [] + for axis in form: + index_list.append(t2o_dict[axis]) + new_arr = getattr(self, param) + new_arr[np.ix_(*index_list)] = getattr(other, param) + setattr(self, param, new_arr) + + # Fix ordering + for axis_ind, axis in enumerate(form): + if order_dict[axis] is not None: + unique_order_diffs = np.unique(np.diff(order_dict[axis])) + if np.array_equal(unique_order_diffs, np.array([1])): + # everything is already in order + continue + setattr( + self, + param, + np.take(getattr(self, param), order_dict[axis], axis=axis_ind), + ) + + def _axis_fast_concat_helper(self, other, axis_name: str): + """ + Concatenate UVParameter objects along an axis assuming no overlap. + + Parameters + ---------- + other : UVBase + The UVBase object to be added. + axis_name : str + The axis name (e.g. "Nblts", "Npols"). + """ + update_params = self._get_param_axis(axis_name) + for param, axis_list in update_params.items(): + axis = axis_list[0] + setattr( + self, + param, + np.concatenate( + [getattr(self, param)] + [getattr(obj, param) for obj in other], + axis=axis, + ), + ) diff --git a/src/pyuvdata/uvdata/uvdata.py b/src/pyuvdata/uvdata/uvdata.py index 1505dbde3..02c285f1d 100644 --- a/src/pyuvdata/uvdata/uvdata.py +++ b/src/pyuvdata/uvdata/uvdata.py @@ -27,7 +27,7 @@ from ..utils import phasing as phs_utils from ..utils.io import hdf5 as hdf5_utils from ..utils.phasing import _get_focus_xyz, _get_nearfield_delay -from ..utils.types import FloatArray, IntArray, StrArray +from ..utils.types import StrArray from ..uvbase import UVBase from .initializers import new_uvdata @@ -41,228 +41,6 @@ ) -def flt_ind_str_arr( - *, - fltarr: FloatArray, - intarr: IntArray, - flt_tols: tuple[float, float], - flt_first: bool = True, -) -> StrArray: - """ - Create a string array built from float and integer arrays for matching. - - Parameters - ---------- - fltarr : np.ndarray of float - float array to be used in output string array - intarr : np.ndarray of int - integer array to be used in output string array - flt_tols : 2-tuple of float - Tolerances (relative, absolute) to use in formatting the floats as strings. - flt_first : bool - Whether to put the float first in the out put string or not (if False - the int comes first.) - - Returns - ------- - np.ndarray of str - String array that combines the float and integer values, useful for matching. - - """ - prec_flt = -2 * np.floor(np.log10(flt_tols[-1])).astype(int) - prec_int = 8 - flt_str_list = ["{1:.{0}f}".format(prec_flt, flt) for flt in fltarr] - int_str_list = [str(intv).zfill(prec_int) for intv in intarr] - list_of_lists = [] - if flt_first: - list_of_lists = [flt_str_list, int_str_list] - else: - list_of_lists = [int_str_list, flt_str_list] - return np.array(["_".join(zpval) for zpval in zip(*list_of_lists, strict=True)]) - - -def _add_freq_order(spw_id: IntArray, freq_arr: FloatArray) -> IntArray: - """ - Get the sorting order for the frequency axis after an add. - - Sort first by spw then by channel, but don't reorder channels if they are - changing monotonically (all ascending or descending) within the spw. - - Parameters - ---------- - spw_id : np.ndarray of int - SPW id array of combined data to be sorted. - freq_arr : np.ndarray of float - Frequency array of combined data to be sorted. - - Returns - ------- - f_order : np.ndarray of int - index array giving the sort order. - - """ - spws = np.unique(spw_id) - f_order = np.concatenate([np.where(spw_id == spw)[0] for spw in np.unique(spw_id)]) - - # With spectral windows sorted, check and see if channels within - # windows need sorting. If they are ordered in ascending or descending - # fashion, leave them be. If not, sort in ascending order - for spw in spws: - select_mask = spw_id[f_order] == spw - check_freqs = freq_arr[f_order[select_mask]] - if not np.all(np.diff(check_freqs) > 0) and not np.all( - np.diff(check_freqs) < 0 - ): - subsort_order = f_order[select_mask] - f_order[select_mask] = subsort_order[np.argsort(check_freqs)] - - return f_order - - -def _axis_add_helper( - this: UVData, - other: UVData, - axis_name: str, - other_inds: IntArray, - final_order: IntArray | None = None, -): - """ - Combine UVParameter objects with a single axis along an axis. - - Parameters - ---------- - this : UVData - The left UVData object in the add. - other : UVData - The right UVData object in the add. - axis_name : str - The axis name (e.g. "Nblts", "Npols"). - other_inds : np.ndarray of int - Indices into the other object along this axis to include. - final_order : np.ndarray of int - Final ordering array giving the sort order after concatenation. - - """ - update_params = this._get_param_axis(axis_name, single_named_axis=True) - other_form_dict = {axis_name: other_inds} - for param, axis_list in update_params.items(): - axis = axis_list[0] - new_array = np.concatenate( - [ - getattr(this, param), - getattr(other, "_" + param).get_from_form(other_form_dict), - ], - axis=axis, - ) - if final_order is not None: - new_array = np.take(new_array, final_order, axis=axis) - - setattr(this, param, new_array) - - -def _axis_pad_helper(this: UVData, axis_name: str, add_len: int): - """ - Pad out UVParameter objects with multiple dimensions along an axis. - - Parameters - ---------- - this : UVData - The left UVData object in the add. - axis_name : str - The axis name (e.g. "Nblts", "Npols"). - add_len : int - The extra length to be padded on for this axis. - - """ - update_params = this._get_param_axis(axis_name) - multi_axis_params = this._get_multi_axis_params() - for param, axis_list in update_params.items(): - if param not in multi_axis_params: - continue - this_param_shape = getattr(this, param).shape - this_param_type = getattr(this, "_" + param).expected_type - bool_type = this_param_type is bool or bool in this_param_type - pad_shape = list(this_param_shape) - for ax in axis_list: - pad_shape[ax] = add_len - if bool_type: - pad_array = np.ones(tuple(pad_shape), dtype=bool) - else: - pad_array = np.zeros(tuple(pad_shape)) - new_array = np.concatenate([getattr(this, param), pad_array], axis=ax) - if bool_type: - new_array = new_array.astype(np.bool_) - setattr(this, param, new_array) - - -def _fill_multi_helper(this: UVData, other: UVData, t2o_dict: dict, order_dict: dict): - """ - Fill UVParameter objects with multiple dimensions from the right side object. - - Parameters - ---------- - this : UVData - The left UVData object in the add. - other : UVData - The right UVData object in the add. - t2o_dict : dict - dict giving the indices in the left object to be filled from the right - object for each axis (keys are axes, values are index arrays). - order_dict : dict - dict giving the final sort indices for each axis (keys are axes, values - are index arrays for sorting). - - """ - multi_axis_params = this._get_multi_axis_params() - for param in multi_axis_params: - form = getattr(this, "_" + param).form - index_list = [] - for axis in form: - index_list.append(t2o_dict[axis]) - new_arr = getattr(this, param) - new_arr[np.ix_(*index_list)] = getattr(other, param) - setattr(this, param, new_arr) - - # Fix ordering - for axis_ind, axis in enumerate(form): - if order_dict[axis] is not None: - unique_order_diffs = np.unique(np.diff(order_dict[axis])) - if np.array_equal(unique_order_diffs, np.array([1])): - # everything is already in order - continue - setattr( - this, - param, - np.take(getattr(this, param), order_dict[axis], axis=axis_ind), - ) - - -def _axis_fast_concat_helper(this: UVData, other: UVData, axis_name: str): - """ - Concatenate UVParameter objects along an axis assuming no overlap. - - Parameters - ---------- - this : UVData - The left UVData object in the add. - other : UVData - The right UVData object in the add. - axis_name : str - The axis name (e.g. "Nblts", "Npols"). - """ - update_params = this._get_param_axis(axis_name) - for param, axis_list in update_params.items(): - axis = axis_list[0] - setattr( - this, - param, - np.concatenate( - [getattr(this, param)] + [getattr(obj, param) for obj in other], - axis=axis, - ), - ) - - class UVData(UVBase): """ A class for defining a radio interferometer dataset. @@ -5588,7 +5366,7 @@ def fix_phase(self, *, use_ant_pos=True): def blt_str_arr(self) -> StrArray: """Create a string array with baseline and time info for matching purposes.""" - return flt_ind_str_arr( + return utils.tools.flt_ind_str_arr( fltarr=self.time_array, intarr=self.baseline_array, flt_tols=self._time_array.tols, @@ -5597,7 +5375,7 @@ def blt_str_arr(self) -> StrArray: def spw_freq_str_arr(self) -> StrArray: """Create a string array with spw and freq info for matching purposes.""" - return flt_ind_str_arr( + return utils.tools.flt_ind_str_arr( fltarr=self.freq_array, intarr=self.flex_spw_id_array, flt_tols=self._freq_array.tols, @@ -5947,7 +5725,7 @@ def __add__( ): # deal with the possibility of spws with channels in # descending order. - order_dict[axis] = _add_freq_order( + order_dict[axis] = utils.frequency._add_freq_order( np.concatenate( ( this.flex_spw_id_array, @@ -5962,12 +5740,12 @@ def __add__( order_dict[axis] = np.argsort(combined_key_arrays[axis]) # first handle parameters with a single named axis - _axis_add_helper( - this, other, axis, other_inds_use[axis], order_dict[axis] + this._axis_add_helper( + other, axis, other_inds_use[axis], order_dict[axis] ) # then pad out parameters with multiple axes - _axis_pad_helper(this, axis, len(other_inds_use[axis])) + this._axis_pad_helper(axis, len(other_inds_use[axis])) else: # no add along this axis, so it's the same as what's already on this combined_key_arrays[axis] = axis_key_arrays[axis]["this"] @@ -5981,7 +5759,7 @@ def __add__( np.isin(combined_key_arrays[axis], inds_dict["other"]) )[0] - _fill_multi_helper(this, other, t2o_dict, order_dict) + this._fill_multi_helper(other, t2o_dict, order_dict) if len(other_inds_use["Nfreqs"]) > 0: # We want to preserve per-spw information based on first appearance @@ -6312,7 +6090,7 @@ def fast_concat( this.telescope = tel_obj # update the relevant shape parameter - _axis_fast_concat_helper(this, other, axis_shape[axis]) + this._axis_fast_concat_helper(other, axis_shape[axis]) new_shape = sum( [getattr(this, axis_shape[axis])] + [getattr(obj, axis_shape[axis]) for obj in other] From b492cf708bc6e9361b6a792251276a8ae9065323 Mon Sep 17 00:00:00 2001 From: Bryna Hazelton Date: Tue, 23 Sep 2025 16:27:40 -0700 Subject: [PATCH 14/16] cleanup add to use a single dict to carry all the per axis info --- src/pyuvdata/uvdata/uvdata.py | 222 ++++++++++++++++++---------------- 1 file changed, 118 insertions(+), 104 deletions(-) diff --git a/src/pyuvdata/uvdata/uvdata.py b/src/pyuvdata/uvdata/uvdata.py index 02c285f1d..5521b173e 100644 --- a/src/pyuvdata/uvdata/uvdata.py +++ b/src/pyuvdata/uvdata/uvdata.py @@ -5502,61 +5502,85 @@ def __add__( # Define parameters that must be the same to add objects compatibility_params = ["_vis_units"] - axes = ["Nblts", "Nfreqs", "Npols"] - # axis_key_params defines which parameters to use as the defining - # parameters along each axis. These are used to identify overlapping data. - axis_key_params = { - "Nblts": ["time_array", "baseline_array"], - "Nfreqs": ["freq_array", "flex_spw_id_array"], - "Npols": ["polarization_array"], + # setup a dict to carry all the axis-specific info we need throughout + # the add process: + # - description is used in history string + # - key_params defines which parameters to use as the defining + # parameters along each axis. These are used to identify overlapping data. + # - key_func specifies a function to form a combined string if there are + # multiple key arrays (e.g. baseline-time, spw-freq) + # - reorder gives method & parameters for reording along each axis + # - order has info about how to sort each axis. Initialize to None + # for axes that are not added along (so do not need sorting), + # updated later. + # ---added later--- + # - check_params gives parameters that should be checked if adding + # along other axes + # - key_arrays gives the arrays to use for checking for overlap per axis + # - overlap_inds has the outcomes of np.intersect1d on the key_arrays + # between this and other. So it has the both/this/other inds + # for any overlaps. + # - other_inds_use has the indices in other that will be added to this + # - combined_key_arrays has the final key arrays after adding. + # - t2o has the mapping of where arrays on other get mapped into this + # along each axis after padding + + axis_info = { + "Nblts": { + "description": "baseline-time", + "key_params": ["time_array", "baseline_array"], + "key_func": "blt_str_arr", + "reorder": {"method": "reorder_blts", "parameter": "order"}, + "order": None, + }, + "Nfreqs": { + "description": "frequency", + "key_params": ["freq_array", "flex_spw_id_array"], + "key_func": "spw_freq_str_arr", + "reorder": {"method": "reorder_freqs", "parameter": "channel_order"}, + "order": None, + }, + "Npols": { + "description": "polarization", + "key_params": ["polarization_array"], + "reorder": {"method": "reorder_pols", "parameter": "order"}, + "order": None, + }, } - # specify a function to form a combined string if there are multiple - # key arrays (e.g. baseline-time, spw-freq) - axis_key_func = {"Nblts": "blt_str_arr", "Nfreqs": "spw_freq_str_arr"} - multi_axis_params = this._get_multi_axis_params() - # axis_parameters gives parameters whose form contains each axis - axis_parameters = {} - # axis_check_params gives parameters that should be checked if adding - # along other axes - axis_check_params = {} - # axis_key_arrays gives the arrays to use for checking for overlap per axis - axis_key_arrays = {} - # axis_overlap_inds has the outcomes of np.intersect1d on the - # axis_key_arrays per axis. So it has the both/this inds/other inds - # for any overlaps. - axis_overlap_inds = {} - for axis, overlap_params in axis_key_params.items(): - axis_parameters[axis] = this._get_param_axis(axis) - axis_check_params[axis] = [] - for param in axis_parameters[axis]: - # get parameters for compatibility checking. Exclude parameters - # that define overlap and multidimensional parameters which are - # handled separately later. - if ( - param not in multi_axis_params - and param not in axis_key_params[axis] - ): - axis_check_params[axis].append("_" + param) + + for axis, info in axis_info.items(): + # get parameters for compatibility checking. Exclude multidimensional + # parameters which are handled separately later. + params_this_axis = this._get_param_axis(axis, single_named_axis=True) + info["check_params"] = [] + for param in params_this_axis: + # Also exclude parameters that define overlap + if param not in info["key_params"]: + info["check_params"].append("_" + param) # build this/other arrays for checking for overlap. - if len(overlap_params) > 1: - axis_key_arrays[axis] = { - "this": getattr(this, axis_key_func[axis])(), - "other": getattr(other, axis_key_func[axis])(), + # key_arrays gives the arrays to use for checking for overlap per axis + if len(info["key_params"]) > 1: + info["key_arrays"] = { + "this": getattr(this, info["key_func"])(), + "other": getattr(other, info["key_func"])(), } else: - axis_key_arrays[axis] = { - "this": getattr(this, overlap_params[0]), - "other": getattr(other, overlap_params[0]), + info["key_arrays"] = { + "this": getattr(this, info["key_params"][0]), + "other": getattr(other, info["key_params"][0]), } # Check if we have overlapping data both_inds, this_inds, other_inds = np.intersect1d( - axis_key_arrays[axis]["this"], - axis_key_arrays[axis]["other"], + info["key_arrays"]["this"], + info["key_arrays"]["other"], return_indices=True, ) - axis_overlap_inds[axis] = { + # overlap_inds has the outcomes of np.intersect1d on the + # key_arrays per axis. So it has the both/this inds/other inds + # for any overlaps. + info["overlap_inds"] = { "this": this_inds, "other": other_inds, "both": both_inds, @@ -5565,11 +5589,12 @@ def __add__( history_update_string = "" if np.all( - [len(axis_overlap_inds[axis]["both"]) > 0 for axis in axis_overlap_inds] + [len(axis_info[axis]["overlap_inds"]["both"]) > 0 for axis in axis_info] ): # We have overlaps, check that overlapping data is not valid this_test = [] other_test = [] + multi_axis_params = this._get_multi_axis_params() for param in multi_axis_params: form = getattr(this, "_" + param).form this_shape = getattr(this, param).shape @@ -5583,12 +5608,12 @@ def __add__( expand_axes = [ax for ax in range(len(form)) if ax != ax_ind] this_index_list.append( np.expand_dims( - axis_overlap_inds[axis]["this"], axis=expand_axes + axis_info[axis]["overlap_inds"]["this"], axis=expand_axes ) ) other_index_list.append( np.expand_dims( - axis_overlap_inds[axis]["other"], axis=expand_axes + axis_info[axis]["overlap_inds"]["other"], axis=expand_axes ) ) this_inds = np.ravel_multi_index(this_index_list, this_shape).flatten() @@ -5622,41 +5647,35 @@ def __add__( ) # Now actually find which axes are going to be added along - # other_inds_use have the indices in other that will be added to this - other_inds_use = {} additions = [] - axis_descriptions = { - "Nblts": "baseline-time", - "Nfreqs": "frequency", - "Npols": "polarization", - } # find the indices in "other" but not in "this" - for axis in axes: + for axis, info in axis_info.items(): temp = np.nonzero( - ~np.isin(axis_key_arrays[axis]["other"], axis_key_arrays[axis]["this"]) + ~np.isin(info["key_arrays"]["other"], info["key_arrays"]["this"]) )[0] if len(temp) > 0: - other_inds_use[axis] = temp + # other_inds_use has the indices in other that will be added to this + info["other_inds_use"] = temp # add params associated with the other axes to compatibility_params - for axis2 in axes: + for axis2 in axis_info: if axis2 != axis: - compatibility_params.extend(axis_check_params[axis2]) - additions.append(axis_descriptions[axis]) + compatibility_params.extend(axis_info[axis2]["check_params"]) + additions.append(info["description"]) else: - other_inds_use[axis] = [] + info["other_inds_use"] = [] # Actually check compatibility parameters for cp in compatibility_params: params_match = None - for axis, check_list in axis_check_params.items(): - if cp in check_list: + for axis, info in axis_info.items(): + if cp in info["check_params"]: # only check that overlapping indices match this_param = getattr(this, cp) this_param_overlap = this_param.get_from_form( - {axis: axis_overlap_inds[axis]["this"]} + {axis: info["overlap_inds"]["this"]} ) other_param_overlap = getattr(other, cp).get_from_form( - {axis: axis_overlap_inds[axis]["other"]} + {axis: info["overlap_inds"]["other"]} ) params_match = np.allclose( this_param_overlap, @@ -5683,85 +5702,80 @@ def __add__( # Next, we want to make sure that the ordering of the _overlapping_ data is # the same, so that things can get plugged together in a sensible way. - reorder_method = { - "Nblts": {"method": "reorder_blts", "parameter": "order"}, - "Nfreqs": {"method": "reorder_freqs", "parameter": "channel_order"}, - "Npols": {"method": "reorder_pols", "parameter": "order"}, - } - for axis, ind_dict in axis_overlap_inds.items(): - if len(ind_dict["this"]) != 0: + for axis, info in axis_info.items(): + if len(info["overlap_inds"]["this"]) != 0: # there is some overlap, so check sorting - this_argsort = np.argsort(ind_dict["this"]) - other_argsort = np.argsort(ind_dict["other"]) + this_argsort = np.argsort(info["overlap_inds"]["this"]) + other_argsort = np.argsort(info["overlap_inds"]["other"]) if np.any(this_argsort != other_argsort): temp_ind = np.arange(getattr(this, axis)) - temp_ind[ind_dict["this"][this_argsort]] = temp_ind[ - ind_dict["this"][other_argsort] + temp_ind[info["overlap_inds"]["this"][this_argsort]] = temp_ind[ + info["overlap_inds"]["this"][other_argsort] ] - kwargs = {reorder_method[axis]["parameter"]: temp_ind} + kwargs = {info["reorder"]["parameter"]: temp_ind} - getattr(this, reorder_method[axis]["method"])(**kwargs) + getattr(this, info["reorder"]["method"])(**kwargs) # checks are all done, start updating parameters - # combined_key_arrays has the final key arrays after adding. - combined_key_arrays = {} - # order_dict has info about how to sort each axis. Initialize to None - # for axes that are not added along (so do not need sorting) - order_dict = {"Nblts": None, "Nfreqs": None, "Npols": None} - for axis in axes: - if len(other_inds_use[axis]) > 0: - combined_key_arrays[axis] = np.concatenate( + for axis, info in axis_info.items(): + if len(info["other_inds_use"]) > 0: + # combined_key_arrays has the final key arrays after adding. + info["combined_key_arrays"] = np.concatenate( ( - axis_key_arrays[axis]["this"], - axis_key_arrays[axis]["other"][other_inds_use[axis]], + info["key_arrays"]["this"], + info["key_arrays"]["other"][info["other_inds_use"]], ) ) if axis == "Npols": - order_dict[axis] = np.argsort(np.abs(combined_key_arrays[axis])) + # order has info about how to sort each axis. + info["order"] = np.argsort(np.abs(info["combined_key_arrays"])) elif axis == "Nfreqs" and ( np.any(np.diff(this.freq_array) < 0) or np.any(np.diff(other.freq_array) < 0) ): # deal with the possibility of spws with channels in # descending order. - order_dict[axis] = utils.frequency._add_freq_order( + info["order"] = utils.frequency._add_freq_order( np.concatenate( ( this.flex_spw_id_array, - other.flex_spw_id_array[other_inds_use[axis]], + other.flex_spw_id_array[info["other_inds_use"]], ) ), np.concatenate( - (this.freq_array, other.freq_array[other_inds_use[axis]]) + (this.freq_array, other.freq_array[info["other_inds_use"]]) ), ) else: - order_dict[axis] = np.argsort(combined_key_arrays[axis]) + info["order"] = np.argsort(info["combined_key_arrays"]) # first handle parameters with a single named axis this._axis_add_helper( - other, axis, other_inds_use[axis], order_dict[axis] + other, axis, info["other_inds_use"], info["order"] ) # then pad out parameters with multiple axes - this._axis_pad_helper(axis, len(other_inds_use[axis])) + this._axis_pad_helper(axis, len(info["other_inds_use"])) else: # no add along this axis, so it's the same as what's already on this - combined_key_arrays[axis] = axis_key_arrays[axis]["this"] + info["combined_key_arrays"] = info["key_arrays"]["this"] # Now fill in multidimensional parameters - # t2o_dict has the mapping of where arrays on other get mapped into + # t2o has the mapping of where arrays on other get mapped into # this after padding - t2o_dict = {} - for axis, inds_dict in axis_key_arrays.items(): - t2o_dict[axis] = np.nonzero( - np.isin(combined_key_arrays[axis], inds_dict["other"]) + for _, info in axis_info.items(): + info["t2o"] = np.nonzero( + np.isin(info["combined_key_arrays"], info["key_arrays"]["other"]) )[0] - this._fill_multi_helper(other, t2o_dict, order_dict) + this._fill_multi_helper( + other, + {axis: info["t2o"] for axis, info in axis_info.items()}, + {axis: info["order"] for axis, info in axis_info.items()}, + ) - if len(other_inds_use["Nfreqs"]) > 0: + if len(axis_info["Nfreqs"]["other_inds_use"]) > 0: # We want to preserve per-spw information based on first appearance # in the concatenated array. unique_index = np.sort( @@ -5815,7 +5829,7 @@ def __add__( ) # Reset blt_order if blt axis was added to - if order_dict["Nblts"] is not None: + if axis_info["Nblts"]["order"] is not None: this.blt_order = ("time", "baseline") this.set_rectangularity(force=True) From f3517ca23b772f0aa64c31e18c6ac626cde4fa74 Mon Sep 17 00:00:00 2001 From: Bryna Hazelton Date: Tue, 23 Sep 2025 17:07:38 -0700 Subject: [PATCH 15/16] use a single dict in fast_concat as well --- src/pyuvdata/uvdata/uvdata.py | 62 ++++++++++++++++++++--------------- 1 file changed, 35 insertions(+), 27 deletions(-) diff --git a/src/pyuvdata/uvdata/uvdata.py b/src/pyuvdata/uvdata/uvdata.py index 5521b173e..ec7ae0388 100644 --- a/src/pyuvdata/uvdata/uvdata.py +++ b/src/pyuvdata/uvdata/uvdata.py @@ -5991,9 +5991,20 @@ def fast_concat( self and other are not compatible. """ - allowed_axes = ["blt", "freq", "polarization"] - if axis not in allowed_axes: - raise ValueError("Axis must be one of: " + ", ".join(allowed_axes)) + # setup a dict to carry all the axis-specific info we need throughout + # the fast concat process: + # - description is used in history string + # - shape: the shape name parameter (e.g. "Nblts", "Nfreqs", "Npols") + # ---added later---- + # - check_params gives parameters that should be checked if adding + # along other axes + axis_info = { + "blt": {"description": "baseline-time", "shape": "Nblts"}, + "freq": {"description": "frequency", "shape": "Nfreqs"}, + "polarization": {"description": "polarization", "shape": "Npols"}, + } + if axis not in axis_info: + raise ValueError("Axis must be one of: " + ", ".join(axis_info)) if inplace: this = self @@ -6043,28 +6054,23 @@ def fast_concat( history_update_string = " Combined data along " - # identify params that are not explicitly included in overlap calc per axis - axis_shape = {"blt": "Nblts", "freq": "Nfreqs", "polarization": "Npols"} - axis_check_params = {} - axis_parameters = {} - for axis2, ax_shape in axis_shape.items(): - axis_parameters[axis2] = this._get_param_axis(ax_shape) - axis_check_params[axis2] = [] - for param in axis_parameters[axis2]: - if param not in this._data_params: - axis_check_params[axis2].append("_" + param) - - for axis2 in axis_shape: - if axis2 != axis: - compatibility_params.extend(axis_check_params[axis2]) + # figure out what parameters to check for compatibility -- only worry + # about single axis params + for _, info in axis_info.items(): + params_this_axis = this._get_param_axis( + info["shape"], single_named_axis=True + ) + info["check_params"] = [] + for param in params_this_axis: + info["check_params"].append("_" + param) - axis_descriptions = { - "blt": "baseline-time", - "freq": "frequency", - "polarization": "polarization", - } + for axis2, info in axis_info.items(): + if axis2 != axis: + compatibility_params.extend(info["check_params"]) - history_update_string += f" {axis_descriptions[axis]} axis using pyuvdata." + history_update_string += ( + f" {axis_info[axis]['description']} axis using pyuvdata." + ) histories_match = [] for obj in other: @@ -6103,13 +6109,15 @@ def fast_concat( this.telescope = tel_obj + # actually do the concat + this._axis_fast_concat_helper(other, axis_info[axis]["shape"]) + # update the relevant shape parameter - this._axis_fast_concat_helper(other, axis_shape[axis]) new_shape = sum( - [getattr(this, axis_shape[axis])] - + [getattr(obj, axis_shape[axis]) for obj in other] + [getattr(this, axis_info[axis]["shape"])] + + [getattr(obj, axis_info[axis]["shape"]) for obj in other] ) - setattr(this, axis_shape[axis], new_shape) + setattr(this, axis_info[axis]["shape"], new_shape) if axis == "freq": # We want to preserve per-spw information based on first appearance From 26ed06f4bcc783721a49f4143da6378f8b0876f2 Mon Sep 17 00:00:00 2001 From: Bryna Hazelton Date: Thu, 11 Dec 2025 14:29:16 -0800 Subject: [PATCH 16/16] update comments --- src/pyuvdata/uvbase.py | 6 +++--- src/pyuvdata/uvdata/uvdata.py | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/pyuvdata/uvbase.py b/src/pyuvdata/uvbase.py index c073e3523..10adac083 100644 --- a/src/pyuvdata/uvbase.py +++ b/src/pyuvdata/uvbase.py @@ -780,6 +780,8 @@ def _get_param_axis(self, axis_name: str, single_named_axis: bool = False): # find the axis number(s) with the named shape. attr = getattr(self, param) + # Only look at where form is a tuple, since that's the only case we + # can have a dynamically defined shape. if ( attr.value is not None and isinstance(attr.form, tuple) @@ -791,9 +793,7 @@ def _get_param_axis(self, axis_name: str, single_named_axis: bool = False): ): continue - # Only look at where form is a tuple, since that's the only case we - # can have a dynamically defined shape. Handle a repeated - # param_name in the form. + # Handle a repeated param_name in the form. ret_dict[attr.name] = np.nonzero(np.asarray(attr.form) == axis_name)[0] return ret_dict diff --git a/src/pyuvdata/uvdata/uvdata.py b/src/pyuvdata/uvdata/uvdata.py index ec7ae0388..8de0e840c 100644 --- a/src/pyuvdata/uvdata/uvdata.py +++ b/src/pyuvdata/uvdata/uvdata.py @@ -5727,8 +5727,9 @@ def __add__( info["key_arrays"]["other"][info["other_inds_use"]], ) ) + # Figure out order -- how to sort each axis. if axis == "Npols": - # order has info about how to sort each axis. + # weird handling for pol integers info["order"] = np.argsort(np.abs(info["combined_key_arrays"])) elif axis == "Nfreqs" and ( np.any(np.diff(this.freq_array) < 0)