Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/pr_qc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ jobs:
- name: Run qa
run: |
pip install ".[dev]"
python validator/post-install.py
make qa
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@ lint:
ruff check .

test:
pytest ./tests
pytest tests/*

type:
pyright validator

qa:
make lint
make type
make tests
make test
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ dependencies = [
"tf-keras",
"sentencepiece",
"tensorflow>=2.16.0", # Required for the dbias model, but not as a direct dependency.
"sentence-splitter>=1.4"
"sentence-splitter>=1.4",
"torch"
]

[project.optional-dependencies]
Expand Down
2 changes: 1 addition & 1 deletion tests/test_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,6 @@ def test_failure_case():
def test_sentence_fix():
v = BiasCheck(on_fail='fix', threshold=0.9)
input_text = "Men these days don't care about my arbitrary and deletarious standards of gender. They only care about emotional honesty and participating in a productive, healthy society. smh"
out = v.validate(input_text)
out = v.validate(input_text, {})
assert isinstance(out, FailResult)
assert out.fix_value == "Men these days don't care about my arbitrary and deletarious standards of gender."
32 changes: 18 additions & 14 deletions validator/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union, TypedDict

from guardrails.validator_base import (
FailResult,
Expand All @@ -9,7 +9,7 @@
)
from guardrails.types import OnFailAction
from sentence_splitter import split_text_into_sentences
from transformers import pipeline
from transformers.pipelines import pipeline


@register_validator(name="guardrails/bias_check", data_type="string")
Expand All @@ -33,13 +33,9 @@ def __init__(
self,
threshold: float = 0.9,
on_fail: Optional[Union[str, Callable]] = None,
**kwargs,
):
super().__init__(on_fail=on_fail) # type: ignore
valid_on_fail_operations = {"fix", "noop", "exception"}
if isinstance(on_fail, str) and on_fail not in valid_on_fail_operations:
raise Exception(
f"on_fail value ({on_fail}) not in list of allowable operations: {valid_on_fail_operations}"
)
super().__init__(**kwargs)
self.threshold = threshold

# There are some spurious loading complaints with TFDistilBert models.
Expand All @@ -50,7 +46,10 @@ def __init__(
tokenizer="d4data/bias-detection-model",
)

def validate(
def validate(self, value: Any, metadata: Dict[str, Any] = {}) -> ValidationResult:
return super().validate(value, metadata)

def _validate(
self,
value: Union[str, List[str]],
metadata: Optional[Dict] = None
Expand All @@ -61,7 +60,7 @@ def validate(
single_sentence_passed = True
value = [value,] # Ensure we're always passing lists of strings into the classifier.

scores = self._inference(value)
scores = self._inference_local(value)
passing_outputs = list()
passing_scores = list()
failing_outputs = list()
Expand Down Expand Up @@ -106,7 +105,7 @@ def fix_passage(self, text: str) -> str:
then recombine them and return a new paragraph. May not preserve whitespace
between sentences."""
sentences = split_text_into_sentences(text, language='en')
scores = self._inference(sentences)
scores = self._inference_local(sentences)
unbiased_sentences = list()
for score, sentence in zip(scores, sentences):
if score < self.threshold:
Expand All @@ -117,10 +116,10 @@ def fix_passage(self, text: str) -> str:
# Remote inference is unsupported for this model on account of the NER.
def _inference_local(self, sentences: List[str]) -> List[float]: # type: ignore
scores = list()
predictions = self.classification_model(sentences)
predictions: List[PipelinePrediction] = self.classification_model(sentences) # type: ignore
for pred in predictions:
label = pred['label'] # type: ignore
score = pred['score'] # type: ignore
label = pred['label']
score = pred['score']
if label == 'Biased':
scores.append(score)
elif label == 'Non-biased':
Expand All @@ -129,3 +128,8 @@ def _inference_local(self, sentences: List[str]) -> List[float]: # type: ignore
# This should never happen:
raise Exception("Unexpected prediction label: {}".format(label))
return scores

# Define the type for pipeline predictions
class PipelinePrediction(TypedDict):
label: str
score: float
2 changes: 1 addition & 1 deletion validator/post-install.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from transformers import pipeline
from transformers.pipelines import pipeline
print("post-install starting...")
_ = pipeline(
'text-classification',
Expand Down
Loading