592e6cb774fbf6151d255ef83d3ad0d0326fa18a
[lldb.git] / mlir / tools / mlir-linalg-ods-gen / mlir-linalg-ods-gen.cpp
1 //===- mlir-linalg-ods-gen.cpp - Linalg ODS generation from math form -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the implementation for the Tensor Comprehension-inspired
10 // parser and ODS pretty-printer for specifying Linalg "named ops" from a
11 // mathematical form.
12 //
13 //===----------------------------------------------------------------------===//
14
15 #include "mlir/IR/AffineExpr.h"
16 #include "mlir/IR/AffineMap.h"
17 #include "mlir/IR/MLIRContext.h"
18 #include "mlir/IR/OpImplementation.h"
19 #include "mlir/Support/FileUtilities.h"
20 #include "mlir/Support/LLVM.h"
21 #include "mlir/Support/LogicalResult.h"
22 #include "llvm/ADT/SetVector.h"
23 #include "llvm/Support/Casting.h"
24 #include "llvm/Support/CommandLine.h"
25 #include "llvm/Support/FormatVariadic.h"
26 #include "llvm/Support/ToolOutputFile.h"
27
28 #define DEBUG_TYPE "linalg-ods-gen"
29
30 static llvm::cl::OptionCategory ODSGenCat("Linalg ODS Gen");
31
32 // Commandline options
33 static llvm::cl::opt<std::string>
34     inputFilename(llvm::cl::Positional, llvm::cl::desc("<input file>"),
35                   llvm::cl::init("-"), llvm::cl::value_desc("filename"));
36
37 static llvm::cl::opt<std::string>
38     outputFilename("o", llvm::cl::desc("Output filename"),
39                    llvm::cl::value_desc("filename"), llvm::cl::init("-"));
40
41 static llvm::cl::opt<bool>
42     genODSDecl("gen-ods-decl", llvm::cl::desc("Emit the ODS ops declarations."),
43                llvm::cl::cat(ODSGenCat));
44
45 static llvm::cl::opt<bool>
46     genODSImpl("gen-impl", llvm::cl::desc("Emit the ops implementations"),
47                llvm::cl::init(false), llvm::cl::cat(ODSGenCat));
48
49 static llvm::cl::opt<bool> testEmitIncludeTdHeader(
50     "test-emit-include-td-header",
51     llvm::cl::desc("Include LinalgStructuredOps.td for end-to-end "
52                    "tblgen testing."),
53     llvm::cl::init(false), llvm::cl::cat(ODSGenCat));
54
55 using llvm::SetVector;
56 using llvm::SMLoc;
57 using llvm::StringRef;
58 using llvm::Twine;
59
60 using namespace mlir;
61
62 //===----------------------------------------------------------------------===//
63 // Lexer
64 //===----------------------------------------------------------------------===//
65
66 namespace {
67 /// This class represents a specific token in the input format.
68 class Token {
69 public:
70   enum class Kind {
71     // Markers.
72     eof,
73     error,
74
75     // Tokens with no info.
76     colon,
77     comma,
78     equal,
79     gt,
80     l_brace,
81     l_paren,
82     lt,
83     minus,
84     plus,
85     r_brace,
86     r_paren,
87     semicolon,
88     star,
89
90     // Keywords.
91     kw_def,
92     FIRST_KEYWORD = kw_def,
93     kw_ods_def,
94     kw_floordiv,
95     kw_ceildiv,
96     kw_mod,
97     LAST_KEYWORD = kw_mod,
98
99     // String valued tokens.
100     id,
101     integer,
102   };
103
104   Token(Kind kind, StringRef spelling) : kind(kind), spelling(spelling) {}
105
106   /// Return the bytes that make up this token.
107   StringRef getSpelling() const { return spelling; }
108
109   /// Return the kind of this token.
110   Kind getKind() const { return kind; }
111
112   /// Return a location for this token.
113   llvm::SMLoc getLoc() const {
114     return llvm::SMLoc::getFromPointer(spelling.data());
115   }
116
117   /// Return if this token is a keyword.
118   bool isKeyword() const {
119     return kind >= Kind::FIRST_KEYWORD && kind <= Kind::LAST_KEYWORD;
120   }
121   bool is(Kind k) const { return kind == k; }
122   bool isNot(Kind k) const { return kind != k; }
123
124   Optional<uint64_t> getUInt64IntegerValue() const {
125     bool isHex = spelling.size() > 1 && spelling[1] == 'x';
126
127     uint64_t result = 0;
128     if (spelling.getAsInteger(isHex ? 0 : 10, result))
129       return None;
130     return result;
131   }
132
133 private:
134   /// Discriminator that indicates the kind of token this is.
135   Kind kind;
136
137   /// A reference to the entire token contents; this is always a pointer into
138   /// a memory buffer owned by the source manager.
139   StringRef spelling;
140 };
141
142 /// This class implements a simple lexer.
143 class Lexer {
144 public:
145   Lexer(llvm::SourceMgr &mgr);
146
147   /// Lex the next token and return it.
148   Token lexToken();
149
150   /// Emit an error to the lexer with the given location and message.
151   Token emitError(llvm::SMLoc loc, const Twine &msg);
152   Token emitError(const char *loc, const Twine &msg);
153
154 private:
155   Token formToken(Token::Kind kind, const char *tokStart) {
156     return Token(kind, StringRef(tokStart, curPtr - tokStart));
157   }
158
159   /// Return the next character in the stream.
160   int getNextChar();
161
162   /// Lex an identifier.
163   Token lexIdentifier(const char *tokStart);
164
165   // Lex an integer.
166   Token lexInteger(const char *tokStart);
167
168   // Skip a comment line, starting with a '//'.
169   void skipComment();
170
171   llvm::SourceMgr &srcMgr;
172   StringRef curBuffer;
173   const char *curPtr;
174 };
175 } // end anonymous namespace
176
177 Lexer::Lexer(llvm::SourceMgr &mgr) : srcMgr(mgr) {
178   curBuffer = srcMgr.getMemoryBuffer(mgr.getMainFileID())->getBuffer();
179   curPtr = curBuffer.begin();
180 }
181
182 Token Lexer::emitError(llvm::SMLoc loc, const Twine &msg) {
183   srcMgr.PrintMessage(loc, llvm::SourceMgr::DK_Error, msg);
184   return formToken(Token::Kind::error, loc.getPointer());
185 }
186 Token Lexer::emitError(const char *loc, const Twine &msg) {
187   return emitError(llvm::SMLoc::getFromPointer(loc), msg);
188 }
189
190 int Lexer::getNextChar() {
191   char curChar = *curPtr++;
192   switch (curChar) {
193   default:
194     return (unsigned char)curChar;
195   case 0: {
196     // A nul character in the stream is either the end of the current buffer
197     // or a random nul in the file. Disambiguate that here.
198     if (curPtr - 1 != curBuffer.end())
199       return 0;
200
201     // Otherwise, return end of file.
202     --curPtr;
203     return EOF;
204   }
205   case '\n':
206   case '\r':
207     // Handle the newline character by ignoring it and incrementing the line
208     // count. However, be careful about 'dos style' files with \n\r in them.
209     // Only treat a \n\r or \r\n as a single line.
210     if ((*curPtr == '\n' || (*curPtr == '\r')) && *curPtr != curChar)
211       ++curPtr;
212     return '\n';
213   }
214 }
215
216 Token Lexer::lexToken() {
217   while (true) {
218     const char *tokStart = curPtr;
219
220     // This always consumes at least one character.
221     int curChar = getNextChar();
222     switch (curChar) {
223     default:
224       // Handle identifiers: [a-zA-Z_]
225       if (isalpha(curChar) || curChar == '_')
226         return lexIdentifier(tokStart);
227
228       // Handle integers: [0-9]
229       if (isdigit(curChar))
230         return lexInteger(tokStart);
231
232       // Unknown character, emit an error.
233       return emitError(tokStart, "unexpected character");
234
235     case EOF:
236       // Return EOF denoting the end of lexing.
237       return formToken(Token::Kind::eof, tokStart);
238
239     // Lex punctuation.
240     case ':':
241       return formToken(Token::Kind::colon, tokStart);
242     case ',':
243       return formToken(Token::Kind::comma, tokStart);
244     case '=':
245       return formToken(Token::Kind::equal, tokStart);
246     case '{':
247       return formToken(Token::Kind::l_brace, tokStart);
248     case '(':
249       return formToken(Token::Kind::l_paren, tokStart);
250     case '}':
251       return formToken(Token::Kind::r_brace, tokStart);
252     case ')':
253       return formToken(Token::Kind::r_paren, tokStart);
254     case '<':
255       return formToken(Token::Kind::lt, tokStart);
256     case '>':
257       return formToken(Token::Kind::gt, tokStart);
258     case '+':
259       return formToken(Token::Kind::plus, tokStart);
260     case '-':
261       return formToken(Token::Kind::minus, tokStart);
262     case ';':
263       return formToken(Token::Kind::semicolon, tokStart);
264     case '*':
265       return formToken(Token::Kind::star, tokStart);
266     case '/':
267       if (*curPtr == '/') {
268         skipComment();
269         continue;
270       }
271       // Unknown character, emit an error.
272       return emitError(tokStart, "unexpected character: not a comment");
273
274     // Ignore whitespace characters.
275     case 0:
276     case ' ':
277     case '\t':
278     case '\n':
279       return lexToken();
280     }
281   }
282 }
283
284 Token Lexer::lexIdentifier(const char *tokStart) {
285   // Match the rest of the identifier regex: [0-9a-zA-Z_\-]*
286   while (isalnum(*curPtr) || *curPtr == '_' || *curPtr == '-')
287     ++curPtr;
288
289   // Check to see if this identifier is a keyword.
290   StringRef str(tokStart, curPtr - tokStart);
291   Token::Kind kind = StringSwitch<Token::Kind>(str)
292                          .Case("def", Token::Kind::kw_def)
293                          .Case("ods_def", Token::Kind::kw_ods_def)
294                          .Case("floordiv", Token::Kind::kw_floordiv)
295                          .Case("ceildiv", Token::Kind::kw_ceildiv)
296                          .Case("mod", Token::Kind::kw_mod)
297                          .Default(Token::Kind::id);
298
299   return Token(kind, str);
300 }
301
302 Token Lexer::lexInteger(const char *tokStart) {
303   // Match the rest of the identifier regex: [0-9a-zA-Z_\-]*
304   while (isdigit(*curPtr))
305     ++curPtr;
306
307   StringRef str(tokStart, curPtr - tokStart);
308   return Token(Token::Kind::integer, str);
309 }
310
311 /// Skip a comment line, starting with a '//'.
312 void Lexer::skipComment() {
313   // Advance over the second '/' in a '//' comment.
314   assert(*curPtr == '/');
315   ++curPtr;
316
317   while (true) {
318     switch (*curPtr++) {
319     case '\n':
320     case '\r':
321       // Newline is end of comment.
322       return;
323     case 0:
324       // If this is the end of the buffer, end the comment.
325       if (curPtr - 1 == curBuffer.end()) {
326         --curPtr;
327         return;
328       }
329       LLVM_FALLTHROUGH;
330     default:
331       // Skip over other characters.
332       break;
333     }
334   }
335 }
336
337 namespace {
338
339 class Parser {
340 public:
341   Parser(llvm::SourceMgr &mgr, MLIRContext *ctx)
342       : lexer(mgr), curToken(lexer.lexToken()), context(ctx) {}
343
344   //===--------------------------------------------------------------------===//
345   // Lexer Utilities
346   //===--------------------------------------------------------------------===//
347
348   /// Advance the current lexer onto the next token.
349   void consumeToken() {
350     assert(curToken.getKind() != Token::Kind::eof &&
351            curToken.getKind() != Token::Kind::error &&
352            "shouldn't advance past EOF or errors");
353     curToken = lexer.lexToken();
354   }
355   void consumeToken(Token::Kind kind) {
356     assert(curToken.getKind() == kind && "unexpected token");
357     curToken = lexer.lexToken();
358   }
359   LogicalResult parseToken(Token::Kind kind, const Twine &msg) {
360     if (curToken.getKind() != kind)
361       return emitError(curToken.getLoc(), msg);
362     consumeToken();
363     return success();
364   }
365   LogicalResult emitError(llvm::SMLoc loc, const Twine &msg) {
366     lexer.emitError(loc, msg);
367     return failure();
368   }
369   LogicalResult emitError(const Twine &msg) {
370     return emitError(curToken.getLoc(), msg);
371   }
372   bool consumeIf(Token::Kind kind) {
373     if (curToken.isNot(kind))
374       return false;
375     consumeToken(kind);
376     return true;
377   }
378   LogicalResult
379   parseCommaSeparatedList(llvm::function_ref<ParseResult()> parseElement) {
380     // Non-empty case starts with an element.
381     if (parseElement())
382       return failure();
383
384     // Otherwise we have a list of comma separated elements.
385     while (consumeIf(Token::Kind::comma)) {
386       if (parseElement())
387         return failure();
388     }
389     return success();
390   }
391   LogicalResult
392   parseCommaSeparatedListUntil(Token::Kind rightToken,
393                                llvm::function_ref<ParseResult()> parseElement,
394                                bool allowEmptyList) {
395     // Handle the empty case.
396     if (curToken.is(rightToken)) {
397       if (!allowEmptyList)
398         return emitError("expected list element");
399       consumeToken(rightToken);
400       return success();
401     }
402
403     if (failed(parseCommaSeparatedList(parseElement)) ||
404         failed(
405             parseToken(rightToken, "expected ',' or right-terminating token")))
406       return failure();
407
408     return success();
409   }
410
411   Lexer lexer;
412   Token curToken;
413   MLIRContext *context;
414 };
415 } // namespace
416
417 //===----------------------------------------------------------------------===//
418 // Affine parsing.
419 //===----------------------------------------------------------------------===//
420
421 namespace {
422
423 /// Lower precedence ops (all at the same precedence level). LNoOp is false in
424 /// the boolean sense.
425 enum AffineLowPrecOp {
426   /// Null value.
427   LNoOp,
428   Add,
429   Sub
430 };
431
432 /// Higher precedence ops - all at the same precedence level. HNoOp is false
433 /// in the boolean sense.
434 enum AffineHighPrecOp {
435   /// Null value.
436   HNoOp,
437   Mul,
438   FloorDiv,
439   CeilDiv,
440   Mod
441 };
442
443 using AffineDimList = SmallVector<std::pair<StringRef, AffineExpr>, 4>;
444 using AffineSymbolList = SmallVector<std::pair<StringRef, AffineExpr>, 4>;
445
446 /// This is a specialized parser for affine expressions.
447 class AffineParser {
448 public:
449   explicit AffineParser(Parser &p,
450                         std::function<AffineExpr(StringRef)> bareIdParsingHook,
451                         AffineDimList &dimList, AffineSymbolList &symbolList)
452       : parser(p), bareIdFallback(bareIdParsingHook), dims(dimList),
453         symbols(symbolList) {}
454
455   /// Parse a comma-separated list of affine exprs.
456   SmallVector<AffineExpr, 4>
457   parseAffineExprs(Token::Kind lDelim = Token::Kind::l_paren,
458                    Token::Kind rDelim = Token::Kind::r_paren);
459
460   /// Parse a single affine expr.`.
461   AffineExpr parseAffineExpr();
462
463 private:
464   // Binary affine op parsing.
465   AffineLowPrecOp consumeIfLowPrecOp();
466   AffineHighPrecOp consumeIfHighPrecOp();
467
468   // AffineExpr parsing.
469   AffineExpr parseParentheticalExpr();
470   AffineExpr parseNegateExpression(AffineExpr lhs);
471   AffineExpr parseIntegerExpr();
472   AffineExpr parseBareIdExpr();
473
474   AffineExpr getAffineBinaryOpExpr(AffineHighPrecOp op, AffineExpr lhs,
475                                    AffineExpr rhs, SMLoc opLoc);
476   AffineExpr getAffineBinaryOpExpr(AffineLowPrecOp op, AffineExpr lhs,
477                                    AffineExpr rhs);
478   AffineExpr parseAffineOperandExpr(AffineExpr lhs);
479   AffineExpr parseAffineLowPrecOpExpr(AffineExpr llhs, AffineLowPrecOp llhsOp);
480   AffineExpr parseAffineHighPrecOpExpr(AffineExpr llhs, AffineHighPrecOp llhsOp,
481                                        SMLoc llhsOpLoc);
482
483   Parser &parser;
484   std::function<AffineExpr(StringRef)> bareIdFallback;
485   AffineDimList &dims;
486   AffineSymbolList &symbols;
487 };
488 } // end anonymous namespace
489
490 /// Create an affine binary high precedence op expression (mul's, div's, mod).
491 /// opLoc is the location of the op token to be used to report errors
492 /// for non-conforming expressions.
493 AffineExpr AffineParser::getAffineBinaryOpExpr(AffineHighPrecOp op,
494                                                AffineExpr lhs, AffineExpr rhs,
495                                                SMLoc opLoc) {
496   switch (op) {
497   case Mul:
498     if (!lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant()) {
499       parser.emitError(opLoc,
500                        "non-affine expression: at least one of the multiply "
501                        "operands has to be either a constant or symbolic");
502       return nullptr;
503     }
504     return lhs * rhs;
505   case FloorDiv:
506     if (!rhs.isSymbolicOrConstant()) {
507       parser.emitError(opLoc,
508                        "non-affine expression: right operand of floordiv "
509                        "has to be either a constant or symbolic");
510       return nullptr;
511     }
512     return lhs.floorDiv(rhs);
513   case CeilDiv:
514     if (!rhs.isSymbolicOrConstant()) {
515       parser.emitError(opLoc, "non-affine expression: right operand of ceildiv "
516                               "has to be either a constant or symbolic");
517       return nullptr;
518     }
519     return lhs.ceilDiv(rhs);
520   case Mod:
521     if (!rhs.isSymbolicOrConstant()) {
522       parser.emitError(opLoc, "non-affine expression: right operand of mod "
523                               "has to be either a constant or symbolic");
524       return nullptr;
525     }
526     return lhs % rhs;
527   case HNoOp:
528     llvm_unreachable("can't create affine expression for null high prec op");
529     return nullptr;
530   }
531   llvm_unreachable("Unknown AffineHighPrecOp");
532 }
533
534 /// Create an affine binary low precedence op expression (add, sub).
535 AffineExpr AffineParser::getAffineBinaryOpExpr(AffineLowPrecOp op,
536                                                AffineExpr lhs, AffineExpr rhs) {
537   switch (op) {
538   case AffineLowPrecOp::Add:
539     return lhs + rhs;
540   case AffineLowPrecOp::Sub:
541     return lhs - rhs;
542   case AffineLowPrecOp::LNoOp:
543     llvm_unreachable("can't create affine expression for null low prec op");
544     return nullptr;
545   }
546   llvm_unreachable("Unknown AffineLowPrecOp");
547 }
548
549 /// Consume this token if it is a lower precedence affine op (there are only
550 /// two precedence levels).
551 AffineLowPrecOp AffineParser::consumeIfLowPrecOp() {
552   switch (parser.curToken.getKind()) {
553   case Token::Kind::plus:
554     parser.consumeToken();
555     return AffineLowPrecOp::Add;
556   case Token::Kind::minus:
557     parser.consumeToken();
558     return AffineLowPrecOp::Sub;
559   default:
560     return AffineLowPrecOp::LNoOp;
561   }
562 }
563
564 /// Consume this token if it is a higher precedence affine op (there are only
565 /// two precedence levels)
566 AffineHighPrecOp AffineParser::consumeIfHighPrecOp() {
567   switch (parser.curToken.getKind()) {
568   case Token::Kind::star:
569     parser.consumeToken(Token::Kind::star);
570     return Mul;
571   case Token::Kind::kw_floordiv:
572     parser.consumeToken(Token::Kind::kw_floordiv);
573     return FloorDiv;
574   case Token::Kind::kw_ceildiv:
575     parser.consumeToken(Token::Kind::kw_ceildiv);
576     return CeilDiv;
577   case Token::Kind::kw_mod:
578     parser.consumeToken(Token::Kind::kw_mod);
579     return Mod;
580   default:
581     return HNoOp;
582   }
583 }
584
585 /// Parse a high precedence op expression list: mul, div, and mod are high
586 /// precedence binary ops, i.e., parse a
587 ///   expr_1 op_1 expr_2 op_2 ... expr_n
588 /// where op_1, op_2 are all a AffineHighPrecOp (mul, div, mod).
589 /// All affine binary ops are left associative.
590 /// Given llhs, returns (llhs llhsOp lhs) op rhs, or (lhs op rhs) if llhs is
591 /// null. If no rhs can be found, returns (llhs llhsOp lhs) or lhs if llhs is
592 /// null. llhsOpLoc is the location of the llhsOp token that will be used to
593 /// report an error for non-conforming expressions.
594 AffineExpr AffineParser::parseAffineHighPrecOpExpr(AffineExpr llhs,
595                                                    AffineHighPrecOp llhsOp,
596                                                    SMLoc llhsOpLoc) {
597   AffineExpr lhs = parseAffineOperandExpr(llhs);
598   if (!lhs)
599     return nullptr;
600
601   // Found an LHS. Parse the remaining expression.
602   auto opLoc = parser.curToken.getLoc();
603   if (AffineHighPrecOp op = consumeIfHighPrecOp()) {
604     if (llhs) {
605       AffineExpr expr = getAffineBinaryOpExpr(llhsOp, llhs, lhs, opLoc);
606       if (!expr)
607         return nullptr;
608       return parseAffineHighPrecOpExpr(expr, op, opLoc);
609     }
610     // No LLHS, get RHS
611     return parseAffineHighPrecOpExpr(lhs, op, opLoc);
612   }
613
614   // This is the last operand in this expression.
615   if (llhs)
616     return getAffineBinaryOpExpr(llhsOp, llhs, lhs, llhsOpLoc);
617
618   // No llhs, 'lhs' itself is the expression.
619   return lhs;
620 }
621
622 /// Parse an affine expression inside parentheses.
623 ///
624 ///   affine-expr ::= `(` affine-expr `)`
625 AffineExpr AffineParser::parseParentheticalExpr() {
626   if (failed(parser.parseToken(Token::Kind::l_paren, "expected '('")))
627     return nullptr;
628   if (parser.curToken.is(Token::Kind::r_paren))
629     return (parser.emitError("no expression inside parentheses"), nullptr);
630
631   auto expr = parseAffineExpr();
632   if (!expr)
633     return nullptr;
634   if (failed(parser.parseToken(Token::Kind::r_paren, "expected ')'")))
635     return nullptr;
636
637   return expr;
638 }
639
640 /// Parse the negation expression.
641 ///
642 ///   affine-expr ::= `-` affine-expr
643 AffineExpr AffineParser::parseNegateExpression(AffineExpr lhs) {
644   if (failed(parser.parseToken(Token::Kind::minus, "expected '-'")))
645     return nullptr;
646
647   AffineExpr operand = parseAffineOperandExpr(lhs);
648   // Since negation has the highest precedence of all ops (including high
649   // precedence ops) but lower than parentheses, we are only going to use
650   // parseAffineOperandExpr instead of parseAffineExpr here.
651   if (!operand)
652     // Extra error message although parseAffineOperandExpr would have
653     // complained. Leads to a better diagnostic.
654     return (parser.emitError("missing operand of negation"), nullptr);
655   return (-1) * operand;
656 }
657
658 /// Parse a bare id that may appear in an affine expression.
659 ///
660 ///   affine-expr ::= bare-id
661 AffineExpr AffineParser::parseBareIdExpr() {
662   if (parser.curToken.isNot(Token::Kind::id))
663     return (parser.emitError("expected id"), nullptr);
664
665   StringRef sRef = parser.curToken.getSpelling();
666   for (auto &list : {dims, symbols}) {
667     for (auto entry : list) {
668       if (entry.first == sRef) {
669         parser.consumeToken(Token::Kind::id);
670         return entry.second;
671       }
672     }
673   }
674
675   // Not found, check fallback path.
676   AffineExpr expr = bareIdFallback(sRef);
677   if (expr) {
678     parser.consumeToken(Token::Kind::id);
679     return expr;
680   }
681
682   return (parser.emitError("use of undeclared id"), nullptr);
683 }
684
685 /// Parse a positive integral constant appearing in an affine expression.
686 ///
687 ///   affine-expr ::= integer-literal
688 AffineExpr AffineParser::parseIntegerExpr() {
689   auto val = parser.curToken.getUInt64IntegerValue();
690   if (!val.hasValue() || (int64_t)val.getValue() < 0)
691     return (parser.emitError("constant too large for index"), nullptr);
692
693   parser.consumeToken(Token::Kind::integer);
694   return getAffineConstantExpr((int64_t)val.getValue(), parser.context);
695 }
696
697 /// Parses an expression that can be a valid operand of an affine expression.
698 /// lhs: if non-null, lhs is an affine expression that is the lhs of a binary
699 /// operator, the rhs of which is being parsed. This is used to determine
700 /// whether an error should be emitted for a missing right operand.
701 //  Eg: for an expression without parentheses (like i + j + k + l), each
702 //  of the four identifiers is an operand. For i + j*k + l, j*k is not an
703 //  operand expression, it's an op expression and will be parsed via
704 //  parseAffineHighPrecOpExpression(). However, for i + (j*k) + -l, (j*k) and
705 //  -l are valid operands that will be parsed by this function.
706 AffineExpr AffineParser::parseAffineOperandExpr(AffineExpr lhs) {
707   switch (parser.curToken.getKind()) {
708   case Token::Kind::id:
709     return parseBareIdExpr();
710   case Token::Kind::integer:
711     return parseIntegerExpr();
712   case Token::Kind::l_paren:
713     return parseParentheticalExpr();
714   case Token::Kind::minus:
715     return parseNegateExpression(lhs);
716   case Token::Kind::kw_ceildiv:
717   case Token::Kind::kw_floordiv:
718   case Token::Kind::kw_mod:
719   case Token::Kind::plus:
720   case Token::Kind::star:
721     if (lhs)
722       parser.emitError("missing right operand of binary operator");
723     else
724       parser.emitError("missing left operand of binary operator");
725     return nullptr;
726   default:
727     if (lhs)
728       parser.emitError("missing right operand of binary operator");
729     else
730       parser.emitError("expected affine expression");
731     return nullptr;
732   }
733 }
734
735 /// Parse affine expressions that are bare-id's, integer constants,
736 /// parenthetical affine expressions, and affine op expressions that are a
737 /// composition of those.
738 ///
739 /// All binary op's associate from left to right.
740 ///
741 /// {add, sub} have lower precedence than {mul, div, and mod}.
742 ///
743 /// Add, sub'are themselves at the same precedence level. Mul, floordiv,
744 /// ceildiv, and mod are at the same higher precedence level. Negation has
745 /// higher precedence than any binary op.
746 ///
747 /// llhs: the affine expression appearing on the left of the one being parsed.
748 /// This function will return ((llhs llhsOp lhs) op rhs) if llhs is non null,
749 /// and lhs op rhs otherwise; if there is no rhs, llhs llhsOp lhs is returned
750 /// if llhs is non-null; otherwise lhs is returned. This is to deal with left
751 /// associativity.
752 ///
753 /// Eg: when the expression is e1 + e2*e3 + e4, with e1 as llhs, this function
754 /// will return the affine expr equivalent of (e1 + (e2*e3)) + e4, where
755 /// (e2*e3) will be parsed using parseAffineHighPrecOpExpr().
756 AffineExpr AffineParser::parseAffineLowPrecOpExpr(AffineExpr llhs,
757                                                   AffineLowPrecOp llhsOp) {
758   AffineExpr lhs;
759   if (!(lhs = parseAffineOperandExpr(llhs)))
760     return nullptr;
761
762   // Found an LHS. Deal with the ops.
763   if (AffineLowPrecOp lOp = consumeIfLowPrecOp()) {
764     if (llhs) {
765       AffineExpr sum = getAffineBinaryOpExpr(llhsOp, llhs, lhs);
766       return parseAffineLowPrecOpExpr(sum, lOp);
767     }
768     // No LLHS, get RHS and form the expression.
769     return parseAffineLowPrecOpExpr(lhs, lOp);
770   }
771   auto opLoc = parser.curToken.getLoc();
772   if (AffineHighPrecOp hOp = consumeIfHighPrecOp()) {
773     // We have a higher precedence op here. Get the rhs operand for the llhs
774     // through parseAffineHighPrecOpExpr.
775     AffineExpr highRes = parseAffineHighPrecOpExpr(lhs, hOp, opLoc);
776     if (!highRes)
777       return nullptr;
778
779     // If llhs is null, the product forms the first operand of the yet to be
780     // found expression. If non-null, the op to associate with llhs is llhsOp.
781     AffineExpr expr =
782         llhs ? getAffineBinaryOpExpr(llhsOp, llhs, highRes) : highRes;
783
784     // Recurse for subsequent low prec op's after the affine high prec op
785     // expression.
786     if (AffineLowPrecOp nextOp = consumeIfLowPrecOp())
787       return parseAffineLowPrecOpExpr(expr, nextOp);
788     return expr;
789   }
790   // Last operand in the expression list.
791   if (llhs)
792     return getAffineBinaryOpExpr(llhsOp, llhs, lhs);
793   // No llhs, 'lhs' itself is the expression.
794   return lhs;
795 }
796
797 /// Parse an affine expression.
798 ///  affine-expr ::= `(` affine-expr `)`
799 ///                | `-` affine-expr
800 ///                | affine-expr `+` affine-expr
801 ///                | affine-expr `-` affine-expr
802 ///                | affine-expr `*` affine-expr
803 ///                | affine-expr `floordiv` affine-expr
804 ///                | affine-expr `ceildiv` affine-expr
805 ///                | affine-expr `mod` affine-expr
806 ///                | bare-id
807 ///                | integer-literal
808 ///
809 /// Additional conditions are checked depending on the production. For eg.,
810 /// one of the operands for `*` has to be either constant/symbolic; the second
811 /// operand for floordiv, ceildiv, and mod has to be a positive integer.
812 AffineExpr AffineParser::parseAffineExpr() {
813   return parseAffineLowPrecOpExpr(nullptr, AffineLowPrecOp::LNoOp);
814 }
815
816 SmallVector<AffineExpr, 4> AffineParser::parseAffineExprs(Token::Kind lDelim,
817                                                           Token::Kind rDelim) {
818   parser.parseToken(lDelim, "expected lDelim at start of affine expr list");
819
820   SmallVector<AffineExpr, 4> exprs;
821   auto parseElt = [&]() -> LogicalResult {
822     auto elt = parseAffineExpr();
823     exprs.push_back(elt);
824     return elt ? success() : failure();
825   };
826
827   if (failed(parser.parseCommaSeparatedListUntil(rDelim, parseElt,
828                                                  /*allowEmptyList=*/true)))
829     llvm_unreachable("Failed AffineExpr parsing");
830
831   return exprs;
832 }
833
834 //===----------------------------------------------------------------------===//
835 // TC parsing.
836 //===----------------------------------------------------------------------===//
837
838 namespace {
839
840 /// Base class for expressions involved in TC parsing.
841 struct Expression {
842   enum class Kind {
843     Uninitialized = 0,
844     TensorExpr = 1,
845     TensorUse = 2,
846   };
847
848   explicit Expression(Kind k = Kind::Uninitialized) : kind(k) {}
849   virtual ~Expression() = default;
850
851   operator bool() const { return kind != Kind::Uninitialized; }
852
853   Kind kind;
854 };
855
856 /// Encodes a tensor use of the form:
857 ///
858 ///   affine-expr-list ::= affine-expr (`,` affine-expr)*
859 ///   tensor-use ::= bare-id `(` `)`
860 ///                | bare-id `(` affine-expr-list `)`
861 ///
862 /// The affine-expr-list is stored as an AffineMap.
863 struct TensorUse : public Expression {
864   TensorUse() : TensorUse("", AffineMap()) {}
865   TensorUse(StringRef name, AffineMap map)
866       : Expression(Kind::TensorUse), tensorId(name), indexingMap(map) {}
867   TensorUse(const TensorUse &use) = default;
868
869   static bool classof(const Expression *e) {
870     return e->kind == Kind::TensorUse;
871   }
872
873   bool operator==(const TensorUse &other) const {
874     return tensorId == other.tensorId && indexingMap == other.indexingMap;
875   }
876
877   /// Visitation function. Performs preorder or postorder traversal depending on
878   /// `PreOrder` and applies `callback` on each node.
879   template <typename Lambda, bool PreOrder>
880   void visit(Lambda callback) const;
881
882   StringRef tensorId;
883   AffineMap indexingMap;
884 };
885
886 /// Encodes a tensor expression of the form:
887 ///
888 ///   op-spec ::= bare-id `<` reduction-dims-list `>`
889 ///             | bare-id
890 ///   op-arg ::= tensor-expr
891 ///            | tensor-use
892 ///   op-arg-list ::= op-arg (`,` op-arg)*
893 ///   tensor-expr ::= op-spec `(` op-arg-list `)`
894 ///
895 /// Underlying op-arg are stored by unique_ptr to base class.
896 struct TensorExpr : public Expression {
897   TensorExpr(StringRef name,
898              SmallVectorImpl<std::unique_ptr<Expression>> &&exprs,
899              ArrayRef<unsigned> reductionDims)
900       : Expression(Kind::TensorExpr), operationName(name),
901         expressions(std::move(exprs)),
902         reductionDimensions(reductionDims.begin(), reductionDims.end()) {}
903
904   static bool classof(const Expression *e) {
905     return e->kind == Kind::TensorExpr;
906   }
907
908   bool operator==(const TensorExpr &other) const {
909     if (operationName != other.operationName)
910       return false;
911     if (expressions.size() != other.expressions.size())
912       return false;
913     for (unsigned i = 0, e = expressions.size(); i < e; ++i)
914       if (*expressions[i] != *other.expressions[i])
915         return false;
916     for (unsigned i = 0, e = reductionDimensions.size(); i < e; ++i)
917       if (reductionDimensions[i] != other.reductionDimensions[i])
918         return false;
919     return true;
920   }
921
922   /// Visitation function. Performs preorder or postorder traversal depending on
923   /// `PreOrder` and applies `callback` on each node.
924   template <typename Lambda, bool PreOrder>
925   void visit(Lambda callback) const;
926
927   StringRef operationName;
928   SmallVector<std::unique_ptr<Expression>, 4> expressions;
929   SetVector<unsigned> reductionDimensions;
930 };
931
932 /// This is a specialized parser for a TCDef.
933 /// This maintains the dims it finds in an eager fashion.
934 class TCParser {
935   enum class EagerDiscoveryMode { None = 0, Symbols, Dimensions };
936
937 public:
938   explicit TCParser(Parser &p);
939
940   /// Uses the AffineParser to parse the affine exprs used in a tensor
941   /// definition. If `discoveryMode` is set to Symbols (resp. Dimensions), new
942   /// symbols (resp. dimensions) are added eagerly. Otherwise, an error is
943   /// emitted on new identifiers.
944   SmallVector<AffineExpr, 4>
945   parseAffineExprs(EagerDiscoveryMode discoveryMode, AffineDimList &dims,
946                    Token::Kind lDelim = Token::Kind::l_paren,
947                    Token::Kind rDelim = Token::Kind::r_paren);
948
949   /// Parse the information for a tensor def.
950   /// All the affine-expr must be dimensionless (i.e. contain only expressions
951   /// involving symbols and constants), but can otherwise contain arbitrary
952   /// affine expressions.
953   LogicalResult parseTensorDef(bool isOutput);
954
955   /// Parses a tensor use.
956   struct ComprehensionParsingState {
957     AffineDimList dims;
958     SmallVector<std::unique_ptr<Expression>, 4> expressions;
959     llvm::DenseMap<TensorUse, unsigned> orderedTensorArgs;
960   };
961   LogicalResult parseTensorUse(TensorUse &result,
962                                ComprehensionParsingState &state);
963
964   /// Parses a tensor expression.
965   LogicalResult parseExpression(TensorUse currentDefinition,
966                                 std::unique_ptr<Expression> &result,
967                                 ComprehensionParsingState &state);
968
969   /// Parse a single comprehension.
970   LogicalResult parseOneComprehension(StringRef cppOpName,
971                                       StringRef linalgOpName,
972                                       ComprehensionParsingState &state);
973
974   /// Parse and print the information for a TC def.
975   /// When `gen-ods-decl` is used, this prints the ODS declaration for the TC.
976   /// When `gen-impl` is used, this prints the C++ implementation for the extra
977   /// methods defined in ODS (`iterator_types`, `indexing_maps` and
978   /// `regionBuilder`).
979   LogicalResult parseAndEmitODSDef(llvm::raw_ostream &os);
980
981   /// Print the ODS class that defines a new `cppOpName` for a `linalgOpName`.
982   void printODS(llvm::raw_ostream &os, StringRef cppOpName,
983                 StringRef linalgOpName, ComprehensionParsingState &state);
984
985   /// Print the C++ StructuredOpsInterface impl of `iterator_types`.
986   void printReferenceIterators(llvm::raw_ostream &os, StringRef cppOpName,
987                                ComprehensionParsingState &state);
988
989   /// Print the C++ StructuredOpsInterface impl of `indexing_maps`.
990   void printReferenceIndexingMaps(llvm::raw_ostream &os, StringRef cppOpName,
991                                   ComprehensionParsingState &state);
992
993   /// Print the C++ StructuredOpsInterface impl of `regionBuilder`.
994   void printRegionBuilder(llvm::raw_ostream &os, StringRef cppOpName,
995                           ComprehensionParsingState &state);
996
997   /// Print the C++ impl for named ops canonicalizers and folders.
998   void printCanonicalizersAndFolders(llvm::raw_ostream &os,
999                                      StringRef cppOpName);
1000
1001 private:
1002   //===--------------------------------------------------------------------===//
1003   // Internal bookkeeping of tensors.
1004   //===--------------------------------------------------------------------===//
1005   struct RegisteredTensor {
1006     StringRef type;
1007     AffineMap shape;
1008     bool isOutput;
1009     AffineMap indexingMap;
1010     unsigned index;
1011   };
1012
1013   //===--------------------------------------------------------------------===//
1014   // Per-TC def state.
1015   //===--------------------------------------------------------------------===//
1016   /// Symbols are per TC def.
1017   AffineSymbolList symbols;
1018   /// Tensors are per TC def.
1019   llvm::StringMap<RegisteredTensor> registeredTensors;
1020   unsigned nextRegisteredTensorIndex;
1021
1022   Parser &parser;
1023 };
1024 } // namespace
1025
1026 namespace llvm {
1027
1028 template <>
1029 struct DenseMapInfo<TensorUse> {
1030   static TensorUse getEmptyKey() { return TensorUse("", AffineMap()); }
1031   static TensorUse getTombstoneKey() {
1032     return TensorUse(DenseMapInfo<StringRef>::getTombstoneKey(),
1033                      DenseMapInfo<AffineMap>::getTombstoneKey());
1034   }
1035   static unsigned getHashValue(const TensorUse &val) {
1036     return ::llvm::hash_value(val.tensorId); // don't care about collisions.
1037   }
1038   static bool isEqual(const TensorUse &LHS, const TensorUse &RHS) {
1039     return LHS == RHS;
1040   }
1041 };
1042
1043 } // namespace llvm
1044
1045 //===----------------------------------------------------------------------===//
1046 // Visitation functions.
1047 //===----------------------------------------------------------------------===//
1048
1049 template <typename Lambda, bool PreOrder>
1050 void visit(const Expression &expr, Lambda callback) {
1051   switch (expr.kind) {
1052   default:
1053     llvm_unreachable("Unexpected kind");
1054   case Expression::Kind::TensorExpr:
1055     static_cast<const TensorExpr &>(expr).visit<Lambda, PreOrder>(callback);
1056     break;
1057   case Expression::Kind::TensorUse:
1058     static_cast<const TensorUse &>(expr).visit<Lambda, PreOrder>(callback);
1059     break;
1060   }
1061 }
1062
1063 template <typename Lambda>
1064 void visitPreorder(const Expression &expr, Lambda callback) {
1065   visit<Lambda, false>(expr, callback);
1066 }
1067
1068 template <typename Lambda>
1069 void visitPostorder(Expression &expr, Lambda callback) {
1070   visit<Lambda, true>(expr, callback);
1071 }
1072
1073 template <typename Lambda, bool PreOrder>
1074 void TensorExpr::visit(Lambda callback) const {
1075   if (!PreOrder)
1076     callback(*this);
1077   for (auto &e : expressions)
1078     ::visit<Lambda, PreOrder>(*e, callback);
1079   if (PreOrder)
1080     callback(*this);
1081 }
1082
1083 template <typename Lambda, bool PreOrder>
1084 void TensorUse::visit(Lambda callback) const {
1085   callback(*this);
1086 }
1087
1088 //===----------------------------------------------------------------------===//
1089 // TC parsing functions.
1090 //===----------------------------------------------------------------------===//
1091 TCParser::TCParser(Parser &p)
1092     : symbols(), registeredTensors(), nextRegisteredTensorIndex(0), parser(p) {}
1093
1094 /// Uses the AffineParser to parse the affine exprs used in a tensor
1095 /// definition. All identifiers are interpreted as symbols, new symbols are
1096 /// added eagerly.
1097 SmallVector<AffineExpr, 4>
1098 TCParser::parseAffineExprs(EagerDiscoveryMode discoveryMode,
1099                            AffineDimList &dims, Token::Kind lDelim,
1100                            Token::Kind rDelim) {
1101   AffineParser affineParser(
1102       parser,
1103       [&](StringRef sRef) {
1104         AffineExpr expr;
1105         if (discoveryMode == EagerDiscoveryMode::Symbols) {
1106           expr = getAffineSymbolExpr(symbols.size(), parser.context);
1107           symbols.emplace_back(sRef, expr);
1108         } else if (discoveryMode == EagerDiscoveryMode::Dimensions) {
1109           expr = getAffineDimExpr(dims.size(), parser.context);
1110           dims.emplace_back(sRef, expr);
1111         }
1112         return expr;
1113       },
1114       dims, symbols);
1115   return affineParser.parseAffineExprs(lDelim, rDelim);
1116 }
1117
1118 /// Parse the information for a tensor def of the form:
1119 ///
1120 ///   affine-expr-list ::= affine-expr (`,` affine-expr )*
1121 ///   tensor-typedef ::= type `(` `)`
1122 ///                    | type `(` affine-expr-list `)`
1123 ///   tensor-def ::= bare-id `:` tensor-typedef
1124 LogicalResult TCParser::parseTensorDef(bool isOutput) {
1125   StringRef tensorId = parser.curToken.getSpelling();
1126   if (failed(parser.parseToken(Token::Kind::id, "expected an id")) ||
1127       failed(parser.parseToken(Token::Kind::colon, "expected colon")))
1128     return failure();
1129
1130   StringRef tensorType = parser.curToken.getSpelling();
1131   if (failed(parser.parseToken(Token::Kind::id, "expected an id")))
1132     return failure();
1133
1134   AffineDimList emptyDims;
1135   auto exprs = parseAffineExprs(EagerDiscoveryMode::Symbols, emptyDims);
1136   assert(emptyDims.empty() && "Unexpected dimension in tensor def");
1137   AffineMap map =
1138       AffineMap::get(/*dimCount=*/0, symbols.size(), exprs, parser.context);
1139
1140   auto iterBoolPair = registeredTensors.try_emplace(
1141       tensorId, RegisteredTensor{tensorType, map, isOutput, AffineMap(),
1142                                  nextRegisteredTensorIndex++});
1143   (void)iterBoolPair;
1144   assert(iterBoolPair.second && "Could not emplace tensor registration");
1145   LLVM_DEBUG(llvm::dbgs() << "Recorded: " << tensorId << " "
1146                           << "with typeString: " << tensorType << " "
1147                           << "and shape: " << map << "\n");
1148
1149   return success();
1150 }
1151
1152 /// Parses a tensor use of the form:
1153 ///
1154 ///   affine-expr-list ::= affine-expr (`,` affine-expr)*
1155 ///   tensor-use ::= bare-id `(` `)`
1156 ///                | bare-id `(` affine-expr-list `)`
1157 LogicalResult TCParser::parseTensorUse(TensorUse &result,
1158                                        ComprehensionParsingState &state) {
1159   StringRef tensorId = parser.curToken.getSpelling();
1160   if (failed(parser.parseToken(Token::Kind::id, "expected an id")))
1161     return failure();
1162
1163   auto exprs = parseAffineExprs(EagerDiscoveryMode::Dimensions, state.dims);
1164   AffineMap map =
1165       AffineMap::get(state.dims.size(), symbols.size(), exprs, parser.context);
1166   LLVM_DEBUG(llvm::dbgs() << "Use of tensor: " << tensorId << " map: " << map
1167                           << "\n");
1168
1169   result = TensorUse(tensorId, map);
1170   return success();
1171 }
1172
1173 /// Parses a tensor expression of the form:
1174 ///
1175 ///   op-spec ::= bare-id `<` reduction-dims-list `>`
1176 ///             | bare-id
1177 ///   op-arg ::= tensor-expr
1178 ///            | tensor-use
1179 ///   op-arg-list ::= op-arg (`,` op-arg)*
1180 ///   tensor-expr ::= op-spec `(` op-arg-list `)`
1181 LogicalResult TCParser::parseExpression(TensorUse currentDefinition,
1182                                         std::unique_ptr<Expression> &result,
1183                                         ComprehensionParsingState &state) {
1184   StringRef opOrTensor = parser.curToken.getSpelling();
1185   if (registeredTensors.count(opOrTensor) > 0) {
1186     TensorUse use;
1187     auto res = parseTensorUse(use, state);
1188     if (failed(res))
1189       return res;
1190     result = std::make_unique<TensorUse>(use);
1191     return success();
1192   }
1193
1194   if (failed(parser.parseToken(Token::Kind::id, "expected an operation")))
1195     return failure();
1196
1197   // This is an op.
1198   SmallVector<unsigned, 4> reductionDims;
1199   SmallVector<std::unique_ptr<Expression>, 4> expressions;
1200
1201   // Check if it has a reduction set, discover dimensions eagerly.
1202   if (parser.curToken.is(Token::Kind::lt)) {
1203     auto iters = parseAffineExprs(EagerDiscoveryMode::Dimensions, state.dims,
1204                                   Token::Kind::lt, Token::Kind::gt);
1205     for (auto iter : iters)
1206       reductionDims.push_back(iter.cast<AffineDimExpr>().getPosition());
1207   }
1208
1209   // If this op is a reduction, it's first argument is the `currentDefinition`
1210   // tensor use.
1211   if (!reductionDims.empty())
1212     expressions.push_back(std::make_unique<TensorUse>(currentDefinition));
1213   LLVM_DEBUG(llvm::dbgs() << "op: " << opOrTensor << "\n");
1214
1215   auto parseExpr = [&]() -> LogicalResult {
1216     std::unique_ptr<Expression> e;
1217     if (failed(parseExpression(currentDefinition, e, state)))
1218       return failure();
1219     expressions.push_back(std::move(e));
1220     return success();
1221   };
1222   if (failed(parser.parseToken(Token::Kind::l_paren, "expected '('")) ||
1223       failed(parser.parseCommaSeparatedListUntil(
1224           Token::Kind::r_paren, parseExpr, /*allowEmptyList=*/true)))
1225     return failure();
1226
1227   result = std::make_unique<TensorExpr>(opOrTensor, std::move(expressions),
1228                                         reductionDims);
1229
1230   return success();
1231 }
1232
1233 //===----------------------------------------------------------------------===//
1234 // Parse and Emit functions.
1235 //===----------------------------------------------------------------------===//
1236
1237 /// Parse the information for a single comprehension.
1238 ///
1239 ///   tensor-def-list ::= tensor-def (`,` tensor-def)*
1240 ///   tensor-expr-list ::= tensor-expr (`,` tensor-expr)*
1241 ///   comprehension ::= tensor-def-list `=` tensor-expr-list `;`
1242 LogicalResult
1243 TCParser::parseOneComprehension(StringRef cppOpName, StringRef linalgOpName,
1244                                 ComprehensionParsingState &state) {
1245   // 1. Parse LHS of `=`, these become the definitions that appear as the output
1246   // tensors or read/write buffers.
1247   SmallVector<TensorUse, 4> definitions;
1248   auto parseUse = [&]() -> LogicalResult {
1249     TensorUse use;
1250     if (failed(parseTensorUse(use, state)))
1251       return failure();
1252     definitions.push_back(use);
1253     return success();
1254   };
1255   if (failed(parser.parseCommaSeparatedListUntil(Token::Kind::equal, parseUse,
1256                                                  /*allowEmptyList=*/true)))
1257     return failure();
1258
1259   // 2. Parse RHS of `=`, this becomes the expressions from which we emit
1260   // computations.
1261   unsigned idx = 0;
1262   auto parseExpr = [&]() -> LogicalResult {
1263     std::unique_ptr<Expression> expr;
1264     if (idx >= definitions.size()) {
1265       parser.emitError("Fewer LHS definitions than RHS expressions");
1266       return failure();
1267     }
1268     if (failed(parseExpression(definitions[idx++], expr, state)))
1269       return failure();
1270     state.expressions.push_back(std::move(expr));
1271     return success();
1272   };
1273   if (failed(parser.parseCommaSeparatedListUntil(
1274           Token::Kind::semicolon, parseExpr, /*allowEmptyList=*/true)))
1275     return failure();
1276   if (idx != definitions.size()) {
1277     parser.emitError("Fewer RHS expressions than LHS definitions");
1278     return failure();
1279   }
1280
1281   // 3. Postprocess.
1282   // 3.a. Normalize all maps to the proper state.dims and symbols counts.
1283   SmallVector<TensorUse, 4> allUses;
1284   allUses.reserve(registeredTensors.size());
1285   for (auto &def : definitions)
1286     allUses.push_back(def);
1287   for (auto &pExpr : state.expressions)
1288     visitPostorder(*pExpr, [&](const Expression &e) {
1289       if (auto *use = dyn_cast<TensorUse>(&e))
1290         allUses.push_back(*use);
1291     });
1292   for (auto &use : allUses)
1293     use.indexingMap =
1294         AffineMap::get(state.dims.size(), symbols.size(),
1295                        use.indexingMap.getResults(), parser.context);
1296
1297   // 3.b. Traverse definitions
1298   llvm::DenseSet<StringRef> seenDefs;
1299   for (auto &def : definitions) {
1300     if (seenDefs.count(def.tensorId) > 0) {
1301       parser.emitError("Unexpected multi-write to a single tensor");
1302       return failure();
1303     }
1304     seenDefs.insert(def.tensorId);
1305     auto tensorIter = registeredTensors.find(def.tensorId);
1306     assert(tensorIter != registeredTensors.end() && "unregistered tensor");
1307     auto &tensor = tensorIter->getValue();
1308     tensor.indexingMap = def.indexingMap;
1309     state.orderedTensorArgs[def] = tensor.index;
1310   }
1311
1312   bool failed = false;
1313   for (auto &pExpr : state.expressions)
1314     visitPostorder(*pExpr, [&](const Expression &e) {
1315       auto *pUse = dyn_cast<TensorUse>(&e);
1316       if (failed || !pUse)
1317         return;
1318       auto &use = *pUse;
1319       LLVM_DEBUG(llvm::dbgs()
1320                  << "\nuse: " << use.tensorId << " map: " << use.indexingMap);
1321       auto tensorIter = registeredTensors.find(use.tensorId);
1322       assert(tensorIter != registeredTensors.end() && "unregistered tensor");
1323       auto &tensor = tensorIter->getValue();
1324       if (tensor.indexingMap && state.orderedTensorArgs.count(use) == 0) {
1325         LLVM_DEBUG(llvm::dbgs() << "\nexisting: " << tensor.indexingMap);
1326         parser.emitError(
1327             "Unexpected multi-read of a tensor with different accesses");
1328         failed = true;
1329         return;
1330       }
1331       seenDefs.insert(use.tensorId);
1332       tensor.indexingMap = use.indexingMap;
1333       state.orderedTensorArgs[use] = tensor.index;
1334     });
1335   if (failed)
1336     return failure();
1337
1338   return success();
1339 }
1340
1341 /// Parse and print the information for a ODS def.
1342 ///
1343 ///   tensor-def-list ::= tensor-def (`,` tensor-def )*
1344 ///
1345 ///   comprehension-list ::= comprehension comprehension*
1346 ///
1347 ///   tc-def ::= `def` bare-id `(`tensor-def-list`)` `->` `(` tensor-def-list`)`
1348 ///     `{` comprehension-list `}`
1349 ///
1350 ///   ods-def ::= `ods_def` `<` bare-id `>` `:` tc-def
1351 ///
1352 /// All the affine-expr in a `tensor-typedef` must be dimensionless (i.e.
1353 /// contain only expressions involving symbols and constants), but can
1354 /// otherwise contain arbitrary affine expressions.
1355 LogicalResult TCParser::parseAndEmitODSDef(llvm::raw_ostream &os) {
1356   if (failed(parser.parseToken(Token::Kind::kw_ods_def,
1357                                "expected 'ods_def' to define a TC ODS")) ||
1358       failed(parser.parseToken(Token::Kind::lt, "expected '<'")))
1359     return failure();
1360   StringRef cppOpName = parser.curToken.getSpelling();
1361   LLVM_DEBUG(llvm::dbgs() << "\n\nStart parsing ODS: " << cppOpName << "\n");
1362
1363   if (failed(parser.parseToken(Token::Kind::id, "expected id")) ||
1364       failed(parser.parseToken(Token::Kind::gt, "expected '>'")) ||
1365       failed(parser.parseToken(Token::Kind::colon, "expected ':'")))
1366     return failure();
1367   if (failed(parser.parseToken(Token::Kind::kw_def,
1368                                "expected 'def' to define a TC")))
1369     return failure();
1370
1371   StringRef tcName = parser.curToken.getSpelling();
1372   LLVM_DEBUG(llvm::dbgs() << "\n\nStart parsing TC: " << tcName << "\n");
1373   if (failed(parser.parseToken(Token::Kind::id, "expected id")) ||
1374       failed(parser.parseToken(Token::Kind::l_paren, "expected '('")))
1375     return failure();
1376
1377   auto parseInputDef = [&]() -> LogicalResult {
1378     return parseTensorDef(/*isOutput=*/false);
1379   };
1380   if (failed(parser.parseCommaSeparatedListUntil(
1381           Token::Kind::r_paren, parseInputDef, /*allowEmptyList=*/false)))
1382     return failure();
1383
1384   if (failed(parser.parseToken(Token::Kind::minus, "expected '-'")) ||
1385       failed(parser.parseToken(Token::Kind::gt, "expected '>'")) ||
1386       failed(parser.parseToken(Token::Kind::l_paren, "expected '('")))
1387     return failure();
1388   auto parseOutputDef = [&]() -> LogicalResult {
1389     return parseTensorDef(/*isOutput=*/true);
1390   };
1391   if (failed(parser.parseCommaSeparatedListUntil(
1392           Token::Kind::r_paren, parseOutputDef, /*allowEmptyList=*/false)))
1393     return failure();
1394
1395   // Since we don't declare symbols separately, we discover them eagerly: each
1396   // newly encountered id in a tensor shape expression is treated as a new
1397   // symbolic. At this point, all tensors have been parsed and all the symbols
1398   // that could be discovered eagerly are now known. Resize all AffineMaps to
1399   // normalize the number of eagerly discovered symbols.
1400   for (auto &tensor : registeredTensors) {
1401     auto &map = tensor.getValue().shape;
1402     map = AffineMap::get(/*dimCount=*/0, symbols.size(), map.getResults(),
1403                          parser.context);
1404   }
1405
1406   if (failed(parser.parseToken(Token::Kind::l_brace, "expected '{'")))
1407     return failure();
1408
1409   SmallVector<ComprehensionParsingState, 4> perComprehensionStates;
1410   while (parser.curToken.isNot(Token::Kind::r_brace)) {
1411     perComprehensionStates.push_back(ComprehensionParsingState());
1412     if (failed(parseOneComprehension(cppOpName, tcName,
1413                                      perComprehensionStates.back())))
1414       return failure();
1415   };
1416   parser.parseToken(Token::Kind::r_brace, "expected '}'");
1417
1418   // Print.
1419   auto nComprehensions = perComprehensionStates.size();
1420   if (nComprehensions != 1) {
1421     parser.emitError("only 1 comprehension supported for now, got: " +
1422                      llvm::Twine(nComprehensions));
1423     return failure();
1424   }
1425   if (genODSDecl) {
1426     auto &state = perComprehensionStates.back();
1427     printODS(os, cppOpName, tcName, state);
1428     os << "\n";
1429   }
1430   if (genODSImpl) {
1431     auto &state = perComprehensionStates.back();
1432     std::string extraMethods;
1433     llvm::raw_string_ostream ss(extraMethods);
1434     printReferenceIterators(ss, cppOpName, state);
1435     printReferenceIndexingMaps(ss, cppOpName, state);
1436     printRegionBuilder(ss, cppOpName, state);
1437     printCanonicalizersAndFolders(ss, cppOpName);
1438     ss.flush();
1439     os << extraMethods << "\n";
1440   }
1441
1442   return success();
1443 }
1444
1445 //===----------------------------------------------------------------------===//
1446 // Printing functions
1447 //===----------------------------------------------------------------------===//
1448
1449 /// Print the ODS class that defines a new `cppOpName` for a `linalgOpName`.
1450 void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
1451                         StringRef linalgOpName,
1452                         ComprehensionParsingState &state) {
1453   const char *header = R"FMT(  def {0} : LinalgStructuredBase_Op<"{1}", [
1454     AttrSizedOperandSegments,
1455     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
1456     SingleBlockImplicitTerminator<"YieldOp">]> {
1457       let arguments = (ins Variadic<AnyShaped>:$inputs,
1458                            Variadic<AnyShaped>:$outputs);
1459       let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
1460       let regions = (region AnyRegion:$region);
1461
1462       let skipDefaultBuilders = 1;
1463       let builders = [ OpBuilderDAG<
1464         (ins "ValueRange":$inputs, "ValueRange":$outputs),
1465         [{{
1466           $_state.addOperands(inputs);
1467           $_state.addOperands(outputs);
1468           $_state.addAttribute(
1469             "operand_segment_sizes",
1470             $_builder.getI32VectorAttr({{
1471               static_cast<int32_t>(inputs.size()),
1472               static_cast<int32_t>(outputs.size())}));
1473           buildNamedStructuredOpRegionAndAttributes<{0}>(
1474             $_builder,
1475             $_state,
1476             TypeRange(inputs),
1477             TypeRange(outputs));
1478         }]>, OpBuilderDAG<
1479         (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
1480              "ValueRange":$outputs),
1481         [{{
1482           $_state.addOperands(inputs);
1483           $_state.addOperands(outputs);
1484           $_state.addTypes(resultTensorTypes);
1485           $_state.addAttribute(
1486             "operand_segment_sizes",
1487             $_builder.getI32VectorAttr({{
1488               static_cast<int32_t>(inputs.size()),
1489               static_cast<int32_t>(outputs.size())}));
1490           buildNamedStructuredOpRegionAndAttributes<{0}>(
1491             $_builder,
1492             $_state,
1493             TypeRange(inputs),
1494             TypeRange(outputs));
1495         }]>, OpBuilderDAG<
1496         (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
1497              CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes),
1498         [{{
1499           $_state.addOperands(operands);
1500           $_state.addAttributes(attributes);
1501           $_state.addTypes(resultTensorTypes);
1502           (void)$_state.addRegion();
1503         }]>
1504       ];
1505       let printer = [{{ return ::printNamedStructuredOp(p, *this); }];
1506       let parser = [{{ return ::parseNamedStructuredOp<{0}>(parser, result); }];
1507       let hasFolder = 1;
1508       let hasCanonicalizer = 1;
1509
1510       let extraClassDeclaration = [{{
1511         // Auto-generated.
1512         ArrayAttr iterator_types();
1513         ArrayAttr indexing_maps();
1514         static void regionBuilder(Block &block);
1515         static std::function<void(Block &)> getRegionBuilder() {{ return regionBuilder; }
1516
1517         // Generic methods.
1518         static unsigned getNumRegionArgs() {{ return {4}; }
1519         std::string getLibraryCallName() {{
1520           return generateLibraryCallName(getOperation());
1521         }
1522       }];
1523   })FMT";
1524
1525   unsigned nInputs = 0, nOutputs = 0;
1526   for (auto &t : registeredTensors) {
1527     if (t.getValue().isOutput)
1528       nOutputs++;
1529     else
1530       nInputs++;
1531   }
1532
1533   os << llvm::formatv(header, cppOpName, linalgOpName, nInputs, nOutputs,
1534                       state.orderedTensorArgs.size());
1535 }
1536
1537 /// Print the C++ StructuredOpsInterface impl of `iterator_types`.
1538 void TCParser::printReferenceIterators(llvm::raw_ostream &os,
1539                                        StringRef cppOpName,
1540                                        ComprehensionParsingState &state) {
1541   const char *referenceReferenceIteratorsFmt =
1542       R"FMT(
1543     ArrayAttr {0}::iterator_types() {
1544       return Builder(getContext()).getStrArrayAttr(SmallVector<StringRef, 8>{{ {1} });
1545     })FMT";
1546
1547   std::string iteratorsStr;
1548   llvm::raw_string_ostream ss(iteratorsStr);
1549   unsigned pos = 0;
1550   llvm::interleaveComma(
1551       state.dims, ss, [&](std::pair<StringRef, AffineExpr> p) {
1552         bool reduction = false;
1553         for (auto &expr : state.expressions) {
1554           visitPostorder(*expr, [&](const Expression &e) {
1555             if (auto *pTensorExpr = dyn_cast<TensorExpr>(&e)) {
1556               if (pTensorExpr->reductionDimensions.count(pos) > 0)
1557                 reduction = true;
1558             }
1559           });
1560           if (reduction)
1561             break;
1562         }
1563         ss << (reduction ? "getReductionIteratorTypeName()"
1564                          : "getParallelIteratorTypeName()");
1565         pos++;
1566       });
1567   ss.flush();
1568
1569   os << llvm::formatv(referenceReferenceIteratorsFmt, cppOpName, iteratorsStr);
1570 }
1571
1572 void TCParser::printCanonicalizersAndFolders(llvm::raw_ostream &os,
1573                                              StringRef cppOpName) {
1574   const char *canonicalizersAndFoldersFmt = R"FMT(
1575     void {0}::getCanonicalizationPatterns(
1576         OwningRewritePatternList &results,
1577         MLIRContext *context) {{
1578       results.insert<EraseDeadLinalgOp>();
1579       results.insert<FoldTensorCastOp>();
1580     }
1581     LogicalResult {0}::fold(ArrayRef<Attribute>,
1582                             SmallVectorImpl<OpFoldResult> &) {{
1583       return foldMemRefCast(*this);
1584     }
1585     void {0}::getEffects(SmallVectorImpl<
1586         SideEffects::EffectInstance<MemoryEffects::Effect> >&effects) {{
1587       getGenericEffectsImpl(effects,
1588         getOperation()->getResults(), getInputBuffers(), getOutputBuffers());
1589     })FMT";
1590   os << llvm::formatv(canonicalizersAndFoldersFmt, cppOpName);
1591 }
1592
1593 /// Print the C++ StructuredOpsInterface impl of `referenceIndexingMaps`.
1594 void TCParser::printReferenceIndexingMaps(llvm::raw_ostream &os,
1595                                           StringRef cppOpName,
1596                                           ComprehensionParsingState &state) {
1597   // 1. Generic string template for specifying reference indexing maps.
1598   const char *referenceIndexingMapsFmt =
1599       R"FMT(
1600   // This is temporary until we transition out of manually specified ops that
1601   // should be auto-generated with linalg-ods-gen.
1602   ArrayAttr {0}::indexing_maps() {
1603     MLIRContext *context = getContext();
1604     AffineExpr {1};
1605     bindDims(context, {1});
1606     return Builder(context).getAffineMapArrayAttr({ {2} });
1607   })FMT";
1608
1609   // 2. Print a comma-separated list of identifiers for the AffineExpr in
1610   // `state.dims`. These will replace the `{1}` placeholder in both
1611   // `AffineExpr {1}` and `bindDims(context, {1})` ensuring the AffineExpr
1612   // identifiers are bound in the right order to the proper AffineDimExpr.
1613   std::string dimsStr;
1614   llvm::raw_string_ostream ss(dimsStr);
1615   llvm::interleaveComma(
1616       state.dims, ss,
1617       [&](std::pair<StringRef, AffineExpr> p) { ss << p.second; });
1618   ss.flush();
1619
1620   // 3. Print a comma-separated list of AffineMap constructors that use the
1621   // identifiers from 1. The AffineExpr use the common arithmetic operators on
1622   // AffineExpr. These AffineMap constructors will replace the `{2}` placeholder
1623   // in return `SmallVector<AffineMap, 8>{{ {2} };`.
1624   std::string mapsStr;
1625   llvm::raw_string_ostream mapsStringStream(mapsStr);
1626   SmallVector<TensorUse, 4> orderedUses(state.orderedTensorArgs.size());
1627   for (const auto &it : state.orderedTensorArgs)
1628     orderedUses[it.second] = it.first;
1629   llvm::interleaveComma(orderedUses, mapsStringStream, [&](TensorUse u) {
1630     assert(u.indexingMap);
1631     const char *mapFmt = "\n\tAffineMap::get({0}, 0, {1}, context)";
1632     if (u.indexingMap.isEmpty()) {
1633       mapsStringStream << llvm::formatv(mapFmt, state.dims.size(), "context");
1634       return;
1635     }
1636
1637     std::string exprsStr;
1638     llvm::raw_string_ostream exprsStringStream(exprsStr);
1639     exprsStringStream << "{";
1640     llvm::interleaveComma(u.indexingMap.getResults(), exprsStringStream);
1641     exprsStringStream << "}";
1642     exprsStringStream.flush();
1643
1644     mapsStringStream << llvm::formatv(mapFmt, state.dims.size(), exprsStr);
1645   });
1646   mapsStringStream.flush();
1647
1648   // 4. Apply format to 1. using 2. and 3.
1649   os << llvm::formatv(referenceIndexingMapsFmt, cppOpName, dimsStr, mapsStr);
1650 }
1651
1652 /// Print the C++ StructuredOpsInterface impl of `regionBuilder`.
1653 void TCParser::printRegionBuilder(llvm::raw_ostream &os, StringRef cppOpName,
1654                                   ComprehensionParsingState &state) {
1655   unsigned count = state.orderedTensorArgs.size();
1656   llvm::DenseMap<const TensorExpr *, unsigned> subExprsMap;
1657   std::function<void(llvm::raw_ostream & os, const Expression &)> printExpr;
1658   printExpr = [&](llvm::raw_ostream &os, const Expression &e) -> void {
1659     if (auto *pUse = dyn_cast<TensorUse>(&e)) {
1660       os << "_" << state.orderedTensorArgs.find(*pUse)->second;
1661       return;
1662     }
1663     auto *pTensorExpr = cast<TensorExpr>(&e);
1664     if (subExprsMap.count(pTensorExpr) > 0) {
1665       os << "_" << subExprsMap[pTensorExpr];
1666     } else {
1667       std::string subExprs;
1668       llvm::raw_string_ostream subExprsStringStream(subExprs);
1669       llvm::interleaveComma(pTensorExpr->expressions, subExprsStringStream,
1670                             [&](const std::unique_ptr<Expression> &e) {
1671                               printExpr(subExprsStringStream, *e);
1672                             });
1673       subExprsStringStream.flush();
1674       const char *tensorExprFmt = "\n    Value _{0} = {1}({2});";
1675       os << llvm::formatv(tensorExprFmt, ++count, pTensorExpr->operationName,
1676                           subExprs);
1677       subExprsMap[pTensorExpr] = count;
1678     }
1679   };
1680
1681   const char *regionBuilderFmt = R"FMT(
1682   void {0}::regionBuilder(Block &block) {
1683     using namespace edsc;
1684     using namespace intrinsics;
1685     auto args = block.getArguments();
1686     Value {1};
1687     {2}
1688     (linalg_yield(ValueRange{ {3} }));
1689   })FMT";
1690
1691   unsigned idx = 0;
1692   std::string valueHandleStr;
1693   llvm::raw_string_ostream valueHandleStringStream(valueHandleStr);
1694   llvm::interleaveComma(
1695       state.orderedTensorArgs, valueHandleStringStream, [&](auto) {
1696         valueHandleStringStream << "_" << idx << "(args[" << idx << "])";
1697         idx++;
1698       });
1699
1700   std::string expressionsStr;
1701   llvm::raw_string_ostream expressionStringStream(expressionsStr);
1702   for (auto &expr : state.expressions)
1703     visitPostorder(*expr, [&](const Expression &e) {
1704       if (e.kind == Expression::Kind::TensorExpr)
1705         printExpr(expressionStringStream, e);
1706     });
1707
1708   std::string yieldStr;
1709   llvm::raw_string_ostream yieldStringStream(yieldStr);
1710   llvm::interleaveComma(state.expressions, yieldStringStream,
1711                         [&](const std::unique_ptr<Expression> &e) {
1712                           printExpr(yieldStringStream, *e);
1713                         });
1714
1715   valueHandleStringStream.flush();
1716   expressionStringStream.flush();
1717   yieldStringStream.flush();
1718
1719   os << llvm::formatv(regionBuilderFmt, cppOpName, valueHandleStr,
1720                       expressionsStr, yieldStr);
1721 }
1722
1723 /// Iterate over each Tensor Comprehension def.
1724 LogicalResult parseAndEmitAllTensorComprehensions(llvm::raw_ostream &os,
1725                                                   Parser &parser) {
1726   while (parser.curToken.getKind() != Token::Kind::eof) {
1727     TCParser tcParser(parser);
1728     if (failed(tcParser.parseAndEmitODSDef(os)))
1729       return failure();
1730   }
1731   return success();
1732 }
1733
1734 int main(int argc, char **argv) {
1735   llvm::cl::ParseCommandLineOptions(argc, argv, "Linalg ODS Gen");
1736
1737   // Set up the input file.
1738   std::string errorMessage;
1739   std::unique_ptr<llvm::MemoryBuffer> file =
1740       mlir::openInputFile(inputFilename, &errorMessage);
1741   if (!file) {
1742     llvm::errs() << errorMessage << "\n";
1743     return 1;
1744   }
1745
1746   std::unique_ptr<llvm::ToolOutputFile> output =
1747       openOutputFile(outputFilename, &errorMessage);
1748   if (!output) {
1749     llvm::errs() << errorMessage << "\n";
1750     exit(1);
1751   }
1752
1753   // Include the proper Linalg header for end-to-end tblgen testing without
1754   // resorting to non-portable shell manipulations.
1755   if (testEmitIncludeTdHeader)
1756     output->os() << "include \"mlir/Dialect/Linalg/IR/LinalgStructuredOps.td\"";
1757
1758   MLIRContext context;
1759   llvm::SourceMgr mgr;
1760   mgr.AddNewSourceBuffer(std::move(file), llvm::SMLoc());
1761   Parser parser(mgr, &context);
1762   parseAndEmitAllTensorComprehensions(output->os(), parser);
1763   output->keep();
1764
1765   return 0;
1766 }