Skip to content

Commit d5738a9

Browse files
fix: default to brute-force computation when mud outside of range
1 parent 60f9f3f commit d5738a9

File tree

3 files changed

+59
-82
lines changed

3 files changed

+59
-82
lines changed

news/fastcalto7.rst

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
**Added:**
22

3-
* Fast calculation support up to muD = 7
3+
* Fast calculation supports values up to muD = 7
44

55
**Changed:**
66

7-
* Clarified error message for fast calculation to explicitly states the invalid muD value
7+
* Default to brute-force computation when muD < 0.5 or > 7.
8+
* Print a warning message instead of error, explicitly stating the input muD value
89

910
**Deprecated:**
1011

src/diffpy/labpdfproc/functions.py

+14-28
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import math
2+
import warnings
23
from pathlib import Path
34

45
import numpy as np
@@ -16,7 +17,7 @@
1617
CVE_METHODS = ["brute_force", "polynomial_interpolation"]
1718

1819
# Pre-computed datasets for polynomial interpolation (fast calculation)
19-
MUD_LIST = [0.5, 1, 2, 3, 4, 5, 6, 7]
20+
MUD_LIST = np.array([0.5, 1, 2, 3, 4, 5, 6, 7])
2021
CWD = Path(__file__).parent.resolve()
2122
MULS = np.loadtxt(CWD / "data" / "inverse_cve.xy")
2223
COEFFICIENT_LIST = np.array(
@@ -74,7 +75,6 @@ def _get_entry_exit_coordinates(self, coordinate, angle):
7475
----------
7576
coordinate : tuple of floats
7677
The coordinates of the grid point.
77-
7878
angle : float
7979
The angle in degrees.
8080
@@ -90,9 +90,7 @@ def _get_entry_exit_coordinates(self, coordinate, angle):
9090
angle = math.radians(angle)
9191
xgrid = coordinate[0]
9292
ygrid = coordinate[1]
93-
9493
entry_point = (-math.sqrt(self.radius**2 - ygrid**2), ygrid)
95-
9694
if not math.isclose(angle, math.pi / 2, abs_tol=epsilon):
9795
b = ygrid - xgrid * math.tan(angle)
9896
a = math.tan(angle)
@@ -107,7 +105,6 @@ def _get_entry_exit_coordinates(self, coordinate, angle):
107105
exit_point = (xexit_root1, yexit_root1)
108106
else:
109107
exit_point = (xgrid, math.sqrt(self.radius**2 - xgrid**2))
110-
111108
return entry_point, exit_point
112109

113110
def _get_path_length(self, grid_point, angle):
@@ -119,7 +116,6 @@ def _get_path_length(self, grid_point, angle):
119116
----------
120117
grid_point : double of floats
121118
The coordinate inside the circle.
122-
123119
angle : float
124120
The angle of the output beam in degrees.
125121
@@ -129,7 +125,6 @@ def _get_path_length(self, grid_point, angle):
129125
The tuple containing three floats,
130126
which are the total distance, entry distance and exit distance.
131127
"""
132-
133128
# move angle a tad above zero if it is zero
134129
# to avoid it having the wrong sign due to some rounding error
135130
angle_delta = 0.000001
@@ -181,7 +176,9 @@ def _cve_brute_force(input_pattern, mud):
181176
Assume mu=mud/2, given that the same mu*D yields the same cve and D/2=1.
182177
"""
183178
mu_sample_invmm = mud / 2
184-
abs_correction = Gridded_circle(mu=mu_sample_invmm)
179+
abs_correction = Gridded_circle(
180+
n_points_on_diameter=N_POINTS_ON_DIAMETER, mu=mu_sample_invmm
181+
)
185182
distances, muls = [], []
186183
for angle in TTH_GRID:
187184
abs_correction.set_distances_at_angle(angle)
@@ -191,7 +188,6 @@ def _cve_brute_force(input_pattern, mud):
191188
distances = np.array(distances) / abs_correction.total_points_in_grid
192189
muls = np.array(muls) / abs_correction.total_points_in_grid
193190
cve = 1 / muls
194-
195191
cve_do = DiffractionObject(
196192
xarray=TTH_GRID,
197193
yarray=cve,
@@ -206,31 +202,21 @@ def _cve_brute_force(input_pattern, mud):
206202

207203
def _cve_polynomial_interpolation(input_pattern, mud):
208204
"""Compute cve using polynomial interpolation method,
209-
raise an error if the mu*D value is out of the range (0.5 to 7).
205+
default to brute-force computation if mu*D is
206+
out of the range (0.5 to 7).
210207
"""
211208
if mud > 7 or mud < 0.5:
212-
raise ValueError(
209+
warnings.warn(
213210
f"Input mu*D = {mud} is out of the acceptable range "
214-
f"({min(MUD_LIST)} to {max(MUD_LIST)}) "
211+
f"({np.min(MUD_LIST)} to {np.max(MUD_LIST)}) "
215212
f"for polynomial interpolation. "
216-
f"Please rerun with a value within this range "
217-
f"or specifying another method from {*CVE_METHODS, }."
213+
f"Proceeding with brute-force computation. "
218214
)
219-
coef1, coef2, coef3, coef4, coef5, coef6, coef7 = [
220-
interpolation_function(mud)
221-
for interpolation_function in INTERPOLATION_FUNCTIONS
222-
]
223-
muls = np.array(
224-
coef1 * MULS**6
225-
+ coef2 * MULS**5
226-
+ coef3 * MULS**4
227-
+ coef4 * MULS**3
228-
+ coef5 * MULS**2
229-
+ coef6 * MULS
230-
+ coef7
231-
)
232-
cve = 1 / muls
215+
return _cve_brute_force(input_pattern, mud)
233216

217+
coeffs = np.array([f(mud) for f in INTERPOLATION_FUNCTIONS])
218+
muls = np.polyval(coeffs, MULS)
219+
cve = 1 / muls
234220
cve_do = DiffractionObject(
235221
xarray=TTH_GRID,
236222
yarray=cve,

tests/test_functions.py

+42-52
Original file line numberDiff line numberDiff line change
@@ -105,49 +105,55 @@ def test_set_muls_at_angle(input_mu, expected_muls):
105105

106106

107107
@pytest.mark.parametrize(
108-
"input_xtype, expected",
109-
[
110-
(
111-
"tth",
108+
"input_diffraction_data, input_cve_params",
109+
[ # Test that cve diffraction object contains the expected info
110+
# Note that all cve values are interpolated to 0.5
111+
# cve do should contain the same input xarray, xtype,
112+
# wavelength, and metadata
113+
( # C1: User did not specify method, default to fast calculation
112114
{
113115
"xarray": np.array([90, 90.1, 90.2]),
114-
"yarray": np.array([0.5, 0.5, 0.5]),
115-
"xtype": "tth",
116+
"yarray": np.array([2, 2, 2]),
116117
},
118+
{"mud": 1, "xtype": "tth"},
117119
),
118-
(
119-
"q",
120+
( # C2: User specified brute-force computation method
120121
{
121-
"xarray": np.array([5.76998, 5.77501, 5.78004]),
122-
"yarray": np.array([0.5, 0.5, 0.5]),
123-
"xtype": "q",
122+
"xarray": np.array([5.1, 5.2, 5.3]),
123+
"yarray": np.array([2, 2, 2]),
124124
},
125+
{"mud": 1, "method": "brute_force", "xtype": "q"},
126+
),
127+
( # C3: User specified mu*D outside the fast calculation range,
128+
# default to brute-force computation
129+
{
130+
"xarray": np.array([5.1, 5.2, 5.3]),
131+
"yarray": np.array([2, 2, 2]),
132+
},
133+
{"mud": 20, "xtype": "q"},
125134
),
126135
],
127136
)
128-
def test_compute_cve(input_xtype, expected, mocker):
129-
xarray, yarray = np.array([90, 90.1, 90.2]), np.array([2, 2, 2])
137+
def test_compute_cve(mocker, input_diffraction_data, input_cve_params):
138+
expected_xarray = input_diffraction_data["xarray"]
130139
expected_cve = np.array([0.5, 0.5, 0.5])
140+
expected_xtype = input_cve_params["xtype"]
141+
mocker.patch("diffpy.labpdfproc.functions.N_POINTS_ON_DIAMETER", 4)
131142
mocker.patch("numpy.interp", return_value=expected_cve)
132143
input_pattern = DiffractionObject(
133-
xarray=xarray,
134-
yarray=yarray,
135-
xtype="tth",
144+
xarray=input_diffraction_data["xarray"],
145+
yarray=input_diffraction_data["yarray"],
146+
xtype=input_cve_params["xtype"],
136147
wavelength=1.54,
137148
scat_quantity="x-ray",
138149
name="test",
139150
metadata={"thing1": 1, "thing2": "thing2"},
140151
)
141-
actual_cve_do = compute_cve(
142-
input_pattern,
143-
mud=1,
144-
method="polynomial_interpolation",
145-
xtype=input_xtype,
146-
)
152+
actual_cve_do = compute_cve(input_pattern, **input_cve_params)
147153
expected_cve_do = DiffractionObject(
148-
xarray=expected["xarray"],
149-
yarray=expected["yarray"],
150-
xtype=expected["xtype"],
154+
xarray=expected_xarray,
155+
yarray=expected_cve,
156+
xtype=expected_xtype,
151157
wavelength=1.54,
152158
scat_quantity="cve",
153159
name="absorption correction, cve, for test",
@@ -156,32 +162,9 @@ def test_compute_cve(input_xtype, expected, mocker):
156162
assert actual_cve_do == expected_cve_do
157163

158164

159-
@pytest.mark.parametrize(
160-
"inputs, msg",
161-
[
162-
(
163-
{"mud": 10, "method": "polynomial_interpolation"},
164-
f"mu*D = 10 is out of the acceptable range (0.5 to 7) "
165-
f"for polynomial interpolation. "
166-
f"Please rerun with a value within this range "
167-
f"or specifying another method from {*CVE_METHODS, }.",
168-
),
169-
(
170-
{"mud": 1, "method": "invalid_method"},
171-
f"Unknown method: invalid_method. "
172-
f"Allowed methods are {*CVE_METHODS, }.",
173-
),
174-
(
175-
{"mud": 7, "method": "invalid_method"},
176-
f"Unknown method: invalid_method. "
177-
f"Allowed methods are {*CVE_METHODS, }.",
178-
),
179-
],
180-
)
181-
def test_compute_cve_bad(mocker, inputs, msg):
165+
def test_compute_cve_bad(mocker):
182166
xarray, yarray = np.array([90, 90.1, 90.2]), np.array([2, 2, 2])
183167
expected_cve = np.array([0.5, 0.5, 0.5])
184-
mocker.patch("diffpy.labpdfproc.functions.TTH_GRID", xarray)
185168
mocker.patch("numpy.interp", return_value=expected_cve)
186169
input_pattern = DiffractionObject(
187170
xarray=xarray,
@@ -192,14 +175,21 @@ def test_compute_cve_bad(mocker, inputs, msg):
192175
name="test",
193176
metadata={"thing1": 1, "thing2": "thing2"},
194177
)
195-
with pytest.raises(ValueError, match=re.escape(msg)):
196-
compute_cve(input_pattern, mud=inputs["mud"], method=inputs["method"])
178+
# Test that the function raises a ValueError
179+
# when an invalid method is provided
180+
with pytest.raises(
181+
ValueError,
182+
match=re.escape(
183+
f"Unknown method: invalid_method. "
184+
f"Allowed methods are {*CVE_METHODS, }."
185+
),
186+
):
187+
compute_cve(input_pattern, mud=1, method="invalid_method")
197188

198189

199190
def test_apply_corr(mocker):
200191
xarray, yarray = np.array([90, 90.1, 90.2]), np.array([2, 2, 2])
201192
expected_cve = np.array([0.5, 0.5, 0.5])
202-
mocker.patch("diffpy.labpdfproc.functions.TTH_GRID", xarray)
203193
mocker.patch("numpy.interp", return_value=expected_cve)
204194
input_pattern = DiffractionObject(
205195
xarray=xarray,

0 commit comments

Comments
 (0)