Skip to content

Add sentence similarity metric using Sentence Transformers #37

@NISH1001

Description

@NISH1001

Initial code looks like

from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np

from evalem.nlp.metrics import SemanticMetric
from evalem._base.structures import EvaluationPredictionInstance, EvaluationReferenceInstance
from evalem._base.structures import MetricResult

class SentenceSimilarity(SemanticMetric):
    def __init__(self, model = 'all-MiniLM-L6-v2') -> None:
        self.model = SentenceTransformer(model)
        
    def compute(
        self,
        predictions: EvaluationPredictionInstance,
        references: EvaluationReferenceInstance,
        **kwargs,
    ) -> MetricResult:
        embeddings_preds = self.model.encode(predictions)
        embeddings_refs = self.model.encode(references)
        scores = np.diag(cosine_similarity(embeddings_preds, embeddings_refs))

        # scores = sent_util.cos_sim(embeddings_preds, embeddings_refs)
        return MetricResult(score=np.mean(scores), metric_name="SentenceSimilarity", total_items=len(predictions), extra=dict(scores=scores))

class CrossEncoderSentenceSimilarity(SemanticMetric):
    def __init__(self, model = 'cross-encoder/stsb-distilroberta-base') -> None:
        self.model_name = model
        self.model = CrossEncoder(model)
        
    def compute(
        self,
        predictions: EvaluationPredictionInstance,
        references: EvaluationReferenceInstance,
        **kwargs,
    ) -> MetricResult:
        sentences = list(zip(references, predictions))
        scores = self.model.predict(sentences)

        # scores = sent_util.cos_sim(embeddings_preds, embeddings_refs)
        return MetricResult(
            score=np.mean(scores),
            metric_name="CrossEncoderSentenceSimilarity",
            total_items=len(predictions),
            extra=dict(scores=scores, model=self.model_name)
        )

result = SentenceSimilarity()(
    references=[...], # flattened list
    predictions=[...], # flattened list
) # gives an object of MetricResult

result = CrossEncoderSentenceSimilarity()(
    references=[...], # flattened list
    predictions=[...], # flattened list
) # gives an object of MetricResult

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions