-
Notifications
You must be signed in to change notification settings - Fork 300
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Auto Cache Plugin #2971
base: master
Are you sure you want to change the base?
Auto Cache Plugin #2971
Changes from all commits
2786c5b
b18bac7
50552a9
73d2327
6d5cdbf
f76f59a
a5fc1bb
ff3af99
7d06370
011bd67
d5d9576
2b0fbb8
b4911ab
02f6b53
47d01ef
f1ebdc9
ff3555f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
from dataclasses import dataclass | ||
from typing import Any, Callable, List, Optional, Protocol, Tuple, Union, runtime_checkable | ||
|
||
from flytekit.image_spec.image_spec import ImageSpec | ||
|
||
|
||
@dataclass | ||
class VersionParameters: | ||
""" | ||
Parameters used for version hash generation. | ||
|
||
Args: | ||
func (Optional[Callable]): The function to generate a version for | ||
container_image (Optional[Union[str, ImageSpec]]): The container image to generate a version for | ||
""" | ||
|
||
func: Optional[Callable[..., Any]] = None | ||
container_image: Optional[Union[str, ImageSpec]] = None | ||
|
||
|
||
@runtime_checkable | ||
class AutoCache(Protocol): | ||
""" | ||
A protocol that defines the interface for a caching mechanism | ||
that generates a version hash of a function based on its source code. | ||
""" | ||
|
||
salt: str | ||
|
||
def get_version(self, params: VersionParameters) -> str: | ||
""" | ||
Generate a version hash based on the provided parameters. | ||
|
||
Args: | ||
params (VersionParameters): Parameters to use for hash generation. | ||
|
||
Returns: | ||
str: The generated version hash. | ||
""" | ||
... | ||
|
||
|
||
class CachePolicy: | ||
""" | ||
A class that combines multiple caching mechanisms to generate a version hash. | ||
|
||
Args: | ||
auto_cache_policies: A list of AutoCache instances (optional). | ||
salt: Optional salt string to add uniqueness to the hash. | ||
cache_serialize: Boolean to indicate if serialization should be used. | ||
cache_version: A version string for the cache. | ||
cache_ignore_input_vars: Tuple of input variable names to ignore. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
auto_cache_policies: List["AutoCache"] = None, | ||
salt: str = "", | ||
cache_serialize: bool = False, | ||
cache_version: str = "", | ||
cache_ignore_input_vars: Tuple[str, ...] = (), | ||
) -> None: | ||
self.auto_cache_policies = auto_cache_policies or [] # Use an empty list if None is provided | ||
self.salt = salt | ||
self.cache_serialize = cache_serialize | ||
self.cache_version = cache_version | ||
self.cache_ignore_input_vars = cache_ignore_input_vars | ||
|
||
def get_version(self, params: "VersionParameters") -> str: | ||
""" | ||
Generate a version hash using all cache objects. If the user passes a version, it takes precedence over auto_cache_policies. | ||
|
||
Args: | ||
params (VersionParameters): Parameters to use for hash generation. | ||
|
||
Returns: | ||
str: The combined hash from all cache objects. | ||
""" | ||
if self.cache_version: | ||
return self.cache_version | ||
|
||
if self.auto_cache_policies: | ||
task_hash = "" | ||
for cache_instance in self.auto_cache_policies: | ||
# Apply the policy's salt to each cache instance | ||
cache_instance.salt = self.salt | ||
task_hash += cache_instance.get_version(params) | ||
|
||
# Generate SHA-256 hash | ||
import hashlib | ||
|
||
hash_obj = hashlib.sha256(task_hash.encode()) | ||
return hash_obj.hexdigest() | ||
|
||
return None |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,6 +9,7 @@ | |
|
||
from flytekit.core import launch_plan as _annotated_launchplan | ||
from flytekit.core import workflow as _annotated_workflow | ||
from flytekit.core.auto_cache import CachePolicy, VersionParameters | ||
from flytekit.core.base_task import PythonTask, TaskMetadata, TaskResolverMixin | ||
from flytekit.core.interface import Interface, output_name_generator, transform_function_to_interface | ||
from flytekit.core.pod_template import PodTemplate | ||
|
@@ -95,7 +96,7 @@ def find_pythontask_plugin(cls, plugin_config_type: type) -> Type[PythonFunction | |
def task( | ||
_task_function: None = ..., | ||
task_config: Optional[T] = ..., | ||
cache: bool = ..., | ||
cache: Union[bool, CachePolicy] = ..., | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we make this accept any Basically the user can provide just a single autocache object like |
||
cache_serialize: bool = ..., | ||
cache_version: str = ..., | ||
cache_ignore_input_vars: Tuple[str, ...] = ..., | ||
|
@@ -132,9 +133,9 @@ def task( | |
|
||
@overload | ||
def task( | ||
_task_function: Callable[P, FuncOut], | ||
_task_function: Callable[..., FuncOut], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why change |
||
task_config: Optional[T] = ..., | ||
cache: bool = ..., | ||
cache: Union[bool, CachePolicy] = ..., | ||
cache_serialize: bool = ..., | ||
cache_version: str = ..., | ||
cache_ignore_input_vars: Tuple[str, ...] = ..., | ||
|
@@ -166,13 +167,13 @@ def task( | |
pod_template_name: Optional[str] = ..., | ||
accelerator: Optional[BaseAccelerator] = ..., | ||
pickle_untyped: bool = ..., | ||
) -> Union[Callable[P, FuncOut], PythonFunctionTask[T]]: ... | ||
) -> Union[Callable[..., FuncOut], PythonFunctionTask[T]]: ... | ||
|
||
|
||
def task( | ||
_task_function: Optional[Callable[P, FuncOut]] = None, | ||
_task_function: Optional[Callable[..., FuncOut]] = None, | ||
task_config: Optional[T] = None, | ||
cache: bool = False, | ||
cache: Union[bool, CachePolicy] = False, | ||
cache_serialize: bool = False, | ||
cache_version: str = "", | ||
cache_ignore_input_vars: Tuple[str, ...] = (), | ||
|
@@ -211,8 +212,8 @@ def task( | |
accelerator: Optional[BaseAccelerator] = None, | ||
pickle_untyped: bool = False, | ||
) -> Union[ | ||
Callable[P, FuncOut], | ||
Callable[[Callable[P, FuncOut]], PythonFunctionTask[T]], | ||
Callable[..., FuncOut], | ||
Callable[[Callable[..., FuncOut]], PythonFunctionTask[T]], | ||
PythonFunctionTask[T], | ||
]: | ||
""" | ||
|
@@ -247,7 +248,7 @@ def my_task(x: int, y: typing.Dict[str, str]) -> str: | |
:param _task_function: This argument is implicitly passed and represents the decorated function | ||
:param task_config: This argument provides configuration for a specific task types. | ||
Please refer to the plugins documentation for the right object to use. | ||
:param cache: Boolean that indicates if caching should be enabled | ||
:param cache: Boolean that indicates if caching should be enabled or a list of AutoCache implementations | ||
:param cache_serialize: Boolean that indicates if identical (ie. same inputs) instances of this task should be | ||
executed in serial when caching is enabled. This means that given multiple concurrent executions over | ||
identical inputs, only a single instance executes and the rest wait to reuse the cached results. This | ||
|
@@ -342,12 +343,24 @@ def launch_dynamically(): | |
:param pickle_untyped: Boolean that indicates if the task allows unspecified data types. | ||
""" | ||
|
||
def wrapper(fn: Callable[P, Any]) -> PythonFunctionTask[T]: | ||
def wrapper(fn: Callable[..., Any]) -> PythonFunctionTask[T]: | ||
if isinstance(cache, CachePolicy): | ||
cache_val = True | ||
params = VersionParameters(func=fn, container_image=container_image) | ||
cache_version_val = cache_version or cache.get_version(params=params) | ||
cache_serialize_val = cache_serialize or cache.cache_serialize | ||
cache_serialize_val = cache_ignore_input_vars or cache.cache_ignore_input_vars | ||
else: | ||
cache_val = cache | ||
cache_version_val = cache_version | ||
cache_serialize_val = cache_serialize | ||
cache_ignore_input_vars_val = cache_ignore_input_vars | ||
Comment on lines
+350
to
+357
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what's the purpose of forwarding all of these parameters via the |
||
|
||
_metadata = TaskMetadata( | ||
cache=cache, | ||
cache_serialize=cache_serialize, | ||
cache_version=cache_version, | ||
cache_ignore_input_vars=cache_ignore_input_vars, | ||
cache=cache_val, | ||
cache_serialize=cache_serialize_val, | ||
cache_version=cache_version_val, | ||
cache_ignore_input_vars=cache_ignore_input_vars_val, | ||
retries=retries, | ||
interruptible=interruptible, | ||
deprecated=deprecated, | ||
|
@@ -433,7 +446,7 @@ def wrapper(fn) -> ReferenceTask: | |
return wrapper | ||
|
||
|
||
def decorate_function(fn: Callable[P, Any]) -> Callable[P, Any]: | ||
def decorate_function(fn: Callable[..., Any]) -> Callable[..., Any]: | ||
""" | ||
Decorates the task with additional functionality if necessary. | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# Flyte Auto Cache Plugin | ||
|
||
This plugin provides a caching mechanism for Flyte tasks that generates a version hash based on the source code of the task and its dependencies. It allows users to manage the cache behavior. | ||
|
||
## Usage | ||
|
||
To install the plugin, run the following command: | ||
|
||
```bash | ||
pip install flytekitplugins-auto-cache | ||
``` | ||
|
||
To use the caching mechanism in a Flyte task, you can define a `CachePolicy` that combines multiple caching strategies. Here’s an example of how to set it up: | ||
|
||
```python | ||
from flytekit import task | ||
from flytekit.core.auto_cache import CachePolicy | ||
from flytekitplugins.auto_cache import CacheFunctionBody, CachePrivateModules | ||
|
||
cache_policy = CachePolicy( | ||
auto_cache_policies = [ | ||
CacheFunctionBody(), | ||
CachePrivateModules(root_dir="../my_package"), | ||
..., | ||
] | ||
salt="my_salt" | ||
) | ||
Comment on lines
+20
to
+27
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's also provide an example of not needing to provide a |
||
|
||
@task(cache=cache_policy) | ||
def task_fn(): | ||
... | ||
``` | ||
|
||
### Salt Parameter | ||
|
||
The `salt` parameter in the `CachePolicy` adds uniqueness to the generated hash. It can be used to differentiate between different versions of the same task. This ensures that even if the underlying code remains unchanged, the hash will vary if a different salt is provided. This feature is particularly useful for invalidating the cache for specific versions of a task. | ||
|
||
## Cache Implementations | ||
|
||
Users can add any number of cache policies that implement the `AutoCache` protocol defined in `@auto_cache.py`. Below are the implementations available so far: | ||
|
||
### 1. CacheFunctionBody | ||
|
||
This implementation hashes the contents of the function of interest, ignoring any formatting or comment changes. It ensures that the core logic of the function is considered for versioning. | ||
|
||
### 2. CacheImage | ||
|
||
This implementation includes the hash of the `container_image` object passed. If the image is specified as a name, that string is hashed. If it is an `ImageSpec`, the parametrization of the `ImageSpec` is hashed, allowing for precise versioning of the container image used in the task. | ||
|
||
### 3. CachePrivateModules | ||
|
||
This implementation recursively searches the task of interest for all callables and constants used. The contents of any callable (function or class) utilized by the task are hashed, ignoring formatting or comments. The values of the literal constants used are also included in the hash. | ||
|
||
It accounts for both `import` and `from-import` statements at the global and local levels within a module or function. Any callables that are within site-packages (i.e., external libraries) are ignored. | ||
|
||
### 4. CacheExternalDependencies | ||
|
||
This implementation recursively searches through all the callables like `CachePrivateModules`, but when an external package is found, it records the version of the package, which is included in the hash. This ensures that changes in external dependencies are reflected in the task's versioning. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
""" | ||
.. currentmodule:: flytekitplugins.auto_cache | ||
|
||
This package contains things that are useful when extending Flytekit. | ||
|
||
.. autosummary:: | ||
:template: custom.rst | ||
:toctree: generated/ | ||
|
||
CacheFunctionBody | ||
CachePrivateModules | ||
""" | ||
|
||
from .cache_external_dependencies import CacheExternalDependencies | ||
from .cache_function_body import CacheFunctionBody | ||
from .cache_image import CacheImage | ||
from .cache_private_modules import CachePrivateModules |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the purpose of saving this state here? aren't these just forwarded to the underlying
TaskMetadata
?