[mlir] NFC: fix trivial typos
[lldb.git] / mlir / tools / mlir-tblgen / OpDefinitionsGen.cpp
1 //===- OpDefinitionsGen.cpp - MLIR op definitions generator ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // OpDefinitionsGen uses the description of operations to generate C++
10 // definitions for ops.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "OpFormatGen.h"
15 #include "mlir/TableGen/CodeGenHelpers.h"
16 #include "mlir/TableGen/Format.h"
17 #include "mlir/TableGen/GenInfo.h"
18 #include "mlir/TableGen/Interfaces.h"
19 #include "mlir/TableGen/OpClass.h"
20 #include "mlir/TableGen/OpTrait.h"
21 #include "mlir/TableGen/Operator.h"
22 #include "mlir/TableGen/SideEffects.h"
23 #include "llvm/ADT/Sequence.h"
24 #include "llvm/ADT/StringExtras.h"
25 #include "llvm/Support/CommandLine.h"
26 #include "llvm/Support/Regex.h"
27 #include "llvm/Support/Signals.h"
28 #include "llvm/TableGen/Error.h"
29 #include "llvm/TableGen/Record.h"
30 #include "llvm/TableGen/TableGenBackend.h"
31
32 #define DEBUG_TYPE "mlir-tblgen-opdefgen"
33
34 using namespace llvm;
35 using namespace mlir;
36 using namespace mlir::tblgen;
37
38 cl::OptionCategory opDefGenCat("Options for -gen-op-defs and -gen-op-decls");
39
40 static cl::opt<std::string> opIncFilter(
41     "op-include-regex",
42     cl::desc("Regex of name of op's to include (no filter if empty)"),
43     cl::cat(opDefGenCat));
44 static cl::opt<std::string> opExcFilter(
45     "op-exclude-regex",
46     cl::desc("Regex of name of op's to exclude (no filter if empty)"),
47     cl::cat(opDefGenCat));
48
49 static const char *const tblgenNamePrefix = "tblgen_";
50 static const char *const generatedArgName = "odsArg";
51 static const char *const builder = "odsBuilder";
52 static const char *const builderOpState = "odsState";
53
54 // The logic to calculate the actual value range for a declared operand/result
55 // of an op with variadic operands/results. Note that this logic is not for
56 // general use; it assumes all variadic operands/results must have the same
57 // number of values.
58 //
59 // {0}: The list of whether each declared operand/result is variadic.
60 // {1}: The total number of non-variadic operands/results.
61 // {2}: The total number of variadic operands/results.
62 // {3}: The total number of actual values.
63 // {4}: "operand" or "result".
64 const char *sameVariadicSizeValueRangeCalcCode = R"(
65   bool isVariadic[] = {{{0}};
66   int prevVariadicCount = 0;
67   for (unsigned i = 0; i < index; ++i)
68     if (isVariadic[i]) ++prevVariadicCount;
69
70   // Calculate how many dynamic values a static variadic {4} corresponds to.
71   // This assumes all static variadic {4}s have the same dynamic value count.
72   int variadicSize = ({3} - {1}) / {2};
73   // `index` passed in as the parameter is the static index which counts each
74   // {4} (variadic or not) as size 1. So here for each previous static variadic
75   // {4}, we need to offset by (variadicSize - 1) to get where the dynamic
76   // value pack for this static {4} starts.
77   int start = index + (variadicSize - 1) * prevVariadicCount;
78   int size = isVariadic[index] ? variadicSize : 1;
79   return {{start, size};
80 )";
81
82 // The logic to calculate the actual value range for a declared operand/result
83 // of an op with variadic operands/results. Note that this logic is assumes
84 // the op has an attribute specifying the size of each operand/result segment
85 // (variadic or not).
86 //
87 // {0}: The name of the attribute specifying the segment sizes.
88 const char *adapterSegmentSizeAttrInitCode = R"(
89   assert(odsAttrs && "missing segment size attribute for op");
90   auto sizeAttr = odsAttrs.get("{0}").cast<::mlir::DenseIntElementsAttr>();
91 )";
92 const char *opSegmentSizeAttrInitCode = R"(
93   auto sizeAttr = (*this)->getAttrOfType<::mlir::DenseIntElementsAttr>("{0}");
94 )";
95 const char *attrSizedSegmentValueRangeCalcCode = R"(
96   unsigned start = 0;
97   for (unsigned i = 0; i < index; ++i)
98     start += (*(sizeAttr.begin() + i)).getZExtValue();
99   unsigned size = (*(sizeAttr.begin() + index)).getZExtValue();
100   return {start, size};
101 )";
102
103 // The logic to build a range of either operand or result values.
104 //
105 // {0}: The begin iterator of the actual values.
106 // {1}: The call to generate the start and length of the value range.
107 const char *valueRangeReturnCode = R"(
108   auto valueRange = {1};
109   return {{std::next({0}, valueRange.first),
110            std::next({0}, valueRange.first + valueRange.second)};
111 )";
112
113 static const char *const opCommentHeader = R"(
114 //===----------------------------------------------------------------------===//
115 // {0} {1}
116 //===----------------------------------------------------------------------===//
117
118 )";
119
120 //===----------------------------------------------------------------------===//
121 // StaticVerifierFunctionEmitter
122 //===----------------------------------------------------------------------===//
123
124 namespace {
125 /// This class deduplicates shared operation verification code by emitting
126 /// static functions alongside the op definitions. These methods are local to
127 /// the definition file, and are invoked within the operation verify methods.
128 /// An example is shown below:
129 ///
130 /// static LogicalResult localVerify(...)
131 ///
132 /// LogicalResult OpA::verify(...) {
133 ///  if (failed(localVerify(...)))
134 ///    return failure();
135 ///  ...
136 /// }
137 ///
138 /// LogicalResult OpB::verify(...) {
139 ///  if (failed(localVerify(...)))
140 ///    return failure();
141 ///  ...
142 /// }
143 ///
144 class StaticVerifierFunctionEmitter {
145 public:
146   StaticVerifierFunctionEmitter(const llvm::RecordKeeper &records,
147                                 ArrayRef<llvm::Record *> opDefs,
148                                 raw_ostream &os, bool emitDecl);
149
150   /// Get the name of the local function used for the given type constraint.
151   /// These functions are used for operand and result constraints and have the
152   /// form:
153   ///   LogicalResult(Operation *op, Type type, StringRef valueKind,
154   ///                 unsigned valueGroupStartIndex);
155   StringRef getTypeConstraintFn(const Constraint &constraint) const {
156     auto it = localTypeConstraints.find(constraint.getAsOpaquePointer());
157     assert(it != localTypeConstraints.end() && "expected valid constraint fn");
158     return it->second;
159   }
160
161 private:
162   /// Returns a unique name to use when generating local methods.
163   static std::string getUniqueName(const llvm::RecordKeeper &records);
164
165   /// Emit local methods for the type constraints used within the provided op
166   /// definitions.
167   void emitTypeConstraintMethods(ArrayRef<llvm::Record *> opDefs,
168                                  raw_ostream &os, bool emitDecl);
169
170   /// A unique label for the file currently being generated. This is used to
171   /// ensure that the local functions have a unique name.
172   std::string uniqueOutputLabel;
173
174   /// A set of functions implementing type constraints, used for operand and
175   /// result verification.
176   llvm::DenseMap<const void *, std::string> localTypeConstraints;
177 };
178 } // namespace
179
180 StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter(
181     const llvm::RecordKeeper &records, ArrayRef<llvm::Record *> opDefs,
182     raw_ostream &os, bool emitDecl)
183     : uniqueOutputLabel(getUniqueName(records)) {
184   llvm::Optional<NamespaceEmitter> namespaceEmitter;
185   if (!emitDecl) {
186     os << formatv(opCommentHeader, "Local Utility Method", "Definitions");
187     namespaceEmitter.emplace(os, Operator(*opDefs[0]).getDialect());
188   }
189
190   emitTypeConstraintMethods(opDefs, os, emitDecl);
191 }
192
193 std::string StaticVerifierFunctionEmitter::getUniqueName(
194     const llvm::RecordKeeper &records) {
195   // Use the input file name when generating a unique name.
196   std::string inputFilename = records.getInputFilename();
197
198   // Drop all but the base filename.
199   StringRef nameRef = llvm::sys::path::filename(inputFilename);
200   nameRef.consume_back(".td");
201
202   // Sanitize any invalid characters.
203   std::string uniqueName;
204   for (char c : nameRef) {
205     if (llvm::isAlnum(c) || c == '_')
206       uniqueName.push_back(c);
207     else
208       uniqueName.append(llvm::utohexstr((unsigned char)c));
209   }
210   return uniqueName;
211 }
212
213 void StaticVerifierFunctionEmitter::emitTypeConstraintMethods(
214     ArrayRef<llvm::Record *> opDefs, raw_ostream &os, bool emitDecl) {
215   // Collect a set of all of the used type constraints within the operation
216   // definitions.
217   llvm::SetVector<const void *> typeConstraints;
218   for (Record *def : opDefs) {
219     Operator op(*def);
220     for (NamedTypeConstraint &operand : op.getOperands())
221       if (operand.hasPredicate())
222         typeConstraints.insert(operand.constraint.getAsOpaquePointer());
223     for (NamedTypeConstraint &result : op.getResults())
224       if (result.hasPredicate())
225         typeConstraints.insert(result.constraint.getAsOpaquePointer());
226   }
227
228   FmtContext fctx;
229   for (auto it : llvm::enumerate(typeConstraints)) {
230     // Generate an obscure and unique name for this type constraint.
231     std::string name = (Twine("__mlir_ods_local_type_constraint_") +
232                         uniqueOutputLabel + Twine(it.index()))
233                            .str();
234     localTypeConstraints.try_emplace(it.value(), name);
235
236     // Only generate the methods if we are generating definitions.
237     if (emitDecl)
238       continue;
239
240     Constraint constraint = Constraint::getFromOpaquePointer(it.value());
241     os << "static ::mlir::LogicalResult " << name
242        << "(::mlir::Operation *op, ::mlir::Type type, ::llvm::StringRef "
243           "valueKind, unsigned valueGroupStartIndex) {\n";
244
245     os << "  if (!("
246        << tgfmt(constraint.getConditionTemplate(), &fctx.withSelf("type"))
247        << ")) {\n"
248        << formatv(
249               "    return op->emitOpError(valueKind) << \" #\" << "
250               "valueGroupStartIndex << \" must be {0}, but got \" << type;\n",
251               constraint.getDescription())
252        << "  }\n"
253        << "  return ::mlir::success();\n"
254        << "}\n\n";
255   }
256 }
257
258 //===----------------------------------------------------------------------===//
259 // Utility structs and functions
260 //===----------------------------------------------------------------------===//
261
262 // Replaces all occurrences of `match` in `str` with `substitute`.
263 static std::string replaceAllSubstrs(std::string str, const std::string &match,
264                                      const std::string &substitute) {
265   std::string::size_type scanLoc = 0, matchLoc = std::string::npos;
266   while ((matchLoc = str.find(match, scanLoc)) != std::string::npos) {
267     str = str.replace(matchLoc, match.size(), substitute);
268     scanLoc = matchLoc + substitute.size();
269   }
270   return str;
271 }
272
273 // Returns whether the record has a value of the given name that can be returned
274 // via getValueAsString.
275 static inline bool hasStringAttribute(const Record &record,
276                                       StringRef fieldName) {
277   auto valueInit = record.getValueInit(fieldName);
278   return isa<StringInit>(valueInit);
279 }
280
281 static std::string getArgumentName(const Operator &op, int index) {
282   const auto &operand = op.getOperand(index);
283   if (!operand.name.empty())
284     return std::string(operand.name);
285   else
286     return std::string(formatv("{0}_{1}", generatedArgName, index));
287 }
288
289 // Returns true if we can use unwrapped value for the given `attr` in builders.
290 static bool canUseUnwrappedRawValue(const tblgen::Attribute &attr) {
291   return attr.getReturnType() != attr.getStorageType() &&
292          // We need to wrap the raw value into an attribute in the builder impl
293          // so we need to make sure that the attribute specifies how to do that.
294          !attr.getConstBuilderTemplate().empty();
295 }
296
297 //===----------------------------------------------------------------------===//
298 // Op emitter
299 //===----------------------------------------------------------------------===//
300
301 namespace {
302 // Helper class to emit a record into the given output stream.
303 class OpEmitter {
304 public:
305   static void
306   emitDecl(const Operator &op, raw_ostream &os,
307            const StaticVerifierFunctionEmitter &staticVerifierEmitter);
308   static void
309   emitDef(const Operator &op, raw_ostream &os,
310           const StaticVerifierFunctionEmitter &staticVerifierEmitter);
311
312 private:
313   OpEmitter(const Operator &op,
314             const StaticVerifierFunctionEmitter &staticVerifierEmitter);
315
316   void emitDecl(raw_ostream &os);
317   void emitDef(raw_ostream &os);
318
319   // Generates the OpAsmOpInterface for this operation if possible.
320   void genOpAsmInterface();
321
322   // Generates the `getOperationName` method for this op.
323   void genOpNameGetter();
324
325   // Generates getters for the attributes.
326   void genAttrGetters();
327
328   // Generates setter for the attributes.
329   void genAttrSetters();
330
331   // Generates removers for optional attributes.
332   void genOptionalAttrRemovers();
333
334   // Generates getters for named operands.
335   void genNamedOperandGetters();
336
337   // Generates setters for named operands.
338   void genNamedOperandSetters();
339
340   // Generates getters for named results.
341   void genNamedResultGetters();
342
343   // Generates getters for named regions.
344   void genNamedRegionGetters();
345
346   // Generates getters for named successors.
347   void genNamedSuccessorGetters();
348
349   // Generates builder methods for the operation.
350   void genBuilder();
351
352   // Generates the build() method that takes each operand/attribute
353   // as a stand-alone parameter.
354   void genSeparateArgParamBuilder();
355
356   // Generates the build() method that takes each operand/attribute as a
357   // stand-alone parameter. The generated build() method uses first operand's
358   // type as all results' types.
359   void genUseOperandAsResultTypeSeparateParamBuilder();
360
361   // Generates the build() method that takes all operands/attributes
362   // collectively as one parameter. The generated build() method uses first
363   // operand's type as all results' types.
364   void genUseOperandAsResultTypeCollectiveParamBuilder();
365
366   // Generates the build() method that takes aggregate operands/attributes
367   // parameters. This build() method uses inferred types as result types.
368   // Requires: The type needs to be inferable via InferTypeOpInterface.
369   void genInferredTypeCollectiveParamBuilder();
370
371   // Generates the build() method that takes each operand/attribute as a
372   // stand-alone parameter. The generated build() method uses first attribute's
373   // type as all result's types.
374   void genUseAttrAsResultTypeBuilder();
375
376   // Generates the build() method that takes all result types collectively as
377   // one parameter. Similarly for operands and attributes.
378   void genCollectiveParamBuilder();
379
380   // The kind of parameter to generate for result types in builders.
381   enum class TypeParamKind {
382     None,       // No result type in parameter list.
383     Separate,   // A separate parameter for each result type.
384     Collective, // An ArrayRef<Type> for all result types.
385   };
386
387   // The kind of parameter to generate for attributes in builders.
388   enum class AttrParamKind {
389     WrappedAttr,    // A wrapped MLIR Attribute instance.
390     UnwrappedValue, // A raw value without MLIR Attribute wrapper.
391   };
392
393   // Builds the parameter list for build() method of this op. This method writes
394   // to `paramList` the comma-separated parameter list and updates
395   // `resultTypeNames` with the names for parameters for specifying result
396   // types. The given `typeParamKind` and `attrParamKind` controls how result
397   // types and attributes are placed in the parameter list.
398   void buildParamList(llvm::SmallVectorImpl<OpMethodParameter> &paramList,
399                       SmallVectorImpl<std::string> &resultTypeNames,
400                       TypeParamKind typeParamKind,
401                       AttrParamKind attrParamKind = AttrParamKind::WrappedAttr);
402
403   // Adds op arguments and regions into operation state for build() methods.
404   void genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
405                                               bool isRawValueAttr = false);
406
407   // Generates canonicalizer declaration for the operation.
408   void genCanonicalizerDecls();
409
410   // Generates the folder declaration for the operation.
411   void genFolderDecls();
412
413   // Generates the parser for the operation.
414   void genParser();
415
416   // Generates the printer for the operation.
417   void genPrinter();
418
419   // Generates verify method for the operation.
420   void genVerifier();
421
422   // Generates verify statements for operands and results in the operation.
423   // The generated code will be attached to `body`.
424   void genOperandResultVerifier(OpMethodBody &body,
425                                 Operator::value_range values,
426                                 StringRef valueKind);
427
428   // Generates verify statements for regions in the operation.
429   // The generated code will be attached to `body`.
430   void genRegionVerifier(OpMethodBody &body);
431
432   // Generates verify statements for successors in the operation.
433   // The generated code will be attached to `body`.
434   void genSuccessorVerifier(OpMethodBody &body);
435
436   // Generates the traits used by the object.
437   void genTraits();
438
439   // Generate the OpInterface methods for all interfaces.
440   void genOpInterfaceMethods();
441
442   // Generate op interface methods for the given interface.
443   void genOpInterfaceMethods(const tblgen::InterfaceOpTrait *trait);
444
445   // Generate op interface method for the given interface method. If
446   // 'declaration' is true, generates a declaration, else a definition.
447   OpMethod *genOpInterfaceMethod(const tblgen::InterfaceMethod &method,
448                                  bool declaration = true);
449
450   // Generate the side effect interface methods.
451   void genSideEffectInterfaceMethods();
452
453   // Generate the type inference interface methods.
454   void genTypeInterfaceMethods();
455
456 private:
457   // The TableGen record for this op.
458   // TODO: OpEmitter should not have a Record directly,
459   // it should rather go through the Operator for better abstraction.
460   const Record &def;
461
462   // The wrapper operator class for querying information from this op.
463   Operator op;
464
465   // The C++ code builder for this op
466   OpClass opClass;
467
468   // The format context for verification code generation.
469   FmtContext verifyCtx;
470
471   // The emitter containing all of the locally emitted verification functions.
472   const StaticVerifierFunctionEmitter &staticVerifierEmitter;
473 };
474 } // end anonymous namespace
475
476 // Populate the format context `ctx` with substitutions of attributes, operands
477 // and results.
478 // - attrGet corresponds to the name of the function to call to get value of
479 //   attribute (the generated function call returns an Attribute);
480 // - operandGet corresponds to the name of the function with which to retrieve
481 //   an operand (the generated function call returns an OperandRange);
482 // - resultGet corresponds to the name of the function to get an result (the
483 //   generated function call returns a ValueRange);
484 static void populateSubstitutions(const Operator &op, const char *attrGet,
485                                   const char *operandGet, const char *resultGet,
486                                   FmtContext &ctx) {
487   // Populate substitutions for attributes and named operands.
488   for (const auto &namedAttr : op.getAttributes())
489     ctx.addSubst(namedAttr.name,
490                  formatv("{0}(\"{1}\")", attrGet, namedAttr.name));
491   for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
492     auto &value = op.getOperand(i);
493     if (value.name.empty())
494       continue;
495
496     if (value.isVariadic())
497       ctx.addSubst(value.name, formatv("{0}({1})", operandGet, i));
498     else
499       ctx.addSubst(value.name, formatv("(*{0}({1}).begin())", operandGet, i));
500   }
501
502   // Populate substitutions for results.
503   for (int i = 0, e = op.getNumResults(); i < e; ++i) {
504     auto &value = op.getResult(i);
505     if (value.name.empty())
506       continue;
507
508     if (value.isVariadic())
509       ctx.addSubst(value.name, formatv("{0}({1})", resultGet, i));
510     else
511       ctx.addSubst(value.name, formatv("(*{0}({1}).begin())", resultGet, i));
512   }
513 }
514
515 // Generate attribute verification. If emitVerificationRequiringOp is set then
516 // only verification for attributes whose value depend on op being known are
517 // emitted, else only verification that doesn't depend on the op being known are
518 // generated.
519 // - emitErrorPrefix is the prefix for the error emitting call which consists
520 //   of the entire function call up to start of error message fragment;
521 // - emitVerificationRequiringOp specifies whether verification should be
522 //   emitted for verification that require the op to exist;
523 static void genAttributeVerifier(const Operator &op, const char *attrGet,
524                                  const Twine &emitErrorPrefix,
525                                  bool emitVerificationRequiringOp,
526                                  FmtContext &ctx, OpMethodBody &body) {
527   for (const auto &namedAttr : op.getAttributes()) {
528     const auto &attr = namedAttr.attr;
529     if (attr.isDerivedAttr())
530       continue;
531
532     auto attrName = namedAttr.name;
533     bool allowMissingAttr = attr.hasDefaultValue() || attr.isOptional();
534     auto attrPred = attr.getPredicate();
535     auto condition = attrPred.isNull() ? "" : attrPred.getCondition();
536     // There is a condition to emit only if the use of $_op and whether to
537     // emit verifications for op matches.
538     bool hasConditionToEmit = (!(condition.find("$_op") != StringRef::npos) ^
539                                emitVerificationRequiringOp);
540
541     // Prefix with `tblgen_` to avoid hiding the attribute accessor.
542     auto varName = tblgenNamePrefix + attrName;
543
544     // If the attribute is
545     //  1. Required (not allowed missing) and not in op verification, or
546     //  2. Has a condition that will get verified
547     // then the variable will be used.
548     //
549     // Therefore, for optional attributes whose verification requires that an
550     // op already exists for verification/emitVerificationRequiringOp is set
551     // has nothing that can be verified here.
552     if ((allowMissingAttr || emitVerificationRequiringOp) &&
553         !hasConditionToEmit)
554       continue;
555
556     body << formatv("  {\n  auto {0} = {1}(\"{2}\");\n", varName, attrGet,
557                     attrName);
558
559     if (!emitVerificationRequiringOp && !allowMissingAttr) {
560       body << "  if (!" << varName << ") return " << emitErrorPrefix
561            << "\"requires attribute '" << attrName << "'\");\n";
562     }
563
564     if (!hasConditionToEmit) {
565       body << "  }\n";
566       continue;
567     }
568
569     if (allowMissingAttr) {
570       // If the attribute has a default value, then only verify the predicate if
571       // set. This does effectively assume that the default value is valid.
572       // TODO: verify the debug value is valid (perhaps in debug mode only).
573       body << "  if (" << varName << ") {\n";
574     }
575
576     body << tgfmt("    if (!($0)) return $1\"attribute '$2' "
577                   "failed to satisfy constraint: $3\");\n",
578                   /*ctx=*/nullptr, tgfmt(condition, &ctx.withSelf(varName)),
579                   emitErrorPrefix, attrName, attr.getDescription());
580     if (allowMissingAttr)
581       body << "  }\n";
582     body << "  }\n";
583   }
584 }
585
586 OpEmitter::OpEmitter(const Operator &op,
587                      const StaticVerifierFunctionEmitter &staticVerifierEmitter)
588     : def(op.getDef()), op(op),
589       opClass(op.getCppClassName(), op.getExtraClassDeclaration()),
590       staticVerifierEmitter(staticVerifierEmitter) {
591   verifyCtx.withOp("(*this->getOperation())");
592
593   genTraits();
594
595   // Generate C++ code for various op methods. The order here determines the
596   // methods in the generated file.
597   genOpAsmInterface();
598   genOpNameGetter();
599   genNamedOperandGetters();
600   genNamedOperandSetters();
601   genNamedResultGetters();
602   genNamedRegionGetters();
603   genNamedSuccessorGetters();
604   genAttrGetters();
605   genAttrSetters();
606   genOptionalAttrRemovers();
607   genBuilder();
608   genParser();
609   genPrinter();
610   genVerifier();
611   genCanonicalizerDecls();
612   genFolderDecls();
613   genTypeInterfaceMethods();
614   genOpInterfaceMethods();
615   generateOpFormat(op, opClass);
616   genSideEffectInterfaceMethods();
617 }
618
619 void OpEmitter::emitDecl(
620     const Operator &op, raw_ostream &os,
621     const StaticVerifierFunctionEmitter &staticVerifierEmitter) {
622   OpEmitter(op, staticVerifierEmitter).emitDecl(os);
623 }
624
625 void OpEmitter::emitDef(
626     const Operator &op, raw_ostream &os,
627     const StaticVerifierFunctionEmitter &staticVerifierEmitter) {
628   OpEmitter(op, staticVerifierEmitter).emitDef(os);
629 }
630
631 void OpEmitter::emitDecl(raw_ostream &os) { opClass.writeDeclTo(os); }
632
633 void OpEmitter::emitDef(raw_ostream &os) { opClass.writeDefTo(os); }
634
635 void OpEmitter::genAttrGetters() {
636   FmtContext fctx;
637   fctx.withBuilder("::mlir::Builder((*this)->getContext())");
638
639   Dialect opDialect = op.getDialect();
640   // Emit the derived attribute body.
641   auto emitDerivedAttr = [&](StringRef name, Attribute attr) {
642     auto *method = opClass.addMethodAndPrune(attr.getReturnType(), name);
643     if (!method)
644       return;
645     auto &body = method->body();
646     body << "  " << attr.getDerivedCodeBody() << "\n";
647   };
648
649   // Emit with return type specified.
650   auto emitAttrWithReturnType = [&](StringRef name, Attribute attr) {
651     auto *method = opClass.addMethodAndPrune(attr.getReturnType(), name);
652     auto &body = method->body();
653     body << "  auto attr = " << name << "Attr();\n";
654     if (attr.hasDefaultValue()) {
655       // Returns the default value if not set.
656       // TODO: this is inefficient, we are recreating the attribute for every
657       // call. This should be set instead.
658       std::string defaultValue = std::string(
659           tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
660       body << "    if (!attr)\n      return "
661            << tgfmt(attr.getConvertFromStorageCall(),
662                     &fctx.withSelf(defaultValue))
663            << ";\n";
664     }
665     body << "  return "
666          << tgfmt(attr.getConvertFromStorageCall(), &fctx.withSelf("attr"))
667          << ";\n";
668   };
669
670   // Generate raw named accessor type. This is a wrapper class that allows
671   // referring to the attributes via accessors instead of having to use
672   // the string interface for better compile time verification.
673   auto emitAttrWithStorageType = [&](StringRef name, Attribute attr) {
674     auto *method =
675         opClass.addMethodAndPrune(attr.getStorageType(), (name + "Attr").str());
676     if (!method)
677       return;
678     auto &body = method->body();
679     body << "  return (*this)->getAttr(\"" << name << "\").template ";
680     if (attr.isOptional() || attr.hasDefaultValue())
681       body << "dyn_cast_or_null<";
682     else
683       body << "cast<";
684     body << attr.getStorageType() << ">();";
685   };
686
687   for (auto &namedAttr : op.getAttributes()) {
688     const auto &name = namedAttr.name;
689     const auto &attr = namedAttr.attr;
690     if (attr.isDerivedAttr()) {
691       emitDerivedAttr(name, attr);
692     } else {
693       emitAttrWithStorageType(name, attr);
694       emitAttrWithReturnType(name, attr);
695     }
696   }
697
698   auto derivedAttrs = make_filter_range(op.getAttributes(),
699                                         [](const NamedAttribute &namedAttr) {
700                                           return namedAttr.attr.isDerivedAttr();
701                                         });
702   if (!derivedAttrs.empty()) {
703     opClass.addTrait("::mlir::DerivedAttributeOpInterface::Trait");
704     // Generate helper method to query whether a named attribute is a derived
705     // attribute. This enables, for example, avoiding adding an attribute that
706     // overlaps with a derived attribute.
707     {
708       auto *method = opClass.addMethodAndPrune("bool", "isDerivedAttribute",
709                                                OpMethod::MP_Static,
710                                                "::llvm::StringRef", "name");
711       auto &body = method->body();
712       for (auto namedAttr : derivedAttrs)
713         body << "  if (name == \"" << namedAttr.name << "\") return true;\n";
714       body << " return false;";
715     }
716     // Generate method to materialize derived attributes as a DictionaryAttr.
717     {
718       auto *method = opClass.addMethodAndPrune("::mlir::DictionaryAttr",
719                                                "materializeDerivedAttributes");
720       auto &body = method->body();
721
722       auto nonMaterializable =
723           make_filter_range(derivedAttrs, [](const NamedAttribute &namedAttr) {
724             return namedAttr.attr.getConvertFromStorageCall().empty();
725           });
726       if (!nonMaterializable.empty()) {
727         std::string attrs;
728         llvm::raw_string_ostream os(attrs);
729         interleaveComma(nonMaterializable, os,
730                         [&](const NamedAttribute &attr) { os << attr.name; });
731         PrintWarning(
732             op.getLoc(),
733             formatv(
734                 "op has non-materializable derived attributes '{0}', skipping",
735                 os.str()));
736         body << formatv("  emitOpError(\"op has non-materializable derived "
737                         "attributes '{0}'\");\n",
738                         attrs);
739         body << "  return nullptr;";
740         return;
741       }
742
743       body << "  ::mlir::MLIRContext* ctx = getContext();\n";
744       body << "  ::mlir::Builder odsBuilder(ctx); (void)odsBuilder;\n";
745       body << "  return ::mlir::DictionaryAttr::get({\n";
746       interleave(
747           derivedAttrs, body,
748           [&](const NamedAttribute &namedAttr) {
749             auto tmpl = namedAttr.attr.getConvertFromStorageCall();
750             body << "    {::mlir::Identifier::get(\"" << namedAttr.name
751                  << "\", ctx),\n"
752                  << tgfmt(tmpl, &fctx.withSelf(namedAttr.name + "()")
753                                      .withBuilder("odsBuilder")
754                                      .addSubst("_ctx", "ctx"))
755                  << "}";
756           },
757           ",\n");
758       body << "\n    }, ctx);";
759     }
760   }
761 }
762
763 void OpEmitter::genAttrSetters() {
764   // Generate raw named setter type. This is a wrapper class that allows setting
765   // to the attributes via setters instead of having to use the string interface
766   // for better compile time verification.
767   auto emitAttrWithStorageType = [&](StringRef name, Attribute attr) {
768     auto *method = opClass.addMethodAndPrune("void", (name + "Attr").str(),
769                                              attr.getStorageType(), "attr");
770     if (!method)
771       return;
772     auto &body = method->body();
773     body << "  (*this)->setAttr(\"" << name << "\", attr);";
774   };
775
776   for (auto &namedAttr : op.getAttributes()) {
777     const auto &name = namedAttr.name;
778     const auto &attr = namedAttr.attr;
779     if (!attr.isDerivedAttr())
780       emitAttrWithStorageType(name, attr);
781   }
782 }
783
784 void OpEmitter::genOptionalAttrRemovers() {
785   // Generate methods for removing optional attributes, instead of having to
786   // use the string interface. Enables better compile time verification.
787   auto emitRemoveAttr = [&](StringRef name) {
788     auto upperInitial = name.take_front().upper();
789     auto suffix = name.drop_front();
790     auto *method = opClass.addMethodAndPrune(
791         "::mlir::Attribute", ("remove" + upperInitial + suffix + "Attr").str());
792     if (!method)
793       return;
794     auto &body = method->body();
795     body << "  return (*this)->removeAttr(\"" << name << "\");";
796   };
797
798   for (const auto &namedAttr : op.getAttributes()) {
799     const auto &name = namedAttr.name;
800     const auto &attr = namedAttr.attr;
801     if (attr.isOptional())
802       emitRemoveAttr(name);
803   }
804 }
805
806 // Generates the code to compute the start and end index of an operand or result
807 // range.
808 template <typename RangeT>
809 static void
810 generateValueRangeStartAndEnd(Class &opClass, StringRef methodName,
811                               int numVariadic, int numNonVariadic,
812                               StringRef rangeSizeCall, bool hasAttrSegmentSize,
813                               StringRef sizeAttrInit, RangeT &&odsValues) {
814   auto *method = opClass.addMethodAndPrune("std::pair<unsigned, unsigned>",
815                                            methodName, "unsigned", "index");
816   if (!method)
817     return;
818   auto &body = method->body();
819   if (numVariadic == 0) {
820     body << "  return {index, 1};\n";
821   } else if (hasAttrSegmentSize) {
822     body << sizeAttrInit << attrSizedSegmentValueRangeCalcCode;
823   } else {
824     // Because the op can have arbitrarily interleaved variadic and non-variadic
825     // operands, we need to embed a list in the "sink" getter method for
826     // calculation at run-time.
827     llvm::SmallVector<StringRef, 4> isVariadic;
828     isVariadic.reserve(llvm::size(odsValues));
829     for (auto &it : odsValues)
830       isVariadic.push_back(it.isVariableLength() ? "true" : "false");
831     std::string isVariadicList = llvm::join(isVariadic, ", ");
832     body << formatv(sameVariadicSizeValueRangeCalcCode, isVariadicList,
833                     numNonVariadic, numVariadic, rangeSizeCall, "operand");
834   }
835 }
836
837 // Generates the named operand getter methods for the given Operator `op` and
838 // puts them in `opClass`.  Uses `rangeType` as the return type of getters that
839 // return a range of operands (individual operands are `Value ` and each
840 // element in the range must also be `Value `); use `rangeBeginCall` to get
841 // an iterator to the beginning of the operand range; use `rangeSizeCall` to
842 // obtain the number of operands. `getOperandCallPattern` contains the code
843 // necessary to obtain a single operand whose position will be substituted
844 // instead of
845 // "{0}" marker in the pattern.  Note that the pattern should work for any kind
846 // of ops, in particular for one-operand ops that may not have the
847 // `getOperand(unsigned)` method.
848 static void generateNamedOperandGetters(const Operator &op, Class &opClass,
849                                         StringRef sizeAttrInit,
850                                         StringRef rangeType,
851                                         StringRef rangeBeginCall,
852                                         StringRef rangeSizeCall,
853                                         StringRef getOperandCallPattern) {
854   const int numOperands = op.getNumOperands();
855   const int numVariadicOperands = op.getNumVariableLengthOperands();
856   const int numNormalOperands = numOperands - numVariadicOperands;
857
858   const auto *sameVariadicSize =
859       op.getTrait("::mlir::OpTrait::SameVariadicOperandSize");
860   const auto *attrSizedOperands =
861       op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments");
862
863   if (numVariadicOperands > 1 && !sameVariadicSize && !attrSizedOperands) {
864     PrintFatalError(op.getLoc(), "op has multiple variadic operands but no "
865                                  "specification over their sizes");
866   }
867
868   if (numVariadicOperands < 2 && attrSizedOperands) {
869     PrintFatalError(op.getLoc(), "op must have at least two variadic operands "
870                                  "to use 'AttrSizedOperandSegments' trait");
871   }
872
873   if (attrSizedOperands && sameVariadicSize) {
874     PrintFatalError(op.getLoc(),
875                     "op cannot have both 'AttrSizedOperandSegments' and "
876                     "'SameVariadicOperandSize' traits");
877   }
878
879   // First emit a few "sink" getter methods upon which we layer all nicer named
880   // getter methods.
881   generateValueRangeStartAndEnd(opClass, "getODSOperandIndexAndLength",
882                                 numVariadicOperands, numNormalOperands,
883                                 rangeSizeCall, attrSizedOperands, sizeAttrInit,
884                                 const_cast<Operator &>(op).getOperands());
885
886   auto *m = opClass.addMethodAndPrune(rangeType, "getODSOperands", "unsigned",
887                                       "index");
888   auto &body = m->body();
889   body << formatv(valueRangeReturnCode, rangeBeginCall,
890                   "getODSOperandIndexAndLength(index)");
891
892   // Then we emit nicer named getter methods by redirecting to the "sink" getter
893   // method.
894   for (int i = 0; i != numOperands; ++i) {
895     const auto &operand = op.getOperand(i);
896     if (operand.name.empty())
897       continue;
898
899     if (operand.isOptional()) {
900       m = opClass.addMethodAndPrune("::mlir::Value", operand.name);
901       m->body() << "  auto operands = getODSOperands(" << i << ");\n"
902                 << "  return operands.empty() ? Value() : *operands.begin();";
903     } else if (operand.isVariadic()) {
904       m = opClass.addMethodAndPrune(rangeType, operand.name);
905       m->body() << "  return getODSOperands(" << i << ");";
906     } else {
907       m = opClass.addMethodAndPrune("::mlir::Value", operand.name);
908       m->body() << "  return *getODSOperands(" << i << ").begin();";
909     }
910   }
911 }
912
913 void OpEmitter::genNamedOperandGetters() {
914   generateNamedOperandGetters(
915       op, opClass,
916       /*sizeAttrInit=*/
917       formatv(opSegmentSizeAttrInitCode, "operand_segment_sizes").str(),
918       /*rangeType=*/"::mlir::Operation::operand_range",
919       /*rangeBeginCall=*/"getOperation()->operand_begin()",
920       /*rangeSizeCall=*/"getOperation()->getNumOperands()",
921       /*getOperandCallPattern=*/"getOperation()->getOperand({0})");
922 }
923
924 void OpEmitter::genNamedOperandSetters() {
925   auto *attrSizedOperands =
926       op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments");
927   for (int i = 0, e = op.getNumOperands(); i != e; ++i) {
928     const auto &operand = op.getOperand(i);
929     if (operand.name.empty())
930       continue;
931     auto *m = opClass.addMethodAndPrune("::mlir::MutableOperandRange",
932                                         (operand.name + "Mutable").str());
933     auto &body = m->body();
934     body << "  auto range = getODSOperandIndexAndLength(" << i << ");\n"
935          << "  return ::mlir::MutableOperandRange(getOperation(), "
936             "range.first, range.second";
937     if (attrSizedOperands)
938       body << ", ::mlir::MutableOperandRange::OperandSegment(" << i
939            << "u, *getOperation()->getAttrDictionary().getNamed("
940               "\"operand_segment_sizes\"))";
941     body << ");\n";
942   }
943 }
944
945 void OpEmitter::genNamedResultGetters() {
946   const int numResults = op.getNumResults();
947   const int numVariadicResults = op.getNumVariableLengthResults();
948   const int numNormalResults = numResults - numVariadicResults;
949
950   // If we have more than one variadic results, we need more complicated logic
951   // to calculate the value range for each result.
952
953   const auto *sameVariadicSize =
954       op.getTrait("::mlir::OpTrait::SameVariadicResultSize");
955   const auto *attrSizedResults =
956       op.getTrait("::mlir::OpTrait::AttrSizedResultSegments");
957
958   if (numVariadicResults > 1 && !sameVariadicSize && !attrSizedResults) {
959     PrintFatalError(op.getLoc(), "op has multiple variadic results but no "
960                                  "specification over their sizes");
961   }
962
963   if (numVariadicResults < 2 && attrSizedResults) {
964     PrintFatalError(op.getLoc(), "op must have at least two variadic results "
965                                  "to use 'AttrSizedResultSegments' trait");
966   }
967
968   if (attrSizedResults && sameVariadicSize) {
969     PrintFatalError(op.getLoc(),
970                     "op cannot have both 'AttrSizedResultSegments' and "
971                     "'SameVariadicResultSize' traits");
972   }
973
974   generateValueRangeStartAndEnd(
975       opClass, "getODSResultIndexAndLength", numVariadicResults,
976       numNormalResults, "getOperation()->getNumResults()", attrSizedResults,
977       formatv(opSegmentSizeAttrInitCode, "result_segment_sizes").str(),
978       op.getResults());
979
980   auto *m = opClass.addMethodAndPrune("::mlir::Operation::result_range",
981                                       "getODSResults", "unsigned", "index");
982   m->body() << formatv(valueRangeReturnCode, "getOperation()->result_begin()",
983                        "getODSResultIndexAndLength(index)");
984
985   for (int i = 0; i != numResults; ++i) {
986     const auto &result = op.getResult(i);
987     if (result.name.empty())
988       continue;
989
990     if (result.isOptional()) {
991       m = opClass.addMethodAndPrune("::mlir::Value", result.name);
992       m->body()
993           << "  auto results = getODSResults(" << i << ");\n"
994           << "  return results.empty() ? ::mlir::Value() : *results.begin();";
995     } else if (result.isVariadic()) {
996       m = opClass.addMethodAndPrune("::mlir::Operation::result_range",
997                                     result.name);
998       m->body() << "  return getODSResults(" << i << ");";
999     } else {
1000       m = opClass.addMethodAndPrune("::mlir::Value", result.name);
1001       m->body() << "  return *getODSResults(" << i << ").begin();";
1002     }
1003   }
1004 }
1005
1006 void OpEmitter::genNamedRegionGetters() {
1007   unsigned numRegions = op.getNumRegions();
1008   for (unsigned i = 0; i < numRegions; ++i) {
1009     const auto &region = op.getRegion(i);
1010     if (region.name.empty())
1011       continue;
1012
1013     // Generate the accessors for a variadic region.
1014     if (region.isVariadic()) {
1015       auto *m = opClass.addMethodAndPrune("::mlir::MutableArrayRef<Region>",
1016                                           region.name);
1017       m->body() << formatv("  return (*this)->getRegions().drop_front({0});",
1018                            i);
1019       continue;
1020     }
1021
1022     auto *m = opClass.addMethodAndPrune("::mlir::Region &", region.name);
1023     m->body() << formatv("  return (*this)->getRegion({0});", i);
1024   }
1025 }
1026
1027 void OpEmitter::genNamedSuccessorGetters() {
1028   unsigned numSuccessors = op.getNumSuccessors();
1029   for (unsigned i = 0; i < numSuccessors; ++i) {
1030     const NamedSuccessor &successor = op.getSuccessor(i);
1031     if (successor.name.empty())
1032       continue;
1033
1034     // Generate the accessors for a variadic successor list.
1035     if (successor.isVariadic()) {
1036       auto *m =
1037           opClass.addMethodAndPrune("::mlir::SuccessorRange", successor.name);
1038       m->body() << formatv(
1039           "  return {std::next((*this)->successor_begin(), {0}), "
1040           "(*this)->successor_end()};",
1041           i);
1042       continue;
1043     }
1044
1045     auto *m = opClass.addMethodAndPrune("::mlir::Block *", successor.name);
1046     m->body() << formatv("  return (*this)->getSuccessor({0});", i);
1047   }
1048 }
1049
1050 static bool canGenerateUnwrappedBuilder(Operator &op) {
1051   // If this op does not have native attributes at all, return directly to avoid
1052   // redefining builders.
1053   if (op.getNumNativeAttributes() == 0)
1054     return false;
1055
1056   bool canGenerate = false;
1057   // We are generating builders that take raw values for attributes. We need to
1058   // make sure the native attributes have a meaningful "unwrapped" value type
1059   // different from the wrapped mlir::Attribute type to avoid redefining
1060   // builders. This checks for the op has at least one such native attribute.
1061   for (int i = 0, e = op.getNumNativeAttributes(); i < e; ++i) {
1062     NamedAttribute &namedAttr = op.getAttribute(i);
1063     if (canUseUnwrappedRawValue(namedAttr.attr)) {
1064       canGenerate = true;
1065       break;
1066     }
1067   }
1068   return canGenerate;
1069 }
1070
1071 static bool canInferType(Operator &op) {
1072   return op.getTrait("::mlir::InferTypeOpInterface::Trait") &&
1073          op.getNumRegions() == 0;
1074 }
1075
1076 void OpEmitter::genSeparateArgParamBuilder() {
1077   SmallVector<AttrParamKind, 2> attrBuilderType;
1078   attrBuilderType.push_back(AttrParamKind::WrappedAttr);
1079   if (canGenerateUnwrappedBuilder(op))
1080     attrBuilderType.push_back(AttrParamKind::UnwrappedValue);
1081
1082   // Emit with separate builders with or without unwrapped attributes and/or
1083   // inferring result type.
1084   auto emit = [&](AttrParamKind attrType, TypeParamKind paramKind,
1085                   bool inferType) {
1086     llvm::SmallVector<OpMethodParameter, 4> paramList;
1087     llvm::SmallVector<std::string, 4> resultNames;
1088     buildParamList(paramList, resultNames, paramKind, attrType);
1089
1090     auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
1091                                         std::move(paramList));
1092     // If the builder is redundant, skip generating the method.
1093     if (!m)
1094       return;
1095     auto &body = m->body();
1096     genCodeForAddingArgAndRegionForBuilder(
1097         body, /*isRawValueAttr=*/attrType == AttrParamKind::UnwrappedValue);
1098
1099     // Push all result types to the operation state
1100
1101     if (inferType) {
1102       // Generate builder that infers type too.
1103       // TODO: Subsume this with general checking if type can be
1104       // inferred automatically.
1105       // TODO: Expand to handle regions.
1106       body << formatv(R"(
1107         ::llvm::SmallVector<::mlir::Type, 2> inferredReturnTypes;
1108         if (succeeded({0}::inferReturnTypes(odsBuilder.getContext(),
1109                       {1}.location, {1}.operands,
1110                       {1}.attributes.getDictionary({1}.getContext()),
1111                       /*regions=*/{{}, inferredReturnTypes)))
1112           {1}.addTypes(inferredReturnTypes);
1113         else
1114           ::llvm::report_fatal_error("Failed to infer result type(s).");)",
1115                       opClass.getClassName(), builderOpState);
1116       return;
1117     }
1118
1119     switch (paramKind) {
1120     case TypeParamKind::None:
1121       return;
1122     case TypeParamKind::Separate:
1123       for (int i = 0, e = op.getNumResults(); i < e; ++i) {
1124         if (op.getResult(i).isOptional())
1125           body << "  if (" << resultNames[i] << ")\n  ";
1126         body << "  " << builderOpState << ".addTypes(" << resultNames[i]
1127              << ");\n";
1128       }
1129       return;
1130     case TypeParamKind::Collective: {
1131       int numResults = op.getNumResults();
1132       int numVariadicResults = op.getNumVariableLengthResults();
1133       int numNonVariadicResults = numResults - numVariadicResults;
1134       bool hasVariadicResult = numVariadicResults != 0;
1135
1136       // Avoid emitting "resultTypes.size() >= 0u" which is always true.
1137       if (!(hasVariadicResult && numNonVariadicResults == 0))
1138         body << "  "
1139              << "assert(resultTypes.size() "
1140              << (hasVariadicResult ? ">=" : "==") << " "
1141              << numNonVariadicResults
1142              << "u && \"mismatched number of results\");\n";
1143       body << "  " << builderOpState << ".addTypes(resultTypes);\n";
1144     }
1145       return;
1146     }
1147     llvm_unreachable("unhandled TypeParamKind");
1148   };
1149
1150   // Some of the build methods generated here may be ambiguous, but TableGen's
1151   // ambiguous function detection will elide those ones.
1152   for (auto attrType : attrBuilderType) {
1153     emit(attrType, TypeParamKind::Separate, /*inferType=*/false);
1154     if (canInferType(op))
1155       emit(attrType, TypeParamKind::None, /*inferType=*/true);
1156     emit(attrType, TypeParamKind::Collective, /*inferType=*/false);
1157   }
1158 }
1159
1160 void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() {
1161   int numResults = op.getNumResults();
1162
1163   // Signature
1164   llvm::SmallVector<OpMethodParameter, 4> paramList;
1165   paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
1166   paramList.emplace_back("::mlir::OperationState &", builderOpState);
1167   paramList.emplace_back("::mlir::ValueRange", "operands");
1168   // Provide default value for `attributes` when its the last parameter
1169   StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}";
1170   paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
1171                          "attributes", attributesDefaultValue);
1172   if (op.getNumVariadicRegions())
1173     paramList.emplace_back("unsigned", "numRegions");
1174
1175   auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
1176                                       std::move(paramList));
1177   // If the builder is redundant, skip generating the method
1178   if (!m)
1179     return;
1180   auto &body = m->body();
1181
1182   // Operands
1183   body << "  " << builderOpState << ".addOperands(operands);\n";
1184
1185   // Attributes
1186   body << "  " << builderOpState << ".addAttributes(attributes);\n";
1187
1188   // Create the correct number of regions
1189   if (int numRegions = op.getNumRegions()) {
1190     body << llvm::formatv(
1191         "  for (unsigned i = 0; i != {0}; ++i)\n",
1192         (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions)));
1193     body << "    (void)" << builderOpState << ".addRegion();\n";
1194   }
1195
1196   // Result types
1197   SmallVector<std::string, 2> resultTypes(numResults, "operands[0].getType()");
1198   body << "  " << builderOpState << ".addTypes({"
1199        << llvm::join(resultTypes, ", ") << "});\n\n";
1200 }
1201
1202 void OpEmitter::genInferredTypeCollectiveParamBuilder() {
1203   // TODO: Expand to support regions.
1204   SmallVector<OpMethodParameter, 4> paramList;
1205   paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
1206   paramList.emplace_back("::mlir::OperationState &", builderOpState);
1207   paramList.emplace_back("::mlir::ValueRange", "operands");
1208   paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
1209                          "attributes", "{}");
1210   auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
1211                                       std::move(paramList));
1212   // If the builder is redundant, skip generating the method
1213   if (!m)
1214     return;
1215   auto &body = m->body();
1216
1217   int numResults = op.getNumResults();
1218   int numVariadicResults = op.getNumVariableLengthResults();
1219   int numNonVariadicResults = numResults - numVariadicResults;
1220
1221   int numOperands = op.getNumOperands();
1222   int numVariadicOperands = op.getNumVariableLengthOperands();
1223   int numNonVariadicOperands = numOperands - numVariadicOperands;
1224
1225   // Operands
1226   if (numVariadicOperands == 0 || numNonVariadicOperands != 0)
1227     body << "  assert(operands.size()"
1228          << (numVariadicOperands != 0 ? " >= " : " == ")
1229          << numNonVariadicOperands
1230          << "u && \"mismatched number of parameters\");\n";
1231   body << "  " << builderOpState << ".addOperands(operands);\n";
1232   body << "  " << builderOpState << ".addAttributes(attributes);\n";
1233
1234   // Create the correct number of regions
1235   if (int numRegions = op.getNumRegions()) {
1236     body << llvm::formatv(
1237         "  for (unsigned i = 0; i != {0}; ++i)\n",
1238         (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions)));
1239     body << "    (void)" << builderOpState << ".addRegion();\n";
1240   }
1241
1242   // Result types
1243   body << formatv(R"(
1244     ::mlir::SmallVector<::mlir::Type, 2> inferredReturnTypes;
1245     if (succeeded({0}::inferReturnTypes(odsBuilder.getContext(),
1246                   {1}.location, operands,
1247                   {1}.attributes.getDictionary({1}.getContext()),
1248                   /*regions=*/{{}, inferredReturnTypes))) {{)",
1249                   opClass.getClassName(), builderOpState);
1250   if (numVariadicResults == 0 || numNonVariadicResults != 0)
1251     body << "  assert(inferredReturnTypes.size()"
1252          << (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults
1253          << "u && \"mismatched number of return types\");\n";
1254   body << "      " << builderOpState << ".addTypes(inferredReturnTypes);";
1255
1256   body << formatv(R"(
1257     } else
1258       ::llvm::report_fatal_error("Failed to infer result type(s).");)",
1259                   opClass.getClassName(), builderOpState);
1260 }
1261
1262 void OpEmitter::genUseOperandAsResultTypeSeparateParamBuilder() {
1263   llvm::SmallVector<OpMethodParameter, 4> paramList;
1264   llvm::SmallVector<std::string, 4> resultNames;
1265   buildParamList(paramList, resultNames, TypeParamKind::None);
1266
1267   auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
1268                                       std::move(paramList));
1269   // If the builder is redundant, skip generating the method
1270   if (!m)
1271     return;
1272   auto &body = m->body();
1273   genCodeForAddingArgAndRegionForBuilder(body);
1274
1275   auto numResults = op.getNumResults();
1276   if (numResults == 0)
1277     return;
1278
1279   // Push all result types to the operation state
1280   const char *index = op.getOperand(0).isVariadic() ? ".front()" : "";
1281   std::string resultType =
1282       formatv("{0}{1}.getType()", getArgumentName(op, 0), index).str();
1283   body << "  " << builderOpState << ".addTypes({" << resultType;
1284   for (int i = 1; i != numResults; ++i)
1285     body << ", " << resultType;
1286   body << "});\n\n";
1287 }
1288
1289 void OpEmitter::genUseAttrAsResultTypeBuilder() {
1290   SmallVector<OpMethodParameter, 4> paramList;
1291   paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
1292   paramList.emplace_back("::mlir::OperationState &", builderOpState);
1293   paramList.emplace_back("::mlir::ValueRange", "operands");
1294   paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
1295                          "attributes", "{}");
1296   auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
1297                                       std::move(paramList));
1298   // If the builder is redundant, skip generating the method
1299   if (!m)
1300     return;
1301
1302   auto &body = m->body();
1303
1304   // Push all result types to the operation state
1305   std::string resultType;
1306   const auto &namedAttr = op.getAttribute(0);
1307
1308   body << "  for (auto attr : attributes) {\n";
1309   body << "    if (attr.first != \"" << namedAttr.name << "\") continue;\n";
1310   if (namedAttr.attr.isTypeAttr()) {
1311     resultType = "attr.second.cast<::mlir::TypeAttr>().getValue()";
1312   } else {
1313     resultType = "attr.second.getType()";
1314   }
1315
1316   // Operands
1317   body << "  " << builderOpState << ".addOperands(operands);\n";
1318
1319   // Attributes
1320   body << "  " << builderOpState << ".addAttributes(attributes);\n";
1321
1322   // Result types
1323   SmallVector<std::string, 2> resultTypes(op.getNumResults(), resultType);
1324   body << "    " << builderOpState << ".addTypes({"
1325        << llvm::join(resultTypes, ", ") << "});\n";
1326   body << "  }\n";
1327 }
1328
1329 /// Returns a signature of the builder as defined by a dag-typed initializer.
1330 /// Updates the context `fctx` to enable replacement of $_builder and $_state
1331 /// in the body. Reports errors at `loc`.
1332 static std::string builderSignatureFromDAG(const DagInit *init,
1333                                            ArrayRef<llvm::SMLoc> loc) {
1334   auto *defInit = dyn_cast<DefInit>(init->getOperator());
1335   if (!defInit || !defInit->getDef()->getName().equals("ins"))
1336     PrintFatalError(loc, "expected 'ins' in builders");
1337
1338   // Inject builder and state arguments.
1339   llvm::SmallVector<std::string, 8> arguments;
1340   arguments.reserve(init->getNumArgs() + 2);
1341   arguments.push_back(llvm::formatv("::mlir::OpBuilder &{0}", builder).str());
1342   arguments.push_back(
1343       llvm::formatv("::mlir::OperationState &{0}", builderOpState).str());
1344
1345   // Accept either a StringInit or a DefInit with two string values as dag
1346   // arguments. The former corresponds to the type, the latter to the type and
1347   // the default value. Similarly to C++, once an argument with a default value
1348   // is detected, the following arguments must have default values as well.
1349   bool seenDefaultValue = false;
1350   for (unsigned i = 0, e = init->getNumArgs(); i < e; ++i) {
1351     // If no name is provided, generate one.
1352     StringInit *argName = init->getArgName(i);
1353     std::string name =
1354         argName ? argName->getValue().str() : "odsArg" + std::to_string(i);
1355
1356     Init *argInit = init->getArg(i);
1357     StringRef type;
1358     std::string defaultValue;
1359     if (StringInit *strType = dyn_cast<StringInit>(argInit)) {
1360       type = strType->getValue();
1361     } else {
1362       const Record *typeAndDefaultValue = cast<DefInit>(argInit)->getDef();
1363       type = typeAndDefaultValue->getValueAsString("type");
1364       StringRef defaultValueRef =
1365           typeAndDefaultValue->getValueAsString("defaultValue");
1366       if (!defaultValueRef.empty()) {
1367         seenDefaultValue = true;
1368         defaultValue = llvm::formatv(" = {0}", defaultValueRef).str();
1369       }
1370     }
1371     if (seenDefaultValue && defaultValue.empty())
1372       PrintFatalError(loc,
1373                       "expected an argument with default value after other "
1374                       "arguments with default values");
1375     arguments.push_back(
1376         llvm::formatv("{0} {1}{2}", type, name, defaultValue).str());
1377   }
1378
1379   return llvm::join(arguments, ", ");
1380 }
1381
1382 void OpEmitter::genBuilder() {
1383   // Handle custom builders if provided.
1384   // TODO: Create wrapper class for OpBuilder to hide the native
1385   // TableGen API calls here.
1386   {
1387     auto *listInit = dyn_cast_or_null<ListInit>(def.getValueInit("builders"));
1388     if (listInit) {
1389       for (Init *init : listInit->getValues()) {
1390         Record *builderDef = cast<DefInit>(init)->getDef();
1391         std::string paramStr = builderSignatureFromDAG(
1392             builderDef->getValueAsDag("dagParams"), op.getLoc());
1393
1394         StringRef body = builderDef->getValueAsString("body");
1395         bool hasBody = !body.empty();
1396         OpMethod::Property properties =
1397             hasBody ? OpMethod::MP_Static : OpMethod::MP_StaticDeclaration;
1398         auto *method =
1399             opClass.addMethodAndPrune("void", "build", properties, paramStr);
1400
1401         FmtContext fctx;
1402         fctx.withBuilder(builder);
1403         fctx.addSubst("_state", builderOpState);
1404         if (hasBody)
1405           method->body() << tgfmt(body, &fctx);
1406       }
1407     }
1408     if (op.skipDefaultBuilders()) {
1409       if (!listInit || listInit->empty())
1410         PrintFatalError(
1411             op.getLoc(),
1412             "default builders are skipped and no custom builders provided");
1413       return;
1414     }
1415   }
1416
1417   // Generate default builders that requires all result type, operands, and
1418   // attributes as parameters.
1419
1420   // We generate three classes of builders here:
1421   // 1. one having a stand-alone parameter for each operand / attribute, and
1422   genSeparateArgParamBuilder();
1423   // 2. one having an aggregated parameter for all result types / operands /
1424   //    attributes, and
1425   genCollectiveParamBuilder();
1426   // 3. one having a stand-alone parameter for each operand and attribute,
1427   //    use the first operand or attribute's type as all result types
1428   //    to facilitate different call patterns.
1429   if (op.getNumVariableLengthResults() == 0) {
1430     if (op.getTrait("::mlir::OpTrait::SameOperandsAndResultType")) {
1431       genUseOperandAsResultTypeSeparateParamBuilder();
1432       genUseOperandAsResultTypeCollectiveParamBuilder();
1433     }
1434     if (op.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType"))
1435       genUseAttrAsResultTypeBuilder();
1436   }
1437 }
1438
1439 void OpEmitter::genCollectiveParamBuilder() {
1440   int numResults = op.getNumResults();
1441   int numVariadicResults = op.getNumVariableLengthResults();
1442   int numNonVariadicResults = numResults - numVariadicResults;
1443
1444   int numOperands = op.getNumOperands();
1445   int numVariadicOperands = op.getNumVariableLengthOperands();
1446   int numNonVariadicOperands = numOperands - numVariadicOperands;
1447
1448   SmallVector<OpMethodParameter, 4> paramList;
1449   paramList.emplace_back("::mlir::OpBuilder &", "");
1450   paramList.emplace_back("::mlir::OperationState &", builderOpState);
1451   paramList.emplace_back("::mlir::TypeRange", "resultTypes");
1452   paramList.emplace_back("::mlir::ValueRange", "operands");
1453   // Provide default value for `attributes` when its the last parameter
1454   StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}";
1455   paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
1456                          "attributes", attributesDefaultValue);
1457   if (op.getNumVariadicRegions())
1458     paramList.emplace_back("unsigned", "numRegions");
1459
1460   auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
1461                                       std::move(paramList));
1462   // If the builder is redundant, skip generating the method
1463   if (!m)
1464     return;
1465   auto &body = m->body();
1466
1467   // Operands
1468   if (numVariadicOperands == 0 || numNonVariadicOperands != 0)
1469     body << "  assert(operands.size()"
1470          << (numVariadicOperands != 0 ? " >= " : " == ")
1471          << numNonVariadicOperands
1472          << "u && \"mismatched number of parameters\");\n";
1473   body << "  " << builderOpState << ".addOperands(operands);\n";
1474
1475   // Attributes
1476   body << "  " << builderOpState << ".addAttributes(attributes);\n";
1477
1478   // Create the correct number of regions
1479   if (int numRegions = op.getNumRegions()) {
1480     body << llvm::formatv(
1481         "  for (unsigned i = 0; i != {0}; ++i)\n",
1482         (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions)));
1483     body << "    (void)" << builderOpState << ".addRegion();\n";
1484   }
1485
1486   // Result types
1487   if (numVariadicResults == 0 || numNonVariadicResults != 0)
1488     body << "  assert(resultTypes.size()"
1489          << (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults
1490          << "u && \"mismatched number of return types\");\n";
1491   body << "  " << builderOpState << ".addTypes(resultTypes);\n";
1492
1493   // Generate builder that infers type too.
1494   // TODO: Expand to handle regions and successors.
1495   if (canInferType(op) && op.getNumSuccessors() == 0)
1496     genInferredTypeCollectiveParamBuilder();
1497 }
1498
1499 void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> &paramList,
1500                                SmallVectorImpl<std::string> &resultTypeNames,
1501                                TypeParamKind typeParamKind,
1502                                AttrParamKind attrParamKind) {
1503   resultTypeNames.clear();
1504   auto numResults = op.getNumResults();
1505   resultTypeNames.reserve(numResults);
1506
1507   paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
1508   paramList.emplace_back("::mlir::OperationState &", builderOpState);
1509
1510   switch (typeParamKind) {
1511   case TypeParamKind::None:
1512     break;
1513   case TypeParamKind::Separate: {
1514     // Add parameters for all return types
1515     for (int i = 0; i < numResults; ++i) {
1516       const auto &result = op.getResult(i);
1517       std::string resultName = std::string(result.name);
1518       if (resultName.empty())
1519         resultName = std::string(formatv("resultType{0}", i));
1520
1521       StringRef type =
1522           result.isVariadic() ? "::mlir::TypeRange" : "::mlir::Type";
1523       OpMethodParameter::Property properties = OpMethodParameter::PP_None;
1524       if (result.isOptional())
1525         properties = OpMethodParameter::PP_Optional;
1526
1527       paramList.emplace_back(type, resultName, properties);
1528       resultTypeNames.emplace_back(std::move(resultName));
1529     }
1530   } break;
1531   case TypeParamKind::Collective: {
1532     paramList.emplace_back("::mlir::TypeRange", "resultTypes");
1533     resultTypeNames.push_back("resultTypes");
1534   } break;
1535   }
1536
1537   // Add parameters for all arguments (operands and attributes).
1538
1539   int numOperands = 0;
1540   int numAttrs = 0;
1541
1542   int defaultValuedAttrStartIndex = op.getNumArgs();
1543   if (attrParamKind == AttrParamKind::UnwrappedValue) {
1544     // Calculate the start index from which we can attach default values in the
1545     // builder declaration.
1546     for (int i = op.getNumArgs() - 1; i >= 0; --i) {
1547       auto *namedAttr = op.getArg(i).dyn_cast<tblgen::NamedAttribute *>();
1548       if (!namedAttr || !namedAttr->attr.hasDefaultValue())
1549         break;
1550
1551       if (!canUseUnwrappedRawValue(namedAttr->attr))
1552         break;
1553
1554       // Creating an APInt requires us to provide bitwidth, value, and
1555       // signedness, which is complicated compared to others. Similarly
1556       // for APFloat.
1557       // TODO: Adjust the 'returnType' field of such attributes
1558       // to support them.
1559       StringRef retType = namedAttr->attr.getReturnType();
1560       if (retType == "::llvm::APInt" || retType == "::llvm::APFloat")
1561         break;
1562
1563       defaultValuedAttrStartIndex = i;
1564     }
1565   }
1566
1567   for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
1568     auto argument = op.getArg(i);
1569     if (argument.is<tblgen::NamedTypeConstraint *>()) {
1570       const auto &operand = op.getOperand(numOperands);
1571       StringRef type =
1572           operand.isVariadic() ? "::mlir::ValueRange" : "::mlir::Value";
1573       OpMethodParameter::Property properties = OpMethodParameter::PP_None;
1574       if (operand.isOptional())
1575         properties = OpMethodParameter::PP_Optional;
1576
1577       paramList.emplace_back(type, getArgumentName(op, numOperands),
1578                              properties);
1579       ++numOperands;
1580     } else {
1581       const auto &namedAttr = op.getAttribute(numAttrs);
1582       const auto &attr = namedAttr.attr;
1583
1584       OpMethodParameter::Property properties = OpMethodParameter::PP_None;
1585       if (attr.isOptional())
1586         properties = OpMethodParameter::PP_Optional;
1587
1588       StringRef type;
1589       switch (attrParamKind) {
1590       case AttrParamKind::WrappedAttr:
1591         type = attr.getStorageType();
1592         break;
1593       case AttrParamKind::UnwrappedValue:
1594         if (canUseUnwrappedRawValue(attr))
1595           type = attr.getReturnType();
1596         else
1597           type = attr.getStorageType();
1598         break;
1599       }
1600
1601       std::string defaultValue;
1602       // Attach default value if requested and possible.
1603       if (attrParamKind == AttrParamKind::UnwrappedValue &&
1604           i >= defaultValuedAttrStartIndex) {
1605         bool isString = attr.getReturnType() == "::llvm::StringRef";
1606         if (isString)
1607           defaultValue.append("\"");
1608         defaultValue += attr.getDefaultValue();
1609         if (isString)
1610           defaultValue.append("\"");
1611       }
1612       paramList.emplace_back(type, namedAttr.name, defaultValue, properties);
1613       ++numAttrs;
1614     }
1615   }
1616
1617   /// Insert parameters for each successor.
1618   for (const NamedSuccessor &succ : op.getSuccessors()) {
1619     StringRef type =
1620         succ.isVariadic() ? "::mlir::BlockRange" : "::mlir::Block *";
1621     paramList.emplace_back(type, succ.name);
1622   }
1623
1624   /// Insert parameters for variadic regions.
1625   for (const NamedRegion &region : op.getRegions())
1626     if (region.isVariadic())
1627       paramList.emplace_back("unsigned",
1628                              llvm::formatv("{0}Count", region.name).str());
1629 }
1630
1631 void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
1632                                                        bool isRawValueAttr) {
1633   // Push all operands to the result.
1634   for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
1635     std::string argName = getArgumentName(op, i);
1636     if (op.getOperand(i).isOptional())
1637       body << "  if (" << argName << ")\n  ";
1638     body << "  " << builderOpState << ".addOperands(" << argName << ");\n";
1639   }
1640
1641   // If the operation has the operand segment size attribute, add it here.
1642   if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
1643     body << "  " << builderOpState
1644          << ".addAttribute(\"operand_segment_sizes\", "
1645             "odsBuilder.getI32VectorAttr({";
1646     interleaveComma(llvm::seq<int>(0, op.getNumOperands()), body, [&](int i) {
1647       if (op.getOperand(i).isOptional())
1648         body << "(" << getArgumentName(op, i) << " ? 1 : 0)";
1649       else if (op.getOperand(i).isVariadic())
1650         body << "static_cast<int32_t>(" << getArgumentName(op, i) << ".size())";
1651       else
1652         body << "1";
1653     });
1654     body << "}));\n";
1655   }
1656
1657   // Push all attributes to the result.
1658   for (const auto &namedAttr : op.getAttributes()) {
1659     auto &attr = namedAttr.attr;
1660     if (!attr.isDerivedAttr()) {
1661       bool emitNotNullCheck = attr.isOptional();
1662       if (emitNotNullCheck) {
1663         body << formatv("  if ({0}) ", namedAttr.name) << "{\n";
1664       }
1665       if (isRawValueAttr && canUseUnwrappedRawValue(attr)) {
1666         // If this is a raw value, then we need to wrap it in an Attribute
1667         // instance.
1668         FmtContext fctx;
1669         fctx.withBuilder("odsBuilder");
1670
1671         std::string builderTemplate =
1672             std::string(attr.getConstBuilderTemplate());
1673
1674         // For StringAttr, its constant builder call will wrap the input in
1675         // quotes, which is correct for normal string literals, but incorrect
1676         // here given we use function arguments. So we need to strip the
1677         // wrapping quotes.
1678         if (StringRef(builderTemplate).contains("\"$0\""))
1679           builderTemplate = replaceAllSubstrs(builderTemplate, "\"$0\"", "$0");
1680
1681         std::string value =
1682             std::string(tgfmt(builderTemplate, &fctx, namedAttr.name));
1683         body << formatv("  {0}.addAttribute(\"{1}\", {2});\n", builderOpState,
1684                         namedAttr.name, value);
1685       } else {
1686         body << formatv("  {0}.addAttribute(\"{1}\", {1});\n", builderOpState,
1687                         namedAttr.name);
1688       }
1689       if (emitNotNullCheck) {
1690         body << "  }\n";
1691       }
1692     }
1693   }
1694
1695   // Create the correct number of regions.
1696   for (const NamedRegion &region : op.getRegions()) {
1697     if (region.isVariadic())
1698       body << formatv("  for (unsigned i = 0; i < {0}Count; ++i)\n  ",
1699                       region.name);
1700
1701     body << "  (void)" << builderOpState << ".addRegion();\n";
1702   }
1703
1704   // Push all successors to the result.
1705   for (const NamedSuccessor &namedSuccessor : op.getSuccessors()) {
1706     body << formatv("  {0}.addSuccessors({1});\n", builderOpState,
1707                     namedSuccessor.name);
1708   }
1709 }
1710
1711 void OpEmitter::genCanonicalizerDecls() {
1712   if (!def.getValueAsBit("hasCanonicalizer"))
1713     return;
1714
1715   SmallVector<OpMethodParameter, 2> paramList;
1716   paramList.emplace_back("::mlir::OwningRewritePatternList &", "results");
1717   paramList.emplace_back("::mlir::MLIRContext *", "context");
1718   opClass.addMethodAndPrune("void", "getCanonicalizationPatterns",
1719                             OpMethod::MP_StaticDeclaration,
1720                             std::move(paramList));
1721 }
1722
1723 void OpEmitter::genFolderDecls() {
1724   bool hasSingleResult =
1725       op.getNumResults() == 1 && op.getNumVariableLengthResults() == 0;
1726
1727   if (def.getValueAsBit("hasFolder")) {
1728     if (hasSingleResult) {
1729       opClass.addMethodAndPrune(
1730           "::mlir::OpFoldResult", "fold", OpMethod::MP_Declaration,
1731           "::llvm::ArrayRef<::mlir::Attribute>", "operands");
1732     } else {
1733       SmallVector<OpMethodParameter, 2> paramList;
1734       paramList.emplace_back("::llvm::ArrayRef<::mlir::Attribute>", "operands");
1735       paramList.emplace_back("::llvm::SmallVectorImpl<::mlir::OpFoldResult> &",
1736                              "results");
1737       opClass.addMethodAndPrune("::mlir::LogicalResult", "fold",
1738                                 OpMethod::MP_Declaration, std::move(paramList));
1739     }
1740   }
1741 }
1742
1743 void OpEmitter::genOpInterfaceMethods(const tblgen::InterfaceOpTrait *opTrait) {
1744   auto interface = opTrait->getOpInterface();
1745
1746   // Get the set of methods that should always be declared.
1747   auto alwaysDeclaredMethodsVec = opTrait->getAlwaysDeclaredMethods();
1748   llvm::StringSet<> alwaysDeclaredMethods;
1749   alwaysDeclaredMethods.insert(alwaysDeclaredMethodsVec.begin(),
1750                                alwaysDeclaredMethodsVec.end());
1751
1752   for (const InterfaceMethod &method : interface.getMethods()) {
1753     // Don't declare if the method has a body.
1754     if (method.getBody())
1755       continue;
1756     // Don't declare if the method has a default implementation and the op
1757     // didn't request that it always be declared.
1758     if (method.getDefaultImplementation() &&
1759         !alwaysDeclaredMethods.count(method.getName()))
1760       continue;
1761     genOpInterfaceMethod(method);
1762   }
1763 }
1764
1765 OpMethod *OpEmitter::genOpInterfaceMethod(const InterfaceMethod &method,
1766                                           bool declaration) {
1767   SmallVector<OpMethodParameter, 4> paramList;
1768   for (const InterfaceMethod::Argument &arg : method.getArguments())
1769     paramList.emplace_back(arg.type, arg.name);
1770
1771   auto properties = method.isStatic() ? OpMethod::MP_Static : OpMethod::MP_None;
1772   if (declaration)
1773     properties =
1774         static_cast<OpMethod::Property>(properties | OpMethod::MP_Declaration);
1775   return opClass.addMethodAndPrune(method.getReturnType(), method.getName(),
1776                                    properties, std::move(paramList));
1777 }
1778
1779 void OpEmitter::genOpInterfaceMethods() {
1780   for (const auto &trait : op.getTraits()) {
1781     if (const auto *opTrait = dyn_cast<tblgen::InterfaceOpTrait>(&trait))
1782       if (opTrait->shouldDeclareMethods())
1783         genOpInterfaceMethods(opTrait);
1784   }
1785 }
1786
1787 void OpEmitter::genSideEffectInterfaceMethods() {
1788   enum EffectKind { Operand, Result, Symbol, Static };
1789   struct EffectLocation {
1790     /// The effect applied.
1791     SideEffect effect;
1792
1793     /// The index if the kind is not static.
1794     unsigned index : 30;
1795
1796     /// The kind of the location.
1797     unsigned kind : 2;
1798   };
1799
1800   StringMap<SmallVector<EffectLocation, 1>> interfaceEffects;
1801   auto resolveDecorators = [&](Operator::var_decorator_range decorators,
1802                                unsigned index, unsigned kind) {
1803     for (auto decorator : decorators)
1804       if (SideEffect *effect = dyn_cast<SideEffect>(&decorator)) {
1805         opClass.addTrait(effect->getInterfaceTrait());
1806         interfaceEffects[effect->getBaseEffectName()].push_back(
1807             EffectLocation{*effect, index, kind});
1808       }
1809   };
1810
1811   // Collect effects that were specified via:
1812   /// Traits.
1813   for (const auto &trait : op.getTraits()) {
1814     const auto *opTrait = dyn_cast<tblgen::SideEffectTrait>(&trait);
1815     if (!opTrait)
1816       continue;
1817     auto &effects = interfaceEffects[opTrait->getBaseEffectName()];
1818     for (auto decorator : opTrait->getEffects())
1819       effects.push_back(EffectLocation{cast<SideEffect>(decorator),
1820                                        /*index=*/0, EffectKind::Static});
1821   }
1822   /// Attributes and Operands.
1823   for (unsigned i = 0, operandIt = 0, e = op.getNumArgs(); i != e; ++i) {
1824     Argument arg = op.getArg(i);
1825     if (arg.is<NamedTypeConstraint *>()) {
1826       resolveDecorators(op.getArgDecorators(i), operandIt, EffectKind::Operand);
1827       ++operandIt;
1828       continue;
1829     }
1830     const NamedAttribute *attr = arg.get<NamedAttribute *>();
1831     if (attr->attr.getBaseAttr().isSymbolRefAttr())
1832       resolveDecorators(op.getArgDecorators(i), i, EffectKind::Symbol);
1833   }
1834   /// Results.
1835   for (unsigned i = 0, e = op.getNumResults(); i != e; ++i)
1836     resolveDecorators(op.getResultDecorators(i), i, EffectKind::Result);
1837
1838   // The code used to add an effect instance.
1839   // {0}: The effect class.
1840   // {1}: Optional value or symbol reference.
1841   // {1}: The resource class.
1842   const char *addEffectCode =
1843       "  effects.emplace_back({0}::get(), {1}{2}::get());\n";
1844
1845   for (auto &it : interfaceEffects) {
1846     // Generate the 'getEffects' method.
1847     std::string type = llvm::formatv("::mlir::SmallVectorImpl<::mlir::"
1848                                      "SideEffects::EffectInstance<{0}>> &",
1849                                      it.first())
1850                            .str();
1851     auto *getEffects =
1852         opClass.addMethodAndPrune("void", "getEffects", type, "effects");
1853     auto &body = getEffects->body();
1854
1855     // Add effect instances for each of the locations marked on the operation.
1856     for (auto &location : it.second) {
1857       StringRef effect = location.effect.getName();
1858       StringRef resource = location.effect.getResource();
1859       if (location.kind == EffectKind::Static) {
1860         // A static instance has no attached value.
1861         body << llvm::formatv(addEffectCode, effect, "", resource).str();
1862       } else if (location.kind == EffectKind::Symbol) {
1863         // A symbol reference requires adding the proper attribute.
1864         const auto *attr = op.getArg(location.index).get<NamedAttribute *>();
1865         if (attr->attr.isOptional()) {
1866           body << "  if (auto symbolRef = " << attr->name << "Attr())\n  "
1867                << llvm::formatv(addEffectCode, effect, "symbolRef, ", resource)
1868                       .str();
1869         } else {
1870           body << llvm::formatv(addEffectCode, effect, attr->name + "(), ",
1871                                 resource)
1872                       .str();
1873         }
1874       } else {
1875         // Otherwise this is an operand/result, so we need to attach the Value.
1876         body << "  for (::mlir::Value value : getODS"
1877              << (location.kind == EffectKind::Operand ? "Operands" : "Results")
1878              << "(" << location.index << "))\n  "
1879              << llvm::formatv(addEffectCode, effect, "value, ", resource).str();
1880       }
1881     }
1882   }
1883 }
1884
1885 void OpEmitter::genTypeInterfaceMethods() {
1886   if (!op.allResultTypesKnown())
1887     return;
1888   // Generate 'inferReturnTypes' method declaration using the interface method
1889   // declared in 'InferTypeOpInterface' op interface.
1890   const auto *trait = dyn_cast<InterfaceOpTrait>(
1891       op.getTrait("::mlir::InferTypeOpInterface::Trait"));
1892   auto interface = trait->getOpInterface();
1893   OpMethod *method = [&]() -> OpMethod * {
1894     for (const InterfaceMethod &interfaceMethod : interface.getMethods()) {
1895       if (interfaceMethod.getName() == "inferReturnTypes") {
1896         return genOpInterfaceMethod(interfaceMethod, /*declaration=*/false);
1897       }
1898     }
1899     assert(0 && "unable to find inferReturnTypes interface method");
1900     return nullptr;
1901   }();
1902   auto &body = method->body();
1903   body << "  inferredReturnTypes.resize(" << op.getNumResults() << ");\n";
1904
1905   FmtContext fctx;
1906   fctx.withBuilder("odsBuilder");
1907   body << "  ::mlir::Builder odsBuilder(context);\n";
1908
1909   auto emitType =
1910       [&](const tblgen::Operator::ArgOrType &type) -> OpMethodBody & {
1911     if (type.isArg()) {
1912       auto argIndex = type.getArg();
1913       assert(!op.getArg(argIndex).is<NamedAttribute *>());
1914       auto arg = op.getArgToOperandOrAttribute(argIndex);
1915       if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand)
1916         return body << "operands[" << arg.operandOrAttributeIndex()
1917                     << "].getType()";
1918       return body << "attributes[" << arg.operandOrAttributeIndex()
1919                   << "].getType()";
1920     } else {
1921       return body << tgfmt(*type.getType().getBuilderCall(), &fctx);
1922     }
1923   };
1924
1925   for (int i = 0, e = op.getNumResults(); i != e; ++i) {
1926     body << "  inferredReturnTypes[" << i << "] = ";
1927     auto types = op.getSameTypeAsResult(i);
1928     emitType(types[0]) << ";\n";
1929     if (types.size() == 1)
1930       continue;
1931     // TODO: We could verify equality here, but skipping that for verification.
1932   }
1933   body << "  return ::mlir::success();";
1934 }
1935
1936 void OpEmitter::genParser() {
1937   if (!hasStringAttribute(def, "parser") ||
1938       hasStringAttribute(def, "assemblyFormat"))
1939     return;
1940
1941   SmallVector<OpMethodParameter, 2> paramList;
1942   paramList.emplace_back("::mlir::OpAsmParser &", "parser");
1943   paramList.emplace_back("::mlir::OperationState &", "result");
1944   auto *method =
1945       opClass.addMethodAndPrune("::mlir::ParseResult", "parse",
1946                                 OpMethod::MP_Static, std::move(paramList));
1947
1948   FmtContext fctx;
1949   fctx.addSubst("cppClass", opClass.getClassName());
1950   auto parser = def.getValueAsString("parser").ltrim().rtrim(" \t\v\f\r");
1951   method->body() << "  " << tgfmt(parser, &fctx);
1952 }
1953
1954 void OpEmitter::genPrinter() {
1955   if (hasStringAttribute(def, "assemblyFormat"))
1956     return;
1957
1958   auto valueInit = def.getValueInit("printer");
1959   StringInit *stringInit = dyn_cast<StringInit>(valueInit);
1960   if (!stringInit)
1961     return;
1962
1963   auto *method =
1964       opClass.addMethodAndPrune("void", "print", "::mlir::OpAsmPrinter &", "p");
1965   FmtContext fctx;
1966   fctx.addSubst("cppClass", opClass.getClassName());
1967   auto printer = stringInit->getValue().ltrim().rtrim(" \t\v\f\r");
1968   method->body() << "  " << tgfmt(printer, &fctx);
1969 }
1970
1971 void OpEmitter::genVerifier() {
1972   auto *method = opClass.addMethodAndPrune("::mlir::LogicalResult", "verify");
1973   auto &body = method->body();
1974   body << "  if (failed(" << op.getAdaptorName()
1975        << "(*this).verify((*this)->getLoc()))) "
1976        << "return ::mlir::failure();\n";
1977
1978   auto *valueInit = def.getValueInit("verifier");
1979   StringInit *stringInit = dyn_cast<StringInit>(valueInit);
1980   bool hasCustomVerify = stringInit && !stringInit->getValue().empty();
1981   populateSubstitutions(op, "(*this)->getAttr", "this->getODSOperands",
1982                         "this->getODSResults", verifyCtx);
1983
1984   genAttributeVerifier(op, "(*this)->getAttr", "emitOpError(",
1985                        /*emitVerificationRequiringOp=*/true, verifyCtx, body);
1986   genOperandResultVerifier(body, op.getOperands(), "operand");
1987   genOperandResultVerifier(body, op.getResults(), "result");
1988
1989   for (auto &trait : op.getTraits()) {
1990     if (auto *t = dyn_cast<tblgen::PredOpTrait>(&trait)) {
1991       body << tgfmt("  if (!($0))\n    "
1992                     "return emitOpError(\"failed to verify that $1\");\n",
1993                     &verifyCtx, tgfmt(t->getPredTemplate(), &verifyCtx),
1994                     t->getDescription());
1995     }
1996   }
1997
1998   genRegionVerifier(body);
1999   genSuccessorVerifier(body);
2000
2001   if (hasCustomVerify) {
2002     FmtContext fctx;
2003     fctx.addSubst("cppClass", opClass.getClassName());
2004     auto printer = stringInit->getValue().ltrim().rtrim(" \t\v\f\r");
2005     body << "  " << tgfmt(printer, &fctx);
2006   } else {
2007     body << "  return ::mlir::success();\n";
2008   }
2009 }
2010
2011 void OpEmitter::genOperandResultVerifier(OpMethodBody &body,
2012                                          Operator::value_range values,
2013                                          StringRef valueKind) {
2014   FmtContext fctx;
2015
2016   body << "  {\n";
2017   body << "    unsigned index = 0; (void)index;\n";
2018
2019   for (auto staticValue : llvm::enumerate(values)) {
2020     bool hasPredicate = staticValue.value().hasPredicate();
2021     bool isOptional = staticValue.value().isOptional();
2022     if (!hasPredicate && !isOptional)
2023       continue;
2024     body << formatv("    auto valueGroup{2} = getODS{0}{1}s({2});\n",
2025                     // Capitalize the first letter to match the function name
2026                     valueKind.substr(0, 1).upper(), valueKind.substr(1),
2027                     staticValue.index());
2028
2029     // If the constraint is optional check that the value group has at most 1
2030     // value.
2031     if (isOptional) {
2032       body << formatv("    if (valueGroup{0}.size() > 1)\n"
2033                       "      return emitOpError(\"{1} group starting at #\") "
2034                       "<< index << \" requires 0 or 1 element, but found \" << "
2035                       "valueGroup{0}.size();\n",
2036                       staticValue.index(), valueKind);
2037     }
2038
2039     // Otherwise, if there is no predicate there is nothing left to do.
2040     if (!hasPredicate)
2041       continue;
2042     // Emit a loop to check all the dynamic values in the pack.
2043     StringRef constraintFn = staticVerifierEmitter.getTypeConstraintFn(
2044         staticValue.value().constraint);
2045     body << "    for (::mlir::Value v : valueGroup" << staticValue.index()
2046          << ") {\n"
2047          << "      if (::mlir::failed(" << constraintFn
2048          << "(getOperation(), v.getType(), \"" << valueKind << "\", index)))\n"
2049          << "        return ::mlir::failure();\n"
2050          << "      ++index;\n"
2051          << "    }\n";
2052   }
2053
2054   body << "  }\n";
2055 }
2056
2057 void OpEmitter::genRegionVerifier(OpMethodBody &body) {
2058   // If we have no regions, there is nothing more to do.
2059   unsigned numRegions = op.getNumRegions();
2060   if (numRegions == 0)
2061     return;
2062
2063   body << "{\n";
2064   body << "    unsigned index = 0; (void)index;\n";
2065
2066   for (unsigned i = 0; i < numRegions; ++i) {
2067     const auto &region = op.getRegion(i);
2068     if (region.constraint.getPredicate().isNull())
2069       continue;
2070
2071     body << "    for (::mlir::Region &region : ";
2072     body << formatv(region.isVariadic()
2073                         ? "{0}()"
2074                         : "::mlir::MutableArrayRef<::mlir::Region>((*this)"
2075                           "->getRegion({1}))",
2076                     region.name, i);
2077     body << ") {\n";
2078     auto constraint = tgfmt(region.constraint.getConditionTemplate(),
2079                             &verifyCtx.withSelf("region"))
2080                           .str();
2081
2082     body << formatv("      (void)region;\n"
2083                     "      if (!({0})) {\n        "
2084                     "return emitOpError(\"region #\") << index << \" {1}"
2085                     "failed to "
2086                     "verify constraint: {2}\";\n      }\n",
2087                     constraint,
2088                     region.name.empty() ? "" : "('" + region.name + "') ",
2089                     region.constraint.getDescription())
2090          << "      ++index;\n"
2091          << "    }\n";
2092   }
2093   body << "  }\n";
2094 }
2095
2096 void OpEmitter::genSuccessorVerifier(OpMethodBody &body) {
2097   // If we have no successors, there is nothing more to do.
2098   unsigned numSuccessors = op.getNumSuccessors();
2099   if (numSuccessors == 0)
2100     return;
2101
2102   body << "{\n";
2103   body << "    unsigned index = 0; (void)index;\n";
2104
2105   for (unsigned i = 0; i < numSuccessors; ++i) {
2106     const auto &successor = op.getSuccessor(i);
2107     if (successor.constraint.getPredicate().isNull())
2108       continue;
2109
2110     if (successor.isVariadic()) {
2111       body << formatv("    for (::mlir::Block *successor : {0}()) {\n",
2112                       successor.name);
2113     } else {
2114       body << "    {\n";
2115       body << formatv("      ::mlir::Block *successor = {0}();\n",
2116                       successor.name);
2117     }
2118     auto constraint = tgfmt(successor.constraint.getConditionTemplate(),
2119                             &verifyCtx.withSelf("successor"))
2120                           .str();
2121
2122     body << formatv("      (void)successor;\n"
2123                     "      if (!({0})) {\n        "
2124                     "return emitOpError(\"successor #\") << index << \"('{1}') "
2125                     "failed to "
2126                     "verify constraint: {2}\";\n      }\n",
2127                     constraint, successor.name,
2128                     successor.constraint.getDescription())
2129          << "      ++index;\n"
2130          << "    }\n";
2131   }
2132   body << "  }\n";
2133 }
2134
2135 /// Add a size count trait to the given operation class.
2136 static void addSizeCountTrait(OpClass &opClass, StringRef traitKind,
2137                               int numTotal, int numVariadic) {
2138   if (numVariadic != 0) {
2139     if (numTotal == numVariadic)
2140       opClass.addTrait("::mlir::OpTrait::Variadic" + traitKind + "s");
2141     else
2142       opClass.addTrait("::mlir::OpTrait::AtLeastN" + traitKind + "s<" +
2143                        Twine(numTotal - numVariadic) + ">::Impl");
2144     return;
2145   }
2146   switch (numTotal) {
2147   case 0:
2148     opClass.addTrait("::mlir::OpTrait::Zero" + traitKind);
2149     break;
2150   case 1:
2151     opClass.addTrait("::mlir::OpTrait::One" + traitKind);
2152     break;
2153   default:
2154     opClass.addTrait("::mlir::OpTrait::N" + traitKind + "s<" + Twine(numTotal) +
2155                      ">::Impl");
2156     break;
2157   }
2158 }
2159
2160 void OpEmitter::genTraits() {
2161   // Add region size trait.
2162   unsigned numRegions = op.getNumRegions();
2163   unsigned numVariadicRegions = op.getNumVariadicRegions();
2164   addSizeCountTrait(opClass, "Region", numRegions, numVariadicRegions);
2165
2166   // Add result size traits.
2167   int numResults = op.getNumResults();
2168   int numVariadicResults = op.getNumVariableLengthResults();
2169   addSizeCountTrait(opClass, "Result", numResults, numVariadicResults);
2170
2171   // For single result ops with a known specific type, generate a OneTypedResult
2172   // trait.
2173   if (numResults == 1 && numVariadicResults == 0) {
2174     auto cppName = op.getResults().begin()->constraint.getCPPClassName();
2175     opClass.addTrait("::mlir::OpTrait::OneTypedResult<" + cppName + ">::Impl");
2176   }
2177
2178   // Add successor size trait.
2179   unsigned numSuccessors = op.getNumSuccessors();
2180   unsigned numVariadicSuccessors = op.getNumVariadicSuccessors();
2181   addSizeCountTrait(opClass, "Successor", numSuccessors, numVariadicSuccessors);
2182
2183   // Add variadic size trait and normal op traits.
2184   int numOperands = op.getNumOperands();
2185   int numVariadicOperands = op.getNumVariableLengthOperands();
2186
2187   // Add operand size trait.
2188   if (numVariadicOperands != 0) {
2189     if (numOperands == numVariadicOperands)
2190       opClass.addTrait("::mlir::OpTrait::VariadicOperands");
2191     else
2192       opClass.addTrait("::mlir::OpTrait::AtLeastNOperands<" +
2193                        Twine(numOperands - numVariadicOperands) + ">::Impl");
2194   } else {
2195     switch (numOperands) {
2196     case 0:
2197       opClass.addTrait("::mlir::OpTrait::ZeroOperands");
2198       break;
2199     case 1:
2200       opClass.addTrait("::mlir::OpTrait::OneOperand");
2201       break;
2202     default:
2203       opClass.addTrait("::mlir::OpTrait::NOperands<" + Twine(numOperands) +
2204                        ">::Impl");
2205       break;
2206     }
2207   }
2208
2209   // Add the native and interface traits.
2210   for (const auto &trait : op.getTraits()) {
2211     if (auto opTrait = dyn_cast<tblgen::NativeOpTrait>(&trait))
2212       opClass.addTrait(opTrait->getTrait());
2213     else if (auto opTrait = dyn_cast<tblgen::InterfaceOpTrait>(&trait))
2214       opClass.addTrait(opTrait->getTrait());
2215   }
2216 }
2217
2218 void OpEmitter::genOpNameGetter() {
2219   auto *method = opClass.addMethodAndPrune(
2220       "::llvm::StringRef", "getOperationName", OpMethod::MP_Static);
2221   method->body() << "  return \"" << op.getOperationName() << "\";\n";
2222 }
2223
2224 void OpEmitter::genOpAsmInterface() {
2225   // If the user only has one results or specifically added the Asm trait,
2226   // then don't generate it for them. We specifically only handle multi result
2227   // operations, because the name of a single result in the common case is not
2228   // interesting(generally 'result'/'output'/etc.).
2229   // TODO: We could also add a flag to allow operations to opt in to this
2230   // generation, even if they only have a single operation.
2231   int numResults = op.getNumResults();
2232   if (numResults <= 1 || op.getTrait("::mlir::OpAsmOpInterface::Trait"))
2233     return;
2234
2235   SmallVector<StringRef, 4> resultNames(numResults);
2236   for (int i = 0; i != numResults; ++i)
2237     resultNames[i] = op.getResultName(i);
2238
2239   // Don't add the trait if none of the results have a valid name.
2240   if (llvm::all_of(resultNames, [](StringRef name) { return name.empty(); }))
2241     return;
2242   opClass.addTrait("::mlir::OpAsmOpInterface::Trait");
2243
2244   // Generate the right accessor for the number of results.
2245   auto *method = opClass.addMethodAndPrune(
2246       "void", "getAsmResultNames", "::mlir::OpAsmSetValueNameFn", "setNameFn");
2247   auto &body = method->body();
2248   for (int i = 0; i != numResults; ++i) {
2249     body << "  auto resultGroup" << i << " = getODSResults(" << i << ");\n"
2250          << "  if (!llvm::empty(resultGroup" << i << "))\n"
2251          << "    setNameFn(*resultGroup" << i << ".begin(), \""
2252          << resultNames[i] << "\");\n";
2253   }
2254 }
2255
2256 //===----------------------------------------------------------------------===//
2257 // OpOperandAdaptor emitter
2258 //===----------------------------------------------------------------------===//
2259
2260 namespace {
2261 // Helper class to emit Op operand adaptors to an output stream.  Operand
2262 // adaptors are wrappers around ArrayRef<Value> that provide named operand
2263 // getters identical to those defined in the Op.
2264 class OpOperandAdaptorEmitter {
2265 public:
2266   static void emitDecl(const Operator &op, raw_ostream &os);
2267   static void emitDef(const Operator &op, raw_ostream &os);
2268
2269 private:
2270   explicit OpOperandAdaptorEmitter(const Operator &op);
2271
2272   // Add verification function. This generates a verify method for the adaptor
2273   // which verifies all the op-independent attribute constraints.
2274   void addVerification();
2275
2276   const Operator &op;
2277   Class adaptor;
2278 };
2279 } // end namespace
2280
2281 OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(const Operator &op)
2282     : op(op), adaptor(op.getAdaptorName()) {
2283   adaptor.newField("::mlir::ValueRange", "odsOperands");
2284   adaptor.newField("::mlir::DictionaryAttr", "odsAttrs");
2285   const auto *attrSizedOperands =
2286       op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments");
2287   {
2288     SmallVector<OpMethodParameter, 2> paramList;
2289     paramList.emplace_back("::mlir::ValueRange", "values");
2290     paramList.emplace_back("::mlir::DictionaryAttr", "attrs",
2291                            attrSizedOperands ? "" : "nullptr");
2292     auto *constructor = adaptor.addConstructorAndPrune(std::move(paramList));
2293
2294     constructor->addMemberInitializer("odsOperands", "values");
2295     constructor->addMemberInitializer("odsAttrs", "attrs");
2296   }
2297
2298   {
2299     auto *constructor = adaptor.addConstructorAndPrune(
2300         llvm::formatv("{0}&", op.getCppClassName()).str(), "op");
2301     constructor->addMemberInitializer("odsOperands", "op->getOperands()");
2302     constructor->addMemberInitializer("odsAttrs", "op->getAttrDictionary()");
2303   }
2304
2305   std::string sizeAttrInit =
2306       formatv(adapterSegmentSizeAttrInitCode, "operand_segment_sizes");
2307   generateNamedOperandGetters(op, adaptor, sizeAttrInit,
2308                               /*rangeType=*/"::mlir::ValueRange",
2309                               /*rangeBeginCall=*/"odsOperands.begin()",
2310                               /*rangeSizeCall=*/"odsOperands.size()",
2311                               /*getOperandCallPattern=*/"odsOperands[{0}]");
2312
2313   FmtContext fctx;
2314   fctx.withBuilder("::mlir::Builder(odsAttrs.getContext())");
2315
2316   auto emitAttr = [&](StringRef name, Attribute attr) {
2317     auto &body = adaptor.addMethodAndPrune(attr.getStorageType(), name)->body();
2318     body << "  assert(odsAttrs && \"no attributes when constructing adapter\");"
2319          << "\n  " << attr.getStorageType() << " attr = "
2320          << "odsAttrs.get(\"" << name << "\").";
2321     if (attr.hasDefaultValue() || attr.isOptional())
2322       body << "dyn_cast_or_null<";
2323     else
2324       body << "cast<";
2325     body << attr.getStorageType() << ">();\n";
2326
2327     if (attr.hasDefaultValue()) {
2328       // Use the default value if attribute is not set.
2329       // TODO: this is inefficient, we are recreating the attribute for every
2330       // call. This should be set instead.
2331       std::string defaultValue = std::string(
2332           tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
2333       body << "  if (!attr)\n    attr = " << defaultValue << ";\n";
2334     }
2335     body << "  return attr;\n";
2336   };
2337
2338   for (auto &namedAttr : op.getAttributes()) {
2339     const auto &name = namedAttr.name;
2340     const auto &attr = namedAttr.attr;
2341     if (!attr.isDerivedAttr())
2342       emitAttr(name, attr);
2343   }
2344
2345   // Add verification function.
2346   addVerification();
2347 }
2348
2349 void OpOperandAdaptorEmitter::addVerification() {
2350   auto *method = adaptor.addMethodAndPrune("::mlir::LogicalResult", "verify",
2351                                            "::mlir::Location", "loc");
2352   auto &body = method->body();
2353
2354   const char *checkAttrSizedValueSegmentsCode = R"(
2355   {
2356     auto sizeAttr = odsAttrs.get("{0}").cast<::mlir::DenseIntElementsAttr>();
2357     auto numElements = sizeAttr.getType().cast<::mlir::ShapedType>().getNumElements();
2358     if (numElements != {1})
2359       return emitError(loc, "'{0}' attribute for specifying {2} segments "
2360                        "must have {1} elements, but got ") << numElements;
2361   }
2362   )";
2363
2364   // Verify a few traits first so that we can use
2365   // getODSOperands()/getODSResults() in the rest of the verifier.
2366   for (auto &trait : op.getTraits()) {
2367     if (auto *t = dyn_cast<tblgen::NativeOpTrait>(&trait)) {
2368       if (t->getTrait() == "::mlir::OpTrait::AttrSizedOperandSegments") {
2369         body << formatv(checkAttrSizedValueSegmentsCode,
2370                         "operand_segment_sizes", op.getNumOperands(),
2371                         "operand");
2372       } else if (t->getTrait() == "::mlir::OpTrait::AttrSizedResultSegments") {
2373         body << formatv(checkAttrSizedValueSegmentsCode, "result_segment_sizes",
2374                         op.getNumResults(), "result");
2375       }
2376     }
2377   }
2378
2379   FmtContext verifyCtx;
2380   populateSubstitutions(op, "odsAttrs.get", "getODSOperands",
2381                         "<no results should be generated>", verifyCtx);
2382   genAttributeVerifier(op, "odsAttrs.get",
2383                        Twine("emitError(loc, \"'") + op.getOperationName() +
2384                            "' op \"",
2385                        /*emitVerificationRequiringOp*/ false, verifyCtx, body);
2386
2387   body << "  return ::mlir::success();";
2388 }
2389
2390 void OpOperandAdaptorEmitter::emitDecl(const Operator &op, raw_ostream &os) {
2391   OpOperandAdaptorEmitter(op).adaptor.writeDeclTo(os);
2392 }
2393
2394 void OpOperandAdaptorEmitter::emitDef(const Operator &op, raw_ostream &os) {
2395   OpOperandAdaptorEmitter(op).adaptor.writeDefTo(os);
2396 }
2397
2398 // Emits the opcode enum and op classes.
2399 static void emitOpClasses(const RecordKeeper &recordKeeper,
2400                           const std::vector<Record *> &defs, raw_ostream &os,
2401                           bool emitDecl) {
2402   // First emit forward declaration for each class, this allows them to refer
2403   // to each others in traits for example.
2404   if (emitDecl) {
2405     os << "#if defined(GET_OP_CLASSES) || defined(GET_OP_FWD_DEFINES)\n";
2406     os << "#undef GET_OP_FWD_DEFINES\n";
2407     for (auto *def : defs) {
2408       Operator op(*def);
2409       NamespaceEmitter emitter(os, op.getDialect());
2410       os << "class " << op.getCppClassName() << ";\n";
2411     }
2412     os << "#endif\n\n";
2413   }
2414
2415   IfDefScope scope("GET_OP_CLASSES", os);
2416   if (defs.empty())
2417     return;
2418
2419   // Generate all of the locally instantiated methods first.
2420   StaticVerifierFunctionEmitter staticVerifierEmitter(recordKeeper, defs, os,
2421                                                       emitDecl);
2422   for (auto *def : defs) {
2423     Operator op(*def);
2424     NamespaceEmitter emitter(os, op.getDialect());
2425     if (emitDecl) {
2426       os << formatv(opCommentHeader, op.getQualCppClassName(), "declarations");
2427       OpOperandAdaptorEmitter::emitDecl(op, os);
2428       OpEmitter::emitDecl(op, os, staticVerifierEmitter);
2429     } else {
2430       os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions");
2431       OpOperandAdaptorEmitter::emitDef(op, os);
2432       OpEmitter::emitDef(op, os, staticVerifierEmitter);
2433     }
2434   }
2435 }
2436
2437 // Emits a comma-separated list of the ops.
2438 static void emitOpList(const std::vector<Record *> &defs, raw_ostream &os) {
2439   IfDefScope scope("GET_OP_LIST", os);
2440
2441   interleave(
2442       // TODO: We are constructing the Operator wrapper instance just for
2443       // getting it's qualified class name here. Reduce the overhead by having a
2444       // lightweight version of Operator class just for that purpose.
2445       defs, [&os](Record *def) { os << Operator(def).getQualCppClassName(); },
2446       [&os]() { os << ",\n"; });
2447 }
2448
2449 static std::string getOperationName(const Record &def) {
2450   auto prefix = def.getValueAsDef("opDialect")->getValueAsString("name");
2451   auto opName = def.getValueAsString("opName");
2452   if (prefix.empty())
2453     return std::string(opName);
2454   return std::string(llvm::formatv("{0}.{1}", prefix, opName));
2455 }
2456
2457 static std::vector<Record *>
2458 getAllDerivedDefinitions(const RecordKeeper &recordKeeper,
2459                          StringRef className) {
2460   Record *classDef = recordKeeper.getClass(className);
2461   if (!classDef)
2462     PrintFatalError("ERROR: Couldn't find the `" + className + "' class!\n");
2463
2464   llvm::Regex includeRegex(opIncFilter), excludeRegex(opExcFilter);
2465   std::vector<Record *> defs;
2466   for (const auto &def : recordKeeper.getDefs()) {
2467     if (!def.second->isSubClassOf(classDef))
2468       continue;
2469     // Include if no include filter or include filter matches.
2470     if (!opIncFilter.empty() &&
2471         !includeRegex.match(getOperationName(*def.second)))
2472       continue;
2473     // Unless there is an exclude filter and it matches.
2474     if (!opExcFilter.empty() &&
2475         excludeRegex.match(getOperationName(*def.second)))
2476       continue;
2477     defs.push_back(def.second.get());
2478   }
2479
2480   return defs;
2481 }
2482
2483 static bool emitOpDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
2484   emitSourceFileHeader("Op Declarations", os);
2485
2486   const auto &defs = getAllDerivedDefinitions(recordKeeper, "Op");
2487   emitOpClasses(recordKeeper, defs, os, /*emitDecl=*/true);
2488
2489   return false;
2490 }
2491
2492 static bool emitOpDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
2493   emitSourceFileHeader("Op Definitions", os);
2494
2495   const auto &defs = getAllDerivedDefinitions(recordKeeper, "Op");
2496   emitOpList(defs, os);
2497   emitOpClasses(recordKeeper, defs, os, /*emitDecl=*/false);
2498
2499   return false;
2500 }
2501
2502 static mlir::GenRegistration
2503     genOpDecls("gen-op-decls", "Generate op declarations",
2504                [](const RecordKeeper &records, raw_ostream &os) {
2505                  return emitOpDecls(records, os);
2506                });
2507
2508 static mlir::GenRegistration genOpDefs("gen-op-defs", "Generate op definitions",
2509                                        [](const RecordKeeper &records,
2510                                           raw_ostream &os) {
2511                                          return emitOpDefs(records, os);
2512                                        });