Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deduce cache version iff cache_version is not set #2734

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions flytekit/core/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,6 @@ def __post_init__(self):
self.timeout = datetime.timedelta(seconds=self.timeout)
elif not isinstance(self.timeout, datetime.timedelta):
raise ValueError("timeout should be duration represented as either a datetime.timedelta or int seconds")
if self.cache and not self.cache_version:
raise ValueError("Caching is enabled ``cache=True`` but ``cache_version`` is not set.")
if self.cache_serialize and not self.cache:
raise ValueError("Cache serialize is enabled ``cache_serialize=True`` but ``cache`` is not enabled.")
if self.cache_ignore_input_vars and not self.cache:
Expand Down
20 changes: 19 additions & 1 deletion flytekit/core/task.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations

import datetime
import hashlib
import inspect
import os
import sys
from functools import update_wrapper
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union, overload

Expand Down Expand Up @@ -339,7 +342,7 @@ def wrapper(fn: Callable[P, Any]) -> PythonFunctionTask[T]:
_metadata = TaskMetadata(
cache=cache,
cache_serialize=cache_serialize,
cache_version=cache_version,
cache_version=_deduce_cache_version(fn) if cache_version == "" else cache_version,
cache_ignore_input_vars=cache_ignore_input_vars,
retries=retries,
interruptible=interruptible,
Expand Down Expand Up @@ -378,6 +381,21 @@ def wrapper(fn: Callable[P, Any]) -> PythonFunctionTask[T]:
return wrapper


def _deduce_cache_version(fn: Callable[P, Any]) -> str:
Copy link
Member

@thomasjpfan thomasjpfan Sep 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the task uses an imported function, any updates to the imported function would not be captured by inspect.getsource:

from utils import helper

@task(cache=True)
def my_task(x: int):
    x = my_helper(x)

I think this is okay as long as we documented the behavior and tell users what to do.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup. I think it's okay too. Could we also update the docstring for cache_version here?

"""
This function is used to deduce the cache version for a task. The cache version is a hash of the function body.
"""
source = inspect.getsource(fn)
# We can be a bit more performant by setting `usedforsecurity=False`
# TODO: remove after dropping support for python 3.8: https://github.com/flyteorg/flyte/issues/5633
if sys.version_info >= (3, 9):
m = hashlib.sha256(usedforsecurity=False)
else:
m = hashlib.sha256()
m.update(source.encode("utf-8"))
return m.hexdigest()


class ReferenceTask(ReferenceEntity, PythonTask): # type: ignore
"""
This is a reference task, the body of the function passed in through the constructor will never be used, only the
Expand Down
74 changes: 65 additions & 9 deletions tests/flytekit/unit/core/test_python_function_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from flytekit.core.pod_template import PodTemplate
from flytekit.core.python_auto_container import get_registerable_container_image
from flytekit.core.python_function_task import PythonFunctionTask
from flytekit.core.task import _deduce_cache_version
from flytekit.core.tracker import isnested, istestfunction
from flytekit.image_spec.image_spec import ImageBuildEngine, ImageSpec
from flytekit.tools.translator import get_serializable_task
Expand Down Expand Up @@ -157,22 +158,77 @@ def bar(i: str):
bar_metadata = bar.metadata
assert bar_metadata.cache is False
assert bar_metadata.cache_serialize is False
assert bar_metadata.cache_version == ""

# test missing cache_version
with pytest.raises(ValueError):

@task(cache=True)
def foo_missing_cache_version(i: str):
print(f"{i}")
assert bar_metadata.cache_version == "b870594331edc52bd4691399d9018c2d7c523bf975f115c349b8ff30af6122de"

# test missing cache
with pytest.raises(ValueError):

@task(cache_serialize=True)
def foo_missing_cache(i: str):
print(f"{i}")

def test_deduce_cache_version_functions():
def foo(a: int, b: int) -> int:
return a + b

assert _deduce_cache_version(foo) == "3da83f75c1dae9691fc4618f72864b2242782f5eb18e404c1e85485804c94963"

def t0(a: int, b: int) -> int:
"""
Sample docstring
"""
return a + b

assert _deduce_cache_version(t0) == "77f42ae196b2948a173363e6c8b3c598bd1892947cc3a5e1d1bc6a8ba50e98cf"

def t1(a: int, b: int) -> int:
"""
Sample docstring plus a dot.
"""
return a + b

assert _deduce_cache_version(t1) == "0795ffaa7c25661592b8aeea20c8464e794f6124591e7222572602b89096b0f2"


def test_deduced_cache_version():
@task(cache=True)
def t0(a: int, b: int) -> int:
"""
Sample docstring
"""
return a + b

t0_metadata = t0.metadata
assert t0_metadata.cache is True
assert t0_metadata.cache_version == "97d4df6ec0e47c539d0ea49b9312a28c3cc5389e70121ae6efc7fb908eccf928"

@task(cache=True)
def t1(a: int, b: int) -> int:
"""
Sample docstring plus a dot.
"""
return a + b

t1_metadata = t1.metadata
assert t1_metadata.cache is True
assert t1_metadata.cache_version == "ff507165c2a93b9542521ef2026c72ed222440393afaef376ce28fe78e1011c3"


def test_deduced_cache_version_same_function_but_different_names():
@task(cache=True)
def t1(a: int, b: int) -> int:
return a + b

t1_metadata = t1.metadata
assert t1_metadata.cache is True
assert t1_metadata.cache_version == "1d811bdb0e792fae5fee8106c71825103aaf0cae404d424c91b68d2864d0ac58"

@task(cache=True)
def t2(a: int, b: int) -> int:
return a + b

t2_metadata = t2.metadata
assert t2_metadata.cache is True
assert t2_metadata.cache_version == "a9a0ce739dd77001d9de8932848ffce95695fdb59fa0c39e9b3849be20610201"

def test_pod_template():
@task(
Expand Down
Loading