Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix(9559)] - Validation fails for enum field with decimal type #1324

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ docs/_build/
htmlcov/
node_modules/

.venv
/.benchmarks/
/.idea/
/.pytest_cache/
Expand Down
13 changes: 12 additions & 1 deletion src/validators/enum_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ pub trait EnumValidateValue: std::fmt::Debug + Clone + Send + Sync {
py: Python<'py>,
input: &I,
lookup: &LiteralLookup<PyObject>,
class: &Py<PyType>,
strict: bool,
) -> ValResult<Option<PyObject>>;
}
Expand Down Expand Up @@ -116,7 +117,7 @@ impl<T: EnumValidateValue> Validator for EnumValidator<T> {
},
input,
));
} else if let Some(v) = T::validate_value(py, input, &self.lookup, strict)? {
} else if let Some(v) = T::validate_value(py, input, &self.lookup, &self.class, strict)? {
state.floor_exactness(Exactness::Lax);
return Ok(v);
} else if let Some(ref missing) = self.missing {
Comment on lines +120 to 123
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So now that I look at this new point in the diff, I see that (and I regret forgetting this) that we already call _missing_ here. But as we observe in the issue and this PR, simply calling _missing_ is not enough because enum __new__ has more complex logic which isn't encapsulated purely by _missing_.

I wonder if there is a case to have a new branch (before or after this one? not sure 🤔) which calls the enum type, with the logic going here instead of in PlainEnumValidator. Putting the logic here would also solve the special-cased enums like IntEnum, I think.

That does beg the question, though: if we add the case of calling the enum type here, do we need logic for _missing_ at all? My intuition is that we don't, and we should try to phase out the _missing_ logic.

Expand Down Expand Up @@ -167,6 +168,7 @@ impl EnumValidateValue for PlainEnumValidator {
py: Python<'py>,
input: &I,
lookup: &LiteralLookup<PyObject>,
class: &Py<PyType>,
strict: bool,
) -> ValResult<Option<PyObject>> {
match lookup.validate(py, input)? {
Expand All @@ -183,8 +185,14 @@ impl EnumValidateValue for PlainEnumValidator {
} else if py_input.is_instance_of::<PyFloat>() {
return Ok(lookup.validate_int(py, input, false)?.map(|v| v.clone_ref(py)));
}
if py_input.is_instance_of::<PyAny>() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_instance_of::<PyAny>() I think will always be true, will probably be optimized away by the compiler but also not necessary at all IMO.

Suggested change
if py_input.is_instance_of::<PyAny>() {

if let Ok(res) = class.call1(py, (py_input,)) {
return Ok(Some(res));
}
}
}
}

Ok(None)
}
}
Expand All @@ -201,6 +209,7 @@ impl EnumValidateValue for IntEnumValidator {
py: Python<'py>,
input: &I,
lookup: &LiteralLookup<PyObject>,
_class: &Py<PyType>,
strict: bool,
) -> ValResult<Option<PyObject>> {
Ok(lookup.validate_int(py, input, strict)?.map(|v| v.clone_ref(py)))
Expand All @@ -217,6 +226,7 @@ impl EnumValidateValue for StrEnumValidator {
py: Python,
input: &I,
lookup: &LiteralLookup<PyObject>,
_class: &Py<PyType>,
strict: bool,
) -> ValResult<Option<PyObject>> {
Ok(lookup.validate_str(input, strict)?.map(|v| v.clone_ref(py)))
Expand All @@ -233,6 +243,7 @@ impl EnumValidateValue for FloatEnumValidator {
py: Python<'py>,
input: &I,
lookup: &LiteralLookup<PyObject>,
_class: &Py<PyType>,
strict: bool,
) -> ValResult<Option<PyObject>> {
Ok(lookup.validate_float(py, input, strict)?.map(|v| v.clone_ref(py)))
Expand Down
1 change: 1 addition & 0 deletions src/validators/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ impl Validator for GeneratorValidator {
hide_input_in_errors: self.hide_input_in_errors,
validation_error_cause: self.validation_error_cause,
};

Ok(v_iterator.into_py(py))
}

Expand Down
2 changes: 2 additions & 0 deletions src/validators/literal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ impl<T: Debug> LiteralLookup<T> {
}
}
}

if let Some(expected_strings) = &self.expected_str {
let validation_result = if input.as_python().is_some() {
input.exact_str()
Expand Down Expand Up @@ -163,6 +164,7 @@ impl<T: Debug> LiteralLookup<T> {
}
}
};

Ok(None)
}

Expand Down
1 change: 0 additions & 1 deletion src/validators/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,6 @@ impl ModelValidator {
.map_err(|e| convert_err(py, e, input));
}
}

let output = self.validator.validate(py, input, state)?;

let instance = create_class(self.class.bind(py))?;
Expand Down
128 changes: 128 additions & 0 deletions tests/validators/test_enums.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import re
import sys
from decimal import Decimal
from enum import Enum, IntEnum, IntFlag

import pytest
Expand Down Expand Up @@ -344,3 +345,130 @@ class ColorEnum(IntEnum):

assert v.validate_python(ColorEnum.GREEN) is ColorEnum.GREEN
assert v.validate_python(1 << 63) is ColorEnum.GREEN


