Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions nemoguardrails/library/content_safety/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from nemoguardrails.llm.cache.utils import (
CacheEntry,
create_normalized_cache_key,
extract_llm_metadata_for_cache,
extract_llm_stats_for_cache,
get_from_cache_and_restore_stats,
)
Expand Down Expand Up @@ -110,6 +111,7 @@ async def content_safety_check_input(
cache_entry: CacheEntry = {
"result": final_result,
"llm_stats": extract_llm_stats_for_cache(),
"llm_metadata": extract_llm_metadata_for_cache(),
}
cache.put(cache_key, cache_entry)
log.debug(f"Content safety result cached for model '{model_name}'")
Expand Down Expand Up @@ -139,6 +141,7 @@ async def content_safety_check_output(
llm_task_manager: LLMTaskManager,
model_name: Optional[str] = None,
context: Optional[dict] = None,
model_caches: Optional[Dict[str, CacheInterface]] = None,
**kwargs,
) -> dict:
_MAX_TOKENS = 3
Expand Down Expand Up @@ -176,12 +179,22 @@ async def content_safety_check_output(
"bot_response": bot_response,
},
)

stop = llm_task_manager.get_stop_tokens(task=task)
max_tokens = llm_task_manager.get_max_tokens(task=task)

llm_call_info_var.set(LLMCallInfo(task=task))

max_tokens = max_tokens or _MAX_TOKENS

llm_call_info_var.set(LLMCallInfo(task=task))
cache = model_caches.get(model_name) if model_caches else None

if cache:
cache_key = create_normalized_cache_key(check_output_prompt)
cached_result = get_from_cache_and_restore_stats(cache, cache_key)
if cached_result is not None:
log.debug(f"Content safety output cache hit for model '{model_name}'")
return cached_result

