From 1b1b658d147b751bdec0cbf912d757a65e8d22d1 Mon Sep 17 00:00:00 2001 From: themylogin Date: Tue, 26 Nov 2024 12:57:01 +0100 Subject: [PATCH] Only redact secrets when the method return value is passed to the user --- .../middlewared/api/base/handler/accept.py | 6 +-- .../api/base/handler/dump_params.py | 8 +-- .../middlewared/api/base/handler/result.py | 8 +++ .../middlewared/api/base/handler/version.py | 10 +++- .../api/base/server/legacy_api_method.py | 38 ++++++++------ .../middlewared/api/base/server/method.py | 10 ++-- src/middlewared/middlewared/main.py | 50 +++++++++++-------- .../api/base/server/test_legacy_api_method.py | 4 -- src/middlewared/middlewared/restful.py | 1 + .../middlewared/service/core_service.py | 6 ++- tests/api2/test_legacy_websocket.py | 40 +++++++++++++-- 11 files changed, 123 insertions(+), 58 deletions(-) diff --git a/src/middlewared/middlewared/api/base/handler/accept.py b/src/middlewared/middlewared/api/base/handler/accept.py index d0e3b52478d81..99a9a75d5fc63 100644 --- a/src/middlewared/middlewared/api/base/handler/accept.py +++ b/src/middlewared/middlewared/api/base/handler/accept.py @@ -15,7 +15,7 @@ def accept_params(model: type[BaseModel], args: list, *, exclude_unset=False, ex :param model: `BaseModel` that defines method args. :param args: a list of method args. :param exclude_unset: if true, will not append default parameters to the list. - :param expose_secrets: if false, will replace `Private` parameters with a placeholder. + :param expose_secrets: if false, will replace `Secret` parameters with a placeholder. :return: a validated list of method args. """ args_as_dict = model_dict_from_list(model, args) @@ -60,7 +60,7 @@ def validate_model(model: type[BaseModel], data: dict, *, exclude_unset=False, e :param model: `BaseModel` subclass. :param data: provided data. :param exclude_unset: if true, will not add default values. - :param expose_secrets: if false, will replace `Private` fields with a placeholder. + :param expose_secrets: if false, will replace `Secret` fields with a placeholder. :return: validated data. """ try: @@ -83,5 +83,5 @@ def validate_model(model: type[BaseModel], data: dict, *, exclude_unset=False, e context={"expose_secrets": expose_secrets}, exclude_unset=exclude_unset, warnings=False, - by_alias=True + by_alias=True, ) diff --git a/src/middlewared/middlewared/api/base/handler/dump_params.py b/src/middlewared/middlewared/api/base/handler/dump_params.py index 85dd0491c9c05..4c11a6e660f10 100644 --- a/src/middlewared/middlewared/api/base/handler/dump_params.py +++ b/src/middlewared/middlewared/api/base/handler/dump_params.py @@ -17,7 +17,7 @@ def dump_params(model: type[BaseModel], args: list, expose_secrets: bool) -> lis :param model: `BaseModel` that defines method args. :param args: a list of method args. - :param expose_secrets: if false, will replace `Private` parameters with a placeholder. + :param expose_secrets: if false, will replace `Secret` parameters with a placeholder. :return: A list of method call arguments ready to be printed. """ try: @@ -32,10 +32,10 @@ def dump_params(model: type[BaseModel], args: list, expose_secrets: bool) -> lis def remove_secrets(model: type[BaseModel], value): """ - Removes `Private` values from a model value. + Removes `Secret` values from a model value. :param model: `BaseModel` that corresponds to `value`. - :param value: value that potentially contains `Private` data. - :return: `value` with `Private` parameters replaced with a placeholder. + :param value: value that potentially contains `Secret` data. + :return: `value` with `Secret` parameters replaced with a placeholder. """ if isinstance(value, dict) and (nested_model := model_field_is_model(model)): return { diff --git a/src/middlewared/middlewared/api/base/handler/result.py b/src/middlewared/middlewared/api/base/handler/result.py index 862e68bd992f5..27015f91fa3f9 100644 --- a/src/middlewared/middlewared/api/base/handler/result.py +++ b/src/middlewared/middlewared/api/base/handler/result.py @@ -2,6 +2,14 @@ def serialize_result(model, result, expose_secrets): + """ + Serializes a `result` of the method execution using the corresponding `model`. + + :param model: `BaseModel` that defines method return value. + :param result: method return value. + :param expose_secrets: if false, will replace `Secret` parameters with a placeholder. + :return: serialized method execution result. + """ return model(result=result).model_dump( context={"expose_secrets": expose_secrets}, warnings=False, diff --git a/src/middlewared/middlewared/api/base/handler/version.py b/src/middlewared/middlewared/api/base/handler/version.py index c17a3b0c37767..0b7e692013b4a 100644 --- a/src/middlewared/middlewared/api/base/handler/version.py +++ b/src/middlewared/middlewared/api/base/handler/version.py @@ -84,6 +84,12 @@ def adapt(self, value: dict, model_name: str, version1: str, version2: str) -> d :param version2: target API version that needs `value` :return: converted value """ + return self.adapt_model(value, model_name, version1, version2)[1] + + def adapt_model(self, value: dict, model_name: str, version1: str, version2: str) -> tuple[type[BaseModel], dict]: + """ + Same as `adapt`, but returned value will be a tuple of `version2` model instance and converted value. + """ try: version1_index = self.versions_history.index(version1) except ValueError: @@ -101,6 +107,7 @@ def adapt(self, value: dict, model_name: str, version1: str, version2: str) -> d raise APIVersionDoesNotContainModelException(current_version.version, model_name) value_factory = functools.partial(validate_model, current_version_model, value) + model = current_version_model if version1_index < version2_index: step = 1 @@ -115,10 +122,11 @@ def adapt(self, value: dict, model_name: str, version1: str, version2: str) -> d value_factory = functools.partial( self._adapt_model, value_factory, model_name, current_version, new_version, direction, ) + model = new_version.models.get(model_name) current_version = new_version - return value_factory() + return model, value_factory() def _adapt_model(self, value_factory: Callable[[], dict], model_name: str, current_version: APIVersion, new_version: APIVersion, direction: Direction): diff --git a/src/middlewared/middlewared/api/base/server/legacy_api_method.py b/src/middlewared/middlewared/api/base/server/legacy_api_method.py index c0df80cd4146b..ec01a31da370c 100644 --- a/src/middlewared/middlewared/api/base/server/legacy_api_method.py +++ b/src/middlewared/middlewared/api/base/server/legacy_api_method.py @@ -35,7 +35,7 @@ def __init__(self, middleware: "Middleware", name: str, api_version: str, adapte methodobj = self.methodobj if crud_methodobj := real_crud_method(methodobj): methodobj = crud_methodobj - if hasattr(methodobj, "new_style_accepts"): + if hasattr(methodobj, "new_style_accepts"): # FIXME: Remove this check when all models become new style self.accepts_model = methodobj.new_style_accepts self.returns_model = methodobj.new_style_returns else: @@ -43,8 +43,8 @@ def __init__(self, middleware: "Middleware", name: str, api_version: str, adapte self.returns_model = None async def call(self, app: "RpcWebSocketApp", params): - if self.accepts_model: - return self._adapt_result(await super().call(app, self._adapt_params(params))) + if self.accepts_model: # FIXME: Remove this check when all models become new style + params = self._adapt_params(params) return await super().call(app, params) @@ -70,22 +70,28 @@ def _adapt_params(self, params): return [adapted_params_dict[field] for field in self.accepts_model.model_fields] - def _adapt_result(self, result): - try: - return self.adapter.adapt( - {"result": result}, - self.returns_model.__name__, - self.adapter.current_version, - self.api_version, - )["result"] - except APIVersionDoesNotContainModelException: - if self.passthrough_nonexistent_methods: - return result + def _dump_result(self, app: "RpcWebSocketApp", methodobj, result): + if self.accepts_model: # FIXME: Remove this check when all models become new style + try: + model, result = self.adapter.adapt_model( + {"result": result}, + self.returns_model.__name__, + self.adapter.current_version, + self.api_version, + ) + except APIVersionDoesNotContainModelException: + if self.passthrough_nonexistent_methods: + return super()._dump_result(app, methodobj, result) + + raise + + return self.middleware.dump_result(self.serviceobj, methodobj, app, result["result"], + new_style_returns_model=model) - raise + return super()._dump_result(app, methodobj, result) def dump_args(self, params): - if self.accepts_model: + if self.accepts_model: # FIXME: Remove this check when all models become new style return dump_params(self.accepts_model, params, False) return super().dump_args(params) diff --git a/src/middlewared/middlewared/api/base/server/method.py b/src/middlewared/middlewared/api/base/server/method.py index ebdd9b92b0ecc..53e344a3b29d1 100644 --- a/src/middlewared/middlewared/api/base/server/method.py +++ b/src/middlewared/middlewared/api/base/server/method.py @@ -39,13 +39,17 @@ async def call(self, app: "RpcWebSocketApp", params: list): result = await self.middleware.call_with_audit(self.name, self.serviceobj, methodobj, params, app) if isinstance(result, Job): - result = result.id - elif isinstance(result, types.GeneratorType): + return result.id + + if isinstance(result, types.GeneratorType): result = list(result) elif isinstance(result, types.AsyncGeneratorType): result = [i async for i in result] - return result + return self._dump_result(app, methodobj, result) + + def _dump_result(self, app: "RpcWebSocketApp", methodobj, result): + return self.middleware.dump_result(self.serviceobj, methodobj, app, result) def dump_args(self, params: list) -> list: """ diff --git a/src/middlewared/middlewared/main.py b/src/middlewared/middlewared/main.py index 4d8e18077fa58..8b323caa528df 100644 --- a/src/middlewared/middlewared/main.py +++ b/src/middlewared/middlewared/main.py @@ -212,7 +212,9 @@ async def call_method(self, message, serviceobj, methodobj): result = [i async for i in result] else: if lam.returns_model: - result = lam._adapt_result(result) + result = lam._dump_result(self, methodobj, result) + else: + result = self.middleware.dump_result(serviceobj, methodobj, self, result) self._send({ 'id': message['id'], @@ -1510,22 +1512,37 @@ def dump_args(self, args, method=None, method_name=None): return [method.accepts[i].dump(arg) if i < len(method.accepts) else arg for i, arg in enumerate(args)] - def dump_result(self, method, result, expose_secrets): + def dump_result(self, serviceobj, methodobj, app, result, *, new_style_returns_model=None): + expose_secrets = True + if app and app.authenticated_credentials: + if app.authenticated_credentials.is_user_session and not ( + credential_has_full_admin(app.authenticated_credentials) or + ( + serviceobj._config.role_prefix and + app.authenticated_credentials.has_role(f'{serviceobj._config.role_prefix}_WRITE') + ) + ): + expose_secrets = False + if isinstance(result, Job): return result - if method_self := getattr(method, "__self__", None): - if method.__name__ in ["create", "update", "delete"]: - if do_method := getattr(method_self, f"do_{method.__name__}", None): + if method_self := getattr(methodobj, "__self__", None): + if methodobj.__name__ in ["create", "update", "delete"]: + if do_method := getattr(method_self, f"do_{methodobj.__name__}", None): if hasattr(do_method, "new_style_returns"): # FIXME: Get rid of `create`/`do_create` duality - method = do_method + methodobj = do_method + + if hasattr(methodobj, "new_style_returns"): + # FIXME: When all models become new style, this should be passed explicitly + if new_style_returns_model is None: + new_style_returns_model = methodobj.new_style_returns - if hasattr(method, "new_style_returns"): - return serialize_result(method.new_style_returns, result, expose_secrets) + return serialize_result(new_style_returns_model, result, expose_secrets) - if not expose_secrets and hasattr(method, "returns") and method.returns: - schema = method.returns[0] + if not expose_secrets and hasattr(methodobj, "returns") and methodobj.returns: + schema = methodobj.returns[0] if isinstance(schema, OROperator): result = schema.dump(result, False) else: @@ -1620,18 +1637,7 @@ async def job_on_finish_cb(job): job = result await job.set_on_finish_cb(job_on_finish_cb) - expose_secrets = True - if app and app.authenticated_credentials: - if app.authenticated_credentials.is_user_session and not ( - credential_has_full_admin(app.authenticated_credentials) or - ( - serviceobj._config.role_prefix and - app.authenticated_credentials.has_role(f'{serviceobj._config.role_prefix}_WRITE') - ) - ): - expose_secrets = False - - result = self.dump_result(methodobj, result, expose_secrets) + return result finally: # If the method is a job, audit message will be logged by `job_on_finish_cb` if job is None: diff --git a/src/middlewared/middlewared/pytest/unit/api/base/server/test_legacy_api_method.py b/src/middlewared/middlewared/pytest/unit/api/base/server/test_legacy_api_method.py index 1873d15eafcad..f64efc67738f6 100644 --- a/src/middlewared/middlewared/pytest/unit/api/base/server/test_legacy_api_method.py +++ b/src/middlewared/middlewared/pytest/unit/api/base/server/test_legacy_api_method.py @@ -59,7 +59,3 @@ def method(number, text, multiplier): def test_adapt_params(): assert legacy_api_method._adapt_params([1]) == [1, "Default", 2] - - -def test_adapt_result(): - assert legacy_api_method._adapt_result(1) == "1" diff --git a/src/middlewared/middlewared/restful.py b/src/middlewared/middlewared/restful.py index f908ac1cbff0b..e54846efae252 100644 --- a/src/middlewared/middlewared/restful.py +++ b/src/middlewared/middlewared/restful.py @@ -856,6 +856,7 @@ async def do(self, http_method, req, resp, app, authorized, **kwargs): if authorized: result = await self.middleware.call_with_audit(methodname, serviceobj, methodobj, method_args, **method_kwargs) + result = self.middleware.dump_result(serviceobj, methodobj, app, result) else: await self.middleware.log_audit_message_for_method(methodname, methodobj, method_args, app, True, False, False) diff --git a/src/middlewared/middlewared/service/core_service.py b/src/middlewared/middlewared/service/core_service.py index 86ba073dd60a9..be967564f1601 100644 --- a/src/middlewared/middlewared/service/core_service.py +++ b/src/middlewared/middlewared/service/core_service.py @@ -804,11 +804,11 @@ async def bulk(self, app, job, method, params, description): # entries for external callers to methods. app is only None # on internal calls to core.bulk. if app: - msg = await self.middleware.call_with_audit(method, serviceobj, methodobj, p, app=app) + msg = await self.middleware.call_with_audit(method, serviceobj, methodobj, p, app) else: msg = await self.middleware.call(method, *p) - status = {"result": msg, "error": None} + status = {"error": None} if isinstance(msg, Job): b_job = msg @@ -817,6 +817,8 @@ async def bulk(self, app, job, method, params, description): if b_job.error: status["error"] = b_job.error + else: + status["result"] = self.middleware.dump_result(serviceobj, methodobj, app, msg) statuses.append(status) except Exception as e: diff --git a/tests/api2/test_legacy_websocket.py b/tests/api2/test_legacy_websocket.py index 4a217021131e9..70ab9eeecdd71 100644 --- a/tests/api2/test_legacy_websocket.py +++ b/tests/api2/test_legacy_websocket.py @@ -1,7 +1,11 @@ +import random +import string + import pytest from truenas_api_client import Client +from middlewared.test.integration.assets.account import unprivileged_user from middlewared.test.integration.assets.cloud_sync import credential from middlewared.test.integration.utils import password, websocket_url @@ -17,7 +21,28 @@ def c(): yield c -def test_adapts_cloud_credentials(c): +@pytest.fixture(scope="module") +def unprivileged_client(): + suffix = "".join([random.choice(string.ascii_lowercase + string.digits) for _ in range(8)]) + with unprivileged_user( + username=f"unprivileged_{suffix}", + group_name=f"unprivileged_users_{suffix}", + privilege_name=f"Unprivileged users ({suffix})", + allowlist=[], + roles=["READONLY_ADMIN"], + web_shell=False, + ) as t: + with Client(websocket_url() + "/websocket") as c: + c.call("auth.login_ex", { + "mechanism": "PASSWORD_PLAIN", + "username": t.username, + "password": t.password, + }) + yield c + + +@pytest.fixture(scope="module") +def ftp_credential(): with credential({ "provider": { "type": "FTP", @@ -27,5 +52,14 @@ def test_adapts_cloud_credentials(c): "pass": "", }, }) as cred: - result = c.call("cloudsync.credentials.get_instance", cred["id"]) - assert result["provider"] == "FTP" + yield cred + + +def test_adapts_cloud_credentials(c, ftp_credential): + result = c.call("cloudsync.credentials.get_instance", ftp_credential["id"]) + assert result["provider"] == "FTP" + + +def test_adapts_cloud_credentials_for_unprivileged(unprivileged_client, ftp_credential): + result = unprivileged_client.call("cloudsync.credentials.get_instance", ftp_credential["id"]) + assert result["attributes"] == "********"