Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions nemoguardrails/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,20 @@
# limitations under the License.

import contextvars
from typing import Optional
from typing import TYPE_CHECKING, Optional

if TYPE_CHECKING:
from nemoguardrails.logging.explain import LLMCallInfo

streaming_handler_var = contextvars.ContextVar("streaming_handler", default=None)

# The object that holds additional explanation information.
explain_info_var = contextvars.ContextVar("explain_info", default=None)

# The current LLM call.
llm_call_info_var = contextvars.ContextVar("llm_call_info", default=None)
llm_call_info_var: contextvars.ContextVar[
Optional["LLMCallInfo"]
] = contextvars.ContextVar("llm_call_info", default=None)

# All the generation options applicable to the current context.
generation_options_var = contextvars.ContextVar("generation_options", default=None)
Expand Down
23 changes: 14 additions & 9 deletions nemoguardrails/library/attention/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ def __init__(self) -> None:
self.user_is_talking = False
self.sentence_distribution = {UNKNOWN_ATTENTION_STATE: 0.0}
self.attention_events: list[ActionEvent] = []
self.utterance_started_event = None
self.utterance_last_event = None
self.utterance_started_event: Optional[ActionEvent] = None
self.utterance_last_event: Optional[ActionEvent] = None

def reset_view(self) -> None:
"""Reset the view. Removing all attention events except for the most recent one"""
Expand All @@ -111,16 +111,18 @@ def update(self, event: ActionEvent, offsets: dict[str, float]) -> None:

Args:
event (ActionEvent): Action event to use for updating the view
offsets (dict[str, float]): You can provide static offsets in seconds for every event type to correct for known latencies of these events.
offsets (dict[str, float]): You can provide static offsets in seconds for every event type to
correct for known latencies of these events.
"""
# print(f"attention_events: {self.attention_events}")
timestamp = _get_action_timestamp(event.name, event.arguments)
if not timestamp:
return

event.corrected_datetime = timestamp + timedelta(
seconds=offsets.get(event.name, 0.0)
)
# Neither ActionEvent nor base class Event have `corrected_time` attribute
# so add it dynamically
corrected_time = timestamp + timedelta(seconds=offsets.get(event.name, 0.0))
setattr(event, "corrected_datetime", corrected_time)

if event.name == "UtteranceUserActionStarted":
self.reset_view()
Expand All @@ -144,7 +146,8 @@ def get_time_spent_percentage(self, attention_levels: list[str]) -> float:
attention_levels (list[str]): List of attention level names to consider `attentive`

Returns:
float: The percentage the user was in the attention levels provided. Returns 1.0 if no attention events have been registered.
float: The percentage the user was in the attention levels provided. Returns 1.0 if no
attention events have been registered.
"""
log_p(f"attention_events={self.attention_events}")

Expand Down Expand Up @@ -194,7 +197,8 @@ def get_time_spent_percentage(self, attention_levels: list[str]) -> float:
)
durations = compute_time_spent_in_states(state_changes)

# If the only state we observed during the duration of the utterance is UNKNOWN_ATTENTION_STATE we treat it as 1.0
# If the only state we observed during the duration of the utterance is UNKNOWN_ATTENTION_STATE
# we treat it as 1.0
if len(durations) == 1 and UNKNOWN_ATTENTION_STATE in durations:
return 1.0

Expand Down Expand Up @@ -238,6 +242,7 @@ async def get_attention_percentage_action(attention_levels: list[str]) -> float:
attention_levels : Name of attention levels for which the user is considered to be `attentive`

