From 189a820276b0fc5121e0f310d763a9bfbe668cdf Mon Sep 17 00:00:00 2001 From: Marcel van der Veldt Date: Fri, 21 Jun 2024 15:51:44 +0200 Subject: [PATCH] Fix writing of Nullable attributes (#768) --- matter_server/common/helpers/util.py | 39 ++++++++++++++++------- matter_server/server/device_controller.py | 1 + tests/common/test_parser.py | 21 +++++++++++- 3 files changed, 48 insertions(+), 13 deletions(-) diff --git a/matter_server/common/helpers/util.py b/matter_server/common/helpers/util.py index 9cf1a4ec..1786bcce 100644 --- a/matter_server/common/helpers/util.py +++ b/matter_server/common/helpers/util.py @@ -107,7 +107,7 @@ def parse_value( value: Any, value_type: Any, default: Any = MISSING, - allow_none: bool = True, + allow_none: bool = False, allow_sdk_types: bool = False, ) -> Any: """ @@ -121,6 +121,16 @@ def parse_value( # this shouldn't happen, but just in case value_type = get_type_hints(value_type, globals(), locals()) + # handle value is None/missing but a default value is set + if value is None and not isinstance(default, type(MISSING)): + return default + # handle value is None and sdk type is Nullable + if value is None and value_type is Nullable: + return Nullable() if allow_sdk_types else None + # handle value is None (but that is allowed according to the annotations) + if value is None and value_type is NoneType: + return None + if isinstance(value, dict): if descriptor := getattr(value_type, "descriptor", None): # handle matter TLV dicts where the keys are just tag identifiers @@ -132,14 +142,6 @@ def parse_value( return None value = None - if value is None and not isinstance(default, type(MISSING)): - return default - if value is None and value_type is NoneType: - return None - if value is None and value_type is Nullable: - return Nullable() if allow_sdk_types else None - if value is None and allow_none: - return None if is_dataclass(value_type) and isinstance(value, dict): return dataclass_from_dict(value_type, value) # get origin value type and inspect one-by-one @@ -156,7 +158,11 @@ def parse_value( subvalue_type = get_args(value_type)[1] return { parse_value(subkey, subkey, subkey_type): parse_value( - f"{subkey}.value", subvalue, subvalue_type + f"{subkey}.value", + subvalue, + subvalue_type, + allow_none=allow_none, + allow_sdk_types=allow_sdk_types, ) for subkey, subvalue in value.items() } @@ -169,7 +175,13 @@ def parse_value( return value # try them all until one succeeds try: - return parse_value(name, value, sub_arg_type) + return parse_value( + name, + value, + sub_arg_type, + allow_none=allow_none, + allow_sdk_types=allow_sdk_types, + ) except (KeyError, TypeError, ValueError): pass # if we get to this point, all possibilities failed @@ -189,8 +201,11 @@ def parse_value( # handle Any as value type (which is basically unprocessable) if value_type is Any: return value + # handle value is None (but that is allowed) + if value is None and allow_none: + return None # raise if value is None and the value is required according to annotations - if value is None and value_type is not NoneType and not allow_none: + if value is None: raise KeyError(f"`{name}` of type `{value_type}` is required.") try: diff --git a/matter_server/server/device_controller.py b/matter_server/server/device_controller.py index bdb06fec..d8b0d5fa 100644 --- a/matter_server/server/device_controller.py +++ b/matter_server/server/device_controller.py @@ -695,6 +695,7 @@ async def write_attribute( name=attribute_path, value=value, value_type=attribute.attribute_type.Type, + allow_none=False, allow_sdk_types=True, ) if node_id >= TEST_NODE_START: diff --git a/tests/common/test_parser.py b/tests/common/test_parser.py index 41f9a44e..e7609695 100644 --- a/tests/common/test_parser.py +++ b/tests/common/test_parser.py @@ -3,8 +3,9 @@ from dataclasses import dataclass import datetime from enum import Enum, IntEnum -from typing import Optional +from typing import Optional, Union +from chip.clusters.Types import Nullable, NullValue import pytest from matter_server.common.helpers.util import dataclass_from_dict, parse_value @@ -110,3 +111,21 @@ def test_dataclass_from_dict(): # test NOCStruct.noc edge case res = parse_value("NOCStruct.noc", 5, bytes) assert res == b"" + + +def test_parse_value(): + """Test special cases in the parse_value helper.""" + # test None value which is allowed + assert parse_value("test", None, int, allow_none=True) is None + # test unexpected None value + with pytest.raises(KeyError): + parse_value("test", None, int, allow_none=False) + # test sdk Nullable type + assert parse_value("test", None, Nullable) is None + assert parse_value("test", None, Nullable, allow_sdk_types=True) == NullValue + assert ( + parse_value( + "test", None, Union[int, Nullable], allow_none=False, allow_sdk_types=True + ) + == NullValue + )