Skip to content
3 changes: 3 additions & 0 deletions clang/include/clang/AST/TypeBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -2956,6 +2956,9 @@ class alignas(TypeAlignment) Type : public ExtQualsTypeCommonBase {
return TST;
}

const TemplateSpecializationType *
getAsTemplateSpecializationTypeWithoutAliases(const ASTContext &Ctx) const;

/// Member-template getAsAdjusted<specific type>. Look through specific kinds
/// of sugar (parens, attributes, etc) for an instance of \<specific type>.
/// This is used when you need to walk over sugar nodes that represent some
Expand Down
35 changes: 35 additions & 0 deletions clang/lib/AST/Type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1925,6 +1925,41 @@ Type::getAsNonAliasTemplateSpecializationType() const {
return TST;
}

const TemplateSpecializationType *
Type::getAsTemplateSpecializationTypeWithoutAliases(
const ASTContext &Ctx) const {
const TemplateSpecializationType *TST =
getAsNonAliasTemplateSpecializationType();
if (!TST)
return TST;

// Ensure the template arguments of the template specialization type are
// without aliases.
SmallVector<TemplateArgument, 4> ArgsWithoutAliases;
ArgsWithoutAliases.reserve(TST->template_arguments().size());
for (const TemplateArgument &TA : TST->template_arguments()) {
if (TA.getKind() == TemplateArgument::ArgKind::Type) {
QualType TAQTy = TA.getAsType();
const Type *TATy = TAQTy->getUnqualifiedDesugaredType();
if (isa<TemplateSpecializationType>(TATy))
TATy = TATy->getAsTemplateSpecializationTypeWithoutAliases(Ctx);
ArgsWithoutAliases.emplace_back(QualType(TATy, TAQTy.getCVRQualifiers()));
} else if (TA.getKind() == TemplateArgument::ArgKind::Template) {
TemplateName TN = TA.getAsTemplate();
while (std::optional<TemplateName> DesugaredTN =
TN.desugar(/*IgnoreDeduced=*/false))
TN = *DesugaredTN;
ArgsWithoutAliases.emplace_back(TN);
} else {
ArgsWithoutAliases.push_back(TA);
}
}
return Ctx
.getTemplateSpecializationType(TST->getKeyword(), TST->getTemplateName(),
ArgsWithoutAliases, {}, QualType{})
->getAs<TemplateSpecializationType>();
}

