1 //===- mlir-linalg-ods-gen.cpp - Linalg ODS generation from math form -----===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file contains the implementation for the Tensor Comprehension-inspired
10 // parser and ODS pretty-printer for specifying Linalg "named ops" from a
13 //===----------------------------------------------------------------------===//
15 #include "mlir/IR/AffineExpr.h"
16 #include "mlir/IR/AffineMap.h"
17 #include "mlir/IR/MLIRContext.h"
18 #include "mlir/IR/OpImplementation.h"
19 #include "mlir/Support/FileUtilities.h"
20 #include "mlir/Support/LLVM.h"
21 #include "mlir/Support/LogicalResult.h"
22 #include "llvm/ADT/SetVector.h"
23 #include "llvm/Support/Casting.h"
24 #include "llvm/Support/CommandLine.h"
25 #include "llvm/Support/FormatVariadic.h"
26 #include "llvm/Support/ToolOutputFile.h"
28 #define DEBUG_TYPE "linalg-ods-gen"
30 static llvm::cl::OptionCategory ODSGenCat("Linalg ODS Gen");
32 // Commandline options
33 static llvm::cl::opt<std::string>
34 inputFilename(llvm::cl::Positional, llvm::cl::desc("<input file>"),
35 llvm::cl::init("-"), llvm::cl::value_desc("filename"));
37 static llvm::cl::opt<std::string>
38 outputFilename("o", llvm::cl::desc("Output filename"),
39 llvm::cl::value_desc("filename"), llvm::cl::init("-"));
41 static llvm::cl::opt<bool>
42 genODSDecl("gen-ods-decl", llvm::cl::desc("Emit the ODS ops declarations."),
43 llvm::cl::cat(ODSGenCat));
45 static llvm::cl::opt<bool>
46 genODSImpl("gen-impl", llvm::cl::desc("Emit the ops implementations"),
47 llvm::cl::init(false), llvm::cl::cat(ODSGenCat));
49 static llvm::cl::opt<bool> testEmitIncludeTdHeader(
50 "test-emit-include-td-header",
51 llvm::cl::desc("Include LinalgStructuredOps.td for end-to-end "
53 llvm::cl::init(false), llvm::cl::cat(ODSGenCat));
55 using llvm::SetVector;
57 using llvm::StringRef;
62 //===----------------------------------------------------------------------===//
64 //===----------------------------------------------------------------------===//
67 /// This class represents a specific token in the input format.
75 // Tokens with no info.
92 FIRST_KEYWORD = kw_def,
97 LAST_KEYWORD = kw_mod,
99 // String valued tokens.
104 Token(Kind kind, StringRef spelling) : kind(kind), spelling(spelling) {}
106 /// Return the bytes that make up this token.
107 StringRef getSpelling() const { return spelling; }
109 /// Return the kind of this token.
110 Kind getKind() const { return kind; }
112 /// Return a location for this token.
113 llvm::SMLoc getLoc() const {
114 return llvm::SMLoc::getFromPointer(spelling.data());
117 /// Return if this token is a keyword.
118 bool isKeyword() const {
119 return kind >= Kind::FIRST_KEYWORD && kind <= Kind::LAST_KEYWORD;
121 bool is(Kind k) const { return kind == k; }
122 bool isNot(Kind k) const { return kind != k; }
124 Optional<uint64_t> getUInt64IntegerValue() const {
125 bool isHex = spelling.size() > 1 && spelling[1] == 'x';
128 if (spelling.getAsInteger(isHex ? 0 : 10, result))
134 /// Discriminator that indicates the kind of token this is.
137 /// A reference to the entire token contents; this is always a pointer into
138 /// a memory buffer owned by the source manager.
142 /// This class implements a simple lexer.
145 Lexer(llvm::SourceMgr &mgr);
147 /// Lex the next token and return it.
150 /// Emit an error to the lexer with the given location and message.
151 Token emitError(llvm::SMLoc loc, const Twine &msg);
152 Token emitError(const char *loc, const Twine &msg);
155 Token formToken(Token::Kind kind, const char *tokStart) {
156 return Token(kind, StringRef(tokStart, curPtr - tokStart));
159 /// Return the next character in the stream.
162 /// Lex an identifier.
163 Token lexIdentifier(const char *tokStart);
166 Token lexInteger(const char *tokStart);
168 // Skip a comment line, starting with a '//'.
171 llvm::SourceMgr &srcMgr;
175 } // end anonymous namespace
177 Lexer::Lexer(llvm::SourceMgr &mgr) : srcMgr(mgr) {
178 curBuffer = srcMgr.getMemoryBuffer(mgr.getMainFileID())->getBuffer();
179 curPtr = curBuffer.begin();
182 Token Lexer::emitError(llvm::SMLoc loc, const Twine &msg) {
183 srcMgr.PrintMessage(loc, llvm::SourceMgr::DK_Error, msg);
184 return formToken(Token::Kind::error, loc.getPointer());
186 Token Lexer::emitError(const char *loc, const Twine &msg) {
187 return emitError(llvm::SMLoc::getFromPointer(loc), msg);
190 int Lexer::getNextChar() {
191 char curChar = *curPtr++;
194 return (unsigned char)curChar;
196 // A nul character in the stream is either the end of the current buffer
197 // or a random nul in the file. Disambiguate that here.
198 if (curPtr - 1 != curBuffer.end())
201 // Otherwise, return end of file.
207 // Handle the newline character by ignoring it and incrementing the line
208 // count. However, be careful about 'dos style' files with \n\r in them.
209 // Only treat a \n\r or \r\n as a single line.
210 if ((*curPtr == '\n' || (*curPtr == '\r')) && *curPtr != curChar)
216 Token Lexer::lexToken() {
218 const char *tokStart = curPtr;
220 // This always consumes at least one character.
221 int curChar = getNextChar();
224 // Handle identifiers: [a-zA-Z_]
225 if (isalpha(curChar) || curChar == '_')
226 return lexIdentifier(tokStart);
228 // Handle integers: [0-9]
229 if (isdigit(curChar))
230 return lexInteger(tokStart);
232 // Unknown character, emit an error.
233 return emitError(tokStart, "unexpected character");
236 // Return EOF denoting the end of lexing.
237 return formToken(Token::Kind::eof, tokStart);
241 return formToken(Token::Kind::colon, tokStart);
243 return formToken(Token::Kind::comma, tokStart);
245 return formToken(Token::Kind::equal, tokStart);
247 return formToken(Token::Kind::l_brace, tokStart);
249 return formToken(Token::Kind::l_paren, tokStart);
251 return formToken(Token::Kind::r_brace, tokStart);
253 return formToken(Token::Kind::r_paren, tokStart);
255 return formToken(Token::Kind::lt, tokStart);
257 return formToken(Token::Kind::gt, tokStart);
259 return formToken(Token::Kind::plus, tokStart);
261 return formToken(Token::Kind::minus, tokStart);
263 return formToken(Token::Kind::semicolon, tokStart);
265 return formToken(Token::Kind::star, tokStart);
267 if (*curPtr == '/') {
271 // Unknown character, emit an error.
272 return emitError(tokStart, "unexpected character: not a comment");
274 // Ignore whitespace characters.
284 Token Lexer::lexIdentifier(const char *tokStart) {
285 // Match the rest of the identifier regex: [0-9a-zA-Z_\-]*
286 while (isalnum(*curPtr) || *curPtr == '_' || *curPtr == '-')
289 // Check to see if this identifier is a keyword.
290 StringRef str(tokStart, curPtr - tokStart);
291 Token::Kind kind = StringSwitch<Token::Kind>(str)
292 .Case("def", Token::Kind::kw_def)
293 .Case("ods_def", Token::Kind::kw_ods_def)
294 .Case("floordiv", Token::Kind::kw_floordiv)
295 .Case("ceildiv", Token::Kind::kw_ceildiv)
296 .Case("mod", Token::Kind::kw_mod)
297 .Default(Token::Kind::id);
299 return Token(kind, str);
302 Token Lexer::lexInteger(const char *tokStart) {
303 // Match the rest of the identifier regex: [0-9a-zA-Z_\-]*
304 while (isdigit(*curPtr))
307 StringRef str(tokStart, curPtr - tokStart);
308 return Token(Token::Kind::integer, str);
311 /// Skip a comment line, starting with a '//'.
312 void Lexer::skipComment() {
313 // Advance over the second '/' in a '//' comment.
314 assert(*curPtr == '/');
321 // Newline is end of comment.
324 // If this is the end of the buffer, end the comment.
325 if (curPtr - 1 == curBuffer.end()) {
331 // Skip over other characters.
341 Parser(llvm::SourceMgr &mgr, MLIRContext *ctx)
342 : lexer(mgr), curToken(lexer.lexToken()), context(ctx) {}
344 //===--------------------------------------------------------------------===//
346 //===--------------------------------------------------------------------===//
348 /// Advance the current lexer onto the next token.
349 void consumeToken() {
350 assert(curToken.getKind() != Token::Kind::eof &&
351 curToken.getKind() != Token::Kind::error &&
352 "shouldn't advance past EOF or errors");
353 curToken = lexer.lexToken();
355 void consumeToken(Token::Kind kind) {
356 assert(curToken.getKind() == kind && "unexpected token");
357 curToken = lexer.lexToken();
359 LogicalResult parseToken(Token::Kind kind, const Twine &msg) {
360 if (curToken.getKind() != kind)
361 return emitError(curToken.getLoc(), msg);
365 LogicalResult emitError(llvm::SMLoc loc, const Twine &msg) {
366 lexer.emitError(loc, msg);
369 LogicalResult emitError(const Twine &msg) {
370 return emitError(curToken.getLoc(), msg);
372 bool consumeIf(Token::Kind kind) {
373 if (curToken.isNot(kind))
379 parseCommaSeparatedList(llvm::function_ref<ParseResult()> parseElement) {
380 // Non-empty case starts with an element.
384 // Otherwise we have a list of comma separated elements.
385 while (consumeIf(Token::Kind::comma)) {
392 parseCommaSeparatedListUntil(Token::Kind rightToken,
393 llvm::function_ref<ParseResult()> parseElement,
394 bool allowEmptyList) {
395 // Handle the empty case.
396 if (curToken.is(rightToken)) {
398 return emitError("expected list element");
399 consumeToken(rightToken);
403 if (failed(parseCommaSeparatedList(parseElement)) ||
405 parseToken(rightToken, "expected ',' or right-terminating token")))
413 MLIRContext *context;
417 //===----------------------------------------------------------------------===//
419 //===----------------------------------------------------------------------===//
423 /// Lower precedence ops (all at the same precedence level). LNoOp is false in
424 /// the boolean sense.
425 enum AffineLowPrecOp {
432 /// Higher precedence ops - all at the same precedence level. HNoOp is false
433 /// in the boolean sense.
434 enum AffineHighPrecOp {
443 using AffineDimList = SmallVector<std::pair<StringRef, AffineExpr>, 4>;
444 using AffineSymbolList = SmallVector<std::pair<StringRef, AffineExpr>, 4>;
446 /// This is a specialized parser for affine expressions.
449 explicit AffineParser(Parser &p,
450 std::function<AffineExpr(StringRef)> bareIdParsingHook,
451 AffineDimList &dimList, AffineSymbolList &symbolList)
452 : parser(p), bareIdFallback(bareIdParsingHook), dims(dimList),
453 symbols(symbolList) {}
455 /// Parse a comma-separated list of affine exprs.
456 SmallVector<AffineExpr, 4>
457 parseAffineExprs(Token::Kind lDelim = Token::Kind::l_paren,
458 Token::Kind rDelim = Token::Kind::r_paren);
460 /// Parse a single affine expr.`.
461 AffineExpr parseAffineExpr();
464 // Binary affine op parsing.
465 AffineLowPrecOp consumeIfLowPrecOp();
466 AffineHighPrecOp consumeIfHighPrecOp();
468 // AffineExpr parsing.
469 AffineExpr parseParentheticalExpr();
470 AffineExpr parseNegateExpression(AffineExpr lhs);
471 AffineExpr parseIntegerExpr();
472 AffineExpr parseBareIdExpr();
474 AffineExpr getAffineBinaryOpExpr(AffineHighPrecOp op, AffineExpr lhs,
475 AffineExpr rhs, SMLoc opLoc);
476 AffineExpr getAffineBinaryOpExpr(AffineLowPrecOp op, AffineExpr lhs,
478 AffineExpr parseAffineOperandExpr(AffineExpr lhs);
479 AffineExpr parseAffineLowPrecOpExpr(AffineExpr llhs, AffineLowPrecOp llhsOp);
480 AffineExpr parseAffineHighPrecOpExpr(AffineExpr llhs, AffineHighPrecOp llhsOp,
484 std::function<AffineExpr(StringRef)> bareIdFallback;
486 AffineSymbolList &symbols;
488 } // end anonymous namespace
490 /// Create an affine binary high precedence op expression (mul's, div's, mod).
491 /// opLoc is the location of the op token to be used to report errors
492 /// for non-conforming expressions.
493 AffineExpr AffineParser::getAffineBinaryOpExpr(AffineHighPrecOp op,
494 AffineExpr lhs, AffineExpr rhs,
498 if (!lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant()) {
499 parser.emitError(opLoc,
500 "non-affine expression: at least one of the multiply "
501 "operands has to be either a constant or symbolic");
506 if (!rhs.isSymbolicOrConstant()) {
507 parser.emitError(opLoc,
508 "non-affine expression: right operand of floordiv "
509 "has to be either a constant or symbolic");
512 return lhs.floorDiv(rhs);
514 if (!rhs.isSymbolicOrConstant()) {
515 parser.emitError(opLoc, "non-affine expression: right operand of ceildiv "
516 "has to be either a constant or symbolic");
519 return lhs.ceilDiv(rhs);
521 if (!rhs.isSymbolicOrConstant()) {
522 parser.emitError(opLoc, "non-affine expression: right operand of mod "
523 "has to be either a constant or symbolic");
528 llvm_unreachable("can't create affine expression for null high prec op");
531 llvm_unreachable("Unknown AffineHighPrecOp");
534 /// Create an affine binary low precedence op expression (add, sub).
535 AffineExpr AffineParser::getAffineBinaryOpExpr(AffineLowPrecOp op,
536 AffineExpr lhs, AffineExpr rhs) {
538 case AffineLowPrecOp::Add:
540 case AffineLowPrecOp::Sub:
542 case AffineLowPrecOp::LNoOp:
543 llvm_unreachable("can't create affine expression for null low prec op");
546 llvm_unreachable("Unknown AffineLowPrecOp");
549 /// Consume this token if it is a lower precedence affine op (there are only
550 /// two precedence levels).
551 AffineLowPrecOp AffineParser::consumeIfLowPrecOp() {
552 switch (parser.curToken.getKind()) {
553 case Token::Kind::plus:
554 parser.consumeToken();
555 return AffineLowPrecOp::Add;
556 case Token::Kind::minus:
557 parser.consumeToken();
558 return AffineLowPrecOp::Sub;
560 return AffineLowPrecOp::LNoOp;
564 /// Consume this token if it is a higher precedence affine op (there are only
565 /// two precedence levels)
566 AffineHighPrecOp AffineParser::consumeIfHighPrecOp() {
567 switch (parser.curToken.getKind()) {
568 case Token::Kind::star:
569 parser.consumeToken(Token::Kind::star);
571 case Token::Kind::kw_floordiv:
572 parser.consumeToken(Token::Kind::kw_floordiv);
574 case Token::Kind::kw_ceildiv:
575 parser.consumeToken(Token::Kind::kw_ceildiv);
577 case Token::Kind::kw_mod:
578 parser.consumeToken(Token::Kind::kw_mod);
585 /// Parse a high precedence op expression list: mul, div, and mod are high
586 /// precedence binary ops, i.e., parse a
587 /// expr_1 op_1 expr_2 op_2 ... expr_n
588 /// where op_1, op_2 are all a AffineHighPrecOp (mul, div, mod).
589 /// All affine binary ops are left associative.
590 /// Given llhs, returns (llhs llhsOp lhs) op rhs, or (lhs op rhs) if llhs is
591 /// null. If no rhs can be found, returns (llhs llhsOp lhs) or lhs if llhs is
592 /// null. llhsOpLoc is the location of the llhsOp token that will be used to
593 /// report an error for non-conforming expressions.
594 AffineExpr AffineParser::parseAffineHighPrecOpExpr(AffineExpr llhs,
595 AffineHighPrecOp llhsOp,
597 AffineExpr lhs = parseAffineOperandExpr(llhs);
601 // Found an LHS. Parse the remaining expression.
602 auto opLoc = parser.curToken.getLoc();
603 if (AffineHighPrecOp op = consumeIfHighPrecOp()) {
605 AffineExpr expr = getAffineBinaryOpExpr(llhsOp, llhs, lhs, opLoc);
608 return parseAffineHighPrecOpExpr(expr, op, opLoc);
611 return parseAffineHighPrecOpExpr(lhs, op, opLoc);
614 // This is the last operand in this expression.
616 return getAffineBinaryOpExpr(llhsOp, llhs, lhs, llhsOpLoc);
618 // No llhs, 'lhs' itself is the expression.
622 /// Parse an affine expression inside parentheses.
624 /// affine-expr ::= `(` affine-expr `)`
625 AffineExpr AffineParser::parseParentheticalExpr() {
626 if (failed(parser.parseToken(Token::Kind::l_paren, "expected '('")))
628 if (parser.curToken.is(Token::Kind::r_paren))
629 return (parser.emitError("no expression inside parentheses"), nullptr);
631 auto expr = parseAffineExpr();
634 if (failed(parser.parseToken(Token::Kind::r_paren, "expected ')'")))
640 /// Parse the negation expression.
642 /// affine-expr ::= `-` affine-expr
643 AffineExpr AffineParser::parseNegateExpression(AffineExpr lhs) {
644 if (failed(parser.parseToken(Token::Kind::minus, "expected '-'")))
647 AffineExpr operand = parseAffineOperandExpr(lhs);
648 // Since negation has the highest precedence of all ops (including high
649 // precedence ops) but lower than parentheses, we are only going to use
650 // parseAffineOperandExpr instead of parseAffineExpr here.
652 // Extra error message although parseAffineOperandExpr would have
653 // complained. Leads to a better diagnostic.
654 return (parser.emitError("missing operand of negation"), nullptr);
655 return (-1) * operand;
658 /// Parse a bare id that may appear in an affine expression.
660 /// affine-expr ::= bare-id
661 AffineExpr AffineParser::parseBareIdExpr() {
662 if (parser.curToken.isNot(Token::Kind::id))
663 return (parser.emitError("expected id"), nullptr);
665 StringRef sRef = parser.curToken.getSpelling();
666 for (auto &list : {dims, symbols}) {
667 for (auto entry : list) {
668 if (entry.first == sRef) {
669 parser.consumeToken(Token::Kind::id);
675 // Not found, check fallback path.
676 AffineExpr expr = bareIdFallback(sRef);
678 parser.consumeToken(Token::Kind::id);
682 return (parser.emitError("use of undeclared id"), nullptr);
685 /// Parse a positive integral constant appearing in an affine expression.
687 /// affine-expr ::= integer-literal
688 AffineExpr AffineParser::parseIntegerExpr() {
689 auto val = parser.curToken.getUInt64IntegerValue();
690 if (!val.hasValue() || (int64_t)val.getValue() < 0)
691 return (parser.emitError("constant too large for index"), nullptr);
693 parser.consumeToken(Token::Kind::integer);
694 return getAffineConstantExpr((int64_t)val.getValue(), parser.context);
697 /// Parses an expression that can be a valid operand of an affine expression.
698 /// lhs: if non-null, lhs is an affine expression that is the lhs of a binary
699 /// operator, the rhs of which is being parsed. This is used to determine
700 /// whether an error should be emitted for a missing right operand.
701 // Eg: for an expression without parentheses (like i + j + k + l), each
702 // of the four identifiers is an operand. For i + j*k + l, j*k is not an
703 // operand expression, it's an op expression and will be parsed via
704 // parseAffineHighPrecOpExpression(). However, for i + (j*k) + -l, (j*k) and
705 // -l are valid operands that will be parsed by this function.
706 AffineExpr AffineParser::parseAffineOperandExpr(AffineExpr lhs) {
707 switch (parser.curToken.getKind()) {
708 case Token::Kind::id:
709 return parseBareIdExpr();
710 case Token::Kind::integer:
711 return parseIntegerExpr();
712 case Token::Kind::l_paren:
713 return parseParentheticalExpr();
714 case Token::Kind::minus:
715 return parseNegateExpression(lhs);
716 case Token::Kind::kw_ceildiv:
717 case Token::Kind::kw_floordiv:
718 case Token::Kind::kw_mod:
719 case Token::Kind::plus:
720 case Token::Kind::star:
722 parser.emitError("missing right operand of binary operator");
724 parser.emitError("missing left operand of binary operator");
728 parser.emitError("missing right operand of binary operator");
730 parser.emitError("expected affine expression");
735 /// Parse affine expressions that are bare-id's, integer constants,
736 /// parenthetical affine expressions, and affine op expressions that are a
737 /// composition of those.
739 /// All binary op's associate from left to right.
741 /// {add, sub} have lower precedence than {mul, div, and mod}.
743 /// Add, sub'are themselves at the same precedence level. Mul, floordiv,
744 /// ceildiv, and mod are at the same higher precedence level. Negation has
745 /// higher precedence than any binary op.
747 /// llhs: the affine expression appearing on the left of the one being parsed.
748 /// This function will return ((llhs llhsOp lhs) op rhs) if llhs is non null,
749 /// and lhs op rhs otherwise; if there is no rhs, llhs llhsOp lhs is returned
750 /// if llhs is non-null; otherwise lhs is returned. This is to deal with left
753 /// Eg: when the expression is e1 + e2*e3 + e4, with e1 as llhs, this function
754 /// will return the affine expr equivalent of (e1 + (e2*e3)) + e4, where
755 /// (e2*e3) will be parsed using parseAffineHighPrecOpExpr().
756 AffineExpr AffineParser::parseAffineLowPrecOpExpr(AffineExpr llhs,
757 AffineLowPrecOp llhsOp) {
759 if (!(lhs = parseAffineOperandExpr(llhs)))
762 // Found an LHS. Deal with the ops.
763 if (AffineLowPrecOp lOp = consumeIfLowPrecOp()) {
765 AffineExpr sum = getAffineBinaryOpExpr(llhsOp, llhs, lhs);
766 return parseAffineLowPrecOpExpr(sum, lOp);
768 // No LLHS, get RHS and form the expression.
769 return parseAffineLowPrecOpExpr(lhs, lOp);
771 auto opLoc = parser.curToken.getLoc();
772 if (AffineHighPrecOp hOp = consumeIfHighPrecOp()) {
773 // We have a higher precedence op here. Get the rhs operand for the llhs
774 // through parseAffineHighPrecOpExpr.
775 AffineExpr highRes = parseAffineHighPrecOpExpr(lhs, hOp, opLoc);
779 // If llhs is null, the product forms the first operand of the yet to be
780 // found expression. If non-null, the op to associate with llhs is llhsOp.
782 llhs ? getAffineBinaryOpExpr(llhsOp, llhs, highRes) : highRes;
784 // Recurse for subsequent low prec op's after the affine high prec op
786 if (AffineLowPrecOp nextOp = consumeIfLowPrecOp())
787 return parseAffineLowPrecOpExpr(expr, nextOp);
790 // Last operand in the expression list.
792 return getAffineBinaryOpExpr(llhsOp, llhs, lhs);
793 // No llhs, 'lhs' itself is the expression.
797 /// Parse an affine expression.
798 /// affine-expr ::= `(` affine-expr `)`
799 /// | `-` affine-expr
800 /// | affine-expr `+` affine-expr
801 /// | affine-expr `-` affine-expr
802 /// | affine-expr `*` affine-expr
803 /// | affine-expr `floordiv` affine-expr
804 /// | affine-expr `ceildiv` affine-expr
805 /// | affine-expr `mod` affine-expr
807 /// | integer-literal
809 /// Additional conditions are checked depending on the production. For eg.,
810 /// one of the operands for `*` has to be either constant/symbolic; the second
811 /// operand for floordiv, ceildiv, and mod has to be a positive integer.
812 AffineExpr AffineParser::parseAffineExpr() {
813 return parseAffineLowPrecOpExpr(nullptr, AffineLowPrecOp::LNoOp);
816 SmallVector<AffineExpr, 4> AffineParser::parseAffineExprs(Token::Kind lDelim,
817 Token::Kind rDelim) {
818 parser.parseToken(lDelim, "expected lDelim at start of affine expr list");
820 SmallVector<AffineExpr, 4> exprs;
821 auto parseElt = [&]() -> LogicalResult {
822 auto elt = parseAffineExpr();
823 exprs.push_back(elt);
824 return elt ? success() : failure();
827 if (failed(parser.parseCommaSeparatedListUntil(rDelim, parseElt,
828 /*allowEmptyList=*/true)))
829 llvm_unreachable("Failed AffineExpr parsing");
834 //===----------------------------------------------------------------------===//
836 //===----------------------------------------------------------------------===//
840 /// Base class for expressions involved in TC parsing.
848 explicit Expression(Kind k = Kind::Uninitialized) : kind(k) {}
849 virtual ~Expression() = default;
851 operator bool() const { return kind != Kind::Uninitialized; }
856 /// Encodes a tensor use of the form:
858 /// affine-expr-list ::= affine-expr (`,` affine-expr)*
859 /// tensor-use ::= bare-id `(` `)`
860 /// | bare-id `(` affine-expr-list `)`
862 /// The affine-expr-list is stored as an AffineMap.
863 struct TensorUse : public Expression {
864 TensorUse() : TensorUse("", AffineMap()) {}
865 TensorUse(StringRef name, AffineMap map)
866 : Expression(Kind::TensorUse), tensorId(name), indexingMap(map) {}
867 TensorUse(const TensorUse &use) = default;
869 static bool classof(const Expression *e) {
870 return e->kind == Kind::TensorUse;
873 bool operator==(const TensorUse &other) const {
874 return tensorId == other.tensorId && indexingMap == other.indexingMap;
877 /// Visitation function. Performs preorder or postorder traversal depending on
878 /// `PreOrder` and applies `callback` on each node.
879 template <typename Lambda, bool PreOrder>
880 void visit(Lambda callback) const;
883 AffineMap indexingMap;
886 /// Encodes a tensor expression of the form:
888 /// op-spec ::= bare-id `<` reduction-dims-list `>`
890 /// op-arg ::= tensor-expr
892 /// op-arg-list ::= op-arg (`,` op-arg)*
893 /// tensor-expr ::= op-spec `(` op-arg-list `)`
895 /// Underlying op-arg are stored by unique_ptr to base class.
896 struct TensorExpr : public Expression {
897 TensorExpr(StringRef name,
898 SmallVectorImpl<std::unique_ptr<Expression>> &&exprs,
899 ArrayRef<unsigned> reductionDims)
900 : Expression(Kind::TensorExpr), operationName(name),
901 expressions(std::move(exprs)),
902 reductionDimensions(reductionDims.begin(), reductionDims.end()) {}
904 static bool classof(const Expression *e) {
905 return e->kind == Kind::TensorExpr;
908 bool operator==(const TensorExpr &other) const {
909 if (operationName != other.operationName)
911 if (expressions.size() != other.expressions.size())
913 for (unsigned i = 0, e = expressions.size(); i < e; ++i)
914 if (*expressions[i] != *other.expressions[i])
916 for (unsigned i = 0, e = reductionDimensions.size(); i < e; ++i)
917 if (reductionDimensions[i] != other.reductionDimensions[i])
922 /// Visitation function. Performs preorder or postorder traversal depending on
923 /// `PreOrder` and applies `callback` on each node.
924 template <typename Lambda, bool PreOrder>
925 void visit(Lambda callback) const;
927 StringRef operationName;
928 SmallVector<std::unique_ptr<Expression>, 4> expressions;
929 SetVector<unsigned> reductionDimensions;
932 /// This is a specialized parser for a TCDef.
933 /// This maintains the dims it finds in an eager fashion.
935 enum class EagerDiscoveryMode { None = 0, Symbols, Dimensions };
938 explicit TCParser(Parser &p);
940 /// Uses the AffineParser to parse the affine exprs used in a tensor
941 /// definition. If `discoveryMode` is set to Symbols (resp. Dimensions), new
942 /// symbols (resp. dimensions) are added eagerly. Otherwise, an error is
943 /// emitted on new identifiers.
944 SmallVector<AffineExpr, 4>
945 parseAffineExprs(EagerDiscoveryMode discoveryMode, AffineDimList &dims,
946 Token::Kind lDelim = Token::Kind::l_paren,
947 Token::Kind rDelim = Token::Kind::r_paren);
949 /// Parse the information for a tensor def.
950 /// All the affine-expr must be dimensionless (i.e. contain only expressions
951 /// involving symbols and constants), but can otherwise contain arbitrary
952 /// affine expressions.
953 LogicalResult parseTensorDef(bool isOutput);
955 /// Parses a tensor use.
956 struct ComprehensionParsingState {
958 SmallVector<std::unique_ptr<Expression>, 4> expressions;
959 llvm::DenseMap<TensorUse, unsigned> orderedTensorArgs;
961 LogicalResult parseTensorUse(TensorUse &result,
962 ComprehensionParsingState &state);
964 /// Parses a tensor expression.
965 LogicalResult parseExpression(TensorUse currentDefinition,
966 std::unique_ptr<Expression> &result,
967 ComprehensionParsingState &state);
969 /// Parse a single comprehension.
970 LogicalResult parseOneComprehension(StringRef cppOpName,
971 StringRef linalgOpName,
972 ComprehensionParsingState &state);
974 /// Parse and print the information for a TC def.
975 /// When `gen-ods-decl` is used, this prints the ODS declaration for the TC.
976 /// When `gen-impl` is used, this prints the C++ implementation for the extra
977 /// methods defined in ODS (`iterator_types`, `indexing_maps` and
978 /// `regionBuilder`).
979 LogicalResult parseAndEmitODSDef(llvm::raw_ostream &os);
981 /// Print the ODS class that defines a new `cppOpName` for a `linalgOpName`.
982 void printODS(llvm::raw_ostream &os, StringRef cppOpName,
983 StringRef linalgOpName, ComprehensionParsingState &state);
985 /// Print the C++ StructuredOpsInterface impl of `iterator_types`.
986 void printReferenceIterators(llvm::raw_ostream &os, StringRef cppOpName,
987 ComprehensionParsingState &state);
989 /// Print the C++ StructuredOpsInterface impl of `indexing_maps`.
990 void printReferenceIndexingMaps(llvm::raw_ostream &os, StringRef cppOpName,
991 ComprehensionParsingState &state);
993 /// Print the C++ StructuredOpsInterface impl of `regionBuilder`.
994 void printRegionBuilder(llvm::raw_ostream &os, StringRef cppOpName,
995 ComprehensionParsingState &state);
997 /// Print the C++ impl for named ops canonicalizers and folders.
998 void printCanonicalizersAndFolders(llvm::raw_ostream &os,
999 StringRef cppOpName);
1002 //===--------------------------------------------------------------------===//
1003 // Internal bookkeeping of tensors.
1004 //===--------------------------------------------------------------------===//
1005 struct RegisteredTensor {
1009 AffineMap indexingMap;
1013 //===--------------------------------------------------------------------===//
1014 // Per-TC def state.
1015 //===--------------------------------------------------------------------===//
1016 /// Symbols are per TC def.
1017 AffineSymbolList symbols;
1018 /// Tensors are per TC def.
1019 llvm::StringMap<RegisteredTensor> registeredTensors;
1020 unsigned nextRegisteredTensorIndex;
1029 struct DenseMapInfo<TensorUse> {
1030 static TensorUse getEmptyKey() { return TensorUse("", AffineMap()); }
1031 static TensorUse getTombstoneKey() {
1032 return TensorUse(DenseMapInfo<StringRef>::getTombstoneKey(),
1033 DenseMapInfo<AffineMap>::getTombstoneKey());
1035 static unsigned getHashValue(const TensorUse &val) {
1036 return ::llvm::hash_value(val.tensorId); // don't care about collisions.
1038 static bool isEqual(const TensorUse &LHS, const TensorUse &RHS) {
1045 //===----------------------------------------------------------------------===//
1046 // Visitation functions.
1047 //===----------------------------------------------------------------------===//
1049 template <typename Lambda, bool PreOrder>
1050 void visit(const Expression &expr, Lambda callback) {
1051 switch (expr.kind) {
1053 llvm_unreachable("Unexpected kind");
1054 case Expression::Kind::TensorExpr:
1055 static_cast<const TensorExpr &>(expr).visit<Lambda, PreOrder>(callback);
1057 case Expression::Kind::TensorUse:
1058 static_cast<const TensorUse &>(expr).visit<Lambda, PreOrder>(callback);
1063 template <typename Lambda>
1064 void visitPreorder(const Expression &expr, Lambda callback) {
1065 visit<Lambda, false>(expr, callback);
1068 template <typename Lambda>
1069 void visitPostorder(Expression &expr, Lambda callback) {
1070 visit<Lambda, true>(expr, callback);
1073 template <typename Lambda, bool PreOrder>
1074 void TensorExpr::visit(Lambda callback) const {
1077 for (auto &e : expressions)
1078 ::visit<Lambda, PreOrder>(*e, callback);
1083 template <typename Lambda, bool PreOrder>
1084 void TensorUse::visit(Lambda callback) const {
1088 //===----------------------------------------------------------------------===//
1089 // TC parsing functions.
1090 //===----------------------------------------------------------------------===//
1091 TCParser::TCParser(Parser &p)
1092 : symbols(), registeredTensors(), nextRegisteredTensorIndex(0), parser(p) {}
1094 /// Uses the AffineParser to parse the affine exprs used in a tensor
1095 /// definition. All identifiers are interpreted as symbols, new symbols are
1097 SmallVector<AffineExpr, 4>
1098 TCParser::parseAffineExprs(EagerDiscoveryMode discoveryMode,
1099 AffineDimList &dims, Token::Kind lDelim,
1100 Token::Kind rDelim) {
1101 AffineParser affineParser(
1103 [&](StringRef sRef) {
1105 if (discoveryMode == EagerDiscoveryMode::Symbols) {
1106 expr = getAffineSymbolExpr(symbols.size(), parser.context);
1107 symbols.emplace_back(sRef, expr);
1108 } else if (discoveryMode == EagerDiscoveryMode::Dimensions) {
1109 expr = getAffineDimExpr(dims.size(), parser.context);
1110 dims.emplace_back(sRef, expr);
1115 return affineParser.parseAffineExprs(lDelim, rDelim);
1118 /// Parse the information for a tensor def of the form:
1120 /// affine-expr-list ::= affine-expr (`,` affine-expr )*
1121 /// tensor-typedef ::= type `(` `)`
1122 /// | type `(` affine-expr-list `)`
1123 /// tensor-def ::= bare-id `:` tensor-typedef
1124 LogicalResult TCParser::parseTensorDef(bool isOutput) {
1125 StringRef tensorId = parser.curToken.getSpelling();
1126 if (failed(parser.parseToken(Token::Kind::id, "expected an id")) ||
1127 failed(parser.parseToken(Token::Kind::colon, "expected colon")))
1130 StringRef tensorType = parser.curToken.getSpelling();
1131 if (failed(parser.parseToken(Token::Kind::id, "expected an id")))
1134 AffineDimList emptyDims;
1135 auto exprs = parseAffineExprs(EagerDiscoveryMode::Symbols, emptyDims);
1136 assert(emptyDims.empty() && "Unexpected dimension in tensor def");
1138 AffineMap::get(/*dimCount=*/0, symbols.size(), exprs, parser.context);
1140 auto iterBoolPair = registeredTensors.try_emplace(
1141 tensorId, RegisteredTensor{tensorType, map, isOutput, AffineMap(),
1142 nextRegisteredTensorIndex++});
1144 assert(iterBoolPair.second && "Could not emplace tensor registration");
1145 LLVM_DEBUG(llvm::dbgs() << "Recorded: " << tensorId << " "
1146 << "with typeString: " << tensorType << " "
1147 << "and shape: " << map << "\n");
1152 /// Parses a tensor use of the form:
1154 /// affine-expr-list ::= affine-expr (`,` affine-expr)*
1155 /// tensor-use ::= bare-id `(` `)`
1156 /// | bare-id `(` affine-expr-list `)`
1157 LogicalResult TCParser::parseTensorUse(TensorUse &result,
1158 ComprehensionParsingState &state) {
1159 StringRef tensorId = parser.curToken.getSpelling();
1160 if (failed(parser.parseToken(Token::Kind::id, "expected an id")))
1163 auto exprs = parseAffineExprs(EagerDiscoveryMode::Dimensions, state.dims);
1165 AffineMap::get(state.dims.size(), symbols.size(), exprs, parser.context);
1166 LLVM_DEBUG(llvm::dbgs() << "Use of tensor: " << tensorId << " map: " << map
1169 result = TensorUse(tensorId, map);
1173 /// Parses a tensor expression of the form:
1175 /// op-spec ::= bare-id `<` reduction-dims-list `>`
1177 /// op-arg ::= tensor-expr
1179 /// op-arg-list ::= op-arg (`,` op-arg)*
1180 /// tensor-expr ::= op-spec `(` op-arg-list `)`
1181 LogicalResult TCParser::parseExpression(TensorUse currentDefinition,
1182 std::unique_ptr<Expression> &result,
1183 ComprehensionParsingState &state) {
1184 StringRef opOrTensor = parser.curToken.getSpelling();
1185 if (registeredTensors.count(opOrTensor) > 0) {
1187 auto res = parseTensorUse(use, state);
1190 result = std::make_unique<TensorUse>(use);
1194 if (failed(parser.parseToken(Token::Kind::id, "expected an operation")))
1198 SmallVector<unsigned, 4> reductionDims;
1199 SmallVector<std::unique_ptr<Expression>, 4> expressions;
1201 // Check if it has a reduction set, discover dimensions eagerly.
1202 if (parser.curToken.is(Token::Kind::lt)) {
1203 auto iters = parseAffineExprs(EagerDiscoveryMode::Dimensions, state.dims,
1204 Token::Kind::lt, Token::Kind::gt);
1205 for (auto iter : iters)
1206 reductionDims.push_back(iter.cast<AffineDimExpr>().getPosition());
1209 // If this op is a reduction, it's first argument is the `currentDefinition`
1211 if (!reductionDims.empty())
1212 expressions.push_back(std::make_unique<TensorUse>(currentDefinition));
1213 LLVM_DEBUG(llvm::dbgs() << "op: " << opOrTensor << "\n");
1215 auto parseExpr = [&]() -> LogicalResult {
1216 std::unique_ptr<Expression> e;
1217 if (failed(parseExpression(currentDefinition, e, state)))
1219 expressions.push_back(std::move(e));
1222 if (failed(parser.parseToken(Token::Kind::l_paren, "expected '('")) ||
1223 failed(parser.parseCommaSeparatedListUntil(
1224 Token::Kind::r_paren, parseExpr, /*allowEmptyList=*/true)))
1227 result = std::make_unique<TensorExpr>(opOrTensor, std::move(expressions),
1233 //===----------------------------------------------------------------------===//
1234 // Parse and Emit functions.
1235 //===----------------------------------------------------------------------===//
1237 /// Parse the information for a single comprehension.
1239 /// tensor-def-list ::= tensor-def (`,` tensor-def)*
1240 /// tensor-expr-list ::= tensor-expr (`,` tensor-expr)*
1241 /// comprehension ::= tensor-def-list `=` tensor-expr-list `;`
1243 TCParser::parseOneComprehension(StringRef cppOpName, StringRef linalgOpName,
1244 ComprehensionParsingState &state) {
1245 // 1. Parse LHS of `=`, these become the definitions that appear as the output
1246 // tensors or read/write buffers.
1247 SmallVector<TensorUse, 4> definitions;
1248 auto parseUse = [&]() -> LogicalResult {
1250 if (failed(parseTensorUse(use, state)))
1252 definitions.push_back(use);
1255 if (failed(parser.parseCommaSeparatedListUntil(Token::Kind::equal, parseUse,
1256 /*allowEmptyList=*/true)))
1259 // 2. Parse RHS of `=`, this becomes the expressions from which we emit
1262 auto parseExpr = [&]() -> LogicalResult {
1263 std::unique_ptr<Expression> expr;
1264 if (idx >= definitions.size()) {
1265 parser.emitError("Fewer LHS definitions than RHS expressions");
1268 if (failed(parseExpression(definitions[idx++], expr, state)))
1270 state.expressions.push_back(std::move(expr));
1273 if (failed(parser.parseCommaSeparatedListUntil(
1274 Token::Kind::semicolon, parseExpr, /*allowEmptyList=*/true)))
1276 if (idx != definitions.size()) {
1277 parser.emitError("Fewer RHS expressions than LHS definitions");
1282 // 3.a. Normalize all maps to the proper state.dims and symbols counts.
1283 SmallVector<TensorUse, 4> allUses;
1284 allUses.reserve(registeredTensors.size());
1285 for (auto &def : definitions)
1286 allUses.push_back(def);
1287 for (auto &pExpr : state.expressions)
1288 visitPostorder(*pExpr, [&](const Expression &e) {
1289 if (auto *use = dyn_cast<TensorUse>(&e))
1290 allUses.push_back(*use);
1292 for (auto &use : allUses)
1294 AffineMap::get(state.dims.size(), symbols.size(),
1295 use.indexingMap.getResults(), parser.context);
1297 // 3.b. Traverse definitions
1298 llvm::DenseSet<StringRef> seenDefs;
1299 for (auto &def : definitions) {
1300 if (seenDefs.count(def.tensorId) > 0) {
1301 parser.emitError("Unexpected multi-write to a single tensor");
1304 seenDefs.insert(def.tensorId);
1305 auto tensorIter = registeredTensors.find(def.tensorId);
1306 assert(tensorIter != registeredTensors.end() && "unregistered tensor");
1307 auto &tensor = tensorIter->getValue();
1308 tensor.indexingMap = def.indexingMap;
1309 state.orderedTensorArgs[def] = tensor.index;
1312 bool failed = false;
1313 for (auto &pExpr : state.expressions)
1314 visitPostorder(*pExpr, [&](const Expression &e) {
1315 auto *pUse = dyn_cast<TensorUse>(&e);
1316 if (failed || !pUse)
1319 LLVM_DEBUG(llvm::dbgs()
1320 << "\nuse: " << use.tensorId << " map: " << use.indexingMap);
1321 auto tensorIter = registeredTensors.find(use.tensorId);
1322 assert(tensorIter != registeredTensors.end() && "unregistered tensor");
1323 auto &tensor = tensorIter->getValue();
1324 if (tensor.indexingMap && state.orderedTensorArgs.count(use) == 0) {
1325 LLVM_DEBUG(llvm::dbgs() << "\nexisting: " << tensor.indexingMap);
1327 "Unexpected multi-read of a tensor with different accesses");
1331 seenDefs.insert(use.tensorId);
1332 tensor.indexingMap = use.indexingMap;
1333 state.orderedTensorArgs[use] = tensor.index;
1341 /// Parse and print the information for a ODS def.
1343 /// tensor-def-list ::= tensor-def (`,` tensor-def )*
1345 /// comprehension-list ::= comprehension comprehension*
1347 /// tc-def ::= `def` bare-id `(`tensor-def-list`)` `->` `(` tensor-def-list`)`
1348 /// `{` comprehension-list `}`
1350 /// ods-def ::= `ods_def` `<` bare-id `>` `:` tc-def
1352 /// All the affine-expr in a `tensor-typedef` must be dimensionless (i.e.
1353 /// contain only expressions involving symbols and constants), but can
1354 /// otherwise contain arbitrary affine expressions.
1355 LogicalResult TCParser::parseAndEmitODSDef(llvm::raw_ostream &os) {
1356 if (failed(parser.parseToken(Token::Kind::kw_ods_def,
1357 "expected 'ods_def' to define a TC ODS")) ||
1358 failed(parser.parseToken(Token::Kind::lt, "expected '<'")))
1360 StringRef cppOpName = parser.curToken.getSpelling();
1361 LLVM_DEBUG(llvm::dbgs() << "\n\nStart parsing ODS: " << cppOpName << "\n");
1363 if (failed(parser.parseToken(Token::Kind::id, "expected id")) ||
1364 failed(parser.parseToken(Token::Kind::gt, "expected '>'")) ||
1365 failed(parser.parseToken(Token::Kind::colon, "expected ':'")))
1367 if (failed(parser.parseToken(Token::Kind::kw_def,
1368 "expected 'def' to define a TC")))
1371 StringRef tcName = parser.curToken.getSpelling();
1372 LLVM_DEBUG(llvm::dbgs() << "\n\nStart parsing TC: " << tcName << "\n");
1373 if (failed(parser.parseToken(Token::Kind::id, "expected id")) ||
1374 failed(parser.parseToken(Token::Kind::l_paren, "expected '('")))
1377 auto parseInputDef = [&]() -> LogicalResult {
1378 return parseTensorDef(/*isOutput=*/false);
1380 if (failed(parser.parseCommaSeparatedListUntil(
1381 Token::Kind::r_paren, parseInputDef, /*allowEmptyList=*/false)))
1384 if (failed(parser.parseToken(Token::Kind::minus, "expected '-'")) ||
1385 failed(parser.parseToken(Token::Kind::gt, "expected '>'")) ||
1386 failed(parser.parseToken(Token::Kind::l_paren, "expected '('")))
1388 auto parseOutputDef = [&]() -> LogicalResult {
1389 return parseTensorDef(/*isOutput=*/true);
1391 if (failed(parser.parseCommaSeparatedListUntil(
1392 Token::Kind::r_paren, parseOutputDef, /*allowEmptyList=*/false)))
1395 // Since we don't declare symbols separately, we discover them eagerly: each
1396 // newly encountered id in a tensor shape expression is treated as a new
1397 // symbolic. At this point, all tensors have been parsed and all the symbols
1398 // that could be discovered eagerly are now known. Resize all AffineMaps to
1399 // normalize the number of eagerly discovered symbols.
1400 for (auto &tensor : registeredTensors) {
1401 auto &map = tensor.getValue().shape;
1402 map = AffineMap::get(/*dimCount=*/0, symbols.size(), map.getResults(),
1406 if (failed(parser.parseToken(Token::Kind::l_brace, "expected '{'")))
1409 SmallVector<ComprehensionParsingState, 4> perComprehensionStates;
1410 while (parser.curToken.isNot(Token::Kind::r_brace)) {
1411 perComprehensionStates.push_back(ComprehensionParsingState());
1412 if (failed(parseOneComprehension(cppOpName, tcName,
1413 perComprehensionStates.back())))
1416 parser.parseToken(Token::Kind::r_brace, "expected '}'");
1419 auto nComprehensions = perComprehensionStates.size();
1420 if (nComprehensions != 1) {
1421 parser.emitError("only 1 comprehension supported for now, got: " +
1422 llvm::Twine(nComprehensions));
1426 auto &state = perComprehensionStates.back();
1427 printODS(os, cppOpName, tcName, state);
1431 auto &state = perComprehensionStates.back();
1432 std::string extraMethods;
1433 llvm::raw_string_ostream ss(extraMethods);
1434 printReferenceIterators(ss, cppOpName, state);
1435 printReferenceIndexingMaps(ss, cppOpName, state);
1436 printRegionBuilder(ss, cppOpName, state);
1437 printCanonicalizersAndFolders(ss, cppOpName);
1439 os << extraMethods << "\n";
1445 //===----------------------------------------------------------------------===//
1446 // Printing functions
1447 //===----------------------------------------------------------------------===//
1449 /// Print the ODS class that defines a new `cppOpName` for a `linalgOpName`.
1450 void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
1451 StringRef linalgOpName,
1452 ComprehensionParsingState &state) {
1453 const char *header = R"FMT( def {0} : LinalgStructuredBase_Op<"{1}", [
1454 AttrSizedOperandSegments,
1455 DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
1456 SingleBlockImplicitTerminator<"YieldOp">]> {
1457 let arguments = (ins Variadic<AnyShaped>:$inputs,
1458 Variadic<AnyShaped>:$outputs);
1459 let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
1460 let regions = (region AnyRegion:$region);
1462 let skipDefaultBuilders = 1;
1463 let builders = [ OpBuilderDAG<
1464 (ins "ValueRange":$inputs, "ValueRange":$outputs),
1466 $_state.addOperands(inputs);
1467 $_state.addOperands(outputs);
1468 $_state.addAttribute(
1469 "operand_segment_sizes",
1470 $_builder.getI32VectorAttr({{
1471 static_cast<int32_t>(inputs.size()),
1472 static_cast<int32_t>(outputs.size())}));
1473 buildNamedStructuredOpRegionAndAttributes<{0}>(
1477 TypeRange(outputs));
1479 (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
1480 "ValueRange":$outputs),
1482 $_state.addOperands(inputs);
1483 $_state.addOperands(outputs);
1484 $_state.addTypes(resultTensorTypes);
1485 $_state.addAttribute(
1486 "operand_segment_sizes",
1487 $_builder.getI32VectorAttr({{
1488 static_cast<int32_t>(inputs.size()),
1489 static_cast<int32_t>(outputs.size())}));
1490 buildNamedStructuredOpRegionAndAttributes<{0}>(
1494 TypeRange(outputs));
1496 (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
1497 CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes),
1499 $_state.addOperands(operands);
1500 $_state.addAttributes(attributes);
1501 $_state.addTypes(resultTensorTypes);
1502 (void)$_state.addRegion();
1505 let printer = [{{ return ::printNamedStructuredOp(p, *this); }];
1506 let parser = [{{ return ::parseNamedStructuredOp<{0}>(parser, result); }];
1508 let hasCanonicalizer = 1;
1510 let extraClassDeclaration = [{{
1512 ArrayAttr iterator_types();
1513 ArrayAttr indexing_maps();
1514 static void regionBuilder(Block &block);
1515 static std::function<void(Block &)> getRegionBuilder() {{ return regionBuilder; }
1518 static unsigned getNumRegionArgs() {{ return {4}; }
1519 std::string getLibraryCallName() {{
1520 return generateLibraryCallName(getOperation());
1525 unsigned nInputs = 0, nOutputs = 0;
1526 for (auto &t : registeredTensors) {
1527 if (t.getValue().isOutput)
1533 os << llvm::formatv(header, cppOpName, linalgOpName, nInputs, nOutputs,
1534 state.orderedTensorArgs.size());
1537 /// Print the C++ StructuredOpsInterface impl of `iterator_types`.
1538 void TCParser::printReferenceIterators(llvm::raw_ostream &os,
1539 StringRef cppOpName,
1540 ComprehensionParsingState &state) {
1541 const char *referenceReferenceIteratorsFmt =
1543 ArrayAttr {0}::iterator_types() {
1544 return Builder(getContext()).getStrArrayAttr(SmallVector<StringRef, 8>{{ {1} });
1547 std::string iteratorsStr;
1548 llvm::raw_string_ostream ss(iteratorsStr);
1550 llvm::interleaveComma(
1551 state.dims, ss, [&](std::pair<StringRef, AffineExpr> p) {
1552 bool reduction = false;
1553 for (auto &expr : state.expressions) {
1554 visitPostorder(*expr, [&](const Expression &e) {
1555 if (auto *pTensorExpr = dyn_cast<TensorExpr>(&e)) {
1556 if (pTensorExpr->reductionDimensions.count(pos) > 0)
1563 ss << (reduction ? "getReductionIteratorTypeName()"
1564 : "getParallelIteratorTypeName()");
1569 os << llvm::formatv(referenceReferenceIteratorsFmt, cppOpName, iteratorsStr);
1572 void TCParser::printCanonicalizersAndFolders(llvm::raw_ostream &os,
1573 StringRef cppOpName) {
1574 const char *canonicalizersAndFoldersFmt = R"FMT(
1575 void {0}::getCanonicalizationPatterns(
1576 OwningRewritePatternList &results,
1577 MLIRContext *context) {{
1578 results.insert<EraseDeadLinalgOp>();
1579 results.insert<FoldTensorCastOp>();
1581 LogicalResult {0}::fold(ArrayRef<Attribute>,
1582 SmallVectorImpl<OpFoldResult> &) {{
1583 return foldMemRefCast(*this);
1585 void {0}::getEffects(SmallVectorImpl<
1586 SideEffects::EffectInstance<MemoryEffects::Effect> >&effects) {{
1587 getGenericEffectsImpl(effects,
1588 getOperation()->getResults(), getInputBuffers(), getOutputBuffers());
1590 os << llvm::formatv(canonicalizersAndFoldersFmt, cppOpName);
1593 /// Print the C++ StructuredOpsInterface impl of `referenceIndexingMaps`.
1594 void TCParser::printReferenceIndexingMaps(llvm::raw_ostream &os,
1595 StringRef cppOpName,
1596 ComprehensionParsingState &state) {
1597 // 1. Generic string template for specifying reference indexing maps.
1598 const char *referenceIndexingMapsFmt =
1600 // This is temporary until we transition out of manually specified ops that
1601 // should be auto-generated with linalg-ods-gen.
1602 ArrayAttr {0}::indexing_maps() {
1603 MLIRContext *context = getContext();
1605 bindDims(context, {1});
1606 return Builder(context).getAffineMapArrayAttr({ {2} });
1609 // 2. Print a comma-separated list of identifiers for the AffineExpr in
1610 // `state.dims`. These will replace the `{1}` placeholder in both
1611 // `AffineExpr {1}` and `bindDims(context, {1})` ensuring the AffineExpr
1612 // identifiers are bound in the right order to the proper AffineDimExpr.
1613 std::string dimsStr;
1614 llvm::raw_string_ostream ss(dimsStr);
1615 llvm::interleaveComma(
1617 [&](std::pair<StringRef, AffineExpr> p) { ss << p.second; });
1620 // 3. Print a comma-separated list of AffineMap constructors that use the
1621 // identifiers from 1. The AffineExpr use the common arithmetic operators on
1622 // AffineExpr. These AffineMap constructors will replace the `{2}` placeholder
1623 // in return `SmallVector<AffineMap, 8>{{ {2} };`.
1624 std::string mapsStr;
1625 llvm::raw_string_ostream mapsStringStream(mapsStr);
1626 SmallVector<TensorUse, 4> orderedUses(state.orderedTensorArgs.size());
1627 for (const auto &it : state.orderedTensorArgs)
1628 orderedUses[it.second] = it.first;
1629 llvm::interleaveComma(orderedUses, mapsStringStream, [&](TensorUse u) {
1630 assert(u.indexingMap);
1631 const char *mapFmt = "\n\tAffineMap::get({0}, 0, {1}, context)";
1632 if (u.indexingMap.isEmpty()) {
1633 mapsStringStream << llvm::formatv(mapFmt, state.dims.size(), "context");
1637 std::string exprsStr;
1638 llvm::raw_string_ostream exprsStringStream(exprsStr);
1639 exprsStringStream << "{";
1640 llvm::interleaveComma(u.indexingMap.getResults(), exprsStringStream);
1641 exprsStringStream << "}";
1642 exprsStringStream.flush();
1644 mapsStringStream << llvm::formatv(mapFmt, state.dims.size(), exprsStr);
1646 mapsStringStream.flush();
1648 // 4. Apply format to 1. using 2. and 3.
1649 os << llvm::formatv(referenceIndexingMapsFmt, cppOpName, dimsStr, mapsStr);
1652 /// Print the C++ StructuredOpsInterface impl of `regionBuilder`.
1653 void TCParser::printRegionBuilder(llvm::raw_ostream &os, StringRef cppOpName,
1654 ComprehensionParsingState &state) {
1655 unsigned count = state.orderedTensorArgs.size();
1656 llvm::DenseMap<const TensorExpr *, unsigned> subExprsMap;
1657 std::function<void(llvm::raw_ostream & os, const Expression &)> printExpr;
1658 printExpr = [&](llvm::raw_ostream &os, const Expression &e) -> void {
1659 if (auto *pUse = dyn_cast<TensorUse>(&e)) {
1660 os << "_" << state.orderedTensorArgs.find(*pUse)->second;
1663 auto *pTensorExpr = cast<TensorExpr>(&e);
1664 if (subExprsMap.count(pTensorExpr) > 0) {
1665 os << "_" << subExprsMap[pTensorExpr];
1667 std::string subExprs;
1668 llvm::raw_string_ostream subExprsStringStream(subExprs);
1669 llvm::interleaveComma(pTensorExpr->expressions, subExprsStringStream,
1670 [&](const std::unique_ptr<Expression> &e) {
1671 printExpr(subExprsStringStream, *e);
1673 subExprsStringStream.flush();
1674 const char *tensorExprFmt = "\n Value _{0} = {1}({2});";
1675 os << llvm::formatv(tensorExprFmt, ++count, pTensorExpr->operationName,
1677 subExprsMap[pTensorExpr] = count;
1681 const char *regionBuilderFmt = R"FMT(
1682 void {0}::regionBuilder(Block &block) {
1683 using namespace edsc;
1684 using namespace intrinsics;
1685 auto args = block.getArguments();
1688 (linalg_yield(ValueRange{ {3} }));
1692 std::string valueHandleStr;
1693 llvm::raw_string_ostream valueHandleStringStream(valueHandleStr);
1694 llvm::interleaveComma(
1695 state.orderedTensorArgs, valueHandleStringStream, [&](auto) {
1696 valueHandleStringStream << "_" << idx << "(args[" << idx << "])";
1700 std::string expressionsStr;
1701 llvm::raw_string_ostream expressionStringStream(expressionsStr);
1702 for (auto &expr : state.expressions)
1703 visitPostorder(*expr, [&](const Expression &e) {
1704 if (e.kind == Expression::Kind::TensorExpr)
1705 printExpr(expressionStringStream, e);
1708 std::string yieldStr;
1709 llvm::raw_string_ostream yieldStringStream(yieldStr);
1710 llvm::interleaveComma(state.expressions, yieldStringStream,
1711 [&](const std::unique_ptr<Expression> &e) {
1712 printExpr(yieldStringStream, *e);
1715 valueHandleStringStream.flush();
1716 expressionStringStream.flush();
1717 yieldStringStream.flush();
1719 os << llvm::formatv(regionBuilderFmt, cppOpName, valueHandleStr,
1720 expressionsStr, yieldStr);
1723 /// Iterate over each Tensor Comprehension def.
1724 LogicalResult parseAndEmitAllTensorComprehensions(llvm::raw_ostream &os,
1726 while (parser.curToken.getKind() != Token::Kind::eof) {
1727 TCParser tcParser(parser);
1728 if (failed(tcParser.parseAndEmitODSDef(os)))
1734 int main(int argc, char **argv) {
1735 llvm::cl::ParseCommandLineOptions(argc, argv, "Linalg ODS Gen");
1737 // Set up the input file.
1738 std::string errorMessage;
1739 std::unique_ptr<llvm::MemoryBuffer> file =
1740 mlir::openInputFile(inputFilename, &errorMessage);
1742 llvm::errs() << errorMessage << "\n";
1746 std::unique_ptr<llvm::ToolOutputFile> output =
1747 openOutputFile(outputFilename, &errorMessage);
1749 llvm::errs() << errorMessage << "\n";
1753 // Include the proper Linalg header for end-to-end tblgen testing without
1754 // resorting to non-portable shell manipulations.
1755 if (testEmitIncludeTdHeader)
1756 output->os() << "include \"mlir/Dialect/Linalg/IR/LinalgStructuredOps.td\"";
1758 MLIRContext context;
1759 llvm::SourceMgr mgr;
1760 mgr.AddNewSourceBuffer(std::move(file), llvm::SMLoc());
1761 Parser parser(mgr, &context);
1762 parseAndEmitAllTensorComprehensions(output->os(), parser);