Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,7 @@ private_common_attrs = {
"//support/public:result",
"//support/public:tuple",
"//support/public:vec",
"//support/public:iterator_adapter",
],
),
"_process_wrapper": attr.label(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,10 @@ impl<'tcx> CcPrerequisites<'tcx> {
.extend(std::mem::take(&mut self.template_specializations));
}

pub fn move_only_defs_to_fwd_decls(&mut self) {
self.fwd_decls.extend(std::mem::take(&mut self.defs));
}

/// Move any definitions that appear in `ty` to the forward declarations of `prereqs`.
pub fn forward_declare_type(&mut self, ty: Ty<'tcx>) {
let mut adts = HashSet::new();
Expand Down
272 changes: 271 additions & 1 deletion cc_bindings_from_rs/generate_bindings/generate_struct_and_union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@ use arc_anyhow::{Context, Result};
use code_gen_utils::{
escape_non_identifier_chars, expect_format_cc_type_name, make_rs_ident, CcInclude,
};
use database::code_snippet::{ApiSnippets, CcPrerequisites, CcSnippet};
use database::code_snippet::{
ApiSnippets, CcPrerequisites, CcSnippet, TemplateSpecialization,
TraitImplTemplateSpecialization,
};
use database::{AdtCoreBindings, BindingsGenerator, StaticMethodMode, TypeLocation};
use error_report::{anyhow, bail, ensure};
use itertools::Itertools;
Expand Down Expand Up @@ -1180,6 +1183,8 @@ pub fn generate_adt<'tcx>(
let trait_operator_snippets = generate_trait_operator_impls(db, core.as_ref());
let constructor_operator_snippets = generate_constructor_impls(db, core.as_ref());
let display_snippets = generate_display_impl(db, core.as_ref());
let into_iterator_snippets =
generate_into_iterator_impls(db, core.as_ref(), &mut member_function_names);

let ApiSnippets {
main_api: public_functions_main_api,
Expand All @@ -1197,6 +1202,7 @@ pub fn generate_adt<'tcx>(
trait_operator_snippets,
constructor_operator_snippets,
display_snippets,
into_iterator_snippets,
]
.into_iter()
.collect();
Expand Down Expand Up @@ -2548,3 +2554,267 @@ pub(crate) fn generate_relocating_ctor<'tcx>(
main_api.prereqs.includes.insert(CcInclude::cstring());
main_api.into_main_api()
}

#[derive(Clone, Copy, PartialEq, Eq)]
enum PassingMode {
Value,
SharedRef,
MutRef,
}

fn get_into_iter_ty<'tcx>(
tcx: TyCtxt<'tcx>,
self_ty: Ty<'tcx>,
into_iterator_trait_id: DefId,
) -> Result<Ty<'tcx>> {
let into_iter_assoc_item = tcx
.associated_items(into_iterator_trait_id)
.in_definition_order()
.find(|item| {
item.name() == rustc_span::symbol::Symbol::intern("IntoIter")
&& matches!(item.kind, ty::AssocKind::Type { .. })
})
.expect("IntoIter to be a required associated item of IntoIterator");

let projection_ty = Ty::new_projection(tcx, into_iter_assoc_item.def_id, [self_ty]);

query_compiler::try_normalize(
tcx,
ty::PseudoCanonicalInput {
typing_env: rustc_middle::ty::TypingEnv::fully_monomorphized(),
value: projection_ty,
},
)
.map_err(|_| anyhow!("Failed to normalize `<{} as IntoIterator>::IntoIter`", self_ty))
}

fn get_into_iter_item_ty<'tcx>(
tcx: TyCtxt<'tcx>,
self_ty: Ty<'tcx>,
into_iterator_trait_id: DefId,
) -> Result<Ty<'tcx>> {
let item_assoc_item = tcx
.associated_items(into_iterator_trait_id)
.in_definition_order()
.find(|item| {
item.name() == rustc_span::symbol::Symbol::intern("Item")
&& matches!(item.kind, ty::AssocKind::Type { .. })
})
.expect("Item to be a required associated item of IntoIterator");

let projection_ty = Ty::new_projection(tcx, item_assoc_item.def_id, [self_ty]);

query_compiler::try_normalize(
tcx,
ty::PseudoCanonicalInput {
typing_env: rustc_middle::ty::TypingEnv::fully_monomorphized(),
value: projection_ty,
},
)
.map_err(|_| anyhow!("Failed to normalize `<{} as IntoIterator>::Item`", self_ty))
}

fn generate_begin_and_end_for_type<'tcx>(
db: &BindingsGenerator<'tcx>,
core: &AdtCoreBindings<'tcx>,
into_iterator_trait_id: DefId,
passing_mode: PassingMode,
) -> Result<Option<ApiSnippets<'tcx>>> {
let tcx = db.tcx();
let self_ty = core.self_ty;

let check_ty = match passing_mode {
PassingMode::Value => self_ty,
PassingMode::SharedRef => Ty::new_imm_ref(tcx, tcx.lifetimes.re_erased, self_ty),
PassingMode::MutRef => Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, self_ty),
};

