Skip to content

Commit 41d711b

Browse files
committed
conditionally use optimized method
1 parent e4846ff commit 41d711b

File tree

1 file changed

+47
-12
lines changed

1 file changed

+47
-12
lines changed

src/pyannote/audio/utils/signal.py

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,25 @@ def __call__(self, scores: SlidingWindowFeature) -> Annotation:
269269
frames = scores.sliding_window
270270
timestamps = [frames[i].middle for i in range(num_frames)]
271271

272-
# annotation meant to store 'active' regions
272+
if self.onset == self.offset:
273+
active = self._opt_binarize(scores, timestamps)
274+
else:
275+
active = self._binarize(scores, timestamps)
276+
277+
# because of padding, some active regions might be overlapping: merge them.
278+
# also: fill same speaker gaps shorter than min_duration_off
279+
if self.pad_offset > 0.0 or self.pad_onset > 0.0 or self.min_duration_off > 0.0:
280+
active = active.support(collar=self.min_duration_off)
281+
282+
# remove tracks shorter than min_duration_on
283+
if self.min_duration_on > 0:
284+
for segment, track in list(active.itertracks()):
285+
if segment.duration < self.min_duration_on:
286+
del active[segment, track]
287+
288+
return active
289+
290+
def _binarize(self, scores, timestamps):
273291
active = Annotation()
274292
track_generator = string_generator()
275293

@@ -282,7 +300,6 @@ def __call__(self, scores: SlidingWindowFeature) -> Annotation:
282300
is_active = k_scores[0] > self.onset
283301

284302
for t, y in zip(timestamps[1:], k_scores[1:]):
285-
286303
# currently active
287304
if is_active:
288305
# switching from active to inactive
@@ -304,19 +321,37 @@ def __call__(self, scores: SlidingWindowFeature) -> Annotation:
304321
region = Segment(start - self.pad_onset, t + self.pad_offset)
305322
active[region, track] = label
306323

307-
# because of padding, some active regions might be overlapping: merge them.
308-
# also: fill same speaker gaps shorter than min_duration_off
309-
if self.pad_offset > 0.0 or self.pad_onset > 0.0 or self.min_duration_off > 0.0:
310-
active = active.support(collar=self.min_duration_off)
324+
return active
311325

312-
# remove tracks shorter than min_duration_on
313-
if self.min_duration_on > 0:
314-
for segment, track in list(active.itertracks()):
315-
if segment.duration < self.min_duration_on:
316-
del active[segment, track]
326+
def _opt_binarize(self, scores, timestamps):
327+
active = Annotation()
317328

318-
return active
329+
for k, k_scores in enumerate(scores.data.T):
330+
label = k if scores.labels is None else scores.labels[k]
331+
332+
# Detect transitions
333+
is_active = k_scores > self.onset
334+
transitions = np.diff(is_active.astype(int))
335+
starts = np.where(transitions == 1)[0] + 1
336+
ends = np.where(transitions == -1)[0] + 1
319337

338+
# If the first frame is active, add it as a start
339+
if is_active[0]:
340+
starts = np.insert(starts, 0, 0)
341+
342+
# If the last frame is active, add it as an end
343+
if is_active[-1]:
344+
ends = np.append(ends, len(is_active) - 1)
345+
346+
# Create segments
347+
for start, end in zip(starts, ends):
348+
region = Segment(
349+
timestamps[start] - self.pad_onset,
350+
timestamps[end] + self.pad_offset,
351+
)
352+
active[region, k] = label
353+
354+
return active
320355

321356
class Peak:
322357
"""Peak detection

0 commit comments

Comments
 (0)