Skip to content

Commit efa604d

Browse files
Update outlines transformers adm to enable demo mode workflow
1 parent e92eac0 commit efa604d

File tree

2 files changed

+125
-38
lines changed

2 files changed

+125
-38
lines changed

align_system/algorithms/outlines_adm.py

Lines changed: 110 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
calculate_votes,
2626
filter_votes_to_responses,
2727
)
28-
from align_system.utils.hydrate_state import hydrate_scenario_state
2928
from align_system.algorithms.abstracts import ActionBasedADM
3029
from align_system.prompt_engineering.outlines_prompts import (
3130
baseline_system_prompt,
@@ -41,29 +40,33 @@
4140
followup_clarify_treatment_from_list,
4241
followup_clarify_tag,
4342
action_choice_json_schema,
43+
action_choice_json_schema_untrimmed,
4444
aid_choice_json_schema,
4545
character_choice_json_schema,
4646
tag_choice_json_schema,
4747
treatment_choice_json_schema,
4848
treatment_choice_from_list_json_schema,
49-
detailed_unstructured_treatment_action_text,
50-
detailed_unstructured_tagging_action_text
49+
5150
)
5251

5352
log = logging.getLogger(__name__)
5453
JSON_HIGHLIGHTER = JSONHighlighter()
5554

5655

5756
class OutlinesTransformersADM(ActionBasedADM):
58-
def __init__(self,
59-
model_name,
60-
device='auto',
61-
baseline=False,
62-
sampler=MultinomialSampler(),
63-
**kwargs):
57+
def __init__(
58+
self,
59+
model_name,
60+
device='auto',
61+
baseline=False,
62+
mode='eval',
63+
sampler=MultinomialSampler(),
64+
**kwargs
65+
):
6466
self.baseline = baseline
6567

