Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Auto Cache Plugin #2971

Open
wants to merge 17 commits into
base: master
Choose a base branch
from
95 changes: 95 additions & 0 deletions flytekit/core/auto_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from dataclasses import dataclass
from typing import Any, Callable, List, Optional, Protocol, Tuple, Union, runtime_checkable

from flytekit.image_spec.image_spec import ImageSpec


@dataclass
class VersionParameters:
"""
Parameters used for version hash generation.

Args:
func (Optional[Callable]): The function to generate a version for
container_image (Optional[Union[str, ImageSpec]]): The container image to generate a version for
"""

func: Optional[Callable[..., Any]] = None
container_image: Optional[Union[str, ImageSpec]] = None


@runtime_checkable
class AutoCache(Protocol):
"""
A protocol that defines the interface for a caching mechanism
that generates a version hash of a function based on its source code.
"""

salt: str

def get_version(self, params: VersionParameters) -> str:
"""
Generate a version hash based on the provided parameters.

Args:
params (VersionParameters): Parameters to use for hash generation.

Returns:
str: The generated version hash.
"""
...


class CachePolicy:
"""
A class that combines multiple caching mechanisms to generate a version hash.

Args:
auto_cache_policies: A list of AutoCache instances (optional).
salt: Optional salt string to add uniqueness to the hash.
cache_serialize: Boolean to indicate if serialization should be used.
cache_version: A version string for the cache.
cache_ignore_input_vars: Tuple of input variable names to ignore.
"""

def __init__(
self,
auto_cache_policies: List["AutoCache"] = None,
salt: str = "",
cache_serialize: bool = False,
cache_version: str = "",
cache_ignore_input_vars: Tuple[str, ...] = (),
) -> None:
self.auto_cache_policies = auto_cache_policies or [] # Use an empty list if None is provided
self.salt = salt
self.cache_serialize = cache_serialize
self.cache_version = cache_version
self.cache_ignore_input_vars = cache_ignore_input_vars
Comment on lines +65 to +67
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the purpose of saving this state here? aren't these just forwarded to the underlying TaskMetadata?


def get_version(self, params: "VersionParameters") -> str:
"""
Generate a version hash using all cache objects. If the user passes a version, it takes precedence over auto_cache_policies.

Args:
params (VersionParameters): Parameters to use for hash generation.

Returns:
str: The combined hash from all cache objects.
"""
if self.cache_version:
return self.cache_version

if self.auto_cache_policies:
task_hash = ""
for cache_instance in self.auto_cache_policies:
# Apply the policy's salt to each cache instance
cache_instance.salt = self.salt
task_hash += cache_instance.get_version(params)

# Generate SHA-256 hash
import hashlib

hash_obj = hashlib.sha256(task_hash.encode())
return hash_obj.hexdigest()

