Skip to content

Commit

Permalink
feat(spark-helpers): enforce schema of returned objects (#617)
Browse files Browse the repository at this point in the history
* feat(spark-helpers): enforce schema of returned objects by other function

* fix: fixing docstring

* test: adding test for array sorter

* chore: pre-commit auto fixes [...]

* fix: addressing a typo

* chore: pre-commit auto fixes [...]

* fix: removing duplicated rows

* fix: addressing review comments

* fix: adding StructType hint

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
DSuveges and pre-commit-ci[bot] committed Jun 6, 2024
1 parent fd3154a commit 689340c
Show file tree
Hide file tree
Showing 3 changed files with 195 additions and 6 deletions.
45 changes: 40 additions & 5 deletions src/gentropy/common/spark_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,20 @@

import re
import sys
from typing import TYPE_CHECKING, Any, Iterable, Optional
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, Iterable, Optional, TypeVar

import pyspark.sql.functions as f
import pyspark.sql.types as t
from pyspark.ml import Pipeline
from pyspark.ml.feature import MinMaxScaler, VectorAssembler
from pyspark.ml.functions import vector_to_array
from pyspark.sql import Row, Window
from pyspark.sql.types import FloatType
from pyspark.sql import Column, Row, Window
from scipy.stats import norm

if TYPE_CHECKING:
from pyspark.sql import Column, DataFrame, WindowSpec
from pyspark.sql import DataFrame, WindowSpec
from pyspark.sql.types import StructType


def convert_from_wide_to_long(
Expand Down Expand Up @@ -54,7 +55,7 @@ def convert_from_wide_to_long(
_vars_and_vals = f.array(
*(
f.struct(
f.lit(c).alias(var_name), f.col(c).cast(FloatType()).alias(value_name)
f.lit(c).alias(var_name), f.col(c).cast(t.FloatType()).alias(value_name)
)
for c in value_vars
)
Expand Down Expand Up @@ -433,3 +434,37 @@ def get_value_from_row(row: Row, column: str) -> Any:
if column not in row:
raise ValueError(f"Column {column} not found in row {row}")
return row[column]


def enforce_schema(
expected_schema: StructType,
) -> Callable[..., Any]:
"""A function to enforce the schema of a function output follows expectation.
Behaviour:
- Fields that are not present in the expected schema will be dropped.
- Expected but missing fields will be added with Null values.
- Fields with incorrect data types will be casted to the expected data type.
This is a decorator function and expected to be used like this:
@enforce_schema(spark_schema)
def my_function() -> t.StructType:
return ...
Args:
expected_schema (StructType): The expected schema of the output.
Returns:
Callable[..., Any]: A decorator function.
"""
T = TypeVar("T", str, Column)

def decorator(function: Callable[..., T]) -> Callable[..., T]:
@wraps(function)
def wrapper(*args: str, **kwargs: str) -> Any:
return f.from_json(f.to_json(function(*args, **kwargs)), expected_schema)

return wrapper

return decorator
2 changes: 1 addition & 1 deletion tests/gentropy/common/test_session.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Tests GWAS Catalog study splitter."""
"""Tests Gentropy session."""

from __future__ import annotations

Expand Down
154 changes: 154 additions & 0 deletions tests/gentropy/common/test_spark_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
"""Tests spark-helper functions."""

from __future__ import annotations

import pytest
from gentropy.common.spark_helpers import (
enforce_schema,
order_array_of_structs_by_field,
)
from pyspark.sql import Column, SparkSession
from pyspark.sql import functions as f
from pyspark.sql import types as t


def test_order_array_of_structs_by_field(spark: SparkSession) -> None:
"""Test order_array_of_structs_by_field."""
data = [
# Values are the same:
("a", 12),
("a", 12),
# First value bigger:
("b", 12),
("b", 1),
# Second value bigger:
("c", 1),
("c", 12),
# First value is null:
("d", None),
("d", 12),
# Second value is null:
("e", 12),
("e", None),
# Both values are null:
("f", None),
("f", None),
]

processed_data = (
spark.createDataFrame(data, ["group", "value"])
.groupBy("group")
.agg(
f.collect_list(f.struct(f.col("value").alias("value"))).alias("values"),
f.max(f.col("value")).alias("max_value"),
)
.withColumn("sorted_values", order_array_of_structs_by_field("values", "value"))
.withColumn("sorted_max", f.col("sorted_values")[0].getField("value"))
.select("max_value", "sorted_max")
.collect()
)

for row in processed_data:
assert row["max_value"] == row["sorted_max"]


class TestEnforceSchema:
"""Test enforce schema."""

EXPECTED_SCHEMA = t.StructType(
[
t.StructField("field1", t.StringType(), True),
t.StructField("field2", t.StringType(), True),
t.StructField("field3", t.StringType(), True),
t.StructField("field4", t.FloatType(), True),
]
)

@staticmethod
@enforce_schema(expected_schema=EXPECTED_SCHEMA)
def good_schema_test() -> Column:
"""Create a struct with the expected schema."""
return f.struct(
f.lit("test1").alias("field1"),
f.lit("test2").alias("field2"),
f.lit("test3").alias("field3"),
f.lit(2.0).alias("field4"),
)

@staticmethod
@enforce_schema(expected_schema=EXPECTED_SCHEMA)
def missing_column_test() -> Column:
"""Create a struct with a missing column."""
return f.struct(
f.lit("test1").alias("field1"),
f.lit("test3").alias("field3"),
)

@staticmethod
@enforce_schema(expected_schema=EXPECTED_SCHEMA)
def wrong_order_test() -> Column:
"""Create a struct with the wrong order."""
return f.struct(
f.lit("test2").alias("field2"),
f.lit("test1").alias("field1"),
)

@staticmethod
@enforce_schema(expected_schema=EXPECTED_SCHEMA)
def extra_column_test() -> Column:
"""Create a struct with an extra column."""
return f.struct(
f.lit("test2").alias("field2"),
f.lit("test1").alias("field1"),
f.lit("test5").alias("field5"),
f.lit(12.1).alias("field6"),
)

@staticmethod
@enforce_schema(expected_schema=EXPECTED_SCHEMA)
def wrong_type_test_1() -> Column:
"""Create a struct with the wrong type."""
return f.struct(
f.lit("test2").alias("field2"),
f.lit("test1").alias("field1"),
f.lit(5).cast(t.IntegerType()).alias("field3"),
)

@staticmethod
@enforce_schema(expected_schema=EXPECTED_SCHEMA)
def wrong_type_test_2() -> Column:
"""Create a struct with the wrong type."""
return f.struct(
f.lit("test2").alias("field2"),
f.lit("test1").alias("field1"),
f.lit("test").alias("field4"),
)

@pytest.fixture(autouse=True)
def _setup(self: TestEnforceSchema, spark: SparkSession) -> None:
"""Setup fixture."""
self.test_dataset = (
spark.createDataFrame(
[("a",)],
["label"],
)
.withColumn("struct_1", self.good_schema_test())
.withColumn("struct_2", self.missing_column_test())
.withColumn("struct_3", self.wrong_order_test())
.withColumn("struct_4", self.extra_column_test())
.withColumn("struct_5", self.wrong_type_test_1())
.withColumn("struct_6", self.wrong_type_test_2())
)

def test_schema_consistency(self: TestEnforceSchema) -> None:
"""Test enforce schema consistency."""
# Looping through all the struct column and test if the schema is consistent
for column in [
"struct_1",
"struct_2",
"struct_3",
"struct_4",
"struct_5",
"struct_6",
]:
assert self.test_dataset.schema[column].dataType == self.EXPECTED_SCHEMA

0 comments on commit 689340c

Please sign in to comment.