Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ public class MLChatAgentRunner implements MLAgentRunner {
public static final String INJECT_DATETIME_FIELD = "inject_datetime";
public static final String DATETIME_FORMAT_FIELD = "datetime_format";
public static final String SYSTEM_PROMPT_FIELD = "system_prompt";
public static final String VERBOSE_FILTER = "verbose_filter";

private static final String DEFAULT_MAX_ITERATIONS = "10";
private static final String MAX_ITERATIONS_MESSAGE = "Agent reached maximum iterations (%d) without completing the task";
Expand Down Expand Up @@ -300,6 +301,7 @@ private void runReAct(
String parentInteractionId = tmpParameters.get(MLAgentExecutor.PARENT_INTERACTION_ID);
boolean verbose = Boolean.parseBoolean(tmpParameters.getOrDefault(VERBOSE, "false"));
boolean traceDisabled = tmpParameters.containsKey(DISABLE_TRACE) && Boolean.parseBoolean(tmpParameters.get(DISABLE_TRACE));
List<String> traceFilter = parseTraceFilter(tmpParameters.get(VERBOSE_FILTER));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Has any validation applied to this verbose filter value? like verboseFilter.matches("^[a-zA-Z]+(,[a-zA-Z]+)*$"). I think it's worth a check before using the filters.


// Create root interaction.
ConversationIndexMemory conversationIndexMemory = (ConversationIndexMemory) memory;
Expand Down Expand Up @@ -379,13 +381,15 @@ private void runReAct(
lastActionInput.set(actionInput);
lastToolSelectionResponse.set(thoughtResponse);

traceTensors
.add(
ModelTensors
.builder()
.mlModelTensors(List.of(ModelTensor.builder().name("response").result(thoughtResponse).build()))
.build()
);
if (shouldIncludeInTrace("LLM", traceFilter)) {
traceTensors
.add(
ModelTensors
.builder()
.mlModelTensors(List.of(ModelTensor.builder().name("response").result(thoughtResponse).build()))
.build()
);
}

saveTraceData(
conversationIndexMemory,
Expand Down Expand Up @@ -487,18 +491,20 @@ private void runReAct(

sessionMsgAnswerBuilder.append(outputToOutputString(filteredOutput));
streamingWrapper.sendToolResponse(outputToOutputString(output), sessionId, parentInteractionId);
traceTensors
.add(
ModelTensors
.builder()
.mlModelTensors(
Collections
.singletonList(
ModelTensor.builder().name("response").result(sessionMsgAnswerBuilder.toString()).build()
)
)
.build()
);
if (shouldIncludeInTrace(lastAction.get(), traceFilter)) {
traceTensors
.add(
ModelTensors
.builder()
.mlModelTensors(
Collections
.singletonList(
ModelTensor.builder().name("response").result(sessionMsgAnswerBuilder.toString()).build()
)
)
.build()
);
}

if (finalI == maxIterations - 1) {
handleMaxIterationsReached(
Expand Down Expand Up @@ -842,6 +848,21 @@ static Map<String, String> constructLLMParams(LLMSpec llm, Map<String, String> p
return tmpParameters;
}

private static List<String> parseTraceFilter(String traceFilterParam) {
if (traceFilterParam == null || traceFilterParam.trim().isEmpty()) {
return null;
}
return List.of(traceFilterParam.split(","));
}

private static boolean shouldIncludeInTrace(String toolName, List<String> traceFilter) {
if (traceFilter == null) {
return true;
}

return traceFilter.contains(toolName);
}

public static void returnFinalResponse(
String sessionId,
ActionListener<Object> listener,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@ public class MLPlanExecuteAndReflectAgentRunner implements MLAgentRunner {
public static final String INJECT_DATETIME_FIELD = "inject_datetime";
public static final String DATETIME_FORMAT_FIELD = "datetime_format";

public static final String EXECUTOR_VERBOSE = "executor_verbose";
public static final String EXECUTOR_VERBOSE_FILTER = "executor_verbose_filter";
Comment on lines +157 to +158
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the example, this field name is "verbose" and "verbose_filter", which seems matching to the name in chatAgent. I think "verbose" and "verbose filter" is simpler.


public MLPlanExecuteAndReflectAgentRunner(
Client client,
Settings settings,
Expand Down Expand Up @@ -435,6 +438,15 @@ private void executePlanningLoop(
allParams.getOrDefault(EXECUTOR_MESSAGE_HISTORY_LIMIT, DEFAULT_EXECUTOR_MESSAGE_HISTORY_LIMIT)
);

// Pass through verbose and verbose_filter if provided
if (allParams.containsKey(EXECUTOR_VERBOSE)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider validation too.

reactParams.put(AgentUtils.VERBOSE, allParams.get(EXECUTOR_VERBOSE));
}

if (allParams.containsKey(EXECUTOR_VERBOSE_FILTER)) {
reactParams.put(MLChatAgentRunner.VERBOSE_FILTER, allParams.get(EXECUTOR_VERBOSE_FILTER));
}

AgentMLInput agentInput = AgentMLInput
.AgentMLInputBuilder()
.agentId(reActAgentId)
Expand All @@ -449,8 +461,9 @@ private void executePlanningLoop(

// Navigate through the structure to get the response
Map<String, String> results = new HashMap<>();
List<String> allResponses = new ArrayList<>();

// Process tensors in a single stream
// Process tensors to collect all responses
reactResult.getMlModelOutputs().stream().flatMap(output -> output.getMlModelTensors().stream()).forEach(tensor -> {
switch (tensor.getName()) {
case MEMORY_ID_FIELD:
Expand All @@ -459,14 +472,35 @@ private void executePlanningLoop(
case PARENT_INTERACTION_ID_FIELD:
results.put(PARENT_INTERACTION_ID_FIELD, tensor.getResult());
break;
default:
Map<String, ?> dataMap = tensor.getDataAsMap();
if (dataMap != null && dataMap.containsKey(RESPONSE_FIELD)) {
results.put(STEP_RESULT_FIELD, (String) dataMap.get(RESPONSE_FIELD));
case RESPONSE_FIELD:
if (tensor.getResult() != null) {
allResponses.add(tensor.getResult());
} else {
Map<String, ?> dataMap = tensor.getDataAsMap();
if (dataMap != null && dataMap.containsKey(RESPONSE_FIELD)) {
allResponses.add((String) dataMap.get(RESPONSE_FIELD));
}
}
}
});

if (!allResponses.isEmpty()) {
StringBuilder stepResult = new StringBuilder();
stepResult.append(allResponses.getLast());
if (allResponses.size() > 1) {
stepResult.append("\n\n<step-traces>");
}

for (int i = 0; i < allResponses.size() - 1; i++) {
stepResult.append("\n\n").append(allResponses.get(i));
if (i == allResponses.size() - 2) {
stepResult.append("\n</step-traces>");
}
}

results.put(STEP_RESULT_FIELD, stepResult.toString());
}

Comment on lines +487 to +503
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if (!allResponses.isEmpty()) {
    results.put(STEP_RESULT_FIELD, formatStepResultWithTraces(allResponses));
}

private String formatStepResultWithTraces(List<String> responses) {
    StringBuilder result = new StringBuilder(responses.getLast());
    
    if (responses.size() > 1) {
        List<String> traces = responses.subList(0, responses.size() - 1);
        result.append("\n\n<step-traces>\n\n")
              .append(String.join("\n\n", traces))
              .append("\n</step-traces>");
    }
    
    return result.toString();
}

This is a better version to avoid for loop.

if (!results.containsKey(STEP_RESULT_FIELD)) {
throw new IllegalStateException("No valid response found in ReAct agent output");
}
Expand Down Expand Up @@ -502,8 +536,17 @@ private void executePlanningLoop(
}, e -> log.error("Failed to update task {} with executor memory ID", taskId, e)));
}

completedSteps.add(String.format("\nStep %d: %s\n", stepsExecuted + 1, stepToExecute));
completedSteps.add(String.format("\nStep %d Result: %s\n", stepsExecuted + 1, results.get(STEP_RESULT_FIELD)));
completedSteps.add(String.format("\n<step-%d>\n%s\n</step-%d>\n", stepsExecuted + 1, stepToExecute, stepsExecuted + 1));
completedSteps
.add(
String
.format(
"\n<step-%d-result>\n%s\n</step-%d-result>\n",
stepsExecuted + 1,
results.get(STEP_RESULT_FIELD),
stepsExecuted + 1
)
);

saveTraceData(
(ConversationIndexMemory) memory,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ public class PromptTemplate {
+ "${parameters."
+ PLANNER_PROMPT_FIELD
+ "} \n"
+ "Objective: ${parameters."
+ "Objective: ```${parameters."
+ USER_PROMPT_FIELD
+ "} \n\nRemember: Respond only in JSON format following the required schema.";
+ "}``` \n\nRemember: Respond only in JSON format following the required schema.";

public static final String DEFAULT_REFLECT_PROMPT_TEMPLATE = "${parameters."
+ DEFAULT_PROMPT_TOOLS_FIELD
Expand All @@ -41,10 +41,10 @@ public class PromptTemplate {
+ "Objective: ```${parameters."
+ USER_PROMPT_FIELD
+ "}```\n\n"
+ "Original plan:\n[${parameters."
+ "Previous plan:\n[${parameters."
+ STEPS_FIELD
+ "}] \n\n"
+ "You have currently executed the following steps from the original plan: \n[${parameters."
+ "You have currently executed the following steps: \n[${parameters."
+ COMPLETED_STEPS_FIELD
+ "}] \n\n"
+ "${parameters."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ public void invokeRemoteService(
SdkHttpFullRequest request;
switch (connector.getActionHttpMethod(action).toUpperCase(Locale.ROOT)) {
case "POST":
log.debug("original payload to remote model: " + payload);
log.debug("\n\n\noriginal payload to remote model: " + payload);
request = ConnectorUtils.buildSdkRequest(action, connector, parameters, payload, POST);
break;
case "GET":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ public void invokeRemoteService(
SdkHttpFullRequest request;
switch (connector.getActionHttpMethod(action).toUpperCase(Locale.ROOT)) {
case "POST":
log.debug("original payload to remote model: " + payload);
log.debug("\n\n\noriginal payload to remote model: " + payload);
request = ConnectorUtils.buildSdkRequest(action, connector, parameters, payload, POST);
break;
case "GET":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1171,4 +1171,96 @@ public void testConstructLLMParams_DefaultValues() {
Assert.assertTrue(result.containsKey(AgentUtils.RESPONSE_FORMAT_INSTRUCTION));
Assert.assertTrue(result.containsKey(AgentUtils.TOOL_RESPONSE));
}

@Test
public void testVerboseFilterWithSpecificFields() {
// Create an MLAgent and run with verbose_filter
MLAgent mlAgent = createMLAgentWithTools();
Map<String, String> params = new HashMap<>();
params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent_interaction_id");
params.put("verbose", "true");
params.put("verbose_filter", "firstTool");

mlChatAgentRunner.run(mlAgent, params, agentActionListener, null);

// Capture the response
ArgumentCaptor<Object> responseCaptor = ArgumentCaptor.forClass(Object.class);
verify(agentActionListener).onResponse(responseCaptor.capture());

Object capturedResponse = responseCaptor.getValue();
assertTrue(capturedResponse instanceof ModelTensorOutput);
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) capturedResponse;

// Count response fields across all outputs
int responseFieldCount = 0;
for (ModelTensors output : modelTensorOutput.getMlModelOutputs()) {
for (ModelTensor tensor : output.getMlModelTensors()) {
if ("response".equals(tensor.getName())) {
responseFieldCount++;
}
}
}

// Verify there is more than one response field
assertEquals(2, responseFieldCount);
}

@Test
public void testVerboseFilterWithInvalidPath() {
// Create an MLAgent and run with invalid verbose_filter
MLAgent mlAgent = createMLAgentWithTools();
Map<String, String> params = new HashMap<>();
params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent_interaction_id");
params.put("verbose", "true");
params.put("verbose_filter", "RandomTool");

mlChatAgentRunner.run(mlAgent, params, agentActionListener, null);

// Should still work but filter nothing
ArgumentCaptor<Object> responseCaptor = ArgumentCaptor.forClass(Object.class);
verify(agentActionListener).onResponse(responseCaptor.capture());

Object capturedResponse = responseCaptor.getValue();
assertTrue(capturedResponse instanceof ModelTensorOutput);
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) capturedResponse;
int responseFieldCount = 0;
for (ModelTensors output : modelTensorOutput.getMlModelOutputs()) {
for (ModelTensor tensor : output.getMlModelTensors()) {
if ("response".equals(tensor.getName())) {
responseFieldCount++;
}
}
}

assertEquals(1, responseFieldCount);
}

@Test
public void testVerboseFilterWithoutVerbose() {
// Create an MLAgent and run with verbose_filter but verbose=false
MLAgent mlAgent = createMLAgentWithTools();
Map<String, String> params = new HashMap<>();
params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent_interaction_id");
params.put("verbose", "false");

mlChatAgentRunner.run(mlAgent, params, agentActionListener, null);

// verbose_filter should be ignored when verbose=false
ArgumentCaptor<Object> responseCaptor = ArgumentCaptor.forClass(Object.class);
verify(agentActionListener).onResponse(responseCaptor.capture());

Object capturedResponse = responseCaptor.getValue();
assertTrue(capturedResponse instanceof ModelTensorOutput);
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) capturedResponse;
int responseFieldCount = 0;
for (ModelTensors output : modelTensorOutput.getMlModelOutputs()) {
for (ModelTensor tensor : output.getMlModelTensors()) {
if ("response".equals(tensor.getName())) {
responseFieldCount++;
}
}
}

assertEquals(1, responseFieldCount);
}
}
Loading
Loading