Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Offload literals #2872

Merged
merged 28 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
12f9edf
wip - Implement offloading of literals
eapolinario Oct 18, 2024
5fefa85
Fix use of metadata bucket prefix
eapolinario Oct 18, 2024
1668fda
Fix repeated use of uri
eapolinario Oct 18, 2024
853df62
Add temporary representation for offloaded literal
eapolinario Oct 18, 2024
5e53a1b
Add one unit test
eapolinario Oct 22, 2024
177368d
Add another test
eapolinario Oct 22, 2024
5fc2e84
Stylistic changes to the two tests
eapolinario Oct 22, 2024
db48d18
Add test for min offloading threshold set to 1MB
eapolinario Oct 28, 2024
6884ee0
Pick a unique engine-dir for tests
eapolinario Oct 28, 2024
5a6423c
s/new_outputs/literal_map_copy/
eapolinario Oct 28, 2024
dbfea93
Remove unused constant
eapolinario Oct 28, 2024
adeed34
Merge remote-tracking branch 'origin' into offload-literals
eapolinario Oct 28, 2024
25908d9
Use output_prefix in definition of offloaded literals
eapolinario Oct 29, 2024
e827693
Add initial version of pbhash.py
eapolinario Nov 6, 2024
e0e2016
Add tests to verify that overriding the hash is carried over to offlo…
eapolinario Nov 7, 2024
b284492
Add a few more tests
eapolinario Nov 7, 2024
b579b83
Always import ParamSpec from `typing_extensions`
eapolinario Nov 7, 2024
8c2336e
Fix lint warnings
eapolinario Nov 7, 2024
c28d537
Merge remote-tracking branch 'origin' into offload-literals
eapolinario Nov 7, 2024
37b2bb4
Set inferred_type using the task type interface
eapolinario Nov 8, 2024
e25496d
Add comment about offloaded literals files and how they are uploaded …
eapolinario Nov 8, 2024
9276e96
Add offloading_enabled
eapolinario Nov 18, 2024
a8bdbca
Add more unit tests including a negative test
eapolinario Nov 19, 2024
fe822b9
Merge remote-tracking branch 'origin' into offload-literals
eapolinario Nov 19, 2024
a4fcfab
Fix bad merge
eapolinario Nov 19, 2024
b3a1b0d
Incorporate feedback.
eapolinario Nov 19, 2024
32c5896
Fix image name (unrelated to this PR - just a nice-to-have to decreas…
eapolinario Nov 21, 2024
12e194a
Add `is_map_task` to `_dispatch_execute`
eapolinario Nov 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 47 additions & 2 deletions flytekit/bin/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import uuid
import warnings
from sys import exit
from typing import Callable, List, Optional
from typing import Callable, Dict, List, Optional

import click
from flyteidl.core import literals_pb2 as _literals_pb2
Expand Down Expand Up @@ -55,6 +55,7 @@
from flytekit.models.core import identifier as _identifier
from flytekit.tools.fast_registration import download_distribution as _download_distribution
from flytekit.tools.module_loader import load_object_from_module
from flytekit.utils.pbhash import compute_hash_string


def get_version_message():
Expand Down Expand Up @@ -179,7 +180,51 @@ def _dispatch_execute(
logger.warning("Task produces no outputs")
output_file_dict = {_constants.OUTPUT_FILE_NAME: _literal_models.LiteralMap(literals={})}
elif isinstance(outputs, _literal_models.LiteralMap):
output_file_dict = {_constants.OUTPUT_FILE_NAME: outputs}
# The keys in this map hold the filenames to the offloaded proto literals.
offloaded_literals: Dict[str, _literal_models.Literal] = {}
literal_map_copy = {}

offloading_enabled = os.environ.get("_F_L_MIN_SIZE_MB", None) is not None
min_offloaded_size = -1
max_offloaded_size = -1
if offloading_enabled:
min_offloaded_size = int(os.environ.get("_F_L_MIN_SIZE_MB", "10")) * 1024 * 1024
max_offloaded_size = int(os.environ.get("_F_L_MAX_SIZE_MB", "1000")) * 1024 * 1024

# Go over each output and create a separate offloaded in case its size is too large
for k, v in outputs.literals.items():
literal_map_copy[k] = v

if not offloading_enabled:
continue

lit = v.to_flyte_idl()
if max_offloaded_size != -1 and lit.ByteSize() >= max_offloaded_size:
raise ValueError(
f"Literal {k} is too large to be offloaded. Max literal size is {max_offloaded_size} whereas the literal size is {lit.ByteSize()} bytes"
)

if min_offloaded_size != -1 and lit.ByteSize() >= min_offloaded_size:
logger.debug(f"Literal {k} is too large to be inlined, offloading to metadata bucket")

# This file will hold the offloaded literal and will be written to the output prefix
# alongside the regular outputs.pb, deck.pb, etc.
# N.B.: by construction `offloaded_filename` is guaranteed to be unique
offloaded_filename = f"{k}_offloaded_metadata.pb"
offloaded_literal = _literal_models.Literal(
offloaded_metadata=_literal_models.LiteralOffloadedMetadata(
uri=f"{output_prefix}/{offloaded_filename}",
eapolinario marked this conversation as resolved.
Show resolved Hide resolved
size_bytes=lit.ByteSize(),
# TODO: remove after https://github.com/flyteorg/flyte/pull/5909 is merged
inferred_type=task_def.interface.outputs[k].type,
),
hash=v.hash if v.hash is not None else compute_hash_string(lit),
)
literal_map_copy[k] = offloaded_literal
offloaded_literals[offloaded_filename] = v
outputs = _literal_models.LiteralMap(literals=literal_map_copy)

output_file_dict = {_constants.OUTPUT_FILE_NAME: outputs, **offloaded_literals}
elif isinstance(outputs, _dynamic_job.DynamicJobSpec):
output_file_dict = {_constants.FUTURES_FILE_NAME: outputs}
else:
Expand Down
8 changes: 2 additions & 6 deletions flytekit/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,7 @@
from functools import update_wrapper
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union, overload

from flytekit.core.utils import str2bool

try:
from typing import ParamSpec
except ImportError:
from typing_extensions import ParamSpec # type: ignore
from typing_extensions import ParamSpec # type: ignore

from flytekit.core import launch_plan as _annotated_launchplan
from flytekit.core import workflow as _annotated_workflow
Expand All @@ -20,6 +15,7 @@
from flytekit.core.python_function_task import PythonFunctionTask
from flytekit.core.reference_entity import ReferenceEntity, TaskReference
from flytekit.core.resources import Resources
from flytekit.core.utils import str2bool
from flytekit.deck import DeckField
from flytekit.extras.accelerators import BaseAccelerator
from flytekit.image_spec.image_spec import ImageSpec
Expand Down
6 changes: 1 addition & 5 deletions flytekit/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,9 @@
from functools import update_wrapper
from typing import Any, Callable, Coroutine, Dict, List, Optional, Tuple, Type, Union, cast, overload

from typing_extensions import ParamSpec # type: ignore
from typing_inspect import is_optional_type

try:
from typing import ParamSpec
except ImportError:
from typing_extensions import ParamSpec # type: ignore

from flytekit.core import constants as _common_constants
from flytekit.core import launch_plan as _annotated_launch_plan
from flytekit.core.base_task import PythonTask, Task
Expand Down
3 changes: 3 additions & 0 deletions flytekit/interaction/string_literals.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@
return [literal_string_repr(i) for i in lit.collection.literals]
if lit.map:
return {k: literal_string_repr(v) for k, v in lit.map.literals.items()}
if lit.offloaded_metadata:
# TODO: load literal from offloaded literal?
return f"Offloaded literal metadata: {lit.offloaded_metadata}"

Check warning on line 66 in flytekit/interaction/string_literals.py

View check run for this annotation

Codecov / codecov/patch

flytekit/interaction/string_literals.py#L66

Added line #L66 was not covered by tests
Comment on lines +64 to +66
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there an example we can try to see if we need to load the literal here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is relevant for the pyflyte fetch command:

❯ pyflyte --config ~/.flyte/config-sandbox.yaml fetch flyte://v1/flytesnacks/development/asvskwn766f5v492pzgt/n0-0-n1/o
Fetching data from flyte://v1/flytesnacks/development/asvskwn766f5v492pzgt/n0-0-n1/o...
╭───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮
│ {                                                                                                                                                                                                                                                                                             │
│     'o0': [                                                                                                                                                                                                                                                                                   │
│         'Offloaded literal metadata: Flyte Serialized object (LiteralOffloadedMetadata):\n  uri: s3://my-s3-bucket/metadata/propeller/flytesnacks-development- [...]\n  size_bytes: 39936015',                                                                                                │
│         'Offloaded literal metadata: Flyte Serialized object (LiteralOffloadedMetadata):\n  uri: s3://my-s3-bucket/metadata/propeller/flytesnacks-development- [...]\n  size_bytes: 39936015',                                                                                                │
...
│         'Offloaded literal metadata: Flyte Serialized object (LiteralOffloadedMetadata):\n  uri: s3://my-s3-bucket/metadata/propeller/flytesnacks-development- [...]\n  size_bytes: 39936015'                                                                                                 │
│     ]                                                                                                                                                                                                                                                                                         │
│ }                                                                                                                                                                                                                                                                                             │
╰───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯

raise ValueError(f"Unknown literal type {lit}")


Expand Down
2 changes: 1 addition & 1 deletion flytekit/models/literals.py
Original file line number Diff line number Diff line change
Expand Up @@ -991,7 +991,7 @@ def to_flyte_idl(self):
map=self.map.to_flyte_idl() if self.map is not None else None,
hash=self.hash,
metadata=self.metadata,
offloaded_metadata=self.offloaded_metadata.to_flyte_idl() if self.offloaded_metadata else None,
offloaded_metadata=self.offloaded_metadata.to_flyte_idl() if self.offloaded_metadata is not None else None,
)

@classmethod
Expand Down
39 changes: 39 additions & 0 deletions flytekit/utils/pbhash.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# This is a module that provides hashing utilities for Protobuf objects.
import base64
import hashlib
import json

Check warning on line 4 in flytekit/utils/pbhash.py

View check run for this annotation

Codecov / codecov/patch

flytekit/utils/pbhash.py#L2-L4

Added lines #L2 - L4 were not covered by tests

from google.protobuf import json_format
from google.protobuf.message import Message

Check warning on line 7 in flytekit/utils/pbhash.py

View check run for this annotation

Codecov / codecov/patch

flytekit/utils/pbhash.py#L6-L7

Added lines #L6 - L7 were not covered by tests


def compute_hash(pb: Message) -> bytes:

Check warning on line 10 in flytekit/utils/pbhash.py

View check run for this annotation

Codecov / codecov/patch

flytekit/utils/pbhash.py#L10

Added line #L10 was not covered by tests
"""
Computes a deterministic hash in bytes for the Protobuf object.
"""
try:
pb_dict = json_format.MessageToDict(pb)

Check warning on line 15 in flytekit/utils/pbhash.py

View check run for this annotation

Codecov / codecov/patch

flytekit/utils/pbhash.py#L14-L15

Added lines #L14 - L15 were not covered by tests
# json.dumps with sorted keys to ensure stability
stable_json_str = json.dumps(

Check warning on line 17 in flytekit/utils/pbhash.py

View check run for this annotation

Codecov / codecov/patch

flytekit/utils/pbhash.py#L17

Added line #L17 was not covered by tests
pb_dict, sort_keys=True, separators=(",", ":")
) # separators to ensure no extra spaces
except Exception as e:
raise ValueError(f"Failed to marshal Protobuf object {pb} to JSON with error: {e}")

Check warning on line 21 in flytekit/utils/pbhash.py

View check run for this annotation

Codecov / codecov/patch

flytekit/utils/pbhash.py#L20-L21

Added lines #L20 - L21 were not covered by tests

try:

Check warning on line 23 in flytekit/utils/pbhash.py

View check run for this annotation

Codecov / codecov/patch

flytekit/utils/pbhash.py#L23

Added line #L23 was not covered by tests
# Deterministically hash the JSON object to a byte array. Using SHA-256 for hashing here,
# assuming it provides a consistent hash output.
hash_obj = hashlib.sha256(stable_json_str.encode("utf-8"))
except Exception as e:
raise ValueError(f"Failed to hash JSON for Protobuf object {pb} with error: {e}")

Check warning on line 28 in flytekit/utils/pbhash.py

View check run for this annotation

Codecov / codecov/patch

flytekit/utils/pbhash.py#L26-L28

Added lines #L26 - L28 were not covered by tests

# The digest is guaranteed to be 32 bytes long
return hash_obj.digest()

Check warning on line 31 in flytekit/utils/pbhash.py

View check run for this annotation

Codecov / codecov/patch

flytekit/utils/pbhash.py#L31

Added line #L31 was not covered by tests


def compute_hash_string(pb: Message) -> str:

Check warning on line 34 in flytekit/utils/pbhash.py

View check run for this annotation

Codecov / codecov/patch

flytekit/utils/pbhash.py#L34

Added line #L34 was not covered by tests
"""
Computes a deterministic hash in base64 encoded string for the Protobuf object
"""
hash_bytes = compute_hash(pb)
return base64.b64encode(hash_bytes).decode("utf-8")

Check warning on line 39 in flytekit/utils/pbhash.py

View check run for this annotation

Codecov / codecov/patch

flytekit/utils/pbhash.py#L38-L39

Added lines #L38 - L39 were not covered by tests
Loading
Loading