diff --git a/src/gentropy/common/spark_helpers.py b/src/gentropy/common/spark_helpers.py index 913e67a67..60b49d88d 100644 --- a/src/gentropy/common/spark_helpers.py +++ b/src/gentropy/common/spark_helpers.py @@ -4,19 +4,20 @@ import re import sys -from typing import TYPE_CHECKING, Any, Iterable, Optional +from functools import wraps +from typing import TYPE_CHECKING, Any, Callable, Iterable, Optional, TypeVar import pyspark.sql.functions as f import pyspark.sql.types as t from pyspark.ml import Pipeline from pyspark.ml.feature import MinMaxScaler, VectorAssembler from pyspark.ml.functions import vector_to_array -from pyspark.sql import Row, Window -from pyspark.sql.types import FloatType +from pyspark.sql import Column, Row, Window from scipy.stats import norm if TYPE_CHECKING: - from pyspark.sql import Column, DataFrame, WindowSpec + from pyspark.sql import DataFrame, WindowSpec + from pyspark.sql.types import StructType def convert_from_wide_to_long( @@ -54,7 +55,7 @@ def convert_from_wide_to_long( _vars_and_vals = f.array( *( f.struct( - f.lit(c).alias(var_name), f.col(c).cast(FloatType()).alias(value_name) + f.lit(c).alias(var_name), f.col(c).cast(t.FloatType()).alias(value_name) ) for c in value_vars ) @@ -433,3 +434,37 @@ def get_value_from_row(row: Row, column: str) -> Any: if column not in row: raise ValueError(f"Column {column} not found in row {row}") return row[column] + + +def enforce_schema( + expected_schema: StructType, +) -> Callable[..., Any]: + """A function to enforce the schema of a function output follows expectation. + + Behaviour: + - Fields that are not present in the expected schema will be dropped. + - Expected but missing fields will be added with Null values. + - Fields with incorrect data types will be casted to the expected data type. + + This is a decorator function and expected to be used like this: + + @enforce_schema(spark_schema) + def my_function() -> t.StructType: + return ... + + Args: + expected_schema (StructType): The expected schema of the output. + + Returns: + Callable[..., Any]: A decorator function. + """ + T = TypeVar("T", str, Column) + + def decorator(function: Callable[..., T]) -> Callable[..., T]: + @wraps(function) + def wrapper(*args: str, **kwargs: str) -> Any: + return f.from_json(f.to_json(function(*args, **kwargs)), expected_schema) + + return wrapper + + return decorator diff --git a/tests/gentropy/common/test_session.py b/tests/gentropy/common/test_session.py index 2ea7ccc1b..6afc640ab 100644 --- a/tests/gentropy/common/test_session.py +++ b/tests/gentropy/common/test_session.py @@ -1,4 +1,4 @@ -"""Tests GWAS Catalog study splitter.""" +"""Tests Gentropy session.""" from __future__ import annotations diff --git a/tests/gentropy/common/test_spark_helpers.py b/tests/gentropy/common/test_spark_helpers.py new file mode 100644 index 000000000..19ef6a436 --- /dev/null +++ b/tests/gentropy/common/test_spark_helpers.py @@ -0,0 +1,154 @@ +"""Tests spark-helper functions.""" + +from __future__ import annotations + +import pytest +from gentropy.common.spark_helpers import ( + enforce_schema, + order_array_of_structs_by_field, +) +from pyspark.sql import Column, SparkSession +from pyspark.sql import functions as f +from pyspark.sql import types as t + + +def test_order_array_of_structs_by_field(spark: SparkSession) -> None: + """Test order_array_of_structs_by_field.""" + data = [ + # Values are the same: + ("a", 12), + ("a", 12), + # First value bigger: + ("b", 12), + ("b", 1), + # Second value bigger: + ("c", 1), + ("c", 12), + # First value is null: + ("d", None), + ("d", 12), + # Second value is null: + ("e", 12), + ("e", None), + # Both values are null: + ("f", None), + ("f", None), + ] + + processed_data = ( + spark.createDataFrame(data, ["group", "value"]) + .groupBy("group") + .agg( + f.collect_list(f.struct(f.col("value").alias("value"))).alias("values"), + f.max(f.col("value")).alias("max_value"), + ) + .withColumn("sorted_values", order_array_of_structs_by_field("values", "value")) + .withColumn("sorted_max", f.col("sorted_values")[0].getField("value")) + .select("max_value", "sorted_max") + .collect() + ) + + for row in processed_data: + assert row["max_value"] == row["sorted_max"] + + +class TestEnforceSchema: + """Test enforce schema.""" + + EXPECTED_SCHEMA = t.StructType( + [ + t.StructField("field1", t.StringType(), True), + t.StructField("field2", t.StringType(), True), + t.StructField("field3", t.StringType(), True), + t.StructField("field4", t.FloatType(), True), + ] + ) + + @staticmethod + @enforce_schema(expected_schema=EXPECTED_SCHEMA) + def good_schema_test() -> Column: + """Create a struct with the expected schema.""" + return f.struct( + f.lit("test1").alias("field1"), + f.lit("test2").alias("field2"), + f.lit("test3").alias("field3"), + f.lit(2.0).alias("field4"), + ) + + @staticmethod + @enforce_schema(expected_schema=EXPECTED_SCHEMA) + def missing_column_test() -> Column: + """Create a struct with a missing column.""" + return f.struct( + f.lit("test1").alias("field1"), + f.lit("test3").alias("field3"), + ) + + @staticmethod + @enforce_schema(expected_schema=EXPECTED_SCHEMA) + def wrong_order_test() -> Column: + """Create a struct with the wrong order.""" + return f.struct( + f.lit("test2").alias("field2"), + f.lit("test1").alias("field1"), + ) + + @staticmethod + @enforce_schema(expected_schema=EXPECTED_SCHEMA) + def extra_column_test() -> Column: + """Create a struct with an extra column.""" + return f.struct( + f.lit("test2").alias("field2"), + f.lit("test1").alias("field1"), + f.lit("test5").alias("field5"), + f.lit(12.1).alias("field6"), + ) + + @staticmethod + @enforce_schema(expected_schema=EXPECTED_SCHEMA) + def wrong_type_test_1() -> Column: + """Create a struct with the wrong type.""" + return f.struct( + f.lit("test2").alias("field2"), + f.lit("test1").alias("field1"), + f.lit(5).cast(t.IntegerType()).alias("field3"), + ) + + @staticmethod + @enforce_schema(expected_schema=EXPECTED_SCHEMA) + def wrong_type_test_2() -> Column: + """Create a struct with the wrong type.""" + return f.struct( + f.lit("test2").alias("field2"), + f.lit("test1").alias("field1"), + f.lit("test").alias("field4"), + ) + + @pytest.fixture(autouse=True) + def _setup(self: TestEnforceSchema, spark: SparkSession) -> None: + """Setup fixture.""" + self.test_dataset = ( + spark.createDataFrame( + [("a",)], + ["label"], + ) + .withColumn("struct_1", self.good_schema_test()) + .withColumn("struct_2", self.missing_column_test()) + .withColumn("struct_3", self.wrong_order_test()) + .withColumn("struct_4", self.extra_column_test()) + .withColumn("struct_5", self.wrong_type_test_1()) + .withColumn("struct_6", self.wrong_type_test_2()) + ) + + def test_schema_consistency(self: TestEnforceSchema) -> None: + """Test enforce schema consistency.""" + # Looping through all the struct column and test if the schema is consistent + for column in [ + "struct_1", + "struct_2", + "struct_3", + "struct_4", + "struct_5", + "struct_6", + ]: + assert self.test_dataset.schema[column].dataType == self.EXPECTED_SCHEMA