Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions nemoguardrails/llm/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,17 @@ def get_task_model(config: RailsConfig, task: Union[str, Task]) -> Optional[Mode
# Fetch current task parameters like name, models to use, and the prompting mode
task_name = str(task.value) if isinstance(task, Task) else task

# Check if the task name contains a model specification (e.g., "content_safety_check_input $model=content_safety")
if "$model=" in task_name:
# Extract the model type from the task name
model_type = task_name.split("$model=")[-1].strip()
# Look for a model with this specific type
if config.models:
_models = [model for model in config.models if model.type == model_type]
if _models:
return _models[0]

# If no model specification or no matching model found, fall back to the original logic
if config.models:
_models = [model for model in config.models if model.type == task_name]
if not _models:
Expand Down
34 changes: 34 additions & 0 deletions tests/test_llm_task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,3 +532,37 @@ def test_get_task_model_fallback_to_main():
result = get_task_model(config, "some_other_task")
assert result is not None
assert result.type == "main"


def test_get_task_model_with_model_specification():
"""Test that get_task_model correctly extracts model type from task names with $model= specification."""
config = RailsConfig.parse_object(
{
"models": [
{
"type": "main",
"engine": "openai",
"model": "gpt-3.5-turbo",
},
{
"type": "content_safety",
"engine": "openai",
"model": "gpt-4",
},
]
}
)

# Test with a task name that contains $model= specification
result = get_task_model(config, "content_safety_check_input $model=content_safety")
assert result is not None
assert result.type == "content_safety"
assert result.engine == "openai"
assert result.model == "gpt-4"

# Test fallback to main model when specified model type doesn't exist
result = get_task_model(config, "unknown_task $model=nonexistent")
assert result is not None
assert result.type == "main"
assert result.engine == "openai"
assert result.model == "gpt-3.5-turbo"