Skip to content

Commit

Permalink
fix(l2g): direct model path to hf repo when none and align training d…
Browse files Browse the repository at this point in the history
…ata filename (#997)

* fix(model): typo in filename when uploading training data to hub

* fix(l2g): direct model to hf repo when path is none

* fix: run_train to keep GS and geneId

* chore: pre-commit auto fixes [...]

---------

Co-authored-by: xyg123 <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Feb 20, 2025
1 parent 2f86159 commit 198019f
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
10 changes: 6 additions & 4 deletions src/gentropy/l2g.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,6 @@ def __init__(

self.session = session
self.run_mode = run_mode
self.model_path = model_path
self.predictions_path = predictions_path
self.features_list = list(features_list) if features_list else None
self.hyperparameters = dict(hyperparameters)
Expand All @@ -164,6 +163,11 @@ def __init__(
self.gold_standard_curation_path = gold_standard_curation_path
self.gene_interactions_path = gene_interactions_path
self.variant_index_path = variant_index_path
self.model_path = (
hf_hub_repo_id
if not model_path and download_from_hub and hf_hub_repo_id
else model_path
)

# Load common inputs
self.credible_set = StudyLocus.from_parquet(
Expand Down Expand Up @@ -333,9 +337,7 @@ def run_train(self) -> None:
# we upload the model saved in the filesystem
self.model_path.split("/")[-1],
hf_hub_token,
data=trained_model.training_data._df.drop(
"goldStandardSet", "geneId"
).toPandas(),
data=trained_model.training_data._df.toPandas(),
repo_id=self.hf_hub_repo_id,
commit_message=self.hf_model_commit_message,
)
Expand Down
2 changes: 1 addition & 1 deletion src/gentropy/method/l2g/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def export_to_hugging_face_hub(
data=data,
)
self._create_hugging_face_model_card(local_repo)
data.to_parquet(f"{local_repo}/training_set.parquet")
data.to_parquet(f"{local_repo}/training_data.parquet")
hub_utils.push(
repo_id=repo_id,
source=local_repo,
Expand Down

0 comments on commit 198019f

Please sign in to comment.