diff --git a/align_system/algorithms/icl_adm_component.py b/align_system/algorithms/icl_adm_component.py index 6ebe198d..a6b4fd6c 100644 --- a/align_system/algorithms/icl_adm_component.py +++ b/align_system/algorithms/icl_adm_component.py @@ -78,6 +78,7 @@ def run(self, # Convert alignment target into kdma values (all that's needed # for building the icl engines, and need something that's # hashable for caching, dicts aren't hashable) + kdma_values = None for target_kdma_value in alignment_target_dict['kdma_values']: if attribute.kdma == target_kdma_value['kdma']: # tuple of tuples; when initializing the icl @@ -85,9 +86,8 @@ def run(self, # mutable arguments such as lists, need to use # tuple kdma_values = ((attribute.kdma, target_kdma_value['value']),) - else: - if self.predict_medical_urgency: - kdma_values = (('medical', 1.0),) + if not kdma_values and self.predict_medical_urgency: + kdma_values = (('medical', 1.0),) icl_gen = init_icl_engine_from_target( self.icl_generator_partial, diff --git a/align_system/algorithms/misc_itm_adm_components.py b/align_system/algorithms/misc_itm_adm_components.py index e9105287..4df9af61 100644 --- a/align_system/algorithms/misc_itm_adm_components.py +++ b/align_system/algorithms/misc_itm_adm_components.py @@ -1,4 +1,5 @@ from rich.highlighter import JSONHighlighter +import numpy as np from align_system.algorithms.abstracts import ADMComponent from align_system.utils import adm_utils, logging @@ -162,3 +163,57 @@ def run_returns(self): def run(self): return "Looked at scores." + + +class Phase2RegressionRemoveIrrelevantAttributes(ADMComponent): + def run_returns(self): + return ('attribute_prediction_scores', + 'alignment_target') + + def run(self, + attribute_prediction_scores, + alignment_target): + # If there are two non-medical attributes, removes the one with smaller delta + + attributes = list({key for inner in attribute_prediction_scores.values() for key in inner}) + # Ignore / don't filter out medical + keep_attributes = [] + if 'medical' in attributes: + attributes.remove('medical') + keep_attributes.append('medical') + + # Only one attribute aside from medical -> filtering not needed + if len(attributes) == 1: + return attribute_prediction_scores, alignment_target + + # Two or more attributes -> keep the one with largest delta + else: + if len(attribute_prediction_scores.keys()) > 2: + raise RuntimeError("Relevance filtering not implemented for more than two choices.") + else: + # Determine most relevant attribute to keep + choiceA, choiceB = list(attribute_prediction_scores.keys()) + max_delta = -np.inf + for attr in attributes: + delta = abs(np.array(attribute_prediction_scores[choiceA][attr]).mean() - np.array(attribute_prediction_scores[choiceB][attr]).mean()) + if delta > max_delta: + max_delta = delta + relevant_attribute = attr + keep_attributes.append(relevant_attribute) + + # Update predicted scores to only have more relevant attribute + filtered_attribute_prediction_scores = {choiceA:{}, choiceB:{}} + for keep_attr in keep_attributes: + filtered_attribute_prediction_scores[choiceA][keep_attr] = attribute_prediction_scores[choiceA][keep_attr] + filtered_attribute_prediction_scores[choiceB][keep_attr] = attribute_prediction_scores[choiceB][keep_attr] + + log.info("[bold]*FILTERING OUT ATTRIBUTES EXCEPT MOST RELEVANT: {}*[/bold]".format(relevant_attribute), extra={"markup": True}) + log.info("Retained:{}".format(filtered_attribute_prediction_scores), extra={"highlighter": JSON_HIGHLIGHTER}) + + # Update target to only include relevant attribute + filtered_alignment_target = alignment_target.copy() + filtered_alignment_target['kdma_values'] = [entry for entry in alignment_target['kdma_values'] if entry['kdma'] in keep_attributes] + log.info("[bold]*UPDATING ALIGNMENT TARGET*[/bold]", extra={"markup": True}) + log.info("{}".format(filtered_alignment_target), extra={"highlighter": JSON_HIGHLIGHTER}) + + return filtered_attribute_prediction_scores, filtered_alignment_target diff --git a/align_system/configs/adm/phase2_pipeline_fewshot_comparative_regression_delta_relevance.yaml b/align_system/configs/adm/phase2_pipeline_fewshot_comparative_regression_delta_relevance.yaml new file mode 100644 index 00000000..cea988ea --- /dev/null +++ b/align_system/configs/adm/phase2_pipeline_fewshot_comparative_regression_delta_relevance.yaml @@ -0,0 +1,61 @@ +name: phase2_pipeline_zeroshot_comparative_regression_delta_relevance + +defaults: + # Import defaults into this namspace (adm) as @name, for further + # customization + + # Shared variables / components + - /attribute@mu: medical_urgency + - /attribute@af: affiliation_focus + - /attribute@mf: merit_focus + - /attribute@ss: search_or_stay + - /attribute@ps: personal_safety + - /inference_engine@structured_inference_engine: outlines_structured_greedy + - /template/scenario_description@scenario_description_template: phase2 + - /template/prompt@prompt_template: phase2_comparative_regression + - /template/output_schema@comparative_regression_choice_schema: phase2_comparative_regression_choice + # ADM components to be used in "steps" + - /adm_component/misc@step_definitions.format_choices: itm_format_choices + - /adm_component/icl@step_definitions.regression_icl: phase2_comparative + - /adm_component/regression@step_definitions.comparative_regression: phase2_comparative_no_template + - /adm_component/misc@step_definitions.regression_rule_based_correction: phase2_regression_rule_based_correction + - /adm_component/misc@step_definitions.remove_irrelevant_attributes: phase2_regression_remove_irrelevant_attributes + - /adm_component/alignment@step_definitions.scalar_alignment: medical_urgency_scalar + - /adm_component/misc@step_definitions.justification_from_reasonings: justification_from_reasonings + - /adm_component/misc@step_definitions.ensure_chosen_action: ensure_chosen_action + - /adm_component/misc@step_definitions.populate_choice_info: populate_choice_info + # Use definitions in this file to override defaults defined above + - _self_ + +attribute_definitions: + medical: ${adm.mu} + affiliation: ${adm.af} + merit: ${adm.mf} + search: ${adm.ss} + personal_safety: ${adm.ps} + +step_definitions: + regression_icl: + scenario_description_template: ${ref:adm.scenario_description_template} + attributes: ${adm.attribute_definitions} + prompt_template: ${ref:adm.prompt_template} + + comparative_regression: + scenario_description_template: ${ref:adm.scenario_description_template} + prompt_template: ${ref:adm.prompt_template} + score_schema_template: ${adm.comparative_regression_choice_schema} + +instance: + _target_: align_system.algorithms.pipeline_adm.PipelineADM + + steps: + # Reference the step instances we want to use in order + - ${ref:adm.step_definitions.format_choices} + - ${ref:adm.step_definitions.regression_icl} + - ${ref:adm.step_definitions.comparative_regression} + - ${ref:adm.step_definitions.regression_rule_based_correction} + - ${ref:adm.step_definitions.remove_irrelevant_attributes} + - ${ref:adm.step_definitions.scalar_alignment} + - ${ref:adm.step_definitions.justification_from_reasonings} + - ${ref:adm.step_definitions.ensure_chosen_action} + - ${ref:adm.step_definitions.populate_choice_info} diff --git a/align_system/configs/adm/phase2_pipeline_zeroshot_comparative_regression_delta_relevance.yaml b/align_system/configs/adm/phase2_pipeline_zeroshot_comparative_regression_delta_relevance.yaml new file mode 100644 index 00000000..4e77bf42 --- /dev/null +++ b/align_system/configs/adm/phase2_pipeline_zeroshot_comparative_regression_delta_relevance.yaml @@ -0,0 +1,54 @@ +name: phase2_pipeline_zeroshot_comparative_regression_delta_relevance + +defaults: + # Import defaults into this namspace (adm) as @name, for further + # customization + + # Shared variables / components + - /attribute@mu: medical_urgency + - /attribute@af: affiliation_focus + - /attribute@mf: merit_focus + - /attribute@ss: search_or_stay + - /attribute@ps: personal_safety + - /inference_engine@structured_inference_engine: outlines_structured_greedy + - /template/scenario_description@scenario_description_template: phase2 + - /template/prompt@prompt_template: phase2_comparative_regression + - /template/output_schema@comparative_regression_choice_schema: phase2_comparative_regression_choice + # ADM components to be used in "steps" + - /adm_component/misc@step_definitions.format_choices: itm_format_choices + - /adm_component/regression@step_definitions.comparative_regression: phase2_comparative + - /adm_component/misc@step_definitions.regression_rule_based_correction: phase2_regression_rule_based_correction + - /adm_component/misc@step_definitions.remove_irrelevant_attributes: phase2_regression_remove_irrelevant_attributes + - /adm_component/alignment@step_definitions.scalar_alignment: medical_urgency_scalar + - /adm_component/misc@step_definitions.justification_from_reasonings: justification_from_reasonings + - /adm_component/misc@step_definitions.ensure_chosen_action: ensure_chosen_action + - /adm_component/misc@step_definitions.populate_choice_info: populate_choice_info + # Use definitions in this file to override defaults defined above + - _self_ + +attribute_definitions: + medical: ${adm.mu} + affiliation: ${adm.af} + merit: ${adm.mf} + search: ${adm.ss} + personal_safety: ${adm.ps} + +step_definitions: + comparative_regression: + scenario_description_template: ${ref:adm.scenario_description_template} + prompt_template: ${ref:adm.prompt_template} + score_schema_template: ${adm.comparative_regression_choice_schema} + +instance: + _target_: align_system.algorithms.pipeline_adm.PipelineADM + + steps: + # Reference the step instances we want to use in order + - ${ref:adm.step_definitions.format_choices} + - ${ref:adm.step_definitions.comparative_regression} + - ${ref:adm.step_definitions.regression_rule_based_correction} + - ${ref:adm.step_definitions.remove_irrelevant_attributes} + - ${ref:adm.step_definitions.scalar_alignment} + - ${ref:adm.step_definitions.justification_from_reasonings} + - ${ref:adm.step_definitions.ensure_chosen_action} + - ${ref:adm.step_definitions.populate_choice_info} diff --git a/align_system/configs/adm_component/misc/phase2_regression_remove_irrelevant_attributes.yaml b/align_system/configs/adm_component/misc/phase2_regression_remove_irrelevant_attributes.yaml new file mode 100644 index 00000000..af97247e --- /dev/null +++ b/align_system/configs/adm_component/misc/phase2_regression_remove_irrelevant_attributes.yaml @@ -0,0 +1 @@ +_target_: align_system.algorithms.misc_itm_adm_components.Phase2RegressionRemoveIrrelevantAttributes diff --git a/align_system/configs/experiment/phase2_june_collab/multi_attribute_pipeline_fewshot_comparative_regression_loo.yaml b/align_system/configs/experiment/phase2_june_collab/multi_attribute_pipeline_fewshot_comparative_regression_loo.yaml new file mode 100644 index 00000000..d6dfe295 --- /dev/null +++ b/align_system/configs/experiment/phase2_june_collab/multi_attribute_pipeline_fewshot_comparative_regression_loo.yaml @@ -0,0 +1,25 @@ +# @package _global_ +defaults: + - override /adm: phase2_pipeline_fewshot_comparative_regression_delta_relevance + - override /interface: ta3 + +interface: + session_type: adept + training_session: full + username: "pipeline_fewshot_comp_reg_multi_attribute_test" + domain: "p2triage" + scenario_ids: + - June2025-AF-train + - June2025-MF-train + +# LOO - Remove for eval +adm: + step_definitions: + regression_icl: + icl_generator_partial: + incontext_settings: + leave_one_out_strategy: 'scenario_description' + +apply_action_filtering: false +force_determinism: true +align_to_target: true \ No newline at end of file diff --git a/align_system/configs/experiment/phase2_june_collab/multi_attribute_pipeline_zeroshot_comparative_regression.yaml b/align_system/configs/experiment/phase2_june_collab/multi_attribute_pipeline_zeroshot_comparative_regression.yaml new file mode 100644 index 00000000..0331247c --- /dev/null +++ b/align_system/configs/experiment/phase2_june_collab/multi_attribute_pipeline_zeroshot_comparative_regression.yaml @@ -0,0 +1,25 @@ +# @package _global_ +defaults: + - override /adm: phase2_pipeline_zeroshot_comparative_regression_delta_relevance + - override /interface: ta3 + +interface: + session_type: adept + training_session: full + username: "pipeline_zeroshot_comp_reg_multi_attribute_test" + domain: "p2triage" + scenario_ids: + - June2025-AF-train + - June2025-MF-train + +# LOO - Remove for eval +adm: + step_definitions: + regression_icl: + icl_generator_partial: + incontext_settings: + leave_one_out_strategy: 'scenario_description' + +apply_action_filtering: false +force_determinism: true +align_to_target: true \ No newline at end of file