55 #include "clang/AST/ASTContext.h"
56 #include "clang/AST/Decl.h"
57 #include "clang/AST/DeclTemplate.h"
58 #include "clang/AST/RecursiveASTVisitor.h"
59 #include "clang/AST/Stmt.h"
60 #include "clang/Basic/LangOptions.h"
61 #include "clang/Basic/SourceLocation.h"
62 #include "clang/Basic/SourceManager.h"
63 #include "clang/Lex/Lexer.h"
64 #include "clang/Tooling/Core/Replacement.h"
65 #include "clang/Tooling/Refactoring/Extract/SourceExtraction.h"
66 #include "llvm/ADT/None.h"
67 #include "llvm/ADT/Optional.h"
68 #include "llvm/ADT/SmallVector.h"
69 #include "llvm/ADT/StringRef.h"
70 #include "llvm/ADT/iterator_range.h"
71 #include "llvm/Support/Casting.h"
72 #include "llvm/Support/Error.h"
78 using Node = SelectionTree::Node;
93 bool isRootStmt(
const Node *N) {
94 if (!N->ASTNode.get<Stmt>())
116 const Node *getParentOfRootStmts(
const Node *CommonAnc) {
119 const Node *
Parent =
nullptr;
120 switch (CommonAnc->Selected) {
121 case SelectionTree::Selection::Unselected:
126 case SelectionTree::Selection::Partial:
129 case SelectionTree::Selection::Complete:
132 Parent = CommonAnc->Parent;
137 if (
Parent->ASTNode.get<DeclStmt>())
142 return llvm::all_of(
Parent->Children, isRootStmt) ?
Parent :
nullptr;
146 struct ExtractionZone {
155 SourceLocation getInsertionPoint()
const {
158 bool isRootStmt(
const Stmt *S)
const;
161 const Node *getLastRootStmt()
const {
return Parent->Children.back(); }
162 void generateRootStmts();
165 llvm::DenseSet<const Stmt *> RootStmts;
172 bool alwaysReturns(
const ExtractionZone &EZ) {
173 const Stmt *Last = EZ.getLastRootStmt()->ASTNode.get<Stmt>();
175 while (
const auto *CS = llvm::dyn_cast<CompoundStmt>(Last)) {
176 if (CS->body_empty())
179 Last = CS->body_back();
181 return llvm::isa<ReturnStmt>(Last);
184 bool ExtractionZone::isRootStmt(
const Stmt *S)
const {
185 return RootStmts.find(S) != RootStmts.end();
189 void ExtractionZone::generateRootStmts() {
190 for (
const Node *Child :
Parent->Children)
191 RootStmts.insert(Child->ASTNode.get<Stmt>());
195 const FunctionDecl *findEnclosingFunction(
const Node *CommonAnc) {
197 for (
const Node *CurNode = CommonAnc; CurNode; CurNode = CurNode->Parent) {
199 if (CurNode->ASTNode.get<LambdaExpr>())
201 if (
const FunctionDecl *Func = CurNode->ASTNode.get<FunctionDecl>()) {
203 if (isa<CXXMethodDecl>(Func))
206 if (Func->isTemplated())
216 llvm::Optional<SourceRange> findZoneRange(
const Node *
Parent,
217 const SourceManager &SM,
218 const LangOptions &LangOpts) {
221 SM, LangOpts,
Parent->Children.front()->ASTNode.getSourceRange()))
222 SR.setBegin(BeginFileRange->getBegin());
226 SM, LangOpts,
Parent->Children.back()->ASTNode.getSourceRange()))
227 SR.setEnd(EndFileRange->getEnd());
237 llvm::Optional<SourceRange>
239 const SourceManager &SM,
240 const LangOptions &LangOpts) {
246 bool validSingleChild(
const Node *Child,
const FunctionDecl *EnclosingFunc) {
250 if (Child->ASTNode.get<Expr>())
253 assert(EnclosingFunc->hasBody() &&
254 "We should always be extracting from a function body.");
255 if (Child->ASTNode.get<Stmt>() == EnclosingFunc->getBody())
262 llvm::Optional<ExtractionZone> findExtractionZone(
const Node *CommonAnc,
263 const SourceManager &SM,
264 const LangOptions &LangOpts) {
265 ExtractionZone ExtZone;
266 ExtZone.Parent = getParentOfRootStmts(CommonAnc);
267 if (!ExtZone.Parent || ExtZone.Parent->Children.empty())
269 ExtZone.EnclosingFunction = findEnclosingFunction(ExtZone.Parent);
270 if (!ExtZone.EnclosingFunction)
274 if (ExtZone.Parent->Children.size() == 1 &&
275 !validSingleChild(ExtZone.getLastRootStmt(), ExtZone.EnclosingFunction))
278 computeEnclosingFuncRange(ExtZone.EnclosingFunction, SM, LangOpts))
279 ExtZone.EnclosingFuncRange = *FuncRange;
280 if (
auto ZoneRange = findZoneRange(ExtZone.Parent, SM, LangOpts))
282 if (ExtZone.EnclosingFuncRange.isInvalid() || ExtZone.ZoneRange.isInvalid())
284 ExtZone.generateRootStmts();
296 std::string render(
const DeclContext *Context)
const;
301 std::string
Name =
"extracted";
314 std::string renderCall()
const;
316 std::string renderDefinition(
const SourceManager &SM)
const;
319 std::string renderParametersForDefinition()
const;
320 std::string renderParametersForCall()
const;
322 std::string getFuncBody(
const SourceManager &SM)
const;
325 std::string NewFunction::renderParametersForDefinition()
const {
327 bool NeedCommaBefore =
false;
331 NeedCommaBefore =
true;
337 std::string NewFunction::renderParametersForCall()
const {
339 bool NeedCommaBefore =
false;
343 NeedCommaBefore =
true;
349 std::string NewFunction::renderCall()
const {
352 renderParametersForCall(),
356 std::string NewFunction::renderDefinition(
const SourceManager &SM)
const {
357 return std::string(llvm::formatv(
359 Name, renderParametersForDefinition(), getFuncBody(SM)));
362 std::string NewFunction::getFuncBody(
const SourceManager &SM)
const {
371 std::string NewFunction::Parameter::render(
const DeclContext *Context)
const {
376 struct CapturedZoneInfo {
377 struct DeclInformation {
389 void markOccurence(ZoneRelative ReferenceLoc);
400 DeclInformation *createDeclInfo(
const Decl *D, ZoneRelative RelativeLoc);
401 DeclInformation *getDeclInfoFor(
const Decl *D);
404 CapturedZoneInfo::DeclInformation *
405 CapturedZoneInfo::createDeclInfo(
const Decl *D, ZoneRelative RelativeLoc) {
408 {D, DeclInformation(D, RelativeLoc,
DeclInfoMap.size())});
410 return &InsertionResult.first->second;
413 CapturedZoneInfo::DeclInformation *
414 CapturedZoneInfo::getDeclInfoFor(
const Decl *D) {
419 return &Iter->second;
422 void CapturedZoneInfo::DeclInformation::markOccurence(
423 ZoneRelative ReferenceLoc) {
424 switch (ReferenceLoc) {
425 case ZoneRelative::Inside:
428 case ZoneRelative::After:
436 bool isLoop(
const Stmt *S) {
437 return isa<ForStmt>(S) || isa<DoStmt>(S) || isa<WhileStmt>(S) ||
438 isa<CXXForRangeStmt>(S);
442 CapturedZoneInfo captureZoneInfo(
const ExtractionZone &ExtZone) {
446 class ExtractionZoneVisitor
447 :
public clang::RecursiveASTVisitor<ExtractionZoneVisitor> {
449 ExtractionZoneVisitor(
const ExtractionZone &ExtZone) : ExtZone(ExtZone) {
450 TraverseDecl(const_cast<FunctionDecl *>(ExtZone.EnclosingFunction));
453 bool TraverseStmt(Stmt *S) {
456 bool IsRootStmt = ExtZone.isRootStmt(const_cast<const Stmt *>(S));
460 CurrentLocation = ZoneRelative::Inside;
461 addToLoopSwitchCounters(S, 1);
463 RecursiveASTVisitor::TraverseStmt(S);
464 addToLoopSwitchCounters(S, -1);
468 CurrentLocation = ZoneRelative::After;
474 void addToLoopSwitchCounters(Stmt *S,
int Increment) {
475 if (CurrentLocation != ZoneRelative::Inside)
478 CurNumberOfNestedLoops += Increment;
479 else if (isa<SwitchStmt>(S))
480 CurNumberOfSwitch += Increment;
485 void decrementLoopSwitchCounters(Stmt *S) {
486 if (CurrentLocation != ZoneRelative::Inside)
489 CurNumberOfNestedLoops--;
490 else if (isa<SwitchStmt>(S))
494 bool VisitDecl(
Decl *D) {
495 Info.createDeclInfo(D, CurrentLocation);
499 bool VisitDeclRefExpr(DeclRefExpr *DRE) {
501 const Decl *D = DRE->getDecl();
502 auto *DeclInfo =
Info.getDeclInfoFor(D);
505 DeclInfo =
Info.createDeclInfo(D, ZoneRelative::OutsideFunc);
506 DeclInfo->markOccurence(CurrentLocation);
511 bool VisitReturnStmt(ReturnStmt *Return) {
512 if (CurrentLocation == ZoneRelative::Inside)
513 Info.HasReturnStmt =
true;
517 bool VisitBreakStmt(BreakStmt *Break) {
520 if (CurrentLocation == ZoneRelative::Inside &&
521 !(CurNumberOfNestedLoops || CurNumberOfSwitch))
522 Info.BrokenControlFlow =
true;
526 bool VisitContinueStmt(ContinueStmt *Continue) {
529 if (CurrentLocation == ZoneRelative::Inside && !CurNumberOfNestedLoops)
530 Info.BrokenControlFlow =
true;
533 CapturedZoneInfo
Info;
534 const ExtractionZone &ExtZone;
538 unsigned CurNumberOfNestedLoops = 0;
539 unsigned CurNumberOfSwitch = 0;
541 ExtractionZoneVisitor Visitor(ExtZone);
542 CapturedZoneInfo Result = std::move(Visitor.Info);
543 Result.AlwaysReturns = alwaysReturns(ExtZone);
551 bool createParameters(NewFunction &ExtractedFunc,
552 const CapturedZoneInfo &CapturedInfo) {
553 for (
const auto &KeyVal : CapturedInfo.DeclInfoMap) {
554 const auto &DeclInfo = KeyVal.second;
558 if (DeclInfo.DeclaredIn == ZoneRelative::Inside &&
559 DeclInfo.IsReferencedInPostZone)
561 if (!DeclInfo.IsReferencedInZone)
563 if (DeclInfo.DeclaredIn == ZoneRelative::Inside ||
564 DeclInfo.DeclaredIn == ZoneRelative::OutsideFunc)
567 const ValueDecl *VD = dyn_cast_or_null<ValueDecl>(DeclInfo.TheDecl);
570 if (!VD || isa<FunctionDecl>(DeclInfo.TheDecl))
573 QualType
TypeInfo = VD->getType().getNonReferenceType();
579 bool IsPassedByReference =
true;
581 ExtractedFunc.Parameters.push_back({std::string(VD->getName()),
TypeInfo,
583 DeclInfo.DeclIndex});
585 llvm::sort(ExtractedFunc.Parameters);
592 tooling::ExtractionSemicolonPolicy
593 getSemicolonPolicy(ExtractionZone &ExtZone,
const SourceManager &SM,
594 const LangOptions &LangOpts) {
596 SourceRange FuncBodyRange = {ExtZone.ZoneRange.getBegin(),
597 ExtZone.ZoneRange.getEnd().getLocWithOffset(-1)};
599 ExtZone.getLastRootStmt()->ASTNode.get<Stmt>(), FuncBodyRange, SM,
602 ExtZone.ZoneRange.setEnd(FuncBodyRange.getEnd().getLocWithOffset(1));
607 bool generateReturnProperties(NewFunction &ExtractedFunc,
608 const FunctionDecl &EnclosingFunc,
609 const CapturedZoneInfo &CapturedInfo) {
613 if (CapturedInfo.HasReturnStmt) {
616 if (!CapturedInfo.AlwaysReturns)
618 QualType Ret = EnclosingFunc.getReturnType();
621 if (Ret->isDependentType())
623 ExtractedFunc.ReturnType = Ret;
627 ExtractedFunc.ReturnType = EnclosingFunc.getParentASTContext().VoidTy;
633 llvm::Expected<NewFunction> getExtractedFunction(ExtractionZone &ExtZone,
634 const SourceManager &SM,
635 const LangOptions &LangOpts) {
636 CapturedZoneInfo CapturedInfo = captureZoneInfo(ExtZone);
638 if (CapturedInfo.BrokenControlFlow)
639 return llvm::createStringError(llvm::inconvertibleErrorCode(),
640 +
"Cannot extract break/continue without "
641 "corresponding loop/switch statement.");
642 NewFunction ExtractedFunc(getSemicolonPolicy(ExtZone, SM, LangOpts));
643 ExtractedFunc.BodyRange = ExtZone.ZoneRange;
644 ExtractedFunc.InsertionPoint = ExtZone.getInsertionPoint();
645 ExtractedFunc.EnclosingFuncContext =
646 ExtZone.EnclosingFunction->getDeclContext();
647 ExtractedFunc.CallerReturnsValue = CapturedInfo.AlwaysReturns;
648 if (!createParameters(ExtractedFunc, CapturedInfo) ||
649 !generateReturnProperties(ExtractedFunc, *ExtZone.EnclosingFunction,
651 return llvm::createStringError(llvm::inconvertibleErrorCode(),
652 +
"Too complex to extract.");
653 return ExtractedFunc;
656 class ExtractFunction :
public Tweak {
658 const char *id() const override final;
659 bool prepare(const Selection &
Inputs) override;
661 std::
string title()
const override {
return "Extract to function"; }
662 Intent intent()
const override {
return Refactor; }
665 ExtractionZone ExtZone;
669 tooling::Replacement replaceWithFuncCall(
const NewFunction &ExtractedFunc,
670 const SourceManager &SM,
671 const LangOptions &LangOpts) {
672 std::string FuncCall = ExtractedFunc.renderCall();
673 return tooling::Replacement(
674 SM, CharSourceRange(ExtractedFunc.BodyRange,
false), FuncCall, LangOpts);
677 tooling::Replacement createFunctionDefinition(
const NewFunction &ExtractedFunc,
678 const SourceManager &SM) {
679 std::string FunctionDef = ExtractedFunc.renderDefinition(SM);
680 return tooling::Replacement(SM, ExtractedFunc.InsertionPoint, 0, FunctionDef);
683 bool ExtractFunction::prepare(
const Selection &
Inputs) {
684 const Node *CommonAnc =
Inputs.ASTSelection.commonAncestor();
685 const SourceManager &SM =
Inputs.AST->getSourceManager();
686 const LangOptions &LangOpts =
Inputs.AST->getLangOpts();
687 if (
auto MaybeExtZone = findExtractionZone(CommonAnc, SM, LangOpts)) {
688 ExtZone = std::move(*MaybeExtZone);
694 Expected<Tweak::Effect> ExtractFunction::apply(
const Selection &
Inputs) {
695 const SourceManager &SM =
Inputs.AST->getSourceManager();
696 const LangOptions &LangOpts =
Inputs.AST->getLangOpts();
697 auto ExtractedFunc = getExtractedFunction(ExtZone, SM, LangOpts);
700 return ExtractedFunc.takeError();
701 tooling::Replacements Result;
702 if (
auto Err = Result.add(createFunctionDefinition(*ExtractedFunc, SM)))
703 return std::move(Err);
704 if (
auto Err = Result.add(replaceWithFuncCall(*ExtractedFunc, SM, LangOpts)))
705 return std::move(Err);
706 return Effect::mainFileEdit(SM, std::move(Result));