6668
model_kwargs = kwargs.get('model_kwargs', {})
69+
self.mode = mode
6770
if 'precision' in kwargs:
6871
if kwargs['precision'] == 'half':
6972
torch_dtype = torch.float16
@@ -87,6 +90,8 @@ def __init__(self,
8790
# the sampler itself (which defaults to 1); setting the number
8891
# of samples in the sampler may result in unexpected behavior
8992
self.sampler = sampler
93+
# Edited prompt from the Demo interface
94+
self._system_ui_prompt = None
9095

9196
def dialog_to_prompt(self, dialog):
9297
tokenizer = self.model.tokenizer.tokenizer
@@ -152,25 +157,35 @@ def batched(cls, iterable, n):
152157
yield batch
153158

154159
@classmethod
155-
def run_in_batches(cls, inference_function, inputs, batch_size):
160+
def run_in_batches(cls, inference_function, inputs, batch_size, max_tokens, seed):
156161
''' Batch inference to avoid out of memory error'''
157162
outputs = []
158163
for batch in cls.batched(inputs, batch_size):
159-
output = inference_function(list(batch))
164+
output = inference_function(
165+
list(batch),
166+
max_tokens=max_tokens,
167+
rng=torch.cuda.manual_seed(seed)
168+
)
160169
if not isinstance(output, list):
161170
output = [output]
162171
outputs.extend(output)
163172
return outputs
164173

165-
def top_level_choose_action(self,
166-
scenario_state,
167-
available_actions,
168-
alignment_target,
169-
num_positive_samples=1,
170-
num_negative_samples=0,
171-
generator_batch_size=5,
172-
kdma_descriptions_map='align_system/prompt_engineering/kdma_descriptions.yml',
173-
**kwargs):
174+
@property
175+
def system_ui_prompt(self) -> str:
176+
return self._system_ui_prompt
177+
178+
@system_ui_prompt.setter
179+
def system_ui_prompt(self, edited_system_prompt: str):
180+
self._system_ui_prompt = edited_system_prompt
181+
182+
def get_dialog_texts(self, scenario_state,
183+
available_actions,
184+
alignment_target,
185+
num_positive_samples=1,
186+
num_negative_samples=0,
187+
kdma_descriptions_map="align_system/prompt_engineering/kdma_descriptions.yml",
188+
**kwargs):
174189
if self.baseline and num_negative_samples > 0:
175190
raise RuntimeError("No notion of negative samples for baseline run")
176191
if self.baseline and "incontext" in kwargs and kwargs["incontext"]["number"] > 0:
@@ -185,6 +200,8 @@ def top_level_choose_action(self,
185200
available_actions,
186201
scenario_state
187202
)
203+
# Sort the choices
204+
choices = sorted(choices)
188205

189206
positive_icl_examples = []
190207
negative_icl_examples = []
@@ -224,8 +241,11 @@ def top_level_choose_action(self,
224241

225242
# Create positive ICL example generators
226243
positive_target = {'kdma': kdma, 'name': name, 'value': value}
227-
positive_icl_example_generator = incontext_utils.BaselineIncontextExampleGenerator(incontext_settings,
228-
[positive_target])
244+
positive_icl_example_generator = (
245+
incontext_utils.BaselineIncontextExampleGenerator(
246+
incontext_settings, [positive_target]
247+
)
248+
)
229249
# Get subset of relevant of examples
230250
positive_selected_icl_examples = positive_icl_example_generator.select_icl_examples(
231251
sys_kdma_name=kdma,
@@ -244,8 +264,11 @@ def top_level_choose_action(self,
244264
if num_negative_samples > 0:
245265
# Create negative ICL example generators
246266
negative_target = {'kdma': kdma, 'name': name, 'value': negative_value}
247-
negative_icl_example_generator = incontext_utils.BaselineIncontextExampleGenerator(incontext_settings,
248-
[negative_target])
267+
negative_icl_example_generator = (
268+
incontext_utils.BaselineIncontextExampleGenerator(
269+
incontext_settings, [negative_target]
270+
)
271+
)
249272
# Get subset of relevant of examples
250273
negative_selected_icl_examples = negative_icl_example_generator.select_icl_examples(
251274
sys_kdma_name=kdma,
@@ -265,10 +288,13 @@ def top_level_choose_action(self,
265288
if "incontext" in kwargs and kwargs["incontext"]["number"] > 0:
266289
raise RuntimeError("No notion of incontext examples for baseline run")
267290

291+
shuffled_choices = choices
268292
positive_dialogs = []
269293
for _ in range(num_positive_samples):
270-
shuffled_choices = random.sample(choices, len(choices))
271-
294+
if kwargs["demo_kwargs"]["shuffle_choices"]:
295+
shuffled_choices = random.sample(choices, len(choices))
296+
if self.system_ui_prompt is not None and self.mode == "demo":
297+
positive_system_prompt = self.system_ui_prompt
272298
prompt = action_selection_prompt(scenario_description, shuffled_choices)
273299
dialog = [{'role': 'system', 'content': positive_system_prompt}]
274300
dialog.extend(positive_icl_examples)
@@ -278,8 +304,8 @@ def top_level_choose_action(self,
278304

279305
negative_dialogs = []
280306
for _ in range(num_negative_samples):
281-
shuffled_choices = random.sample(choices, len(choices))
282-
307+
if kwargs["demo_kwargs"]["shuffle_choices"]:
308+
shuffled_choices = random.sample(choices, len(choices))
283309
prompt = action_selection_prompt(scenario_description, shuffled_choices)
284310
dialog = [{'role': 'system', 'content': negative_system_prompt}]
285311
dialog.extend(negative_icl_examples)
@@ -290,20 +316,66 @@ def top_level_choose_action(self,
290316
# Need to set the whitespace_pattern to prevent the state
291317
# machine from looping indefinitely in some cases, see:
292318
# https://github.com/outlines-dev/outlines/issues/690#issuecomment-2102291934
293-
generator = outlines.generate.json(
294-
self.model,
295-
action_choice_json_schema(json.dumps(choices)),
296-
sampler=self.sampler,
297-
whitespace_pattern=r"[ ]?")
298319

299-
dialog_texts = [self.dialog_to_prompt(d) for d in
300-
itertools.chain(positive_dialogs, negative_dialogs)]
320+
dialog_texts = [
321+
self.dialog_to_prompt(d)
322+
for d in itertools.chain(positive_dialogs, negative_dialogs)
323+
]
301324

302-
log.info("[bold]*DIALOG PROMPT*[/bold]",
303-
extra={"markup": True})
325+
log.info("[bold]*DIALOG PROMPT*[/bold]", extra={"markup": True})
304326
log.info(dialog_texts[0])
305327

306-
responses = self.run_in_batches(generator, dialog_texts, generator_batch_size)
328+
return dialog_texts, positive_dialogs
329+
330+
def top_level_choose_action(self,
331+
scenario_state,
332+
available_actions,
333+
alignment_target,
334+
num_positive_samples=1,
335+
num_negative_samples=0,
336+
generator_batch_size=5,
337+
kdma_descriptions_map='align_system/prompt_engineering/kdma_descriptions.yml',
338+
**kwargs):
339+
if 'demo_kwargs' not in kwargs and self.mode == "demo":
340+
raise ValueError('Demo configuration missing')
341+
342+
demo_kwargs = kwargs["demo_kwargs"]
343+
choices = adm_utils.format_choices(
344+
[a.unstructured for a in available_actions],
345+
available_actions,
346+
scenario_state,
347+
)
348+
349+
dialog_texts, positive_dialogs = self.get_dialog_texts(
350+
scenario_state,
351+
available_actions,
352+
alignment_target,
353+
num_positive_samples=num_positive_samples,
354+
num_negative_samples=num_negative_samples,
355+
generator_batch_size=generator_batch_size,
356+
kdma_descriptions_map=kdma_descriptions_map,
357+
**kwargs
358+
)
359+
if self.mode == "eval":
360+
generator = outlines.generate.json(
361+
self.model,
362+
action_choice_json_schema(json.dumps(choices)),
363+
sampler=self.sampler,
364+
whitespace_pattern=r"[ ]?")
365+
else:
366+
generator = outlines.generate.json(
367+
self.model,
368+
action_choice_json_schema_untrimmed(json.dumps(choices)),
369+
sampler=self.sampler,
370+
whitespace_pattern=r"[ ]?",
371+
)
372+
responses = self.run_in_batches(
373+
inference_function=generator,
374+
inputs=dialog_texts,
375+
batch_size=generator_batch_size,
376+
max_tokens=int(demo_kwargs["max_generator_tokens"]),
377+
seed=int(demo_kwargs["generator_seed"]),
378+
)
307379
positive_responses_choices =\
308380
[r['action_choice'] for r in
309381
responses[0:num_positive_samples]]

align_system/prompt_engineering/outlines_prompts.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,21 @@ def action_choice_json_schema(choices_json_str):
347347
'''
348348

349349

350+
@outlines.prompt
351+
def action_choice_json_schema_untrimmed(choices_json_str):
352+
'''
353+
{"$defs": {"ActionChoice": {"enum": {{ choices_json_str }},
354+
"title": "ActionChoice",
355+
"type": "string"}},
356+
"properties": {"detailed_reasoning": {"title": "Detailed Reasoning",
357+
"type": "string", "minLength": 1},
358+
"action_choice": {"$ref": "#/$defs/ActionChoice"}},
359+
"required": ["detailed_reasoning", "action_choice"],
360+
"title": "ActionSelection",
361+
"type": "object"}
362+
'''
363+
364+
350365
@outlines.prompt
351366
def character_choice_json_schema(choices_json_str):
352367
'''

0 commit comments

Comments
 (0)