@@ -1155,3 +1155,88 @@ async def test_stream_usage_enabled_for_all_providers_when_streaming(
1155
1155
1156
1156
# stream_usage should be set for all providers when streaming is enabled
1157
1157
assert kwargs .get ("stream_usage" ) is True
1158
+
1159
+
1160
+ def test_register_methods_return_self ():
1161
+ """Test that all register_* methods return self for method chaining."""
1162
+ config = RailsConfig .from_content (config = {"models" : []})
1163
+ rails = LLMRails (config = config , llm = FakeLLM (responses = []))
1164
+
1165
+ def dummy_action ():
1166
+ pass
1167
+
1168
+ result = rails .register_action (dummy_action , "test_action" )
1169
+ assert result is rails , "register_action should return self"
1170
+
1171
+ result = rails .register_action_param ("test_param" , "test_value" )
1172
+ assert result is rails , "register_action_param should return self"
1173
+
1174
+ def dummy_filter (text ):
1175
+ return text
1176
+
1177
+ result = rails .register_filter (dummy_filter , "test_filter" )
1178
+ assert result is rails , "register_filter should return self"
1179
+
1180
+ def dummy_parser (text ):
1181
+ return text
1182
+
1183
+ result = rails .register_output_parser (dummy_parser , "test_parser" )
1184
+ assert result is rails , "register_output_parser should return self"
1185
+
1186
+ result = rails .register_prompt_context ("test_context" , "test_value" )
1187
+ assert result is rails , "register_prompt_context should return self"
1188
+
1189
+ from nemoguardrails .embeddings .index import EmbeddingsIndex
1190
+
1191
+ class DummyEmbeddingProvider (EmbeddingsIndex ):
1192
+ def __init__ (self , ** kwargs ):
1193
+ pass
1194
+
1195
+ def build (self ):
1196
+ pass
1197
+
1198
+ def search (self , text , max_results = 5 ):
1199
+ return []
1200
+
1201
+ result = rails .register_embedding_search_provider (
1202
+ "dummy_provider" , DummyEmbeddingProvider
1203
+ )
1204
+ assert result is rails , "register_embedding_search_provider should return self"
1205
+
1206
+ from nemoguardrails .embeddings .providers .base import EmbeddingModel
1207
+
1208
+ class DummyEmbeddingModel (EmbeddingModel ):
1209
+ def encode (self , texts ):
1210
+ return []
1211
+
1212
+ result = rails .register_embedding_provider (DummyEmbeddingModel , "dummy_embedding" )
1213
+ assert result is rails , "register_embedding_provider should return self"
1214
+
1215
+
1216
+ def test_method_chaining ():
1217
+ """Test that method chaining works correctly with register_* methods."""
1218
+ config = RailsConfig .from_content (config = {"models" : []})
1219
+ rails = LLMRails (config = config , llm = FakeLLM (responses = []))
1220
+
1221
+ def dummy_action ():
1222
+ return "action_result"
1223
+
1224
+ def dummy_filter (text ):
1225
+ return text .upper ()
1226
+
1227
+ def dummy_parser (text ):
1228
+ return {"parsed" : text }
1229
+
1230
+ result = (
1231
+ rails .register_action (dummy_action , "chained_action" )
1232
+ .register_action_param ("chained_param" , "param_value" )
1233
+ .register_filter (dummy_filter , "chained_filter" )
1234
+ .register_output_parser (dummy_parser , "chained_parser" )
1235
+ .register_prompt_context ("chained_context" , "context_value" )
1236
+ )
1237
+
1238
+ assert result is rails , "Method chaining should return the same rails instance"
1239
+
1240
+ assert "chained_action" in rails .runtime .action_dispatcher .registered_actions
1241
+ assert "chained_param" in rails .runtime .registered_action_params
1242
+ assert rails .runtime .registered_action_params ["chained_param" ] == "param_value"
0 commit comments