diff --git a/src/gentropy/dataset/l2g_feature_matrix.py b/src/gentropy/dataset/l2g_feature_matrix.py index 4c611e3da..7ef526704 100644 --- a/src/gentropy/dataset/l2g_feature_matrix.py +++ b/src/gentropy/dataset/l2g_feature_matrix.py @@ -108,6 +108,88 @@ def get_schema(cls: type[L2GFeatureMatrix]) -> StructType: """ return parse_spark_schema("l2g_feature_matrix.json") + def merge_features_in_efo( + self: L2GFeatureMatrix, + features: list[str], + credible_set: StudyLocus, + study_index: StudyIndex, + max_distance: int = 500000, + ) -> L2GFeatureMatrix: + """Merge studyLocusId-to-geneId pairings in the feature matrix, filling in missing features. + + Args: + features (list[str]): List of features to merge + credible_set (StudyLocus): Credible set dataset + study_index (StudyIndex): Study index dataset + max_distance (int): Maximum allowed base pair distance for grouping variants. Default is 500,000. + + Returns: + L2GFeatureMatrix: L2G feature matrix dataset + """ + from pyspark.sql import functions as f + from pyspark.sql.window import Window + + efo_df = ( + credible_set.df.join(study_index.df, on="studyId", how="inner").select( + "studyId", + "studyLocusId", + "variantId", + f.explode(study_index.df["traitFromSourceMappedIds"]).alias( + "efo_terms" + ), + ) + ).join( + self._df, + on="studyLocusId", + how="inner", + ) + + efo_df = efo_df.withColumn( + "chromosome", f.split(f.col("variantId"), "_").getItem(0) + ) + efo_df = efo_df.withColumn( + "position", f.split(f.col("variantId"), "_").getItem(1).cast("long") + ) + + window_spec = Window.partitionBy("efo_terms", "geneId", "chromosome").orderBy( + "position" + ) + + efo_df = efo_df.withColumn( + "position_diff", f.col("position") - f.lag("position", 1).over(window_spec) + ) + efo_df = efo_df.withColumn( + "group", + f.sum(f.when(f.col("position_diff") > max_distance, 1).otherwise(0)).over( + window_spec + ), + ) + + max_df = efo_df.groupBy("efo_terms", "geneId", "group").agg( + *[f.max(col).alias(f"{col}_max") for col in features] + ) + + imputed_df = efo_df.join( + max_df, on=["efo_terms", "geneId", "group"], how="left" + ) + + for col in features: + imputed_df = imputed_df.withColumn(col, f.col(f"{col}_max")).drop( + f"{col}_max" + ) + + self.df = imputed_df.drop( + "efo_terms", + "studyId", + "chromosome", + "position", + "position_diff", + "group", + "variantId", + ).distinct() + + return self + def calculate_feature_missingness_rate( self: L2GFeatureMatrix, ) -> dict[str, float]: diff --git a/src/gentropy/dataset/l2g_gold_standard.py b/src/gentropy/dataset/l2g_gold_standard.py index 5bc48413c..d2bcc77b7 100644 --- a/src/gentropy/dataset/l2g_gold_standard.py +++ b/src/gentropy/dataset/l2g_gold_standard.py @@ -1,4 +1,5 @@ """L2G gold standard dataset.""" + from __future__ import annotations from dataclasses import dataclass @@ -56,6 +57,7 @@ def from_otg_curation( OpenTargetsL2GGoldStandard.as_l2g_gold_standard(gold_standard_curation, v2g) # .filter_unique_associations(study_locus_overlap) .remove_false_negatives(interactions_df) + .balance_classes() ) @classmethod @@ -197,3 +199,31 @@ def remove_false_negatives( .distinct() ) return L2GGoldStandard(_df=df, _schema=self.get_schema()) + + def balance_classes( + self: L2GGoldStandard, imbalance_ratio: float = 2.0 + ) -> L2GGoldStandard: + """Balances the classes of the gold standard dataset. + + Args: + imbalance_ratio (float): maximum ratio of negative to positive samples + + Returns: + L2GGoldStandard: A balanced gold standard dataset. + """ + positive_df = self.df.filter(f.col("goldStandardSet") == self.GS_POSITIVE_LABEL) + negative_df = self.df.filter(f.col("goldStandardSet") == self.GS_NEGATIVE_LABEL) + + negative_sample_fraction = min( + (positive_df.count() * imbalance_ratio / negative_df.count()), 1.0 + ) + + negative_sample = negative_df.sample( + withReplacement=False, + fraction=negative_sample_fraction, + seed=42, + ) + + return L2GGoldStandard( + _df=positive_df.union(negative_sample), _schema=self.get_schema() + )