diff --git a/CHANGELOG.md b/CHANGELOG.md index a10d1a309..e9fc7938e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ - BREAKING(util): make `Binarize.__call__` return `string` tracks (instead of `int`) [@benniekiss](https://github.com/benniekiss/) - feat(cli): add option to apply pipeline on a directory of audio files +- feat(pipeline): add `preload` option to base `Pipeline.__call__` to force preloading audio in memory ([@antoinelaurent](https://github.com/antoinelaurent/)) - feat(pipeline): add `Pipeline.cuda()` convenience method [@tkanarsky](https://github.com/tkanarsky/) - improve(util): make `permutate` faster thanks to vectorized cost function diff --git a/src/pyannote/audio/core/pipeline.py b/src/pyannote/audio/core/pipeline.py index c9f5bfbd7..d082858a4 100644 --- a/src/pyannote/audio/core/pipeline.py +++ b/src/pyannote/audio/core/pipeline.py @@ -407,7 +407,23 @@ def classes(self) -> List | Iterator: """ raise NotImplementedError() - def __call__(self, file: AudioFile, **kwargs): + def __call__(self, file: AudioFile, preload: bool = False, **kwargs): + """Validate file, (optionally) load it in memory, then process it + + Parameters + ---------- + file : AudioFile + File to process + preload : bool, optional + Whether to preload waveform before applying the pipeline. + kwargs : keyword arguments, optional + Additional keyword arguments passed to `self.apply(...)` + + Returns + ------- + output : Any + Whatever `self.apply(...)` returns + """ fix_reproducibility(getattr(self, "device", torch.device("cpu"))) if not self.instantiated: @@ -433,9 +449,28 @@ def __call__(self, file: AudioFile, **kwargs): file = Audio.validate_file(file) + # check if the instance has preprocessors and wrap the file if so if hasattr(self, "preprocessors"): file = ProtocolFile(file, lazy=self.preprocessors) + # pre-load the audio in memory if requested + if preload: + # raise error if `waveform`` is already in memory (or will be via a preprocessor) + if ( + "waveform" in getattr(self, "preprocessors", dict()) + or "waveform" in file + ): + raise ValueError( + "Cannot preload audio: `waveform` key is already available or will be via a preprocessor." + ) + + # load waveform in memory (and keep track of its original sample rate) + file["waveform"], file["sample_rate"] = Audio()(file) + + # the above line already took care of channel selection, + # therefore we remove the `channel` key from the file + file.pop("channel", None) + # send file duration to telemetry as well as # requested number of speakers in case of diarization track_pipeline_apply(self, file, **kwargs)