// Copyright 2016 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include #include #include #include #include #include "clang/AST/ASTContext.h" #include "clang/AST/ParentMap.h" #include "clang/ASTMatchers/ASTMatchFinder.h" #include "clang/ASTMatchers/ASTMatchers.h" #include "clang/ASTMatchers/ASTMatchersMacros.h" #include "clang/Analysis/CFG.h" #include "clang/Basic/SourceManager.h" #include "clang/Frontend/FrontendActions.h" #include "clang/Lex/Lexer.h" #include "clang/Tooling/CommonOptionsParser.h" #include "clang/Tooling/Refactoring.h" #include "clang/Tooling/Tooling.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/TargetSelect.h" using Replacements = std::vector; using clang::ASTContext; using clang::CFG; using clang::CFGBlock; using clang::CFGLifetimeEnds; using clang::CFGStmt; using clang::CallExpr; using clang::Decl; using clang::DeclRefExpr; using clang::FunctionDecl; using clang::LambdaExpr; using clang::Stmt; using clang::UnaryOperator; using clang::ast_type_traits::DynTypedNode; using clang::tooling::CommonOptionsParser; using namespace clang::ast_matchers; namespace { class Rewriter { public: virtual ~Rewriter() {} }; // Removes unneeded base::Passed() on a parameter of base::BindOnce(). // Example: // // Before // base::BindOnce(&Foo, base::Passed(&bar)); // base::BindOnce(&Foo, base::Passed(std::move(baz))); // base::BindOnce(&Foo, base::Passed(qux)); // // // After // base::BindOnce(&Foo, std::move(bar)); // base::BindOnce(&Foo, std::move(baz)); // base::BindOnce(&Foo, std::move(*qux)); class PassedToMoveRewriter : public MatchFinder::MatchCallback, public Rewriter { public: explicit PassedToMoveRewriter(Replacements* replacements) : replacements_(replacements) {} StatementMatcher GetMatcher() { auto is_passed = namedDecl(hasName("::base::Passed")); auto is_bind_once_call = callee(namedDecl(hasName("::base::BindOnce"))); // Matches base::Passed() call on a base::BindOnce() argument. return callExpr(is_bind_once_call, hasAnyArgument(ignoringImplicit( callExpr(callee(is_passed)).bind("target")))); } void run(const MatchFinder::MatchResult& result) override { auto* target = result.Nodes.getNodeAs("target"); auto* callee = target->getCallee()->IgnoreImpCasts(); auto* callee_decl = clang::dyn_cast(callee)->getDecl(); auto* passed_decl = clang::dyn_cast(callee_decl); auto* param_type = passed_decl->getParamDecl(0)->getType().getTypePtr(); if (param_type->isRValueReferenceType()) { // base::Passed(xxx) -> xxx. // The parameter type is already an rvalue reference. // Example: // std::unique_ptr foo(); // std::unique_ptr bar; // base::Passed(foo()); // base::Passed(std::move(bar)); // In these cases, we can just remove base::Passed. auto left = clang::CharSourceRange::getTokenRange( result.SourceManager->getSpellingLoc(target->getBeginLoc()), result.SourceManager->getSpellingLoc(target->getArg(0)->getExprLoc()) .getLocWithOffset(-1)); auto r_paren = clang::CharSourceRange::getTokenRange( result.SourceManager->getSpellingLoc(target->getRParenLoc()), result.SourceManager->getSpellingLoc(target->getRParenLoc())); replacements_->emplace_back(*result.SourceManager, left, " "); replacements_->emplace_back(*result.SourceManager, r_paren, " "); return; } if (!param_type->isPointerType()) return; auto* passed_arg = target->getArg(0)->IgnoreImpCasts(); if (auto* unary = clang::dyn_cast(passed_arg)) { if (unary->getOpcode() == clang::UO_AddrOf) { // base::Passed(&xxx) -> std::move(xxx). auto left = clang::CharSourceRange::getTokenRange( result.SourceManager->getSpellingLoc(target->getBeginLoc()), result.SourceManager->getSpellingLoc( target->getArg(0)->getExprLoc())); replacements_->emplace_back(*result.SourceManager, left, "std::move("); return; } } // base::Passed(xxx) -> std::move(*xxx) auto left = clang::CharSourceRange::getTokenRange( result.SourceManager->getSpellingLoc(target->getBeginLoc()), result.SourceManager->getSpellingLoc(target->getArg(0)->getExprLoc()) .getLocWithOffset(-1)); replacements_->emplace_back(*result.SourceManager, left, "std::move(*"); } private: Replacements* replacements_; }; // Replace base::Bind() and base::BindRepeating() to base::BindOnce() where // resulting callbacks are implicitly converted into base::OnceCallback. // Example: // // Before // base::PostTask(FROM_HERE, base::Bind(&Foo)); // base::OnceCallback cb = base::Bind(&Foo); // // // After // base::PostTask(FROM_HERE, base::BindOnce(&Foo)); // base::OnceCallback cb = base::BindOnce(&Foo); class BindOnceRewriter : public MatchFinder::MatchCallback, public Rewriter { public: explicit BindOnceRewriter(Replacements* replacements) : replacements_(replacements) {} StatementMatcher GetMatcher() { auto is_once_callback = hasType(hasCanonicalType(hasDeclaration( classTemplateSpecializationDecl(hasName("::base::OnceCallback"))))); auto is_repeating_callback = hasType(hasCanonicalType(hasDeclaration(classTemplateSpecializationDecl( hasName("::base::RepeatingCallback"))))); auto bind_call = callExpr(callee(namedDecl(anyOf(hasName("::base::Bind"), hasName("::base::BindRepeating"))))) .bind("target"); auto parameter_construction = cxxConstructExpr(is_repeating_callback, argumentCountIs(1), hasArgument(0, ignoringImplicit(bind_call))); auto constructor_conversion = cxxConstructExpr( is_once_callback, argumentCountIs(1), hasArgument(0, ignoringImplicit(parameter_construction))); return implicitCastExpr(is_once_callback, hasSourceExpression(constructor_conversion)); } void run(const MatchFinder::MatchResult& result) override { auto* target = result.Nodes.getNodeAs("target"); auto* callee = target->getCallee(); auto range = clang::CharSourceRange::getTokenRange( result.SourceManager->getSpellingLoc(callee->getEndLoc()), result.SourceManager->getSpellingLoc(callee->getEndLoc())); replacements_->emplace_back(*result.SourceManager, range, "BindOnce"); } private: Replacements* replacements_; }; // Converts pass-by-const-ref base::Callback's to pass-by-value. // Example: // // Before // using BarCallback = base::Callback; // void Foo(const base::Callback& cb); // void Bar(const BarCallback& cb); // // // After // using BarCallback = base::Callback; // void Foo(base::Callback cb); // void Bar(BarCallback cb); class PassByValueRewriter : public MatchFinder::MatchCallback, public Rewriter { public: explicit PassByValueRewriter(Replacements* replacements) : replacements_(replacements) {} DeclarationMatcher GetMatcher() { auto is_repeating_callback = namedDecl(hasName("::base::RepeatingCallback")); return parmVarDecl( hasType(hasCanonicalType(references(is_repeating_callback)))) .bind("target"); } void run(const MatchFinder::MatchResult& result) override { auto* target = result.Nodes.getNodeAs("target"); auto qual_type = target->getType(); auto* ref_type = clang::dyn_cast(qual_type.getTypePtr()); if (!ref_type || !ref_type->getPointeeType().isLocalConstQualified()) return; // Remove the leading `const` and the following `&`. auto type_loc = target->getTypeSourceInfo()->getTypeLoc(); auto const_keyword = clang::CharSourceRange::getTokenRange( result.SourceManager->getSpellingLoc(target->getBeginLoc()), result.SourceManager->getSpellingLoc(target->getBeginLoc())); auto lvalue_ref = clang::CharSourceRange::getTokenRange( result.SourceManager->getSpellingLoc(type_loc.getEndLoc()), result.SourceManager->getSpellingLoc(type_loc.getEndLoc())); replacements_->emplace_back(*result.SourceManager, const_keyword, " "); replacements_->emplace_back(*result.SourceManager, lvalue_ref, " "); } private: Replacements* replacements_; }; // Adds std::move() to base::RepeatingCallback<> where it looks relevant. // Example: // // Before // void Foo(base::Callback cb1) { // base::Closure cb2 = base::Bind(cb1, 42); // PostTask(FROM_HERE, cb2); // } // // // After // void Foo(base::Callback cb1) { // base::Closure cb2 = base::Bind(std::move(cb1), 42); // PostTask(FROM_HERE, std::move(cb2)); // } class AddStdMoveRewriter : public MatchFinder::MatchCallback, public Rewriter { public: explicit AddStdMoveRewriter(Replacements* replacements) : replacements_(replacements) {} StatementMatcher GetMatcher() { return declRefExpr( hasType(hasCanonicalType(hasDeclaration( namedDecl(hasName("::base::RepeatingCallback"))))), anyOf(hasAncestor(cxxConstructorDecl().bind("enclosing_ctor")), hasAncestor(functionDecl().bind("enclosing_func")), hasAncestor(lambdaExpr().bind("enclosing_lambda")))) .bind("target"); } // Build Control Flow Graph (CFG) for |stmt| and populate class members with // the content of the graph. Returns true if the analysis finished // successfully. bool ExtractCFGContentToMembers(Stmt* stmt, ASTContext* context) { // Try to make a cache entry. The failure implies it's already in the cache. auto inserted = cfg_cache_.emplace(stmt, nullptr); if (!inserted.second) return !!inserted.first->second; std::unique_ptr& cfg = inserted.first->second; CFG::BuildOptions opts; opts.AddInitializers = true; opts.AddLifetime = true; opts.AddStaticInitBranches = true; cfg = CFG::buildCFG(nullptr, stmt, context, opts); // CFG construction may fail. Report it to the caller. if (!cfg) return false; if (!parent_map_) parent_map_ = std::make_unique(stmt); else parent_map_->addStmt(stmt); // Populate |top_stmts_|, that contains Stmts that is evaluated in its own // CFGElement. for (auto* block : *cfg) { for (auto& elem : *block) { if (auto stmt = elem.getAs()) top_stmts_.insert(stmt->getStmt()); } } // Populate |enclosing_block_|, that maps a Stmt to a CFGBlock that contains // the Stmt. std::function recursive_set_enclosing = [&](const CFGBlock* block, const Stmt* stmt) { enclosing_block_[stmt] = block; for (auto* c : stmt->children()) { if (!c) continue; if (top_stmts_.find(c) != top_stmts_.end()) continue; recursive_set_enclosing(block, c); } }; for (auto* block : *cfg) { for (auto& elem : *block) { if (auto stmt = elem.getAs()) recursive_set_enclosing(block, stmt->getStmt()); } } return true; } const Stmt* EnclosingCxxStatement(const Stmt* stmt) { while (true) { const Stmt* parent = parent_map_->getParentIgnoreParenCasts(stmt); assert(parent); switch (parent->getStmtClass()) { case Stmt::CompoundStmtClass: case Stmt::ForStmtClass: case Stmt::CXXForRangeStmtClass: case Stmt::WhileStmtClass: case Stmt::DoStmtClass: case Stmt::IfStmtClass: // Other candidates: // Stmt::CXXTryStmtClass // Stmt::CXXCatchStmtClass // Stmt::CapturedStmtClass // Stmt::SwitchStmtClass // Stmt::SwitchCaseClass return stmt; default: stmt = parent; break; } } } bool WasPointerTaken(const Stmt* stmt, const Decl* decl) { std::function visit_stmt = [&](const Stmt* stmt) { if (auto* op = clang::dyn_cast(stmt)) { if (op->getOpcode() == clang::UO_AddrOf) { auto* ref = clang::dyn_cast(op->getSubExpr()); // |ref| may be null if the sub-expr has a dependent type. if (ref && ref->getDecl() == decl) return true; } } for (auto* c : stmt->children()) { if (!c) continue; if (visit_stmt(c)) return true; } return false; }; return visit_stmt(stmt); } bool HasCapturingLambda(const Stmt* stmt, const Decl* decl) { std::function visit_stmt = [&](const Stmt* stmt) { if (auto* l = clang::dyn_cast(stmt)) { for (auto c : l->captures()) { if (c.getCapturedVar() == decl) return true; } } for (auto* c : stmt->children()) { if (!c) continue; if (visit_stmt(c)) return true; } return false; }; return visit_stmt(stmt); } // Returns true if there are multiple occurrences to |decl| in one of C++ // statements in |stmt|. bool HasUnorderedOccurrences(const Decl* decl, const Stmt* stmt) { int count = 0; std::function visit_stmt = [&](const Stmt* s) { if (auto* ref = clang::dyn_cast(s)) { if (ref->getDecl() == decl) ++count; } for (auto* c : s->children()) { if (!c) continue; visit_stmt(c); } }; visit_stmt(EnclosingCxxStatement(stmt)); return count > 1; } void run(const MatchFinder::MatchResult& result) override { auto* target = result.Nodes.getNodeAs("target"); auto* decl = clang::dyn_cast(target->getDecl()); // Other than local variables and parameters are out-of-scope. if (!decl || !decl->isLocalVarDeclOrParm()) return; auto qual_type = decl->getType(); // Qualified variables are out-of-scope. They are likely not movable. if (qual_type.getCanonicalType().hasQualifiers()) return; auto* type = qual_type.getTypePtr(); // References and pointers are out-of-scope. if (type->isReferenceType() || type->isPointerType()) return; Stmt* body = nullptr; if (auto* ctor = result.Nodes.getNodeAs("enclosing_ctor")) return; // Skip constructor case for now. TBD. else if (auto* func = result.Nodes.getNodeAs("enclosing_func")) body = func->getBody(); else if (auto* lambda = result.Nodes.getNodeAs("enclosing_lambda")) body = lambda->getBody(); else return; // Disable the replacement if there is a lambda that captures |decl|. if (HasCapturingLambda(body, decl)) return; // Disable the replacement if the pointer to |decl| is taken in the scope. if (WasPointerTaken(body, decl)) return; if (!ExtractCFGContentToMembers(body, result.Context)) return; auto* parent = parent_map_->getParentIgnoreParenCasts(target); if (auto* p = clang::dyn_cast(parent)) { auto* callee = p->getCalleeDecl(); // |callee| may be null if the CallExpr has an unresolved look up. if (!callee) return; auto* callee_decl = clang::dyn_cast(callee); auto name = callee_decl->getQualifiedNameAsString(); // Disable the replacement if it's already in std::move() or // std::forward(). if (name == "std::__1::move" || name == "std::__1::forward") return; } else if (parent->getStmtClass() == Stmt::ReturnStmtClass) { // Disable the replacement if it's in a return statement. return; } // If the same C++ statement contains multiple reference to the variable, // don't insert std::move() to be conservative. if (HasUnorderedOccurrences(decl, target)) return; bool saw_reuse = false; ForEachFollowingStmts(target, [&](const Stmt* stmt) { if (auto* ref = clang::dyn_cast(stmt)) { if (ref->getDecl() == decl) { saw_reuse = true; return false; } } // TODO: Detect Reset() and operator=() to stop the traversal. return true; }); if (saw_reuse) return; replacements_->emplace_back( *result.SourceManager, result.SourceManager->getSpellingLoc(target->getBeginLoc()), 0, "std::move("); replacements_->emplace_back( *result.SourceManager, clang::Lexer::getLocForEndOfToken(target->getEndLoc(), 0, *result.SourceManager, result.Context->getLangOpts()), 0, ")"); } // Invokes |handler| for each Stmt that follows |target| until it reaches the // end of the lifetime of the variable that |target| references. // If |handler| returns false, stops following the current control flow. void ForEachFollowingStmts(const DeclRefExpr* target, std::function handler) { auto* decl = target->getDecl(); auto* block = enclosing_block_[target]; std::set visited; std::vector stack = {block}; bool saw_target = false; std::function visit_stmt = [&](const Stmt* s) { for (auto* t : s->children()) { if (!t) continue; // |t| is evaluated elsewhere if a sub-Stmt is in |top_stmt_|. if (top_stmts_.find(t) != top_stmts_.end()) continue; if (!visit_stmt(t)) return false; } if (!saw_target) { if (s == target) saw_target = true; return true; } return handler(s); }; bool visited_initial_block_twice = false; while (!stack.empty()) { auto* b = stack.back(); stack.pop_back(); if (!visited.insert(b).second) { if (b != block || visited_initial_block_twice) continue; visited_initial_block_twice = true; } bool cont = true; for (auto e : *b) { if (auto s = e.getAs()) { if (!visit_stmt(s->getStmt())) { cont = false; break; } } else if (auto l = e.getAs()) { if (l->getVarDecl() == decl) { cont = false; break; } } } if (cont) { for (auto s : b->succs()) { if (!s) continue; // Unreachable block. stack.push_back(s); } } } } private: // Function body to CFG. std::map> cfg_cache_; // Statement to the enclosing CFGBlock. std::map enclosing_block_; // Stmt to its parent Stmt. std::unique_ptr parent_map_; // A set of Stmt that a CFGElement has it directly. std::set top_stmts_; Replacements* replacements_; }; // Remove base::AdaptCallbackForRepeating() where resulting // base::RepeatingCallback is implicitly converted into base::OnceCallback. // Example: // // Before // base::PostTask( // FROM_HERE, // base::AdaptCallbackForRepeating(base::BindOnce(&Foo))); // base::OnceCallback cb = base::AdaptCallbackForRepeating( // base::OnceBind(&Foo)); // // // After // base::PostTask(FROM_HERE, base::BindOnce(&Foo)); // base::OnceCallback cb = base::BindOnce(&Foo); class AdaptCallbackForRepeatingRewriter : public MatchFinder::MatchCallback, public Rewriter { public: explicit AdaptCallbackForRepeatingRewriter(Replacements* replacements) : replacements_(replacements) {} StatementMatcher GetMatcher() { auto is_once_callback = hasType(hasCanonicalType(hasDeclaration( classTemplateSpecializationDecl(hasName("::base::OnceCallback"))))); auto is_repeating_callback = hasType(hasCanonicalType(hasDeclaration(classTemplateSpecializationDecl( hasName("::base::RepeatingCallback"))))); auto adapt_callback_call = callExpr( callee(namedDecl(hasName("::base::AdaptCallbackForRepeating")))) .bind("target"); auto parameter_construction = cxxConstructExpr(is_repeating_callback, argumentCountIs(1), hasArgument(0, ignoringImplicit(adapt_callback_call))); auto constructor_conversion = cxxConstructExpr( is_once_callback, argumentCountIs(1), hasArgument(0, ignoringImplicit(parameter_construction))); return implicitCastExpr(is_once_callback, hasSourceExpression(constructor_conversion)); } void run(const MatchFinder::MatchResult& result) override { auto* target = result.Nodes.getNodeAs("target"); auto left = clang::CharSourceRange::getTokenRange( result.SourceManager->getSpellingLoc(target->getBeginLoc()), result.SourceManager->getSpellingLoc(target->getArg(0)->getExprLoc()) .getLocWithOffset(-1)); // We use " " as replacement to work around https://crbug.com/861886. replacements_->emplace_back(*result.SourceManager, left, " "); auto r_paren = clang::CharSourceRange::getTokenRange( result.SourceManager->getSpellingLoc(target->getRParenLoc()), result.SourceManager->getSpellingLoc(target->getRParenLoc())); replacements_->emplace_back(*result.SourceManager, r_paren, " "); } private: Replacements* replacements_; }; llvm::cl::extrahelp common_help(CommonOptionsParser::HelpMessage); llvm::cl::OptionCategory rewriter_category("Rewriter Options"); llvm::cl::opt rewriter_option( "rewriter", llvm::cl::desc(R"(One of the name of rewriter to apply. Available rewriters are: remove_unneeded_passed bind_to_bind_once pass_by_value add_std_move remove_unneeded_adapt_callback The default is remove_unneeded_passed. )"), llvm::cl::init("remove_unneeded_passed"), llvm::cl::cat(rewriter_category)); } // namespace. int main(int argc, const char* argv[]) { llvm::InitializeNativeTarget(); llvm::InitializeNativeTargetAsmParser(); CommonOptionsParser options(argc, argv, rewriter_category); clang::tooling::ClangTool tool(options.getCompilations(), options.getSourcePathList()); MatchFinder match_finder; std::vector replacements; std::unique_ptr rewriter; if (rewriter_option == "remove_unneeded_passed") { auto passed_to_move = std::make_unique(&replacements); match_finder.addMatcher(passed_to_move->GetMatcher(), passed_to_move.get()); rewriter = std::move(passed_to_move); } else if (rewriter_option == "bind_to_bind_once") { auto bind_once = std::make_unique(&replacements); match_finder.addMatcher(bind_once->GetMatcher(), bind_once.get()); rewriter = std::move(bind_once); } else if (rewriter_option == "pass_by_value") { auto pass_by_value = std::make_unique(&replacements); match_finder.addMatcher(pass_by_value->GetMatcher(), pass_by_value.get()); rewriter = std::move(pass_by_value); } else if (rewriter_option == "add_std_move") { auto add_std_move = std::make_unique(&replacements); match_finder.addMatcher(add_std_move->GetMatcher(), add_std_move.get()); rewriter = std::move(add_std_move); } else if (rewriter_option == "remove_unneeded_adapt_callback") { auto remove_unneeded_adapt_callback = std::make_unique(&replacements); match_finder.addMatcher(remove_unneeded_adapt_callback->GetMatcher(), remove_unneeded_adapt_callback.get()); rewriter = std::move(remove_unneeded_adapt_callback); } else { abort(); } std::unique_ptr factory = clang::tooling::newFrontendActionFactory(&match_finder); int result = tool.run(factory.get()); if (result != 0) return result; // Serialization format is documented in tools/clang/scripts/run_tool.py llvm::outs() << "==== BEGIN EDITS ====\n"; for (const auto& r : replacements) { std::string replacement_text = r.getReplacementText().str(); std::replace(replacement_text.begin(), replacement_text.end(), '\n', '\0'); llvm::outs() << "r:::" << r.getFilePath() << ":::" << r.getOffset() << ":::" << r.getLength() << ":::" << replacement_text << "\n"; } llvm::outs() << "==== END EDITS ====\n"; return 0; }