diff --git a/news/fix-get-array-index.rst b/news/fix-get-array-index.rst new file mode 100644 index 0000000..67814f5 --- /dev/null +++ b/news/fix-get-array-index.rst @@ -0,0 +1,23 @@ +**Added:** + +* + +**Changed:** + +* + +**Deprecated:** + +* + +**Removed:** + +* + +**Fixed:** + +* Let ``DiffractionObject.get_array_index`` to use the ``xtype`` from its inputs. + +**Security:** + +* diff --git a/src/diffpy/utils/diffraction_objects.py b/src/diffpy/utils/diffraction_objects.py index 25f32f4..ec09b41 100644 --- a/src/diffpy/utils/diffraction_objects.py +++ b/src/diffpy/utils/diffraction_objects.py @@ -409,17 +409,17 @@ def uuid(self): def uuid(self, _): raise AttributeError(_setter_wmsg("uuid")) - def get_array_index(self, xtype, xvalue): - """Return the index of the closest value in the array associated with + def get_array_index(self, xvalue, xtype=None): + f"""Return the index of the closest value in the array associated with the specified xtype and the value provided. Parameters ---------- - xtype : str - The type of the independent variable in `xarray`. Must be one - of {*XQUANTITIES}. xvalue : float The value of the xtype to find the closest index for. + xtype : str, optional + The type of the independent variable in `xarray`. Must be one + of {*XQUANTITIES, }. Default is {self._input_xtype} Returns ------- @@ -427,8 +427,11 @@ def get_array_index(self, xtype, xvalue): The index of the closest value in the array associated with the specified xtype and the value provided. """ - - xtype = self._input_xtype + if xtype is None: + xtype = self._input_xtype + else: + if xtype not in XQUANTITIES: + raise ValueError(_xtype_wmsg(xtype)) xarray = self.on_xtype(xtype)[0] if len(xarray) == 0: raise ValueError( diff --git a/tests/test_diffraction_objects.py b/tests/test_diffraction_objects.py index 8ed1dd3..59ee593 100644 --- a/tests/test_diffraction_objects.py +++ b/tests/test_diffraction_objects.py @@ -376,8 +376,8 @@ def test_scale_to_bad(org_do_args, target_do_args, scale_inputs): }, 0, ), - ( # C2: Target value lies in the array, expect the (first) closest - # index + # C2: Target value lies in the array, expect the closest index + ( # 1. xtype(tth) is equal to self._input_xtype(tth) { "wavelength": 4 * np.pi, "xarray": np.array([30, 60]), @@ -390,7 +390,7 @@ def test_scale_to_bad(org_do_args, target_do_args, scale_inputs): }, 0, ), - ( + ( # 2. use default xtype(equal to self._input_xtype) { "wavelength": 4 * np.pi, "xarray": np.array([30, 60]), @@ -398,11 +398,24 @@ def test_scale_to_bad(org_do_args, target_do_args, scale_inputs): "xtype": "tth", }, { - "xtype": "q", - "value": 0.25, + "xtype": None, + "value": 45, }, 0, ), + ( # 3. xtype(q) is different from self._input_xtype(tth) + { + "wavelength": 4 * np.pi, + "xarray": np.array([30, 60]), + "yarray": np.array([1, 2]), + "xtype": "tth", + }, + { + "xtype": "q", + "value": 0.5, + }, + 1, + ), # C3: Target value out of the range, expect the closest index ( # 1. Test with xtype of "q" { @@ -435,12 +448,13 @@ def test_scale_to_bad(org_do_args, target_do_args, scale_inputs): def test_get_array_index(do_args, get_array_index_inputs, expected_index): do = DiffractionObject(**do_args) actual_index = do.get_array_index( - get_array_index_inputs["xtype"], get_array_index_inputs["value"] + get_array_index_inputs["value"], get_array_index_inputs["xtype"] ) assert actual_index == expected_index def test_get_array_index_bad(): + # empty array in DiffractionObject do = DiffractionObject( wavelength=2 * np.pi, xarray=np.array([]), @@ -454,6 +468,22 @@ def test_get_array_index_bad(): ), ): do.get_array_index(xtype="tth", xvalue=30) + # non-existing xtype + do = DiffractionObject( + wavelength=4 * np.pi, + xarray=np.array([30, 60]), + yarray=np.array([1, 2]), + xtype="tth", + ) + non_existing_xtype = "non_existing_xtype" + with pytest.raises( + ValueError, + match=re.escape( + f"I don't know how to handle the xtype, '{non_existing_xtype}'. " + f"Please rerun specifying an xtype from {*XQUANTITIES, }" + ), + ): + do.get_array_index(xtype=non_existing_xtype, xvalue=30) def test_dump(tmp_path, mocker):