diff --git a/app/buck2_build_api/src/interpreter/rule_defs/provider/callable.rs b/app/buck2_build_api/src/interpreter/rule_defs/provider/callable.rs index 2f467db1d5ec1..966c5a3e6d724 100644 --- a/app/buck2_build_api/src/interpreter/rule_defs/provider/callable.rs +++ b/app/buck2_build_api/src/interpreter/rule_defs/provider/callable.rs @@ -157,6 +157,7 @@ pub(crate) struct UserProviderCallableData { /// Type id of provider callable instance. pub(crate) ty_provider_type_instance_id: TypeInstanceId, pub(crate) fields: IndexMap, + pub(crate) ty_provider: Ty, } /// Initialized after the name is assigned to the provider. @@ -387,6 +388,7 @@ impl<'v> StarlarkValue<'v> for UserProviderCallable { provider_id, fields: self.fields.clone(), ty_provider_type_instance_id, + ty_provider: ty_provider.dupe(), }), ty_provider, ty_callable, diff --git a/app/buck2_build_api/src/interpreter/rule_defs/provider/user.rs b/app/buck2_build_api/src/interpreter/rule_defs/provider/user.rs index 487c3bd5745aa..b18935bacb857 100644 --- a/app/buck2_build_api/src/interpreter/rule_defs/provider/user.rs +++ b/app/buck2_build_api/src/interpreter/rule_defs/provider/user.rs @@ -148,6 +148,10 @@ where fn provide(&'v self, demand: &mut Demand<'_, 'v>) { demand.provide_value::<&dyn ProviderLike>(self); } + + fn typechecker_ty(&self) -> Option { + Some(self.callable.ty_provider.dupe()) + } } impl<'v, V: ValueLike<'v>> serde::Serialize for UserProviderGen<'v, V> { diff --git a/app/buck2_interpreter_for_build_tests/src/attrs/attrs_global.rs b/app/buck2_interpreter_for_build_tests/src/attrs/attrs_global.rs index ec44eb4309d77..052816e030a4f 100644 --- a/app/buck2_interpreter_for_build_tests/src/attrs/attrs_global.rs +++ b/app/buck2_interpreter_for_build_tests/src/attrs/attrs_global.rs @@ -14,10 +14,6 @@ use buck2_interpreter_for_build::interpreter::testing::Tester; fn test_attr_display() -> buck2_error::Result<()> { let mut tester = Tester::new().unwrap(); tester.run_starlark_bzl_test(r#" -def assert_eq(a, b): - if a != b: - fail(a + " != " + b) - assert_eq(repr(attrs.bool(default = True)), "attrs.bool(default=True)") assert_eq(repr(attrs.string()), "attrs.string()") assert_eq(repr(attrs.list(attrs.string())), "attrs.list(attrs.string())") diff --git a/prelude/android/android_binary_native_library_rules.bzl b/prelude/android/android_binary_native_library_rules.bzl index e00be8b191607..602de49f4e684 100644 --- a/prelude/android/android_binary_native_library_rules.bzl +++ b/prelude/android/android_binary_native_library_rules.bzl @@ -232,13 +232,13 @@ def get_android_binary_native_library_info( if native_library_merge_non_asset_libs: mergemap_cmd.add(cmd_args("--merge-non-asset-libs")) native_library_merge_dir = ctx.actions.declare_output("merge_sequence_output") - native_library_merge_map = native_library_merge_dir.project("merge.map") + native_library_merge_map_file = native_library_merge_dir.project("merge.map") split_groups_map = native_library_merge_dir.project("split_groups.map") mergemap_cmd.add(cmd_args(native_library_merge_dir.as_output(), format = "--output={}")) ctx.actions.run(mergemap_cmd, category = "compute_mergemap", allow_cache_upload = True) enhance_ctx.debug_output("compute_merge_sequence", native_library_merge_dir) - dynamic_inputs.append(native_library_merge_map) + dynamic_inputs.append(native_library_merge_map_file) dynamic_inputs.append(split_groups_map) mergemap_gencode_jar = None @@ -268,7 +268,7 @@ def get_android_binary_native_library_info( # When changing this dynamic_output, the workflow is a lot better if you compute the module graph once and # then set it as the binary's precomputed_apk_module_graph attr. if native_library_merge_sequence: - merge_map_by_platform = artifacts[native_library_merge_map].read_json() + merge_map_by_platform = artifacts[native_library_merge_map_file].read_json() split_groups = artifacts[split_groups_map].read_json() native_library_merge_debug_outputs["merge_sequence_output"] = native_library_merge_dir elif native_library_merge_map: diff --git a/prelude/android/android_manifest.bzl b/prelude/android/android_manifest.bzl index 075ff39f55182..c84312d124b8b 100644 --- a/prelude/android/android_manifest.bzl +++ b/prelude/android/android_manifest.bzl @@ -47,11 +47,13 @@ def generate_android_manifest( ]) if not manifests: - manifests = [] + manifests_args = [] elif isinstance(manifests, TransitiveSet): - manifests = manifests.project_as_args("artifacts", ordering = "bfs") + manifests_args = manifests.project_as_args("artifacts", ordering = "bfs") + else: + manifests_args = manifests - library_manifest_paths_file = argfile(actions = ctx.actions, name = "{}/library_manifest_paths_file".format(module_name), args = manifests) + library_manifest_paths_file = argfile(actions = ctx.actions, name = "{}/library_manifest_paths_file".format(module_name), args = manifests_args) generate_manifest_cmd.add(["--library-manifests-list", library_manifest_paths_file]) diff --git a/prelude/android/android_providers.bzl b/prelude/android/android_providers.bzl index 1daf9eb6f9541..0b55196b82151 100644 --- a/prelude/android/android_providers.bzl +++ b/prelude/android/android_providers.bzl @@ -296,7 +296,7 @@ def merge_android_packageable_info( AndroidBuildConfigInfoTSet, ) - deps = _get_transitive_set( + deps_tset = _get_transitive_set( actions, filter(None, [dep.deps for dep in android_packageable_deps]), DepsInfo( @@ -334,7 +334,7 @@ def merge_android_packageable_info( return AndroidPackageableInfo( target_label = label.raw_target(), build_config_infos = build_config_infos, - deps = deps, + deps = deps_tset, manifests = manifests, prebuilt_native_library_dirs = prebuilt_native_library_dirs, resource_infos = resource_infos, diff --git a/prelude/apple/apple_library.bzl b/prelude/apple/apple_library.bzl index d41824f457cf1..61257fbc3aa49 100644 --- a/prelude/apple/apple_library.bzl +++ b/prelude/apple/apple_library.bzl @@ -630,13 +630,9 @@ def _get_link_style_sub_targets_and_providers( ) if get_apple_stripped_attr_value_with_default_fallback(ctx): - if False: - # TODO(nga): `output.unstripped` is never `None`. - def unknown(): - pass - - output = unknown() - expect(output.unstripped != None, "Expecting unstripped output to be non-null when stripping is enabled.") + # TODO(nga): `output.unstripped` is never `None`. + unstripped: None | typing.Any = output.unstripped + expect(unstripped != None, "Expecting unstripped output to be non-null when stripping is enabled.") dsym_executable = output.unstripped else: dsym_executable = output.default diff --git a/prelude/apple/swift/swift_compilation.bzl b/prelude/apple/swift/swift_compilation.bzl index 36c22d1ae1c84..01e3ce9a7621f 100644 --- a/prelude/apple/swift/swift_compilation.bzl +++ b/prelude/apple/swift/swift_compilation.bzl @@ -876,7 +876,7 @@ def _compile_with_argsfile( allow_cache_upload, action_execution_attributes = _get_action_properties(ctx, toolchain, cacheable, build_swift_incrementally, explicit_modules_enabled) - argsfile, output_file_map = compile_with_argsfile( + argsfile, output_file_map_artifact = compile_with_argsfile( ctx = ctx, category = category, shared_flags = shared_flags, @@ -903,9 +903,9 @@ def _compile_with_argsfile( if extension: # Swift correctly handles relative paths and we can utilize the relative argsfile for Xcode. - return CompileArgsfiles(relative = {extension: argsfile}, xcode = {extension: argsfile}), output_file_map + return CompileArgsfiles(relative = {extension: argsfile}, xcode = {extension: argsfile}), output_file_map_artifact else: - return None, output_file_map + return None, output_file_map_artifact def _get_action_properties( ctx: AnalysisContext, @@ -1154,10 +1154,10 @@ def _add_swift_module_map_args( cmd: cmd_args, is_macro: bool): module_name = get_module_name(ctx) - sdk_swiftmodule_deps_tset = [sdk_swiftmodule_deps_tset] if sdk_swiftmodule_deps_tset else [] + sdk_swiftmodule_deps_tsets = [sdk_swiftmodule_deps_tset] if sdk_swiftmodule_deps_tset else [] all_deps_tset = ctx.actions.tset( SwiftCompiledModuleTset, - children = _get_swift_paths_tsets(is_macro, ctx.attrs.deps + getattr(ctx.attrs, "exported_deps", [])) + [pcm_deps_tset, sdk_deps_tset] + sdk_swiftmodule_deps_tset, + children = _get_swift_paths_tsets(is_macro, ctx.attrs.deps + getattr(ctx.attrs, "exported_deps", [])) + [pcm_deps_tset, sdk_deps_tset] + sdk_swiftmodule_deps_tsets, ) swift_module_map_artifact = write_swift_module_map_with_deps( ctx, diff --git a/prelude/artifact_tset.bzl b/prelude/artifact_tset.bzl index 1281c766ecac5..dacbe045b5819 100644 --- a/prelude/artifact_tset.bzl +++ b/prelude/artifact_tset.bzl @@ -60,7 +60,7 @@ def make_artifact_tset( ) # As a convenience for our callers, filter our `None` children. - children = [c._tset for c in children if c._tset != None] + children_ = [c._tset for c in children if c._tset != None] # Build list of all non-child values. values = [] @@ -69,15 +69,15 @@ def make_artifact_tset( values.extend(infos) # If there's no children or artifacts, return `None`. - if not values and not children: + if not values and not children_: return ArtifactTSet() # We only build a `_ArtifactTSet` if there's something to package. kwargs = {} if values: kwargs["value"] = values - if children: - kwargs["children"] = children + if children_: + kwargs["children"] = children_ return ArtifactTSet( _tset = actions.tset(_ArtifactTSet, **kwargs), ) diff --git a/prelude/cfg/modifier/cfg_constructor.bzl b/prelude/cfg/modifier/cfg_constructor.bzl index b99494c727c3f..1ad1266341170 100644 --- a/prelude/cfg/modifier/cfg_constructor.bzl +++ b/prelude/cfg/modifier/cfg_constructor.bzl @@ -67,15 +67,15 @@ def cfg_constructor_pre_constraint_analysis( Returns `(refs, PostConstraintAnalysisParams)`, where `refs` is a list of fully qualified configuration targets we need providers for. """ - package_modifiers = package_modifiers or [] + package_modifiers_1 = package_modifiers or [] target_modifiers = target_modifiers or [] # Convert JSONs back to TaggedModifiers - package_modifiers = [json_to_tagged_modifiers(modifier_json) for modifier_json in package_modifiers] + tagged_package_modifiers: list[TaggedModifiers] = [json_to_tagged_modifiers(modifier_json) for modifier_json in package_modifiers_1] # Filter PACKAGE modifiers based on rule name. # This only filters out PACKAGE modifiers from `extra_cfg_modifiers_per_rule` argument of `set_cfg_modifiers` function. - package_modifiers = [tagged_modifiers for tagged_modifiers in package_modifiers if tagged_modifiers.rule_name == None or tagged_modifiers.rule_name == rule_name] + tagged_package_modifiers = [tagged_modifiers for tagged_modifiers in tagged_package_modifiers if tagged_modifiers.rule_name == None or tagged_modifiers.rule_name == rule_name] # Resolve all aliases in CLI modifiers cli_modifiers = [resolved_modifier for modifier in cli_modifiers for resolved_modifier in resolve_alias(modifier, aliases)] @@ -85,7 +85,7 @@ def cfg_constructor_pre_constraint_analysis( if buckconfig_backed_modifiers: refs.append(buckconfig_backed_modifiers) - for tagged_modifiers in package_modifiers: + for tagged_modifiers in tagged_package_modifiers: for modifier in tagged_modifiers.modifiers: refs.extend(modifier_to_refs(modifier, tagged_modifiers.location)) for modifier in target_modifiers: @@ -95,7 +95,7 @@ def cfg_constructor_pre_constraint_analysis( return refs, PostConstraintAnalysisParams( legacy_platform = legacy_platform, - package_modifiers = package_modifiers, + package_modifiers = tagged_package_modifiers, target_modifiers = target_modifiers, cli_modifiers = cli_modifiers, extra_data = extra_data, diff --git a/prelude/cxx/cxx_library.bzl b/prelude/cxx/cxx_library.bzl index 8331687e4586e..29e9bf382e8fc 100644 --- a/prelude/cxx/cxx_library.bzl +++ b/prelude/cxx/cxx_library.bzl @@ -1434,7 +1434,7 @@ def _form_library_outputs( outputs = {} solibs = {} link_infos = {} - providers = [] + providers: list[Provider] = [] sanitizer_runtime_files = [] gcno_files = [] diff --git a/prelude/cxx/groups.bzl b/prelude/cxx/groups.bzl index 8166515abae12..45960c6ca861d 100644 --- a/prelude/cxx/groups.bzl +++ b/prelude/cxx/groups.bzl @@ -183,11 +183,11 @@ def _parse_filter(entry: str) -> GroupFilterInfo: if target_regex != None: regex_expr = regex("^{}$".format(target_regex), fancy = True) - def matches_regex(_r, t, _labels): + def matches_regex2(_r, t, _labels): return regex_expr.match(str(t.raw_target())) return GroupFilterInfo( - matches = matches_regex, + matches = matches_regex2, info = {"target_regex": str(regex_expr)}, ) diff --git a/prelude/cxx/link.bzl b/prelude/cxx/link.bzl index 5df6aa230749b..d8a7cd5236921 100644 --- a/prelude/cxx/link.bzl +++ b/prelude/cxx/link.bzl @@ -531,7 +531,10 @@ def _anon_cxx_link( if generates_split_debug(cxx_toolchain): split_debug_output = anon_link_target.artifact("split_debug_output") - output = ctx.actions.assert_short_path(anon_link_target.artifact("output"), short_path = output) + output_artifact = ctx.actions.assert_short_path( + anon_link_target.artifact("output"), + short_path = output, + ) external_debug_info = link_external_debug_info( ctx = ctx, @@ -541,12 +544,12 @@ def _anon_cxx_link( # The anon target API doesn't allow us to return the list of artifacts for # sanitizer runtime, so it has be computed here - sanitizer_runtime_args = cxx_sanitizer_runtime_arguments(ctx, cxx_toolchain, output) + sanitizer_runtime_args = cxx_sanitizer_runtime_arguments(ctx, cxx_toolchain, output_artifact) return CxxLinkResult( linked_object = LinkedObject( - output = output, - unstripped_output = output, + output = output_artifact, + unstripped_output = output_artifact, dwp = dwp, external_debug_info = external_debug_info, ), diff --git a/prelude/cxx/shared_library_interface.bzl b/prelude/cxx/shared_library_interface.bzl index dbdcce486a75d..025e5ab4b3a32 100644 --- a/prelude/cxx/shared_library_interface.bzl +++ b/prelude/cxx/shared_library_interface.bzl @@ -23,14 +23,14 @@ def _shared_library_interface( linker_info = get_cxx_toolchain_info(ctx).linker_info args = cmd_args(linker_info.mk_shlib_intf[RunInfo]) args.add(shared_lib) - output = ctx.actions.declare_output(output) - args.add(output.as_output()) + output_artifact = ctx.actions.declare_output(output) + args.add(output_artifact.as_output()) ctx.actions.run( args, category = "generate_shared_library_interface", identifier = identifier, ) - return output + return output_artifact _InterfaceInfo = provider(fields = { "artifact": provider_field(typing.Any, default = None), # "artifact" diff --git a/prelude/erlang/erlang_application.bzl b/prelude/erlang/erlang_application.bzl index 3d87c57200fbc..2d2260de98e0b 100644 --- a/prelude/erlang/erlang_application.bzl +++ b/prelude/erlang/erlang_application.bzl @@ -317,8 +317,8 @@ def _link_src_dir(ctx: AnalysisContext, *, extra_srcs: list[Artifact]) -> Artifa if ctx.attrs.app_src: srcs[ctx.attrs.app_src.basename] = ctx.attrs.app_src - for extra_srcs in extra_srcs: - srcs[extra_srcs.basename] = extra_srcs + for extra_src in extra_srcs: + srcs[extra_src.basename] = extra_src return ctx.actions.symlinked_dir(paths.join(erlang_build.utils.BUILD_DIR, "src"), srcs) diff --git a/prelude/http_archive/unarchive.bzl b/prelude/http_archive/unarchive.bzl index 64668a14f4e5e..4425604e191a9 100644 --- a/prelude/http_archive/unarchive.bzl +++ b/prelude/http_archive/unarchive.bzl @@ -215,17 +215,18 @@ def unarchive( if needs_strip_prefix: ctx.actions.copy_dir(output.as_output(), script_output.project(strip_prefix), has_content_based_path = has_content_based_path) + sub_targets_dict: dict[str, list[Provider]] = {} if type(sub_targets) == type([]): - sub_targets = { + sub_targets_dict = { path: [DefaultInfo(default_output = output.project(path))] for path in sub_targets } elif type(sub_targets) == type({}): - sub_targets = { + sub_targets_dict = { name: [DefaultInfo(default_outputs = [output.project(path) for path in paths])] for name, paths in sub_targets.items() } else: fail("sub_targets must be a list or dict") - return output, sub_targets + return output, sub_targets_dict diff --git a/prelude/jvm/cd_jar_creator_util.bzl b/prelude/jvm/cd_jar_creator_util.bzl index 0c35f1b6cdcaf..75c16d421fb50 100644 --- a/prelude/jvm/cd_jar_creator_util.bzl +++ b/prelude/jvm/cd_jar_creator_util.bzl @@ -119,7 +119,7 @@ def define_output_paths(actions: AnalysisActions, prefix: [str, None], label: La ) def encode_output_paths(label: Label, paths: OutputPaths, target_type: TargetType) -> struct: - paths = struct( + paths_value = struct( classesDir = paths.classes.as_output(), outputJarDirPath = cmd_args(paths.jar.as_output(), parent = 1), annotationPath = paths.annotations.as_output(), @@ -127,9 +127,9 @@ def encode_output_paths(label: Label, paths: OutputPaths, target_type: TargetTyp ) return struct( - libraryPaths = paths if target_type == TargetType("library") else None, - sourceAbiPaths = paths if target_type == TargetType("source_abi") else None, - sourceOnlyAbiPaths = paths if target_type == TargetType("source_only_abi") else None, + libraryPaths = paths_value if target_type == TargetType("library") else None, + sourceAbiPaths = paths_value if target_type == TargetType("source_abi") else None, + sourceOnlyAbiPaths = paths_value if target_type == TargetType("source_only_abi") else None, libraryTargetFullyQualifiedName = base_qualified_name(label), ) diff --git a/prelude/linking/link_info.bzl b/prelude/linking/link_info.bzl index db1bcc887f44b..1076a40c56f34 100644 --- a/prelude/linking/link_info.bzl +++ b/prelude/linking/link_info.bzl @@ -1039,9 +1039,9 @@ LinkCommandDebugOutput = record( # NB: Debug output is _not_ transitive over deps, so tsets are not used here. LinkCommandDebugOutputInfo = provider( - fields = [ - "debug_outputs", # ["LinkCommandDebugOutput"] - ], + fields = { + "debug_outputs": provider_field(list[LinkCommandDebugOutput]), + }, ) UnstrippedLinkOutputInfo = provider(fields = { diff --git a/prelude/linking/linkable_graph.bzl b/prelude/linking/linkable_graph.bzl index c0cd640e1add8..acd90b5922be3 100644 --- a/prelude/linking/linkable_graph.bzl +++ b/prelude/linking/linkable_graph.bzl @@ -190,15 +190,15 @@ def create_linkable_node( ) if not linker_flags: linker_flags = LinkerFlags() - deps = linkable_deps(deps) - exported_deps = linkable_deps(exported_deps) + ldeps = linkable_deps(deps) + lexported_deps = linkable_deps(exported_deps) return LinkableNode( labels = ctx.attrs.labels, preferred_linkage = preferred_linkage, default_link_strategy = default_link_strategy, - deps = deps, - exported_deps = exported_deps, - all_deps = deps + exported_deps, + deps = ldeps, + exported_deps = lexported_deps, + all_deps = ldeps + lexported_deps, link_infos = link_infos, shared_libs = shared_libs, can_be_asset = can_be_asset, diff --git a/prelude/rust/build.bzl b/prelude/rust/build.bzl index 42812b54e2e01..43801aa3a0fe7 100644 --- a/prelude/rust/build.bzl +++ b/prelude/rust/build.bzl @@ -931,7 +931,7 @@ def _abbreviated_subdir( infallible_diagnostics: bool, incremental_enabled: bool, profile_mode: ProfileMode | None) -> str: - crate_type = { + crate_type_str = { CrateType("bin"): "X", # mnemonic: "eXecutable" CrateType("rlib"): "L", # "Library" CrateType("dylib"): "D", @@ -940,7 +940,7 @@ def _abbreviated_subdir( CrateType("staticlib"): "S", }[crate_type] - reloc_model = { + reloc_model_str = { RelocModel("static"): "S", RelocModel("pic"): "P", RelocModel("pie"): "I", @@ -951,13 +951,13 @@ def _abbreviated_subdir( RelocModel("default"): "D", }[reloc_model] - dep_link_strategy = { + dep_link_strategy_str = { LinkStrategy("static"): "T", LinkStrategy("static_pic"): "P", LinkStrategy("shared"): "H", }[dep_link_strategy] - emit = { + emit_str = { Emit("asm"): "s", Emit("llvm-bc"): "b", Emit("llvm-ir"): "i", @@ -972,17 +972,17 @@ def _abbreviated_subdir( Emit("metadata-fast"): "M", # "Metadata" }[emit] - profile_mode = { + profile_mode_str = { None: "", ProfileMode("llvm-time-trace"): "L", ProfileMode("self-profile"): "P", }[profile_mode] - return crate_type + reloc_model + dep_link_strategy + emit + \ + return crate_type_str + reloc_model_str + dep_link_strategy_str + emit_str + \ ("T" if is_rustdoc_test else "") + \ ("D" if infallible_diagnostics else "") + \ ("I" if incremental_enabled else "") + \ - profile_mode + profile_mode_str # Compute which are common to both rustc and rustdoc def _compute_common_args( diff --git a/prelude/rust/extern.bzl b/prelude/rust/extern.bzl index 15f1dfa41739f..855323e2e0292 100644 --- a/prelude/rust/extern.bzl +++ b/prelude/rust/extern.bzl @@ -26,13 +26,13 @@ def crate_name_as_cmd_arg(crate: CrateName) -> cmd_args | str | ResolvedStringWi # def extern_arg(flags: list[str], crate: CrateName, lib: Artifact) -> cmd_args: if flags == []: - flags = "" + flags_ = "" else: - flags = ",".join(flags) + ":" + flags_ = ",".join(flags) + ":" return cmd_args( "--extern=", - flags, + flags_, crate_name_as_cmd_arg(crate), "=", lib, diff --git a/prelude/sh_binary.bzl b/prelude/sh_binary.bzl index 4e72a16e3c900..e53e025e64470 100644 --- a/prelude/sh_binary.bzl +++ b/prelude/sh_binary.bzl @@ -35,15 +35,15 @@ def _generate_script( main_link = main_path if main_path.endswith(".bat") or main_path.endswith(".cmd") else main_path + ".bat" else: main_link = main_path if main_path.endswith(".sh") else main_path + ".sh" - resources = {_derive_link(src): src for src in resources} - resources[main_link] = main + resources_dict = {_derive_link(src): src for src in resources} + resources_dict[main_link] = main # windows isn't stable with resources passed in as symbolic links for # remote execution. Allow using copies instead. if copy_resources: - resources_dir = actions.copied_dir("resources", resources, has_content_based_path = has_content_based_path) + resources_dir = actions.copied_dir("resources", resources_dict, has_content_based_path = has_content_based_path) else: - resources_dir = actions.symlinked_dir("resources", resources, has_content_based_path = has_content_based_path) + resources_dir = actions.symlinked_dir("resources", resources_dict, has_content_based_path = has_content_based_path) script_name = name + (".bat" if is_windows else "") script = actions.declare_output(script_name, has_content_based_path = has_content_based_path) diff --git a/prelude/tests/test_toolchain.bzl b/prelude/tests/test_toolchain.bzl index 0381b4d747ee7..4d0abb68a3b70 100644 --- a/prelude/tests/test_toolchain.bzl +++ b/prelude/tests/test_toolchain.bzl @@ -25,5 +25,5 @@ noop_test_toolchain = rule( def test_toolchain_labels( test_toolchain: Dependency) -> list[str]: asserts.true(TestToolchainInfo in test_toolchain, "Expected a TestToolchainInfo provider") - test_toolchain = test_toolchain[TestToolchainInfo] - return [test_toolchain.sanitizer] if test_toolchain.sanitizer else [] + test_toolchain_info = test_toolchain[TestToolchainInfo] + return [test_toolchain_info.sanitizer] if test_toolchain_info.sanitizer else [] diff --git a/prelude/third-party/build.bzl b/prelude/third-party/build.bzl index 915c4eecf3c9b..f02ff3374cf6c 100644 --- a/prelude/third-party/build.bzl +++ b/prelude/third-party/build.bzl @@ -88,12 +88,12 @@ def create_third_party_build_root( ) cmd.add(cmd_args(argsfile, format = "@{}", hidden = [s.lib.output for s in shared_libs])) - out = ctx.actions.declare_output(out, dir = True) - cmd.add(out.as_output()) + out_dir = ctx.actions.declare_output(out, dir = True) + cmd.add(out_dir.as_output()) ctx.actions.run(cmd, category = "third_party_build_root") - return artifact_ext(out) + return artifact_ext(out_dir) def create_third_party_build_info( ctx: AnalysisContext, diff --git a/starlark-rust/starlark/src/eval/compiler/module.rs b/starlark-rust/starlark/src/eval/compiler/module.rs index dce7c13a42c41..35c83f3f16c85 100644 --- a/starlark-rust/starlark/src/eval/compiler/module.rs +++ b/starlark-rust/starlark/src/eval/compiler/module.rs @@ -17,6 +17,8 @@ //! Compile and evaluate module top-level statements. +use std::collections::HashMap; + use starlark_syntax::eval_exception::EvalException; use starlark_syntax::syntax::ast::LoadP; use starlark_syntax::syntax::ast::StmtP; @@ -36,11 +38,9 @@ use crate::eval::runtime::frame_span::FrameSpan; use crate::eval::runtime::frozen_file_span::FrozenFileSpan; use crate::typing::Ty; use crate::typing::TypingOracleCtx; -use crate::typing::bindings::BindingsCollect; -use crate::typing::error::InternalError; use crate::typing::fill_types_for_lint::ModuleVarTypes; use crate::typing::mode::TypecheckMode; -use crate::typing::typecheck::solve_bindings; +use crate::typing::typecheck::TypeChecker; use crate::values::FrozenRef; use crate::values::FrozenStringValue; use crate::values::Value; @@ -162,12 +162,12 @@ impl<'v> Compiler<'v, '_, '_, '_> { } } - self.typecheck(&mut stmts)?; + self.typecheck(stmt)?; Ok(last) } - fn typecheck(&mut self, stmts: &mut [&mut CstStmt]) -> Result<(), EvalException> { + fn typecheck(&mut self, stmt: &CstStmt) -> Result<(), EvalException> { let typecheck = self.eval.static_typechecking || self.typecheck; if !typecheck { return Ok(()); @@ -177,27 +177,19 @@ impl<'v> Compiler<'v, '_, '_, '_> { codemap: &self.codemap, }; let module_var_types = self.mk_module_var_types(); - for top in stmts.iter_mut() { - if let StmtP::Def(_) = &mut top.node { - let BindingsCollect { bindings, .. } = BindingsCollect::collect_one( - top, - TypecheckMode::Compiler, - &self.codemap, - &mut Vec::new(), - ) - .map_err(InternalError::into_eval_exception)?; - let (errors, ..) = match solve_bindings(bindings, oracle, &module_var_types) { - Ok(x) => x, - Err(e) => return Err(e.into_eval_exception()), - }; - - if let Some(error) = errors.into_iter().next() { - return Err(error.into_eval_exception()); - } - } - } - - Ok(()) + let mut approximations = Vec::new(); + let mut checker = TypeChecker { + oracle, + typecheck_mode: TypecheckMode::Compiler, + module_var_types: &module_var_types, + approximations: &mut approximations, + all_solved_types: HashMap::new(), + }; + // Just immediately return on any type error + let mut error_handler = Err; + checker + .check_module_scope(stmt, &mut error_handler) + .map_err(|e| e.into_eval_exception()) } fn mk_module_var_types(&self) -> ModuleVarTypes { diff --git a/starlark-rust/starlark/src/tests/derive/attrs.rs b/starlark-rust/starlark/src/tests/derive/attrs.rs index 84b00257f5f98..6069f074b12fb 100644 --- a/starlark-rust/starlark/src/tests/derive/attrs.rs +++ b/starlark-rust/starlark/src/tests/derive/attrs.rs @@ -25,6 +25,7 @@ use crate as starlark; use crate::any::ProvidesStaticType; use crate::assert::Assert; use crate::starlark_simple_value; +use crate::typing::Ty; use crate::values::StarlarkAttrs; use crate::values::StarlarkValue; @@ -53,6 +54,11 @@ fn test_derive_attrs() { #[starlark_value(type = "example")] impl<'v> StarlarkValue<'v> for Example { starlark_attrs!(); + + // Ideally we could derive a TyUser. + fn typechecker_ty(&self) -> Option { + Some(Ty::any()) + } } #[derive( diff --git a/starlark-rust/starlark/src/tests/type_annot.rs b/starlark-rust/starlark/src/tests/type_annot.rs index a9d2bc0a52b54..cb028edc0830a 100644 --- a/starlark-rust/starlark/src/tests/type_annot.rs +++ b/starlark-rust/starlark/src/tests/type_annot.rs @@ -55,13 +55,6 @@ fn test_types_enable() { #[test] fn test_type_assign_annotation() { - assert::pass( - r#" -x : str = "test" -xs: typing.Any = [1,2] -xs[0] : int = 4 -"#, - ); assert::fail( "a, b : typing.Any = 1, 2", "not allowed on multiple assignments", @@ -70,9 +63,117 @@ xs[0] : int = 4 "a = 1\na : typing.Any += 1", "not allowed on augmented assignments", ); + assert::fail( + r#" +def check(): + x = struct(y = 5) + x.y: int = 10 +"#, + "not allowed on complex assignments", + ); + assert::fail( + r#" +def check(): + x = list() + x[0]: int = 10 +"#, + "not allowed on complex assignments", + ); assert::fail("a : str = noop(1)", "does not match the type annotation"); } +#[test] +fn test_type_annotation_is_definitive() { + assert::fail( + r#" +def mk_any() -> typing.Any: + return 5 +def float_only(a: float): + pass +def check(): + x: str = mk_any() + float_only(x) +"#, + "Expected type `float` but got `str`", + ); +} + +#[test] +fn test_type_annotation_is_checked() { + assert::fail( + r#" +def check(): + x: float = "hello" +"#, + "Expected type `float` but got `str`", + ); + assert::fail( + r#" +def check(): + x: float = 5.5 + x = "hello" +"#, + "Expected type `float` but got `str`", + ); + assert::fail( + r#" +def check(): + x: float = 5.5 + x, _ = ("", "") +"#, + "Expected type `float` but got `str`", + ); +} + +#[test] +fn test_fail_double_type_declaration() { + assert::fail( + r#" +def check(): + x: str = "hello" + x: int = 5 +"#, + "Second declaration of binding `x`", + ); +} + +#[test] +fn test_fail_second_type_declaration_only() { + assert::fail( + r#" +def check(): + x = "hello" + x: int = 5 +"#, + "Second declaration of binding `x`", + ); +} + +#[test] +fn test_checks_assignment() { + assert::fail( + r#" +def check(): + x = "hello" + y: int = x +"#, + "Expected type `int` but got `str`", + ); +} + +#[test] +fn test_checks_second_assignment() { + assert::fail( + r#" +def check(): + y: int = 4 + x = "hello" + y = x +"#, + "Expected type `int` but got `str`", + ); +} + #[test] fn test_only_globals_or_bultins_allowed() { assert::fail( @@ -132,3 +233,72 @@ def foo(x: T): pass "String literals are not allowed in type expressions", ); } + +#[test] +fn test_double_def_toplevel() { + assert::fail( + r#" +def foo(): pass +def foo(): pass + "#, + "Second declaration of binding `foo`", + ); +} + +#[test] +fn test_double_def_different_signatures_toplevel() { + assert::fail( + r#" +def foo(): pass +def foo(x: int): pass + "#, + "Second declaration of binding `foo`", + ); +} + +#[test] +fn test_assign_to_def_toplevel() { + assert::fails( + r#" +def foo(): pass +foo = 5 + "#, + &["Expected type", "but got `int`"], + ); +} + +#[test] +fn test_double_def() { + assert::fail( + r#" +def check(): + def foo(): pass + def foo(): pass +"#, + "Second declaration of binding `foo`", + ); +} + +#[test] +fn test_double_def_different_signatures() { + assert::fail( + r#" +def check(): + def foo(): pass + def foo(x: int): pass +"#, + "Second declaration of binding `foo`", + ); +} + +#[test] +fn test_assign_to_def() { + assert::fails( + r#" +def check(): + def foo(): pass + foo = 5 + "#, + &["Expected type", "but got `int`"], + ); +} diff --git a/starlark-rust/starlark/src/typing/bindings.rs b/starlark-rust/starlark/src/typing/bindings.rs index e7ff9b2f67b8b..32d21a6d9f3c0 100644 --- a/starlark-rust/starlark/src/typing/bindings.rs +++ b/starlark-rust/starlark/src/typing/bindings.rs @@ -39,6 +39,7 @@ use crate::codemap::Span; use crate::codemap::Spanned; use crate::eval::compiler::scope::BindingId; use crate::eval::compiler::scope::ResolvedIdent; +use crate::eval::compiler::scope::payload::CstAssignIdent; use crate::eval::compiler::scope::payload::CstAssignIdentExt; use crate::eval::compiler::scope::payload::CstAssignTarget; use crate::eval::compiler::scope::payload::CstExpr; @@ -50,6 +51,7 @@ use crate::typing::TyBasic; use crate::typing::arc_ty::ArcTy; use crate::typing::callable_param::ParamIsRequired; use crate::typing::error::InternalError; +use crate::typing::error::TypingError; use crate::typing::mode::TypecheckMode; use crate::typing::tuple::TyTuple; use crate::typing::ty::Approximation; @@ -97,31 +99,55 @@ pub(crate) struct Bindings<'a> { /// if expr: ... /// ``` pub(crate) check: Vec<&'a CstExpr>, - pub(crate) check_type: Vec<(Span, Option<&'a CstExpr>, Ty)>, + /// Bind expressions that need to be checked. + /// + /// While collecting, this is just `return` (None bindexpr), + /// `return bindexpr`, and `x: ty = bindexpr`. + /// + /// But once we know which bindings are annotated, we can + /// insert a check_type for any expression bound to an annotated binding. + pub(crate) check_type: Vec<(Span, Option>, Ty)>, + + /// Includes double-binding errors, like `x: str = ...; x: str = ...`. + /// These break compiler typechecker, but LSP continues past them. + pub(crate) errors: Vec, +} + +/// Basically a child `def`. +pub(crate) struct ChildDef<'a> { + pub(crate) body: &'a CstStmt, + pub(crate) return_type: Ty, + pub(crate) param_types: HashMap, } pub(crate) struct BindingsCollect<'a, 'b> { pub(crate) bindings: Bindings<'a>, pub(crate) approximations: &'b mut Vec, + pub(crate) children_sink: &'b mut Vec>, } impl<'a, 'b> BindingsCollect<'a, 'b> { - /// Collect all the assignments to variables. + /// Collect all the assignments to variables in a scope. /// /// This function only fails on internal errors. - pub(crate) fn collect_one( - x: &'a mut CstStmt, + pub(crate) fn collect_scope( + scope: &'a CstStmt, + return_type: &Ty, + visible: &HashMap, typecheck_mode: TypecheckMode, codemap: &CodeMap, approximations: &'b mut Vec, - ) -> Result { + children_sink: &'b mut Vec>, + ) -> Result, InternalError> { let mut res = BindingsCollect { bindings: Bindings::default(), approximations, + children_sink, }; + res.bindings.types = visible.clone(); - res.visit(Visit::Stmt(x), &Ty::any(), typecheck_mode, codemap)?; - Ok(res) + res.visit(Visit::Stmt(scope), return_type, typecheck_mode, codemap)?; + Ok(res.bindings) } fn assign( @@ -206,6 +232,43 @@ impl<'a, 'b> BindingsCollect<'a, 'b> { } } + /// Check declare-once and insert an explicit type for the binding. + fn type_annotation( + &mut self, + binding: &CstAssignIdent, + ty: Ty, + error_span: Span, + codemap: &CodeMap, + ) -> Result<(), InternalError> { + let resolved_id = binding.resolved_binding_id(codemap)?; + // This is always an error for the compiler, but LSP might continue typechecking in a + // degraded state. + // + // Two cases: + // 1. `x = ...; x: str = ...`. We don't add the annotation. Binding goes through the solver. + // 2. `x: str = ...; x: int = ...`, even if they're the same type. Keep the original for + // degraded typecheck. + // + let entry = self.bindings.types.entry(resolved_id); + // Previous un-annotated assignment + if self.bindings.expressions.contains_key(&resolved_id) + // Previous typed assignment + || matches!(entry, std::collections::hash_map::Entry::Occupied(_)) + { + self.bindings.errors.push(TypingError::new( + crate::Error::new_kind(starlark_syntax::ErrorKind::Other(anyhow::anyhow!( + "Second declaration of binding `{}`", + binding.node.ident, + ))), + error_span, + codemap, + )); + return Ok(()); + } + entry.or_insert(ty); + Ok(()) + } + fn visit_def( &mut self, def: &'a DefP, @@ -226,12 +289,13 @@ impl<'a, 'b> BindingsCollect<'a, 'b> { let mut args = None; let mut named_only = Vec::new(); let mut kwargs = None; + let mut param_types = HashMap::new(); for p in params { let name = &p.node.ident; let ty = p.node.ty; let ty = Self::resolve_ty_opt(ty, typecheck_mode, codemap)?; - let name_ty = match &p.node.kind { + let ty = match &p.node.kind { DefParamKind::Regular(mode, default_value) => { let required = match default_value.is_some() { true => ParamIsRequired::No, @@ -256,34 +320,48 @@ impl<'a, 'b> BindingsCollect<'a, 'b> { )); } } - Some((name, ty)) + ty } DefParamKind::Args => { // There is the type we require people calling us use (usually any) // and then separately the type we are when we are running (always tuple) args = Some(ty.dupe()); - Some((name, Ty::basic(TyBasic::Tuple(TyTuple::Of(ArcTy::new(ty)))))) + Ty::basic(TyBasic::Tuple(TyTuple::Of(ArcTy::new(ty)))) } DefParamKind::Kwargs => { let var_ty = Ty::dict(Ty::string(), ty.clone()); kwargs = Some(ty.dupe()); - Some((name, var_ty)) + var_ty } }; - if let Some((name, ty)) = name_ty { - self.bindings - .types - .insert(name.resolved_binding_id(codemap)?, ty); - } + param_types.insert(name.resolved_binding_id(codemap)?, ty); } let params2 = ParamSpec::new_parts(pos_only, pos_or_named, args, named_only, kwargs) .map_err(|e| InternalError::from_error(e, def.signature_span(), codemap))?; let ret_ty = Self::resolve_ty_opt(return_type.as_deref(), typecheck_mode, codemap)?; - self.bindings.types.insert( - name.resolved_binding_id(codemap)?, + + // Defining a function is like an annotated assignment + // + // foo: Function[...] = lambda x, y: ... + // + // (even if there are no type annotations, because even just having parameters + // is a kind of type annotation). + // + self.type_annotation( + name, Ty::function(params2, ret_ty.clone()), - ); - def.visit_children_err(|x| self.visit(x, &ret_ty, typecheck_mode, codemap))?; + name.span, + codemap, + )?; + + def.visit_header_err(|x| self.visit(x, &ret_ty, typecheck_mode, codemap))?; + + // Function body just gets added to the queue. + self.children_sink.push(ChildDef { + body: &def.body, + param_types, + return_type: ret_ty, + }); Ok(()) } @@ -299,15 +377,25 @@ impl<'a, 'b> BindingsCollect<'a, 'b> { StmtP::Assign(AssignP { lhs, ty, rhs }) => { if let Some(ty) = ty { let ty2 = Self::resolved_ty(ty, typecheck_mode, codemap)?; - self.bindings - .check_type - .push((ty.span, Some(rhs), ty2.clone())); + self.bindings.check_type.push(( + ty.span, + Some(BindExpr::Expr(rhs)), + ty2.clone(), + )); if let AssignTargetP::Identifier(id) = &**lhs { - // FIXME: This could be duplicated if you declare the type of a variable twice, - // we would only see the second one. - self.bindings - .types - .insert(id.resolved_binding_id(codemap)?, ty2); + self.type_annotation(id, ty2, lhs.span.merge(ty.span), codemap)?; + } else { + // We should have caught this when parsing. + return Err(InternalError::from_error( + crate::Error::new_kind(starlark_syntax::ErrorKind::Other( + anyhow::anyhow!( + "Cannot annotate type of complex assign target", + ), + )), + lhs.span.merge(ty.span), + codemap, + ) + .into()); } } self.assign(lhs, BindExpr::Expr(rhs), codemap)? @@ -324,11 +412,11 @@ impl<'a, 'b> BindingsCollect<'a, 'b> { return Ok(()); } StmtP::Load(..) => {} - StmtP::Return(ret) => { - self.bindings - .check_type - .push((x.span, ret.as_ref(), return_type.clone())) - } + StmtP::Return(ret) => self.bindings.check_type.push(( + x.span, + ret.as_ref().map(BindExpr::Expr), + return_type.clone(), + )), StmtP::Expression(x) => { // We want to find ident.append(), ident.extend(), ident.extend() // to fake up a BindExpr::ListAppend/ListExtend diff --git a/starlark-rust/starlark/src/typing/ctx.rs b/starlark-rust/starlark/src/typing/ctx.rs index a4a5d5dca6731..6ec81d37a1e75 100644 --- a/starlark-rust/starlark/src/typing/ctx.rs +++ b/starlark-rust/starlark/src/typing/ctx.rs @@ -52,13 +52,45 @@ use crate::typing::oracle::traits::TypingUnOp; use crate::typing::ty::Approximation; use crate::typing::ty::Ty; +pub(crate) enum BindingType { + /// Do not modify during type solver, user provided + Annotated(Ty), + /// Please solve this binding type + Solver(Ty), +} + +impl BindingType { + pub(crate) fn into_inner(self) -> Ty { + match self { + Self::Annotated(x) => x, + Self::Solver(x) => x, + } + } +} + +impl AsRef for BindingType { + fn as_ref(&self) -> &Ty { + self + } +} + +impl std::ops::Deref for BindingType { + type Target = Ty; + fn deref(&self) -> &Self::Target { + match self { + Self::Annotated(x) => x, + Self::Solver(x) => x, + } + } +} + pub(crate) struct TypingContext<'a> { pub(crate) oracle: TypingOracleCtx<'a>, // We'd prefer this to be a &mut self, // but that makes writing the code more fiddly, so just RefCell the errors pub(crate) errors: RefCell>, pub(crate) approximoations: RefCell>, - pub(crate) types: UnorderedMap, + pub(crate) types: UnorderedMap, pub(crate) module_var_types: &'a ModuleVarTypes, } @@ -240,7 +272,7 @@ impl TypingContext<'_> { AssignTargetP::Identifier(x) => { if let Some(i) = x.payload { if let Some(ty) = self.types.get(&i) { - return Ok(ty.clone()); + return Ok(ty.as_ref().clone()); } } Err(InternalError::msg( @@ -390,7 +422,7 @@ impl TypingContext<'_> { .unwrap_or_else(Ty::any), Some(ResolvedIdent::Slot(_, i)) => { if let Some(ty) = self.types.get(i) { - ty.clone() + ty.as_ref().clone() } else { // All types must be resolved to this point, // this code is unreachable. diff --git a/starlark-rust/starlark/src/typing/error.rs b/starlark-rust/starlark/src/typing/error.rs index 05926fba430b0..6be14e11d59c5 100644 --- a/starlark-rust/starlark/src/typing/error.rs +++ b/starlark-rust/starlark/src/typing/error.rs @@ -136,6 +136,24 @@ impl From for TypingOrInternalError { } } +impl TypingOrInternalError { + #[cold] + pub(crate) fn into_eval_exception(self) -> EvalException { + match self { + Self::Typing(e) => e.into_eval_exception(), + Self::Internal(e) => e.into_eval_exception(), + } + } + + #[cold] + pub(crate) fn into_error(self) -> crate::Error { + match self { + Self::Typing(e) => e.into_error(), + Self::Internal(e) => e.into_error(), + } + } +} + pub enum TypingNoContextOrInternalError { Typing, Internal(InternalError), diff --git a/starlark-rust/starlark/src/typing/function.rs b/starlark-rust/starlark/src/typing/function.rs index 54dea06f6bde8..7b14117e13bc7 100644 --- a/starlark-rust/starlark/src/typing/function.rs +++ b/starlark-rust/starlark/src/typing/function.rs @@ -55,6 +55,11 @@ pub trait TyCustomFunctionImpl: fn as_function(&self) -> Option<&TyFunction> { None } + + /// Type check an attribute of the callable value + fn attribute(&self, _attr: &str) -> Result { + Err(TypingNoContextError) + } } #[derive( @@ -119,8 +124,8 @@ impl TyCustomImpl for TyCustomFunction { Ok(Ty::any()) } - fn attribute(&self, _attr: &str) -> Result { - Err(TypingNoContextError) + fn attribute(&self, attr: &str) -> Result { + self.0.attribute(attr) } fn matcher(&self, factory: T) -> T::Result { @@ -181,4 +186,13 @@ impl TyCustomFunctionImpl for TyFunction { fn as_function(&self) -> Option<&TyFunction> { Some(self) } + + fn attribute(&self, attr: &str) -> Result { + if attr == "type" + && let Some(as_type) = self.type_attr.dupe() + { + return Ok(as_type); + } + Err(TypingNoContextError) + } } diff --git a/starlark-rust/starlark/src/typing/tests.rs b/starlark-rust/starlark/src/typing/tests.rs index 11c4e7e1589cb..3a81989d54bf5 100644 --- a/starlark-rust/starlark/src/typing/tests.rs +++ b/starlark-rust/starlark/src/typing/tests.rs @@ -468,8 +468,8 @@ def foo() -> str: "test_bit_or_with_load", r#" load("foo.bzl", "foo") -test = int | foo() -def test() -> test: +test_type = int | foo() +def test() -> test_type: pass "#, ); diff --git a/starlark-rust/starlark/src/typing/tests/golden/test_bit_or_with_load.golden b/starlark-rust/starlark/src/typing/tests/golden/test_bit_or_with_load.golden index 7eeba04aa61c9..459790f14e359 100644 --- a/starlark-rust/starlark/src/typing/tests/golden/test_bit_or_with_load.golden +++ b/starlark-rust/starlark/src/typing/tests/golden/test_bit_or_with_load.golden @@ -6,21 +6,21 @@ Code: load("foo.bzl", "foo") -test = int | foo() -def test() -> test: +test_type = int | foo() +def test() -> test_type: pass No errors. Approximations: -Approximation: Unknown type = "Span { begin: Pos(57), end: Pos(61) }" +Approximation: Unknown type = "Span { begin: Pos(62), end: Pos(71) }" Compiler typechecker (eval): Compiler typechecker and eval results mismatch. error: String literals are not allowed in type expressions: `"test"` - --> filename:3:8 + --> filename:3:13 | -3 | test = int | foo() - | ^^^^^^^^^^^ +3 | test_type = int | foo() + | ^^^^^^^^^^^ | diff --git a/starlark-rust/starlark/src/typing/typecheck.rs b/starlark-rust/starlark/src/typing/typecheck.rs index 955dd493d80ff..cee9eef7dd14e 100644 --- a/starlark-rust/starlark/src/typing/typecheck.rs +++ b/starlark-rust/starlark/src/typing/typecheck.rs @@ -23,7 +23,6 @@ use std::fmt::Display; use dupe::Dupe; use starlark_map::unordered_map::UnorderedMap; use starlark_syntax::slice_vec_ext::VecExt; -use starlark_syntax::syntax::ast::StmtP; use starlark_syntax::syntax::ast::Visibility; use starlark_syntax::syntax::module::AstModuleFields; use starlark_syntax::syntax::top_level_stmts::top_level_stmts_mut; @@ -43,9 +42,11 @@ use crate::syntax::AstModule; use crate::syntax::Dialect; use crate::typing::bindings::Bindings; use crate::typing::bindings::BindingsCollect; +use crate::typing::ctx::BindingType; use crate::typing::ctx::TypingContext; use crate::typing::error::InternalError; use crate::typing::error::TypingError; +use crate::typing::error::TypingOrInternalError; use crate::typing::fill_types_for_lint::ModuleVarTypes; use crate::typing::fill_types_for_lint::fill_types_for_lint_typechecker; use crate::typing::interface::Interface; @@ -55,20 +56,140 @@ use crate::typing::ty::Approximation; use crate::typing::ty::Ty; use crate::values::FrozenHeap; +/// Recursive function type-checker. +/// +/// You can call `solve_bindings` on as big or as little input as you like, but it +/// is O(n * m) where n is the number of expressions assigned to bindings, and m is +/// a measure of complexity of the assignments (`a=1; b=a; b=""; c=b; a=c` is +/// deliberately complex). If you lump all bindings in an entire module together and +/// solve them all at once, `m` is set to the most complex function and `n` is large. +/// +/// Fortunately, because of variable shadowing and no `global` keyword, we only ever +/// need to look at a function body to determine the types of its local bindings from +/// the assignments made to it. +/// +/// ```python +/// x = 5 +/// def child(): +/// x = "string" # completely different binding, no relation to outer `x` +/// ``` +/// +/// In some circumstances child scopes could influence parent scopes via mutation +/// methods, but it's not a big loss if we don't infer in these cases. +/// +/// ```python +/// xs = list() +/// def child(): +/// xs.push(5) # could influence type of xs, but we will not support this +/// ``` +/// +/// So we can solve bindings for every scope individually. That's what this does. +/// As for ordering, a parent scope is finished and solved before we check any of +/// its child scopes, so child scopes have access to parent binding solutions. +/// +/// ```python +/// x = 5 +/// def child(): +/// y = x # y solves to `int | str` +/// x = "string" +/// ``` +/// +pub(crate) struct TypeChecker<'a> { + pub(crate) oracle: TypingOracleCtx<'a>, + pub(crate) typecheck_mode: TypecheckMode, + pub(crate) module_var_types: &'a ModuleVarTypes, + pub(crate) approximations: &'a mut Vec, + pub(crate) all_solved_types: HashMap, +} + +impl<'a> TypeChecker<'a> { + fn codemap(&self) -> &'a CodeMap { + self.oracle.codemap + } + + /// Typecheck an entire module. + /// + /// To just immediately return on encountering a type error, pass `&mut Err` as the error + /// handler. To collect errors and continue, be sure to return Ok from the error handler. + pub(crate) fn check_module_scope( + &mut self, + module: &CstStmt, + eh: &mut dyn FnMut(TypingError) -> Result<(), TypingError>, + ) -> Result<(), TypingOrInternalError> { + self.check_scope(module, &Ty::any(), &HashMap::default(), eh) + } + + /// Recursive scope check. Checks the scope's bindings, and then all child defs. + pub(crate) fn check_scope( + &mut self, + body: &CstStmt, + return_type: &Ty, + visible: &HashMap, + eh: &mut dyn FnMut(TypingError) -> Result<(), TypingError>, + ) -> Result<(), TypingOrInternalError> { + let mut children = Vec::new(); + let bindings = BindingsCollect::collect_scope( + body, + return_type, + visible, + self.typecheck_mode, + self.codemap(), + self.approximations, + &mut children, + )?; + let (errors, solved, mut approx) = + solve_bindings(bindings, self.oracle, self.module_var_types)?; + + self.approximations.append(&mut approx); + + for error in errors { + eh(error)?; + } + + // Save all solved types to the output + let solved_copy = solved.iter().map(|(&b, ty)| (b, ty.dupe())); + self.all_solved_types.extend(solved_copy.clone()); + + for child in children { + let mut child_visible = visible.clone(); + child_visible.extend(solved_copy.clone()); + child_visible.extend(child.param_types); + self.check_scope(child.body, &child.return_type, &child_visible, eh)?; + } + Ok(()) + } +} + // Things which are None in the map have type void - they are never constructed pub(crate) fn solve_bindings( - bindings: Bindings, + mut bindings: Bindings, oracle: TypingOracleCtx, module_var_types: &ModuleVarTypes, ) -> Result<(Vec, HashMap, Vec), InternalError> { - let mut types = bindings - .expressions - .keys() - .map(|x| (*x, Ty::never())) - .collect::>(); + let mut types: UnorderedMap = UnorderedMap::new(); + + // No need to (expensively) solve over bound expressions where the binding's type + // is already provided by user. Move these into check_type, to check all assignments + // match the type annotation. for (k, ty) in bindings.types { - types.insert(k, ty); + types.insert(k, BindingType::Annotated(ty.clone())); + + if let Some(exprs) = bindings.expressions.get_mut(&k) { + for expr in std::mem::take(exprs) { + bindings + .check_type + .push((expr.span(), Some(expr), ty.dupe())); + } + } } + // So we don't have to shift_remove N times, just call retain at the end + bindings.expressions.retain(|_, exprs| !exprs.is_empty()); + + // Initialize unsolved types + bindings.expressions.keys().for_each(|x| { + types.insert(*x, BindingType::Solver(Ty::never())); + }); + // FIXME: Should be a fixed point, just do 10 iterations since that probably converges let mut changed = false; let mut ctx = TypingContext { @@ -79,13 +200,17 @@ pub(crate) fn solve_bindings( module_var_types, }; const ITERATIONS: usize = 100; + for _iteration in 0..ITERATIONS { changed = false; ctx.errors.borrow_mut().clear(); for (name, exprs) in &bindings.expressions { for expr in exprs { let ty = ctx.expression_bind_type(expr)?; - let t = ctx.types.get_mut(name).unwrap(); + let BindingType::Solver(t) = ctx.types.get_mut(name).unwrap() else { + // unreachable + continue; + }; let new = Ty::union2(t.clone(), ty); if &new != t { changed = true; @@ -110,7 +235,7 @@ pub(crate) fn solve_bindings( for (span, e, require) in &bindings.check_type { let ty = match e { None => Ty::none(), - Some(x) => ctx.expression_type(x)?, + Some(x) => ctx.expression_bind_type(x)?, }; ctx.validate_type( Spanned { @@ -120,9 +245,15 @@ pub(crate) fn solve_bindings( require, )?; } + // Put binding errors first as the compiler fails with the first error and otherwise they + // never show up in tests. + bindings.errors.append(&mut ctx.errors.into_inner()); Ok(( - ctx.errors.into_inner(), - ctx.types.into_hash_map(), + bindings.errors, + ctx.types + .into_entries_unordered() + .map(|(k, v)| (k, v.into_inner())) + .collect(), ctx.approximoations.into_inner(), )) } @@ -216,12 +347,12 @@ impl AstModuleTypecheck for AstModule { let scope_errors = scope_errors.into_map(TypingError::from_eval_exception); // We don't really need to properly unpack top-level statements, // but make it safe against future changes. - let mut cst: Vec<&mut CstStmt> = top_level_stmts_mut(&mut cst); + let mut cst_toplevel: Vec<&mut CstStmt> = top_level_stmts_mut(&mut cst); let oracle = TypingOracleCtx { codemap: &codemap }; let mut approximations = Vec::new(); let (fill_types_errors, module_var_types) = match fill_types_for_lint_typechecker( - &mut cst, + &mut cst_toplevel, oracle, &scope_data, &mut approximations, @@ -241,58 +372,43 @@ impl AstModuleTypecheck for AstModule { }; let mut typemap = UnorderedMap::new(); + let oracle = TypingOracleCtx { codemap: &codemap }; + let mut type_checker = TypeChecker { + oracle, + typecheck_mode: TypecheckMode::Lint, + module_var_types: &module_var_types, + approximations: &mut approximations, + all_solved_types: HashMap::new(), + }; let mut all_solve_errors = Vec::new(); - for top in cst.iter_mut() { - if let StmtP::Def(_) = &mut top.node { - let bindings = match BindingsCollect::collect_one( - top, - TypecheckMode::Lint, - &codemap, - &mut approximations, - ) { - Ok(bindings) => bindings, - Err(e) => { - return ( - vec![InternalError::into_error(e)], - TypeMap { - codemap, - bindings: UnorderedMap::new(), - }, - Interface::default(), - Vec::new(), - ); - } - }; - let (solve_errors, types, solve_approximations) = - match solve_bindings(bindings.bindings, oracle, &module_var_types) { - Ok(x) => x, - Err(e) => { - return ( - vec![e.into_error()], - TypeMap { - codemap, - bindings: UnorderedMap::new(), - }, - Interface::default(), - Vec::new(), - ); - } - }; - - all_solve_errors.extend(solve_errors); - approximations.extend(solve_approximations); + let mut error_handler = |error| { + all_solve_errors.push(error); + // continue checking + Ok(()) + }; + if let Err(unrecoverable) = type_checker.check_module_scope(&cst, &mut error_handler) { + // This is generally an internal error, since most typing errors are handled + // by error_handler. Except some type errors that can't (yet?) be recovered. + return ( + vec![unrecoverable.into_error()], + TypeMap { + codemap, + bindings: UnorderedMap::new(), + }, + Interface::default(), + Vec::new(), + ); + } - for (id, ty) in &types { - let binding = scope_data.get_binding(*id); - let name = binding.name.as_str().to_owned(); - let span = match binding.source { - BindingSource::Source(span) => span, - BindingSource::FromModule => Span::default(), - }; - typemap.insert(*id, (name, span, ty.clone())); - } - } + for (id, ty) in &type_checker.all_solved_types { + let binding = scope_data.get_binding(*id); + let name = binding.name.as_str().to_owned(); + let span = match binding.source { + BindingSource::Source(span) => span, + BindingSource::FromModule => Span::default(), + }; + typemap.insert(*id, (name, span, ty.clone())); } let typemap = TypeMap { diff --git a/starlark-rust/starlark/src/typing/user.rs b/starlark-rust/starlark/src/typing/user.rs index 87e028675682f..e7e920ae6b1ef 100644 --- a/starlark-rust/starlark/src/typing/user.rs +++ b/starlark-rust/starlark/src/typing/user.rs @@ -253,11 +253,13 @@ impl TyCustomImpl for TyUser { } fn as_callable(&self) -> Option { - if self.base.is_callable() { - Some(TyCallable::any()) - } else { - None - } + self.callable.as_ref().cloned().or_else(|| { + if self.base.is_callable() { + Some(TyCallable::any()) + } else { + None + } + }) } fn validate_call( @@ -292,6 +294,15 @@ impl TyCustomImpl for TyUser { } self.supertypes.iter().any(|x| x == other) } + + fn bin_op( + &self, + bin_op: super::TypingBinOp, + rhs: &TyBasic, + _ctx: &TypingOracleCtx, + ) -> Result { + Ok(self.base.bin_op(bin_op, rhs)?) + } } #[cfg(test)] diff --git a/starlark-rust/starlark/src/values/types/function.rs b/starlark-rust/starlark/src/values/types/function.rs index 20821c22196f5..cac565263aecd 100644 --- a/starlark-rust/starlark/src/values/types/function.rs +++ b/starlark-rust/starlark/src/values/types/function.rs @@ -172,9 +172,12 @@ impl<'v> StarlarkValue<'v> for NativeFunction { self.as_type.clone() } - fn has_attr(&self, _attribute: &str, _heap: &'v Heap) -> bool { - // TODO(nga): implement properly. - false + fn has_attr(&self, attribute: &str, _heap: &'v Heap) -> bool { + if self.as_type.is_some() { + attribute == "type" + } else { + false + } } fn dir_attr(&self) -> Vec { diff --git a/starlark-rust/starlark_map/src/unordered_map.rs b/starlark-rust/starlark_map/src/unordered_map.rs index c6c7edd270b55..e139819a03019 100644 --- a/starlark-rust/starlark_map/src/unordered_map.rs +++ b/starlark-rust/starlark_map/src/unordered_map.rs @@ -221,7 +221,7 @@ impl UnorderedMap { /// Into entries, in arbitrary order. #[inline] - pub(crate) fn into_entries_unordered(self) -> impl ExactSizeIterator { + pub fn into_entries_unordered(self) -> impl ExactSizeIterator { self.0.into_iter() } diff --git a/starlark-rust/starlark_syntax/src/syntax/grammar_util.rs b/starlark-rust/starlark_syntax/src/syntax/grammar_util.rs index 2f8a7397e00ea..2f0de6e9046d4 100644 --- a/starlark-rust/starlark_syntax/src/syntax/grammar_util.rs +++ b/starlark-rust/starlark_syntax/src/syntax/grammar_util.rs @@ -66,6 +66,8 @@ enum GrammarUtilError { TypeAnnotationOnAssignOp, #[error("type annotations not allowed on multiple assignments")] TypeAnnotationOnTupleAssign, + #[error("type annotations not allowed on complex assignments")] + TypeAnnotationOnComplexAssign, #[error("`load` statement requires at least two arguments")] LoadRequiresAtLeastTwoArguments, } @@ -125,12 +127,13 @@ pub fn check_assignment( } let lhs = check_assign(codemap, lhs)?; if let Some(ty) = &ty { - let err = if op.is_some() { - Some(GrammarUtilError::TypeAnnotationOnAssignOp) - } else if matches!(lhs.node, AssignTargetP::Tuple(_)) { - Some(GrammarUtilError::TypeAnnotationOnTupleAssign) - } else { - None + let err = match lhs.node { + _ if op.is_some() => Some(GrammarUtilError::TypeAnnotationOnAssignOp), + AssignTargetP::Tuple(_) => Some(GrammarUtilError::TypeAnnotationOnTupleAssign), + AssignTargetP::Index(_) | AssignTargetP::Dot(..) => { + Some(GrammarUtilError::TypeAnnotationOnComplexAssign) + } + AssignTargetP::Identifier(_) => None, }; if let Some(err) = err { return Err(EvalException::new_anyhow(err.into(), ty.span, codemap)); diff --git a/starlark-rust/starlark_syntax/src/syntax/uniplate.rs b/starlark-rust/starlark_syntax/src/syntax/uniplate.rs index c1df16dceb79a..a68a64dba2402 100644 --- a/starlark-rust/starlark_syntax/src/syntax/uniplate.rs +++ b/starlark-rust/starlark_syntax/src/syntax/uniplate.rs @@ -86,6 +86,36 @@ impl DefP

{ f(Visit::Stmt(body)); } + /// Visit the parameters and return type, but not the body + pub fn visit_header_err<'a, E>( + &'a self, + mut f: impl FnMut(Visit<'a, P>) -> Result<(), E>, + ) -> Result<(), E> { + let DefP { + name: _, + params, + return_type, + body: _, + payload: _, + } = self; + let mut result = Ok(()); + params.iter().for_each(|x| { + x.visit_expr(|x| { + if result.is_ok() { + result = f(Visit::Expr(x)); + } + }) + }); + return_type.iter().for_each(|x| { + x.visit_expr(|x| { + if result.is_ok() { + result = f(Visit::Expr(x)); + } + }) + }); + result + } + pub fn visit_children_err<'a, E>( &'a self, mut f: impl FnMut(Visit<'a, P>) -> Result<(), E>,