Skip to content

Commit 949e422

Browse files
authored
feat(llmrails): support method chaining by returning self from LLMRails.register_* methods (#1296)
1 parent 52ac7ed commit 949e422

File tree

2 files changed

+112
-7
lines changed

2 files changed

+112
-7
lines changed

nemoguardrails/rails/llm/llmrails.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
from langchain_core.language_models import BaseChatModel
3030
from langchain_core.language_models.llms import BaseLLM
31+
from typing_extensions import Self
3132

3233
from nemoguardrails.actions.llm.generation import LLMGenerationActions
3334
from nemoguardrails.actions.llm.utils import (
@@ -1407,33 +1408,38 @@ def process_events(
14071408
self.process_events_async(events, state, blocking)
14081409
)
14091410

1410-
def register_action(self, action: callable, name: Optional[str] = None):
1411+
def register_action(self, action: callable, name: Optional[str] = None) -> Self:
14111412
"""Register a custom action for the rails configuration."""
14121413
self.runtime.register_action(action, name)
1414+
return self
14131415

1414-
def register_action_param(self, name: str, value: Any):
1416+
def register_action_param(self, name: str, value: Any) -> Self:
14151417
"""Registers a custom action parameter."""
14161418
self.runtime.register_action_param(name, value)
1419+
return self
14171420

1418-
def register_filter(self, filter_fn: callable, name: Optional[str] = None):
1421+
def register_filter(self, filter_fn: callable, name: Optional[str] = None) -> Self:
14191422
"""Register a custom filter for the rails configuration."""
14201423
self.runtime.llm_task_manager.register_filter(filter_fn, name)
1424+
return self
14211425

1422-
def register_output_parser(self, output_parser: callable, name: str):
1426+
def register_output_parser(self, output_parser: callable, name: str) -> Self:
14231427
"""Register a custom output parser for the rails configuration."""
14241428
self.runtime.llm_task_manager.register_output_parser(output_parser, name)
1429+
return self
14251430

1426-
def register_prompt_context(self, name: str, value_or_fn: Any):
1431+
def register_prompt_context(self, name: str, value_or_fn: Any) -> Self:
14271432
"""Register a value to be included in the prompt context.
14281433
14291434
:name: The name of the variable or function that will be used.
14301435
:value_or_fn: The value or function that will be used to generate the value.
14311436
"""
14321437
self.runtime.llm_task_manager.register_prompt_context(name, value_or_fn)
1438+
return self
14331439

14341440
def register_embedding_search_provider(
14351441
self, name: str, cls: Type[EmbeddingsIndex]
1436-
) -> None:
1442+
) -> Self:
14371443
"""Register a new embedding search provider.
14381444
14391445
Args:
@@ -1442,10 +1448,11 @@ def register_embedding_search_provider(
14421448
"""
14431449

14441450
self.embedding_search_providers[name] = cls
1451+
return self
14451452

14461453
def register_embedding_provider(
14471454
self, cls: Type[EmbeddingModel], name: Optional[str] = None
1448-
) -> None:
1455+
) -> Self:
14491456
"""Register a custom embedding provider.
14501457
14511458
Args:
@@ -1457,6 +1464,7 @@ def register_embedding_provider(
14571464
ValueError: If the model does not have 'encode' or 'encode_async' methods.
14581465
"""
14591466
register_embedding_provider(engine_name=name, model=cls)
1467+
return self
14601468

14611469
def explain(self) -> ExplainInfo:
14621470
"""Helper function to return the latest ExplainInfo object."""

tests/test_llmrails.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1155,3 +1155,100 @@ async def test_stream_usage_enabled_for_all_providers_when_streaming(
11551155

11561156
# stream_usage should be set for all providers when streaming is enabled
11571157
assert kwargs.get("stream_usage") is True
1158+
1159+
1160+
# Add this test after the existing tests, around line 1100+
1161+
1162+
1163+
def test_register_methods_return_self():
1164+
"""Test that all register_* methods return self for method chaining."""
1165+
config = RailsConfig.from_content(config={"models": []})
1166+
rails = LLMRails(config=config, llm=FakeLLM(responses=[]))
1167+
1168+
# Test register_action returns self
1169+
def dummy_action():
1170+
pass
1171+
1172+
result = rails.register_action(dummy_action, "test_action")
1173+
assert result is rails, "register_action should return self"
1174+
1175+
# Test register_action_param returns self
1176+
result = rails.register_action_param("test_param", "test_value")
1177+
assert result is rails, "register_action_param should return self"
1178+
1179+
# Test register_filter returns self
1180+
def dummy_filter(text):
1181+
return text
1182+
1183+
result = rails.register_filter(dummy_filter, "test_filter")
1184+
assert result is rails, "register_filter should return self"
1185+
1186+
# Test register_output_parser returns self
1187+
def dummy_parser(text):
1188+
return text
1189+
1190+
result = rails.register_output_parser(dummy_parser, "test_parser")
1191+
assert result is rails, "register_output_parser should return self"
1192+
1193+
# Test register_prompt_context returns self
1194+
result = rails.register_prompt_context("test_context", "test_value")
1195+
assert result is rails, "register_prompt_context should return self"
1196+
1197+
# Test register_embedding_search_provider returns self
1198+
from nemoguardrails.embeddings.index import EmbeddingsIndex
1199+
1200+
class DummyEmbeddingProvider(EmbeddingsIndex):
1201+
def __init__(self, **kwargs):
1202+
pass
1203+
1204+
def build(self):
1205+
pass
1206+
1207+
def search(self, text, max_results=5):
1208+
return []
1209+
1210+
result = rails.register_embedding_search_provider(
1211+
"dummy_provider", DummyEmbeddingProvider
1212+
)
1213+
assert result is rails, "register_embedding_search_provider should return self"
1214+
1215+
# Test register_embedding_provider returns self
1216+
from nemoguardrails.embeddings.providers.base import EmbeddingModel
1217+
1218+
class DummyEmbeddingModel(EmbeddingModel):
1219+
def encode(self, texts):
1220+
return []
1221+
1222+
result = rails.register_embedding_provider(DummyEmbeddingModel, "dummy_embedding")
1223+
assert result is rails, "register_embedding_provider should return self"
1224+
1225+
1226+
def test_method_chaining():
1227+
"""Test that method chaining works correctly with register_* methods."""
1228+
config = RailsConfig.from_content(config={"models": []})
1229+
rails = LLMRails(config=config, llm=FakeLLM(responses=[]))
1230+
1231+
def dummy_action():
1232+
return "action_result"
1233+
1234+
def dummy_filter(text):
1235+
return text.upper()
1236+
1237+
def dummy_parser(text):
1238+
return {"parsed": text}
1239+
1240+
# Test chaining multiple register methods
1241+
result = (
1242+
rails.register_action(dummy_action, "chained_action")
1243+
.register_action_param("chained_param", "param_value")
1244+
.register_filter(dummy_filter, "chained_filter")
1245+
.register_output_parser(dummy_parser, "chained_parser")
1246+
.register_prompt_context("chained_context", "context_value")
1247+
)
1248+
1249+
assert result is rails, "Method chaining should return the same rails instance"
1250+
1251+
# Verify that all registrations actually worked
1252+
assert "chained_action" in rails.runtime.action_dispatcher.registered_actions
1253+
assert "chained_param" in rails.runtime.registered_action_params
1254+
assert rails.runtime.registered_action_params["chained_param"] == "param_value"

0 commit comments

Comments
 (0)