-
Notifications
You must be signed in to change notification settings - Fork 5
Delta based relevance for comp reg alignment to multi-attribute targets #192
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
26da5b0
73f5854
3700e31
7903e17
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,4 +1,5 @@ | ||||||||||||||||||||||||||||
| from rich.highlighter import JSONHighlighter | ||||||||||||||||||||||||||||
| import numpy as np | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| from align_system.algorithms.abstracts import ADMComponent | ||||||||||||||||||||||||||||
| from align_system.utils import adm_utils, logging | ||||||||||||||||||||||||||||
|
|
@@ -162,3 +163,57 @@ def run_returns(self): | |||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| def run(self): | ||||||||||||||||||||||||||||
| return "Looked at scores." | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| class Phase2RegressionRemoveIrrelevantAttributes(ADMComponent): | ||||||||||||||||||||||||||||
| def run_returns(self): | ||||||||||||||||||||||||||||
| return ('attribute_prediction_scores', | ||||||||||||||||||||||||||||
| 'alignment_target') | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| def run(self, | ||||||||||||||||||||||||||||
| attribute_prediction_scores, | ||||||||||||||||||||||||||||
| alignment_target): | ||||||||||||||||||||||||||||
| # If there are two non-medical attributes, removes the one with smaller delta | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| attributes = list({key for inner in attribute_prediction_scores.values() for key in inner}) | ||||||||||||||||||||||||||||
| # Ignore / don't filter out medical | ||||||||||||||||||||||||||||
| keep_attributes = [] | ||||||||||||||||||||||||||||
| if 'medical' in attributes: | ||||||||||||||||||||||||||||
| attributes.remove('medical') | ||||||||||||||||||||||||||||
| keep_attributes.append('medical') | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| # Only one attribute aside from medical -> filtering not needed | ||||||||||||||||||||||||||||
| if len(attributes) == 1: | ||||||||||||||||||||||||||||
| return attribute_prediction_scores, alignment_target | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| # Two or more attributes -> keep the one with largest delta | ||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||
| if len(attribute_prediction_scores.keys()) > 2: | ||||||||||||||||||||||||||||
| raise RuntimeError("Relevance filtering not implemented for more than two choices.") | ||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: David recommended I raise |
||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||
| # Determine most relevant attribute to keep | ||||||||||||||||||||||||||||
| choiceA, choiceB = list(attribute_prediction_scores.keys()) | ||||||||||||||||||||||||||||
| max_delta = -np.inf | ||||||||||||||||||||||||||||
| for attr in attributes: | ||||||||||||||||||||||||||||
| delta = abs(np.array(attribute_prediction_scores[choiceA][attr]).mean() - np.array(attribute_prediction_scores[choiceB][attr]).mean()) | ||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I know at one point we potentially had to handle both a single prediction and a list of predictions. Do we still need to do that? (If not yayyyyy simpler code :)) |
||||||||||||||||||||||||||||
| if delta > max_delta: | ||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will only keep the first one if there's a tie. I think that's fine because I don't know what we'd actually do if there was a tie (plus that seems unlikely), but just double checking you didn't have a different set of assumptions |
||||||||||||||||||||||||||||
| max_delta = delta | ||||||||||||||||||||||||||||
| relevant_attribute = attr | ||||||||||||||||||||||||||||
| keep_attributes.append(relevant_attribute) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| # Update predicted scores to only have more relevant attribute | ||||||||||||||||||||||||||||
| filtered_attribute_prediction_scores = {choiceA:{}, choiceB:{}} | ||||||||||||||||||||||||||||
| for keep_attr in keep_attributes: | ||||||||||||||||||||||||||||
| filtered_attribute_prediction_scores[choiceA][keep_attr] = attribute_prediction_scores[choiceA][keep_attr] | ||||||||||||||||||||||||||||
| filtered_attribute_prediction_scores[choiceB][keep_attr] = attribute_prediction_scores[choiceB][keep_attr] | ||||||||||||||||||||||||||||
|
Comment on lines
+204
to
+208
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (nit)
Suggested change
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| log.info("[bold]*FILTERING OUT ATTRIBUTES EXCEPT MOST RELEVANT: {}*[/bold]".format(relevant_attribute), extra={"markup": True}) | ||||||||||||||||||||||||||||
| log.info("Retained:{}".format(filtered_attribute_prediction_scores), extra={"highlighter": JSON_HIGHLIGHTER}) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| # Update target to only include relevant attribute | ||||||||||||||||||||||||||||
| filtered_alignment_target = alignment_target.copy() | ||||||||||||||||||||||||||||
| filtered_alignment_target['kdma_values'] = [entry for entry in alignment_target['kdma_values'] if entry['kdma'] in keep_attributes] | ||||||||||||||||||||||||||||
| log.info("[bold]*UPDATING ALIGNMENT TARGET*[/bold]", extra={"markup": True}) | ||||||||||||||||||||||||||||
| log.info("{}".format(filtered_alignment_target), extra={"highlighter": JSON_HIGHLIGHTER}) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| return filtered_attribute_prediction_scores, filtered_alignment_target | ||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,61 @@ | ||
| name: phase2_pipeline_zeroshot_comparative_regression_delta_relevance | ||
|
|
||
| defaults: | ||
| # Import defaults into this namspace (adm) as @name, for further | ||
| # customization | ||
|
|
||
| # Shared variables / components | ||
| - /attribute@mu: medical_urgency | ||
| - /attribute@af: affiliation_focus | ||
| - /attribute@mf: merit_focus | ||
| - /attribute@ss: search_or_stay | ||
| - /attribute@ps: personal_safety | ||
| - /inference_engine@structured_inference_engine: outlines_structured_greedy | ||
| - /template/scenario_description@scenario_description_template: phase2 | ||
| - /template/prompt@prompt_template: phase2_comparative_regression | ||
| - /template/output_schema@comparative_regression_choice_schema: phase2_comparative_regression_choice | ||
| # ADM components to be used in "steps" | ||
| - /adm_component/misc@step_definitions.format_choices: itm_format_choices | ||
| - /adm_component/icl@step_definitions.regression_icl: phase2_comparative | ||
| - /adm_component/regression@step_definitions.comparative_regression: phase2_comparative_no_template | ||
| - /adm_component/misc@step_definitions.regression_rule_based_correction: phase2_regression_rule_based_correction | ||
| - /adm_component/misc@step_definitions.remove_irrelevant_attributes: phase2_regression_remove_irrelevant_attributes | ||
| - /adm_component/alignment@step_definitions.scalar_alignment: medical_urgency_scalar | ||
| - /adm_component/misc@step_definitions.justification_from_reasonings: justification_from_reasonings | ||
| - /adm_component/misc@step_definitions.ensure_chosen_action: ensure_chosen_action | ||
| - /adm_component/misc@step_definitions.populate_choice_info: populate_choice_info | ||
| # Use definitions in this file to override defaults defined above | ||
| - _self_ | ||
|
|
||
| attribute_definitions: | ||
| medical: ${adm.mu} | ||
| affiliation: ${adm.af} | ||
| merit: ${adm.mf} | ||
| search: ${adm.ss} | ||
| personal_safety: ${adm.ps} | ||
|
|
||
| step_definitions: | ||
| regression_icl: | ||
| scenario_description_template: ${ref:adm.scenario_description_template} | ||
| attributes: ${adm.attribute_definitions} | ||
| prompt_template: ${ref:adm.prompt_template} | ||
|
|
||
| comparative_regression: | ||
| scenario_description_template: ${ref:adm.scenario_description_template} | ||
| prompt_template: ${ref:adm.prompt_template} | ||
| score_schema_template: ${adm.comparative_regression_choice_schema} | ||
|
|
||
| instance: | ||
| _target_: align_system.algorithms.pipeline_adm.PipelineADM | ||
|
|
||
| steps: | ||
| # Reference the step instances we want to use in order | ||
| - ${ref:adm.step_definitions.format_choices} | ||
| - ${ref:adm.step_definitions.regression_icl} | ||
| - ${ref:adm.step_definitions.comparative_regression} | ||
| - ${ref:adm.step_definitions.regression_rule_based_correction} | ||
| - ${ref:adm.step_definitions.remove_irrelevant_attributes} | ||
| - ${ref:adm.step_definitions.scalar_alignment} | ||
| - ${ref:adm.step_definitions.justification_from_reasonings} | ||
| - ${ref:adm.step_definitions.ensure_chosen_action} | ||
| - ${ref:adm.step_definitions.populate_choice_info} |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,54 @@ | ||
| name: phase2_pipeline_zeroshot_comparative_regression_delta_relevance | ||
|
|
||
| defaults: | ||
| # Import defaults into this namspace (adm) as @name, for further | ||
| # customization | ||
|
|
||
| # Shared variables / components | ||
| - /attribute@mu: medical_urgency | ||
| - /attribute@af: affiliation_focus | ||
| - /attribute@mf: merit_focus | ||
| - /attribute@ss: search_or_stay | ||
| - /attribute@ps: personal_safety | ||
| - /inference_engine@structured_inference_engine: outlines_structured_greedy | ||
| - /template/scenario_description@scenario_description_template: phase2 | ||
| - /template/prompt@prompt_template: phase2_comparative_regression | ||
| - /template/output_schema@comparative_regression_choice_schema: phase2_comparative_regression_choice | ||
| # ADM components to be used in "steps" | ||
| - /adm_component/misc@step_definitions.format_choices: itm_format_choices | ||
| - /adm_component/regression@step_definitions.comparative_regression: phase2_comparative | ||
| - /adm_component/misc@step_definitions.regression_rule_based_correction: phase2_regression_rule_based_correction | ||
| - /adm_component/misc@step_definitions.remove_irrelevant_attributes: phase2_regression_remove_irrelevant_attributes | ||
| - /adm_component/alignment@step_definitions.scalar_alignment: medical_urgency_scalar | ||
| - /adm_component/misc@step_definitions.justification_from_reasonings: justification_from_reasonings | ||
| - /adm_component/misc@step_definitions.ensure_chosen_action: ensure_chosen_action | ||
| - /adm_component/misc@step_definitions.populate_choice_info: populate_choice_info | ||
| # Use definitions in this file to override defaults defined above | ||
| - _self_ | ||
|
|
||
| attribute_definitions: | ||
| medical: ${adm.mu} | ||
| affiliation: ${adm.af} | ||
| merit: ${adm.mf} | ||
| search: ${adm.ss} | ||
| personal_safety: ${adm.ps} | ||
|
|
||
| step_definitions: | ||
| comparative_regression: | ||
| scenario_description_template: ${ref:adm.scenario_description_template} | ||
| prompt_template: ${ref:adm.prompt_template} | ||
| score_schema_template: ${adm.comparative_regression_choice_schema} | ||
|
|
||
| instance: | ||
| _target_: align_system.algorithms.pipeline_adm.PipelineADM | ||
|
|
||
| steps: | ||
| # Reference the step instances we want to use in order | ||
| - ${ref:adm.step_definitions.format_choices} | ||
| - ${ref:adm.step_definitions.comparative_regression} | ||
| - ${ref:adm.step_definitions.regression_rule_based_correction} | ||
| - ${ref:adm.step_definitions.remove_irrelevant_attributes} | ||
| - ${ref:adm.step_definitions.scalar_alignment} | ||
| - ${ref:adm.step_definitions.justification_from_reasonings} | ||
| - ${ref:adm.step_definitions.ensure_chosen_action} | ||
| - ${ref:adm.step_definitions.populate_choice_info} |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| _target_: align_system.algorithms.misc_itm_adm_components.Phase2RegressionRemoveIrrelevantAttributes |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,25 @@ | ||
| # @package _global_ | ||
| defaults: | ||
| - override /adm: phase2_pipeline_fewshot_comparative_regression_delta_relevance | ||
| - override /interface: ta3 | ||
|
|
||
| interface: | ||
| session_type: adept | ||
| training_session: full | ||
| username: "pipeline_fewshot_comp_reg_multi_attribute_test" | ||
| domain: "p2triage" | ||
| scenario_ids: | ||
| - June2025-AF-train | ||
| - June2025-MF-train | ||
|
|
||
| # LOO - Remove for eval | ||
| adm: | ||
| step_definitions: | ||
| regression_icl: | ||
| icl_generator_partial: | ||
| incontext_settings: | ||
| leave_one_out_strategy: 'scenario_description' | ||
|
|
||
| apply_action_filtering: false | ||
| force_determinism: true | ||
| align_to_target: true |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,25 @@ | ||
| # @package _global_ | ||
| defaults: | ||
| - override /adm: phase2_pipeline_zeroshot_comparative_regression_delta_relevance | ||
| - override /interface: ta3 | ||
|
|
||
| interface: | ||
| session_type: adept | ||
| training_session: full | ||
| username: "pipeline_zeroshot_comp_reg_multi_attribute_test" | ||
| domain: "p2triage" | ||
| scenario_ids: | ||
| - June2025-AF-train | ||
| - June2025-MF-train | ||
|
|
||
| # LOO - Remove for eval | ||
| adm: | ||
| step_definitions: | ||
| regression_icl: | ||
| icl_generator_partial: | ||
| incontext_settings: | ||
| leave_one_out_strategy: 'scenario_description' | ||
|
|
||
| apply_action_filtering: false | ||
| force_determinism: true | ||
| align_to_target: true |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume we're looping all choice predictions in case different choices have different predictions? Do we think that'll actually happen?