-
Notifications
You must be signed in to change notification settings - Fork 801
[SYCL][clang] Fix more free-function kernel integration header cases #20877
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: sycl
Are you sure you want to change the base?
Changes from 1 commit
549ca8f
bea8af0
14b14d5
3c7f972
abbb457
9bdcbac
b4a108d
07ff294
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6626,6 +6626,195 @@ static void PrintNSClosingBraces(raw_ostream &OS, const DeclContext *DC) { | |
| [](raw_ostream &, const NamespaceDecl *) {}, OS, DC); | ||
| } | ||
|
|
||
| /// Dedicated visitor which helps with printing of kernel arguments in forward | ||
| /// declarations of free function kernels which are declared as function | ||
| /// templates. | ||
| /// | ||
| /// Based on: | ||
| /// \code | ||
| /// template <typename T1, typename T2> | ||
| /// void foo(T1 a, int b, T2 c); | ||
| /// \endcode | ||
| /// | ||
| /// It prints into the output stream "T1, int, T2". | ||
| /// | ||
| /// The main complexity (which motivates addition of such visitor) comes from | ||
| /// the fact that there could be type aliases and default template arguments. | ||
| /// For example: | ||
| /// \code | ||
| /// template<typename T> | ||
| /// void kernel(sycl::accessor<T, 1>); | ||
| /// template void kernel(sycl::accessor<int, 1>); | ||
| /// \endcode | ||
| /// sycl::accessor has many template arguments which have default values. If | ||
| /// we iterate over non-canonicalized argument type, we don't get those default | ||
| /// values and we don't get necessary namespace qualifiers for all the template | ||
| /// arguments. If we iterate over canonicalized argument type, then all | ||
| /// references to T will be replaced with something like type-argument-X-Y. | ||
| /// What this visitor does is it iterates over both in sync, picking the right | ||
| /// values from one or another. | ||
| /// | ||
| /// Moral of the story: drop integration header ASAP (but that is blocked | ||
Fznamznon marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| /// by support for 3rd-party host compilers, which is important). | ||
| class FreeFunctionTemplateKernelArgsPrinter | ||
| : public ConstTemplateArgumentVisitor<FreeFunctionTemplateKernelArgsPrinter, | ||
| void, ArrayRef<TemplateArgument>> { | ||
| raw_ostream &O; | ||
| PrintingPolicy &Policy; | ||
| ASTContext &Context; | ||
|
|
||
| using Base = | ||
| ConstTemplateArgumentVisitor<FreeFunctionTemplateKernelArgsPrinter, void, | ||
| ArrayRef<TemplateArgument>>; | ||
|
|
||
| void PrintTemplateDeclName(const TemplateDecl *TD, | ||
| ArrayRef<TemplateArgument> SpecArgs) {} | ||
|
|
||
| public: | ||
| FreeFunctionTemplateKernelArgsPrinter(raw_ostream &O, PrintingPolicy &Policy, | ||
| ASTContext &Context) | ||
| : O(O), Policy(Policy), Context(Context) {} | ||
|
|
||
| void Visit(const TemplateSpecializationType *T, | ||
|
||
| const TemplateSpecializationType *CT) { | ||
| ArrayRef<TemplateArgument> SpecArgs = T->template_arguments(); | ||
| ArrayRef<TemplateArgument> DeclArgs = CT->template_arguments(); | ||
|
|
||
| const TemplateDecl *TD = CT->getTemplateName().getAsTemplateDecl(); | ||
| if (!TD->getIdentifier()) | ||
| TD = T->getTemplateName().getAsTemplateDecl(); | ||
|
||
| TD->printQualifiedName(O); | ||
|
|
||
| O << "<"; | ||
| for (size_t I = 0, E = std::max(DeclArgs.size(), SpecArgs.size()), | ||
| SE = SpecArgs.size(); | ||
| I < E; ++I) { | ||
| if (I != 0) | ||
| O << ", "; | ||
| // If we have a specialized argument, use it. Otherwise fallback to a | ||
| // default argument. | ||
| // We pass specialized arguments in case there are references to them | ||
| // from other types. | ||
| // FIXME: passing SpecArgs here is incorrect. It refers to template | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That FIXME does seem concerning.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree, though the current version of this still addresses a chunk of the current issues we are seeing with the prototype generation. I can try to add some more disabled cases to the test so we know what to fix.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. After having a look at this, I cannot seem to find a case that allows template arguments of differing depth. Maybe @AlexeySachkov knows of one? |
||
| // arguments of a single function argument, but DeclArgs contain | ||
| // references (in form of depth-index) to template arguments of the | ||
| // function itself which results in incorrect integration header being | ||
| // produced. | ||
| Base::Visit(I < SE ? SpecArgs[I] : DeclArgs[I], SpecArgs); | ||
| } | ||
| O << ">"; | ||
| } | ||
|
|
||
| // Internal version of the function above that is used when template argument | ||
| // is a template by itself | ||
| void Visit(const TemplateSpecializationType *T, | ||
| ArrayRef<TemplateArgument> SpecArgs) { | ||
| const TemplateDecl *TD = T->getTemplateName().getAsTemplateDecl(); | ||
| const auto *TTPD = dyn_cast<TemplateTemplateParmDecl>(TD); | ||
| if (TTPD && !TTPD->getIdentifier()) | ||
| SpecArgs[TTPD->getIndex()].print(Policy, O, /* IncludeType = */ false); | ||
| else | ||
| TD->printQualifiedName(O); | ||
| O << "<"; | ||
| ArrayRef<const TemplateArgument> DeclArgs = T->template_arguments(); | ||
| for (size_t I = 0, E = DeclArgs.size(); I < E; ++I) { | ||
| if (I != 0) | ||
| O << ", "; | ||
| Base::Visit(DeclArgs[I], SpecArgs); | ||
| } | ||
| O << ">"; | ||
| } | ||
|
|
||
| void VisitNullTemplateArgument(const TemplateArgument &, | ||
| ArrayRef<TemplateArgument>) { | ||
| llvm_unreachable("If template argument has not been deduced, then we can't " | ||
| "forward-declare it, something went wrong"); | ||
| } | ||
|
|
||
| void VisitTypeTemplateArgument(const TemplateArgument &Arg, | ||
| ArrayRef<TemplateArgument> SpecArgs) { | ||
| // If we reference an existing template argument without a known identifier, | ||
| // print it instead. | ||
| const auto *TPT = dyn_cast<TemplateTypeParmType>(Arg.getAsType()); | ||
| if (TPT && !TPT->getIdentifier()) { | ||
| SpecArgs[TPT->getIndex()].print(Policy, O, /* IncludeType = */ false); | ||
| return; | ||
| } | ||
|
|
||
| const auto *TST = dyn_cast<TemplateSpecializationType>(Arg.getAsType()); | ||
| if (TST && Arg.isInstantiationDependent()) { | ||
| // This is an instantiation dependent template specialization, meaning | ||
| // that some of its arguments reference template arguments of the free | ||
| // function kernel itself. | ||
| Visit(TST, SpecArgs); | ||
| return; | ||
| } | ||
|
|
||
| Arg.print(Policy, O, /* IncludeType = */ false); | ||
| } | ||
|
|
||
| void VisitDeclarationTemplateArgument(const TemplateArgument &, | ||
| ArrayRef<TemplateArgument>) { | ||
| llvm_unreachable("Free function kernels cannot have non-type template " | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does it make sense to diagnose this case prior integration header emission? Like in
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should, but I would do it in a new
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure! |
||
| "arguments which are pointers or references"); | ||
| } | ||
|
|
||
| void VisitNullPtrTemplateArgument(const TemplateArgument &, | ||
| ArrayRef<TemplateArgument>) { | ||
| llvm_unreachable("Free function kernels cannot have non-type template " | ||
| "arguments which are pointers or references"); | ||
| } | ||
|
|
||
| void VisitIntegralTemplateArgument(const TemplateArgument &Arg, | ||
| ArrayRef<TemplateArgument>) { | ||
| Arg.print(Policy, O, /* IncludeType = */ false); | ||
| } | ||
|
|
||
| void VisitStructuralValueTemplateArgument(const TemplateArgument &Arg, | ||
| ArrayRef<TemplateArgument>) { | ||
| Arg.print(Policy, O, /* IncludeType = */ false); | ||
| } | ||
|
|
||
| void VisitTemplateTemplateArgument(const TemplateArgument &Arg, | ||
| ArrayRef<TemplateArgument>) { | ||
| Arg.print(Policy, O, /* IncludeType = */ false); | ||
| } | ||
|
|
||
| void VisitTemplateExpansionTemplateArgument(const TemplateArgument &Arg, | ||
| ArrayRef<TemplateArgument>) { | ||
| // Likely does not work similar to the one above | ||
|
||
| Arg.print(Policy, O, /* IncludeType = */ false); | ||
| } | ||
|
|
||
| void VisitExpressionTemplateArgument(const TemplateArgument &Arg, | ||
| ArrayRef<TemplateArgument>) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be great to explain somewhere why all the methods additionally accept an ArrayRef that sometimes ends up unused.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Where would be a good place to explain it? Top of the visitor definition?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep, I think that would be fine.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It has been added! |
||
| Expr *E = Arg.getAsExpr(); | ||
| assert(E && "Failed to get an Expr for an Expression template arg?"); | ||
|
|
||
| if (Arg.isInstantiationDependent() || | ||
| E->getType().getTypePtr()->isScopedEnumeralType()) { | ||
steffenlarsen marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| // Scoped enumerations can't be implicitly cast from integers, so | ||
| // we don't need to evaluate them. | ||
| // If expression is instantiation-dependent, then we can't evaluate it | ||
| // either, let's fallback to default printing mechanism. | ||
| Arg.print(Policy, O, /* IncludeType = */ false); | ||
| return; | ||
| } | ||
|
|
||
| Expr::EvalResult Res; | ||
| [[maybe_unused]] bool Success = | ||
| Arg.getAsExpr()->EvaluateAsConstantExpr(Res, Context); | ||
| assert(Success && "invalid non-type template argument?"); | ||
| assert(!Res.Val.isAbsent() && "couldn't read the evaulation result?"); | ||
| Res.Val.printPretty(O, Policy, Arg.getAsExpr()->getType(), &Context); | ||
| } | ||
|
|
||
| void VisitPackTemplateArgument(const TemplateArgument &Arg, | ||
| ArrayRef<TemplateArgument>) { | ||
| Arg.print(Policy, O, /* IncludeType = */ false); | ||
| } | ||
| }; | ||
|
|
||
| class FreeFunctionPrinter { | ||
| raw_ostream &O; | ||
| PrintingPolicy &Policy; | ||
|
|
@@ -6789,7 +6978,10 @@ class FreeFunctionPrinter { | |
| llvm::raw_svector_ostream ParmListOstream{ParamList}; | ||
| Policy.SuppressTagKeyword = true; | ||
|
|
||
| for (ParmVarDecl *Param : Parameters) { | ||
| FreeFunctionTemplateKernelArgsPrinter Printer(ParmListOstream, Policy, | ||
| Context); | ||
|
|
||
| for (const ParmVarDecl *Param : Parameters) { | ||
| if (FirstParam) | ||
| FirstParam = false; | ||
| else | ||
|
|
@@ -6822,53 +7014,11 @@ class FreeFunctionPrinter { | |
| } | ||
|
|
||
| const TemplateSpecializationType *TSTAsNonAlias = | ||
| TST->getAsNonAliasTemplateSpecializationType(); | ||
| TST->getAsTemplateSpecializationTypeWithoutAliases(Context); | ||
|
||
| if (TSTAsNonAlias) | ||
| TST = TSTAsNonAlias; | ||
|
|
||
| TemplateName CTN = CTST->getTemplateName(); | ||
| CTN.getAsTemplateDecl()->printQualifiedName(ParmListOstream); | ||
| ParmListOstream << "<"; | ||
|
|
||
| ArrayRef<TemplateArgument> SpecArgs = TST->template_arguments(); | ||
| ArrayRef<TemplateArgument> DeclArgs = CTST->template_arguments(); | ||
|
|
||
| auto TemplateArgPrinter = [&](const TemplateArgument &Arg) { | ||
| if (Arg.getKind() != TemplateArgument::ArgKind::Expression || | ||
| Arg.isInstantiationDependent()) { | ||
| Arg.print(Policy, ParmListOstream, /* IncludeType = */ false); | ||
| return; | ||
| } | ||
|
|
||
| Expr *E = Arg.getAsExpr(); | ||
| assert(E && "Failed to get an Expr for an Expression template arg?"); | ||
| if (E->getType().getTypePtr()->isScopedEnumeralType()) { | ||
| // Scoped enumerations can't be implicitly cast from integers, so | ||
| // we don't need to evaluate them. | ||
| Arg.print(Policy, ParmListOstream, /* IncludeType = */ false); | ||
| return; | ||
| } | ||
|
|
||
| Expr::EvalResult Res; | ||
| [[maybe_unused]] bool Success = | ||
| Arg.getAsExpr()->EvaluateAsConstantExpr(Res, Context); | ||
| assert(Success && "invalid non-type template argument?"); | ||
| assert(!Res.Val.isAbsent() && "couldn't read the evaulation result?"); | ||
| Res.Val.printPretty(ParmListOstream, Policy, Arg.getAsExpr()->getType(), | ||
| &Context); | ||
| }; | ||
|
|
||
| for (size_t I = 0, E = std::max(DeclArgs.size(), SpecArgs.size()), | ||
| SE = SpecArgs.size(); | ||
| I < E; ++I) { | ||
| if (I != 0) | ||
| ParmListOstream << ", "; | ||
| // If we have a specialized argument, use it. Otherwise fallback to a | ||
| // default argument. | ||
| TemplateArgPrinter(I < SE ? SpecArgs[I] : DeclArgs[I]); | ||
| } | ||
|
|
||
| ParmListOstream << ">"; | ||
| Printer.Visit(TST, CTST); | ||
| } | ||
| return ParamList.str().str(); | ||
| } | ||
|
|
@@ -6886,26 +7036,39 @@ class FreeFunctionPrinter { | |
| std::string getTemplateParameters(const clang::TemplateParameterList *TPL) { | ||
| std::string TemplateParams{"template <"}; | ||
| bool FirstParam{true}; | ||
| for (NamedDecl *Param : *TPL) { | ||
| for (const NamedDecl *Param : *TPL) { | ||
| if (!FirstParam) | ||
| TemplateParams += ", "; | ||
| FirstParam = false; | ||
| if (const auto *TemplateParam = dyn_cast<TemplateTypeParmDecl>(Param)) { | ||
| TemplateParams += | ||
| TemplateParam->wasDeclaredWithTypename() ? "typename " : "class "; | ||
| if (TemplateParam->isParameterPack()) | ||
| TemplateParams += "... "; | ||
| TemplateParams += TemplateParam->getNameAsString(); | ||
| } else if (const auto *NonTypeParam = | ||
| dyn_cast<NonTypeTemplateParmDecl>(Param)) { | ||
| TemplateParams += NonTypeParam->getType().getAsString(); | ||
| TemplateParams += " "; | ||
| TemplateParams += NonTypeParam->getNameAsString(); | ||
| } | ||
| TemplateParams += getTemplateParameter(Param); | ||
| } | ||
| TemplateParams += "> "; | ||
| return TemplateParams; | ||
| } | ||
|
|
||
| /// Helper method to get text representation of a template parameter. | ||
| /// \param Param The template parameter. | ||
| std::string getTemplateParameter(const NamedDecl *Param) { | ||
| auto GetTypenameOrClass = [](const auto *Param) { | ||
| return Param->wasDeclaredWithTypename() ? "typename " : "class "; | ||
| }; | ||
| if (const auto *TemplateParam = dyn_cast<TemplateTypeParmDecl>(Param)) { | ||
| std::string TemplateParamStr = GetTypenameOrClass(TemplateParam); | ||
| if (TemplateParam->isParameterPack()) | ||
| TemplateParamStr += "... "; | ||
| TemplateParamStr += TemplateParam->getNameAsString(); | ||
| return TemplateParamStr; | ||
| } else if (const auto *NonTypeParam = | ||
| dyn_cast<NonTypeTemplateParmDecl>(Param)) { | ||
| return NonTypeParam->getType().getAsString() + " " + | ||
| NonTypeParam->getNameAsString(); | ||
| } else if (const auto *TTParam = | ||
| dyn_cast<TemplateTemplateParmDecl>(Param)) { | ||
| return getTemplateParameters(TTParam->getTemplateParameters()) + " " + | ||
| GetTypenameOrClass(TTParam) + TTParam->getNameAsString(); | ||
| } | ||
| return ""; | ||
| } | ||
| }; | ||
|
|
||
| void SYCLIntegrationHeader::emit(raw_ostream &O) { | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.