Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Offload literals #2872

Merged
merged 28 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
12f9edf
wip - Implement offloading of literals
eapolinario Oct 18, 2024
5fefa85
Fix use of metadata bucket prefix
eapolinario Oct 18, 2024
1668fda
Fix repeated use of uri
eapolinario Oct 18, 2024
853df62
Add temporary representation for offloaded literal
eapolinario Oct 18, 2024
5e53a1b
Add one unit test
eapolinario Oct 22, 2024
177368d
Add another test
eapolinario Oct 22, 2024
5fc2e84
Stylistic changes to the two tests
eapolinario Oct 22, 2024
db48d18
Add test for min offloading threshold set to 1MB
eapolinario Oct 28, 2024
6884ee0
Pick a unique engine-dir for tests
eapolinario Oct 28, 2024
5a6423c
s/new_outputs/literal_map_copy/
eapolinario Oct 28, 2024
dbfea93
Remove unused constant
eapolinario Oct 28, 2024
adeed34
Merge remote-tracking branch 'origin' into offload-literals
eapolinario Oct 28, 2024
25908d9
Use output_prefix in definition of offloaded literals
eapolinario Oct 29, 2024
e827693
Add initial version of pbhash.py
eapolinario Nov 6, 2024
e0e2016
Add tests to verify that overriding the hash is carried over to offlo…
eapolinario Nov 7, 2024
b284492
Add a few more tests
eapolinario Nov 7, 2024
b579b83
Always import ParamSpec from `typing_extensions`
eapolinario Nov 7, 2024
8c2336e
Fix lint warnings
eapolinario Nov 7, 2024
c28d537
Merge remote-tracking branch 'origin' into offload-literals
eapolinario Nov 7, 2024
37b2bb4
Set inferred_type using the task type interface
eapolinario Nov 8, 2024
e25496d
Add comment about offloaded literals files and how they are uploaded …
eapolinario Nov 8, 2024
9276e96
Add offloading_enabled
eapolinario Nov 18, 2024
a8bdbca
Add more unit tests including a negative test
eapolinario Nov 19, 2024
fe822b9
Merge remote-tracking branch 'origin' into offload-literals
eapolinario Nov 19, 2024
a4fcfab
Fix bad merge
eapolinario Nov 19, 2024
b3a1b0d
Incorporate feedback.
eapolinario Nov 19, 2024
32c5896
Fix image name (unrelated to this PR - just a nice-to-have to decreas…
eapolinario Nov 21, 2024
12e194a
Add `is_map_task` to `_dispatch_execute`
eapolinario Nov 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 36 additions & 2 deletions flytekit/bin/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import traceback
import warnings
from sys import exit
from typing import Callable, List, Optional
from typing import Callable, Dict, List, Optional

import click
from flyteidl.core import literals_pb2 as _literals_pb2
Expand Down Expand Up @@ -137,7 +137,41 @@ def _dispatch_execute(
logger.warning("Task produces no outputs")
output_file_dict = {_constants.OUTPUT_FILE_NAME: _literal_models.LiteralMap(literals={})}
elif isinstance(outputs, _literal_models.LiteralMap):
output_file_dict = {_constants.OUTPUT_FILE_NAME: outputs}
offloaded_literals: Dict[str, _literal_models.Literal] = {}
literal_map_copy = {}

min_offloaded_size = int(os.environ.get("FK_L_MIN_SIZE_MB", "10")) * 1024 * 1024
max_offloaded_size = int(os.environ.get("FK_L_MAX_SIZE_MB", "1000")) * 1024 * 1024
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we'll need to update the injection of environment variables to match those names.

The need for short env vars is described flyteorg/flyte#5665.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok with keeping the current ones .


# Go over each output and create a separate offloaded in case its size is too large
for k, v in outputs.literals.items():
lit = v.to_flyte_idl()
if lit.ByteSize() >= max_offloaded_size:
raise ValueError(
f"Literal {k} is too large to be offloaded. Max literal size is {max_offloaded_size} whereas the literal size is {lit.ByteSize()} bytes"
)

if lit.ByteSize() >= min_offloaded_size:
logger.debug(f"Literal {k} is too large to be inlined, offloading to metadata bucket")

# TODO: hash calculation
eapolinario marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how will this hash interoperate with hashmethod as specified by the user? how does caching work again?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the hash calculation in entrypoint only happens after dispatch_execute (in other words, if a user defined a HashMethod then we can reuse it here). I just pushed an update that implements this idea.


offloaded_filename = f"{k}_offloaded_metadata.pb"

offloaded_literal = _literal_models.Literal(
offloaded_metadata=_literal_models.LiteralOffloadedMetadata(
uri=f"{output_prefix}/{offloaded_filename}",
eapolinario marked this conversation as resolved.
Show resolved Hide resolved
size_bytes=lit.ByteSize(),
# TODO: do I have to set the inferred literal type?
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pmahindrakar-oss , in what conditions we have to set this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is needed when we have launcplan execution created by propeller
eg : here
https://demo.hosted.unionai.cloud/console/projects/flytesnacks/domains/development/executions/akzthr8gxbdvpnq6f5lr/nodes

[UserError] failed to launch workflow, caused by: rpc error: code = InvalidArgument desc = invalid input input wrong type. Expected collection_type:{simple:STRING}, but got collection_type:{}

Code link in my comments below https://github.com/flyteorg/flytekit/pull/2872/files#r1819666906

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I commented in about separating those discussions and how we use the task typed interface to set inferred_type with the appropriate literal type.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're trying to get rid of LiteralTypeForLiteral in flyteorg/flyte#5909. Still pending a review.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we get rid of the literaltypeforliteral function, and we don't need inferredtype, we should deprecate the field in the idl. don't want to keep it around and have future us wonder if it's necessary.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's separate those two discussions. We might be able to remove LiteralTypeForLiteral, but that shouldn't block us from getting the actual literal type from the typed interface in flytekit and use it here.

Also, if/when we end up removing LiteralTypeForLiteral we should use a different field to store the offloaded literal literal type and not use inferred in the name since it'll be the exact type.

)
)
literal_map_copy[k] = offloaded_literal
offloaded_literals[offloaded_filename] = v
else:
literal_map_copy[k] = v
outputs = _literal_models.LiteralMap(literals=literal_map_copy)

output_file_dict = {_constants.OUTPUT_FILE_NAME: outputs, **offloaded_literals}
elif isinstance(outputs, _dynamic_job.DynamicJobSpec):
output_file_dict = {_constants.FUTURES_FILE_NAME: outputs}
else:
Expand Down
3 changes: 3 additions & 0 deletions flytekit/interaction/string_literals.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ def literal_string_repr(lit: Literal) -> typing.Any:
return [literal_string_repr(i) for i in lit.collection.literals]
if lit.map:
return {k: literal_string_repr(v) for k, v in lit.map.literals.items()}
if lit.offloaded_metadata:
# TODO: load literal from offloaded literal?
return f"Offloaded literal metadata: {lit.offloaded_metadata}"
Comment on lines +64 to +66
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there an example we can try to see if we need to load the literal here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is relevant for the pyflyte fetch command:

❯ pyflyte --config ~/.flyte/config-sandbox.yaml fetch flyte://v1/flytesnacks/development/asvskwn766f5v492pzgt/n0-0-n1/o
Fetching data from flyte://v1/flytesnacks/development/asvskwn766f5v492pzgt/n0-0-n1/o...
╭───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮
│ {                                                                                                                                                                                                                                                                                             │
│     'o0': [                                                                                                                                                                                                                                                                                   │
│         'Offloaded literal metadata: Flyte Serialized object (LiteralOffloadedMetadata):\n  uri: s3://my-s3-bucket/metadata/propeller/flytesnacks-development- [...]\n  size_bytes: 39936015',                                                                                                │
│         'Offloaded literal metadata: Flyte Serialized object (LiteralOffloadedMetadata):\n  uri: s3://my-s3-bucket/metadata/propeller/flytesnacks-development- [...]\n  size_bytes: 39936015',                                                                                                │
...
│         'Offloaded literal metadata: Flyte Serialized object (LiteralOffloadedMetadata):\n  uri: s3://my-s3-bucket/metadata/propeller/flytesnacks-development- [...]\n  size_bytes: 39936015'                                                                                                 │
│     ]                                                                                                                                                                                                                                                                                         │
│ }                                                                                                                                                                                                                                                                                             │
╰───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯

raise ValueError(f"Unknown literal type {lit}")


Expand Down
2 changes: 1 addition & 1 deletion flytekit/models/literals.py
Original file line number Diff line number Diff line change
Expand Up @@ -991,7 +991,7 @@ def to_flyte_idl(self):
map=self.map.to_flyte_idl() if self.map is not None else None,
hash=self.hash,
metadata=self.metadata,
offloaded_metadata=self.offloaded_metadata.to_flyte_idl() if self.offloaded_metadata else None,
offloaded_metadata=self.offloaded_metadata.to_flyte_idl() if self.offloaded_metadata is not None else None,
)

@classmethod
Expand Down
261 changes: 259 additions & 2 deletions tests/flytekit/unit/bin/test_python_entrypoint.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from dataclasses import dataclass
from datetime import datetime
import os
import re
import textwrap
Expand All @@ -8,21 +10,26 @@
import mock
import pytest
from flyteidl.core.errors_pb2 import ErrorDocument
from flyteidl.core import literals_pb2
from flyteidl.core.literals_pb2 import Literal, LiteralCollection, Scalar, Primitive

from flytekit.bin.entrypoint import _dispatch_execute, normalize_inputs, setup_execution, get_traceback_str
from flytekit.configuration import Image, ImageConfig, SerializationSettings
from flytekit.core import mock_stats
from flytekit.models.core import identifier as id_models
from flytekit.core import context_manager
from flytekit.core.base_task import IgnoreOutputs
from flytekit.core.dynamic_workflow_task import dynamic
from flytekit.core.promise import VoidPromise
from flytekit.core.task import task
from flytekit.core.type_engine import TypeEngine
from flytekit.exceptions import user as user_exceptions
from flytekit.exceptions.scopes import system_entry_point
from flytekit.exceptions.scopes import system_entry_point, user_entry_point
from flytekit.exceptions.user import FlyteRecoverableException, FlyteUserRuntimeException
from flytekit.models import literals as _literal_models
from flytekit.models.core import errors as error_models
from flytekit.models.core import errors as error_models, execution
from flytekit.models.core import execution as execution_models
from flytekit.core.utils import write_proto_to_file


@mock.patch("flytekit.core.utils.load_proto_from_file")
Expand Down Expand Up @@ -453,3 +460,253 @@ def test_get_traceback_str():
expected_error_re = re.compile(expected_error_pattern)
print(traceback_str) # helpful for debugging
assert expected_error_re.match(traceback_str) is not None


def test_dispatch_execute_offloaded_literals(tmp_path_factory):
@task
def t1(a: typing.List[int]) -> typing.List[str]:
return [f"string is: {x}" for x in a]

inputs_path = tmp_path_factory.mktemp("inputs")
outputs_path = tmp_path_factory.mktemp("outputs")

ctx = context_manager.FlyteContext.current_context()
with context_manager.FlyteContextManager.with_context(
ctx.with_execution_state(
ctx.execution_state.with_params(
engine_dir=tmp_path_factory.mktemp("engine_dir"),
mode=context_manager.ExecutionState.Mode.TASK_EXECUTION,
user_space_params=context_manager.ExecutionParameters(
execution_date=datetime.now(),
tmp_dir="/tmp",
stats=mock_stats.MockStats(),
logging=None,
raw_output_prefix="",
output_metadata_prefix=str(outputs_path.absolute()),
execution_id=id_models.WorkflowExecutionIdentifier("p", "d", "n"),
),
),
),
) as ctx:
xs: typing.List[int] = [1, 2, 3]
input_literal_map = _literal_models.LiteralMap(
{
"a": _literal_models.Literal(
collection=_literal_models.LiteralCollection(
literals=[
_literal_models.Literal(
scalar=_literal_models.Scalar(primitive=_literal_models.Primitive(integer=x)),
) for x in xs
]
)
)
}
)

write_proto_to_file(input_literal_map.to_flyte_idl(), str(inputs_path/"inputs.pb"))

with mock.patch.dict(os.environ, {"FK_L_MIN_SIZE_MB": "0"}):
_dispatch_execute(ctx, lambda: t1, str(inputs_path/"inputs.pb"), str(outputs_path.absolute()))

assert "error.pb" not in os.listdir(outputs_path)

for ff in os.listdir(outputs_path):
with open(outputs_path/ff, "rb") as f:
if ff == "outputs.pb":
lit = literals_pb2.LiteralMap()
lit.ParseFromString(f.read())
assert len(lit.literals) == 1
assert "o0" in lit.literals
assert lit.literals["o0"].offloaded_metadata is not None
assert lit.literals["o0"].offloaded_metadata.size_bytes == 62
assert lit.literals["o0"].offloaded_metadata.uri.endswith("/o0_offloaded_metadata.pb")
elif ff == "o0_offloaded_metadata.pb":
lit = literals_pb2.Literal()
lit.ParseFromString(f.read())
assert lit == Literal(
collection=LiteralCollection(
literals=[
Literal(
scalar=Scalar(primitive=Primitive(string_value="string is: 1")),
),
Literal(
scalar=Scalar(primitive=Primitive(string_value="string is: 2")),
),
Literal(
scalar=Scalar(primitive=Primitive(string_value="string is: 3")),
),
]
)
)
else:
assert False, f"Unexpected file {ff}"


def test_dispatch_execute_offloaded_literals_two_outputs_offloaded(tmp_path_factory):
@task
def t1(xs: typing.List[int]) -> typing.Tuple[int, typing.List[str]]:
return sum(xs), [f"string is: {x}" for x in xs]

inputs_path = tmp_path_factory.mktemp("inputs")
outputs_path = tmp_path_factory.mktemp("outputs")

ctx = context_manager.FlyteContext.current_context()
with context_manager.FlyteContextManager.with_context(
ctx.with_execution_state(
ctx.execution_state.with_params(
engine_dir=tmp_path_factory.mktemp("engine_dir"),
mode=context_manager.ExecutionState.Mode.TASK_EXECUTION,
user_space_params=context_manager.ExecutionParameters(
execution_date=datetime.now(),
tmp_dir="/tmp",
stats=mock_stats.MockStats(),
logging=None,
raw_output_prefix="",
output_metadata_prefix=str(outputs_path.absolute()),
execution_id=id_models.WorkflowExecutionIdentifier("p", "d", "n"),
),
),
),
) as ctx:
xs: typing.List[int] = [1, 2, 3, 4]
input_literal_map = _literal_models.LiteralMap(
{
"xs": _literal_models.Literal(
collection=_literal_models.LiteralCollection(
literals=[
_literal_models.Literal(
scalar=_literal_models.Scalar(primitive=_literal_models.Primitive(integer=x)),
) for x in xs
]
)
)
}
)

write_proto_to_file(input_literal_map.to_flyte_idl(), str(inputs_path/"inputs.pb"))

with mock.patch.dict(os.environ, {"FK_L_MIN_SIZE_MB": "0"}):
_dispatch_execute(ctx, lambda: t1, str(inputs_path/"inputs.pb"), str(outputs_path.absolute()))

assert "error.pb" not in os.listdir(outputs_path)

for ff in os.listdir(outputs_path):
with open(outputs_path/ff, "rb") as f:
if ff == "outputs.pb":
lit = literals_pb2.LiteralMap()
lit.ParseFromString(f.read())
assert len(lit.literals) == 2
assert "o0" in lit.literals
assert lit.literals["o0"].offloaded_metadata is not None
assert lit.literals["o0"].offloaded_metadata.size_bytes == 6
assert lit.literals["o0"].offloaded_metadata.uri.endswith("/o0_offloaded_metadata.pb")
assert "o1" in lit.literals
assert lit.literals["o1"].offloaded_metadata is not None
assert lit.literals["o1"].offloaded_metadata.size_bytes == 82
assert lit.literals["o1"].offloaded_metadata.uri.endswith("/o1_offloaded_metadata.pb")
elif ff == "o0_offloaded_metadata.pb":
lit = literals_pb2.Literal()
lit.ParseFromString(f.read())
assert lit == Literal(
scalar=Scalar(primitive=Primitive(integer=10)),
)
elif ff == "o1_offloaded_metadata.pb":
lit = literals_pb2.Literal()
lit.ParseFromString(f.read())
assert lit == Literal(
collection=LiteralCollection(
literals=[
Literal(
scalar=Scalar(primitive=Primitive(string_value="string is: 1")),
),
Literal(
scalar=Scalar(primitive=Primitive(string_value="string is: 2")),
),
Literal(
scalar=Scalar(primitive=Primitive(string_value="string is: 3")),
),
Literal(
scalar=Scalar(primitive=Primitive(string_value="string is: 4")),
),
]
)
)
else:
assert False, f"Unexpected file {ff}"


def test_dispatch_execute_offloaded_literals_two_outputs_only_second_one_offloaded(tmp_path_factory):
@dataclass
class DC:
a: typing.List[int]
b: typing.List[str]

@task
def t1(n: int) -> typing.Tuple[int, DC]:
return n, DC(a=list(range(n)), b=[f"string is: {x}" for x in range(n)])

inputs_path = tmp_path_factory.mktemp("inputs")
outputs_path = tmp_path_factory.mktemp("outputs")

ctx = context_manager.FlyteContext.current_context()
with context_manager.FlyteContextManager.with_context(
ctx.with_execution_state(
ctx.execution_state.with_params(
engine_dir=tmp_path_factory.mktemp("engine_dir"),
mode=context_manager.ExecutionState.Mode.TASK_EXECUTION,
user_space_params=context_manager.ExecutionParameters(
execution_date=datetime.now(),
tmp_dir="/tmp",
stats=mock_stats.MockStats(),
logging=None,
raw_output_prefix="",
output_metadata_prefix=str(outputs_path.absolute()),
execution_id=id_models.WorkflowExecutionIdentifier("p", "d", "n"),
),
),
),
) as ctx:
input_literal_map = _literal_models.LiteralMap(
{
"n": _literal_models.Literal(
scalar=_literal_models.Scalar(primitive=_literal_models.Primitive(integer=56_000)),
)
}
)

write_proto_to_file(input_literal_map.to_flyte_idl(), str(inputs_path/"inputs.pb"))

# Notice how the threshold is set to 1MB
with mock.patch.dict(os.environ, {"FK_L_MIN_SIZE_MB": "1"}):
_dispatch_execute(ctx, lambda: t1, str(inputs_path/"inputs.pb"), str(outputs_path.absolute()))

assert "error.pb" not in os.listdir(outputs_path)

# o0 is not offloaded
assert "o0_offloaded_metadata.pb" not in os.listdir(outputs_path)

for ff in os.listdir(outputs_path):
with open(outputs_path/ff, "rb") as f:
if ff == "outputs.pb":
lit = literals_pb2.LiteralMap()
lit.ParseFromString(f.read())
assert len(lit.literals) == 2

# o0 is not offloaded
assert "o0" in lit.literals
assert lit.literals["o0"].HasField("offloaded_metadata") is False

# o1 is offloaded
assert "o1" in lit.literals
assert lit.literals["o1"].HasField("offloaded_metadata") is True
assert lit.literals["o1"].offloaded_metadata.size_bytes == 1108538
assert lit.literals["o1"].offloaded_metadata.uri.endswith("/o1_offloaded_metadata.pb")
elif ff == "o1_offloaded_metadata.pb":
lit = literals_pb2.Literal()
lit.ParseFromString(f.read())
# Load the dataclass from the proto
transformer = TypeEngine.get_transformer(DC)
dc = transformer.to_python_value(ctx, _literal_models.Literal.from_flyte_idl(lit), DC)
assert dc.a == list(range(56_000))
else:
assert False, f"Unexpected file {ff}"
Loading