@@ -28,10 +28,10 @@ def __init__(
28
28
self .generator_llm = generator_llm
29
29
self .reflector_llm = reflector_llm or generator_llm
30
30
31
- self .short_reflector_llm = short_reflector_llm
32
- self .long_reflector_llm = long_reflector_llm
33
- self .crossover_llm = crossover_llm
34
- self .mutataion_llm = mutation_llm
31
+ self .short_reflector_llm = short_reflector_llm or self . reflector_llm
32
+ self .long_reflector_llm = long_reflector_llm or self . reflector_llm
33
+ self .crossover_llm = crossover_llm or generator_llm
34
+ self .mutation_llm = mutation_llm or generator_llm
35
35
36
36
self .root_dir = root_dir
37
37
@@ -83,7 +83,7 @@ def init_prompt(self) -> None:
83
83
self .user_reflector_st_prompt = file_to_string (f'{ self .prompt_dir } /common/user_reflector_st.txt' ) if self .problem_type != "black_box" else file_to_string (f'{ self .prompt_dir } /common/user_reflector_st_black_box.txt' ) # shrot-term reflection
84
84
self .user_reflector_lt_prompt = file_to_string (f'{ self .prompt_dir } /common/user_reflector_lt.txt' ) # long-term reflection
85
85
self .crossover_prompt = file_to_string (f'{ self .prompt_dir } /common/crossover.txt' )
86
- self .mutataion_prompt = file_to_string (f'{ self .prompt_dir } /common/mutation.txt' )
86
+ self .mutation_prompt = file_to_string (f'{ self .prompt_dir } /common/mutation.txt' )
87
87
self .user_generator_prompt = file_to_string (f'{ self .prompt_dir } /common/user_generator.txt' ).format (
88
88
func_name = self .func_name ,
89
89
problem_desc = self .problem_desc ,
@@ -373,7 +373,7 @@ def short_term_reflection(self, population: list[dict]) -> tuple[list[list[dict]
373
373
better_code_lst .append (better_code )
374
374
375
375
# Asynchronously generate responses
376
- response_lst = self .reflector_llm . multi_chat_completion ( messages_lst ) if self . short_reflector_llm is None else self . short_reflector_llm .multi_chat_completion (messages_lst )
376
+ response_lst = self .short_reflector_llm .multi_chat_completion (messages_lst )
377
377
return response_lst , worse_code_lst , better_code_lst
378
378
379
379
def long_term_reflection (self , short_term_reflections : list [str ]) -> None :
@@ -392,7 +392,7 @@ def long_term_reflection(self, short_term_reflections: list[str]) -> None:
392
392
logging .info ("Long-term Reflection Prompt: \n System Prompt: \n " + system + "\n User Prompt: \n " + user )
393
393
self .print_long_term_reflection_prompt = False
394
394
395
- self .long_term_reflection_str = self .reflector_llm . multi_chat_completion ([ messages ])[ 0 ] if self . long_reflector_llm is None else self . long_reflector_llm .multi_chat_completion ([messages ])[0 ]
395
+ self .long_term_reflection_str = self .long_reflector_llm .multi_chat_completion ([messages ])[0 ]
396
396
397
397
# Write reflections to file
398
398
file_name = f"problem_iter{ self .iteration } _short_term_reflections.txt"
@@ -430,8 +430,7 @@ def crossover(self, short_term_reflection_tuple: tuple[list[list[dict]], list[st
430
430
self .print_crossover_prompt = False
431
431
432
432
# Asynchronously generate responses
433
- # use crossover_llm if it's not None
434
- response_lst = self .generator_llm .multi_chat_completion (messages_lst ) if self .crossover_llm is None else self .crossover_llm .multi_chat_completion (messages_lst )
433
+ response_lst = self .crossover_llm .multi_chat_completion (messages_lst )
435
434
crossed_population = [self .response_to_individual (response , response_id ) for response_id , response in enumerate (response_lst )]
436
435
437
436
assert len (crossed_population ) == self .cfg .pop_size
@@ -442,7 +441,7 @@ def mutate(self) -> list[dict]:
442
441
"""Elitist-based mutation. We only mutate the best individual to generate n_pop new individuals."""
443
442
system = self .system_generator_prompt
444
443
func_signature1 = self .func_signature .format (version = 1 )
445
- user = self .mutataion_prompt .format (
444
+ user = self .mutation_prompt .format (
446
445
user_generator = self .user_generator_prompt ,
447
446
reflection = self .long_term_reflection_str + self .external_knowledge ,
448
447
func_signature1 = func_signature1 ,
@@ -453,10 +452,7 @@ def mutate(self) -> list[dict]:
453
452
if self .print_mutate_prompt :
454
453
logging .info ("Mutation Prompt: \n System Prompt: \n " + system + "\n User Prompt: \n " + user )
455
454
self .print_mutate_prompt = False
456
- if self .mutataion_llm is None :
457
- responses = self .generator_llm .multi_chat_completion ([messages ], int (self .cfg .pop_size * self .mutation_rate ))
458
- else :
459
- responses = self .mutataion_llm .multi_chat_completion ([messages ], int (self .cfg .pop_size * self .mutation_rate ))
455
+ responses = self .mutation_llm .multi_chat_completion ([messages ], int (self .cfg .pop_size * self .mutation_rate ))
460
456
population = [self .response_to_individual (response , response_id ) for response_id , response in enumerate (responses )]
461
457
return population
462
458
0 commit comments