Skip to content

Commit 4e5dc5a

Browse files
committed
Fix:revert config changes + better implementation
1 parent 0bc1098 commit 4e5dc5a

File tree

2 files changed

+13
-17
lines changed

2 files changed

+13
-17
lines changed

cfg/config.yaml

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
defaults:
22
- _self_
33
- problem: tsp_aco
4-
- llm_client: siliconflow
4+
- llm_client: openai
55
# [Optional] set different clients for operators
6-
- llm_client@llm_long_ref: siliconflow
6+
- llm_client@llm_long_ref: null
77
- llm_client@llm_short_ref: null
88
- llm_client@llm_crossover: null
9-
- llm_client@llm_mutation: null
9+
- llm_client@llm_mutation: null
1010
- override hydra/output: local
1111

1212
hydra:

reevo.py

+10-14
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@ def __init__(
2828
self.generator_llm = generator_llm
2929
self.reflector_llm = reflector_llm or generator_llm
3030

31-
self.short_reflector_llm = short_reflector_llm
32-
self.long_reflector_llm = long_reflector_llm
33-
self.crossover_llm = crossover_llm
34-
self.mutataion_llm = mutation_llm
31+
self.short_reflector_llm = short_reflector_llm or self.reflector_llm
32+
self.long_reflector_llm = long_reflector_llm or self.reflector_llm
33+
self.crossover_llm = crossover_llm or generator_llm
34+
self.mutation_llm = mutation_llm or generator_llm
3535

3636
self.root_dir = root_dir
3737

@@ -83,7 +83,7 @@ def init_prompt(self) -> None:
8383
self.user_reflector_st_prompt = file_to_string(f'{self.prompt_dir}/common/user_reflector_st.txt') if self.problem_type != "black_box" else file_to_string(f'{self.prompt_dir}/common/user_reflector_st_black_box.txt') # shrot-term reflection
8484
self.user_reflector_lt_prompt = file_to_string(f'{self.prompt_dir}/common/user_reflector_lt.txt') # long-term reflection
8585
self.crossover_prompt = file_to_string(f'{self.prompt_dir}/common/crossover.txt')
86-
self.mutataion_prompt = file_to_string(f'{self.prompt_dir}/common/mutation.txt')
86+
self.mutation_prompt = file_to_string(f'{self.prompt_dir}/common/mutation.txt')
8787
self.user_generator_prompt = file_to_string(f'{self.prompt_dir}/common/user_generator.txt').format(
8888
func_name=self.func_name,
8989
problem_desc=self.problem_desc,
@@ -373,7 +373,7 @@ def short_term_reflection(self, population: list[dict]) -> tuple[list[list[dict]
373373
better_code_lst.append(better_code)
374374

375375
# Asynchronously generate responses
376-
response_lst = self.reflector_llm.multi_chat_completion(messages_lst) if self.short_reflector_llm is None else self.short_reflector_llm.multi_chat_completion(messages_lst)
376+
response_lst = self.short_reflector_llm.multi_chat_completion(messages_lst)
377377
return response_lst, worse_code_lst, better_code_lst
378378

379379
def long_term_reflection(self, short_term_reflections: list[str]) -> None:
@@ -392,7 +392,7 @@ def long_term_reflection(self, short_term_reflections: list[str]) -> None:
392392
logging.info("Long-term Reflection Prompt: \nSystem Prompt: \n" + system + "\nUser Prompt: \n" + user)
393393
self.print_long_term_reflection_prompt = False
394394

395-
self.long_term_reflection_str = self.reflector_llm.multi_chat_completion([messages])[0] if self.long_reflector_llm is None else self.long_reflector_llm.multi_chat_completion([messages])[0]
395+
self.long_term_reflection_str = self.long_reflector_llm.multi_chat_completion([messages])[0]
396396

397397
# Write reflections to file
398398
file_name = f"problem_iter{self.iteration}_short_term_reflections.txt"
@@ -430,8 +430,7 @@ def crossover(self, short_term_reflection_tuple: tuple[list[list[dict]], list[st
430430
self.print_crossover_prompt = False
431431

432432
# Asynchronously generate responses
433-
# use crossover_llm if it's not None
434-
response_lst = self.generator_llm.multi_chat_completion(messages_lst) if self.crossover_llm is None else self.crossover_llm.multi_chat_completion(messages_lst)
433+
response_lst = self.crossover_llm.multi_chat_completion(messages_lst)
435434
crossed_population = [self.response_to_individual(response, response_id) for response_id, response in enumerate(response_lst)]
436435

437436
assert len(crossed_population) == self.cfg.pop_size
@@ -442,7 +441,7 @@ def mutate(self) -> list[dict]:
442441
"""Elitist-based mutation. We only mutate the best individual to generate n_pop new individuals."""
443442
system = self.system_generator_prompt
444443
func_signature1 = self.func_signature.format(version=1)
445-
user = self.mutataion_prompt.format(
444+
user = self.mutation_prompt.format(
446445
user_generator = self.user_generator_prompt,
447446
reflection = self.long_term_reflection_str + self.external_knowledge,
448447
func_signature1 = func_signature1,
@@ -453,10 +452,7 @@ def mutate(self) -> list[dict]:
453452
if self.print_mutate_prompt:
454453
logging.info("Mutation Prompt: \nSystem Prompt: \n" + system + "\nUser Prompt: \n" + user)
455454
self.print_mutate_prompt = False
456-
if self.mutataion_llm is None:
457-
responses = self.generator_llm.multi_chat_completion([messages], int(self.cfg.pop_size * self.mutation_rate))
458-
else:
459-
responses = self.mutataion_llm.multi_chat_completion([messages], int(self.cfg.pop_size * self.mutation_rate))
455+
responses = self.mutation_llm.multi_chat_completion([messages], int(self.cfg.pop_size * self.mutation_rate))
460456
population = [self.response_to_individual(response, response_id) for response_id, response in enumerate(responses)]
461457
return population
462458

0 commit comments

Comments
 (0)