11 #include "clang/AST/ASTContext.h"
12 #include "clang/AST/RecursiveASTVisitor.h"
13 #include "clang/ASTMatchers/ASTMatchFinder.h"
14 #include "clang/Lex/Preprocessor.h"
15 #include "clang/Tooling/FixIt.h"
16 #include "llvm/ADT/StringExtras.h"
26 struct UnqualNameVisitor :
public RecursiveASTVisitor<UnqualNameVisitor> {
28 UnqualNameVisitor(
const FunctionDecl &F) : F(F) {}
32 bool shouldWalkTypesOfTypeLocs()
const {
return false; }
34 bool VisitUnqualName(StringRef UnqualName) {
36 for (ParmVarDecl *Param : F.parameters())
37 if (
const IdentifierInfo *Ident = Param->getIdentifier())
38 if (Ident->getName() == UnqualName) {
45 bool TraverseTypeLoc(TypeLoc TL,
bool Elaborated =
false) {
50 switch (TL.getTypeLocClass()) {
53 TL.getAs<RecordTypeLoc>().getTypePtr()->getDecl()->getName()))
58 TL.getAs<EnumTypeLoc>().getTypePtr()->getDecl()->getName()))
61 case TypeLoc::TemplateSpecialization:
62 if (VisitUnqualName(TL.getAs<TemplateSpecializationTypeLoc>()
74 return RecursiveASTVisitor<UnqualNameVisitor>::TraverseTypeLoc(TL);
79 bool TraverseQualifiedTypeLoc(QualifiedTypeLoc TL) {
80 return TraverseTypeLoc(TL.getUnqualifiedLoc());
85 bool TraverseElaboratedTypeLoc(ElaboratedTypeLoc TL) {
86 if (TL.getQualifierLoc() &&
87 !TraverseNestedNameSpecifierLoc(TL.getQualifierLoc()))
89 return TraverseTypeLoc(TL.getNamedTypeLoc(),
true);
92 bool VisitDeclRefExpr(DeclRefExpr *S) {
93 DeclarationName
Name = S->getNameInfo().getName();
94 return S->getQualifierLoc() || !
Name.isIdentifier() ||
95 !VisitUnqualName(
Name.getAsIdentifierInfo()->getName());
99 const FunctionDecl &F;
104 "use a trailing return type for this function";
107 const SourceManager &SM) {
110 assert(!
Loc.isMacroID() &&
111 "SourceLocation must not be a macro ID after recursive expansion");
115 SourceLocation UseTrailingReturnTypeCheck::findTrailingReturnTypeSourceLocation(
116 const FunctionDecl &F,
const FunctionTypeLoc &FTL,
const ASTContext &
Ctx,
117 const SourceManager &SM,
const LangOptions &LangOpts) {
119 SourceRange ExceptionSpecRange = F.getExceptionSpecSourceRange();
120 if (ExceptionSpecRange.isValid())
121 return Lexer::getLocForEndOfToken(ExceptionSpecRange.getEnd(), 0, SM,
126 SourceLocation ClosingParen = FTL.getRParenLoc();
127 if (ClosingParen.isMacroID())
130 SourceLocation Result =
131 Lexer::getLocForEndOfToken(ClosingParen, 0, SM, LangOpts);
134 std::pair<FileID, unsigned>
Loc = SM.getDecomposedLoc(Result);
135 StringRef File = SM.getBufferData(
Loc.first);
136 const char *TokenBegin = File.data() +
Loc.second;
137 Lexer Lexer(SM.getLocForStartOfFile(
Loc.first), LangOpts, File.begin(),
138 TokenBegin, File.end());
140 while (!Lexer.LexFromRawLexer(T)) {
141 if (T.is(tok::raw_identifier)) {
142 IdentifierInfo &
Info =
Ctx.Idents.get(
143 StringRef(SM.getCharacterData(T.getLocation()), T.getLength()));
144 T.setIdentifierInfo(&Info);
145 T.setKind(
Info.getTokenID());
148 if (T.isOneOf(tok::amp, tok::ampamp, tok::kw_const, tok::kw_volatile,
150 Result = T.getEndLoc();
159 return T.isOneOf(tok::kw_const, tok::kw_volatile, tok::kw_restrict);
163 return T.isOneOf(tok::kw_constexpr, tok::kw_inline, tok::kw_extern,
164 tok::kw_static, tok::kw_friend, tok::kw_virtual);
167 static llvm::Optional<ClassifiedToken>
173 bool ContainsQualifiers =
false;
174 bool ContainsSpecifiers =
false;
175 bool ContainsSomethingElse =
false;
179 End.setKind(tok::eof);
180 SmallVector<Token, 2> Stream{Tok, End};
183 PP.EnterTokenStream(Stream,
false,
false);
190 bool Qual =
IsCVR(T);
194 ContainsQualifiers |= Qual;
195 ContainsSpecifiers |= Spec;
196 ContainsSomethingElse |= !Qual && !Spec;
201 if (ContainsQualifiers + ContainsSpecifiers + ContainsSomethingElse > 1)
207 llvm::Optional<SmallVector<ClassifiedToken, 8>>
208 UseTrailingReturnTypeCheck::classifyTokensBeforeFunctionName(
209 const FunctionDecl &F,
const ASTContext &
Ctx,
const SourceManager &SM,
210 const LangOptions &LangOpts) {
215 std::pair<FileID, unsigned>
Loc = SM.getDecomposedLoc(BeginF);
216 StringRef File = SM.getBufferData(
Loc.first);
217 const char *TokenBegin = File.data() +
Loc.second;
218 Lexer Lexer(SM.getLocForStartOfFile(
Loc.first), LangOpts, File.begin(),
219 TokenBegin, File.end());
221 SmallVector<ClassifiedToken, 8> ClassifiedTokens;
222 while (!Lexer.LexFromRawLexer(T) &&
223 SM.isBeforeInTranslationUnit(T.getLocation(), BeginNameF)) {
224 if (T.is(tok::raw_identifier)) {
225 IdentifierInfo &
Info =
Ctx.Idents.get(
226 StringRef(SM.getCharacterData(T.getLocation()), T.getLength()));
228 if (
Info.hasMacroDefinition()) {
229 const MacroInfo *MI =
PP->getMacroInfo(&Info);
230 if (!MI || MI->isFunctionLike()) {
232 diag(F.getLocation(),
Message);
237 T.setIdentifierInfo(&Info);
238 T.setKind(
Info.getTokenID());
242 ClassifiedTokens.push_back(*CT);
244 diag(F.getLocation(),
Message);
249 return ClassifiedTokens;
253 bool Result =
Type.hasLocalQualifiers();
254 if (
Type->isPointerType())
256 Type->castAs<PointerType>()->getPointeeType());
257 if (
Type->isReferenceType())
259 Type->castAs<ReferenceType>()->getPointeeType());
263 SourceRange UseTrailingReturnTypeCheck::findReturnTypeAndCVSourceRange(
264 const FunctionDecl &F,
const ASTContext &
Ctx,
const SourceManager &SM,
265 const LangOptions &LangOpts) {
269 SourceRange ReturnTypeRange = F.getReturnTypeSourceRange();
270 if (ReturnTypeRange.isInvalid()) {
273 diag(F.getLocation(),
Message);
279 return ReturnTypeRange;
282 llvm::Optional<SmallVector<ClassifiedToken, 8>> MaybeTokens =
283 classifyTokensBeforeFunctionName(F,
Ctx, SM, LangOpts);
286 const SmallVector<ClassifiedToken, 8> &Tokens = *MaybeTokens;
288 ReturnTypeRange.setBegin(
expandIfMacroId(ReturnTypeRange.getBegin(), SM));
291 bool ExtendedLeft =
false;
292 for (
size_t I = 0; I < Tokens.size(); I++) {
294 if (!SM.isBeforeInTranslationUnit(Tokens[I].T.getLocation(),
295 ReturnTypeRange.getBegin()) &&
297 assert(I <=
size_t(std::numeric_limits<int>::max()) &&
298 "Integer overflow detected");
299 for (
int J = static_cast<int>(I) - 1; J >= 0 && Tokens[J].isQualifier;
301 ReturnTypeRange.setBegin(Tokens[J].T.getLocation());
305 if (SM.isBeforeInTranslationUnit(ReturnTypeRange.getEnd(),
306 Tokens[I].T.getLocation())) {
307 for (
size_t J = I; J < Tokens.size() && Tokens[J].isQualifier; J++)
308 ReturnTypeRange.setEnd(Tokens[J].T.getLocation());
313 assert(!ReturnTypeRange.getBegin().isMacroID() &&
314 "Return type source range begin must not be a macro");
315 assert(!ReturnTypeRange.getEnd().isMacroID() &&
316 "Return type source range end must not be a macro");
317 return ReturnTypeRange;
320 bool UseTrailingReturnTypeCheck::keepSpecifiers(
321 std::string &
ReturnType, std::string &Auto, SourceRange ReturnTypeCVRange,
322 const FunctionDecl &F,
const FriendDecl *Fr,
const ASTContext &
Ctx,
323 const SourceManager &SM,
const LangOptions &LangOpts) {
326 const auto *M = dyn_cast<CXXMethodDecl>(&F);
327 if (!F.isConstexpr() && !F.isInlineSpecified() &&
328 F.getStorageClass() != SC_Extern && F.getStorageClass() != SC_Static &&
329 !Fr && !(M && M->isVirtualAsWritten()))
334 llvm::Optional<SmallVector<ClassifiedToken, 8>> MaybeTokens =
335 classifyTokensBeforeFunctionName(F,
Ctx, SM, LangOpts);
340 unsigned int ReturnTypeBeginOffset =
341 SM.getDecomposedLoc(ReturnTypeCVRange.getBegin()).second;
342 size_t InitialAutoLength =
Auto.size();
343 unsigned int DeletedChars = 0;
344 for (ClassifiedToken CT : *MaybeTokens) {
345 if (SM.isBeforeInTranslationUnit(CT.T.getLocation(),
346 ReturnTypeCVRange.getBegin()) ||
347 SM.isBeforeInTranslationUnit(ReturnTypeCVRange.getEnd(),
355 unsigned int TOffset = SM.getDecomposedLoc(CT.T.getLocation()).second;
356 assert(TOffset >= ReturnTypeBeginOffset &&
357 "Token location must be after the beginning of the return type");
358 unsigned int TOffsetInRT = TOffset - ReturnTypeBeginOffset - DeletedChars;
359 unsigned int TLengthWithWS = CT.T.getLength();
360 while (TOffsetInRT + TLengthWithWS <
ReturnType.size() &&
361 llvm::isSpace(
ReturnType[TOffsetInRT + TLengthWithWS]))
363 std::string Specifier =
ReturnType.substr(TOffsetInRT, TLengthWithWS);
364 if (!llvm::isSpace(Specifier.back()))
365 Specifier.push_back(
' ');
366 Auto.insert(
Auto.size() - InitialAutoLength, Specifier);
368 DeletedChars += TLengthWithWS;
374 void UseTrailingReturnTypeCheck::registerMatchers(MatchFinder *Finder) {
375 auto F = functionDecl(unless(anyOf(hasTrailingReturn(), returns(voidType()),
376 returns(autoType()), cxxConversionDecl(),
377 cxxMethodDecl(isImplicit()))))
380 Finder->addMatcher(F,
this);
381 Finder->addMatcher(friendDecl(hasDescendant(F)).bind(
"Friend"),
this);
384 void UseTrailingReturnTypeCheck::registerPPCallbacks(
385 const SourceManager &SM, Preprocessor *
PP, Preprocessor *ModuleExpanderPP) {
389 void UseTrailingReturnTypeCheck::check(
const MatchFinder::MatchResult &Result) {
390 assert(
PP &&
"Expected registerPPCallbacks() to have been called before so "
391 "preprocessor is available");
393 const auto *F = Result.Nodes.getNodeAs<FunctionDecl>(
"Func");
394 const auto *Fr = Result.Nodes.getNodeAs<FriendDecl>(
"Friend");
395 assert(F &&
"Matcher is expected to find only FunctionDecls");
397 if (F->getLocation().isInvalid())
401 if (F->getDeclaredReturnType()->isFunctionPointerType() ||
402 F->getDeclaredReturnType()->isMemberFunctionPointerType() ||
403 F->getDeclaredReturnType()->isMemberPointerType() ||
404 F->getDeclaredReturnType()->getAs<DecltypeType>() !=
nullptr) {
405 diag(F->getLocation(),
Message);
409 const ASTContext &
Ctx = *Result.Context;
410 const SourceManager &SM = *Result.SourceManager;
411 const LangOptions &LangOpts = getLangOpts();
413 const TypeSourceInfo *TSI = F->getTypeSourceInfo();
417 FunctionTypeLoc FTL =
418 TSI->getTypeLoc().IgnoreParens().getAs<FunctionTypeLoc>();
423 diag(F->getLocation(),
Message);
427 SourceLocation InsertionLoc =
428 findTrailingReturnTypeSourceLocation(*F, FTL,
Ctx, SM, LangOpts);
429 if (InsertionLoc.isInvalid()) {
430 diag(F->getLocation(),
Message);
437 SourceRange ReturnTypeCVRange =
438 findReturnTypeAndCVSourceRange(*F,
Ctx, SM, LangOpts);
439 if (ReturnTypeCVRange.isInvalid())
450 UnqualNameVisitor UNV{*F};
451 UNV.TraverseTypeLoc(FTL.getReturnLoc());
453 diag(F->getLocation(),
Message);
457 SourceLocation ReturnTypeEnd =
458 Lexer::getLocForEndOfToken(ReturnTypeCVRange.getEnd(), 0, SM, LangOpts);
459 StringRef CharAfterReturnType = Lexer::getSourceText(
460 CharSourceRange::getCharRange(ReturnTypeEnd,
461 ReturnTypeEnd.getLocWithOffset(1)),
463 bool NeedSpaceAfterAuto =
464 CharAfterReturnType.empty() || !llvm::isSpace(CharAfterReturnType[0]);
466 std::string Auto = NeedSpaceAfterAuto ?
"auto " :
"auto";
468 std::string(tooling::fixit::getText(ReturnTypeCVRange,
Ctx));
469 keepSpecifiers(
ReturnType, Auto, ReturnTypeCVRange, *F, Fr,
Ctx, SM,
472 diag(F->getLocation(),
Message)
473 << FixItHint::CreateReplacement(ReturnTypeCVRange, Auto)
474 << FixItHint::CreateInsertion(InsertionLoc,
" -> " +
ReturnType);