diff --git a/phys2denoise/metrics/utils.py b/phys2denoise/metrics/utils.py index 7426a5c..7fb7a23 100644 --- a/phys2denoise/metrics/utils.py +++ b/phys2denoise/metrics/utils.py @@ -1,5 +1,6 @@ """Miscellaneous utility functions for metric calculation.""" import logging +import warnings import numpy as np @@ -45,11 +46,23 @@ def mirrorpad_1d(arr, buffer=250): ------- arr_out """ + mirror = np.flip(arr, axis=0) - idx = range(arr.shape[0] - buffer, arr.shape[0]) - pre_mirror = np.take(mirror, idx, axis=0) - idx = range(0, buffer) - post_mirror = np.take(mirror, idx, axis=0) + # If buffer is too long, fix it and issue a warning + try: + idx = range(arr.shape[0] - buffer, arr.shape[0]) + pre_mirror = np.take(mirror, idx, axis=0) + idx = range(0, buffer) + post_mirror = np.take(mirror, idx, axis=0) + except IndexError: + fixed_buffer = len(arr) + warnings.warn('Requested buffer size ({}) is longer than arr length ({}). Fixing buffer size to {}.'.format( + buffer, len(arr), fixed_buffer + )) + idx = range(arr.shape[0] - fixed_buffer, arr.shape[0]) + pre_mirror = np.take(mirror, idx, axis=0) + idx = range(0, fixed_buffer) + post_mirror = np.take(mirror, idx, axis=0) arr_out = np.concatenate((pre_mirror, arr, post_mirror), axis=0) return arr_out