Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(safe_array_union): allow for sorting nested structs #793

Merged
merged 3 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 86 additions & 2 deletions src/gentropy/common/spark_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,14 +614,21 @@ def rename_all_columns(df: DataFrame, prefix: str) -> DataFrame:
)


def safe_array_union(a: Column, b: Column) -> Column:
def safe_array_union(
a: Column, b: Column, fields_order: list[str] | None = None
) -> Column:
"""Merge the content of two optional columns.

The function assumes the array columns have the same schema. Otherwise, the function will fail.
The function assumes the array columns have the same schema.
If the `fields_order` is passed, the function assumes that it deals with array of structs and sorts the nested
struct fields by the provided `fields_order` before conducting array_merge.
If the `fields_order` is not passed and both columns are <array<struct<...>> type then function assumes struct fields have the same order,
otherwise the function will raise an AnalysisException.

Args:
a (Column): One optional array column.
b (Column): The other optional array column.
fields_order (list[str] | None): The order of the fields in the struct. Defaults to None.

Returns:
Column: array column with merged content.
Expand All @@ -644,12 +651,89 @@ def safe_array_union(a: Column, b: Column) -> Column:
| null|
+------+
<BLANKLINE>
>>> schema="arr2: array<struct<b:int,a:string>>, arr: array<struct<a:string,b:int>>"
>>> data = [([(1,"a",), (2, "c")],[("a", 1,)]),]
>>> df = spark.createDataFrame(data=data, schema=schema)
>>> df.select(safe_array_union(f.col("arr"), f.col("arr2"), fields_order=["a", "b"]).alias("merged")).show()
+----------------+
| merged|
+----------------+
|[{a, 1}, {c, 2}]|
+----------------+
<BLANKLINE>
>>> schema="arr2: array<struct<b:int,a:string>>, arr: array<struct<a:string,b:int>>"
>>> data = [([(1,"a",), (2, "c")],[("a", 1,)]),]
>>> df = spark.createDataFrame(data=data, schema=schema)
>>> df.select(safe_array_union(f.col("arr"), f.col("arr2")).alias("merged")).show() # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
pyspark.sql.utils.AnalysisException: ...
"""
if fields_order:
# sort the nested struct fields by the provided order
a = sort_array_struct_by_columns(a, fields_order)
b = sort_array_struct_by_columns(b, fields_order)
return f.when(a.isNotNull() & b.isNotNull(), f.array_union(a, b)).otherwise(
f.coalesce(a, b)
)



Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this do the same as order_array_of_structs_by_field and order_array_of_structs_by_two_fields?

def sort_array_struct_by_columns(column: Column, fields_order: list[str]) -> Column:
"""Sort nested struct fields by provided fields order.

Args:
column (Column): Column with array of structs.
fields_order (list[str]): List of field names to sort by.

Returns:
Column: Sorted column.

Examples:
>>> schema="arr: array<struct<b:int,a:string>>"
>>> data = [([(1,"a",), (2, "c")],)]
>>> fields_order = ["a", "b"]
>>> df = spark.createDataFrame(data=data, schema=schema)
>>> df.select(sort_array_struct_by_columns(f.col("arr"), fields_order).alias("sorted")).show()
+----------------+
| sorted|
+----------------+
|[{c, 2}, {a, 1}]|
+----------------+
<BLANKLINE>
"""
column_name = extract_column_name(column)
fields_order_expr = ", ".join([f"x.{field}" for field in fields_order])
return f.expr(
f"sort_array(transform({column_name}, x -> struct({fields_order_expr})), False)"
).alias(column_name)


def extract_column_name(column: Column) -> str:
"""Extract column name from a column expression.

Args:
column (Column): Column expression.

Returns:
str: Column name.

Raises:
ValueError: If the column name cannot be extracted.

Examples:
>>> extract_column_name(f.col('col1'))
'col1'
>>> extract_column_name(f.sort_array(f.col('col1')))
'sort_array(col1, true)'
"""
pattern = re.compile("^Column<'(?P<name>.*)'>?")

_match = pattern.search(str(column))
if not _match:
raise ValueError(f"Cannot extract column name from {column}")
return _match.group("name")


def create_empty_column_if_not_exists(
col_name: str, col_schema: t.DataType = t.NullType()
) -> Column:
Expand Down
14 changes: 12 additions & 2 deletions src/gentropy/dataset/variant_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
from typing import TYPE_CHECKING

import pyspark.sql.functions as f
import pyspark.sql.types as t

from gentropy.common.schemas import parse_spark_schema
from gentropy.common.spark_helpers import (
get_nested_struct_schema,
get_record_with_maximum_value,
normalise_column,
rename_all_columns,
Expand Down Expand Up @@ -131,6 +133,7 @@ def add_annotation(
# Prefix for renaming columns:
prefix = "annotation_"


# Generate select expressions that to merge and import columns from annotation:
select_expressions = []

Expand All @@ -141,10 +144,17 @@ def add_annotation(
# If an annotation column can be found in both datasets:
if (column in self.df.columns) and (column in annotation_source.df.columns):
# Arrays are merged:
if "ArrayType" in field.dataType.__str__():
if isinstance(field.dataType, t.ArrayType):
fields_order = None
if isinstance(field.dataType.elementType, t.StructType):
# Extract the schema of the array to get the order of the fields:
array_schema = [
field for field in VariantIndex.get_schema().fields if field.name == column
][0].dataType
fields_order = get_nested_struct_schema(array_schema).fieldNames()
select_expressions.append(
safe_array_union(
f.col(column), f.col(f"{prefix}{column}")
f.col(column), f.col(f"{prefix}{column}"), fields_order
).alias(column)
)
# Non-array columns are coalesced:
Expand Down
1 change: 0 additions & 1 deletion src/gentropy/datasource/open_targets/variants.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ def as_vcf_df(
variant_df = variant_df.withColumn(
col, create_empty_column_if_not_exists(col)
)

return (
variant_df.filter(f.col("variantId").isNotNull())
.withColumn(
Expand Down
Loading