Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import static org.opensearch.ml.common.MLTask.STATE_FIELD;
import static org.opensearch.ml.common.MLTask.TASK_ID_FIELD;
import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD;
import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD;
import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD;
import static org.opensearch.ml.common.utils.MLTaskUtils.updateMLTaskDirectly;
Expand Down Expand Up @@ -460,9 +461,9 @@ private void executePlanningLoop(
results.put(PARENT_INTERACTION_ID_FIELD, tensor.getResult());
break;
default:
Map<String, ?> dataMap = tensor.getDataAsMap();
if (dataMap != null && dataMap.containsKey(RESPONSE_FIELD)) {
results.put(STEP_RESULT_FIELD, (String) dataMap.get(RESPONSE_FIELD));
String stepResult = parseTensorDataMap(tensor);
if (stepResult != null) {
results.put(STEP_RESULT_FIELD, stepResult);
}
}
});
Expand Down Expand Up @@ -502,8 +503,17 @@ private void executePlanningLoop(
}, e -> log.error("Failed to update task {} with executor memory ID", taskId, e)));
}

completedSteps.add(String.format("\nStep %d: %s\n", stepsExecuted + 1, stepToExecute));
completedSteps.add(String.format("\nStep %d Result: %s\n", stepsExecuted + 1, results.get(STEP_RESULT_FIELD)));
completedSteps.add(String.format("\n<step-%d>\n%s\n</step-%d>\n", stepsExecuted + 1, stepToExecute, stepsExecuted + 1));
completedSteps
.add(
String
.format(
"\n<step-%d-result>\n%s\n</step-%d-result>\n",
stepsExecuted + 1,
results.get(STEP_RESULT_FIELD),
stepsExecuted + 1
)
);

