diff --git a/Snakefile b/Snakefile index 1daa8dee9..f009821c1 100644 --- a/Snakefile +++ b/Snakefile @@ -523,9 +523,12 @@ rule evaluation_ensemble_pr_curve: pr_curve_png = SEP.join([out_dir, '{dataset_gold_standard_pairs}-eval', 'pr-curve-ensemble-nodes.png']), pr_curve_file = SEP.join([out_dir, '{dataset_gold_standard_pairs}-eval', 'pr-curve-ensemble-nodes.txt']), run: + input_dataset = Evaluation.from_file(input.dataset_file) + input_interactome = input_dataset.get_interactome() + input_nodes = input_dataset.get_interesting_input_nodes() node_table = Evaluation.from_file(input.gold_standard_file).node_table - node_ensemble_dict = Evaluation.edge_frequency_node_ensemble(node_table, input.ensemble_file, input.dataset_file) - Evaluation.precision_recall_curve_node_ensemble(node_ensemble_dict, node_table, output.pr_curve_png, output.pr_curve_file) + node_ensemble_dict = Evaluation.edge_frequency_node_ensemble(node_table, input.ensemble_file, input_interactome) + Evaluation.precision_recall_curve_node_ensemble(node_ensemble_dict, node_table, input_nodes, output.pr_curve_png, output.pr_curve_file) # Returns list of algorithm specific ensemble files per dataset def collect_ensemble_per_algo_per_dataset(wildcards): @@ -542,9 +545,12 @@ rule evaluation_per_algo_ensemble_pr_curve: pr_curve_png = SEP.join([out_dir, '{dataset_gold_standard_pairs}-eval', 'pr-curve-ensemble-nodes-per-algorithm.png']), pr_curve_file = SEP.join([out_dir, '{dataset_gold_standard_pairs}-eval', 'pr-curve-ensemble-nodes-per-algorithm.txt']), run: + input_dataset = Evaluation.from_file(input.dataset_file) + input_interactome = input_dataset.get_interactome() + input_nodes = input_dataset.get_interesting_input_nodes() node_table = Evaluation.from_file(input.gold_standard_file).node_table - node_ensembles_dict = Evaluation.edge_frequency_node_ensemble(node_table, input.ensemble_files, input.dataset_file) - Evaluation.precision_recall_curve_node_ensemble(node_ensembles_dict, node_table, output.pr_curve_png, output.pr_curve_file, include_aggregate_algo_eval) + node_ensembles_dict = Evaluation.edge_frequency_node_ensemble(node_table, input.ensemble_files, input_interactome) + Evaluation.precision_recall_curve_node_ensemble(node_ensembles_dict, node_table, input_nodes, output.pr_curve_png, output.pr_curve_file, include_aggregate_algo_eval) # Remove the output directory diff --git a/spras/dataset.py b/spras/dataset.py index 1346750e3..8b501733e 100644 --- a/spras/dataset.py +++ b/spras/dataset.py @@ -171,6 +171,15 @@ def contains_node_columns(self, col_names: list[str] | str): return False return True + def get_interesting_input_nodes(self) -> pd.DataFrame: + """ + Returns: a table listing the input nodes considered as starting points for pathway reconstruction algorithms, + restricted to nodes that have at least one of the specified attributes. + """ + interesting_input_node_columns = ["sources", "targets", "prize", "active"] + interesting_input_nodes = Dataset.get_node_columns(self, col_names = interesting_input_node_columns) + return interesting_input_nodes + def get_other_files(self): return self.other_files.copy() diff --git a/spras/evaluation.py b/spras/evaluation.py index 732e58576..313727fb8 100644 --- a/spras/evaluation.py +++ b/spras/evaluation.py @@ -288,7 +288,7 @@ def pca_chosen_pathway(coordinates_files: list[Union[str, PathLike]], pathway_su return rep_pathways @staticmethod - def edge_frequency_node_ensemble(node_table: pd.DataFrame, ensemble_files: list[Union[str, PathLike]], dataset_file: str) -> dict: + def edge_frequency_node_ensemble(node_table: pd.DataFrame, ensemble_files: list[Union[str, PathLike]], input_interactome: pd.DataFrame) -> dict: """ Generates a dictionary of node ensembles using edge frequency data from a list of ensemble files. A list of ensemble files can contain an aggregated ensemble or algorithm-specific ensembles per dataset @@ -308,28 +308,25 @@ def edge_frequency_node_ensemble(node_table: pd.DataFrame, ensemble_files: list[ @param node_table: dataFrame of gold standard nodes (column: NODEID) @param ensemble_files: list of file paths containing edge ensemble outputs - @param dataset_file: path to the dataset file used to load the interactome + @param input_interactome: the input interactome used for a specific dataset @return: dictionary mapping each ensemble source to its node ensemble DataFrame """ node_ensembles_dict = dict() - pickle = Evaluation.from_file(dataset_file) - interactome = pickle.get_interactome() - - if interactome.empty: + if input_interactome.empty: raise ValueError( - f"Cannot compute PR curve or generate node ensemble. Input network for dataset \"{dataset_file.split('-')[0]}\" is empty." + f"Cannot compute PR curve or generate node ensemble. The input network is empty." ) if node_table.empty: raise ValueError( - f"Cannot compute PR curve or generate node ensemble. Gold standard associated with dataset \"{dataset_file.split('-')[0]}\" is empty." + f"Cannot compute PR curve or generate node ensemble. The gold standard is empty." ) # set the initial default frequencies to 0 for all interactome and gold standard nodes - node1_interactome = interactome[['Interactor1']].rename(columns={'Interactor1': 'Node'}) + node1_interactome = input_interactome[['Interactor1']].rename(columns={'Interactor1': 'Node'}) node1_interactome['Frequency'] = 0.0 - node2_interactome = interactome[['Interactor2']].rename(columns={'Interactor2': 'Node'}) + node2_interactome = input_interactome[['Interactor2']].rename(columns={'Interactor2': 'Node'}) node2_interactome['Frequency'] = 0.0 gs_nodes = node_table[[Evaluation.NODE_ID]].rename(columns={Evaluation.NODE_ID: 'Node'}) gs_nodes['Frequency'] = 0.0 @@ -354,7 +351,7 @@ def edge_frequency_node_ensemble(node_table: pd.DataFrame, ensemble_files: list[ return node_ensembles_dict @staticmethod - def precision_recall_curve_node_ensemble(node_ensembles: dict, node_table: pd.DataFrame, output_png: str | PathLike, + def precision_recall_curve_node_ensemble(node_ensembles: dict, node_table: pd.DataFrame, input_nodes: pd.DataFrame, output_png: str | PathLike, output_file: str | PathLike, aggregate_per_algorithm: bool = False): """ Plots precision-recall (PR) curves for a set of node ensembles evaluated against a gold standard. @@ -365,6 +362,7 @@ def precision_recall_curve_node_ensemble(node_ensembles: dict, node_table: pd.Da @param node_ensembles: dict of the pre-computed node_ensemble(s) @param node_table: gold standard nodes + @param input_nodes: the input nodes (sources, targets, prizes, actives) used for a specific dataset @param output_png: filename to save the precision and recall curves as a .png image @param output_file: filename to save the precision, recall, threshold values, average precision, and baseline average precision @@ -380,13 +378,41 @@ def precision_recall_curve_node_ensemble(node_ensembles: dict, node_table: pd.Da prc_dfs = [] metric_dfs = [] - + prc_input_nodes_baseline_df = None baseline = None for label, node_ensemble in node_ensembles.items(): if not node_ensemble.empty: y_true = [1 if node in gold_standard_nodes else 0 for node in node_ensemble['Node']] y_scores = node_ensemble['Frequency'].tolist() + + # input nodes (sources, targets, prizes, actives) may be easier to recover but are still valid gold standard nodes; + # the Input_Nodes_Baseline PR curve highlights their overlap with the gold standard. + if prc_input_nodes_baseline_df is None: + input_nodes_set = set(input_nodes['NODEID']) + input_nodes_gold_intersection = input_nodes_set & gold_standard_nodes # TODO should this be all inputs nodes or the intersection with the gold standard for this baseline? I think it should be the intersection + input_nodes_ensemble_df = node_ensemble.copy() + + # set 'Frequency' to 1.0 if the input node is also in the gold standard intersection, else set to 0.0 + input_nodes_ensemble_df["Frequency"] = ( + input_nodes_ensemble_df["Node"].isin(input_nodes_gold_intersection).astype(float) + ) + + y_scores_input_nodes = input_nodes_ensemble_df['Frequency'].tolist() + + precision_input_nodes, recall_input_nodes, thresholds_input_nodes = precision_recall_curve(y_true, y_scores_input_nodes) + plt.plot(recall_input_nodes, precision_input_nodes, color='black', marker='o', linestyle='--', label=f'Input Nodes Baseline') + + # Dropping last elements because scikit-learn adds (1, 0) to precision/recall for plotting, not tied to real thresholds + prc_input_nodes_baseline_data = { + 'Threshold': thresholds_input_nodes, + 'Precision': precision_input_nodes[:-1], + 'Recall': recall_input_nodes[:-1], + } + + prc_input_nodes_baseline_data = {'Ensemble_Source': ["Input_Nodes_Baseline"] * len(thresholds_input_nodes), **prc_input_nodes_baseline_data} + prc_input_nodes_baseline_df = pd.DataFrame.from_dict(prc_input_nodes_baseline_data) + precision, recall, thresholds = precision_recall_curve(y_true, y_scores) # avg precision summarizes a precision-recall curve as the weighted mean of precisions achieved at each threshold avg_precision = average_precision_score(y_true, y_scores) @@ -446,7 +472,19 @@ def precision_recall_curve_node_ensemble(node_ensembles: dict, node_table: pd.Da combined_metrics_df['Baseline'] = baseline # merge dfs and NaN out metric values except for first row of each Ensemble_Source - complete_df = combined_prc_df.merge(combined_metrics_df, on='Ensemble_Source', how='left') + complete_df = combined_prc_df.merge(combined_metrics_df, on='Ensemble_Source', how='left').merge(prc_input_nodes_baseline_df, on=['Ensemble_Source', 'Threshold', 'Precision', 'Recall'], how='outer') + + # for each Ensemble_Source, remove Average_Precision and Baseline in all but the first row not_last_rows = complete_df.duplicated(subset='Ensemble_Source', keep='first') complete_df.loc[not_last_rows, ['Average_Precision', 'Baseline']] = None + + # move Input_Nodes_Baseline to the top of the df + complete_df.sort_values( + by='Ensemble_Source', + # x.ne('Input_Nodes_Baseline'): returns a Series of booleans; True for all rows except Input_Nodes_Baseline. + # Since False < True, baseline rows sort to the top. + key=lambda x: x.ne('Input_Nodes_Baseline'), + inplace=True + ) + complete_df.to_csv(output_file, index=False, sep='\t') diff --git a/test/evaluate/expected/expected-pr-curve-ensemble-nodes-empty.txt b/test/evaluate/expected/expected-pr-curve-ensemble-nodes-empty.txt index c9f6561c6..4517ceb0e 100644 --- a/test/evaluate/expected/expected-pr-curve-ensemble-nodes-empty.txt +++ b/test/evaluate/expected/expected-pr-curve-ensemble-nodes-empty.txt @@ -1,2 +1,4 @@ Ensemble_Source Threshold Precision Recall Average_Precision Baseline +Input_Nodes_Baseline 0.0 0.15384615384615385 1.0 +Input_Nodes_Baseline 1.0 1.0 0.25 Aggregated 0.0 0.15384615384615385 1.0 0.15384615384615385 0.15384615384615385 diff --git a/test/evaluate/expected/expected-pr-curve-ensemble-nodes.txt b/test/evaluate/expected/expected-pr-curve-ensemble-nodes.txt index b0e50594e..9a78dbe77 100644 --- a/test/evaluate/expected/expected-pr-curve-ensemble-nodes.txt +++ b/test/evaluate/expected/expected-pr-curve-ensemble-nodes.txt @@ -1,4 +1,6 @@ Ensemble_Source Threshold Precision Recall Average_Precision Baseline +Input_Nodes_Baseline 0.0 0.15384615384615385 1.0 +Input_Nodes_Baseline 1.0 1.0 0.25 Aggregated 0.0 0.15384615384615385 1.0 0.6666666666666666 0.15384615384615385 Aggregated 0.01 0.6666666666666666 1.0 Aggregated 0.5 0.75 0.75 diff --git a/test/evaluate/expected/expected-pr-curve-multiple-ensemble-nodes.txt b/test/evaluate/expected/expected-pr-curve-multiple-ensemble-nodes.txt index 630a89ceb..483e972d8 100644 --- a/test/evaluate/expected/expected-pr-curve-multiple-ensemble-nodes.txt +++ b/test/evaluate/expected/expected-pr-curve-multiple-ensemble-nodes.txt @@ -1,4 +1,6 @@ Ensemble_Source Threshold Precision Recall Average_Precision Baseline +Input_Nodes_Baseline 0.0 0.15384615384615385 1.0 +Input_Nodes_Baseline 1.0 1.0 0.25 Ensemble1 0.0 0.15384615384615385 1.0 0.6666666666666666 0.15384615384615385 Ensemble1 0.01 0.6666666666666666 1.0 Ensemble1 0.5 0.75 0.75 diff --git a/test/evaluate/input/input-nodes.txt b/test/evaluate/input/input-nodes.txt index 1ffd8bdff..ec0e4058f 100644 --- a/test/evaluate/input/input-nodes.txt +++ b/test/evaluate/input/input-nodes.txt @@ -1,3 +1,4 @@ NODEID prize active dummy sources targets -N -C 5.7 True True +N +C 5.7 True True +A 5 True True \ No newline at end of file diff --git a/test/evaluate/test_evaluate.py b/test/evaluate/test_evaluate.py index ce50350e5..422660bd7 100644 --- a/test/evaluate/test_evaluate.py +++ b/test/evaluate/test_evaluate.py @@ -126,8 +126,9 @@ def test_node_ensemble(self): out_path_file = Path(OUT_DIR + 'node-ensemble.csv') out_path_file.unlink(missing_ok=True) ensemble_network = [INPUT_DIR + 'ensemble-network.tsv'] - input_network = OUT_DIR + 'data.pickle' - node_ensemble_dict = Evaluation.edge_frequency_node_ensemble(GS_NODE_TABLE, ensemble_network, input_network) + input_data = OUT_DIR + 'data.pickle' + input_interactome = Evaluation.from_file(input_data).get_interactome() + node_ensemble_dict = Evaluation.edge_frequency_node_ensemble(GS_NODE_TABLE, ensemble_network, input_interactome) node_ensemble_dict['ensemble'].to_csv(out_path_file, sep='\t', index=False) assert filecmp.cmp(out_path_file, EXPECT_DIR + 'expected-node-ensemble.csv', shallow=False) @@ -135,9 +136,9 @@ def test_empty_node_ensemble(self): out_path_file = Path(OUT_DIR + 'empty-node-ensemble.csv') out_path_file.unlink(missing_ok=True) empty_ensemble_network = [INPUT_DIR + 'empty-ensemble-network.tsv'] - input_network = OUT_DIR + 'data.pickle' - node_ensemble_dict = Evaluation.edge_frequency_node_ensemble(GS_NODE_TABLE, empty_ensemble_network, - input_network) + input_data = OUT_DIR + 'data.pickle' + input_interactome = Evaluation.from_file(input_data).get_interactome() + node_ensemble_dict = Evaluation.edge_frequency_node_ensemble(GS_NODE_TABLE, empty_ensemble_network, input_interactome) node_ensemble_dict['empty'].to_csv(out_path_file, sep='\t', index=False) assert filecmp.cmp(out_path_file, EXPECT_DIR + 'expected-empty-node-ensemble.csv', shallow=False) @@ -147,8 +148,9 @@ def test_multiple_node_ensemble(self): out_path_empty_file = Path(OUT_DIR + 'empty-node-ensemble.csv') out_path_empty_file.unlink(missing_ok=True) ensemble_networks = [INPUT_DIR + 'ensemble-network.tsv', INPUT_DIR + 'empty-ensemble-network.tsv'] - input_network = OUT_DIR + 'data.pickle' - node_ensemble_dict = Evaluation.edge_frequency_node_ensemble(GS_NODE_TABLE, ensemble_networks, input_network) + input_data = OUT_DIR + 'data.pickle' + input_interactome = Evaluation.from_file(input_data).get_interactome() + node_ensemble_dict = Evaluation.edge_frequency_node_ensemble(GS_NODE_TABLE, ensemble_networks, input_interactome) node_ensemble_dict['ensemble'].to_csv(out_path_file, sep='\t', index=False) assert filecmp.cmp(out_path_file, EXPECT_DIR + 'expected-node-ensemble.csv', shallow=False) node_ensemble_dict['empty'].to_csv(out_path_empty_file, sep='\t', index=False) @@ -159,9 +161,11 @@ def test_precision_recall_curve_ensemble_nodes(self): out_path_png.unlink(missing_ok=True) out_path_file = Path(OUT_DIR + 'pr-curve-ensemble-nodes.txt') out_path_file.unlink(missing_ok=True) + input_data = OUT_DIR + 'data.pickle' + input_nodes = Evaluation.from_file(input_data).get_interesting_input_nodes() ensemble_file = pd.read_csv(INPUT_DIR + 'node-ensemble.csv', sep='\t', header=0) node_ensembles_dict = {'ensemble': ensemble_file} - Evaluation.precision_recall_curve_node_ensemble(node_ensembles_dict, GS_NODE_TABLE, out_path_png, + Evaluation.precision_recall_curve_node_ensemble(node_ensembles_dict, GS_NODE_TABLE, input_nodes, out_path_png, out_path_file) assert out_path_png.exists() assert filecmp.cmp(out_path_file, EXPECT_DIR + 'expected-pr-curve-ensemble-nodes.txt', shallow=False) @@ -171,9 +175,11 @@ def test_precision_recall_curve_ensemble_nodes_empty(self): out_path_png.unlink(missing_ok=True) out_path_file = Path(OUT_DIR + 'pr-curve-ensemble-nodes-empty.txt') out_path_file.unlink(missing_ok=True) + input_data = OUT_DIR + 'data.pickle' + input_nodes = Evaluation.from_file(input_data).get_interesting_input_nodes() empty_ensemble_file = pd.read_csv(INPUT_DIR + 'node-ensemble-empty.csv', sep='\t', header=0) node_ensembles_dict = {'ensemble': empty_ensemble_file} - Evaluation.precision_recall_curve_node_ensemble(node_ensembles_dict, GS_NODE_TABLE, out_path_png, + Evaluation.precision_recall_curve_node_ensemble(node_ensembles_dict, GS_NODE_TABLE, input_nodes, out_path_png, out_path_file) assert out_path_png.exists() assert filecmp.cmp(out_path_file, EXPECT_DIR + 'expected-pr-curve-ensemble-nodes-empty.txt', shallow=False) @@ -183,10 +189,12 @@ def test_precision_recall_curve_multiple_ensemble_nodes(self): out_path_png.unlink(missing_ok=True) out_path_file = Path(OUT_DIR + 'pr-curve-multiple-ensemble-nodes.txt') out_path_file.unlink(missing_ok=True) + input_data = OUT_DIR + 'data.pickle' + input_nodes = Evaluation.from_file(input_data).get_interesting_input_nodes() ensemble_file = pd.read_csv(INPUT_DIR + 'node-ensemble.csv', sep='\t', header=0) empty_ensemble_file = pd.read_csv(INPUT_DIR + 'node-ensemble-empty.csv', sep='\t', header=0) node_ensembles_dict = {'ensemble1': ensemble_file, 'ensemble2': ensemble_file, 'ensemble3': empty_ensemble_file} - Evaluation.precision_recall_curve_node_ensemble(node_ensembles_dict, GS_NODE_TABLE, out_path_png, + Evaluation.precision_recall_curve_node_ensemble(node_ensembles_dict, GS_NODE_TABLE, input_nodes, out_path_png, out_path_file, True) assert out_path_png.exists() assert filecmp.cmp(out_path_file, EXPECT_DIR + 'expected-pr-curve-multiple-ensemble-nodes.txt', shallow=False)