diff --git a/src/pyuvdata/utils/frequency.py b/src/pyuvdata/utils/frequency.py index c42d741f2..dabedf4d8 100644 --- a/src/pyuvdata/utils/frequency.py +++ b/src/pyuvdata/utils/frequency.py @@ -8,6 +8,7 @@ from . import tools from .pol import jstr2num, polstr2num +from .types import FloatArray, IntArray def _check_flex_spw_contiguous(*, spw_array, flex_spw_id_array, strict=True): @@ -549,3 +550,41 @@ def _select_freq_helper( freq_inds = freq_inds.tolist() return freq_inds, spw_inds, selections + + +def _add_freq_order(spw_id: IntArray, freq_arr: FloatArray) -> IntArray: + """ + Get the sorting order for the frequency axis after an add. + + Sort first by spw then by channel, but don't reorder channels if they are + changing monotonically (all ascending or descending) within the spw. + + Parameters + ---------- + spw_id : np.ndarray of int + SPW id array of combined data to be sorted. + freq_arr : np.ndarray of float + Frequency array of combined data to be sorted. + + Returns + ------- + f_order : np.ndarray of int + index array giving the sort order. + + """ + spws = np.unique(spw_id) + f_order = np.concatenate([np.where(spw_id == spw)[0] for spw in np.unique(spw_id)]) + + # With spectral windows sorted, check and see if channels within + # windows need sorting. If they are ordered in ascending or descending + # fashion, leave them be. If not, sort in ascending order + for spw in spws: + select_mask = spw_id[f_order] == spw + check_freqs = freq_arr[f_order[select_mask]] + if not np.all(np.diff(check_freqs) > 0) and not np.all( + np.diff(check_freqs) < 0 + ): + subsort_order = f_order[select_mask] + f_order[select_mask] = subsort_order[np.argsort(check_freqs)] + + return f_order diff --git a/src/pyuvdata/utils/tools.py b/src/pyuvdata/utils/tools.py index 039c90af6..9a7247703 100644 --- a/src/pyuvdata/utils/tools.py +++ b/src/pyuvdata/utils/tools.py @@ -9,6 +9,8 @@ import numpy as np +from .types import FloatArray, IntArray, StrArray + def _get_iterable(x): """Return iterable version of input.""" @@ -687,3 +689,43 @@ def _ntimes_to_nblts(uvd): inds.append(np.where(unique_t == i)[0][0]) return np.asarray(inds) + + +def flt_ind_str_arr( + *, + fltarr: FloatArray, + intarr: IntArray, + flt_tols: tuple[float, float], + flt_first: bool = True, +) -> StrArray: + """ + Create a string array built from float and integer arrays for matching. + + Parameters + ---------- + fltarr : np.ndarray of float + float array to be used in output string array + intarr : np.ndarray of int + integer array to be used in output string array + flt_tols : 2-tuple of float + Tolerances (relative, absolute) to use in formatting the floats as strings. + flt_first : bool + Whether to put the float first in the out put string or not (if False + the int comes first.) + + Returns + ------- + np.ndarray of str + String array that combines the float and integer values, useful for matching. + + """ + prec_flt = -2 * np.floor(np.log10(flt_tols[-1])).astype(int) + prec_int = 8 + flt_str_list = ["{1:.{0}f}".format(prec_flt, flt) for flt in fltarr] + int_str_list = [str(intv).zfill(prec_int) for intv in intarr] + list_of_lists = [] + if flt_first: + list_of_lists = [flt_str_list, int_str_list] + else: + list_of_lists = [int_str_list, flt_str_list] + return np.array(["_".join(zpval) for zpval in zip(*list_of_lists, strict=True)]) diff --git a/src/pyuvdata/uvbase.py b/src/pyuvdata/uvbase.py index 7bb3d9a81..10adac083 100644 --- a/src/pyuvdata/uvbase.py +++ b/src/pyuvdata/uvbase.py @@ -16,6 +16,7 @@ from . import __version__, parameter as uvp from .utils.tools import _get_iterable, slicify +from .utils.types import IntArray __all__ = ["UVBase"] @@ -754,6 +755,63 @@ def copy(self): """ return copy.deepcopy(self) + def _get_param_axis(self, axis_name: str, single_named_axis: bool = False): + """ + Get a mapping of parameters that have a given axis to the axis number. + + Parameters + ---------- + axis_name : str + A named parameter within the object (e.g., "Nblts", "Ntimes", "Nants"). + single_named_axis : bool + Option to only include parameters with a single named axis. + + Returns + ------- + dict + The keys are UVParameter names that have an axis with axis_name + (axis_name appears in their form). The values are a list of the axis + indices where axis_name appears in their form. + + """ + ret_dict = {} + for param in self: + # For each attribute, if the value is None, then bail, otherwise + # find the axis number(s) with the named shape. + + attr = getattr(self, param) + # Only look at where form is a tuple, since that's the only case we + # can have a dynamically defined shape. + if ( + attr.value is not None + and isinstance(attr.form, tuple) + and axis_name in attr.form + ): + if ( + single_named_axis + and sum([isinstance(entry, str) for entry in attr.form]) > 1 + ): + continue + + # Handle a repeated param_name in the form. + ret_dict[attr.name] = np.nonzero(np.asarray(attr.form) == axis_name)[0] + return ret_dict + + def _get_multi_axis_params(self) -> list[str]: + """Get a list of all multidimensional parameters.""" + ret_list = [] + for param in self: + # For each attribute, if the value is None, then bail, otherwise + # check if it's multidimensional + attr = getattr(self, param) + if ( + attr.value is not None + and isinstance(attr.form, tuple) + and sum([isinstance(entry, str) for entry in attr.form]) > 1 + ): + ret_list.append(attr.name) + return ret_list + def _select_along_param_axis(self, param_dict: dict): """ Downselect values along a parameterized axis. @@ -808,3 +866,135 @@ def _select_along_param_axis(self, param_dict: dict): # here in the case of a repeated param_name in the form. attr.value = attr.get_from_form(slice_dict) attr.setter(self) + + def _axis_add_helper( + self, + other, + axis_name: str, + other_inds: IntArray, + final_order: IntArray | None = None, + ): + """ + Combine UVParameter objects with a single axis along an axis. + + Parameters + ---------- + other : UVBase + The UVBase object to be added. + axis_name : str + The axis name (e.g. "Nblts", "Npols"). + other_inds : np.ndarray of int + Indices into the other object along this axis to include. + final_order : np.ndarray of int + Final ordering array giving the sort order after concatenation. + + """ + update_params = self._get_param_axis(axis_name, single_named_axis=True) + other_form_dict = {axis_name: other_inds} + for param, axis_list in update_params.items(): + axis = axis_list[0] + new_array = np.concatenate( + [ + getattr(self, param), + getattr(other, "_" + param).get_from_form(other_form_dict), + ], + axis=axis, + ) + if final_order is not None: + new_array = np.take(new_array, final_order, axis=axis) + + setattr(self, param, new_array) + + def _axis_pad_helper(self, axis_name: str, add_len: int): + """ + Pad out UVParameter objects with multiple dimensions along an axis. + + Parameters + ---------- + axis_name : str + The axis name (e.g. "Nblts", "Npols"). + add_len : int + The extra length to be padded on for this axis. + + """ + update_params = self._get_param_axis(axis_name) + multi_axis_params = self._get_multi_axis_params() + for param, axis_list in update_params.items(): + if param not in multi_axis_params: + continue + this_param_shape = getattr(self, param).shape + this_param_type = getattr(self, "_" + param).expected_type + bool_type = this_param_type is bool or bool in this_param_type + pad_shape = list(this_param_shape) + for ax in axis_list: + pad_shape[ax] = add_len + if bool_type: + pad_array = np.ones(tuple(pad_shape), dtype=bool) + else: + pad_array = np.zeros(tuple(pad_shape)) + new_array = np.concatenate([getattr(self, param), pad_array], axis=ax) + if bool_type: + new_array = new_array.astype(np.bool_) + setattr(self, param, new_array) + + def _fill_multi_helper(self, other, t2o_dict: dict, order_dict: dict): + """ + Fill UVParameter objects with multiple dimensions from the right side object. + + Parameters + ---------- + other : UVBase + The UVBase object to be added. + t2o_dict : dict + dict giving the indices in the left object to be filled from the right + object for each axis (keys are axes, values are index arrays). + order_dict : dict + dict giving the final sort indices for each axis (keys are axes, values + are index arrays for sorting). + + """ + multi_axis_params = self._get_multi_axis_params() + for param in multi_axis_params: + form = getattr(self, "_" + param).form + index_list = [] + for axis in form: + index_list.append(t2o_dict[axis]) + new_arr = getattr(self, param) + new_arr[np.ix_(*index_list)] = getattr(other, param) + setattr(self, param, new_arr) + + # Fix ordering + for axis_ind, axis in enumerate(form): + if order_dict[axis] is not None: + unique_order_diffs = np.unique(np.diff(order_dict[axis])) + if np.array_equal(unique_order_diffs, np.array([1])): + # everything is already in order + continue + setattr( + self, + param, + np.take(getattr(self, param), order_dict[axis], axis=axis_ind), + ) + + def _axis_fast_concat_helper(self, other, axis_name: str): + """ + Concatenate UVParameter objects along an axis assuming no overlap. + + Parameters + ---------- + other : UVBase + The UVBase object to be added. + axis_name : str + The axis name (e.g. "Nblts", "Npols"). + """ + update_params = self._get_param_axis(axis_name) + for param, axis_list in update_params.items(): + axis = axis_list[0] + setattr( + self, + param, + np.concatenate( + [getattr(self, param)] + [getattr(obj, param) for obj in other], + axis=axis, + ), + ) diff --git a/src/pyuvdata/uvdata/uvdata.py b/src/pyuvdata/uvdata/uvdata.py index b55c1e1ce..8de0e840c 100644 --- a/src/pyuvdata/uvdata/uvdata.py +++ b/src/pyuvdata/uvdata/uvdata.py @@ -27,6 +27,7 @@ from ..utils import phasing as phs_utils from ..utils.io import hdf5 as hdf5_utils from ..utils.phasing import _get_focus_xyz, _get_nearfield_delay +from ..utils.types import StrArray from ..uvbase import UVBase from .initializers import new_uvdata @@ -5363,6 +5364,28 @@ def fix_phase(self, *, use_ant_pos=True): use_ant_pos=False, ) + def blt_str_arr(self) -> StrArray: + """Create a string array with baseline and time info for matching purposes.""" + return utils.tools.flt_ind_str_arr( + fltarr=self.time_array, + intarr=self.baseline_array, + flt_tols=self._time_array.tols, + flt_first=True, + ) + + def spw_freq_str_arr(self) -> StrArray: + """Create a string array with spw and freq info for matching purposes.""" + return utils.tools.flt_ind_str_arr( + fltarr=self.freq_array, + intarr=self.flex_spw_id_array, + flt_tols=self._freq_array.tols, + flt_first=False, + ) + + def flexpol_dict(self) -> dict: + """Create a dict with flexpol information for comparison.""" + return dict(zip(self.spw_array, self.flex_spw_polarization_array, strict=True)) + def __add__( self, other, @@ -5447,91 +5470,172 @@ def __add__( strict_uvw_antpos_check=strict_uvw_antpos_check, ) + if ( + this.flex_spw_polarization_array is not None + or other.flex_spw_polarization_array is not None + ): + # special checking for flex_spw + if (this.flex_spw_polarization_array is None) != ( + other.flex_spw_polarization_array is None + ): + raise ValueError( + "Cannot add a flex-pol and non-flex-pol UVData objects. Use " + "the `remove_flex_pol` method to convert the objects to " + "have a regular polarization axis." + ) + elif this.flex_spw_polarization_array is not None: + this_flexpol_dict = this.flexpol_dict() + other_flexpol_dict = other.flexpol_dict() + for key in other_flexpol_dict: + try: + if this_flexpol_dict[key] != other_flexpol_dict[key]: + raise ValueError( + "Cannot add a flex-pol UVData objects where " + "the same spectral window contains different " + "polarizations. Use the `remove_flex_pol` " + "method to convert the objects to have a " + "regular polarization axis." + ) + except KeyError: + this_flexpol_dict[key] = other_flexpol_dict[key] + # Define parameters that must be the same to add objects compatibility_params = ["_vis_units"] - # Build up history string - history_update_string = " Combined data along " - n_axes = 0 + # setup a dict to carry all the axis-specific info we need throughout + # the add process: + # - description is used in history string + # - key_params defines which parameters to use as the defining + # parameters along each axis. These are used to identify overlapping data. + # - key_func specifies a function to form a combined string if there are + # multiple key arrays (e.g. baseline-time, spw-freq) + # - reorder gives method & parameters for reording along each axis + # - order has info about how to sort each axis. Initialize to None + # for axes that are not added along (so do not need sorting), + # updated later. + # ---added later--- + # - check_params gives parameters that should be checked if adding + # along other axes + # - key_arrays gives the arrays to use for checking for overlap per axis + # - overlap_inds has the outcomes of np.intersect1d on the key_arrays + # between this and other. So it has the both/this/other inds + # for any overlaps. + # - other_inds_use has the indices in other that will be added to this + # - combined_key_arrays has the final key arrays after adding. + # - t2o has the mapping of where arrays on other get mapped into this + # along each axis after padding + + axis_info = { + "Nblts": { + "description": "baseline-time", + "key_params": ["time_array", "baseline_array"], + "key_func": "blt_str_arr", + "reorder": {"method": "reorder_blts", "parameter": "order"}, + "order": None, + }, + "Nfreqs": { + "description": "frequency", + "key_params": ["freq_array", "flex_spw_id_array"], + "key_func": "spw_freq_str_arr", + "reorder": {"method": "reorder_freqs", "parameter": "channel_order"}, + "order": None, + }, + "Npols": { + "description": "polarization", + "key_params": ["polarization_array"], + "reorder": {"method": "reorder_pols", "parameter": "order"}, + "order": None, + }, + } - # Create blt arrays for convenience - prec_t = -2 * np.floor(np.log10(this._time_array.tols[-1])).astype(int) - prec_b = 8 - this_blts = np.array( - [ - "_".join( - ["{1:.{0}f}".format(prec_t, blt[0]), str(blt[1]).zfill(prec_b)] - ) - for blt in zip(this.time_array, this.baseline_array, strict=True) - ] - ) - other_blts = np.array( - [ - "_".join( - ["{1:.{0}f}".format(prec_t, blt[0]), str(blt[1]).zfill(prec_b)] - ) - for blt in zip(other.time_array, other.baseline_array, strict=True) - ] - ) - # Check we don't have overlapping data - both_pol, this_pol_ind, other_pol_ind = np.intersect1d( - this.polarization_array, other.polarization_array, return_indices=True - ) - - # If we have a flexible spectral window, the handling here becomes a bit funky, - # because we are allowed to have channels with the same frequency *if* they - # belong to different spectral windows (one real-life example: you might want - # to preserve guard bands in the correlator, which can have overlaping RF - # frequency channels) - this_freq_ind = np.array([], dtype=np.int64) - other_freq_ind = np.array([], dtype=np.int64) - both_freq = np.array([], dtype=float) - both_spw = np.intersect1d(this.spw_array, other.spw_array) - for idx in both_spw: - this_mask = np.where(this.flex_spw_id_array == idx)[0] - other_mask = np.where(other.flex_spw_id_array == idx)[0] - both_spw_freq, this_spw_ind, other_spw_ind = np.intersect1d( - this.freq_array[this_mask], - other.freq_array[other_mask], + for axis, info in axis_info.items(): + # get parameters for compatibility checking. Exclude multidimensional + # parameters which are handled separately later. + params_this_axis = this._get_param_axis(axis, single_named_axis=True) + info["check_params"] = [] + for param in params_this_axis: + # Also exclude parameters that define overlap + if param not in info["key_params"]: + info["check_params"].append("_" + param) + + # build this/other arrays for checking for overlap. + # key_arrays gives the arrays to use for checking for overlap per axis + if len(info["key_params"]) > 1: + info["key_arrays"] = { + "this": getattr(this, info["key_func"])(), + "other": getattr(other, info["key_func"])(), + } + else: + info["key_arrays"] = { + "this": getattr(this, info["key_params"][0]), + "other": getattr(other, info["key_params"][0]), + } + + # Check if we have overlapping data + both_inds, this_inds, other_inds = np.intersect1d( + info["key_arrays"]["this"], + info["key_arrays"]["other"], return_indices=True, ) - this_freq_ind = np.append(this_freq_ind, this_mask[this_spw_ind]) - other_freq_ind = np.append(other_freq_ind, other_mask[other_spw_ind]) - both_freq = np.append(both_freq, both_spw_freq) + # overlap_inds has the outcomes of np.intersect1d on the + # key_arrays per axis. So it has the both/this inds/other inds + # for any overlaps. + info["overlap_inds"] = { + "this": this_inds, + "other": other_inds, + "both": both_inds, + } - both_blts, this_blts_ind, other_blts_ind = np.intersect1d( - this_blts, other_blts, return_indices=True - ) - if not self.metadata_only and ( - len(both_pol) > 0 and len(both_freq) > 0 and len(both_blts) > 0 + history_update_string = "" + + if np.all( + [len(axis_info[axis]["overlap_inds"]["both"]) > 0 for axis in axis_info] ): - # check that overlapping data is not valid - this_inds = np.ravel_multi_index( - ( - this_blts_ind[:, np.newaxis, np.newaxis], - this_freq_ind[np.newaxis, :, np.newaxis], - this_pol_ind[np.newaxis, np.newaxis, :], - ), - this.data_array.shape, - ).flatten() - other_inds = np.ravel_multi_index( - ( - other_blts_ind[:, np.newaxis, np.newaxis], - other_freq_ind[np.newaxis, :, np.newaxis], - other_pol_ind[np.newaxis, np.newaxis, :], - ), - other.data_array.shape, - ).flatten() - this_all_zero = np.all(this.data_array.flatten()[this_inds] == 0) - this_all_flag = np.all(this.flag_array.flatten()[this_inds]) - other_all_zero = np.all(other.data_array.flatten()[other_inds] == 0) - other_all_flag = np.all(other.flag_array.flatten()[other_inds]) - - if this_all_zero and this_all_flag: + # We have overlaps, check that overlapping data is not valid + this_test = [] + other_test = [] + multi_axis_params = this._get_multi_axis_params() + for param in multi_axis_params: + form = getattr(this, "_" + param).form + this_shape = getattr(this, param).shape + other_shape = getattr(other, param).shape + this_param_type = getattr(this, "_" + param).expected_type + bool_type = this_param_type is bool or bool in this_param_type + + this_index_list = [] + other_index_list = [] + for ax_ind, axis in enumerate(form): + expand_axes = [ax for ax in range(len(form)) if ax != ax_ind] + this_index_list.append( + np.expand_dims( + axis_info[axis]["overlap_inds"]["this"], axis=expand_axes + ) + ) + other_index_list.append( + np.expand_dims( + axis_info[axis]["overlap_inds"]["other"], axis=expand_axes + ) + ) + this_inds = np.ravel_multi_index(this_index_list, this_shape).flatten() + + other_inds = np.ravel_multi_index( + other_index_list, other_shape + ).flatten() + + this_arr = getattr(this, param).flatten()[this_inds] + other_arr = getattr(other, param).flatten()[other_inds] + + if bool_type: + this_test.append(np.all(this_arr)) + other_test.append(np.all(other_arr)) + else: + this_test.append(np.all(this_arr == 0)) + other_test.append(np.all(other_arr == 0)) + + if np.all(this_test): # we're fine to overwrite; update history accordingly history_update_string = " Overwrote invalid data using pyuvdata." - this.history += history_update_string - elif other_all_zero and other_all_flag: + elif np.all(other_test): raise ValueError( "To combine these data, please run the add operation again, " "but with the object whose data is to be overwritten as the " @@ -5542,138 +5646,44 @@ def __add__( "These objects have overlapping data and cannot be combined." ) - # find the blt indices in "other" but not in "this" - temp = np.nonzero(~np.isin(other_blts, this_blts))[0] - if len(temp) > 0: - bnew_inds = temp - new_blts = other_blts[temp] - history_update_string += "baseline-time" - n_axes += 1 - else: - bnew_inds, new_blts = ([], []) - - # if there's any overlap in blts, check extra params - temp = np.nonzero(np.isin(other_blts, this_blts))[0] - if len(temp) > 0: - # add metadata to be checked to compatibility params - extra_params = [ - "_integration_time", - "_lst_array", - "_phase_center_catalog", - "_phase_center_id_array", - "_phase_center_app_ra", - "_phase_center_app_dec", - "_phase_center_frame_pa", - "_Nphase", - "_uvw_array", - ] - compatibility_params.extend(extra_params) - - # find the freq indices in "other" but not in "this" - if (this.flex_spw_polarization_array is None) != ( - other.flex_spw_polarization_array is None - ): - raise ValueError( - "Cannot add a flex-pol and non-flex-pol UVData objects. Use " - "the `remove_flex_pol` method to convert the objects to " - "have a regular polarization axis." - ) - elif this.flex_spw_polarization_array is not None: - this_flexpol_dict = dict( - zip(this.spw_array, this.flex_spw_polarization_array, strict=True) - ) - other_flexpol_dict = dict( - zip(other.spw_array, other.flex_spw_polarization_array, strict=True) - ) - for key in other_flexpol_dict: - try: - if this_flexpol_dict[key] != other_flexpol_dict[key]: - raise ValueError( - "Cannot add a flex-pol UVData objects where the same " - "spectral window contains different polarizations. Use " - "the `remove_flex_pol` method to convert the objects " - "to have a regular polarization axis." - ) - except KeyError: - this_flexpol_dict[key] = other_flexpol_dict[key] - - other_mask = np.ones_like(other.flex_spw_id_array, dtype=bool) - for idx in np.intersect1d(this.spw_array, other.spw_array): - other_mask[other.flex_spw_id_array == idx] = np.isin( - other.freq_array[other.flex_spw_id_array == idx], - this.freq_array[this.flex_spw_id_array == idx], - invert=True, - ) - temp = np.where(other_mask)[0] - if len(temp) > 0: - fnew_inds = temp - if n_axes > 0: - history_update_string += ", frequency" - else: - history_update_string += "frequency" - n_axes += 1 - else: - fnew_inds = [] - - # if channel width is an array and there's any overlap in freqs, - # check extra params - temp = np.nonzero(np.isin(other.freq_array, this.freq_array))[0] - if len(temp) > 0: - # add metadata to be checked to compatibility params - extra_params = ["_channel_width"] - compatibility_params.extend(extra_params) - - # find the pol indices in "other" but not in "this" - temp = np.nonzero(~np.isin(other.polarization_array, this.polarization_array))[ - 0 - ] - if len(temp) > 0: - pnew_inds = temp - if n_axes > 0: - history_update_string += ", polarization" + # Now actually find which axes are going to be added along + additions = [] + # find the indices in "other" but not in "this" + for axis, info in axis_info.items(): + temp = np.nonzero( + ~np.isin(info["key_arrays"]["other"], info["key_arrays"]["this"]) + )[0] + if len(temp) > 0: + # other_inds_use has the indices in other that will be added to this + info["other_inds_use"] = temp + # add params associated with the other axes to compatibility_params + for axis2 in axis_info: + if axis2 != axis: + compatibility_params.extend(axis_info[axis2]["check_params"]) + additions.append(info["description"]) else: - history_update_string += "polarization" - n_axes += 1 - else: - pnew_inds = [] + info["other_inds_use"] = [] # Actually check compatibility parameters - blt_inds_params = [ - "_integration_time", - "_lst_array", - "_phase_center_app_ra", - "_phase_center_app_dec", - "_phase_center_frame_pa", - "_phase_center_id_array", - ] for cp in compatibility_params: - if cp in blt_inds_params: - # only check that overlapping blt indices match - this_param = getattr(this, cp) - other_param = getattr(other, cp) - params_match = np.allclose( - this_param.value[this_blts_ind], - other_param.value[other_blts_ind], - rtol=this_param.tols[0], - atol=this_param.tols[1], - ) - elif cp == "_uvw_array": - # only check that overlapping blt indices match - params_match = np.allclose( - this.uvw_array[this_blts_ind, :], - other.uvw_array[other_blts_ind, :], - rtol=this._uvw_array.tols[0], - atol=this._uvw_array.tols[1], - ) - elif cp == "_channel_width": - # only check that overlapping freq indices match - params_match = np.allclose( - this.channel_width[this_freq_ind], - other.channel_width[other_freq_ind], - rtol=this._channel_width.tols[0], - atol=this._channel_width.tols[1], - ) - else: + params_match = None + for axis, info in axis_info.items(): + if cp in info["check_params"]: + # only check that overlapping indices match + this_param = getattr(this, cp) + this_param_overlap = this_param.get_from_form( + {axis: info["overlap_inds"]["this"]} + ) + other_param_overlap = getattr(other, cp).get_from_form( + {axis: info["overlap_inds"]["other"]} + ) + params_match = np.allclose( + this_param_overlap, + other_param_overlap, + rtol=this_param.tols[0], + atol=this_param.tols[1], + ) + if params_match is None: params_match = getattr(this, cp) == getattr(other, cp) if not params_match: msg = ( @@ -5692,101 +5702,81 @@ def __add__( # Next, we want to make sure that the ordering of the _overlapping_ data is # the same, so that things can get plugged together in a sensible way. - if len(this_blts_ind) != 0: - this_argsort = np.argsort(this_blts_ind) - other_argsort = np.argsort(other_blts_ind) - - if np.any(this_argsort != other_argsort): - temp_ind = np.arange(this.Nblts) - temp_ind[this_blts_ind[this_argsort]] = temp_ind[ - this_blts_ind[other_argsort] - ] - - this.reorder_blts(order=temp_ind) - - if len(this_freq_ind) != 0: - this_argsort = np.argsort(this_freq_ind) - other_argsort = np.argsort(other_freq_ind) - if np.any(this_argsort != other_argsort): - temp_ind = np.arange(this.Nfreqs) - temp_ind[this_freq_ind[this_argsort]] = temp_ind[ - this_freq_ind[other_argsort] - ] - - this.reorder_freqs(channel_order=temp_ind) + for axis, info in axis_info.items(): + if len(info["overlap_inds"]["this"]) != 0: + # there is some overlap, so check sorting + this_argsort = np.argsort(info["overlap_inds"]["this"]) + other_argsort = np.argsort(info["overlap_inds"]["other"]) + + if np.any(this_argsort != other_argsort): + temp_ind = np.arange(getattr(this, axis)) + temp_ind[info["overlap_inds"]["this"][this_argsort]] = temp_ind[ + info["overlap_inds"]["this"][other_argsort] + ] + kwargs = {info["reorder"]["parameter"]: temp_ind} + + getattr(this, info["reorder"]["method"])(**kwargs) + + # checks are all done, start updating parameters + for axis, info in axis_info.items(): + if len(info["other_inds_use"]) > 0: + # combined_key_arrays has the final key arrays after adding. + info["combined_key_arrays"] = np.concatenate( + ( + info["key_arrays"]["this"], + info["key_arrays"]["other"][info["other_inds_use"]], + ) + ) + # Figure out order -- how to sort each axis. + if axis == "Npols": + # weird handling for pol integers + info["order"] = np.argsort(np.abs(info["combined_key_arrays"])) + elif axis == "Nfreqs" and ( + np.any(np.diff(this.freq_array) < 0) + or np.any(np.diff(other.freq_array) < 0) + ): + # deal with the possibility of spws with channels in + # descending order. + info["order"] = utils.frequency._add_freq_order( + np.concatenate( + ( + this.flex_spw_id_array, + other.flex_spw_id_array[info["other_inds_use"]], + ) + ), + np.concatenate( + (this.freq_array, other.freq_array[info["other_inds_use"]]) + ), + ) + else: + info["order"] = np.argsort(info["combined_key_arrays"]) - if len(this_pol_ind) != 0: - this_argsort = np.argsort(this_pol_ind) - other_argsort = np.argsort(other_pol_ind) - if np.any(this_argsort != other_argsort): - temp_ind = np.arange(this.Npols) - temp_ind[this_pol_ind[this_argsort]] = temp_ind[ - this_pol_ind[other_argsort] - ] + # first handle parameters with a single named axis + this._axis_add_helper( + other, axis, info["other_inds_use"], info["order"] + ) - this.reorder_pols(temp_ind) + # then pad out parameters with multiple axes + this._axis_pad_helper(axis, len(info["other_inds_use"])) + else: + # no add along this axis, so it's the same as what's already on this + info["combined_key_arrays"] = info["key_arrays"]["this"] + + # Now fill in multidimensional parameters + # t2o has the mapping of where arrays on other get mapped into + # this after padding + for _, info in axis_info.items(): + info["t2o"] = np.nonzero( + np.isin(info["combined_key_arrays"], info["key_arrays"]["other"]) + )[0] - # Pad out self to accommodate new data - blt_order = None - if len(bnew_inds) > 0: - this_blts = np.concatenate((this_blts, new_blts)) - blt_order = np.argsort(this_blts) - if not self.metadata_only: - zero_pad = np.zeros((len(bnew_inds), this.Nfreqs, this.Npols)) - this.data_array = np.concatenate([this.data_array, zero_pad], axis=0) - this.nsample_array = np.concatenate( - [this.nsample_array, zero_pad], axis=0 - ) - this.flag_array = np.concatenate( - [this.flag_array, 1 - zero_pad], axis=0 - ).astype(np.bool_) - this.uvw_array = np.concatenate( - [this.uvw_array, other.uvw_array[bnew_inds, :]], axis=0 - )[blt_order, :] - this.time_array = np.concatenate( - [this.time_array, other.time_array[bnew_inds]] - )[blt_order] - this.integration_time = np.concatenate( - [this.integration_time, other.integration_time[bnew_inds]] - )[blt_order] - this.lst_array = np.concatenate( - [this.lst_array, other.lst_array[bnew_inds]] - )[blt_order] - this.ant_1_array = np.concatenate( - [this.ant_1_array, other.ant_1_array[bnew_inds]] - )[blt_order] - this.ant_2_array = np.concatenate( - [this.ant_2_array, other.ant_2_array[bnew_inds]] - )[blt_order] - this.baseline_array = np.concatenate( - [this.baseline_array, other.baseline_array[bnew_inds]] - )[blt_order] - this.phase_center_app_ra = np.concatenate( - [this.phase_center_app_ra, other.phase_center_app_ra[bnew_inds]] - )[blt_order] - this.phase_center_app_dec = np.concatenate( - [this.phase_center_app_dec, other.phase_center_app_dec[bnew_inds]] - )[blt_order] - this.phase_center_frame_pa = np.concatenate( - [this.phase_center_frame_pa, other.phase_center_frame_pa[bnew_inds]] - )[blt_order] - this.phase_center_id_array = np.concatenate( - [this.phase_center_id_array, other.phase_center_id_array[bnew_inds]] - )[blt_order] - - f_order = None - if len(fnew_inds) > 0: - this.freq_array = np.concatenate( - [this.freq_array, other.freq_array[fnew_inds]] - ) - this.channel_width = np.concatenate( - [this.channel_width, other.channel_width[fnew_inds]] - ) + this._fill_multi_helper( + other, + {axis: info["t2o"] for axis, info in axis_info.items()}, + {axis: info["order"] for axis, info in axis_info.items()}, + ) - this.flex_spw_id_array = np.concatenate( - [this.flex_spw_id_array, other.flex_spw_id_array[fnew_inds]] - ) - this.spw_array = np.concatenate([this.spw_array, other.spw_array]) + if len(axis_info["Nfreqs"]["other_inds_use"]) > 0: # We want to preserve per-spw information based on first appearance # in the concatenated array. unique_index = np.sort( @@ -5799,108 +5789,13 @@ def __add__( this.flex_spw_polarization_array = np.array( [this_flexpol_dict[key] for key in this.spw_array] ) - # Need to sort out the order of the individual windows first. - f_order = np.concatenate( - [ - np.where(this.flex_spw_id_array == idx)[0] - for idx in sorted(this.spw_array) - ] - ) - - # With spectral windows sorted, check and see if channels within - # windows need sorting. If they are ordered in ascending or descending - # fashion, leave them be. If not, sort in ascending order - for idx in this.spw_array: - select_mask = this.flex_spw_id_array[f_order] == idx - check_freqs = this.freq_array[f_order[select_mask]] - if (not np.all(check_freqs[1:] > check_freqs[:-1])) and ( - not np.all(check_freqs[1:] < check_freqs[:-1]) - ): - subsort_order = f_order[select_mask] - f_order[select_mask] = subsort_order[np.argsort(check_freqs)] - - if not self.metadata_only: - zero_pad = np.zeros( - (this.data_array.shape[0], len(fnew_inds), this.Npols) - ) - this.data_array = np.concatenate([this.data_array, zero_pad], axis=1) - this.nsample_array = np.concatenate( - [this.nsample_array, zero_pad], axis=1 - ) - this.flag_array = np.concatenate( - [this.flag_array, 1 - zero_pad], axis=1 - ).astype(np.bool_) - - p_order = None - if len(pnew_inds) > 0: - this.polarization_array = np.concatenate( - [this.polarization_array, other.polarization_array[pnew_inds]] - ) - p_order = np.argsort(np.abs(this.polarization_array)) - if not self.metadata_only: - zero_pad = np.zeros( - (this.data_array.shape[0], this.data_array.shape[1], len(pnew_inds)) - ) - this.data_array = np.concatenate([this.data_array, zero_pad], axis=2) - this.nsample_array = np.concatenate( - [this.nsample_array, zero_pad], axis=2 - ) - this.flag_array = np.concatenate( - [this.flag_array, 1 - zero_pad], axis=2 - ).astype(np.bool_) - - # Now populate the data - pol_t2o = np.nonzero( - np.isin(this.polarization_array, other.polarization_array) - )[0] - this_freqs = this.freq_array - other_freqs = other.freq_array - - freq_t2o = np.zeros(this_freqs.shape, dtype=bool) - for spw_id in set(this.spw_array).intersection(other.spw_array): - mask = this.flex_spw_id_array == spw_id - freq_t2o[mask] |= np.isin( - this_freqs[mask], other_freqs[other.flex_spw_id_array == spw_id] - ) - freq_t2o = np.nonzero(freq_t2o)[0] - blt_t2o = np.nonzero(np.isin(this_blts, other_blts))[0] - - if not self.metadata_only: - this.data_array[np.ix_(blt_t2o, freq_t2o, pol_t2o)] = other.data_array - this.nsample_array[np.ix_(blt_t2o, freq_t2o, pol_t2o)] = other.nsample_array - this.flag_array[np.ix_(blt_t2o, freq_t2o, pol_t2o)] = other.flag_array - - # Fix ordering - axis_dict = { - 0: {"inds": bnew_inds, "order": blt_order}, - 1: {"inds": fnew_inds, "order": f_order}, - 2: {"inds": pnew_inds, "order": p_order}, - } - for axis, subdict in axis_dict.items(): - for name, param in zip( - this._data_params, this.data_like_parameters, strict=True - ): - if len(subdict["inds"]) > 0: - unique_order_diffs = np.unique(np.diff(subdict["order"])) - if np.array_equal(unique_order_diffs, np.array([1])): - # everything is already in order - continue - setattr(this, name, np.take(param, subdict["order"], axis=axis)) - - if len(fnew_inds) > 0: - this.freq_array = this.freq_array[f_order] - this.channel_width = this.channel_width[f_order] - this.flex_spw_id_array = this.flex_spw_id_array[f_order] - - if len(pnew_inds) > 0: - this.polarization_array = this.polarization_array[p_order] # Update N parameters (e.g. Npols) this.Ntimes = len(np.unique(this.time_array)) this.Nbls = len(np.unique(this.baseline_array)) - this.Nblts = this.uvw_array.shape[0] + this.Nblts = this.baseline_array.size this.Nfreqs = this.freq_array.size - this.Npols = this.polarization_array.shape[0] + this.Npols = this.polarization_array.size this.Nants_data = this._calc_nants_data() # Update filename parameter @@ -5908,9 +5803,14 @@ def __add__( if this.filename is not None: this._filename.form = (len(this.filename),) - if n_axes > 0: - history_update_string += " axis using pyuvdata." + if len(additions) > 0: + # Build up history string + history_update_string += ( + " Combined data along " + ", ".join(additions) + " axis using pyuvdata." + ) + if len(history_update_string) > 0: + # this can be true even if len(additions)=0 b/c of filling in invalid data. histories_match = utils.history._check_histories( this.history, other.history ) @@ -5929,9 +5829,9 @@ def __add__( + extra_history ) - # Reset blt_order if blt axis was added to and it is set - if len(blt_t2o) > 0: - this.blt_order = None + # Reset blt_order if blt axis was added to + if axis_info["Nblts"]["order"] is not None: + this.blt_order = ("time", "baseline") this.set_rectangularity(force=True) @@ -6092,9 +5992,20 @@ def fast_concat( self and other are not compatible. """ - allowed_axes = ["blt", "freq", "polarization"] - if axis not in allowed_axes: - raise ValueError("Axis must be one of: " + ", ".join(allowed_axes)) + # setup a dict to carry all the axis-specific info we need throughout + # the fast concat process: + # - description is used in history string + # - shape: the shape name parameter (e.g. "Nblts", "Nfreqs", "Npols") + # ---added later---- + # - check_params gives parameters that should be checked if adding + # along other axes + axis_info = { + "blt": {"description": "baseline-time", "shape": "Nblts"}, + "freq": {"description": "frequency", "shape": "Nfreqs"}, + "polarization": {"description": "polarization", "shape": "Npols"}, + } + if axis not in axis_info: + raise ValueError("Axis must be one of: " + ", ".join(axis_info)) if inplace: this = self @@ -6144,39 +6055,23 @@ def fast_concat( history_update_string = " Combined data along " - if axis == "freq": - history_update_string += "frequency" - compatibility_params += [ - "_polarization_array", - "_ant_1_array", - "_ant_2_array", - "_integration_time", - "_uvw_array", - "_lst_array", - "_phase_center_id_array", - ] - elif axis == "polarization": - history_update_string += "polarization" - compatibility_params += [ - "_freq_array", - "_channel_width", - "_flex_spw_id_array", - "_ant_1_array", - "_ant_2_array", - "_integration_time", - "_uvw_array", - "_lst_array", - "_phase_center_id_array", - ] - elif axis == "blt": - history_update_string += "baseline-time" - compatibility_params += [ - "_freq_array", - "_polarization_array", - "_flex_spw_id_array", - ] + # figure out what parameters to check for compatibility -- only worry + # about single axis params + for _, info in axis_info.items(): + params_this_axis = this._get_param_axis( + info["shape"], single_named_axis=True + ) + info["check_params"] = [] + for param in params_this_axis: + info["check_params"].append("_" + param) - history_update_string += " axis using pyuvdata." + for axis2, info in axis_info.items(): + if axis2 != axis: + compatibility_params.extend(info["check_params"]) + + history_update_string += ( + f" {axis_info[axis]['description']} axis using pyuvdata." + ) histories_match = [] for obj in other: @@ -6215,106 +6110,28 @@ def fast_concat( this.telescope = tel_obj + # actually do the concat + this._axis_fast_concat_helper(other, axis_info[axis]["shape"]) + + # update the relevant shape parameter + new_shape = sum( + [getattr(this, axis_info[axis]["shape"])] + + [getattr(obj, axis_info[axis]["shape"]) for obj in other] + ) + setattr(this, axis_info[axis]["shape"], new_shape) + if axis == "freq": - this.Nfreqs = sum([this.Nfreqs] + [obj.Nfreqs for obj in other]) - this.freq_array = np.concatenate( - [this.freq_array] + [obj.freq_array for obj in other] - ) - this.channel_width = np.concatenate( - [this.channel_width] + [obj.channel_width for obj in other] - ) - this.flex_spw_id_array = np.concatenate( - [this.flex_spw_id_array] + [obj.flex_spw_id_array for obj in other] - ) - this.spw_array = np.concatenate( - [this.spw_array] + [obj.spw_array for obj in other] - ) # We want to preserve per-spw information based on first appearance # in the concatenated array. unique_index = np.sort( np.unique(this.flex_spw_id_array, return_index=True)[1] ) this.spw_array = this.flex_spw_id_array[unique_index] - this.Nspws = len(this.spw_array) - - if not self.metadata_only: - this.data_array = np.concatenate( - [this.data_array] + [obj.data_array for obj in other], axis=1 - ) - this.nsample_array = np.concatenate( - [this.nsample_array] + [obj.nsample_array for obj in other], axis=1 - ) - this.flag_array = np.concatenate( - [this.flag_array] + [obj.flag_array for obj in other], axis=1 - ) - elif axis == "polarization": - this.polarization_array = np.concatenate( - [this.polarization_array] + [obj.polarization_array for obj in other] - ) - this.Npols = sum([this.Npols] + [obj.Npols for obj in other]) - - if not self.metadata_only: - this.data_array = np.concatenate( - [this.data_array] + [obj.data_array for obj in other], axis=2 - ) - this.nsample_array = np.concatenate( - [this.nsample_array] + [obj.nsample_array for obj in other], axis=2 - ) - this.flag_array = np.concatenate( - [this.flag_array] + [obj.flag_array for obj in other], axis=2 - ) elif axis == "blt": - this.Nblts = sum([this.Nblts] + [obj.Nblts for obj in other]) - this.ant_1_array = np.concatenate( - [this.ant_1_array] + [obj.ant_1_array for obj in other] - ) - this.ant_2_array = np.concatenate( - [this.ant_2_array] + [obj.ant_2_array for obj in other] - ) this.Nants_data = this._calc_nants_data() - this.uvw_array = np.concatenate( - [this.uvw_array] + [obj.uvw_array for obj in other], axis=0 - ) - this.time_array = np.concatenate( - [this.time_array] + [obj.time_array for obj in other] - ) this.Ntimes = len(np.unique(this.time_array)) - this.lst_array = np.concatenate( - [this.lst_array] + [obj.lst_array for obj in other] - ) - this.baseline_array = np.concatenate( - [this.baseline_array] + [obj.baseline_array for obj in other] - ) this.Nbls = len(np.unique(this.baseline_array)) - this.integration_time = np.concatenate( - [this.integration_time] + [obj.integration_time for obj in other] - ) - this.phase_center_app_ra = np.concatenate( - [this.phase_center_app_ra] + [obj.phase_center_app_ra for obj in other] - ) - this.phase_center_app_dec = np.concatenate( - [this.phase_center_app_dec] - + [obj.phase_center_app_dec for obj in other] - ) - this.phase_center_frame_pa = np.concatenate( - [this.phase_center_frame_pa] - + [obj.phase_center_frame_pa for obj in other] - ) - this.phase_center_id_array = np.concatenate( - [this.phase_center_id_array] - + [obj.phase_center_id_array for obj in other] - ) - if not self.metadata_only: - this.data_array = np.concatenate( - [this.data_array] + [obj.data_array for obj in other], axis=0 - ) - this.nsample_array = np.concatenate( - [this.nsample_array] + [obj.nsample_array for obj in other], axis=0 - ) - this.flag_array = np.concatenate( - [this.flag_array] + [obj.flag_array for obj in other], axis=0 - ) # update filename attribute for obj in other: diff --git a/tests/uvdata/test_uvdata.py b/tests/uvdata/test_uvdata.py index abf7fe291..6ad1c62de 100644 --- a/tests/uvdata/test_uvdata.py +++ b/tests/uvdata/test_uvdata.py @@ -3385,6 +3385,7 @@ def test_sum_vis_errors(hera_uvh5, attr_to_get, attr_to_set, arg_dict, msg): @pytest.mark.filterwarnings("ignore:The uvw_array does not match the expected values") def test_add_freq(casa_uvfits): uv_full = casa_uvfits + uv_full.scan_number_array = np.arange(uv_full.Nblts) uv1 = uv_full.select(freq_chans=np.arange(0, 32), inplace=False) uv2 = uv_full.select(freq_chans=np.arange(32, 64), inplace=False) @@ -3415,6 +3416,7 @@ def test_add_freq(casa_uvfits): @pytest.mark.filterwarnings("ignore:The uvw_array does not match the expected values") def test_add_pols(casa_uvfits): uv_full = casa_uvfits + uv_full.scan_number_array = np.arange(uv_full.Nblts) uv1 = uv_full.select(polarizations=uv_full.polarization_array[0:2], inplace=False) uv2 = uv_full.select(polarizations=uv_full.polarization_array[2:4], inplace=False) @@ -3453,6 +3455,7 @@ def test_add_pols(casa_uvfits): @pytest.mark.filterwarnings("ignore:The uvw_array does not match the expected values") def test_add_times(casa_uvfits): uv_full = casa_uvfits + uv_full.scan_number_array = np.arange(uv_full.Nblts) times = np.unique(uv_full.time_array) uv1 = uv_full.select(times=times[0 : len(times) // 2], inplace=False) @@ -3474,6 +3477,8 @@ def test_add_times(casa_uvfits): @pytest.mark.filterwarnings("ignore:The uvw_array does not match the expected values") def test_add_bls(casa_uvfits): uv_full = casa_uvfits + uv_full.reorder_blts() + uv_full.scan_number_array = np.arange(uv_full.Nblts) ant_list = list(range(15)) # Roughly half the antennas in the data # All blts where ant_1 is in list @@ -3514,6 +3519,7 @@ def test_add_bls(casa_uvfits): uv3.ant_1_array = uv3.ant_1_array[-1::-1] uv3.ant_2_array = uv3.ant_2_array[-1::-1] uv3.baseline_array = uv3.baseline_array[-1::-1] + uv3.scan_number_array = uv3.scan_number_array[-1::-1] uv1 += uv3 uv1 += uv2 assert utils.history._check_histories( @@ -3532,6 +3538,7 @@ def test_add_bls(casa_uvfits): @pytest.mark.filterwarnings("ignore:The uvw_array does not match the expected values") def test_add_multi_axis(casa_uvfits): uv_full = casa_uvfits + uv_full.scan_number_array = np.arange(uv_full.Nblts) uv_ref = uv_full.copy() times = np.unique(uv_full.time_array) @@ -3976,6 +3983,7 @@ def test_flex_spw_add_concat(sma_mir, add_method, screen1, screen2): if np.any(np.logical_and(screen1, screen2)): flag_screen = screen2[screen1] uv1.data_array[:, flag_screen] = 0.0 + uv1.nsample_array[:, flag_screen] = 0.0 uv1.flag_array[:, flag_screen] = True uv_recomb = getattr(uv1, add_method[0])(uv2, **add_method[1]) @@ -4005,7 +4013,7 @@ def test_flex_spw_add_concat(sma_mir, add_method, screen1, screen2): [ [], [["unproject_phase", {}], ["select", {"freq_chans": np.arange(32, 64)}]], - "UVParameter phase_center_catalog does not match. Cannot combine objects.", + "UVParameter phase_center_app_dec does not match. Cannot combine objects.", ], [ [["vis_units", "Jy"]], @@ -4017,6 +4025,11 @@ def test_flex_spw_add_concat(sma_mir, add_method, screen1, screen2): [["select", {"freq_chans": np.arange(32, 64)}]], "UVParameter integration_time does not match.", ], + [ + [["scan_number_array", np.arange(1360)]], + [["select", {"freq_chans": np.arange(32, 64)}]], + "UVParameter scan_number_array does not match.", + ], ], ) def test_break_add(casa_uvfits, attr_to_set, attr_to_get, msg): @@ -4026,6 +4039,7 @@ def test_break_add(casa_uvfits, attr_to_set, attr_to_get, msg): """ # Test failure modes of add function uv1 = casa_uvfits + uv1._set_scan_numbers() uv2 = uv1.copy() uv1.select(freq_chans=np.arange(0, 32)) @@ -4170,6 +4184,7 @@ def test_fast_concat_times(casa_uvfits): @pytest.mark.parametrize("in_order", [True, False]) def test_fast_concat_bls(casa_uvfits, in_order): uv_full = casa_uvfits + uv_full.scan_number_array = np.arange(uv_full.Nblts) if in_order: # divide in half to keep in order