diff --git a/pyrefly/lib/state/lsp.rs b/pyrefly/lib/state/lsp.rs index 8f972858a3..e906ee270e 100644 --- a/pyrefly/lib/state/lsp.rs +++ b/pyrefly/lib/state/lsp.rs @@ -2513,6 +2513,8 @@ impl<'a> Transaction<'a> { && let Some(ast) = self.get_ast(handle) && let Some(module_info) = self.get_module_info(handle) { + let mut autoimport_candidates = Vec::new(); + let mut names_with_public = OrderedSet::new(); for (handle_to_import_from, name, export) in self.search_exports_fuzzy(identifier.as_str()) { @@ -2529,7 +2531,7 @@ impl<'a> Transaction<'a> { &ast, self.config_finder(), handle.dupe(), - handle_to_import_from, + handle_to_import_from.clone(), &name, import_format, ); @@ -2541,29 +2543,47 @@ impl<'a> Transaction<'a> { }; let auto_import_label_detail = format!(" (import {imported_module})"); - completions.push(CompletionItem { - label: name, - detail: Some(insert_text), - kind: export - .symbol_kind - .map_or(Some(CompletionItemKind::VARIABLE), |k| { - Some(k.to_lsp_completion_item_kind()) - }), - additional_text_edits, - label_details: supports_completion_item_details.then_some( - CompletionItemLabelDetails { - detail: Some(auto_import_label_detail), - description: Some(module_description), + let is_private_import = handle_to_import_from + .module() + .components() + .last() + .is_some_and(|component| component.as_str().starts_with('_')); + if !is_private_import { + names_with_public.insert(name.clone()); + } + autoimport_candidates.push(( + CompletionItem { + label: name, + detail: Some(insert_text), + kind: export + .symbol_kind + .map_or(Some(CompletionItemKind::VARIABLE), |k| { + Some(k.to_lsp_completion_item_kind()) + }), + additional_text_edits, + label_details: supports_completion_item_details.then_some( + CompletionItemLabelDetails { + detail: Some(auto_import_label_detail), + description: Some(module_description), + }, + ), + tags: if export.deprecation.is_some() { + Some(vec![CompletionItemTag::DEPRECATED]) + } else { + None }, - ), - tags: if export.deprecation.is_some() { - Some(vec![CompletionItemTag::DEPRECATED]) - } else { - None + sort_text: Some(format!("4{}", depth)), + ..Default::default() }, - sort_text: Some(format!("4{}", depth)), - ..Default::default() - }); + is_private_import, + )); + } + + for (mut item, is_private_import) in autoimport_candidates { + if is_private_import && names_with_public.contains(&item.label) { + item.sort_text = Some("b".to_owned()); + } + completions.push(item); } for module_name in self.search_modules_fuzzy(identifier.as_str()) { @@ -3038,20 +3058,23 @@ impl<'a> Transaction<'a> { .as_ref() .is_some_and(|tags| tags.contains(&CompletionItemTag::DEPRECATED)) { - "9" + "9".to_owned() } else if item.additional_text_edits.is_some() { - "4" + if let Some(sort_text) = &item.sort_text { + format!("4{sort_text}") + } else { + "4".to_owned() + } } else if item.label.starts_with("__") { - "3" + "3".to_owned() } else if item.label.as_str().starts_with("_") { - "2" + "2".to_owned() } else if let Some(sort_text) = &item.sort_text { // 1 is reserved for re-exports - sort_text.as_str() + sort_text.clone() } else { - "0" - } - .to_owned(); + "0".to_owned() + }; item.sort_text = Some(sort_text); } (result, is_incomplete) @@ -3078,12 +3101,14 @@ impl<'a> Transaction<'a> { /// - Handles stdlib patterns where a public module (`io`) re-exports from a /// private implementation module (`_io`). fn should_include_reexport(original: &Handle, canonical: &Handle) -> bool { - let canonical_components = canonical.module().components(); + let canonical_module = canonical.module(); + let original_module = original.module(); + let canonical_components = canonical_module.components(); let canonical_component = canonical_components .last() .map(|name| name.as_str()) .unwrap_or(""); - let original_components = original.module().components(); + let original_components = original_module.components(); let original_component = original_components .last() .map(|name| name.as_str()) @@ -3095,15 +3120,23 @@ impl<'a> Transaction<'a> { return true; } - // Include re-export if original is a parent package of canonical - if canonical_components.len() > original_components.len() { - canonical_components + // Include re-export if original is a parent package of canonical. + if canonical_components.len() > original_components.len() + && canonical_components .iter() .zip(original_components.iter()) .all(|(c, o)| c == o) - } else { - false + { + return true; + } + // Some stdlib shims encode dotted modules with underscores (e.g. _collections_abc). + if canonical_module.as_str().starts_with('_') && original_module.as_str().contains('.') { + let canonical_trim = canonical_module.as_str().trim_start_matches('_'); + if canonical_trim == original_module.as_str().replace('.', "_") { + return true; + } } + false } pub fn search_exports_exact(&self, name: &str) -> Vec<(Handle, Export)> { diff --git a/pyrefly/lib/test/lsp/completion.rs b/pyrefly/lib/test/lsp/completion.rs index 51e5324fd7..007ea26c15 100644 --- a/pyrefly/lib/test/lsp/completion.rs +++ b/pyrefly/lib/test/lsp/completion.rs @@ -1809,6 +1809,41 @@ Completion Results: ); } +#[test] +fn autoimport_prefers_public_reexport_for_dotted_private_module() { + let code = r#" +T = Thing +# ^ +"#; + let report = get_batched_lsp_operations_report_allow_error( + &[ + ("main", code), + ("_foo_bar", "Thing = 1\n"), + ("foo.bar", "from _foo_bar import Thing\n"), + ], + get_test_report(Default::default(), ImportFormat::Absolute), + ); + assert_eq!( + r#" +# main.py +2 | T = Thing + ^ +Completion Results: +- (Variable) Thing: from foo.bar import Thing + +- (Variable) Thing: from _foo_bar import Thing + + + +# _foo_bar.py + +# foo.bar.py +"# + .trim(), + report.trim(), + ); +} + #[test] fn autoimport_completions_set_label_details() { let code = r#"