saveTraceData(
(ConversationIndexMemory) memory,
Expand Down Expand Up @@ -544,6 +554,39 @@ private void executePlanningLoop(
client.execute(MLPredictionTaskAction.INSTANCE, request, planListener);
}

@VisibleForTesting
String parseTensorDataMap(ModelTensor tensor) {
Map<String, ?> dataMap = tensor.getDataAsMap();
if (dataMap == null) {
return null;
}

StringBuilder stepResult = new StringBuilder();
if (dataMap.containsKey(RESPONSE_FIELD)) {
stepResult.append((String) dataMap.get(RESPONSE_FIELD));
}

if (dataMap.containsKey(INTERACTIONS_ADDITIONAL_INFO_FIELD)) {
stepResult.append("\n<step-traces>\n");
((Map<String, Object>) dataMap.get(INTERACTIONS_ADDITIONAL_INFO_FIELD))
.forEach(
(key, value) -> stepResult
.append("<")
.append(key)
.append(">")
.append("\n")
.append(value)
.append("\n")
.append("</")
.append(key)
.append(">")
);
stepResult.append("\n</step-traces>\n");
}

return stepResult.toString();
}

@VisibleForTesting
Map<String, Object> parseLLMOutput(Map<String, String> allParams, ModelTensorOutput modelTensorOutput) {
Map<String, Object> modelOutput = new HashMap<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ public class PromptTemplate {
+ "${parameters."
+ PLANNER_PROMPT_FIELD
+ "} \n"
+ "Objective: ${parameters."
+ "Objective: ```${parameters."
+ USER_PROMPT_FIELD
+ "} \n\nRemember: Respond only in JSON format following the required schema.";
+ "}``` \n\nRemember: Respond only in JSON format following the required schema.";

public static final String DEFAULT_REFLECT_PROMPT_TEMPLATE = "${parameters."
+ DEFAULT_PROMPT_TOOLS_FIELD
Expand All @@ -41,10 +41,10 @@ public class PromptTemplate {
+ "Objective: ```${parameters."
+ USER_PROMPT_FIELD
+ "}```\n\n"
+ "Original plan:\n[${parameters."
+ "Previous plan:\n[${parameters."
+ STEPS_FIELD
+ "}] \n\n"
+ "You have currently executed the following steps from the original plan: \n[${parameters."
+ "You have currently executed the following steps: \n[${parameters."
+ COMPLETED_STEPS_FIELD
+ "}] \n\n"
+ "${parameters."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.any;
Expand Down Expand Up @@ -677,6 +678,42 @@ public void testSaveAndReturnFinalResult() {
assertEquals(finalResult, secondModelTensorList.get(0).getDataAsMap().get("response"));
}

@Test
public void testParseTensorDataMap() {
// Test with response only
Map<String, Object> dataMap = new HashMap<>();
dataMap.put("response", "test response");
ModelTensor tensor = ModelTensor.builder().dataAsMap(dataMap).build();

String result = mlPlanExecuteAndReflectAgentRunner.parseTensorDataMap(tensor);
assertEquals("test response", result);

// Test with additional info
Map<String, Object> additionalInfo = new HashMap<>();
additionalInfo.put("trace1", "content1");
additionalInfo.put("trace2", "content2");
dataMap.put("additional_info", additionalInfo);

result = mlPlanExecuteAndReflectAgentRunner.parseTensorDataMap(tensor);
assertTrue(result.contains("test response"));
assertTrue(result.contains("<step-traces>"));
assertTrue(result.contains("<trace1>\ncontent1\n</trace1>"));
assertTrue(result.contains("<trace2>\ncontent2\n</trace2>"));
assertTrue(result.contains("</step-traces>"));

// Test with null dataMap
ModelTensor nullTensor = ModelTensor.builder().build();
assertNull(mlPlanExecuteAndReflectAgentRunner.parseTensorDataMap(nullTensor));

// No response field
Map<String, Object> noResponseMap = new HashMap<>();
noResponseMap.put("additional_info", additionalInfo);
ModelTensor noResponseTensor = ModelTensor.builder().dataAsMap(noResponseMap).build();
result = mlPlanExecuteAndReflectAgentRunner.parseTensorDataMap(noResponseTensor);
assertTrue(result.contains("<step-traces>"));
assertFalse(result.contains("test response"));
}

@Test
public void testUpdateTaskWithExecutorAgentInfo() {
MLAgent mlAgent = createMLAgentWithTools();
Expand Down Expand Up @@ -765,4 +802,57 @@ public void testUpdateTaskWithExecutorAgentInfo() {
mlTaskUtilsMockedStatic.verify(() -> MLTaskUtils.updateMLTaskDirectly(eq(taskId), eq(taskUpdates), eq(client), any()));
}
}

@Test
public void testExecutionWithNullStepResult() {
MLAgent mlAgent = createMLAgentWithTools();

// Setup LLM response for planning phase - returns steps to execute
doAnswer(invocation -> {
ActionListener<Object> listener = invocation.getArgument(2);
ModelTensor modelTensor = ModelTensor
.builder()
.dataAsMap(ImmutableMap.of("response", "{\"steps\":[\"step1\"], \"result\":\"\"}"))
.build();
ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build();
ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build();
when(mlTaskResponse.getOutput()).thenReturn(mlModelTensorOutput);
listener.onResponse(mlTaskResponse);
return null;
}).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(MLPredictionTaskRequest.class), any());

// Setup executor response with tensor that has null dataMap - this will hit line 465
doAnswer(invocation -> {
ActionListener<Object> listener = invocation.getArgument(2);
ModelTensor memoryIdTensor = ModelTensor.builder().name(MLAgentExecutor.MEMORY_ID).result("test_memory_id").build();
ModelTensor parentIdTensor = ModelTensor.builder().name(MLAgentExecutor.PARENT_INTERACTION_ID).result("test_parent_id").build();
// This tensor will return null from parseTensorDataMap, hitting the stepResult != null check
ModelTensor nullDataTensor = ModelTensor.builder().name("other").build();
ModelTensors modelTensors = ModelTensors
.builder()
.mlModelTensors(Arrays.asList(memoryIdTensor, parentIdTensor, nullDataTensor))
.build();
ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build();
when(mlExecuteTaskResponse.getOutput()).thenReturn(mlModelTensorOutput);
listener.onResponse(mlExecuteTaskResponse);
return null;
}).when(client).execute(eq(MLExecuteTaskAction.INSTANCE), any(MLExecuteTaskRequest.class), any());

Map<String, String> params = new HashMap<>();
params.put("question", "test question");
params.put("parent_interaction_id", "test_parent_interaction_id");

// Capture the exception in the listener
doAnswer(invocation -> {
Exception e = invocation.getArgument(0);
assertTrue(e instanceof IllegalStateException);
assertEquals("No valid response found in ReAct agent output", e.getMessage());
return null;
}).when(agentActionListener).onFailure(any());

mlPlanExecuteAndReflectAgentRunner.run(mlAgent, params, agentActionListener);

// Verify that onFailure was called with the expected exception
verify(agentActionListener).onFailure(any(IllegalStateException.class));
}
}
Loading