Skip to content
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -524,8 +524,10 @@ rule evaluation_ensemble_pr_curve:
pr_curve_file = SEP.join([out_dir, '{dataset_gold_standard_pairs}-eval', 'pr-curve-ensemble-nodes.txt']),
run:
node_table = Evaluation.from_file(input.gold_standard_file).node_table
node_ensemble_dict = Evaluation.edge_frequency_node_ensemble(node_table, input.ensemble_file, input.dataset_file)
Evaluation.precision_recall_curve_node_ensemble(node_ensemble_dict, node_table, output.pr_curve_png, output.pr_curve_file)
input_interactome = Evaluation.from_file(input.dataset_file).get_interactome()
node_ensemble_dict = Evaluation.edge_frequency_node_ensemble(node_table, input.ensemble_file, input_interactome)
input_nodes = Evaluation.from_file(input.dataset_file).get_node_columns(["sources", "targets", "prize", "active"])
Evaluation.precision_recall_curve_node_ensemble(node_ensemble_dict, node_table, input_nodes,output.pr_curve_png, output.pr_curve_file)

# Returns list of algorithm specific ensemble files per dataset
def collect_ensemble_per_algo_per_dataset(wildcards):
Expand All @@ -543,8 +545,10 @@ rule evaluation_per_algo_ensemble_pr_curve:
pr_curve_file = SEP.join([out_dir, '{dataset_gold_standard_pairs}-eval', 'pr-curve-ensemble-nodes-per-algorithm.txt']),
run:
node_table = Evaluation.from_file(input.gold_standard_file).node_table
node_ensembles_dict = Evaluation.edge_frequency_node_ensemble(node_table, input.ensemble_files, input.dataset_file)
Evaluation.precision_recall_curve_node_ensemble(node_ensembles_dict, node_table, output.pr_curve_png, output.pr_curve_file, include_aggregate_algo_eval)
input_interactome = Evaluation.from_file(input.dataset_file).get_interactome()
node_ensembles_dict = Evaluation.edge_frequency_node_ensemble(node_table, input.ensemble_files, input_interactome)
input_nodes = Evaluation.from_file(input.dataset_file).get_node_columns(["sources", "targets", "prize", "active"])
Evaluation.precision_recall_curve_node_ensemble(node_ensembles_dict, node_table, input_nodes,output.pr_curve_png, output.pr_curve_file, include_aggregate_algo_eval)


