Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(ld clumping): a revised logic allows a more accurate clumping #772

Merged
merged 8 commits into from
Sep 23, 2024
8 changes: 6 additions & 2 deletions src/gentropy/dataset/study_locus.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,11 +798,12 @@ def clump(self: StudyLocus) -> StudyLocus:
Returns:
StudyLocus: with empty credible sets for linked variants and QC flag.
"""
self.df = (
clumped_df = (
self.df.withColumn(
"is_lead_linked",
LDclumping._is_lead_linked(
self.df.studyId,
self.df.chromosome,
self.df.variantId,
self.df.pValueExponent,
self.df.pValueMantissa,
Expand All @@ -823,7 +824,10 @@ def clump(self: StudyLocus) -> StudyLocus:
)
.drop("is_lead_linked")
)
return self
return StudyLocus(
_df=clumped_df,
_schema=self.get_schema(),
)

def exclude_region(
self: StudyLocus, region: GenomicRegion, exclude_overlap: bool = False
Expand Down
45 changes: 23 additions & 22 deletions src/gentropy/method/clump.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Clumps GWAS significant variants to generate a studyLocus dataset of independent variants."""


from __future__ import annotations

from typing import TYPE_CHECKING
Expand All @@ -20,6 +19,7 @@ class LDclumping:
@staticmethod
def _is_lead_linked(
study_id: Column,
chromosome: Column,
variant_id: Column,
p_value_exponent: Column,
p_value_mantissa: Column,
Expand All @@ -29,6 +29,7 @@ def _is_lead_linked(

Args:
study_id (Column): studyId
chromosome (Column): chromosome
variant_id (Column): Lead variant id
p_value_exponent (Column): p-value exponent
p_value_mantissa (Column): p-value mantissa
Expand All @@ -37,31 +38,31 @@ def _is_lead_linked(
Returns:
Column: Boolean in which True indicates that the lead is linked to another tag in the same dataset.
"""
leads_in_study = f.collect_set(variant_id).over(Window.partitionBy(study_id))
tags_in_studylocus = f.array_union(
# Get all tag variants from the credible set per studyLocusId
f.transform(ld_set, lambda x: x.tagVariantId),
# And append the lead variant so that the intersection is the same for all studyLocusIds in a study
f.array(variant_id),
)
intersect_lead_tags = f.array_sort(
f.array_intersect(leads_in_study, tags_in_studylocus)
# Partitoning data by study and chromosome - this is the scope for looking for linked loci.
# Within the partition, we order the data by increasing p-value, and we collect the more significant lead variants in the window.
windowspec = (
Window.partitionBy(study_id, chromosome)
.orderBy(p_value_exponent.asc(), p_value_mantissa.asc())
.rowsBetween(Window.unboundedPreceding, Window.currentRow)
)
return (
# If the lead is in the credible set, we rank the peaks by p-value
f.when(
f.size(intersect_lead_tags) > 0,
f.row_number().over(
Window.partitionBy(study_id, intersect_lead_tags).orderBy(
p_value_exponent, p_value_mantissa
)
)
> 1,
more_significant_leads = f.collect_set(variant_id).over(windowspec)

# Collect all variants from the ld_set + adding the lead variant to the list to make sure that the lead is always in the list.
tags_in_studylocus = f.array_distinct(
f.array_union(
f.array(variant_id),
f.transform(ld_set, lambda x: x.getField("tagVariantId")),
)
# If the intersection is empty (lead is not in the credible set or cred set is empty), the association is not linked
.otherwise(f.lit(False))
)

# If more than one tags of the ld_set can be found in the list of the more significant leads, the lead is linked.
# Study loci without variantId is considered as not linked.
# Also leads that were not found in the LD index is also considered as not linked.
return f.when(
variant_id.isNotNull(),
f.size(f.array_intersect(more_significant_leads, tags_in_studylocus)) > 1,
).otherwise(f.lit(False))

@classmethod
def clump(cls: type[LDclumping], associations: StudyLocus) -> StudyLocus:
"""Perform clumping on studyLocus dataset.
Expand Down
250 changes: 102 additions & 148 deletions tests/gentropy/method/test_clump.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING

import pyspark.sql.functions as f
import pyspark.sql.types as t
Expand All @@ -20,142 +20,77 @@ def test_clump(mock_study_locus: StudyLocus) -> None:
assert isinstance(LDclumping.clump(mock_study_locus), StudyLocus)


@pytest.mark.parametrize(
("observed_data", "expected_data"),
[
class TestIsLeadLinked:
"""Testing the is_lead_linked method."""

DATA = [
# Linked to V2:
(
[
(
# Dependent locus - lead is correlated with a more significant variant
1,
"L1",
"GCST005650_1",
1.0,
-17,
[{"tagVariantId": "T1"}, {"tagVariantId": "L2"}],
None,
),
(
# Dependent locus - lead shows a stronger association than the row above
2,
"L2",
"GCST005650_1",
4.0,
-18,
[
{"tagVariantId": "T2"},
{"tagVariantId": "T3"},
{"tagVariantId": "L1"},
],
None,
),
(
# Independent locus
3,
"L2",
"GCST005650_1",
4.0,
-18,
[
{"tagVariantId": "L3"},
{"tagVariantId": "T4"},
{"tagVariantId": "L5"},
],
None,
),
(
# Empty credible set
4,
"L3",
"GCST005650_1",
4.0,
-18,
[],
None,
),
(
# Null credible set
5,
"L4",
"GCST005650_1",
4.0,
-18,
None,
None,
),
],
[
(
# Signal is linked to the next row
1,
"L1",
"GCST005650_1",
1.0,
-17,
[{"tagVariantId": "T1"}, {"tagVariantId": "L2"}],
True,
),
(
# Signal is the most significant
2,
"L2",
"GCST005650_1",
4.0,
-18,
[
{"tagVariantId": "T2"},
{"tagVariantId": "T3"},
{"tagVariantId": "L1"},
],
False,
),
(
# Signal is not linked
3,
"L2",
"GCST005650_1",
4.0,
-18,
[
{"tagVariantId": "L3"},
{"tagVariantId": "T4"},
{"tagVariantId": "L5"},
],
False,
),
(
# Empty credible set - signal is not linked
4,
"L3",
"GCST005650_1",
4.0,
-18,
[],
False,
),
(
# Null credible set - signal is not linked
5,
"L4",
"GCST005650_1",
4.0,
-18,
None,
False,
),
],
)
],
)
def test_is_lead_linked(
spark: SparkSession, observed_data: list[Any], expected_data: list[Any]
) -> None:
"""Test function that annotates whether a studyLocusId is linked to a more statistically significant studyLocusId."""
schema = t.StructType(
"s1",
1,
"c1",
"v3",
1.0,
-8,
[{"tagVariantId": "v3"}, {"tagVariantId": "v2"}, {"tagVariantId": "v4"}],
True,
),
# True lead:
(
"s1",
2,
"c1",
"v1",
1.0,
-10,
[{"tagVariantId": "v1"}, {"tagVariantId": "v2"}, {"tagVariantId": "v3"}],
False,
),
# Linked to V1:
(
"s1",
3,
"c1",
"v2",
1.0,
-9,
[{"tagVariantId": "v2"}, {"tagVariantId": "v1"}],
True,
),
# Independent - No LD set:
("s1", 4, "c1", "v10", 1.0, -10, [], False),
# Independent - No variantId:
("s1", 5, "c1", None, 1.0, -10, [], False),
# An other independent variant on the same chromosome, but lead is not in ldSet:
(
"s1",
6,
"c1",
"v6",
1.0,
-8,
[{"tagVariantId": "v7"}, {"tagVariantId": "v8"}, {"tagVariantId": "v9"}],
False,
),
# An other independent variant on a different chromosome, but lead is not in ldSet:
(
"s1",
7,
"c2",
"v10",
1.0,
-8,
[{"tagVariantId": "v2"}, {"tagVariantId": "v10"}],
False,
),
]

SCHEMA = t.StructType(
[
t.StructField("studyId", t.StringType(), True),
t.StructField("studyLocusId", t.LongType(), True),
t.StructField("chromosome", t.StringType(), True),
t.StructField("variantId", t.StringType(), True),
t.StructField("studyId", t.StringType(), True),
t.StructField("pValueMantissa", t.FloatType(), True),
t.StructField("pValueExponent", t.IntegerType(), True),
t.StructField(
Expand All @@ -169,28 +104,47 @@ def test_is_lead_linked(
),
True,
),
t.StructField("is_lead_linked", t.BooleanType(), True),
t.StructField("expected_flag", t.BooleanType(), True),
]
)
study_locus_df = spark.createDataFrame(
observed_data,
schema,
)
observed_df = (
study_locus_df.withColumn(

@pytest.fixture(autouse=True)
def _setup(self: TestIsLeadLinked, spark: SparkSession) -> None:
"""Setup study the mock index for testing."""
# Store input data:
self.df = spark.createDataFrame(self.DATA, self.SCHEMA)

def test_is_lead_correctness(self: TestIsLeadLinked) -> None:
"""Test the correctness of the is_lead_linked method."""
observed = self.df.withColumn(
"is_lead_linked",
LDclumping._is_lead_linked(
f.col("studyId"),
f.col("chromosome"),
f.col("variantId"),
f.col("pValueExponent"),
f.col("pValueMantissa"),
f.col("ldSet"),
),
)
.orderBy("studyLocusId")
.collect()
)
expected_df = (
spark.createDataFrame(expected_data, schema).orderBy("studyLocusId").collect()
)
assert observed_df == expected_df
).collect()

for row in observed:
assert row["is_lead_linked"] == row["expected_flag"]

def test_flagging(self: TestIsLeadLinked) -> None:
"""Test flagging of lead variants."""
# Create the study locus and clump:
sl_flagged = StudyLocus(
_df=self.df.drop("expected_flag").withColumn("qualityControls", f.array()),
_schema=StudyLocus.get_schema(),
).clump()

# Assert that the clumped locus is a StudyLocus:
assert isinstance(sl_flagged, StudyLocus)

# Assert that the clumped locus has the correct columns:
for row in sl_flagged.df.join(self.df, on="studylocusId").collect():
if len(row["qualityControls"]) == 0:
assert not row["expected_flag"]
else:
assert row["expected_flag"]
Loading