if !does_type_implement_trait(tcx, check_ty, into_iterator_trait_id, []) {
return Ok(None);
}
if let Some(iterator_trait_id) = tcx.get_diagnostic_item(sym::Iterator)
&& does_type_implement_trait(tcx, self_ty, iterator_trait_id, [])
{
return Ok(None);
}

let into_iter_ty = get_into_iter_ty(tcx, check_ty, into_iterator_trait_id)?;

let item_ty = get_into_iter_item_ty(tcx, check_ty, into_iterator_trait_id)?;

let _ = db
.format_ty_for_cc(item_ty, TypeLocation::Other)
.context("Failed to format IntoIterator::Item")?;

let into_iter_cc_ty = db
.format_ty_for_cc(into_iter_ty, TypeLocation::Other)
.context("Failed to format IntoIterator::IntoIter")?;

let static_check_ty = replace_all_regions_with_static(tcx, check_ty);
let rs_fully_qualified_name = db.format_ty_for_rs(static_check_ty)?;

let TraitThunks { method_name_to_cc_thunk_name, cc_thunk_decls, rs_thunk_impls } =
generate_trait_thunks(
db,
into_iterator_trait_id,
&[],
check_ty,
core.def_id,
rs_fully_qualified_name,
/*is_constructor=*/ false,
)?;

let into_iter_thunk_name = method_name_to_cc_thunk_name
.get(&sym::into_iter)
.expect("IntoIterator trait missing into_iter method");

let into_iter_fn_assoc_item = tcx
.associated_items(into_iterator_trait_id)
.in_definition_order()
.find(|item| item.name() == sym::into_iter && matches!(item.kind, ty::AssocKind::Fn { .. }))
.expect("IntoIterator should have into_iter method");
let into_iter_fn_id = into_iter_fn_assoc_item.def_id;

