Skip to content

Commit dde1553

Browse files
committed
Add tests asserting that register methods return self
1 parent 01088f0 commit dde1553

File tree

1 file changed

+85
-0
lines changed

1 file changed

+85
-0
lines changed

tests/test_llmrails.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1155,3 +1155,88 @@ async def test_stream_usage_enabled_for_all_providers_when_streaming(
11551155

11561156
# stream_usage should be set for all providers when streaming is enabled
11571157
assert kwargs.get("stream_usage") is True
1158+
1159+
1160+
def test_register_methods_return_self():
1161+
"""Test that all register_* methods return self for method chaining."""
1162+
config = RailsConfig.from_content(config={"models": []})
1163+
rails = LLMRails(config=config, llm=FakeLLM(responses=[]))
1164+
1165+
def dummy_action():
1166+
pass
1167+
1168+
result = rails.register_action(dummy_action, "test_action")
1169+
assert result is rails, "register_action should return self"
1170+
1171+
result = rails.register_action_param("test_param", "test_value")
1172+
assert result is rails, "register_action_param should return self"
1173+
1174+
def dummy_filter(text):
1175+
return text
1176+
1177+
result = rails.register_filter(dummy_filter, "test_filter")
1178+
assert result is rails, "register_filter should return self"
1179+
1180+
def dummy_parser(text):
1181+
return text
1182+
1183+
result = rails.register_output_parser(dummy_parser, "test_parser")
1184+
assert result is rails, "register_output_parser should return self"
1185+
1186+
result = rails.register_prompt_context("test_context", "test_value")
1187+
assert result is rails, "register_prompt_context should return self"
1188+
1189+
from nemoguardrails.embeddings.index import EmbeddingsIndex
1190+
1191+
class DummyEmbeddingProvider(EmbeddingsIndex):
1192+
def __init__(self, **kwargs):
1193+
pass
1194+
1195+
def build(self):
1196+
pass
1197+
1198+
def search(self, text, max_results=5):
1199+
return []
1200+
1201+
result = rails.register_embedding_search_provider(
1202+
"dummy_provider", DummyEmbeddingProvider
1203+
)
1204+
assert result is rails, "register_embedding_search_provider should return self"
1205+
1206+
from nemoguardrails.embeddings.providers.base import EmbeddingModel
1207+
1208+
class DummyEmbeddingModel(EmbeddingModel):
1209+
def encode(self, texts):
1210+
return []
1211+
1212+
result = rails.register_embedding_provider(DummyEmbeddingModel, "dummy_embedding")
1213+
assert result is rails, "register_embedding_provider should return self"
1214+
1215+
1216+
def test_method_chaining():
1217+
"""Test that method chaining works correctly with register_* methods."""
1218+
config = RailsConfig.from_content(config={"models": []})
1219+
rails = LLMRails(config=config, llm=FakeLLM(responses=[]))
1220+
1221+
def dummy_action():
1222+
return "action_result"
1223+
1224+
def dummy_filter(text):
1225+
return text.upper()
1226+
1227+
def dummy_parser(text):
1228+
return {"parsed": text}
1229+
1230+
result = (
1231+
rails.register_action(dummy_action, "chained_action")
1232+
.register_action_param("chained_param", "param_value")
1233+
.register_filter(dummy_filter, "chained_filter")
1234+
.register_output_parser(dummy_parser, "chained_parser")
1235+
.register_prompt_context("chained_context", "context_value")
1236+
)
1237+
1238+
assert result is rails, "Method chaining should return the same rails instance"
1239+
1240+
assert "chained_action" in rails.runtime.action_dispatcher.registered_actions
1241+
assert "chained_param" in rails.runtime.registered_action_params
1242+
assert rails.runtime.registered_action_params["chained_param"] == "param_value"

0 commit comments

Comments
 (0)