diff --git a/google/cloud/firestore_v1/base_query.py b/google/cloud/firestore_v1/base_query.py index 3509bbf17..5a9efaf78 100644 --- a/google/cloud/firestore_v1/base_query.py +++ b/google/cloud/firestore_v1/base_query.py @@ -67,10 +67,12 @@ _BAD_DIR_STRING: str -_BAD_OP_NAN_NULL: str +_BAD_OP_NAN: str +_BAD_OP_NULL: str _BAD_OP_STRING: str _COMPARISON_OPERATORS: Dict[str, Any] _EQ_OP: str +_NEQ_OP: str _INVALID_CURSOR_TRANSFORM: str _INVALID_WHERE_TRANSFORM: str _MISMATCH_CURSOR_W_ORDER_BY: str @@ -80,12 +82,13 @@ _EQ_OP = "==" +_NEQ_OP = "!=" _operator_enum = StructuredQuery.FieldFilter.Operator _COMPARISON_OPERATORS = { "<": _operator_enum.LESS_THAN, "<=": _operator_enum.LESS_THAN_OR_EQUAL, _EQ_OP: _operator_enum.EQUAL, - "!=": _operator_enum.NOT_EQUAL, + _NEQ_OP: _operator_enum.NOT_EQUAL, ">=": _operator_enum.GREATER_THAN_OR_EQUAL, ">": _operator_enum.GREATER_THAN, "array_contains": _operator_enum.ARRAY_CONTAINS, @@ -104,7 +107,7 @@ _operator_enum.NOT_IN, ) _BAD_OP_STRING = "Operator string {!r} is invalid. Valid choices are: {}." -_BAD_OP_NAN_NULL = 'Only an equality filter ("==") can be used with None or NaN values' +_BAD_OP_NAN_NULL = 'Only equality ("==") or not-equal ("!=") filters can be used with None or NaN values' _INVALID_WHERE_TRANSFORM = "Transforms cannot be used as where values." _BAD_DIR_STRING = "Invalid direction {!r}. Must be one of {!r} or {!r}." _INVALID_CURSOR_TRANSFORM = "Transforms cannot be used as cursor values." @@ -136,26 +139,49 @@ def _to_pb(self): """Build the protobuf representation based on values in the filter""" +def _validate_opation(op_string, value): + """ + Given an input operator string (e.g, '!='), and a value (e.g. None), + ensure that the operator and value combination is valid, and return + an approproate new operator value. A new operator will be used if + the operaion is a comparison against Null or NaN + + Args: + op_string (Optional[str]): the requested operator + value (Any): the value the operator is acting on + Returns: + str | StructuredQuery.UnaryFilter.Operator: operator to use in requests + Raises: + ValueError: if the operator and value combination is invalid + """ + if value is None: + if op_string == _EQ_OP: + return StructuredQuery.UnaryFilter.Operator.IS_NULL + elif op_string == _NEQ_OP: + return StructuredQuery.UnaryFilter.Operator.IS_NOT_NULL + else: + raise ValueError(_BAD_OP_NAN_NULL) + + elif _isnan(value): + if op_string == _EQ_OP: + return StructuredQuery.UnaryFilter.Operator.IS_NAN + elif op_string == _NEQ_OP: + return StructuredQuery.UnaryFilter.Operator.IS_NOT_NAN + else: + raise ValueError(_BAD_OP_NAN_NULL) + elif isinstance(value, (transforms.Sentinel, transforms._ValueList)): + raise ValueError(_INVALID_WHERE_TRANSFORM) + else: + return op_string + + class FieldFilter(BaseFilter): """Class representation of a Field Filter.""" def __init__(self, field_path, op_string, value=None): self.field_path = field_path self.value = value - - if value is None: - if op_string != _EQ_OP: - raise ValueError(_BAD_OP_NAN_NULL) - self.op_string = StructuredQuery.UnaryFilter.Operator.IS_NULL - - elif _isnan(value): - if op_string != _EQ_OP: - raise ValueError(_BAD_OP_NAN_NULL) - self.op_string = StructuredQuery.UnaryFilter.Operator.IS_NAN - elif isinstance(value, (transforms.Sentinel, transforms._ValueList)): - raise ValueError(_INVALID_WHERE_TRANSFORM) - else: - self.op_string = op_string + self.op_string = _validate_opation(op_string, value) def _to_pb(self): """Returns the protobuf representation, either a StructuredQuery.UnaryFilter or a StructuredQuery.FieldFilter""" @@ -478,22 +504,12 @@ def where( UserWarning, stacklevel=2, ) - if value is None: - if op_string != _EQ_OP: - raise ValueError(_BAD_OP_NAN_NULL) - filter_pb = query.StructuredQuery.UnaryFilter( - field=query.StructuredQuery.FieldReference(field_path=field_path), - op=StructuredQuery.UnaryFilter.Operator.IS_NULL, - ) - elif _isnan(value): - if op_string != _EQ_OP: - raise ValueError(_BAD_OP_NAN_NULL) + op = _validate_opation(op_string, value) + if isinstance(op, StructuredQuery.UnaryFilter.Operator): filter_pb = query.StructuredQuery.UnaryFilter( field=query.StructuredQuery.FieldReference(field_path=field_path), - op=StructuredQuery.UnaryFilter.Operator.IS_NAN, + op=op, ) - elif isinstance(value, (transforms.Sentinel, transforms._ValueList)): - raise ValueError(_INVALID_WHERE_TRANSFORM) else: filter_pb = query.StructuredQuery.FieldFilter( field=query.StructuredQuery.FieldReference(field_path=field_path), diff --git a/tests/system/test_system.py b/tests/system/test_system.py index ed525db57..0ab599c31 100644 --- a/tests/system/test_system.py +++ b/tests/system/test_system.py @@ -1503,6 +1503,10 @@ def test_query_unary(client, cleanup, database): # Add to clean-up. cleanup(document1.delete) + _, document2 = collection.add({field_name: 123}) + # Add to clean-up. + cleanup(document2.delete) + # 0. Query for null. query0 = collection.where(filter=FieldFilter(field_name, "==", None)) values0 = list(query0.stream()) diff --git a/tests/system/test_system_async.py b/tests/system/test_system_async.py index 675b23a98..200be7d8a 100644 --- a/tests/system/test_system_async.py +++ b/tests/system/test_system_async.py @@ -1444,6 +1444,10 @@ async def test_query_unary(client, cleanup, database): # Add to clean-up. cleanup(document1.delete) + _, document2 = await collection.add({field_name: 123}) + # Add to clean-up. + cleanup(document2.delete) + # 0. Query for null. query0 = collection.where(filter=FieldFilter(field_name, "==", None)) values0 = [i async for i in query0.stream()] @@ -1462,6 +1466,23 @@ async def test_query_unary(client, cleanup, database): assert len(data1) == 1 assert math.isnan(data1[field_name]) + # 2. Query for not null + query2 = collection.where(filter=FieldFilter(field_name, "!=", None)) + values2 = [i async for i in query2.stream()] + assert len(values2) == 2 + # should fetch documents 1 (NaN) and 2 (int) + assert any(snapshot.reference._path == document1._path for snapshot in values2) + assert any(snapshot.reference._path == document2._path for snapshot in values2) + + # 3. Query for not NAN. + query3 = collection.where(filter=FieldFilter(field_name, "!=", nan_val)) + values3 = [i async for i in query3.stream()] + assert len(values3) == 1 + snapshot3 = values3[0] + assert snapshot3.reference._path == document2._path + # only document2 is not NaN + assert snapshot3.to_dict() == {field_name: 123} + @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) async def test_collection_group_queries(client, cleanup, database): diff --git a/tests/unit/v1/test_base_query.py b/tests/unit/v1/test_base_query.py index 24caa5e40..7f6b0e5e2 100644 --- a/tests/unit/v1/test_base_query.py +++ b/tests/unit/v1/test_base_query.py @@ -301,6 +301,20 @@ def test_basequery_where_eq_null(unary_helper_function): unary_helper_function(None, op_enum) +@pytest.mark.parametrize( + "unary_helper_function", + [ + (_where_unary_helper), + (_where_unary_helper_field_filter), + ], +) +def test_basequery_where_neq_null(unary_helper_function): + from google.cloud.firestore_v1.types import StructuredQuery + + op_enum = StructuredQuery.UnaryFilter.Operator.IS_NOT_NULL + unary_helper_function(None, op_enum, op_string="!=") + + @pytest.mark.parametrize( "unary_helper_function", [ @@ -330,6 +344,20 @@ def test_basequery_where_eq_nan(unary_helper_function): unary_helper_function(float("nan"), op_enum) +@pytest.mark.parametrize( + "unary_helper_function", + [ + (_where_unary_helper), + (_where_unary_helper_field_filter), + ], +) +def test_basequery_where_neq_nan(unary_helper_function): + from google.cloud.firestore_v1.types import StructuredQuery + + op_enum = StructuredQuery.UnaryFilter.Operator.IS_NOT_NAN + unary_helper_function(float("nan"), op_enum, op_string="!=") + + @pytest.mark.parametrize( "unary_helper_function", [