return None
43 changes: 28 additions & 15 deletions flytekit/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from flytekit.core import launch_plan as _annotated_launchplan
from flytekit.core import workflow as _annotated_workflow
from flytekit.core.auto_cache import CachePolicy, VersionParameters
from flytekit.core.base_task import PythonTask, TaskMetadata, TaskResolverMixin
from flytekit.core.interface import Interface, output_name_generator, transform_function_to_interface
from flytekit.core.pod_template import PodTemplate
Expand Down Expand Up @@ -95,7 +96,7 @@ def find_pythontask_plugin(cls, plugin_config_type: type) -> Type[PythonFunction
def task(
_task_function: None = ...,
task_config: Optional[T] = ...,
cache: bool = ...,
cache: Union[bool, CachePolicy] = ...,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make this accept any AutoCache-compliant object?

Basically the user can provide just a single autocache object like CacheFunctionBody or compose multiple into a CachePolicy, but users should be forced to always use a CachePolicy object.

cache_serialize: bool = ...,
cache_version: str = ...,
cache_ignore_input_vars: Tuple[str, ...] = ...,
Expand Down Expand Up @@ -132,9 +133,9 @@ def task(

@overload
def task(
_task_function: Callable[P, FuncOut],
_task_function: Callable[..., FuncOut],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why change P to ...?

task_config: Optional[T] = ...,
cache: bool = ...,
cache: Union[bool, CachePolicy] = ...,
cache_serialize: bool = ...,
cache_version: str = ...,
cache_ignore_input_vars: Tuple[str, ...] = ...,
Expand Down Expand Up @@ -166,13 +167,13 @@ def task(
pod_template_name: Optional[str] = ...,
accelerator: Optional[BaseAccelerator] = ...,
pickle_untyped: bool = ...,
) -> Union[Callable[P, FuncOut], PythonFunctionTask[T]]: ...
) -> Union[Callable[..., FuncOut], PythonFunctionTask[T]]: ...


def task(
_task_function: Optional[Callable[P, FuncOut]] = None,
_task_function: Optional[Callable[..., FuncOut]] = None,
task_config: Optional[T] = None,
cache: bool = False,
cache: Union[bool, CachePolicy] = False,
cache_serialize: bool = False,
cache_version: str = "",
cache_ignore_input_vars: Tuple[str, ...] = (),
Expand Down Expand Up @@ -211,8 +212,8 @@ def task(
accelerator: Optional[BaseAccelerator] = None,
pickle_untyped: bool = False,
) -> Union[
Callable[P, FuncOut],
Callable[[Callable[P, FuncOut]], PythonFunctionTask[T]],
Callable[..., FuncOut],
Callable[[Callable[..., FuncOut]], PythonFunctionTask[T]],
PythonFunctionTask[T],
]:
"""
Expand Down Expand Up @@ -247,7 +248,7 @@ def my_task(x: int, y: typing.Dict[str, str]) -> str:
:param _task_function: This argument is implicitly passed and represents the decorated function
:param task_config: This argument provides configuration for a specific task types.
Please refer to the plugins documentation for the right object to use.
:param cache: Boolean that indicates if caching should be enabled
:param cache: Boolean that indicates if caching should be enabled or a list of AutoCache implementations
:param cache_serialize: Boolean that indicates if identical (ie. same inputs) instances of this task should be
executed in serial when caching is enabled. This means that given multiple concurrent executions over
identical inputs, only a single instance executes and the rest wait to reuse the cached results. This
Expand Down Expand Up @@ -342,12 +343,24 @@ def launch_dynamically():
:param pickle_untyped: Boolean that indicates if the task allows unspecified data types.
"""

def wrapper(fn: Callable[P, Any]) -> PythonFunctionTask[T]:
def wrapper(fn: Callable[..., Any]) -> PythonFunctionTask[T]:
if isinstance(cache, CachePolicy):
cache_val = True
params = VersionParameters(func=fn, container_image=container_image)
cache_version_val = cache_version or cache.get_version(params=params)
cache_serialize_val = cache_serialize or cache.cache_serialize
cache_serialize_val = cache_ignore_input_vars or cache.cache_ignore_input_vars
else:
cache_val = cache
cache_version_val = cache_version
cache_serialize_val = cache_serialize
cache_ignore_input_vars_val = cache_ignore_input_vars
Comment on lines +350 to +357
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the purpose of forwarding all of these parameters via the CachePolicy object? It doesn't look like it's being modified there.


_metadata = TaskMetadata(
cache=cache,
cache_serialize=cache_serialize,
cache_version=cache_version,
cache_ignore_input_vars=cache_ignore_input_vars,
cache=cache_val,
cache_serialize=cache_serialize_val,
cache_version=cache_version_val,
cache_ignore_input_vars=cache_ignore_input_vars_val,
retries=retries,
interruptible=interruptible,
deprecated=deprecated,
Expand Down Expand Up @@ -433,7 +446,7 @@ def wrapper(fn) -> ReferenceTask:
return wrapper


def decorate_function(fn: Callable[P, Any]) -> Callable[P, Any]:
def decorate_function(fn: Callable[..., Any]) -> Callable[..., Any]:
"""
Decorates the task with additional functionality if necessary.

Expand Down
10 changes: 5 additions & 5 deletions flytekit/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,25 +843,25 @@ def workflow(

@overload
def workflow(
_workflow_function: Callable[P, FuncOut],
_workflow_function: Callable[..., FuncOut],
failure_policy: Optional[WorkflowFailurePolicy] = ...,
interruptible: bool = ...,
on_failure: Optional[Union[WorkflowBase, Task]] = ...,
docs: Optional[Documentation] = ...,
pickle_untyped: bool = ...,
default_options: Optional[Options] = ...,
) -> Union[Callable[P, FuncOut], PythonFunctionWorkflow]: ...
) -> Union[Callable[..., FuncOut], PythonFunctionWorkflow]: ...


def workflow(
_workflow_function: Optional[Callable[P, FuncOut]] = None,
_workflow_function: Optional[Callable[..., FuncOut]] = None,
failure_policy: Optional[WorkflowFailurePolicy] = None,
interruptible: bool = False,
on_failure: Optional[Union[WorkflowBase, Task]] = None,
docs: Optional[Documentation] = None,
pickle_untyped: bool = False,
default_options: Optional[Options] = None,
) -> Union[Callable[P, FuncOut], Callable[[Callable[P, FuncOut]], PythonFunctionWorkflow], PythonFunctionWorkflow]:
) -> Union[Callable[..., FuncOut], Callable[[Callable[..., FuncOut]], PythonFunctionWorkflow], PythonFunctionWorkflow]:
"""
This decorator declares a function to be a Flyte workflow. Workflows are declarative entities that construct a DAG
of tasks using the data flow between tasks.
Expand Down Expand Up @@ -898,7 +898,7 @@ def workflow(
the labels and annotations are allowed to be set as defaults.
"""

def wrapper(fn: Callable[P, FuncOut]) -> PythonFunctionWorkflow:
def wrapper(fn: Callable[..., FuncOut]) -> PythonFunctionWorkflow:
workflow_metadata = WorkflowMetadata(on_failure=failure_policy or WorkflowFailurePolicy.FAIL_IMMEDIATELY)

workflow_metadata_defaults = WorkflowMetadataDefaults(interruptible)
Expand Down
58 changes: 58 additions & 0 deletions plugins/flytekit-auto-cache/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Flyte Auto Cache Plugin

This plugin provides a caching mechanism for Flyte tasks that generates a version hash based on the source code of the task and its dependencies. It allows users to manage the cache behavior.

## Usage

To install the plugin, run the following command:

```bash
pip install flytekitplugins-auto-cache
```

To use the caching mechanism in a Flyte task, you can define a `CachePolicy` that combines multiple caching strategies. Here’s an example of how to set it up:

```python
from flytekit import task
from flytekit.core.auto_cache import CachePolicy
from flytekitplugins.auto_cache import CacheFunctionBody, CachePrivateModules

cache_policy = CachePolicy(
auto_cache_policies = [
CacheFunctionBody(),
CachePrivateModules(root_dir="../my_package"),
...,
]
salt="my_salt"
)
Comment on lines +20 to +27
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's also provide an example of not needing to provide a CachePolicy object, e.g. just a passing in CacheFunctionBody.


@task(cache=cache_policy)
def task_fn():
...
```

### Salt Parameter

The `salt` parameter in the `CachePolicy` adds uniqueness to the generated hash. It can be used to differentiate between different versions of the same task. This ensures that even if the underlying code remains unchanged, the hash will vary if a different salt is provided. This feature is particularly useful for invalidating the cache for specific versions of a task.

## Cache Implementations

Users can add any number of cache policies that implement the `AutoCache` protocol defined in `@auto_cache.py`. Below are the implementations available so far:

### 1. CacheFunctionBody

This implementation hashes the contents of the function of interest, ignoring any formatting or comment changes. It ensures that the core logic of the function is considered for versioning.

### 2. CacheImage

This implementation includes the hash of the `container_image` object passed. If the image is specified as a name, that string is hashed. If it is an `ImageSpec`, the parametrization of the `ImageSpec` is hashed, allowing for precise versioning of the container image used in the task.

### 3. CachePrivateModules

This implementation recursively searches the task of interest for all callables and constants used. The contents of any callable (function or class) utilized by the task are hashed, ignoring formatting or comments. The values of the literal constants used are also included in the hash.

It accounts for both `import` and `from-import` statements at the global and local levels within a module or function. Any callables that are within site-packages (i.e., external libraries) are ignored.

### 4. CacheExternalDependencies

This implementation recursively searches through all the callables like `CachePrivateModules`, but when an external package is found, it records the version of the package, which is included in the hash. This ensures that changes in external dependencies are reflected in the task's versioning.
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""
.. currentmodule:: flytekitplugins.auto_cache

This package contains things that are useful when extending Flytekit.

.. autosummary::
:template: custom.rst
:toctree: generated/

CacheFunctionBody
CachePrivateModules
"""

from .cache_external_dependencies import CacheExternalDependencies
from .cache_function_body import CacheFunctionBody
from .cache_image import CacheImage
from .cache_private_modules import CachePrivateModules
Loading
Loading