diff --git a/CHANGELOG.md b/CHANGELOG.md index b31d6df78..2f9b9ebd2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,8 @@ All notable changes to this project will be documented in this file. ## [Unreleased] ### Added +- A colormap option to the UVBeam and AnalyticBeam `plot` methods, along with +better default colormaps for phase. - A NotImplementedError to `mwa_corr_fits.py` that is thrown when trying to read fringe-stopped data. diff --git a/docs/Images/hera_efield_phase.png b/docs/Images/hera_efield_phase.png index 966c4b6cb..9cd788daf 100644 Binary files a/docs/Images/hera_efield_phase.png and b/docs/Images/hera_efield_phase.png differ diff --git a/src/pyuvdata/analytic_beam.py b/src/pyuvdata/analytic_beam.py index b20b81c71..d1b788bd6 100644 --- a/src/pyuvdata/analytic_beam.py +++ b/src/pyuvdata/analytic_beam.py @@ -646,6 +646,7 @@ def plot( beam_type: str, freq: float, complex_type: str = "real", + colormap: str | None = None, logcolor: bool | None = None, plt_kwargs: dict | None = None, norm_kwargs: dict | None = None, @@ -666,6 +667,11 @@ def plot( What to plot for complex beams, options are: [real, imag, abs, phase]. Defaults to "real" for complex beams. Ignored for real beams (i.e. power beams, same feed). + colormap : str, optional + Matplotlib colormap to use. Defaults to "twlight" if complex_type="phase" + and logcolor=False, otherwise it defaults to "viridis" if the data to be + plotted is positive definite (e.g. if complex_type="abs") and "PRGn" + otherwise. logcolor : bool, optional Option to use log scaling for the color. Defaults to True for power beams and False for E-field beams. Results in using @@ -690,6 +696,7 @@ def plot( beam_type=beam_type, freq=freq, complex_type=complex_type, + colormap=colormap, logcolor=logcolor, plt_kwargs=plt_kwargs, norm_kwargs=norm_kwargs, diff --git a/src/pyuvdata/utils/plotting.py b/src/pyuvdata/utils/plotting.py index 801c8b474..53c246896 100644 --- a/src/pyuvdata/utils/plotting.py +++ b/src/pyuvdata/utils/plotting.py @@ -10,37 +10,78 @@ from ..analytic_beam import AnalyticBeam, UnpolarizedAnalyticBeam from .coordinates import _get_hpix_obj, hpx_latlon_to_zenithangle_azimuth from .pol import polnum2str +from .types import FloatArray -def beam_plot( +def get_az_za_grid(max_zenith_deg: float = 90.0): + """ + Get an azimuth, zenith angle grid. + + Parameters + ---------- + max_zenith_deg : float + Maximum zenith angle to include in the plot in degrees. Default is + 90 to go down to the horizon. + + """ + az_grid = np.deg2rad(np.arange(0, 360)) + za_grid = np.deg2rad(np.arange(0, 91)) * (max_zenith_deg / 90.0) + + return az_grid, za_grid + + +def plot_beam_arrays( + beam_vals: FloatArray, + az_array: FloatArray, + za_array: FloatArray, *, - beam_obj: UVBeam | AnalyticBeam, - freq: int | float, - beam_type: str | None = None, complex_type: str = "real", + beam_name: str = "", + beam_type_label: str = "", + feedpol_label: list[str] | None = None, + freq_label: float | None = None, + colormap: str | None = None, logcolor: bool | None = None, plt_kwargs: dict | None = None, norm_kwargs: dict | None = None, - max_zenith_deg: float = 90.0, savefile: str | None = None, ): """ - Make a pretty plot of a beam. + Make a pretty plot of a beam from arrays. + + This is usually called via UVBeam.plot or AnalyticBeam.plot but can be used + for arrays in memory if desired. Parameters ---------- - beam_obj : UVBeam or AnalyticBeam - The beam to plot. - freq : int or float - Either the index into the freq_array for UVBeam objects (int) or the - frequency to evaluate the beam at in Hz (float) for AnalyticBeam objects. - beam_type : str - Required for analytic beams to specify the beam type to plot. Ignored for - UVBeams. + beam_vals : ndarray of float + The beam values to plot. Shape depends on whether it is regularly + gridded in az/za or not. For regular grids, the shape is + (naxes_vec, nfeedpol, za_grid.size, az_grid.size). For irregular az/za, + the shape is (naxes_vec, nfeedpol, n_directions). + az_array : ndarray of float + The azimuth values. For regular grids, the shape is (za_grid.size, + az_grid.size). For irregular grids the shape is (n_directions,). + za_array : ndarray of float + The zenith angle values. For regular grids, the shape is (za_grid.size, + az_grid.size). For irregular grids the shape is (n_directions,). complex_type : str What to plot for complex beams, options are: [real, imag, abs, phase]. Defaults to "real" for complex beams. Ignored for real beams (i.e. power beams, same feed). + beam_name : str, optional + The telescope name or beam name, used to label the plots. + freq_label : float, optional + The frequency, used only in the plots title. Optional. + beam_type_label : str, optional + Used for labelling the plots. + feedpol_label: list of str, optional + Feed or polarization labels, used for plot labelling. If provided, must + be the same length as the 1st (not 0th) dimension of beam_vals. + colormap : str, optional + Matplotlib colormap to use. Defaults to "twlight" if complex_type="phase" + and logcolor=False, otherwise it defaults to "viridis" if the data to be + plotted is positive definite (e.g. if complex_type="abs") and "PRGn" otherwise. logcolor : bool, optional Option to use log scaling for the color. Defaults to True for power beams and False for E-field beams. Results in using @@ -51,9 +92,6 @@ def beam_plot( norm_kwargs : dict, optional Keywords to be passed into the norm object, typically vmin/vmax, plus linthresh for SymLogNorm. - max_zenith_deg : float - Maximum zenith angle to include in the plot in degrees. Default is - 90 to go down to the horizon. savefile : str File to save the plot to. @@ -76,109 +114,22 @@ def beam_plot( complex_func = {"real": np.real, "imag": np.imag, "abs": np.abs, "phase": np.angle} - if isinstance(beam_obj, UVBeam): - beam_type = beam_obj.beam_type - - feed_labels = np.degrees(beam_obj.feed_angle).astype(str) - feed_labels[np.isclose(beam_obj.feed_angle, 0)] = "N/S" - feed_labels[np.isclose(beam_obj.feed_angle, np.pi / 2)] = "E/W" - - if beam_type == "efield": - nfeedpol = beam_obj.Nfeeds - feedpol_label = feed_labels - if issubclass(type(beam_obj), UnpolarizedAnalyticBeam): - feedpol_label = beam_obj.feed_array - if logcolor is None: - logcolor = False + if freq_label is not None: + si_prefix = {"T": 1e12, "G": 1e9, "M": 1e6, "k": 1e3} + freq_str = f"{freq_label:.0f} Hz" + for prefix, multiplier in si_prefix.items(): + if freq_label > multiplier: + freq_str = f"{freq_label / multiplier:.0f} {prefix}Hz" + break else: - nfeedpol = beam_obj.Npols - pol_strs = polnum2str(beam_obj.polarization_array) - if np.max(beam_obj.polarization_array) <= -5 and not issubclass( - type(beam_obj), UnpolarizedAnalyticBeam - ): - # linear pols, use feed angles. - feedpol_label = [""] * nfeedpol - for col_i, polstr in enumerate(pol_strs): - feed0_ind = np.nonzero(beam_obj.feed_array == polstr[0])[0][0] - feed1_ind = np.nonzero(beam_obj.feed_array == polstr[1])[0][0] - if feed0_ind == feed1_ind: - feedpol_label[col_i] = feed_labels[feed0_ind] - else: - feedpol_label[col_i] = "-".join( - [feed_labels[feed0_ind], feed_labels[feed1_ind]] - ) - else: - feedpol_label = pol_strs - if logcolor is None: - logcolor = True - - if isinstance(beam_obj, UVBeam): - naxes_vec = beam_obj.Naxes_vec - name = beam_obj.telescope_name - freq_title = beam_obj.freq_array[freq] - - reg_grid = True - if beam_obj.pixel_coordinate_system == "healpix": - HEALPix = _get_hpix_obj() - - hpx_obj = HEALPix(nside=beam_obj.nside, order=beam_obj.ordering) - hpx_lon, hpx_lat = hpx_obj.healpix_to_lonlat(beam_obj.pixel_array) - za_array, az_array = hpx_latlon_to_zenithangle_azimuth( - hpx_lat.rad, hpx_lon.rad - ) - pts_use = np.nonzero(za_array <= np.radians(max_zenith_deg))[0] - za_array = za_array[pts_use] - az_array = az_array[pts_use] - reg_grid = False - else: - za_use = np.nonzero(beam_obj.axis2_array <= np.radians(max_zenith_deg))[0] - az_array, za_array = np.meshgrid( - beam_obj.axis1_array, beam_obj.axis2_array[za_use] - ) - - beam_vals = copy.deepcopy(beam_obj.data_array)[:, :, freq] - if reg_grid: - beam_vals = beam_vals[:, :, za_use, :] - else: - beam_vals = beam_vals[:, :, pts_use] - elif issubclass(type(beam_obj), AnalyticBeam): - name = beam_obj.__class__.__name__ - freq_title = freq - - naxes_vec = beam_obj.Naxes_vec - if beam_type == "power": - naxes_vec = 1 - reg_grid = True - - az_grid = np.deg2rad(np.arange(0, 360)) - za_grid = np.deg2rad(np.arange(0, 91)) * (max_zenith_deg / 90.0) - az_array, za_array = np.meshgrid(az_grid, za_grid) - bi = BeamInterface(beam_obj, beam_type=beam_type) - beam_vals = bi.compute_response( - az_array=az_array.flatten(), - za_array=za_array.flatten(), - freq_array=np.asarray([freq]), - ) - if issubclass(type(beam_obj), UnpolarizedAnalyticBeam): - beam_vals = beam_vals[0, :] - naxes_vec = 1 - if nfeedpol == 1: - feedpol_label = [""] - beam_vals = beam_vals.reshape(naxes_vec, nfeedpol, za_grid.size, az_grid.size) - si_prefix = {"T": 1e12, "G": 1e9, "M": 1e6, "k": 1e3} - freq_str = f"{freq_title:.0f} Hz" - for prefix, multiplier in si_prefix.items(): - if freq_title > multiplier: - freq_str = f"{freq_title / multiplier:.0f} {prefix}Hz" - break + freq_str = "" az_za_radial_val = np.sin(za_array) # get 4 radial ticks with values spaced sinusoidally (so ~linear in the plot), # rounded to the nearest 5 degrees radial_ticks_deg = ( np.round( - np.degrees(np.arcsin(np.linspace(0, np.sin(np.radians(max_zenith_deg)), 5))) - / 5 + np.degrees(np.arcsin(np.linspace(0, np.sin(np.max(za_array)), 5))) / 5 ).astype(int) * 5 )[1:] @@ -187,10 +138,49 @@ def beam_plot( beam_vals = complex_func[complex_type](beam_vals) type_label = ", " + complex_type else: + complex_type = "real" type_label = "" + if beam_vals.ndim == 4: + reg_grid = True + elif beam_vals.ndim == 3: + reg_grid = False + else: + raise ValueError("beam_vals must be 3 or 4 dimensional.") + beam_shape = beam_vals.shape + naxes_vec = beam_shape[0] + nfeedpol = beam_shape[1] + + if feedpol_label is None: + feedpol_label = np.arange(nfeedpol).astype(str).tolist() + else: + if len(feedpol_label) != nfeedpol: + raise ValueError( + "feedpol_label must have the same number of elements as the " + "1st (not 0th) dimension of beam_vals." + ) + + if reg_grid: + exp_coord_shape = (beam_shape[2], beam_shape[3]) + if az_array.shape != exp_coord_shape or za_array.shape != exp_coord_shape: + raise ValueError( + "az_array and za_array must be shaped like the last 2 dimensions " + "of beam_vals for regularly gridded beam_vals (beam_vals has 4 " + "dimensions)." + ) + else: + exp_coord_shape = (beam_shape[2],) + if az_array.shape != exp_coord_shape or za_array.shape != exp_coord_shape: + raise ValueError( + "az_array and za_array must be shaped like the last dimension " + "of beam_vals for irregular beam_vals (beam_vals has 3 dimensions)." + ) + norm_use = None - colormap = "viridis" + if complex_type == "phase": + default_colormap = "twilight" + else: + default_colormap = "viridis" if norm_kwargs is None: norm_kwargs = {} if plt_kwargs is None: @@ -208,7 +198,7 @@ def beam_plot( if key not in norm_kwargs: norm_kwargs[key] = value norm_use = SymLogNorm(**norm_kwargs) - colormap = "PRGn" + default_colormap = "PRGn" else: norm_use = LogNorm(**norm_kwargs) else: @@ -216,23 +206,27 @@ def beam_plot( for key in ["vmax", "vmin"]: if key in norm_kwargs: plt_kwargs[key] = norm_kwargs[key] - if np.min(beam_vals) < 0: - colormap = "PRGn" + if complex_type == "phase": + default_norm_kwargs = {"vmax": np.pi, "vmin": -np.pi} + elif np.min(beam_vals) < 0: + default_colormap = "PRGn" default_norm_kwargs = { "vmax": np.max(np.abs(beam_vals)), "vmin": -1 * np.max(np.abs(beam_vals)), } - for key, value in default_norm_kwargs.items(): - if key not in plt_kwargs: - plt_kwargs[key] = value + else: + default_norm_kwargs = {"vmax": np.max(np.abs(beam_vals)), "vmin": 0} + for key, value in default_norm_kwargs.items(): + if key not in plt_kwargs: + plt_kwargs[key] = value + + if colormap is None: + colormap = default_colormap if naxes_vec == 2: vec_label = ["azimuth", "zenith angle"] else: - if beam_type == "power": - vec_label = ["power"] - else: - vec_label = ["E-field"] + vec_label = [beam_type_label] nrow = naxes_vec ncol = nfeedpol @@ -276,7 +270,7 @@ def beam_plot( ax_use.set_rmax(np.max(az_za_radial_val)) _ = ax_use.set_title( - f"{feedpol_label[fp_i]} {name} {vec_label[vec_i]} " + f"{feedpol_label[fp_i]} {beam_name} {vec_label[vec_i]} " f"response ({freq_str}){type_label}", fontsize="medium", ) @@ -293,3 +287,165 @@ def beam_plot( else: plt.savefig(savefile, bbox_inches="tight") plt.close() + + +def beam_plot( + *, + beam_obj: UVBeam | AnalyticBeam, + freq: int | float, + beam_type: str | None = None, + complex_type: str = "real", + colormap: str | None = None, + logcolor: bool | None = None, + plt_kwargs: dict | None = None, + norm_kwargs: dict | None = None, + max_zenith_deg: float = 90.0, + savefile: str | None = None, +): + """ + Make a pretty plot of a beam. + + Parameters + ---------- + beam_obj : UVBeam or AnalyticBeam + The beam to plot. + freq : int or float + Either the index into the freq_array for UVBeam objects (int) or the + frequency to evaluate the beam at in Hz (float) for AnalyticBeam objects. + beam_type : str + Required for analytic beams to specify the beam type to plot. Ignored for + UVBeams. + complex_type : str + What to plot for complex beams, options are: [real, imag, abs, phase]. + Defaults to "real" for complex beams. Ignored for real beams + (i.e. power beams, same feed). + colormap : str, optional + Matplotlib colormap to use. Defaults to "twlight" if complex_type="phase" + and logcolor=False, otherwise it defaults to "viridis" if the data to be + plotted is positive definite (e.g. if complex_type="abs") and "PRGn" otherwise. + logcolor : bool, optional + Option to use log scaling for the color. Defaults to True for power + beams and False for E-field beams. Results in using + matplotlib.colors.LogNorm or matplotlib.colors.SymLogNorm if the data + have negative values. + plt_kwargs : dict, optional + Keywords to be passed into the matplotlib.pyplot.imshow call. + norm_kwargs : dict, optional + Keywords to be passed into the norm object, typically vmin/vmax, plus + linthresh for SymLogNorm. + max_zenith_deg : float + Maximum zenith angle to include in the plot in degrees. Default is + 90 to go down to the horizon. + savefile : str + File to save the plot to. + + """ + if isinstance(beam_obj, UVBeam): + beam_type = beam_obj.beam_type + + feed_labels = np.degrees(beam_obj.feed_angle).astype(str) + feed_labels[np.isclose(beam_obj.feed_angle, 0)] = "N/S" + feed_labels[np.isclose(beam_obj.feed_angle, np.pi / 2)] = "E/W" + + if beam_type == "efield": + nfeedpol = beam_obj.Nfeeds + feedpol_label = feed_labels + if issubclass(type(beam_obj), UnpolarizedAnalyticBeam): + feedpol_label = beam_obj.feed_array + if logcolor is None: + logcolor = False + else: + nfeedpol = beam_obj.Npols + pol_strs = polnum2str(beam_obj.polarization_array) + if np.max(beam_obj.polarization_array) <= -5 and not issubclass( + type(beam_obj), UnpolarizedAnalyticBeam + ): + # linear pols, use feed angles. + feedpol_label = [""] * nfeedpol + for col_i, polstr in enumerate(pol_strs): + feed0_ind = np.nonzero(beam_obj.feed_array == polstr[0])[0][0] + feed1_ind = np.nonzero(beam_obj.feed_array == polstr[1])[0][0] + if feed0_ind == feed1_ind: + feedpol_label[col_i] = feed_labels[feed0_ind] + else: + feedpol_label[col_i] = "-".join( + [feed_labels[feed0_ind], feed_labels[feed1_ind]] + ) + else: + feedpol_label = pol_strs + if logcolor is None: + logcolor = True + + if isinstance(beam_obj, UVBeam): + naxes_vec = beam_obj.Naxes_vec + name = beam_obj.telescope_name + freq_title = beam_obj.freq_array[freq] + + reg_grid = True + if beam_obj.pixel_coordinate_system == "healpix": + HEALPix = _get_hpix_obj() + + hpx_obj = HEALPix(nside=beam_obj.nside, order=beam_obj.ordering) + hpx_lon, hpx_lat = hpx_obj.healpix_to_lonlat(beam_obj.pixel_array) + za_array, az_array = hpx_latlon_to_zenithangle_azimuth( + hpx_lat.rad, hpx_lon.rad + ) + pts_use = np.nonzero(za_array <= np.radians(max_zenith_deg))[0] + za_array = za_array[pts_use] + az_array = az_array[pts_use] + reg_grid = False + else: + za_use = np.nonzero(beam_obj.axis2_array <= np.radians(max_zenith_deg))[0] + az_array, za_array = np.meshgrid( + beam_obj.axis1_array, beam_obj.axis2_array[za_use] + ) + + beam_vals = copy.deepcopy(beam_obj.data_array)[:, :, freq] + if reg_grid: + beam_vals = beam_vals[:, :, za_use, :] + else: + beam_vals = beam_vals[:, :, pts_use] + elif issubclass(type(beam_obj), AnalyticBeam): + name = beam_obj.__class__.__name__ + freq_title = freq + + naxes_vec = beam_obj.Naxes_vec + if beam_type == "power": + naxes_vec = 1 + reg_grid = True + + az_grid, za_grid = get_az_za_grid(max_zenith_deg=max_zenith_deg) + az_array, za_array = np.meshgrid(az_grid, za_grid) + bi = BeamInterface(beam_obj, beam_type=beam_type) + beam_vals = bi.compute_response( + az_array=az_array.flatten(), + za_array=za_array.flatten(), + freq_array=np.asarray([freq]), + ) + if issubclass(type(beam_obj), UnpolarizedAnalyticBeam): + beam_vals = beam_vals[0, :] + naxes_vec = 1 + if nfeedpol == 1: + feedpol_label = [""] + beam_vals = beam_vals.reshape(naxes_vec, nfeedpol, za_grid.size, az_grid.size) + + if beam_type == "power": + beam_type_label = "power" + elif beam_type == "efield": + beam_type_label = "E-field" + + plot_beam_arrays( + beam_vals, + az_array, + za_array, + complex_type=complex_type, + beam_name=name, + beam_type_label=beam_type_label, + freq_label=freq_title, + feedpol_label=feedpol_label, + colormap=colormap, + logcolor=logcolor, + plt_kwargs=plt_kwargs, + norm_kwargs=norm_kwargs, + savefile=savefile, + ) diff --git a/src/pyuvdata/uvbeam/uvbeam.py b/src/pyuvdata/uvbeam/uvbeam.py index 01dff6799..354ff7e4b 100644 --- a/src/pyuvdata/uvbeam/uvbeam.py +++ b/src/pyuvdata/uvbeam/uvbeam.py @@ -4869,6 +4869,7 @@ def plot( *, freq_ind: int = 0, complex_type: str = "real", + colormap: str | None = None, logcolor: bool | None = None, plt_kwargs: dict | None = None, norm_kwargs: dict | None = None, @@ -4886,6 +4887,11 @@ def plot( What to plot for complex beams, options are: [real, imag, abs, phase]. Defaults to "real" for complex beams. Ignored for real beams (i.e. power beams, same feed). + colormap : str, optional + Matplotlib colormap to use. Defaults to "twlight" if complex_type="phase" + and logcolor=False, otherwise it defaults to "viridis" if the data to be + plotted is positive definite (e.g. if complex_type="abs") and "PRGn" + otherwise. logcolor : bool, optional Option to use log scaling for the color. Defaults to True for power beams and False for E-field beams. Results in using @@ -4919,6 +4925,7 @@ def plot( plt_kwargs=plt_kwargs, norm_kwargs=norm_kwargs, max_zenith_deg=max_zenith_deg, + colormap=colormap, savefile=savefile, ) diff --git a/tests/test_analytic_beam.py b/tests/test_analytic_beam.py index ae6c400d9..dc9549413 100644 --- a/tests/test_analytic_beam.py +++ b/tests/test_analytic_beam.py @@ -641,9 +641,17 @@ def test_set_x_orientation_deprecation(): @pytest.mark.parametrize( - ("beam", "beam_type", "complex_type", "logcolor", "max_zenith_deg", "norm_kwargs"), + ( + "beam", + "beam_type", + "complex_type", + "logcolor", + "max_zenith_deg", + "norm_kwargs", + "colormap", + ), [ - (AiryBeam(diameter=7), "efield", "real", False, 90.0, None), + (AiryBeam(diameter=7), "efield", "real", False, 90.0, None, "inferno"), ( GaussianBeam(diameter=7, feed_array=["x"]), "efield", @@ -651,8 +659,9 @@ def test_set_x_orientation_deprecation(): False, 90.0, None, + None, ), - (AiryBeam(diameter=7), "power", "real", True, 90.0, {"vmin": 1e-9}), + (AiryBeam(diameter=7), "power", "real", True, 90.0, {"vmin": 1e-9}, None), ( GaussianBeam(diameter=7, include_cross_pols=False), "power", @@ -660,6 +669,7 @@ def test_set_x_orientation_deprecation(): True, 90.0, {}, + None, ), ( AiryBeam(diameter=7, feed_array=["x"]), @@ -668,14 +678,22 @@ def test_set_x_orientation_deprecation(): False, 90.0, {"vmin": 0, "vmax": 1}, + None, ), - (ShortDipoleBeam(), "efield", "real", None, 90.0, None), - (ShortDipoleBeam(), "power", "real", None, 90.0, None), - (UniformBeam(), "power", "real", None, 90.0, None), + (ShortDipoleBeam(), "efield", "real", None, 90.0, None, None), + (ShortDipoleBeam(), "power", "real", None, 90.0, None, None), + (UniformBeam(), "power", "real", None, 90.0, None, None), ], ) def test_plotting( - tmp_path, beam, beam_type, complex_type, logcolor, max_zenith_deg, norm_kwargs + tmp_path, + beam, + beam_type, + complex_type, + logcolor, + max_zenith_deg, + norm_kwargs, + colormap, ): """Test plotting method.""" pytest.importorskip("matplotlib") @@ -688,6 +706,7 @@ def test_plotting( beam_type=beam_type, freq=100e6, complex_type=complex_type, + colormap=colormap, logcolor=logcolor, max_zenith_deg=max_zenith_deg, norm_kwargs=norm_kwargs, diff --git a/tests/utils/test_plotting.py b/tests/utils/test_plotting.py new file mode 100644 index 000000000..aaba99bfe --- /dev/null +++ b/tests/utils/test_plotting.py @@ -0,0 +1,82 @@ +# Copyright (c) 2024 Radio Astronomy Software Group +# Licensed under the 2-clause BSD License + +import numpy as np +import pytest + +from pyuvdata import ShortDipoleBeam +from pyuvdata.utils.plotting import get_az_za_grid, plot_beam_arrays + + +def test_plot_arrays(tmp_path): + pytest.importorskip("matplotlib") + import matplotlib + + matplotlib.use("Agg") # Must be before importing matplotlib.pyplot or pylab! + + dipole_beam = ShortDipoleBeam() + + az_grid, za_grid = get_az_za_grid() + az_array, za_array = np.meshgrid(az_grid, za_grid) + + beam_vals = dipole_beam.efield_eval( + az_array=az_array.flatten(), + za_array=za_array.flatten(), + freq_array=np.asarray(np.asarray([100e6])), + ) + beam_vals = beam_vals.reshape(2, 2, za_grid.size, az_grid.size) + + savefile = str(tmp_path / "test.png") + + plot_beam_arrays( + beam_vals, + az_array, + za_array, + complex_type="real", + beam_type_label="E-field", + beam_name="short dipole", + savefile=savefile, + ) + + +def test_plot_arrays_errors(): + pytest.importorskip("matplotlib") + import matplotlib + + matplotlib.use("Agg") # Must be before importing matplotlib.pyplot or pylab! + + dipole_beam = ShortDipoleBeam() + + az_grid, za_grid = get_az_za_grid() + az_array, za_array = np.meshgrid(az_grid, za_grid) + + beam_vals = dipole_beam.efield_eval( + az_array=az_array.flatten(), + za_array=za_array.flatten(), + freq_array=np.asarray(np.asarray([100e6])), + ) + + with pytest.raises( + ValueError, + match="az_array and za_array must be shaped like the last dimension " + "of beam_vals for irregular beam_vals", + ): + plot_beam_arrays(beam_vals[:, :, 0], az_array, za_array) + + beam_vals = beam_vals.reshape(2, 2, za_grid.size, az_grid.size) + + with pytest.raises(ValueError, match="beam_vals must be 3 or 4 dimensional."): + plot_beam_arrays(beam_vals[0, 0], az_array, za_array) + + with pytest.raises( + ValueError, + match="feedpol_label must have the same number of elements as the 1st", + ): + plot_beam_arrays(beam_vals, az_array, za_array, feedpol_label=["foo"]) + + with pytest.raises( + ValueError, + match="az_array and za_array must be shaped like the last 2 dimensions " + "of beam_vals for regularly gridded beam_vals", + ): + plot_beam_arrays(beam_vals, az_array[0], za_array) diff --git a/tests/uvbeam/test_uvbeam.py b/tests/uvbeam/test_uvbeam.py index c890d3f83..55c874d2e 100644 --- a/tests/uvbeam/test_uvbeam.py +++ b/tests/uvbeam/test_uvbeam.py @@ -3458,13 +3458,23 @@ def test_fix_feeds_dep_warnings(cst_power_2freq_cut, mod_params, warn_msg): "logcolor", "max_zenith_deg", "norm_kwargs", + "colormap", ), [ - ("mwa", None, [{}], "real", None, 90.0, None), - ("hera", None, [{}], "real", True, 45.0, {"linthresh": 1e-4}), - ("mwa", None, [{}], "imag", False, 20.0, {}), - ("hera", None, [{}], "phase", False, 20.0, None), - ("mwa", ["peak_normalize"], [{}], "abs", True, 45.0, {"vmin": 1e-4, "vmax": 1}), + ("mwa", None, [{}], "real", None, 90.0, None, "inferno"), + ("hera", None, [{}], "real", True, 45.0, {"linthresh": 1e-4}, None), + ("mwa", None, [{}], "imag", False, 20.0, {}, None), + ("hera", None, [{}], "phase", False, 20.0, None, None), + ( + "mwa", + ["peak_normalize"], + [{}], + "abs", + True, + 45.0, + {"vmin": 1e-4, "vmax": 1}, + None, + ), ( "hera", ["peak_normalize"], @@ -3473,6 +3483,7 @@ def test_fix_feeds_dep_warnings(cst_power_2freq_cut, mod_params, warn_msg): False, 45.0, {"vmin": -1, "vmax": 1}, + None, ), ( "mwa", @@ -3482,9 +3493,10 @@ def test_fix_feeds_dep_warnings(cst_power_2freq_cut, mod_params, warn_msg): True, 20.0, {}, + None, ), - ("hera", ["efield_to_power"], [{}], "real", True, 20.0, {}), - ("mwa", ["to_healpix"], [{}], "phase", False, 20.0, {}), + ("hera", ["efield_to_power"], [{}], "real", True, 20.0, {}, None), + ("mwa", ["to_healpix"], [{}], "phase", False, 20.0, {}, None), ( "hera", ["efield_to_power", "to_healpix"], @@ -3493,8 +3505,9 @@ def test_fix_feeds_dep_warnings(cst_power_2freq_cut, mod_params, warn_msg): True, 20.0, {}, + None, ), - ("mwa", ["efield_to_pstokes"], [{}], "real", None, 20.0, {}), + ("mwa", ["efield_to_pstokes"], [{}], "real", None, 20.0, {}, None), ], ) def test_plotting( @@ -3508,8 +3521,9 @@ def test_plotting( logcolor, max_zenith_deg, norm_kwargs, + colormap, ): - """Test plotting method.""" + """Test plotting method. This is just a smoke test to make sure it doesn't error.""" pytest.importorskip("matplotlib") import matplotlib @@ -3529,6 +3543,7 @@ def test_plotting( savefile = str(tmp_path / "test.png") beam.plot( complex_type=complex_type, + colormap=colormap, logcolor=logcolor, max_zenith_deg=max_zenith_deg, norm_kwargs=norm_kwargs,