Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ rule evaluation_ensemble_pr_curve:
run:
node_table = Evaluation.from_file(input.gold_standard_file).node_table
node_ensemble_dict = Evaluation.edge_frequency_node_ensemble(node_table, input.ensemble_file, input.dataset_file)
Evaluation.precision_recall_curve_node_ensemble(node_ensemble_dict, node_table, output.pr_curve_png, output.pr_curve_file)
Evaluation.precision_recall_curve_node_ensemble(node_ensemble_dict, node_table, input.dataset_file,output.pr_curve_png, output.pr_curve_file)

# Returns list of algorithm specific ensemble files per dataset
def collect_ensemble_per_algo_per_dataset(wildcards):
Expand All @@ -544,7 +544,7 @@ rule evaluation_per_algo_ensemble_pr_curve:
run:
node_table = Evaluation.from_file(input.gold_standard_file).node_table
node_ensembles_dict = Evaluation.edge_frequency_node_ensemble(node_table, input.ensemble_files, input.dataset_file)
Evaluation.precision_recall_curve_node_ensemble(node_ensembles_dict, node_table, output.pr_curve_png, output.pr_curve_file, include_aggregate_algo_eval)
Evaluation.precision_recall_curve_node_ensemble(node_ensembles_dict, node_table, input.dataset_file,output.pr_curve_png, output.pr_curve_file, include_aggregate_algo_eval)


