2525 calculate_votes ,
2626 filter_votes_to_responses ,
2727)
28- from align_system .utils .hydrate_state import hydrate_scenario_state
2928from align_system .algorithms .abstracts import ActionBasedADM
3029from align_system .prompt_engineering .outlines_prompts import (
3130 baseline_system_prompt ,
4140 followup_clarify_treatment_from_list ,
4241 followup_clarify_tag ,
4342 action_choice_json_schema ,
43+ action_choice_json_schema_untrimmed ,
4444 aid_choice_json_schema ,
4545 character_choice_json_schema ,
4646 tag_choice_json_schema ,
4747 treatment_choice_json_schema ,
4848 treatment_choice_from_list_json_schema ,
49- detailed_unstructured_treatment_action_text ,
50- detailed_unstructured_tagging_action_text
49+
5150)
5251
5352log = logging .getLogger (__name__ )
5453JSON_HIGHLIGHTER = JSONHighlighter ()
5554
5655
5756class OutlinesTransformersADM (ActionBasedADM ):
58- def __init__ (self ,
59- model_name ,
60- device = 'auto' ,
61- baseline = False ,
62- sampler = MultinomialSampler (),
63- ** kwargs ):
57+ def __init__ (
58+ self ,
59+ model_name ,
60+ device = 'auto' ,
61+ baseline = False ,
62+ mode = 'eval' ,
63+ sampler = MultinomialSampler (),
64+ ** kwargs
65+ ):
6466 self .baseline = baseline
6567
6668 model_kwargs = kwargs .get ('model_kwargs' , {})
69+ self .mode = mode
6770 if 'precision' in kwargs :
6871 if kwargs ['precision' ] == 'half' :
6972 torch_dtype = torch .float16
@@ -87,6 +90,8 @@ def __init__(self,
8790 # the sampler itself (which defaults to 1); setting the number
8891 # of samples in the sampler may result in unexpected behavior
8992 self .sampler = sampler
93+ # Edited prompt from the Demo interface
94+ self ._system_ui_prompt = None
9095
9196 def dialog_to_prompt (self , dialog ):
9297 tokenizer = self .model .tokenizer .tokenizer
@@ -152,25 +157,35 @@ def batched(cls, iterable, n):
152157 yield batch
153158
154159 @classmethod
155- def run_in_batches (cls , inference_function , inputs , batch_size ):
160+ def run_in_batches (cls , inference_function , inputs , batch_size , max_tokens , seed ):
156161 ''' Batch inference to avoid out of memory error'''
157162 outputs = []
158163 for batch in cls .batched (inputs , batch_size ):
159- output = inference_function (list (batch ))
164+ output = inference_function (
165+ list (batch ),
166+ max_tokens = max_tokens ,
167+ rng = torch .cuda .manual_seed (seed )
168+ )
160169 if not isinstance (output , list ):
161170 output = [output ]
162171 outputs .extend (output )
163172 return outputs
164173
165- def top_level_choose_action (self ,
166- scenario_state ,
167- available_actions ,
168- alignment_target ,
169- num_positive_samples = 1 ,
170- num_negative_samples = 0 ,
171- generator_batch_size = 5 ,
172- kdma_descriptions_map = 'align_system/prompt_engineering/kdma_descriptions.yml' ,
173- ** kwargs ):
174+ @property
175+ def system_ui_prompt (self ) -> str :
176+ return self ._system_ui_prompt
177+
178+ @system_ui_prompt .setter
179+ def system_ui_prompt (self , edited_system_prompt : str ):
180+ self ._system_ui_prompt = edited_system_prompt
181+
182+ def get_dialog_texts (self , scenario_state ,
183+ available_actions ,
184+ alignment_target ,
185+ num_positive_samples = 1 ,
186+ num_negative_samples = 0 ,
187+ kdma_descriptions_map = "align_system/prompt_engineering/kdma_descriptions.yml" ,
188+ ** kwargs ):
174189 if self .baseline and num_negative_samples > 0 :
175190 raise RuntimeError ("No notion of negative samples for baseline run" )
176191 if self .baseline and "incontext" in kwargs and kwargs ["incontext" ]["number" ] > 0 :
@@ -185,6 +200,8 @@ def top_level_choose_action(self,
185200 available_actions ,
186201 scenario_state
187202 )
203+ # Sort the choices
204+ choices = sorted (choices )
188205
189206 positive_icl_examples = []
190207 negative_icl_examples = []
@@ -224,8 +241,11 @@ def top_level_choose_action(self,
224241
225242 # Create positive ICL example generators
226243 positive_target = {'kdma' : kdma , 'name' : name , 'value' : value }
227- positive_icl_example_generator = incontext_utils .BaselineIncontextExampleGenerator (incontext_settings ,
228- [positive_target ])
244+ positive_icl_example_generator = (
245+ incontext_utils .BaselineIncontextExampleGenerator (
246+ incontext_settings , [positive_target ]
247+ )
248+ )
229249 # Get subset of relevant of examples
230250 positive_selected_icl_examples = positive_icl_example_generator .select_icl_examples (
231251 sys_kdma_name = kdma ,
@@ -244,8 +264,11 @@ def top_level_choose_action(self,
244264 if num_negative_samples > 0 :
245265 # Create negative ICL example generators
246266 negative_target = {'kdma' : kdma , 'name' : name , 'value' : negative_value }
247- negative_icl_example_generator = incontext_utils .BaselineIncontextExampleGenerator (incontext_settings ,
248- [negative_target ])
267+ negative_icl_example_generator = (
268+ incontext_utils .BaselineIncontextExampleGenerator (
269+ incontext_settings , [negative_target ]
270+ )
271+ )
249272 # Get subset of relevant of examples
250273 negative_selected_icl_examples = negative_icl_example_generator .select_icl_examples (
251274 sys_kdma_name = kdma ,
@@ -265,10 +288,13 @@ def top_level_choose_action(self,
265288 if "incontext" in kwargs and kwargs ["incontext" ]["number" ] > 0 :
266289 raise RuntimeError ("No notion of incontext examples for baseline run" )
267290
291+ shuffled_choices = choices
268292 positive_dialogs = []
269293 for _ in range (num_positive_samples ):
270- shuffled_choices = random .sample (choices , len (choices ))
271-
294+ if kwargs ["demo_kwargs" ]["shuffle_choices" ]:
295+ shuffled_choices = random .sample (choices , len (choices ))
296+ if self .system_ui_prompt is not None and self .mode == "demo" :
297+ positive_system_prompt = self .system_ui_prompt
272298 prompt = action_selection_prompt (scenario_description , shuffled_choices )
273299 dialog = [{'role' : 'system' , 'content' : positive_system_prompt }]
274300 dialog .extend (positive_icl_examples )
@@ -278,8 +304,8 @@ def top_level_choose_action(self,
278304
279305 negative_dialogs = []
280306 for _ in range (num_negative_samples ):
281- shuffled_choices = random . sample ( choices , len ( choices ))
282-
307+ if kwargs [ "demo_kwargs" ][ "shuffle_choices" ]:
308+ shuffled_choices = random . sample ( choices , len ( choices ))
283309 prompt = action_selection_prompt (scenario_description , shuffled_choices )
284310 dialog = [{'role' : 'system' , 'content' : negative_system_prompt }]
285311 dialog .extend (negative_icl_examples )
@@ -290,20 +316,66 @@ def top_level_choose_action(self,
290316 # Need to set the whitespace_pattern to prevent the state
291317 # machine from looping indefinitely in some cases, see:
292318 # https://github.com/outlines-dev/outlines/issues/690#issuecomment-2102291934
293- generator = outlines .generate .json (
294- self .model ,
295- action_choice_json_schema (json .dumps (choices )),
296- sampler = self .sampler ,
297- whitespace_pattern = r"[ ]?" )
298319
299- dialog_texts = [self .dialog_to_prompt (d ) for d in
300- itertools .chain (positive_dialogs , negative_dialogs )]
320+ dialog_texts = [
321+ self .dialog_to_prompt (d )
322+ for d in itertools .chain (positive_dialogs , negative_dialogs )
323+ ]
301324
302- log .info ("[bold]*DIALOG PROMPT*[/bold]" ,
303- extra = {"markup" : True })
325+ log .info ("[bold]*DIALOG PROMPT*[/bold]" , extra = {"markup" : True })
304326 log .info (dialog_texts [0 ])
305327
306- responses = self .run_in_batches (generator , dialog_texts , generator_batch_size )
328+ return dialog_texts , positive_dialogs
329+
330+ def top_level_choose_action (self ,
331+ scenario_state ,
332+ available_actions ,
333+ alignment_target ,
334+ num_positive_samples = 1 ,
335+ num_negative_samples = 0 ,
336+ generator_batch_size = 5 ,
337+ kdma_descriptions_map = 'align_system/prompt_engineering/kdma_descriptions.yml' ,
338+ ** kwargs ):
339+ if 'demo_kwargs' not in kwargs and self .mode == "demo" :
340+ raise ValueError ('Demo configuration missing' )
341+
342+ demo_kwargs = kwargs ["demo_kwargs" ]
343+ choices = adm_utils .format_choices (
344+ [a .unstructured for a in available_actions ],
345+ available_actions ,
346+ scenario_state ,
347+ )
348+
349+ dialog_texts , positive_dialogs = self .get_dialog_texts (
350+ scenario_state ,
351+ available_actions ,
352+ alignment_target ,
353+ num_positive_samples = num_positive_samples ,
354+ num_negative_samples = num_negative_samples ,
355+ generator_batch_size = generator_batch_size ,
356+ kdma_descriptions_map = kdma_descriptions_map ,
357+ ** kwargs
358+ )
359+ if self .mode == "eval" :
360+ generator = outlines .generate .json (
361+ self .model ,
362+ action_choice_json_schema (json .dumps (choices )),
363+ sampler = self .sampler ,
364+ whitespace_pattern = r"[ ]?" )
365+ else :
366+ generator = outlines .generate .json (
367+ self .model ,
368+ action_choice_json_schema_untrimmed (json .dumps (choices )),
369+ sampler = self .sampler ,
370+ whitespace_pattern = r"[ ]?" ,
371+ )
372+ responses = self .run_in_batches (
373+ inference_function = generator ,
374+ inputs = dialog_texts ,
375+ batch_size = generator_batch_size ,
376+ max_tokens = int (demo_kwargs ["max_generator_tokens" ]),
377+ seed = int (demo_kwargs ["generator_seed" ]),
378+ )
307379 positive_responses_choices = \
308380 [r ['action_choice' ] for r in
309381 responses [0 :num_positive_samples ]]
0 commit comments