diff --git a/evalem/_base/metrics.py b/evalem/_base/metrics.py index 476a661..d8c6563 100755 --- a/evalem/_base/metrics.py +++ b/evalem/_base/metrics.py @@ -14,8 +14,12 @@ EvaluationPredictionInstance, EvaluationReferenceInstance, MetricResult, + MultiplePredictionInstance, + MultipleReferenceInstance, + PredictionInstance, SequenceType, SinglePredictionInstance, + SingleReferenceInstance, ) @@ -105,7 +109,64 @@ def __call__( ) @staticmethod - def _flatten_references( + def _is_single_prediction_multi_reference(predictions, references) -> bool: + return isinstance(predictions, PredictionInstance) and isinstance( + references, + SequenceType, + ) + + @staticmethod + def _is_multi_prediction_single_reference(predictions, references) -> bool: + return isinstance(predictions, SequenceType) and isinstance( + references, + PredictionInstance, + ) + + @staticmethod + def _is_multi_prediction_multi_reference(predictions, references) -> bool: + return isinstance(predictions, SequenceType) and isinstance( + references, + SequenceType, + ) + + def _flatten_single_prediction_multi_reference( + self, + predictions: SinglePredictionInstance, + references: MultipleReferenceInstance, + ) -> Tuple[SinglePredictionInstance, SingleReferenceInstance]: + res = [] + for preds, refs in zip(predictions, references): + if Metric._is_single_prediction_multi_reference(preds, refs): + res.extend(list(map(lambda r: (preds, r), refs))) + else: + res.append((preds, refs)) + predictions, references = zip(*res) + return predictions, references + + def _flatten_multi_prediction_single_reference( + self, + predictions: MultipleReferenceInstance, + references: SingleReferenceInstance, + ) -> Tuple[SinglePredictionInstance, SingleReferenceInstance]: + res = [] + for preds, refs in zip(predictions, references): + if Metric._is_multi_prediction_single_reference(preds, refs): + res.extend(list(map(lambda p: (p, refs), preds))) + else: + res.append((preds, refs)) + predictions, references = zip(*res) + return predictions, references + + def _flatten_multi_prediction_multi_reference( + self, + predictions: MultipleReferenceInstance, + references: SingleReferenceInstance, + ) -> Tuple[SinglePredictionInstance, SingleReferenceInstance]: + # No-op + return predictions, references + + def _flatten_instances( + self, predictions: EvaluationPredictionInstance, references: EvaluationReferenceInstance, ) -> Tuple[EvaluationPredictionInstance, EvaluationReferenceInstance]: @@ -124,17 +185,14 @@ def _flatten_references( Returns: Tuple of flattened lists (predictions, references) """ - res = [] - for pred, ref in zip(predictions, references): - # if multiple predictions, skip for now - if isinstance(pred, SequenceType) and not isinstance(pred, str): - raise TypeError("Cannot handle multiple prediction instance") - # if multiple references - elif isinstance(ref, SequenceType) and not isinstance(ref, str): - res.extend(list(map(lambda r: (pred, r), ref))) - else: - res.append((pred, ref)) - predictions, references = zip(*res) + predictions, references = self._flatten_multi_prediction_single_reference( + predictions, + references, + ) + predictions, references = self._flatten_single_prediction_multi_reference( + predictions, + references, + ) return predictions, references @@ -266,7 +324,7 @@ def compute( references, ) - predictions, references = self._flatten_references(predictions, references) + predictions, references = self._flatten_instances(predictions, references) labels = self.__get_labels(predictions, references) return MetricResult.from_dict( diff --git a/evalem/_base/structures.py b/evalem/_base/structures.py index f784333..0e6d99b 100755 --- a/evalem/_base/structures.py +++ b/evalem/_base/structures.py @@ -84,14 +84,14 @@ def __hash__(self) -> str: # Represents type instance for any single downstream prediction PredictionInstance = Union[ str, - Type[PredictionDTO], + PredictionDTO, dict, ImageTensor, - Type[ClassificationDTO], + ClassificationDTO, ] # Represents type instance for any single downstream reference/ground-truth -ReferenceInstance = Union[str, Type[ReferenceDTO]] +ReferenceInstance = Union[str, ReferenceDTO] SinglePredictionInstance = List[PredictionInstance] MultiplePredictionInstance = List[List[PredictionInstance]] diff --git a/evalem/nlp/metrics/basics.py b/evalem/nlp/metrics/basics.py index 68a1f9d..b2d9da1 100755 --- a/evalem/nlp/metrics/basics.py +++ b/evalem/nlp/metrics/basics.py @@ -1,9 +1,11 @@ #!/usr/bin/env python3 import dataclasses +from typing import Tuple from ..._base.metrics import JuryBasedMetric from ..._base.structures import ( + EvaluationPredictionInstance, EvaluationReferenceInstance, MetricResult, SinglePredictionInstance, @@ -15,6 +17,13 @@ class ExactMatchMetric(JuryBasedMetric, NLPMetric): def __init__(self) -> None: super().__init__(metrics="exact_match") + def _flatten_multi_prediction_multi_reference( + self, + predictions: EvaluationPredictionInstance, + references: EvaluationReferenceInstance, + ) -> Tuple[EvaluationPredictionInstance, EvaluationReferenceInstance]: + raise NotImplementedError() + def compute( self, predictions: SinglePredictionInstance, @@ -24,7 +33,7 @@ def compute( # This metric doesn't support multi-reference format. # So, we flatten everything: # Single Prediction, Multi-Ref -> Single Prediction, Single-Ref - predictions, references = self._flatten_references(predictions, references) + predictions, references = self._flatten_instances(predictions, references) result = super().compute( predictions=predictions, references=references, diff --git a/evalem/nlp/metrics/llm.py b/evalem/nlp/metrics/llm.py index e09b79b..8cf7d67 100755 --- a/evalem/nlp/metrics/llm.py +++ b/evalem/nlp/metrics/llm.py @@ -54,6 +54,15 @@ class LLMAsJudgeMetric(NLPMetric): ```aggregation_type```: ```Optional[AggregationType]``` Decides how to aggregate scores from the multiple judgement tries. Defaults to `AggregationType.MEAN` if not provided. + ```max_n```: ```Optional[int]``` + If set, the total number of references or predictions per item. + This is to reduce LLM calls and thus minimizing scoring time. + Default behaviour is no truncation when set to `None` or less than 1. + will be truncated. + - If single reference, multiple predictions, total number of prediction will + be truncated + - If multiple reference, single prediction, total number of + reference will be truncated ```debug```:```bool``` Boolean flag for debug-mode outputs @@ -103,21 +112,29 @@ def __init__( temperature: float = 0.0, prompt: Optional[str] = None, aggregation_type: Optional[List[AggregationType]] = None, + max_n: Optional[int] = None, debug: bool = False, ) -> None: super().__init__(debug=debug) + + model = self.__clean_model(model) + api_base = self.__clean_url(api_base) self.model = outlines.models.openai( - self.__clean_model(model), + model, base_url=api_base, api_key=api_key, config=OpenAIConfig(temperature=temperature), ) - self.api_base = self.__clean_url(api_base) + self.api_base = api_base self.n_tries = n_tries or 1 self.prompt = prompt or LLMAsJudgeMetric._prompt self.aggregation_type = aggregation_type or AggregationType.MEAN - self._sanity_check_prmopt(self.prompt) + self.max_n = max_n or None + if self.max_n: + logger.warning( + f"Total number of predictions/references per item will be truncated based on `max_n` value.", + ) def _sanity_check_prmopt(self, prompt: str) -> bool: if "{prediction}" not in prompt or "{reference}" not in prompt: @@ -133,24 +150,25 @@ def __clean_model(self, model: str) -> str: def __clean_url(self, url: str) -> str: if not url.endswith("/v1"): - url = urljoin(url, "/v1") + url = urljoin(url, "v1") return url - @staticmethod - def _flatten_references( + def _flatten_instances( + self, predictions, references, + max_n: Optional[int] = None, ) -> Tuple[EvaluationPredictionInstance, EvaluationReferenceInstance]: + if max_n is not None and max_n < 1: + max_n = None res = [] for preds, refs in zip(predictions, references): # multiple predictions, single reference - if isinstance(preds, SequenceType) and isinstance(refs, str): - res.extend(list(map(lambda p: (p, refs), preds))) - + if self._is_multi_prediction_single_reference(preds, refs): + res.extend(list(map(lambda p: (p, refs), preds[slice(max_n)]))) # single prediction, multiple references - elif isinstance(preds, str) and isinstance(refs, SequenceType): - res.extend(list(map(lambda r: (preds, r), refs))) - + elif self._is_single_prediction_multi_reference(preds, refs): + res.extend(list(map(lambda r: (preds, r), refs[slice(max_n)]))) # single prediction, single reference else: res.append((preds, refs)) @@ -165,7 +183,11 @@ def compute( **kwargs, ) -> MetricResult: # make sure to flatten - predictions, references = self._flatten_references(predictions, references) + predictions, references = self._flatten_instances( + predictions, + references, + max_n=self.max_n, + ) if self.debug: logger.debug(f"Evaluating for {len(predictions)} predictions.") generator = outlines.generate.choice(self.model, ["0", "1"])