From b29269fca5e803d10fc855e2deb8f52c8ed1200b Mon Sep 17 00:00:00 2001 From: Szymon Szyszkowski Date: Thu, 26 Sep 2024 15:55:10 +0100 Subject: [PATCH 1/2] fix: remove study_index_path from coloc step --- src/gentropy/config.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/gentropy/config.py b/src/gentropy/config.py index c56a9dfb3..3a67e7868 100644 --- a/src/gentropy/config.py +++ b/src/gentropy/config.py @@ -36,7 +36,6 @@ class ColocalisationConfig(StepConfig): """Colocalisation step configuration.""" credible_set_path: str = MISSING - study_index_path: str = MISSING coloc_path: str = MISSING colocalisation_method: str = MISSING _target_: str = "gentropy.colocalisation.ColocalisationStep" From 7b133799f445ab8ccaa4fe4d19ad26dad996f0cb Mon Sep 17 00:00:00 2001 From: Szymon Szyszkowski Date: Fri, 27 Sep 2024 13:57:45 +0100 Subject: [PATCH 2/2] fix(safe_array_union): sort struct fields in array --- src/gentropy/common/spark_helpers.py | 88 ++++++++++++++++++- src/gentropy/dataset/variant_index.py | 14 ++- .../datasource/open_targets/variants.py | 1 - 3 files changed, 98 insertions(+), 5 deletions(-) diff --git a/src/gentropy/common/spark_helpers.py b/src/gentropy/common/spark_helpers.py index 680975ef6..3fdabfbcc 100644 --- a/src/gentropy/common/spark_helpers.py +++ b/src/gentropy/common/spark_helpers.py @@ -614,14 +614,21 @@ def rename_all_columns(df: DataFrame, prefix: str) -> DataFrame: ) -def safe_array_union(a: Column, b: Column) -> Column: +def safe_array_union( + a: Column, b: Column, fields_order: list[str] | None = None +) -> Column: """Merge the content of two optional columns. - The function assumes the array columns have the same schema. Otherwise, the function will fail. + The function assumes the array columns have the same schema. + If the `fields_order` is passed, the function assumes that it deals with array of structs and sorts the nested + struct fields by the provided `fields_order` before conducting array_merge. + If the `fields_order` is not passed and both columns are > type then function assumes struct fields have the same order, + otherwise the function will raise an AnalysisException. Args: a (Column): One optional array column. b (Column): The other optional array column. + fields_order (list[str] | None): The order of the fields in the struct. Defaults to None. Returns: Column: array column with merged content. @@ -644,12 +651,89 @@ def safe_array_union(a: Column, b: Column) -> Column: | null| +------+ + >>> schema="arr2: array>, arr: array>" + >>> data = [([(1,"a",), (2, "c")],[("a", 1,)]),] + >>> df = spark.createDataFrame(data=data, schema=schema) + >>> df.select(safe_array_union(f.col("arr"), f.col("arr2"), fields_order=["a", "b"]).alias("merged")).show() + +----------------+ + | merged| + +----------------+ + |[{a, 1}, {c, 2}]| + +----------------+ + + >>> schema="arr2: array>, arr: array>" + >>> data = [([(1,"a",), (2, "c")],[("a", 1,)]),] + >>> df = spark.createDataFrame(data=data, schema=schema) + >>> df.select(safe_array_union(f.col("arr"), f.col("arr2")).alias("merged")).show() # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + pyspark.sql.utils.AnalysisException: ... """ + if fields_order: + # sort the nested struct fields by the provided order + a = sort_array_struct_by_columns(a, fields_order) + b = sort_array_struct_by_columns(b, fields_order) return f.when(a.isNotNull() & b.isNotNull(), f.array_union(a, b)).otherwise( f.coalesce(a, b) ) + +def sort_array_struct_by_columns(column: Column, fields_order: list[str]) -> Column: + """Sort nested struct fields by provided fields order. + + Args: + column (Column): Column with array of structs. + fields_order (list[str]): List of field names to sort by. + + Returns: + Column: Sorted column. + + Examples: + >>> schema="arr: array>" + >>> data = [([(1,"a",), (2, "c")],)] + >>> fields_order = ["a", "b"] + >>> df = spark.createDataFrame(data=data, schema=schema) + >>> df.select(sort_array_struct_by_columns(f.col("arr"), fields_order).alias("sorted")).show() + +----------------+ + | sorted| + +----------------+ + |[{c, 2}, {a, 1}]| + +----------------+ + + """ + column_name = extract_column_name(column) + fields_order_expr = ", ".join([f"x.{field}" for field in fields_order]) + return f.expr( + f"sort_array(transform({column_name}, x -> struct({fields_order_expr})), False)" + ).alias(column_name) + + +def extract_column_name(column: Column) -> str: + """Extract column name from a column expression. + + Args: + column (Column): Column expression. + + Returns: + str: Column name. + + Raises: + ValueError: If the column name cannot be extracted. + + Examples: + >>> extract_column_name(f.col('col1')) + 'col1' + >>> extract_column_name(f.sort_array(f.col('col1'))) + 'sort_array(col1, true)' + """ + pattern = re.compile("^Column<'(?P.*)'>?") + + _match = pattern.search(str(column)) + if not _match: + raise ValueError(f"Cannot extract column name from {column}") + return _match.group("name") + + def create_empty_column_if_not_exists( col_name: str, col_schema: t.DataType = t.NullType() ) -> Column: diff --git a/src/gentropy/dataset/variant_index.py b/src/gentropy/dataset/variant_index.py index 1cc1eac1b..2f24cd985 100644 --- a/src/gentropy/dataset/variant_index.py +++ b/src/gentropy/dataset/variant_index.py @@ -6,9 +6,11 @@ from typing import TYPE_CHECKING import pyspark.sql.functions as f +import pyspark.sql.types as t from gentropy.common.schemas import parse_spark_schema from gentropy.common.spark_helpers import ( + get_nested_struct_schema, get_record_with_maximum_value, normalise_column, rename_all_columns, @@ -131,6 +133,7 @@ def add_annotation( # Prefix for renaming columns: prefix = "annotation_" + # Generate select expressions that to merge and import columns from annotation: select_expressions = [] @@ -141,10 +144,17 @@ def add_annotation( # If an annotation column can be found in both datasets: if (column in self.df.columns) and (column in annotation_source.df.columns): # Arrays are merged: - if "ArrayType" in field.dataType.__str__(): + if isinstance(field.dataType, t.ArrayType): + fields_order = None + if isinstance(field.dataType.elementType, t.StructType): + # Extract the schema of the array to get the order of the fields: + array_schema = [ + field for field in VariantIndex.get_schema().fields if field.name == column + ][0].dataType + fields_order = get_nested_struct_schema(array_schema).fieldNames() select_expressions.append( safe_array_union( - f.col(column), f.col(f"{prefix}{column}") + f.col(column), f.col(f"{prefix}{column}"), fields_order ).alias(column) ) # Non-array columns are coalesced: diff --git a/src/gentropy/datasource/open_targets/variants.py b/src/gentropy/datasource/open_targets/variants.py index 03018438b..5b6822ae6 100644 --- a/src/gentropy/datasource/open_targets/variants.py +++ b/src/gentropy/datasource/open_targets/variants.py @@ -95,7 +95,6 @@ def as_vcf_df( variant_df = variant_df.withColumn( col, create_empty_column_if_not_exists(col) ) - return ( variant_df.filter(f.col("variantId").isNotNull()) .withColumn(