Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix(9559)] - Validation fails for enum field with decimal type #1324

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ docs/_build/
htmlcov/
node_modules/

.venv
/.benchmarks/
/.idea/
/.pytest_cache/
Expand Down
2 changes: 2 additions & 0 deletions src/input/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ pub(crate) use return_enums::{
EitherInt, EitherString, GenericIterator, Int, MaxLengthCheck, ValidationMatch,
};

pub(crate) use shared::decimal_as_int;

// Defined here as it's not exported by pyo3
pub fn py_error_on_minusone(py: Python<'_>, result: c_int) -> PyResult<()> {
if result != -1 {
Expand Down
26 changes: 25 additions & 1 deletion src/validators/decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::errors::ErrorType;
use crate::errors::ValResult;
use crate::errors::{ErrorTypeDefaults, Number};
use crate::errors::{ToErrorValue, ValError};
use crate::input::Input;
use crate::input::{decimal_as_int, EitherInt, Input};
use crate::tools::SchemaDict;

use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator};
Expand Down Expand Up @@ -288,3 +288,27 @@ fn handle_decimal_new_error(input: impl ToErrorValue, error: PyErr, decimal_exce
ValError::InternalErr(error)
}
}

pub(crate) fn try_from_decimal_to_int<'a, 'py, I: Input<'py> + ?Sized>(
py: Python<'py>,
input: &'a I,
) -> ValResult<i64> {
let Some(py_input) = input.as_python() else {
return Err(ValError::new(ErrorTypeDefaults::DecimalType, input));
};

if let Ok(false) = py_input.is_instance(get_decimal_type(py)) {
return Err(ValError::new(ErrorTypeDefaults::DecimalType, input));
}

let dec_value = match decimal_as_int(input, py_input)? {
EitherInt::Py(value) => value,
_ => return Err(ValError::new(ErrorType::DecimalParsing { context: None }, input)),
};

let either_int = dec_value.exact_int()?;

let int = either_int.into_i64(py)?;

Ok(int)
}
2 changes: 1 addition & 1 deletion src/validators/enum_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ impl EnumValidateValue for PlainEnumValidator {
lookup: &LiteralLookup<PyObject>,
strict: bool,
) -> ValResult<Option<PyObject>> {
match lookup.validate(py, input)? {
match lookup.validate(py, input, strict)? {
Some((_, v)) => Ok(Some(v.clone_ref(py))),
None => {
if !strict {
Expand Down
22 changes: 21 additions & 1 deletion src/validators/literal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use crate::input::{Input, ValidationMatch};
use crate::py_gc::PyGcTraverse;
use crate::tools::SchemaDict;

use super::decimal::try_from_decimal_to_int;
use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator};

#[derive(Debug, Clone, Default)]
Expand Down Expand Up @@ -104,6 +105,7 @@ impl<T: Debug> LiteralLookup<T> {
&self,
py: Python<'py>,
input: &'a I,
strict: bool,
) -> ValResult<Option<(&'a I, &T)>> {
if let Some(expected_bool) = &self.expected_bool {
if let Ok(bool_value) = input.validate_bool(true) {
Expand All @@ -123,7 +125,15 @@ impl<T: Debug> LiteralLookup<T> {
return Ok(Some((input, &self.values[*id])));
}
}
// if the input is a Decimal type, we need to check if its value is in the expected_ints
if let Ok(value) = try_from_decimal_to_int(py, input) {
let Some(id) = expected_ints.get(&value) else {
return Ok(None);
};
return Ok(Some((input, &self.values[*id])));
}
}

if let Some(expected_strings) = &self.expected_str {
let validation_result = if input.as_python().is_some() {
input.exact_str()
Expand All @@ -142,6 +152,15 @@ impl<T: Debug> LiteralLookup<T> {
return Ok(Some((input, &self.values[*id])));
}
}
if !strict {
// if the input is a Decimal type, we need to check if its value is in the expected_ints
if let Ok(value) = try_from_decimal_to_int(py, input) {
let Some(id) = expected_strings.get(&value.to_string()) else {
return Ok(None);
};
return Ok(Some((input, &self.values[*id])));
}
}
}
if let Some(expected_py_dict) = &self.expected_py_dict {
// We don't use ? to unpack the result of `get_item` in the next line because unhashable
Expand All @@ -163,6 +182,7 @@ impl<T: Debug> LiteralLookup<T> {
}
}
};

Ok(None)
}

Expand Down Expand Up @@ -269,7 +289,7 @@ impl Validator for LiteralValidator {
input: &(impl Input<'py> + ?Sized),
_state: &mut ValidationState<'_, 'py>,
) -> ValResult<PyObject> {
match self.lookup.validate(py, input)? {
match self.lookup.validate(py, input, _state.strict_or(false))? {
Some((_, v)) => Ok(v.clone()),
None => Err(ValError::new(
ErrorType::LiteralError {
Expand Down
2 changes: 1 addition & 1 deletion src/validators/union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ impl TaggedUnionValidator {
input: &(impl Input<'py> + ?Sized),
state: &mut ValidationState<'_, 'py>,
) -> ValResult<PyObject> {
if let Ok(Some((tag, validator))) = self.lookup.validate(py, tag) {
if let Ok(Some((tag, validator))) = self.lookup.validate(py, tag, state.strict_or(false)) {
return match validator.validate(py, input, state) {
Ok(res) => Ok(res),
Err(err) => Err(err.with_outer_location(tag)),
Expand Down
105 changes: 105 additions & 0 deletions tests/validators/test_enums.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import re
import sys
from decimal import Decimal
from enum import Enum, IntEnum, IntFlag

import pytest
Expand Down Expand Up @@ -344,3 +345,107 @@ class ColorEnum(IntEnum):

assert v.validate_python(ColorEnum.GREEN) is ColorEnum.GREEN
assert v.validate_python(1 << 63) is ColorEnum.GREEN


@pytest.mark.parametrize(
'value',
[-1, 0, 1],
)
def test_enum_int_validation_should_succeed_for_decimal(value: int):
# GIVEN
class MyEnum(Enum):
VALUE = value

class MyIntEnum(IntEnum):
VALUE = value

# WHEN
v = SchemaValidator(
core_schema.with_default_schema(
schema=core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())),
default=MyEnum.VALUE,
)
)

v_int = SchemaValidator(
core_schema.with_default_schema(
schema=core_schema.enum_schema(MyIntEnum, list(MyIntEnum.__members__.values())),
default=MyIntEnum.VALUE,
)
)

# THEN
assert v.validate_python(Decimal(value)) is MyEnum.VALUE
assert v.validate_python(Decimal(float(value))) is MyEnum.VALUE

assert v_int.validate_python(Decimal(value)) is MyIntEnum.VALUE
assert v_int.validate_python(Decimal(float(value))) is MyIntEnum.VALUE


def test_enum_str_validation_should_succeed_for_decimal_with_strict_disabled():
# GIVEN
class MyEnum(Enum):
VALUE = '1'

# WHEN
v = SchemaValidator(
core_schema.with_default_schema(
schema=core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())),
default=MyEnum.VALUE,
)
)

# THEN
assert v.validate_python(Decimal(1)) is MyEnum.VALUE


def test_enum_str_validation_should_fail_for_decimal_with_strict_enabled():
# GIVEN
class MyEnum(Enum):
VALUE = '1'

# WHEN
v = SchemaValidator(
core_schema.with_default_schema(
schema=core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values()), strict=True),
default=MyEnum.VALUE,
)
)

# THEN
with pytest.raises(ValidationError):
v.validate_python(Decimal(1))


def test_enum_int_validation_should_fail_for_incorrect_decimal_value():
# GIVEN
class MyEnum(Enum):
VALUE = 1

class MyStrEnum(Enum):
VALUE = '2'

# WHEN
v = SchemaValidator(
core_schema.with_default_schema(
schema=core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())),
default=MyEnum.VALUE,
)
)

v_str = SchemaValidator(
core_schema.with_default_schema(
schema=core_schema.enum_schema(MyStrEnum, list(MyStrEnum.__members__.values())),
default=MyStrEnum.VALUE,
)
)

# THEN
with pytest.raises(ValidationError):
v.validate_python(Decimal(2))

with pytest.raises(ValidationError):
v.validate_python((1, 2))

with pytest.raises(ValidationError):
v_str.validate_python(Decimal(1))
38 changes: 38 additions & 0 deletions tests/validators/test_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import re
from copy import deepcopy
from decimal import Decimal
from typing import Any, Callable, Dict, List, Set, Tuple

import pytest
Expand Down Expand Up @@ -1312,3 +1313,40 @@ class OtherModel:
'ctx': {'class_name': 'MyModel'},
}
]


def test_model_with_enum_int_field_validation_should_succeed_for_decimal():
from enum import Enum

class EnumClass(Enum):
enum_value = 1
enum_value_2 = 2

class MyModel:
__slots__ = (
'__dict__',
'__pydantic_fields_set__',
'__pydantic_extra__',
'__pydantic_private__',
)
enum_field: EnumClass

v = SchemaValidator(
core_schema.model_schema(
MyModel,
core_schema.model_fields_schema(
{
'enum_field': core_schema.model_field(
core_schema.enum_schema(EnumClass, list(EnumClass.__members__.values()))
),
'enum_field_2': core_schema.model_field(
core_schema.enum_schema(EnumClass, list(EnumClass.__members__.values()))
),
}
),
)
)
v.validate_json('{"enum_field": 1, "enum_field_2": 2}')
m = v.validate_python({'enum_field': Decimal(1), 'enum_field_2': Decimal(2)})
v.validate_assignment(m, 'enum_field', Decimal(1))
v.validate_assignment(m, 'enum_field_2', Decimal(2))