-
Notifications
You must be signed in to change notification settings - Fork 266
Fix: Add pad_token handling and batch fallback for HuggingFaceCrossEn… #331
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,64 +1,91 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from typing import Any, Dict, List, Tuple | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from pydantic import BaseModel, ConfigDict, Field | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from langchain_community.cross_encoders.base import BaseCrossEncoder | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| DEFAULT_MODEL_NAME = "BAAI/bge-reranker-base" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| class HuggingFaceCrossEncoder(BaseModel, BaseCrossEncoder): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """HuggingFace cross encoder models. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Example: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| .. code-block:: python | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from langchain_community.cross_encoders import HuggingFaceCrossEncoder | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| model_name = "BAAI/bge-reranker-base" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| model_kwargs = {'device': 'cpu'} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| hf = HuggingFaceCrossEncoder( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| model_name=model_name, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| model_kwargs=model_kwargs | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| client: Any = None #: :meta private: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tokenizer: Any = None #: :meta private: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| model_name: str = DEFAULT_MODEL_NAME | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Model name to use.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| model_kwargs: Dict[str, Any] = Field(default_factory=dict) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Keyword arguments to pass to the model.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def __init__(self, **kwargs: Any): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Initialize the sentence_transformer.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| super().__init__(**kwargs) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import sentence_transformers | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| except ImportError as exc: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| raise ImportError( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "Could not import sentence_transformers python package. " | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "Please install it with `pip install sentence-transformers`." | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) from exc | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.client = sentence_transformers.CrossEncoder( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.model_name, **self.model_kwargs | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Store tokenizer reference for pad token management | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.tokenizer = self.client.tokenizer | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Ensure padding token is available | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self._ensure_padding_token() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _ensure_padding_token(self): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Ensure that a padding token is available for the tokenizer.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if self.tokenizer.pad_token is None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if self.tokenizer.eos_token: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.tokenizer.pad_token = self.tokenizer.eos_token | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.client.config.pad_token_id = self.tokenizer.eos_token_id | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| elif hasattr(self.tokenizer, 'unk_token') and self.tokenizer.unk_token: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.tokenizer.pad_token = self.tokenizer.unk_token | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.client.config.pad_token_id = self.tokenizer.unk_token_id | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.tokenizer.add_special_tokens({'pad_token': '[PAD]'}) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.client.resize_token_embeddings(len(self.tokenizer)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Potential AttributeError: The code calls
Suggested change
Spotted by Diamond |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.client.config.pad_token_id = self.tokenizer.pad_token_id | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+49
to
+61
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The new method
Suggested change
Spotted by Diamond (based on custom rule: Code quality) |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| model_config = ConfigDict(extra="forbid", protected_namespaces=()) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def score(self, text_pairs: List[Tuple[str, str]]) -> List[float]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Compute similarity scores using a HuggingFace transformer model. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| text_pairs: The list of text text_pairs to score the similarity. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| List of scores, one for each pair. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| scores = self.client.predict(text_pairs) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| scores = self.client.predict(text_pairs) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| except ValueError as e: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if "Cannot handle batch sizes > 1" in str(e) or "pad_token" in str(e).lower(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Fallback to processing pairs individually for models without pad_token | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| scores = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for pair in text_pairs: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| score = self.client.predict([pair]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| scores.extend(score if isinstance(score, list) else [score]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| raise e | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Some models e.g bert-multilingual-passage-reranking-msmarco | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # gives two score not_relevant and relevant as compare with the query. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if len(scores.shape) > 1: # we are going to get the relevant scores | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| scores = map(lambda x: x[1], scores) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if hasattr(scores, 'shape') and len(scores.shape) > 1: # we are going to get the relevant scores | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| scores = [x[1] for x in scores] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return scores | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Potential AttributeError: The code assumes
self.client.configexists, but CrossEncoder instances may not have aconfigattribute. This will cause a runtime error when trying to setpad_token_id. Should check if the config attribute exists before accessing it, or handle the AttributeError exception.Spotted by Diamond

Is this helpful? React 👍 or 👎 to let us know.