# Remove the output directory
Expand Down
67 changes: 55 additions & 12 deletions spras/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def pca_chosen_pathway(coordinates_files: list[Union[str, PathLike]], pathway_su
return rep_pathways

@staticmethod
def edge_frequency_node_ensemble(node_table: pd.DataFrame, ensemble_files: list[Union[str, PathLike]], dataset_file: str) -> dict:
def edge_frequency_node_ensemble(node_table: pd.DataFrame, ensemble_files: list[Union[str, PathLike]], input_interactome: pd.DataFrame) -> dict:
"""
Generates a dictionary of node ensembles using edge frequency data from a list of ensemble files.
A list of ensemble files can contain an aggregated ensemble or algorithm-specific ensembles per dataset
Expand All @@ -308,28 +308,25 @@ def edge_frequency_node_ensemble(node_table: pd.DataFrame, ensemble_files: list[

@param node_table: dataFrame of gold standard nodes (column: NODEID)
@param ensemble_files: list of file paths containing edge ensemble outputs
@param dataset_file: path to the dataset file used to load the interactome
@param input_interactome: the input interactome used for a specific dataset
@return: dictionary mapping each ensemble source to its node ensemble DataFrame
"""

node_ensembles_dict = dict()

pickle = Evaluation.from_file(dataset_file)
interactome = pickle.get_interactome()

if interactome.empty:
if input_interactome.empty:
raise ValueError(
f"Cannot compute PR curve or generate node ensemble. Input network for dataset \"{dataset_file.split('-')[0]}\" is empty."
f"Cannot compute PR curve or generate node ensemble. The input network is empty."
)
if node_table.empty:
raise ValueError(
f"Cannot compute PR curve or generate node ensemble. Gold standard associated with dataset \"{dataset_file.split('-')[0]}\" is empty."
f"Cannot compute PR curve or generate node ensemble. The gold standard is empty."
)
Comment on lines 318 to 324
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should still supply the dataset name in this function to preserve the error information.


# set the initial default frequencies to 0 for all interactome and gold standard nodes
node1_interactome = interactome[['Interactor1']].rename(columns={'Interactor1': 'Node'})
node1_interactome = input_interactome[['Interactor1']].rename(columns={'Interactor1': 'Node'})
node1_interactome['Frequency'] = 0.0
node2_interactome = interactome[['Interactor2']].rename(columns={'Interactor2': 'Node'})
node2_interactome = input_interactome[['Interactor2']].rename(columns={'Interactor2': 'Node'})
node2_interactome['Frequency'] = 0.0
gs_nodes = node_table[[Evaluation.NODE_ID]].rename(columns={Evaluation.NODE_ID: 'Node'})
gs_nodes['Frequency'] = 0.0
Expand All @@ -354,7 +351,7 @@ def edge_frequency_node_ensemble(node_table: pd.DataFrame, ensemble_files: list[
return node_ensembles_dict

@staticmethod
def precision_recall_curve_node_ensemble(node_ensembles: dict, node_table: pd.DataFrame, output_png: str | PathLike,
def precision_recall_curve_node_ensemble(node_ensembles: dict, node_table: pd.DataFrame, input_nodes: pd.DataFrame, output_png: str | PathLike,
output_file: str | PathLike, aggregate_per_algorithm: bool = False):
"""
Plots precision-recall (PR) curves for a set of node ensembles evaluated against a gold standard.
Expand All @@ -365,6 +362,7 @@ def precision_recall_curve_node_ensemble(node_ensembles: dict, node_table: pd.Da

@param node_ensembles: dict of the pre-computed node_ensemble(s)
@param node_table: gold standard nodes
@param input_nodes: the input nodes (sources, targets, prizes, actives) used for a specific dataset
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@param input_nodes: the input nodes (sources, targets, prizes, actives) used for a specific dataset
@param input_nodes: the input nodes (usually from `Dataset#get_interesting_input_nodes`) used for a specific dataset

@param output_png: filename to save the precision and recall curves as a .png image
@param output_file: filename to save the precision, recall, threshold values, average precision, and baseline
average precision
Expand All @@ -380,13 +378,46 @@ def precision_recall_curve_node_ensemble(node_ensembles: dict, node_table: pd.Da

prc_dfs = []
metric_dfs = []
prc_input_nodes_baseline_df = None

baseline = None
precision_input_nodes = None
recall_input_nodes = None
thresholds_input_nodes = None


for label, node_ensemble in node_ensembles.items():
if not node_ensemble.empty:
y_true = [1 if node in gold_standard_nodes else 0 for node in node_ensemble['Node']]
y_scores = node_ensemble['Frequency'].tolist()

# input nodes (sources, targets, prizes, actives) may be easier to recover but are still valid gold standard nodes;
# the Input_Nodes_Baseline PR curve highlights their overlap with the gold standard.
if prc_input_nodes_baseline_df is None:
input_nodes_set = set(input_nodes['NODEID'])
input_nodes_gold_intersection = input_nodes_set & gold_standard_nodes # TODO should this be all inputs nodes or the intersection with the gold standard for this baseline? I think it should be the intersection
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very good question. For a synthetic dataset like Panther pathways, the inputs are sampled from the pathway nodes that make up the gold standard so it doesn't matter.

For an omics input like EGFR, it matters substantially. What makes you prefer the intersection? I was inclined to say all input nodes because we cannot have a baseline algorithm that makes use of gold standard information as part of its ranking. I could create a valid pathway reconstruction algorithm that takes the input nodes and simply returns those. I can't use a gold standard in a valid pathway reconstruction algorithm, however.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my opinion, the only input nodes that matter for evaluation are those that overlap with the gold standard. While it’s true that an algorithm could trivially return all input nodes, our precision recall evaluation is only defined with respect to the gold standard. Input nodes that aren’t in the gold standard don’t contribute to true positives, so including them in the baseline wouldn’t be meaningful.

That said, I also see the case for using all input nodes as it represents a valid baseline algorithm where an algorithm could simply return the given inputs without any reconstruction.

Maybe the difference is that the intersection provides an upper bound, while all input nodes provides a lower bound on what you could do without doing any reconstruction and we should provide both?

Copy link
Collaborator Author

@ntalluri ntalluri Aug 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

think through 3 baselines and explain why each of these can be used/needed

  1. the no intersection, but input nodes by itself
  2. the intersection of the gold standard and input nodes
  3. the the gold standard by itself (which I think is what baseline is)
  • or do we want to make an ensemble of the gold standard as 1.0 and everything else 0.0 and do a PRC

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Deciding this is the last point of feedback and last decision to finalize. Then I can do a last careful review and we should be ready to merge.

In our meeting, we discussed how options 1 and 2 will give the same recall. The only difference is that one will have some precision value and the other always has precision of 1.0.

input_nodes_ensemble_df = node_ensemble.copy()

input_nodes_ensemble_df["Frequency"] = (
input_nodes_ensemble_df["Node"].isin(input_nodes_gold_intersection).astype(float)
)

y_scores_input_nodes = input_nodes_ensemble_df['Frequency'].tolist()

precision_input_nodes, recall_input_nodes, thresholds_input_nodes = precision_recall_curve(y_true, y_scores_input_nodes)
plt.plot(recall_input_nodes, precision_input_nodes, color='black', marker='o', linestyle='--',
label=f'Input Nodes Baseline')

# Dropping last elements because scikit-learn adds (1, 0) to precision/recall for plotting, not tied to real thresholds
prc_input_nodes_baseline_data = {
Comment on lines +406 to +407
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(The comment should be moved up)

Suggested change
# Dropping last elements because scikit-learn adds (1, 0) to precision/recall for plotting, not tied to real thresholds
prc_input_nodes_baseline_data = {
# Dropping last elements because scikit-learn adds (1, 0) to precision/recall for plotting, not tied to real thresholds
# https://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_recall_curve.html#sklearn.metrics.precision_recall_curve:~:text=Returns%3A-,precision,predictions%20with%20score%20%3E%3D%20thresholds%5Bi%5D%20and%20the%20last%20element%20is%200.,-thresholds
prc_input_nodes_baseline_data = {

'Threshold': thresholds_input_nodes,
'Precision': precision_input_nodes[:-1],
'Recall': recall_input_nodes[:-1],
}

prc_input_nodes_baseline_data = {'Ensemble_Source': ["Input_Nodes_Baseline"] * len(thresholds_input_nodes), **prc_input_nodes_baseline_data}
prc_input_nodes_baseline_df = pd.DataFrame.from_dict(prc_input_nodes_baseline_data)

precision, recall, thresholds = precision_recall_curve(y_true, y_scores)
# avg precision summarizes a precision-recall curve as the weighted mean of precisions achieved at each threshold
avg_precision = average_precision_score(y_true, y_scores)
Expand Down Expand Up @@ -446,7 +477,19 @@ def precision_recall_curve_node_ensemble(node_ensembles: dict, node_table: pd.Da
combined_metrics_df['Baseline'] = baseline

# merge dfs and NaN out metric values except for first row of each Ensemble_Source
complete_df = combined_prc_df.merge(combined_metrics_df, on='Ensemble_Source', how='left')
complete_df = combined_prc_df.merge(combined_metrics_df, on='Ensemble_Source', how='left').merge(prc_input_nodes_baseline_df, on=['Ensemble_Source', 'Threshold', 'Precision', 'Recall'], how='outer')

# for each Ensemble_Source, remove Average_Precision and Baseline in all but the first row
not_last_rows = complete_df.duplicated(subset='Ensemble_Source', keep='first')
complete_df.loc[not_last_rows, ['Average_Precision', 'Baseline']] = None

# move Input_Nodes_Baseline to the top of the df
complete_df.sort_values(
by='Ensemble_Source',
# x.ne('Input_Nodes_Baseline'): returns a Series of booleans; True for all rows except Input_Nodes_Baseline.
# Since False < True, baseline rows sort to the top.
key=lambda x: x.ne('Input_Nodes_Baseline'),
inplace=True
)

complete_df.to_csv(output_file, index=False, sep='\t')
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
Ensemble_Source Threshold Precision Recall Average_Precision Baseline
Input_Nodes_Baseline 0.0 0.15384615384615385 1.0
Input_Nodes_Baseline 1.0 1.0 0.25
Aggregated 0.0 0.15384615384615385 1.0 0.15384615384615385 0.15384615384615385
2 changes: 2 additions & 0 deletions test/evaluate/expected/expected-pr-curve-ensemble-nodes.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
Ensemble_Source Threshold Precision Recall Average_Precision Baseline
Input_Nodes_Baseline 0.0 0.15384615384615385 1.0
Input_Nodes_Baseline 1.0 1.0 0.25
Aggregated 0.0 0.15384615384615385 1.0 0.6666666666666666 0.15384615384615385
Aggregated 0.01 0.6666666666666666 1.0
Aggregated 0.5 0.75 0.75
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
Ensemble_Source Threshold Precision Recall Average_Precision Baseline
Input_Nodes_Baseline 0.0 0.15384615384615385 1.0
Input_Nodes_Baseline 1.0 1.0 0.25
Ensemble1 0.0 0.15384615384615385 1.0 0.6666666666666666 0.15384615384615385
Ensemble1 0.01 0.6666666666666666 1.0
Ensemble1 0.5 0.75 0.75
Expand Down
5 changes: 3 additions & 2 deletions test/evaluate/input/input-nodes.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
NODEID prize active dummy sources targets
N
C 5.7 True True
N
C 5.7 True True
A 5 True True
28 changes: 18 additions & 10 deletions test/evaluate/test_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,18 +126,19 @@ def test_node_ensemble(self):
out_path_file = Path(OUT_DIR + 'node-ensemble.csv')
out_path_file.unlink(missing_ok=True)
ensemble_network = [INPUT_DIR + 'ensemble-network.tsv']
input_network = OUT_DIR + 'data.pickle'
node_ensemble_dict = Evaluation.edge_frequency_node_ensemble(GS_NODE_TABLE, ensemble_network, input_network)
input_data = OUT_DIR + 'data.pickle'
input_interactome = Evaluation.from_file(input_data).get_interactome()
node_ensemble_dict = Evaluation.edge_frequency_node_ensemble(GS_NODE_TABLE, ensemble_network, input_interactome)
node_ensemble_dict['ensemble'].to_csv(out_path_file, sep='\t', index=False)
assert filecmp.cmp(out_path_file, EXPECT_DIR + 'expected-node-ensemble.csv', shallow=False)

def test_empty_node_ensemble(self):
out_path_file = Path(OUT_DIR + 'empty-node-ensemble.csv')
out_path_file.unlink(missing_ok=True)
empty_ensemble_network = [INPUT_DIR + 'empty-ensemble-network.tsv']
input_network = OUT_DIR + 'data.pickle'
node_ensemble_dict = Evaluation.edge_frequency_node_ensemble(GS_NODE_TABLE, empty_ensemble_network,
input_network)
input_data = OUT_DIR + 'data.pickle'
input_interactome = Evaluation.from_file(input_data).get_interactome()
node_ensemble_dict = Evaluation.edge_frequency_node_ensemble(GS_NODE_TABLE, empty_ensemble_network, input_interactome)
node_ensemble_dict['empty'].to_csv(out_path_file, sep='\t', index=False)
assert filecmp.cmp(out_path_file, EXPECT_DIR + 'expected-empty-node-ensemble.csv', shallow=False)

Expand All @@ -147,8 +148,9 @@ def test_multiple_node_ensemble(self):
out_path_empty_file = Path(OUT_DIR + 'empty-node-ensemble.csv')
out_path_empty_file.unlink(missing_ok=True)
ensemble_networks = [INPUT_DIR + 'ensemble-network.tsv', INPUT_DIR + 'empty-ensemble-network.tsv']
input_network = OUT_DIR + 'data.pickle'
node_ensemble_dict = Evaluation.edge_frequency_node_ensemble(GS_NODE_TABLE, ensemble_networks, input_network)
input_data = OUT_DIR + 'data.pickle'
input_interactome = Evaluation.from_file(input_data).get_interactome()
node_ensemble_dict = Evaluation.edge_frequency_node_ensemble(GS_NODE_TABLE, ensemble_networks, input_interactome)
node_ensemble_dict['ensemble'].to_csv(out_path_file, sep='\t', index=False)
assert filecmp.cmp(out_path_file, EXPECT_DIR + 'expected-node-ensemble.csv', shallow=False)
node_ensemble_dict['empty'].to_csv(out_path_empty_file, sep='\t', index=False)
Expand All @@ -159,9 +161,11 @@ def test_precision_recall_curve_ensemble_nodes(self):
out_path_png.unlink(missing_ok=True)
out_path_file = Path(OUT_DIR + 'pr-curve-ensemble-nodes.txt')
out_path_file.unlink(missing_ok=True)
input_data = OUT_DIR + 'data.pickle'
input_nodes = Evaluation.from_file(input_data).get_node_columns(["sources", "targets", "prize", "active"])
ensemble_file = pd.read_csv(INPUT_DIR + 'node-ensemble.csv', sep='\t', header=0)
node_ensembles_dict = {'ensemble': ensemble_file}
Evaluation.precision_recall_curve_node_ensemble(node_ensembles_dict, GS_NODE_TABLE, out_path_png,
Evaluation.precision_recall_curve_node_ensemble(node_ensembles_dict, GS_NODE_TABLE, input_nodes, out_path_png,
out_path_file)
assert out_path_png.exists()
assert filecmp.cmp(out_path_file, EXPECT_DIR + 'expected-pr-curve-ensemble-nodes.txt', shallow=False)
Expand All @@ -171,9 +175,11 @@ def test_precision_recall_curve_ensemble_nodes_empty(self):
out_path_png.unlink(missing_ok=True)
out_path_file = Path(OUT_DIR + 'pr-curve-ensemble-nodes-empty.txt')
out_path_file.unlink(missing_ok=True)
input_data = OUT_DIR + 'data.pickle'
input_nodes = Evaluation.from_file(input_data).get_node_columns(["sources", "targets", "prize", "active"])
empty_ensemble_file = pd.read_csv(INPUT_DIR + 'node-ensemble-empty.csv', sep='\t', header=0)
node_ensembles_dict = {'ensemble': empty_ensemble_file}
Evaluation.precision_recall_curve_node_ensemble(node_ensembles_dict, GS_NODE_TABLE, out_path_png,
Evaluation.precision_recall_curve_node_ensemble(node_ensembles_dict, GS_NODE_TABLE, input_nodes, out_path_png,
out_path_file)
assert out_path_png.exists()
assert filecmp.cmp(out_path_file, EXPECT_DIR + 'expected-pr-curve-ensemble-nodes-empty.txt', shallow=False)
Expand All @@ -183,10 +189,12 @@ def test_precision_recall_curve_multiple_ensemble_nodes(self):
out_path_png.unlink(missing_ok=True)
out_path_file = Path(OUT_DIR + 'pr-curve-multiple-ensemble-nodes.txt')
out_path_file.unlink(missing_ok=True)
input_data = OUT_DIR + 'data.pickle'
input_nodes = Evaluation.from_file(input_data).get_node_columns(["sources", "targets", "prize", "active"])
ensemble_file = pd.read_csv(INPUT_DIR + 'node-ensemble.csv', sep='\t', header=0)
empty_ensemble_file = pd.read_csv(INPUT_DIR + 'node-ensemble-empty.csv', sep='\t', header=0)
node_ensembles_dict = {'ensemble1': ensemble_file, 'ensemble2': ensemble_file, 'ensemble3': empty_ensemble_file}
Evaluation.precision_recall_curve_node_ensemble(node_ensembles_dict, GS_NODE_TABLE, out_path_png,
Evaluation.precision_recall_curve_node_ensemble(node_ensembles_dict, GS_NODE_TABLE, input_nodes, out_path_png,
out_path_file, True)
assert out_path_png.exists()
assert filecmp.cmp(out_path_file, EXPECT_DIR + 'expected-pr-curve-multiple-ensemble-nodes.txt', shallow=False)
Loading