From 12f9edfaf1fd0b00b6e898ad4aa055252a731b75 Mon Sep 17 00:00:00 2001 From: Eduardo Apolinario Date: Fri, 18 Oct 2024 17:54:30 -0400 Subject: [PATCH 01/25] wip - Implement offloading of literals Signed-off-by: Eduardo Apolinario --- flytekit/bin/entrypoint.py | 33 +++++++++++- .../unit/bin/test_python_entrypoint.py | 51 ++++++++++++++++++- 2 files changed, 82 insertions(+), 2 deletions(-) diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py index 71bfaa8708..72f7bb1081 100644 --- a/flytekit/bin/entrypoint.py +++ b/flytekit/bin/entrypoint.py @@ -53,6 +53,10 @@ from flytekit.tools.module_loader import load_object_from_module +# MAX_OFFLOADED_LITERAL_SIZE_BYTES = 10 * 1024 * 1024 +MAX_OFFLOADED_LITERAL_SIZE_BYTES = 10 + + def get_version_message(): import flytekit @@ -137,7 +141,34 @@ def _dispatch_execute( logger.warning("Task produces no outputs") output_file_dict = {_constants.OUTPUT_FILE_NAME: _literal_models.LiteralMap(literals={})} elif isinstance(outputs, _literal_models.LiteralMap): - output_file_dict = {_constants.OUTPUT_FILE_NAME: outputs} + # output_file_dict = {_constants.OUTPUT_FILE_NAME: outputs} + offloaded_literals = {} + new_outputs = {} + # Offload literals if they are too large + for k, v in outputs.literals.items(): + assert type(v) == _literal_models.Literal + lit = v.to_flyte_idl() + if lit.ByteSize() > MAX_OFFLOADED_LITERAL_SIZE_BYTES: + logger.debug(f"Literal {k} is too large to be inlined, offloading to metadata bucket") + + # TODO: hash calculation + + offloaded_filename = f"{k}_offloaded_metadata.pb" + # Offload the literal to a remote file in the metadata bucket + offloaded_literal = _literal_models.Literal( + offloaded_metadata=_literal_models.LiteralOffloadedMetadata( + uri=offloaded_filename, + size_bytes=lit.ByteSize(), + # TODO: do I have to set the inferred literal type? + ) + ) + new_outputs[k] = offloaded_literal + offloaded_literals[offloaded_filename] = v + else: + new_outputs[k] = v + outputs = _literal_models.LiteralMap(literals=new_outputs) + + output_file_dict = {_constants.OUTPUT_FILE_NAME: outputs, **offloaded_literals} elif isinstance(outputs, _dynamic_job.DynamicJobSpec): output_file_dict = {_constants.FUTURES_FILE_NAME: outputs} else: diff --git a/tests/flytekit/unit/bin/test_python_entrypoint.py b/tests/flytekit/unit/bin/test_python_entrypoint.py index fea5706cd3..8334c1de50 100644 --- a/tests/flytekit/unit/bin/test_python_entrypoint.py +++ b/tests/flytekit/unit/bin/test_python_entrypoint.py @@ -18,7 +18,7 @@ from flytekit.core.task import task from flytekit.core.type_engine import TypeEngine from flytekit.exceptions import user as user_exceptions -from flytekit.exceptions.scopes import system_entry_point +from flytekit.exceptions.scopes import system_entry_point, user_entry_point from flytekit.exceptions.user import FlyteRecoverableException, FlyteUserRuntimeException from flytekit.models import literals as _literal_models from flytekit.models.core import errors as error_models @@ -453,3 +453,52 @@ def test_get_traceback_str(): expected_error_re = re.compile(expected_error_pattern) print(traceback_str) # helpful for debugging assert expected_error_re.match(traceback_str) is not None + +# Write a unit test to exercise the offloading of literals in entrypoint.py +@mock.patch("flytekit.core.utils.load_proto_from_file") +@mock.patch("flytekit.core.data_persistence.FileAccessProvider.get_data") +@mock.patch("flytekit.core.data_persistence.FileAccessProvider.put_data") +@mock.patch("flytekit.core.utils.write_proto_to_file") +def test_dispatch_execute_offloaded_literals(mock_write_to_file, mock_upload_dir, mock_get_data, mock_load_proto): + # Just leave these here, mock them out so nothing happens + mock_get_data.return_value = True + mock_upload_dir.return_value = True + + @task + def t1(a: typing.List[int]) -> typing.List[str]: + return [f"string is: {x}" for x in a] + + ctx = context_manager.FlyteContext.current_context() + with context_manager.FlyteContextManager.with_context( + ctx.with_execution_state( + ctx.execution_state.with_params(mode=context_manager.ExecutionState.Mode.TASK_EXECUTION) + ) + ) as ctx: + xs: typing.List[int] = [5]*1_000 + input_literal_map = _literal_models.LiteralMap( + { + "a": _literal_models.Literal( + collection=_literal_models.LiteralCollection( + literals=[ + _literal_models.Literal( + scalar=_literal_models.Scalar(primitive=_literal_models.Primitive(integer=x)), + ) for x in xs + ] + ) + ) + } + ) + mock_load_proto.return_value = input_literal_map.to_flyte_idl() + + files = OrderedDict() + mock_write_to_file.side_effect = get_output_collector(files) + # See comment in test_dispatch_execute_ignore for why we need to decorate + user_entry_point(_dispatch_execute)(ctx, lambda: t1, "inputs path", "outputs prefix") + assert len(files) == 2 + + k = list(files.keys())[0] + assert "outputs.pb" in k + + # v = list(files.values())[0] + # lm = _literal_models.LiteralMap.from_flyte_idl(v) + # assert lm.literals["o0"].scalar.primitive.string_value == "string is: 5" From 5fefa850b099477f9f8dc02a829468382cdc43a0 Mon Sep 17 00:00:00 2001 From: Eduardo Apolinario Date: Fri, 18 Oct 2024 18:22:44 -0400 Subject: [PATCH 02/25] Fix use of metadata bucket prefix Signed-off-by: Eduardo Apolinario --- flytekit/bin/entrypoint.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py index 72f7bb1081..46da41b979 100644 --- a/flytekit/bin/entrypoint.py +++ b/flytekit/bin/entrypoint.py @@ -12,7 +12,7 @@ import traceback import warnings from sys import exit -from typing import Callable, List, Optional +from typing import Callable, Dict, List, Optional import click from flyteidl.core import literals_pb2 as _literals_pb2 @@ -52,7 +52,6 @@ from flytekit.tools.fast_registration import download_distribution as _download_distribution from flytekit.tools.module_loader import load_object_from_module - # MAX_OFFLOADED_LITERAL_SIZE_BYTES = 10 * 1024 * 1024 MAX_OFFLOADED_LITERAL_SIZE_BYTES = 10 @@ -141,12 +140,11 @@ def _dispatch_execute( logger.warning("Task produces no outputs") output_file_dict = {_constants.OUTPUT_FILE_NAME: _literal_models.LiteralMap(literals={})} elif isinstance(outputs, _literal_models.LiteralMap): - # output_file_dict = {_constants.OUTPUT_FILE_NAME: outputs} - offloaded_literals = {} + offloaded_literals: Dict[str, _literal_models.Literal] = {} new_outputs = {} - # Offload literals if they are too large + + # Go over each output and create a separate offloaded in case its size is too large for k, v in outputs.literals.items(): - assert type(v) == _literal_models.Literal lit = v.to_flyte_idl() if lit.ByteSize() > MAX_OFFLOADED_LITERAL_SIZE_BYTES: logger.debug(f"Literal {k} is too large to be inlined, offloading to metadata bucket") @@ -154,9 +152,10 @@ def _dispatch_execute( # TODO: hash calculation offloaded_filename = f"{k}_offloaded_metadata.pb" - # Offload the literal to a remote file in the metadata bucket + offloaded_literal = _literal_models.Literal( offloaded_metadata=_literal_models.LiteralOffloadedMetadata( + uri=f"{ctx.user_space_params.output_metadata_prefix}/{offloaded_filename}", uri=offloaded_filename, size_bytes=lit.ByteSize(), # TODO: do I have to set the inferred literal type? From 1668fdaf6e3ff4132b56e2e3ff1c6ebfa2e25d67 Mon Sep 17 00:00:00 2001 From: Eduardo Apolinario Date: Fri, 18 Oct 2024 18:26:05 -0400 Subject: [PATCH 03/25] Fix repeated use of uri Signed-off-by: Eduardo Apolinario --- flytekit/bin/entrypoint.py | 1 - 1 file changed, 1 deletion(-) diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py index 46da41b979..37c7d51968 100644 --- a/flytekit/bin/entrypoint.py +++ b/flytekit/bin/entrypoint.py @@ -156,7 +156,6 @@ def _dispatch_execute( offloaded_literal = _literal_models.Literal( offloaded_metadata=_literal_models.LiteralOffloadedMetadata( uri=f"{ctx.user_space_params.output_metadata_prefix}/{offloaded_filename}", - uri=offloaded_filename, size_bytes=lit.ByteSize(), # TODO: do I have to set the inferred literal type? ) From 853df625229a09a1df0de16976d7034b337e06a8 Mon Sep 17 00:00:00 2001 From: Eduardo Apolinario Date: Fri, 18 Oct 2024 18:32:15 -0400 Subject: [PATCH 04/25] Add temporary representation for offloaded literal Signed-off-by: Eduardo Apolinario --- flytekit/interaction/string_literals.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/flytekit/interaction/string_literals.py b/flytekit/interaction/string_literals.py index 0bfb3c866a..7366901c3e 100644 --- a/flytekit/interaction/string_literals.py +++ b/flytekit/interaction/string_literals.py @@ -61,6 +61,9 @@ def literal_string_repr(lit: Literal) -> typing.Any: return [literal_string_repr(i) for i in lit.collection.literals] if lit.map: return {k: literal_string_repr(v) for k, v in lit.map.literals.items()} + if lit.offloaded_metadata: + # TODO: load literal from offloaded literal + return f"Offloaded literal metadata: {lit.offloaded_metadata}" raise ValueError(f"Unknown literal type {lit}") From 5e53a1b76769c624dd8b119b14e9936ec734fcca Mon Sep 17 00:00:00 2001 From: Eduardo Apolinario Date: Mon, 21 Oct 2024 22:35:27 -0400 Subject: [PATCH 05/25] Add one unit test Signed-off-by: Eduardo Apolinario --- flytekit/bin/entrypoint.py | 8 +- flytekit/interaction/string_literals.py | 2 +- .../unit/bin/test_python_entrypoint.py | 88 +++++++++++++------ 3 files changed, 68 insertions(+), 30 deletions(-) diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py index 37c7d51968..d7734cd5a5 100644 --- a/flytekit/bin/entrypoint.py +++ b/flytekit/bin/entrypoint.py @@ -143,10 +143,16 @@ def _dispatch_execute( offloaded_literals: Dict[str, _literal_models.Literal] = {} new_outputs = {} + min_offloaded_size = int(os.environ.get("FK_L_MIN_SIZE_MB", "10")) * 1024 * 1024 + max_offloaded_size = int(os.environ.get("FK_L_MAX_SIZE_MB", "1000")) * 1024 * 1024 + # Go over each output and create a separate offloaded in case its size is too large for k, v in outputs.literals.items(): lit = v.to_flyte_idl() - if lit.ByteSize() > MAX_OFFLOADED_LITERAL_SIZE_BYTES: + if lit.ByteSize() >= max_offloaded_size: + raise ValueError(f"Literal {k} is too large to be offloaded. Max literal size is {max_offloaded_size} whereas the literal size is {lit.ByteSize()} bytes") + + if lit.ByteSize() >= min_offloaded_size: logger.debug(f"Literal {k} is too large to be inlined, offloading to metadata bucket") # TODO: hash calculation diff --git a/flytekit/interaction/string_literals.py b/flytekit/interaction/string_literals.py index 7366901c3e..6f70488981 100644 --- a/flytekit/interaction/string_literals.py +++ b/flytekit/interaction/string_literals.py @@ -62,7 +62,7 @@ def literal_string_repr(lit: Literal) -> typing.Any: if lit.map: return {k: literal_string_repr(v) for k, v in lit.map.literals.items()} if lit.offloaded_metadata: - # TODO: load literal from offloaded literal + # TODO: load literal from offloaded literal? return f"Offloaded literal metadata: {lit.offloaded_metadata}" raise ValueError(f"Unknown literal type {lit}") diff --git a/tests/flytekit/unit/bin/test_python_entrypoint.py b/tests/flytekit/unit/bin/test_python_entrypoint.py index 8334c1de50..74cbc7df30 100644 --- a/tests/flytekit/unit/bin/test_python_entrypoint.py +++ b/tests/flytekit/unit/bin/test_python_entrypoint.py @@ -1,3 +1,4 @@ +from datetime import datetime import os import re import textwrap @@ -8,9 +9,13 @@ import mock import pytest from flyteidl.core.errors_pb2 import ErrorDocument +from flyteidl.core import literals_pb2 +from flyteidl.core.literals_pb2 import Literal, LiteralCollection, Scalar, Primitive from flytekit.bin.entrypoint import _dispatch_execute, normalize_inputs, setup_execution, get_traceback_str from flytekit.configuration import Image, ImageConfig, SerializationSettings +from flytekit.core import mock_stats +from flytekit.models.core import identifier as id_models from flytekit.core import context_manager from flytekit.core.base_task import IgnoreOutputs from flytekit.core.dynamic_workflow_task import dynamic @@ -21,8 +26,9 @@ from flytekit.exceptions.scopes import system_entry_point, user_entry_point from flytekit.exceptions.user import FlyteRecoverableException, FlyteUserRuntimeException from flytekit.models import literals as _literal_models -from flytekit.models.core import errors as error_models +from flytekit.models.core import errors as error_models, execution from flytekit.models.core import execution as execution_models +from flytekit.core.utils import write_proto_to_file @mock.patch("flytekit.core.utils.load_proto_from_file") @@ -454,27 +460,32 @@ def test_get_traceback_str(): print(traceback_str) # helpful for debugging assert expected_error_re.match(traceback_str) is not None -# Write a unit test to exercise the offloading of literals in entrypoint.py -@mock.patch("flytekit.core.utils.load_proto_from_file") -@mock.patch("flytekit.core.data_persistence.FileAccessProvider.get_data") -@mock.patch("flytekit.core.data_persistence.FileAccessProvider.put_data") -@mock.patch("flytekit.core.utils.write_proto_to_file") -def test_dispatch_execute_offloaded_literals(mock_write_to_file, mock_upload_dir, mock_get_data, mock_load_proto): - # Just leave these here, mock them out so nothing happens - mock_get_data.return_value = True - mock_upload_dir.return_value = True - +def test_dispatch_execute_offloaded_literals(tmp_path_factory): @task def t1(a: typing.List[int]) -> typing.List[str]: return [f"string is: {x}" for x in a] + inputs_path = tmp_path_factory.mktemp("inputs") + outputs_path = tmp_path_factory.mktemp("outputs") + ctx = context_manager.FlyteContext.current_context() with context_manager.FlyteContextManager.with_context( - ctx.with_execution_state( - ctx.execution_state.with_params(mode=context_manager.ExecutionState.Mode.TASK_EXECUTION) - ) + ctx.with_execution_state( + ctx.execution_state.with_params( + mode=context_manager.ExecutionState.Mode.TASK_EXECUTION, + user_space_params=context_manager.ExecutionParameters( + execution_date=datetime.now(), + tmp_dir="/tmp", + stats=mock_stats.MockStats(), + logging=None, + raw_output_prefix="", + output_metadata_prefix=str(outputs_path.absolute()), + execution_id=id_models.WorkflowExecutionIdentifier("p", "d", "n"), + ), + ), + ), ) as ctx: - xs: typing.List[int] = [5]*1_000 + xs: typing.List[int] = [1, 2, 3] input_literal_map = _literal_models.LiteralMap( { "a": _literal_models.Literal( @@ -488,17 +499,38 @@ def t1(a: typing.List[int]) -> typing.List[str]: ) } ) - mock_load_proto.return_value = input_literal_map.to_flyte_idl() - - files = OrderedDict() - mock_write_to_file.side_effect = get_output_collector(files) - # See comment in test_dispatch_execute_ignore for why we need to decorate - user_entry_point(_dispatch_execute)(ctx, lambda: t1, "inputs path", "outputs prefix") - assert len(files) == 2 - - k = list(files.keys())[0] - assert "outputs.pb" in k - # v = list(files.values())[0] - # lm = _literal_models.LiteralMap.from_flyte_idl(v) - # assert lm.literals["o0"].scalar.primitive.string_value == "string is: 5" + write_proto_to_file(input_literal_map.to_flyte_idl(), str(inputs_path/"inputs.pb")) + + with mock.patch.dict(os.environ, {"FK_L_MIN_SIZE_MB": "0"}): + _dispatch_execute(ctx, t1, str(inputs_path/"inputs.pb"), str(outputs_path.absolute())) + + for ff in os.listdir(outputs_path): + if ff == "outputs.pb": + with open(outputs_path/ff, "rb") as f: + lit = literals_pb2.LiteralMap() + lit.ParseFromString(f.read()) + assert len(lit.literals) == 1 + assert "o0" in lit.literals + assert lit.literals["o0"].offloaded_metadata is not None + assert lit.literals["o0"].offloaded_metadata.size_bytes == 62 + assert lit.literals["o0"].offloaded_metadata.uri.endswith("/o0_offloaded_metadata.pb") + elif ff == "o0_offloaded_metadata.pb": + with open(outputs_path/ff, "rb") as f: + lit = literals_pb2.Literal() + lit.ParseFromString(f.read()) + assert lit == Literal( + collection=LiteralCollection( + literals=[ + Literal( + scalar=Scalar(primitive=Primitive(string_value="string is: 1")), + ), + Literal( + scalar=Scalar(primitive=Primitive(string_value="string is: 2")), + ), + Literal( + scalar=Scalar(primitive=Primitive(string_value="string is: 3")), + ), + ] + ) + ) From 177368d2ef46bbfa2cd3590b773c6890a46d0eed Mon Sep 17 00:00:00 2001 From: Eduardo Apolinario Date: Mon, 21 Oct 2024 22:49:53 -0400 Subject: [PATCH 06/25] Add another test Signed-off-by: Eduardo Apolinario --- .../unit/bin/test_python_entrypoint.py | 92 ++++++++++++++++++- 1 file changed, 91 insertions(+), 1 deletion(-) diff --git a/tests/flytekit/unit/bin/test_python_entrypoint.py b/tests/flytekit/unit/bin/test_python_entrypoint.py index 74cbc7df30..0310cae315 100644 --- a/tests/flytekit/unit/bin/test_python_entrypoint.py +++ b/tests/flytekit/unit/bin/test_python_entrypoint.py @@ -460,6 +460,7 @@ def test_get_traceback_str(): print(traceback_str) # helpful for debugging assert expected_error_re.match(traceback_str) is not None + def test_dispatch_execute_offloaded_literals(tmp_path_factory): @task def t1(a: typing.List[int]) -> typing.List[str]: @@ -503,7 +504,9 @@ def t1(a: typing.List[int]) -> typing.List[str]: write_proto_to_file(input_literal_map.to_flyte_idl(), str(inputs_path/"inputs.pb")) with mock.patch.dict(os.environ, {"FK_L_MIN_SIZE_MB": "0"}): - _dispatch_execute(ctx, t1, str(inputs_path/"inputs.pb"), str(outputs_path.absolute())) + _dispatch_execute(ctx, lambda: t1, str(inputs_path/"inputs.pb"), str(outputs_path.absolute())) + + assert "error.pb" not in os.listdir(outputs_path) for ff in os.listdir(outputs_path): if ff == "outputs.pb": @@ -534,3 +537,90 @@ def t1(a: typing.List[int]) -> typing.List[str]: ] ) ) + + + +def test_dispatch_execute_offloaded_literals_only_o1_is_offloaded(tmp_path_factory): + @task + def t1(xs: typing.List[int]) -> typing.Tuple[int, typing.List[str]]: + return sum(xs), [f"string is: {x}" for x in xs] + + inputs_path = tmp_path_factory.mktemp("inputs") + outputs_path = tmp_path_factory.mktemp("outputs") + + ctx = context_manager.FlyteContext.current_context() + with context_manager.FlyteContextManager.with_context( + ctx.with_execution_state( + ctx.execution_state.with_params( + mode=context_manager.ExecutionState.Mode.TASK_EXECUTION, + user_space_params=context_manager.ExecutionParameters( + execution_date=datetime.now(), + tmp_dir="/tmp", + stats=mock_stats.MockStats(), + logging=None, + raw_output_prefix="", + output_metadata_prefix=str(outputs_path.absolute()), + execution_id=id_models.WorkflowExecutionIdentifier("p", "d", "n"), + ), + ), + ), + ) as ctx: + xs: typing.List[int] = [1, 2, 3, 4] + input_literal_map = _literal_models.LiteralMap( + { + "xs": _literal_models.Literal( + collection=_literal_models.LiteralCollection( + literals=[ + _literal_models.Literal( + scalar=_literal_models.Scalar(primitive=_literal_models.Primitive(integer=x)), + ) for x in xs + ] + ) + ) + } + ) + + write_proto_to_file(input_literal_map.to_flyte_idl(), str(inputs_path/"inputs.pb")) + + with mock.patch.dict(os.environ, {"FK_L_MIN_SIZE_MB": "0"}): + _dispatch_execute(ctx, lambda: t1, str(inputs_path/"inputs.pb"), str(outputs_path.absolute())) + + assert "error.pb" not in os.listdir(outputs_path) + + for ff in os.listdir(outputs_path): + with open(outputs_path/ff, "rb") as f: + if ff == "outputs.pb": + lit = literals_pb2.LiteralMap() + lit.ParseFromString(f.read()) + assert len(lit.literals) == 2 + assert "o0" in lit.literals + assert lit.literals["o1"].offloaded_metadata is not None + assert lit.literals["o1"].offloaded_metadata.size_bytes == 82 + assert lit.literals["o1"].offloaded_metadata.uri.endswith("/o1_offloaded_metadata.pb") + elif ff == "o0_offloaded_metadata.pb": + lit = literals_pb2.Literal() + lit.ParseFromString(f.read()) + assert lit == Literal( + scalar=Scalar(primitive=Primitive(integer=10)), + ) + elif ff == "o1_offloaded_metadata.pb": + lit = literals_pb2.Literal() + lit.ParseFromString(f.read()) + assert lit == Literal( + collection=LiteralCollection( + literals=[ + Literal( + scalar=Scalar(primitive=Primitive(string_value="string is: 1")), + ), + Literal( + scalar=Scalar(primitive=Primitive(string_value="string is: 2")), + ), + Literal( + scalar=Scalar(primitive=Primitive(string_value="string is: 3")), + ), + Literal( + scalar=Scalar(primitive=Primitive(string_value="string is: 4")), + ), + ] + ) + ) From 5fc2e84db2050352d7f2a2107bcb56b7ebf5e5d7 Mon Sep 17 00:00:00 2001 From: Eduardo Apolinario Date: Mon, 21 Oct 2024 22:56:55 -0400 Subject: [PATCH 07/25] Stylistic changes to the two tests Signed-off-by: Eduardo Apolinario --- .../unit/bin/test_python_entrypoint.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/tests/flytekit/unit/bin/test_python_entrypoint.py b/tests/flytekit/unit/bin/test_python_entrypoint.py index 0310cae315..2693c25488 100644 --- a/tests/flytekit/unit/bin/test_python_entrypoint.py +++ b/tests/flytekit/unit/bin/test_python_entrypoint.py @@ -509,8 +509,8 @@ def t1(a: typing.List[int]) -> typing.List[str]: assert "error.pb" not in os.listdir(outputs_path) for ff in os.listdir(outputs_path): - if ff == "outputs.pb": - with open(outputs_path/ff, "rb") as f: + with open(outputs_path/ff, "rb") as f: + if ff == "outputs.pb": lit = literals_pb2.LiteralMap() lit.ParseFromString(f.read()) assert len(lit.literals) == 1 @@ -518,8 +518,7 @@ def t1(a: typing.List[int]) -> typing.List[str]: assert lit.literals["o0"].offloaded_metadata is not None assert lit.literals["o0"].offloaded_metadata.size_bytes == 62 assert lit.literals["o0"].offloaded_metadata.uri.endswith("/o0_offloaded_metadata.pb") - elif ff == "o0_offloaded_metadata.pb": - with open(outputs_path/ff, "rb") as f: + elif ff == "o0_offloaded_metadata.pb": lit = literals_pb2.Literal() lit.ParseFromString(f.read()) assert lit == Literal( @@ -537,10 +536,11 @@ def t1(a: typing.List[int]) -> typing.List[str]: ] ) ) + else: + assert False, f"Unexpected file {ff}" - -def test_dispatch_execute_offloaded_literals_only_o1_is_offloaded(tmp_path_factory): +def test_dispatch_execute_offloaded_literals_two_outputs_offloaded(tmp_path_factory): @task def t1(xs: typing.List[int]) -> typing.Tuple[int, typing.List[str]]: return sum(xs), [f"string is: {x}" for x in xs] @@ -594,6 +594,10 @@ def t1(xs: typing.List[int]) -> typing.Tuple[int, typing.List[str]]: lit.ParseFromString(f.read()) assert len(lit.literals) == 2 assert "o0" in lit.literals + assert lit.literals["o0"].offloaded_metadata is not None + assert lit.literals["o0"].offloaded_metadata.size_bytes == 6 + assert lit.literals["o0"].offloaded_metadata.uri.endswith("/o0_offloaded_metadata.pb") + assert "o1" in lit.literals assert lit.literals["o1"].offloaded_metadata is not None assert lit.literals["o1"].offloaded_metadata.size_bytes == 82 assert lit.literals["o1"].offloaded_metadata.uri.endswith("/o1_offloaded_metadata.pb") @@ -624,3 +628,5 @@ def t1(xs: typing.List[int]) -> typing.Tuple[int, typing.List[str]]: ] ) ) + else: + assert False, f"Unexpected file {ff}" From db48d18640364e44fa9e0781e4c6a9bc01bd76f1 Mon Sep 17 00:00:00 2001 From: Eduardo Apolinario Date: Mon, 28 Oct 2024 14:47:07 -0400 Subject: [PATCH 08/25] Add test for min offloading threshold set to 1MB Signed-off-by: Eduardo Apolinario --- flytekit/bin/entrypoint.py | 4 +- flytekit/models/literals.py | 2 +- .../unit/bin/test_python_entrypoint.py | 77 +++++++++++++++++++ 3 files changed, 81 insertions(+), 2 deletions(-) diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py index d7734cd5a5..3c54a72d9a 100644 --- a/flytekit/bin/entrypoint.py +++ b/flytekit/bin/entrypoint.py @@ -150,7 +150,9 @@ def _dispatch_execute( for k, v in outputs.literals.items(): lit = v.to_flyte_idl() if lit.ByteSize() >= max_offloaded_size: - raise ValueError(f"Literal {k} is too large to be offloaded. Max literal size is {max_offloaded_size} whereas the literal size is {lit.ByteSize()} bytes") + raise ValueError( + f"Literal {k} is too large to be offloaded. Max literal size is {max_offloaded_size} whereas the literal size is {lit.ByteSize()} bytes" + ) if lit.ByteSize() >= min_offloaded_size: logger.debug(f"Literal {k} is too large to be inlined, offloading to metadata bucket") diff --git a/flytekit/models/literals.py b/flytekit/models/literals.py index f433c2fad1..d65ebfafae 100644 --- a/flytekit/models/literals.py +++ b/flytekit/models/literals.py @@ -991,7 +991,7 @@ def to_flyte_idl(self): map=self.map.to_flyte_idl() if self.map is not None else None, hash=self.hash, metadata=self.metadata, - offloaded_metadata=self.offloaded_metadata.to_flyte_idl() if self.offloaded_metadata else None, + offloaded_metadata=self.offloaded_metadata.to_flyte_idl() if self.offloaded_metadata is not None else None, ) @classmethod diff --git a/tests/flytekit/unit/bin/test_python_entrypoint.py b/tests/flytekit/unit/bin/test_python_entrypoint.py index 2693c25488..821d0e77cc 100644 --- a/tests/flytekit/unit/bin/test_python_entrypoint.py +++ b/tests/flytekit/unit/bin/test_python_entrypoint.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass from datetime import datetime import os import re @@ -630,3 +631,79 @@ def t1(xs: typing.List[int]) -> typing.Tuple[int, typing.List[str]]: ) else: assert False, f"Unexpected file {ff}" + + +def test_dispatch_execute_offloaded_literals_two_outputs_only_second_one_offloaded(tmp_path_factory): + @dataclass + class DC: + a: typing.List[int] + b: typing.List[str] + + @task + def t1(n: int) -> typing.Tuple[int, DC]: + return n, DC(a=list(range(n)), b=[f"string is: {x}" for x in range(n)]) + + inputs_path = tmp_path_factory.mktemp("inputs") + outputs_path = tmp_path_factory.mktemp("outputs") + + ctx = context_manager.FlyteContext.current_context() + with context_manager.FlyteContextManager.with_context( + ctx.with_execution_state( + ctx.execution_state.with_params( + mode=context_manager.ExecutionState.Mode.TASK_EXECUTION, + user_space_params=context_manager.ExecutionParameters( + execution_date=datetime.now(), + tmp_dir="/tmp", + stats=mock_stats.MockStats(), + logging=None, + raw_output_prefix="", + output_metadata_prefix=str(outputs_path.absolute()), + execution_id=id_models.WorkflowExecutionIdentifier("p", "d", "n"), + ), + ), + ), + ) as ctx: + input_literal_map = _literal_models.LiteralMap( + { + "n": _literal_models.Literal( + scalar=_literal_models.Scalar(primitive=_literal_models.Primitive(integer=56_000)), + ) + } + ) + + write_proto_to_file(input_literal_map.to_flyte_idl(), str(inputs_path/"inputs.pb")) + + # Notice how the threshold is set to 1MB + with mock.patch.dict(os.environ, {"FK_L_MIN_SIZE_MB": "1"}): + _dispatch_execute(ctx, lambda: t1, str(inputs_path/"inputs.pb"), str(outputs_path.absolute())) + + assert "error.pb" not in os.listdir(outputs_path) + + # o0 is not offloaded + assert "o0_offloaded_metadata.pb" not in os.listdir(outputs_path) + + for ff in os.listdir(outputs_path): + with open(outputs_path/ff, "rb") as f: + if ff == "outputs.pb": + lit = literals_pb2.LiteralMap() + lit.ParseFromString(f.read()) + assert len(lit.literals) == 2 + + # o0 is not offloaded + assert "o0" in lit.literals + assert lit.literals["o0"].HasField("offloaded_metadata") is False + + # o1 is offloaded + assert "o1" in lit.literals + assert lit.literals["o1"].HasField("offloaded_metadata") is True + assert lit.literals["o1"].offloaded_metadata.size_bytes == 1108538 + assert lit.literals["o1"].offloaded_metadata.uri.endswith("/o1_offloaded_metadata.pb") + elif ff == "o1_offloaded_metadata.pb": + lit = literals_pb2.Literal() + lit.ParseFromString(f.read()) + # Load the dataclass from the proto + transformer = TypeEngine.get_transformer(DC) + dc = transformer.to_python_value(ctx, _literal_models.Literal.from_flyte_idl(lit), DC) + assert dc.a == list(range(56_000)) + else: + assert False, f"Unexpected file {ff}" From 6884ee0238f955927f36ec3b77f84b771790cf21 Mon Sep 17 00:00:00 2001 From: Eduardo Apolinario Date: Mon, 28 Oct 2024 15:06:34 -0400 Subject: [PATCH 09/25] Pick a unique engine-dir for tests Signed-off-by: Eduardo Apolinario --- tests/flytekit/unit/bin/test_python_entrypoint.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/flytekit/unit/bin/test_python_entrypoint.py b/tests/flytekit/unit/bin/test_python_entrypoint.py index 821d0e77cc..28abf70a2f 100644 --- a/tests/flytekit/unit/bin/test_python_entrypoint.py +++ b/tests/flytekit/unit/bin/test_python_entrypoint.py @@ -474,6 +474,7 @@ def t1(a: typing.List[int]) -> typing.List[str]: with context_manager.FlyteContextManager.with_context( ctx.with_execution_state( ctx.execution_state.with_params( + engine_dir=tmp_path_factory.mktemp("engine_dir"), mode=context_manager.ExecutionState.Mode.TASK_EXECUTION, user_space_params=context_manager.ExecutionParameters( execution_date=datetime.now(), @@ -553,6 +554,7 @@ def t1(xs: typing.List[int]) -> typing.Tuple[int, typing.List[str]]: with context_manager.FlyteContextManager.with_context( ctx.with_execution_state( ctx.execution_state.with_params( + engine_dir=tmp_path_factory.mktemp("engine_dir"), mode=context_manager.ExecutionState.Mode.TASK_EXECUTION, user_space_params=context_manager.ExecutionParameters( execution_date=datetime.now(), @@ -650,6 +652,7 @@ def t1(n: int) -> typing.Tuple[int, DC]: with context_manager.FlyteContextManager.with_context( ctx.with_execution_state( ctx.execution_state.with_params( + engine_dir=tmp_path_factory.mktemp("engine_dir"), mode=context_manager.ExecutionState.Mode.TASK_EXECUTION, user_space_params=context_manager.ExecutionParameters( execution_date=datetime.now(), From 5a6423ca6f30d5dc203e7d40cf6d84fed09e2f0b Mon Sep 17 00:00:00 2001 From: Eduardo Apolinario Date: Mon, 28 Oct 2024 15:26:13 -0400 Subject: [PATCH 10/25] s/new_outputs/literal_map_copy/ Signed-off-by: Eduardo Apolinario --- flytekit/bin/entrypoint.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py index 3c54a72d9a..a22aac47b5 100644 --- a/flytekit/bin/entrypoint.py +++ b/flytekit/bin/entrypoint.py @@ -141,7 +141,7 @@ def _dispatch_execute( output_file_dict = {_constants.OUTPUT_FILE_NAME: _literal_models.LiteralMap(literals={})} elif isinstance(outputs, _literal_models.LiteralMap): offloaded_literals: Dict[str, _literal_models.Literal] = {} - new_outputs = {} + literal_map_copy = {} min_offloaded_size = int(os.environ.get("FK_L_MIN_SIZE_MB", "10")) * 1024 * 1024 max_offloaded_size = int(os.environ.get("FK_L_MAX_SIZE_MB", "1000")) * 1024 * 1024 @@ -168,11 +168,11 @@ def _dispatch_execute( # TODO: do I have to set the inferred literal type? ) ) - new_outputs[k] = offloaded_literal + literal_map_copy[k] = offloaded_literal offloaded_literals[offloaded_filename] = v else: - new_outputs[k] = v - outputs = _literal_models.LiteralMap(literals=new_outputs) + literal_map_copy[k] = v + outputs = _literal_models.LiteralMap(literals=literal_map_copy) output_file_dict = {_constants.OUTPUT_FILE_NAME: outputs, **offloaded_literals} elif isinstance(outputs, _dynamic_job.DynamicJobSpec): From dbfea9339c91eaec8a76c48eb6ef7e9bcd4686a8 Mon Sep 17 00:00:00 2001 From: Eduardo Apolinario Date: Mon, 28 Oct 2024 15:28:48 -0400 Subject: [PATCH 11/25] Remove unused constant Signed-off-by: Eduardo Apolinario --- flytekit/bin/entrypoint.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py index a22aac47b5..d5d5bd26bf 100644 --- a/flytekit/bin/entrypoint.py +++ b/flytekit/bin/entrypoint.py @@ -52,9 +52,6 @@ from flytekit.tools.fast_registration import download_distribution as _download_distribution from flytekit.tools.module_loader import load_object_from_module -# MAX_OFFLOADED_LITERAL_SIZE_BYTES = 10 * 1024 * 1024 -MAX_OFFLOADED_LITERAL_SIZE_BYTES = 10 - def get_version_message(): import flytekit From 25908d95927fa506a76d4aec361dbe48a0e23afa Mon Sep 17 00:00:00 2001 From: Eduardo Apolinario Date: Mon, 28 Oct 2024 22:30:59 -0400 Subject: [PATCH 12/25] Use output_prefix in definition of offloaded literals Signed-off-by: Eduardo Apolinario --- flytekit/bin/entrypoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py index d5d5bd26bf..1cbf326b08 100644 --- a/flytekit/bin/entrypoint.py +++ b/flytekit/bin/entrypoint.py @@ -160,7 +160,7 @@ def _dispatch_execute( offloaded_literal = _literal_models.Literal( offloaded_metadata=_literal_models.LiteralOffloadedMetadata( - uri=f"{ctx.user_space_params.output_metadata_prefix}/{offloaded_filename}", + uri=f"{output_prefix}/{offloaded_filename}", size_bytes=lit.ByteSize(), # TODO: do I have to set the inferred literal type? ) From e827693a2f1823c7e9ef0db852c7f7a77c4f0f07 Mon Sep 17 00:00:00 2001 From: Eduardo Apolinario Date: Wed, 6 Nov 2024 18:51:08 -0500 Subject: [PATCH 13/25] Add initial version of pbhash.py Signed-off-by: Eduardo Apolinario --- flytekit/utils/pbhash.py | 35 ++++++++++++ tests/flytekit/unit/utils/test_pbhash.py | 69 ++++++++++++++++++++++++ 2 files changed, 104 insertions(+) create mode 100644 flytekit/utils/pbhash.py create mode 100644 tests/flytekit/unit/utils/test_pbhash.py diff --git a/flytekit/utils/pbhash.py b/flytekit/utils/pbhash.py new file mode 100644 index 0000000000..723ab62265 --- /dev/null +++ b/flytekit/utils/pbhash.py @@ -0,0 +1,35 @@ +# This is a module that provides hashing utilities for Protobuf objects. +import base64 +import hashlib +import json +from google.protobuf import json_format +from google.protobuf.message import Message + +def compute_hash(pb: Message) -> bytes: + """ + Computes a deterministic hash in bytes for the Protobuf object. + """ + try: + pb_dict = json_format.MessageToDict(pb) + # json.dumps with sorted keys to ensure stability + stable_json_str = json.dumps(pb_dict, sort_keys=True, separators=(',', ':')) # separators to ensure no extra spaces + except Exception as e: + raise ValueError(f"Failed to marshal Protobuf object {pb} to JSON with error: {e}") + + try: + # Deterministically hash the JSON object to a byte array. Using SHA-256 for hashing here, + # assuming it provides a consistent hash output. + hash_obj = hashlib.sha256(stable_json_str.encode('utf-8')) + except Exception as e: + raise ValueError(f"Failed to hash JSON for Protobuf object {pb} with error: {e}") + + # The digest is guaranteed to be 32 bytes long + return hash_obj.digest() + + +def compute_hash_string(pb: Message) -> str: + """ + Computes a deterministic hash in base64 encoded string for the Protobuf object + """ + hash_bytes = compute_hash(pb) + return base64.b64encode(hash_bytes).decode('utf-8') diff --git a/tests/flytekit/unit/utils/test_pbhash.py b/tests/flytekit/unit/utils/test_pbhash.py new file mode 100644 index 0000000000..c00cfdc89b --- /dev/null +++ b/tests/flytekit/unit/utils/test_pbhash.py @@ -0,0 +1,69 @@ +import pytest +from flyteidl.core.literals_pb2 import Blob, BlobMetadata, Literal, LiteralCollection, Scalar +from flyteidl.core.types_pb2 import BlobType + +from flytekit.utils.pbhash import compute_hash_string + + +@pytest.mark.parametrize( + "lit, expected_hash", + [ + ( + Literal( + scalar=Scalar( + blob=Blob( + uri="s3://my-bucket/my-key", + metadata=BlobMetadata( + type=BlobType( + format="PythonPickle", + dimensionality=BlobType.BlobDimensionality.SINGLE, + ), + ), + ), + ), + ), + "KdNNbLBYoamXYLz8SBuJd/kVDPxO4gVGdNQl61qeTfA=", + ), + ( + Literal( + scalar=Scalar( + blob=Blob( + metadata=BlobMetadata( + type=BlobType( + format="PythonPickle", + dimensionality=BlobType.BlobDimensionality.SINGLE, + ), + ), + uri="s3://my-bucket/my-key", + ), + ), + ), + "KdNNbLBYoamXYLz8SBuJd/kVDPxO4gVGdNQl61qeTfA=", + ), + ( + # Literal collection + Literal( + collection=LiteralCollection( + literals=[ + Literal( + scalar=Scalar( + blob=Blob( + metadata=BlobMetadata( + type=BlobType( + format="PythonPickle", + dimensionality=BlobType.BlobDimensionality.SINGLE, + ), + ), + uri="s3://my-bucket/my-key", + ), + ), + ), + ], + ), + ), + "RauoCNnZfCSHgcmMKVugozLAcssq/mWdMjbGanRJufI=", + ) + ], +) +def test_lit(lit, expected_hash): + assert compute_hash_string(lit) == expected_hash From e0e20164a63d0b9c7a867c072ea862bb470036bc Mon Sep 17 00:00:00 2001 From: Eduardo Apolinario Date: Wed, 6 Nov 2024 21:44:08 -0500 Subject: [PATCH 14/25] Add tests to verify that overriding the hash is carried over to offloaded literals Signed-off-by: Eduardo Apolinario --- flytekit/bin/entrypoint.py | 10 +- .../unit/bin/test_python_entrypoint.py | 143 +++++++++++------- 2 files changed, 93 insertions(+), 60 deletions(-) diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py index 1cbf326b08..6448e04e99 100644 --- a/flytekit/bin/entrypoint.py +++ b/flytekit/bin/entrypoint.py @@ -51,6 +51,7 @@ from flytekit.models.core import identifier as _identifier from flytekit.tools.fast_registration import download_distribution as _download_distribution from flytekit.tools.module_loader import load_object_from_module +from flytekit.utils.pbhash import compute_hash_string def get_version_message(): @@ -140,8 +141,8 @@ def _dispatch_execute( offloaded_literals: Dict[str, _literal_models.Literal] = {} literal_map_copy = {} - min_offloaded_size = int(os.environ.get("FK_L_MIN_SIZE_MB", "10")) * 1024 * 1024 - max_offloaded_size = int(os.environ.get("FK_L_MAX_SIZE_MB", "1000")) * 1024 * 1024 + min_offloaded_size = int(os.environ.get("_F_L_MIN_SIZE_MB", "10")) * 1024 * 1024 + max_offloaded_size = int(os.environ.get("_F_L_MAX_SIZE_MB", "1000")) * 1024 * 1024 # Go over each output and create a separate offloaded in case its size is too large for k, v in outputs.literals.items(): @@ -154,8 +155,6 @@ def _dispatch_execute( if lit.ByteSize() >= min_offloaded_size: logger.debug(f"Literal {k} is too large to be inlined, offloading to metadata bucket") - # TODO: hash calculation - offloaded_filename = f"{k}_offloaded_metadata.pb" offloaded_literal = _literal_models.Literal( @@ -163,7 +162,8 @@ def _dispatch_execute( uri=f"{output_prefix}/{offloaded_filename}", size_bytes=lit.ByteSize(), # TODO: do I have to set the inferred literal type? - ) + ), + hash=v.hash if v.hash is not None else compute_hash_string(lit), ) literal_map_copy[k] = offloaded_literal offloaded_literals[offloaded_filename] = v diff --git a/tests/flytekit/unit/bin/test_python_entrypoint.py b/tests/flytekit/unit/bin/test_python_entrypoint.py index 28abf70a2f..aa9b36fd74 100644 --- a/tests/flytekit/unit/bin/test_python_entrypoint.py +++ b/tests/flytekit/unit/bin/test_python_entrypoint.py @@ -16,6 +16,7 @@ from flytekit.bin.entrypoint import _dispatch_execute, normalize_inputs, setup_execution, get_traceback_str from flytekit.configuration import Image, ImageConfig, SerializationSettings from flytekit.core import mock_stats +from flytekit.core.hash import HashMethod from flytekit.models.core import identifier as id_models from flytekit.core import context_manager from flytekit.core.base_task import IgnoreOutputs @@ -462,6 +463,30 @@ def test_get_traceback_str(): assert expected_error_re.match(traceback_str) is not None +def get_flyte_context(tmp_path_factory, outputs_path): + """ + This is a helper function to create a flyte context with the right parameters for testing offloading of literals. + """ + ctx = context_manager.FlyteContext.current_context() + return context_manager.FlyteContextManager.with_context( + ctx.with_execution_state( + ctx.execution_state.with_params( + engine_dir=tmp_path_factory.mktemp("engine_dir"), + mode=context_manager.ExecutionState.Mode.TASK_EXECUTION, + user_space_params=context_manager.ExecutionParameters( + execution_date=datetime.now(), + tmp_dir="/tmp", + stats=mock_stats.MockStats(), + logging=None, + raw_output_prefix="", + output_metadata_prefix=str(outputs_path.absolute()), + execution_id=id_models.WorkflowExecutionIdentifier("p", "d", "n"), + ), + ), + ), + ) + + def test_dispatch_execute_offloaded_literals(tmp_path_factory): @task def t1(a: typing.List[int]) -> typing.List[str]: @@ -471,23 +496,7 @@ def t1(a: typing.List[int]) -> typing.List[str]: outputs_path = tmp_path_factory.mktemp("outputs") ctx = context_manager.FlyteContext.current_context() - with context_manager.FlyteContextManager.with_context( - ctx.with_execution_state( - ctx.execution_state.with_params( - engine_dir=tmp_path_factory.mktemp("engine_dir"), - mode=context_manager.ExecutionState.Mode.TASK_EXECUTION, - user_space_params=context_manager.ExecutionParameters( - execution_date=datetime.now(), - tmp_dir="/tmp", - stats=mock_stats.MockStats(), - logging=None, - raw_output_prefix="", - output_metadata_prefix=str(outputs_path.absolute()), - execution_id=id_models.WorkflowExecutionIdentifier("p", "d", "n"), - ), - ), - ), - ) as ctx: + with get_flyte_context(tmp_path_factory, outputs_path) as ctx: xs: typing.List[int] = [1, 2, 3] input_literal_map = _literal_models.LiteralMap( { @@ -505,7 +514,7 @@ def t1(a: typing.List[int]) -> typing.List[str]: write_proto_to_file(input_literal_map.to_flyte_idl(), str(inputs_path/"inputs.pb")) - with mock.patch.dict(os.environ, {"FK_L_MIN_SIZE_MB": "0"}): + with mock.patch.dict(os.environ, {"_F_L_MIN_SIZE_MB": "0"}): _dispatch_execute(ctx, lambda: t1, str(inputs_path/"inputs.pb"), str(outputs_path.absolute())) assert "error.pb" not in os.listdir(outputs_path) @@ -551,23 +560,7 @@ def t1(xs: typing.List[int]) -> typing.Tuple[int, typing.List[str]]: outputs_path = tmp_path_factory.mktemp("outputs") ctx = context_manager.FlyteContext.current_context() - with context_manager.FlyteContextManager.with_context( - ctx.with_execution_state( - ctx.execution_state.with_params( - engine_dir=tmp_path_factory.mktemp("engine_dir"), - mode=context_manager.ExecutionState.Mode.TASK_EXECUTION, - user_space_params=context_manager.ExecutionParameters( - execution_date=datetime.now(), - tmp_dir="/tmp", - stats=mock_stats.MockStats(), - logging=None, - raw_output_prefix="", - output_metadata_prefix=str(outputs_path.absolute()), - execution_id=id_models.WorkflowExecutionIdentifier("p", "d", "n"), - ), - ), - ), - ) as ctx: + with get_flyte_context(tmp_path_factory, outputs_path) as ctx: xs: typing.List[int] = [1, 2, 3, 4] input_literal_map = _literal_models.LiteralMap( { @@ -585,7 +578,7 @@ def t1(xs: typing.List[int]) -> typing.Tuple[int, typing.List[str]]: write_proto_to_file(input_literal_map.to_flyte_idl(), str(inputs_path/"inputs.pb")) - with mock.patch.dict(os.environ, {"FK_L_MIN_SIZE_MB": "0"}): + with mock.patch.dict(os.environ, {"_F_L_MIN_SIZE_MB": "0"}): _dispatch_execute(ctx, lambda: t1, str(inputs_path/"inputs.pb"), str(outputs_path.absolute())) assert "error.pb" not in os.listdir(outputs_path) @@ -648,24 +641,7 @@ def t1(n: int) -> typing.Tuple[int, DC]: inputs_path = tmp_path_factory.mktemp("inputs") outputs_path = tmp_path_factory.mktemp("outputs") - ctx = context_manager.FlyteContext.current_context() - with context_manager.FlyteContextManager.with_context( - ctx.with_execution_state( - ctx.execution_state.with_params( - engine_dir=tmp_path_factory.mktemp("engine_dir"), - mode=context_manager.ExecutionState.Mode.TASK_EXECUTION, - user_space_params=context_manager.ExecutionParameters( - execution_date=datetime.now(), - tmp_dir="/tmp", - stats=mock_stats.MockStats(), - logging=None, - raw_output_prefix="", - output_metadata_prefix=str(outputs_path.absolute()), - execution_id=id_models.WorkflowExecutionIdentifier("p", "d", "n"), - ), - ), - ), - ) as ctx: + with get_flyte_context(tmp_path_factory, outputs_path) as ctx: input_literal_map = _literal_models.LiteralMap( { "n": _literal_models.Literal( @@ -677,7 +653,7 @@ def t1(n: int) -> typing.Tuple[int, DC]: write_proto_to_file(input_literal_map.to_flyte_idl(), str(inputs_path/"inputs.pb")) # Notice how the threshold is set to 1MB - with mock.patch.dict(os.environ, {"FK_L_MIN_SIZE_MB": "1"}): + with mock.patch.dict(os.environ, {"_F_L_MIN_SIZE_MB": "1"}): _dispatch_execute(ctx, lambda: t1, str(inputs_path/"inputs.pb"), str(outputs_path.absolute())) assert "error.pb" not in os.listdir(outputs_path) @@ -695,18 +671,75 @@ def t1(n: int) -> typing.Tuple[int, DC]: # o0 is not offloaded assert "o0" in lit.literals assert lit.literals["o0"].HasField("offloaded_metadata") is False + assert lit.literals["o0"].hash == "" # o1 is offloaded assert "o1" in lit.literals assert lit.literals["o1"].HasField("offloaded_metadata") is True assert lit.literals["o1"].offloaded_metadata.size_bytes == 1108538 assert lit.literals["o1"].offloaded_metadata.uri.endswith("/o1_offloaded_metadata.pb") + assert lit.literals["o1"].hash == "VS9bthLslGa8tjuVBCcmO3UdGHrkpyOBXzJlmY47fw8=" elif ff == "o1_offloaded_metadata.pb": lit = literals_pb2.Literal() lit.ParseFromString(f.read()) + assert lit.hash == "" # Load the dataclass from the proto transformer = TypeEngine.get_transformer(DC) dc = transformer.to_python_value(ctx, _literal_models.Literal.from_flyte_idl(lit), DC) assert dc.a == list(range(56_000)) else: assert False, f"Unexpected file {ff}" + + + +def test_dispatch_execute_offloaded_literals_annotated_hash(tmp_path_factory): + class A: + def __init__(self, a: int): + self.a = a + + @task + def t1(n: int) -> typing.Annotated[A, HashMethod(lambda x: str(x.a))]: + return A(a=n) + + inputs_path = tmp_path_factory.mktemp("inputs") + outputs_path = tmp_path_factory.mktemp("outputs") + + with get_flyte_context(tmp_path_factory, outputs_path) as ctx: + input_literal_map = _literal_models.LiteralMap( + { + "n": _literal_models.Literal( + scalar=_literal_models.Scalar(primitive=_literal_models.Primitive(integer=1234)), + ) + } + ) + + write_proto_to_file(input_literal_map.to_flyte_idl(), str(inputs_path/"inputs.pb")) + + # All literals should be offloaded + with mock.patch.dict(os.environ, {"_F_L_MIN_SIZE_MB": "0"}): + _dispatch_execute(ctx, lambda: t1, str(inputs_path/"inputs.pb"), str(outputs_path.absolute())) + + assert "error.pb" not in os.listdir(outputs_path) + + for ff in os.listdir(outputs_path): + with open(outputs_path/ff, "rb") as f: + if ff == "outputs.pb": + lit = literals_pb2.LiteralMap() + lit.ParseFromString(f.read()) + assert len(lit.literals) == 1 + + # o1 is offloaded + assert "o0" in lit.literals + assert lit.literals["o0"].HasField("offloaded_metadata") is True + assert lit.literals["o0"].offloaded_metadata.size_bytes > 0 + assert lit.literals["o0"].offloaded_metadata.uri.endswith("/o0_offloaded_metadata.pb") + assert lit.literals["o0"].hash == "1234" + elif ff == "o0_offloaded_metadata.pb": + lit = literals_pb2.Literal() + lit.ParseFromString(f.read()) + assert lit.hash == "1234" + transformer = TypeEngine.get_transformer(A) + a = transformer.to_python_value(ctx, _literal_models.Literal.from_flyte_idl(lit), A) + assert a.a == 1234 + else: + assert False, f"Unexpected file {ff}" From b2844920a632981481056b740c5df1d26a9adc3a Mon Sep 17 00:00:00 2001 From: Eduardo Apolinario Date: Thu, 7 Nov 2024 07:31:03 -0500 Subject: [PATCH 15/25] Add a few more tests Signed-off-by: Eduardo Apolinario --- tests/flytekit/unit/utils/test_pbhash.py | 81 +++++++++++++++++++++++- 1 file changed, 78 insertions(+), 3 deletions(-) diff --git a/tests/flytekit/unit/utils/test_pbhash.py b/tests/flytekit/unit/utils/test_pbhash.py index c00cfdc89b..f608271177 100644 --- a/tests/flytekit/unit/utils/test_pbhash.py +++ b/tests/flytekit/unit/utils/test_pbhash.py @@ -1,13 +1,35 @@ +import tempfile +import mock import pytest -from flyteidl.core.literals_pb2 import Blob, BlobMetadata, Literal, LiteralCollection, Scalar -from flyteidl.core.types_pb2 import BlobType +import typing +from dataclasses import dataclass, field +from google.protobuf import json_format +from google.protobuf import struct_pb2 +from dataclasses_json import DataClassJsonMixin +from flyteidl.core.literals_pb2 import Blob, BlobMetadata, Literal, LiteralCollection, LiteralMap, Primitive, Scalar +from flyteidl.core.types_pb2 import BlobType +from flytekit.core.context_manager import FlyteContext, FlyteContextManager +from flytekit.core.type_engine import DataclassTransformer +from flytekit.types.file.file import FlyteFile from flytekit.utils.pbhash import compute_hash_string @pytest.mark.parametrize( "lit, expected_hash", [ + ( + Literal(scalar=Scalar(primitive=Primitive(integer=1))), + "aJB6fp0kDrfAZt22e/IFnT8IJIlobjxcweiZA8I7/dA=", + ), + ( + Literal(collection=LiteralCollection(literals=[Literal(scalar=Scalar(primitive=Primitive(integer=1)))])), + "qN7iA0GnbLzFGcHB7y09lbxgx+9cTIViSlyL9/kCSC0=", + ), + ( + Literal(map=LiteralMap(literals={"a": Literal(scalar=Scalar(primitive=Primitive(integer=1)))})), + "JhrkdOQ+xzPVNYiKzD5sHhZprQB5Nq1GsYUVbmLAswU=", + ), ( Literal( scalar=Scalar( @@ -65,5 +87,58 @@ ) ], ) -def test_lit(lit, expected_hash): +def test_direct_literals(lit, expected_hash): assert compute_hash_string(lit) == expected_hash + + +@mock.patch("flytekit.core.data_persistence.FileAccessProvider.async_put_data") +def test_dataclass_literals(mock_put_data): + + @dataclass + class A(DataClassJsonMixin): + a: int + + @dataclass + class TestFileStruct(DataClassJsonMixin): + a: FlyteFile + b: typing.Optional[FlyteFile] + b_prime: typing.Optional[FlyteFile] + c: typing.Union[FlyteFile, None] + d: typing.List[FlyteFile] + e: typing.List[typing.Optional[FlyteFile]] + e_prime: typing.List[typing.Optional[FlyteFile]] + f: typing.Dict[str, FlyteFile] + g: typing.Dict[str, typing.Optional[FlyteFile]] + g_prime: typing.Dict[str, typing.Optional[FlyteFile]] + h: typing.Optional[FlyteFile] = None + h_prime: typing.Optional[FlyteFile] = None + i: typing.Optional[A] = None + i_prime: typing.Optional[A] = field(default_factory=lambda: A(a=99)) + + remote_path = "s3://tmp/file" + mock_put_data.return_value = remote_path + + with tempfile.TemporaryFile() as f: + f.write(b"abc") + f1 = FlyteFile("f1", remote_path=remote_path) + o = TestFileStruct( + a=f1, + b=f1, + b_prime=None, + c=f1, + d=[f1], + e=[f1], + e_prime=[None], + f={"a": f1}, + g={"a": f1}, + g_prime={"a": None}, + h=f1, + i=A(a=42), + ) + + ctx = FlyteContext.current_context() + tf = DataclassTransformer() + lt = tf.get_literal_type(TestFileStruct) + lv = tf.to_literal(ctx, o, TestFileStruct, lt) + + assert compute_hash_string(lv.to_flyte_idl()) == "Hp/cWul3sBI5r8XKdVzAlvNBJ4OSX9L2d/SADI8+YOY=" From b579b835ec8e77a0a9c3e2259635cedc0deaedfe Mon Sep 17 00:00:00 2001 From: Eduardo Apolinario Date: Thu, 7 Nov 2024 07:31:14 -0500 Subject: [PATCH 16/25] Always import ParamSpec from `typing_extensions` Signed-off-by: Eduardo Apolinario --- flytekit/core/task.py | 5 +---- flytekit/core/workflow.py | 5 +---- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/flytekit/core/task.py b/flytekit/core/task.py index 745f452a83..8cb7e00f11 100644 --- a/flytekit/core/task.py +++ b/flytekit/core/task.py @@ -7,10 +7,7 @@ from flytekit.core.utils import str2bool -try: - from typing import ParamSpec -except ImportError: - from typing_extensions import ParamSpec # type: ignore +from typing_extensions import ParamSpec # type: ignore from flytekit.core import launch_plan as _annotated_launchplan from flytekit.core import workflow as _annotated_workflow diff --git a/flytekit/core/workflow.py b/flytekit/core/workflow.py index de0f620e96..5c6f98e518 100644 --- a/flytekit/core/workflow.py +++ b/flytekit/core/workflow.py @@ -10,10 +10,7 @@ from typing_inspect import is_optional_type -try: - from typing import ParamSpec -except ImportError: - from typing_extensions import ParamSpec # type: ignore +from typing_extensions import ParamSpec # type: ignore from flytekit.core import constants as _common_constants from flytekit.core import launch_plan as _annotated_launch_plan From 8c2336ee0734beb802530de86641748146a73e7c Mon Sep 17 00:00:00 2001 From: Eduardo Apolinario Date: Thu, 7 Nov 2024 07:41:06 -0500 Subject: [PATCH 17/25] Fix lint warnings Signed-off-by: Eduardo Apolinario --- flytekit/core/task.py | 3 +-- flytekit/core/workflow.py | 3 +-- flytekit/utils/pbhash.py | 10 +++++++--- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/flytekit/core/task.py b/flytekit/core/task.py index 8cb7e00f11..6c6ed6a2a3 100644 --- a/flytekit/core/task.py +++ b/flytekit/core/task.py @@ -5,8 +5,6 @@ from functools import update_wrapper from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union, overload -from flytekit.core.utils import str2bool - from typing_extensions import ParamSpec # type: ignore from flytekit.core import launch_plan as _annotated_launchplan @@ -17,6 +15,7 @@ from flytekit.core.python_function_task import PythonFunctionTask from flytekit.core.reference_entity import ReferenceEntity, TaskReference from flytekit.core.resources import Resources +from flytekit.core.utils import str2bool from flytekit.deck import DeckField from flytekit.extras.accelerators import BaseAccelerator from flytekit.image_spec.image_spec import ImageSpec diff --git a/flytekit/core/workflow.py b/flytekit/core/workflow.py index 5c6f98e518..a07b7f09de 100644 --- a/flytekit/core/workflow.py +++ b/flytekit/core/workflow.py @@ -8,9 +8,8 @@ from functools import update_wrapper from typing import Any, Callable, Coroutine, Dict, List, Optional, Tuple, Type, Union, cast, overload -from typing_inspect import is_optional_type - from typing_extensions import ParamSpec # type: ignore +from typing_inspect import is_optional_type from flytekit.core import constants as _common_constants from flytekit.core import launch_plan as _annotated_launch_plan diff --git a/flytekit/utils/pbhash.py b/flytekit/utils/pbhash.py index 723ab62265..ae4a364d12 100644 --- a/flytekit/utils/pbhash.py +++ b/flytekit/utils/pbhash.py @@ -2,9 +2,11 @@ import base64 import hashlib import json + from google.protobuf import json_format from google.protobuf.message import Message + def compute_hash(pb: Message) -> bytes: """ Computes a deterministic hash in bytes for the Protobuf object. @@ -12,14 +14,16 @@ def compute_hash(pb: Message) -> bytes: try: pb_dict = json_format.MessageToDict(pb) # json.dumps with sorted keys to ensure stability - stable_json_str = json.dumps(pb_dict, sort_keys=True, separators=(',', ':')) # separators to ensure no extra spaces + stable_json_str = json.dumps( + pb_dict, sort_keys=True, separators=(",", ":") + ) # separators to ensure no extra spaces except Exception as e: raise ValueError(f"Failed to marshal Protobuf object {pb} to JSON with error: {e}") try: # Deterministically hash the JSON object to a byte array. Using SHA-256 for hashing here, # assuming it provides a consistent hash output. - hash_obj = hashlib.sha256(stable_json_str.encode('utf-8')) + hash_obj = hashlib.sha256(stable_json_str.encode("utf-8")) except Exception as e: raise ValueError(f"Failed to hash JSON for Protobuf object {pb} with error: {e}") @@ -32,4 +36,4 @@ def compute_hash_string(pb: Message) -> str: Computes a deterministic hash in base64 encoded string for the Protobuf object """ hash_bytes = compute_hash(pb) - return base64.b64encode(hash_bytes).decode('utf-8') + return base64.b64encode(hash_bytes).decode("utf-8") From 37b2bb4ca421d92a72807eb7b9214a8bd83f1ac3 Mon Sep 17 00:00:00 2001 From: Eduardo Apolinario Date: Thu, 7 Nov 2024 20:52:56 -0500 Subject: [PATCH 18/25] Set inferred_type using the task type interface Signed-off-by: Eduardo Apolinario --- flytekit/bin/entrypoint.py | 2 +- tests/flytekit/unit/bin/test_python_entrypoint.py | 10 ++++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py index 013a5d1815..4d01b5d87f 100644 --- a/flytekit/bin/entrypoint.py +++ b/flytekit/bin/entrypoint.py @@ -161,7 +161,7 @@ def _dispatch_execute( offloaded_metadata=_literal_models.LiteralOffloadedMetadata( uri=f"{output_prefix}/{offloaded_filename}", size_bytes=lit.ByteSize(), - # TODO: do I have to set the inferred literal type? + inferred_type=task_def.interface.outputs[k].type, ), hash=v.hash if v.hash is not None else compute_hash_string(lit), ) diff --git a/tests/flytekit/unit/bin/test_python_entrypoint.py b/tests/flytekit/unit/bin/test_python_entrypoint.py index aa9b36fd74..7e58799caf 100644 --- a/tests/flytekit/unit/bin/test_python_entrypoint.py +++ b/tests/flytekit/unit/bin/test_python_entrypoint.py @@ -23,7 +23,7 @@ from flytekit.core.dynamic_workflow_task import dynamic from flytekit.core.promise import VoidPromise from flytekit.core.task import task -from flytekit.core.type_engine import TypeEngine +from flytekit.core.type_engine import TypeEngine, DataclassTransformer from flytekit.exceptions import user as user_exceptions from flytekit.exceptions.scopes import system_entry_point, user_entry_point from flytekit.exceptions.user import FlyteRecoverableException, FlyteUserRuntimeException @@ -31,6 +31,7 @@ from flytekit.models.core import errors as error_models, execution from flytekit.models.core import execution as execution_models from flytekit.core.utils import write_proto_to_file +from flytekit.models.types import LiteralType, SimpleType @mock.patch("flytekit.core.utils.load_proto_from_file") @@ -529,6 +530,7 @@ def t1(a: typing.List[int]) -> typing.List[str]: assert lit.literals["o0"].offloaded_metadata is not None assert lit.literals["o0"].offloaded_metadata.size_bytes == 62 assert lit.literals["o0"].offloaded_metadata.uri.endswith("/o0_offloaded_metadata.pb") + assert lit.literals["o0"].offloaded_metadata.inferred_type == LiteralType(collection_type=LiteralType(simple=SimpleType.STRING)).to_flyte_idl() elif ff == "o0_offloaded_metadata.pb": lit = literals_pb2.Literal() lit.ParseFromString(f.read()) @@ -593,10 +595,12 @@ def t1(xs: typing.List[int]) -> typing.Tuple[int, typing.List[str]]: assert lit.literals["o0"].offloaded_metadata is not None assert lit.literals["o0"].offloaded_metadata.size_bytes == 6 assert lit.literals["o0"].offloaded_metadata.uri.endswith("/o0_offloaded_metadata.pb") + assert lit.literals["o0"].offloaded_metadata.inferred_type == LiteralType(simple=SimpleType.INTEGER).to_flyte_idl() assert "o1" in lit.literals assert lit.literals["o1"].offloaded_metadata is not None assert lit.literals["o1"].offloaded_metadata.size_bytes == 82 assert lit.literals["o1"].offloaded_metadata.uri.endswith("/o1_offloaded_metadata.pb") + assert lit.literals["o1"].offloaded_metadata.inferred_type == LiteralType(collection_type=LiteralType(simple=SimpleType.STRING)).to_flyte_idl() elif ff == "o0_offloaded_metadata.pb": lit = literals_pb2.Literal() lit.ParseFromString(f.read()) @@ -679,6 +683,7 @@ def t1(n: int) -> typing.Tuple[int, DC]: assert lit.literals["o1"].offloaded_metadata.size_bytes == 1108538 assert lit.literals["o1"].offloaded_metadata.uri.endswith("/o1_offloaded_metadata.pb") assert lit.literals["o1"].hash == "VS9bthLslGa8tjuVBCcmO3UdGHrkpyOBXzJlmY47fw8=" + assert lit.literals["o1"].offloaded_metadata.inferred_type == DataclassTransformer().get_literal_type(DC).to_flyte_idl() elif ff == "o1_offloaded_metadata.pb": lit = literals_pb2.Literal() lit.ParseFromString(f.read()) @@ -728,12 +733,13 @@ def t1(n: int) -> typing.Annotated[A, HashMethod(lambda x: str(x.a))]: lit.ParseFromString(f.read()) assert len(lit.literals) == 1 - # o1 is offloaded + # o0 is offloaded assert "o0" in lit.literals assert lit.literals["o0"].HasField("offloaded_metadata") is True assert lit.literals["o0"].offloaded_metadata.size_bytes > 0 assert lit.literals["o0"].offloaded_metadata.uri.endswith("/o0_offloaded_metadata.pb") assert lit.literals["o0"].hash == "1234" + assert lit.literals["o0"].offloaded_metadata.inferred_type == t1.interface.outputs["o0"].type.to_flyte_idl() elif ff == "o0_offloaded_metadata.pb": lit = literals_pb2.Literal() lit.ParseFromString(f.read()) From e25496d8c5ded4def5b3b8a8138b4825d1ed0fad Mon Sep 17 00:00:00 2001 From: Eduardo Apolinario Date: Thu, 7 Nov 2024 22:30:35 -0500 Subject: [PATCH 19/25] Add comment about offloaded literals files and how they are uploaded to the metadata bucket Signed-off-by: Eduardo Apolinario --- flytekit/bin/entrypoint.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py index 4d01b5d87f..b09c15148a 100644 --- a/flytekit/bin/entrypoint.py +++ b/flytekit/bin/entrypoint.py @@ -155,8 +155,10 @@ def _dispatch_execute( if lit.ByteSize() >= min_offloaded_size: logger.debug(f"Literal {k} is too large to be inlined, offloading to metadata bucket") + # This file will hold the offloaded literal and will be written to the output prefix + # alongside the regular outputs.pb, deck.pb, etc. + # N.B.: by construction `offloaded_filename` is guaranteed to be unique offloaded_filename = f"{k}_offloaded_metadata.pb" - offloaded_literal = _literal_models.Literal( offloaded_metadata=_literal_models.LiteralOffloadedMetadata( uri=f"{output_prefix}/{offloaded_filename}", From 9276e96e5db4c3bfce5a19d3825dded13d70d72a Mon Sep 17 00:00:00 2001 From: Eduardo Apolinario Date: Mon, 18 Nov 2024 18:45:07 -0500 Subject: [PATCH 20/25] Add offloading_enabled Signed-off-by: Eduardo Apolinario --- flytekit/bin/entrypoint.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py index b09c15148a..0970e2b6ab 100644 --- a/flytekit/bin/entrypoint.py +++ b/flytekit/bin/entrypoint.py @@ -141,18 +141,26 @@ def _dispatch_execute( offloaded_literals: Dict[str, _literal_models.Literal] = {} literal_map_copy = {} - min_offloaded_size = int(os.environ.get("_F_L_MIN_SIZE_MB", "10")) * 1024 * 1024 - max_offloaded_size = int(os.environ.get("_F_L_MAX_SIZE_MB", "1000")) * 1024 * 1024 + offloading_enabled = os.environ.get("_F_L_MIN_SIZE_MB", None) is not None + min_offloaded_size = -1 + max_offloaded_size = -1 + if offloading_enabled: + min_offloaded_size = int(os.environ.get("_F_L_MIN_SIZE_MB", "10")) * 1024 * 1024 + max_offloaded_size = int(os.environ.get("_F_L_MAX_SIZE_MB", "1000")) * 1024 * 1024 # Go over each output and create a separate offloaded in case its size is too large for k, v in outputs.literals.items(): + if not offloading_enabled: + literal_map_copy[k] = v + continue + lit = v.to_flyte_idl() - if lit.ByteSize() >= max_offloaded_size: + if max_offloaded_size != -1 and lit.ByteSize() >= max_offloaded_size: raise ValueError( f"Literal {k} is too large to be offloaded. Max literal size is {max_offloaded_size} whereas the literal size is {lit.ByteSize()} bytes" ) - if lit.ByteSize() >= min_offloaded_size: + if min_offloaded_size != -1 and lit.ByteSize() >= min_offloaded_size: logger.debug(f"Literal {k} is too large to be inlined, offloading to metadata bucket") # This file will hold the offloaded literal and will be written to the output prefix @@ -169,8 +177,6 @@ def _dispatch_execute( ) literal_map_copy[k] = offloaded_literal offloaded_literals[offloaded_filename] = v - else: - literal_map_copy[k] = v outputs = _literal_models.LiteralMap(literals=literal_map_copy) output_file_dict = {_constants.OUTPUT_FILE_NAME: outputs, **offloaded_literals} From a8bdbca037106fcdfa6a5a643a2b996ac288ea34 Mon Sep 17 00:00:00 2001 From: Eduardo Apolinario Date: Tue, 19 Nov 2024 14:26:02 -0500 Subject: [PATCH 21/25] Add more unit tests including a negative test Signed-off-by: Eduardo Apolinario --- flytekit/bin/entrypoint.py | 3 +- .../unit/bin/test_python_entrypoint.py | 99 ++++++++++++++++--- 2 files changed, 89 insertions(+), 13 deletions(-) diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py index 0970e2b6ab..90e54ad3b6 100644 --- a/flytekit/bin/entrypoint.py +++ b/flytekit/bin/entrypoint.py @@ -150,8 +150,9 @@ def _dispatch_execute( # Go over each output and create a separate offloaded in case its size is too large for k, v in outputs.literals.items(): + literal_map_copy[k] = v + if not offloading_enabled: - literal_map_copy[k] = v continue lit = v.to_flyte_idl() diff --git a/tests/flytekit/unit/bin/test_python_entrypoint.py b/tests/flytekit/unit/bin/test_python_entrypoint.py index 7e58799caf..b62bfd8e34 100644 --- a/tests/flytekit/unit/bin/test_python_entrypoint.py +++ b/tests/flytekit/unit/bin/test_python_entrypoint.py @@ -501,15 +501,7 @@ def t1(a: typing.List[int]) -> typing.List[str]: xs: typing.List[int] = [1, 2, 3] input_literal_map = _literal_models.LiteralMap( { - "a": _literal_models.Literal( - collection=_literal_models.LiteralCollection( - literals=[ - _literal_models.Literal( - scalar=_literal_models.Scalar(primitive=_literal_models.Primitive(integer=x)), - ) for x in xs - ] - ) - ) + "a": TypeEngine.to_literal(ctx, xs, typing.List[int], TypeEngine.to_literal_type(typing.List[int])), } ) @@ -527,7 +519,7 @@ def t1(a: typing.List[int]) -> typing.List[str]: lit.ParseFromString(f.read()) assert len(lit.literals) == 1 assert "o0" in lit.literals - assert lit.literals["o0"].offloaded_metadata is not None + assert lit.literals["o0"].HasField("offloaded_metadata") == True assert lit.literals["o0"].offloaded_metadata.size_bytes == 62 assert lit.literals["o0"].offloaded_metadata.uri.endswith("/o0_offloaded_metadata.pb") assert lit.literals["o0"].offloaded_metadata.inferred_type == LiteralType(collection_type=LiteralType(simple=SimpleType.STRING)).to_flyte_idl() @@ -592,12 +584,12 @@ def t1(xs: typing.List[int]) -> typing.Tuple[int, typing.List[str]]: lit.ParseFromString(f.read()) assert len(lit.literals) == 2 assert "o0" in lit.literals - assert lit.literals["o0"].offloaded_metadata is not None + assert lit.literals["o0"].HasField("offloaded_metadata") == True assert lit.literals["o0"].offloaded_metadata.size_bytes == 6 assert lit.literals["o0"].offloaded_metadata.uri.endswith("/o0_offloaded_metadata.pb") assert lit.literals["o0"].offloaded_metadata.inferred_type == LiteralType(simple=SimpleType.INTEGER).to_flyte_idl() assert "o1" in lit.literals - assert lit.literals["o1"].offloaded_metadata is not None + assert lit.literals["o1"].HasField("offloaded_metadata") == True assert lit.literals["o1"].offloaded_metadata.size_bytes == 82 assert lit.literals["o1"].offloaded_metadata.uri.endswith("/o1_offloaded_metadata.pb") assert lit.literals["o1"].offloaded_metadata.inferred_type == LiteralType(collection_type=LiteralType(simple=SimpleType.STRING)).to_flyte_idl() @@ -749,3 +741,86 @@ def t1(n: int) -> typing.Annotated[A, HashMethod(lambda x: str(x.a))]: assert a.a == 1234 else: assert False, f"Unexpected file {ff}" + + +def test_dispatch_execute_offloaded_nested_lists_of_literals(tmp_path_factory): + @task + def t1(a: typing.List[int]) -> typing.List[typing.List[str]]: + return [[f"string is: {x}" for x in a] for _ in range(len(a))] + + inputs_path = tmp_path_factory.mktemp("inputs") + outputs_path = tmp_path_factory.mktemp("outputs") + + ctx = context_manager.FlyteContext.current_context() + with get_flyte_context(tmp_path_factory, outputs_path) as ctx: + xs: typing.List[int] = [1, 2, 3] + input_literal_map = _literal_models.LiteralMap( + { + "a": TypeEngine.to_literal(ctx, xs, typing.List[int], TypeEngine.to_literal_type(typing.List[int])), + } + ) + + write_proto_to_file(input_literal_map.to_flyte_idl(), str(inputs_path/"inputs.pb")) + + with mock.patch.dict(os.environ, {"_F_L_MIN_SIZE_MB": "0"}): + _dispatch_execute(ctx, lambda: t1, str(inputs_path/"inputs.pb"), str(outputs_path.absolute())) + + assert "error.pb" not in os.listdir(outputs_path) + + for ff in os.listdir(outputs_path): + with open(outputs_path/ff, "rb") as f: + if ff == "outputs.pb": + lit = literals_pb2.LiteralMap() + lit.ParseFromString(f.read()) + assert len(lit.literals) == 1 + assert "o0" in lit.literals + assert lit.literals["o0"].HasField("offloaded_metadata") == True + assert lit.literals["o0"].offloaded_metadata.size_bytes == 195 + assert lit.literals["o0"].offloaded_metadata.uri.endswith("/o0_offloaded_metadata.pb") + assert lit.literals["o0"].offloaded_metadata.inferred_type == LiteralType(collection_type=LiteralType(collection_type=LiteralType(simple=SimpleType.STRING))).to_flyte_idl() + elif ff == "o0_offloaded_metadata.pb": + lit = literals_pb2.Literal() + lit.ParseFromString(f.read()) + expected_output = [[f"string is: {x}" for x in xs] for _ in range(len(xs))] + assert lit == TypeEngine.to_literal(ctx, expected_output, typing.List[typing.List[str]], TypeEngine.to_literal_type(typing.List[typing.List[str]])).to_flyte_idl() + else: + assert False, f"Unexpected file {ff}" + + +def test_dispatch_execute_offloaded_nested_lists_of_literals_offloading_disabled(tmp_path_factory): + @task + def t1(a: typing.List[int]) -> typing.List[typing.List[str]]: + return [[f"string is: {x}" for x in a] for _ in range(len(a))] + + inputs_path = tmp_path_factory.mktemp("inputs") + outputs_path = tmp_path_factory.mktemp("outputs") + + ctx = context_manager.FlyteContext.current_context() + with get_flyte_context(tmp_path_factory, outputs_path) as ctx: + xs: typing.List[int] = [1, 2, 3] + input_literal_map = _literal_models.LiteralMap( + { + "a": TypeEngine.to_literal(ctx, xs, typing.List[int], TypeEngine.to_literal_type(typing.List[int])), + } + ) + + write_proto_to_file(input_literal_map.to_flyte_idl(), str(inputs_path/"inputs.pb")) + + # Ensure that this is not set by an external source + assert os.environ.get("_F_L_MIN_SIZE_MB") is None + + # Notice how we're setting the env var to None, which disables offloading completely + _dispatch_execute(ctx, lambda: t1, str(inputs_path/"inputs.pb"), str(outputs_path.absolute())) + + assert "error.pb" not in os.listdir(outputs_path) + + for ff in os.listdir(outputs_path): + with open(outputs_path/ff, "rb") as f: + if ff == "outputs.pb": + lit = literals_pb2.LiteralMap() + lit.ParseFromString(f.read()) + assert len(lit.literals) == 1 + assert "o0" in lit.literals + assert lit.literals["o0"].HasField("offloaded_metadata") == False + else: + assert False, f"Unexpected file {ff}" From a4fcfabef361bf0eb01c6dd6c68095331e96b489 Mon Sep 17 00:00:00 2001 From: Eduardo Apolinario Date: Tue, 19 Nov 2024 14:56:21 -0500 Subject: [PATCH 22/25] Fix bad merge Signed-off-by: Eduardo Apolinario --- tests/flytekit/unit/bin/test_python_entrypoint.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/tests/flytekit/unit/bin/test_python_entrypoint.py b/tests/flytekit/unit/bin/test_python_entrypoint.py index ca05a97f8b..d4dc88b9d1 100644 --- a/tests/flytekit/unit/bin/test_python_entrypoint.py +++ b/tests/flytekit/unit/bin/test_python_entrypoint.py @@ -1,7 +1,4 @@ -<<<<<<< HEAD from dataclasses import dataclass -======= ->>>>>>> origin from datetime import datetime import os import re @@ -15,12 +12,9 @@ import mock import pytest from flyteidl.core.errors_pb2 import ErrorDocument -<<<<<<< HEAD from flyteidl.core import literals_pb2 from flyteidl.core.literals_pb2 import Literal, LiteralCollection, Scalar, Primitive -======= from google.protobuf.timestamp_pb2 import Timestamp ->>>>>>> origin from flytekit.bin.entrypoint import _dispatch_execute, get_container_error_timestamp, normalize_inputs, setup_execution, get_traceback_str @@ -35,12 +29,8 @@ from flytekit.core.task import task from flytekit.core.type_engine import TypeEngine, DataclassTransformer from flytekit.exceptions import user as user_exceptions -<<<<<<< HEAD -from flytekit.exceptions.scopes import system_entry_point, user_entry_point -======= from flytekit.exceptions.base import FlyteException from flytekit.exceptions.scopes import system_entry_point ->>>>>>> origin from flytekit.exceptions.user import FlyteRecoverableException, FlyteUserRuntimeException from flytekit.models import literals as _literal_models from flytekit.models.core import errors as error_models, execution From b3a1b0ddeed930267f44e63ff1a9f02024aa8580 Mon Sep 17 00:00:00 2001 From: Eduardo Apolinario Date: Tue, 19 Nov 2024 15:52:33 -0500 Subject: [PATCH 23/25] Incorporate feedback. Signed-off-by: Eduardo Apolinario --- flytekit/bin/entrypoint.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py index ae9a266ffb..ed04335b00 100644 --- a/flytekit/bin/entrypoint.py +++ b/flytekit/bin/entrypoint.py @@ -180,6 +180,7 @@ def _dispatch_execute( logger.warning("Task produces no outputs") output_file_dict = {_constants.OUTPUT_FILE_NAME: _literal_models.LiteralMap(literals={})} elif isinstance(outputs, _literal_models.LiteralMap): + # The keys in this map hold the filenames to the offloaded proto literals. offloaded_literals: Dict[str, _literal_models.Literal] = {} literal_map_copy = {} @@ -214,6 +215,7 @@ def _dispatch_execute( offloaded_metadata=_literal_models.LiteralOffloadedMetadata( uri=f"{output_prefix}/{offloaded_filename}", size_bytes=lit.ByteSize(), + # TODO: remove after https://github.com/flyteorg/flyte/pull/5909 is merged inferred_type=task_def.interface.outputs[k].type, ), hash=v.hash if v.hash is not None else compute_hash_string(lit), From 32c5896f3d4c1a1db417d13e6509dad71e02773e Mon Sep 17 00:00:00 2001 From: Eduardo Apolinario Date: Thu, 21 Nov 2024 17:17:04 -0500 Subject: [PATCH 24/25] Fix image name (unrelated to this PR - just a nice-to-have to decrease flakiness) Signed-off-by: Eduardo Apolinario --- tests/flytekit/unit/core/image_spec/test_image_spec.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/flytekit/unit/core/image_spec/test_image_spec.py b/tests/flytekit/unit/core/image_spec/test_image_spec.py index b20030c3a0..aa6808215b 100644 --- a/tests/flytekit/unit/core/image_spec/test_image_spec.py +++ b/tests/flytekit/unit/core/image_spec/test_image_spec.py @@ -114,7 +114,7 @@ def test_build_existing_image_with_force_push(): image_spec = ImageSpec(name="hello", builder="test").force_push() builder = Mock() - builder.build_image.return_value = "new_image_name" + builder.build_image.return_value = "fqn.xyz/new_image_name:v-test" ImageBuildEngine.register("test", builder) ImageBuildEngine.build(image_spec) From 12e194a308f0eacc622d7358ea5698ac741aedaf Mon Sep 17 00:00:00 2001 From: Eduardo Apolinario Date: Thu, 21 Nov 2024 17:29:42 -0500 Subject: [PATCH 25/25] Add `is_map_task` to `_dispatch_execute` Signed-off-by: Eduardo Apolinario --- flytekit/bin/entrypoint.py | 19 +++- .../unit/bin/test_python_entrypoint.py | 90 +++++++++++++++++++ 2 files changed, 107 insertions(+), 2 deletions(-) diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py index ed04335b00..084e8f733b 100644 --- a/flytekit/bin/entrypoint.py +++ b/flytekit/bin/entrypoint.py @@ -136,6 +136,7 @@ def _dispatch_execute( load_task: Callable[[], PythonTask], inputs_path: str, output_prefix: str, + is_map_task: bool = False, ): """ Dispatches execute to PythonTask @@ -145,6 +146,12 @@ def _dispatch_execute( a: [Optional] Record outputs to output_prefix b: OR if IgnoreOutputs is raised, then ignore uploading outputs c: OR if an unhandled exception is retrieved - record it as an errors.pb + + :param ctx: FlyteContext + :param load_task: Callable[[], PythonTask] + :param inputs: Where to read inputs + :param output_prefix: Where to write primitive outputs + :param is_map_task: Whether this task is executing as part of a map task """ error_file_name = _build_error_file_name() worker_name = _get_worker_name() @@ -206,6 +213,14 @@ def _dispatch_execute( if min_offloaded_size != -1 and lit.ByteSize() >= min_offloaded_size: logger.debug(f"Literal {k} is too large to be inlined, offloading to metadata bucket") + inferred_type = task_def.interface.outputs[k].type + + # In the case of map tasks we need to use the type of the collection as inferred type as the task + # typed interface of the offloaded literal. This is done because the map task interface present in + # the task template contains the (correct) type for the entire map task, not the single node execution. + # For that reason we "unwrap" the collection type and use it as the inferred type of the offloaded literal. + if is_map_task: + inferred_type = inferred_type.collection_type # This file will hold the offloaded literal and will be written to the output prefix # alongside the regular outputs.pb, deck.pb, etc. @@ -216,7 +231,7 @@ def _dispatch_execute( uri=f"{output_prefix}/{offloaded_filename}", size_bytes=lit.ByteSize(), # TODO: remove after https://github.com/flyteorg/flyte/pull/5909 is merged - inferred_type=task_def.interface.outputs[k].type, + inferred_type=inferred_type, ), hash=v.hash if v.hash is not None else compute_hash_string(lit), ) @@ -633,7 +648,7 @@ def load_task(): ) return - _dispatch_execute(ctx, load_task, inputs, output_prefix) + _dispatch_execute(ctx, load_task, inputs, output_prefix, is_map_task=True) def normalize_inputs( diff --git a/tests/flytekit/unit/bin/test_python_entrypoint.py b/tests/flytekit/unit/bin/test_python_entrypoint.py index d4dc88b9d1..3955019cd0 100644 --- a/tests/flytekit/unit/bin/test_python_entrypoint.py +++ b/tests/flytekit/unit/bin/test_python_entrypoint.py @@ -20,6 +20,7 @@ from flytekit.bin.entrypoint import _dispatch_execute, get_container_error_timestamp, normalize_inputs, setup_execution, get_traceback_str from flytekit.configuration import Image, ImageConfig, SerializationSettings from flytekit.core import mock_stats +from flytekit.core.array_node_map_task import ArrayNodeMapTask from flytekit.core.hash import HashMethod from flytekit.models.core import identifier as id_models from flytekit.core import context_manager @@ -893,3 +894,92 @@ def t1(a: typing.List[int]) -> typing.List[typing.List[str]]: assert lit.literals["o0"].HasField("offloaded_metadata") == False else: assert False, f"Unexpected file {ff}" + + + +def test_dispatch_execute_offloaded_map_task(tmp_path_factory): + @task + def t1(n: int) -> int: + return n + 1 + + inputs: typing.List[int] = [1, 2, 3, 4] + for i, v in enumerate(inputs): + inputs_path = tmp_path_factory.mktemp("inputs") + outputs_path = tmp_path_factory.mktemp("outputs") + + ctx = context_manager.FlyteContext.current_context() + with get_flyte_context(tmp_path_factory, outputs_path) as ctx: + input_literal_map = _literal_models.LiteralMap( + { + "n": TypeEngine.to_literal(ctx, inputs, typing.List[int], TypeEngine.to_literal_type(typing.List[int])), + } + ) + + write_proto_to_file(input_literal_map.to_flyte_idl(), str(inputs_path/"inputs.pb")) + + with mock.patch.dict( + os.environ, + { + "_F_L_MIN_SIZE_MB": "0", # Always offload + "BATCH_JOB_ARRAY_INDEX_OFFSET": str(i), + }): + _dispatch_execute(ctx, lambda: ArrayNodeMapTask(python_function_task=t1), str(inputs_path/"inputs.pb"), str(outputs_path.absolute()), is_map_task=True) + + assert "error.pb" not in os.listdir(outputs_path) + + for ff in os.listdir(outputs_path): + with open(outputs_path/ff, "rb") as f: + if ff == "outputs.pb": + lit = literals_pb2.LiteralMap() + lit.ParseFromString(f.read()) + assert len(lit.literals) == 1 + assert "o0" in lit.literals + assert lit.literals["o0"].HasField("offloaded_metadata") == True + assert lit.literals["o0"].offloaded_metadata.uri.endswith("/o0_offloaded_metadata.pb") + assert lit.literals["o0"].offloaded_metadata.inferred_type == LiteralType(simple=SimpleType.INTEGER).to_flyte_idl() + elif ff == "o0_offloaded_metadata.pb": + lit = literals_pb2.Literal() + lit.ParseFromString(f.read()) + expected_output = v + 1 + assert lit == TypeEngine.to_literal(ctx, expected_output, int, TypeEngine.to_literal_type(int)).to_flyte_idl() + else: + assert False, f"Unexpected file {ff}" + + +def test_dispatch_execute_offloaded_nested_lists_of_literals_offloading_disabled(tmp_path_factory): + @task + def t1(a: typing.List[int]) -> typing.List[typing.List[str]]: + return [[f"string is: {x}" for x in a] for _ in range(len(a))] + + inputs_path = tmp_path_factory.mktemp("inputs") + outputs_path = tmp_path_factory.mktemp("outputs") + + ctx = context_manager.FlyteContext.current_context() + with get_flyte_context(tmp_path_factory, outputs_path) as ctx: + xs: typing.List[int] = [1, 2, 3] + input_literal_map = _literal_models.LiteralMap( + { + "a": TypeEngine.to_literal(ctx, xs, typing.List[int], TypeEngine.to_literal_type(typing.List[int])), + } + ) + + write_proto_to_file(input_literal_map.to_flyte_idl(), str(inputs_path/"inputs.pb")) + + # Ensure that this is not set by an external source + assert os.environ.get("_F_L_MIN_SIZE_MB") is None + + # Notice how we're setting the env var to None, which disables offloading completely + _dispatch_execute(ctx, lambda: t1, str(inputs_path/"inputs.pb"), str(outputs_path.absolute())) + + assert "error.pb" not in os.listdir(outputs_path) + + for ff in os.listdir(outputs_path): + with open(outputs_path/ff, "rb") as f: + if ff == "outputs.pb": + lit = literals_pb2.LiteralMap() + lit.ParseFromString(f.read()) + assert len(lit.literals) == 1 + assert "o0" in lit.literals + assert lit.literals["o0"].HasField("offloaded_metadata") == False + else: + assert False, f"Unexpected file {ff}"