Skip to content

Commit

Permalink
feat: add IS_NOT_NULL operator to filters
Browse files Browse the repository at this point in the history
* unit tests
  • Loading branch information
mgraczyk committed Dec 19, 2024
1 parent a1596a3 commit 503a1ce
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 15 deletions.
34 changes: 23 additions & 11 deletions google/cloud/firestore_v1/base_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,12 @@


_BAD_DIR_STRING: str
_BAD_OP_NAN_NULL: str
_BAD_OP_NAN: str
_BAD_OP_NULL: str
_BAD_OP_STRING: str
_COMPARISON_OPERATORS: Dict[str, Any]
_EQ_OP: str
_NEQ_OP: str
_INVALID_CURSOR_TRANSFORM: str
_INVALID_WHERE_TRANSFORM: str
_MISMATCH_CURSOR_W_ORDER_BY: str
Expand All @@ -80,12 +82,13 @@


_EQ_OP = "=="
_NEQ_OP = "!="
_operator_enum = StructuredQuery.FieldFilter.Operator
_COMPARISON_OPERATORS = {
"<": _operator_enum.LESS_THAN,
"<=": _operator_enum.LESS_THAN_OR_EQUAL,
_EQ_OP: _operator_enum.EQUAL,
"!=": _operator_enum.NOT_EQUAL,
_NEQ_OP: _operator_enum.NOT_EQUAL,
">=": _operator_enum.GREATER_THAN_OR_EQUAL,
">": _operator_enum.GREATER_THAN,
"array_contains": _operator_enum.ARRAY_CONTAINS,
Expand All @@ -104,7 +107,8 @@
_operator_enum.NOT_IN,
)
_BAD_OP_STRING = "Operator string {!r} is invalid. Valid choices are: {}."
_BAD_OP_NAN_NULL = 'Only an equality filter ("==") can be used with None or NaN values'
_BAD_OP_NAN = 'Only an equality filter ("==") can be used with NaN values'
_BAD_OP_NULL = 'Only equality ("==") or not-equal ("!=") filters can be used with None values'
_INVALID_WHERE_TRANSFORM = "Transforms cannot be used as where values."
_BAD_DIR_STRING = "Invalid direction {!r}. Must be one of {!r} or {!r}."
_INVALID_CURSOR_TRANSFORM = "Transforms cannot be used as cursor values."
Expand Down Expand Up @@ -144,13 +148,16 @@ def __init__(self, field_path, op_string, value=None):
self.value = value

if value is None:
if op_string != _EQ_OP:
raise ValueError(_BAD_OP_NAN_NULL)
self.op_string = StructuredQuery.UnaryFilter.Operator.IS_NULL
if op_string == _EQ_OP:
self.op_string = StructuredQuery.UnaryFilter.Operator.IS_NULL
elif op_string == _NEQ_OP:
self.op_string = StructuredQuery.UnaryFilter.Operator.IS_NOT_NULL
else:
raise ValueError(_BAD_OP_NULL)

elif _isnan(value):
if op_string != _EQ_OP:
raise ValueError(_BAD_OP_NAN_NULL)
raise ValueError(_BAD_OP_NAN)
self.op_string = StructuredQuery.UnaryFilter.Operator.IS_NAN
elif isinstance(value, (transforms.Sentinel, transforms._ValueList)):
raise ValueError(_INVALID_WHERE_TRANSFORM)
Expand Down Expand Up @@ -479,15 +486,20 @@ def where(
stacklevel=2,
)
if value is None:
if op_string != _EQ_OP:
raise ValueError(_BAD_OP_NAN_NULL)
if op_string == _EQ_OP:
op = StructuredQuery.UnaryFilter.Operator.IS_NULL
elif op_string == _NEQ_OP:
op = StructuredQuery.UnaryFilter.Operator.IS_NOT_NULL
else:
raise ValueError(_BAD_OP_NULL)

filter_pb = query.StructuredQuery.UnaryFilter(
field=query.StructuredQuery.FieldReference(field_path=field_path),
op=StructuredQuery.UnaryFilter.Operator.IS_NULL,
op=op
)
elif _isnan(value):
if op_string != _EQ_OP:
raise ValueError(_BAD_OP_NAN_NULL)
raise ValueError(_BAD_OP_NAN)
filter_pb = query.StructuredQuery.UnaryFilter(
field=query.StructuredQuery.FieldReference(field_path=field_path),
op=StructuredQuery.UnaryFilter.Operator.IS_NAN,
Expand Down
22 changes: 18 additions & 4 deletions tests/unit/v1/test_base_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,20 @@ def test_basequery_where_eq_null(unary_helper_function):
unary_helper_function(None, op_enum)


@pytest.mark.parametrize(
"unary_helper_function",
[
(_where_unary_helper),
(_where_unary_helper_field_filter),
],
)
def test_basequery_where_neq_null(unary_helper_function):
from google.cloud.firestore_v1.types import StructuredQuery

op_enum = StructuredQuery.UnaryFilter.Operator.IS_NOT_NULL
unary_helper_function(None, op_enum, op_string="!=")


@pytest.mark.parametrize(
"unary_helper_function",
[
Expand All @@ -309,11 +323,11 @@ def test_basequery_where_eq_null(unary_helper_function):
],
)
def test_basequery_where_gt_null(unary_helper_function):
from google.cloud.firestore_v1.base_query import _BAD_OP_NAN_NULL
from google.cloud.firestore_v1.base_query import _BAD_OP_NULL

with pytest.raises(ValueError) as exc:
unary_helper_function(None, 0, op_string=">")
assert str(exc.value) == _BAD_OP_NAN_NULL
assert str(exc.value) == _BAD_OP_NULL


@pytest.mark.parametrize(
Expand All @@ -338,11 +352,11 @@ def test_basequery_where_eq_nan(unary_helper_function):
],
)
def test_basequery_where_le_nan(unary_helper_function):
from google.cloud.firestore_v1.base_query import _BAD_OP_NAN_NULL
from google.cloud.firestore_v1.base_query import _BAD_OP_NAN

with pytest.raises(ValueError) as exc:
unary_helper_function(float("nan"), 0, op_string="<=")
assert str(exc.value) == _BAD_OP_NAN_NULL
assert str(exc.value) == _BAD_OP_NAN


@pytest.mark.parametrize(
Expand Down

0 comments on commit 503a1ce

Please sign in to comment.