Returns:
float: The percentage the user was in the attention levels provided. Returns 1.0 if no attention events have been registered.
float: The percentage the user was in the attention levels provided. Returns 1.0 if no
attention events have been registered.
"""
return _attention_view.get_time_spent_percentage(attention_levels)
38 changes: 30 additions & 8 deletions nemoguardrails/library/autoalign/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,8 @@ def process_autoalign_output(responses: List[Any], show_toxic_phrases: bool = Fa
response_dict["combined_response"] = ", ".join(prefixes) + " detected."
if (
"toxicity_detection" in response_dict.keys()
and response_dict["toxicity_detection"]["guarded"]
and isinstance(response_dict["toxicity_detection"], dict)
and response_dict["toxicity_detection"].get("guarded", False)
and show_toxic_phrases
):
response_dict["combined_response"] += suffix
Expand All @@ -173,11 +174,12 @@ async def autoalign_infer(
headers = {"x-api-key": api_key}
config = copy.deepcopy(DEFAULT_CONFIG)
# enable the select guardrail
for task in task_config.keys():
if task != "factcheck":
config[task]["mode"] = "DETECT"
if task_config[task]:
config[task].update(task_config[task])
if task_config is not None:
for task in task_config.keys():
if task != "factcheck" and isinstance(config.get(task), dict):
config[task]["mode"] = "DETECT"
if task_config[task] and isinstance(config.get(task), dict):
config[task].update(task_config[task])
request_body = {"prompt": text, "config": config, "multi_language": multi_language}

guardrails_configured = []
Expand Down Expand Up @@ -287,7 +289,12 @@ async def autoalign_input_api(
**kwargs,
):
"""Calls AutoAlign API for the user message and guardrail configuration provided"""
if not context:
raise ValueError("Context is required")
user_message = context.get("user_message")
if not user_message:
raise ValueError("user_message is required in context")

autoalign_config = llm_task_manager.config.rails.config.autoalign
autoalign_api_url = autoalign_config.parameters.get("endpoint")
multi_language = autoalign_config.parameters.get("multi_language", False)
Expand Down Expand Up @@ -327,7 +334,12 @@ async def autoalign_output_api(
**kwargs,
):
"""Calls AutoAlign API for the bot message and guardrail configuration provided"""
if context is None:
raise ValueError("Context is required")
bot_message = context.get("bot_message")
if bot_message is None:
raise ValueError("bot_message is required in context")

autoalign_config = llm_task_manager.config.rails.config.autoalign
autoalign_api_url = autoalign_config.parameters.get("endpoint")
multi_language = autoalign_config.parameters.get("multi_language", False)
Expand Down Expand Up @@ -366,8 +378,12 @@ async def autoalign_groundedness_output_api(
):
"""Calls AutoAlign groundedness check API and checks whether the bot message is factually grounded according to given
documents"""

if context is None:
raise ValueError("Context is required")
bot_message = context.get("bot_message")
if bot_message is None:
raise ValueError("bot_message is required in context")

documents = context.get("relevant_chunks_sep", [])

autoalign_config = llm_task_manager.config.rails.config.autoalign
Expand Down Expand Up @@ -404,9 +420,15 @@ async def autoalign_factcheck_output_api(
show_autoalign_message: bool = True,
):
"""Calls Autoalign Factchecker API and checks if the user message is factually answered by the bot message"""

if context is None:
raise ValueError("Context is required")
user_message = context.get("user_message")
if user_message is None:
raise ValueError("user_message is required in context")
bot_message = context.get("bot_message")
if bot_message is None:
raise ValueError("bot_message is required in context")

autoalign_config = llm_task_manager.config.rails.config.autoalign
autoalign_factcheck_api_url = autoalign_config.parameters.get("fact_check_endpoint")
multi_language = autoalign_config.parameters.get("multi_language", False)
Expand Down
7 changes: 4 additions & 3 deletions nemoguardrails/library/clavata/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,11 @@ def get_policy_id(
pass

# Not a valid UUID, try to match the provided alias to a policy ID and return that
policy_id = config.policies.get(policy)
if policy_id is None:
raise ClavataPluginValueError(f"Policy with alias '{policy}' not found.")

try:
policy_id = config.policies.get(policy)
if policy_id is None:
raise ClavataPluginValueError(f"Policy with alias '{policy}' not found.")
return uuid.UUID(policy_id)
except ValueError as e:
# Specifically catch the ValueError for badly formed UUIDs so we can provide a more helpful error message
Expand Down
6 changes: 4 additions & 2 deletions nemoguardrails/library/cleanlab/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import os
from typing import Dict, Optional, Union
Expand All @@ -38,13 +37,16 @@ async def call_cleanlab_api(
context: Optional[dict] = None,
**kwargs,
) -> Union[ValueError, ImportError, Dict]:
if context is None:
raise ValueError("Context is required")

api_key = os.environ.get("CLEANLAB_API_KEY")

if api_key is None:
raise ValueError("CLEANLAB_API_KEY environment variable not set.")

try:
from cleanlab_studio import Studio
from cleanlab_studio import Studio # type: ignore
except ImportError:
raise ImportError(
"Please install cleanlab-studio using 'pip install --upgrade cleanlab-studio' command"
Expand Down
3 changes: 3 additions & 0 deletions nemoguardrails/library/factchecking/align_score/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ async def alignscore_check_facts(
**kwargs,
):
"""Checks the facts for the bot response using an information alignment score."""
if context is None:
raise ValueError("Context is required")

fact_checking_config = llm_task_manager.config.rails.config.fact_checking
fallback_to_self_check = fact_checking_config.fallback_to_self_check

Expand Down
20 changes: 17 additions & 3 deletions nemoguardrails/library/factchecking/align_score/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,24 @@
from functools import lru_cache
from typing import List

import nltk
try:
import nltk # type: ignore
except ImportError:
nltk = None

try:
from alignscore import AlignScore # type: ignore
except ImportError:
AlignScore = None

import typer
import uvicorn
from alignscore import AlignScore
from fastapi import FastAPI
from pydantic import BaseModel

# Make sure we have the punkt tokenizer downloaded.
nltk.download("punkt")
if nltk is not None:
nltk.download("punkt")

models_path = os.environ.get("ALIGN_SCORE_PATH")

Expand All @@ -47,6 +56,11 @@ def get_model(model: str):
Args
model: The type of the model to be loaded, i.e. "base", "large".
"""
if models_path is None:
raise ValueError("ALIGN_SCORE_PATH environment variable not set")
if AlignScore is None:
raise ImportError("alignscore package not available")

