diff --git a/doc/source/changelog.rst b/doc/source/changelog.rst index 3b0f2ec..625d7da 100644 --- a/doc/source/changelog.rst +++ b/doc/source/changelog.rst @@ -2,9 +2,15 @@ Changelog ######### +Version 4.0.0rc2 (2025-02-23) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +- feat(optimize): add option to pass keyword arguments to pipeline during optimization + Version 4.0.0rc1 (2025-02-11) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +- feat(optimize): add option to pass keyword arguments to pipeline during optimization - BREAKING: drop support for `Python` < 3.10 - BREAKING: switch to native namespace package - BREAKING: remove `pyannote.pipeline.blocks` submodule diff --git a/src/pyannote/pipeline/optimizer.py b/src/pyannote/pipeline/optimizer.py index 4b4de05..1057093 100644 --- a/src/pyannote/pipeline/optimizer.py +++ b/src/pyannote/pipeline/optimizer.py @@ -30,7 +30,7 @@ import time import warnings from pathlib import Path -from typing import Iterable, Optional, Callable, Generator, Union, Dict +from typing import Iterable, Optional, Callable, Generator, Mapping, Union, Dict import numpy as np import optuna.logging @@ -231,7 +231,16 @@ def objective(trial: Trial) -> float: # process input with pipeline # (and keep track of processing time) before_processing = time.time() - output = pipeline(input) + + # get optional kwargs to be passed to the pipeline + # (e.g. num_speakers for speaker diarization). they + # must be stored in a 'pipeline_kwargs' key in the + # `input` dictionary. + if isinstance(input, Mapping): + pipeline_kwargs = input.get("pipeline_kwargs", {}) + else: + pipeline_kwargs = {} + output = pipeline(input, **pipeline_kwargs) after_processing = time.time() processing_time.append(after_processing - before_processing)