Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(schema): recursive validation of arbitrarily deep nested structure #790

Merged
merged 15 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 198 additions & 13 deletions src/gentropy/common/schemas.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,65 @@
"""Methods for handling schemas."""

from __future__ import annotations

import importlib.resources as pkg_resources
import json
from collections import namedtuple
from collections import defaultdict, namedtuple
from typing import Any

import pyspark.sql.types as t
from pyspark.sql.types import ArrayType, StructType

from gentropy.assets import schemas


def parse_spark_schema(schema_json: str) -> t.StructType:
class SchemaValidationError(Exception):
"""This exception is raised when a schema validation fails."""

def __init__(
self: SchemaValidationError, message: str, errors: defaultdict[str, list[str]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool. So the default dict acts as a dictionary that stores errors found in the schemas but whose keys are not predetermined.

) -> None:
"""Initialize the SchemaValidationError.

Args:
message (str): The message to be displayed.
errors (defaultdict[str, list[str]]): The collection of observed discrepancies
"""
super().__init__(message)
self.message = message # Explicitly set the message attribute
self.errors = errors

def __str__(self: SchemaValidationError) -> str:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the method printed when you raise the exception, is that right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, exactly.

"""Return a string representation of the exception.

Returns:
str: The string representation of the exception.
"""
stringified_errors = "\n ".join(
[f'{k}: {",".join(v)}' for k, v in self.errors.items()]
)
return f"{self.message}\nErrors:\n {stringified_errors}"


def parse_spark_schema(schema_json: str) -> StructType:
"""Parse Spark schema from JSON.

Args:
schema_json (str): JSON filename containing spark schema in the schemas package

Returns:
t.StructType: Spark schema
StructType: Spark schema
"""
core_schema = json.loads(
pkg_resources.read_text(schemas, schema_json, encoding="utf-8")
)
return t.StructType.fromJson(core_schema)
return StructType.fromJson(core_schema)


def flatten_schema(schema: t.StructType, prefix: str = "") -> list[Any]:
def flatten_schema(schema: StructType, prefix: str = "") -> list[Any]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we are now parsing each schema without flattening, I'd suggest removing this function

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can drop it if you think this function would not be useful in other context.

"""It takes a Spark schema and returns a list of all fields in the schema once flattened.

Args:
schema (t.StructType): The schema of the dataframe
schema (StructType): The schema of the dataframe
prefix (str): The prefix to prepend to the field names. Defaults to "".

Returns:
Expand All @@ -53,14 +82,170 @@ def flatten_schema(schema: t.StructType, prefix: str = "") -> list[Any]:
for field in schema.fields:
name = f"{prefix}.{field.name}" if prefix else field.name
dtype = field.dataType
if isinstance(dtype, t.StructType):
fields.append(Field(name, t.ArrayType(t.StructType())))
if isinstance(dtype, StructType):
fields.append(Field(name, ArrayType(StructType())))
fields += flatten_schema(dtype, prefix=name)
elif isinstance(dtype, t.ArrayType) and isinstance(
dtype.elementType, t.StructType
):
fields.append(Field(name, t.ArrayType(t.StructType())))
elif isinstance(dtype, ArrayType) and isinstance(dtype.elementType, StructType):
fields.append(Field(name, ArrayType(StructType())))
fields += flatten_schema(dtype.elementType, prefix=name)
else:
fields.append(Field(name, dtype))
return fields


def compare_array_schemas(
observed_schema: ArrayType,
expected_schema: ArrayType,
parent_field_name: str | None = None,
schema_issues: defaultdict[str, list[str]] | None = None,
) -> defaultdict[str, list[str]]:
"""Compare two array schemas.

The comparison is done recursively, so nested structs are also compared.

Args:
observed_schema (ArrayType): The observed schema.
expected_schema (ArrayType): The expected schema.
parent_field_name (str | None): The parent field name. Defaults to None.
schema_issues (defaultdict[str, list[str]] | None): The schema issues. Defaults to None.

Returns:
defaultdict[str, list[str]]: The schema issues.
"""
# Create default values if not provided:
if schema_issues is None:
schema_issues = defaultdict(list)

if parent_field_name is None:
parent_field_name = ""

observed_type = observed_schema.elementType.typeName()
expected_type = expected_schema.elementType.typeName()

# If element types are not matching, no further tests are needed:
if observed_type != expected_type:
schema_issues["columns_with_non_matching_type"].append(
f'For column "{parent_field_name}[]" found {observed_type} instead of {expected_type}'
)

# If element type is a struct, resolve nesting:
elif observed_type == "struct":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To enforce both types are the same

Suggested change
elif observed_type == "struct":
elif observed_type == "struct" and expected_type == "struct":

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, explicit is better than implicit. (the equality of the two schemas were tested already)

schema_issues = compare_struct_schemas(
observed_schema.elementType,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My mypy is raising an issue here because both the observed_schema and expected_schema are technically of type DataType, whereas the parameters have to be StructType. Are you having issues as well? I think mypy would be able to resolve it if the conditional above was made using the elementType, not the name of the type, i.e. elif observed_schema.elementType == StructType(). Same applies below when you call compare_array_schemas

expected_schema.elementType,
f"{parent_field_name}[].",
schema_issues,
)

# If element type is an array, resolve nesting:
elif observed_type == "array":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To enforce both types are the same

Suggested change
elif observed_type == "array":
elif observed_type == "array" and expected_type == "array":

schema_issues = compare_array_schemas(
observed_schema.elementType,
expected_schema.elementType,
parent_field_name,
schema_issues,
)

return schema_issues


def compare_struct_schemas(
observed_schema: StructType,
expected_schema: StructType,
parent_field_name: str | None = None,
schema_issues: defaultdict[str, list[str]] | None = None,
) -> defaultdict[str, list[str]]:
"""Compare two struct schemas.

The comparison is done recursively, so nested structs are also compared.

Checking logic:
1. Checking for duplicated columns in the observed schema.
2. Checking for missing mandatory columns in the observed schema.
3. Now we know that all mandatory columns are present, we can iterate over the observed schema and compare the types.
4. Flagging unexpected columns in the observed schema.
5. Flagging columns with non-matching types.
6. If a column is a struct -> call compare_struct_schemas
7. If a column is an array -> call compare_array_schemas
8. Return dictionary with issues.

Args:
observed_schema (StructType): The observed schema.
expected_schema (StructType): The expected schema.
parent_field_name (str | None): The parent field name. Defaults to None.
schema_issues (defaultdict[str, list[str]] | None): The schema issues. Defaults to None.

Returns:
defaultdict[str, list[str]]: The schema issues.
"""
# Create default values if not provided:
if schema_issues is None:
schema_issues = defaultdict(list)

if parent_field_name is None:
parent_field_name = ""

# Flagging duplicated columns if present:
if duplicated_columns := list(
{
f"{parent_field_name}{field.name}"
for field in observed_schema
if list(observed_schema).count(field) > 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really nice

}
):
schema_issues["duplicated_columns"] = duplicated_columns

# Testing mandatory fields:
required_fields = [x.name for x in expected_schema if not x.nullable]
if missing_required_fields := [
f"{parent_field_name}{req}"
for req in required_fields
if not any(field.name == req for field in observed_schema)
]:
schema_issues["missing_mandatory_columns"] = missing_required_fields

# Converting schema to dictionaries for easier comparison:
observed_schema_dict = {field.name: field for field in observed_schema}
expected_schema_dict = {field.name: field for field in expected_schema}

# Testing optional fields and types:
for field_name, field in observed_schema_dict.items():
# Testing observed field name, if name is not matched, no further tests are needed:
if field_name not in expected_schema_dict:
schema_issues["unexpected_columns"].append(
f"{parent_field_name}{field_name}"
)
continue

# When we made sure the field is in both schemas, extracting field type information:
observed_type = field.dataType
observed_type_name = field.dataType.typeName()

expected_type = expected_schema_dict[field_name].dataType
expected_type_name = expected_schema_dict[field_name].dataType.typeName()

# Flagging non-matching types if types don't match, jumping to next field:
if observed_type_name != expected_type_name:
schema_issues["columns_with_non_matching_type"].append(
f'For column "{parent_field_name}{field_name}" found {observed_type_name} instead of {expected_type_name}'
)
continue

# If column is a struct, resolve nesting:
if observed_type_name == "struct":
schema_issues = compare_struct_schemas(
observed_type,
expected_type,
f"{parent_field_name}{field_name}.",
schema_issues,
)
# If column is an array, resolve nesting:
elif observed_type_name == "array":
schema_issues = compare_array_schemas(
observed_type,
expected_type,
f"{parent_field_name}{field_name}[]",
schema_issues,
)

return schema_issues
52 changes: 5 additions & 47 deletions src/gentropy/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from pyspark.sql.window import Window
from typing_extensions import Self

from gentropy.common.schemas import flatten_schema
from gentropy.common.schemas import SchemaValidationError, compare_struct_schemas

if TYPE_CHECKING:
from enum import Enum
Expand Down Expand Up @@ -142,57 +142,15 @@ def validate_schema(self: Dataset) -> None:
"""Validate DataFrame schema against expected class schema.

Raises:
ValueError: DataFrame schema is not valid
SchemaValidationError: If the DataFrame schema does not match the expected schema
"""
expected_schema = self._schema
expected_fields = flatten_schema(expected_schema)
observed_schema = self._df.schema
observed_fields = flatten_schema(observed_schema)

# Unexpected fields in dataset
if unexpected_field_names := [
x.name
for x in observed_fields
if x.name not in [y.name for y in expected_fields]
]:
raise ValueError(
f"The {unexpected_field_names} fields are not included in DataFrame schema: {expected_fields}"
)

# Required fields not in dataset
required_fields = [x.name for x in expected_schema if not x.nullable]
if missing_required_fields := [
req
for req in required_fields
if not any(field.name == req for field in observed_fields)
]:
raise ValueError(
f"The {missing_required_fields} fields are required but missing: {required_fields}"
)

# Fields with duplicated names
if duplicated_fields := [
x for x in set(observed_fields) if observed_fields.count(x) > 1
]:
raise ValueError(
f"The following fields are duplicated in DataFrame schema: {duplicated_fields}"
)

# Fields with different datatype
observed_field_types = {
field.name: type(field.dataType) for field in observed_fields
}
expected_field_types = {
field.name: type(field.dataType) for field in expected_fields
}
if fields_with_different_observed_datatype := [
name
for name, observed_type in observed_field_types.items()
if name in expected_field_types
and observed_type != expected_field_types[name]
]:
raise ValueError(
f"The following fields present differences in their datatypes: {fields_with_different_observed_datatype}."
if discrepancies := compare_struct_schemas(observed_schema, expected_schema):
raise SchemaValidationError(
f"Schema validation failed for {type(self).__name__}", discrepancies
)

def valid_rows(self: Self, invalid_flags: list[str], invalid: bool = False) -> Self:
Expand Down
Loading
Loading