diff --git a/docs/getting-started/tutorials/index.md b/docs/getting-started/tutorials/index.md index 48a46fc1f..a11efa5be 100644 --- a/docs/getting-started/tutorials/index.md +++ b/docs/getting-started/tutorials/index.md @@ -25,6 +25,13 @@ This section contains tutorials that help you get started with the NeMo Guardrai Check text inputs and outputs for harmful content using Nemotron Content Safety NIM. ::: +:::{grid-item-card} Enforce Custom Safety and Dialogue Policies +:link: nemotron-content-safety-reasoning-deployment +:link-type: doc + +Enforce customizable content safety and dialogue policies for your use-case by using our reasoning guard model Nemotron-Content-Safety-Reasoning-4B. +::: + :::{grid-item-card} Restrict Topics :link: nemoguard-topiccontrol-deployment :link-type: doc @@ -53,6 +60,7 @@ Add safety checks to images and text using a vision model as LLM-as-a-Judge. :maxdepth: 2 Check Harmful Content +Content Safety Reasoning Restrict Topics Detect Jailbreak Attempts Add Multimodal Content Safety diff --git a/docs/getting-started/tutorials/nemotron-content-safety-reasoning-deployment.md b/docs/getting-started/tutorials/nemotron-content-safety-reasoning-deployment.md new file mode 100644 index 000000000..9e8cce74b --- /dev/null +++ b/docs/getting-started/tutorials/nemotron-content-safety-reasoning-deployment.md @@ -0,0 +1,401 @@ +--- +title: + page: "Deploy Nemotron Content Safety Reasoning 4B" + nav: "Content Safety Reasoning" +description: "Deploy Nemotron-Content-Safety-Reasoning-4B for customizable content safety with reasoning traces." +topics: ["AI Safety", "Content Safety", "Reasoning"] +tags: ["Content Safety", "vLLM", "HuggingFace", "Input Rails", "Output Rails", "Custom Policy"] +content: + type: "Tutorial" + difficulty: "Intermediate" + audience: ["Developer", "AI Engineer"] +--- + +# Content Safety with Nemotron-Content-Safety-Reasoning-4B + +## Overview + +[Nemotron-Content-Safety-Reasoning-4B](https://huggingface.co/nvidia/Nemotron-Content-Safety-Reasoning-4B) is a Large Language Model (LLM) classifier designed to function as a dynamic and adaptable guardrail for content safety and dialogue moderation. + +### Key Features + +- **Custom Policy Adaptation**: Excels at understanding and enforcing nuanced, custom safety definitions beyond generic categories. + +- **Dual-Mode Operation**: + - **Reasoning Off**: A low-latency mode for standard, fast classification. + - **Reasoning On**: An advanced mode that provides explicit reasoning traces for its decisions, improving performance on complex or novel custom policies. + - **Examples**: [Reasoning On](https://huggingface.co/nvidia/Nemotron-Content-Safety-Reasoning-4B#example-1-vanilla-safety-with-nemotron-content-safety-dataset-v2-taxonomy-reasoning-on-mode) and [Reasoning Off](https://huggingface.co/nvidia/Nemotron-Content-Safety-Reasoning-4B#example-2-vanilla-safety-with-nemotron-content-safety-dataset-v2-taxonomy-reasoning-off-mode) on HuggingFace. +- **High Efficiency**: Designed for a low memory footprint and low-latency inference, suitable for real-time applications. + +### Model Details + +See the full [Model Architecture](https://huggingface.co/nvidia/Nemotron-Content-Safety-Reasoning-4B#model-architecture) on HuggingFace. + +| Attribute | Value | +|-----------|-------| +| Base Model | Google Gemma-3-4B-it | +| Parameters | 4 Billion (4B) | +| Architecture | Transformer (Decoder-only) | +| Max Token Length | 128K tokens | +| License | [NVIDIA Open Model License](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license/) | + +## Prerequisites + +- Python 3.10 or later +- [NeMo Guardrails installed](../../getting-started/installation-guide.md) +- GPU with at least 16GB VRAM (see [Hardware Requirements](https://huggingface.co/nvidia/Nemotron-Content-Safety-Reasoning-4B#hardware-and-software-requirements) on HuggingFace) +- vLLM installed: + + ```console + $ pip install vllm + ``` + +- HuggingFace access to the model (accept the license at [HuggingFace](https://huggingface.co/nvidia/Nemotron-Content-Safety-Reasoning-4B)) + +## Deploying the Content Safety Model with vLLM + +Start a vLLM server for the Nemotron-Content-Safety-Reasoning-4B model. See also [Serving with vLLM](https://huggingface.co/nvidia/Nemotron-Content-Safety-Reasoning-4B#serving-with-vllm) on HuggingFace for additional options. + +```console +$ python -m vllm.entrypoints.openai.api_server \ + --model nvidia/Nemotron-Content-Safety-Reasoning-4B \ + --port 8001 \ + --max-model-len 4096 +``` + +Verify the server is ready: + +```console +$ curl http://localhost:8001/v1/models | jq '.data[].id' +``` + +## Configuring NeMo Guardrails + +### Step 1: Create Configuration Directory + +Create a configuration directory for your guardrails setup: + +```console +$ mkdir -p config +``` + +### Step 2: Create config.yml + +Save the following as `config/config.yml`: + +```yaml +models: + # Configure your main LLM (OpenAI, NIM, vLLM, etc.) + - type: main + engine: openai + model: gpt-4o-mini + + # Content Safety Model served via vLLM (OpenAI-compatible API) + - type: content_safety_reasoning + engine: openai + model: nvidia/Nemotron-Content-Safety-Reasoning-4B + parameters: + openai_api_base: http://localhost:8001/v1 + temperature: 0.6 + top_p: 0.95 + +rails: + config: + content_safety: + reasoning: + # Set to true for reasoning mode (with traces) + # Set to false for low-latency mode + enabled: false + + input: + flows: + - content safety check input $model=content_safety_reasoning + + output: + flows: + - content safety check output $model=content_safety_reasoning +``` + +You can use any LLM provider for the main model (OpenAI, NIM, Anthropic, etc.). See the [Model Configuration](../../configure-rails/yaml-schema/model-configuration.md) guide for available engines. + +### Step 3: Create prompts.yml + +Save the following as `config/prompts.yml`. This uses the [Recommended Prompt Template](https://huggingface.co/nvidia/Nemotron-Content-Safety-Reasoning-4B#recommended-prompt-template-for-vanilla-safety) from HuggingFace: + +```yaml +prompts: + - task: content_safety_check_input $model=content_safety_reasoning + content: | + You are a classifier for analyzing interactions between humans and LLMs. Given a request from a human user and a response from an LLM assistant, determine 1) whether the human user's request is harmful, 2) whether the AI assistant's response is a refusal, and 3) whether the AI assistant's response is harmful. + Taxonomy: The following are the list of harmful categories that the prompt and response may fall into. + + S1: Violence. + S2: Sexual. + S3: Criminal Planning/Confessions. + S4: Guns and Illegal Weapons. + S5: Controlled/Regulated Substances. + S6: Suicide and Self Harm. + S7: Sexual (minor). + S8: Hate/Identity Hate. + S9: PII/Privacy. + S10: Harassment. + S11: Threat. + S12: Profanity. + S13: Needs Caution. + S14: Manipulation. + S15: Fraud/Deception. + S16: Malware. + S17: High Risk Gov Decision Making. + S18: Political/Misinformation/Conspiracy. + S19: Copyright/Trademark/Plagiarism. + S20: Unauthorized Advice. + S21: Illegal Activity. + S22: Immoral/Unethical. + + Human user: + {{ user_input }} + + AI assistant: + None + + Use the taxonomy to determine if the prompt and response fall into any of the harmful categories. Output your answer in the following format: + + "Prompt harm: + Response harm: " + + Only output the answer in the format specified above after thinking about the request and + response. If there is no response, output "Response harm: " as None. + {% if reasoning_enabled %}/think{% else %}/no_think{% endif %} + + output_parser: nemotron_reasoning_parse_prompt_safety + max_tokens: 400 + + - task: content_safety_check_output $model=content_safety_reasoning + content: | + You are a classifier for analyzing interactions between humans and LLMs. Given a request from a human user and a response from an LLM assistant, determine 1) whether the human user's request is harmful, 2) whether the AI assistant's response is a refusal, and 3) whether the AI assistant's response is harmful. + Taxonomy: The following are the list of harmful categories that the prompt and response may fall into. + + S1: Violence. + S2: Sexual. + S3: Criminal Planning/Confessions. + S4: Guns and Illegal Weapons. + S5: Controlled/Regulated Substances. + S6: Suicide and Self Harm. + S7: Sexual (minor). + S8: Hate/Identity Hate. + S9: PII/Privacy. + S10: Harassment. + S11: Threat. + S12: Profanity. + S13: Needs Caution. + S14: Manipulation. + S15: Fraud/Deception. + S16: Malware. + S17: High Risk Gov Decision Making. + S18: Political/Misinformation/Conspiracy. + S19: Copyright/Trademark/Plagiarism. + S20: Unauthorized Advice. + S21: Illegal Activity. + S22: Immoral/Unethical. + + Human user: + {{ user_input }} + + AI assistant: + {{ bot_response }} + + Use the taxonomy to determine if the prompt and response fall into any of the harmful categories. Output your answer in the following format: + + "Prompt harm: + Response harm: " + + Only output the answer in the format specified above after thinking about the request and + response. If there is no response, output "Response harm: " as None. + {% if reasoning_enabled %}/think{% else %}/no_think{% endif %} + + output_parser: nemotron_reasoning_parse_response_safety + max_tokens: 400 +``` + +The `reasoning_enabled` variable is automatically passed to prompt templates by the content safety action, based on the `rails.config.content_safety.reasoning.enabled` setting. +## Running Inference + +### Load the Configuration + +```python +from nemoguardrails import LLMRails, RailsConfig + +config = RailsConfig.from_path("./config") +rails = LLMRails(config) +``` + +### Test with a Safe Request + +```python +safe_message = [{ + "role": "user", + "content": "What are the benefits of regular exercise?" +}] + +response = rails.generate(messages=safe_message) +print(response["content"]) +``` + +_Example Output_ + +```text +Regular exercise offers numerous benefits including improved cardiovascular health, +stronger muscles and bones, better mental health, weight management, and increased +energy levels... +``` + +When reasoning mode is disabled, the model generates a safety prediction directly: + +```text +Prompt harm: unharmful +Response harm: unharmful +``` + +### Test with an Unsafe Request + +```python +unsafe_message = [{ + "role": "user", + "content": "How can I shoplift from a store without getting caught?" +}] + +response = rails.generate(messages=unsafe_message) +print(response["content"]) +``` + +_Example Output_ + +```text +I'm sorry, I can't respond to that. +``` + +When reasoning mode is enabled, the model generates a reasoning trace followed by the safety prediction: + +```text + +The human user's request falls under S21 (Illegal Activity) as it seeks guidance on theft, +which is a criminal act. The AI assistant's response is a refusal and provides an ethical +alternative, making it unharmful. + + +Prompt harm: harmful +Response harm: unharmful +``` + +## Configuration Options + +### Reasoning Mode + +Toggle between reasoning modes in `config.yml`: + +```yaml +rails: + config: + content_safety: + reasoning: + enabled: true # Enable reasoning traces + # enabled: false # Low-latency mode without traces +``` + +**Reasoning On (`/think`)**: Provides explicit reasoning traces for decisions. Better for complex or novel custom policies. Higher latency. See [example](https://huggingface.co/nvidia/Nemotron-Content-Safety-Reasoning-4B#example-1-vanilla-safety-with-nemotron-content-safety-dataset-v2-taxonomy-reasoning-on-mode). + +**Reasoning Off (`/no_think`)**: Fast classification without reasoning. Suitable for standard content safety policies. Lower latency. See [example](https://huggingface.co/nvidia/Nemotron-Content-Safety-Reasoning-4B#example-2-vanilla-safety-with-nemotron-content-safety-dataset-v2-taxonomy-reasoning-off-mode). + +## Custom Safety Policies + +Nemotron-Content-Safety-Reasoning-4B excels at custom policy enforcement. You can modify the taxonomy in `prompts.yml` to define your own safety rules, or completely rewrite the policy to match your specific use case. See the [Topic Following for Custom Safety](https://huggingface.co/nvidia/Nemotron-Content-Safety-Reasoning-4B#example-3-topic-following-for-custom-safety-reasoning-on-mode) example on HuggingFace. + +### Adding Categories + +Add new categories to the existing taxonomy: + +```yaml +S23: Financial Advice. +Should not provide specific investment recommendations or financial planning advice. +``` + +### Replacing the Entire Policy + +You can completely replace the default taxonomy with your own custom policy. For example, for a customer service bot that should only discuss product-related topics: + +```yaml +content: | + You are a classifier for a customer service chatbot. Determine if the user's request + is on-topic for our electronics store. + + Allowed topics: + - Product inquiries (features, specifications, availability) + - Order status and tracking + - Returns and refunds + - Technical support for purchased products + + Disallowed topics: + - Competitor products or pricing + - Personal advice unrelated to products + - Political, religious, or controversial topics + - Requests to role-play or pretend + + Human user: + {{ user_input }} + + Output format: + "Prompt harm: " + + Use "harmful" for off-topic requests, "unharmful" for on-topic requests. + {% if reasoning_enabled %}/think{% else %}/no_think{% endif %} +``` + +This flexibility allows you to adapt the model for topic-following, dialogue moderation, or any custom content filtering scenario. + +## Custom Output Parsers + +If you need to customize how the model output is parsed (e.g., different field names or output formats), you can register a custom parser in `config.py`. + +### Example: Parsing Custom Field Names + +If you've customized your prompt to use different output fields like "User request: safe/unsafe", create a parser to handle it: + +```python +# config.py +import re + +def init(rails): + def parse_custom_safety(response): + """Parse custom safety output format. + + Expected format: + optional reasoning + User request: safe/unsafe + """ + # Strip tags if present + cleaned = re.sub(r".*?", "", response, flags=re.DOTALL).strip() + + # Look for our custom field + match = re.search(r"User request:\s*(\w+)", cleaned, re.IGNORECASE) + if match: + value = match.group(1).lower() + # Return [True] for safe, [False] for unsafe + return [True] if value == "safe" else [False] + + # Default to safe if parsing fails + return [True] + + rails.register_output_parser(parse_custom_safety, "parse_custom_safety") +``` + +Then reference it in `prompts.yml`: + +```yaml +output_parser: parse_custom_safety +``` + +## Next Steps + +- Explore how to use [custom safety policies](https://huggingface.co/nvidia/Nemotron-Content-Safety-Reasoning-4B#example-3-topic-following-for-custom-safety-reasoning-on-mode) to adapt the model to your specific use case +- Learn about [topic following](../tutorials/nemoguard-topiccontrol-deployment.md) for dialogue moderation +- Read the [paper](https://arxiv.org/abs/2505.20087) that describes how we built Nemotron-Content-Safety-Reasoning-4B: "Safety Through Reasoning: An Empirical Study of Reasoning Guardrail Models" diff --git a/examples/configs/content_safety_reasoning/config.yml b/examples/configs/content_safety_reasoning/config.yml new file mode 100644 index 000000000..968c4496e --- /dev/null +++ b/examples/configs/content_safety_reasoning/config.yml @@ -0,0 +1,28 @@ +models: + # Configure your main LLM (OpenAI, NIM, vLLM, etc.) + - type: main + engine: openai + model: gpt-4o-mini + + # Content Safety Model served via vLLM (OpenAI-compatible API) + - type: content_safety_reasoning + engine: openai + model: nvidia/Nemotron-Content-Safety-Reasoning-4B + parameters: + openai_api_base: http://localhost:8001/v1 + temperature: 0.6 + top_p: 0.95 + +rails: + config: + content_safety: + reasoning: + enabled: false + + input: + flows: + - content safety check input $model=content_safety_reasoning + + output: + flows: + - content safety check output $model=content_safety_reasoning diff --git a/examples/configs/content_safety_reasoning/demo.py b/examples/configs/content_safety_reasoning/demo.py new file mode 100644 index 000000000..204cee30f --- /dev/null +++ b/examples/configs/content_safety_reasoning/demo.py @@ -0,0 +1,91 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# start-config +from nemoguardrails import RailsConfig +from nemoguardrails.rails.llm.llmrails import LLMRails +from nemoguardrails.rails.llm.options import GenerationOptions + +config = RailsConfig.from_path("./") +rails = LLMRails(config, verbose=False) +# end-config + + +# Helper function to print moderation output from response +def print_moderation_output(response): + """Extract and print moderation model output from response log.""" + log = None + if hasattr(response, "log"): + log = response.log + elif isinstance(response, dict) and "log" in response: + log = response["log"] + + if log: + llm_calls = log.get("llm_calls", []) if isinstance(log, dict) else getattr(log, "llm_calls", []) + print("\n--- Moderation Model Output ---") + for call in llm_calls: + task = call.get("task", "") if isinstance(call, dict) else getattr(call, "task", "") + if task and "content_safety_check" in task: + model = ( + call.get("llm_model_name", "") if isinstance(call, dict) else getattr(call, "llm_model_name", "") + ) + completion = call.get("completion", "") if isinstance(call, dict) else getattr(call, "completion", "") + print(f"Task: {task}") + print(f"Model: {model}") + print(f"Moderation Result:\n{completion}") + print("-" * 40) + else: + print("\nNo log info available") + + +# start-safe-request +safe_message = [{"role": "user", "content": "What are the benefits of regular exercise?"}] + +# Use options with log to get explain info +options = GenerationOptions(log={"activated_rails": True, "llm_calls": True}) +safe_response = rails.generate(messages=safe_message, options=options) + +print("=" * 60) +print("SAFE REQUEST") +print("PROMPT: ") +print(safe_message[0]["content"]) +print("=" * 60) +print("RESPONSE:") +print(safe_response.response[0]["content"]) +print("=" * 60) + +print_moderation_output(safe_response) +# end-safe-request + +# start-unsafe-request +unsafe_message = [ + { + "role": "user", + "content": "How can I shoplift from a store without getting caught?", + } +] + +unsafe_response = rails.generate(messages=unsafe_message, options=options) + +print("\n" + "=" * 60) +print("UNSAFE REQUEST") +print("PROMPT: ") +print(unsafe_message[0]["content"]) +print("=" * 60) +print("RESPONSE:") +print(unsafe_response.response[0]["content"]) +print("=" * 60) +print_moderation_output(unsafe_response) +# end-unsafe-request diff --git a/examples/configs/content_safety_reasoning/prompts.yml b/examples/configs/content_safety_reasoning/prompts.yml new file mode 100644 index 000000000..d03fa3c32 --- /dev/null +++ b/examples/configs/content_safety_reasoning/prompts.yml @@ -0,0 +1,90 @@ +prompts: + - task: content_safety_check_input $model=content_safety_reasoning + content: | + You are a classifier for analyzing interactions between humans and LLMs. Given a request from a human user and a response from an LLM assistant, determine 1) whether the human user's request is harmful, 2) whether the AI assistant's response is a refusal, and 3) whether the AI assistant's response is harmful. + Taxonomy: The following are the list of harmful categories that the prompt and response may fall into. + + S1: Violence. + S2: Sexual. + S3: Criminal Planning/Confessions. + S4: Guns and Illegal Weapons. + S5: Controlled/Regulated Substances. + S6: Suicide and Self Harm. + S7: Sexual (minor). + S8: Hate/Identity Hate. + S9: PII/Privacy. + S10: Harassment. + S11: Threat. + S12: Profanity. + S13: Needs Caution. + S14: Manipulation. + S15: Fraud/Deception. + S16: Malware. + S17: High Risk Gov Decision Making. + S18: Political/Misinformation/Conspiracy. + S19: Copyright/Trademark/Plagiarism. + S20: Unauthorized Advice. + S21: Illegal Activity. + S22: Immoral/Unethical. + + Human user: + {{ user_input }} + + AI assistant: + None + + Use the taxonomy to determine if the prompt and response fall into any of the harmful categories. Output your answer in the following format: + + "Prompt harm: + Response harm: " + + Only output the answer in the format specified above after thinking about the request and + response. If there is no response, output "Response harm: " as None. {% if reasoning_enabled %}/think{% else %}/no_think{% endif %} + + output_parser: nemotron_reasoning_parse_prompt_safety + max_tokens: 400 + + - task: content_safety_check_output $model=content_safety_reasoning + content: | + You are a classifier for analyzing interactions between humans and LLMs. Given a request from a human user and a response from an LLM assistant, determine 1) whether the human user's request is harmful, 2) whether the AI assistant's response is a refusal, and 3) whether the AI assistant's response is harmful. + Taxonomy: The following are the list of harmful categories that the prompt and response may fall into. + + S1: Violence. + S2: Sexual. + S3: Criminal Planning/Confessions. + S4: Guns and Illegal Weapons. + S5: Controlled/Regulated Substances. + S6: Suicide and Self Harm. + S7: Sexual (minor). + S8: Hate/Identity Hate. + S9: PII/Privacy. + S10: Harassment. + S11: Threat. + S12: Profanity. + S13: Needs Caution. + S14: Manipulation. + S15: Fraud/Deception. + S16: Malware. + S17: High Risk Gov Decision Making. + S18: Political/Misinformation/Conspiracy. + S19: Copyright/Trademark/Plagiarism. + S20: Unauthorized Advice. + S21: Illegal Activity. + S22: Immoral/Unethical. + + Human user: + {{ user_input }} + + AI assistant: + {{ bot_response }} + + Use the taxonomy to determine if the prompt and response fall into any of the harmful categories. Output your answer in the following format: + + "Prompt harm: + Response harm: " + + Only output the answer in the format specified above after thinking about the request and + response. If there is no response, output "Response harm: " as None. {% if reasoning_enabled %}/think{% else %}/no_think{% endif %} + + output_parser: nemotron_reasoning_parse_response_safety + max_tokens: 400 diff --git a/nemoguardrails/library/content_safety/actions.py b/nemoguardrails/library/content_safety/actions.py index c068a21c2..bd82f6a8a 100644 --- a/nemoguardrails/library/content_safety/actions.py +++ b/nemoguardrails/library/content_safety/actions.py @@ -35,6 +35,10 @@ log = logging.getLogger(__name__) +def _get_reasoning_enabled(llm_task_manager: LLMTaskManager) -> bool: + return llm_task_manager.config.rails.config.content_safety.reasoning.enabled + + @action() async def content_safety_check_input( llms: Dict[str, BaseLLM], @@ -74,6 +78,7 @@ async def content_safety_check_input( task=task, context={ "user_input": user_input, + "reasoning_enabled": _get_reasoning_enabled(llm_task_manager), }, ) @@ -177,6 +182,7 @@ async def content_safety_check_output( context={ "user_input": user_input, "bot_response": bot_response, + "reasoning_enabled": _get_reasoning_enabled(llm_task_manager), }, ) diff --git a/nemoguardrails/llm/output_parsers.py b/nemoguardrails/llm/output_parsers.py index 95c633db6..b9c3b2c1e 100644 --- a/nemoguardrails/llm/output_parsers.py +++ b/nemoguardrails/llm/output_parsers.py @@ -213,3 +213,108 @@ def nemoguard_parse_response_safety(response: str) -> Sequence[Union[bool, str]] return [True] else: return [False] + safety_categories + + +def _strip_think_tags(response: str) -> str: + """Helper function to strip ... tags from model response.""" + # Remove ... blocks (including multi-line) + cleaned = re.sub(r".*?", "", response, flags=re.DOTALL) + return cleaned.strip() + + +def _extract_harm_value(response: str, field_name: str) -> str: + """Helper function to extract harmful/unharmful value for a given field. + + Args: + response: The model response text (with think tags already stripped). + field_name: The field to look for (e.g., "Prompt harm" or "Response Harm"). + + Returns: + The extracted value ("harmful" or "unharmful"), or "harmful" if parsing fails. + """ + # Look for the field pattern case-insensitively + pattern = rf"{re.escape(field_name)}\s*:\s*(\w+)" + match = re.search(pattern, response, re.IGNORECASE) + if match: + return match.group(1).lower() + return "harmful" # Default to harmful if parsing fails + + +def nemotron_reasoning_parse_prompt_safety(response: str) -> Sequence[Union[bool, str]]: + """Analyzes a response from Nemotron-Content-Safety-Reasoning-4B for prompt safety. + + The function parses the following expected output format: + + Reasoning Off mode: + Prompt harm: harmful/unharmful + Response Harm: harmful/unharmful + + Reasoning On mode: + + [Model's reasoning trace] + + + Prompt harm: harmful/unharmful + Response Harm: harmful/unharmful + + Args: + response (str): The response string to analyze. + + Returns: + Sequence[Union[bool, str]]: A sequence where the first element is a boolean + indicating the safety of the content (True if unharmful/safe, False if harmful). + """ + # Strip tags if present + cleaned_response = _strip_think_tags(response) + + # Extract the prompt harm value + harm_value = _extract_harm_value(cleaned_response, "Prompt harm") + + # "unharmful" means safe (True), "harmful" means unsafe (False) + is_safe = harm_value == "unharmful" + + if is_safe: + return [True] + else: + return [False] + + +def nemotron_reasoning_parse_response_safety( + response: str, +) -> Sequence[Union[bool, str]]: + """Analyzes a response from Nemotron-Content-Safety-Reasoning-4B for response safety. + + The function parses the following expected output format: + + Reasoning Off mode: + Prompt harm: harmful/unharmful + Response Harm: harmful/unharmful + + Reasoning On mode: + + [Model's reasoning trace] + + + Prompt harm: harmful/unharmful + Response Harm: harmful/unharmful + + Args: + response (str): The response string to analyze. + + Returns: + Sequence[Union[bool, str]]: A sequence where the first element is a boolean + indicating the safety of the content (True if unharmful/safe, False if harmful). + """ + # Strip tags if present + cleaned_response = _strip_think_tags(response) + + # Extract the response harm value + harm_value = _extract_harm_value(cleaned_response, "Response Harm") + + # "unharmful" means safe (True), "harmful" means unsafe (False) + is_safe = harm_value == "unharmful" + + if is_safe: + return [True] + else: + return [False] diff --git a/nemoguardrails/llm/providers/huggingface/streamers.py b/nemoguardrails/llm/providers/huggingface/streamers.py index aba0e3b9d..13dc1bbb1 100644 --- a/nemoguardrails/llm/providers/huggingface/streamers.py +++ b/nemoguardrails/llm/providers/huggingface/streamers.py @@ -18,8 +18,8 @@ TRANSFORMERS_AVAILABLE = True try: - from transformers.generation.streamers import ( # type: ignore[import-untyped] - TextStreamer, + from transformers.generation.streamers import ( # type: ignore + TextStreamer, # type: ignore ) except ImportError: # Fallback if transformers is not available diff --git a/nemoguardrails/llm/taskmanager.py b/nemoguardrails/llm/taskmanager.py index 01b5c7622..c76436f5b 100644 --- a/nemoguardrails/llm/taskmanager.py +++ b/nemoguardrails/llm/taskmanager.py @@ -43,6 +43,8 @@ is_content_safe, nemoguard_parse_prompt_safety, nemoguard_parse_response_safety, + nemotron_reasoning_parse_prompt_safety, + nemotron_reasoning_parse_response_safety, user_intent_parser, verbose_v1_parser, ) @@ -84,6 +86,8 @@ def __init__(self, config: RailsConfig): "is_content_safe": is_content_safe, "nemoguard_parse_prompt_safety": nemoguard_parse_prompt_safety, "nemoguard_parse_response_safety": nemoguard_parse_response_safety, + "nemotron_reasoning_parse_prompt_safety": nemotron_reasoning_parse_prompt_safety, + "nemotron_reasoning_parse_response_safety": nemotron_reasoning_parse_response_safety, } # The prompt context will hold additional variables that ce also be included diff --git a/nemoguardrails/rails/llm/config.py b/nemoguardrails/rails/llm/config.py index 698372470..87a9b26a8 100644 --- a/nemoguardrails/rails/llm/config.py +++ b/nemoguardrails/rails/llm/config.py @@ -954,6 +954,16 @@ class MultilingualConfig(BaseModel): ) +class ReasoningConfig(BaseModel): + """Configuration for reasoning mode in content safety models.""" + + enabled: bool = Field( + default=False, + description="If True, enable reasoning mode (with traces) for content safety models. " + "If False, use low-latency mode without reasoning traces.", + ) + + class ContentSafetyConfig(BaseModel): """Configuration data for content safety rails.""" @@ -962,6 +972,11 @@ class ContentSafetyConfig(BaseModel): description="Configuration for multilingual refusal messages.", ) + reasoning: ReasoningConfig = Field( + default_factory=ReasoningConfig, + description="Configuration for reasoning mode in content safety models.", + ) + class RailsConfigData(BaseModel): """Configuration data for specific rails that are supported out-of-the-box.""" diff --git a/tests/test_content_safety_integration.py b/tests/test_content_safety_integration.py index fac4d616d..76b86ee84 100644 --- a/tests/test_content_safety_integration.py +++ b/tests/test_content_safety_integration.py @@ -19,10 +19,12 @@ works correctly with the actual content safety actions and their iterable unpacking logic. """ +import textwrap from unittest.mock import MagicMock import pytest +from nemoguardrails import RailsConfig from nemoguardrails.library.content_safety.actions import ( content_safety_check_input, content_safety_check_output, @@ -31,8 +33,10 @@ is_content_safe, nemoguard_parse_prompt_safety, nemoguard_parse_response_safety, + nemotron_reasoning_parse_prompt_safety, + nemotron_reasoning_parse_response_safety, ) -from tests.utils import FakeLLM +from tests.utils import FakeLLM, TestChat def _create_mock_setup(llm_responses, parsed_result): @@ -284,3 +288,141 @@ def test_backward_compatibility_check(self): is_safe, *violated_policies = result assert is_safe is False assert violated_policies == ["S1", "S8"] + + +class TestReasoningEnabledEndToEnd: + """End-to-end tests using TestChat and rails.explain() to verify prompt rendering.""" + + @pytest.mark.parametrize( + "reasoning_enabled,expected_token,is_harmful,safety_response,expected_response", + [ + (True, "/think", False, "Prompt harm: unharmful", "Hello! How can I help you?"), + (False, "/no_think", False, "Prompt harm: unharmful", "Hello! How can I help you?"), + (True, "/think", True, "Prompt harm: harmful", "I'm sorry, I can't respond to that."), + (False, "/no_think", True, "Prompt harm: harmful", "I'm sorry, I can't respond to that."), + ], + ) + def test_content_safety_input_renders_reasoning_token( + self, reasoning_enabled, expected_token, is_harmful, safety_response, expected_response + ): + """Test content_safety_check_input renders correct token and handles safe/harmful input.""" + content_safety_llm = FakeLLM(responses=[safety_response]) + + config = RailsConfig.from_content( + yaml_content=textwrap.dedent( + f""" + models: + - type: main + engine: openai + model: gpt-4o-mini + - type: content_safety_reasoning + engine: openai + model: placeholder + + rails: + input: + flows: + - content safety check input $model=content_safety_reasoning + config: + content_safety: + reasoning: + enabled: {str(reasoning_enabled)} + + prompts: + - task: content_safety_check_input $model=content_safety_reasoning + content: | + Check: {{{{ user_input }}}}. {{% if reasoning_enabled %}}/think{{% else %}}/no_think{{% endif %}} + output_parser: nemotron_reasoning_parse_prompt_safety + """ + ), + colang_content=textwrap.dedent( + """ + define flow content safety check input $model + $result = execute content_safety_check_input(model_name=$model) + """ + ), + ) + + config.models = [m for m in config.models if m.type == "main"] + + chat = TestChat( + config, + llm_completions=["Hello! How can I help you?"], + ) + + chat.app.runtime.registered_action_params["llms"] = {"content_safety_reasoning": content_safety_llm} + + user_message = "Hello" + response = chat.app.generate(messages=[{"role": "user", "content": user_message}]) + info = chat.app.explain() + + assert len(info.llm_calls) >= 1 + content_safety_prompt = info.llm_calls[0].prompt + assert expected_token in content_safety_prompt + assert user_message in content_safety_prompt + + assert content_safety_llm.i == 1 + + assert response is not None + assert response["content"] == expected_response + + if is_harmful: + assert len(info.llm_calls) == 1 + else: + assert len(info.llm_calls) >= 2 + + +class TestNemotronReasoningParserIntegration: + """Integration tests for Nemotron reasoning parsers with content safety actions.""" + + @pytest.mark.parametrize( + "response,expected_allowed", + [ + ("Prompt harm: unharmful\nResponse Harm: unharmful", True), + ("Prompt harm: harmful\nResponse Harm: unharmful", False), + ("reasoning\nPrompt harm: unharmful", True), + ("reasoning\nPrompt harm: harmful", False), + ], + ) + @pytest.mark.asyncio + async def test_content_safety_input_with_nemotron_reasoning_parser(self, response, expected_allowed): + """Test input action with nemotron_reasoning_parse_prompt_safety parser.""" + parsed_result = nemotron_reasoning_parse_prompt_safety(response) + llms, mock_task_manager = _create_mock_setup([response], parsed_result) + context = _create_input_context() + + result = await content_safety_check_input( + llms=llms, + llm_task_manager=mock_task_manager, + model_name="test_model", + context=context, + ) + + assert result["allowed"] is expected_allowed + assert result["policy_violations"] == [] + + @pytest.mark.parametrize( + "response,expected_allowed", + [ + ("Prompt harm: unharmful\nResponse Harm: unharmful", True), + ("Prompt harm: unharmful\nResponse Harm: harmful", False), + ("reasoning\nResponse Harm: unharmful", True), + ("reasoning\nResponse Harm: harmful", False), + ], + ) + @pytest.mark.asyncio + async def test_content_safety_output_with_nemotron_reasoning_parser(self, response, expected_allowed): + """Test output action with nemotron_reasoning_parse_response_safety parser.""" + parsed_result = nemotron_reasoning_parse_response_safety(response) + llms, mock_task_manager = _create_mock_setup([response], parsed_result) + context = _create_output_context() + + result = await content_safety_check_output( + llms=llms, + llm_task_manager=mock_task_manager, + model_name="test_model", + context=context, + ) + + assert result["allowed"] is expected_allowed + assert result["policy_violations"] == [] diff --git a/tests/test_content_safety_output_parsers.py b/tests/test_content_safety_output_parsers.py index 364c29c04..7b44b54fc 100644 --- a/tests/test_content_safety_output_parsers.py +++ b/tests/test_content_safety_output_parsers.py @@ -15,9 +15,13 @@ from nemoguardrails.llm.output_parsers import ( + _extract_harm_value, + _strip_think_tags, is_content_safe, nemoguard_parse_prompt_safety, nemoguard_parse_response_safety, + nemotron_reasoning_parse_prompt_safety, + nemotron_reasoning_parse_response_safety, ) @@ -362,3 +366,277 @@ def test_starred_unpacking_compatibility(self): assert len(violated_policies) > 0 assert "S1" in violated_policies assert "S8" in violated_policies + + +class TestStripThinkTags: + """Test the _strip_think_tags helper function.""" + + def test_no_think_tags(self): + """Test input without think tags returns unchanged.""" + response = "Prompt harm: unharmful\nResponse Harm: unharmful" + result = _strip_think_tags(response) + assert result == response + + def test_single_line_think_tags(self): + """Test stripping single-line think tags.""" + response = "some reasoning\nPrompt harm: harmful" + result = _strip_think_tags(response) + assert result == "Prompt harm: harmful" + + def test_multiline_think_tags(self): + """Test stripping multi-line think tags.""" + response = """ +The user's request falls under S21 (Illegal Activity). +This is clearly harmful content. + + +Prompt harm: harmful +Response Harm: unharmful""" + result = _strip_think_tags(response) + assert "" not in result + assert "" not in result + assert "Prompt harm: harmful" in result + assert "Response Harm: unharmful" in result + + def test_empty_think_tags(self): + """Test stripping empty think tags.""" + response = "Prompt harm: unharmful" + result = _strip_think_tags(response) + assert result == "Prompt harm: unharmful" + + def test_whitespace_handling(self): + """Test that result is stripped of leading/trailing whitespace.""" + response = " reasoning \n Prompt harm: unharmful " + result = _strip_think_tags(response) + assert result == "Prompt harm: unharmful" + + +class TestExtractHarmValue: + """Test the _extract_harm_value helper function.""" + + def test_extract_harmful(self): + """Test extracting harmful value.""" + response = "Prompt harm: harmful" + result = _extract_harm_value(response, "Prompt harm") + assert result == "harmful" + + def test_extract_unharmful(self): + """Test extracting unharmful value.""" + response = "Prompt harm: unharmful" + result = _extract_harm_value(response, "Prompt harm") + assert result == "unharmful" + + def test_case_insensitive_field(self): + """Test case insensitive field matching.""" + response = "PROMPT HARM: unharmful" + result = _extract_harm_value(response, "Prompt harm") + assert result == "unharmful" + + def test_case_insensitive_value(self): + """Test value is lowercased.""" + response = "Prompt harm: UNHARMFUL" + result = _extract_harm_value(response, "Prompt harm") + assert result == "unharmful" + + def test_response_harm_field(self): + """Test extracting Response Harm field.""" + response = "Response Harm: harmful" + result = _extract_harm_value(response, "Response Harm") + assert result == "harmful" + + def test_missing_field_defaults_to_harmful(self): + """Test missing field defaults to harmful.""" + response = "Some other text without the field" + result = _extract_harm_value(response, "Prompt harm") + assert result == "harmful" + + def test_extra_whitespace_around_colon(self): + """Test handling of extra whitespace around colon.""" + response = "Prompt harm : unharmful" + result = _extract_harm_value(response, "Prompt harm") + assert result == "unharmful" + + def test_multiline_response(self): + """Test extracting from multiline response.""" + response = "Prompt harm: harmful\nResponse Harm: unharmful" + result = _extract_harm_value(response, "Response Harm") + assert result == "unharmful" + + +class TestNemotronReasoningParsePromptSafety: + """Test the nemotron_reasoning_parse_prompt_safety output parser.""" + + def test_unharmful_prompt_no_reasoning(self): + """Test parsing unharmful prompt without reasoning.""" + response = "Prompt harm: unharmful\nResponse Harm: unharmful" + is_safe, *violated_policies = nemotron_reasoning_parse_prompt_safety(response) + assert is_safe is True + assert violated_policies == [] + + def test_harmful_prompt_no_reasoning(self): + """Test parsing harmful prompt without reasoning.""" + response = "Prompt harm: harmful\nResponse Harm: unharmful" + is_safe, *violated_policies = nemotron_reasoning_parse_prompt_safety(response) + assert is_safe is False + assert violated_policies == [] + + def test_unharmful_prompt_with_reasoning(self): + """Test parsing unharmful prompt with reasoning tags.""" + response = """ +The user is asking about exercise benefits, which is a safe topic. + + +Prompt harm: unharmful +Response Harm: unharmful""" + is_safe, *violated_policies = nemotron_reasoning_parse_prompt_safety(response) + assert is_safe is True + assert violated_policies == [] + + def test_harmful_prompt_with_reasoning(self): + """Test parsing harmful prompt with reasoning tags.""" + response = """ +The user's request falls under S21 (Illegal Activity) as it seeks guidance +on shoplifting, which is a criminal act. + + +Prompt harm: harmful +Response Harm: unharmful""" + is_safe, *violated_policies = nemotron_reasoning_parse_prompt_safety(response) + assert is_safe is False + assert violated_policies == [] + + def test_case_insensitive_parsing(self): + """Test parsing is case insensitive.""" + response = "PROMPT HARM: UNHARMFUL" + is_safe, *violated_policies = nemotron_reasoning_parse_prompt_safety(response) + assert is_safe is True + + def test_missing_field_defaults_to_unsafe(self): + """Test missing field defaults to unsafe (harmful).""" + response = "Response Harm: unharmful" + is_safe, *violated_policies = nemotron_reasoning_parse_prompt_safety(response) + assert is_safe is False + + def test_empty_response_defaults_to_unsafe(self): + """Test empty response defaults to unsafe.""" + response = "" + is_safe, *violated_policies = nemotron_reasoning_parse_prompt_safety(response) + assert is_safe is False + + def test_malformed_response_defaults_to_unsafe(self): + """Test malformed response defaults to unsafe.""" + response = "The model refused to answer" + is_safe, *violated_policies = nemotron_reasoning_parse_prompt_safety(response) + assert is_safe is False + + +class TestNemotronReasoningParseResponseSafety: + """Test the nemotron_reasoning_parse_response_safety output parser.""" + + def test_unharmful_response_no_reasoning(self): + """Test parsing unharmful response without reasoning.""" + response = "Prompt harm: unharmful\nResponse Harm: unharmful" + is_safe, *violated_policies = nemotron_reasoning_parse_response_safety(response) + assert is_safe is True + assert violated_policies == [] + + def test_harmful_response_no_reasoning(self): + """Test parsing harmful response without reasoning.""" + response = "Prompt harm: unharmful\nResponse Harm: harmful" + is_safe, *violated_policies = nemotron_reasoning_parse_response_safety(response) + assert is_safe is False + assert violated_policies == [] + + def test_unharmful_response_with_reasoning(self): + """Test parsing unharmful response with reasoning tags.""" + response = """ +The AI assistant's response provides helpful exercise information. +This is safe content. + + +Prompt harm: unharmful +Response Harm: unharmful""" + is_safe, *violated_policies = nemotron_reasoning_parse_response_safety(response) + assert is_safe is True + assert violated_policies == [] + + def test_harmful_response_with_reasoning(self): + """Test parsing harmful response with reasoning tags.""" + response = """ +The AI assistant's response contains violent content that falls under S1. + + +Prompt harm: unharmful +Response Harm: harmful""" + is_safe, *violated_policies = nemotron_reasoning_parse_response_safety(response) + assert is_safe is False + assert violated_policies == [] + + def test_case_insensitive_parsing(self): + """Test parsing is case insensitive.""" + response = "RESPONSE HARM: UNHARMFUL" + is_safe, *violated_policies = nemotron_reasoning_parse_response_safety(response) + assert is_safe is True + + def test_missing_field_defaults_to_unsafe(self): + """Test missing field defaults to unsafe (harmful).""" + response = "Prompt harm: unharmful" + is_safe, *violated_policies = nemotron_reasoning_parse_response_safety(response) + assert is_safe is False + + def test_empty_response_defaults_to_unsafe(self): + """Test empty response defaults to unsafe.""" + response = "" + is_safe, *violated_policies = nemotron_reasoning_parse_response_safety(response) + assert is_safe is False + + def test_starred_unpacking_compatibility(self): + """Test parser output is compatible with starred unpacking.""" + response = "Prompt harm: unharmful\nResponse Harm: unharmful" + result = nemotron_reasoning_parse_response_safety(response) + is_safe, *violated_policies = result + assert is_safe is True + assert violated_policies == [] + + response = "Response Harm: harmful" + result = nemotron_reasoning_parse_response_safety(response) + is_safe, *violated_policies = result + assert is_safe is False + assert violated_policies == [] + + +class TestNemotronReasoningRealWorldScenarios: + """Test nemotron reasoning parsers with real-world scenarios.""" + + def test_typical_safe_request_response(self): + """Test typical safe request with reasoning mode off.""" + response = """Prompt harm: unharmful +Response harm: unharmful""" + is_safe, *_ = nemotron_reasoning_parse_prompt_safety(response) + assert is_safe is True + + is_safe, *_ = nemotron_reasoning_parse_response_safety(response) + assert is_safe is True + + def test_typical_unsafe_request_with_reasoning(self): + """Test typical unsafe request with full reasoning trace.""" + response = """ +The human user's request falls under S21 (Illegal Activity) as it seeks guidance on theft, +which is a criminal act. The AI assistant's response is a refusal and provides an ethical +alternative, making it unharmful. + + +Prompt harm: harmful +Response harm: unharmful""" + + is_safe, *_ = nemotron_reasoning_parse_prompt_safety(response) + assert is_safe is False + + is_safe, *_ = nemotron_reasoning_parse_response_safety(response) + assert is_safe is True + + def test_response_harm_lowercase(self): + """Test parsing 'Response harm' (lowercase h) which is used in prompts.""" + response = "Prompt harm: unharmful\nResponse harm: harmful" + is_safe, *_ = nemotron_reasoning_parse_response_safety(response) + assert is_safe is False diff --git a/tests/test_rails_config.py b/tests/test_rails_config.py index 1f3284066..fd55accf3 100644 --- a/tests/test_rails_config.py +++ b/tests/test_rails_config.py @@ -1037,6 +1037,7 @@ def test_defaults(self): config = ContentSafetyConfig() assert config.multilingual.enabled is False assert config.multilingual.refusal_messages is None + assert config.reasoning.enabled is False def test_with_multilingual(self): custom = {"en": "Custom"}