Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 64 additions & 8 deletions align_system/algorithms/llama_2_single_kdma_adm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import random
import os
import pathlib
import random
from align_system.algorithms.lib.aligned_decision_maker import AlignedDecisionMaker

from jinja2.exceptions import TemplateError
Expand Down Expand Up @@ -113,6 +114,7 @@ def __init__(self, device='cuda', hf_model='meta-llama/Llama-2-7b-chat-hf', prec
self.hf_model = hf_model
self.temperature = temperature
self.chat_template = kwargs.get('chat_template', None)
self.dataset = []

assert precision in ['full', 'half'], "precision must be either 'full' or 'half'."
self.precision = torch.float32 if precision == 'full' else torch.float16
Expand All @@ -124,11 +126,11 @@ def __init__(self, device='cuda', hf_model='meta-llama/Llama-2-7b-chat-hf', prec
def load_model(self, model=None, tokenizer=None):
assert (model is None) == (tokenizer is None), "model and tokenizer must both be None or both be not None."
if model is not None:
print('Loading model and tokenizer from provided objects.')
log.info('Loading model and tokenizer from provided objects.')
self.model = model
self.tokenizer = tokenizer
else:
print('Loading model:', self.hf_model)
log.info('Loading model: %s', self.hf_model)
if self.device == 'auto':
self.model = AutoModelForCausalLM.from_pretrained(self.hf_model, torch_dtype=self.precision, device_map='auto')
else:
Expand Down Expand Up @@ -282,7 +284,7 @@ def respond_to_dialog(self, dialog, prefix=None):
else:
new_dialog.append(message)
dialog = new_dialog
print('INPUT\n', dialog)
log.info('INPUT\n %s', dialog)
prompt_tokens = [self.tokenizer.apply_chat_template(dialog, tokenize=True)]
inference_pair['input'] = self.tokenizer.apply_chat_template(dialog, tokenize=False)

Expand All @@ -298,11 +300,11 @@ def respond_to_dialog(self, dialog, prefix=None):

outputs = self.model.generate(prompt_tokens, return_dict_in_generate=True, output_scores=True, max_new_tokens=512, temperature=self.temperature, do_sample=True)

# Print the generated model output
# log.info the generated model output
generated_output = self.tokenizer.decode(outputs.sequences[0][prompt_length:])
inference_pair['output'] = generated_output

print('INFERENCE PAIR\n', inference_pair)
log.info('INFERENCE PAIR\n %s', inference_pair)

return generated_output, inference_pair

Expand Down Expand Up @@ -402,6 +404,7 @@ def aligned_decision_maker(self, question, choices, target_kdmas, n_positive_sam
shuffled_choices,
system_message=system_message)


if not logged_aligned_dialog:
log.debug("[bold]*ALIGNED DIALOG*[/bold]",
extra={"markup": True})
Expand All @@ -422,7 +425,7 @@ def aligned_decision_maker(self, question, choices, target_kdmas, n_positive_sam
if not good_parse:
reasoning, answer_idx, parse_method = Llama2SingleKDMAADM.bert_similarity_parse(high_response, shuffled_choices)

print('CHOSEN ANSWER IDX', answer_idx, shuffled_choices)
log.explain('CHOSEN ANSWER IDX %s %s', answer_idx, shuffled_choices)
assert answer_idx is not None, f'Failed to parse answer index from generated output: {low_response}'

responses.append({
Expand Down Expand Up @@ -594,10 +597,10 @@ def parse_generated_output(generated_output, n_choices):

@staticmethod
def bert_similarity_parse(generated_output, choices):
print('BERT SIMILARITY PARSE')
log.info('BERT SIMILARITY PARSE')
force_choice_func = build_force_choice_func('bert')
answer_idx, _ = force_choice_func(generated_output, choices)
print('ANSWER IDX', answer_idx, type(answer_idx))
log.info('ANSWER IDX %s %s', answer_idx, type(answer_idx))
return generated_output, answer_idx, 'bert_similarity'

@staticmethod
Expand Down Expand Up @@ -749,12 +752,65 @@ def run_aligned_decision_maker_with_voting(
break

return reasoning, answer_idx, responses, inference_pairs

def format_single_incontext_prompt(self, sample, labels):
prompt = sample['scenario']
if sample['state'] is not None:
prompt += f'\n{sample["state"]}'

for choice, label in zip(sample['choices'],labels):
level = 'high' if list(label.values())[0] > 5 else 'low'
attribute = list(label.keys())[0].replace('_', ' ')
prompt += f' If you had a {level} {attribute}, you would select {choice}.'

return prompt


#TODO: add prompt completetion here for choices as well.


def __call__(self, sample, target_kdma_values, **kwargs):
""" Build the prompt and send to the LLM to ask for a single KDMA


"""
prompt = sample['scenario']
if sample['state'] is not None:
prompt += f'\n{sample["state"]}'

if 'incontext' in kwargs:
possible_samples = []

#sam has both info in first element and labels in second element
for sam in kwargs['dataset']:
if sam[0]['probe_id'] != sample['probe_id']:

possible_samples.append(sam)

if len(possible_samples) < kwargs['incontext']['number']:
raise RuntimeError(f'Not enough possible incontext samples to learn from here.'
f'Only {len(possible_samples)} samples while asking for'
f'{kwargs["incontext"]["number"]} in context samples')

if kwargs['incontext']['method'] == 'random':
chosen_sample = random.sample(possible_samples, kwargs['incontext']['number'])
else:
raise(f'"{kwargs["incontext"]["method"]}" is not a valid incontext method. Please use "random", ')

incontext_prompt_start = ' Here are some examples of similar problems with their attributes. '


extra_prompts = [incontext_prompt_start]
ci = 1
for cs, cl in chosen_sample:
extra_prompts.append(f' Example {ci}' + self.format_single_incontext_prompt(cs, cl))
ci += 1

extra_prompts.append(' Given these similar examples, please answer the question for the following scenario. ')

extra_prompts = ''.join(extra_prompts)
prompt = extra_prompts + prompt

if 'retriever' in kwargs:
# retriever_prompt = "How would you treat the following injuries: {}".format(prompt)
retriever_prompt = "{} {}".format(prompt, sample['probe'])
Expand Down
2 changes: 1 addition & 1 deletion align_system/evaluation/adm_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def generate_outputs(dataset, adm, target_kdma_values, **kwargs):
})
continue

outputs.append(adm(input_, target_kdma_values, labels=label, **kwargs))
outputs.append(adm(input_, target_kdma_values, labels=label, dataset=dataset, **kwargs))

return outputs

Expand Down