Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 44 additions & 17 deletions libs/community/langchain_community/cross_encoders/huggingface.py
Original file line number Diff line number Diff line change
@@ -1,64 +1,91 @@
from typing import Any, Dict, List, Tuple

from pydantic import BaseModel, ConfigDict, Field

from langchain_community.cross_encoders.base import BaseCrossEncoder

DEFAULT_MODEL_NAME = "BAAI/bge-reranker-base"


class HuggingFaceCrossEncoder(BaseModel, BaseCrossEncoder):
"""HuggingFace cross encoder models.
Example:
.. code-block:: python
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
model_name = "BAAI/bge-reranker-base"
model_kwargs = {'device': 'cpu'}
hf = HuggingFaceCrossEncoder(
model_name=model_name,
model_kwargs=model_kwargs
)
"""

client: Any = None #: :meta private:
tokenizer: Any = None #: :meta private:
model_name: str = DEFAULT_MODEL_NAME
"""Model name to use."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Keyword arguments to pass to the model."""

def __init__(self, **kwargs: Any):
"""Initialize the sentence_transformer."""
super().__init__(**kwargs)
try:
import sentence_transformers

except ImportError as exc:
raise ImportError(
"Could not import sentence_transformers python package. "
"Please install it with `pip install sentence-transformers`."
) from exc

self.client = sentence_transformers.CrossEncoder(
self.model_name, **self.model_kwargs
)


# Store tokenizer reference for pad token management
self.tokenizer = self.client.tokenizer

# Ensure padding token is available
self._ensure_padding_token()

def _ensure_padding_token(self):
"""Ensure that a padding token is available for the tokenizer."""
if self.tokenizer.pad_token is None:
if self.tokenizer.eos_token:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.client.config.pad_token_id = self.tokenizer.eos_token_id
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential AttributeError: The code assumes self.client.config exists, but CrossEncoder instances may not have a config attribute. This will cause a runtime error when trying to set pad_token_id. Should check if the config attribute exists before accessing it, or handle the AttributeError exception.

Suggested change
self.client.config.pad_token_id = self.tokenizer.eos_token_id
if hasattr(self.client, "config"):
self.client.config.pad_token_id = self.tokenizer.eos_token_id

Spotted by Diamond

Fix in Graphite


Is this helpful? React 👍 or 👎 to let us know.

elif hasattr(self.tokenizer, 'unk_token') and self.tokenizer.unk_token:
self.tokenizer.pad_token = self.tokenizer.unk_token
self.client.config.pad_token_id = self.tokenizer.unk_token_id
else:
self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
self.client.resize_token_embeddings(len(self.tokenizer))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential AttributeError: The code calls self.client.resize_token_embeddings() but CrossEncoder instances may not have this method. This is typically a method on transformer models, not CrossEncoder wrappers. This will cause a runtime error when executed.

Suggested change
self.client.resize_token_embeddings(len(self.tokenizer))
if hasattr(self.client, "resize_token_embeddings"):
self.client.resize_token_embeddings(len(self.tokenizer))

Spotted by Diamond

Fix in Graphite


Is this helpful? React 👍 or 👎 to let us know.

self.client.config.pad_token_id = self.tokenizer.pad_token_id
Comment on lines +49 to +61
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new method _ensure_padding_token violates the 'Use Google-Style Docstrings (with Args section)' rule. While it has a basic docstring, it lacks the proper Google-style format with an Args section. Since this is a private method (indicated by the underscore prefix), it should still follow the docstring guidelines for consistency. The method should include a proper docstring with Args section describing any parameters it uses (even though it currently has none, the format should be established for future maintainability).

Suggested change
def _ensure_padding_token(self):
"""Ensure that a padding token is available for the tokenizer."""
if self.tokenizer.pad_token is None:
if self.tokenizer.eos_token:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.client.config.pad_token_id = self.tokenizer.eos_token_id
elif hasattr(self.tokenizer, 'unk_token') and self.tokenizer.unk_token:
self.tokenizer.pad_token = self.tokenizer.unk_token
self.client.config.pad_token_id = self.tokenizer.unk_token_id
else:
self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
self.client.resize_token_embeddings(len(self.tokenizer))
self.client.config.pad_token_id = self.tokenizer.pad_token_id
def _ensure_padding_token(self):
"""Ensure that a padding token is available for the tokenizer."""
if self.tokenizer.pad_token is None:
if self.tokenizer.eos_token:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.client.config.pad_token_id = self.tokenizer.eos_token_id
elif hasattr(self.tokenizer, 'unk_token') and self.tokenizer.unk_token:
self.tokenizer.pad_token = self.tokenizer.unk_token
self.client.config.pad_token_id = self.tokenizer.unk_token_id
else:
self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
self.client.resize_token_embeddings(len(self.tokenizer))
self.client.config.pad_token_id = self.tokenizer.pad_token_id

Spotted by Diamond (based on custom rule: Code quality)

Fix in Graphite


Is this helpful? React 👍 or 👎 to let us know.


model_config = ConfigDict(extra="forbid", protected_namespaces=())

def score(self, text_pairs: List[Tuple[str, str]]) -> List[float]:
"""Compute similarity scores using a HuggingFace transformer model.
Args:
text_pairs: The list of text text_pairs to score the similarity.
Returns:
List of scores, one for each pair.
"""
scores = self.client.predict(text_pairs)
try:
scores = self.client.predict(text_pairs)
except ValueError as e:
if "Cannot handle batch sizes > 1" in str(e) or "pad_token" in str(e).lower():
# Fallback to processing pairs individually for models without pad_token
scores = []
for pair in text_pairs:
score = self.client.predict([pair])
scores.extend(score if isinstance(score, list) else [score])
else:
raise e

# Some models e.g bert-multilingual-passage-reranking-msmarco
# gives two score not_relevant and relevant as compare with the query.
if len(scores.shape) > 1: # we are going to get the relevant scores
scores = map(lambda x: x[1], scores)
if hasattr(scores, 'shape') and len(scores.shape) > 1: # we are going to get the relevant scores
scores = [x[1] for x in scores]

return scores