Skip to content

Commit 945b3a6

Browse files
committed
use a single dict in fast_concat as well
1 parent 1f2807a commit 945b3a6

File tree

1 file changed

+35
-27
lines changed

1 file changed

+35
-27
lines changed

src/pyuvdata/uvdata/uvdata.py

Lines changed: 35 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5970,9 +5970,20 @@ def fast_concat(
59705970
self and other are not compatible.
59715971

59725972
"""
5973-
allowed_axes = ["blt", "freq", "polarization"]
5974-
if axis not in allowed_axes:
5975-
raise ValueError("Axis must be one of: " + ", ".join(allowed_axes))
5973+
# setup a dict to carry all the axis-specific info we need throughout
5974+
# the fast concat process:
5975+
# - description is used in history string
5976+
# - shape: the shape name parameter (e.g. "Nblts", "Nfreqs", "Npols")
5977+
# ---added later----
5978+
# - check_params gives parameters that should be checked if adding
5979+
# along other axes
5980+
axis_info = {
5981+
"blt": {"description": "baseline-time", "shape": "Nblts"},
5982+
"freq": {"description": "frequency", "shape": "Nfreqs"},
5983+
"polarization": {"description": "polarization", "shape": "Npols"},
5984+
}
5985+
if axis not in axis_info:
5986+
raise ValueError("Axis must be one of: " + ", ".join(axis_info))
59765987

59775988
if inplace:
59785989
this = self
@@ -6022,28 +6033,23 @@ def fast_concat(
60226033

60236034
history_update_string = " Combined data along "
60246035

6025-
# identify params that are not explicitly included in overlap calc per axis
6026-
axis_shape = {"blt": "Nblts", "freq": "Nfreqs", "polarization": "Npols"}
6027-
axis_check_params = {}
6028-
axis_parameters = {}
6029-
for axis2, ax_shape in axis_shape.items():
6030-
axis_parameters[axis2] = this._get_param_axis(ax_shape)
6031-
axis_check_params[axis2] = []
6032-
for param in axis_parameters[axis2]:
6033-
if param not in this._data_params:
6034-
axis_check_params[axis2].append("_" + param)
6035-
6036-
for axis2 in axis_shape:
6037-
if axis2 != axis:
6038-
compatibility_params.extend(axis_check_params[axis2])
6036+
# figure out what parameters to check for compatibility -- only worry
6037+
# about single axis params
6038+
for _, info in axis_info.items():
6039+
params_this_axis = this._get_param_axis(
6040+
info["shape"], single_named_axis=True
6041+
)
6042+
info["check_params"] = []
6043+
for param in params_this_axis:
6044+
info["check_params"].append("_" + param)
60396045

6040-
axis_descriptions = {
6041-
"blt": "baseline-time",
6042-
"freq": "frequency",
6043-
"polarization": "polarization",
6044-
}
6046+
for axis2, info in axis_info.items():
6047+
if axis2 != axis:
6048+
compatibility_params.extend(info["check_params"])
60456049

6046-
history_update_string += f" {axis_descriptions[axis]} axis using pyuvdata."
6050+
history_update_string += (
6051+
f" {axis_info[axis]['description']} axis using pyuvdata."
6052+
)
60476053

60486054
histories_match = []
60496055
for obj in other:
@@ -6082,13 +6088,15 @@ def fast_concat(
60826088

60836089
this.telescope = tel_obj
60846090

6091+
# actually do the concat
6092+
this._axis_fast_concat_helper(other, axis_info[axis]["shape"])
6093+
60856094
# update the relevant shape parameter
6086-
this._axis_fast_concat_helper(other, axis_shape[axis])
60876095
new_shape = sum(
6088-
[getattr(this, axis_shape[axis])]
6089-
+ [getattr(obj, axis_shape[axis]) for obj in other]
6096+
[getattr(this, axis_info[axis]["shape"])]
6097+
+ [getattr(obj, axis_info[axis]["shape"]) for obj in other]
60906098
)
6091-
setattr(this, axis_shape[axis], new_shape)
6099+
setattr(this, axis_info[axis]["shape"], new_shape)
60926100

60936101
if axis == "freq":
60946102
# We want to preserve per-spw information based on first appearance

0 commit comments

Comments
 (0)