c4173426277cfb737331430bbb08989dfac89c7e
[lldb.git] / clang-tools-extra / clang-include-fixer / IncludeFixer.cpp
1 //===-- IncludeFixer.cpp - Include inserter based on sema callbacks -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "IncludeFixer.h"
10 #include "clang/Format/Format.h"
11 #include "clang/Frontend/CompilerInstance.h"
12 #include "clang/Lex/HeaderSearch.h"
13 #include "clang/Lex/Preprocessor.h"
14 #include "clang/Parse/ParseAST.h"
15 #include "clang/Sema/Sema.h"
16 #include "llvm/Support/Debug.h"
17 #include "llvm/Support/raw_ostream.h"
18
19 #define DEBUG_TYPE "clang-include-fixer"
20
21 using namespace clang;
22
23 namespace clang {
24 namespace include_fixer {
25 namespace {
26 /// Manages the parse, gathers include suggestions.
27 class Action : public clang::ASTFrontendAction {
28 public:
29   explicit Action(SymbolIndexManager &SymbolIndexMgr, bool MinimizeIncludePaths)
30       : SemaSource(SymbolIndexMgr, MinimizeIncludePaths,
31                    /*GenerateDiagnostics=*/false) {}
32
33   std::unique_ptr<clang::ASTConsumer>
34   CreateASTConsumer(clang::CompilerInstance &Compiler,
35                     StringRef InFile) override {
36     SemaSource.setFilePath(InFile);
37     return llvm::make_unique<clang::ASTConsumer>();
38   }
39
40   void ExecuteAction() override {
41     clang::CompilerInstance *Compiler = &getCompilerInstance();
42     assert(!Compiler->hasSema() && "CI already has Sema");
43
44     // Set up our hooks into sema and parse the AST.
45     if (hasCodeCompletionSupport() &&
46         !Compiler->getFrontendOpts().CodeCompletionAt.FileName.empty())
47       Compiler->createCodeCompletionConsumer();
48
49     clang::CodeCompleteConsumer *CompletionConsumer = nullptr;
50     if (Compiler->hasCodeCompletionConsumer())
51       CompletionConsumer = &Compiler->getCodeCompletionConsumer();
52
53     Compiler->createSema(getTranslationUnitKind(), CompletionConsumer);
54     SemaSource.setCompilerInstance(Compiler);
55     Compiler->getSema().addExternalSource(&SemaSource);
56
57     clang::ParseAST(Compiler->getSema(), Compiler->getFrontendOpts().ShowStats,
58                     Compiler->getFrontendOpts().SkipFunctionBodies);
59   }
60
61   IncludeFixerContext
62   getIncludeFixerContext(const clang::SourceManager &SourceManager,
63                          clang::HeaderSearch &HeaderSearch) const {
64     return SemaSource.getIncludeFixerContext(SourceManager, HeaderSearch,
65                                              SemaSource.getMatchedSymbols());
66   }
67
68 private:
69   IncludeFixerSemaSource SemaSource;
70 };
71
72 } // namespace
73
74 IncludeFixerActionFactory::IncludeFixerActionFactory(
75     SymbolIndexManager &SymbolIndexMgr,
76     std::vector<IncludeFixerContext> &Contexts, StringRef StyleName,
77     bool MinimizeIncludePaths)
78     : SymbolIndexMgr(SymbolIndexMgr), Contexts(Contexts),
79       MinimizeIncludePaths(MinimizeIncludePaths) {}
80
81 IncludeFixerActionFactory::~IncludeFixerActionFactory() = default;
82
83 bool IncludeFixerActionFactory::runInvocation(
84     std::shared_ptr<clang::CompilerInvocation> Invocation,
85     clang::FileManager *Files,
86     std::shared_ptr<clang::PCHContainerOperations> PCHContainerOps,
87     clang::DiagnosticConsumer *Diagnostics) {
88   assert(Invocation->getFrontendOpts().Inputs.size() == 1);
89
90   // Set up Clang.
91   clang::CompilerInstance Compiler(PCHContainerOps);
92   Compiler.setInvocation(std::move(Invocation));
93   Compiler.setFileManager(Files);
94
95   // Create the compiler's actual diagnostics engine. We want to drop all
96   // diagnostics here.
97   Compiler.createDiagnostics(new clang::IgnoringDiagConsumer,
98                              /*ShouldOwnClient=*/true);
99   Compiler.createSourceManager(*Files);
100
101   // We abort on fatal errors so don't let a large number of errors become
102   // fatal. A missing #include can cause thousands of errors.
103   Compiler.getDiagnostics().setErrorLimit(0);
104
105   // Run the parser, gather missing includes.
106   auto ScopedToolAction =
107       llvm::make_unique<Action>(SymbolIndexMgr, MinimizeIncludePaths);
108   Compiler.ExecuteAction(*ScopedToolAction);
109
110   Contexts.push_back(ScopedToolAction->getIncludeFixerContext(
111       Compiler.getSourceManager(),
112       Compiler.getPreprocessor().getHeaderSearchInfo()));
113
114   // Technically this should only return true if we're sure that we have a
115   // parseable file. We don't know that though. Only inform users of fatal
116   // errors.
117   return !Compiler.getDiagnostics().hasFatalErrorOccurred();
118 }
119
120 static bool addDiagnosticsForContext(TypoCorrection &Correction,
121                                      const IncludeFixerContext &Context,
122                                      StringRef Code, SourceLocation StartOfFile,
123                                      ASTContext &Ctx) {
124   auto Reps = createIncludeFixerReplacements(
125       Code, Context, format::getLLVMStyle(), /*AddQualifiers=*/false);
126   if (!Reps || Reps->size() != 1)
127     return false;
128
129   unsigned DiagID = Ctx.getDiagnostics().getCustomDiagID(
130       DiagnosticsEngine::Note, "Add '#include %0' to provide the missing "
131                                "declaration [clang-include-fixer]");
132
133   // FIXME: Currently we only generate a diagnostic for the first header. Give
134   // the user choices.
135   const tooling::Replacement &Placed = *Reps->begin();
136
137   auto Begin = StartOfFile.getLocWithOffset(Placed.getOffset());
138   auto End = Begin.getLocWithOffset(std::max(0, (int)Placed.getLength() - 1));
139   PartialDiagnostic PD(DiagID, Ctx.getDiagAllocator());
140   PD << Context.getHeaderInfos().front().Header
141      << FixItHint::CreateReplacement(CharSourceRange::getCharRange(Begin, End),
142                                      Placed.getReplacementText());
143   Correction.addExtraDiagnostic(std::move(PD));
144   return true;
145 }
146
147 /// Callback for incomplete types. If we encounter a forward declaration we
148 /// have the fully qualified name ready. Just query that.
149 bool IncludeFixerSemaSource::MaybeDiagnoseMissingCompleteType(
150     clang::SourceLocation Loc, clang::QualType T) {
151   // Ignore spurious callbacks from SFINAE contexts.
152   if (CI->getSema().isSFINAEContext())
153     return false;
154
155   clang::ASTContext &context = CI->getASTContext();
156   std::string QueryString = QualType(T->getUnqualifiedDesugaredType(), 0)
157                                 .getAsString(context.getPrintingPolicy());
158   LLVM_DEBUG(llvm::dbgs() << "Query missing complete type '" << QueryString
159                           << "'");
160   // Pass an empty range here since we don't add qualifier in this case.
161   std::vector<find_all_symbols::SymbolInfo> MatchedSymbols =
162       query(QueryString, "", tooling::Range());
163
164   if (!MatchedSymbols.empty() && GenerateDiagnostics) {
165     TypoCorrection Correction;
166     FileID FID = CI->getSourceManager().getFileID(Loc);
167     StringRef Code = CI->getSourceManager().getBufferData(FID);
168     SourceLocation StartOfFile =
169         CI->getSourceManager().getLocForStartOfFile(FID);
170     addDiagnosticsForContext(
171         Correction,
172         getIncludeFixerContext(CI->getSourceManager(),
173                                CI->getPreprocessor().getHeaderSearchInfo(),
174                                MatchedSymbols),
175         Code, StartOfFile, CI->getASTContext());
176     for (const PartialDiagnostic &PD : Correction.getExtraDiagnostics())
177       CI->getSema().Diag(Loc, PD);
178   }
179   return true;
180 }
181
182 /// Callback for unknown identifiers. Try to piece together as much
183 /// qualification as we can get and do a query.
184 clang::TypoCorrection IncludeFixerSemaSource::CorrectTypo(
185     const DeclarationNameInfo &Typo, int LookupKind, Scope *S, CXXScopeSpec *SS,
186     CorrectionCandidateCallback &CCC, DeclContext *MemberContext,
187     bool EnteringContext, const ObjCObjectPointerType *OPT) {
188   // Ignore spurious callbacks from SFINAE contexts.
189   if (CI->getSema().isSFINAEContext())
190     return clang::TypoCorrection();
191
192   // We currently ignore the unidentified symbol which is not from the
193   // main file.
194   //
195   // However, this is not always true due to templates in a non-self contained
196   // header, consider the case:
197   //
198   //   // header.h
199   //   template <typename T>
200   //   class Foo {
201   //     T t;
202   //   };
203   //
204   //   // test.cc
205   //   // We need to add <bar.h> in test.cc instead of header.h.
206   //   class Bar;
207   //   Foo<Bar> foo;
208   //
209   // FIXME: Add the missing header to the header file where the symbol comes
210   // from.
211   if (!CI->getSourceManager().isWrittenInMainFile(Typo.getLoc()))
212     return clang::TypoCorrection();
213
214   std::string TypoScopeString;
215   if (S) {
216     // FIXME: Currently we only use namespace contexts. Use other context
217     // types for query.
218     for (const auto *Context = S->getEntity(); Context;
219          Context = Context->getParent()) {
220       if (const auto *ND = dyn_cast<NamespaceDecl>(Context)) {
221         if (!ND->getName().empty())
222           TypoScopeString = ND->getNameAsString() + "::" + TypoScopeString;
223       }
224     }
225   }
226
227   auto ExtendNestedNameSpecifier = [this](CharSourceRange Range) {
228     StringRef Source =
229         Lexer::getSourceText(Range, CI->getSourceManager(), CI->getLangOpts());
230
231     // Skip forward until we find a character that's neither identifier nor
232     // colon. This is a bit of a hack around the fact that we will only get a
233     // single callback for a long nested name if a part of the beginning is
234     // unknown. For example:
235     //
236     // llvm::sys::path::parent_path(...)
237     // ^~~~  ^~~
238     //    known
239     //            ^~~~
240     //      unknown, last callback
241     //                  ^~~~~~~~~~~
242     //                  no callback
243     //
244     // With the extension we get the full nested name specifier including
245     // parent_path.
246     // FIXME: Don't rely on source text.
247     const char *End = Source.end();
248     while (isIdentifierBody(*End) || *End == ':')
249       ++End;
250
251     return std::string(Source.begin(), End);
252   };
253
254   /// If we have a scope specification, use that to get more precise results.
255   std::string QueryString;
256   tooling::Range SymbolRange;
257   const auto &SM = CI->getSourceManager();
258   auto CreateToolingRange = [&QueryString, &SM](SourceLocation BeginLoc) {
259     return tooling::Range(SM.getDecomposedLoc(BeginLoc).second,
260                           QueryString.size());
261   };
262   if (SS && SS->getRange().isValid()) {
263     auto Range = CharSourceRange::getTokenRange(SS->getRange().getBegin(),
264                                                 Typo.getLoc());
265
266     QueryString = ExtendNestedNameSpecifier(Range);
267     SymbolRange = CreateToolingRange(Range.getBegin());
268   } else if (Typo.getName().isIdentifier() && !Typo.getLoc().isMacroID()) {
269     auto Range =
270         CharSourceRange::getTokenRange(Typo.getBeginLoc(), Typo.getEndLoc());
271
272     QueryString = ExtendNestedNameSpecifier(Range);
273     SymbolRange = CreateToolingRange(Range.getBegin());
274   } else {
275     QueryString = Typo.getAsString();
276     SymbolRange = CreateToolingRange(Typo.getLoc());
277   }
278
279   LLVM_DEBUG(llvm::dbgs() << "TypoScopeQualifiers: " << TypoScopeString
280                           << "\n");
281   std::vector<find_all_symbols::SymbolInfo> MatchedSymbols =
282       query(QueryString, TypoScopeString, SymbolRange);
283
284   if (!MatchedSymbols.empty() && GenerateDiagnostics) {
285     TypoCorrection Correction(Typo.getName());
286     Correction.setCorrectionRange(SS, Typo);
287     FileID FID = SM.getFileID(Typo.getLoc());
288     StringRef Code = SM.getBufferData(FID);
289     SourceLocation StartOfFile = SM.getLocForStartOfFile(FID);
290     if (addDiagnosticsForContext(
291             Correction, getIncludeFixerContext(
292                             SM, CI->getPreprocessor().getHeaderSearchInfo(),
293                             MatchedSymbols),
294             Code, StartOfFile, CI->getASTContext()))
295       return Correction;
296   }
297   return TypoCorrection();
298 }
299
300 /// Get the minimal include for a given path.
301 std::string IncludeFixerSemaSource::minimizeInclude(
302     StringRef Include, const clang::SourceManager &SourceManager,
303     clang::HeaderSearch &HeaderSearch) const {
304   if (!MinimizeIncludePaths)
305     return Include;
306
307   // Get the FileEntry for the include.
308   StringRef StrippedInclude = Include.trim("\"<>");
309   const FileEntry *Entry =
310       SourceManager.getFileManager().getFile(StrippedInclude);
311
312   // If the file doesn't exist return the path from the database.
313   // FIXME: This should never happen.
314   if (!Entry)
315     return Include;
316
317   bool IsSystem = false;
318   std::string Suggestion =
319       HeaderSearch.suggestPathToFileForDiagnostics(Entry, "", &IsSystem);
320
321   return IsSystem ? '<' + Suggestion + '>' : '"' + Suggestion + '"';
322 }
323
324 /// Get the include fixer context for the queried symbol.
325 IncludeFixerContext IncludeFixerSemaSource::getIncludeFixerContext(
326     const clang::SourceManager &SourceManager,
327     clang::HeaderSearch &HeaderSearch,
328     ArrayRef<find_all_symbols::SymbolInfo> MatchedSymbols) const {
329   std::vector<find_all_symbols::SymbolInfo> SymbolCandidates;
330   for (const auto &Symbol : MatchedSymbols) {
331     std::string FilePath = Symbol.getFilePath().str();
332     std::string MinimizedFilePath = minimizeInclude(
333         ((FilePath[0] == '"' || FilePath[0] == '<') ? FilePath
334                                                     : "\"" + FilePath + "\""),
335         SourceManager, HeaderSearch);
336     SymbolCandidates.emplace_back(Symbol.getName(), Symbol.getSymbolKind(),
337                                   MinimizedFilePath, Symbol.getContexts());
338   }
339   return IncludeFixerContext(FilePath, QuerySymbolInfos, SymbolCandidates);
340 }
341
342 std::vector<find_all_symbols::SymbolInfo>
343 IncludeFixerSemaSource::query(StringRef Query, StringRef ScopedQualifiers,
344                               tooling::Range Range) {
345   assert(!Query.empty() && "Empty query!");
346
347   // Save all instances of an unidentified symbol.
348   //
349   // We use conservative behavior for detecting the same unidentified symbol
350   // here. The symbols which have the same ScopedQualifier and RawIdentifier
351   // are considered equal. So that clang-include-fixer avoids false positives,
352   // and always adds missing qualifiers to correct symbols.
353   if (!GenerateDiagnostics && !QuerySymbolInfos.empty()) {
354     if (ScopedQualifiers == QuerySymbolInfos.front().ScopedQualifiers &&
355         Query == QuerySymbolInfos.front().RawIdentifier) {
356       QuerySymbolInfos.push_back({Query.str(), ScopedQualifiers, Range});
357     }
358     return {};
359   }
360
361   LLVM_DEBUG(llvm::dbgs() << "Looking up '" << Query << "' at ");
362   LLVM_DEBUG(CI->getSourceManager()
363                  .getLocForStartOfFile(CI->getSourceManager().getMainFileID())
364                  .getLocWithOffset(Range.getOffset())
365                  .print(llvm::dbgs(), CI->getSourceManager()));
366   LLVM_DEBUG(llvm::dbgs() << " ...");
367   llvm::StringRef FileName = CI->getSourceManager().getFilename(
368       CI->getSourceManager().getLocForStartOfFile(
369           CI->getSourceManager().getMainFileID()));
370
371   QuerySymbolInfos.push_back({Query.str(), ScopedQualifiers, Range});
372
373   // Query the symbol based on C++ name Lookup rules.
374   // Firstly, lookup the identifier with scoped namespace contexts;
375   // If that fails, falls back to look up the identifier directly.
376   //
377   // For example:
378   //
379   // namespace a {
380   // b::foo f;
381   // }
382   //
383   // 1. lookup a::b::foo.
384   // 2. lookup b::foo.
385   std::string QueryString = ScopedQualifiers.str() + Query.str();
386   // It's unsafe to do nested search for the identifier with scoped namespace
387   // context, it might treat the identifier as a nested class of the scoped
388   // namespace.
389   std::vector<find_all_symbols::SymbolInfo> MatchedSymbols =
390       SymbolIndexMgr.search(QueryString, /*IsNestedSearch=*/false, FileName);
391   if (MatchedSymbols.empty())
392     MatchedSymbols =
393         SymbolIndexMgr.search(Query, /*IsNestedSearch=*/true, FileName);
394   LLVM_DEBUG(llvm::dbgs() << "Having found " << MatchedSymbols.size()
395                           << " symbols\n");
396   // We store a copy of MatchedSymbols in a place where it's globally reachable.
397   // This is used by the standalone version of the tool.
398   this->MatchedSymbols = MatchedSymbols;
399   return MatchedSymbols;
400 }
401
402 llvm::Expected<tooling::Replacements> createIncludeFixerReplacements(
403     StringRef Code, const IncludeFixerContext &Context,
404     const clang::format::FormatStyle &Style, bool AddQualifiers) {
405   if (Context.getHeaderInfos().empty())
406     return tooling::Replacements();
407   StringRef FilePath = Context.getFilePath();
408   std::string IncludeName =
409       "#include " + Context.getHeaderInfos().front().Header + "\n";
410   // Create replacements for the new header.
411   clang::tooling::Replacements Insertions;
412   auto Err =
413       Insertions.add(tooling::Replacement(FilePath, UINT_MAX, 0, IncludeName));
414   if (Err)
415     return std::move(Err);
416
417   auto CleanReplaces = cleanupAroundReplacements(Code, Insertions, Style);
418   if (!CleanReplaces)
419     return CleanReplaces;
420
421   auto Replaces = std::move(*CleanReplaces);
422   if (AddQualifiers) {
423     for (const auto &Info : Context.getQuerySymbolInfos()) {
424       // Ignore the empty range.
425       if (Info.Range.getLength() > 0) {
426         auto R = tooling::Replacement(
427             {FilePath, Info.Range.getOffset(), Info.Range.getLength(),
428              Context.getHeaderInfos().front().QualifiedName});
429         auto Err = Replaces.add(R);
430         if (Err) {
431           llvm::consumeError(std::move(Err));
432           R = tooling::Replacement(
433               R.getFilePath(), Replaces.getShiftedCodePosition(R.getOffset()),
434               R.getLength(), R.getReplacementText());
435           Replaces = Replaces.merge(tooling::Replacements(R));
436         }
437       }
438     }
439   }
440   return formatReplacements(Code, Replaces, Style);
441 }
442
443 } // namespace include_fixer
444 } // namespace clang