From 40c9540fc9a01ec61643de8ad45dd51af9212b57 Mon Sep 17 00:00:00 2001 From: Eduardo Apolinario <653394+eapolinario@users.noreply.github.com> Date: Thu, 21 Nov 2024 20:39:24 -0500 Subject: [PATCH] Offload literals (#2872) * wip - Implement offloading of literals Signed-off-by: Eduardo Apolinario * Fix use of metadata bucket prefix Signed-off-by: Eduardo Apolinario * Fix repeated use of uri Signed-off-by: Eduardo Apolinario * Add temporary representation for offloaded literal Signed-off-by: Eduardo Apolinario * Add one unit test Signed-off-by: Eduardo Apolinario * Add another test Signed-off-by: Eduardo Apolinario * Stylistic changes to the two tests Signed-off-by: Eduardo Apolinario * Add test for min offloading threshold set to 1MB Signed-off-by: Eduardo Apolinario * Pick a unique engine-dir for tests Signed-off-by: Eduardo Apolinario * s/new_outputs/literal_map_copy/ Signed-off-by: Eduardo Apolinario * Remove unused constant Signed-off-by: Eduardo Apolinario * Use output_prefix in definition of offloaded literals Signed-off-by: Eduardo Apolinario * Add initial version of pbhash.py Signed-off-by: Eduardo Apolinario * Add tests to verify that overriding the hash is carried over to offloaded literals Signed-off-by: Eduardo Apolinario * Add a few more tests Signed-off-by: Eduardo Apolinario * Always import ParamSpec from `typing_extensions` Signed-off-by: Eduardo Apolinario * Fix lint warnings Signed-off-by: Eduardo Apolinario * Set inferred_type using the task type interface Signed-off-by: Eduardo Apolinario * Add comment about offloaded literals files and how they are uploaded to the metadata bucket Signed-off-by: Eduardo Apolinario * Add offloading_enabled Signed-off-by: Eduardo Apolinario * Add more unit tests including a negative test Signed-off-by: Eduardo Apolinario * Fix bad merge Signed-off-by: Eduardo Apolinario * Incorporate feedback. Signed-off-by: Eduardo Apolinario * Fix image name (unrelated to this PR - just a nice-to-have to decrease flakiness) Signed-off-by: Eduardo Apolinario * Add `is_map_task` to `_dispatch_execute` Signed-off-by: Eduardo Apolinario --------- Signed-off-by: Eduardo Apolinario Co-authored-by: Eduardo Apolinario --- flytekit/bin/entrypoint.py | 66 ++- flytekit/core/task.py | 8 +- flytekit/core/workflow.py | 6 +- flytekit/interaction/string_literals.py | 3 + flytekit/models/literals.py | 2 +- flytekit/utils/pbhash.py | 39 ++ .../unit/bin/test_python_entrypoint.py | 465 +++++++++++++++++- .../unit/core/image_spec/test_image_spec.py | 2 +- tests/flytekit/unit/utils/test_pbhash.py | 144 ++++++ 9 files changed, 717 insertions(+), 18 deletions(-) create mode 100644 flytekit/utils/pbhash.py create mode 100644 tests/flytekit/unit/utils/test_pbhash.py diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py index 61579f00ee..084e8f733b 100644 --- a/flytekit/bin/entrypoint.py +++ b/flytekit/bin/entrypoint.py @@ -14,7 +14,7 @@ import uuid import warnings from sys import exit -from typing import Callable, List, Optional +from typing import Callable, Dict, List, Optional import click from flyteidl.core import literals_pb2 as _literals_pb2 @@ -55,6 +55,7 @@ from flytekit.models.core import identifier as _identifier from flytekit.tools.fast_registration import download_distribution as _download_distribution from flytekit.tools.module_loader import load_object_from_module +from flytekit.utils.pbhash import compute_hash_string def get_version_message(): @@ -135,6 +136,7 @@ def _dispatch_execute( load_task: Callable[[], PythonTask], inputs_path: str, output_prefix: str, + is_map_task: bool = False, ): """ Dispatches execute to PythonTask @@ -144,6 +146,12 @@ def _dispatch_execute( a: [Optional] Record outputs to output_prefix b: OR if IgnoreOutputs is raised, then ignore uploading outputs c: OR if an unhandled exception is retrieved - record it as an errors.pb + + :param ctx: FlyteContext + :param load_task: Callable[[], PythonTask] + :param inputs: Where to read inputs + :param output_prefix: Where to write primitive outputs + :param is_map_task: Whether this task is executing as part of a map task """ error_file_name = _build_error_file_name() worker_name = _get_worker_name() @@ -179,7 +187,59 @@ def _dispatch_execute( logger.warning("Task produces no outputs") output_file_dict = {_constants.OUTPUT_FILE_NAME: _literal_models.LiteralMap(literals={})} elif isinstance(outputs, _literal_models.LiteralMap): - output_file_dict = {_constants.OUTPUT_FILE_NAME: outputs} + # The keys in this map hold the filenames to the offloaded proto literals. + offloaded_literals: Dict[str, _literal_models.Literal] = {} + literal_map_copy = {} + + offloading_enabled = os.environ.get("_F_L_MIN_SIZE_MB", None) is not None + min_offloaded_size = -1 + max_offloaded_size = -1 + if offloading_enabled: + min_offloaded_size = int(os.environ.get("_F_L_MIN_SIZE_MB", "10")) * 1024 * 1024 + max_offloaded_size = int(os.environ.get("_F_L_MAX_SIZE_MB", "1000")) * 1024 * 1024 + + # Go over each output and create a separate offloaded in case its size is too large + for k, v in outputs.literals.items(): + literal_map_copy[k] = v + + if not offloading_enabled: + continue + + lit = v.to_flyte_idl() + if max_offloaded_size != -1 and lit.ByteSize() >= max_offloaded_size: + raise ValueError( + f"Literal {k} is too large to be offloaded. Max literal size is {max_offloaded_size} whereas the literal size is {lit.ByteSize()} bytes" + ) + + if min_offloaded_size != -1 and lit.ByteSize() >= min_offloaded_size: + logger.debug(f"Literal {k} is too large to be inlined, offloading to metadata bucket") + inferred_type = task_def.interface.outputs[k].type + + # In the case of map tasks we need to use the type of the collection as inferred type as the task + # typed interface of the offloaded literal. This is done because the map task interface present in + # the task template contains the (correct) type for the entire map task, not the single node execution. + # For that reason we "unwrap" the collection type and use it as the inferred type of the offloaded literal. + if is_map_task: + inferred_type = inferred_type.collection_type + + # This file will hold the offloaded literal and will be written to the output prefix + # alongside the regular outputs.pb, deck.pb, etc. + # N.B.: by construction `offloaded_filename` is guaranteed to be unique + offloaded_filename = f"{k}_offloaded_metadata.pb" + offloaded_literal = _literal_models.Literal( + offloaded_metadata=_literal_models.LiteralOffloadedMetadata( + uri=f"{output_prefix}/{offloaded_filename}", + size_bytes=lit.ByteSize(), + # TODO: remove after https://github.com/flyteorg/flyte/pull/5909 is merged + inferred_type=inferred_type, + ), + hash=v.hash if v.hash is not None else compute_hash_string(lit), + ) + literal_map_copy[k] = offloaded_literal + offloaded_literals[offloaded_filename] = v + outputs = _literal_models.LiteralMap(literals=literal_map_copy) + + output_file_dict = {_constants.OUTPUT_FILE_NAME: outputs, **offloaded_literals} elif isinstance(outputs, _dynamic_job.DynamicJobSpec): output_file_dict = {_constants.FUTURES_FILE_NAME: outputs} else: @@ -588,7 +648,7 @@ def load_task(): ) return - _dispatch_execute(ctx, load_task, inputs, output_prefix) + _dispatch_execute(ctx, load_task, inputs, output_prefix, is_map_task=True) def normalize_inputs( diff --git a/flytekit/core/task.py b/flytekit/core/task.py index 9709adda08..1196fd95c7 100644 --- a/flytekit/core/task.py +++ b/flytekit/core/task.py @@ -5,12 +5,7 @@ from functools import update_wrapper from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union, overload -from flytekit.core.utils import str2bool - -try: - from typing import ParamSpec -except ImportError: - from typing_extensions import ParamSpec # type: ignore +from typing_extensions import ParamSpec # type: ignore from flytekit.core import launch_plan as _annotated_launchplan from flytekit.core import workflow as _annotated_workflow @@ -20,6 +15,7 @@ from flytekit.core.python_function_task import PythonFunctionTask from flytekit.core.reference_entity import ReferenceEntity, TaskReference from flytekit.core.resources import Resources +from flytekit.core.utils import str2bool from flytekit.deck import DeckField from flytekit.extras.accelerators import BaseAccelerator from flytekit.image_spec.image_spec import ImageSpec diff --git a/flytekit/core/workflow.py b/flytekit/core/workflow.py index 9cccf19e58..bb48cde73b 100644 --- a/flytekit/core/workflow.py +++ b/flytekit/core/workflow.py @@ -8,13 +8,9 @@ from functools import update_wrapper from typing import Any, Callable, Coroutine, Dict, List, Optional, Tuple, Type, Union, cast, overload +from typing_extensions import ParamSpec # type: ignore from typing_inspect import is_optional_type -try: - from typing import ParamSpec -except ImportError: - from typing_extensions import ParamSpec # type: ignore - from flytekit.core import constants as _common_constants from flytekit.core import launch_plan as _annotated_launch_plan from flytekit.core.base_task import PythonTask, Task diff --git a/flytekit/interaction/string_literals.py b/flytekit/interaction/string_literals.py index 0bfb3c866a..6f70488981 100644 --- a/flytekit/interaction/string_literals.py +++ b/flytekit/interaction/string_literals.py @@ -61,6 +61,9 @@ def literal_string_repr(lit: Literal) -> typing.Any: return [literal_string_repr(i) for i in lit.collection.literals] if lit.map: return {k: literal_string_repr(v) for k, v in lit.map.literals.items()} + if lit.offloaded_metadata: + # TODO: load literal from offloaded literal? + return f"Offloaded literal metadata: {lit.offloaded_metadata}" raise ValueError(f"Unknown literal type {lit}") diff --git a/flytekit/models/literals.py b/flytekit/models/literals.py index f433c2fad1..d65ebfafae 100644 --- a/flytekit/models/literals.py +++ b/flytekit/models/literals.py @@ -991,7 +991,7 @@ def to_flyte_idl(self): map=self.map.to_flyte_idl() if self.map is not None else None, hash=self.hash, metadata=self.metadata, - offloaded_metadata=self.offloaded_metadata.to_flyte_idl() if self.offloaded_metadata else None, + offloaded_metadata=self.offloaded_metadata.to_flyte_idl() if self.offloaded_metadata is not None else None, ) @classmethod diff --git a/flytekit/utils/pbhash.py b/flytekit/utils/pbhash.py new file mode 100644 index 0000000000..ae4a364d12 --- /dev/null +++ b/flytekit/utils/pbhash.py @@ -0,0 +1,39 @@ +# This is a module that provides hashing utilities for Protobuf objects. +import base64 +import hashlib +import json + +from google.protobuf import json_format +from google.protobuf.message import Message + + +def compute_hash(pb: Message) -> bytes: + """ + Computes a deterministic hash in bytes for the Protobuf object. + """ + try: + pb_dict = json_format.MessageToDict(pb) + # json.dumps with sorted keys to ensure stability + stable_json_str = json.dumps( + pb_dict, sort_keys=True, separators=(",", ":") + ) # separators to ensure no extra spaces + except Exception as e: + raise ValueError(f"Failed to marshal Protobuf object {pb} to JSON with error: {e}") + + try: + # Deterministically hash the JSON object to a byte array. Using SHA-256 for hashing here, + # assuming it provides a consistent hash output. + hash_obj = hashlib.sha256(stable_json_str.encode("utf-8")) + except Exception as e: + raise ValueError(f"Failed to hash JSON for Protobuf object {pb} with error: {e}") + + # The digest is guaranteed to be 32 bytes long + return hash_obj.digest() + + +def compute_hash_string(pb: Message) -> str: + """ + Computes a deterministic hash in base64 encoded string for the Protobuf object + """ + hash_bytes = compute_hash(pb) + return base64.b64encode(hash_bytes).decode("utf-8") diff --git a/tests/flytekit/unit/bin/test_python_entrypoint.py b/tests/flytekit/unit/bin/test_python_entrypoint.py index fb32193a76..3955019cd0 100644 --- a/tests/flytekit/unit/bin/test_python_entrypoint.py +++ b/tests/flytekit/unit/bin/test_python_entrypoint.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass from datetime import datetime import os import re @@ -11,24 +12,32 @@ import mock import pytest from flyteidl.core.errors_pb2 import ErrorDocument +from flyteidl.core import literals_pb2 +from flyteidl.core.literals_pb2 import Literal, LiteralCollection, Scalar, Primitive from google.protobuf.timestamp_pb2 import Timestamp from flytekit.bin.entrypoint import _dispatch_execute, get_container_error_timestamp, normalize_inputs, setup_execution, get_traceback_str from flytekit.configuration import Image, ImageConfig, SerializationSettings +from flytekit.core import mock_stats +from flytekit.core.array_node_map_task import ArrayNodeMapTask +from flytekit.core.hash import HashMethod +from flytekit.models.core import identifier as id_models from flytekit.core import context_manager from flytekit.core.base_task import IgnoreOutputs from flytekit.core.dynamic_workflow_task import dynamic from flytekit.core.promise import VoidPromise from flytekit.core.task import task -from flytekit.core.type_engine import TypeEngine +from flytekit.core.type_engine import TypeEngine, DataclassTransformer from flytekit.exceptions import user as user_exceptions from flytekit.exceptions.base import FlyteException from flytekit.exceptions.scopes import system_entry_point from flytekit.exceptions.user import FlyteRecoverableException, FlyteUserRuntimeException from flytekit.models import literals as _literal_models -from flytekit.models.core import errors as error_models +from flytekit.models.core import errors as error_models, execution from flytekit.models.core import execution as execution_models +from flytekit.core.utils import write_proto_to_file +from flytekit.models.types import LiteralType, SimpleType @mock.patch("flytekit.core.utils.load_proto_from_file") @@ -508,6 +517,7 @@ def test_get_traceback_str(): print(traceback_str) # helpful for debugging assert expected_error_re.match(traceback_str) is not None + def test_get_container_error_timestamp() -> None: assert get_container_error_timestamp(FlyteException("foo", timestamp=10.5)) == Timestamp(seconds=10, nanos=500000000) @@ -522,3 +532,454 @@ def test_get_container_error_timestamp() -> None: current_dtime = datetime.now() error_timestamp = get_container_error_timestamp(None) assert error_timestamp.ToDatetime() >= current_dtime + + +def get_flyte_context(tmp_path_factory, outputs_path): + """ + This is a helper function to create a flyte context with the right parameters for testing offloading of literals. + """ + ctx = context_manager.FlyteContext.current_context() + return context_manager.FlyteContextManager.with_context( + ctx.with_execution_state( + ctx.execution_state.with_params( + engine_dir=tmp_path_factory.mktemp("engine_dir"), + mode=context_manager.ExecutionState.Mode.TASK_EXECUTION, + user_space_params=context_manager.ExecutionParameters( + execution_date=datetime.now(), + tmp_dir="/tmp", + stats=mock_stats.MockStats(), + logging=None, + raw_output_prefix="", + output_metadata_prefix=str(outputs_path.absolute()), + execution_id=id_models.WorkflowExecutionIdentifier("p", "d", "n"), + ), + ), + ), + ) + + +def test_dispatch_execute_offloaded_literals(tmp_path_factory): + @task + def t1(a: typing.List[int]) -> typing.List[str]: + return [f"string is: {x}" for x in a] + + inputs_path = tmp_path_factory.mktemp("inputs") + outputs_path = tmp_path_factory.mktemp("outputs") + + ctx = context_manager.FlyteContext.current_context() + with get_flyte_context(tmp_path_factory, outputs_path) as ctx: + xs: typing.List[int] = [1, 2, 3] + input_literal_map = _literal_models.LiteralMap( + { + "a": TypeEngine.to_literal(ctx, xs, typing.List[int], TypeEngine.to_literal_type(typing.List[int])), + } + ) + + write_proto_to_file(input_literal_map.to_flyte_idl(), str(inputs_path/"inputs.pb")) + + with mock.patch.dict(os.environ, {"_F_L_MIN_SIZE_MB": "0"}): + _dispatch_execute(ctx, lambda: t1, str(inputs_path/"inputs.pb"), str(outputs_path.absolute())) + + assert "error.pb" not in os.listdir(outputs_path) + + for ff in os.listdir(outputs_path): + with open(outputs_path/ff, "rb") as f: + if ff == "outputs.pb": + lit = literals_pb2.LiteralMap() + lit.ParseFromString(f.read()) + assert len(lit.literals) == 1 + assert "o0" in lit.literals + assert lit.literals["o0"].HasField("offloaded_metadata") == True + assert lit.literals["o0"].offloaded_metadata.size_bytes == 62 + assert lit.literals["o0"].offloaded_metadata.uri.endswith("/o0_offloaded_metadata.pb") + assert lit.literals["o0"].offloaded_metadata.inferred_type == LiteralType(collection_type=LiteralType(simple=SimpleType.STRING)).to_flyte_idl() + elif ff == "o0_offloaded_metadata.pb": + lit = literals_pb2.Literal() + lit.ParseFromString(f.read()) + assert lit == Literal( + collection=LiteralCollection( + literals=[ + Literal( + scalar=Scalar(primitive=Primitive(string_value="string is: 1")), + ), + Literal( + scalar=Scalar(primitive=Primitive(string_value="string is: 2")), + ), + Literal( + scalar=Scalar(primitive=Primitive(string_value="string is: 3")), + ), + ] + ) + ) + else: + assert False, f"Unexpected file {ff}" + + +def test_dispatch_execute_offloaded_literals_two_outputs_offloaded(tmp_path_factory): + @task + def t1(xs: typing.List[int]) -> typing.Tuple[int, typing.List[str]]: + return sum(xs), [f"string is: {x}" for x in xs] + + inputs_path = tmp_path_factory.mktemp("inputs") + outputs_path = tmp_path_factory.mktemp("outputs") + + ctx = context_manager.FlyteContext.current_context() + with get_flyte_context(tmp_path_factory, outputs_path) as ctx: + xs: typing.List[int] = [1, 2, 3, 4] + input_literal_map = _literal_models.LiteralMap( + { + "xs": _literal_models.Literal( + collection=_literal_models.LiteralCollection( + literals=[ + _literal_models.Literal( + scalar=_literal_models.Scalar(primitive=_literal_models.Primitive(integer=x)), + ) for x in xs + ] + ) + ) + } + ) + + write_proto_to_file(input_literal_map.to_flyte_idl(), str(inputs_path/"inputs.pb")) + + with mock.patch.dict(os.environ, {"_F_L_MIN_SIZE_MB": "0"}): + _dispatch_execute(ctx, lambda: t1, str(inputs_path/"inputs.pb"), str(outputs_path.absolute())) + + assert "error.pb" not in os.listdir(outputs_path) + + for ff in os.listdir(outputs_path): + with open(outputs_path/ff, "rb") as f: + if ff == "outputs.pb": + lit = literals_pb2.LiteralMap() + lit.ParseFromString(f.read()) + assert len(lit.literals) == 2 + assert "o0" in lit.literals + assert lit.literals["o0"].HasField("offloaded_metadata") == True + assert lit.literals["o0"].offloaded_metadata.size_bytes == 6 + assert lit.literals["o0"].offloaded_metadata.uri.endswith("/o0_offloaded_metadata.pb") + assert lit.literals["o0"].offloaded_metadata.inferred_type == LiteralType(simple=SimpleType.INTEGER).to_flyte_idl() + assert "o1" in lit.literals + assert lit.literals["o1"].HasField("offloaded_metadata") == True + assert lit.literals["o1"].offloaded_metadata.size_bytes == 82 + assert lit.literals["o1"].offloaded_metadata.uri.endswith("/o1_offloaded_metadata.pb") + assert lit.literals["o1"].offloaded_metadata.inferred_type == LiteralType(collection_type=LiteralType(simple=SimpleType.STRING)).to_flyte_idl() + elif ff == "o0_offloaded_metadata.pb": + lit = literals_pb2.Literal() + lit.ParseFromString(f.read()) + assert lit == Literal( + scalar=Scalar(primitive=Primitive(integer=10)), + ) + elif ff == "o1_offloaded_metadata.pb": + lit = literals_pb2.Literal() + lit.ParseFromString(f.read()) + assert lit == Literal( + collection=LiteralCollection( + literals=[ + Literal( + scalar=Scalar(primitive=Primitive(string_value="string is: 1")), + ), + Literal( + scalar=Scalar(primitive=Primitive(string_value="string is: 2")), + ), + Literal( + scalar=Scalar(primitive=Primitive(string_value="string is: 3")), + ), + Literal( + scalar=Scalar(primitive=Primitive(string_value="string is: 4")), + ), + ] + ) + ) + else: + assert False, f"Unexpected file {ff}" + + +def test_dispatch_execute_offloaded_literals_two_outputs_only_second_one_offloaded(tmp_path_factory): + @dataclass + class DC: + a: typing.List[int] + b: typing.List[str] + + @task + def t1(n: int) -> typing.Tuple[int, DC]: + return n, DC(a=list(range(n)), b=[f"string is: {x}" for x in range(n)]) + + inputs_path = tmp_path_factory.mktemp("inputs") + outputs_path = tmp_path_factory.mktemp("outputs") + + with get_flyte_context(tmp_path_factory, outputs_path) as ctx: + input_literal_map = _literal_models.LiteralMap( + { + "n": _literal_models.Literal( + scalar=_literal_models.Scalar(primitive=_literal_models.Primitive(integer=56_000)), + ) + } + ) + + write_proto_to_file(input_literal_map.to_flyte_idl(), str(inputs_path/"inputs.pb")) + + # Notice how the threshold is set to 1MB + with mock.patch.dict(os.environ, {"_F_L_MIN_SIZE_MB": "1"}): + _dispatch_execute(ctx, lambda: t1, str(inputs_path/"inputs.pb"), str(outputs_path.absolute())) + + assert "error.pb" not in os.listdir(outputs_path) + + # o0 is not offloaded + assert "o0_offloaded_metadata.pb" not in os.listdir(outputs_path) + + for ff in os.listdir(outputs_path): + with open(outputs_path/ff, "rb") as f: + if ff == "outputs.pb": + lit = literals_pb2.LiteralMap() + lit.ParseFromString(f.read()) + assert len(lit.literals) == 2 + + # o0 is not offloaded + assert "o0" in lit.literals + assert lit.literals["o0"].HasField("offloaded_metadata") is False + assert lit.literals["o0"].hash == "" + + # o1 is offloaded + assert "o1" in lit.literals + assert lit.literals["o1"].HasField("offloaded_metadata") is True + assert lit.literals["o1"].offloaded_metadata.size_bytes == 1108538 + assert lit.literals["o1"].offloaded_metadata.uri.endswith("/o1_offloaded_metadata.pb") + assert lit.literals["o1"].hash == "VS9bthLslGa8tjuVBCcmO3UdGHrkpyOBXzJlmY47fw8=" + assert lit.literals["o1"].offloaded_metadata.inferred_type == DataclassTransformer().get_literal_type(DC).to_flyte_idl() + elif ff == "o1_offloaded_metadata.pb": + lit = literals_pb2.Literal() + lit.ParseFromString(f.read()) + assert lit.hash == "" + # Load the dataclass from the proto + transformer = TypeEngine.get_transformer(DC) + dc = transformer.to_python_value(ctx, _literal_models.Literal.from_flyte_idl(lit), DC) + assert dc.a == list(range(56_000)) + else: + assert False, f"Unexpected file {ff}" + + + +def test_dispatch_execute_offloaded_literals_annotated_hash(tmp_path_factory): + class A: + def __init__(self, a: int): + self.a = a + + @task + def t1(n: int) -> typing.Annotated[A, HashMethod(lambda x: str(x.a))]: + return A(a=n) + + inputs_path = tmp_path_factory.mktemp("inputs") + outputs_path = tmp_path_factory.mktemp("outputs") + + with get_flyte_context(tmp_path_factory, outputs_path) as ctx: + input_literal_map = _literal_models.LiteralMap( + { + "n": _literal_models.Literal( + scalar=_literal_models.Scalar(primitive=_literal_models.Primitive(integer=1234)), + ) + } + ) + + write_proto_to_file(input_literal_map.to_flyte_idl(), str(inputs_path/"inputs.pb")) + + # All literals should be offloaded + with mock.patch.dict(os.environ, {"_F_L_MIN_SIZE_MB": "0"}): + _dispatch_execute(ctx, lambda: t1, str(inputs_path/"inputs.pb"), str(outputs_path.absolute())) + + assert "error.pb" not in os.listdir(outputs_path) + + for ff in os.listdir(outputs_path): + with open(outputs_path/ff, "rb") as f: + if ff == "outputs.pb": + lit = literals_pb2.LiteralMap() + lit.ParseFromString(f.read()) + assert len(lit.literals) == 1 + + # o0 is offloaded + assert "o0" in lit.literals + assert lit.literals["o0"].HasField("offloaded_metadata") is True + assert lit.literals["o0"].offloaded_metadata.size_bytes > 0 + assert lit.literals["o0"].offloaded_metadata.uri.endswith("/o0_offloaded_metadata.pb") + assert lit.literals["o0"].hash == "1234" + assert lit.literals["o0"].offloaded_metadata.inferred_type == t1.interface.outputs["o0"].type.to_flyte_idl() + elif ff == "o0_offloaded_metadata.pb": + lit = literals_pb2.Literal() + lit.ParseFromString(f.read()) + assert lit.hash == "1234" + transformer = TypeEngine.get_transformer(A) + a = transformer.to_python_value(ctx, _literal_models.Literal.from_flyte_idl(lit), A) + assert a.a == 1234 + else: + assert False, f"Unexpected file {ff}" + + +def test_dispatch_execute_offloaded_nested_lists_of_literals(tmp_path_factory): + @task + def t1(a: typing.List[int]) -> typing.List[typing.List[str]]: + return [[f"string is: {x}" for x in a] for _ in range(len(a))] + + inputs_path = tmp_path_factory.mktemp("inputs") + outputs_path = tmp_path_factory.mktemp("outputs") + + ctx = context_manager.FlyteContext.current_context() + with get_flyte_context(tmp_path_factory, outputs_path) as ctx: + xs: typing.List[int] = [1, 2, 3] + input_literal_map = _literal_models.LiteralMap( + { + "a": TypeEngine.to_literal(ctx, xs, typing.List[int], TypeEngine.to_literal_type(typing.List[int])), + } + ) + + write_proto_to_file(input_literal_map.to_flyte_idl(), str(inputs_path/"inputs.pb")) + + with mock.patch.dict(os.environ, {"_F_L_MIN_SIZE_MB": "0"}): + _dispatch_execute(ctx, lambda: t1, str(inputs_path/"inputs.pb"), str(outputs_path.absolute())) + + assert "error.pb" not in os.listdir(outputs_path) + + for ff in os.listdir(outputs_path): + with open(outputs_path/ff, "rb") as f: + if ff == "outputs.pb": + lit = literals_pb2.LiteralMap() + lit.ParseFromString(f.read()) + assert len(lit.literals) == 1 + assert "o0" in lit.literals + assert lit.literals["o0"].HasField("offloaded_metadata") == True + assert lit.literals["o0"].offloaded_metadata.size_bytes == 195 + assert lit.literals["o0"].offloaded_metadata.uri.endswith("/o0_offloaded_metadata.pb") + assert lit.literals["o0"].offloaded_metadata.inferred_type == LiteralType(collection_type=LiteralType(collection_type=LiteralType(simple=SimpleType.STRING))).to_flyte_idl() + elif ff == "o0_offloaded_metadata.pb": + lit = literals_pb2.Literal() + lit.ParseFromString(f.read()) + expected_output = [[f"string is: {x}" for x in xs] for _ in range(len(xs))] + assert lit == TypeEngine.to_literal(ctx, expected_output, typing.List[typing.List[str]], TypeEngine.to_literal_type(typing.List[typing.List[str]])).to_flyte_idl() + else: + assert False, f"Unexpected file {ff}" + + +def test_dispatch_execute_offloaded_nested_lists_of_literals_offloading_disabled(tmp_path_factory): + @task + def t1(a: typing.List[int]) -> typing.List[typing.List[str]]: + return [[f"string is: {x}" for x in a] for _ in range(len(a))] + + inputs_path = tmp_path_factory.mktemp("inputs") + outputs_path = tmp_path_factory.mktemp("outputs") + + ctx = context_manager.FlyteContext.current_context() + with get_flyte_context(tmp_path_factory, outputs_path) as ctx: + xs: typing.List[int] = [1, 2, 3] + input_literal_map = _literal_models.LiteralMap( + { + "a": TypeEngine.to_literal(ctx, xs, typing.List[int], TypeEngine.to_literal_type(typing.List[int])), + } + ) + + write_proto_to_file(input_literal_map.to_flyte_idl(), str(inputs_path/"inputs.pb")) + + # Ensure that this is not set by an external source + assert os.environ.get("_F_L_MIN_SIZE_MB") is None + + # Notice how we're setting the env var to None, which disables offloading completely + _dispatch_execute(ctx, lambda: t1, str(inputs_path/"inputs.pb"), str(outputs_path.absolute())) + + assert "error.pb" not in os.listdir(outputs_path) + + for ff in os.listdir(outputs_path): + with open(outputs_path/ff, "rb") as f: + if ff == "outputs.pb": + lit = literals_pb2.LiteralMap() + lit.ParseFromString(f.read()) + assert len(lit.literals) == 1 + assert "o0" in lit.literals + assert lit.literals["o0"].HasField("offloaded_metadata") == False + else: + assert False, f"Unexpected file {ff}" + + + +def test_dispatch_execute_offloaded_map_task(tmp_path_factory): + @task + def t1(n: int) -> int: + return n + 1 + + inputs: typing.List[int] = [1, 2, 3, 4] + for i, v in enumerate(inputs): + inputs_path = tmp_path_factory.mktemp("inputs") + outputs_path = tmp_path_factory.mktemp("outputs") + + ctx = context_manager.FlyteContext.current_context() + with get_flyte_context(tmp_path_factory, outputs_path) as ctx: + input_literal_map = _literal_models.LiteralMap( + { + "n": TypeEngine.to_literal(ctx, inputs, typing.List[int], TypeEngine.to_literal_type(typing.List[int])), + } + ) + + write_proto_to_file(input_literal_map.to_flyte_idl(), str(inputs_path/"inputs.pb")) + + with mock.patch.dict( + os.environ, + { + "_F_L_MIN_SIZE_MB": "0", # Always offload + "BATCH_JOB_ARRAY_INDEX_OFFSET": str(i), + }): + _dispatch_execute(ctx, lambda: ArrayNodeMapTask(python_function_task=t1), str(inputs_path/"inputs.pb"), str(outputs_path.absolute()), is_map_task=True) + + assert "error.pb" not in os.listdir(outputs_path) + + for ff in os.listdir(outputs_path): + with open(outputs_path/ff, "rb") as f: + if ff == "outputs.pb": + lit = literals_pb2.LiteralMap() + lit.ParseFromString(f.read()) + assert len(lit.literals) == 1 + assert "o0" in lit.literals + assert lit.literals["o0"].HasField("offloaded_metadata") == True + assert lit.literals["o0"].offloaded_metadata.uri.endswith("/o0_offloaded_metadata.pb") + assert lit.literals["o0"].offloaded_metadata.inferred_type == LiteralType(simple=SimpleType.INTEGER).to_flyte_idl() + elif ff == "o0_offloaded_metadata.pb": + lit = literals_pb2.Literal() + lit.ParseFromString(f.read()) + expected_output = v + 1 + assert lit == TypeEngine.to_literal(ctx, expected_output, int, TypeEngine.to_literal_type(int)).to_flyte_idl() + else: + assert False, f"Unexpected file {ff}" + + +def test_dispatch_execute_offloaded_nested_lists_of_literals_offloading_disabled(tmp_path_factory): + @task + def t1(a: typing.List[int]) -> typing.List[typing.List[str]]: + return [[f"string is: {x}" for x in a] for _ in range(len(a))] + + inputs_path = tmp_path_factory.mktemp("inputs") + outputs_path = tmp_path_factory.mktemp("outputs") + + ctx = context_manager.FlyteContext.current_context() + with get_flyte_context(tmp_path_factory, outputs_path) as ctx: + xs: typing.List[int] = [1, 2, 3] + input_literal_map = _literal_models.LiteralMap( + { + "a": TypeEngine.to_literal(ctx, xs, typing.List[int], TypeEngine.to_literal_type(typing.List[int])), + } + ) + + write_proto_to_file(input_literal_map.to_flyte_idl(), str(inputs_path/"inputs.pb")) + + # Ensure that this is not set by an external source + assert os.environ.get("_F_L_MIN_SIZE_MB") is None + + # Notice how we're setting the env var to None, which disables offloading completely + _dispatch_execute(ctx, lambda: t1, str(inputs_path/"inputs.pb"), str(outputs_path.absolute())) + + assert "error.pb" not in os.listdir(outputs_path) + + for ff in os.listdir(outputs_path): + with open(outputs_path/ff, "rb") as f: + if ff == "outputs.pb": + lit = literals_pb2.LiteralMap() + lit.ParseFromString(f.read()) + assert len(lit.literals) == 1 + assert "o0" in lit.literals + assert lit.literals["o0"].HasField("offloaded_metadata") == False + else: + assert False, f"Unexpected file {ff}" diff --git a/tests/flytekit/unit/core/image_spec/test_image_spec.py b/tests/flytekit/unit/core/image_spec/test_image_spec.py index b20030c3a0..aa6808215b 100644 --- a/tests/flytekit/unit/core/image_spec/test_image_spec.py +++ b/tests/flytekit/unit/core/image_spec/test_image_spec.py @@ -114,7 +114,7 @@ def test_build_existing_image_with_force_push(): image_spec = ImageSpec(name="hello", builder="test").force_push() builder = Mock() - builder.build_image.return_value = "new_image_name" + builder.build_image.return_value = "fqn.xyz/new_image_name:v-test" ImageBuildEngine.register("test", builder) ImageBuildEngine.build(image_spec) diff --git a/tests/flytekit/unit/utils/test_pbhash.py b/tests/flytekit/unit/utils/test_pbhash.py new file mode 100644 index 0000000000..f608271177 --- /dev/null +++ b/tests/flytekit/unit/utils/test_pbhash.py @@ -0,0 +1,144 @@ +import tempfile +import mock +import pytest +import typing +from dataclasses import dataclass, field +from google.protobuf import json_format +from google.protobuf import struct_pb2 +from dataclasses_json import DataClassJsonMixin + +from flyteidl.core.literals_pb2 import Blob, BlobMetadata, Literal, LiteralCollection, LiteralMap, Primitive, Scalar +from flyteidl.core.types_pb2 import BlobType +from flytekit.core.context_manager import FlyteContext, FlyteContextManager +from flytekit.core.type_engine import DataclassTransformer +from flytekit.types.file.file import FlyteFile +from flytekit.utils.pbhash import compute_hash_string + + +@pytest.mark.parametrize( + "lit, expected_hash", + [ + ( + Literal(scalar=Scalar(primitive=Primitive(integer=1))), + "aJB6fp0kDrfAZt22e/IFnT8IJIlobjxcweiZA8I7/dA=", + ), + ( + Literal(collection=LiteralCollection(literals=[Literal(scalar=Scalar(primitive=Primitive(integer=1)))])), + "qN7iA0GnbLzFGcHB7y09lbxgx+9cTIViSlyL9/kCSC0=", + ), + ( + Literal(map=LiteralMap(literals={"a": Literal(scalar=Scalar(primitive=Primitive(integer=1)))})), + "JhrkdOQ+xzPVNYiKzD5sHhZprQB5Nq1GsYUVbmLAswU=", + ), + ( + Literal( + scalar=Scalar( + blob=Blob( + uri="s3://my-bucket/my-key", + metadata=BlobMetadata( + type=BlobType( + format="PythonPickle", + dimensionality=BlobType.BlobDimensionality.SINGLE, + ), + ), + ), + ), + ), + "KdNNbLBYoamXYLz8SBuJd/kVDPxO4gVGdNQl61qeTfA=", + ), + ( + Literal( + scalar=Scalar( + blob=Blob( + metadata=BlobMetadata( + type=BlobType( + format="PythonPickle", + dimensionality=BlobType.BlobDimensionality.SINGLE, + ), + ), + uri="s3://my-bucket/my-key", + ), + ), + ), + "KdNNbLBYoamXYLz8SBuJd/kVDPxO4gVGdNQl61qeTfA=", + ), + ( + # Literal collection + Literal( + collection=LiteralCollection( + literals=[ + Literal( + scalar=Scalar( + blob=Blob( + metadata=BlobMetadata( + type=BlobType( + format="PythonPickle", + dimensionality=BlobType.BlobDimensionality.SINGLE, + ), + ), + uri="s3://my-bucket/my-key", + ), + ), + ), + ], + ), + ), + "RauoCNnZfCSHgcmMKVugozLAcssq/mWdMjbGanRJufI=", + ) + ], +) +def test_direct_literals(lit, expected_hash): + assert compute_hash_string(lit) == expected_hash + + +@mock.patch("flytekit.core.data_persistence.FileAccessProvider.async_put_data") +def test_dataclass_literals(mock_put_data): + + @dataclass + class A(DataClassJsonMixin): + a: int + + @dataclass + class TestFileStruct(DataClassJsonMixin): + a: FlyteFile + b: typing.Optional[FlyteFile] + b_prime: typing.Optional[FlyteFile] + c: typing.Union[FlyteFile, None] + d: typing.List[FlyteFile] + e: typing.List[typing.Optional[FlyteFile]] + e_prime: typing.List[typing.Optional[FlyteFile]] + f: typing.Dict[str, FlyteFile] + g: typing.Dict[str, typing.Optional[FlyteFile]] + g_prime: typing.Dict[str, typing.Optional[FlyteFile]] + h: typing.Optional[FlyteFile] = None + h_prime: typing.Optional[FlyteFile] = None + i: typing.Optional[A] = None + i_prime: typing.Optional[A] = field(default_factory=lambda: A(a=99)) + + remote_path = "s3://tmp/file" + mock_put_data.return_value = remote_path + + with tempfile.TemporaryFile() as f: + f.write(b"abc") + f1 = FlyteFile("f1", remote_path=remote_path) + o = TestFileStruct( + a=f1, + b=f1, + b_prime=None, + c=f1, + d=[f1], + e=[f1], + e_prime=[None], + f={"a": f1}, + g={"a": f1}, + g_prime={"a": None}, + h=f1, + i=A(a=42), + ) + + ctx = FlyteContext.current_context() + tf = DataclassTransformer() + lt = tf.get_literal_type(TestFileStruct) + lv = tf.to_literal(ctx, o, TestFileStruct, lt) + + assert compute_hash_string(lv.to_flyte_idl()) == "Hp/cWul3sBI5r8XKdVzAlvNBJ4OSX9L2d/SADI8+YOY="