Skip to content

Commit d5a95ec

Browse files
authored
[ty] Implicit type aliases: Add support for Callable (#21496)
## Summary Add support for `Callable` special forms in implicit type aliases. ## Typing conformance Four new tests are passing ## Ecosystem impact * All of the `invalid-type-form` errors are from libraries that use `mypy_extensions` and do something like `Callable[[NamedArg("x", str)], int]`. * A handful of new false positives because we do not support generic specializations of implicit type aliases, yet. But other * Everything else looks like true positives or known limitations ## Test Plan New Markdown tests.
1 parent b1e354b commit d5a95ec

File tree

5 files changed

+153
-57
lines changed

5 files changed

+153
-57
lines changed

crates/ty_python_semantic/resources/mdtest/implicit_type_aliases.md

Lines changed: 67 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ g(None)
3333
We also support unions in type aliases:
3434

3535
```py
36-
from typing_extensions import Any, Never, Literal, LiteralString, Tuple, Annotated, Optional, Union
36+
from typing_extensions import Any, Never, Literal, LiteralString, Tuple, Annotated, Optional, Union, Callable
3737
from ty_extensions import Unknown
3838

3939
IntOrStr = int | str
@@ -68,6 +68,8 @@ IntOrOptional = int | Optional[str]
6868
OptionalOrInt = Optional[str] | int
6969
IntOrTypeOfStr = int | type[str]
7070
TypeOfStrOrInt = type[str] | int
71+
IntOrCallable = int | Callable[[str], bytes]
72+
CallableOrInt = Callable[[str], bytes] | int
7173

7274
reveal_type(IntOrStr) # revealed: types.UnionType
7375
reveal_type(IntOrStrOrBytes1) # revealed: types.UnionType
@@ -101,6 +103,8 @@ reveal_type(IntOrOptional) # revealed: types.UnionType
101103
reveal_type(OptionalOrInt) # revealed: types.UnionType
102104
reveal_type(IntOrTypeOfStr) # revealed: types.UnionType
103105
reveal_type(TypeOfStrOrInt) # revealed: types.UnionType
106+
reveal_type(IntOrCallable) # revealed: types.UnionType
107+
reveal_type(CallableOrInt) # revealed: types.UnionType
104108

105109
def _(
106110
int_or_str: IntOrStr,
@@ -135,6 +139,8 @@ def _(
135139
optional_or_int: OptionalOrInt,
136140
int_or_type_of_str: IntOrTypeOfStr,
137141
type_of_str_or_int: TypeOfStrOrInt,
142+
int_or_callable: IntOrCallable,
143+
callable_or_int: CallableOrInt,
138144
):
139145
reveal_type(int_or_str) # revealed: int | str
140146
reveal_type(int_or_str_or_bytes1) # revealed: int | str | bytes
@@ -168,6 +174,8 @@ def _(
168174
reveal_type(optional_or_int) # revealed: str | None | int
169175
reveal_type(int_or_type_of_str) # revealed: int | type[str]
170176
reveal_type(type_of_str_or_int) # revealed: type[str] | int
177+
reveal_type(int_or_callable) # revealed: int | ((str, /) -> bytes)
178+
reveal_type(callable_or_int) # revealed: ((str, /) -> bytes) | int
171179
```
172180

173181
If a type is unioned with itself in a value expression, the result is just that type. No
@@ -944,7 +952,60 @@ def _(
944952
reveal_type(dict_too_many_args) # revealed: dict[Unknown, Unknown]
945953
```
946954

947-
## Stringified annotations?
955+
## `Callable[...]`
956+
957+
We support implicit type aliases using `Callable[...]`:
958+
959+
```py
960+
from typing import Callable, Union
961+
962+
CallableNoArgs = Callable[[], None]
963+
BasicCallable = Callable[[int, str], bytes]
964+
GradualCallable = Callable[..., str]
965+
966+
reveal_type(CallableNoArgs) # revealed: GenericAlias
967+
reveal_type(BasicCallable) # revealed: GenericAlias
968+
reveal_type(GradualCallable) # revealed: GenericAlias
969+
970+
def _(
971+
callable_no_args: CallableNoArgs,
972+
basic_callable: BasicCallable,
973+
gradual_callable: GradualCallable,
974+
):
975+
reveal_type(callable_no_args) # revealed: () -> None
976+
reveal_type(basic_callable) # revealed: (int, str, /) -> bytes
977+
reveal_type(gradual_callable) # revealed: (...) -> str
978+
```
979+
980+
Nested callables work as expected:
981+
982+
```py
983+
TakesCallable = Callable[[Callable[[int], str]], bytes]
984+
ReturnsCallable = Callable[[int], Callable[[str], bytes]]
985+
986+
def _(takes_callable: TakesCallable, returns_callable: ReturnsCallable):
987+
reveal_type(takes_callable) # revealed: ((int, /) -> str, /) -> bytes
988+
reveal_type(returns_callable) # revealed: (int, /) -> (str, /) -> bytes
989+
```
990+
991+
Invalid uses result in diagnostics:
992+
993+
```py
994+
# error: [invalid-type-form] "Special form `typing.Callable` expected exactly two arguments (parameter types and return type)"
995+
InvalidCallable1 = Callable[[int]]
996+
997+
# error: [invalid-type-form] "The first argument to `Callable` must be either a list of types, ParamSpec, Concatenate, or `...`"
998+
InvalidCallable2 = Callable[int, str]
999+
1000+
reveal_type(InvalidCallable1) # revealed: GenericAlias
1001+
reveal_type(InvalidCallable2) # revealed: GenericAlias
1002+
1003+
def _(invalid_callable1: InvalidCallable1, invalid_callable2: InvalidCallable2):
1004+
reveal_type(invalid_callable1) # revealed: (...) -> Unknown
1005+
reveal_type(invalid_callable2) # revealed: (...) -> Unknown
1006+
```
1007+
1008+
## Stringified annotations
9481009

9491010
From the [typing spec on type aliases](https://typing.python.org/en/latest/spec/aliases.html):
9501011

@@ -974,14 +1035,15 @@ We *do* support stringified annotations if they appear in a position where a typ
9741035
syntactically expected:
9751036

9761037
```py
977-
from typing import Union, List, Dict, Annotated
1038+
from typing import Union, List, Dict, Annotated, Callable
9781039

9791040
ListOfInts1 = list["int"]
9801041
ListOfInts2 = List["int"]
9811042
StrOrStyle = Union[str, "Style"]
9821043
SubclassOfStyle = type["Style"]
9831044
DictStrToStyle = Dict[str, "Style"]
9841045
AnnotatedStyle = Annotated["Style", "metadata"]
1046+
CallableStyleToStyle = Callable[["Style"], "Style"]
9851047

9861048
class Style: ...
9871049

@@ -992,13 +1054,15 @@ def _(
9921054
subclass_of_style: SubclassOfStyle,
9931055
dict_str_to_style: DictStrToStyle,
9941056
annotated_style: AnnotatedStyle,
1057+
callable_style_to_style: CallableStyleToStyle,
9951058
):
9961059
reveal_type(list_of_ints1) # revealed: list[int]
9971060
reveal_type(list_of_ints2) # revealed: list[int]
9981061
reveal_type(str_or_style) # revealed: str | Style
9991062
reveal_type(subclass_of_style) # revealed: type[Style]
10001063
reveal_type(dict_str_to_style) # revealed: dict[str, Style]
10011064
reveal_type(annotated_style) # revealed: Style
1065+
reveal_type(callable_style_to_style) # revealed: (Style, /) -> Style
10021066
```
10031067

10041068
## Recursive

crates/ty_python_semantic/src/types.rs

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6747,6 +6747,7 @@ impl<'db> Type<'db> {
67476747

67486748
Ok(ty.inner(db).to_meta_type(db))
67496749
}
6750+
KnownInstanceType::Callable(callable) => Ok(Type::Callable(*callable)),
67506751
},
67516752

67526753
Type::SpecialForm(special_form) => match special_form {
@@ -7990,6 +7991,9 @@ pub enum KnownInstanceType<'db> {
79907991
/// An instance of `typing.GenericAlias` representing a `type[...]` expression.
79917992
TypeGenericAlias(InternedType<'db>),
79927993

7994+
/// An instance of `typing.GenericAlias` representing a `Callable[...]` expression.
7995+
Callable(CallableType<'db>),
7996+
79937997
/// An identity callable created with `typing.NewType(name, base)`, which behaves like a
79947998
/// subtype of `base` in type expressions. See the `struct NewType` payload for an example.
79957999
NewType(NewType<'db>),
@@ -8029,6 +8033,9 @@ fn walk_known_instance_type<'db, V: visitor::TypeVisitor<'db> + ?Sized>(
80298033
| KnownInstanceType::TypeGenericAlias(ty) => {
80308034
visitor.visit_type(db, ty.inner(db));
80318035
}
8036+
KnownInstanceType::Callable(callable) => {
8037+
visitor.visit_callable_type(db, callable);
8038+
}
80328039
KnownInstanceType::NewType(newtype) => {
80338040
if let ClassType::Generic(generic_alias) = newtype.base_class_type(db) {
80348041
visitor.visit_generic_alias_type(db, generic_alias);
@@ -8074,6 +8081,7 @@ impl<'db> KnownInstanceType<'db> {
80748081
Self::Literal(ty) => Self::Literal(ty.normalized_impl(db, visitor)),
80758082
Self::Annotated(ty) => Self::Annotated(ty.normalized_impl(db, visitor)),
80768083
Self::TypeGenericAlias(ty) => Self::TypeGenericAlias(ty.normalized_impl(db, visitor)),
8084+
Self::Callable(callable) => Self::Callable(callable.normalized_impl(db, visitor)),
80778085
Self::NewType(newtype) => Self::NewType(
80788086
newtype
80798087
.map_base_class_type(db, |class_type| class_type.normalized_impl(db, visitor)),
@@ -8096,9 +8104,10 @@ impl<'db> KnownInstanceType<'db> {
80968104
Self::Field(_) => KnownClass::Field,
80978105
Self::ConstraintSet(_) => KnownClass::ConstraintSet,
80988106
Self::UnionType(_) => KnownClass::UnionType,
8099-
Self::Literal(_) | Self::Annotated(_) | Self::TypeGenericAlias(_) => {
8100-
KnownClass::GenericAlias
8101-
}
8107+
Self::Literal(_)
8108+
| Self::Annotated(_)
8109+
| Self::TypeGenericAlias(_)
8110+
| Self::Callable(_) => KnownClass::GenericAlias,
81028111
Self::NewType(_) => KnownClass::NewType,
81038112
}
81048113
}
@@ -8184,7 +8193,9 @@ impl<'db> KnownInstanceType<'db> {
81848193
KnownInstanceType::Annotated(_) => {
81858194
f.write_str("<typing.Annotated special form>")
81868195
}
8187-
KnownInstanceType::TypeGenericAlias(_) => f.write_str("GenericAlias"),
8196+
KnownInstanceType::TypeGenericAlias(_) | KnownInstanceType::Callable(_) => {
8197+
f.write_str("GenericAlias")
8198+
}
81888199
KnownInstanceType::NewType(declaration) => {
81898200
write!(f, "<NewType pseudo-class '{}'>", declaration.name(self.db))
81908201
}

crates/ty_python_semantic/src/types/class_base.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ impl<'db> ClassBase<'db> {
174174
| KnownInstanceType::Deprecated(_)
175175
| KnownInstanceType::Field(_)
176176
| KnownInstanceType::ConstraintSet(_)
177+
| KnownInstanceType::Callable(_)
177178
| KnownInstanceType::UnionType(_)
178179
| KnownInstanceType::Literal(_)
179180
// A class inheriting from a newtype would make intuitive sense, but newtype

crates/ty_python_semantic/src/types/infer/builder.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9506,7 +9506,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
95069506
KnownInstanceType::UnionType(_)
95079507
| KnownInstanceType::Literal(_)
95089508
| KnownInstanceType::Annotated(_)
9509-
| KnownInstanceType::TypeGenericAlias(_),
9509+
| KnownInstanceType::TypeGenericAlias(_)
9510+
| KnownInstanceType::Callable(_),
95109511
),
95119512
Type::ClassLiteral(..)
95129513
| Type::SubclassOf(..)
@@ -9516,7 +9517,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
95169517
KnownInstanceType::UnionType(_)
95179518
| KnownInstanceType::Literal(_)
95189519
| KnownInstanceType::Annotated(_)
9519-
| KnownInstanceType::TypeGenericAlias(_),
9520+
| KnownInstanceType::TypeGenericAlias(_)
9521+
| KnownInstanceType::Callable(_),
95209522
),
95219523
ast::Operator::BitOr,
95229524
) if pep_604_unions_allowed() => {
@@ -10827,6 +10829,14 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
1082710829
InternedType::new(self.db(), argument_ty),
1082810830
));
1082910831
}
10832+
Type::SpecialForm(SpecialFormType::Callable) => {
10833+
let callable = self
10834+
.infer_callable_type(subscript)
10835+
.as_callable()
10836+
.expect("always returns Type::Callable");
10837+
10838+
return Type::KnownInstance(KnownInstanceType::Callable(callable));
10839+
}
1083010840
// `typing` special forms with a single generic argument
1083110841
Type::SpecialForm(
1083210842
special_form @ (SpecialFormType::List

crates/ty_python_semantic/src/types/infer/builder/type_expression.rs

Lines changed: 58 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -839,6 +839,10 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
839839
}
840840
Type::unknown()
841841
}
842+
KnownInstanceType::Callable(_) => {
843+
self.infer_type_expression(slice);
844+
todo_type!("Generic specialization of typing.Callable")
845+
}
842846
KnownInstanceType::Annotated(_) => {
843847
self.infer_type_expression(slice);
844848
todo_type!("Generic specialization of typing.Annotated")
@@ -929,6 +933,58 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
929933
ty
930934
}
931935

936+
/// Infer the type of a `Callable[...]` type expression.
937+
pub(crate) fn infer_callable_type(&mut self, subscript: &ast::ExprSubscript) -> Type<'db> {
938+
let db = self.db();
939+
940+
let arguments_slice = &*subscript.slice;
941+
942+
let mut arguments = match arguments_slice {
943+
ast::Expr::Tuple(tuple) => Either::Left(tuple.iter()),
944+
_ => {
945+
self.infer_callable_parameter_types(arguments_slice);
946+
Either::Right(std::iter::empty::<&ast::Expr>())
947+
}
948+
};
949+
950+
let first_argument = arguments.next();
951+
952+
let parameters = first_argument.and_then(|arg| self.infer_callable_parameter_types(arg));
953+
954+
let return_type = arguments.next().map(|arg| self.infer_type_expression(arg));
955+
956+
let correct_argument_number = if let Some(third_argument) = arguments.next() {
957+
self.infer_type_expression(third_argument);
958+
for argument in arguments {
959+
self.infer_type_expression(argument);
960+
}
961+
false
962+
} else {
963+
return_type.is_some()
964+
};
965+
966+
if !correct_argument_number {
967+
report_invalid_arguments_to_callable(&self.context, subscript);
968+
}
969+
970+
let callable_type = if let (Some(parameters), Some(return_type), true) =
971+
(parameters, return_type, correct_argument_number)
972+
{
973+
CallableType::single(db, Signature::new(parameters, Some(return_type)))
974+
} else {
975+
CallableType::unknown(db)
976+
};
977+
978+
// `Signature` / `Parameters` are not a `Type` variant, so we're storing
979+
// the outer callable type on these expressions instead.
980+
self.store_expression_type(arguments_slice, callable_type);
981+
if let Some(first_argument) = first_argument {
982+
self.store_expression_type(first_argument, callable_type);
983+
}
984+
985+
callable_type
986+
}
987+
932988
pub(crate) fn infer_parameterized_special_form_type_expression(
933989
&mut self,
934990
subscript: &ast::ExprSubscript,
@@ -979,53 +1035,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
9791035
}
9801036
_ => self.infer_type_expression(arguments_slice),
9811037
},
982-
SpecialFormType::Callable => {
983-
let mut arguments = match arguments_slice {
984-
ast::Expr::Tuple(tuple) => Either::Left(tuple.iter()),
985-
_ => {
986-
self.infer_callable_parameter_types(arguments_slice);
987-
Either::Right(std::iter::empty::<&ast::Expr>())
988-
}
989-
};
990-
991-
let first_argument = arguments.next();
992-
993-
let parameters =
994-
first_argument.and_then(|arg| self.infer_callable_parameter_types(arg));
995-
996-
let return_type = arguments.next().map(|arg| self.infer_type_expression(arg));
997-
998-
let correct_argument_number = if let Some(third_argument) = arguments.next() {
999-
self.infer_type_expression(third_argument);
1000-
for argument in arguments {
1001-
self.infer_type_expression(argument);
1002-
}
1003-
false
1004-
} else {
1005-
return_type.is_some()
1006-
};
1007-
1008-
if !correct_argument_number {
1009-
report_invalid_arguments_to_callable(&self.context, subscript);
1010-
}
1011-
1012-
let callable_type = if let (Some(parameters), Some(return_type), true) =
1013-
(parameters, return_type, correct_argument_number)
1014-
{
1015-
CallableType::single(db, Signature::new(parameters, Some(return_type)))
1016-
} else {
1017-
CallableType::unknown(db)
1018-
};
1019-
1020-
// `Signature` / `Parameters` are not a `Type` variant, so we're storing
1021-
// the outer callable type on these expressions instead.
1022-
self.store_expression_type(arguments_slice, callable_type);
1023-
if let Some(first_argument) = first_argument {
1024-
self.store_expression_type(first_argument, callable_type);
1025-
}
1026-
1027-
callable_type
1028-
}
1038+
SpecialFormType::Callable => self.infer_callable_type(subscript),
10291039

10301040
// `ty_extensions` special forms
10311041
SpecialFormType::Not => {
@@ -1491,7 +1501,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
14911501
///
14921502
/// It returns `None` if the argument is invalid i.e., not a list of types, parameter
14931503
/// specification, `typing.Concatenate`, or `...`.
1494-
fn infer_callable_parameter_types(
1504+
pub(super) fn infer_callable_parameter_types(
14951505
&mut self,
14961506
parameters: &ast::Expr,
14971507
) -> Option<Parameters<'db>> {

0 commit comments

Comments
 (0)