NestedNameSpecifier Type::getPrefix() const {
switch (getTypeClass()) {
case Type::DependentName:
Expand Down
279 changes: 221 additions & 58 deletions clang/lib/Sema/SemaSYCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6626,6 +6626,195 @@ static void PrintNSClosingBraces(raw_ostream &OS, const DeclContext *DC) {
[](raw_ostream &, const NamespaceDecl *) {}, OS, DC);
}

/// Dedicated visitor which helps with printing of kernel arguments in forward
/// declarations of free function kernels which are declared as function
/// templates.
///
/// Based on:
/// \code
/// template <typename T1, typename T2>
/// void foo(T1 a, int b, T2 c);
/// \endcode
///
/// It prints into the output stream "T1, int, T2".
///
/// The main complexity (which motivates addition of such visitor) comes from
/// the fact that there could be type aliases and default template arguments.
/// For example:
/// \code
/// template<typename T>
/// void kernel(sycl::accessor<T, 1>);
/// template void kernel(sycl::accessor<int, 1>);
/// \endcode
/// sycl::accessor has many template arguments which have default values. If
/// we iterate over non-canonicalized argument type, we don't get those default
/// values and we don't get necessary namespace qualifiers for all the template
/// arguments. If we iterate over canonicalized argument type, then all
/// references to T will be replaced with something like type-argument-X-Y.
/// What this visitor does is it iterates over both in sync, picking the right
/// values from one or another.
///
/// Moral of the story: drop integration header ASAP (but that is blocked
/// by support for 3rd-party host compilers, which is important).
class FreeFunctionTemplateKernelArgsPrinter
: public ConstTemplateArgumentVisitor<FreeFunctionTemplateKernelArgsPrinter,
void, ArrayRef<TemplateArgument>> {
raw_ostream &O;
PrintingPolicy &Policy;
ASTContext &Context;

using Base =
ConstTemplateArgumentVisitor<FreeFunctionTemplateKernelArgsPrinter, void,
ArrayRef<TemplateArgument>>;

void PrintTemplateDeclName(const TemplateDecl *TD,
ArrayRef<TemplateArgument> SpecArgs) {}

public:
FreeFunctionTemplateKernelArgsPrinter(raw_ostream &O, PrintingPolicy &Policy,
ASTContext &Context)
: O(O), Policy(Policy), Context(Context) {}

void Visit(const TemplateSpecializationType *T,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please give T and CT meaningful names or add a comment explaining what is what?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've moved the logic for extracting the types into this function, so I think the intention is a little clearer. That is, this function now takes the parameter declaration and extract the type and canonical type from it.

const TemplateSpecializationType *CT) {
ArrayRef<TemplateArgument> SpecArgs = T->template_arguments();
ArrayRef<TemplateArgument> DeclArgs = CT->template_arguments();

const TemplateDecl *TD = CT->getTemplateName().getAsTemplateDecl();
if (!TD->getIdentifier())
TD = T->getTemplateName().getAsTemplateDecl();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is T guaranteed to have an identifier?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't think of a case where it doesn't, but that doesn't mean it's impossible. I can introduce an unreachable or similar to error out instead of generating an invalid identifier if that is better?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think an assert might make sense.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added!

TD->printQualifiedName(O);

O << "<";
for (size_t I = 0, E = std::max(DeclArgs.size(), SpecArgs.size()),
SE = SpecArgs.size();
I < E; ++I) {
if (I != 0)
O << ", ";
// If we have a specialized argument, use it. Otherwise fallback to a
// default argument.
// We pass specialized arguments in case there are references to them
// from other types.
// FIXME: passing SpecArgs here is incorrect. It refers to template
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That FIXME does seem concerning.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, though the current version of this still addresses a chunk of the current issues we are seeing with the prototype generation. I can try to add some more disabled cases to the test so we know what to fix.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After having a look at this, I cannot seem to find a case that allows template arguments of differing depth. Maybe @AlexeySachkov knows of one?

// arguments of a single function argument, but DeclArgs contain
// references (in form of depth-index) to template arguments of the
// function itself which results in incorrect integration header being
// produced.
Base::Visit(I < SE ? SpecArgs[I] : DeclArgs[I], SpecArgs);
}
O << ">";
}

// Internal version of the function above that is used when template argument
// is a template by itself
void Visit(const TemplateSpecializationType *T,
ArrayRef<TemplateArgument> SpecArgs) {
const TemplateDecl *TD = T->getTemplateName().getAsTemplateDecl();
const auto *TTPD = dyn_cast<TemplateTemplateParmDecl>(TD);
if (TTPD && !TTPD->getIdentifier())
SpecArgs[TTPD->getIndex()].print(Policy, O, /* IncludeType = */ false);
else
TD->printQualifiedName(O);
O << "<";
ArrayRef<const TemplateArgument> DeclArgs = T->template_arguments();
for (size_t I = 0, E = DeclArgs.size(); I < E; ++I) {
if (I != 0)
O << ", ";
Base::Visit(DeclArgs[I], SpecArgs);
}
O << ">";
}

void VisitNullTemplateArgument(const TemplateArgument &,
ArrayRef<TemplateArgument>) {
llvm_unreachable("If template argument has not been deduced, then we can't "
"forward-declare it, something went wrong");
}

void VisitTypeTemplateArgument(const TemplateArgument &Arg,
ArrayRef<TemplateArgument> SpecArgs) {
// If we reference an existing template argument without a known identifier,
// print it instead.
const auto *TPT = dyn_cast<TemplateTypeParmType>(Arg.getAsType());
if (TPT && !TPT->getIdentifier()) {
SpecArgs[TPT->getIndex()].print(Policy, O, /* IncludeType = */ false);
return;
}

const auto *TST = dyn_cast<TemplateSpecializationType>(Arg.getAsType());
if (TST && Arg.isInstantiationDependent()) {
// This is an instantiation dependent template specialization, meaning
// that some of its arguments reference template arguments of the free
// function kernel itself.
Visit(TST, SpecArgs);
return;
}

Arg.print(Policy, O, /* IncludeType = */ false);
}

void VisitDeclarationTemplateArgument(const TemplateArgument &,
ArrayRef<TemplateArgument>) {
llvm_unreachable("Free function kernels cannot have non-type template "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make sense to diagnose this case prior integration header emission? Like in SyclKernelFieldChecker.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should, but I would do it in a new SyclFreeFunctionKernelChecker (unless I missed another place that would be equally as fitting.) Would it be alright as a follow-up?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure!

"arguments which are pointers or references");
}

void VisitNullPtrTemplateArgument(const TemplateArgument &,
ArrayRef<TemplateArgument>) {
llvm_unreachable("Free function kernels cannot have non-type template "
"arguments which are pointers or references");
}

void VisitIntegralTemplateArgument(const TemplateArgument &Arg,
ArrayRef<TemplateArgument>) {
Arg.print(Policy, O, /* IncludeType = */ false);
}

void VisitStructuralValueTemplateArgument(const TemplateArgument &Arg,
ArrayRef<TemplateArgument>) {
Arg.print(Policy, O, /* IncludeType = */ false);
}

void VisitTemplateTemplateArgument(const TemplateArgument &Arg,
ArrayRef<TemplateArgument>) {
Arg.print(Policy, O, /* IncludeType = */ false);
}

void VisitTemplateExpansionTemplateArgument(const TemplateArgument &Arg,
ArrayRef<TemplateArgument>) {
// Likely does not work similar to the one above
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is that a FIXME?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added a test case for packs and it seems to work fine, so I've removed the comment.

Arg.print(Policy, O, /* IncludeType = */ false);
}

void VisitExpressionTemplateArgument(const TemplateArgument &Arg,
ArrayRef<TemplateArgument>) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be great to explain somewhere why all the methods additionally accept an ArrayRef that sometimes ends up unused.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where would be a good place to explain it? Top of the visitor definition?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, I think that would be fine.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It has been added!

Expr *E = Arg.getAsExpr();
assert(E && "Failed to get an Expr for an Expression template arg?");

if (Arg.isInstantiationDependent() ||
E->getType().getTypePtr()->isScopedEnumeralType()) {
// Scoped enumerations can't be implicitly cast from integers, so
// we don't need to evaluate them.
// If expression is instantiation-dependent, then we can't evaluate it
// either, let's fallback to default printing mechanism.
Arg.print(Policy, O, /* IncludeType = */ false);
return;
}

Expr::EvalResult Res;
[[maybe_unused]] bool Success =
Arg.getAsExpr()->EvaluateAsConstantExpr(Res, Context);
assert(Success && "invalid non-type template argument?");
assert(!Res.Val.isAbsent() && "couldn't read the evaulation result?");
Res.Val.printPretty(O, Policy, Arg.getAsExpr()->getType(), &Context);
}

void VisitPackTemplateArgument(const TemplateArgument &Arg,
ArrayRef<TemplateArgument>) {
Arg.print(Policy, O, /* IncludeType = */ false);
}
};

