diff --git a/src/gentropy/assets/schemas/l2g_predictions.json b/src/gentropy/assets/schemas/l2g_predictions.json index 57247a49a..1d100bf94 100644 --- a/src/gentropy/assets/schemas/l2g_predictions.json +++ b/src/gentropy/assets/schemas/l2g_predictions.json @@ -21,14 +21,41 @@ }, { "metadata": {}, - "name": "locusToGeneFeatures", + "name": "features", "nullable": true, "type": { - "keyType": "string", - "type": "map", - "valueContainsNull": true, - "valueType": "float" + "containsNull": false, + "elementType": { + "fields": [ + { + "metadata": {}, + "name": "name", + "nullable": false, + "type": "string" + }, + { + "metadata": {}, + "name": "value", + "nullable": false, + "type": "float" + }, + { + "metadata": {}, + "name": "shapValue", + "nullable": true, + "type": "float" + } + ], + "type": "struct" + }, + "type": "array" } + }, + { + "name": "shapBaseValue", + "type": "float", + "nullable": true, + "metadata": {} } ] } diff --git a/src/gentropy/dataset/dataset.py b/src/gentropy/dataset/dataset.py index 3735d9812..1262043d4 100644 --- a/src/gentropy/dataset/dataset.py +++ b/src/gentropy/dataset/dataset.py @@ -174,11 +174,13 @@ def from_parquet( def filter(self: Self, condition: Column) -> Self: """Creates a new instance of a Dataset with the DataFrame filtered by the condition. + Preserves all attributes from the original instance. + Args: condition (Column): Condition to filter the DataFrame Returns: - Self: Filtered Dataset + Self: Filtered Dataset with preserved attributes """ filtered_df = self._df.filter(condition) attrs = {k: v for k, v in self.__dict__.items() if k != "_df"} diff --git a/src/gentropy/dataset/l2g_prediction.py b/src/gentropy/dataset/l2g_prediction.py index 255722414..d7477a354 100644 --- a/src/gentropy/dataset/l2g_prediction.py +++ b/src/gentropy/dataset/l2g_prediction.py @@ -2,14 +2,18 @@ from __future__ import annotations +import logging from dataclasses import dataclass, field from typing import TYPE_CHECKING import pyspark.sql.functions as f +import shap from pyspark.sql import DataFrame +from pyspark.sql.types import StructType from gentropy.common.schemas import parse_spark_schema from gentropy.common.session import Session +from gentropy.common.spark_helpers import pivot_df from gentropy.dataset.dataset import Dataset from gentropy.dataset.l2g_feature_matrix import L2GFeatureMatrix from gentropy.dataset.study_index import StudyIndex @@ -17,6 +21,7 @@ from gentropy.method.l2g.model import LocusToGeneModel if TYPE_CHECKING: + from pandas import DataFrame as pd_dataframe from pyspark.sql.types import StructType @@ -47,6 +52,7 @@ def from_credible_set( credible_set: StudyLocus, feature_matrix: L2GFeatureMatrix, model_path: str | None, + features_list: list[str] | None = None, hf_token: str | None = None, download_from_hub: bool = True, ) -> L2GPrediction: @@ -57,19 +63,29 @@ def from_credible_set( credible_set (StudyLocus): Dataset containing credible sets from GWAS only feature_matrix (L2GFeatureMatrix): Dataset containing all credible sets and their annotations model_path (str | None): Path to the model file. It can be either in the filesystem or the name on the Hugging Face Hub (in the form of username/repo_name). + features_list (list[str] | None): Default list of features the model uses. Only used if the model is not downloaded from the Hub. CAUTION: This default list can differ from the actual list the model was trained on. hf_token (str | None): Hugging Face token to download the model from the Hub. Only required if the model is private. download_from_hub (bool): Whether to download the model from the Hugging Face Hub. Defaults to True. Returns: L2GPrediction: L2G scores for a set of credible sets. + + Raises: + AttributeError: If `features_list` is not provided and the model is not downloaded from the Hub. """ # Load the model if download_from_hub: # Model ID defaults to "opentargets/locus_to_gene" and it assumes the name of the classifier is "classifier.skops". model_id = model_path or "opentargets/locus_to_gene" - l2g_model = LocusToGeneModel.load_from_hub(model_id, hf_token) + l2g_model = LocusToGeneModel.load_from_hub(session, model_id, hf_token) elif model_path: - l2g_model = LocusToGeneModel.load_from_disk(model_path) + if not features_list: + raise AttributeError( + "features_list is required if the model is not downloaded from the Hub" + ) + l2g_model = LocusToGeneModel.load_from_disk( + session, path=model_path, features_list=features_list + ) # Prepare data fm = ( @@ -79,7 +95,7 @@ def from_credible_set( .select("studyLocusId") .join(feature_matrix._df, "studyLocusId") .filter(f.col("isProteinCoding") == 1) - ) + ), ) .fill_na() .select_features(l2g_model.features_list) @@ -127,7 +143,129 @@ def to_disease_target_evidence( ) ) - def add_locus_to_gene_features( + def explain( + self: L2GPrediction, feature_matrix: L2GFeatureMatrix | None = None + ) -> L2GPrediction: + """Extract Shapley values for the L2G predictions and add them as a map in an additional column. + + Args: + feature_matrix (L2GFeatureMatrix | None): Feature matrix in case the predictions are missing the feature annotation. If None, the features are fetched from the dataset. + + Returns: + L2GPrediction: L2GPrediction object with additional column containing feature name to Shapley value mappings + + Raises: + ValueError: If the model is not set or If feature matrix is not provided and the predictions do not have features + """ + # Fetch features if they are not present: + if "features" not in self.df.columns: + if feature_matrix is None: + raise ValueError( + "Feature matrix is required to explain the L2G predictions" + ) + self.add_features(feature_matrix) + + if self.model is None: + raise ValueError("Model not set, explainer cannot be created") + + # Format and pivot the dataframe to pass them before calculating shapley values + pdf = pivot_df( + df=self.df.withColumn("feature", f.explode("features")).select( + "studyLocusId", + "geneId", + "score", + f.col("feature.name").alias("feature_name"), + f.col("feature.value").alias("feature_value"), + ), + pivot_col="feature_name", + value_col="feature_value", + grouping_cols=[f.col("studyLocusId"), f.col("geneId"), f.col("score")], + ).toPandas() + pdf = pdf.rename( + # trim the suffix that is added after pivoting the df + columns={ + col: col.replace("_feature_value", "") + for col in pdf.columns + if col.endswith("_feature_value") + } + ) + + features_list = self.model.features_list # The matrix needs to present the features in the same order that the model was trained on) + base_value, shap_values = L2GPrediction._explain( + model=self.model, + pdf=pdf.filter(items=features_list), + ) + for i, feature in enumerate(features_list): + pdf[f"shap_{feature}"] = [row[i] for row in shap_values] + + spark_session = self.df.sparkSession + return L2GPrediction( + _df=( + spark_session.createDataFrame(pdf.to_dict(orient="records")) + .withColumn( + "features", + f.array( + *( + f.struct( + f.lit(feature).alias("name"), + f.col(feature).cast("float").alias("value"), + f.col(f"shap_{feature}") + .cast("float") + .alias("shapValue"), + ) + for feature in features_list + ) + ), + ) + .withColumn("shapBaseValue", f.lit(base_value).cast("float")) + .select(*L2GPrediction.get_schema().names) + ), + _schema=self.get_schema(), + model=self.model, + ) + + @staticmethod + def _explain( + model: LocusToGeneModel, pdf: pd_dataframe + ) -> tuple[float, list[list[float]]]: + """Calculate SHAP values. Output is in probability form (approximated from the log odds ratios). + + Args: + model (LocusToGeneModel): L2G model + pdf (pd_dataframe): Pandas dataframe containing the feature matrix in the same order that the model was trained on + + Returns: + tuple[float, list[list[float]]]: A tuple containing: + - base_value (float): Base value of the model + - shap_values (list[list[float]]): SHAP values for prediction + + Raises: + AttributeError: If model.training_data is not set, seed dataset to get shapley values cannot be created. + """ + if not model.training_data: + raise AttributeError( + "`model.training_data` is missing, seed dataset to get shapley values cannot be created." + ) + background_data = model.training_data._df.select( + *model.features_list + ).toPandas() + explainer = shap.TreeExplainer( + model.model, + data=background_data, + model_output="probability", + ) + if pdf.shape[0] >= 10_000: + logging.warning( + "Calculating SHAP values for more than 10,000 rows. This may take a while..." + ) + shap_values = explainer.shap_values( + pdf.to_numpy(), + check_additivity=False, + ) + base_value = explainer.expected_value + return (base_value, shap_values) + + def add_features( self: L2GPrediction, feature_matrix: L2GFeatureMatrix, ) -> L2GPrediction: @@ -137,41 +275,30 @@ def add_locus_to_gene_features( feature_matrix (L2GFeatureMatrix): Feature matrix dataset Returns: - L2GPrediction: L2G predictions with additional features + L2GPrediction: L2G predictions with additional column `features` Raises: ValueError: If model is not set, feature list won't be available """ if self.model is None: raise ValueError("Model not set, feature annotation cannot be created.") - # Testing if `locusToGeneFeatures` column already exists: - if "locusToGeneFeatures" in self.df.columns: - self.df = self.df.drop("locusToGeneFeatures") - - # Aggregating all features into a single map column: - aggregated_features = ( - feature_matrix._df.withColumn( - "locusToGeneFeatures", - f.create_map( - *sum( - ( - (f.lit(feature), f.col(feature)) - for feature in self.model.features_list - ), - (), - ) - ), - ) - .withColumn( - "locusToGeneFeatures", - f.expr("map_filter(locusToGeneFeatures, (k, v) -> v != 0)"), - ) - .drop(*self.model.features_list) - ) - return L2GPrediction( - _df=self.df.join( - aggregated_features, on=["studyLocusId", "geneId"], how="left" - ), - _schema=self.get_schema(), - model=self.model, + # Testing if `features` column already exists: + if "features" in self.df.columns: + self.df = self.df.drop("features") + + features_list = self.model.features_list + feature_expressions = [ + f.struct(f.lit(col).alias("name"), f.col(col).alias("value")) + for col in features_list + ] + self.df = self.df.join( + feature_matrix._df.select(*features_list, "studyLocusId", "geneId"), + on=["studyLocusId", "geneId"], + how="left", + ).select( + "studyLocusId", + "geneId", + "score", + f.array(*feature_expressions).alias("features"), ) + return self diff --git a/src/gentropy/l2g.py b/src/gentropy/l2g.py index 5f22471e3..39b7fcca1 100644 --- a/src/gentropy/l2g.py +++ b/src/gentropy/l2g.py @@ -284,14 +284,13 @@ def run_predict(self) -> None: self.credible_set, self.feature_matrix, model_path=self.model_path, + features_list=self.features_list, hf_token=access_gcp_secret("hfhub-key", "open-targets-genetics-dev"), download_from_hub=self.download_from_hub, ) - predictions.filter( - f.col("score") >= self.l2g_threshold - ).add_locus_to_gene_features( + predictions.filter(f.col("score") >= self.l2g_threshold).add_features( self.feature_matrix, - ).df.coalesce(self.session.output_partitions).write.mode( + ).explain().df.coalesce(self.session.output_partitions).write.mode( self.session.write_mode ).parquet(self.predictions_path) self.session.logger.info("L2G predictions saved successfully.") @@ -331,7 +330,7 @@ def run_train(self) -> None: "hfhub-key", "open-targets-genetics-dev" ) trained_model.export_to_hugging_face_hub( - # we upload the model in the filesystem + # we upload the model saved in the filesystem self.model_path.split("/")[-1], hf_hub_token, data=trained_model.training_data._df.drop( diff --git a/src/gentropy/method/l2g/model.py b/src/gentropy/method/l2g/model.py index 9d9011332..8d31597ee 100644 --- a/src/gentropy/method/l2g/model.py +++ b/src/gentropy/method/l2g/model.py @@ -3,6 +3,7 @@ from __future__ import annotations import json +import logging from dataclasses import dataclass, field from pathlib import Path from typing import TYPE_CHECKING, Any @@ -16,9 +17,9 @@ from gentropy.common.session import Session from gentropy.common.utils import copy_to_gcs +from gentropy.dataset.l2g_feature_matrix import L2GFeatureMatrix if TYPE_CHECKING: - from gentropy.dataset.l2g_feature_matrix import L2GFeatureMatrix from gentropy.dataset.l2g_prediction import L2GPrediction @@ -27,9 +28,7 @@ class LocusToGeneModel: """Wrapper for the Locus to Gene classifier.""" model: Any = GradientBoostingClassifier(random_state=42) - features_list: list[str] = field( - default_factory=list - ) # TODO: default to list in config if not provided + features_list: list[str] = field(default_factory=list) hyperparameters: dict[str, Any] = field( default_factory=lambda: { "n_estimators": 100, @@ -55,12 +54,18 @@ def __post_init__(self: LocusToGeneModel) -> None: @classmethod def load_from_disk( - cls: type[LocusToGeneModel], path: str, **kwargs: Any + cls: type[LocusToGeneModel], + session: Session, + path: str, + model_name: str = "classifier.skops", + **kwargs: Any, ) -> LocusToGeneModel: """Load a fitted model from disk. Args: - path (str): Path to the model + session (Session): Session object that loads the training data + path (str): Path to the directory containing model and metadata + model_name (str): Name of the persisted model to load. Defaults to "classifier.skops". **kwargs(Any): Keyword arguments to pass to the constructor Returns: @@ -69,8 +74,9 @@ def load_from_disk( Raises: ValueError: If the model has not been fitted yet """ - if path.startswith("gs://"): - path = path.removeprefix("gs://") + model_path = (Path(path) / model_name).as_posix() + if model_path.startswith("gs://"): + path = model_path.removeprefix("gs://") bucket_name = path.split("/")[0] blob_name = "/".join(path.split("/")[1:]) from google.cloud import storage @@ -81,25 +87,41 @@ def load_from_disk( data = blob.download_as_string(client=client) loaded_model = sio.loads(data, trusted=sio.get_untrusted_types(data=data)) else: - loaded_model = sio.load(path, trusted=sio.get_untrusted_types(file=path)) + loaded_model = sio.load( + model_path, trusted=sio.get_untrusted_types(file=model_path) + ) + try: + # Try loading the training data if it is in the model directory + training_data = L2GFeatureMatrix( + _df=session.spark.createDataFrame( + # Parquet is read with Pandas to easily read local files + pd.read_parquet( + (Path(path) / "training_data.parquet").as_posix() + ) + ), + features_list=kwargs.get("features_list"), + ) + except Exception as e: + logging.error("Training data set to none. Error: %s", e) + training_data = None if not loaded_model._is_fitted(): raise ValueError("Model has not been fitted yet.") - return cls(model=loaded_model, **kwargs) + return cls(model=loaded_model, training_data=training_data, **kwargs) @classmethod def load_from_hub( cls: type[LocusToGeneModel], + session: Session, model_id: str, hf_token: str | None = None, - model_name: str = "classifier.skops", ) -> LocusToGeneModel: """Load a model from the Hugging Face Hub. This will download the model from the hub and load it from disk. Args: + session (Session): Session object to load the training data model_id (str): Model ID on the Hugging Face Hub hf_token (str | None): Hugging Face Hub token to download the model (only required if private) - model_name (str): Name of the persisted model to load. Defaults to "classifier.skops". Returns: LocusToGeneModel: L2G model loaded from the Hugging Face Hub @@ -119,14 +141,22 @@ def get_features_list_from_metadata() -> list[str]: return [ column for column in model_config["sklearn"]["columns"] - if column != "studyLocusId" + if column + not in [ + "studyLocusId", + "geneId", + "traitFromSourceMappedId", + "goldStandardSet", + ] ] - local_path = Path(model_id) + local_path = model_id hub_utils.download(repo_id=model_id, dst=local_path, token=hf_token) features_list = get_features_list_from_metadata() return cls.load_from_disk( - str(Path(local_path) / model_name), features_list=features_list + session, + local_path, + features_list=features_list, ) @property @@ -196,6 +226,8 @@ def save(self: LocusToGeneModel, path: str) -> None: sio.dump(self.model, local_path) copy_to_gcs(local_path, path) else: + # create directory if path does not exist + Path(path).parent.mkdir(parents=True, exist_ok=True) sio.dump(self.model, path) @staticmethod @@ -231,7 +263,6 @@ def _create_hugging_face_model_card( - Distance: (from credible set variants to gene) - Molecular QTL Colocalization - - Chromatin Interaction: (e.g., promoter-capture Hi-C) - Variant Pathogenicity: (from VEP) More information at: https://opentargets.github.io/gentropy/python_api/methods/l2g/_l2g/ @@ -270,7 +301,7 @@ def export_to_hugging_face_hub( repo_id: str = "opentargets/locus_to_gene", local_repo: str = "locus_to_gene", ) -> None: - """Share the model on Hugging Face Hub. + """Share the model and training dataset on Hugging Face Hub. Args: model_path (str): The path to the L2G model file. @@ -294,6 +325,7 @@ def export_to_hugging_face_hub( data=data, ) self._create_hugging_face_model_card(local_repo) + data.to_parquet(f"{local_repo}/training_set.parquet") hub_utils.push( repo_id=repo_id, source=local_repo,