result = await llm_call(
llm,
Expand All @@ -194,4 +207,16 @@ async def content_safety_check_output(

is_safe, *violated_policies = result

return {"allowed": is_safe, "policy_violations": violated_policies}
final_result = {"allowed": is_safe, "policy_violations": violated_policies}

if cache:
cache_key = create_normalized_cache_key(check_output_prompt)
cache_entry: CacheEntry = {
"result": final_result,
"llm_stats": extract_llm_stats_for_cache(),
"llm_metadata": extract_llm_metadata_for_cache(),
}
cache.put(cache_key, cache_entry)
log.debug(f"Content safety output result cached for model '{model_name}'")

return final_result
32 changes: 31 additions & 1 deletion nemoguardrails/library/topic_safety/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,14 @@
from nemoguardrails.actions.actions import action
from nemoguardrails.actions.llm.utils import llm_call
from nemoguardrails.context import llm_call_info_var
from nemoguardrails.llm.cache import CacheInterface
from nemoguardrails.llm.cache.utils import (
CacheEntry,
create_normalized_cache_key,
extract_llm_metadata_for_cache,
extract_llm_stats_for_cache,
get_from_cache_and_restore_stats,
)
from nemoguardrails.llm.filters import to_chat_messages
from nemoguardrails.llm.taskmanager import LLMTaskManager
from nemoguardrails.logging.explain import LLMCallInfo
Expand All @@ -35,6 +43,7 @@ async def topic_safety_check_input(
model_name: Optional[str] = None,
context: Optional[dict] = None,
events: Optional[List[dict]] = None,
model_caches: Optional[Dict[str, CacheInterface]] = None,
**kwargs,
) -> dict:
_MAX_TOKENS = 10
Expand Down Expand Up @@ -102,11 +111,32 @@ async def topic_safety_check_input(
messages.extend(conversation_history)
messages.append({"type": "user", "content": user_input})

cache = model_caches.get(model_name) if model_caches else None

if cache:
cache_key = create_normalized_cache_key(messages)
cached_result = get_from_cache_and_restore_stats(cache, cache_key)
if cached_result is not None:
log.debug(f"Topic safety cache hit for model '{model_name}'")
return cached_result

result = await llm_call(llm, messages, stop=stop, llm_params={"temperature": 0.01})

if result.lower().strip() == "off-topic":
on_topic = False
else:
on_topic = True

return {"on_topic": on_topic}
final_result = {"on_topic": on_topic}

if cache:
cache_key = create_normalized_cache_key(messages)
cache_entry: CacheEntry = {
"result": final_result,
"llm_stats": extract_llm_stats_for_cache(),
"llm_metadata": extract_llm_metadata_for_cache(),
}
cache.put(cache_key, cache_entry)
log.debug(f"Topic safety result cached for model '{model_name}'")

return final_result
3 changes: 3 additions & 0 deletions nemoguardrails/llm/cache/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,9 @@ def get_from_cache_and_restore_stats(
if cached_metadata:
restore_llm_metadata_from_cache(cached_metadata)

if cached_metadata:
restore_llm_metadata_from_cache(cached_metadata)

processing_log = processing_log_var.get()
if processing_log is not None:
llm_call_info = llm_call_info_var.get()
Expand Down
111 changes: 110 additions & 1 deletion tests/test_content_safety_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@
import pytest

from nemoguardrails.context import llm_call_info_var, llm_stats_var
from nemoguardrails.library.content_safety.actions import content_safety_check_input
from nemoguardrails.library.content_safety.actions import (
content_safety_check_input,
content_safety_check_output,
)
from nemoguardrails.llm.cache.lfu import LFUCache
from nemoguardrails.llm.cache.utils import create_normalized_cache_key
from nemoguardrails.logging.explain import LLMCallInfo
Expand Down Expand Up @@ -95,6 +98,7 @@ async def test_content_safety_cache_retrieves_result_and_restores_stats(
"prompt_tokens": 80,
"completion_tokens": 20,
},
"llm_metadata": None,
}
cache_key = create_normalized_cache_key("test prompt")
cache.put(cache_key, cache_entry)
Expand Down Expand Up @@ -138,6 +142,7 @@ async def test_content_safety_cache_duration_reflects_cache_read_time(
"prompt_tokens": 40,
"completion_tokens": 10,
},
"llm_metadata": None,
}
cache_key = create_normalized_cache_key("test prompt")
cache.put(cache_key, cache_entry)
Expand Down Expand Up @@ -191,6 +196,7 @@ async def test_content_safety_cache_handles_missing_stats_gracefully(
cache_entry = {
"result": {"allowed": True, "policy_violations": []},
"llm_stats": None,
"llm_metadata": None,
}
cache_key = create_normalized_cache_key("test_key")
cache.put(cache_key, cache_entry)
Expand All @@ -213,3 +219,106 @@ async def test_content_safety_cache_handles_missing_stats_gracefully(

assert result["allowed"] is True
assert llm_stats.get_stat("total_calls") == 0


@pytest.mark.asyncio
async def test_content_safety_check_output_cache_stores_result(
fake_llm_with_stats, mock_task_manager
):
cache = LFUCache(maxsize=10)
mock_task_manager.parse_task_output.return_value = [True, "policy2"]

result = await content_safety_check_output(
llms=fake_llm_with_stats,
llm_task_manager=mock_task_manager,
model_name="test_model",
context={"user_message": "test user input", "bot_message": "test bot response"},
model_caches={"test_model": cache},
)

assert result["allowed"] is True
assert result["policy_violations"] == ["policy2"]
assert cache.size() == 1


@pytest.mark.asyncio
async def test_content_safety_check_output_cache_hit(
fake_llm_with_stats, mock_task_manager
):
cache = LFUCache(maxsize=10)

cache_entry = {
"result": {"allowed": False, "policy_violations": ["unsafe_output"]},
"llm_stats": {
"total_tokens": 75,
"prompt_tokens": 60,
"completion_tokens": 15,
},
"llm_metadata": None,
}
cache_key = create_normalized_cache_key("test output prompt")
cache.put(cache_key, cache_entry)

mock_task_manager.render_task_prompt.return_value = "test output prompt"

llm_stats = LLMStats()
llm_stats_var.set(llm_stats)

result = await content_safety_check_output(
llms=fake_llm_with_stats,
llm_task_manager=mock_task_manager,
model_name="test_model",
context={"user_message": "user", "bot_message": "bot"},
model_caches={"test_model": cache},
)

assert result["allowed"] is False
assert result["policy_violations"] == ["unsafe_output"]
assert llm_stats.get_stat("total_calls") == 1
assert llm_stats.get_stat("total_tokens") == 75

llm_call_info = llm_call_info_var.get()
assert llm_call_info.from_cache is True


@pytest.mark.asyncio
async def test_content_safety_check_output_cache_miss(
fake_llm_with_stats, mock_task_manager
):
cache = LFUCache(maxsize=10)

cache_entry = {
"result": {"allowed": True, "policy_violations": []},
"llm_stats": {
"total_tokens": 50,
"prompt_tokens": 40,
"completion_tokens": 10,
},
"llm_metadata": None,
}
cache_key = create_normalized_cache_key("different prompt")
cache.put(cache_key, cache_entry)

mock_task_manager.render_task_prompt.return_value = "new output prompt"
mock_task_manager.parse_task_output.return_value = [True, "policy2"]

llm_stats = LLMStats()
llm_stats_var.set(llm_stats)

llm_call_info = LLMCallInfo(task="content_safety_check_output $model=test_model")
llm_call_info_var.set(llm_call_info)

result = await content_safety_check_output(
llms=fake_llm_with_stats,
llm_task_manager=mock_task_manager,
model_name="test_model",
context={"user_message": "new user input", "bot_message": "new bot response"},
model_caches={"test_model": cache},
)

assert result["allowed"] is True
assert result["policy_violations"] == ["policy2"]
assert cache.size() == 2

llm_call_info = llm_call_info_var.get()
assert llm_call_info.from_cache is False
Loading