From 2786c5b4fc17d0c2ba91cb8e03a27ebf4b746e99 Mon Sep 17 00:00:00 2001 From: Daniel Sola Date: Tue, 5 Nov 2024 15:07:05 -0800 Subject: [PATCH 01/16] wip --- plugins/flytekit-auto-cache/README.md | 9 +++++ .../flytekitplugins/auto_cache/__init__.py | 10 +++++ plugins/flytekit-auto-cache/setup.py | 37 +++++++++++++++++++ plugins/flytekit-auto-cache/tests/__init__.py | 0 .../tests/test_auto_cache.py | 0 5 files changed, 56 insertions(+) create mode 100644 plugins/flytekit-auto-cache/README.md create mode 100644 plugins/flytekit-auto-cache/flytekitplugins/auto_cache/__init__.py create mode 100644 plugins/flytekit-auto-cache/setup.py create mode 100644 plugins/flytekit-auto-cache/tests/__init__.py create mode 100644 plugins/flytekit-auto-cache/tests/test_auto_cache.py diff --git a/plugins/flytekit-auto-cache/README.md b/plugins/flytekit-auto-cache/README.md new file mode 100644 index 0000000000..76d0e1f853 --- /dev/null +++ b/plugins/flytekit-auto-cache/README.md @@ -0,0 +1,9 @@ +# Flytekit Auto Cache Plugin + + + +To install the plugin, run the following command: + +```bash +pip install flytekitplugins-auto-cache +``` diff --git a/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/__init__.py b/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/__init__.py new file mode 100644 index 0000000000..54d65b9e43 --- /dev/null +++ b/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/__init__.py @@ -0,0 +1,10 @@ +""" +.. currentmodule:: flytekitplugins.auto_cache + +This package contains things that are useful when extending Flytekit. + +.. autosummary:: + :template: custom.rst + :toctree: generated/ + +""" diff --git a/plugins/flytekit-auto-cache/setup.py b/plugins/flytekit-auto-cache/setup.py new file mode 100644 index 0000000000..6ba1d5060f --- /dev/null +++ b/plugins/flytekit-auto-cache/setup.py @@ -0,0 +1,37 @@ +from setuptools import setup + +PLUGIN_NAME = "auto_cache" + +microlib_name = "flytekitplugins-auto-cache" + +plugin_requires = ["flytekit"] + +__version__ = "0.0.0+develop" + +setup( + name=microlib_name, + version=__version__, + author="flyteorg", + author_email="admin@flyte.org", + description="This package holds the auto cache plugins for flytekit", + namespace_packages=["flytekitplugins"], + packages=[f"flytekitplugins.{PLUGIN_NAME}"], + install_requires=plugin_requires, + license="apache2", + python_requires=">=3.8", + classifiers=[ + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + ], + entry_points={"flytekit.plugins": [f"{PLUGIN_NAME}=flytekitplugins.{PLUGIN_NAME}"]}, +) diff --git a/plugins/flytekit-auto-cache/tests/__init__.py b/plugins/flytekit-auto-cache/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/flytekit-auto-cache/tests/test_auto_cache.py b/plugins/flytekit-auto-cache/tests/test_auto_cache.py new file mode 100644 index 0000000000..e69de29bb2 From b18bac7f112fc567225476d964ef19d76129706d Mon Sep 17 00:00:00 2001 From: Daniel Sola Date: Wed, 6 Nov 2024 15:41:58 -0800 Subject: [PATCH 02/16] initial auto cache method Signed-off-by: Daniel Sola --- flytekit/core/auto_cache.py | 37 +++++++++ flytekit/core/task.py | 17 ++-- .../flytekitplugins/auto_cache/__init__.py | 3 + .../auto_cache/cache_function_body.py | 53 ++++++++++++ .../tests/test_auto_cache.py | 80 +++++++++++++++++++ 5 files changed, 185 insertions(+), 5 deletions(-) create mode 100644 flytekit/core/auto_cache.py create mode 100644 plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_function_body.py diff --git a/flytekit/core/auto_cache.py b/flytekit/core/auto_cache.py new file mode 100644 index 0000000000..adf4741a0b --- /dev/null +++ b/flytekit/core/auto_cache.py @@ -0,0 +1,37 @@ +from typing import Any, Callable, Protocol, runtime_checkable + + +@runtime_checkable +class AutoCache(Protocol): + """ + A protocol that defines the interface for a caching mechanism + that generates a version hash of a function based on its source code. + + Attributes: + salt (str): A string used to add uniqueness to the generated hash. Default is "salt". + + Methods: + get_version(func: Callable[..., Any]) -> str: + Given a function, generates a version hash based on its source code and the salt. + """ + + def __init__(self, salt: str = "salt") -> None: + """ + Initialize the AutoCache instance with a salt value. + + Args: + salt (str): A string to be used as the salt in the hashing process. Defaults to "salt". + """ + self.salt = salt + + def get_version(self, func: Callable[..., Any]) -> str: + """ + Generate a version hash for the provided function. + + Args: + func (Callable[..., Any]): A callable function whose version hash needs to be generated. + + Returns: + str: The SHA-256 hash of the function's source code combined with the salt. + """ + ... diff --git a/flytekit/core/task.py b/flytekit/core/task.py index 745f452a83..7519f341d6 100644 --- a/flytekit/core/task.py +++ b/flytekit/core/task.py @@ -5,6 +5,7 @@ from functools import update_wrapper from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union, overload +from flytekit.core.auto_cache import AutoCache from flytekit.core.utils import str2bool try: @@ -99,7 +100,7 @@ def find_pythontask_plugin(cls, plugin_config_type: type) -> Type[PythonFunction def task( _task_function: None = ..., task_config: Optional[T] = ..., - cache: bool = ..., + cache: Union[bool, list[AutoCache]] = ..., cache_serialize: bool = ..., cache_version: str = ..., cache_ignore_input_vars: Tuple[str, ...] = ..., @@ -137,7 +138,7 @@ def task( def task( _task_function: Callable[P, FuncOut], task_config: Optional[T] = ..., - cache: bool = ..., + cache: Union[bool, list[AutoCache]] = ..., cache_serialize: bool = ..., cache_version: str = ..., cache_ignore_input_vars: Tuple[str, ...] = ..., @@ -174,7 +175,7 @@ def task( def task( _task_function: Optional[Callable[P, FuncOut]] = None, task_config: Optional[T] = None, - cache: bool = False, + cache: Union[bool, list[AutoCache]] = False, cache_serialize: bool = False, cache_version: str = "", cache_ignore_input_vars: Tuple[str, ...] = (), @@ -248,7 +249,7 @@ def my_task(x: int, y: typing.Dict[str, str]) -> str: :param _task_function: This argument is implicitly passed and represents the decorated function :param task_config: This argument provides configuration for a specific task types. Please refer to the plugins documentation for the right object to use. - :param cache: Boolean that indicates if caching should be enabled + :param cache: Boolean that indicates if caching should be enabled or a list of AutoCache implementations :param cache_serialize: Boolean that indicates if identical (ie. same inputs) instances of this task should be executed in serial when caching is enabled. This means that given multiple concurrent executions over identical inputs, only a single instance executes and the rest wait to reuse the cached results. This @@ -343,10 +344,16 @@ def launch_dynamically(): """ def wrapper(fn: Callable[P, Any]) -> PythonFunctionTask[T]: + if isinstance(cache, list) and all(isinstance(item, AutoCache) for item in cache): + cache_versions = [item.get_version() for item in cache] + task_hash = "".join(cache_versions) + else: + task_hash = "" + _metadata = TaskMetadata( cache=cache, cache_serialize=cache_serialize, - cache_version=cache_version, + cache_version=cache_version if not task_hash else task_hash, cache_ignore_input_vars=cache_ignore_input_vars, retries=retries, interruptible=interruptible, diff --git a/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/__init__.py b/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/__init__.py index 54d65b9e43..5872eda7ee 100644 --- a/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/__init__.py +++ b/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/__init__.py @@ -7,4 +7,7 @@ :template: custom.rst :toctree: generated/ + CacheFunctionBody """ + +from .cache_function_body import CacheFunctionBody diff --git a/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_function_body.py b/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_function_body.py new file mode 100644 index 0000000000..2e4472109d --- /dev/null +++ b/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_function_body.py @@ -0,0 +1,53 @@ +import ast +import hashlib +import inspect +from typing import Any, Callable + + +class CacheFunctionBody: + """ + A class that implements a versioning mechanism for functions by generating + a SHA-256 hash of the function's source code combined with a salt. + + Attributes: + salt (str): A string used to add uniqueness to the generated hash. Default is "salt". + + Methods: + get_version(func: Callable[..., Any]) -> str: + Given a function, generates a version hash based on its source code and the salt. + """ + + def __init__(self, salt: str = "salt") -> None: + """ + Initialize the CacheFunctionBody instance with a salt value. + + Args: + salt (str): A string to be used as the salt in the hashing process. Defaults to "salt". + """ + self.salt = salt + + def get_version(self, func: Callable[..., Any]) -> str: + """ + Generate a version hash for the provided function by parsing its source code + and adding a salt before applying the SHA-256 hash function. + + Args: + func (Callable[..., Any]): A callable function whose version hash needs to be generated. + + Returns: + str: The SHA-256 hash of the function's source code combined with the salt. + """ + # Get the source code of the function + source = inspect.getsource(func) + + # Parse the source code into an Abstract Syntax Tree (AST) + parsed_ast = ast.parse(source) + + # Convert the AST into a string representation (dump it) + ast_bytes = ast.dump(parsed_ast).encode("utf-8") + + # Combine the AST bytes with the salt (encoded into bytes) + combined_data = ast_bytes + self.salt.encode("utf-8") + + # Return the SHA-256 hash of the combined data (AST + salt) + return hashlib.sha256(combined_data).hexdigest() diff --git a/plugins/flytekit-auto-cache/tests/test_auto_cache.py b/plugins/flytekit-auto-cache/tests/test_auto_cache.py index e69de29bb2..bd65b5461a 100644 --- a/plugins/flytekit-auto-cache/tests/test_auto_cache.py +++ b/plugins/flytekit-auto-cache/tests/test_auto_cache.py @@ -0,0 +1,80 @@ +from flytekitplugins.auto_cache import CacheFunctionBody + + +# Dummy functions +def dummy_function(x: int, y: int) -> int: + result = x + y + return result + +def dummy_function_modified(x: int, y: int) -> int: + result = x * y + return result + + +def dummy_function_with_comments_and_formatting(x: int, y: int) -> int: + # Adding a new line here + result = ( + x + y + ) + # Another new line + return result + + + +def test_get_version_with_same_function_and_salt(): + """ + Test that calling get_version with the same function and salt returns the same hash. + """ + cache1 = CacheFunctionBody(salt="salt") + cache2 = CacheFunctionBody(salt="salt") + + # Both calls should return the same hash since the function and salt are the same + version1 = cache1.get_version(dummy_function) + version2 = cache2.get_version(dummy_function) + + assert version1 == version2, f"Expected {version1}, but got {version2}" + + +def test_get_version_with_different_salt(): + """ + Test that calling get_version with different salts returns different hashes for the same function. + """ + cache1 = CacheFunctionBody(salt="salt1") + cache2 = CacheFunctionBody(salt="salt2") + + # The hashes should be different because the salts are different + version1 = cache1.get_version(dummy_function) + version2 = cache2.get_version(dummy_function) + + assert version1 != version2, f"Expected different hashes but got the same: {version1}" + + +def test_get_version_with_different_function_source(): + """ + Test that calling get_version with different function sources returns different hashes. + """ + cache = CacheFunctionBody(salt="salt") + + # The hash should be different because the function source has changed + version1 = cache.get_version(dummy_function) + version2 = cache.get_version(dummy_function_modified) + + assert version1 != version2, f"Expected different hashes but got the same: {version1} and {version2}" + + +def test_get_version_with_comments_and_formatting_changes(): + """ + Test that adding comments, changing formatting, or modifying the function signature + results in a different hash. + """ + # Modify the function by adding comments and changing the formatting + cache = CacheFunctionBody(salt="salt") + + # Get the hash for the original dummy function + original_version = cache.get_version(dummy_function) + + # Get the hash for the function with comments and formatting changes + version_with_comments_and_formatting = cache.get_version(dummy_function_with_comments_and_formatting) + + # Assert that the hashes are different + assert original_version != version_with_comments_and_formatting, f"Expected different hashes but got the same: {original_version} and {version_with_comments_and_formatting}" From 50552a96bab88d8052c05c5e76f8bc7c257a17dc Mon Sep 17 00:00:00 2001 From: Daniel Sola Date: Wed, 6 Nov 2024 16:24:17 -0800 Subject: [PATCH 03/16] fix tests to import dummy functions Signed-off-by: Daniel Sola --- .../tests/dummy_functions/__init__.py | 0 .../tests/dummy_functions/dummy_function.py | 3 + ...mmy_function_comments_formatting_change.py | 9 +++ .../dummy_function_logic_change.py | 3 + .../tests/test_auto_cache.py | 77 ++++++++++--------- 5 files changed, 56 insertions(+), 36 deletions(-) create mode 100644 plugins/flytekit-auto-cache/tests/dummy_functions/__init__.py create mode 100644 plugins/flytekit-auto-cache/tests/dummy_functions/dummy_function.py create mode 100644 plugins/flytekit-auto-cache/tests/dummy_functions/dummy_function_comments_formatting_change.py create mode 100644 plugins/flytekit-auto-cache/tests/dummy_functions/dummy_function_logic_change.py diff --git a/plugins/flytekit-auto-cache/tests/dummy_functions/__init__.py b/plugins/flytekit-auto-cache/tests/dummy_functions/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/flytekit-auto-cache/tests/dummy_functions/dummy_function.py b/plugins/flytekit-auto-cache/tests/dummy_functions/dummy_function.py new file mode 100644 index 0000000000..413311f465 --- /dev/null +++ b/plugins/flytekit-auto-cache/tests/dummy_functions/dummy_function.py @@ -0,0 +1,3 @@ +def dummy_function(x: int, y: int) -> int: + result = x + y + return result diff --git a/plugins/flytekit-auto-cache/tests/dummy_functions/dummy_function_comments_formatting_change.py b/plugins/flytekit-auto-cache/tests/dummy_functions/dummy_function_comments_formatting_change.py new file mode 100644 index 0000000000..68b95eb025 --- /dev/null +++ b/plugins/flytekit-auto-cache/tests/dummy_functions/dummy_function_comments_formatting_change.py @@ -0,0 +1,9 @@ +def dummy_function(x: int, y: int) -> int: + # Adding some comments + result = ( + x + # Adding inline comment + y # Another inline comment + ) + + # More comments + return result diff --git a/plugins/flytekit-auto-cache/tests/dummy_functions/dummy_function_logic_change.py b/plugins/flytekit-auto-cache/tests/dummy_functions/dummy_function_logic_change.py new file mode 100644 index 0000000000..0a63e11adf --- /dev/null +++ b/plugins/flytekit-auto-cache/tests/dummy_functions/dummy_function_logic_change.py @@ -0,0 +1,3 @@ +def dummy_function(x: int, y: int) -> int: + result = x * y + return result diff --git a/plugins/flytekit-auto-cache/tests/test_auto_cache.py b/plugins/flytekit-auto-cache/tests/test_auto_cache.py index bd65b5461a..efebe8ad0a 100644 --- a/plugins/flytekit-auto-cache/tests/test_auto_cache.py +++ b/plugins/flytekit-auto-cache/tests/test_auto_cache.py @@ -1,26 +1,9 @@ +from dummy_functions.dummy_function import dummy_function +from dummy_functions.dummy_function_comments_formatting_change import dummy_function as dummy_function_comments_formatting_change +from dummy_functions.dummy_function_logic_change import dummy_function as dummy_function_logic_change from flytekitplugins.auto_cache import CacheFunctionBody -# Dummy functions -def dummy_function(x: int, y: int) -> int: - result = x + y - return result - -def dummy_function_modified(x: int, y: int) -> int: - result = x * y - return result - - -def dummy_function_with_comments_and_formatting(x: int, y: int) -> int: - # Adding a new line here - result = ( - x + y - ) - # Another new line - return result - - - def test_get_version_with_same_function_and_salt(): """ Test that calling get_version with the same function and salt returns the same hash. @@ -49,32 +32,54 @@ def test_get_version_with_different_salt(): assert version1 != version2, f"Expected different hashes but got the same: {version1}" -def test_get_version_with_different_function_source(): + +def test_get_version_with_different_logic(): """ - Test that calling get_version with different function sources returns different hashes. + Test that functions with the same name but different logic produce different hashes. """ cache = CacheFunctionBody(salt="salt") - - # The hash should be different because the function source has changed version1 = cache.get_version(dummy_function) - version2 = cache.get_version(dummy_function_modified) + version2 = cache.get_version(dummy_function_logic_change) - assert version1 != version2, f"Expected different hashes but got the same: {version1} and {version2}" + assert version1 != version2, ( + f"Hashes should be different for functions with same name but different logic. " + f"Got {version1} and {version2}" + ) +# Test functions with different names but same logic +def function_one(x: int, y: int) -> int: + result = x + y + return result -def test_get_version_with_comments_and_formatting_changes(): +def function_two(x: int, y: int) -> int: + result = x + y + return result + +def test_get_version_with_different_function_names(): """ - Test that adding comments, changing formatting, or modifying the function signature - results in a different hash. + Test that functions with different names but same logic produce different hashes. """ - # Modify the function by adding comments and changing the formatting cache = CacheFunctionBody(salt="salt") - # Get the hash for the original dummy function - original_version = cache.get_version(dummy_function) + version1 = cache.get_version(function_one) + version2 = cache.get_version(function_two) - # Get the hash for the function with comments and formatting changes - version_with_comments_and_formatting = cache.get_version(dummy_function_with_comments_and_formatting) + assert version1 != version2, ( + f"Hashes should be different for functions with different names. " + f"Got {version1} and {version2}" + ) - # Assert that the hashes are different - assert original_version != version_with_comments_and_formatting, f"Expected different hashes but got the same: {original_version} and {version_with_comments_and_formatting}" +def test_get_version_with_formatting_changes(): + """ + Test that changing formatting and comments but keeping the same function name + results in the same hash. + """ + + cache = CacheFunctionBody(salt="salt") + version1 = cache.get_version(dummy_function) + version2 = cache.get_version(dummy_function_comments_formatting_change) + + assert version1 == version2, ( + f"Hashes should be the same for functions with same name but different formatting. " + f"Got {version1} and {version2}" + ) From 73d23270237db061d99485997606b07ec015df09 Mon Sep 17 00:00:00 2001 From: Daniel Sola Date: Wed, 13 Nov 2024 14:48:36 -0800 Subject: [PATCH 04/16] add recursive cache for private modules --- .../flytekitplugins/auto_cache/__init__.py | 2 + .../auto_cache/cache_function_body.py | 9 +- .../auto_cache/cache_private_modules.py | 156 ++++++++++++++++++ .../tests/my_package/__init__.py | 0 .../tests/my_package/main.py | 22 +++ .../tests/my_package/module_a.py | 16 ++ .../tests/my_package/module_b.py | 11 ++ .../tests/my_package/module_c.py | 21 +++ .../tests/my_package/module_d.py | 2 + .../tests/my_package/my_dir/__init__.py | 1 + .../tests/my_package/my_dir/module_in_dir.py | 5 + ...st_auto_cache.py => test_function_body.py} | 0 .../tests/test_recursive.py | 29 ++++ 13 files changed, 272 insertions(+), 2 deletions(-) create mode 100644 plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_private_modules.py create mode 100644 plugins/flytekit-auto-cache/tests/my_package/__init__.py create mode 100644 plugins/flytekit-auto-cache/tests/my_package/main.py create mode 100644 plugins/flytekit-auto-cache/tests/my_package/module_a.py create mode 100644 plugins/flytekit-auto-cache/tests/my_package/module_b.py create mode 100644 plugins/flytekit-auto-cache/tests/my_package/module_c.py create mode 100644 plugins/flytekit-auto-cache/tests/my_package/module_d.py create mode 100644 plugins/flytekit-auto-cache/tests/my_package/my_dir/__init__.py create mode 100644 plugins/flytekit-auto-cache/tests/my_package/my_dir/module_in_dir.py rename plugins/flytekit-auto-cache/tests/{test_auto_cache.py => test_function_body.py} (100%) create mode 100644 plugins/flytekit-auto-cache/tests/test_recursive.py diff --git a/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/__init__.py b/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/__init__.py index 5872eda7ee..475bc7515f 100644 --- a/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/__init__.py +++ b/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/__init__.py @@ -8,6 +8,8 @@ :toctree: generated/ CacheFunctionBody + CachePrivateModules """ from .cache_function_body import CacheFunctionBody +from .cache_private_modules import CachePrivateModules diff --git a/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_function_body.py b/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_function_body.py index 2e4472109d..6751b53c0f 100644 --- a/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_function_body.py +++ b/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_function_body.py @@ -1,6 +1,7 @@ import ast import hashlib import inspect +import textwrap from typing import Any, Callable @@ -27,6 +28,9 @@ def __init__(self, salt: str = "salt") -> None: self.salt = salt def get_version(self, func: Callable[..., Any]) -> str: + return self._get_version(func=func) + + def _get_version(self, func: Callable[..., Any]) -> str: """ Generate a version hash for the provided function by parsing its source code and adding a salt before applying the SHA-256 hash function. @@ -37,11 +41,12 @@ def get_version(self, func: Callable[..., Any]) -> str: Returns: str: The SHA-256 hash of the function's source code combined with the salt. """ - # Get the source code of the function + # Get the source code of the function and dedent source = inspect.getsource(func) + dedented_source = textwrap.dedent(source) # Parse the source code into an Abstract Syntax Tree (AST) - parsed_ast = ast.parse(source) + parsed_ast = ast.parse(dedented_source) # Convert the AST into a string representation (dump it) ast_bytes = ast.dump(parsed_ast).encode("utf-8") diff --git a/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_private_modules.py b/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_private_modules.py new file mode 100644 index 0000000000..2296948ccc --- /dev/null +++ b/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_private_modules.py @@ -0,0 +1,156 @@ +import ast +import hashlib +import importlib.util +import inspect +import sys +import textwrap +from contextlib import contextmanager +from pathlib import Path +from typing import Any, Callable, Set, Union + + +@contextmanager +def temporarily_add_to_syspath(path): + """Temporarily add the given path to sys.path.""" + sys.path.insert(0, str(path)) + try: + yield + finally: + sys.path.pop(0) + + +class CachePrivateModules: + def __init__(self, salt: str, root_dir: str): + self.salt = salt + self.root_dir = Path(root_dir).resolve() + self.dependencies = self._get_function_dependencies(func, set()) + + def get_version(self, func: Callable[..., Any]) -> str: + hash_components = [self._get_version(func)] + for dep in self.dependencies: + hash_components.append(self._get_version(dep)) + # Combine all component hashes into a single version hash + combined_hash = hashlib.sha256("".join(hash_components).encode("utf-8")).hexdigest() + return combined_hash + + def _get_version(self, func: Callable[..., Any]) -> str: + source = inspect.getsource(func) + dedented_source = textwrap.dedent(source) + parsed_ast = ast.parse(dedented_source) + ast_bytes = ast.dump(parsed_ast).encode("utf-8") + combined_data = ast_bytes + self.salt.encode("utf-8") + return hashlib.sha256(combined_data).hexdigest() + + def _get_function_dependencies(self, func: Callable[..., Any], visited: Set[str]) -> Set[Callable[..., Any]]: + """Recursively gather all functions, methods, and classes used within `func` and defined in the user’s package.""" + dependencies = set() + # Dedent the source code to handle class method indentation + source = textwrap.dedent(inspect.getsource(func)) + parsed_ast = ast.parse(source) + + # Build a locals dictionary for function-level imports + locals_dict = {} + for node in ast.walk(parsed_ast): + if isinstance(node, ast.Import): + for alias in node.names: + module = importlib.import_module(alias.name) + locals_dict[alias.asname or alias.name] = module + elif isinstance(node, ast.ImportFrom): + module = importlib.import_module(node.module) + for alias in node.names: + imported_obj = getattr(module, alias.name, None) + if imported_obj: + locals_dict[alias.asname or alias.name] = imported_obj + + # Check each function call in the AST + for node in ast.walk(parsed_ast): + if isinstance(node, ast.Call): + func_name = self._get_callable_name(node.func) + if func_name and func_name not in visited: + visited.add(func_name) + try: + # Attempt to resolve using locals first, then globals + func_obj = locals_dict.get(func_name) or self._resolve_callable(func_name, func.__globals__) + if inspect.isclass(func_obj) and self._is_user_defined(func_obj): + # Add class methods as dependencies + for name, method in inspect.getmembers(func_obj, predicate=inspect.isfunction): + if method not in visited: + visited.add(method.__qualname__) + dependencies.add(method) + dependencies.update(self._get_function_dependencies(method, visited)) + elif (inspect.isfunction(func_obj) or inspect.ismethod(func_obj)) and self._is_user_defined( + func_obj + ): + dependencies.add(func_obj) + dependencies.update(self._get_function_dependencies(func_obj, visited)) + except (NameError, AttributeError): + pass + return dependencies + + def _get_callable_name(self, node: ast.AST) -> Union[str, None]: + """Retrieve the name of the callable from an AST node.""" + if isinstance(node, ast.Name): + return node.id + elif isinstance(node, ast.Attribute): + return f"{node.value.id}.{node.attr}" if isinstance(node.value, ast.Name) else node.attr + return None + + def __resolve_callable(self, func_name: str, globals_dict: dict) -> Callable[..., Any]: + """Resolve a callable from its name within the given globals dictionary.""" + parts = func_name.split(".") + obj = globals_dict.get(parts[0], None) + for part in parts[1:]: + obj = getattr(obj, part, None) + if obj is None: + break + return obj + + def _resolve_callable(self, func_name: str, globals_dict: dict) -> Callable[..., Any]: + """Resolve a callable from its name within the given globals dictionary, handling modules as entry points.""" + parts = func_name.split(".") + + # First, try resolving directly from globals_dict for a straightforward reference + obj = globals_dict.get(parts[0], None) + for part in parts[1:]: + if obj is None: + break + obj = getattr(obj, part, None) + + # If not found, iterate through modules in globals_dict and attempt resolution from them + if not callable(obj): + for module in globals_dict.values(): + if isinstance(module, type(sys)): # Check if the global value is a module + obj = module + for part in parts: + obj = getattr(obj, part, None) + if obj is None: + break + if callable(obj): # Exit if we successfully resolve the callable + break + obj = None # Reset if we didn't find the callable in this module + + # Return the callable if successfully resolved; otherwise, None + return obj if callable(obj) else None + + def _is_user_defined(self, obj: Any) -> bool: + """Check if a callable or class is user-defined within the package.""" + module_name = getattr(obj, "__module__", None) + return module_name and self._can_import_module(module_name) + + def _can_import_module(self, module_name: str) -> bool: + """ + Check if a module with the given name can be imported from the specified root package directory. + + Args: + module_name (str): The module name to check for import. + + Returns: + bool: True if the module can be imported from the root directory, False otherwise. + """ + with temporarily_add_to_syspath(self.root_dir): + spec = importlib.util.find_spec(module_name) + if spec and spec.origin: + # Check if the module's path is within the root directory + module_path = Path(spec.origin).resolve() + return self.root_dir in module_path.parents + return False diff --git a/plugins/flytekit-auto-cache/tests/my_package/__init__.py b/plugins/flytekit-auto-cache/tests/my_package/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/flytekit-auto-cache/tests/my_package/main.py b/plugins/flytekit-auto-cache/tests/my_package/main.py new file mode 100644 index 0000000000..80811d3f48 --- /dev/null +++ b/plugins/flytekit-auto-cache/tests/my_package/main.py @@ -0,0 +1,22 @@ +import sys +from pathlib import Path + +# Add the parent directory of `my_package` to sys.path +sys.path.append(str(Path(__file__).resolve().parent)) + +# import module_a +from module_a import helper_function +from my_dir.module_in_dir import helper_in_directory +from module_c import DummyClass +import pandas as pd # External library + +def my_main_function(): + print("Main function") + helper_in_directory() + helper_function() + # module_a.helper_function() + df = pd.DataFrame({"a": [1, 2, 3]}) + print(df) + dc = DummyClass() + print(dc) + dc.dummy_method() diff --git a/plugins/flytekit-auto-cache/tests/my_package/module_a.py b/plugins/flytekit-auto-cache/tests/my_package/module_a.py new file mode 100644 index 0000000000..cc85c7f9ab --- /dev/null +++ b/plugins/flytekit-auto-cache/tests/my_package/module_a.py @@ -0,0 +1,16 @@ +# import sys +# from pathlib import Path +# +# # Add the parent directory of `my_package` to sys.path +# sys.path.append(str(Path(__file__).resolve().parent)) + +# from module_b import another_helper +import module_b + +def helper_function(): + print("Helper function") + module_b.another_helper() + # another_helper() + +def unused_function(): + print("Unused function") diff --git a/plugins/flytekit-auto-cache/tests/my_package/module_b.py b/plugins/flytekit-auto-cache/tests/my_package/module_b.py new file mode 100644 index 0000000000..ba8a6d21be --- /dev/null +++ b/plugins/flytekit-auto-cache/tests/my_package/module_b.py @@ -0,0 +1,11 @@ +# import sys +# from pathlib import Path +# +# # Add the parent directory of `my_package` to sys.path +# sys.path.append(str(Path(__file__).resolve().parent)) + +from module_c import third_helper + +def another_helper(): + print("Another helper") + third_helper() diff --git a/plugins/flytekit-auto-cache/tests/my_package/module_c.py b/plugins/flytekit-auto-cache/tests/my_package/module_c.py new file mode 100644 index 0000000000..eb144754fe --- /dev/null +++ b/plugins/flytekit-auto-cache/tests/my_package/module_c.py @@ -0,0 +1,21 @@ +# import sys +# from pathlib import Path +# +# # Add the parent directory of `my_package` to sys.path +# sys.path.append(str(Path(__file__).resolve().parent)) + +# from module_d import fourth_helper +import my_dir + +def third_helper(): + print("Third helper") + +class DummyClass: + def dummy_method(self) -> str: + my_dir.module_in_dir.other_helper_in_directory() + return "Hello from dummy method!" + + def other_dummy_method(self): + from module_d import fourth_helper + print("Other dummy method") + fourth_helper() diff --git a/plugins/flytekit-auto-cache/tests/my_package/module_d.py b/plugins/flytekit-auto-cache/tests/my_package/module_d.py new file mode 100644 index 0000000000..92db897762 --- /dev/null +++ b/plugins/flytekit-auto-cache/tests/my_package/module_d.py @@ -0,0 +1,2 @@ +def fourth_helper(): + print("Fourth helper") diff --git a/plugins/flytekit-auto-cache/tests/my_package/my_dir/__init__.py b/plugins/flytekit-auto-cache/tests/my_package/my_dir/__init__.py new file mode 100644 index 0000000000..14fd12b8cc --- /dev/null +++ b/plugins/flytekit-auto-cache/tests/my_package/my_dir/__init__.py @@ -0,0 +1 @@ +from .module_in_dir import other_helper_in_directory diff --git a/plugins/flytekit-auto-cache/tests/my_package/my_dir/module_in_dir.py b/plugins/flytekit-auto-cache/tests/my_package/my_dir/module_in_dir.py new file mode 100644 index 0000000000..7159a09643 --- /dev/null +++ b/plugins/flytekit-auto-cache/tests/my_package/my_dir/module_in_dir.py @@ -0,0 +1,5 @@ +def helper_in_directory(): + print("Helper in directory") + +def other_helper_in_directory(): + print("Other helper in directory") diff --git a/plugins/flytekit-auto-cache/tests/test_auto_cache.py b/plugins/flytekit-auto-cache/tests/test_function_body.py similarity index 100% rename from plugins/flytekit-auto-cache/tests/test_auto_cache.py rename to plugins/flytekit-auto-cache/tests/test_function_body.py diff --git a/plugins/flytekit-auto-cache/tests/test_recursive.py b/plugins/flytekit-auto-cache/tests/test_recursive.py new file mode 100644 index 0000000000..913bcd921d --- /dev/null +++ b/plugins/flytekit-auto-cache/tests/test_recursive.py @@ -0,0 +1,29 @@ +from flytekitplugins.auto_cache import CachePrivateModules +from my_package.main import my_main_function as func + + + +def test_dependencies(): + expected_dependencies = { + "module_a.helper_function", + "module_b.another_helper", + "module_c.DummyClass.dummy_method", + "module_c.DummyClass.other_dummy_method", + "module_c.third_helper", + "module_d.fourth_helper", + "my_dir.module_in_dir.helper_in_directory", + "my_dir.module_in_dir.other_helper_in_directory", + } + + cache = CachePrivateModules(salt="salt", root_dir="./my_package") + + actual_dependencies = { + f"{dep.__module__}.{dep.__qualname__}".replace("my_package.", "") + for dep in cache.dependencies + } + + assert actual_dependencies == expected_dependencies, ( + f"Dependencies do not match:\n" + f"Expected: {expected_dependencies}\n" + f"Actual: {actual_dependencies}" + ) From 6d5cdbf1ebf5cd7fec526711720f1aaf4b201562 Mon Sep 17 00:00:00 2001 From: Daniel Sola Date: Wed, 13 Nov 2024 15:44:47 -0800 Subject: [PATCH 05/16] cleanup Signed-off-by: Daniel Sola --- .../auto_cache/cache_private_modules.py | 46 ++--- .../tests/recursive_modules_v3.py | 158 ++++++++++++++ .../tests/recursive_modules_v4.py | 194 ++++++++++++++++++ .../tests/test_recursive.py | 9 +- 4 files changed, 376 insertions(+), 31 deletions(-) create mode 100644 plugins/flytekit-auto-cache/tests/recursive_modules_v3.py create mode 100644 plugins/flytekit-auto-cache/tests/recursive_modules_v4.py diff --git a/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_private_modules.py b/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_private_modules.py index 2296948ccc..bee640b240 100644 --- a/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_private_modules.py +++ b/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_private_modules.py @@ -23,11 +23,11 @@ class CachePrivateModules: def __init__(self, salt: str, root_dir: str): self.salt = salt self.root_dir = Path(root_dir).resolve() - self.dependencies = self._get_function_dependencies(func, set()) def get_version(self, func: Callable[..., Any]) -> str: hash_components = [self._get_version(func)] - for dep in self.dependencies: + dependencies = self._get_function_dependencies(func, set()) + for dep in dependencies: hash_components.append(self._get_version(dep)) # Combine all component hashes into a single version hash combined_hash = hashlib.sha256("".join(hash_components).encode("utf-8")).hexdigest() @@ -95,16 +95,6 @@ def _get_callable_name(self, node: ast.AST) -> Union[str, None]: return f"{node.value.id}.{node.attr}" if isinstance(node.value, ast.Name) else node.attr return None - def __resolve_callable(self, func_name: str, globals_dict: dict) -> Callable[..., Any]: - """Resolve a callable from its name within the given globals dictionary.""" - parts = func_name.split(".") - obj = globals_dict.get(parts[0], None) - for part in parts[1:]: - obj = getattr(obj, part, None) - if obj is None: - break - return obj - def _resolve_callable(self, func_name: str, globals_dict: dict) -> Callable[..., Any]: """Resolve a callable from its name within the given globals dictionary, handling modules as entry points.""" parts = func_name.split(".") @@ -135,22 +125,24 @@ def _resolve_callable(self, func_name: str, globals_dict: dict) -> Callable[..., def _is_user_defined(self, obj: Any) -> bool: """Check if a callable or class is user-defined within the package.""" module_name = getattr(obj, "__module__", None) - return module_name and self._can_import_module(module_name) - - def _can_import_module(self, module_name: str) -> bool: - """ - Check if a module with the given name can be imported from the specified root package directory. - - Args: - module_name (str): The module name to check for import. + if not module_name: + return False - Returns: - bool: True if the module can be imported from the root directory, False otherwise. - """ + # Retrieve the module specification to get its path with temporarily_add_to_syspath(self.root_dir): spec = importlib.util.find_spec(module_name) - if spec and spec.origin: - # Check if the module's path is within the root directory - module_path = Path(spec.origin).resolve() - return self.root_dir in module_path.parents + if not spec or not spec.origin: + return False + + module_path = Path(spec.origin).resolve() + + # Check if the module is within the root directory but not in site-packages + if self.root_dir in module_path.parents: + # Exclude standard library or site-packages by checking common paths + site_packages_paths = {Path(p).resolve() for p in sys.path if "site-packages" in p} + is_in_site_packages = any(sp in module_path.parents for sp in site_packages_paths) + + # Return True if within root_dir but not in site-packages + return not is_in_site_packages + return False diff --git a/plugins/flytekit-auto-cache/tests/recursive_modules_v3.py b/plugins/flytekit-auto-cache/tests/recursive_modules_v3.py new file mode 100644 index 0000000000..f522596d8d --- /dev/null +++ b/plugins/flytekit-auto-cache/tests/recursive_modules_v3.py @@ -0,0 +1,158 @@ +import ast +import hashlib +import importlib.util +import inspect +import sys +import textwrap +from contextlib import contextmanager +from pathlib import Path +from typing import Callable, Any, Set, Union + +@contextmanager +def temporarily_add_to_syspath(path): + """Temporarily add the given path to sys.path.""" + sys.path.insert(0, str(path)) + try: + yield + finally: + sys.path.pop(0) + + +class VersionHasher: + def __init__(self, salt: str, root_dir: str): + self.salt = salt + self.root_dir = Path(root_dir).resolve() + + def get_version(self, func: Callable[..., Any]) -> str: + source = inspect.getsource(func) + # Dedent the source code to handle class method indentation + dedented_source = textwrap.dedent(source) + parsed_ast = ast.parse(dedented_source) + ast_bytes = ast.dump(parsed_ast).encode("utf-8") + combined_data = ast_bytes + self.salt.encode("utf-8") + return hashlib.sha256(combined_data).hexdigest() + + def _get_function_dependencies(self, func: Callable[..., Any], visited: Set[str]) -> Set[Callable[..., Any]]: + """Recursively gather all functions, methods, and classes used within `func` and defined in the user’s package.""" + dependencies = set() + # Dedent the source code to handle class method indentation + source = textwrap.dedent(inspect.getsource(func)) + parsed_ast = ast.parse(source) + + # Build a locals dictionary for function-level imports + locals_dict = {} + for node in ast.walk(parsed_ast): + if isinstance(node, ast.Import): + for alias in node.names: + module = importlib.import_module(alias.name) + locals_dict[alias.asname or alias.name] = module + elif isinstance(node, ast.ImportFrom): + module = importlib.import_module(node.module) + for alias in node.names: + imported_obj = getattr(module, alias.name, None) + if imported_obj: + locals_dict[alias.asname or alias.name] = imported_obj + + # Check each function call in the AST + for node in ast.walk(parsed_ast): + if isinstance(node, ast.Call): + func_name = self._get_callable_name(node.func) + if func_name and func_name not in visited: + visited.add(func_name) + try: + # Attempt to resolve using locals first, then globals + func_obj = locals_dict.get(func_name) or self._resolve_callable(func_name, func.__globals__) + if inspect.isclass(func_obj) and self._is_user_defined(func_obj): + # Add class methods as dependencies + for name, method in inspect.getmembers(func_obj, predicate=inspect.isfunction): + if method not in visited: + visited.add(method.__qualname__) + dependencies.add(method) + dependencies.update(self._get_function_dependencies(method, visited)) + elif (inspect.isfunction(func_obj) or inspect.ismethod(func_obj)) and self._is_user_defined( + func_obj): + dependencies.add(func_obj) + dependencies.update(self._get_function_dependencies(func_obj, visited)) + except (NameError, AttributeError): + pass + return dependencies + def _get_callable_name(self, node: ast.AST) -> Union[str, None]: + """Retrieve the name of the callable from an AST node.""" + if isinstance(node, ast.Name): + return node.id + elif isinstance(node, ast.Attribute): + return f"{node.value.id}.{node.attr}" if isinstance(node.value, ast.Name) else node.attr + return None + + def __resolve_callable(self, func_name: str, globals_dict: dict) -> Callable[..., Any]: + """Resolve a callable from its name within the given globals dictionary.""" + parts = func_name.split(".") + obj = globals_dict.get(parts[0], None) + for part in parts[1:]: + obj = getattr(obj, part, None) + if obj is None: + break + return obj + + def _resolve_callable(self, func_name: str, globals_dict: dict) -> Callable[..., Any]: + """Resolve a callable from its name within the given globals dictionary, handling modules as entry points.""" + parts = func_name.split(".") + + # First, try resolving directly from globals_dict for a straightforward reference + obj = globals_dict.get(parts[0], None) + for part in parts[1:]: + if obj is None: + break + obj = getattr(obj, part, None) + + # If not found, iterate through modules in globals_dict and attempt resolution from them + if not callable(obj): + for module in globals_dict.values(): + if isinstance(module, type(sys)): # Check if the global value is a module + obj = module + for part in parts: + obj = getattr(obj, part, None) + if obj is None: + break + if callable(obj): # Exit if we successfully resolve the callable + break + obj = None # Reset if we didn't find the callable in this module + + # Return the callable if successfully resolved; otherwise, None + return obj if callable(obj) else None + + def _is_user_defined(self, obj: Any) -> bool: + """Check if a callable or class is user-defined within the package.""" + module_name = getattr(obj, "__module__", None) + return module_name and self._can_import_module(module_name) + + def _can_import_module(self, module_name: str) -> bool: + """ + Check if a module with the given name can be imported from the specified root package directory. + + Args: + module_name (str): The module name to check for import. + + Returns: + bool: True if the module can be imported from the root directory, False otherwise. + """ + with temporarily_add_to_syspath(self.root_dir): + spec = importlib.util.find_spec(module_name) + if spec and spec.origin: + # Check if the module's path is within the root directory + module_path = Path(spec.origin).resolve() + return self.root_dir in module_path.parents + return False + +from my_package.main import my_main_function as func + +vh = VersionHasher(salt="salt", root_dir="./my_package") + +hash_components = [vh.get_version(func)] +# Gather dependencies and add their hashes +dependencies = vh._get_function_dependencies(func, set()) +for dep in dependencies: + hash_components.append(vh.get_version(dep)) +# Combine all component hashes into a single version hash +combined_hash = hashlib.sha256("".join(hash_components).encode("utf-8")).hexdigest() +print(combined_hash) diff --git a/plugins/flytekit-auto-cache/tests/recursive_modules_v4.py b/plugins/flytekit-auto-cache/tests/recursive_modules_v4.py new file mode 100644 index 0000000000..019a5ba16a --- /dev/null +++ b/plugins/flytekit-auto-cache/tests/recursive_modules_v4.py @@ -0,0 +1,194 @@ +import ast +import hashlib +import importlib.util +import inspect +import sys +import textwrap +from contextlib import contextmanager +from pathlib import Path +from typing import Callable, Any, Set, Union + +@contextmanager +def temporarily_add_to_syspath(path): + """Temporarily add the given path to sys.path.""" + sys.path.insert(0, str(path)) + try: + yield + finally: + sys.path.pop(0) + + +class CachePrivateModules: + def __init__(self, salt: str, root_dir: str): + self.salt = salt + self.root_dir = Path(root_dir).resolve() + + def get_version(self, func: Callable[..., Any]) -> str: + hash_components = [self._get_version(func)] + dependencies = self._get_function_dependencies(func, set()) + for dep in dependencies: + hash_components.append(self._get_version(dep)) + # Combine all component hashes into a single version hash + combined_hash = hashlib.sha256("".join(hash_components).encode("utf-8")).hexdigest() + return combined_hash + + def _get_version(self, func: Callable[..., Any]) -> str: + source = inspect.getsource(func) + dedented_source = textwrap.dedent(source) + parsed_ast = ast.parse(dedented_source) + ast_bytes = ast.dump(parsed_ast).encode("utf-8") + combined_data = ast_bytes + self.salt.encode("utf-8") + return hashlib.sha256(combined_data).hexdigest() + + def _get_function_dependencies(self, func: Callable[..., Any], visited: Set[str]) -> Set[Callable[..., Any]]: + """Recursively gather all functions, methods, and classes used within `func` and defined in the user’s package.""" + dependencies = set() + # Dedent the source code to handle class method indentation + source = textwrap.dedent(inspect.getsource(func)) + parsed_ast = ast.parse(source) + + # Build a locals dictionary for function-level imports + locals_dict = {} + for node in ast.walk(parsed_ast): + if isinstance(node, ast.Import): + for alias in node.names: + module = importlib.import_module(alias.name) + locals_dict[alias.asname or alias.name] = module + elif isinstance(node, ast.ImportFrom): + module = importlib.import_module(node.module) + for alias in node.names: + imported_obj = getattr(module, alias.name, None) + if imported_obj: + locals_dict[alias.asname or alias.name] = imported_obj + + # Check each function call in the AST + for node in ast.walk(parsed_ast): + if isinstance(node, ast.Call): + func_name = self._get_callable_name(node.func) + if func_name and func_name not in visited: + visited.add(func_name) + try: + # Attempt to resolve using locals first, then globals + func_obj = locals_dict.get(func_name) or self._resolve_callable(func_name, func.__globals__) + if inspect.isclass(func_obj) and self._is_user_defined(func_obj): + # Add class methods as dependencies + for name, method in inspect.getmembers(func_obj, predicate=inspect.isfunction): + if method not in visited: + visited.add(method.__qualname__) + dependencies.add(method) + dependencies.update(self._get_function_dependencies(method, visited)) + elif (inspect.isfunction(func_obj) or inspect.ismethod(func_obj)) and self._is_user_defined( + func_obj + ): + dependencies.add(func_obj) + dependencies.update(self._get_function_dependencies(func_obj, visited)) + except (NameError, AttributeError): + pass + return dependencies + + def _get_callable_name(self, node: ast.AST) -> Union[str, None]: + """Retrieve the name of the callable from an AST node.""" + if isinstance(node, ast.Name): + return node.id + elif isinstance(node, ast.Attribute): + return f"{node.value.id}.{node.attr}" if isinstance(node.value, ast.Name) else node.attr + return None + + def __resolve_callable(self, func_name: str, globals_dict: dict) -> Callable[..., Any]: + """Resolve a callable from its name within the given globals dictionary.""" + parts = func_name.split(".") + obj = globals_dict.get(parts[0], None) + for part in parts[1:]: + obj = getattr(obj, part, None) + if obj is None: + break + return obj + + def _resolve_callable(self, func_name: str, globals_dict: dict) -> Callable[..., Any]: + """Resolve a callable from its name within the given globals dictionary, handling modules as entry points.""" + parts = func_name.split(".") + + # First, try resolving directly from globals_dict for a straightforward reference + obj = globals_dict.get(parts[0], None) + for part in parts[1:]: + if obj is None: + break + obj = getattr(obj, part, None) + + # If not found, iterate through modules in globals_dict and attempt resolution from them + if not callable(obj): + for module in globals_dict.values(): + if isinstance(module, type(sys)): # Check if the global value is a module + obj = module + for part in parts: + obj = getattr(obj, part, None) + if obj is None: + break + if callable(obj): # Exit if we successfully resolve the callable + break + obj = None # Reset if we didn't find the callable in this module + + # Return the callable if successfully resolved; otherwise, None + return obj if callable(obj) else None + + def _is_user_defined(self, obj: Any) -> bool: + """Check if a callable or class is user-defined within the package.""" + module_name = getattr(obj, "__module__", None) + if not module_name: + return False + + # Retrieve the module specification to get its path + with temporarily_add_to_syspath(self.root_dir): + spec = importlib.util.find_spec(module_name) + if not spec or not spec.origin: + return False + + module_path = Path(spec.origin).resolve() + + # Check if the module is within the root directory but not in site-packages + if self.root_dir in module_path.parents: + # Exclude standard library or site-packages by checking common paths + site_packages_paths = {Path(p).resolve() for p in sys.path if 'site-packages' in p} + is_in_site_packages = any(sp in module_path.parents for sp in site_packages_paths) + + # Return True if within root_dir but not in site-packages + return not is_in_site_packages + + return False + + def __is_user_defined(self, obj: Any) -> bool: + """Check if a callable or class is user-defined within the package.""" + module_name = getattr(obj, "__module__", None) + return module_name and self._can_import_module(module_name) + + def _can_import_module(self, module_name: str) -> bool: + """ + Check if a module with the given name can be imported from the specified root package directory. + + Args: + module_name (str): The module name to check for import. + + Returns: + bool: True if the module can be imported from the root directory, False otherwise. + """ + with temporarily_add_to_syspath(self.root_dir): + spec = importlib.util.find_spec(module_name) + if spec and spec.origin: + # Check if the module's path is within the root directory + module_path = Path(spec.origin).resolve() + return self.root_dir in module_path.parents + return False + + +from my_package.main import my_main_function as func + +vh = CachePrivateModules(salt="salt", root_dir="./my_package") + +hash_components = [vh._get_version(func)] +# Gather dependencies and add their hashes +dependencies = vh._get_function_dependencies(func, set()) +for dep in dependencies: + hash_components.append(vh._get_version(dep)) +# Combine all component hashes into a single version hash +combined_hash = hashlib.sha256("".join(hash_components).encode("utf-8")).hexdigest() +print(combined_hash) diff --git a/plugins/flytekit-auto-cache/tests/test_recursive.py b/plugins/flytekit-auto-cache/tests/test_recursive.py index 913bcd921d..1dbbf56a83 100644 --- a/plugins/flytekit-auto-cache/tests/test_recursive.py +++ b/plugins/flytekit-auto-cache/tests/test_recursive.py @@ -16,14 +16,15 @@ def test_dependencies(): } cache = CachePrivateModules(salt="salt", root_dir="./my_package") + actual_dependencies = cache._get_function_dependencies(func, set()) - actual_dependencies = { + actual_dependencies_str = { f"{dep.__module__}.{dep.__qualname__}".replace("my_package.", "") - for dep in cache.dependencies + for dep in actual_dependencies } - assert actual_dependencies == expected_dependencies, ( + assert actual_dependencies_str == expected_dependencies, ( f"Dependencies do not match:\n" f"Expected: {expected_dependencies}\n" - f"Actual: {actual_dependencies}" + f"Actual: {actual_dependencies_str}" ) From f76f59a752e7639117279c865031c5764ad4571e Mon Sep 17 00:00:00 2001 From: Daniel Sola Date: Wed, 13 Nov 2024 15:45:39 -0800 Subject: [PATCH 06/16] cleanup Signed-off-by: Daniel Sola --- .../tests/recursive_modules_v3.py | 158 -------------- .../tests/recursive_modules_v4.py | 194 ------------------ 2 files changed, 352 deletions(-) delete mode 100644 plugins/flytekit-auto-cache/tests/recursive_modules_v3.py delete mode 100644 plugins/flytekit-auto-cache/tests/recursive_modules_v4.py diff --git a/plugins/flytekit-auto-cache/tests/recursive_modules_v3.py b/plugins/flytekit-auto-cache/tests/recursive_modules_v3.py deleted file mode 100644 index f522596d8d..0000000000 --- a/plugins/flytekit-auto-cache/tests/recursive_modules_v3.py +++ /dev/null @@ -1,158 +0,0 @@ -import ast -import hashlib -import importlib.util -import inspect -import sys -import textwrap -from contextlib import contextmanager -from pathlib import Path -from typing import Callable, Any, Set, Union - -@contextmanager -def temporarily_add_to_syspath(path): - """Temporarily add the given path to sys.path.""" - sys.path.insert(0, str(path)) - try: - yield - finally: - sys.path.pop(0) - - -class VersionHasher: - def __init__(self, salt: str, root_dir: str): - self.salt = salt - self.root_dir = Path(root_dir).resolve() - - def get_version(self, func: Callable[..., Any]) -> str: - source = inspect.getsource(func) - # Dedent the source code to handle class method indentation - dedented_source = textwrap.dedent(source) - parsed_ast = ast.parse(dedented_source) - ast_bytes = ast.dump(parsed_ast).encode("utf-8") - combined_data = ast_bytes + self.salt.encode("utf-8") - return hashlib.sha256(combined_data).hexdigest() - - def _get_function_dependencies(self, func: Callable[..., Any], visited: Set[str]) -> Set[Callable[..., Any]]: - """Recursively gather all functions, methods, and classes used within `func` and defined in the user’s package.""" - dependencies = set() - # Dedent the source code to handle class method indentation - source = textwrap.dedent(inspect.getsource(func)) - parsed_ast = ast.parse(source) - - # Build a locals dictionary for function-level imports - locals_dict = {} - for node in ast.walk(parsed_ast): - if isinstance(node, ast.Import): - for alias in node.names: - module = importlib.import_module(alias.name) - locals_dict[alias.asname or alias.name] = module - elif isinstance(node, ast.ImportFrom): - module = importlib.import_module(node.module) - for alias in node.names: - imported_obj = getattr(module, alias.name, None) - if imported_obj: - locals_dict[alias.asname or alias.name] = imported_obj - - # Check each function call in the AST - for node in ast.walk(parsed_ast): - if isinstance(node, ast.Call): - func_name = self._get_callable_name(node.func) - if func_name and func_name not in visited: - visited.add(func_name) - try: - # Attempt to resolve using locals first, then globals - func_obj = locals_dict.get(func_name) or self._resolve_callable(func_name, func.__globals__) - if inspect.isclass(func_obj) and self._is_user_defined(func_obj): - # Add class methods as dependencies - for name, method in inspect.getmembers(func_obj, predicate=inspect.isfunction): - if method not in visited: - visited.add(method.__qualname__) - dependencies.add(method) - dependencies.update(self._get_function_dependencies(method, visited)) - elif (inspect.isfunction(func_obj) or inspect.ismethod(func_obj)) and self._is_user_defined( - func_obj): - dependencies.add(func_obj) - dependencies.update(self._get_function_dependencies(func_obj, visited)) - except (NameError, AttributeError): - pass - return dependencies - def _get_callable_name(self, node: ast.AST) -> Union[str, None]: - """Retrieve the name of the callable from an AST node.""" - if isinstance(node, ast.Name): - return node.id - elif isinstance(node, ast.Attribute): - return f"{node.value.id}.{node.attr}" if isinstance(node.value, ast.Name) else node.attr - return None - - def __resolve_callable(self, func_name: str, globals_dict: dict) -> Callable[..., Any]: - """Resolve a callable from its name within the given globals dictionary.""" - parts = func_name.split(".") - obj = globals_dict.get(parts[0], None) - for part in parts[1:]: - obj = getattr(obj, part, None) - if obj is None: - break - return obj - - def _resolve_callable(self, func_name: str, globals_dict: dict) -> Callable[..., Any]: - """Resolve a callable from its name within the given globals dictionary, handling modules as entry points.""" - parts = func_name.split(".") - - # First, try resolving directly from globals_dict for a straightforward reference - obj = globals_dict.get(parts[0], None) - for part in parts[1:]: - if obj is None: - break - obj = getattr(obj, part, None) - - # If not found, iterate through modules in globals_dict and attempt resolution from them - if not callable(obj): - for module in globals_dict.values(): - if isinstance(module, type(sys)): # Check if the global value is a module - obj = module - for part in parts: - obj = getattr(obj, part, None) - if obj is None: - break - if callable(obj): # Exit if we successfully resolve the callable - break - obj = None # Reset if we didn't find the callable in this module - - # Return the callable if successfully resolved; otherwise, None - return obj if callable(obj) else None - - def _is_user_defined(self, obj: Any) -> bool: - """Check if a callable or class is user-defined within the package.""" - module_name = getattr(obj, "__module__", None) - return module_name and self._can_import_module(module_name) - - def _can_import_module(self, module_name: str) -> bool: - """ - Check if a module with the given name can be imported from the specified root package directory. - - Args: - module_name (str): The module name to check for import. - - Returns: - bool: True if the module can be imported from the root directory, False otherwise. - """ - with temporarily_add_to_syspath(self.root_dir): - spec = importlib.util.find_spec(module_name) - if spec and spec.origin: - # Check if the module's path is within the root directory - module_path = Path(spec.origin).resolve() - return self.root_dir in module_path.parents - return False - -from my_package.main import my_main_function as func - -vh = VersionHasher(salt="salt", root_dir="./my_package") - -hash_components = [vh.get_version(func)] -# Gather dependencies and add their hashes -dependencies = vh._get_function_dependencies(func, set()) -for dep in dependencies: - hash_components.append(vh.get_version(dep)) -# Combine all component hashes into a single version hash -combined_hash = hashlib.sha256("".join(hash_components).encode("utf-8")).hexdigest() -print(combined_hash) diff --git a/plugins/flytekit-auto-cache/tests/recursive_modules_v4.py b/plugins/flytekit-auto-cache/tests/recursive_modules_v4.py deleted file mode 100644 index 019a5ba16a..0000000000 --- a/plugins/flytekit-auto-cache/tests/recursive_modules_v4.py +++ /dev/null @@ -1,194 +0,0 @@ -import ast -import hashlib -import importlib.util -import inspect -import sys -import textwrap -from contextlib import contextmanager -from pathlib import Path -from typing import Callable, Any, Set, Union - -@contextmanager -def temporarily_add_to_syspath(path): - """Temporarily add the given path to sys.path.""" - sys.path.insert(0, str(path)) - try: - yield - finally: - sys.path.pop(0) - - -class CachePrivateModules: - def __init__(self, salt: str, root_dir: str): - self.salt = salt - self.root_dir = Path(root_dir).resolve() - - def get_version(self, func: Callable[..., Any]) -> str: - hash_components = [self._get_version(func)] - dependencies = self._get_function_dependencies(func, set()) - for dep in dependencies: - hash_components.append(self._get_version(dep)) - # Combine all component hashes into a single version hash - combined_hash = hashlib.sha256("".join(hash_components).encode("utf-8")).hexdigest() - return combined_hash - - def _get_version(self, func: Callable[..., Any]) -> str: - source = inspect.getsource(func) - dedented_source = textwrap.dedent(source) - parsed_ast = ast.parse(dedented_source) - ast_bytes = ast.dump(parsed_ast).encode("utf-8") - combined_data = ast_bytes + self.salt.encode("utf-8") - return hashlib.sha256(combined_data).hexdigest() - - def _get_function_dependencies(self, func: Callable[..., Any], visited: Set[str]) -> Set[Callable[..., Any]]: - """Recursively gather all functions, methods, and classes used within `func` and defined in the user’s package.""" - dependencies = set() - # Dedent the source code to handle class method indentation - source = textwrap.dedent(inspect.getsource(func)) - parsed_ast = ast.parse(source) - - # Build a locals dictionary for function-level imports - locals_dict = {} - for node in ast.walk(parsed_ast): - if isinstance(node, ast.Import): - for alias in node.names: - module = importlib.import_module(alias.name) - locals_dict[alias.asname or alias.name] = module - elif isinstance(node, ast.ImportFrom): - module = importlib.import_module(node.module) - for alias in node.names: - imported_obj = getattr(module, alias.name, None) - if imported_obj: - locals_dict[alias.asname or alias.name] = imported_obj - - # Check each function call in the AST - for node in ast.walk(parsed_ast): - if isinstance(node, ast.Call): - func_name = self._get_callable_name(node.func) - if func_name and func_name not in visited: - visited.add(func_name) - try: - # Attempt to resolve using locals first, then globals - func_obj = locals_dict.get(func_name) or self._resolve_callable(func_name, func.__globals__) - if inspect.isclass(func_obj) and self._is_user_defined(func_obj): - # Add class methods as dependencies - for name, method in inspect.getmembers(func_obj, predicate=inspect.isfunction): - if method not in visited: - visited.add(method.__qualname__) - dependencies.add(method) - dependencies.update(self._get_function_dependencies(method, visited)) - elif (inspect.isfunction(func_obj) or inspect.ismethod(func_obj)) and self._is_user_defined( - func_obj - ): - dependencies.add(func_obj) - dependencies.update(self._get_function_dependencies(func_obj, visited)) - except (NameError, AttributeError): - pass - return dependencies - - def _get_callable_name(self, node: ast.AST) -> Union[str, None]: - """Retrieve the name of the callable from an AST node.""" - if isinstance(node, ast.Name): - return node.id - elif isinstance(node, ast.Attribute): - return f"{node.value.id}.{node.attr}" if isinstance(node.value, ast.Name) else node.attr - return None - - def __resolve_callable(self, func_name: str, globals_dict: dict) -> Callable[..., Any]: - """Resolve a callable from its name within the given globals dictionary.""" - parts = func_name.split(".") - obj = globals_dict.get(parts[0], None) - for part in parts[1:]: - obj = getattr(obj, part, None) - if obj is None: - break - return obj - - def _resolve_callable(self, func_name: str, globals_dict: dict) -> Callable[..., Any]: - """Resolve a callable from its name within the given globals dictionary, handling modules as entry points.""" - parts = func_name.split(".") - - # First, try resolving directly from globals_dict for a straightforward reference - obj = globals_dict.get(parts[0], None) - for part in parts[1:]: - if obj is None: - break - obj = getattr(obj, part, None) - - # If not found, iterate through modules in globals_dict and attempt resolution from them - if not callable(obj): - for module in globals_dict.values(): - if isinstance(module, type(sys)): # Check if the global value is a module - obj = module - for part in parts: - obj = getattr(obj, part, None) - if obj is None: - break - if callable(obj): # Exit if we successfully resolve the callable - break - obj = None # Reset if we didn't find the callable in this module - - # Return the callable if successfully resolved; otherwise, None - return obj if callable(obj) else None - - def _is_user_defined(self, obj: Any) -> bool: - """Check if a callable or class is user-defined within the package.""" - module_name = getattr(obj, "__module__", None) - if not module_name: - return False - - # Retrieve the module specification to get its path - with temporarily_add_to_syspath(self.root_dir): - spec = importlib.util.find_spec(module_name) - if not spec or not spec.origin: - return False - - module_path = Path(spec.origin).resolve() - - # Check if the module is within the root directory but not in site-packages - if self.root_dir in module_path.parents: - # Exclude standard library or site-packages by checking common paths - site_packages_paths = {Path(p).resolve() for p in sys.path if 'site-packages' in p} - is_in_site_packages = any(sp in module_path.parents for sp in site_packages_paths) - - # Return True if within root_dir but not in site-packages - return not is_in_site_packages - - return False - - def __is_user_defined(self, obj: Any) -> bool: - """Check if a callable or class is user-defined within the package.""" - module_name = getattr(obj, "__module__", None) - return module_name and self._can_import_module(module_name) - - def _can_import_module(self, module_name: str) -> bool: - """ - Check if a module with the given name can be imported from the specified root package directory. - - Args: - module_name (str): The module name to check for import. - - Returns: - bool: True if the module can be imported from the root directory, False otherwise. - """ - with temporarily_add_to_syspath(self.root_dir): - spec = importlib.util.find_spec(module_name) - if spec and spec.origin: - # Check if the module's path is within the root directory - module_path = Path(spec.origin).resolve() - return self.root_dir in module_path.parents - return False - - -from my_package.main import my_main_function as func - -vh = CachePrivateModules(salt="salt", root_dir="./my_package") - -hash_components = [vh._get_version(func)] -# Gather dependencies and add their hashes -dependencies = vh._get_function_dependencies(func, set()) -for dep in dependencies: - hash_components.append(vh._get_version(dep)) -# Combine all component hashes into a single version hash -combined_hash = hashlib.sha256("".join(hash_components).encode("utf-8")).hexdigest() -print(combined_hash) From a5fc1bb65a12631e9fdd4c2ddb4ab64ee7a4000f Mon Sep 17 00:00:00 2001 From: Daniel Sola Date: Wed, 13 Nov 2024 15:53:40 -0800 Subject: [PATCH 07/16] cleanup Signed-off-by: Daniel Sola --- plugins/flytekit-auto-cache/tests/my_package/main.py | 2 -- plugins/flytekit-auto-cache/tests/my_package/module_a.py | 8 -------- plugins/flytekit-auto-cache/tests/my_package/module_b.py | 6 ------ plugins/flytekit-auto-cache/tests/my_package/module_c.py | 7 ------- 4 files changed, 23 deletions(-) diff --git a/plugins/flytekit-auto-cache/tests/my_package/main.py b/plugins/flytekit-auto-cache/tests/my_package/main.py index 80811d3f48..fa113e44f4 100644 --- a/plugins/flytekit-auto-cache/tests/my_package/main.py +++ b/plugins/flytekit-auto-cache/tests/my_package/main.py @@ -4,7 +4,6 @@ # Add the parent directory of `my_package` to sys.path sys.path.append(str(Path(__file__).resolve().parent)) -# import module_a from module_a import helper_function from my_dir.module_in_dir import helper_in_directory from module_c import DummyClass @@ -14,7 +13,6 @@ def my_main_function(): print("Main function") helper_in_directory() helper_function() - # module_a.helper_function() df = pd.DataFrame({"a": [1, 2, 3]}) print(df) dc = DummyClass() diff --git a/plugins/flytekit-auto-cache/tests/my_package/module_a.py b/plugins/flytekit-auto-cache/tests/my_package/module_a.py index cc85c7f9ab..3c6d29ec13 100644 --- a/plugins/flytekit-auto-cache/tests/my_package/module_a.py +++ b/plugins/flytekit-auto-cache/tests/my_package/module_a.py @@ -1,16 +1,8 @@ -# import sys -# from pathlib import Path -# -# # Add the parent directory of `my_package` to sys.path -# sys.path.append(str(Path(__file__).resolve().parent)) - -# from module_b import another_helper import module_b def helper_function(): print("Helper function") module_b.another_helper() - # another_helper() def unused_function(): print("Unused function") diff --git a/plugins/flytekit-auto-cache/tests/my_package/module_b.py b/plugins/flytekit-auto-cache/tests/my_package/module_b.py index ba8a6d21be..641810f068 100644 --- a/plugins/flytekit-auto-cache/tests/my_package/module_b.py +++ b/plugins/flytekit-auto-cache/tests/my_package/module_b.py @@ -1,9 +1,3 @@ -# import sys -# from pathlib import Path -# -# # Add the parent directory of `my_package` to sys.path -# sys.path.append(str(Path(__file__).resolve().parent)) - from module_c import third_helper def another_helper(): diff --git a/plugins/flytekit-auto-cache/tests/my_package/module_c.py b/plugins/flytekit-auto-cache/tests/my_package/module_c.py index eb144754fe..7c711d58f4 100644 --- a/plugins/flytekit-auto-cache/tests/my_package/module_c.py +++ b/plugins/flytekit-auto-cache/tests/my_package/module_c.py @@ -1,10 +1,3 @@ -# import sys -# from pathlib import Path -# -# # Add the parent directory of `my_package` to sys.path -# sys.path.append(str(Path(__file__).resolve().parent)) - -# from module_d import fourth_helper import my_dir def third_helper(): From ff3af99d7bbc0b6fb5992d3d6d26e71e1f8c0e7a Mon Sep 17 00:00:00 2001 From: Daniel Sola Date: Thu, 21 Nov 2024 14:11:24 -0800 Subject: [PATCH 08/16] add image hash Signed-off-by: Daniel Sola --- flytekit/core/auto_cache.py | 71 ++++++++++---- flytekit/core/task.py | 36 +++---- flytekit/core/workflow.py | 10 +- .../auto_cache/cache_function_body.py | 8 +- .../flytekitplugins/auto_cache/cache_image.py | 22 +++++ .../auto_cache/cache_private_modules.py | 9 +- .../tests/test_function_body.py | 36 +++++-- .../flytekit-auto-cache/tests/test_image.py | 93 +++++++++++++++++++ 8 files changed, 233 insertions(+), 52 deletions(-) create mode 100644 plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_image.py create mode 100644 plugins/flytekit-auto-cache/tests/test_image.py diff --git a/flytekit/core/auto_cache.py b/flytekit/core/auto_cache.py index adf4741a0b..2915abb729 100644 --- a/flytekit/core/auto_cache.py +++ b/flytekit/core/auto_cache.py @@ -1,4 +1,21 @@ -from typing import Any, Callable, Protocol, runtime_checkable +from dataclasses import dataclass +from typing import Any, Callable, Optional, Protocol, Union, runtime_checkable + +from flytekit.image_spec.image_spec import ImageSpec + + +@dataclass +class VersionParameters: + """ + Parameters used for version hash generation. + + Args: + func (Optional[Callable]): The function to generate a version for + container_image (Optional[Union[str, ImageSpec]]): The container image to generate a version for + """ + + func: Optional[Callable[..., Any]] = None + container_image: Optional[Union[str, ImageSpec]] = None @runtime_checkable @@ -6,32 +23,54 @@ class AutoCache(Protocol): """ A protocol that defines the interface for a caching mechanism that generates a version hash of a function based on its source code. - - Attributes: - salt (str): A string used to add uniqueness to the generated hash. Default is "salt". - - Methods: - get_version(func: Callable[..., Any]) -> str: - Given a function, generates a version hash based on its source code and the salt. """ - def __init__(self, salt: str = "salt") -> None: + salt: str + + def get_version(self, params: VersionParameters) -> str: """ - Initialize the AutoCache instance with a salt value. + Generate a version hash based on the provided parameters. Args: - salt (str): A string to be used as the salt in the hashing process. Defaults to "salt". + params (VersionParameters): Parameters to use for hash generation. + + Returns: + str: The generated version hash. """ + ... + + +class CachePolicy: + """ + A class that combines multiple caching mechanisms to generate a version hash. + + Args: + *cache_objects: Variable number of AutoCache instances + salt: Optional salt string to add uniqueness to the hash + """ + + def __init__(self, *cache_objects: AutoCache, salt: str = "") -> None: + self.cache_objects = cache_objects self.salt = salt - def get_version(self, func: Callable[..., Any]) -> str: + def get_version(self, params: VersionParameters) -> str: """ - Generate a version hash for the provided function. + Generate a version hash using all cache objects. Args: - func (Callable[..., Any]): A callable function whose version hash needs to be generated. + params (VersionParameters): Parameters to use for hash generation. Returns: - str: The SHA-256 hash of the function's source code combined with the salt. + str: The combined hash from all cache objects. """ - ... + task_hash = "" + for cache_instance in self.cache_objects: + # Apply the policy's salt to each cache instance + cache_instance.salt = self.salt + task_hash += cache_instance.get_version(params) + + # Generate SHA-256 hash + import hashlib + + hash_obj = hashlib.sha256(task_hash.encode()) + return hash_obj.hexdigest() diff --git a/flytekit/core/task.py b/flytekit/core/task.py index 7519f341d6..cfd509b278 100644 --- a/flytekit/core/task.py +++ b/flytekit/core/task.py @@ -5,7 +5,7 @@ from functools import update_wrapper from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union, overload -from flytekit.core.auto_cache import AutoCache +from flytekit.core.auto_cache import CachePolicy, VersionParameters from flytekit.core.utils import str2bool try: @@ -100,7 +100,7 @@ def find_pythontask_plugin(cls, plugin_config_type: type) -> Type[PythonFunction def task( _task_function: None = ..., task_config: Optional[T] = ..., - cache: Union[bool, list[AutoCache]] = ..., + cache: Union[bool, CachePolicy] = ..., cache_serialize: bool = ..., cache_version: str = ..., cache_ignore_input_vars: Tuple[str, ...] = ..., @@ -136,9 +136,9 @@ def task( @overload def task( - _task_function: Callable[P, FuncOut], + _task_function: Callable[..., FuncOut], task_config: Optional[T] = ..., - cache: Union[bool, list[AutoCache]] = ..., + cache: Union[bool, CachePolicy] = ..., cache_serialize: bool = ..., cache_version: str = ..., cache_ignore_input_vars: Tuple[str, ...] = ..., @@ -169,13 +169,13 @@ def task( pod_template: Optional["PodTemplate"] = ..., pod_template_name: Optional[str] = ..., accelerator: Optional[BaseAccelerator] = ..., -) -> Union[Callable[P, FuncOut], PythonFunctionTask[T]]: ... +) -> Union[Callable[..., FuncOut], PythonFunctionTask[T]]: ... def task( - _task_function: Optional[Callable[P, FuncOut]] = None, + _task_function: Optional[Callable[..., FuncOut]] = None, task_config: Optional[T] = None, - cache: Union[bool, list[AutoCache]] = False, + cache: Union[bool, CachePolicy] = False, cache_serialize: bool = False, cache_version: str = "", cache_ignore_input_vars: Tuple[str, ...] = (), @@ -213,8 +213,8 @@ def task( pod_template_name: Optional[str] = None, accelerator: Optional[BaseAccelerator] = None, ) -> Union[ - Callable[P, FuncOut], - Callable[[Callable[P, FuncOut]], PythonFunctionTask[T]], + Callable[..., FuncOut], + Callable[[Callable[..., FuncOut]], PythonFunctionTask[T]], PythonFunctionTask[T], ]: """ @@ -343,17 +343,19 @@ def launch_dynamically(): :param accelerator: The accelerator to use for this task. """ - def wrapper(fn: Callable[P, Any]) -> PythonFunctionTask[T]: - if isinstance(cache, list) and all(isinstance(item, AutoCache) for item in cache): - cache_versions = [item.get_version() for item in cache] - task_hash = "".join(cache_versions) + def wrapper(fn: Callable[..., Any]) -> PythonFunctionTask[T]: + if isinstance(cache, CachePolicy): + params = VersionParameters(func=fn, container_image=container_image) + cache_version_val = cache.get_version(params=params) + cache_val = True else: - task_hash = "" + cache_val = cache + cache_version_val = cache_version _metadata = TaskMetadata( - cache=cache, + cache=cache_val, cache_serialize=cache_serialize, - cache_version=cache_version if not task_hash else task_hash, + cache_version=cache_version_val, cache_ignore_input_vars=cache_ignore_input_vars, retries=retries, interruptible=interruptible, @@ -439,7 +441,7 @@ def wrapper(fn) -> ReferenceTask: return wrapper -def decorate_function(fn: Callable[P, Any]) -> Callable[P, Any]: +def decorate_function(fn: Callable[..., Any]) -> Callable[..., Any]: """ Decorates the task with additional functionality if necessary. diff --git a/flytekit/core/workflow.py b/flytekit/core/workflow.py index de0f620e96..07f99103e2 100644 --- a/flytekit/core/workflow.py +++ b/flytekit/core/workflow.py @@ -843,23 +843,23 @@ def workflow( @overload def workflow( - _workflow_function: Callable[P, FuncOut], + _workflow_function: Callable[..., FuncOut], failure_policy: Optional[WorkflowFailurePolicy] = ..., interruptible: bool = ..., on_failure: Optional[Union[WorkflowBase, Task]] = ..., docs: Optional[Documentation] = ..., default_options: Optional[Options] = ..., -) -> Union[Callable[P, FuncOut], PythonFunctionWorkflow]: ... +) -> Union[Callable[..., FuncOut], PythonFunctionWorkflow]: ... def workflow( - _workflow_function: Optional[Callable[P, FuncOut]] = None, + _workflow_function: Optional[Callable[..., FuncOut]] = None, failure_policy: Optional[WorkflowFailurePolicy] = None, interruptible: bool = False, on_failure: Optional[Union[WorkflowBase, Task]] = None, docs: Optional[Documentation] = None, default_options: Optional[Options] = None, -) -> Union[Callable[P, FuncOut], Callable[[Callable[P, FuncOut]], PythonFunctionWorkflow], PythonFunctionWorkflow]: +) -> Union[Callable[..., FuncOut], Callable[[Callable[..., FuncOut]], PythonFunctionWorkflow], PythonFunctionWorkflow]: """ This decorator declares a function to be a Flyte workflow. Workflows are declarative entities that construct a DAG of tasks using the data flow between tasks. @@ -894,7 +894,7 @@ def workflow( the labels and annotations are allowed to be set as defaults. """ - def wrapper(fn: Callable[P, FuncOut]) -> PythonFunctionWorkflow: + def wrapper(fn: Callable[..., FuncOut]) -> PythonFunctionWorkflow: workflow_metadata = WorkflowMetadata(on_failure=failure_policy or WorkflowFailurePolicy.FAIL_IMMEDIATELY) workflow_metadata_defaults = WorkflowMetadataDefaults(interruptible) diff --git a/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_function_body.py b/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_function_body.py index 6751b53c0f..830aee94f0 100644 --- a/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_function_body.py +++ b/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_function_body.py @@ -4,6 +4,8 @@ import textwrap from typing import Any, Callable +from flytekit.core.auto_cache import VersionParameters + class CacheFunctionBody: """ @@ -27,8 +29,10 @@ def __init__(self, salt: str = "salt") -> None: """ self.salt = salt - def get_version(self, func: Callable[..., Any]) -> str: - return self._get_version(func=func) + def get_version(self, params: VersionParameters) -> str: + if params.func is None: + raise ValueError("Function-based cache requires a function parameter") + return self._get_version(func=params.func) def _get_version(self, func: Callable[..., Any]) -> str: """ diff --git a/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_image.py b/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_image.py new file mode 100644 index 0000000000..cf57b645ce --- /dev/null +++ b/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_image.py @@ -0,0 +1,22 @@ +import hashlib + +from flytekit.core.auto_cache import VersionParameters +from flytekit.image_spec.image_spec import ImageSpec + + +class CacheImage: + def __init__(self, salt: str): + self.salt = salt + + def get_version(self, params: VersionParameters) -> str: + if params.container_image is None: + raise ValueError("Image-based cache requires a container_image parameter") + + # If the image is an ImageSpec, combine tag with salt + if isinstance(params.container_image, ImageSpec): + combined = params.container_image.tag + self.salt + return hashlib.sha256(combined.encode("utf-8")).hexdigest() + + # If the image is a string, combine with salt + combined = params.container_image + self.salt + return hashlib.sha256(combined.encode("utf-8")).hexdigest() diff --git a/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_private_modules.py b/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_private_modules.py index bee640b240..8166c8811d 100644 --- a/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_private_modules.py +++ b/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_private_modules.py @@ -8,6 +8,8 @@ from pathlib import Path from typing import Any, Callable, Set, Union +from flytekit.core.auto_cache import VersionParameters + @contextmanager def temporarily_add_to_syspath(path): @@ -24,8 +26,11 @@ def __init__(self, salt: str, root_dir: str): self.salt = salt self.root_dir = Path(root_dir).resolve() - def get_version(self, func: Callable[..., Any]) -> str: - hash_components = [self._get_version(func)] + def get_version(self, params: VersionParameters) -> str: + if params.func is None: + raise ValueError("Function-based cache requires a function parameter") + + hash_components = [self._get_version(params.func)] dependencies = self._get_function_dependencies(func, set()) for dep in dependencies: hash_components.append(self._get_version(dep)) diff --git a/plugins/flytekit-auto-cache/tests/test_function_body.py b/plugins/flytekit-auto-cache/tests/test_function_body.py index efebe8ad0a..351b7e62af 100644 --- a/plugins/flytekit-auto-cache/tests/test_function_body.py +++ b/plugins/flytekit-auto-cache/tests/test_function_body.py @@ -1,6 +1,7 @@ from dummy_functions.dummy_function import dummy_function from dummy_functions.dummy_function_comments_formatting_change import dummy_function as dummy_function_comments_formatting_change from dummy_functions.dummy_function_logic_change import dummy_function as dummy_function_logic_change +from flytekit.core.auto_cache import VersionParameters from flytekitplugins.auto_cache import CacheFunctionBody @@ -11,9 +12,11 @@ def test_get_version_with_same_function_and_salt(): cache1 = CacheFunctionBody(salt="salt") cache2 = CacheFunctionBody(salt="salt") + params = VersionParameters(func=dummy_function) + # Both calls should return the same hash since the function and salt are the same - version1 = cache1.get_version(dummy_function) - version2 = cache2.get_version(dummy_function) + version1 = cache1.get_version(params) + version2 = cache2.get_version(params) assert version1 == version2, f"Expected {version1}, but got {version2}" @@ -25,9 +28,11 @@ def test_get_version_with_different_salt(): cache1 = CacheFunctionBody(salt="salt1") cache2 = CacheFunctionBody(salt="salt2") + params = VersionParameters(func=dummy_function) + # The hashes should be different because the salts are different - version1 = cache1.get_version(dummy_function) - version2 = cache2.get_version(dummy_function) + version1 = cache1.get_version(params) + version2 = cache2.get_version(params) assert version1 != version2, f"Expected different hashes but got the same: {version1}" @@ -38,8 +43,12 @@ def test_get_version_with_different_logic(): Test that functions with the same name but different logic produce different hashes. """ cache = CacheFunctionBody(salt="salt") - version1 = cache.get_version(dummy_function) - version2 = cache.get_version(dummy_function_logic_change) + + params1 = VersionParameters(func=dummy_function) + params2 = VersionParameters(func=dummy_function_logic_change) + + version1 = cache.get_version(params1) + version2 = cache.get_version(params2) assert version1 != version2, ( f"Hashes should be different for functions with same name but different logic. " @@ -61,8 +70,11 @@ def test_get_version_with_different_function_names(): """ cache = CacheFunctionBody(salt="salt") - version1 = cache.get_version(function_one) - version2 = cache.get_version(function_two) + params1 = VersionParameters(func=function_one) + params2 = VersionParameters(func=function_two) + + version1 = cache.get_version(params1) + version2 = cache.get_version(params2) assert version1 != version2, ( f"Hashes should be different for functions with different names. " @@ -76,8 +88,12 @@ def test_get_version_with_formatting_changes(): """ cache = CacheFunctionBody(salt="salt") - version1 = cache.get_version(dummy_function) - version2 = cache.get_version(dummy_function_comments_formatting_change) + + params1 = VersionParameters(func=dummy_function) + params2 = VersionParameters(func=dummy_function_comments_formatting_change) + + version1 = cache.get_version(params1) + version2 = cache.get_version(params2) assert version1 == version2, ( f"Hashes should be the same for functions with same name but different formatting. " diff --git a/plugins/flytekit-auto-cache/tests/test_image.py b/plugins/flytekit-auto-cache/tests/test_image.py new file mode 100644 index 0000000000..8a9020ab5c --- /dev/null +++ b/plugins/flytekit-auto-cache/tests/test_image.py @@ -0,0 +1,93 @@ +import pytest # type: ignore +import hashlib +from flytekit.core.auto_cache import VersionParameters +from flytekit.image_spec.image_spec import ImageSpec +from flytekitplugins.auto_cache import CacheImage + + +def test_get_version_with_same_image_and_salt(): + """ + Test that calling get_version with the same image and salt returns the same hash. + """ + cache1 = CacheImage(salt="salt") + cache2 = CacheImage(salt="salt") + + params = VersionParameters(container_image="python:3.9") + + version1 = cache1.get_version(params) + version2 = cache2.get_version(params) + + assert version1 == version2, f"Expected {version1}, but got {version2}" + + +def test_get_version_with_different_salt(): + """ + Test that calling get_version with different salts returns different hashes for the same image. + """ + cache1 = CacheImage(salt="salt1") + cache2 = CacheImage(salt="salt2") + + params = VersionParameters(container_image="python:3.9") + + version1 = cache1.get_version(params) + version2 = cache2.get_version(params) + + assert version1 != version2, f"Expected different hashes but got the same: {version1}" + + +def test_get_version_with_different_images(): + """ + Test that different images produce different hashes. + """ + cache = CacheImage(salt="salt") + + params1 = VersionParameters(container_image="python:3.9") + params2 = VersionParameters(container_image="python:3.8") + + version1 = cache.get_version(params1) + version2 = cache.get_version(params2) + + assert version1 != version2, ( + f"Hashes should be different for different images. " + f"Got {version1} and {version2}" + ) + + +def test_get_version_with_image_spec(): + """ + Test that ImageSpec objects use their tag directly. + """ + cache = CacheImage(salt="salt") + + image_spec = ImageSpec( + name="my-image", + registry="my-registry", + tag="v1.0.0" + ) + params = VersionParameters(container_image=image_spec) + + version = cache.get_version(params) + expected = hashlib.sha256("v1.0.0".encode("utf-8")).hexdigest() + assert version == expected, f"Expected {expected}, but got {version}" + + +def test_get_version_without_image(): + """ + Test that calling get_version without an image raises ValueError. + """ + cache = CacheImage(salt="salt") + params = VersionParameters(func=lambda x: x) # Only providing func, no image + + with pytest.raises(ValueError, match="Image-based cache requires a container_image parameter"): + cache.get_version(params) + + +def test_get_version_with_none_image(): + """ + Test that calling get_version with None image raises ValueError. + """ + cache = CacheImage(salt="salt") + params = VersionParameters(container_image=None) + + with pytest.raises(ValueError, match="Image-based cache requires a container_image parameter"): + cache.get_version(params) From 7d0637015b2f00e8eeabe08396a593cdb7949331 Mon Sep 17 00:00:00 2001 From: Daniel Sola Date: Fri, 22 Nov 2024 13:27:34 -0800 Subject: [PATCH 09/16] formatting --- .../flytekitplugins/auto_cache/cache_private_modules.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_private_modules.py b/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_private_modules.py index 8166c8811d..64781f9740 100644 --- a/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_private_modules.py +++ b/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_private_modules.py @@ -31,7 +31,7 @@ def get_version(self, params: VersionParameters) -> str: raise ValueError("Function-based cache requires a function parameter") hash_components = [self._get_version(params.func)] - dependencies = self._get_function_dependencies(func, set()) + dependencies = self._get_function_dependencies(params.func, set()) for dep in dependencies: hash_components.append(self._get_version(dep)) # Combine all component hashes into a single version hash From d5d9576f80ef5e23de0a1c27d3c616b1188dd341 Mon Sep 17 00:00:00 2001 From: Daniel Sola Date: Wed, 27 Nov 2024 15:26:28 -0800 Subject: [PATCH 10/16] test for internal versions Signed-off-by: Daniel Sola --- .../flytekitplugins/auto_cache/__init__.py | 2 + .../auto_cache/cache_external_dependencies.py | 119 ++++++++++++++++++ .../tests/requirements-test.txt | 20 +++ .../tests/test_external_dependencies.py | 45 +++++++ .../tests/verify_versions.py | 42 +++++++ 5 files changed, 228 insertions(+) create mode 100644 plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_external_dependencies.py create mode 100644 plugins/flytekit-auto-cache/tests/requirements-test.txt create mode 100644 plugins/flytekit-auto-cache/tests/test_external_dependencies.py create mode 100644 plugins/flytekit-auto-cache/tests/verify_versions.py diff --git a/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/__init__.py b/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/__init__.py index 475bc7515f..9c80ba89ea 100644 --- a/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/__init__.py +++ b/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/__init__.py @@ -11,5 +11,7 @@ CachePrivateModules """ +from .cache_external_dependencies import CacheExternalDependencies from .cache_function_body import CacheFunctionBody +from .cache_image import CacheImage from .cache_private_modules import CachePrivateModules diff --git a/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_external_dependencies.py b/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_external_dependencies.py new file mode 100644 index 0000000000..3814143530 --- /dev/null +++ b/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_external_dependencies.py @@ -0,0 +1,119 @@ +import hashlib +import importlib +import sys +from pathlib import Path +from typing import Any, Optional + +import click +from flytekitplugins.auto_cache.cache_private_modules import CachePrivateModules, temporarily_add_to_syspath + +from flytekit.core.auto_cache import VersionParameters + + +class CacheExternalDependencies(CachePrivateModules): + """ + A cache implementation that tracks external package dependencies and their versions. + Inherits the dependency traversal logic from CachePrivateModules but focuses on external packages. + """ + + def __init__(self, salt: str, root_dir: str): + super().__init__(salt=salt, root_dir=root_dir) + self._package_versions = {} # Cache for package versions + self._external_dependencies = set() + + def get_version_dict(self) -> dict[str, str]: + """ + Get a dictionary mapping package names to their versions. + + Returns: + dict[str, str]: Dictionary mapping package names to version strings + """ + versions = {} + for package in sorted(self._external_dependencies): + version = self._get_package_version(package) + if version: + versions[package] = version + return versions + + def get_version(self, params: VersionParameters) -> str: + if params.func is None: + raise ValueError("Function-based cache requires a function parameter") + + # Get all dependencies including nested function calls + _ = self._get_function_dependencies(params.func, set()) + + # Get package versions and create version string + versions = self.get_version_dict() + version_components = [f"{pkg}=={ver}" for pkg, ver in versions.items()] + + # Combine package versions with salt + combined_data = "|".join(version_components).encode("utf-8") + self.salt.encode("utf-8") + return hashlib.sha256(combined_data).hexdigest() + + def _is_user_defined(self, obj: Any) -> bool: + """Check if a callable or class is user-defined within the package.""" + module_name = getattr(obj, "__module__", None) + if not module_name: + return False + + # Retrieve the module specification to get its path + with temporarily_add_to_syspath(self.root_dir): + spec = importlib.util.find_spec(module_name) + if not spec or not spec.origin: + return False + + module_path = Path(spec.origin).resolve() + + site_packages_paths = {Path(p).resolve() for p in sys.path if "site-packages" in p} + is_in_site_packages = any(sp in module_path.parents for sp in site_packages_paths) + + # If it's in site-packages, add the module name to external dependencies + if is_in_site_packages: + root_package = module_name.split(".")[0] + self._external_dependencies.add(root_package) + + # Check if the module is within the root directory but not in site-packages + if self.root_dir in module_path.parents: + # Exclude standard library or site-packages by checking common paths but return True if within root_dir but not in site-packages + return not is_in_site_packages + + return False + + def _get_package_version(self, package_name: str) -> str: + """ + Get the version of an installed package. + + Args: + package_name: Name of the package + + Returns: + str: Version string of the package or "unknown" if version cannot be determined + """ + if package_name in self._package_versions: + return self._package_versions[package_name] + + version: Optional[str] = None + try: + # Try importlib.metadata first (most reliable) + version = importlib.metadata.version(package_name) + except Exception as e: + click.secho(f"Could not get version for {package_name} using importlib.metadata: {str(e)}", fg="yellow") + try: + # Fall back to checking package attributes + package = importlib.import_module(package_name) + version = getattr(package, "__version__", None) + if not version: + version = getattr(package, "version", None) + click.secho(f"Found by {package_name} importing module.", fg="yellow") + except ImportError as e: + click.secho(f"Could not import {package_name}: {str(e)}", fg="yellow") + + if not version: + click.secho( + f"Could not determine version for package {package_name}. " "This may affect cache invalidation.", + fg="yellow", + ) + version = "unknown" + + self._package_versions[package_name] = version + return version diff --git a/plugins/flytekit-auto-cache/tests/requirements-test.txt b/plugins/flytekit-auto-cache/tests/requirements-test.txt new file mode 100644 index 0000000000..c597925b08 --- /dev/null +++ b/plugins/flytekit-auto-cache/tests/requirements-test.txt @@ -0,0 +1,20 @@ +numpy==1.24.3 +pandas==2.0.3 +requests==2.31.0 +matplotlib==3.7.2 +pillow==10.0.0 +scipy==1.11.2 +pytest==7.4.0 +urllib3==2.0.4 +cryptography==41.0.3 +setuptools==68.0.0 +flask==2.3.2 +django==4.2.4 +scikit-learn==1.3.0 +beautifulsoup4==4.12.2 +pyyaml==6.0 +fastapi==0.100.0 +sqlalchemy==2.0.36 +tqdm==4.65.0 +pytest-mock==3.11.0 +jinja2==3.1.2 diff --git a/plugins/flytekit-auto-cache/tests/test_external_dependencies.py b/plugins/flytekit-auto-cache/tests/test_external_dependencies.py new file mode 100644 index 0000000000..8f9c23d65b --- /dev/null +++ b/plugins/flytekit-auto-cache/tests/test_external_dependencies.py @@ -0,0 +1,45 @@ +import subprocess +from pathlib import Path + + + +def test_package_versions_in_isolated_env(): + """ + Test package version detection in an isolated environment with known package versions. + Creates a temporary venv in the test directory and cleans it up after. + """ + test_dir = Path(__file__).parent + plugin_dir = test_dir.parent # Get the plugin root directory + flytekit_dir = plugin_dir.parent.parent # Get the flytekit root directory + reqs_file = test_dir / "requirements-test.txt" + + venv_path = test_dir / ".venv" + subprocess.run(["python", "-m", "venv", str(venv_path)], check=True) + + try: + pip = str(venv_path / "bin" / "pip") + # First install flytekit in editable mode + subprocess.run([pip, "install", "-e", str(flytekit_dir)], check=True) + # Then install the local plugin in editable mode + subprocess.run([pip, "install", "-e", str(plugin_dir)], check=True) + # Finally install the test requirements + subprocess.run([pip, "install", "-r", str(reqs_file)], check=True) + + python = str(venv_path / "bin" / "python") + verify_script = test_dir / "verify_versions.py" + + result = subprocess.run( + [python, str(verify_script)], + capture_output=True, + text=True, + check=True + ) + + assert result.returncode == 0, f"Version verification failed: {result.stderr}" + + finally: + import shutil + shutil.rmtree(venv_path) + +if __name__ == "__main__": + test_package_versions_in_isolated_env() diff --git a/plugins/flytekit-auto-cache/tests/verify_versions.py b/plugins/flytekit-auto-cache/tests/verify_versions.py new file mode 100644 index 0000000000..1dbc32352c --- /dev/null +++ b/plugins/flytekit-auto-cache/tests/verify_versions.py @@ -0,0 +1,42 @@ +from flytekitplugins.auto_cache import CacheExternalDependencies + +# These versions should match requirements-test.txt +EXPECTED_VERSIONS = { + "numpy": "1.24.3", + "pandas": "2.0.3", + "requests": "2.31.0", + "matplotlib": "3.7.2", + "pillow": "10.0.0", + "scipy": "1.11.2", + "pytest": "7.4.0", + "urllib3": "2.0.4", + "cryptography": "41.0.3", + "setuptools": "68.0.0", + "flask": "2.3.2", + "django": "4.2.4", + "scikit-learn": "1.3.0", + "beautifulsoup4": "4.12.2", + "pyyaml": "6.0", + "fastapi": "0.100.0", + "sqlalchemy": "2.0.36", + "tqdm": "4.65.0", + "pytest-mock": "3.11.0", + "jinja2": "3.1.2", +} + + +def main(): + cache = CacheExternalDependencies(salt="salt", root_dir="./my_package") + # Hydrate _external_dependencies that would be discovered by running _get_function_dependencies. + cache._external_dependencies = set(EXPECTED_VERSIONS.keys()) + versions = cache.get_version_dict() + + # Verify that the versions extracted by the cache match the versions that are actually in the environment + for package, expected_version in EXPECTED_VERSIONS.items(): + actual_version = versions.get(package) + assert actual_version == expected_version, \ + f"Version mismatch for {package}. Expected {expected_version}, got {actual_version}" + print("Package versions dict matches expected versions!") + +if __name__ == "__main__": + main() From 2b0fbb81f5f85275d2a8b8f5a40ed5bd6d7435e5 Mon Sep 17 00:00:00 2001 From: Daniel Sola Date: Thu, 28 Nov 2024 09:34:45 -0800 Subject: [PATCH 11/16] test for identified packages Signed-off-by: Daniel Sola --- .../auto_cache/cache_private_modules.py | 10 +++++++++- .../tests/my_package/module_a.py | 6 ++++++ .../tests/my_package/module_b.py | 4 ++++ .../tests/my_package/module_c.py | 8 ++++++-- .../tests/my_package/my_dir/__init__.py | 1 + .../tests/my_package/my_dir/module_in_dir.py | 7 ++++++- .../tests/test_external_dependencies.py | 20 +++++++++++++++---- .../tests/verify_identified_packages.py | 14 +++++++++++++ 8 files changed, 62 insertions(+), 8 deletions(-) create mode 100644 plugins/flytekit-auto-cache/tests/verify_identified_packages.py diff --git a/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_private_modules.py b/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_private_modules.py index 64781f9740..e055f388cc 100644 --- a/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_private_modules.py +++ b/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_private_modules.py @@ -63,9 +63,14 @@ def _get_function_dependencies(self, func: Callable[..., Any], visited: Set[str] elif isinstance(node, ast.ImportFrom): module = importlib.import_module(node.module) for alias in node.names: + # Resolve attributes or submodules imported_obj = getattr(module, alias.name, None) if imported_obj: locals_dict[alias.asname or alias.name] = imported_obj + else: + # Fallback: attempt to import as submodule. e.g. `from PIL import Image` + submodule = importlib.import_module(f"{node.module}.{alias.name}") + locals_dict[alias.asname or alias.name] = submodule # Check each function call in the AST for node in ast.walk(parsed_ast): @@ -75,7 +80,10 @@ def _get_function_dependencies(self, func: Callable[..., Any], visited: Set[str] visited.add(func_name) try: # Attempt to resolve using locals first, then globals - func_obj = locals_dict.get(func_name) or self._resolve_callable(func_name, func.__globals__) + # func_obj = locals_dict.get(func_name) or self._resolve_callable(func_name, func.__globals__) + func_obj = self._resolve_callable(func_name, locals_dict) or self._resolve_callable( + func_name, func.__globals__ + ) if inspect.isclass(func_obj) and self._is_user_defined(func_obj): # Add class methods as dependencies for name, method in inspect.getmembers(func_obj, predicate=inspect.isfunction): diff --git a/plugins/flytekit-auto-cache/tests/my_package/module_a.py b/plugins/flytekit-auto-cache/tests/my_package/module_a.py index 3c6d29ec13..03a687eb87 100644 --- a/plugins/flytekit-auto-cache/tests/my_package/module_a.py +++ b/plugins/flytekit-auto-cache/tests/my_package/module_a.py @@ -1,8 +1,14 @@ import module_b +from scipy.linalg import norm +from cryptography.fernet import Fernet def helper_function(): print("Helper function") module_b.another_helper() + result = norm([1, 2, 3]) + print(result) def unused_function(): print("Unused function") + key = Fernet.generate_key() + print(key) diff --git a/plugins/flytekit-auto-cache/tests/my_package/module_b.py b/plugins/flytekit-auto-cache/tests/my_package/module_b.py index 641810f068..2634b5f32a 100644 --- a/plugins/flytekit-auto-cache/tests/my_package/module_b.py +++ b/plugins/flytekit-auto-cache/tests/my_package/module_b.py @@ -1,5 +1,9 @@ from module_c import third_helper +from my_dir import bs def another_helper(): print("Another helper") third_helper() + html = "

