Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 29 additions & 32 deletions clusterclue/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,14 +257,13 @@ def get_commands():
)
# GWMs
build.add_argument(
"--n_comp",
dest="n_comp_list",
"--k_values",
nargs="+",
type=int,
default=[100,],
default=[1500,],
metavar="<int>",
help="Specify one or more integers to define the number of components "
"for dimensionality reduction when generating the motifs (default: 100).",
help="Specify one or more integers to define the number of motifs (GWMS). "
"Each integer represents a different clustering configuration (default: 1500).",
)
build.add_argument(
"--ref_sc",
Expand Down Expand Up @@ -378,50 +377,48 @@ def get_commands():


def main():
"""
Main function to execute the iPRESTO pipeline.

This function retrieves command line arguments, prints them if verbose mode is enabled,
and then runs the main pipeline with the provided arguments.

Parameters:
None

Returns:
None
"""
start_time = time.time()

cmd = get_commands()

Path(cmd.out_dir_path).mkdir(parents=True, exist_ok=True)
log_file_path = Path(cmd.out_dir_path) / "clusterclue.log"

# Set up multiprocessing-friendly logging
# multiprocessing-friendly logging
queue = Queue(-1)
listener = Process(target=listener_process, args=(queue, log_file_path, cmd.verbose))
listener.start()
worker_configurer(queue)
logger = logging.getLogger("clusterclue.cli")
logger.info("Command: %s", " ".join(sys.argv))

if cmd.mode == "build":
create_new_motifs(cmd, queue)
exit_code = 0
try:
if cmd.mode == "build":
create_new_motifs(cmd, queue)
elif cmd.mode == "detect":
detect_existing_motifs(cmd, queue)

end_time = time.time()
elapsed_time = end_time - start_time
hours, remainder = divmod(elapsed_time, 3600)
minutes, seconds = divmod(remainder, 60)
logger.info("Total runtime: %d hours and %d minutes", int(hours), int(minutes))

elif cmd.mode == "detect":
detect_existing_motifs(cmd, queue)
except Exception as e:
# Log the full traceback so it appears in the log file
logger.exception(f"Pipeline failed with error: {e}")
exit_code = 1

end_time = time.time()
elapsed_time = end_time - start_time
hours, remainder = divmod(elapsed_time, 3600)
minutes, seconds = divmod(remainder, 60)
logger.info("Total runtime: %d hours and %d minutes", int(hours), int(minutes))
finally:
# Always shut down the listener cleanly regardless of success or failure
queue.put(None)
listener.join()
queue.close()
queue.join_thread()

queue.put(None)
listener.join()
queue.close()
queue.join_thread()
sys.exit(exit_code) # non-zero exit code triggers set -e in bash


if __name__ == "__main__":
main()
main()
83 changes: 38 additions & 45 deletions clusterclue/evaluate/evaluate_hits.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,29 +93,38 @@ def calculate_jaccard(annotated_genes: set, hit_genes: set) -> float:
return len(overlapping_genes) / len(union_genes)


def calculate_penalized_f1(
f1: float, n_overlapping_hits: int, alpha: float
) -> float:
"""Calculate the penalized F1 score based on the number of hits with Jaccard index above a threshold.
def penalize_score(
score: float,
n_overlap: int,
alpha: float = 0.3,
beta: float = 2.0
) -> float:
"""
Penalizes a score based on the number of overlapping hits.

The term (n_overlapping_hits−1) ensures no penalty when only one predicted subcluster overlaps
(ideal case), with penalties scaling for n_overlapping_hits>1.
Uses with tunable parameters alpha and beta for penalty strength and growth rate.


Args:
f1 (float): The F1 score to be penalized.
n_overlapping_hits (int): Number of of hits with Jaccard index above the threshold.
score (float): The original score to be penalized.
n_overlap (int): The number of overlapping hits (n).
alpha (float): Controls the penalty strength, higher values (e.g. 0.5)
penalize more than lower values (e.g. 0.1).

penalize more than lower values (e.g. 0.1). Default is 0.3.
beta (float): Controls the growth rate of the penalty with n_overlap.
beta=1 is linear (original), beta=2 is quadratic, higher = more aggressive.

Returns:
float: The penalized F1 score.
float: The penalized score.
"""
penalty = 1 / (1 + alpha * (n_overlapping_hits - 1))
return f1 * penalty
penalty = 1 / (1 + alpha * (n_overlap - 1)**beta)
return score * penalty


def assign_best_hit(row: pd.Series, hits: Dict[str, List[MotifHit | PrestoHit]], alpha: float = 0.25) -> dict:
def assign_best_hit(
row: pd.Series,
hits: Dict[str, List[MotifHit | PrestoHit]],
alpha: float = 0.25,
beta: float = 2.0
) -> dict:
"""Find the best hit for a given row of annotated subclusters.
alpha is set to 0.25 based on manual evaluation of motif sets,
but can be adjusted based on the desired penalty strength.
Expand Down Expand Up @@ -153,10 +162,13 @@ def assign_best_hit(row: pd.Series, hits: Dict[str, List[MotifHit | PrestoHit]],

if overlapping_hits:
best_hit = max(overlapping_hits, key=lambda x: x["overlap_score"])
n_overlapping_hits = len(overlapping_hits)
penalized_f1 = calculate_penalized_f1(best_hit["overlap_score"], n_overlapping_hits, alpha)
best_hit["n_overlapping_hits"] = n_overlapping_hits
best_hit["redundancy_penalised_overlap_score"] = penalized_f1
best_hit["n_overlapping_hits"] = len(overlapping_hits)
best_hit["penalized_score"] = penalize_score(
best_hit["overlap_score"],
len(overlapping_hits),
alpha=alpha,
beta=beta
)
else:
best_hit = {
"subcluster_id": row["subcluster_id"],
Expand All @@ -169,41 +181,22 @@ def assign_best_hit(row: pd.Series, hits: Dict[str, List[MotifHit | PrestoHit]],
"precision": 0.0,
"overlap_score": 0.0,
"n_overlapping_hits": 0,
"redundancy_penalised_overlap_score": 0.0,
"penalized_score": 0.0,

}

return best_hit


def get_best_hits(
ref_subclusters: pd.DataFrame,
hits: Dict[str, List[MotifHit | PrestoHit]],
alpha: float = 0.25,
) -> pd.DataFrame:
best_hits = pd.DataFrame(
ref_subclusters.apply(lambda row: assign_best_hit(row, hits, alpha=alpha), axis=1).tolist()
)
return best_hits


def calculate_evaluation(ref_subclusters_with_hits: pd.DataFrame) -> tuple:
scores = ref_subclusters_with_hits[["overlap_score", "redundancy_penalised_overlap_score"]]
mean_scores = scores.mean().to_dict()

m_os = round(mean_scores["overlap_score"], 3)
m_rpos = round(mean_scores["redundancy_penalised_overlap_score"], 3)
return m_os, m_rpos


def write_motif_evaluation(ref_subclusters: pd.DataFrame, best_hits: pd.DataFrame, output_filepath: Path) -> None:
"""Writes the best motif set hits to a tsv file."""
eval_df = ref_subclusters.merge(best_hits, on="subcluster_id")
eval_df["tokenized_genes"] = eval_df["tokenized_genes"].apply(lambda x: ";".join(x))
eval_df["motif_hit_genes"] = eval_df["motif_hit_genes"].apply(lambda x: ";".join(sorted(x)))
eval_df["overlapping_genes"] = eval_df["overlapping_genes"].apply(lambda x: ";".join(sorted(x)))
eval_df["jaccard"] = eval_df["jaccard"].round(3)
eval_df["recall"] = eval_df["recall"].round(3)
eval_df["precision"] = eval_df["precision"].round(3)
eval_df["overlap_score"] = eval_df["overlap_score"].round(3)
eval_df["redundancy_penalised_overlap_score"] = eval_df["redundancy_penalised_overlap_score"].round(3)
eval_df["jaccard"] = eval_df["jaccard"].round(4)
eval_df["recall"] = eval_df["recall"].round(4)
eval_df["precision"] = eval_df["precision"].round(4)
eval_df["overlap_score"] = eval_df["overlap_score"].round(4)
eval_df["penalized_score"] = eval_df["penalized_score"].round(4)
eval_df.to_csv(output_filepath, sep="\t", index=False)
Loading