From 695447062e4901177955a62a0bb36cad0c0772b0 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 31 Jul 2025 12:39:51 +0200 Subject: [PATCH 1/2] improve benchmar detection --- .../benchmark/benchmark_peak_detection.py | 74 +++++++++++-------- 1 file changed, 45 insertions(+), 29 deletions(-) diff --git a/src/spikeinterface/benchmark/benchmark_peak_detection.py b/src/spikeinterface/benchmark/benchmark_peak_detection.py index 9f6129d838..c37917d468 100644 --- a/src/spikeinterface/benchmark/benchmark_peak_detection.py +++ b/src/spikeinterface/benchmark/benchmark_peak_detection.py @@ -57,7 +57,7 @@ def compute_result(self, **result_params): spikes, self.recording.sampling_frequency, unit_ids=self.recording.channel_ids ) - self.result["gt_comparison"] = GroundTruthComparison( + self.result["gt_comparison_by_channels"] = GroundTruthComparison( self.result["gt_on_channels"], self.result["peak_on_channels"], exhaustive_gt=self.exhaustive_gt ) @@ -82,35 +82,34 @@ def compute_result(self, **result_params): sorting["segment_index"] = peaks[detected_matches]["segment_index"] order = np.lexsort((sorting["sample_index"], sorting["segment_index"])) sorting = sorting[order] - self.result["sliced_gt_sorting"] = NumpySorting( + self.result["matched_sorting"] = NumpySorting( sorting, self.recording.sampling_frequency, self.gt_sorting.unit_ids ) - self.result["sliced_gt_comparison"] = GroundTruthComparison( - self.gt_sorting, self.result["sliced_gt_sorting"], exhaustive_gt=self.exhaustive_gt + self.result["gt_comparison"] = GroundTruthComparison( + self.gt_sorting, self.result["matched_sorting"], exhaustive_gt=self.exhaustive_gt ) ratio = 100 * len(gt_matches) / len(times2) print("Only {0:.2f}% of gt peaks are matched to detected peaks".format(ratio)) sorting_analyzer = create_sorting_analyzer( - self.result["sliced_gt_sorting"], self.recording, format="memory", sparse=False, **job_kwargs + self.result["matched_sorting"], self.recording, format="memory", sparse=False, **job_kwargs ) sorting_analyzer.compute("random_spikes") sorting_analyzer.compute("templates", **job_kwargs) - self.result["templates"] = sorting_analyzer.get_extension("templates").get_data() + self.result["matched_templates"] = sorting_analyzer.get_extension("templates").get_data() _run_key_saved = [("peaks", "npy")] _result_key_saved = [ + ("gt_comparison_by_channels", "pickle"), + ("matched_sorting", "sorting"), ("gt_comparison", "pickle"), - ("sliced_gt_sorting", "sorting"), - ("sliced_gt_comparison", "pickle"), - ("sliced_gt_sorting", "sorting"), ("peak_on_channels", "sorting"), ("gt_on_channels", "sorting"), ("matches", "pickle"), - ("templates", "npy"), + ("matched_templates", "npy"), ("gt_amplitudes", "npy"), ("gt_templates", "npy"), ] @@ -127,6 +126,12 @@ def create_benchmark(self, key): init_kwargs = self.cases[key]["init_kwargs"] benchmark = PeakDetectionBenchmark(recording, gt_sorting, params, **init_kwargs) return benchmark + + def plot_performances_vs_snr(self, **kwargs): + from .benchmark_plot_tools import plot_performances_vs_snr + + return plot_performances_vs_snr(self, **kwargs) + def plot_agreements_by_channels(self, case_keys=None, figsize=(15, 15)): if case_keys is None: @@ -138,7 +143,7 @@ def plot_agreements_by_channels(self, case_keys=None, figsize=(15, 15)): for count, key in enumerate(case_keys): ax = axs[0, count] ax.set_title(self.cases[key]["label"]) - plot_agreement_matrix(self.get_result(key)["gt_comparison"], ax=ax) + plot_agreement_matrix(self.get_result(key)["gt_comparison_by_channels"], ax=ax) def plot_agreements_by_units(self, case_keys=None, figsize=(15, 15)): if case_keys is None: @@ -150,37 +155,48 @@ def plot_agreements_by_units(self, case_keys=None, figsize=(15, 15)): for count, key in enumerate(case_keys): ax = axs[0, count] ax.set_title(self.cases[key]["label"]) - plot_agreement_matrix(self.get_result(key)["sliced_gt_comparison"], ax=ax) + plot_agreement_matrix(self.get_result(key)["gt_comparison"], ax=ax) - def plot_detected_amplitudes(self, case_keys=None, figsize=(15, 5), detect_threshold=None, axs=None): + def plot_detected_amplitude_distributions(self, case_keys=None, show_legend=True, detect_threshold=None, figsize=(15, 5),ax=None): if case_keys is None: case_keys = list(self.cases.keys()) import matplotlib.pyplot as plt - if axs is None: - fig, axs = plt.subplots(ncols=len(case_keys), figsize=figsize, squeeze=False) + if ax is None: + fig, ax = plt.subplots(figsize=figsize, squeeze=False) else: - fig = axs[0].get_figure() - assert len(axs) == len(case_keys), "axs should be the same length as case_keys" + fig = ax.get_figure() + + + # plot only the first key for gt amplitude + # TODO make a loop for all of then + key0 = case_keys[0] + data2 = self.get_result(key0)["gt_amplitudes"] + bins = np.linspace(data2.min(), data2.max(), 100) + ax.hist(data2, bins=bins, alpha=0.1, label="gt", color="k") for count, key in enumerate(case_keys): - ax = axs[count] despine(ax) data1 = self.get_result(key)["peaks"]["amplitude"] - data2 = self.get_result(key)["gt_amplitudes"] + color = self.get_colors()[key] - bins = np.linspace(data2.min(), data2.max(), 100) - ax.hist(data1, bins=bins, label="detected", histtype="step", color=color, linewidth=2) - ax.hist(data2, bins=bins, alpha=0.1, label="gt", color="k") - ax.set_yscale("log") + + label = self.cases[key]["label"] + ax.hist(data1, bins=bins, label=label, histtype="step", color=color, linewidth=2) + # ax.set_title(self.cases[key]["label"]) + + ax.set_yscale("log") + + if detect_threshold is not None: + noise_levels = get_noise_levels(self.benchmarks[key].recording, return_in_uV=False).mean() + ymin, ymax = ax.get_ylim() + abs_threshold = -detect_threshold * noise_levels + ax.plot([abs_threshold, abs_threshold], [ymin, ymax], "k--") + + if show_legend: ax.legend() - if detect_threshold is not None: - noise_levels = get_noise_levels(self.benchmarks[key].recording, return_in_uV=False).mean() - ymin, ymax = ax.get_ylim() - abs_threshold = -detect_threshold * noise_levels - ax.plot([abs_threshold, abs_threshold], [ymin, ymax], "k--") return fig @@ -266,7 +282,7 @@ def plot_template_similarities(self, case_keys=None, metric="l2", figsize=(15, 5 import sklearn.metrics gt_templates = self.get_result(key)["gt_templates"] - found_templates = self.get_result(key)["templates"] + found_templates = self.get_result(key)["matched_templates"] num_templates = len(gt_templates) distances = np.zeros(num_templates) From adb8b4160916fbcabdeb3112dd652f57a8a8c7f6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 22 Aug 2025 16:02:24 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../benchmark/benchmark_peak_detection.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/benchmark/benchmark_peak_detection.py b/src/spikeinterface/benchmark/benchmark_peak_detection.py index c37917d468..dcff0d8907 100644 --- a/src/spikeinterface/benchmark/benchmark_peak_detection.py +++ b/src/spikeinterface/benchmark/benchmark_peak_detection.py @@ -126,13 +126,12 @@ def create_benchmark(self, key): init_kwargs = self.cases[key]["init_kwargs"] benchmark = PeakDetectionBenchmark(recording, gt_sorting, params, **init_kwargs) return benchmark - + def plot_performances_vs_snr(self, **kwargs): from .benchmark_plot_tools import plot_performances_vs_snr return plot_performances_vs_snr(self, **kwargs) - def plot_agreements_by_channels(self, case_keys=None, figsize=(15, 15)): if case_keys is None: case_keys = list(self.cases.keys()) @@ -157,7 +156,9 @@ def plot_agreements_by_units(self, case_keys=None, figsize=(15, 15)): ax.set_title(self.cases[key]["label"]) plot_agreement_matrix(self.get_result(key)["gt_comparison"], ax=ax) - def plot_detected_amplitude_distributions(self, case_keys=None, show_legend=True, detect_threshold=None, figsize=(15, 5),ax=None): + def plot_detected_amplitude_distributions( + self, case_keys=None, show_legend=True, detect_threshold=None, figsize=(15, 5), ax=None + ): if case_keys is None: case_keys = list(self.cases.keys()) @@ -168,7 +169,6 @@ def plot_detected_amplitude_distributions(self, case_keys=None, show_legend=True else: fig = ax.get_figure() - # plot only the first key for gt amplitude # TODO make a loop for all of then key0 = case_keys[0] @@ -179,12 +179,12 @@ def plot_detected_amplitude_distributions(self, case_keys=None, show_legend=True for count, key in enumerate(case_keys): despine(ax) data1 = self.get_result(key)["peaks"]["amplitude"] - + color = self.get_colors()[key] - + label = self.cases[key]["label"] ax.hist(data1, bins=bins, label=label, histtype="step", color=color, linewidth=2) - + # ax.set_title(self.cases[key]["label"]) ax.set_yscale("log")