Skip to content

Commit

Permalink
feat(l2gprediction): add score explanation based on features (#939)
Browse files Browse the repository at this point in the history
* feat(prediction): add `model` as instance attribute

* feat: added `convert_map_type_to_columns` spark util

* feat(prediction): new method `explain` returns shapley values

* feat(prediction): `explain` returns predictions with shapley values

* chore: compute `shapleyValues` in the l2g step

* refactor: use pandas udf instead

* refactor: forget about udfs and get shaps single threaded

* chore: remove reference to chromatin interaction data in HF card

* fix(l2g_prediction): methods that return new instance preserve attribute

* feat(dataset): `filter` method preserves all instance attributes

* feat(l2gmodel): add features_list as model attribute and load it from the hub metadata

* fix: pass correct order of features to shapley explainer

* feat(l2g): predict mode to extract feature list from model, not from config

* feat(l2g): pass default features list if model is loaded from a path

* feat(l2gmodel): add features_list as model attribute and load it from the hub metadata

* feat(l2g): predict mode to extract feature list from model, not from config

* feat(l2gprediction): add `model` as attribute

* feat(l2gmodel): add features_list as model attribute and load it from the hub metadata

* feat(l2g): predict mode to extract feature list from model, not from config

* feat(l2gprediction): add `model` as attribute

* chore: fix typo

* chore: remove `convert_map_type_to_columns`

* feat(l2gprediction): refactor feature annotation and change schema

* chore: pre-commit auto fixes [...]

* feat: report as log odds

* feat: calculate scaled probabilities

* chore(l2gprediction): remove shapBaseProbability

* chore: correct typo in add_features and make schemas non nullable

* fix: rename columns in pandas df after pivoting

* fix: add raw shap contributions

* fix(model): when saving create directory if not exists

* feat(l2g): bundle model and training data in hf

* feat(model): include data when loading model

* feat: final version of shap explanations

* fix: do not infer features_list from df

* fix: get_features_list_from_metadata returned cols that were not features

* refactor(model): read training data in the local filesystem w pandas

* chore: successful run, remove test
  • Loading branch information
ireneisdoomed authored Feb 19, 2025
1 parent cef8afc commit f952f6c
Show file tree
Hide file tree
Showing 5 changed files with 250 additions and 63 deletions.
37 changes: 32 additions & 5 deletions src/gentropy/assets/schemas/l2g_predictions.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,41 @@
},
{
"metadata": {},
"name": "locusToGeneFeatures",
"name": "features",
"nullable": true,
"type": {
"keyType": "string",
"type": "map",
"valueContainsNull": true,
"valueType": "float"
"containsNull": false,
"elementType": {
"fields": [
{
"metadata": {},
"name": "name",
"nullable": false,
"type": "string"
},
{
"metadata": {},
"name": "value",
"nullable": false,
"type": "float"
},
{
"metadata": {},
"name": "shapValue",
"nullable": true,
"type": "float"
}
],
"type": "struct"
},
"type": "array"
}
},
{
"name": "shapBaseValue",
"type": "float",
"nullable": true,
"metadata": {}
}
]
}
4 changes: 3 additions & 1 deletion src/gentropy/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,13 @@ def from_parquet(
def filter(self: Self, condition: Column) -> Self:
"""Creates a new instance of a Dataset with the DataFrame filtered by the condition.
Preserves all attributes from the original instance.
Args:
condition (Column): Condition to filter the DataFrame
Returns:
Self: Filtered Dataset
Self: Filtered Dataset with preserved attributes
"""
filtered_df = self._df.filter(condition)
attrs = {k: v for k, v in self.__dict__.items() if k != "_df"}
Expand Down
197 changes: 162 additions & 35 deletions src/gentropy/dataset/l2g_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,26 @@

from __future__ import annotations

import logging
from dataclasses import dataclass, field
from typing import TYPE_CHECKING

import pyspark.sql.functions as f
import shap
from pyspark.sql import DataFrame
from pyspark.sql.types import StructType

from gentropy.common.schemas import parse_spark_schema
from gentropy.common.session import Session
from gentropy.common.spark_helpers import pivot_df
from gentropy.dataset.dataset import Dataset
from gentropy.dataset.l2g_feature_matrix import L2GFeatureMatrix
from gentropy.dataset.study_index import StudyIndex
from gentropy.dataset.study_locus import StudyLocus
from gentropy.method.l2g.model import LocusToGeneModel

if TYPE_CHECKING:
from pandas import DataFrame as pd_dataframe
from pyspark.sql.types import StructType


Expand Down Expand Up @@ -47,6 +52,7 @@ def from_credible_set(
credible_set: StudyLocus,
feature_matrix: L2GFeatureMatrix,
model_path: str | None,
features_list: list[str] | None = None,
hf_token: str | None = None,
download_from_hub: bool = True,
) -> L2GPrediction:
Expand All @@ -57,19 +63,29 @@ def from_credible_set(
credible_set (StudyLocus): Dataset containing credible sets from GWAS only
feature_matrix (L2GFeatureMatrix): Dataset containing all credible sets and their annotations
model_path (str | None): Path to the model file. It can be either in the filesystem or the name on the Hugging Face Hub (in the form of username/repo_name).
features_list (list[str] | None): Default list of features the model uses. Only used if the model is not downloaded from the Hub. CAUTION: This default list can differ from the actual list the model was trained on.
hf_token (str | None): Hugging Face token to download the model from the Hub. Only required if the model is private.
download_from_hub (bool): Whether to download the model from the Hugging Face Hub. Defaults to True.
Returns:
L2GPrediction: L2G scores for a set of credible sets.
Raises:
AttributeError: If `features_list` is not provided and the model is not downloaded from the Hub.
"""
# Load the model
if download_from_hub:
# Model ID defaults to "opentargets/locus_to_gene" and it assumes the name of the classifier is "classifier.skops".
model_id = model_path or "opentargets/locus_to_gene"
l2g_model = LocusToGeneModel.load_from_hub(model_id, hf_token)
l2g_model = LocusToGeneModel.load_from_hub(session, model_id, hf_token)
elif model_path:
l2g_model = LocusToGeneModel.load_from_disk(model_path)
if not features_list:
raise AttributeError(
"features_list is required if the model is not downloaded from the Hub"
)
l2g_model = LocusToGeneModel.load_from_disk(
session, path=model_path, features_list=features_list
)

# Prepare data
fm = (
Expand All @@ -79,7 +95,7 @@ def from_credible_set(
.select("studyLocusId")
.join(feature_matrix._df, "studyLocusId")
.filter(f.col("isProteinCoding") == 1)
)
),
)
.fill_na()
.select_features(l2g_model.features_list)
Expand Down Expand Up @@ -127,7 +143,129 @@ def to_disease_target_evidence(
)
)

def add_locus_to_gene_features(
def explain(
self: L2GPrediction, feature_matrix: L2GFeatureMatrix | None = None
) -> L2GPrediction:
"""Extract Shapley values for the L2G predictions and add them as a map in an additional column.
Args:
feature_matrix (L2GFeatureMatrix | None): Feature matrix in case the predictions are missing the feature annotation. If None, the features are fetched from the dataset.
Returns:
L2GPrediction: L2GPrediction object with additional column containing feature name to Shapley value mappings
Raises:
ValueError: If the model is not set or If feature matrix is not provided and the predictions do not have features
"""
# Fetch features if they are not present:
if "features" not in self.df.columns:
if feature_matrix is None:
raise ValueError(
"Feature matrix is required to explain the L2G predictions"
)
self.add_features(feature_matrix)

if self.model is None:
raise ValueError("Model not set, explainer cannot be created")

# Format and pivot the dataframe to pass them before calculating shapley values
pdf = pivot_df(
df=self.df.withColumn("feature", f.explode("features")).select(
"studyLocusId",
"geneId",
"score",
f.col("feature.name").alias("feature_name"),
f.col("feature.value").alias("feature_value"),
),
pivot_col="feature_name",
value_col="feature_value",
grouping_cols=[f.col("studyLocusId"), f.col("geneId"), f.col("score")],
).toPandas()
pdf = pdf.rename(
# trim the suffix that is added after pivoting the df
columns={
col: col.replace("_feature_value", "")
for col in pdf.columns
if col.endswith("_feature_value")
}
)

features_list = self.model.features_list # The matrix needs to present the features in the same order that the model was trained on)
base_value, shap_values = L2GPrediction._explain(
model=self.model,
pdf=pdf.filter(items=features_list),
)
for i, feature in enumerate(features_list):
pdf[f"shap_{feature}"] = [row[i] for row in shap_values]

spark_session = self.df.sparkSession
return L2GPrediction(
_df=(
spark_session.createDataFrame(pdf.to_dict(orient="records"))
.withColumn(
"features",
f.array(
*(
f.struct(
f.lit(feature).alias("name"),
f.col(feature).cast("float").alias("value"),
f.col(f"shap_{feature}")
.cast("float")
.alias("shapValue"),
)
for feature in features_list
)
),
)
.withColumn("shapBaseValue", f.lit(base_value).cast("float"))
.select(*L2GPrediction.get_schema().names)
),
_schema=self.get_schema(),
model=self.model,
)

@staticmethod
def _explain(
model: LocusToGeneModel, pdf: pd_dataframe
) -> tuple[float, list[list[float]]]:
"""Calculate SHAP values. Output is in probability form (approximated from the log odds ratios).
Args:
model (LocusToGeneModel): L2G model
pdf (pd_dataframe): Pandas dataframe containing the feature matrix in the same order that the model was trained on
Returns:
tuple[float, list[list[float]]]: A tuple containing:
- base_value (float): Base value of the model
- shap_values (list[list[float]]): SHAP values for prediction
Raises:
AttributeError: If model.training_data is not set, seed dataset to get shapley values cannot be created.
"""
if not model.training_data:
raise AttributeError(
"`model.training_data` is missing, seed dataset to get shapley values cannot be created."
)
background_data = model.training_data._df.select(
*model.features_list
).toPandas()
explainer = shap.TreeExplainer(
model.model,
data=background_data,
model_output="probability",
)
if pdf.shape[0] >= 10_000:
logging.warning(
"Calculating SHAP values for more than 10,000 rows. This may take a while..."
)
shap_values = explainer.shap_values(
pdf.to_numpy(),
check_additivity=False,
)
base_value = explainer.expected_value
return (base_value, shap_values)

def add_features(
self: L2GPrediction,
feature_matrix: L2GFeatureMatrix,
) -> L2GPrediction:
Expand All @@ -137,41 +275,30 @@ def add_locus_to_gene_features(
feature_matrix (L2GFeatureMatrix): Feature matrix dataset
Returns:
L2GPrediction: L2G predictions with additional features
L2GPrediction: L2G predictions with additional column `features`
Raises:
ValueError: If model is not set, feature list won't be available
"""
if self.model is None:
raise ValueError("Model not set, feature annotation cannot be created.")
# Testing if `locusToGeneFeatures` column already exists:
if "locusToGeneFeatures" in self.df.columns:
self.df = self.df.drop("locusToGeneFeatures")

# Aggregating all features into a single map column:
aggregated_features = (
feature_matrix._df.withColumn(
"locusToGeneFeatures",
f.create_map(
*sum(
(
(f.lit(feature), f.col(feature))
for feature in self.model.features_list
),
(),
)
),
)
.withColumn(
"locusToGeneFeatures",
f.expr("map_filter(locusToGeneFeatures, (k, v) -> v != 0)"),
)
.drop(*self.model.features_list)
)
return L2GPrediction(
_df=self.df.join(
aggregated_features, on=["studyLocusId", "geneId"], how="left"
),
_schema=self.get_schema(),
model=self.model,
# Testing if `features` column already exists:
if "features" in self.df.columns:
self.df = self.df.drop("features")

features_list = self.model.features_list
feature_expressions = [
f.struct(f.lit(col).alias("name"), f.col(col).alias("value"))
for col in features_list
]
self.df = self.df.join(
feature_matrix._df.select(*features_list, "studyLocusId", "geneId"),
on=["studyLocusId", "geneId"],
how="left",
).select(
"studyLocusId",
"geneId",
"score",
f.array(*feature_expressions).alias("features"),
)
return self
9 changes: 4 additions & 5 deletions src/gentropy/l2g.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,14 +284,13 @@ def run_predict(self) -> None:
self.credible_set,
self.feature_matrix,
model_path=self.model_path,
features_list=self.features_list,
hf_token=access_gcp_secret("hfhub-key", "open-targets-genetics-dev"),
download_from_hub=self.download_from_hub,
)
predictions.filter(
f.col("score") >= self.l2g_threshold
).add_locus_to_gene_features(
predictions.filter(f.col("score") >= self.l2g_threshold).add_features(
self.feature_matrix,
).df.coalesce(self.session.output_partitions).write.mode(
).explain().df.coalesce(self.session.output_partitions).write.mode(
self.session.write_mode
).parquet(self.predictions_path)
self.session.logger.info("L2G predictions saved successfully.")
Expand Down Expand Up @@ -331,7 +330,7 @@ def run_train(self) -> None:
"hfhub-key", "open-targets-genetics-dev"
)
trained_model.export_to_hugging_face_hub(
# we upload the model in the filesystem
# we upload the model saved in the filesystem
self.model_path.split("/")[-1],
hf_hub_token,
data=trained_model.training_data._df.drop(
Expand Down
Loading

0 comments on commit f952f6c

Please sign in to comment.