let adt_cc_name = &core.cc_short_name;
let param_cc_type_tokens = match passing_mode {
PassingMode::Value => quote! { #adt_cc_name && },
PassingMode::SharedRef => quote! { const #adt_cc_name & },
PassingMode::MutRef => quote! { #adt_cc_name & },
};

let param = Param {
cc_name: format_ident!("self_"),
cpp_type: CcParamTy {
snippet: CcSnippet::new(param_cc_type_tokens.clone()),
is_lifetime_bound: false,
},
ty: check_ty,
};

let impl_body = generate_thunk_call(
db,
into_iter_fn_id,
into_iter_thunk_name.clone(),
into_iter_ty,
ThunkSelfParameter::new(
/*has_self=*/ false, /*by_copy=*/ false, /*is_trait_method=*/ false,
),
&[param],
)?;

let mut main_api_prereqs = CcPrerequisites::default();
let into_iter_cc_ty_tokens_main = into_iter_cc_ty.clone().into_tokens(&mut main_api_prereqs);
main_api_prereqs.includes.insert(db.support_header("rs_std/iterator_adapter.h"));

let iterator_trait_id = tcx
.get_diagnostic_item(sym::Iterator)
.ok_or_else(|| anyhow!("Iterator trait not found"))?;
let mut impls = tcx.non_blanket_impls_for_ty(iterator_trait_id, into_iter_ty);
let Some(trait_impl_def_id) = impls.next() else {
return Ok(None);
};
let generics = tcx.generics_of(trait_impl_def_id);
let has_type_or_const_params = generics.own_params.iter().any(|param| {
matches!(
param.kind,
ty::GenericParamDefKind::Type { .. } | ty::GenericParamDefKind::Const { .. }
)
});
if has_type_or_const_params {
bail!("IntoIterator/Iterator impls with generic type or const parameters are not supported yet.");
}
let specialization = TemplateSpecialization::TraitImpl(TraitImplTemplateSpecialization {
self_ty_cc_name: into_iter_cc_ty_tokens_main.clone(),
trait_impl: trait_impl_def_id,
});
main_api_prereqs.template_specializations.insert(specialization);

let (ref_qualifiers, self_binding) = match passing_mode {
PassingMode::Value => {
(quote! { && }, quote! { #adt_cc_name&& self_ = ::std::move(*this); })
}
PassingMode::SharedRef => {
(quote! { const & }, quote! { const #adt_cc_name& self_ = *this; })
}
PassingMode::MutRef => (quote! { & }, quote! { #adt_cc_name& self_ = *this; }),
};

main_api_prereqs.move_only_defs_to_fwd_decls();
let main_api = CcSnippet {
tokens: quote! {
rs::IteratorAdapter< #into_iter_cc_ty_tokens_main > begin() #ref_qualifiers;
rs::IteratorEnd end() #ref_qualifiers;
},
prereqs: main_api_prereqs,
};

let mut cc_details_prereqs = CcPrerequisites::default();
let into_iter_cc_ty_tokens_details = into_iter_cc_ty.into_tokens(&mut cc_details_prereqs);
cc_details_prereqs.includes.insert(db.support_header("rs_std/iterator_adapter.h"));

let cc_thunk_decls_tokens = cc_thunk_decls.into_tokens(&mut cc_details_prereqs);
let impl_body_tokens = impl_body.into_tokens(&mut cc_details_prereqs);
cc_details_prereqs.move_defs_to_fwd_decls();

let call_expr = if matches!(into_iter_ty.kind(), ty::Ref(..)) {
quote! { &call_into_iter() }
} else {
quote! { call_into_iter() }
};

let cc_details = CcSnippet {
tokens: quote! {
#cc_thunk_decls_tokens

inline rs::IteratorAdapter< #into_iter_cc_ty_tokens_details > #adt_cc_name :: begin () #ref_qualifiers {
#self_binding
auto call_into_iter = [&]() -> decltype(auto) {
#impl_body_tokens
};
return rs::IteratorAdapter< #into_iter_cc_ty_tokens_details >(#call_expr);
}
inline rs::IteratorEnd #adt_cc_name :: end () #ref_qualifiers {
return rs::IteratorEnd();
}
},
prereqs: cc_details_prereqs,
};

Ok(Some(ApiSnippets { main_api, cc_details, rs_details: rs_thunk_impls }))
}

fn generate_into_iterator_impls<'tcx>(
db: &BindingsGenerator<'tcx>,
core: &AdtCoreBindings<'tcx>,
member_function_names: &mut HashSet<String>,
) -> ApiSnippets<'tcx> {
let tcx = db.tcx();
let Some(into_iterator_trait_id) = tcx.get_diagnostic_item(sym::IntoIterator) else {
return ApiSnippets::default();
};

if member_function_names.contains("begin") || member_function_names.contains("end") {
return ApiSnippets::default();
}

let mut snippets = Vec::new();

let mut try_generate = |passing_mode| -> Result<Option<ApiSnippets<'tcx>>> {
generate_begin_and_end_for_type(db, core, into_iterator_trait_id, passing_mode)
};

for mode in [PassingMode::Value, PassingMode::SharedRef, PassingMode::MutRef] {
match try_generate(mode) {
Ok(Some(s)) => snippets.push(s),
Ok(None) => {}
Err(err) => {
if let Some(def_id) = core.def_id {
let main_api = generate_unsupported_def(db, def_id, err);
snippets.push(ApiSnippets { main_api, ..Default::default() });
}
}
}
}

snippets.into_iter().collect()
}
Original file line number Diff line number Diff line change
Expand Up @@ -1929,7 +1929,15 @@ fn generate_trait_impl_specialization<'tcx>(

prereqs.depend_on_def(db, trait_def_id).map_err(|err| (impl_def_id, err))?;
if let Some(adt) = trait_ref.self_ty().ty_adt_def() {
prereqs.depend_on_def(db, adt.did()).map_err(|err| (impl_def_id, err))?;
let def_id = adt.did();
let canonical_name = db.symbol_canonical_name(def_id).expect(
"Self type should have a canonical name if we are generating a specialization for it",
);
if canonical_name.krate_num == db.source_crate_num() {
prereqs.fwd_decls.insert(def_id);
} else {
prereqs.depend_on_def(db, def_id).map_err(|err| (impl_def_id, err))?;
}
}

let mut member_function_names = HashSet::new();
Expand Down
6 changes: 5 additions & 1 deletion cc_bindings_from_rs/generate_bindings/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1279,6 +1279,7 @@ pub(crate) fn create_type_alias_with_rs_type<'tcx>(
let cc_bindings = db.format_ty_for_cc(alias_type, TypeLocation::Other)?;
let mut main_api_prereqs = CcPrerequisites::default();
let actual_type_name = cc_bindings.into_tokens(&mut main_api_prereqs);
main_api_prereqs.move_defs_to_fwd_decls();

let alias_name = format_cc_ident(db, alias_name).context("Error formatting type alias name")?;

Expand Down Expand Up @@ -2271,6 +2272,7 @@ fn generate_crate(db: &BindingsGenerator) -> Result<BindingsTokens> {
.map(|(spec, main_api)| (Node::Specialization(spec.clone()), main_api)),
)
.flat_map(|(successor, main_api)| {
let successor_clone = successor.clone();
let predecessors = main_api
.prereqs
.defs
Expand All @@ -2286,13 +2288,15 @@ fn generate_crate(db: &BindingsGenerator) -> Result<BindingsTokens> {
.iter()
.cloned()
.map(Node::Specialization),
);
)
.filter(move |pre| pre != &successor_clone);
predecessors.map(move |predecessor| toposort::Dependency {
predecessor,
successor: successor.clone(),
})
})
.collect::<Vec<_>>();

let spec_keys: HashMap<&TemplateSpecialization<'_>, NodeSortKey> =
specializations.keys().map(|spec| (spec, NodeSortKey::new(tcx, spec))).collect();

Expand Down
Loading
Loading