Skip to content

Commit

Permalink
Merge branch 'dev' of https://github.com/opentargets/gentropy into il…
Browse files Browse the repository at this point in the history
…-3434
  • Loading branch information
ireneisdoomed committed Sep 27, 2024
2 parents bb47b01 + 9f83329 commit 3f5de30
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 13 deletions.
52 changes: 50 additions & 2 deletions src/gentropy/common/spark_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,8 @@ def order_array_of_structs_by_two_fields(
"""Sort array of structs by a field in descending order and by an other field in an ascending order.
This function doesn't deal with null values, assumes the sort columns are not nullable.
The sorting function compares the descending_column first, in case when two values from descending_column are equal
it compares the ascending_column. When values in both columns are equal, the rows order is preserved.
Args:
array_name (str): Column name with array of structs
Expand All @@ -406,6 +408,20 @@ def order_array_of_structs_by_two_fields(
|[{1.0, 45, First}, {1.0, 125, Second}, {0.5, 232, Third}, {0.5, 233, Fourth}]|
+-----------------------------------------------------------------------------+
<BLANKLINE>
>>> data = [(1.0, 45, 'First'), (1.0, 45, 'Second'), (0.5, 233, 'Fourth'), (1.0, 125, 'Third'),]
>>> (
... spark.createDataFrame(data, ['col1', 'col2', 'ranking'])
... .groupBy(f.lit('c'))
... .agg(f.collect_list(f.struct('col1','col2', 'ranking')).alias('list'))
... .select(order_array_of_structs_by_two_fields('list', 'col1', 'col2').alias('sorted_list'))
... .show(truncate=False)
... )
+----------------------------------------------------------------------------+
|sorted_list |
+----------------------------------------------------------------------------+
|[{1.0, 45, First}, {1.0, 45, Second}, {1.0, 125, Third}, {0.5, 233, Fourth}]|
+----------------------------------------------------------------------------+
<BLANKLINE>
"""
return f.expr(
f"""
Expand All @@ -425,6 +441,7 @@ def order_array_of_structs_by_two_fields(
when left.{descending_column} > right.{descending_column} then -1
when left.{descending_column} == right.{descending_column} and left.{ascending_column} > right.{ascending_column} then 1
when left.{descending_column} == right.{descending_column} and left.{ascending_column} < right.{ascending_column} then -1
when left.{ascending_column} == right.{ascending_column} and left.{descending_column} == right.{descending_column} then 0
end)
"""
)
Expand Down Expand Up @@ -525,7 +542,7 @@ def get_value_from_row(row: Row, column: str) -> Any:


def enforce_schema(
expected_schema: t.StructType,
expected_schema: t.ArrayType | t.StructType | Column | str,
) -> Callable[..., Any]:
"""A function to enforce the schema of a function output follows expectation.
Expand All @@ -541,7 +558,7 @@ def my_function() -> t.StructType:
return ...
Args:
expected_schema (t.StructType): The expected schema of the output.
expected_schema (t.ArrayType | t.StructType | Column | str): The expected schema of the output.
Returns:
Callable[..., Any]: A decorator function.
Expand Down Expand Up @@ -687,3 +704,34 @@ def get_standard_error_from_confidence_interval(lower: Column, upper: Column) ->
<BLANKLINE>
"""
return (upper - lower) / (2 * 1.96)


def get_nested_struct_schema(dtype: t.DataType) -> t.StructType:
"""Get the bottom StructType from a nested ArrayType type.
Args:
dtype (t.DataType): The nested data structure.
Returns:
t.StructType: The nested struct schema.
Raises:
TypeError: If the input data type is not a nested struct.
Examples:
>>> get_nested_struct_schema(t.ArrayType(t.StructType([t.StructField('a', t.StringType())])))
StructType([StructField('a', StringType(), True)])
>>> get_nested_struct_schema(t.ArrayType(t.ArrayType(t.StructType([t.StructField("a", t.StringType())]))))
StructType([StructField('a', StringType(), True)])
"""
if isinstance(dtype, t.StructField):
dtype = dtype.dataType

match dtype:
case t.StructType(fields=_):
return dtype
case t.ArrayType(elementType=dtype):
return get_nested_struct_schema(dtype)
case _:
raise TypeError("The input data type must be a nested struct.")
1 change: 0 additions & 1 deletion src/gentropy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ class ColocalisationConfig(StepConfig):
"""Colocalisation step configuration."""

credible_set_path: str = MISSING
study_index_path: str = MISSING
coloc_path: str = MISSING
colocalisation_method: str = MISSING
_target_: str = "gentropy.colocalisation.ColocalisationStep"
Expand Down
23 changes: 13 additions & 10 deletions src/gentropy/datasource/ensembl/vep_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from gentropy.common.schemas import parse_spark_schema
from gentropy.common.spark_helpers import (
enforce_schema,
get_nested_struct_schema,
map_column_by_dictionary,
order_array_of_structs_by_field,
order_array_of_structs_by_two_fields,
Expand All @@ -26,14 +27,16 @@

class VariantEffectPredictorParser:
"""Collection of methods to parse VEP output in json format."""
# NOTE: Due to the fact that the comparison of the xrefs is done om the base of rsids
# if the field `colocalised_variants` have multiple rsids, this extracting xrefs will result in
# an array of xref structs, rather then the struct itself.

# Schema description of the dbXref object:
DBXREF_SCHEMA = VariantIndex.get_schema()["dbXrefs"].dataType

# Schema description of the in silico predictor object:
IN_SILICO_PREDICTOR_SCHEMA = VariantIndex.get_schema()[
"inSilicoPredictors"
].dataType
IN_SILICO_PREDICTOR_SCHEMA = get_nested_struct_schema(
VariantIndex.get_schema()["inSilicoPredictors"]
)

# Schema for the allele frequency column:
ALLELE_FREQUENCY_SCHEMA = VariantIndex.get_schema()["alleleFrequencies"].dataType
Expand Down Expand Up @@ -350,12 +353,12 @@ def _get_max_alpha_missense(transcripts: Column) -> Column:
... .select(VariantEffectPredictorParser._get_max_alpha_missense(f.col('transcripts')).alias('am'))
... .show(truncate=False)
... )
+------------------------------------------------------+
|am |
+------------------------------------------------------+
|[{max alpha missense, assessment 1, 0.4, null, gene1}]|
|[{max alpha missense, null, null, null, gene1}] |
+------------------------------------------------------+
+----------------------------------------------------+
|am |
+----------------------------------------------------+
|{max alpha missense, assessment 1, 0.4, null, gene1}|
|{max alpha missense, null, null, null, gene1} |
+----------------------------------------------------+
<BLANKLINE>
"""
return f.transform(
Expand Down
Loading

0 comments on commit 3f5de30

Please sign in to comment.