diff --git a/app/buck2_build_api/src/interpreter/rule_defs/provider/callable.rs b/app/buck2_build_api/src/interpreter/rule_defs/provider/callable.rs index 2f467db1d5ec1..966c5a3e6d724 100644 --- a/app/buck2_build_api/src/interpreter/rule_defs/provider/callable.rs +++ b/app/buck2_build_api/src/interpreter/rule_defs/provider/callable.rs @@ -157,6 +157,7 @@ pub(crate) struct UserProviderCallableData { /// Type id of provider callable instance. pub(crate) ty_provider_type_instance_id: TypeInstanceId, pub(crate) fields: IndexMap, + pub(crate) ty_provider: Ty, } /// Initialized after the name is assigned to the provider. @@ -387,6 +388,7 @@ impl<'v> StarlarkValue<'v> for UserProviderCallable { provider_id, fields: self.fields.clone(), ty_provider_type_instance_id, + ty_provider: ty_provider.dupe(), }), ty_provider, ty_callable, diff --git a/app/buck2_build_api/src/interpreter/rule_defs/provider/collection.rs b/app/buck2_build_api/src/interpreter/rule_defs/provider/collection.rs index 1adf4482a3c95..d83301dad2628 100644 --- a/app/buck2_build_api/src/interpreter/rule_defs/provider/collection.rs +++ b/app/buck2_build_api/src/interpreter/rule_defs/provider/collection.rs @@ -37,6 +37,10 @@ use starlark::environment::Methods; use starlark::environment::MethodsBuilder; use starlark::environment::MethodsStatic; use starlark::typing::Ty; +use starlark::typing::TyCustomIndex; +use starlark::typing::TyStarlarkValue; +use starlark::typing::TyUser; +use starlark::typing::TyUserParams; use starlark::values::AllocFrozenValue; use starlark::values::AllocStaticSimple; use starlark::values::AllocValue; @@ -64,6 +68,7 @@ use starlark::values::none::NoneOr; use starlark::values::starlark_value; use starlark::values::starlark_value_as_type::StarlarkValueAsType; use starlark::values::type_repr::StarlarkTypeRepr; +use starlark::values::typing::TypeInstanceId; use crate::interpreter::rule_defs::provider::DefaultInfo; use crate::interpreter::rule_defs::provider::DefaultInfoCallable; @@ -372,6 +377,7 @@ impl FrozenProviderCollection { /// ``` #[starlark_module] fn provider_collection_methods(builder: &mut MethodsBuilder) { + #[starlark(ty_custom_function = super::dependency::GetTyIdentity)] fn get<'v>( this: &ProviderCollection<'v>, index: Value<'v>, @@ -380,11 +386,38 @@ fn provider_collection_methods(builder: &mut MethodsBuilder) { } } +static PROVIDER_COLLECTION_TYPE: std::sync::LazyLock = std::sync::LazyLock::new(|| { + Ty::custom( + TyUser::new( + "ProviderCollection".to_owned(), + TyStarlarkValue::new::(), + TypeInstanceId::r#gen(), + TyUserParams { + index_custom: Some(TyCustomIndex::new(super::dependency::GetTyIdentity)), + ..TyUserParams::default() + }, + ) + .unwrap(), + ) +}); + #[starlark_value(type = "ProviderCollection")] impl<'v, V: ValueLike<'v>> StarlarkValue<'v> for ProviderCollectionGen where Self: ProvidesStaticType<'v>, { + fn get_type_starlark_repr() -> Ty { + PROVIDER_COLLECTION_TYPE.dupe() + } + + fn eval_type(&self) -> Option { + Some(PROVIDER_COLLECTION_TYPE.dupe()) + } + + fn typechecker_ty(&self) -> Option { + Some(PROVIDER_COLLECTION_TYPE.dupe()) + } + fn at(&self, index: Value<'v>, _heap: &'v Heap) -> starlark::Result> { match self.get_impl(index, GetOp::At)? { Either::Left(v) => Ok(v), diff --git a/app/buck2_build_api/src/interpreter/rule_defs/provider/dependency.rs b/app/buck2_build_api/src/interpreter/rule_defs/provider/dependency.rs index 4258d15e526d3..4253f78682e6d 100644 --- a/app/buck2_build_api/src/interpreter/rule_defs/provider/dependency.rs +++ b/app/buck2_build_api/src/interpreter/rule_defs/provider/dependency.rs @@ -19,13 +19,25 @@ use buck2_core::provider::label::ConfiguredProvidersLabel; use buck2_core::provider::label::ProviderName; use buck2_error::BuckErrorContext; use buck2_interpreter::types::configured_providers_label::StarlarkConfiguredProvidersLabel; +use dupe::Dupe; use starlark::any::ProvidesStaticType; use starlark::coerce::Coerce; use starlark::environment::GlobalsBuilder; use starlark::environment::Methods; use starlark::environment::MethodsBuilder; use starlark::environment::MethodsStatic; +use starlark::typing::ParamSpec; use starlark::typing::Ty; +use starlark::typing::TyBasic; +use starlark::typing::TyCallArgs; +use starlark::typing::TyCallable; +use starlark::typing::TyCustomIndex; +use starlark::typing::TyStarlarkValue; +use starlark::typing::TyUser; +use starlark::typing::TyUserParams; +use starlark::typing::TypingNoContextOrInternalError; +use starlark::typing::TypingOrInternalError; +use starlark::typing::TypingOracleCtx; use starlark::values::Freeze; use starlark::values::FrozenValue; use starlark::values::FrozenValueTyped; @@ -41,6 +53,7 @@ use starlark::values::ValueOfUncheckedGeneric; use starlark::values::none::NoneOr; use starlark::values::starlark_value; use starlark::values::starlark_value_as_type::StarlarkValueAsType; +use starlark::values::typing::TypeInstanceId; use starlark_map::StarlarkHasher; use crate::interpreter::rule_defs::provider::collection::FrozenProviderCollection; @@ -127,13 +140,36 @@ impl<'v> Dependency<'v> { } } +static DEPENDENCY_TYPE: std::sync::LazyLock = std::sync::LazyLock::new(|| { + Ty::custom( + TyUser::new( + "Dependency".to_owned(), + TyStarlarkValue::new::(), + TypeInstanceId::r#gen(), + TyUserParams { + index_custom: Some(TyCustomIndex::new(GetTyIdentity)), + ..TyUserParams::default() + }, + ) + .unwrap(), + ) +}); + #[starlark_value(type = "Dependency")] impl<'v, V: ValueLike<'v>> StarlarkValue<'v> for DependencyGen where Self: ProvidesStaticType<'v>, { fn get_type_starlark_repr() -> Ty { - Ty::starlark_value::>>() + DEPENDENCY_TYPE.dupe() + } + + fn eval_type(&self) -> Option { + Some(DEPENDENCY_TYPE.dupe()) + } + + fn typechecker_ty(&self) -> Option { + Some(DEPENDENCY_TYPE.dupe()) } fn get_methods() -> Option<&'static Methods> { @@ -222,6 +258,7 @@ fn dependency_methods(builder: &mut MethodsBuilder) { /// .... /// collection.get(FooInfo) # None if absent, a FooInfo instance if present /// ``` + #[starlark(ty_custom_function = GetTyIdentity)] fn get<'v>( this: &Dependency<'v>, index: Value<'v>, @@ -239,3 +276,44 @@ fn dependency_methods(builder: &mut MethodsBuilder) { pub(crate) fn register_dependency(globals: &mut GlobalsBuilder) { const Dependency: StarlarkValueAsType> = StarlarkValueAsType::new(); } + +#[derive(Debug, PartialEq, PartialOrd, Eq, Ord, Hash, Allocative)] +pub(super) struct GetTyIdentity; + +impl starlark::typing::TyCustomFunctionImpl for GetTyIdentity { + fn as_callable(&self) -> starlark::typing::TyCallable { + TyCallable::new(ParamSpec::pos_only([Ty::any()], []), Ty::any()) + } + + fn validate_call( + &self, + span: starlark::codemap::Span, + args: &TyCallArgs, + oracle: starlark::typing::TypingOracleCtx, + ) -> Result { + let first_arg = args + .pos + .first() + .ok_or_else(|| oracle.mk_error(span, anyhow::anyhow!("No first argument")))?; + let Some(ret) = first_arg.node.as_callable_return() else { + return Err(oracle + .mk_error(span, anyhow::anyhow!("Not a provider callable")) + .into()); + }; + Ok(Ty::union2(Ty::none(), ret)) + } +} + +impl starlark::typing::TyCustomIndexImpl for GetTyIdentity { + fn index( + &self, + item: &TyBasic, + _ctx: &TypingOracleCtx, + ) -> Result { + let item_ty = Ty::basic(item.dupe()); + let Some(ret) = item_ty.as_callable_return() else { + return Err(TypingNoContextOrInternalError::Typing); + }; + Ok(ret) + } +} diff --git a/app/buck2_build_api/src/interpreter/rule_defs/provider/user.rs b/app/buck2_build_api/src/interpreter/rule_defs/provider/user.rs index 487c3bd5745aa..b18935bacb857 100644 --- a/app/buck2_build_api/src/interpreter/rule_defs/provider/user.rs +++ b/app/buck2_build_api/src/interpreter/rule_defs/provider/user.rs @@ -148,6 +148,10 @@ where fn provide(&'v self, demand: &mut Demand<'_, 'v>) { demand.provide_value::<&dyn ProviderLike>(self); } + + fn typechecker_ty(&self) -> Option { + Some(self.callable.ty_provider.dupe()) + } } impl<'v, V: ValueLike<'v>> serde::Serialize for UserProviderGen<'v, V> { diff --git a/app/buck2_build_api_tests/src/interpreter/rule_defs/provider.rs b/app/buck2_build_api_tests/src/interpreter/rule_defs/provider.rs index b8af2b3471fcb..9ca4e6ff63eb0 100644 --- a/app/buck2_build_api_tests/src/interpreter/rule_defs/provider.rs +++ b/app/buck2_build_api_tests/src/interpreter/rule_defs/provider.rs @@ -11,6 +11,7 @@ mod build_defs; mod builtin; mod collection; +mod dependency; mod field_types; mod provider_symbol; mod tests; diff --git a/app/buck2_build_api_tests/src/interpreter/rule_defs/provider/dependency.rs b/app/buck2_build_api_tests/src/interpreter/rule_defs/provider/dependency.rs new file mode 100644 index 0000000000000..7e12c6b600d53 --- /dev/null +++ b/app/buck2_build_api_tests/src/interpreter/rule_defs/provider/dependency.rs @@ -0,0 +1,109 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is dual-licensed under either the MIT license found in the + * LICENSE-MIT file in the root directory of this source tree or the Apache + * License, Version 2.0 found in the LICENSE-APACHE file in the root directory + * of this source tree. You may select, at your option, one of the + * above-listed licenses. + */ + +use buck2_interpreter_for_build::interpreter::testing::Tester; +use indoc::indoc; + +#[test] +fn dependency_type_check_get_ok() -> buck2_error::Result<()> { + let mut tester = Tester::new()?; + tester.run_starlark_bzl_test(indoc!( + r#" + def check(dep: Dependency): + prov = dep.get(DefaultInfo) + if prov: + x = prov.default_outputs[0] + def test(): pass + "# + ))?; + Ok(()) +} + +#[test] +#[ignore = "starlark can't check None yet"] +fn dependency_type_check_get_fail_none() -> buck2_error::Result<()> { + let mut tester = Tester::new()?; + tester.run_starlark_bzl_test_expecting_error( + indoc!( + r#" + def check(dep: Dependency): + prov = dep.get(DefaultInfo) + x = prov.default_outputs[0] + def test(): pass + "# + ), + "The attribute `default_outputs` is not available on the type `None | DefaultInfo`", + ); + Ok(()) +} + +#[test] +fn dependency_type_check_get_fail_field() -> buck2_error::Result<()> { + let mut tester = Tester::new()?; + tester.run_starlark_bzl_test_expecting_error( + indoc!( + r#" + def check(dep: Dependency): + prov = dep.get(DefaultInfo) + if prov: + x = prov.xxxx + def test(): pass + "# + ), + // TODO: this error message could be better, we ruled out None + "The attribute `xxxx` is not available on the type `None | DefaultInfo`", + ); + Ok(()) +} + +#[test] +fn dependency_type_check_index_ok() -> buck2_error::Result<()> { + let mut tester = Tester::new()?; + tester.run_starlark_bzl_test(indoc!( + r#" + def check(dep: Dependency): + prov = dep[DefaultInfo] + x = prov.default_outputs[0] + + def test(): pass + "# + ))?; + Ok(()) +} + +#[test] +fn dependency_type_check_index_fail() -> buck2_error::Result<()> { + let mut tester = Tester::new()?; + tester.run_starlark_bzl_test_expecting_error( + indoc!( + r#" + def check(dep: Dependency): + prov = dep[DefaultInfo] + x = prov.xxxx + def test(): pass + "# + ), + "The attribute `xxxx` is not available on the type `DefaultInfo`", + ); + Ok(()) +} + +#[test] +fn dependency_type_check_binop_in() -> buck2_error::Result<()> { + let mut tester = Tester::new()?; + tester.run_starlark_bzl_test(indoc!( + r#" + def check(dep: Dependency): + return DefaultInfo in dep + def test(): pass + "# + ))?; + Ok(()) +} diff --git a/app/buck2_build_api_tests/src/interpreter/rule_defs/provider/tests.rs b/app/buck2_build_api_tests/src/interpreter/rule_defs/provider/tests.rs index 0d9210fe1b4ec..4a7fef6683c6c 100644 --- a/app/buck2_build_api_tests/src/interpreter/rule_defs/provider/tests.rs +++ b/app/buck2_build_api_tests/src/interpreter/rule_defs/provider/tests.rs @@ -188,3 +188,39 @@ fn test_provider_non_unique_fields() { "non-unique field names", ) } + +#[test] +fn test_typecheck_provider_field_inline() { + let mut tester = provider_tester(); + tester.run_starlark_bzl_test_expecting_error( + indoc!( + r#" + FooInfo = provider(fields=["field"]) + def checkonly(): + foo = FooInfo(field="foo1") + x = foo.ield + def test(): + pass + "# + ), + "The attribute `ield` is not available on the type `FooInfo`", + ); +} + +#[test] +fn test_typecheck_provider_field_ambient() { + let mut tester = provider_tester(); + tester.run_starlark_bzl_test_expecting_error( + indoc!( + r#" + FooInfo = provider(fields=["field"]) + foo = FooInfo(field="foo1") + def checkonly(): + x = foo.ield + def test(): + pass + "# + ), + "The attribute `ield` is not available on the type `FooInfo`", + ); +} diff --git a/prelude/python/needed_coverage.bzl b/prelude/python/needed_coverage.bzl index 00a942de5a82c..f58406a6c53fe 100644 --- a/prelude/python/needed_coverage.bzl +++ b/prelude/python/needed_coverage.bzl @@ -30,7 +30,7 @@ def _parse_python_needed_coverage_spec( fail("ratio_percentage must be between 0 and 100 (inclusive): {}".format(ratio_percentage)) ratio_percentage = ratio_percentage / 100.0 - coverage = dep[PythonNeededCoverageInfo] + coverage = dep.get(PythonNeededCoverageInfo) expect(coverage != None, "{} doesn't have a `PythonNeededCoverageInfo` provider", dep.label) # Extract modules for this dep. diff --git a/starlark-rust/starlark/src/environment/globals.rs b/starlark-rust/starlark/src/environment/globals.rs index a827436216bbe..5a799bd32a2cc 100644 --- a/starlark-rust/starlark/src/environment/globals.rs +++ b/starlark-rust/starlark/src/environment/globals.rs @@ -306,7 +306,6 @@ impl GlobalsBuilder { &components, as_type.as_ref().map(|x| x.0.dupe()), ) - .unwrap() // TODO(nga): do not unwrap. }), docs: components.into_docs(as_type), special_builtin_function, diff --git a/starlark-rust/starlark/src/environment/methods.rs b/starlark-rust/starlark/src/environment/methods.rs index 2236cf8f5ea46..aee81b4f57869 100644 --- a/starlark-rust/starlark/src/environment/methods.rs +++ b/starlark-rust/starlark/src/environment/methods.rs @@ -213,10 +213,10 @@ impl MethodsBuilder { name: &str, components: NativeCallableComponents, sig: ParametersSpec, + ty: Option, f: NativeMethFn, ) { - // TODO(nga): do not unwrap. - let ty = Ty::from_native_callable_components(&components, None).unwrap(); + let ty = ty.unwrap_or_else(|| Ty::from_native_callable_components(&components, None)); self.members.insert( name, diff --git a/starlark-rust/starlark/src/typing.rs b/starlark-rust/starlark/src/typing.rs index b33f68be3b7dd..958e40467526f 100644 --- a/starlark-rust/starlark/src/typing.rs +++ b/starlark-rust/starlark/src/typing.rs @@ -34,6 +34,7 @@ pub(crate) mod custom; pub(crate) mod error; pub(crate) mod fill_types_for_lint; pub(crate) mod function; +pub(crate) mod index; pub(crate) mod interface; pub(crate) mod mode; pub(crate) mod oracle; @@ -52,10 +53,17 @@ pub mod macro_support; mod tests; pub use basic::TyBasic; +pub use call_args::TyCallArgs; pub use callable::TyCallable; pub use callable_param::ParamIsRequired; pub use callable_param::ParamSpec; +pub use error::TypingError; +pub use error::TypingNoContextOrInternalError; +pub use error::TypingOrInternalError; +pub use function::TyCustomFunctionImpl; pub use function::TyFunction; +pub use index::TyCustomIndex; +pub use index::TyCustomIndexImpl; pub use interface::Interface; pub use oracle::ctx::TypingOracleCtx; pub use oracle::traits::TypingBinOp; diff --git a/starlark-rust/starlark/src/typing/call_args.rs b/starlark-rust/starlark/src/typing/call_args.rs index 6675a45024a5c..41f4dda51375d 100644 --- a/starlark-rust/starlark/src/typing/call_args.rs +++ b/starlark-rust/starlark/src/typing/call_args.rs @@ -21,9 +21,12 @@ use crate::typing::Ty; /// Function call arguments. pub struct TyCallArgs<'a> { - pub(crate) pos: Vec>, - pub(crate) named: Vec>, + /// Positional args + pub pos: Vec>, + /// Named args + pub named: Vec>, /// In starlark, `*args` always come after all positional and named arguments. - pub(crate) args: Option>, - pub(crate) kwargs: Option>, + pub args: Option>, + /// `**kwargs` + pub kwargs: Option>, } diff --git a/starlark-rust/starlark/src/typing/callable.rs b/starlark-rust/starlark/src/typing/callable.rs index 49760d59feb52..b6d3eded79b21 100644 --- a/starlark-rust/starlark/src/typing/callable.rs +++ b/starlark-rust/starlark/src/typing/callable.rs @@ -63,7 +63,8 @@ impl TyCallable { &self.inner.params } - pub(crate) fn result(&self) -> &Ty { + /// The return type of the callable + pub fn result(&self) -> &Ty { &self.inner.result } diff --git a/starlark-rust/starlark/src/typing/callable_param.rs b/starlark-rust/starlark/src/typing/callable_param.rs index bdc92da01bbf6..deac01dd96a7c 100644 --- a/starlark-rust/starlark/src/typing/callable_param.rs +++ b/starlark-rust/starlark/src/typing/callable_param.rs @@ -292,7 +292,7 @@ impl ParamSpec { } /// `*args`. - pub(crate) fn args(ty: Ty) -> ParamSpec { + pub fn args(ty: Ty) -> ParamSpec { ParamSpec::new_parts([], [], Some(ty), [], None).expect("Cannot fail") } @@ -301,8 +301,8 @@ impl ParamSpec { ParamSpec::new_parts([], [], None, [], Some(ty)).expect("Cannot fail") } - /// `arg=, arg=, ..., arg, arg, ..., /`. - pub(crate) fn pos_only( + /// `arg, arg, ..., arg=, arg=, ..., /`. + pub fn pos_only( required: impl IntoIterator, optional: impl IntoIterator, ) -> ParamSpec { diff --git a/starlark-rust/starlark/src/typing/error.rs b/starlark-rust/starlark/src/typing/error.rs index 05926fba430b0..f987c47409b35 100644 --- a/starlark-rust/starlark/src/typing/error.rs +++ b/starlark-rust/starlark/src/typing/error.rs @@ -78,8 +78,9 @@ impl TypingError { // TODO(nga): some errors we create, we ignore later. For example, when typechecking a union, // if either variant is good, we ignore the other variant errors. // So we pay for expensive error creation we ignore. Make this function cheap. + /// New with a message #[cold] - pub(crate) fn msg(message: impl Display, span: Span, codemap: &CodeMap) -> TypingError { + pub fn msg(message: impl Display, span: Span, codemap: &CodeMap) -> TypingError { TypingError(EvalException::new_anyhow( anyhow::Error::msg(message.to_string()), span, @@ -120,7 +121,9 @@ pub struct TypingNoContextError; /// * Typing error means, types are not compatible. /// * Internal error means, bug in the typechecker. pub enum TypingOrInternalError { + /// Types are not compatible Typing(TypingError), + /// Bug in the type checker Internal(InternalError), } @@ -136,8 +139,12 @@ impl From for TypingOrInternalError { } } +/// Either an acontextual typing error (without a message/span), +/// or an internal error with details. pub enum TypingNoContextOrInternalError { + /// Types are not compatible Typing, + /// Bug in the type checker Internal(InternalError), } diff --git a/starlark-rust/starlark/src/typing/function.rs b/starlark-rust/starlark/src/typing/function.rs index 54dea06f6bde8..b9a885795896d 100644 --- a/starlark-rust/starlark/src/typing/function.rs +++ b/starlark-rust/starlark/src/typing/function.rs @@ -36,13 +36,20 @@ use crate::typing::error::TypingOrInternalError; use crate::values::typing::type_compiled::alloc::TypeMatcherAlloc; /// Custom function typechecker. +/// +/// Can be used to implement generics, where e.g. the return type depends on the arguments, +/// or where arguments are checked to be all the same type, etc. pub trait TyCustomFunctionImpl: Debug + Eq + Ord + Hash + Allocative + Send + Sync + 'static { + /// Whether this function is also a type. For example, `list` is a function and also a type. + /// + /// Default is false. fn is_type(&self) -> bool { false } + /// Type-check a function call. Returns the return type of the function. fn validate_call( &self, span: Span, @@ -50,8 +57,11 @@ pub trait TyCustomFunctionImpl: oracle: TypingOracleCtx, ) -> Result; + /// Represent this as a [`TyCallable`], for display purposes. fn as_callable(&self) -> TyCallable; + /// Only for `TyFunction`'s implementation. + #[doc(hidden)] fn as_function(&self) -> Option<&TyFunction> { None } diff --git a/starlark-rust/starlark/src/typing/index.rs b/starlark-rust/starlark/src/typing/index.rs new file mode 100644 index 0000000000000..7d54523c7fee1 --- /dev/null +++ b/starlark-rust/starlark/src/typing/index.rs @@ -0,0 +1,49 @@ +/* + * Copyright 2019 The Starlark in Rust Authors. + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +use std::fmt::Debug; +use std::sync::Arc; + +use allocative::Allocative; +use dupe::Dupe; + +use crate::typing::Ty; +use crate::typing::TyBasic; +use crate::typing::TypingOracleCtx; +use crate::typing::error::TypingNoContextOrInternalError; + +/// Custom index (`[x] -> y`) typechecker. +pub trait TyCustomIndexImpl: Allocative + Debug + Send + Sync + 'static { + /// Type check an index operation. The idea of using this is for the return type to depend on + /// the index item type. + fn index( + &self, + item: &TyBasic, + ctx: &TypingOracleCtx, + ) -> Result; +} + +/// An `Arc` wrapper for [`TyCustomIndexImpl`]. +#[derive(Debug, Allocative, Clone, Dupe)] +pub struct TyCustomIndex(pub(crate) Arc); + +impl TyCustomIndex { + /// Create new + pub fn new(ty: T) -> Self { + Self(Arc::new(ty)) + } +} diff --git a/starlark-rust/starlark/src/typing/oracle/ctx.rs b/starlark-rust/starlark/src/typing/oracle/ctx.rs index 404e26141af28..51dc5f6c6777c 100644 --- a/starlark-rust/starlark/src/typing/oracle/ctx.rs +++ b/starlark-rust/starlark/src/typing/oracle/ctx.rs @@ -87,7 +87,8 @@ pub struct TypingOracleCtx<'a> { } impl<'a> TypingOracleCtx<'a> { - pub(crate) fn mk_error(&self, span: Span, err: impl Into) -> TypingError { + /// Make an error + pub fn mk_error(&self, span: Span, err: impl Into) -> TypingError { TypingError::new_anyhow(err.into(), span, self.codemap) } diff --git a/starlark-rust/starlark/src/typing/ty.rs b/starlark-rust/starlark/src/typing/ty.rs index 751004b3a9012..5fc96413ade34 100644 --- a/starlark-rust/starlark/src/typing/ty.rs +++ b/starlark-rust/starlark/src/typing/ty.rs @@ -148,7 +148,8 @@ impl Ty { } } - pub(crate) const fn basic(basic: TyBasic) -> Self { + /// Create a type that is not a union + pub const fn basic(basic: TyBasic) -> Self { Ty { alternatives: SmallArcVec1::one(basic), } @@ -305,6 +306,21 @@ impl Ty { } } + /// Typecheck through a callable to its return type. + pub fn as_callable_return(&self) -> Option { + self.typecheck_union_simple(|basic| match basic { + TyBasic::Callable(c) => Ok(c.result().dupe()), + TyBasic::Custom(custom) => custom + .0 + .as_callable_dyn() + .map(|callable| callable.result().dupe()) + .ok_or(TypingNoContextError), + TyBasic::Any => Ok(Ty::any()), + _ => Err(TypingNoContextError), + }) + .ok() + } + /// Create a unions type, which will be normalised before being created. pub fn unions(xs: Vec) -> Self { // Handle common cases first. @@ -494,14 +510,14 @@ impl Ty { pub(crate) fn from_native_callable_components( comp: &NativeCallableComponents, as_type: Option, - ) -> starlark::Result { + ) -> Self { let result = comp.return_type.clone(); let params = comp.param_spec.param_spec(); match as_type { - None => Ok(Ty::function(params, result)), - Some(type_attr) => Ok(Ty::ctor_function(type_attr, params, result)), + None => Ty::function(params, result), + Some(type_attr) => Ty::ctor_function(type_attr, params, result), } } diff --git a/starlark-rust/starlark/src/typing/user.rs b/starlark-rust/starlark/src/typing/user.rs index 87e028675682f..021bb38bd7b00 100644 --- a/starlark-rust/starlark/src/typing/user.rs +++ b/starlark-rust/starlark/src/typing/user.rs @@ -26,6 +26,7 @@ use starlark_syntax::codemap::Span; use crate::typing::Ty; use crate::typing::TyBasic; +use crate::typing::TyCustomIndex; use crate::typing::TypingOracleCtx; use crate::typing::call_args::TyCallArgs; use crate::typing::callable::TyCallable; @@ -107,6 +108,8 @@ pub struct TyUserParams { pub fields: TyUserFields, /// Set if more precise callable signature is known than `base` provides. pub callable: Option, + /// Set for a custom typing function for the index + pub index_custom: Option, /// Set if more precise index signature is known than `base` provides. pub index: Option, /// Set if more precise iter item is known than `base` provides. @@ -130,6 +133,7 @@ pub struct TyUser { /// Set if more precise callable signature is known than `base` provides. callable: Option, /// Set if more precise index signature is known than `base` provides. + index_custom: Option, index: Option, /// Set if more precise iter item is known than `base` provides. iter_item: Option, @@ -148,6 +152,7 @@ impl TyUser { matcher, fields, callable, + index_custom, index, iter_item, _non_exhaustive: (), @@ -157,7 +162,7 @@ impl TyUser { name, ))); } - if index.is_some() && !base.is_indexable() { + if (index_custom.is_some() || index.is_some()) && !base.is_indexable() { return Err(crate::Error::new_native( TyUserError::IndexableNotIndexable(name), )); @@ -175,6 +180,7 @@ impl TyUser { id, fields, callable, + index_custom, index, iter_item, }) @@ -234,7 +240,9 @@ impl TyCustomImpl for TyUser { item: &TyBasic, ctx: &TypingOracleCtx, ) -> Result { - if let Some(index) = &self.index { + if let Some(index_custom) = &self.index_custom { + return index_custom.0.index(item, ctx); + } else if let Some(index) = &self.index { if !ctx.intersects(&Ty::basic(item.dupe()), &index.index)? { return Err(TypingNoContextOrInternalError::Typing); } @@ -253,11 +261,10 @@ impl TyCustomImpl for TyUser { } fn as_callable(&self) -> Option { - if self.base.is_callable() { - Some(TyCallable::any()) - } else { - None - } + self.callable + .as_ref() + .filter(|_| self.base.is_callable()) + .cloned() } fn validate_call( @@ -292,6 +299,15 @@ impl TyCustomImpl for TyUser { } self.supertypes.iter().any(|x| x == other) } + + fn bin_op( + &self, + bin_op: super::TypingBinOp, + rhs: &TyBasic, + _ctx: &TypingOracleCtx, + ) -> Result { + Ok(self.base.bin_op(bin_op, rhs)?) + } } #[cfg(test)] diff --git a/starlark-rust/starlark_derive/src/module/parse/fun.rs b/starlark-rust/starlark_derive/src/module/parse/fun.rs index 02dc7a388cbe9..4e071516258e8 100644 --- a/starlark-rust/starlark_derive/src/module/parse/fun.rs +++ b/starlark-rust/starlark_derive/src/module/parse/fun.rs @@ -425,13 +425,6 @@ pub(crate) fn parse_fun(func: ItemFn, module_kind: ModuleKind) -> syn::Result StarFunSource::Arguments, diff --git a/starlark-rust/starlark_derive/src/module/render/fun.rs b/starlark-rust/starlark_derive/src/module/render/fun.rs index 63869d83cd26d..8415d91d30374 100644 --- a/starlark-rust/starlark_derive/src/module/render/fun.rs +++ b/starlark-rust/starlark_derive/src/module/render/fun.rs @@ -160,18 +160,14 @@ impl StarFun { "methods cannot have an `as_type` attribute", )); } - if self.starlark_ty_custom_function.is_some() { - return Err(syn::Error::new( - self.span(), - "methods cannot have a `ty_custom_function` attribute", - )); - } + let ty_custom = self.ty_custom_expr(); Ok(syn::parse_quote! { #[allow(clippy::redundant_closure)] globals_builder.set_method( #name_str, #components, #param_spec, + #ty_custom, __starlark_invoke_outer #turbofish, ); })