@@ -269,19 +269,35 @@ def __call__(self, scores: SlidingWindowFeature) -> Annotation:
269
269
frames = scores .sliding_window
270
270
timestamps = [frames [i ].middle for i in range (num_frames )]
271
271
272
- # annotation meant to store 'active' regions
272
+ if self .onset == self .offset :
273
+ active = self ._opt_binarize (scores , timestamps )
274
+ else :
275
+ active = self ._binarize (scores , timestamps )
276
+
277
+ # because of padding, some active regions might be overlapping: merge them.
278
+ # also: fill same speaker gaps shorter than min_duration_off
279
+ if self .pad_offset > 0.0 or self .pad_onset > 0.0 or self .min_duration_off > 0.0 :
280
+ active = active .support (collar = self .min_duration_off )
281
+
282
+ # remove tracks shorter than min_duration_on
283
+ if self .min_duration_on > 0 :
284
+ for segment , track in list (active .itertracks ()):
285
+ if segment .duration < self .min_duration_on :
286
+ del active [segment , track ]
287
+
288
+ return active
289
+
290
+ def _binarize (self , scores , timestamps ):
273
291
active = Annotation ()
274
292
275
293
for k , k_scores in enumerate (scores .data .T ):
276
-
277
294
label = k if scores .labels is None else scores .labels [k ]
278
295
279
296
# initial state
280
297
start = timestamps [0 ]
281
298
is_active = k_scores [0 ] > self .onset
282
299
283
300
for t , y in zip (timestamps [1 :], k_scores [1 :]):
284
-
285
301
# currently active
286
302
if is_active :
287
303
# switching from active to inactive
@@ -303,19 +319,37 @@ def __call__(self, scores: SlidingWindowFeature) -> Annotation:
303
319
region = Segment (start - self .pad_onset , t + self .pad_offset )
304
320
active [region , k ] = label
305
321
306
- # because of padding, some active regions might be overlapping: merge them.
307
- # also: fill same speaker gaps shorter than min_duration_off
308
- if self .pad_offset > 0.0 or self .pad_onset > 0.0 or self .min_duration_off > 0.0 :
309
- active = active .support (collar = self .min_duration_off )
322
+ return active
310
323
311
- # remove tracks shorter than min_duration_on
312
- if self .min_duration_on > 0 :
313
- for segment , track in list (active .itertracks ()):
314
- if segment .duration < self .min_duration_on :
315
- del active [segment , track ]
324
+ def _opt_binarize (self , scores , timestamps ):
325
+ active = Annotation ()
316
326
317
- return active
327
+ for k , k_scores in enumerate (scores .data .T ):
328
+ label = k if scores .labels is None else scores .labels [k ]
329
+
330
+ # Detect transitions
331
+ is_active = k_scores > self .onset
332
+ transitions = np .diff (is_active .astype (int ))
333
+ starts = np .where (transitions == 1 )[0 ] + 1
334
+ ends = np .where (transitions == - 1 )[0 ] + 1
335
+
336
+ # If the first frame is active, add it as a start
337
+ if is_active [0 ]:
338
+ starts = np .insert (starts , 0 , 0 )
339
+
340
+ # If the last frame is active, add it as an end
341
+ if is_active [- 1 ]:
342
+ ends = np .append (ends , len (is_active ) - 1 )
343
+
344
+ # Create segments
345
+ for start , end in zip (starts , ends ):
346
+ region = Segment (
347
+ timestamps [start ] - self .pad_onset ,
348
+ timestamps [end ] + self .pad_offset ,
349
+ )
350
+ active [region , k ] = label
318
351
352
+ return active
319
353
320
354
class Peak :
321
355
"""Peak detection
0 commit comments