@pytest.mark.parametrize(
'value',
[-1, 0, 1],
)
def test_enum_int_validation_should_succeed_for_decimal(value: int):
# GIVEN
class MyEnum(Enum):
VALUE = value

class MyIntEnum(IntEnum):
VALUE = value

# WHEN
v = SchemaValidator(
core_schema.with_default_schema(
schema=core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())),
default=MyEnum.VALUE,
)
)

v_int = SchemaValidator(
core_schema.with_default_schema(
schema=core_schema.enum_schema(MyIntEnum, list(MyIntEnum.__members__.values())),
default=MyIntEnum.VALUE,
)
)

# THEN
assert v.validate_python(Decimal(value)) is MyEnum.VALUE
assert v.validate_python(Decimal(float(value))) is MyEnum.VALUE

assert v_int.validate_python(Decimal(value)) is MyIntEnum.VALUE
assert v_int.validate_python(Decimal(float(value))) is MyIntEnum.VALUE


def test_enum_int_validation_should_succeed_for_custom_type():
# GIVEN
class AnyWrapper:
def __init__(self, value):
self.value = value

def __eq__(self, other: object) -> bool:
return self.value == other

class MyEnum(Enum):
VALUE = 999
SECOND_VALUE = 1000000
THIRD_VALUE = 'Py03'

# WHEN
v = SchemaValidator(
core_schema.with_default_schema(
schema=core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())),
default=MyEnum.VALUE,
)
)

# THEN
assert v.validate_python(AnyWrapper(999)) is MyEnum.VALUE
assert v.validate_python(AnyWrapper(1000000)) is MyEnum.SECOND_VALUE
assert v.validate_python(AnyWrapper('Py03')) is MyEnum.THIRD_VALUE


def test_enum_str_validation_should_fail_for_decimal_when_expecting_str_value():
# GIVEN
class MyEnum(Enum):
VALUE = '1'

# WHEN
v = SchemaValidator(
core_schema.with_default_schema(
schema=core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())),
default=MyEnum.VALUE,
)
)

# THEN
with pytest.raises(ValidationError):
v.validate_python(Decimal(1))


def test_enum_int_validation_should_fail_for_incorrect_decimal_value():
# GIVEN
class MyEnum(Enum):
VALUE = 1

# WHEN
v = SchemaValidator(
core_schema.with_default_schema(
schema=core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())),
default=MyEnum.VALUE,
)
)

# THEN
with pytest.raises(ValidationError):
v.validate_python(Decimal(2))

with pytest.raises(ValidationError):
v.validate_python((1, 2))

with pytest.raises(ValidationError):
v.validate_python(Decimal(1.1))


def test_enum_int_validation_should_fail_for_plain_type_without_eq_checking():
# GIVEN
class MyEnum(Enum):
VALUE = 1
Comment on lines +456 to +458
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realise now that we probably also want to test this with e.g. MyEnum(IntEnum), which I think goes through a separate code pathway but probably was also broken when we moved enum validation to Rust?


class MyClass:
def __init__(self, value):
self.value = value

# WHEN
v = SchemaValidator(
core_schema.with_default_schema(
schema=core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())),
default=MyEnum.VALUE,
)
)

# THEN
with pytest.raises(ValidationError):
v.validate_python(MyClass(1))
60 changes: 60 additions & 0 deletions tests/validators/test_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import re
from copy import deepcopy
from decimal import Decimal
from typing import Any, Callable, Dict, List, Set, Tuple

import pytest
Expand Down Expand Up @@ -1312,3 +1313,62 @@ class OtherModel:
'ctx': {'class_name': 'MyModel'},
}
]


def test_model_with_enum_int_field_validation_should_succeed_for_any_type_equality_checks():
# GIVEN
from enum import Enum

class EnumClass(Enum):
enum_value = 1
enum_value_2 = 2
enum_value_3 = 3

class IntWrappable:
def __init__(self, value: int):
self.value = value

def __eq__(self, value: object) -> bool:
return self.value == value

class MyModel:
__slots__ = (
'__dict__',
'__pydantic_fields_set__',
'__pydantic_extra__',
'__pydantic_private__',
)
enum_field: EnumClass

# WHEN
v = SchemaValidator(
core_schema.model_schema(
MyModel,
core_schema.model_fields_schema(
{
'enum_field': core_schema.model_field(
core_schema.enum_schema(EnumClass, list(EnumClass.__members__.values()))
),
'enum_field_2': core_schema.model_field(
core_schema.enum_schema(EnumClass, list(EnumClass.__members__.values()))
),
'enum_field_3': core_schema.model_field(
core_schema.enum_schema(EnumClass, list(EnumClass.__members__.values()))
),
}
),
)
)

# THEN
v.validate_json('{"enum_field": 1, "enum_field_2": 2, "enum_field_3": 3}')
m = v.validate_python(
{
'enum_field': Decimal(1),
'enum_field_2': Decimal(2),
'enum_field_3': IntWrappable(3),
}
)
v.validate_assignment(m, 'enum_field', Decimal(1))
v.validate_assignment(m, 'enum_field_2', Decimal(2))
v.validate_assignment(m, 'enum_field_3', IntWrappable(3))