From ad3f503ae988f7b09fb6723ca2c6a8c7ba0b8d90 Mon Sep 17 00:00:00 2001 From: David Ochoa Date: Fri, 20 Sep 2024 15:29:30 +0100 Subject: [PATCH] feat: flag PICS top hits in studies with credset sumstats (#777) Co-authored-by: Daniel Suveges --- src/gentropy/dataset/study_locus.py | 42 +++++++++++++ src/gentropy/study_locus_validation.py | 1 + tests/gentropy/dataset/test_study_locus.py | 71 ++++++++++++++++++++++ 3 files changed, 114 insertions(+) diff --git a/src/gentropy/dataset/study_locus.py b/src/gentropy/dataset/study_locus.py index 283280527..e2f9dfece 100644 --- a/src/gentropy/dataset/study_locus.py +++ b/src/gentropy/dataset/study_locus.py @@ -52,6 +52,7 @@ class StudyLocusQualityCheck(Enum): INVALID_VARIANT_IDENTIFIER (str): Flagging study loci where identifier of any tagging variant was not found in the variant index TOP_HIT (str): Study locus from curated top hit IN_MHC (str): Flagging study loci in the MHC region + REDUNDANT_PICS_TOP_HIT (str): Flagging study loci in studies with PICS results from summary statistics """ SUBSIGNIFICANT_FLAG = "Subsignificant p-value" @@ -74,6 +75,9 @@ class StudyLocusQualityCheck(Enum): "Some variant identifiers of this locus were not found in variant index" ) IN_MHC = "MHC region" + REDUNDANT_PICS_TOP_HIT = ( + "PICS results from summary statistics available for this same study" + ) TOP_HIT = "Study locus from curated top hit" @@ -878,6 +882,44 @@ def qc_MHC_region(self: StudyLocus) -> StudyLocus: ) return self + def qc_redundant_top_hits_from_PICS(self: StudyLocus) -> StudyLocus: + """Flag associations from top hits when the study contains other PICS associations from summary statistics. + + This flag can be useful to identify top hits that should be explained by other associations in the study derived from the summary statistics. + + Returns: + StudyLocus: Updated study locus with redundant top hits flagged. + """ + studies_with_pics_sumstats = ( + self.df.filter(f.col("finemappingMethod") == "pics") + # Returns True if the study contains any PICS associations from summary statistics + .withColumn( + "hasPicsSumstats", + ~f.array_contains( + "qualityControls", StudyLocusQualityCheck.TOP_HIT.value + ), + ) + .groupBy("studyId") + .agg(f.max(f.col("hasPicsSumstats")).alias("studiesWithPicsSumstats")) + ) + + return StudyLocus( + _df=self.df.join(studies_with_pics_sumstats, on="studyId", how="left") + .withColumn( + "qualityControls", + self.update_quality_flag( + f.col("qualityControls"), + f.array_contains( + "qualityControls", StudyLocusQualityCheck.TOP_HIT.value + ) + & f.col("studiesWithPicsSumstats"), + StudyLocusQualityCheck.REDUNDANT_PICS_TOP_HIT, + ), + ) + .drop("studiesWithPicsSumstats"), + _schema=StudyLocus.get_schema(), + ) + def _qc_no_population(self: StudyLocus) -> StudyLocus: """Flag associations where the study doesn't have population information to resolve LD. diff --git a/src/gentropy/study_locus_validation.py b/src/gentropy/study_locus_validation.py index da660ca57..e3d10f3db 100644 --- a/src/gentropy/study_locus_validation.py +++ b/src/gentropy/study_locus_validation.py @@ -46,6 +46,7 @@ def __init__( # Add flag for MHC region .qc_MHC_region() .validate_study(study_index) # Flagging studies not in study index + .qc_redundant_top_hits_from_PICS() # Flagging top hits from studies with PICS summary statistics .validate_unique_study_locus_id() # Flagging duplicated study locus ids ).persist() # we will need this for 2 types of outputs diff --git a/tests/gentropy/dataset/test_study_locus.py b/tests/gentropy/dataset/test_study_locus.py index 1daf9bb89..94390d20b 100644 --- a/tests/gentropy/dataset/test_study_locus.py +++ b/tests/gentropy/dataset/test_study_locus.py @@ -778,3 +778,74 @@ def test_study_validation_correctness(self: TestStudyLocusValidation) -> None: ) .count() ) == 1 + + +class TestStudyLocusRedundancyFlagging: + """Collection of tests related to flagging redundant credible sets.""" + + STUDY_LOCUS_DATA = [ + (1, "v1", "s1", "pics", []), + (2, "v2", "s1", "pics", [StudyLocusQualityCheck.TOP_HIT.value]), + (3, "v3", "s1", "pics", []), + (3, "v3", "s1", "pics", []), + (1, "v1", "s1", "pics", [StudyLocusQualityCheck.TOP_HIT.value]), + (1, "v1", "s2", "pics", [StudyLocusQualityCheck.TOP_HIT.value]), + (1, "v1", "s2", "pics", [StudyLocusQualityCheck.TOP_HIT.value]), + (1, "v1", "s3", "SuSie", []), + (1, "v1", "s3", "pics", [StudyLocusQualityCheck.TOP_HIT.value]), + (1, "v1", "s4", "pics", []), + (1, "v1", "s4", "SuSie", []), + (1, "v1", "s4", "pics", [StudyLocusQualityCheck.TOP_HIT.value]), + ] + + STUDY_LOCUS_SCHEMA = t.StructType( + [ + t.StructField("studyLocusId", t.LongType(), False), + t.StructField("variantId", t.StringType(), False), + t.StructField("studyId", t.StringType(), False), + t.StructField("finemappingMethod", t.StringType(), False), + t.StructField("qualityControls", t.ArrayType(t.StringType()), False), + ] + ) + + @pytest.fixture(autouse=True) + def _setup(self: TestStudyLocusRedundancyFlagging, spark: SparkSession) -> None: + """Setup study locus for testing.""" + self.study_locus = StudyLocus( + _df=spark.createDataFrame( + self.STUDY_LOCUS_DATA, schema=self.STUDY_LOCUS_SCHEMA + ), + _schema=StudyLocus.get_schema(), + ) + + def test_qc_redundant_top_hits_from_PICS_returntype( + self: TestStudyLocusRedundancyFlagging, + ) -> None: + """Test qc_redundant_top_hits_from_PICS.""" + assert isinstance( + self.study_locus.qc_redundant_top_hits_from_PICS(), StudyLocus + ) + + def test_qc_redundant_top_hits_from_PICS_no_data_loss( + self: TestStudyLocusRedundancyFlagging, + ) -> None: + """Testing if the redundancy flagging returns the same number of rows.""" + assert ( + self.study_locus.qc_redundant_top_hits_from_PICS().df.count() + == self.study_locus.df.count() + ) + + def test_qc_redundant_top_hits_from_PICS_correctness( + self: TestStudyLocusRedundancyFlagging, + ) -> None: + """Testing if the study validation flags the right number of studies.""" + assert ( + self.study_locus.qc_redundant_top_hits_from_PICS() + .df.filter( + f.array_contains( + f.col("qualityControls"), + StudyLocusQualityCheck.REDUNDANT_PICS_TOP_HIT.value, + ) + ) + .count() + ) == 3