Skip to content

Commit 47b2934

Browse files
committed
Update default dtype parameters in check_1d_arrays decorator, and allow None type
1 parent 8505c15 commit 47b2934

File tree

2 files changed

+10
-14
lines changed

2 files changed

+10
-14
lines changed

sigima/tools/checks.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@ def check_1d_arrays(
1616
func: Callable[..., Any] | None = None,
1717
*,
1818
x_1d: bool = True,
19-
x_dtype: type = np.floating,
19+
x_dtype: type | None = np.floating, # Default to floating point types
2020
x_sorted: bool = False,
2121
x_evenly_spaced: bool = False,
2222
y_1d: bool = True,
23-
y_dtype: type = np.floating,
23+
y_dtype: type | None = np.inexact, # Default to inexact types (float or complex)
2424
x_y_same_size: bool = True,
2525
rtol: float = 1e-5,
2626
) -> Callable:
@@ -65,7 +65,7 @@ def wrapper(x: np.ndarray, y: np.ndarray, *args: Any, **kwargs: Any) -> Any:
6565
# === Check x array
6666
if x_1d and x.ndim != 1:
6767
raise ValueError("x must be 1-D.")
68-
if not np.issubdtype(x.dtype, x_dtype):
68+
if x_dtype is not None and not np.issubdtype(x.dtype, x_dtype):
6969
raise TypeError(f"x must be of type {x_dtype}, but got {x.dtype}.")
7070
if x_sorted and x.size > 1 and not np.all(np.diff(x) >= 0.0):
7171
raise ValueError("x must be sorted in ascending order.")
@@ -76,7 +76,7 @@ def wrapper(x: np.ndarray, y: np.ndarray, *args: Any, **kwargs: Any) -> Any:
7676
# === Check y array
7777
if y_1d and y.ndim != 1:
7878
raise ValueError("y must be 1-D.")
79-
if not np.issubdtype(y.dtype, y_dtype):
79+
if y_dtype is not None and not np.issubdtype(y.dtype, y_dtype):
8080
raise TypeError(f"y must be of type {y_dtype}, but got {y.dtype}.")
8181
if x_y_same_size and x.size != y.size:
8282
raise ValueError("x and y must have the same size.")

sigima/tools/signal/fourier.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def zero_padding(
4646
return xnew, ynew
4747

4848

49-
@check_1d_arrays(x_dtype=np.inexact, x_evenly_spaced=True, y_dtype=np.inexact)
49+
@check_1d_arrays(x_evenly_spaced=True)
5050
def fft1d(
5151
x: np.ndarray, y: np.ndarray, shift: bool = True
5252
) -> tuple[np.ndarray, np.ndarray]:
@@ -70,12 +70,7 @@ def fft1d(
7070
return f, sp
7171

7272

73-
@check_1d_arrays(
74-
x_dtype=np.inexact,
75-
x_sorted=False,
76-
x_evenly_spaced=False,
77-
y_dtype=np.complexfloating,
78-
)
73+
@check_1d_arrays(x_evenly_spaced=False, x_sorted=False, y_dtype=np.complexfloating)
7974
def ifft1d(
8075
f: np.ndarray, sp: np.ndarray, initial: float = 0.0
8176
) -> tuple[np.ndarray, np.ndarray]:
@@ -116,7 +111,7 @@ def ifft1d(
116111
return x, y.real
117112

118113

119-
@check_1d_arrays(x_dtype=np.inexact, x_evenly_spaced=True, y_dtype=np.inexact)
114+
@check_1d_arrays(x_evenly_spaced=True)
120115
def magnitude_spectrum(
121116
x: np.ndarray, y: np.ndarray, log_scale: bool = False
122117
) -> tuple[np.ndarray, np.ndarray]:
@@ -138,7 +133,7 @@ def magnitude_spectrum(
138133
return x1, y_mag
139134

140135

141-
@check_1d_arrays(x_dtype=np.inexact, x_evenly_spaced=True, y_dtype=np.inexact)
136+
@check_1d_arrays(x_evenly_spaced=True)
142137
def phase_spectrum(x: np.ndarray, y: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
143138
"""Compute phase spectrum.
144139
@@ -154,7 +149,7 @@ def phase_spectrum(x: np.ndarray, y: np.ndarray) -> tuple[np.ndarray, np.ndarray
154149
return x1, y_phase
155150

156151

157-
@check_1d_arrays(x_dtype=np.inexact, x_evenly_spaced=True, y_dtype=np.inexact)
152+
@check_1d_arrays(x_evenly_spaced=True)
158153
def psd(
159154
x: np.ndarray, y: np.ndarray, log_scale: bool = False
160155
) -> tuple[np.ndarray, np.ndarray]:
@@ -189,6 +184,7 @@ def sort_frequencies(x: np.ndarray, y: np.ndarray) -> np.ndarray:
189184
return freqs[np.argsort(fourier)]
190185

191186

187+
@check_1d_arrays(x_evenly_spaced=True)
192188
def brickwall_filter(
193189
x: np.ndarray,
194190
y: np.ndarray,

0 commit comments

Comments
 (0)