diff --git a/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java b/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java index ae537c1df4..53f66ce384 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java @@ -365,10 +365,6 @@ public T createPayload(String action, Map parameters) { jsonObject.addProperty("stream", true); payload = jsonObject.toString(); } - // Log payload for debugging - - log.info("=== PAYLOAD DEBUG === Action: {} | Payload: {}", action, payload); - return (T) payload; } return (T) parameters.get("http_body"); diff --git a/common/src/main/java/org/opensearch/ml/common/contextmanager/ContextManagerContext.java b/common/src/main/java/org/opensearch/ml/common/contextmanager/ContextManagerContext.java index 811449002b..c4bf694f03 100644 --- a/common/src/main/java/org/opensearch/ml/common/contextmanager/ContextManagerContext.java +++ b/common/src/main/java/org/opensearch/ml/common/contextmanager/ContextManagerContext.java @@ -65,7 +65,7 @@ public class ContextManagerContext { * Additional parameters for context processing */ @Builder.Default - private Map parameters = new HashMap<>(); + private Map parameters = new HashMap<>(); /** * Get the total token count for the current context. @@ -174,7 +174,7 @@ public Object getParameter(String key) { * @param key the parameter key * @param value the parameter value */ - public void setParameter(String key, Object value) { + public void setParameter(String key, String value) { if (parameters == null) { parameters = new HashMap<>(); } diff --git a/common/src/main/java/org/opensearch/ml/common/input/execute/agent/AgentMLInput.java b/common/src/main/java/org/opensearch/ml/common/input/execute/agent/AgentMLInput.java index c7e29af391..c92ca43846 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/execute/agent/AgentMLInput.java +++ b/common/src/main/java/org/opensearch/ml/common/input/execute/agent/AgentMLInput.java @@ -110,7 +110,12 @@ public AgentMLInput(XContentParser parser, FunctionName functionName) throws IOE tenantId = parser.textOrNull(); break; case PARAMETERS_FIELD: - Map parameters = StringUtils.getParameterMap(parser.map()); + Map parameterObjs = parser.map(); + Map parameters = StringUtils.getParameterMap(parameterObjs); + // Extract context_management from parameters + if (parameterObjs.containsKey("context_management")) { + contextManagementName = (String) parameterObjs.get("context_management"); + } inputDataset = new RemoteInferenceInputDataSet(parameters); break; case ASYNC_FIELD: diff --git a/common/src/test/java/org/opensearch/ml/common/contextmanager/ToolsOutputTruncateManagerTest.java b/common/src/test/java/org/opensearch/ml/common/contextmanager/ToolsOutputTruncateManagerTest.java deleted file mode 100644 index 1c02aa4aa6..0000000000 --- a/common/src/test/java/org/opensearch/ml/common/contextmanager/ToolsOutputTruncateManagerTest.java +++ /dev/null @@ -1,266 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.ml.common.contextmanager; - -import java.util.ArrayList; -import java.util.HashMap; -import java.util.Map; - -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; - -/** - * Unit tests for ToolsOutputTruncateManager. - */ -public class ToolsOutputTruncateManagerTest { - - private ToolsOutputTruncateManager manager; - private ContextManagerContext context; - - @Before - public void setUp() { - manager = new ToolsOutputTruncateManager(); - context = ContextManagerContext.builder().toolInteractions(new ArrayList<>()).build(); - } - - @Test - public void testGetType() { - Assert.assertEquals("ToolsOutputTruncateManager", manager.getType()); - } - - @Test - public void testInitializeWithDefaults() { - Map config = new HashMap<>(); - manager.initialize(config); - - // Should initialize with default values without throwing exceptions - Assert.assertNotNull(manager); - } - - @Test - public void testInitializeWithCustomConfig() { - Map config = new HashMap<>(); - config.put("max_tokens", 1000); - config.put("truncation_strategy", "preserve_end"); - config.put("truncation_marker", "... [TRUNCATED]"); - - manager.initialize(config); - - // Should initialize without throwing exceptions - Assert.assertNotNull(manager); - } - - @Test - public void testInitializeWithActivationRules() { - Map config = new HashMap<>(); - Map activation = new HashMap<>(); - activation.put("tokens_exceed", 5000); - config.put("activation", activation); - - manager.initialize(config); - - // Should initialize without throwing exceptions - Assert.assertNotNull(manager); - } - - @Test - public void testShouldActivateWithNoRules() { - Map config = new HashMap<>(); - manager.initialize(config); - - // Should always activate when no rules are defined - Assert.assertTrue(manager.shouldActivate(context)); - } - - @Test - public void testShouldActivateWithTokensExceedRule() { - Map config = new HashMap<>(); - Map activation = new HashMap<>(); - activation.put("tokens_exceed", 100); - config.put("activation", activation); - - manager.initialize(config); - - // Create context with small tool output (should not activate) - Map interaction = new HashMap<>(); - interaction.put("output", "Small output"); - context.getToolInteractions().add(interaction); - - Assert.assertFalse(manager.shouldActivate(context)); - - // Create context with large tool output (should activate) - String largeOutput = "This is a very long output that should exceed the token limit. ".repeat(50); - interaction.put("output", largeOutput); - - Assert.assertTrue(manager.shouldActivate(context)); - } - - @Test - public void testExecuteWithNoToolInteractions() { - Map config = new HashMap<>(); - manager.initialize(config); - - // Should handle empty tool interactions gracefully - manager.execute(context); - - Assert.assertTrue(context.getToolInteractions().isEmpty()); - } - - @Test - public void testExecuteWithSmallToolOutput() { - Map config = new HashMap<>(); - config.put("max_tokens", 1000); - manager.initialize(config); - - // Add small tool output - Map interaction = new HashMap<>(); - interaction.put("output", "Small output that should not be truncated"); - context.getToolInteractions().add(interaction); - - String originalOutput = (String) interaction.get("output"); - manager.execute(context); - - // Output should remain unchanged - Assert.assertEquals(originalOutput, interaction.get("output")); - } - - @Test - public void testExecuteWithLargeToolOutput() { - Map config = new HashMap<>(); - config.put("max_tokens", 50); - config.put("truncation_strategy", "preserve_beginning"); - config.put("truncation_marker", "... [TRUNCATED]"); - manager.initialize(config); - - // Add large tool output - String largeOutput = "This is a very long output that should definitely be truncated because it exceeds the token limit. " - .repeat(10); - Map interaction = new HashMap<>(); - interaction.put("output", largeOutput); - context.getToolInteractions().add(interaction); - - manager.execute(context); - - String truncatedOutput = (String) interaction.get("output"); - - // Output should be truncated and contain the marker - Assert.assertNotEquals(largeOutput, truncatedOutput); - Assert.assertTrue(truncatedOutput.contains("... [TRUNCATED]")); - Assert.assertTrue(truncatedOutput.length() < largeOutput.length()); - } - - @Test - public void testExecuteWithMultipleToolOutputs() { - Map config = new HashMap<>(); - config.put("max_tokens", 50); - config.put("truncation_marker", "... [TRUNCATED]"); - manager.initialize(config); - - // Add multiple tool outputs - some large, some small - String smallOutput = "Small output"; - String largeOutput = "This is a very long output that should be truncated. ".repeat(10); - - Map interaction1 = new HashMap<>(); - interaction1.put("output", smallOutput); - context.getToolInteractions().add(interaction1); - - Map interaction2 = new HashMap<>(); - interaction2.put("output", largeOutput); - context.getToolInteractions().add(interaction2); - - Map interaction3 = new HashMap<>(); - interaction3.put("output", smallOutput); - context.getToolInteractions().add(interaction3); - - manager.execute(context); - - // First and third outputs should remain unchanged - Assert.assertEquals(smallOutput, interaction1.get("output")); - Assert.assertEquals(smallOutput, interaction3.get("output")); - - // Second output should be truncated - String truncatedOutput = (String) interaction2.get("output"); - Assert.assertNotEquals(largeOutput, truncatedOutput); - Assert.assertTrue(truncatedOutput.contains("... [TRUNCATED]")); - } - - @Test - public void testExecuteWithNonStringOutput() { - Map config = new HashMap<>(); - manager.initialize(config); - - // Add non-string tool output - Map interaction = new HashMap<>(); - interaction.put("output", 12345); - context.getToolInteractions().add(interaction); - - // Should handle non-string outputs gracefully - manager.execute(context); - - // Output should remain unchanged - Assert.assertEquals(12345, interaction.get("output")); - } - - @Test - public void testTruncationStrategies() { - // Test preserve_beginning strategy - testTruncationStrategy("preserve_beginning"); - - // Test preserve_end strategy - testTruncationStrategy("preserve_end"); - - // Test preserve_middle strategy - testTruncationStrategy("preserve_middle"); - } - - private void testTruncationStrategy(String strategy) { - ToolsOutputTruncateManager testManager = new ToolsOutputTruncateManager(); - Map config = new HashMap<>(); - config.put("max_tokens", 50); - config.put("truncation_strategy", strategy); - config.put("truncation_marker", "... [TRUNCATED]"); - testManager.initialize(config); - - ContextManagerContext testContext = ContextManagerContext.builder().toolInteractions(new ArrayList<>()).build(); - - String largeOutput = "This is a very long output that should be truncated according to the specified strategy. ".repeat(10); - Map interaction = new HashMap<>(); - interaction.put("output", largeOutput); - testContext.getToolInteractions().add(interaction); - - testManager.execute(testContext); - - String truncatedOutput = (String) interaction.get("output"); - - // Output should be truncated and contain the marker - Assert.assertNotEquals(largeOutput, truncatedOutput); - Assert.assertTrue(truncatedOutput.contains("... [TRUNCATED]")); - Assert.assertTrue(truncatedOutput.length() < largeOutput.length()); - } - - @Test - public void testInvalidTruncationStrategy() { - Map config = new HashMap<>(); - config.put("truncation_strategy", "invalid_strategy"); - - // Should handle invalid strategy gracefully and use default - manager.initialize(config); - - Assert.assertNotNull(manager); - } - - @Test - public void testInvalidMaxTokensConfig() { - Map config = new HashMap<>(); - config.put("max_tokens", "invalid_number"); - - // Should handle invalid config gracefully and use default - manager.initialize(config); - - Assert.assertNotNull(manager); - } -} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/agents/AgentContextUtil.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/agents/AgentContextUtil.java new file mode 100644 index 0000000000..14f87204da --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/agents/AgentContextUtil.java @@ -0,0 +1,199 @@ +package org.opensearch.ml.engine.agents; + +import static org.opensearch.ml.engine.algorithms.agent.MLAgentExecutor.QUESTION; +import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.CHAT_HISTORY; +import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.INTERACTIONS; +import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.SYSTEM_PROMPT_FIELD; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ml.common.agent.MLToolSpec; +import org.opensearch.ml.common.contextmanager.ContextManagerContext; +import org.opensearch.ml.common.hooks.EnhancedPostToolEvent; +import org.opensearch.ml.common.hooks.HookRegistry; +import org.opensearch.ml.common.hooks.PreLLMEvent; +import org.opensearch.ml.common.spi.memory.Memory; +import org.opensearch.ml.common.utils.StringUtils; +import org.opensearch.ml.engine.memory.ConversationIndexMemory; + +public class AgentContextUtil { + private static final Logger log = LogManager.getLogger(AgentContextUtil.class); + + public static ContextManagerContext buildContextManagerContextForToolOutput( + String toolOutput, + Map parameters, + List toolSpecs, + Memory memory + ) { + ContextManagerContext.ContextManagerContextBuilder builder = ContextManagerContext.builder(); + + String systemPrompt = parameters.get(SYSTEM_PROMPT_FIELD); + if (systemPrompt != null) { + builder.systemPrompt(systemPrompt); + } + + String userPrompt = parameters.get(QUESTION); + if (userPrompt != null) { + builder.userPrompt(userPrompt); + } + + if (toolSpecs != null) { + builder.toolConfigs(toolSpecs); + } + + Map contextParameters = new HashMap<>(); + contextParameters.putAll(parameters); + contextParameters.put("_current_tool_output", toolOutput); + builder.parameters(contextParameters); + + return builder.build(); + } + + public static Object extractProcessedToolOutput(ContextManagerContext context) { + if (context.getParameters() != null) { + return context.getParameters().get("_current_tool_output"); + } + return null; + } + + public static Object extractFromContext(ContextManagerContext context, String key) { + if (context.getParameters() != null) { + return context.getParameters().get(key); + } + return null; + } + + public static ContextManagerContext buildContextManagerContext( + Map parameters, + List interactions, + List toolSpecs, + Memory memory + ) { + ContextManagerContext.ContextManagerContextBuilder builder = ContextManagerContext.builder(); + + String systemPrompt = parameters.get(SYSTEM_PROMPT_FIELD); + if (systemPrompt != null) { + builder.systemPrompt(systemPrompt); + } + + String userPrompt = parameters.get(QUESTION); + if (userPrompt != null) { + builder.userPrompt(userPrompt); + } + + if (memory instanceof ConversationIndexMemory) { + String chatHistory = parameters.get(CHAT_HISTORY); + // TODO to add chatHistory into context, currently there is no context manager working on chat_history + } + + if (toolSpecs != null) { + builder.toolConfigs(toolSpecs); + } + + List> toolInteractions = new ArrayList<>(); + if (interactions != null) { + for (String interaction : interactions) { + Map toolInteraction = new HashMap<>(); + toolInteraction.put("output", interaction); + toolInteractions.add(toolInteraction); + } + } + builder.toolInteractions(toolInteractions); + + Map contextParameters = new HashMap<>(); + contextParameters.putAll(parameters); + builder.parameters(contextParameters); + + return builder.build(); + } + + public static Object emitPostToolHook( + Object toolOutput, + Map parameters, + List toolSpecs, + Memory memory, + HookRegistry hookRegistry + ) { + if (hookRegistry != null) { + try { + if (toolOutput == null) { + log.warn("Tool output is null, skipping POST_TOOL hook"); + return null; + } + ContextManagerContext context = buildContextManagerContextForToolOutput( + StringUtils.toJson(toolOutput), + parameters, + toolSpecs, + memory + ); + EnhancedPostToolEvent event = new EnhancedPostToolEvent(null, null, context, new HashMap<>()); + hookRegistry.emit(event); + + Object processedOutput = extractProcessedToolOutput(context); + return processedOutput != null ? processedOutput : toolOutput; + } catch (Exception e) { + log.error("Failed to emit POST_TOOL hook event", e); + return toolOutput; + } + } + return toolOutput; + } + + public static ContextManagerContext emitPreLLMHook( + Map parameters, + List interactions, + List toolSpecs, + Memory memory, + HookRegistry hookRegistry + ) { + ContextManagerContext context = buildContextManagerContext(parameters, interactions, toolSpecs, memory); + try { + PreLLMEvent event = new PreLLMEvent(context, new HashMap<>()); + hookRegistry.emit(event); + log.debug("Emitted PRE_LLM hook event and updated context"); + return context; + + } catch (Exception e) { + log.error("Failed to emit PRE_LLM hook event", e); + return context; + } + } + + public static void updateParametersFromContext(Map parameters, ContextManagerContext context) { + if (context.getSystemPrompt() != null) { + parameters.put(SYSTEM_PROMPT_FIELD, context.getSystemPrompt()); + } + + if (context.getUserPrompt() != null) { + parameters.put(QUESTION, context.getUserPrompt()); + } + + if (context.getChatHistory() != null && !context.getChatHistory().isEmpty()) { + } + + if (context.getToolInteractions() != null && !context.getToolInteractions().isEmpty()) { + List updatedInteractions = new ArrayList<>(); + for (Map toolInteraction : context.getToolInteractions()) { + Object output = toolInteraction.get("output"); + if (output instanceof String) { + updatedInteractions.add((String) output); + } + } + if (!updatedInteractions.isEmpty()) { + parameters.put(INTERACTIONS, ", " + String.join(", ", updatedInteractions)); + } + } + + if (context.getParameters() != null) { + for (Map.Entry entry : context.getParameters().entrySet()) { + parameters.put(entry.getKey(), entry.getValue()); + + } + } + } +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java index 464c7af78f..51f3c9d869 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java @@ -665,7 +665,8 @@ protected MLAgentRunner getAgentRunner(MLAgent mlAgent, HookRegistry hookRegistr toolFactories, memoryFactoryMap, sdkClient, - encryptor + encryptor, + hookRegistry ); default: throw new IllegalArgumentException("Unsupported agent type: " + mlAgent.getType()); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 63478ee6cb..2c03fb9f87 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -62,9 +62,7 @@ import org.opensearch.ml.common.agent.MLToolSpec; import org.opensearch.ml.common.contextmanager.ContextManagerContext; import org.opensearch.ml.common.conversation.Interaction; -import org.opensearch.ml.common.hooks.EnhancedPostToolEvent; import org.opensearch.ml.common.hooks.HookRegistry; -import org.opensearch.ml.common.hooks.PreLLMEvent; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.output.model.ModelTensors; @@ -73,6 +71,7 @@ import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.common.utils.StringUtils; +import org.opensearch.ml.engine.agents.AgentContextUtil; import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.function_calling.FunctionCalling; import org.opensearch.ml.engine.function_calling.FunctionCallingFactory; @@ -541,11 +540,16 @@ private void runReAct( return; } // Emit PRE_LLM hook event - List currentToolSpecs = new ArrayList<>(toolSpecMap.values()); - emitPreLLMHook(tmpParameters, interactions, currentToolSpecs, memory); - - ActionRequest request = streamingWrapper.createPredictionRequest(llm, tmpParameters, tenantId); - streamingWrapper.executeRequest(request, (ActionListener) nextStepListener); + if (hookRegistry != null) { + List currentToolSpecs = new ArrayList<>(toolSpecMap.values()); + ContextManagerContext contextAfterEvent = AgentContextUtil + .emitPreLLMHook(tmpParameters, interactions, currentToolSpecs, memory, hookRegistry); + ActionRequest request = streamingWrapper.createPredictionRequest(llm, contextAfterEvent.getParameters(), tenantId); + streamingWrapper.executeRequest(request, (ActionListener) nextStepListener); + } else { + ActionRequest request = streamingWrapper.createPredictionRequest(llm, tmpParameters, tenantId); + streamingWrapper.executeRequest(request, (ActionListener) nextStepListener); + } } }, e -> { log.error("Failed to run chat agent", e); @@ -559,10 +563,16 @@ private void runReAct( // Emit PRE_LLM hook event for initial LLM call List initialToolSpecs = new ArrayList<>(toolSpecMap.values()); tmpParameters.put("_llm_model_id", llm.getModelId()); - emitPreLLMHook(tmpParameters, interactions, initialToolSpecs, memory); + if (hookRegistry != null) { + ContextManagerContext contextAfterEvent = AgentContextUtil + .emitPreLLMHook(tmpParameters, interactions, initialToolSpecs, memory, hookRegistry); + ActionRequest request = streamingWrapper.createPredictionRequest(llm, contextAfterEvent.getParameters(), tenantId); + streamingWrapper.executeRequest(request, firstListener); + } else { + ActionRequest request = streamingWrapper.createPredictionRequest(llm, tmpParameters, tenantId); + streamingWrapper.executeRequest(request, firstListener); + } - ActionRequest request = streamingWrapper.createPredictionRequest(llm, tmpParameters, tenantId); - streamingWrapper.executeRequest(request, firstListener); } private static List createFinalAnswerTensors(List sessionId, List lastThought) { @@ -640,7 +650,9 @@ private static void runTool( // Emit POST_TOOL hook event after tool execution and process current tool output List postToolSpecs = new ArrayList<>(toolSpecMap.values()); - String outputResponseAfterHook = emitPostToolHook(outputResponse, tmpParameters, postToolSpecs, null).toString(); + String outputResponseAfterHook = AgentContextUtil + .emitPostToolHook(outputResponse, tmpParameters, postToolSpecs, null, hookRegistry) + .toString(); List> toolResults = List .of(Map.of(TOOL_CALL_ID, toolCallId, TOOL_RESULT, Map.of("text", outputResponseAfterHook))); @@ -650,7 +662,7 @@ private static void runTool( } else { // Emit POST_TOOL hook event for non-function calling path List postToolSpecs = new ArrayList<>(toolSpecMap.values()); - Object processedOutput = emitPostToolHook(r, tmpParameters, postToolSpecs, null); + Object processedOutput = AgentContextUtil.emitPostToolHook(r, tmpParameters, postToolSpecs, null, hookRegistry); interactions .add( substitute( @@ -975,240 +987,4 @@ private void saveMessage( } } - /** - * Build ContextManagerContext for current tool output - */ - private static ContextManagerContext buildContextManagerContextForToolOutput( - Object toolOutput, - Map parameters, - List toolSpecs, - Memory memory - ) { - ContextManagerContext.ContextManagerContextBuilder builder = ContextManagerContext.builder(); - - // Set system prompt - String systemPrompt = parameters.get(SYSTEM_PROMPT_FIELD); - if (systemPrompt != null) { - builder.systemPrompt(systemPrompt); - } - - // Set user prompt - String userPrompt = parameters.get(MLAgentExecutor.QUESTION); - if (userPrompt != null) { - builder.userPrompt(userPrompt); - } - - // Set tool configurations - if (toolSpecs != null) { - builder.toolConfigs(toolSpecs); - } - - // Set current tool output as parameter for context managers to process - Map contextParameters = new HashMap<>(); - contextParameters.putAll(parameters); - contextParameters.put("_current_tool_output", toolOutput); - builder.parameters(contextParameters); - - return builder.build(); - } - - /** - * Extract processed tool output from context - */ - private static Object extractProcessedToolOutput(ContextManagerContext context) { - if (context.getParameters() != null) { - return context.getParameters().get("_current_tool_output"); - } - return null; - } - - /** - * Build ContextManagerContext from current agent execution state - */ - private ContextManagerContext buildContextManagerContext( - Map parameters, - List interactions, - List toolSpecs, - Memory memory - ) { - ContextManagerContext.ContextManagerContextBuilder builder = ContextManagerContext.builder(); - - // Set system prompt - String systemPrompt = parameters.get(SYSTEM_PROMPT_FIELD); - if (systemPrompt != null) { - builder.systemPrompt(systemPrompt); - } - - // Set user prompt - String userPrompt = parameters.get(MLAgentExecutor.QUESTION); - if (userPrompt != null) { - builder.userPrompt(userPrompt); - } - - // Set chat history from memory - if (memory instanceof ConversationIndexMemory) { - // For now, we'll use the chat history that's already been processed - // In a more complete implementation, we might want to fetch fresh history - String chatHistory = parameters.get(CHAT_HISTORY); - if (chatHistory != null) { - // Convert chat history string back to interactions - // This is a simplified approach - in practice, you might want to store - // the original interactions list - List chatHistoryList = new ArrayList<>(); - // For now, we'll leave this empty and rely on the existing chat history processing - builder.chatHistory(chatHistoryList); - } - } - - // Set tool configurations - if (toolSpecs != null) { - builder.toolConfigs(toolSpecs); - } - - // Set tool interactions - List> toolInteractions = new ArrayList<>(); - if (interactions != null) { - for (String interaction : interactions) { - Map toolInteraction = new HashMap<>(); - toolInteraction.put("output", interaction); - toolInteractions.add(toolInteraction); - } - } - builder.toolInteractions(toolInteractions); - - // Set additional parameters - Map contextParameters = new HashMap<>(); - contextParameters.putAll(parameters); - builder.parameters(contextParameters); - - return builder.build(); - } - - /** - * Emit POST_TOOL hook event and process current tool output - */ - private static Object emitPostToolHook(Object toolOutput, Map parameters, List toolSpecs, Memory memory) { - log.info("MLChatAgentRunner.emitPostToolHook() called with hookRegistry: {}", hookRegistry != null ? "present" : "null"); - if (hookRegistry != null) { - try { - // Create context with current tool output - ContextManagerContext context = buildContextManagerContextForToolOutput(toolOutput, parameters, toolSpecs, memory); - EnhancedPostToolEvent event = new EnhancedPostToolEvent(null, null, context, new HashMap<>()); - log - .info( - "Emitting POST_TOOL hook event with context containing {} tool interactions", - context.getToolInteractions() != null ? context.getToolInteractions().size() : 0 - ); - hookRegistry.emit(event); - - // Extract processed tool output from context - Object processedOutput = extractProcessedToolOutput(context); - log - .info( - "POST_TOOL hook processing completed. Original output length: {}, Processed output length: {}", - String.valueOf(toolOutput).length(), - processedOutput != null ? String.valueOf(processedOutput).length() : "null" - ); - return processedOutput != null ? processedOutput : toolOutput; - } catch (Exception e) { - log.error("Failed to emit POST_TOOL hook event", e); - return toolOutput; // Return original output on error - } - } - log.warn("No hook registry available, returning original tool output"); - return toolOutput; // Return original output if no hook registry - } - - /** - * Emit PRE_LLM hook event and update context - */ - private void emitPreLLMHook(Map parameters, List interactions, List toolSpecs, Memory memory) { - if (hookRegistry != null) { - try { - - ContextManagerContext context = buildContextManagerContext(parameters, interactions, toolSpecs, memory); - PreLLMEvent event = new PreLLMEvent(context, new HashMap<>()); - hookRegistry.emit(event); - - // Update parameters with any changes made by context managers - updateParametersFromContext(parameters, context); - log.debug("Emitted PRE_LLM hook event and updated context"); - } catch (Exception e) { - log.error("Failed to emit PRE_LLM hook event", e); - // Continue execution even if hook fails - } - } - } - - /** - * Update interactions list with processed results from context - */ - private void updateInteractionsFromContext(List interactions, ContextManagerContext context) { - if (context.getToolInteractions() != null) { - interactions.clear(); - for (Map toolInteraction : context.getToolInteractions()) { - Object output = toolInteraction.get("output"); - if (output instanceof String) { - interactions.add((String) output); - } - } - } - } - - /** - * Update parameters from transformed context - */ - private void updateParametersFromContext(Map parameters, ContextManagerContext context) { - // Update system prompt if changed - if (context.getSystemPrompt() != null) { - parameters.put(SYSTEM_PROMPT_FIELD, context.getSystemPrompt()); - } - - // Update user prompt if changed - if (context.getUserPrompt() != null) { - parameters.put(MLAgentExecutor.QUESTION, context.getUserPrompt()); - } - - // Update chat history if changed - if (context.getChatHistory() != null && !context.getChatHistory().isEmpty()) { - // Convert interactions back to chat history string - // TODO this need more consideration with memory index - // StringBuilder chatHistoryBuilder = new StringBuilder(); - // String chatHistoryPrefix = parameters.getOrDefault(PROMPT_CHAT_HISTORY_PREFIX, CHAT_HISTORY_PREFIX); - // chatHistoryBuilder.append(chatHistoryPrefix); - // - // for (Interaction interaction : context.getChatHistory()) { - // if (interaction.getInput() != null && interaction.getResponse() != null) { - // chatHistoryBuilder.append("Human: ").append(interaction.getInput()).append("\n"); - // chatHistoryBuilder.append("Assistant: ").append(interaction.getResponse()).append("\n"); - // } - // } - // parameters.put(CHAT_HISTORY, chatHistoryBuilder.toString()); - } - - // Update tool interactions if changed by context management - if (context.getToolInteractions() != null && !context.getToolInteractions().isEmpty()) { - List updatedInteractions = new ArrayList<>(); - for (Map toolInteraction : context.getToolInteractions()) { - Object output = toolInteraction.get("output"); - if (output instanceof String) { - updatedInteractions.add((String) output); - } - } - if (!updatedInteractions.isEmpty()) { - // Update the _interactions parameter with processed tool outputs - parameters.put(INTERACTIONS, ", " + String.join(", ", updatedInteractions)); - } - } - - // Update any additional parameters - if (context.getParameters() != null) { - for (Map.Entry entry : context.getParameters().entrySet()) { - if (entry.getValue() instanceof String) { - parameters.put(entry.getKey(), (String) entry.getValue()); - } - } - } - } - } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java index b8b89d8aa2..4065ad0929 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java @@ -20,6 +20,7 @@ import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getCurrentDateTime; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMcpToolSpecs; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMlToolSpecs; +import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.INTERACTIONS; import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.LLM_INTERFACE; import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.MAX_ITERATION; import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.saveTraceData; @@ -55,6 +56,7 @@ import org.opensearch.ml.common.conversation.Interaction; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.exception.MLException; +import org.opensearch.ml.common.hooks.HookRegistry; import org.opensearch.ml.common.input.execute.agent.AgentMLInput; import org.opensearch.ml.common.input.remote.RemoteInferenceMLInput; import org.opensearch.ml.common.output.model.ModelTensor; @@ -68,6 +70,7 @@ import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; import org.opensearch.ml.common.utils.StringUtils; +import org.opensearch.ml.engine.agents.AgentContextUtil; import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.memory.ConversationIndexMemory; import org.opensearch.remote.metadata.client.SdkClient; @@ -91,6 +94,7 @@ public class MLPlanExecuteAndReflectAgentRunner implements MLAgentRunner { private final Map memoryFactoryMap; private SdkClient sdkClient; private Encryptor encryptor; + private HookRegistry hookRegistry; // flag to track if task has been updated with executor memory ids or not private boolean taskUpdated = false; private final Map taskUpdates = new HashMap<>(); @@ -162,7 +166,8 @@ public MLPlanExecuteAndReflectAgentRunner( Map toolFactories, Map memoryFactoryMap, SdkClient sdkClient, - Encryptor encryptor + Encryptor encryptor, + HookRegistry hookRegistry ) { this.client = client; this.settings = settings; @@ -172,6 +177,7 @@ public MLPlanExecuteAndReflectAgentRunner( this.memoryFactoryMap = memoryFactoryMap; this.sdkClient = sdkClient; this.encryptor = encryptor; + this.hookRegistry = hookRegistry; this.plannerPrompt = DEFAULT_PLANNER_PROMPT; this.plannerPromptTemplate = DEFAULT_PLANNER_PROMPT_TEMPLATE; this.reflectPrompt = DEFAULT_REFLECT_PROMPT; @@ -289,7 +295,7 @@ public void run(MLAgent mlAgent, Map apiParams, ActionListenerwrap(memory -> { + .create(allParams.get(USER_PROMPT_FIELD), memoryId, appType, ActionListener.wrap(memory -> { memory.getMessages(ActionListener.>wrap(interactions -> { List completedSteps = new ArrayList<>(); for (Interaction interaction : interactions) { @@ -365,6 +371,7 @@ private void executePlanningLoop( // completedSteps stores the step and its result, hence divide by 2 to find total steps completed // on reaching max iteration, update parent interaction question with last executed step rather than task to allow continue using // memory_id + // emit PRE_LLM hook for planner agent if (stepsExecuted >= maxSteps) { String finalResult = String .format( @@ -385,13 +392,33 @@ private void executePlanningLoop( ); return; } + MLPredictionTaskRequest request; + // Planner agent doesn't use INTERACTIONS for now, reusing the INTERACTIONS to pass over + // completedSteps to context management. + // TODO should refactor the completed steps as message array format, similar to chat agent. + + Map requestParams = new HashMap<>(allParams); + + if (hookRegistry != null && !completedSteps.isEmpty()) { + requestParams.put("_llm_model_id", llm.getModelId()); + requestParams.put(INTERACTIONS, ", " + String.join(", ", completedSteps)); + try { + AgentContextUtil.emitPreLLMHook(requestParams, completedSteps, null, memory, hookRegistry); + } catch (Exception e) { + log.error("Failed to emit pre-LLM hook", e); + } + if (requestParams.get(INTERACTIONS) != null || requestParams.get(INTERACTIONS) != "") { + requestParams.put(COMPLETED_STEPS_FIELD, StringUtils.toJson(requestParams.get(INTERACTIONS))); + requestParams.put(INTERACTIONS, ""); + } + } - MLPredictionTaskRequest request = new MLPredictionTaskRequest( + request = new MLPredictionTaskRequest( llm.getModelId(), RemoteInferenceMLInput .builder() .algorithm(FunctionName.REMOTE) - .inputDataset(RemoteInferenceInputDataSet.builder().parameters(allParams).build()) + .inputDataset(RemoteInferenceInputDataSet.builder().parameters(requestParams).build()) .build(), null, allParams.get(TENANT_ID_FIELD) @@ -442,6 +469,9 @@ private void executePlanningLoop( .inputDataset(RemoteInferenceInputDataSet.builder().parameters(reactParams).build()) .build(); + // Pass hookRegistry to internal agent execution + agentInput.setHookRegistry(hookRegistry); + MLExecuteTaskRequest executeRequest = new MLExecuteTaskRequest(FunctionName.AGENT, agentInput); client.execute(MLExecuteTaskAction.INSTANCE, executeRequest, ActionListener.wrap(executeResponse -> { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/contextmanager/SlidingWindowManager.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/contextmanager/SlidingWindowManager.java index 64f75191d3..80ad461f28 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/contextmanager/SlidingWindowManager.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/contextmanager/SlidingWindowManager.java @@ -116,7 +116,7 @@ public void execute(ContextManagerContext context) { context.setToolInteractions(updatedToolInteractions); // Update the _interactions parameter with smaller size of updated interactions - Map parameters = context.getParameters(); + Map parameters = context.getParameters(); if (parameters == null) { parameters = new HashMap<>(); context.setParameters(parameters); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/contextmanager/SummarizationManager.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/contextmanager/SummarizationManager.java index b9a4cc4ca8..85f8449881 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/contextmanager/SummarizationManager.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/contextmanager/SummarizationManager.java @@ -6,6 +6,7 @@ package org.opensearch.ml.engine.algorithms.contextmanager; import static org.opensearch.ml.common.FunctionName.REMOTE; +import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.INTERACTIONS; import java.util.ArrayList; import java.util.HashMap; @@ -51,7 +52,7 @@ public class SummarizationManager implements ContextManager { private static final double DEFAULT_SUMMARY_RATIO = 0.3; private static final int DEFAULT_PRESERVE_RECENT_MESSAGES = 10; private static final String DEFAULT_SUMMARIZATION_PROMPT = - "You are a tool interactions summarization agent. Summarize the provided tool interactions concisely while preserving key information and context."; + "You are a interactions summarization agent. Summarize the provided interactions concisely while preserving key information and context."; protected double summaryRatio; protected int preserveRecentMessages; @@ -150,7 +151,7 @@ public void execute(ContextManagerContext context) { // Get model ID String modelId = summarizationModelId; if (modelId == null) { - Map parameters = context.getParameters(); + Map parameters = context.getParameters(); if (parameters != null) { modelId = (String) parameters.get("_llm_model_id"); } @@ -163,7 +164,7 @@ public void execute(ContextManagerContext context) { // Prepare summarization parameters Map summarizationParameters = new HashMap<>(); - summarizationParameters.put("prompt", StringUtils.toJson(String.join("\n", messagesToSummarize))); + summarizationParameters.put("prompt", "Help summarize the following" + StringUtils.toJson(String.join(",", messagesToSummarize))); summarizationParameters.put("system_prompt", summarizationSystemPrompt); executeSummarization(context, modelId, summarizationParameters, messagesToSummarizeCount, remainingMessages, toolInteractions); @@ -193,7 +194,6 @@ protected void executeSummarization( String summary = extractSummaryFromResponse(response); processSummarizationResult(context, summary, messagesToSummarizeCount, remainingMessages, originalToolInteractions); } catch (Exception e) { - log.error("Failed to process summarization response", e); // Fallback to default behavior processSummarizationResult( context, @@ -204,7 +204,6 @@ protected void executeSummarization( ); } }, e -> { - log.error("Summarization prediction failed", e); // Fallback to default behavior processSummarizationResult( context, @@ -218,7 +217,6 @@ protected void executeSummarization( client.execute(MLPredictionTaskAction.INSTANCE, request, listener); } catch (Exception e) { - log.error("Failed to execute summarization", e); // Fallback to default behavior processSummarizationResult( context, @@ -262,12 +260,12 @@ protected void processSummarizationResult( context.setToolInteractions(updatedToolInteractions); // Update parameters - Map parameters = context.getParameters(); + Map parameters = context.getParameters(); if (parameters == null) { parameters = new HashMap<>(); - context.setParameters(parameters); } - parameters.put("_interactions", ", " + String.join(", ", updatedInteractions)); + parameters.put(INTERACTIONS, ", " + String.join(", ", updatedInteractions)); + context.setParameters(parameters); log .info( @@ -294,12 +292,6 @@ private String extractSummaryFromResponse(MLTaskResponse response) { Map dataAsMap = tensors.get(0).getDataAsMap(); // TODO need to parse LLM response output, maybe reused how filtered output from chatAgentRunner return StringUtils.toJson(dataAsMap); - // if (dataAsMap.containsKey("response")) { - // return dataAsMap.get("response").toString(); - // } - // if (dataAsMap.containsKey("result")) { - // return dataAsMap.get("result").toString(); - // } } } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/contextmanager/ToolsOutputTruncateManager.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/contextmanager/ToolsOutputTruncateManager.java index b5515ed56e..4fa97c156d 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/contextmanager/ToolsOutputTruncateManager.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/contextmanager/ToolsOutputTruncateManager.java @@ -28,7 +28,7 @@ public class ToolsOutputTruncateManager implements ContextManager { private static final String MAX_OUTPUT_LENGTH_KEY = "max_output_length"; // Default values - private static final int DEFAULT_MAX_OUTPUT_LENGTH = 2000; + private static final int DEFAULT_MAX_OUTPUT_LENGTH = 40000; private int maxOutputLength; private List activationRules; @@ -76,7 +76,7 @@ public boolean shouldActivate(ContextManagerContext context) { @Override public void execute(ContextManagerContext context) { // Process current tool output from parameters - Map parameters = context.getParameters(); + Map parameters = context.getParameters(); if (parameters == null) { log.debug("No parameters available for tool output truncation"); return; diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutorTest.java index 67a1dc0db3..5100e3c556 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutorTest.java @@ -668,7 +668,7 @@ public void test_Regenerate_OriginalInteraction_NotExist() throws IOException { @Test public void test_CreateFlowAgent() { MLAgent mlAgent = MLAgent.builder().name("test_agent").type("flow").build(); - MLAgentRunner mlAgentRunner = mlAgentExecutor.getAgentRunner(mlAgent, Mockito.any()); + MLAgentRunner mlAgentRunner = mlAgentExecutor.getAgentRunner(mlAgent, null); Assert.assertTrue(mlAgentRunner instanceof MLFlowAgentRunner); } @@ -676,7 +676,7 @@ public void test_CreateFlowAgent() { public void test_CreateChatAgent() { LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build(); MLAgent mlAgent = MLAgent.builder().name("test_agent").type(MLAgentType.CONVERSATIONAL.name()).llm(llmSpec).build(); - MLAgentRunner mlAgentRunner = mlAgentExecutor.getAgentRunner(mlAgent, Mockito.any()); + MLAgentRunner mlAgentRunner = mlAgentExecutor.getAgentRunner(mlAgent, null); Assert.assertTrue(mlAgentRunner instanceof MLChatAgentRunner); } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java index f6c3e3618e..c63db9df4f 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java @@ -710,7 +710,7 @@ public void testToolParameters() { // Verify the size of parameters passed in the tool run method. ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class); verify(firstTool).run((Map) argumentCaptor.capture(), any()); - assertEquals(15, ((Map) argumentCaptor.getValue()).size()); + assertEquals(16, ((Map) argumentCaptor.getValue()).size()); Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); ModelTensorOutput modelTensorOutput = (ModelTensorOutput) objectCaptor.getValue(); @@ -738,7 +738,7 @@ public void testToolUseOriginalInput() { // Verify the size of parameters passed in the tool run method. ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class); verify(firstTool).run((Map) argumentCaptor.capture(), any()); - assertEquals(16, ((Map) argumentCaptor.getValue()).size()); + assertEquals(17, ((Map) argumentCaptor.getValue()).size()); assertEquals("raw input", ((Map) argumentCaptor.getValue()).get("input")); Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); @@ -804,7 +804,7 @@ public void testToolConfig() { // Verify the size of parameters passed in the tool run method. ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class); verify(firstTool).run((Map) argumentCaptor.capture(), any()); - assertEquals(16, ((Map) argumentCaptor.getValue()).size()); + assertEquals(17, ((Map) argumentCaptor.getValue()).size()); // The value of input should be "config_value". assertEquals("config_value", ((Map) argumentCaptor.getValue()).get("input")); @@ -834,7 +834,7 @@ public void testToolConfigWithInputPlaceholder() { // Verify the size of parameters passed in the tool run method. ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class); verify(firstTool).run((Map) argumentCaptor.capture(), any()); - assertEquals(16, ((Map) argumentCaptor.getValue()).size()); + assertEquals(17, ((Map) argumentCaptor.getValue()).size()); // The value of input should be replaced with the value associated with the key "key2" of the first tool. assertEquals("value2", ((Map) argumentCaptor.getValue()).get("input")); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java index 5245ccc320..c11dd72649 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java @@ -135,6 +135,7 @@ public void setup() { // memory mlMemorySpec = new MLMemorySpec(ConversationIndexMemory.TYPE, "uuid", 10); + when(memoryMap.get(ConversationIndexMemory.TYPE)).thenReturn(memoryFactory); when(memoryMap.get(anyString())).thenReturn(memoryFactory); when(conversationIndexMemory.getConversationId()).thenReturn("test_memory_id"); when(conversationIndexMemory.getMemoryManager()).thenReturn(mlMemoryManager); @@ -170,7 +171,8 @@ public void setup() { toolFactories, memoryMap, sdkClient, - encryptor + encryptor, + null ); // Setup tools diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteAction.java index 92092a6cb1..6b293595c6 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteAction.java @@ -79,18 +79,6 @@ public List routes() { public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { MLExecuteTaskRequest mlExecuteTaskRequest = getRequest(request); - // Extract context_management query parameter for agent execution - String uri = request.getHttpRequest().uri(); - if (uri.startsWith(ML_BASE_URI + "/agents/")) { - String contextManagementName = request.param("context_management"); - // Store context management name in the agent input - if (contextManagementName != null && !contextManagementName.trim().isEmpty()) { - if (mlExecuteTaskRequest.getInput() instanceof AgentMLInput) { - ((AgentMLInput) mlExecuteTaskRequest.getInput()).setContextManagementName(contextManagementName); - } - } - } - return channel -> client.execute(MLExecuteTaskAction.INSTANCE, mlExecuteTaskRequest, new ActionListener<>() { @Override public void onResponse(MLExecuteTaskResponse response) {