class FreeFunctionPrinter {
raw_ostream &O;
PrintingPolicy &Policy;
Expand Down Expand Up @@ -6789,7 +6978,10 @@ class FreeFunctionPrinter {
llvm::raw_svector_ostream ParmListOstream{ParamList};
Policy.SuppressTagKeyword = true;

for (ParmVarDecl *Param : Parameters) {
FreeFunctionTemplateKernelArgsPrinter Printer(ParmListOstream, Policy,
Context);

for (const ParmVarDecl *Param : Parameters) {
if (FirstParam)
FirstParam = false;
else
Expand Down Expand Up @@ -6822,53 +7014,11 @@ class FreeFunctionPrinter {
}

const TemplateSpecializationType *TSTAsNonAlias =
TST->getAsNonAliasTemplateSpecializationType();
TST->getAsTemplateSpecializationTypeWithoutAliases(Context);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I was kind of expecting that the visitor soluiton would help us to avoid unwrapping and modifying types manually like getAsTemplateSpecializationTypeWithoutAliases does. Is that not the case? Could you please elaborate why do we need such a complex solution?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This patch was the marriage between a patch @AlexeySachkov made and one that I prepared. The solution is still needed to avoid aliases, but it has now been integrated into the visitor, which I indeed think is the better solution. Good thinking!

if (TSTAsNonAlias)
TST = TSTAsNonAlias;

TemplateName CTN = CTST->getTemplateName();
CTN.getAsTemplateDecl()->printQualifiedName(ParmListOstream);
ParmListOstream << "<";

ArrayRef<TemplateArgument> SpecArgs = TST->template_arguments();
ArrayRef<TemplateArgument> DeclArgs = CTST->template_arguments();

auto TemplateArgPrinter = [&](const TemplateArgument &Arg) {
if (Arg.getKind() != TemplateArgument::ArgKind::Expression ||
Arg.isInstantiationDependent()) {
Arg.print(Policy, ParmListOstream, /* IncludeType = */ false);
return;
}

Expr *E = Arg.getAsExpr();
assert(E && "Failed to get an Expr for an Expression template arg?");
if (E->getType().getTypePtr()->isScopedEnumeralType()) {
// Scoped enumerations can't be implicitly cast from integers, so
// we don't need to evaluate them.
Arg.print(Policy, ParmListOstream, /* IncludeType = */ false);
return;
}

Expr::EvalResult Res;
[[maybe_unused]] bool Success =
Arg.getAsExpr()->EvaluateAsConstantExpr(Res, Context);
assert(Success && "invalid non-type template argument?");
assert(!Res.Val.isAbsent() && "couldn't read the evaulation result?");
Res.Val.printPretty(ParmListOstream, Policy, Arg.getAsExpr()->getType(),
&Context);
};

for (size_t I = 0, E = std::max(DeclArgs.size(), SpecArgs.size()),
SE = SpecArgs.size();
I < E; ++I) {
if (I != 0)
ParmListOstream << ", ";
// If we have a specialized argument, use it. Otherwise fallback to a
// default argument.
TemplateArgPrinter(I < SE ? SpecArgs[I] : DeclArgs[I]);
}

ParmListOstream << ">";
Printer.Visit(TST, CTST);
}
return ParamList.str().str();
}
Expand All @@ -6886,26 +7036,39 @@ class FreeFunctionPrinter {
std::string getTemplateParameters(const clang::TemplateParameterList *TPL) {
std::string TemplateParams{"template <"};
bool FirstParam{true};
for (NamedDecl *Param : *TPL) {
for (const NamedDecl *Param : *TPL) {
if (!FirstParam)
TemplateParams += ", ";
FirstParam = false;
if (const auto *TemplateParam = dyn_cast<TemplateTypeParmDecl>(Param)) {
TemplateParams +=
TemplateParam->wasDeclaredWithTypename() ? "typename " : "class ";
if (TemplateParam->isParameterPack())
TemplateParams += "... ";
TemplateParams += TemplateParam->getNameAsString();
} else if (const auto *NonTypeParam =
dyn_cast<NonTypeTemplateParmDecl>(Param)) {
TemplateParams += NonTypeParam->getType().getAsString();
TemplateParams += " ";
TemplateParams += NonTypeParam->getNameAsString();
}
TemplateParams += getTemplateParameter(Param);
}
TemplateParams += "> ";
return TemplateParams;
}

/// Helper method to get text representation of a template parameter.
/// \param Param The template parameter.
std::string getTemplateParameter(const NamedDecl *Param) {
auto GetTypenameOrClass = [](const auto *Param) {
return Param->wasDeclaredWithTypename() ? "typename " : "class ";
};
if (const auto *TemplateParam = dyn_cast<TemplateTypeParmDecl>(Param)) {
std::string TemplateParamStr = GetTypenameOrClass(TemplateParam);
if (TemplateParam->isParameterPack())
TemplateParamStr += "... ";
TemplateParamStr += TemplateParam->getNameAsString();
return TemplateParamStr;
} else if (const auto *NonTypeParam =
dyn_cast<NonTypeTemplateParmDecl>(Param)) {
return NonTypeParam->getType().getAsString() + " " +
NonTypeParam->getNameAsString();
} else if (const auto *TTParam =
dyn_cast<TemplateTemplateParmDecl>(Param)) {
return getTemplateParameters(TTParam->getTemplateParameters()) + " " +
GetTypenameOrClass(TTParam) + TTParam->getNameAsString();
}
return "";
}
};

void SYCLIntegrationHeader::emit(raw_ostream &O) {
Expand Down
Loading
Loading