-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix(schema): recursive validation of arbitrarily deep nested structure #790
Changes from 12 commits
41aa006
781c130
a844c64
d9e554b
f703d47
952a6c3
79a9f0b
0075406
9035c84
6c6918f
77eb0b8
5b3323e
e209fc0
31e2da2
a94d8ca
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -1,36 +1,65 @@ | ||||||
"""Methods for handling schemas.""" | ||||||
|
||||||
from __future__ import annotations | ||||||
|
||||||
import importlib.resources as pkg_resources | ||||||
import json | ||||||
from collections import namedtuple | ||||||
from collections import defaultdict, namedtuple | ||||||
from typing import Any | ||||||
|
||||||
import pyspark.sql.types as t | ||||||
from pyspark.sql.types import ArrayType, StructType | ||||||
|
||||||
from gentropy.assets import schemas | ||||||
|
||||||
|
||||||
def parse_spark_schema(schema_json: str) -> t.StructType: | ||||||
class SchemaValidationError(Exception): | ||||||
"""This exception is raised when a schema validation fails.""" | ||||||
|
||||||
def __init__( | ||||||
self: SchemaValidationError, message: str, errors: defaultdict[str, list[str]] | ||||||
) -> None: | ||||||
"""Initialize the SchemaValidationError. | ||||||
|
||||||
Args: | ||||||
message (str): The message to be displayed. | ||||||
errors (defaultdict[str, list[str]]): The collection of observed discrepancies | ||||||
""" | ||||||
super().__init__(message) | ||||||
self.message = message # Explicitly set the message attribute | ||||||
self.errors = errors | ||||||
|
||||||
def __str__(self: SchemaValidationError) -> str: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the method printed when you raise the exception, is that right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, exactly. |
||||||
"""Return a string representation of the exception. | ||||||
|
||||||
Returns: | ||||||
str: The string representation of the exception. | ||||||
""" | ||||||
stringified_errors = "\n ".join( | ||||||
[f'{k}: {",".join(v)}' for k, v in self.errors.items()] | ||||||
) | ||||||
return f"{self.message}\nErrors:\n {stringified_errors}" | ||||||
|
||||||
|
||||||
def parse_spark_schema(schema_json: str) -> StructType: | ||||||
"""Parse Spark schema from JSON. | ||||||
|
||||||
Args: | ||||||
schema_json (str): JSON filename containing spark schema in the schemas package | ||||||
|
||||||
Returns: | ||||||
t.StructType: Spark schema | ||||||
StructType: Spark schema | ||||||
""" | ||||||
core_schema = json.loads( | ||||||
pkg_resources.read_text(schemas, schema_json, encoding="utf-8") | ||||||
) | ||||||
return t.StructType.fromJson(core_schema) | ||||||
return StructType.fromJson(core_schema) | ||||||
|
||||||
|
||||||
def flatten_schema(schema: t.StructType, prefix: str = "") -> list[Any]: | ||||||
def flatten_schema(schema: StructType, prefix: str = "") -> list[Any]: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since we are now parsing each schema without flattening, I'd suggest removing this function There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can drop it if you think this function would not be useful in other context. |
||||||
"""It takes a Spark schema and returns a list of all fields in the schema once flattened. | ||||||
|
||||||
Args: | ||||||
schema (t.StructType): The schema of the dataframe | ||||||
schema (StructType): The schema of the dataframe | ||||||
prefix (str): The prefix to prepend to the field names. Defaults to "". | ||||||
|
||||||
Returns: | ||||||
|
@@ -53,14 +82,170 @@ def flatten_schema(schema: t.StructType, prefix: str = "") -> list[Any]: | |||||
for field in schema.fields: | ||||||
name = f"{prefix}.{field.name}" if prefix else field.name | ||||||
dtype = field.dataType | ||||||
if isinstance(dtype, t.StructType): | ||||||
fields.append(Field(name, t.ArrayType(t.StructType()))) | ||||||
if isinstance(dtype, StructType): | ||||||
fields.append(Field(name, ArrayType(StructType()))) | ||||||
fields += flatten_schema(dtype, prefix=name) | ||||||
elif isinstance(dtype, t.ArrayType) and isinstance( | ||||||
dtype.elementType, t.StructType | ||||||
): | ||||||
fields.append(Field(name, t.ArrayType(t.StructType()))) | ||||||
elif isinstance(dtype, ArrayType) and isinstance(dtype.elementType, StructType): | ||||||
fields.append(Field(name, ArrayType(StructType()))) | ||||||
fields += flatten_schema(dtype.elementType, prefix=name) | ||||||
else: | ||||||
fields.append(Field(name, dtype)) | ||||||
return fields | ||||||
|
||||||
|
||||||
def compare_array_schemas( | ||||||
observed_schema: ArrayType, | ||||||
expected_schema: ArrayType, | ||||||
parent_field_name: str | None = None, | ||||||
schema_issues: defaultdict[str, list[str]] | None = None, | ||||||
) -> defaultdict[str, list[str]]: | ||||||
"""Compare two array schemas. | ||||||
|
||||||
The comparison is done recursively, so nested structs are also compared. | ||||||
|
||||||
Args: | ||||||
observed_schema (ArrayType): The observed schema. | ||||||
expected_schema (ArrayType): The expected schema. | ||||||
parent_field_name (str | None): The parent field name. Defaults to None. | ||||||
schema_issues (defaultdict[str, list[str]] | None): The schema issues. Defaults to None. | ||||||
|
||||||
Returns: | ||||||
defaultdict[str, list[str]]: The schema issues. | ||||||
""" | ||||||
# Create default values if not provided: | ||||||
if schema_issues is None: | ||||||
schema_issues = defaultdict(list) | ||||||
|
||||||
if parent_field_name is None: | ||||||
parent_field_name = "" | ||||||
|
||||||
observed_type = observed_schema.elementType.typeName() | ||||||
expected_type = expected_schema.elementType.typeName() | ||||||
|
||||||
# If element types are not matching, no further tests are needed: | ||||||
if observed_type != expected_type: | ||||||
schema_issues["columns_with_non_matching_type"].append( | ||||||
f'For column "{parent_field_name}[]" found {observed_type} instead of {expected_type}' | ||||||
) | ||||||
|
||||||
# If element type is a struct, resolve nesting: | ||||||
elif observed_type == "struct": | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To enforce both types are the same
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree, explicit is better than implicit. (the equality of the two schemas were tested already) |
||||||
schema_issues = compare_struct_schemas( | ||||||
observed_schema.elementType, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My mypy is raising an issue here because both the |
||||||
expected_schema.elementType, | ||||||
f"{parent_field_name}[].", | ||||||
schema_issues, | ||||||
) | ||||||
|
||||||
# If element type is an array, resolve nesting: | ||||||
elif observed_type == "array": | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To enforce both types are the same
Suggested change
|
||||||
schema_issues = compare_array_schemas( | ||||||
observed_schema.elementType, | ||||||
expected_schema.elementType, | ||||||
parent_field_name, | ||||||
schema_issues, | ||||||
) | ||||||
|
||||||
return schema_issues | ||||||
|
||||||
|
||||||
def compare_struct_schemas( | ||||||
observed_schema: StructType, | ||||||
expected_schema: StructType, | ||||||
parent_field_name: str | None = None, | ||||||
schema_issues: defaultdict[str, list[str]] | None = None, | ||||||
) -> defaultdict[str, list[str]]: | ||||||
"""Compare two struct schemas. | ||||||
|
||||||
The comparison is done recursively, so nested structs are also compared. | ||||||
|
||||||
Checking logic: | ||||||
1. Checking for duplicated columns in the observed schema. | ||||||
2. Checking for missing mandatory columns in the observed schema. | ||||||
3. Now we know that all mandatory columns are present, we can iterate over the observed schema and compare the types. | ||||||
4. Flagging unexpected columns in the observed schema. | ||||||
5. Flagging columns with non-matching types. | ||||||
6. If a column is a struct -> call compare_struct_schemas | ||||||
7. If a column is an array -> call compare_array_schemas | ||||||
8. Return dictionary with issues. | ||||||
|
||||||
Args: | ||||||
observed_schema (StructType): The observed schema. | ||||||
expected_schema (StructType): The expected schema. | ||||||
parent_field_name (str | None): The parent field name. Defaults to None. | ||||||
schema_issues (defaultdict[str, list[str]] | None): The schema issues. Defaults to None. | ||||||
|
||||||
Returns: | ||||||
defaultdict[str, list[str]]: The schema issues. | ||||||
""" | ||||||
# Create default values if not provided: | ||||||
if schema_issues is None: | ||||||
schema_issues = defaultdict(list) | ||||||
|
||||||
if parent_field_name is None: | ||||||
parent_field_name = "" | ||||||
|
||||||
# Flagging duplicated columns if present: | ||||||
if duplicated_columns := list( | ||||||
{ | ||||||
f"{parent_field_name}{field.name}" | ||||||
for field in observed_schema | ||||||
if list(observed_schema).count(field) > 1 | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Really nice |
||||||
} | ||||||
): | ||||||
schema_issues["duplicated_columns"] = duplicated_columns | ||||||
|
||||||
# Testing mandatory fields: | ||||||
required_fields = [x.name for x in expected_schema if not x.nullable] | ||||||
if missing_required_fields := [ | ||||||
f"{parent_field_name}{req}" | ||||||
for req in required_fields | ||||||
if not any(field.name == req for field in observed_schema) | ||||||
]: | ||||||
schema_issues["missing_mandatory_columns"] = missing_required_fields | ||||||
|
||||||
# Converting schema to dictionaries for easier comparison: | ||||||
observed_schema_dict = {field.name: field for field in observed_schema} | ||||||
expected_schema_dict = {field.name: field for field in expected_schema} | ||||||
|
||||||
# Testing optional fields and types: | ||||||
for field_name, field in observed_schema_dict.items(): | ||||||
# Testing observed field name, if name is not matched, no further tests are needed: | ||||||
if field_name not in expected_schema_dict: | ||||||
schema_issues["unexpected_columns"].append( | ||||||
f"{parent_field_name}{field_name}" | ||||||
) | ||||||
continue | ||||||
|
||||||
# When we made sure the field is in both schemas, extracting field type information: | ||||||
observed_type = field.dataType | ||||||
observed_type_name = field.dataType.typeName() | ||||||
|
||||||
expected_type = expected_schema_dict[field_name].dataType | ||||||
expected_type_name = expected_schema_dict[field_name].dataType.typeName() | ||||||
|
||||||
# Flagging non-matching types if types don't match, jumping to next field: | ||||||
if observed_type_name != expected_type_name: | ||||||
schema_issues["columns_with_non_matching_type"].append( | ||||||
f'For column "{parent_field_name}{field_name}" found {observed_type_name} instead of {expected_type_name}' | ||||||
) | ||||||
continue | ||||||
|
||||||
# If column is a struct, resolve nesting: | ||||||
if observed_type_name == "struct": | ||||||
schema_issues = compare_struct_schemas( | ||||||
observed_type, | ||||||
expected_type, | ||||||
f"{parent_field_name}{field_name}.", | ||||||
schema_issues, | ||||||
) | ||||||
# If column is an array, resolve nesting: | ||||||
elif observed_type_name == "array": | ||||||
schema_issues = compare_array_schemas( | ||||||
observed_type, | ||||||
expected_type, | ||||||
f"{parent_field_name}{field_name}[]", | ||||||
schema_issues, | ||||||
) | ||||||
|
||||||
return schema_issues |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool. So the default dict acts as a dictionary that stores errors found in the schemas but whose keys are not predetermined.