Skip to content

Commit ef3f888

Browse files
aymeric-roucherbaptistecolle
authored andcommitted
Also store final outputs
1 parent 9cdf0d9 commit ef3f888

File tree

1 file changed

+15
-15
lines changed

1 file changed

+15
-15
lines changed

scripts/generate_agent_traces.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,9 @@ class ModifiedFinalAnswerTool(Tool):
3636
output_type = "string"
3737

3838
def forward(self, answer_function: Any) -> str:
39-
print("USING MODIFIED FINAL ANSWER TOOL")
40-
return inspect.getsource(answer_function)
39+
source_code = inspect.getsource(answer_function)
40+
print("USING MODIFIED FINAL ANSWER TOOL, got source code:\n", source_code)
41+
return source_code
4142

4243
def __init__(self, *args, **kwargs):
4344
self.is_initialized = False
@@ -110,19 +111,18 @@ def model(messages, stop_sequences = None):
110111
max_steps=10,
111112
verbosity_level=2
112113
)
113-
114+
114115
try:
115116
output = agent.run(task)
116-
print("GOT OUTPUT:", output)
117-
return agent.write_memory_to_messages()
117+
return agent.write_memory_to_messages(), output
118118
except Exception as e:
119119
print(f"Error when generating agentic trace: {e}")
120120
return None
121121

122122
def process_example(example, session, args, output_file, pbar=None):
123123
prompt = f"""Here is a task to solve using a function:
124124
{example[args.prompt_column]}
125-
125+
126126
Now write a function that solves the problem, test it and return it using final_answer(your_function).
127127
The function should take the inputs described in the task above, using them in this way: the function will be passed the 'lines' described in the task as different arguments.
128128
For instance:
@@ -132,29 +132,28 @@ def process_example(example, session, args, output_file, pbar=None):
132132
ALWAYS RUN THE FUNCTION IN A CODE SNIPPET WITH TEST CASES BEFORE RETURNING IT.
133133
"""
134134
try:
135-
agent_runs = []
135+
agent_outputs, agent_memories = [], []
136136
for _ in range(args.num_generations):
137-
agent_run = get_agent_run(session, prompt, args)
138-
agent_runs.append(agent_run)
137+
agent_output, agent_memory = get_agent_run(session, prompt, args)
138+
agent_outputs.append(agent_output)
139+
agent_memories.append(agent_memory)
139140

140-
if any(agent_run is None for agent_run in agent_runs):
141+
if any(agent_output is None for agent_output in agent_outputs):
141142
print("Error processing example")
142143
if pbar:
143144
pbar.update(1)
144145
return None
145146

146-
generations = []
147147
finish_reasons = []
148148
api_metadata = []
149149

150-
for agent_run in agent_runs:
151-
generations.append(agent_run)
150+
for agent_run in agent_output:
152151
finish_reasons.append(None)
153152
api_metadata.append(None)
154153

155154
# Convert agent_run to a serializable format
156155
serializable_generations = []
157-
for generation in generations:
156+
for generation in agent_memories:
158157
if generation is not None:
159158
# Convert to a simple list of dictionaries if it's not already
160159
if isinstance(generation, list):
@@ -167,11 +166,12 @@ def process_example(example, session, args, output_file, pbar=None):
167166
serializable_generations.append(str(generation))
168167
else:
169168
serializable_generations.append(None)
170-
169+
171170
# Combine original dataset fields with generations
172171
result = {
173172
**example, # Preserve all original dataset fields
174173
"generations": serializable_generations,
174+
"final_outputs": agent_outputs,
175175
"finish_reasons": finish_reasons,
176176
"api_metadata": api_metadata,
177177
}

0 commit comments

Comments
 (0)