From bd75a82eaf135298cb4fa9250d396a0151a31284 Mon Sep 17 00:00:00 2001 From: jx2lee Date: Wed, 18 Dec 2024 00:37:23 +0900 Subject: [PATCH 1/4] extra forbid in execution api --- airflow/api_fastapi/execution_api/datamodels/taskinstance.py | 4 +++- airflow/api_fastapi/execution_api/datamodels/variable.py | 2 ++ airflow/api_fastapi/execution_api/datamodels/xcom.py | 4 +++- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py index c1bf588c2bbd4..498d94fcc9de4 100644 --- a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py +++ b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py @@ -24,7 +24,7 @@ from pydantic import AwareDatetime, Discriminator, Field, Tag, TypeAdapter, WithJsonSchema, field_validator from airflow.api_fastapi.common.types import UtcDateTime -from airflow.api_fastapi.core_api.base import BaseModel +from airflow.api_fastapi.core_api.base import BaseModel, ConfigDict from airflow.api_fastapi.execution_api.datamodels.connection import ConnectionResponse from airflow.api_fastapi.execution_api.datamodels.variable import VariableResponse from airflow.utils.state import IntermediateTIState, TaskInstanceState as TIState, TerminalTIState @@ -187,6 +187,8 @@ class DagRun(BaseModel): class TIRunContext(BaseModel): """Response schema for TaskInstance run context.""" + model_config = ConfigDict(extra="forbid") + dag_run: DagRun """DAG run information for the task instance.""" diff --git a/airflow/api_fastapi/execution_api/datamodels/variable.py b/airflow/api_fastapi/execution_api/datamodels/variable.py index 6c597524763aa..546a06f09b83e 100644 --- a/airflow/api_fastapi/execution_api/datamodels/variable.py +++ b/airflow/api_fastapi/execution_api/datamodels/variable.py @@ -25,6 +25,8 @@ class VariableResponse(BaseModel): """Variable schema for responses with fields that are needed for Runtime.""" + model_config = ConfigDict(extra="forbid") + key: str val: str | None = Field(alias="value") diff --git a/airflow/api_fastapi/execution_api/datamodels/xcom.py b/airflow/api_fastapi/execution_api/datamodels/xcom.py index 1f913f9ac380e..6f897aa6966b9 100644 --- a/airflow/api_fastapi/execution_api/datamodels/xcom.py +++ b/airflow/api_fastapi/execution_api/datamodels/xcom.py @@ -19,12 +19,14 @@ from typing import Any -from airflow.api_fastapi.core_api.base import BaseModel +from airflow.api_fastapi.core_api.base import BaseModel, ConfigDict class XComResponse(BaseModel): """XCom schema for responses with fields that are needed for Runtime.""" + model_config = ConfigDict(extra="forbid") + key: str value: Any """The returned XCom value in a JSON-compatible format.""" From dafa7fe79bc5bdbc718201e5925a72f2ec76e7d2 Mon Sep 17 00:00:00 2001 From: jx2lee Date: Wed, 18 Dec 2024 00:45:22 +0900 Subject: [PATCH 2/4] generated datamodel-codegen --- task_sdk/src/airflow/sdk/api/datamodels/_generated.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py index 00187364c8669..3ee086096c2e6 100644 --- a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -154,6 +154,9 @@ class VariableResponse(BaseModel): Variable schema for responses with fields that are needed for Runtime. """ + model_config = ConfigDict( + extra="forbid", + ) key: Annotated[str, Field(title="Key")] value: Annotated[str | None, Field(title="Value")] = None @@ -163,6 +166,9 @@ class XComResponse(BaseModel): XCom schema for responses with fields that are needed for Runtime. """ + model_config = ConfigDict( + extra="forbid", + ) key: Annotated[str, Field(title="Key")] value: Annotated[Any, Field(title="Value")] @@ -205,6 +211,9 @@ class TIRunContext(BaseModel): Response schema for TaskInstance run context. """ + model_config = ConfigDict( + extra="forbid", + ) dag_run: DagRun variables: Annotated[list[VariableResponse] | None, Field(title="Variables")] = None connections: Annotated[list[ConnectionResponse] | None, Field(title="Connections")] = None From 823a6f0f9de3cca58b8b48cfc95a1bededd69e24 Mon Sep 17 00:00:00 2001 From: jx2lee Date: Wed, 18 Dec 2024 22:59:57 +0900 Subject: [PATCH 3/4] add mising --- .../execution_api/datamodels/taskinstance.py | 12 ++++++++++-- .../airflow/sdk/api/datamodels/_generated.py | 18 +++++++++++++++--- .../execution_api/routes/test_variables.py | 2 +- 3 files changed, 26 insertions(+), 6 deletions(-) diff --git a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py index 498d94fcc9de4..1228e11cd510c 100644 --- a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py +++ b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py @@ -36,6 +36,8 @@ class TIEnterRunningPayload(BaseModel): """Schema for updating TaskInstance to 'RUNNING' state with minimal required fields.""" + model_config = ConfigDict(extra="forbid") + state: Annotated[ Literal[TIState.RUNNING], # Specify a default in the schema, but not in code. @@ -54,6 +56,8 @@ class TIEnterRunningPayload(BaseModel): class TITerminalStatePayload(BaseModel): """Schema for updating TaskInstance to a terminal state (e.g., SUCCESS or FAILED).""" + model_config = ConfigDict(extra="forbid") + state: TerminalTIState end_date: UtcDateTime @@ -63,12 +67,16 @@ class TITerminalStatePayload(BaseModel): class TITargetStatePayload(BaseModel): """Schema for updating TaskInstance to a target state, excluding terminal and running states.""" + model_config = ConfigDict(extra="forbid") + state: IntermediateTIState class TIDeferredStatePayload(BaseModel): """Schema for updating TaskInstance to a deferred state.""" + model_config = ConfigDict(extra="forbid") + state: Annotated[ Literal[IntermediateTIState.DEFERRED], # Specify a default in the schema, but not in code, so Pydantic marks it as required. @@ -148,6 +156,8 @@ def ti_state_discriminator(v: dict[str, str] | BaseModel) -> str: class TIHeartbeatInfo(BaseModel): """Schema for TaskInstance heartbeat endpoint.""" + model_config = ConfigDict(extra="forbid") + hostname: str pid: int @@ -187,8 +197,6 @@ class DagRun(BaseModel): class TIRunContext(BaseModel): """Response schema for TaskInstance run context.""" - model_config = ConfigDict(extra="forbid") - dag_run: DagRun """DAG run information for the task instance.""" diff --git a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py index 3ee086096c2e6..a7a89ee7f4663 100644 --- a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -74,6 +74,9 @@ class TIDeferredStatePayload(BaseModel): Schema for updating TaskInstance to a deferred state. """ + model_config = ConfigDict( + extra="forbid", + ) state: Annotated[Literal["deferred"] | None, Field(title="State")] = "deferred" classpath: Annotated[str, Field(title="Classpath")] trigger_kwargs: Annotated[dict[str, Any] | None, Field(title="Trigger Kwargs")] = None @@ -86,6 +89,9 @@ class TIEnterRunningPayload(BaseModel): Schema for updating TaskInstance to 'RUNNING' state with minimal required fields. """ + model_config = ConfigDict( + extra="forbid", + ) state: Annotated[Literal["running"] | None, Field(title="State")] = "running" hostname: Annotated[str, Field(title="Hostname")] unixname: Annotated[str, Field(title="Unixname")] @@ -98,6 +104,9 @@ class TIHeartbeatInfo(BaseModel): Schema for TaskInstance heartbeat endpoint. """ + model_config = ConfigDict( + extra="forbid", + ) hostname: Annotated[str, Field(title="Hostname")] pid: Annotated[int, Field(title="Pid")] @@ -117,6 +126,9 @@ class TITargetStatePayload(BaseModel): Schema for updating TaskInstance to a target state, excluding terminal and running states. """ + model_config = ConfigDict( + extra="forbid", + ) state: IntermediateTIState @@ -211,9 +223,6 @@ class TIRunContext(BaseModel): Response schema for TaskInstance run context. """ - model_config = ConfigDict( - extra="forbid", - ) dag_run: DagRun variables: Annotated[list[VariableResponse] | None, Field(title="Variables")] = None connections: Annotated[list[ConnectionResponse] | None, Field(title="Connections")] = None @@ -224,5 +233,8 @@ class TITerminalStatePayload(BaseModel): Schema for updating TaskInstance to a terminal state (e.g., SUCCESS or FAILED). """ + model_config = ConfigDict( + extra="forbid", + ) state: TerminalTIState end_date: Annotated[datetime, Field(title="End Date")] diff --git a/tests/api_fastapi/execution_api/routes/test_variables.py b/tests/api_fastapi/execution_api/routes/test_variables.py index 20a9b43c07ace..45868e2a6092e 100644 --- a/tests/api_fastapi/execution_api/routes/test_variables.py +++ b/tests/api_fastapi/execution_api/routes/test_variables.py @@ -54,7 +54,7 @@ def test_variable_get_from_db(self, client, session): {"AIRFLOW_VAR_KEY1": "VALUE"}, ) def test_variable_get_from_env_var(self, client, session): - response = client.get("/execution/variables/key1") + response = client.get("/execution/variables/key1", params={"foo": "bar"}) assert response.status_code == 200 assert response.json() == {"key": "key1", "value": "VALUE"} From ac39ba4f53579ebd1128053f2588c8630439124d Mon Sep 17 00:00:00 2001 From: jx2lee Date: Thu, 19 Dec 2024 10:29:12 +0900 Subject: [PATCH 4/4] test --- .../routes/test_task_instances.py | 83 +++++++++++++++++++ .../execution_api/routes/test_variables.py | 2 +- 2 files changed, 84 insertions(+), 1 deletion(-) diff --git a/tests/api_fastapi/execution_api/routes/test_task_instances.py b/tests/api_fastapi/execution_api/routes/test_task_instances.py index 4ed5f8f1598f3..cd7a49971cdb6 100644 --- a/tests/api_fastapi/execution_api/routes/test_task_instances.py +++ b/tests/api_fastapi/execution_api/routes/test_task_instances.py @@ -136,6 +136,39 @@ def test_ti_run_state_conflict_if_not_queued( assert session.scalar(select(TaskInstance.state).where(TaskInstance.id == ti.id)) == initial_ti_state + def test_ti_run_failed_with_extra(self, client, session, create_task_instance, time_machine): + """ + Test that a 422 error is returned when extra fields are included in the payload. + """ + instant_str = "2024-12-19T00:00:00Z" + instant = timezone.parse(instant_str) + time_machine.move_to(instant, tick=False) + + ti = create_task_instance( + task_id="test_ti_run_failed_with_extra", + state=State.QUEUED, + session=session, + start_date=instant, + ) + + session.commit() + + response = client.patch( + f"/execution/task-instances/{ti.id}/run", + json={ + "state": "running", + "hostname": "random-hostname", + "unixname": "random-unixname", + "pid": 100, + "start_date": instant_str, + "foo": "bar", + }, + ) + + assert response.status_code == 422 + assert response.json()["detail"][0]["type"] == "extra_forbidden" + assert response.json()["detail"][0]["msg"] == "Extra inputs are not permitted" + class TestTIUpdateState: def setup_method(self): @@ -340,6 +373,27 @@ def test_ti_update_state_to_reschedule(self, client, session, create_task_instan assert trs[0].map_index == -1 assert trs[0].duration == 129600 + def test_ti_update_state_failed_with_extra(self, client, session, create_task_instance, time_machine): + """ + Test that a 422 error is returned when extra fields are included in the payload. + """ + ti = create_task_instance( + task_id="test_ti_update_state_failed_with_extra", + state=State.RUNNING, + session=session, + start_date=DEFAULT_START_DATE, + ) + + session.commit() + + response = client.patch( + f"/execution/task-instances/{ti.id}/state", json={"state": "scheduled", "foo": "bar"} + ) + + assert response.status_code == 422 + assert response.json()["detail"][0]["type"] == "extra_forbidden" + assert response.json()["detail"][0]["msg"] == "Extra inputs are not permitted" + class TestTIHealthEndpoint: def setup_method(self): @@ -536,6 +590,35 @@ def test_ti_update_state_to_failed_table_check(self, client, session, create_tas assert ti.next_kwargs is None assert ti.duration == 3600.00 + def test_ti_heartbeat_with_extra( + self, + client, + session, + create_task_instance, + time_machine, + ): + """ + Test that a 422 error is returned when extra fields are included in the payload. + """ + ti = create_task_instance( + task_id="test_ti_heartbeat_when_task_not_running", + state=State.RUNNING, + hostname="random-hostname", + pid=1547, + session=session, + ) + session.commit() + task_instance_id = ti.id + + response = client.put( + f"/execution/task-instances/{task_instance_id}/heartbeat", + json={"hostname": "random-hostname", "pid": 1547, "foo": "bar"}, + ) + + assert response.status_code == 422 + assert response.json()["detail"][0]["type"] == "extra_forbidden" + assert response.json()["detail"][0]["msg"] == "Extra inputs are not permitted" + class TestTIPutRTIF: def setup_method(self): diff --git a/tests/api_fastapi/execution_api/routes/test_variables.py b/tests/api_fastapi/execution_api/routes/test_variables.py index 45868e2a6092e..20a9b43c07ace 100644 --- a/tests/api_fastapi/execution_api/routes/test_variables.py +++ b/tests/api_fastapi/execution_api/routes/test_variables.py @@ -54,7 +54,7 @@ def test_variable_get_from_db(self, client, session): {"AIRFLOW_VAR_KEY1": "VALUE"}, ) def test_variable_get_from_env_var(self, client, session): - response = client.get("/execution/variables/key1", params={"foo": "bar"}) + response = client.get("/execution/variables/key1") assert response.status_code == 200 assert response.json() == {"key": "key1", "value": "VALUE"}