Skip to content

Commit

Permalink
Merge branch 'dev' into vh-3448
Browse files Browse the repository at this point in the history
  • Loading branch information
DSuveges committed Sep 26, 2024
2 parents f1b0817 + 51125c7 commit c441b79
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 15 deletions.
52 changes: 50 additions & 2 deletions src/gentropy/common/spark_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,8 @@ def order_array_of_structs_by_two_fields(
"""Sort array of structs by a field in descending order and by an other field in an ascending order.
This function doesn't deal with null values, assumes the sort columns are not nullable.
The sorting function compares the descending_column first, in case when two values from descending_column are equal
it compares the ascending_column. When values in both columns are equal, the rows order is preserved.
Args:
array_name (str): Column name with array of structs
Expand All @@ -406,6 +408,20 @@ def order_array_of_structs_by_two_fields(
|[{1.0, 45, First}, {1.0, 125, Second}, {0.5, 232, Third}, {0.5, 233, Fourth}]|
+-----------------------------------------------------------------------------+
<BLANKLINE>
>>> data = [(1.0, 45, 'First'), (1.0, 45, 'Second'), (0.5, 233, 'Fourth'), (1.0, 125, 'Third'),]
>>> (
... spark.createDataFrame(data, ['col1', 'col2', 'ranking'])
... .groupBy(f.lit('c'))
... .agg(f.collect_list(f.struct('col1','col2', 'ranking')).alias('list'))
... .select(order_array_of_structs_by_two_fields('list', 'col1', 'col2').alias('sorted_list'))
... .show(truncate=False)
... )
+----------------------------------------------------------------------------+
|sorted_list |
+----------------------------------------------------------------------------+
|[{1.0, 45, First}, {1.0, 45, Second}, {1.0, 125, Third}, {0.5, 233, Fourth}]|
+----------------------------------------------------------------------------+
<BLANKLINE>
"""
return f.expr(
f"""
Expand All @@ -425,6 +441,7 @@ def order_array_of_structs_by_two_fields(
when left.{descending_column} > right.{descending_column} then -1
when left.{descending_column} == right.{descending_column} and left.{ascending_column} > right.{ascending_column} then 1
when left.{descending_column} == right.{descending_column} and left.{ascending_column} < right.{ascending_column} then -1
when left.{ascending_column} == right.{ascending_column} and left.{descending_column} == right.{descending_column} then 0
end)
"""
)
Expand Down Expand Up @@ -525,7 +542,7 @@ def get_value_from_row(row: Row, column: str) -> Any:


def enforce_schema(
expected_schema: t.StructType,
expected_schema: t.ArrayType | t.StructType | Column | str,
) -> Callable[..., Any]:
"""A function to enforce the schema of a function output follows expectation.
Expand All @@ -541,7 +558,7 @@ def my_function() -> t.StructType:
return ...
Args:
expected_schema (t.StructType): The expected schema of the output.
expected_schema (t.ArrayType | t.StructType | Column | str): The expected schema of the output.
Returns:
Callable[..., Any]: A decorator function.
Expand Down Expand Up @@ -687,3 +704,34 @@ def get_standard_error_from_confidence_interval(lower: Column, upper: Column) ->
<BLANKLINE>
"""
return (upper - lower) / (2 * 1.96)


def get_nested_struct_schema(dtype: t.DataType) -> t.StructType:
"""Get the bottom StructType from a nested ArrayType type.
Args:
dtype (t.DataType): The nested data structure.
Returns:
t.StructType: The nested struct schema.
Raises:
TypeError: If the input data type is not a nested struct.
Examples:
>>> get_nested_struct_schema(t.ArrayType(t.StructType([t.StructField('a', t.StringType())])))
StructType([StructField('a', StringType(), True)])
>>> get_nested_struct_schema(t.ArrayType(t.ArrayType(t.StructType([t.StructField("a", t.StringType())]))))
StructType([StructField('a', StringType(), True)])
"""
if isinstance(dtype, t.StructField):
dtype = dtype.dataType

match dtype:
case t.StructType(fields=_):
return dtype
case t.ArrayType(elementType=dtype):
return get_nested_struct_schema(dtype)
case _:
raise TypeError("The input data type must be a nested struct.")
1 change: 0 additions & 1 deletion src/gentropy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,6 @@ class StudyLocusValidationStepConfig(StepConfig):
valid_study_locus_path: str = MISSING
invalid_study_locus_path: str = MISSING
invalid_qc_reasons: list[str] = MISSING
gwas_significance: float = WindowBasedClumpingStepConfig.gwas_significance
_target_: str = "gentropy.study_locus_validation.StudyLocusValidationStep"


Expand Down
23 changes: 13 additions & 10 deletions src/gentropy/datasource/ensembl/vep_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from gentropy.common.schemas import parse_spark_schema
from gentropy.common.spark_helpers import (
enforce_schema,
get_nested_struct_schema,
map_column_by_dictionary,
order_array_of_structs_by_field,
order_array_of_structs_by_two_fields,
Expand All @@ -26,14 +27,16 @@

class VariantEffectPredictorParser:
"""Collection of methods to parse VEP output in json format."""
# NOTE: Due to the fact that the comparison of the xrefs is done om the base of rsids
# if the field `colocalised_variants` have multiple rsids, this extracting xrefs will result in
# an array of xref structs, rather then the struct itself.

# Schema description of the dbXref object:
DBXREF_SCHEMA = VariantIndex.get_schema()["dbXrefs"].dataType

# Schema description of the in silico predictor object:
IN_SILICO_PREDICTOR_SCHEMA = VariantIndex.get_schema()[
"inSilicoPredictors"
].dataType
IN_SILICO_PREDICTOR_SCHEMA = get_nested_struct_schema(
VariantIndex.get_schema()["inSilicoPredictors"]
)

# Schema for the allele frequency column:
ALLELE_FREQUENCY_SCHEMA = VariantIndex.get_schema()["alleleFrequencies"].dataType
Expand Down Expand Up @@ -350,12 +353,12 @@ def _get_max_alpha_missense(transcripts: Column) -> Column:
... .select(VariantEffectPredictorParser._get_max_alpha_missense(f.col('transcripts')).alias('am'))
... .show(truncate=False)
... )
+------------------------------------------------------+
|am |
+------------------------------------------------------+
|[{max alpha missense, assessment 1, 0.4, null, gene1}]|
|[{max alpha missense, null, null, null, gene1}] |
+------------------------------------------------------+
+----------------------------------------------------+
|am |
+----------------------------------------------------+
|{max alpha missense, assessment 1, 0.4, null, gene1}|
|{max alpha missense, null, null, null, gene1} |
+----------------------------------------------------+
<BLANKLINE>
"""
return f.transform(
Expand Down
2 changes: 0 additions & 2 deletions src/gentropy/study_locus_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ def __init__(
session: Session,
study_index_path: str,
study_locus_path: list[str],
gwas_significance: float,
valid_study_locus_path: str,
invalid_study_locus_path: str,
invalid_qc_reasons: list[str] = [],
Expand All @@ -30,7 +29,6 @@ def __init__(
session (Session): Session object.
study_index_path (str): Path to study index file.
study_locus_path (list[str]): Path to study locus dataset.
gwas_significance (float): GWAS significance threshold.
valid_study_locus_path (str): Path to write the valid records.
invalid_study_locus_path (str): Path to write the output file.
invalid_qc_reasons (list[str]): List of invalid quality check reason names from `StudyLocusQualityCheck` (e.g. ['SUBSIGNIFICANT_FLAG']).
Expand Down
Loading

0 comments on commit c441b79

Please sign in to comment.