# Remove the output directory
Expand Down
49 changes: 47 additions & 2 deletions spras/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ def edge_frequency_node_ensemble(node_table: pd.DataFrame, ensemble_files: list[
return node_ensembles_dict

@staticmethod
def precision_recall_curve_node_ensemble(node_ensembles: dict, node_table: pd.DataFrame, output_png: str | PathLike,
def precision_recall_curve_node_ensemble(node_ensembles: dict, node_table: pd.DataFrame, dataset_file: str, output_png: str | PathLike,
output_file: str | PathLike, aggregate_per_algorithm: bool = False):
"""
Plots precision-recall (PR) curves for a set of node ensembles evaluated against a gold standard.
Expand All @@ -380,13 +380,58 @@ def precision_recall_curve_node_ensemble(node_ensembles: dict, node_table: pd.Da

prc_dfs = []
metric_dfs = []
prc_input_nodes_baseline_df = None

baseline = None
precision_input_nodes = None
recall_input_nodes = None
thresholds_input_nodes = None


for label, node_ensemble in node_ensembles.items():
if not node_ensemble.empty:
y_true = [1 if node in gold_standard_nodes else 0 for node in node_ensemble['Node']]
y_scores = node_ensemble['Frequency'].tolist()

if precision_input_nodes is None and recall_input_nodes is None and thresholds_input_nodes is None:
pickle = Evaluation.from_file(dataset_file)
input_nodes_df = pickle.get_node_columns(["sources", "targets", "prize", "active"])
input_nodes = set(input_nodes_df['NODEID'])
input_nodes_gold_intersection = input_nodes & gold_standard_nodes # TODO should this be all inputs nodes or the intersection with the gold standard for this baseline?
input_nodes_ensemble_df = node_ensemble.copy()

input_nodes_ensemble_df.loc[
input_nodes_ensemble_df['Node'].isin(input_nodes_gold_intersection),
'Frequency'
] = 1.0

input_nodes_ensemble_df.loc[
~input_nodes_ensemble_df['Node'].isin(input_nodes_gold_intersection),
'Frequency'
] = 0.0

print(input_nodes_ensemble_df)

y_scores_input_nodes = input_nodes_ensemble_df['Frequency'].tolist()

precision_input_nodes, recall_input_nodes, thresholds_input_nodes = precision_recall_curve(y_true, y_scores_input_nodes)
plt.plot(recall_input_nodes, precision_input_nodes, color='black', marker='o', linestyle='--',
label=f'Input Nodes Baseline')

print(precision_input_nodes)
print(recall_input_nodes)
prc_input_nodes_baseline_data = {
'Threshold': thresholds_input_nodes,
'Precision': precision_input_nodes[:-1],
'Recall': recall_input_nodes[:-1],
}

prc_input_nodes_baseline_data = {'Ensemble_Source': ["input_nodes_baseline"] * len(thresholds_input_nodes), **prc_input_nodes_baseline_data}

prc_input_nodes_baseline_df = pd.DataFrame.from_dict(prc_input_nodes_baseline_data)



precision, recall, thresholds = precision_recall_curve(y_true, y_scores)
# avg precision summarizes a precision-recall curve as the weighted mean of precisions achieved at each threshold
avg_precision = average_precision_score(y_true, y_scores)
Expand Down Expand Up @@ -446,7 +491,7 @@ def precision_recall_curve_node_ensemble(node_ensembles: dict, node_table: pd.Da
combined_metrics_df['Baseline'] = baseline

# merge dfs and NaN out metric values except for first row of each Ensemble_Source
complete_df = combined_prc_df.merge(combined_metrics_df, on='Ensemble_Source', how='left')
complete_df = combined_prc_df.merge(combined_metrics_df, on='Ensemble_Source', how='left').merge(prc_input_nodes_baseline_df, on=['Ensemble_Source', 'Threshold', 'Precision', 'Recall'], how='outer')
not_last_rows = complete_df.duplicated(subset='Ensemble_Source', keep='first')
complete_df.loc[not_last_rows, ['Average_Precision', 'Baseline']] = None
complete_df.to_csv(output_file, index=False, sep='\t')
4 changes: 2 additions & 2 deletions test/evaluate/input/input-nodes.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
NODEID prize active dummy sources targets
N
C 5.7 True True
N
C 5.7 True True
32 changes: 23 additions & 9 deletions test/evaluate/test_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ def setup_class(cls):
'other_files': []
})

# TODO figure out why the input-nodes file is not being included in the data.pickle file
# it keeps coming up empty
with open(out_dataset, 'wb') as f:
pickle.dump(dataset, f)

Expand Down Expand Up @@ -126,18 +128,18 @@ def test_node_ensemble(self):
out_path_file = Path(OUT_DIR + 'node-ensemble.csv')
out_path_file.unlink(missing_ok=True)
ensemble_network = [INPUT_DIR + 'ensemble-network.tsv']
input_network = OUT_DIR + 'data.pickle'
node_ensemble_dict = Evaluation.edge_frequency_node_ensemble(GS_NODE_TABLE, ensemble_network, input_network)
input_data = OUT_DIR + 'data.pickle'
node_ensemble_dict = Evaluation.edge_frequency_node_ensemble(GS_NODE_TABLE, ensemble_network, input_data)
node_ensemble_dict['ensemble'].to_csv(out_path_file, sep='\t', index=False)
assert filecmp.cmp(out_path_file, EXPECT_DIR + 'expected-node-ensemble.csv', shallow=False)

def test_empty_node_ensemble(self):
out_path_file = Path(OUT_DIR + 'empty-node-ensemble.csv')
out_path_file.unlink(missing_ok=True)
empty_ensemble_network = [INPUT_DIR + 'empty-ensemble-network.tsv']
input_network = OUT_DIR + 'data.pickle'
input_data = OUT_DIR + 'data.pickle'
node_ensemble_dict = Evaluation.edge_frequency_node_ensemble(GS_NODE_TABLE, empty_ensemble_network,
input_network)
input_data)
node_ensemble_dict['empty'].to_csv(out_path_file, sep='\t', index=False)
assert filecmp.cmp(out_path_file, EXPECT_DIR + 'expected-empty-node-ensemble.csv', shallow=False)

Expand All @@ -147,8 +149,8 @@ def test_multiple_node_ensemble(self):
out_path_empty_file = Path(OUT_DIR + 'empty-node-ensemble.csv')
out_path_empty_file.unlink(missing_ok=True)
ensemble_networks = [INPUT_DIR + 'ensemble-network.tsv', INPUT_DIR + 'empty-ensemble-network.tsv']
input_network = OUT_DIR + 'data.pickle'
node_ensemble_dict = Evaluation.edge_frequency_node_ensemble(GS_NODE_TABLE, ensemble_networks, input_network)
input_data = OUT_DIR + 'data.pickle'
node_ensemble_dict = Evaluation.edge_frequency_node_ensemble(GS_NODE_TABLE, ensemble_networks, input_data)
node_ensemble_dict['ensemble'].to_csv(out_path_file, sep='\t', index=False)
assert filecmp.cmp(out_path_file, EXPECT_DIR + 'expected-node-ensemble.csv', shallow=False)
node_ensemble_dict['empty'].to_csv(out_path_empty_file, sep='\t', index=False)
Expand All @@ -159,9 +161,19 @@ def test_precision_recall_curve_ensemble_nodes(self):
out_path_png.unlink(missing_ok=True)
out_path_file = Path(OUT_DIR + 'pr-curve-ensemble-nodes.txt')
out_path_file.unlink(missing_ok=True)
input_data = OUT_DIR + 'data.pickle'

pickle = Evaluation.from_file(input_data)
input_nodes_df = pickle.get_node_columns(["prize", "active"])

print(input_nodes_df)

ensemble_file = pd.read_csv(INPUT_DIR + 'node-ensemble.csv', sep='\t', header=0)

print(ensemble_file)

node_ensembles_dict = {'ensemble': ensemble_file}
Evaluation.precision_recall_curve_node_ensemble(node_ensembles_dict, GS_NODE_TABLE, out_path_png,
Evaluation.precision_recall_curve_node_ensemble(node_ensembles_dict, GS_NODE_TABLE, input_data, out_path_png,
out_path_file)
assert out_path_png.exists()
assert filecmp.cmp(out_path_file, EXPECT_DIR + 'expected-pr-curve-ensemble-nodes.txt', shallow=False)
Expand All @@ -171,9 +183,10 @@ def test_precision_recall_curve_ensemble_nodes_empty(self):
out_path_png.unlink(missing_ok=True)
out_path_file = Path(OUT_DIR + 'pr-curve-ensemble-nodes-empty.txt')
out_path_file.unlink(missing_ok=True)
input_data = OUT_DIR + 'data.pickle'
empty_ensemble_file = pd.read_csv(INPUT_DIR + 'node-ensemble-empty.csv', sep='\t', header=0)
node_ensembles_dict = {'ensemble': empty_ensemble_file}
Evaluation.precision_recall_curve_node_ensemble(node_ensembles_dict, GS_NODE_TABLE, out_path_png,
Evaluation.precision_recall_curve_node_ensemble(node_ensembles_dict, GS_NODE_TABLE, input_data, out_path_png,
out_path_file)
assert out_path_png.exists()
assert filecmp.cmp(out_path_file, EXPECT_DIR + 'expected-pr-curve-ensemble-nodes-empty.txt', shallow=False)
Expand All @@ -183,10 +196,11 @@ def test_precision_recall_curve_multiple_ensemble_nodes(self):
out_path_png.unlink(missing_ok=True)
out_path_file = Path(OUT_DIR + 'pr-curve-multiple-ensemble-nodes.txt')
out_path_file.unlink(missing_ok=True)
input_data = OUT_DIR + 'data.pickle'
ensemble_file = pd.read_csv(INPUT_DIR + 'node-ensemble.csv', sep='\t', header=0)
empty_ensemble_file = pd.read_csv(INPUT_DIR + 'node-ensemble-empty.csv', sep='\t', header=0)
node_ensembles_dict = {'ensemble1': ensemble_file, 'ensemble2': ensemble_file, 'ensemble3': empty_ensemble_file}
Evaluation.precision_recall_curve_node_ensemble(node_ensembles_dict, GS_NODE_TABLE, out_path_png,
Evaluation.precision_recall_curve_node_ensemble(node_ensembles_dict, GS_NODE_TABLE, input_data, out_path_png,
out_path_file, True)
assert out_path_png.exists()
assert filecmp.cmp(out_path_file, EXPECT_DIR + 'expected-pr-curve-multiple-ensemble-nodes.txt', shallow=False)
Loading