Skip to content

Commit

Permalink
Types/generic alias - assert fix only (#2743)
Browse files Browse the repository at this point in the history
Signed-off-by: Yee Hing Tong <[email protected]>
  • Loading branch information
wild-endeavor authored Sep 14, 2024
1 parent c06ef30 commit e3dc8f9
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 0 deletions.
46 changes: 46 additions & 0 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,37 @@ def type_assertions_enabled(self) -> bool:
"""
return self._type_assertions_enabled

def isinstance_generic(self, obj, generic_alias):
origin = get_origin(generic_alias) # list from list[int])
args = get_args(generic_alias) # (int,) from list[int]

if not isinstance(obj, origin):
raise TypeTransformerFailedError(f"Value '{obj}' is not of container type {origin}")

# Optionally check the type of elements if it's a collection like list or dict
if origin in {list, tuple, set}:
for item in obj:
self.assert_type(args[0], item)
return
raise TypeTransformerFailedError(f"Not all items in '{obj}' are of type {args[0]}")

if origin is dict:
key_type, value_type = args
for k, v in obj.items():
self.assert_type(key_type, k)
self.assert_type(value_type, v)
return
raise TypeTransformerFailedError(f"Not all values in '{obj}' are of type {value_type}")

return

def assert_type(self, t: Type[T], v: T):
if sys.version_info >= (3, 10):
import types

if isinstance(t, types.GenericAlias):
return self.isinstance_generic(v, t)

if not hasattr(t, "__origin__") and not isinstance(v, t):
raise TypeTransformerFailedError(f"Expected value of type {t} but got '{v}' of type {type(v)}")

Expand Down Expand Up @@ -1509,6 +1539,22 @@ def get_sub_type_in_optional(t: Type[T]) -> Type[T]:
"""
return get_args(t)[0]

def assert_type(self, t: Type[T], v: T):
python_type = get_underlying_type(t)
if _is_union_type(python_type):
for sub_type in get_args(python_type):
if sub_type == typing.Any:
# this is an edge case
return
try:
super().assert_type(sub_type, v)
return
except TypeTransformerFailedError:
continue
raise TypeTransformerFailedError(f"Value {v} is not of type {t}")
else:
super().assert_type(t, v)

def get_literal_type(self, t: Type[T]) -> Optional[LiteralType]:
t = get_underlying_type(t)

Expand Down
37 changes: 37 additions & 0 deletions tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3291,3 +3291,40 @@ def run():
assert mock_wrapper.mock_calls[-N:] == [mock.call.after_import_mock()]*N
expected_number_of_register_calls = len(mock_wrapper.mock_calls) - N
assert all([mock_call[0] == "mock_register" for mock_call in mock_wrapper.mock_calls[:expected_number_of_register_calls]])


@pytest.mark.skipif(sys.version_info < (3, 10), reason="PEP604 requires >=3.10, 585 requires >=3.9")
def test_option_list_with_pipe():
pt = list[int] | None
lt = TypeEngine.to_literal_type(pt)

ctx = FlyteContextManager.current_context()
lit = TypeEngine.to_literal(ctx, [1, 2, 3], pt, lt)
assert lit.scalar.union.value.collection.literals[2].scalar.primitive.integer == 3

TypeEngine.to_literal(ctx, None, pt, lt)

with pytest.raises(TypeTransformerFailedError):
TypeEngine.to_literal(ctx, [1, 2, "3"], pt, lt)


@pytest.mark.skipif(sys.version_info < (3, 10), reason="PEP604 requires >=3.10, 585 requires >=3.9")
def test_option_list_with_pipe_2():
pt = list[list[dict[str, str]] | None] | None
lt = TypeEngine.to_literal_type(pt)

ctx = FlyteContextManager.current_context()
lit = TypeEngine.to_literal(ctx, [[{"a": "one"}], None, [{"b": "two"}]], pt, lt)
uv = lit.scalar.union.value
assert uv is not None
assert len(uv.collection.literals) == 3
first = uv.collection.literals[0]
assert first.scalar.union.value.collection.literals[0].map.literals["a"].scalar.primitive.string_value == "one"

assert len(lt.union_type.variants) == 2
v1 = lt.union_type.variants[0]
assert len(v1.collection_type.union_type.variants) == 2
assert v1.collection_type.union_type.variants[0].collection_type.map_value_type.simple == SimpleType.STRING

with pytest.raises(TypeTransformerFailedError):
TypeEngine.to_literal(ctx, [[{"a": "one"}], None, [{"b": 3}]], pt, lt)

0 comments on commit e3dc8f9

Please sign in to comment.