Skip to content

Commit 44328a2

Browse files
authored
Merge pull request #356 from ycexiao/fix-get-array-index
Fix: let `DiffractionObject.get_array_index` to use `xtype` from its inputs.
2 parents 93bc105 + eaeafe5 commit 44328a2

File tree

3 files changed

+69
-13
lines changed

3 files changed

+69
-13
lines changed

news/fix-get-array-index.rst

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
**Added:**
2+
3+
* <news item>
4+
5+
**Changed:**
6+
7+
* <news item>
8+
9+
**Deprecated:**
10+
11+
* <news item>
12+
13+
**Removed:**
14+
15+
* <news item>
16+
17+
**Fixed:**
18+
19+
* Let ``DiffractionObject.get_array_index`` to use the ``xtype`` from its inputs.
20+
21+
**Security:**
22+
23+
* <news item>

src/diffpy/utils/diffraction_objects.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -409,26 +409,29 @@ def uuid(self):
409409
def uuid(self, _):
410410
raise AttributeError(_setter_wmsg("uuid"))
411411

412-
def get_array_index(self, xtype, xvalue):
413-
"""Return the index of the closest value in the array associated with
412+
def get_array_index(self, xvalue, xtype=None):
413+
f"""Return the index of the closest value in the array associated with
414414
the specified xtype and the value provided.
415415
416416
Parameters
417417
----------
418-
xtype : str
419-
The type of the independent variable in `xarray`. Must be one
420-
of {*XQUANTITIES}.
421418
xvalue : float
422419
The value of the xtype to find the closest index for.
420+
xtype : str, optional
421+
The type of the independent variable in `xarray`. Must be one
422+
of {*XQUANTITIES, }. Default is {self._input_xtype}
423423
424424
Returns
425425
-------
426426
index : int
427427
The index of the closest value in the array associated with the
428428
specified xtype and the value provided.
429429
"""
430-
431-
xtype = self._input_xtype
430+
if xtype is None:
431+
xtype = self._input_xtype
432+
else:
433+
if xtype not in XQUANTITIES:
434+
raise ValueError(_xtype_wmsg(xtype))
432435
xarray = self.on_xtype(xtype)[0]
433436
if len(xarray) == 0:
434437
raise ValueError(

tests/test_diffraction_objects.py

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -376,8 +376,8 @@ def test_scale_to_bad(org_do_args, target_do_args, scale_inputs):
376376
},
377377
0,
378378
),
379-
( # C2: Target value lies in the array, expect the (first) closest
380-
# index
379+
# C2: Target value lies in the array, expect the closest index
380+
( # 1. xtype(tth) is equal to self._input_xtype(tth)
381381
{
382382
"wavelength": 4 * np.pi,
383383
"xarray": np.array([30, 60]),
@@ -390,19 +390,32 @@ def test_scale_to_bad(org_do_args, target_do_args, scale_inputs):
390390
},
391391
0,
392392
),
393-
(
393+
( # 2. use default xtype(equal to self._input_xtype)
394394
{
395395
"wavelength": 4 * np.pi,
396396
"xarray": np.array([30, 60]),
397397
"yarray": np.array([1, 2]),
398398
"xtype": "tth",
399399
},
400400
{
401-
"xtype": "q",
402-
"value": 0.25,
401+
"xtype": None,
402+
"value": 45,
403403
},
404404
0,
405405
),
406+
( # 3. xtype(q) is different from self._input_xtype(tth)
407+
{
408+
"wavelength": 4 * np.pi,
409+
"xarray": np.array([30, 60]),
410+
"yarray": np.array([1, 2]),
411+
"xtype": "tth",
412+
},
413+
{
414+
"xtype": "q",
415+
"value": 0.5,
416+
},
417+
1,
418+
),
406419
# C3: Target value out of the range, expect the closest index
407420
( # 1. Test with xtype of "q"
408421
{
@@ -435,12 +448,13 @@ def test_scale_to_bad(org_do_args, target_do_args, scale_inputs):
435448
def test_get_array_index(do_args, get_array_index_inputs, expected_index):
436449
do = DiffractionObject(**do_args)
437450
actual_index = do.get_array_index(
438-
get_array_index_inputs["xtype"], get_array_index_inputs["value"]
451+
get_array_index_inputs["value"], get_array_index_inputs["xtype"]
439452
)
440453
assert actual_index == expected_index
441454

442455

443456
def test_get_array_index_bad():
457+
# empty array in DiffractionObject
444458
do = DiffractionObject(
445459
wavelength=2 * np.pi,
446460
xarray=np.array([]),
@@ -454,6 +468,22 @@ def test_get_array_index_bad():
454468
),
455469
):
456470
do.get_array_index(xtype="tth", xvalue=30)
471+
# non-existing xtype
472+
do = DiffractionObject(
473+
wavelength=4 * np.pi,
474+
xarray=np.array([30, 60]),
475+
yarray=np.array([1, 2]),
476+
xtype="tth",
477+
)
478+
non_existing_xtype = "non_existing_xtype"
479+
with pytest.raises(
480+
ValueError,
481+
match=re.escape(
482+
f"I don't know how to handle the xtype, '{non_existing_xtype}'. "
483+
f"Please rerun specifying an xtype from {*XQUANTITIES, }"
484+
),
485+
):
486+
do.get_array_index(xtype=non_existing_xtype, xvalue=30)
457487

458488

459489
def test_dump(tmp_path, mocker):

0 commit comments

Comments
 (0)