diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index be5cbc6255..34444cd1b4 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -148,7 +148,37 @@ def type_assertions_enabled(self) -> bool: """ return self._type_assertions_enabled + def isinstance_generic(self, obj, generic_alias): + origin = get_origin(generic_alias) # list from list[int]) + args = get_args(generic_alias) # (int,) from list[int] + + if not isinstance(obj, origin): + raise TypeTransformerFailedError(f"Value '{obj}' is not of container type {origin}") + + # Optionally check the type of elements if it's a collection like list or dict + if origin in {list, tuple, set}: + for item in obj: + self.assert_type(args[0], item) + return + raise TypeTransformerFailedError(f"Not all items in '{obj}' are of type {args[0]}") + + if origin is dict: + key_type, value_type = args + for k, v in obj.items(): + self.assert_type(key_type, k) + self.assert_type(value_type, v) + return + raise TypeTransformerFailedError(f"Not all values in '{obj}' are of type {value_type}") + + return + def assert_type(self, t: Type[T], v: T): + if sys.version_info >= (3, 10): + import types + + if isinstance(t, types.GenericAlias): + return self.isinstance_generic(v, t) + if not hasattr(t, "__origin__") and not isinstance(v, t): raise TypeTransformerFailedError(f"Expected value of type {t} but got '{v}' of type {type(v)}") @@ -1506,6 +1536,22 @@ def get_sub_type_in_optional(t: Type[T]) -> Type[T]: """ return get_args(t)[0] + def assert_type(self, t: Type[T], v: T): + python_type = get_underlying_type(t) + if _is_union_type(python_type): + for sub_type in get_args(python_type): + if sub_type == typing.Any: + # this is an edge case + return + try: + super().assert_type(sub_type, v) + return + except TypeTransformerFailedError: + continue + raise TypeTransformerFailedError(f"Value {v} is not of type {t}") + else: + super().assert_type(t, v) + def get_literal_type(self, t: Type[T]) -> Optional[LiteralType]: t = get_underlying_type(t) diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 57f6cddecf..eb4c0d0f44 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -3275,3 +3275,40 @@ def run(): assert mock_wrapper.mock_calls[-N:] == [mock.call.after_import_mock()]*N expected_number_of_register_calls = len(mock_wrapper.mock_calls) - N assert all([mock_call[0] == "mock_register" for mock_call in mock_wrapper.mock_calls[:expected_number_of_register_calls]]) + + +@pytest.mark.skipif(sys.version_info < (3, 10), reason="PEP604 requires >=3.10, 585 requires >=3.9") +def test_option_list_with_pipe(): + pt = list[int] | None + lt = TypeEngine.to_literal_type(pt) + + ctx = FlyteContextManager.current_context() + lit = TypeEngine.to_literal(ctx, [1, 2, 3], pt, lt) + assert lit.scalar.union.value.collection.literals[2].scalar.primitive.integer == 3 + + TypeEngine.to_literal(ctx, None, pt, lt) + + with pytest.raises(TypeTransformerFailedError): + TypeEngine.to_literal(ctx, [1, 2, "3"], pt, lt) + + +@pytest.mark.skipif(sys.version_info < (3, 10), reason="PEP604 requires >=3.10, 585 requires >=3.9") +def test_option_list_with_pipe_2(): + pt = list[list[dict[str, str]] | None] | None + lt = TypeEngine.to_literal_type(pt) + + ctx = FlyteContextManager.current_context() + lit = TypeEngine.to_literal(ctx, [[{"a": "one"}], None, [{"b": "two"}]], pt, lt) + uv = lit.scalar.union.value + assert uv is not None + assert len(uv.collection.literals) == 3 + first = uv.collection.literals[0] + assert first.scalar.union.value.collection.literals[0].map.literals["a"].scalar.primitive.string_value == "one" + + assert len(lt.union_type.variants) == 2 + v1 = lt.union_type.variants[0] + assert len(v1.collection_type.union_type.variants) == 2 + assert v1.collection_type.union_type.variants[0].collection_type.map_value_type.simple == SimpleType.STRING + + with pytest.raises(TypeTransformerFailedError): + TypeEngine.to_literal(ctx, [[{"a": "one"}], None, [{"b": 3}]], pt, lt)