Hello, world!

" + soup = bs.BeautifulSoup(html, "html.parser") + print(soup.p.text) diff --git a/plugins/flytekit-auto-cache/tests/my_package/module_c.py b/plugins/flytekit-auto-cache/tests/my_package/module_c.py index 7c711d58f4..a02cc26c29 100644 --- a/plugins/flytekit-auto-cache/tests/my_package/module_c.py +++ b/plugins/flytekit-auto-cache/tests/my_package/module_c.py @@ -5,10 +5,14 @@ def third_helper(): class DummyClass: def dummy_method(self) -> str: - my_dir.module_in_dir.other_helper_in_directory() + my_dir.other_helper_in_directory() + import numpy as np + print(np.mean(np.array([1, 2, 3, 4, 5]))) return "Hello from dummy method!" def other_dummy_method(self): from module_d import fourth_helper - print("Other dummy method") + from PIL import Image + img = Image.new("RGB", (100, 100), color="white") + print(img.info) fourth_helper() diff --git a/plugins/flytekit-auto-cache/tests/my_package/my_dir/__init__.py b/plugins/flytekit-auto-cache/tests/my_package/my_dir/__init__.py index 14fd12b8cc..8ab0842a0c 100644 --- a/plugins/flytekit-auto-cache/tests/my_package/my_dir/__init__.py +++ b/plugins/flytekit-auto-cache/tests/my_package/my_dir/__init__.py @@ -1 +1,2 @@ from .module_in_dir import other_helper_in_directory +import bs4 as bs diff --git a/plugins/flytekit-auto-cache/tests/my_package/my_dir/module_in_dir.py b/plugins/flytekit-auto-cache/tests/my_package/my_dir/module_in_dir.py index 7159a09643..a8279c677a 100644 --- a/plugins/flytekit-auto-cache/tests/my_package/my_dir/module_in_dir.py +++ b/plugins/flytekit-auto-cache/tests/my_package/my_dir/module_in_dir.py @@ -1,5 +1,10 @@ +from sklearn.preprocessing import StandardScaler + def helper_in_directory(): print("Helper in directory") def other_helper_in_directory(): - print("Other helper in directory") + data = [[1, 2], [3, 4], [5, 6]] + scaler = StandardScaler() + scaled_data = scaler.fit_transform(data) + print(scaled_data) diff --git a/plugins/flytekit-auto-cache/tests/test_external_dependencies.py b/plugins/flytekit-auto-cache/tests/test_external_dependencies.py index 8f9c23d65b..4f4f96318a 100644 --- a/plugins/flytekit-auto-cache/tests/test_external_dependencies.py +++ b/plugins/flytekit-auto-cache/tests/test_external_dependencies.py @@ -26,16 +26,28 @@ def test_package_versions_in_isolated_env(): subprocess.run([pip, "install", "-r", str(reqs_file)], check=True) python = str(venv_path / "bin" / "python") - verify_script = test_dir / "verify_versions.py" - result = subprocess.run( - [python, str(verify_script)], + # Run a test to verify that CacheExternalDependencies can identify the version of various popular packages + # verify_version_script = test_dir / "verify_versions.py" + # result_version = subprocess.run( + # [python, str(verify_version_script)], + # capture_output=True, + # text=True, + # check=True + # ) + + # assert result_version.returncode == 0, f"Version verification failed: {result_version.stderr}" + + # Run a test to verify that CacheExternalDependencies cen identify packages used in a complex repo + verify_packages_script = test_dir / "verify_identified_packages.py" + result_package = subprocess.run( + [python, str(verify_packages_script)], capture_output=True, text=True, check=True ) - assert result.returncode == 0, f"Version verification failed: {result.stderr}" + assert result_package.returncode == 0, f"Package verification failed: {result_package.stderr}" finally: import shutil diff --git a/plugins/flytekit-auto-cache/tests/verify_identified_packages.py b/plugins/flytekit-auto-cache/tests/verify_identified_packages.py new file mode 100644 index 0000000000..c4e1639486 --- /dev/null +++ b/plugins/flytekit-auto-cache/tests/verify_identified_packages.py @@ -0,0 +1,14 @@ +from flytekitplugins.auto_cache import CacheExternalDependencies +from my_package.main import my_main_function as func + + +def main(): + cache = CacheExternalDependencies(salt="salt", root_dir="./my_package") + _ = cache._get_function_dependencies(func, set()) + packages = cache.get_version_dict().keys() + + expected_packages = {'PIL', 'bs4', 'numpy', 'pandas', 'scipy', 'sklearn'} + set(packages) == expected_packages, f"Expected keys {expected_packages}, but got {set(packages)}" + +if __name__ == "__main__": + main() From b4911abcf82ae2d8ea77c9099c33c9d397f55fef Mon Sep 17 00:00:00 2001 From: Daniel Sola Date: Thu, 28 Nov 2024 14:53:28 -0800 Subject: [PATCH 12/16] account for constants in basic recursive cache Signed-off-by: Daniel Sola --- .../auto_cache/cache_external_dependencies.py | 5 +- .../auto_cache/cache_private_modules.py | 141 +++++++++++++++++- .../tests/my_package/main.py | 4 +- .../tests/my_package/module_a.py | 2 + .../tests/my_package/module_c.py | 4 +- .../tests/my_package/module_d.py | 4 + .../tests/my_package/my_dir/module_in_dir.py | 1 + .../tests/my_package/utils.py | 2 + .../tests/verify_identified_packages.py | 5 + 9 files changed, 158 insertions(+), 10 deletions(-) create mode 100644 plugins/flytekit-auto-cache/tests/my_package/utils.py diff --git a/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_external_dependencies.py b/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_external_dependencies.py index 3814143530..a61d4a9728 100644 --- a/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_external_dependencies.py +++ b/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_external_dependencies.py @@ -52,7 +52,10 @@ def get_version(self, params: VersionParameters) -> str: def _is_user_defined(self, obj: Any) -> bool: """Check if a callable or class is user-defined within the package.""" - module_name = getattr(obj, "__module__", None) + if isinstance(obj, type(sys)): # Check if the object is a module + module_name = obj.__name__ + else: + module_name = getattr(obj, "__module__", None) if not module_name: return False diff --git a/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_private_modules.py b/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_private_modules.py index e055f388cc..bcb0ddad51 100644 --- a/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_private_modules.py +++ b/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_private_modules.py @@ -25,6 +25,7 @@ class CachePrivateModules: def __init__(self, salt: str, root_dir: str): self.salt = salt self.root_dir = Path(root_dir).resolve() + self.constants = {} def get_version(self, params: VersionParameters) -> str: if params.func is None: @@ -34,6 +35,8 @@ def get_version(self, params: VersionParameters) -> str: dependencies = self._get_function_dependencies(params.func, set()) for dep in dependencies: hash_components.append(self._get_version(dep)) + for key, value in self.constants.items(): + hash_components.append(f"{key}={value}") # Combine all component hashes into a single version hash combined_hash = hashlib.sha256("".join(hash_components).encode("utf-8")).hexdigest() return combined_hash @@ -46,8 +49,25 @@ def _get_version(self, func: Callable[..., Any]) -> str: combined_data = ast_bytes + self.salt.encode("utf-8") return hashlib.sha256(combined_data).hexdigest() - def _get_function_dependencies(self, func: Callable[..., Any], visited: Set[str]) -> Set[Callable[..., Any]]: - """Recursively gather all functions, methods, and classes used within `func` and defined in the user’s package.""" + def _get_function_dependencies( + self, func: Callable[..., Any], visited: Set[str], class_attributes: dict = None + ) -> Set[Callable[..., Any]]: + """ + Recursively gather all functions, methods, and classes used within `func` and defined in the user’s package. + + This method walks through the Abstract Syntax Tree (AST) of the given function to identify all imported modules, + functions, methods, and classes. It then checks each function call to identify the dependencies. The method + also extracts literal constants from the function's global namespace and imported modules. + + Parameters: + - func: The function for which to gather dependencies. + - visited: A set to keep track of visited functions to avoid infinite recursion. + - class_attributes: A dictionary of attributes if the func is a method from a class. + + Returns: + - A set of all dependencies found. + """ + dependencies = set() # Dedent the source code to handle class method indentation source = textwrap.dedent(inspect.getsource(func)) @@ -55,13 +75,22 @@ def _get_function_dependencies(self, func: Callable[..., Any], visited: Set[str] # Build a locals dictionary for function-level imports locals_dict = {} + constant_imports = {} + if class_attributes: + constant_imports.update(class_attributes) + for node in ast.walk(parsed_ast): if isinstance(node, ast.Import): for alias in node.names: module = importlib.import_module(alias.name) locals_dict[alias.asname or alias.name] = module + module_constants = self.get_module_literal_constants(module) + constant_imports.update( + {f"{alias.asname or alias.name}.{name}": value for name, value in module_constants.items()} + ) elif isinstance(node, ast.ImportFrom): - module = importlib.import_module(node.module) + module_name = node.module + module = importlib.import_module(module_name) for alias in node.names: # Resolve attributes or submodules imported_obj = getattr(module, alias.name, None) @@ -69,8 +98,22 @@ def _get_function_dependencies(self, func: Callable[..., Any], visited: Set[str] locals_dict[alias.asname or alias.name] = imported_obj else: # Fallback: attempt to import as submodule. e.g. `from PIL import Image` - submodule = importlib.import_module(f"{node.module}.{alias.name}") + submodule = importlib.import_module(f"{module_name}.{alias.name}") locals_dict[alias.asname or alias.name] = submodule + # If it's a module, find its constants + if inspect.ismodule(imported_obj): + module_constants = self.get_module_literal_constants(imported_obj) + constant_imports.update( + { + f"{alias.asname or alias.name}.{name}": value + for name, value in module_constants.items() + } + ) + # If the import itself is a constant, add it + elif self.is_literal_constant(imported_obj): + constant_imports.update({f"{module_name}.{alias.asname or alias.name}": imported_obj}) + + global_constants = {key: value for key, value in func.__globals__.items() if self.is_literal_constant(value)} # Check each function call in the AST for node in ast.walk(parsed_ast): @@ -80,17 +123,22 @@ def _get_function_dependencies(self, func: Callable[..., Any], visited: Set[str] visited.add(func_name) try: # Attempt to resolve using locals first, then globals - # func_obj = locals_dict.get(func_name) or self._resolve_callable(func_name, func.__globals__) func_obj = self._resolve_callable(func_name, locals_dict) or self._resolve_callable( func_name, func.__globals__ ) if inspect.isclass(func_obj) and self._is_user_defined(func_obj): + # Add class attributes as potential constants + current_class_attributes = { + f"class.{func_name}.{name}": value for name, value in func_obj.__dict__.items() + } # Add class methods as dependencies for name, method in inspect.getmembers(func_obj, predicate=inspect.isfunction): if method not in visited: visited.add(method.__qualname__) dependencies.add(method) - dependencies.update(self._get_function_dependencies(method, visited)) + dependencies.update( + self._get_function_dependencies(method, visited, current_class_attributes) + ) elif (inspect.isfunction(func_obj) or inspect.ismethod(func_obj)) and self._is_user_defined( func_obj ): @@ -98,8 +146,84 @@ def _get_function_dependencies(self, func: Callable[..., Any], visited: Set[str] dependencies.update(self._get_function_dependencies(func_obj, visited)) except (NameError, AttributeError): pass + + referenced_constants = self.get_referenced_constants( + func=func, constant_imports=constant_imports, global_constants=global_constants + ) + self.constants.update(referenced_constants) + return dependencies + def is_literal_constant(self, value): + """ + Check if a value is a literal constant + + Supports basic immutable types and nested structures of those types + """ + # Basic immutable types + literal_types = (int, float, str, bool, type(None), complex, tuple, frozenset) + + # Recursively check for literals + def _is_literal(val): + # Direct type check + if isinstance(val, literal_types): + return True + + # Check nested structures + if isinstance(val, (tuple, list, frozenset)): + return all(_is_literal(item) for item in val) + + return False + + return _is_literal(value) + + def get_module_literal_constants(self, module): + """ + Find all literal constants in a module + + Uses module's __dict__ to find uppercase attributes that are literals + """ + constants = {} + for name, value in module.__dict__.items(): + # Check for uppercase name (convention for constants) + # and verify it's a literal + if self.is_literal_constant(value): + constants[name] = value + return constants + + def get_referenced_constants(self, func, constant_imports=None, global_constants=None): + """ + Find constants that are actually referenced in the function + + :param func: The function to analyze + :param constant_imports: Dictionary of potential constant imports + :param global_constants: Dictionary of global constants + :return: Dictionary of referenced constants + """ + referenced_constants = {} + source_code = inspect.getsource(func) + dedented_source = textwrap.dedent(source_code) + + # Check imported constants + if constant_imports: + for name, value in constant_imports.items(): + name_to_search, final_name = name, name + name_parts = name.split(".") + # If the constant is a class attribute, use the class name in the final constant name but search for "self." + if len(name_parts) == 3 and name_parts[0] == "class": + name_to_search = f"self.{name_parts[2]}" + final_name = f"{name_parts[1]}.{name_parts[2]}" + if name_to_search in dedented_source: + referenced_constants[final_name] = str(value) + + # Check global constants + if global_constants: + for name, value in global_constants.items(): + if name in dedented_source: + referenced_constants[name] = str(value) + + return referenced_constants + def _get_callable_name(self, node: ast.AST) -> Union[str, None]: """Retrieve the name of the callable from an AST node.""" if isinstance(node, ast.Name): @@ -137,7 +261,10 @@ def _resolve_callable(self, func_name: str, globals_dict: dict) -> Callable[..., def _is_user_defined(self, obj: Any) -> bool: """Check if a callable or class is user-defined within the package.""" - module_name = getattr(obj, "__module__", None) + if isinstance(obj, type(sys)): # Check if the object is a module + module_name = obj.__name__ + else: + module_name = getattr(obj, "__module__", None) if not module_name: return False diff --git a/plugins/flytekit-auto-cache/tests/my_package/main.py b/plugins/flytekit-auto-cache/tests/my_package/main.py index fa113e44f4..1ea6b1bf8a 100644 --- a/plugins/flytekit-auto-cache/tests/my_package/main.py +++ b/plugins/flytekit-auto-cache/tests/my_package/main.py @@ -9,6 +9,8 @@ from module_c import DummyClass import pandas as pd # External library +from utils import SOME_CONSTANT + def my_main_function(): print("Main function") helper_in_directory() @@ -16,5 +18,5 @@ def my_main_function(): df = pd.DataFrame({"a": [1, 2, 3]}) print(df) dc = DummyClass() - print(dc) dc.dummy_method() + sum([SOME_CONSTANT, 1]) diff --git a/plugins/flytekit-auto-cache/tests/my_package/module_a.py b/plugins/flytekit-auto-cache/tests/my_package/module_a.py index 03a687eb87..a9863fc55e 100644 --- a/plugins/flytekit-auto-cache/tests/my_package/module_a.py +++ b/plugins/flytekit-auto-cache/tests/my_package/module_a.py @@ -1,12 +1,14 @@ import module_b from scipy.linalg import norm from cryptography.fernet import Fernet +from utils import SOME_CONSTANT def helper_function(): print("Helper function") module_b.another_helper() result = norm([1, 2, 3]) print(result) + sum([SOME_CONSTANT, 1]) def unused_function(): print("Unused function") diff --git a/plugins/flytekit-auto-cache/tests/my_package/module_c.py b/plugins/flytekit-auto-cache/tests/my_package/module_c.py index a02cc26c29..6562bd2276 100644 --- a/plugins/flytekit-auto-cache/tests/my_package/module_c.py +++ b/plugins/flytekit-auto-cache/tests/my_package/module_c.py @@ -4,11 +4,13 @@ def third_helper(): print("Third helper") class DummyClass: + some_attr = "some_custom_attr" + def dummy_method(self) -> str: my_dir.other_helper_in_directory() import numpy as np print(np.mean(np.array([1, 2, 3, 4, 5]))) - return "Hello from dummy method!" + return f"{self.some_attr}" def other_dummy_method(self): from module_d import fourth_helper diff --git a/plugins/flytekit-auto-cache/tests/my_package/module_d.py b/plugins/flytekit-auto-cache/tests/my_package/module_d.py index 92db897762..6533938e58 100644 --- a/plugins/flytekit-auto-cache/tests/my_package/module_d.py +++ b/plugins/flytekit-auto-cache/tests/my_package/module_d.py @@ -1,2 +1,6 @@ def fourth_helper(): print("Fourth helper") + import yaml + print(yaml.__version__) + import my_dir.module_in_dir as mod + print(mod.SOME_OTHER_CONSTANT) diff --git a/plugins/flytekit-auto-cache/tests/my_package/my_dir/module_in_dir.py b/plugins/flytekit-auto-cache/tests/my_package/my_dir/module_in_dir.py index a8279c677a..f756acfa64 100644 --- a/plugins/flytekit-auto-cache/tests/my_package/my_dir/module_in_dir.py +++ b/plugins/flytekit-auto-cache/tests/my_package/my_dir/module_in_dir.py @@ -1,4 +1,5 @@ from sklearn.preprocessing import StandardScaler +SOME_OTHER_CONSTANT = 222 def helper_in_directory(): print("Helper in directory") diff --git a/plugins/flytekit-auto-cache/tests/my_package/utils.py b/plugins/flytekit-auto-cache/tests/my_package/utils.py new file mode 100644 index 0000000000..942d164625 --- /dev/null +++ b/plugins/flytekit-auto-cache/tests/my_package/utils.py @@ -0,0 +1,2 @@ + +SOME_CONSTANT = 111 diff --git a/plugins/flytekit-auto-cache/tests/verify_identified_packages.py b/plugins/flytekit-auto-cache/tests/verify_identified_packages.py index c4e1639486..7748d49a5c 100644 --- a/plugins/flytekit-auto-cache/tests/verify_identified_packages.py +++ b/plugins/flytekit-auto-cache/tests/verify_identified_packages.py @@ -10,5 +10,10 @@ def main(): expected_packages = {'PIL', 'bs4', 'numpy', 'pandas', 'scipy', 'sklearn'} set(packages) == expected_packages, f"Expected keys {expected_packages}, but got {set(packages)}" + expected_constants = {'SOME_CONSTANT': '111', 'DummyClass.some_attr': 'some_custom_attr', 'yaml.__version__': '6.0', 'mod.SOME_OTHER_CONSTANT': '222'} + assert set(cache.constants.keys()) == set(expected_constants.keys()), f"Expected constants keys {set(expected_constants.keys())}, but got {set(cache.constants.keys())}" + for key in expected_constants: + assert cache.constants[key] == expected_constants[key], f"Expected value for {key} to be {expected_constants[key]}, but got {cache.constants[key]}" + if __name__ == "__main__": main() From 02f6b53c238bfdc71797f0afd87b6c9aa3950b06 Mon Sep 17 00:00:00 2001 From: Daniel Sola Date: Mon, 2 Dec 2024 15:09:40 -0800 Subject: [PATCH 13/16] add comments Signed-off-by: Daniel Sola --- .../auto_cache/cache_external_dependencies.py | 6 +- .../auto_cache/cache_function_body.py | 2 +- .../flytekitplugins/auto_cache/cache_image.py | 31 ++- .../auto_cache/cache_private_modules.py | 185 ++++++++++++------ .../tests/my_package/module_a.py | 3 +- .../tests/my_package/module_d.py | 2 + .../tests/my_package/utils.py | 2 + .../tests/test_external_dependencies.py | 18 +- .../tests/test_recursive.py | 2 + .../tests/verify_identified_packages.py | 2 +- 10 files changed, 181 insertions(+), 72 deletions(-) diff --git a/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_external_dependencies.py b/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_external_dependencies.py index a61d4a9728..72a8813c4a 100644 --- a/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_external_dependencies.py +++ b/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_external_dependencies.py @@ -51,7 +51,11 @@ def get_version(self, params: VersionParameters) -> str: return hashlib.sha256(combined_data).hexdigest() def _is_user_defined(self, obj: Any) -> bool: - """Check if a callable or class is user-defined within the package.""" + """ + Similar to the parent, this method checks if a callable or class is user-defined within the package. + If it identifies a non-user-defined package, it adds the external dependency to a list of packages + for which we will check their versions and hash. + """ if isinstance(obj, type(sys)): # Check if the object is a module module_name = obj.__name__ else: diff --git a/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_function_body.py b/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_function_body.py index 830aee94f0..42f485afea 100644 --- a/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_function_body.py +++ b/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_function_body.py @@ -20,7 +20,7 @@ class CacheFunctionBody: Given a function, generates a version hash based on its source code and the salt. """ - def __init__(self, salt: str = "salt") -> None: + def __init__(self, salt: str = "") -> None: """ Initialize the CacheFunctionBody instance with a salt value. diff --git a/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_image.py b/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_image.py index cf57b645ce..e68d2cfc0a 100644 --- a/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_image.py +++ b/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_image.py @@ -5,10 +5,39 @@ class CacheImage: - def __init__(self, salt: str): + """ + A class that generates a version hash given a container image. + + Attributes: + salt (str): A string used to add uniqueness to the generated hash. Defaults to an empty string. + + Methods: + get_version(params: VersionParameters) -> str: + Given a VersionParameters object, generates a version hash based on the container_image and the salt. + """ + + def __init__(self, salt: str = ""): + """ + Initialize the CacheImage instance with a salt value. + + Args: + salt (str): A string to be used as the salt in the hashing process. Defaults to an empty string. + """ self.salt = salt def get_version(self, params: VersionParameters) -> str: + """ + Generates a version hash for the container image specified in the VersionParameters object. + + Args: + params (VersionParameters): An object containing the container_image parameter. + + Returns: + str: The SHA-256 hash of the container_image combined with the salt. + + Raises: + ValueError: If the container_image parameter is None. + """ if params.container_image is None: raise ValueError("Image-based cache requires a container_image parameter") diff --git a/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_private_modules.py b/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_private_modules.py index bcb0ddad51..aabd0022fc 100644 --- a/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_private_modules.py +++ b/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_private_modules.py @@ -8,6 +8,9 @@ from pathlib import Path from typing import Any, Callable, Set, Union +import click +from flytekitplugins.auto_cache.cache_function_body import CacheFunctionBody + from flytekit.core.auto_cache import VersionParameters @@ -21,117 +24,167 @@ def temporarily_add_to_syspath(path): sys.path.pop(0) -class CachePrivateModules: +class CachePrivateModules(CacheFunctionBody): + """ + A class that extends CacheFunctionBody to cache private modules and their dependencies. + + It extends the functionality to recursively follow all callables, making a list of all functions, + classes, and methods that are eventually used by the initial function of interest. Only functions + internal to this package are included, while externally imported packages are ignored. It handles + both import and from-import statements, as well as aliases, for both top-level imports global to + the module and local imports to the function or method. Additionally, it identifies constants at + the same import levels described above, accounting for both constants defined in the internal + package and external packages. The contents of all these functions are hashed using the same + logic as CacheFunctionBody, and the constants and their values are also fed into the hash. + + Attributes: + salt (str): A string used to add uniqueness to the generated hash. + root_dir (Path): The root directory of the project, used to resolve module paths. + constants (dict): A dictionary to store constants that are part of the versioning process. + """ + def __init__(self, salt: str, root_dir: str): + """ + Initialize the CachePrivateModules instance with a salt value and a root directory. + + Args: + salt (str): A string to be used as the salt in the hashing process. + root_dir (str): The root directory of the project, used to resolve module paths. + """ self.salt = salt self.root_dir = Path(root_dir).resolve() self.constants = {} def get_version(self, params: VersionParameters) -> str: + """ + Generates a version hash for the provided function and its dependencies. + + This method recursively identifies all callables (functions, methods, classes) used by the provided function, + hashes their contents, and combines these hashes with the hashes of all constants and their values identified. + The resulting combined hash is then returned as the version string. The user provided salt is by `_get_version` + of the parent CacheFunctionBody. + + Args: + params (VersionParameters): An object containing the function parameter. + + Returns: + str: The SHA-256 hash of the combined hashes of the function, its dependencies, and constants. + + Raises: + ValueError: If the function parameter is None. + """ if params.func is None: raise ValueError("Function-based cache requires a function parameter") - hash_components = [self._get_version(params.func)] + # Initialize a list to hold all hash components + hash_components = [self._get_version(params.func)] # Start with the hash of the provided function + # Identify all dependencies of the provided function dependencies = self._get_function_dependencies(params.func, set()) + # Hash each dependency and add to the list of hash components for dep in dependencies: hash_components.append(self._get_version(dep)) + # Add hashes of constants and their values to the list of hash components for key, value in self.constants.items(): hash_components.append(f"{key}={value}") # Combine all component hashes into a single version hash combined_hash = hashlib.sha256("".join(hash_components).encode("utf-8")).hexdigest() return combined_hash - def _get_version(self, func: Callable[..., Any]) -> str: - source = inspect.getsource(func) - dedented_source = textwrap.dedent(source) - parsed_ast = ast.parse(dedented_source) - ast_bytes = ast.dump(parsed_ast).encode("utf-8") - combined_data = ast_bytes + self.salt.encode("utf-8") - return hashlib.sha256(combined_data).hexdigest() + def _get_alias_name(self, alias: ast.alias) -> str: + """ + Extracts the alias name from an AST alias node. + + This method takes an AST alias node and returns its alias name. If the alias has an 'asname', it returns the 'asname'; + otherwise, it returns the 'name' of the alias. + + Args: + alias (ast.alias): The AST alias node from which to extract the alias name. + + Returns: + str: The alias name or the name of the alias if 'asname' is not provided. + """ + return alias.asname or alias.name def _get_function_dependencies( self, func: Callable[..., Any], visited: Set[str], class_attributes: dict = None ) -> Set[Callable[..., Any]]: """ - Recursively gather all functions, methods, and classes used within `func` and defined in the user’s package. + Recursively identifies all functions, methods, and classes used within `func` and defined in the user’s package. - This method walks through the Abstract Syntax Tree (AST) of the given function to identify all imported modules, - functions, methods, and classes. It then checks each function call to identify the dependencies. The method - also extracts literal constants from the function's global namespace and imported modules. + This method traverses the Abstract Syntax Tree (AST) of the given function to identify all imported modules, + functions, methods, and classes. It then inspects each function call to identify the dependencies. Additionally, + the method extracts literal constants from the function's global namespace and imported modules. - Parameters: - - func: The function for which to gather dependencies. - - visited: A set to keep track of visited functions to avoid infinite recursion. - - class_attributes: A dictionary of attributes if the func is a method from a class. + Args: + func (Callable[..., Any]): The function for which to gather dependencies. + visited (Set[str]): A set to keep track of visited functions to avoid infinite recursion. + class_attributes (dict, optional): A dictionary of attributes if the func is a method from a class. Returns: - - A set of all dependencies found. + Set[Callable[..., Any]]: A set of all dependencies found. """ dependencies = set() - # Dedent the source code to handle class method indentation source = textwrap.dedent(inspect.getsource(func)) parsed_ast = ast.parse(source) - # Build a locals dictionary for function-level imports + # Initialize a dictionary to mimic the function's global namespace for locally defined imports locals_dict = {} + # Initialize a dictionary to hold constant imports and class attributes constant_imports = {} + # If class attributes are provided, include them in the constant imports if class_attributes: constant_imports.update(class_attributes) + # Check each function call in the AST for node in ast.walk(parsed_ast): if isinstance(node, ast.Import): + # For each alias in the import statement, we import the module and add it to the locals_dict. + # This is because the module itself is being imported, not a specific attribute or function. for alias in node.names: module = importlib.import_module(alias.name) - locals_dict[alias.asname or alias.name] = module + locals_dict[self._get_alias_name(alias)] = module + # We then get all the literal constants defined in the module's __init__.py file. + # These constants are later checked for usage within the function. module_constants = self.get_module_literal_constants(module) constant_imports.update( - {f"{alias.asname or alias.name}.{name}": value for name, value in module_constants.items()} + {f"{self._get_alias_name(alias)}.{name}": value for name, value in module_constants.items()} ) elif isinstance(node, ast.ImportFrom): module_name = node.module module = importlib.import_module(module_name) for alias in node.names: - # Resolve attributes or submodules + # Attempt to resolve the imported object directly from the module imported_obj = getattr(module, alias.name, None) if imported_obj: - locals_dict[alias.asname or alias.name] = imported_obj + # If the object is found directly in the module, add it to the locals_dict + locals_dict[self._get_alias_name(alias)] = imported_obj + # Check if the imported object is a literal constant and add it to constant_imports if so + if self.is_literal_constant(imported_obj): + constant_imports.update({f"{self._get_alias_name(alias)}": imported_obj}) else: - # Fallback: attempt to import as submodule. e.g. `from PIL import Image` + # If the object is not found directly in the module, attempt to import it as a submodule + # This is necessary for cases like `from PIL import Image`, where Image is not imported in PIL's __init__.py + # PIL and similar packages use different mechanisms to expose their objects, requiring this fallback approach submodule = importlib.import_module(f"{module_name}.{alias.name}") - locals_dict[alias.asname or alias.name] = submodule - # If it's a module, find its constants - if inspect.ismodule(imported_obj): - module_constants = self.get_module_literal_constants(imported_obj) - constant_imports.update( - { - f"{alias.asname or alias.name}.{name}": value - for name, value in module_constants.items() - } - ) - # If the import itself is a constant, add it - elif self.is_literal_constant(imported_obj): - constant_imports.update({f"{module_name}.{alias.asname or alias.name}": imported_obj}) - - global_constants = {key: value for key, value in func.__globals__.items() if self.is_literal_constant(value)} + imported_obj = getattr(submodule, alias.name, None) + locals_dict[self._get_alias_name(alias)] = imported_obj - # Check each function call in the AST - for node in ast.walk(parsed_ast): - if isinstance(node, ast.Call): + elif isinstance(node, ast.Call): + # Add callable to the set of dependencies if it's user defined and continue the recursive search within those callables. func_name = self._get_callable_name(node.func) if func_name and func_name not in visited: visited.add(func_name) try: - # Attempt to resolve using locals first, then globals + # Attempt to resolve the callable object using locals first, then globals func_obj = self._resolve_callable(func_name, locals_dict) or self._resolve_callable( func_name, func.__globals__ ) + # If the callable is a class and user-defined, we add and search all method. We also include attributes as potential constants. if inspect.isclass(func_obj) and self._is_user_defined(func_obj): - # Add class attributes as potential constants current_class_attributes = { f"class.{func_name}.{name}": value for name, value in func_obj.__dict__.items() } - # Add class methods as dependencies for name, method in inspect.getmembers(func_obj, predicate=inspect.isfunction): if method not in visited: visited.add(method.__qualname__) @@ -139,14 +192,27 @@ def _get_function_dependencies( dependencies.update( self._get_function_dependencies(method, visited, current_class_attributes) ) + # If the callable is a function or method and user-defined, add it as a dependency and search its dependencies elif (inspect.isfunction(func_obj) or inspect.ismethod(func_obj)) and self._is_user_defined( func_obj ): + # Add the function or method as a dependency dependencies.add(func_obj) + # Recursively search the function or method's dependencies dependencies.update(self._get_function_dependencies(func_obj, visited)) - except (NameError, AttributeError): - pass - + except (NameError, AttributeError) as e: + click.secho(f"Could not process the callable {func_name} due to error: {str(e)}", fg="yellow") + + # Extract potential constants from the global import context + global_constants = {} + for key, value in func.__globals__.items(): + if hasattr(value, "__dict__"): + module_constants = self.get_module_literal_constants(value) + global_constants.update({f"{key}.{name}": value for name, value in module_constants.items()}) + elif self.is_literal_constant(value): + global_constants[key] = value + + # Check for the usage of all potnential constants and update the set of constants to be hashed referenced_constants = self.get_referenced_constants( func=func, constant_imports=constant_imports, global_constants=global_constants ) @@ -193,12 +259,15 @@ def get_module_literal_constants(self, module): def get_referenced_constants(self, func, constant_imports=None, global_constants=None): """ - Find constants that are actually referenced in the function + Identifies constants that are actually used within the given function. + + Args: + func: The function to be analyzed for constant references. + constant_imports: A dictionary containing potential constant imports. + global_constants: A dictionary of global constants. - :param func: The function to analyze - :param constant_imports: Dictionary of potential constant imports - :param global_constants: Dictionary of global constants - :return: Dictionary of referenced constants + Returns: + A dictionary containing the constants that are actually referenced within the function. """ referenced_constants = {} source_code = inspect.getsource(func) @@ -232,12 +301,12 @@ def _get_callable_name(self, node: ast.AST) -> Union[str, None]: return f"{node.value.id}.{node.attr}" if isinstance(node.value, ast.Name) else node.attr return None - def _resolve_callable(self, func_name: str, globals_dict: dict) -> Callable[..., Any]: - """Resolve a callable from its name within the given globals dictionary, handling modules as entry points.""" + def _resolve_callable(self, func_name: str, locals_globals_dict: dict) -> Callable[..., Any]: + """Resolve a callable from its name within the given locals/globals dictionary, handling modules as entry points.""" parts = func_name.split(".") # First, try resolving directly from globals_dict for a straightforward reference - obj = globals_dict.get(parts[0], None) + obj = locals_globals_dict.get(parts[0], None) for part in parts[1:]: if obj is None: break @@ -245,7 +314,7 @@ def _resolve_callable(self, func_name: str, globals_dict: dict) -> Callable[..., # If not found, iterate through modules in globals_dict and attempt resolution from them if not callable(obj): - for module in globals_dict.values(): + for module in locals_globals_dict.values(): if isinstance(module, type(sys)): # Check if the global value is a module obj = module for part in parts: diff --git a/plugins/flytekit-auto-cache/tests/my_package/module_a.py b/plugins/flytekit-auto-cache/tests/my_package/module_a.py index a9863fc55e..2ab9985ee4 100644 --- a/plugins/flytekit-auto-cache/tests/my_package/module_a.py +++ b/plugins/flytekit-auto-cache/tests/my_package/module_a.py @@ -2,13 +2,14 @@ from scipy.linalg import norm from cryptography.fernet import Fernet from utils import SOME_CONSTANT +import utils def helper_function(): print("Helper function") module_b.another_helper() result = norm([1, 2, 3]) print(result) - sum([SOME_CONSTANT, 1]) + sum([SOME_CONSTANT, utils.THIRD_CONSTANT]) def unused_function(): print("Unused function") diff --git a/plugins/flytekit-auto-cache/tests/my_package/module_d.py b/plugins/flytekit-auto-cache/tests/my_package/module_d.py index 6533938e58..315167c555 100644 --- a/plugins/flytekit-auto-cache/tests/my_package/module_d.py +++ b/plugins/flytekit-auto-cache/tests/my_package/module_d.py @@ -4,3 +4,5 @@ def fourth_helper(): print(yaml.__version__) import my_dir.module_in_dir as mod print(mod.SOME_OTHER_CONSTANT) + from utils import OTHER_CONSTANT as MY_OTHER_CONSTANT + print(MY_OTHER_CONSTANT) diff --git a/plugins/flytekit-auto-cache/tests/my_package/utils.py b/plugins/flytekit-auto-cache/tests/my_package/utils.py index 942d164625..2064c09e4a 100644 --- a/plugins/flytekit-auto-cache/tests/my_package/utils.py +++ b/plugins/flytekit-auto-cache/tests/my_package/utils.py @@ -1,2 +1,4 @@ SOME_CONSTANT = 111 +OTHER_CONSTANT = 999 +THIRD_CONSTANT = 1010 diff --git a/plugins/flytekit-auto-cache/tests/test_external_dependencies.py b/plugins/flytekit-auto-cache/tests/test_external_dependencies.py index 4f4f96318a..7734833fb9 100644 --- a/plugins/flytekit-auto-cache/tests/test_external_dependencies.py +++ b/plugins/flytekit-auto-cache/tests/test_external_dependencies.py @@ -28,15 +28,15 @@ def test_package_versions_in_isolated_env(): python = str(venv_path / "bin" / "python") # Run a test to verify that CacheExternalDependencies can identify the version of various popular packages - # verify_version_script = test_dir / "verify_versions.py" - # result_version = subprocess.run( - # [python, str(verify_version_script)], - # capture_output=True, - # text=True, - # check=True - # ) - - # assert result_version.returncode == 0, f"Version verification failed: {result_version.stderr}" + verify_version_script = test_dir / "verify_versions.py" + result_version = subprocess.run( + [python, str(verify_version_script)], + capture_output=True, + text=True, + check=True + ) + + assert result_version.returncode == 0, f"Version verification failed: {result_version.stderr}" # Run a test to verify that CacheExternalDependencies cen identify packages used in a complex repo verify_packages_script = test_dir / "verify_identified_packages.py" diff --git a/plugins/flytekit-auto-cache/tests/test_recursive.py b/plugins/flytekit-auto-cache/tests/test_recursive.py index 1dbbf56a83..df53ee2034 100644 --- a/plugins/flytekit-auto-cache/tests/test_recursive.py +++ b/plugins/flytekit-auto-cache/tests/test_recursive.py @@ -28,3 +28,5 @@ def test_dependencies(): f"Expected: {expected_dependencies}\n" f"Actual: {actual_dependencies_str}" ) + +test_dependencies() diff --git a/plugins/flytekit-auto-cache/tests/verify_identified_packages.py b/plugins/flytekit-auto-cache/tests/verify_identified_packages.py index 7748d49a5c..1963d7f82e 100644 --- a/plugins/flytekit-auto-cache/tests/verify_identified_packages.py +++ b/plugins/flytekit-auto-cache/tests/verify_identified_packages.py @@ -10,7 +10,7 @@ def main(): expected_packages = {'PIL', 'bs4', 'numpy', 'pandas', 'scipy', 'sklearn'} set(packages) == expected_packages, f"Expected keys {expected_packages}, but got {set(packages)}" - expected_constants = {'SOME_CONSTANT': '111', 'DummyClass.some_attr': 'some_custom_attr', 'yaml.__version__': '6.0', 'mod.SOME_OTHER_CONSTANT': '222'} + expected_constants = {'SOME_CONSTANT': '111', 'utils.THIRD_CONSTANT': '1010', 'DummyClass.some_attr': 'some_custom_attr', 'yaml.__version__': '6.0', 'mod.SOME_OTHER_CONSTANT': '222', 'MY_OTHER_CONSTANT': '999'} assert set(cache.constants.keys()) == set(expected_constants.keys()), f"Expected constants keys {set(expected_constants.keys())}, but got {set(cache.constants.keys())}" for key in expected_constants: assert cache.constants[key] == expected_constants[key], f"Expected value for {key} to be {expected_constants[key]}, but got {cache.constants[key]}" From 47d01ef464d34ddd50af9d79aa0df4434b27e893 Mon Sep 17 00:00:00 2001 From: Daniel Sola Date: Mon, 2 Dec 2024 15:37:47 -0800 Subject: [PATCH 14/16] readme Signed-off-by: Daniel Sola --- plugins/flytekit-auto-cache/README.md | 49 ++++++++++++++++++- .../auto_cache/cache_external_dependencies.py | 2 +- .../auto_cache/cache_private_modules.py | 2 +- .../tests/test_external_dependencies.py | 3 -- .../tests/test_recursive.py | 2 - 5 files changed, 50 insertions(+), 8 deletions(-) diff --git a/plugins/flytekit-auto-cache/README.md b/plugins/flytekit-auto-cache/README.md index 76d0e1f853..da4f110d12 100644 --- a/plugins/flytekit-auto-cache/README.md +++ b/plugins/flytekit-auto-cache/README.md @@ -1,9 +1,56 @@ -# Flytekit Auto Cache Plugin +# Flyte Auto Cache Plugin +This plugin provides a caching mechanism for Flyte tasks that generates a version hash based on the source code of the task and its dependencies. It allows users to manage the cache behavior. +## Usage To install the plugin, run the following command: ```bash pip install flytekitplugins-auto-cache ``` + +To use the caching mechanism in a Flyte task, you can define a `CachePolicy` that combines multiple caching strategies. Here’s an example of how to set it up: + +```python +from flytekit import task +from flytekit.core.auto_cache import CachePolicy +from flytekitplugins.auto_cache import CacheFunctionBody, CachePrivateModules + +cache_policy = CachePolicy( + CacheFunctionBody(), + CachePrivateModules(root_dir="../my_package"), + ..., + salt="my_salt" +) + +@task(cache=cache_policy) +def task_fn(): + ... +``` + +### Salt Parameter + +The `salt` parameter in the `CachePolicy` adds uniqueness to the generated hash. It can be used to differentiate between different versions of the same task. This ensures that even if the underlying code remains unchanged, the hash will vary if a different salt is provided. This feature is particularly useful for invalidating the cache for specific versions of a task. + +## Cache Implementations + +Users can add any number of cache policies that implement the `AutoCache` protocol defined in `@auto_cache.py`. Below are the implementations available so far: + +### 1. CacheFunctionBody + +This implementation hashes the contents of the function of interest, ignoring any formatting or comment changes. It ensures that the core logic of the function is considered for versioning. + +### 2. CacheImage + +This implementation includes the hash of the `container_image` object passed. If the image is specified as a name, that string is hashed. If it is an `ImageSpec`, the parametrization of the `ImageSpec` is hashed, allowing for precise versioning of the container image used in the task. + +### 3. CachePrivateModules + +This implementation recursively searches the task of interest for all callables and constants used. The contents of any callable (function or class) utilized by the task are hashed, ignoring formatting or comments. The values of the literal constants used are also included in the hash. + +It accounts for both `import` and `from-import` statements at the global and local levels within a module or function. Any callables that are within site-packages (i.e., external libraries) are ignored. + +### 4. CacheExternalDependencies + +This implementation recursively searches through all the callables like `CachePrivateModules`, but when an external package is found, it records the version of the package, which is included in the hash. This ensures that changes in external dependencies are reflected in the task's versioning. diff --git a/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_external_dependencies.py b/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_external_dependencies.py index 72a8813c4a..4da6a02e04 100644 --- a/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_external_dependencies.py +++ b/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_external_dependencies.py @@ -16,7 +16,7 @@ class CacheExternalDependencies(CachePrivateModules): Inherits the dependency traversal logic from CachePrivateModules but focuses on external packages. """ - def __init__(self, salt: str, root_dir: str): + def __init__(self, root_dir: str, salt: str = ""): super().__init__(salt=salt, root_dir=root_dir) self._package_versions = {} # Cache for package versions self._external_dependencies = set() diff --git a/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_private_modules.py b/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_private_modules.py index aabd0022fc..69b697076d 100644 --- a/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_private_modules.py +++ b/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_private_modules.py @@ -43,7 +43,7 @@ class CachePrivateModules(CacheFunctionBody): constants (dict): A dictionary to store constants that are part of the versioning process. """ - def __init__(self, salt: str, root_dir: str): + def __init__(self, root_dir: str, salt: str = ""): """ Initialize the CachePrivateModules instance with a salt value and a root directory. diff --git a/plugins/flytekit-auto-cache/tests/test_external_dependencies.py b/plugins/flytekit-auto-cache/tests/test_external_dependencies.py index 7734833fb9..a4ef5046b4 100644 --- a/plugins/flytekit-auto-cache/tests/test_external_dependencies.py +++ b/plugins/flytekit-auto-cache/tests/test_external_dependencies.py @@ -52,6 +52,3 @@ def test_package_versions_in_isolated_env(): finally: import shutil shutil.rmtree(venv_path) - -if __name__ == "__main__": - test_package_versions_in_isolated_env() diff --git a/plugins/flytekit-auto-cache/tests/test_recursive.py b/plugins/flytekit-auto-cache/tests/test_recursive.py index df53ee2034..1dbbf56a83 100644 --- a/plugins/flytekit-auto-cache/tests/test_recursive.py +++ b/plugins/flytekit-auto-cache/tests/test_recursive.py @@ -28,5 +28,3 @@ def test_dependencies(): f"Expected: {expected_dependencies}\n" f"Actual: {actual_dependencies_str}" ) - -test_dependencies() From f1ebdc93957878908a2a8b05d911e7c0bdac6dcb Mon Sep 17 00:00:00 2001 From: Daniel Sola Date: Mon, 2 Dec 2024 15:51:01 -0800 Subject: [PATCH 15/16] add other cache params in CachePolicy --- flytekit/core/auto_cache.py | 51 +++++++++++++++++++++++++------------ flytekit/core/task.py | 8 ++++-- 2 files changed, 41 insertions(+), 18 deletions(-) diff --git a/flytekit/core/auto_cache.py b/flytekit/core/auto_cache.py index 2915abb729..4ddb4d2963 100644 --- a/flytekit/core/auto_cache.py +++ b/flytekit/core/auto_cache.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any, Callable, Optional, Protocol, Union, runtime_checkable +from typing import Any, Callable, List, Optional, Protocol, Tuple, Union, runtime_checkable from flytekit.image_spec.image_spec import ImageSpec @@ -45,17 +45,30 @@ class CachePolicy: A class that combines multiple caching mechanisms to generate a version hash. Args: - *cache_objects: Variable number of AutoCache instances - salt: Optional salt string to add uniqueness to the hash + auto_cache_policies: A list of AutoCache instances (optional). + salt: Optional salt string to add uniqueness to the hash. + cache_serialize: Boolean to indicate if serialization should be used. + cache_version: A version string for the cache. + cache_ignore_input_vars: Tuple of input variable names to ignore. """ - def __init__(self, *cache_objects: AutoCache, salt: str = "") -> None: - self.cache_objects = cache_objects + def __init__( + self, + auto_cache_policies: List["AutoCache"] = None, + salt: str = "", + cache_serialize: bool = False, + cache_version: str = "", + cache_ignore_input_vars: Tuple[str, ...] = (), + ) -> None: + self.auto_cache_policies = auto_cache_policies or [] # Use an empty list if None is provided self.salt = salt + self.cache_serialize = cache_serialize + self.cache_version = cache_version + self.cache_ignore_input_vars = cache_ignore_input_vars - def get_version(self, params: VersionParameters) -> str: + def get_version(self, params: "VersionParameters") -> str: """ - Generate a version hash using all cache objects. + Generate a version hash using all cache objects. If the user passes a version, it takes precedence over auto_cache_policies. Args: params (VersionParameters): Parameters to use for hash generation. @@ -63,14 +76,20 @@ def get_version(self, params: VersionParameters) -> str: Returns: str: The combined hash from all cache objects. """ - task_hash = "" - for cache_instance in self.cache_objects: - # Apply the policy's salt to each cache instance - cache_instance.salt = self.salt - task_hash += cache_instance.get_version(params) + if self.cache_version: + return self.cache_version + + if self.auto_cache_policies: + task_hash = "" + for cache_instance in self.auto_cache_policies: + # Apply the policy's salt to each cache instance + cache_instance.salt = self.salt + task_hash += cache_instance.get_version(params) + + # Generate SHA-256 hash + import hashlib - # Generate SHA-256 hash - import hashlib + hash_obj = hashlib.sha256(task_hash.encode()) + return hash_obj.hexdigest() - hash_obj = hashlib.sha256(task_hash.encode()) - return hash_obj.hexdigest() + return None diff --git a/flytekit/core/task.py b/flytekit/core/task.py index 64d4095dda..e0c60f09be 100644 --- a/flytekit/core/task.py +++ b/flytekit/core/task.py @@ -348,15 +348,19 @@ def wrapper(fn: Callable[..., Any]) -> PythonFunctionTask[T]: params = VersionParameters(func=fn, container_image=container_image) cache_version_val = cache.get_version(params=params) cache_val = True + cache_serialize_val = cache_serialize or cache.cache_serialize + cache_serialize_val = cache_ignore_input_vars or cache.cache_ignore_input_vars else: cache_val = cache cache_version_val = cache_version + cache_serialize_val = cache_serialize + cache_ignore_input_vars_val = cache_ignore_input_vars _metadata = TaskMetadata( cache=cache_val, - cache_serialize=cache_serialize, + cache_serialize=cache_serialize_val, cache_version=cache_version_val, - cache_ignore_input_vars=cache_ignore_input_vars, + cache_ignore_input_vars=cache_ignore_input_vars_val, retries=retries, interruptible=interruptible, deprecated=deprecated, From ff3555f6696078e9439adcd887de004d8abec39f Mon Sep 17 00:00:00 2001 From: Daniel Sola Date: Mon, 2 Dec 2024 15:53:52 -0800 Subject: [PATCH 16/16] readme Signed-off-by: Daniel Sola --- flytekit/core/task.py | 4 ++-- plugins/flytekit-auto-cache/README.md | 8 +++++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/flytekit/core/task.py b/flytekit/core/task.py index e0c60f09be..69d8f23d4f 100644 --- a/flytekit/core/task.py +++ b/flytekit/core/task.py @@ -345,9 +345,9 @@ def launch_dynamically(): def wrapper(fn: Callable[..., Any]) -> PythonFunctionTask[T]: if isinstance(cache, CachePolicy): - params = VersionParameters(func=fn, container_image=container_image) - cache_version_val = cache.get_version(params=params) cache_val = True + params = VersionParameters(func=fn, container_image=container_image) + cache_version_val = cache_version or cache.get_version(params=params) cache_serialize_val = cache_serialize or cache.cache_serialize cache_serialize_val = cache_ignore_input_vars or cache.cache_ignore_input_vars else: diff --git a/plugins/flytekit-auto-cache/README.md b/plugins/flytekit-auto-cache/README.md index da4f110d12..acb54c9ff3 100644 --- a/plugins/flytekit-auto-cache/README.md +++ b/plugins/flytekit-auto-cache/README.md @@ -18,9 +18,11 @@ from flytekit.core.auto_cache import CachePolicy from flytekitplugins.auto_cache import CacheFunctionBody, CachePrivateModules cache_policy = CachePolicy( - CacheFunctionBody(), - CachePrivateModules(root_dir="../my_package"), - ..., + auto_cache_policies = [ + CacheFunctionBody(), + CachePrivateModules(root_dir="../my_package"), + ..., + ] salt="my_salt" )