diff --git a/clang-tools-extra/clangd/InlayHints.cpp b/clang-tools-extra/clangd/InlayHints.cpp index 20a238612a7e4..197c62c40dcf0 100644 --- a/clang-tools-extra/clangd/InlayHints.cpp +++ b/clang-tools-extra/clangd/InlayHints.cpp @@ -33,7 +33,6 @@ #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/Twine.h" -#include "llvm/ADT/identity.h" #include "llvm/Support/Casting.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/FormatVariadic.h" @@ -339,53 +338,6 @@ QualType maybeDesugar(ASTContext &AST, QualType QT) { return QT; } -// Given a callee expression `Fn`, if the call is through a function pointer, -// try to find the declaration of the corresponding function pointer type, -// so that we can recover argument names from it. -// FIXME: This function is mostly duplicated in SemaCodeComplete.cpp; unify. -static FunctionProtoTypeLoc getPrototypeLoc(Expr *Fn) { - TypeLoc Target; - Expr *NakedFn = Fn->IgnoreParenCasts(); - if (const auto *T = NakedFn->getType().getTypePtr()->getAs()) { - Target = T->getDecl()->getTypeSourceInfo()->getTypeLoc(); - } else if (const auto *DR = dyn_cast(NakedFn)) { - const auto *D = DR->getDecl(); - if (const auto *const VD = dyn_cast(D)) { - Target = VD->getTypeSourceInfo()->getTypeLoc(); - } - } - - if (!Target) - return {}; - - // Unwrap types that may be wrapping the function type - while (true) { - if (auto P = Target.getAs()) { - Target = P.getPointeeLoc(); - continue; - } - if (auto A = Target.getAs()) { - Target = A.getModifiedLoc(); - continue; - } - if (auto P = Target.getAs()) { - Target = P.getInnerLoc(); - continue; - } - break; - } - - if (auto F = Target.getAs()) { - // In some edge cases the AST can contain a "trivial" FunctionProtoTypeLoc - // which has null parameters. Avoid these as they don't contain useful - // information. - if (llvm::all_of(F.getParams(), llvm::identity())) - return F; - } - - return {}; -} - ArrayRef maybeDropCxxExplicitObjectParameters(ArrayRef Params) { if (!Params.empty() && Params.front()->isExplicitObjectParameter()) @@ -514,7 +466,8 @@ class InlayHintVisitor : public RecursiveASTVisitor { Callee.Decl = FD; else if (const auto *FTD = dyn_cast(CalleeDecls[0])) Callee.Decl = FTD->getTemplatedDecl(); - else if (FunctionProtoTypeLoc Loc = getPrototypeLoc(E->getCallee())) + else if (FunctionProtoTypeLoc Loc = + Resolver->getFunctionProtoTypeLoc(E->getCallee())) Callee.Loc = Loc; else return true; diff --git a/clang/include/clang/Sema/HeuristicResolver.h b/clang/include/clang/Sema/HeuristicResolver.h index df60d3359c6a6..e193c0bc14cd9 100644 --- a/clang/include/clang/Sema/HeuristicResolver.h +++ b/clang/include/clang/Sema/HeuristicResolver.h @@ -20,6 +20,7 @@ class CXXBasePath; class CXXDependentScopeMemberExpr; class DeclarationName; class DependentScopeDeclRefExpr; +class FunctionProtoTypeLoc; class NamedDecl; class Type; class UnresolvedUsingValueDecl; @@ -93,6 +94,12 @@ class HeuristicResolver { // during simplification, and the operation fails if no pointer type is found. QualType simplifyType(QualType Type, const Expr *E, bool UnwrapPointer); + // Given an expression `Fn` representing the callee in a function call, + // if the call is through a function pointer, try to find the declaration of + // the corresponding function pointer type, so that we can recover argument + // names from it. + FunctionProtoTypeLoc getFunctionProtoTypeLoc(const Expr *Fn) const; + private: ASTContext &Ctx; }; diff --git a/clang/lib/Sema/HeuristicResolver.cpp b/clang/lib/Sema/HeuristicResolver.cpp index 0c67f1f2a3878..6874d30516f8f 100644 --- a/clang/lib/Sema/HeuristicResolver.cpp +++ b/clang/lib/Sema/HeuristicResolver.cpp @@ -13,6 +13,7 @@ #include "clang/AST/ExprCXX.h" #include "clang/AST/TemplateBase.h" #include "clang/AST/Type.h" +#include "llvm/ADT/identity.h" namespace clang { @@ -50,6 +51,7 @@ class HeuristicResolverImpl { llvm::function_ref Filter); TagDecl *resolveTypeToTagDecl(QualType T); QualType simplifyType(QualType Type, const Expr *E, bool UnwrapPointer); + FunctionProtoTypeLoc getFunctionProtoTypeLoc(const Expr *Fn); private: ASTContext &Ctx; @@ -506,6 +508,56 @@ std::vector HeuristicResolverImpl::resolveDependentMember( } return {}; } + +FunctionProtoTypeLoc +HeuristicResolverImpl::getFunctionProtoTypeLoc(const Expr *Fn) { + TypeLoc Target; + const Expr *NakedFn = Fn->IgnoreParenCasts(); + if (const auto *T = NakedFn->getType().getTypePtr()->getAs()) { + Target = T->getDecl()->getTypeSourceInfo()->getTypeLoc(); + } else if (const auto *DR = dyn_cast(NakedFn)) { + const auto *D = DR->getDecl(); + if (const auto *const VD = dyn_cast(D)) { + Target = VD->getTypeSourceInfo()->getTypeLoc(); + } + } else if (const auto *ME = dyn_cast(NakedFn)) { + const auto *MD = ME->getMemberDecl(); + if (const auto *FD = dyn_cast(MD)) { + Target = FD->getTypeSourceInfo()->getTypeLoc(); + } + } + + if (!Target) + return {}; + + // Unwrap types that may be wrapping the function type + while (true) { + if (auto P = Target.getAs()) { + Target = P.getPointeeLoc(); + continue; + } + if (auto A = Target.getAs()) { + Target = A.getModifiedLoc(); + continue; + } + if (auto P = Target.getAs()) { + Target = P.getInnerLoc(); + continue; + } + break; + } + + if (auto F = Target.getAs()) { + // In some edge cases the AST can contain a "trivial" FunctionProtoTypeLoc + // which has null parameters. Avoid these as they don't contain useful + // information. + if (llvm::all_of(F.getParams(), llvm::identity())) + return F; + } + + return {}; +} + } // namespace std::vector HeuristicResolver::resolveMemberExpr( @@ -557,4 +609,9 @@ QualType HeuristicResolver::simplifyType(QualType Type, const Expr *E, return HeuristicResolverImpl(Ctx).simplifyType(Type, E, UnwrapPointer); } +FunctionProtoTypeLoc +HeuristicResolver::getFunctionProtoTypeLoc(const Expr *Fn) const { + return HeuristicResolverImpl(Ctx).getFunctionProtoTypeLoc(Fn); +} + } // namespace clang diff --git a/clang/lib/Sema/SemaCodeComplete.cpp b/clang/lib/Sema/SemaCodeComplete.cpp index f9f7c192f19d2..84739d4cee802 100644 --- a/clang/lib/Sema/SemaCodeComplete.cpp +++ b/clang/lib/Sema/SemaCodeComplete.cpp @@ -6283,54 +6283,6 @@ ProduceSignatureHelp(Sema &SemaRef, MutableArrayRef Candidates, return getParamType(SemaRef, Candidates, CurrentArg); } -// Given a callee expression `Fn`, if the call is through a function pointer, -// try to find the declaration of the corresponding function pointer type, -// so that we can recover argument names from it. -static FunctionProtoTypeLoc GetPrototypeLoc(Expr *Fn) { - TypeLoc Target; - - if (const auto *T = Fn->getType().getTypePtr()->getAs()) { - Target = T->getDecl()->getTypeSourceInfo()->getTypeLoc(); - - } else if (const auto *DR = dyn_cast(Fn)) { - const auto *D = DR->getDecl(); - if (const auto *const VD = dyn_cast(D)) { - Target = VD->getTypeSourceInfo()->getTypeLoc(); - } - } else if (const auto *ME = dyn_cast(Fn)) { - const auto *MD = ME->getMemberDecl(); - if (const auto *FD = dyn_cast(MD)) { - Target = FD->getTypeSourceInfo()->getTypeLoc(); - } - } - - if (!Target) - return {}; - - // Unwrap types that may be wrapping the function type - while (true) { - if (auto P = Target.getAs()) { - Target = P.getPointeeLoc(); - continue; - } - if (auto A = Target.getAs()) { - Target = A.getModifiedLoc(); - continue; - } - if (auto P = Target.getAs()) { - Target = P.getInnerLoc(); - continue; - } - break; - } - - if (auto F = Target.getAs()) { - return F; - } - - return {}; -} - QualType SemaCodeCompletion::ProduceCallSignatureHelp(Expr *Fn, ArrayRef Args, SourceLocation OpenParLoc) { @@ -6419,7 +6371,7 @@ SemaCodeCompletion::ProduceCallSignatureHelp(Expr *Fn, ArrayRef Args, // Lastly we check whether expression's type is function pointer or // function. - FunctionProtoTypeLoc P = GetPrototypeLoc(NakedFn); + FunctionProtoTypeLoc P = Resolver.getFunctionProtoTypeLoc(NakedFn); QualType T = NakedFn->getType(); if (!T->getPointeeType().isNull()) T = T->getPointeeType(); diff --git a/clang/unittests/Sema/HeuristicResolverTest.cpp b/clang/unittests/Sema/HeuristicResolverTest.cpp index 3ed6bba790be3..ee434f7a1d43a 100644 --- a/clang/unittests/Sema/HeuristicResolverTest.cpp +++ b/clang/unittests/Sema/HeuristicResolverTest.cpp @@ -766,5 +766,85 @@ TEST(HeuristicResolver, UsingValueDecl) { cxxMethodDecl(hasName("waldo")).bind("output")); } +// `arg` is a ParamVarDecl*, `Expected` is a string +MATCHER_P(ParamNameMatcher, Expected, "paramNameMatcher") { + EXPECT_TRUE(arg); + if (IdentifierInfo *Ident = arg->getDeclName().getAsIdentifierInfo()) { + return Ident->getName() == Expected; + } + return false; +} + +// Helper function for testing HeuristicResolver::getProtoTypeLoc. +// Takes a matcher that selects a callee expression bound to the ID "input", +// calls getProtoTypeLoc() on it, and checks that the call found a +// FunctionProtoTypeLoc encoding the given parameter names. +template +void expectParameterNames(ASTContext &Ctx, const InputMatcher &IM, + ParameterNames... ExpectedParameterNames) { + auto InputMatches = match(IM, Ctx); + ASSERT_EQ(1u, InputMatches.size()); + const auto *Input = InputMatches[0].template getNodeAs("input"); + ASSERT_TRUE(Input); + + HeuristicResolver H(Ctx); + auto Loc = H.getFunctionProtoTypeLoc(Input); + ASSERT_TRUE(Loc); + EXPECT_THAT(Loc.getParams(), + ElementsAre(ParamNameMatcher(ExpectedParameterNames)...)); +} + +TEST(HeuristicResolver, ProtoTypeLoc) { + std::string Code = R"cpp( + void (*f1)(int param1); + void (__stdcall *f2)(int param2); + using f3_t = void(*)(int param3); + f3_t f3; + using f4_t = void(__stdcall *)(int param4); + f4_t f4; + struct S { + void (*f5)(int param5); + using f6_t = void(*)(int param6); + f6_t f6; + }; + void bar() { + f1(42); + f2(42); + f3(42); + f4(42); + S s; + s.f5(42); + s.f6(42); + } + )cpp"; + auto TU = tooling::buildASTFromCodeWithArgs(Code, {"-std=c++20"}); + auto &Ctx = TU->getASTContext(); + auto checkFreeFunction = [&](llvm::StringRef FunctionName, + llvm::StringRef ParamName) { + expectParameterNames( + Ctx, + callExpr( + callee(implicitCastExpr(hasSourceExpression(declRefExpr( + to(namedDecl(hasName(FunctionName)))))) + .bind("input"))), + ParamName); + }; + checkFreeFunction("f1", "param1"); + checkFreeFunction("f2", "param2"); + checkFreeFunction("f3", "param3"); + checkFreeFunction("f4", "param4"); + auto checkMemberFunction = [&](llvm::StringRef MemberName, + llvm::StringRef ParamName) { + expectParameterNames( + Ctx, + callExpr(callee(implicitCastExpr(hasSourceExpression(memberExpr( + member(hasName(MemberName))))) + .bind("input"))), + ParamName); + }; + checkMemberFunction("f5", "param5"); + checkMemberFunction("f6", "param6"); +} + } // namespace } // namespace clang