From 198019fd8aa0d68be74e5a0d3fce415f8480ab25 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez=20Santiago?= <45119610+ireneisdoomed@users.noreply.github.com> Date: Thu, 20 Feb 2025 15:40:09 +0000 Subject: [PATCH] fix(l2g): direct model path to hf repo when none and align training data filename (#997) * fix(model): typo in filename when uploading training data to hub * fix(l2g): direct model to hf repo when path is none * fix: run_train to keep GS and geneId * chore: pre-commit auto fixes [...] --------- Co-authored-by: xyg123 Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- src/gentropy/l2g.py | 10 ++++++---- src/gentropy/method/l2g/model.py | 2 +- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/gentropy/l2g.py b/src/gentropy/l2g.py index 39b7fcca1..4c3d6c867 100644 --- a/src/gentropy/l2g.py +++ b/src/gentropy/l2g.py @@ -151,7 +151,6 @@ def __init__( self.session = session self.run_mode = run_mode - self.model_path = model_path self.predictions_path = predictions_path self.features_list = list(features_list) if features_list else None self.hyperparameters = dict(hyperparameters) @@ -164,6 +163,11 @@ def __init__( self.gold_standard_curation_path = gold_standard_curation_path self.gene_interactions_path = gene_interactions_path self.variant_index_path = variant_index_path + self.model_path = ( + hf_hub_repo_id + if not model_path and download_from_hub and hf_hub_repo_id + else model_path + ) # Load common inputs self.credible_set = StudyLocus.from_parquet( @@ -333,9 +337,7 @@ def run_train(self) -> None: # we upload the model saved in the filesystem self.model_path.split("/")[-1], hf_hub_token, - data=trained_model.training_data._df.drop( - "goldStandardSet", "geneId" - ).toPandas(), + data=trained_model.training_data._df.toPandas(), repo_id=self.hf_hub_repo_id, commit_message=self.hf_model_commit_message, ) diff --git a/src/gentropy/method/l2g/model.py b/src/gentropy/method/l2g/model.py index 8d31597ee..c695f8a68 100644 --- a/src/gentropy/method/l2g/model.py +++ b/src/gentropy/method/l2g/model.py @@ -325,7 +325,7 @@ def export_to_hugging_face_hub( data=data, ) self._create_hugging_face_model_card(local_repo) - data.to_parquet(f"{local_repo}/training_set.parquet") + data.to_parquet(f"{local_repo}/training_data.parquet") hub_utils.push( repo_id=repo_id, source=local_repo,