Skip to content

Commit

Permalink
Add Echo task (#2654)
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw authored Aug 26, 2024
1 parent 64c56f8 commit 74d847a
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 1 deletion.
43 changes: 42 additions & 1 deletion flytekit/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from flytekit.core import launch_plan as _annotated_launchplan
from flytekit.core import workflow as _annotated_workflow
from flytekit.core.base_task import PythonTask, TaskMetadata, TaskResolverMixin
from flytekit.core.interface import transform_function_to_interface
from flytekit.core.interface import Interface, output_name_generator, transform_function_to_interface
from flytekit.core.pod_template import PodTemplate
from flytekit.core.python_function_task import PythonFunctionTask
from flytekit.core.reference_entity import ReferenceEntity, TaskReference
Expand Down Expand Up @@ -416,3 +416,44 @@ def wrapper(fn) -> ReferenceTask:
return ReferenceTask(project, domain, name, version, interface.inputs, interface.outputs)

return wrapper


class Echo(PythonTask):
_TASK_TYPE = "echo"

def __init__(self, name: str, inputs: Optional[Dict[str, Type]] = None, **kwargs):
"""
A task that simply echoes the inputs back to the user.
The task's inputs and outputs interface are the same.
FlytePropeller uses echo plugin to handle this task, and it won't create a pod for this task.
It will simply pass the inputs to the outputs.
https://github.com/flyteorg/flyte/blob/master/flyteplugins/go/tasks/plugins/testing/echo.go
Note: Make sure to enable the echo plugin in the propeller config to use this task.
```
task-plugins:
enabled-plugins:
- echo
```
:param name: The name of the task.
:param inputs: Name and type of inputs specified as a dictionary.
e.g. {"a": int, "b": str}.
:param kwargs: All other args required by the parent type - PythonTask.
"""
outputs = dict(zip(output_name_generator(len(inputs)), inputs.values())) if inputs else None
super().__init__(
task_type=self._TASK_TYPE,
name=name,
interface=Interface(inputs=inputs, outputs=outputs),
**kwargs,
)

def execute(self, **kwargs) -> Any:
values = list(kwargs.values())
if len(values) == 1:
return values[0]
else:
return tuple(values)
39 changes: 39 additions & 0 deletions tests/flytekit/unit/core/test_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from flytekit import task, workflow
from flytekit.configuration import Image, ImageConfig, SerializationSettings
from flytekit.core.condition import conditional
from flytekit.core.task import Echo
from flytekit.models.core.workflow import Node
from flytekit.tools.translator import get_serializable

Expand Down Expand Up @@ -495,3 +496,41 @@ def multiplier_2(my_input: float) -> float:

res = multiplier_2(my_input=10.0)
assert res == 20


def test_echo_in_condition():
echo1 = Echo(name="echo", inputs={"a": typing.Optional[float]})

@task()
def t1(radius: float) -> typing.Optional[float]:
return 2 * 3.14 * radius

@workflow
def wf1(radius: float) -> typing.Optional[float]:
return (
conditional("shape_properties_with_multiple_branches")
.if_((radius >= 0.1) & (radius < 1.0))
.then(t1(radius=radius))
.else_()
.then(echo1(a=radius))
)

assert wf1(radius=1.8) == 1.8

echo2 = Echo(name="echo", inputs={"a": float, "b": float})

@task()
def t2(radius: float) -> typing.Tuple[float, float]:
return 2 * 3.14 * radius, 2 * 3.14 * radius

@workflow
def wf2(radius1: float, radius2: float) -> typing.Tuple[float, float]:
return (
conditional("shape_properties_with_multiple_branches")
.if_((radius1 >= 0.1) & (radius1 < 1.0))
.then(t2(radius=radius2))
.else_()
.then(echo2(a=radius1, b=radius2))
)

assert wf2(radius1=1.8, radius2=1.8) == (1.8, 1.8)

0 comments on commit 74d847a

Please sign in to comment.