return AlignScore(
model="roberta-base",
batch_size=32,
Expand Down
12 changes: 12 additions & 0 deletions nemoguardrails/library/fiddler/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ async def call_fiddler_guardrail(

@action(name="call fiddler safety on user message", is_system_action=True)
async def call_fiddler_safety_user(config: RailsConfig, context: Optional[dict] = None):
if context is None:
log.error("Context is required for Fiddler Jailbreak Guardrails")
return False

fiddler_config: FiddlerGuardrails = getattr(config.rails.config, "fiddler")
base_url = fiddler_config.fiddler_endpoint

Expand Down Expand Up @@ -114,6 +118,10 @@ async def call_fiddler_safety_user(config: RailsConfig, context: Optional[dict]

@action(name="call fiddler safety on bot message", is_system_action=True)
async def call_fiddler_safety_bot(config: RailsConfig, context: Optional[dict] = None):
if context is None:
log.error("Context is required for Fiddler Safety Guardrails")
return False

fiddler_config: FiddlerGuardrails = getattr(config.rails.config, "fiddler")
base_url = fiddler_config.fiddler_endpoint

Expand Down Expand Up @@ -144,6 +152,10 @@ async def call_fiddler_safety_bot(config: RailsConfig, context: Optional[dict] =
async def call_fiddler_faithfulness(
config: RailsConfig, context: Optional[dict] = None
):
if context is None:
log.error("Context is required for Fiddler Faithfulness Guardrails")
return False

fiddler_config: FiddlerGuardrails = getattr(config.rails.config, "fiddler")
base_url = fiddler_config.fiddler_endpoint

Expand Down
9 changes: 6 additions & 3 deletions nemoguardrails/library/gcp_moderate_text/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
from typing import Optional

try:
from google.cloud import language_v2
from google.cloud import language_v2 # type: ignore
except ImportError:
# The exception about installing google-cloud-language will be on the first call to the moderation api
pass
language_v2 = None


from nemoguardrails.actions import action
Expand Down Expand Up @@ -115,8 +115,11 @@ async def call_gcp_text_moderation_api(

For more information check https://cloud.google.com/docs/authentication/application-default-credentials
"""
if context is None:
raise ValueError("Context is required")

try:
from google.cloud import language_v2
from google.cloud import language_v2 # type: ignore

except ImportError:
raise ImportError(
Expand Down
28 changes: 22 additions & 6 deletions nemoguardrails/library/guardrails_ai/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
from typing import Any, Dict, Optional, Type

try:
from guardrails import Guard
from guardrails import Guard # type: ignore
except ImportError:
# Mock Guard class for when guardrails is not available
class Guard:
class Guard: # type: ignore
def __init__(self):
pass

Expand Down Expand Up @@ -110,11 +110,18 @@ def validate_guardrails_ai_input(
Dict with validation_result
"""

text = text or context.get("user_message", "")
text = text or (context.get("user_message", "") if context else "")
if not text:
raise ValueError("Either 'text' or 'context' must be provided.")

validator_config = config.rails.config.guardrails_ai.get_validator_config(validator)
guardrails_ai_config = config.rails.config.guardrails_ai
if guardrails_ai_config is None:
raise ValueError("Guardrails AI config is not configured")

validator_config = guardrails_ai_config.get_validator_config(validator)
if validator_config is None:
raise ValueError(f"Validator config for '{validator}' not found")

parameters = validator_config.parameters or {}
metadata = validator_config.metadata or {}

Expand Down Expand Up @@ -149,11 +156,20 @@ def validate_guardrails_ai_output(
Dict with validation_result
"""

text = text or context.get("bot_message", "")
text = text or (context.get("bot_message", "") if context else "")
if not text:
raise ValueError("Either 'text' or 'context' must be provided.")
if config is None:
raise ValueError("Config is required")

guardrails_ai_config = config.rails.config.guardrails_ai
if guardrails_ai_config is None:
raise ValueError("Guardrails AI config is not configured")

validator_config = guardrails_ai_config.get_validator_config(validator)
if validator_config is None:
raise ValueError(f"Validator config for '{validator}' not found")

validator_config = config.rails.config.guardrails_ai.get_validator_config(validator)
parameters = validator_config.parameters or {}
metadata = validator_config.metadata or {}

Expand Down
2 changes: 1 addition & 1 deletion nemoguardrails/library/guardrails_ai/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.

try:
from guardrails.errors import ValidationError
from guardrails.errors import ValidationError # type: ignore

GuardrailsAIValidationError = ValidationError
except ImportError:
Expand Down
4 changes: 3 additions & 1 deletion nemoguardrails/library/guardrails_ai/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,9 @@ def get_validator_info(validator_path: str) -> Dict[str, str]:
# not in registry, try to fetch from hub
try:
try:
from guardrails.hub.validator_package_service import get_validator_manifest
from guardrails.hub.validator_package_service import ( # type: ignore
get_validator_manifest,
)
except ImportError:
raise GuardrailsAIConfigError(
"Could not import get_validator_manifest. "
Expand Down
Loading