From 0072598ef707c81efd12b18d770f4e8f2e655804 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Tue, 10 Sep 2024 21:07:16 -0700 Subject: [PATCH 1/5] test generic alias Signed-off-by: Yee Hing Tong --- flytekit/core/type_engine.py | 38 ++++++++++++++++++++ flytekit/interaction/click_types.py | 1 + tests/flytekit/unit/core/test_type_engine.py | 10 ++++++ 3 files changed, 49 insertions(+) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index be5cbc6255..d51cf85cb6 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -148,7 +148,35 @@ def type_assertions_enabled(self) -> bool: """ return self._type_assertions_enabled + @staticmethod + def isinstance_generic(obj, generic_alias): + origin = get_origin(generic_alias) # list from list[int]) + args = get_args(generic_alias) # (int,) from list[int] + + if not isinstance(obj, origin): + raise TypeTransformerFailedError(f"Value '{obj}' is not of container type {origin}") + + # Optionally check the type of elements if it's a collection like list or dict + if origin in {list, tuple, set}: + if all(isinstance(item, args[0]) for item in obj): + return + raise TypeTransformerFailedError(f"Not all items in '{obj}' are of type {args[0]}") + + if origin is dict: + key_type, value_type = args + if all(isinstance(k, key_type) and isinstance(v, value_type) for k, v in obj.items()): + return + raise TypeTransformerFailedError(f"Not all values in '{obj}' are of type {value_type}") + + return + def assert_type(self, t: Type[T], v: T): + if sys.version_info >= (3, 10): + import types + + if isinstance(t, types.GenericAlias): + return self.isinstance_generic(v, t) + if not hasattr(t, "__origin__") and not isinstance(v, t): raise TypeTransformerFailedError(f"Expected value of type {t} but got '{v}' of type {type(v)}") @@ -1506,6 +1534,16 @@ def get_sub_type_in_optional(t: Type[T]) -> Type[T]: """ return get_args(t)[0] + def assert_type(self, t: Type[T], v: T): + python_type = get_underlying_type(t) + for sub_type in get_args(python_type): + try: + super().assert_type(sub_type, v) + return + except TypeTransformerFailedError: + continue + raise TypeTransformerFailedError(f"Value {v} is not of type {t}") + def get_literal_type(self, t: Type[T]) -> Optional[LiteralType]: t = get_underlying_type(t) diff --git a/flytekit/interaction/click_types.py b/flytekit/interaction/click_types.py index 04a1848f84..d008b89767 100644 --- a/flytekit/interaction/click_types.py +++ b/flytekit/interaction/click_types.py @@ -485,6 +485,7 @@ def convert( # If the input matches the default value in the launch plan, serialization can be skipped. if param and value == param.default: return None + breakpoint() lit = TypeEngine.to_literal(self._flyte_ctx, value, self._python_type, self._literal_type) if not self._is_remote: diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 57f6cddecf..293012ed4a 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -3275,3 +3275,13 @@ def run(): assert mock_wrapper.mock_calls[-N:] == [mock.call.after_import_mock()]*N expected_number_of_register_calls = len(mock_wrapper.mock_calls) - N assert all([mock_call[0] == "mock_register" for mock_call in mock_wrapper.mock_calls[:expected_number_of_register_calls]]) + + +@pytest.mark.skipif(sys.version_info < (3, 10), reason="PEP604 requires >=3.10, 585 requires >=3.9") +def test_option_list_with_pipe(): + pt = list[int] | None + lt = TypeEngine.to_literal_type(pt) + + ctx = FlyteContextManager.current_context() + lit = TypeEngine.to_literal(ctx, [1, 2, 3], pt, lt) + assert lit.scalar.union.value.collection.literals[2].scalar.primitive.integer == 3 From 598d76cf2751ac455b39f5e6deadedf8265f6030 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Tue, 10 Sep 2024 21:08:47 -0700 Subject: [PATCH 2/5] remove breakpoint Signed-off-by: Yee Hing Tong --- flytekit/interaction/click_types.py | 1 - 1 file changed, 1 deletion(-) diff --git a/flytekit/interaction/click_types.py b/flytekit/interaction/click_types.py index d008b89767..04a1848f84 100644 --- a/flytekit/interaction/click_types.py +++ b/flytekit/interaction/click_types.py @@ -485,7 +485,6 @@ def convert( # If the input matches the default value in the launch plan, serialization can be skipped. if param and value == param.default: return None - breakpoint() lit = TypeEngine.to_literal(self._flyte_ctx, value, self._python_type, self._literal_type) if not self._is_remote: From 2255237c1abebfb50721344d3248afce4b10c9a0 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Wed, 11 Sep 2024 04:38:55 -0700 Subject: [PATCH 3/5] add more complicated check Signed-off-by: Yee Hing Tong --- flytekit/core/type_engine.py | 27 ++++++++++++-------- tests/flytekit/unit/core/test_type_engine.py | 19 ++++++++++++++ 2 files changed, 35 insertions(+), 11 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index d51cf85cb6..e7ca650c49 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -148,8 +148,7 @@ def type_assertions_enabled(self) -> bool: """ return self._type_assertions_enabled - @staticmethod - def isinstance_generic(obj, generic_alias): + def isinstance_generic(self, obj, generic_alias): origin = get_origin(generic_alias) # list from list[int]) args = get_args(generic_alias) # (int,) from list[int] @@ -158,13 +157,16 @@ def isinstance_generic(obj, generic_alias): # Optionally check the type of elements if it's a collection like list or dict if origin in {list, tuple, set}: - if all(isinstance(item, args[0]) for item in obj): + for item in obj: + self.assert_type(args[0], item) return raise TypeTransformerFailedError(f"Not all items in '{obj}' are of type {args[0]}") if origin is dict: key_type, value_type = args - if all(isinstance(k, key_type) and isinstance(v, value_type) for k, v in obj.items()): + for k, v in obj.items(): + self.assert_type(key_type, k) + self.assert_type(value_type, v) return raise TypeTransformerFailedError(f"Not all values in '{obj}' are of type {value_type}") @@ -1536,13 +1538,16 @@ def get_sub_type_in_optional(t: Type[T]) -> Type[T]: def assert_type(self, t: Type[T], v: T): python_type = get_underlying_type(t) - for sub_type in get_args(python_type): - try: - super().assert_type(sub_type, v) - return - except TypeTransformerFailedError: - continue - raise TypeTransformerFailedError(f"Value {v} is not of type {t}") + if _is_union_type(python_type): + for sub_type in get_args(python_type): + try: + super().assert_type(sub_type, v) + return + except TypeTransformerFailedError: + continue + raise TypeTransformerFailedError(f"Value {v} is not of type {t}") + else: + super().assert_type(t, v) def get_literal_type(self, t: Type[T]) -> Optional[LiteralType]: t = get_underlying_type(t) diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 293012ed4a..159c552156 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -3285,3 +3285,22 @@ def test_option_list_with_pipe(): ctx = FlyteContextManager.current_context() lit = TypeEngine.to_literal(ctx, [1, 2, 3], pt, lt) assert lit.scalar.union.value.collection.literals[2].scalar.primitive.integer == 3 + + +@pytest.mark.skipif(sys.version_info < (3, 10), reason="PEP604 requires >=3.10, 585 requires >=3.9") +def test_option_list_with_pipe_2(): + pt = list[list[dict[str, str]] | None] | None + lt = TypeEngine.to_literal_type(pt) + + ctx = FlyteContextManager.current_context() + lit = TypeEngine.to_literal(ctx, [[{"a": "one"}], None, [{"b": "two"}]], pt, lt) + uv = lit.scalar.union.value + assert uv is not None + assert len(uv.collection.literals) == 3 + first = uv.collection.literals[0] + assert first.scalar.union.value.collection.literals[0].map.literals["a"].scalar.primitive.string_value == "one" + + assert len(lt.union_type.variants) == 2 + v1 = lt.union_type.variants[0] + assert len(v1.collection_type.union_type.variants) == 2 + assert v1.collection_type.union_type.variants[0].collection_type.map_value_type.simple == SimpleType.STRING From 10174b9ec1eaa2fa92504c1ceeb5ac7c1722c8d7 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Wed, 11 Sep 2024 04:43:05 -0700 Subject: [PATCH 4/5] test Signed-off-by: Yee Hing Tong --- tests/flytekit/unit/core/test_type_engine.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 159c552156..eb4c0d0f44 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -3286,6 +3286,11 @@ def test_option_list_with_pipe(): lit = TypeEngine.to_literal(ctx, [1, 2, 3], pt, lt) assert lit.scalar.union.value.collection.literals[2].scalar.primitive.integer == 3 + TypeEngine.to_literal(ctx, None, pt, lt) + + with pytest.raises(TypeTransformerFailedError): + TypeEngine.to_literal(ctx, [1, 2, "3"], pt, lt) + @pytest.mark.skipif(sys.version_info < (3, 10), reason="PEP604 requires >=3.10, 585 requires >=3.9") def test_option_list_with_pipe_2(): @@ -3304,3 +3309,6 @@ def test_option_list_with_pipe_2(): v1 = lt.union_type.variants[0] assert len(v1.collection_type.union_type.variants) == 2 assert v1.collection_type.union_type.variants[0].collection_type.map_value_type.simple == SimpleType.STRING + + with pytest.raises(TypeTransformerFailedError): + TypeEngine.to_literal(ctx, [[{"a": "one"}], None, [{"b": 3}]], pt, lt) From 7a9d4207eaec68b158d5b553a22723d606eaeaa2 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Wed, 11 Sep 2024 06:52:27 -0700 Subject: [PATCH 5/5] try to fix an edge case Signed-off-by: Yee Hing Tong --- flytekit/core/type_engine.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index e7ca650c49..34444cd1b4 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -1540,6 +1540,9 @@ def assert_type(self, t: Type[T], v: T): python_type = get_underlying_type(t) if _is_union_type(python_type): for sub_type in get_args(python_type): + if sub_type == typing.Any: + # this is an edge case + return try: super().assert_type(sub_type, v) return