[mlir][PDL] Add support for PDL bytecode and expose PDL support to OwningRewritePatte...
authorRiver Riddle <riddleriver@gmail.com>
Tue, 1 Dec 2020 22:30:18 +0000 (14:30 -0800)
committerRiver Riddle <riddleriver@gmail.com>
Tue, 1 Dec 2020 23:05:50 +0000 (15:05 -0800)
PDL patterns are now supported via a new `PDLPatternModule` class. This class contains a ModuleOp with the pdl::PatternOp operations representing the patterns, as well as a collection of registered C++ functions for native constraints/creations/rewrites/etc. that may be invoked via the pdl patterns. Instances of this class are added to an OwningRewritePatternList in the same fashion as C++ RewritePatterns, i.e. via the `insert` method.

The PDL bytecode is an in-memory representation of the PDL interpreter dialect that can be efficiently interpreted/executed. The representation of the bytecode boils down to a code array(for opcodes/memory locations/etc) and a memory buffer(for storing attributes/operations/values/any other data necessary). The bytecode operations are effectively a 1-1 mapping to the PDLInterp dialect operations, with a few exceptions in cases where the in-memory representation of the bytecode can be more efficient than the MLIR representation. For example, a generic `AreEqual` bytecode op can be used to represent AreEqualOp, CheckAttributeOp, and CheckTypeOp.

The execution of the bytecode is split into two phases: matching and rewriting. When matching, all of the matched patterns are collected to avoid the overhead of re-running parts of the matcher. These matched patterns are then considered alongside the native C++ patterns, which rewrite immediately in-place via `RewritePattern::matchAndRewrite`,  for the given root operation. When a PDL pattern is matched and has the highest benefit, it is passed back to the bytecode to execute its rewriter.

Differential Revision: https://reviews.llvm.org/D89107

23 files changed:
mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
mlir/include/mlir/IR/BlockSupport.h
mlir/include/mlir/IR/Operation.h
mlir/include/mlir/IR/OperationSupport.h
mlir/include/mlir/IR/PatternMatch.h
mlir/include/mlir/IR/StorageUniquerSupport.h
mlir/include/mlir/Rewrite/FrozenRewritePatternList.h
mlir/include/mlir/Rewrite/PatternApplicator.h
mlir/lib/IR/Block.cpp
mlir/lib/IR/Operation.cpp
mlir/lib/IR/PatternMatch.cpp
mlir/lib/Rewrite/ByteCode.cpp [new file with mode: 0644]
mlir/lib/Rewrite/ByteCode.h [new file with mode: 0644]
mlir/lib/Rewrite/CMakeLists.txt
mlir/lib/Rewrite/FrozenRewritePatternList.cpp
mlir/lib/Rewrite/PatternApplicator.cpp
mlir/test/Rewrite/pdl-bytecode.mlir [new file with mode: 0644]
mlir/test/lib/CMakeLists.txt
mlir/test/lib/Rewrite/CMakeLists.txt [new file with mode: 0644]
mlir/test/lib/Rewrite/TestPDLByteCode.cpp [new file with mode: 0644]
mlir/test/lib/Transforms/TestLinalgTransforms.cpp
mlir/tools/mlir-opt/CMakeLists.txt
mlir/tools/mlir-opt/mlir-opt.cpp

index df49eb3..6b11c0d 100644 (file)
@@ -108,7 +108,7 @@ def PDLInterp_ApplyConstraintOp : PDLInterp_PredicateOp<"apply_constraint"> {
     ```mlir
     // Apply `myConstraint` to the entities defined by `input`, `attr`, and
     // `op`.
-    pdl_interp.apply_constraint "myConstraint"[42, "abc", i32](%input, %attr, %op : !pdl.value, !pdl.attribute, !pdl.operation)
+    pdl_interp.apply_constraint "myConstraint"[42, "abc", i32](%input, %attr, %op : !pdl.value, !pdl.attribute, !pdl.operation) -> ^matchDest, ^failureDest
     ```
   }];
 
@@ -316,7 +316,7 @@ def PDLInterp_CheckTypeOp
     Example:
 
     ```mlir
-    pdl_interp.check_type %type is 0 -> ^matchDest, ^failureDest
+    pdl_interp.check_type %type is i32 -> ^matchDest, ^failureDest
     ```
   }];
 
@@ -338,7 +338,7 @@ def PDLInterp_CreateAttributeOp
     Example:
 
     ```mlir
-    pdl_interp.create_attribute 10 : i64
+    %attr = pdl_interp.create_attribute 10 : i64
     ```
   }];
 
@@ -369,7 +369,7 @@ def PDLInterp_CreateNativeOp : PDLInterp_Op<"create_native"> {
     Example:
 
     ```mlir
-    %ret = pdl_interp.create_native "myNativeFunc"[42, "gt"](%arg0, %arg1) : !pdl.attribute
+    %ret = pdl_interp.create_native "myNativeFunc"[42, "gt"](%arg0, %arg1 : !pdl.value, !pdl.value) : !pdl.attribute
     ```
   }];
 
@@ -772,7 +772,7 @@ def PDLInterp_SwitchAttributeOp
     Example:
 
     ```mlir
-    pdl_interp.switch_attribute %attr to [10, true] -> ^10Dest, ^trueDest, ^defaultDest
+    pdl_interp.switch_attribute %attr to [10, true](^10Dest, ^trueDest) -> ^defaultDest
     ```
   }];
   let arguments = (ins PDL_Attribute:$attribute, ArrayAttr:$caseValues);
@@ -837,7 +837,7 @@ def PDLInterp_SwitchOperationNameOp
     Example:
 
     ```mlir
-    pdl_interp.switch_operation_name of %op to ["foo.op", "bar.op"] -> ^fooDest, ^barDest, ^defaultDest
+    pdl_interp.switch_operation_name of %op to ["foo.op", "bar.op"](^fooDest, ^barDest) -> ^defaultDest
     ```
   }];
 
@@ -874,7 +874,7 @@ def PDLInterp_SwitchResultCountOp
     Example:
 
     ```mlir
-    pdl_interp.switch_result_count of %op to [0, 2] -> ^0Dest, ^2Dest, ^defaultDest
+    pdl_interp.switch_result_count of %op to [0, 2](^0Dest, ^2Dest) -> ^defaultDest
     ```
   }];
 
index fc16eff..6cf2df9 100644 (file)
@@ -58,6 +58,7 @@ class SuccessorRange final
           SuccessorRange, BlockOperand *, Block *, Block *, Block *> {
 public:
   using RangeBaseT::RangeBaseT;
+  SuccessorRange();
   SuccessorRange(Block *block);
   SuccessorRange(Operation *term);
 
index 5b3c448..3d5bc66 100644 (file)
@@ -69,6 +69,9 @@ public:
   /// Remove this operation from its parent block and delete it.
   void erase();
 
+  /// Remove the operation from its parent block, but don't delete it.
+  void remove();
+
   /// Create a deep copy of this operation, remapping any operands that use
   /// values outside of the operation using the map that is provided (leaving
   /// them alone if no entry is present).  Replaces references to cloned
index 96d6d11..74899c9 100644 (file)
@@ -349,7 +349,7 @@ public:
   void *getAsOpaquePointer() const {
     return static_cast<void *>(representation.getOpaqueValue());
   }
-  static OperationName getFromOpaquePointer(void *pointer);
+  static OperationName getFromOpaquePointer(const void *pointer);
 
 private:
   RepresentationUnion representation;
index 2158f09..4fdc087 100644 (file)
@@ -10,6 +10,7 @@
 #define MLIR_PATTERNMATCHER_H
 
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
 
 namespace mlir {
 
@@ -226,6 +227,189 @@ template <typename SourceOp> struct OpRewritePattern : public RewritePattern {
 };
 
 //===----------------------------------------------------------------------===//
+// PDLPatternModule
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// PDLValue
+
+/// Storage type of byte-code interpreter values. These are passed to constraint
+/// functions as arguments.
+class PDLValue {
+  /// The internal implementation type when the value is an Attribute,
+  /// Operation*, or Type. See `impl` below for more details.
+  using AttrOpTypeImplT = llvm::PointerUnion<Attribute, Operation *, Type>;
+
+public:
+  PDLValue(const PDLValue &other) : impl(other.impl) {}
+  PDLValue(std::nullptr_t = nullptr) : impl() {}
+  PDLValue(Attribute value) : impl(value) {}
+  PDLValue(Operation *value) : impl(value) {}
+  PDLValue(Type value) : impl(value) {}
+  PDLValue(Value value) : impl(value) {}
+
+  /// Returns true if the type of the held value is `T`.
+  template <typename T>
+  std::enable_if_t<std::is_same<T, Value>::value, bool> isa() const {
+    return impl.is<Value>();
+  }
+  template <typename T>
+  std::enable_if_t<!std::is_same<T, Value>::value, bool> isa() const {
+    auto attrOpTypeImpl = impl.dyn_cast<AttrOpTypeImplT>();
+    return attrOpTypeImpl && attrOpTypeImpl.is<T>();
+  }
+
+  /// Attempt to dynamically cast this value to type `T`, returns null if this
+  /// value is not an instance of `T`.
+  template <typename T>
+  std::enable_if_t<std::is_same<T, Value>::value, T> dyn_cast() const {
+    return impl.dyn_cast<T>();
+  }
+  template <typename T>
+  std::enable_if_t<!std::is_same<T, Value>::value, T> dyn_cast() const {
+    auto attrOpTypeImpl = impl.dyn_cast<AttrOpTypeImplT>();
+    return attrOpTypeImpl && attrOpTypeImpl.dyn_cast<T>();
+  }
+
+  /// Cast this value to type `T`, asserts if this value is not an instance of
+  /// `T`.
+  template <typename T>
+  std::enable_if_t<std::is_same<T, Value>::value, T> cast() const {
+    return impl.get<T>();
+  }
+  template <typename T>
+  std::enable_if_t<!std::is_same<T, Value>::value, T> cast() const {
+    return impl.get<AttrOpTypeImplT>().get<T>();
+  }
+
+  /// Get an opaque pointer to the value.
+  void *getAsOpaquePointer() { return impl.getOpaqueValue(); }
+
+  /// Print this value to the provided output stream.
+  void print(raw_ostream &os);
+
+private:
+  /// The internal opaque representation of a PDLValue. We use a nested
+  /// PointerUnion structure here because `Value` only has 1 low bit
+  /// available, where as the remaining types all have 3.
+  llvm::PointerUnion<AttrOpTypeImplT, Value> impl;
+};
+
+inline raw_ostream &operator<<(raw_ostream &os, PDLValue value) {
+  value.print(os);
+  return os;
+}
+
+//===----------------------------------------------------------------------===//
+// PDLPatternModule
+
+/// A generic PDL pattern constraint function. This function applies a
+/// constraint to a given set of opaque PDLValue entities. The second parameter
+/// is a set of constant value parameters specified in Attribute form. Returns
+/// success if the constraint successfully held, failure otherwise.
+using PDLConstraintFunction = std::function<LogicalResult(
+    ArrayRef<PDLValue>, ArrayAttr, PatternRewriter &)>;
+/// A native PDL creation function. This function creates a new PDLValue given
+/// a set of existing PDL values, a set of constant parameters specified in
+/// Attribute form, and a PatternRewriter. Returns the newly created PDLValue.
+using PDLCreateFunction =
+    std::function<PDLValue(ArrayRef<PDLValue>, ArrayAttr, PatternRewriter &)>;
+/// A native PDL rewrite function. This function rewrites the given root
+/// operation using the provided PatternRewriter. This method is only invoked
+/// when the corresponding match was successful.
+using PDLRewriteFunction = std::function<void(Operation *, ArrayRef<PDLValue>,
+                                              ArrayAttr, PatternRewriter &)>;
+/// A generic PDL pattern constraint function. This function applies a
+/// constraint to a given opaque PDLValue entity. The second parameter is a set
+/// of constant value parameters specified in Attribute form. Returns success if
+/// the constraint successfully held, failure otherwise.
+using PDLSingleEntityConstraintFunction =
+    std::function<LogicalResult(PDLValue, ArrayAttr, PatternRewriter &)>;
+
+/// This class contains all of the necessary data for a set of PDL patterns, or
+/// pattern rewrites specified in the form of the PDL dialect. This PDL module
+/// contained by this pattern may contain any number of `pdl.pattern`
+/// operations.
+class PDLPatternModule {
+public:
+  PDLPatternModule() = default;
+
+  /// Construct a PDL pattern with the given module.
+  PDLPatternModule(OwningModuleRef pdlModule)
+      : pdlModule(std::move(pdlModule)) {}
+
+  /// Merge the state in `other` into this pattern module.
+  void mergeIn(PDLPatternModule &&other);
+
+  /// Return the internal PDL module of this pattern.
+  ModuleOp getModule() { return pdlModule.get(); }
+
+  //===--------------------------------------------------------------------===//
+  // Function Registry
+
+  /// Register a constraint function.
+  void registerConstraintFunction(StringRef name,
+                                  PDLConstraintFunction constraintFn);
+  /// Register a single entity constraint function.
+  template <typename SingleEntityFn>
+  std::enable_if_t<!llvm::is_invocable<SingleEntityFn, ArrayRef<PDLValue>,
+                                       ArrayAttr, PatternRewriter &>::value>
+  registerConstraintFunction(StringRef name, SingleEntityFn &&constraintFn) {
+    registerConstraintFunction(name, [=](ArrayRef<PDLValue> values,
+                                         ArrayAttr constantParams,
+                                         PatternRewriter &rewriter) {
+      assert(values.size() == 1 && "expected values to have a single entity");
+      return constraintFn(values[0], constantParams, rewriter);
+    });
+  }
+
+  /// Register a creation function.
+  void registerCreateFunction(StringRef name, PDLCreateFunction createFn);
+
+  /// Register a rewrite function.
+  void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn);
+
+  /// Return the set of the registered constraint functions.
+  const llvm::StringMap<PDLConstraintFunction> &getConstraintFunctions() const {
+    return constraintFunctions;
+  }
+  llvm::StringMap<PDLConstraintFunction> takeConstraintFunctions() {
+    return constraintFunctions;
+  }
+  /// Return the set of the registered create functions.
+  const llvm::StringMap<PDLCreateFunction> &getCreateFunctions() const {
+    return createFunctions;
+  }
+  llvm::StringMap<PDLCreateFunction> takeCreateFunctions() {
+    return createFunctions;
+  }
+  /// Return the set of the registered rewrite functions.
+  const llvm::StringMap<PDLRewriteFunction> &getRewriteFunctions() const {
+    return rewriteFunctions;
+  }
+  llvm::StringMap<PDLRewriteFunction> takeRewriteFunctions() {
+    return rewriteFunctions;
+  }
+
+  /// Clear out the patterns and functions within this module.
+  void clear() {
+    pdlModule = nullptr;
+    constraintFunctions.clear();
+    createFunctions.clear();
+    rewriteFunctions.clear();
+  }
+
+private:
+  /// The module containing the `pdl.pattern` operations.
+  OwningModuleRef pdlModule;
+
+  /// The external functions referenced from within the PDL module.
+  llvm::StringMap<PDLConstraintFunction> constraintFunctions;
+  llvm::StringMap<PDLCreateFunction> createFunctions;
+  llvm::StringMap<PDLRewriteFunction> rewriteFunctions;
+};
+
+//===----------------------------------------------------------------------===//
 // PatternRewriter
 //===----------------------------------------------------------------------===//
 
@@ -384,28 +568,28 @@ private:
 //===----------------------------------------------------------------------===//
 
 class OwningRewritePatternList {
-  using PatternListT = std::vector<std::unique_ptr<RewritePattern>>;
+  using NativePatternListT = std::vector<std::unique_ptr<RewritePattern>>;
 
 public:
   OwningRewritePatternList() = default;
 
-  /// Construct a OwningRewritePatternList populated with the pattern `t` of
-  /// type `T`.
-  template <typename T>
-  OwningRewritePatternList(T &&t) {
-    patterns.emplace_back(std::make_unique<T>(std::forward<T>(t)));
+  /// Construct a OwningRewritePatternList populated with the given pattern.
+  OwningRewritePatternList(std::unique_ptr<RewritePattern> pattern) {
+    nativePatterns.emplace_back(std::move(pattern));
   }
+  OwningRewritePatternList(PDLPatternModule &&pattern)
+      : pdlPatterns(std::move(pattern)) {}
+
+  /// Return the native patterns held in this list.
+  NativePatternListT &getNativePatterns() { return nativePatterns; }
 
-  PatternListT::iterator begin() { return patterns.begin(); }
-  PatternListT::iterator end() { return patterns.end(); }
-  PatternListT::const_iterator begin() const { return patterns.begin(); }
-  PatternListT::const_iterator end() const { return patterns.end(); }
-  PatternListT::size_type size() const { return patterns.size(); }
-  void clear() { patterns.clear(); }
+  /// Return the PDL patterns held in this list.
+  PDLPatternModule &getPDLPatterns() { return pdlPatterns; }
 
-  /// Take ownership of the patterns held by this list.
-  std::vector<std::unique_ptr<RewritePattern>> takePatterns() {
-    return std::move(patterns);
+  /// Clear out all of the held patterns in this list.
+  void clear() {
+    nativePatterns.clear();
+    pdlPatterns.clear();
   }
 
   //===--------------------------------------------------------------------===//
@@ -419,31 +603,53 @@ public:
             typename... ConstructorArgs,
             typename = std::enable_if_t<sizeof...(Ts) != 0>>
   OwningRewritePatternList &insert(ConstructorArg &&arg,
-                                   ConstructorArgs &&... args) {
+                                   ConstructorArgs &&...args) {
     // The following expands a call to emplace_back for each of the pattern
     // types 'Ts'. This magic is necessary due to a limitation in the places
     // that a parameter pack can be expanded in c++11.
     // FIXME: In c++17 this can be simplified by using 'fold expressions'.
-    (void)std::initializer_list<int>{
-        0, (patterns.emplace_back(std::make_unique<Ts>(arg, args...)), 0)...};
+    (void)std::initializer_list<int>{0, (insertImpl<Ts>(arg, args...), 0)...};
     return *this;
   }
 
   /// Add an instance of each of the pattern types 'Ts'. Return a reference to
   /// `this` for chaining insertions.
   template <typename... Ts> OwningRewritePatternList &insert() {
-    (void)std::initializer_list<int>{
-        0, (patterns.emplace_back(std::make_unique<Ts>()), 0)...};
+    (void)std::initializer_list<int>{0, (insertImpl<Ts>(), 0)...};
     return *this;
   }
 
-  /// Add the given pattern to the pattern list.
-  void insert(std::unique_ptr<RewritePattern> pattern) {
-    patterns.emplace_back(std::move(pattern));
+  /// Add the given native pattern to the pattern list. Return a reference to
+  /// `this` for chaining insertions.
+  OwningRewritePatternList &insert(std::unique_ptr<RewritePattern> pattern) {
+    nativePatterns.emplace_back(std::move(pattern));
+    return *this;
+  }
+
+  /// Add the given PDL pattern to the pattern list. Return a reference to
+  /// `this` for chaining insertions.
+  OwningRewritePatternList &insert(PDLPatternModule &&pattern) {
+    pdlPatterns.mergeIn(std::move(pattern));
+    return *this;
   }
 
 private:
-  PatternListT patterns;
+  /// Add an instance of the pattern type 'T'. Return a reference to `this` for
+  /// chaining insertions.
+  template <typename T, typename... Args>
+  std::enable_if_t<std::is_base_of<RewritePattern, T>::value>
+  insertImpl(Args &&...args) {
+    nativePatterns.emplace_back(
+        std::make_unique<T>(std::forward<Args>(args)...));
+  }
+  template <typename T, typename... Args>
+  std::enable_if_t<std::is_base_of<PDLPatternModule, T>::value>
+  insertImpl(Args &&...args) {
+    pdlPatterns.mergeIn(T(std::forward<Args>(args)...));
+  }
+
+  NativePatternListT nativePatterns;
+  PDLPatternModule pdlPatterns;
 };
 
 } // end namespace mlir
index c0096bb..719bb1a 100644 (file)
@@ -104,6 +104,12 @@ public:
     return UniquerT::template get<ConcreteT>(loc.getContext(), args...);
   }
 
+  /// Get an instance of the concrete type from a void pointer.
+  static ConcreteT getFromOpaquePointer(const void *ptr) {
+    return ptr ? BaseT::getFromOpaquePointer(ptr).template cast<ConcreteT>()
+               : nullptr;
+  }
+
 protected:
   /// Mutate the current storage instance. This will not change the unique key.
   /// The arguments are forwarded to 'ConcreteT::mutate'.
index fb2657d..c2335b9 100644 (file)
 #include "mlir/IR/PatternMatch.h"
 
 namespace mlir {
+namespace detail {
+class PDLByteCode;
+} // end namespace detail
+
 /// This class represents a frozen set of patterns that can be processed by a
 /// pattern applicator. This class is designed to enable caching pattern lists
 /// such that they need not be continuously recomputed.
 class FrozenRewritePatternList {
-  using PatternListT = std::vector<std::unique_ptr<RewritePattern>>;
+  using NativePatternListT = std::vector<std::unique_ptr<RewritePattern>>;
 
 public:
   /// Freeze the patterns held in `patterns`, and take ownership.
   FrozenRewritePatternList(OwningRewritePatternList &&patterns);
+  FrozenRewritePatternList(FrozenRewritePatternList &&patterns);
+  ~FrozenRewritePatternList();
+
+  /// Return the native patterns held by this list.
+  iterator_range<llvm::pointee_iterator<NativePatternListT::const_iterator>>
+  getNativePatterns() const {
+    return llvm::make_pointee_range(nativePatterns);
+  }
 
-  /// Return the patterns held by this list.
-  iterator_range<llvm::pointee_iterator<PatternListT::const_iterator>>
-  getPatterns() const {
-    return llvm::make_pointee_range(patterns);
+  /// Return the compiled PDL bytecode held by this list. Returns null if
+  /// there are no PDL patterns within the list.
+  const detail::PDLByteCode *getPDLByteCode() const {
+    return pdlByteCode.get();
   }
 
 private:
-  /// The patterns held by this list.
-  std::vector<std::unique_ptr<RewritePattern>> patterns;
+  /// The set of.
+  std::vector<std::unique_ptr<RewritePattern>> nativePatterns;
+
+  /// The bytecode containing the compiled PDL patterns.
+  std::unique_ptr<detail::PDLByteCode> pdlByteCode;
 };
 
 } // end namespace mlir
index cb7794b..9d19717 100644 (file)
 namespace mlir {
 class PatternRewriter;
 
+namespace detail {
+class PDLByteCodeMutableState;
+} // end namespace detail
+
 /// This class manages the application of a group of rewrite patterns, with a
 /// user-provided cost model.
 class PatternApplicator {
@@ -29,8 +33,8 @@ public:
   /// `impossibleToMatch`.
   using CostModel = function_ref<PatternBenefit(const Pattern &)>;
 
-  explicit PatternApplicator(const FrozenRewritePatternList &frozenPatternList)
-      : frozenPatternList(frozenPatternList) {}
+  explicit PatternApplicator(const FrozenRewritePatternList &frozenPatternList);
+  ~PatternApplicator();
 
   /// Attempt to match and rewrite the given op with any pattern, allowing a
   /// predicate to decide if a pattern can be applied or not, and hooks for if
@@ -60,16 +64,6 @@ public:
   void walkAllPatterns(function_ref<void(const Pattern &)> walk);
 
 private:
-  /// Attempt to match and rewrite the given op with the given pattern, allowing
-  /// a predicate to decide if a pattern can be applied or not, and hooks for if
-  /// the pattern match was a success or failure.
-  LogicalResult
-  matchAndRewrite(Operation *op, const RewritePattern &pattern,
-                  PatternRewriter &rewriter,
-                  function_ref<bool(const Pattern &)> canApply,
-                  function_ref<void(const Pattern &)> onFailure,
-                  function_ref<LogicalResult(const Pattern &)> onSuccess);
-
   /// The list that owns the patterns used within this applicator.
   const FrozenRewritePatternList &frozenPatternList;
   /// The set of patterns to match for each operation, stable sorted by benefit.
@@ -77,6 +71,8 @@ private:
   /// The set of patterns that may match against any operation type, stable
   /// sorted by benefit.
   SmallVector<const RewritePattern *, 1> anyOpPatterns;
+  /// The mutable state used during execution of the PDL bytecode.
+  std::unique_ptr<detail::PDLByteCodeMutableState> mutableByteCodeState;
 };
 
 } // end namespace mlir
index b9ddabb..79e7daa 100644 (file)
@@ -302,13 +302,15 @@ unsigned PredecessorIterator::getSuccessorIndex() const {
 // SuccessorRange
 //===----------------------------------------------------------------------===//
 
-SuccessorRange::SuccessorRange(Block *block) : SuccessorRange(nullptr, 0) {
+SuccessorRange::SuccessorRange() : SuccessorRange(nullptr, 0) {}
+
+SuccessorRange::SuccessorRange(Block *block) : SuccessorRange() {
   if (Operation *term = block->getTerminator())
     if ((count = term->getNumSuccessors()))
       base = term->getBlockOperands().data();
 }
 
-SuccessorRange::SuccessorRange(Operation *term) : SuccessorRange(nullptr, 0) {
+SuccessorRange::SuccessorRange(Operation *term) : SuccessorRange() {
   if ((count = term->getNumSuccessors()))
     base = term->getBlockOperands().data();
 }
index e725dd8..3037bf0 100644 (file)
@@ -61,8 +61,9 @@ const AbstractOperation *OperationName::getAbstractOperation() const {
   return representation.dyn_cast<const AbstractOperation *>();
 }
 
-OperationName OperationName::getFromOpaquePointer(void *pointer) {
-  return OperationName(RepresentationUnion::getFromOpaqueValue(pointer));
+OperationName OperationName::getFromOpaquePointer(const void *pointer) {
+  return OperationName(
+      RepresentationUnion::getFromOpaqueValue(const_cast<void *>(pointer)));
 }
 
 //===----------------------------------------------------------------------===//
@@ -484,6 +485,12 @@ void Operation::erase() {
     destroy();
 }
 
+/// Remove the operation from its parent block, but don't delete it.
+void Operation::remove() {
+  if (Block *parent = getBlock())
+    parent->getOperations().remove(this);
+}
+
 /// Unlink this operation from its current block and insert it right before
 /// `existingOp` which may be in the same or another block in the same
 /// function.
index edd5e7b..6558fcf 100644 (file)
@@ -70,6 +70,84 @@ LogicalResult RewritePattern::match(Operation *op) const {
 void RewritePattern::anchor() {}
 
 //===----------------------------------------------------------------------===//
+// PDLValue
+//===----------------------------------------------------------------------===//
+
+void PDLValue::print(raw_ostream &os) {
+  if (!impl) {
+    os << "<Null-PDLValue>";
+    return;
+  }
+  if (Value val = impl.dyn_cast<Value>()) {
+    os << val;
+    return;
+  }
+  AttrOpTypeImplT aotImpl = impl.get<AttrOpTypeImplT>();
+  if (Attribute attr = aotImpl.dyn_cast<Attribute>())
+    os << attr;
+  else if (Operation *op = aotImpl.dyn_cast<Operation *>())
+    os << *op;
+  else
+    os << aotImpl.get<Type>();
+}
+
+//===----------------------------------------------------------------------===//
+// PDLPatternModule
+//===----------------------------------------------------------------------===//
+
+void PDLPatternModule::mergeIn(PDLPatternModule &&other) {
+  // Ignore the other module if it has no patterns.
+  if (!other.pdlModule)
+    return;
+  // Steal the other state if we have no patterns.
+  if (!pdlModule) {
+    constraintFunctions = std::move(other.constraintFunctions);
+    createFunctions = std::move(other.createFunctions);
+    rewriteFunctions = std::move(other.rewriteFunctions);
+    pdlModule = std::move(other.pdlModule);
+    return;
+  }
+  // Steal the functions of the other module.
+  for (auto &it : constraintFunctions)
+    registerConstraintFunction(it.first(), std::move(it.second));
+  for (auto &it : createFunctions)
+    registerCreateFunction(it.first(), std::move(it.second));
+  for (auto &it : rewriteFunctions)
+    registerRewriteFunction(it.first(), std::move(it.second));
+
+  // Merge the pattern operations from the other module into this one.
+  Block *block = pdlModule->getBody();
+  block->getTerminator()->erase();
+  block->getOperations().splice(block->end(),
+                                other.pdlModule->getBody()->getOperations());
+}
+
+//===----------------------------------------------------------------------===//
+// Function Registry
+
+void PDLPatternModule::registerConstraintFunction(
+    StringRef name, PDLConstraintFunction constraintFn) {
+  auto it = constraintFunctions.try_emplace(name, std::move(constraintFn));
+  (void)it;
+  assert(it.second &&
+         "constraint with the given name has already been registered");
+}
+void PDLPatternModule::registerCreateFunction(StringRef name,
+                                              PDLCreateFunction createFn) {
+  auto it = createFunctions.try_emplace(name, std::move(createFn));
+  (void)it;
+  assert(it.second && "native create function with the given name has "
+                      "already been registered");
+}
+void PDLPatternModule::registerRewriteFunction(StringRef name,
+                                               PDLRewriteFunction rewriteFn) {
+  auto it = rewriteFunctions.try_emplace(name, std::move(rewriteFn));
+  (void)it;
+  assert(it.second && "native rewrite function with the given name has "
+                      "already been registered");
+}
+
+//===----------------------------------------------------------------------===//
 // PatternRewriter
 //===----------------------------------------------------------------------===//
 
diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp
new file mode 100644 (file)
index 0000000..ae5f322
--- /dev/null
@@ -0,0 +1,1262 @@
+//===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements MLIR to byte-code generation and the interpreter.
+//
+//===----------------------------------------------------------------------===//
+
+#include "ByteCode.h"
+#include "mlir/Analysis/Liveness.h"
+#include "mlir/Dialect/PDL/IR/PDLTypes.h"
+#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/RegionGraphTraits.h"
+#include "llvm/ADT/IntervalMap.h"
+#include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "pdl-bytecode"
+
+using namespace mlir;
+using namespace mlir::detail;
+
+//===----------------------------------------------------------------------===//
+// PDLByteCodePattern
+//===----------------------------------------------------------------------===//
+
+PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp,
+                                              ByteCodeAddr rewriterAddr) {
+  SmallVector<StringRef, 8> generatedOps;
+  if (ArrayAttr generatedOpsAttr = matchOp.generatedOpsAttr())
+    generatedOps =
+        llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>());
+
+  PatternBenefit benefit = matchOp.benefit();
+  MLIRContext *ctx = matchOp.getContext();
+
+  // Check to see if this is pattern matches a specific operation type.
+  if (Optional<StringRef> rootKind = matchOp.rootKind())
+    return PDLByteCodePattern(rewriterAddr, *rootKind, generatedOps, benefit,
+                              ctx);
+  return PDLByteCodePattern(rewriterAddr, generatedOps, benefit, ctx,
+                            MatchAnyOpTypeTag());
+}
+
+//===----------------------------------------------------------------------===//
+// PDLByteCodeMutableState
+//===----------------------------------------------------------------------===//
+
+/// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds
+/// to the position of the pattern within the range returned by
+/// `PDLByteCode::getPatterns`.
+void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex,
+                                                   PatternBenefit benefit) {
+  currentPatternBenefits[patternIndex] = benefit;
+}
+
+//===----------------------------------------------------------------------===//
+// Bytecode OpCodes
+//===----------------------------------------------------------------------===//
+
+namespace {
+enum OpCode : ByteCodeField {
+  /// Apply an externally registered constraint.
+  ApplyConstraint,
+  /// Apply an externally registered rewrite.
+  ApplyRewrite,
+  /// Check if two generic values are equal.
+  AreEqual,
+  /// Unconditional branch.
+  Branch,
+  /// Compare the operand count of an operation with a constant.
+  CheckOperandCount,
+  /// Compare the name of an operation with a constant.
+  CheckOperationName,
+  /// Compare the result count of an operation with a constant.
+  CheckResultCount,
+  /// Invoke a native creation method.
+  CreateNative,
+  /// Create an operation.
+  CreateOperation,
+  /// Erase an operation.
+  EraseOp,
+  /// Terminate a matcher or rewrite sequence.
+  Finalize,
+  /// Get a specific attribute of an operation.
+  GetAttribute,
+  /// Get the type of an attribute.
+  GetAttributeType,
+  /// Get the defining operation of a value.
+  GetDefiningOp,
+  /// Get a specific operand of an operation.
+  GetOperand0,
+  GetOperand1,
+  GetOperand2,
+  GetOperand3,
+  GetOperandN,
+  /// Get a specific result of an operation.
+  GetResult0,
+  GetResult1,
+  GetResult2,
+  GetResult3,
+  GetResultN,
+  /// Get the type of a value.
+  GetValueType,
+  /// Check if a generic value is not null.
+  IsNotNull,
+  /// Record a successful pattern match.
+  RecordMatch,
+  /// Replace an operation.
+  ReplaceOp,
+  /// Compare an attribute with a set of constants.
+  SwitchAttribute,
+  /// Compare the operand count of an operation with a set of constants.
+  SwitchOperandCount,
+  /// Compare the name of an operation with a set of constants.
+  SwitchOperationName,
+  /// Compare the result count of an operation with a set of constants.
+  SwitchResultCount,
+  /// Compare a type with a set of constants.
+  SwitchType,
+};
+
+enum class PDLValueKind { Attribute, Operation, Type, Value };
+} // end anonymous namespace
+
+//===----------------------------------------------------------------------===//
+// ByteCode Generation
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// Generator
+
+namespace {
+struct ByteCodeWriter;
+
+/// This class represents the main generator for the pattern bytecode.
+class Generator {
+public:
+  Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData,
+            SmallVectorImpl<ByteCodeField> &matcherByteCode,
+            SmallVectorImpl<ByteCodeField> &rewriterByteCode,
+            SmallVectorImpl<PDLByteCodePattern> &patterns,
+            ByteCodeField &maxValueMemoryIndex,
+            llvm::StringMap<PDLConstraintFunction> &constraintFns,
+            llvm::StringMap<PDLCreateFunction> &createFns,
+            llvm::StringMap<PDLRewriteFunction> &rewriteFns)
+      : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
+        rewriterByteCode(rewriterByteCode), patterns(patterns),
+        maxValueMemoryIndex(maxValueMemoryIndex) {
+    for (auto it : llvm::enumerate(constraintFns))
+      constraintToMemIndex.try_emplace(it.value().first(), it.index());
+    for (auto it : llvm::enumerate(createFns))
+      nativeCreateToMemIndex.try_emplace(it.value().first(), it.index());
+    for (auto it : llvm::enumerate(rewriteFns))
+      externalRewriterToMemIndex.try_emplace(it.value().first(), it.index());
+  }
+
+  /// Generate the bytecode for the given PDL interpreter module.
+  void generate(ModuleOp module);
+
+  /// Return the memory index to use for the given value.
+  ByteCodeField &getMemIndex(Value value) {
+    assert(valueToMemIndex.count(value) &&
+           "expected memory index to be assigned");
+    return valueToMemIndex[value];
+  }
+
+  /// Return an index to use when referring to the given data that is uniqued in
+  /// the MLIR context.
+  template <typename T>
+  std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &>
+  getMemIndex(T val) {
+    const void *opaqueVal = val.getAsOpaquePointer();
+
+    // Get or insert a reference to this value.
+    auto it = uniquedDataToMemIndex.try_emplace(
+        opaqueVal, maxValueMemoryIndex + uniquedData.size());
+    if (it.second)
+      uniquedData.push_back(opaqueVal);
+    return it.first->second;
+  }
+
+private:
+  /// Allocate memory indices for the results of operations within the matcher
+  /// and rewriters.
+  void allocateMemoryIndices(FuncOp matcherFunc, ModuleOp rewriterModule);
+
+  /// Generate the bytecode for the given operation.
+  void generate(Operation *op, ByteCodeWriter &writer);
+  void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer);
+  void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer);
+  void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer);
+  void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer);
+  void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer);
+  void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer);
+  void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer);
+  void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer);
+  void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer);
+  void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
+  void generate(pdl_interp::CreateNativeOp op, ByteCodeWriter &writer);
+  void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
+  void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
+  void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
+  void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer);
+  void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer);
+  void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer);
+  void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer);
+  void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer);
+  void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer);
+  void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
+  void generate(pdl_interp::InferredTypeOp op, ByteCodeWriter &writer);
+  void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
+  void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer);
+  void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer);
+  void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer);
+  void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer);
+  void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer);
+  void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer);
+  void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer);
+
+  /// Mapping from value to its corresponding memory index.
+  DenseMap<Value, ByteCodeField> valueToMemIndex;
+
+  /// Mapping from the name of an externally registered rewrite to its index in
+  /// the bytecode registry.
+  llvm::StringMap<ByteCodeField> externalRewriterToMemIndex;
+
+  /// Mapping from the name of an externally registered constraint to its index
+  /// in the bytecode registry.
+  llvm::StringMap<ByteCodeField> constraintToMemIndex;
+
+  /// Mapping from the name of an externally registered creation method to its
+  /// index in the bytecode registry.
+  llvm::StringMap<ByteCodeField> nativeCreateToMemIndex;
+
+  /// Mapping from rewriter function name to the bytecode address of the
+  /// rewriter function in byte.
+  llvm::StringMap<ByteCodeAddr> rewriterToAddr;
+
+  /// Mapping from a uniqued storage object to its memory index within
+  /// `uniquedData`.
+  DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex;
+
+  /// The current MLIR context.
+  MLIRContext *ctx;
+
+  /// Data of the ByteCode class to be populated.
+  std::vector<const void *> &uniquedData;
+  SmallVectorImpl<ByteCodeField> &matcherByteCode;
+  SmallVectorImpl<ByteCodeField> &rewriterByteCode;
+  SmallVectorImpl<PDLByteCodePattern> &patterns;
+  ByteCodeField &maxValueMemoryIndex;
+};
+
+/// This class provides utilities for writing a bytecode stream.
+struct ByteCodeWriter {
+  ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator)
+      : bytecode(bytecode), generator(generator) {}
+
+  /// Append a field to the bytecode.
+  void append(ByteCodeField field) { bytecode.push_back(field); }
+
+  /// Append an address to the bytecode.
+  void append(ByteCodeAddr field) {
+    static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
+                  "unexpected ByteCode address size");
+
+    ByteCodeField fieldParts[2];
+    std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr));
+    bytecode.append({fieldParts[0], fieldParts[1]});
+  }
+
+  /// Append a successor range to the bytecode, the exact address will need to
+  /// be resolved later.
+  void append(SuccessorRange successors) {
+    // Add back references to the any successors so that the address can be
+    // resolved later.
+    for (Block *successor : successors) {
+      unresolvedSuccessorRefs[successor].push_back(bytecode.size());
+      append(ByteCodeAddr(0));
+    }
+  }
+
+  /// Append a range of values that will be read as generic PDLValues.
+  void appendPDLValueList(OperandRange values) {
+    bytecode.push_back(values.size());
+    for (Value value : values) {
+      // Append the type of the value in addition to the value itself.
+      PDLValueKind kind =
+          TypeSwitch<Type, PDLValueKind>(value.getType())
+              .Case<pdl::AttributeType>(
+                  [](Type) { return PDLValueKind::Attribute; })
+              .Case<pdl::OperationType>(
+                  [](Type) { return PDLValueKind::Operation; })
+              .Case<pdl::TypeType>([](Type) { return PDLValueKind::Type; })
+              .Case<pdl::ValueType>([](Type) { return PDLValueKind::Value; });
+      bytecode.push_back(static_cast<ByteCodeField>(kind));
+      append(value);
+    }
+  }
+
+  /// Check if the given class `T` has an iterator type.
+  template <typename T, typename... Args>
+  using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer());
+
+  /// Append a value that will be stored in a memory slot and not inline within
+  /// the bytecode.
+  template <typename T>
+  std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value ||
+                   std::is_pointer<T>::value>
+  append(T value) {
+    bytecode.push_back(generator.getMemIndex(value));
+  }
+
+  /// Append a range of values.
+  template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>>
+  std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value>
+  append(T range) {
+    bytecode.push_back(llvm::size(range));
+    for (auto it : range)
+      append(it);
+  }
+
+  /// Append a variadic number of fields to the bytecode.
+  template <typename FieldTy, typename Field2Ty, typename... FieldTys>
+  void append(FieldTy field, Field2Ty field2, FieldTys... fields) {
+    append(field);
+    append(field2, fields...);
+  }
+
+  /// Successor references in the bytecode that have yet to be resolved.
+  DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs;
+
+  /// The underlying bytecode buffer.
+  SmallVectorImpl<ByteCodeField> &bytecode;
+
+  /// The main generator producing PDL.
+  Generator &generator;
+};
+} // end anonymous namespace
+
+void Generator::generate(ModuleOp module) {
+  FuncOp matcherFunc = module.lookupSymbol<FuncOp>(
+      pdl_interp::PDLInterpDialect::getMatcherFunctionName());
+  ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>(
+      pdl_interp::PDLInterpDialect::getRewriterModuleName());
+  assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module");
+
+  // Allocate memory indices for the results of operations within the matcher
+  // and rewriters.
+  allocateMemoryIndices(matcherFunc, rewriterModule);
+
+  // Generate code for the rewriter functions.
+  ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this);
+  for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) {
+    rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size());
+    for (Operation &op : rewriterFunc.getOps())
+      generate(&op, rewriterByteCodeWriter);
+  }
+  assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() &&
+         "unexpected branches in rewriter function");
+
+  // Generate code for the matcher function.
+  DenseMap<Block *, ByteCodeAddr> blockToAddr;
+  llvm::ReversePostOrderTraversal<Region *> rpot(&matcherFunc.getBody());
+  ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this);
+  for (Block *block : rpot) {
+    // Keep track of where this block begins within the matcher function.
+    blockToAddr.try_emplace(block, matcherByteCode.size());
+    for (Operation &op : *block)
+      generate(&op, matcherByteCodeWriter);
+  }
+
+  // Resolve successor references in the matcher.
+  for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) {
+    ByteCodeAddr addr = blockToAddr[it.first];
+    for (unsigned offsetToFix : it.second)
+      std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr));
+  }
+}
+
+void Generator::allocateMemoryIndices(FuncOp matcherFunc,
+                                      ModuleOp rewriterModule) {
+  // Rewriters use simplistic allocation scheme that simply assigns an index to
+  // each result.
+  for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) {
+    ByteCodeField index = 0;
+    for (BlockArgument arg : rewriterFunc.getArguments())
+      valueToMemIndex.try_emplace(arg, index++);
+    rewriterFunc.getBody().walk([&](Operation *op) {
+      for (Value result : op->getResults())
+        valueToMemIndex.try_emplace(result, index++);
+    });
+    if (index > maxValueMemoryIndex)
+      maxValueMemoryIndex = index;
+  }
+
+  // The matcher function uses a more sophisticated numbering that tries to
+  // minimize the number of memory indices assigned. This is done by determining
+  // a live range of the values within the matcher, then the allocation is just
+  // finding the minimal number of overlapping live ranges. This is essentially
+  // a simplified form of register allocation where we don't necessarily have a
+  // limited number of registers, but we still want to minimize the number used.
+  DenseMap<Operation *, ByteCodeField> opToIndex;
+  matcherFunc.getBody().walk([&](Operation *op) {
+    opToIndex.insert(std::make_pair(op, opToIndex.size()));
+  });
+
+  // Liveness info for each of the defs within the matcher.
+  using LivenessSet = llvm::IntervalMap<ByteCodeField, char, 16>;
+  LivenessSet::Allocator allocator;
+  DenseMap<Value, LivenessSet> valueDefRanges;
+
+  // Assign the root operation being matched to slot 0.
+  BlockArgument rootOpArg = matcherFunc.getArgument(0);
+  valueToMemIndex[rootOpArg] = 0;
+
+  // Walk each of the blocks, computing the def interval that the value is used.
+  Liveness matcherLiveness(matcherFunc);
+  for (Block &block : matcherFunc.getBody()) {
+    const LivenessBlockInfo *info = matcherLiveness.getLiveness(&block);
+    assert(info && "expected liveness info for block");
+    auto processValue = [&](Value value, Operation *firstUseOrDef) {
+      // We don't need to process the root op argument, this value is always
+      // assigned to the first memory slot.
+      if (value == rootOpArg)
+        return;
+
+      // Set indices for the range of this block that the value is used.
+      auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first;
+      defRangeIt->second.insert(
+          opToIndex[firstUseOrDef],
+          opToIndex[info->getEndOperation(value, firstUseOrDef)],
+          /*dummyValue*/ 0);
+    };
+
+    // Process the live-ins of this block.
+    for (Value liveIn : info->in())
+      processValue(liveIn, &block.front());
+
+    // Process any new defs within this block.
+    for (Operation &op : block)
+      for (Value result : op.getResults())
+        processValue(result, &op);
+  }
+
+  // Greedily allocate memory slots using the computed def live ranges.
+  std::vector<LivenessSet> allocatedIndices;
+  for (auto &defIt : valueDefRanges) {
+    ByteCodeField &memIndex = valueToMemIndex[defIt.first];
+    LivenessSet &defSet = defIt.second;
+
+    // Try to allocate to an existing index.
+    for (auto existingIndexIt : llvm::enumerate(allocatedIndices)) {
+      LivenessSet &existingIndex = existingIndexIt.value();
+      llvm::IntervalMapOverlaps<LivenessSet, LivenessSet> overlaps(
+          defIt.second, existingIndex);
+      if (overlaps.valid())
+        continue;
+      // Union the range of the def within the existing index.
+      for (auto it = defSet.begin(), e = defSet.end(); it != e; ++it)
+        existingIndex.insert(it.start(), it.stop(), /*dummyValue*/ 0);
+      memIndex = existingIndexIt.index() + 1;
+    }
+
+    // If no existing index could be used, add a new one.
+    if (memIndex == 0) {
+      allocatedIndices.emplace_back(allocator);
+      for (auto it = defSet.begin(), e = defSet.end(); it != e; ++it)
+        allocatedIndices.back().insert(it.start(), it.stop(), /*dummyValue*/ 0);
+      memIndex = allocatedIndices.size();
+    }
+  }
+
+  // Update the max number of indices.
+  ByteCodeField numMatcherIndices = allocatedIndices.size() + 1;
+  if (numMatcherIndices > maxValueMemoryIndex)
+    maxValueMemoryIndex = numMatcherIndices;
+}
+
+void Generator::generate(Operation *op, ByteCodeWriter &writer) {
+  TypeSwitch<Operation *>(op)
+      .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp,
+            pdl_interp::AreEqualOp, pdl_interp::BranchOp,
+            pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp,
+            pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
+            pdl_interp::CheckTypeOp, pdl_interp::CreateAttributeOp,
+            pdl_interp::CreateNativeOp, pdl_interp::CreateOperationOp,
+            pdl_interp::CreateTypeOp, pdl_interp::EraseOp,
+            pdl_interp::FinalizeOp, pdl_interp::GetAttributeOp,
+            pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp,
+            pdl_interp::GetOperandOp, pdl_interp::GetResultOp,
+            pdl_interp::GetValueTypeOp, pdl_interp::InferredTypeOp,
+            pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp,
+            pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp,
+            pdl_interp::SwitchTypeOp, pdl_interp::SwitchOperandCountOp,
+            pdl_interp::SwitchOperationNameOp, pdl_interp::SwitchResultCountOp>(
+          [&](auto interpOp) { this->generate(interpOp, writer); })
+      .Default([](Operation *) {
+        llvm_unreachable("unknown `pdl_interp` operation");
+      });
+}
+
+void Generator::generate(pdl_interp::ApplyConstraintOp op,
+                         ByteCodeWriter &writer) {
+  assert(constraintToMemIndex.count(op.name()) &&
+         "expected index for constraint function");
+  writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.name()],
+                op.constParamsAttr());
+  writer.appendPDLValueList(op.args());
+  writer.append(op.getSuccessors());
+}
+void Generator::generate(pdl_interp::ApplyRewriteOp op,
+                         ByteCodeWriter &writer) {
+  assert(externalRewriterToMemIndex.count(op.name()) &&
+         "expected index for rewrite function");
+  writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.name()],
+                op.constParamsAttr(), op.root());
+  writer.appendPDLValueList(op.args());
+}
+void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) {
+  writer.append(OpCode::AreEqual, op.lhs(), op.rhs(), op.getSuccessors());
+}
+void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) {
+  writer.append(OpCode::Branch, SuccessorRange(op));
+}
+void Generator::generate(pdl_interp::CheckAttributeOp op,
+                         ByteCodeWriter &writer) {
+  writer.append(OpCode::AreEqual, op.attribute(), op.constantValue(),
+                op.getSuccessors());
+}
+void Generator::generate(pdl_interp::CheckOperandCountOp op,
+                         ByteCodeWriter &writer) {
+  writer.append(OpCode::CheckOperandCount, op.operation(), op.count(),
+                op.getSuccessors());
+}
+void Generator::generate(pdl_interp::CheckOperationNameOp op,
+                         ByteCodeWriter &writer) {
+  writer.append(OpCode::CheckOperationName, op.operation(),
+                OperationName(op.name(), ctx), op.getSuccessors());
+}
+void Generator::generate(pdl_interp::CheckResultCountOp op,
+                         ByteCodeWriter &writer) {
+  writer.append(OpCode::CheckResultCount, op.operation(), op.count(),
+                op.getSuccessors());
+}
+void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) {
+  writer.append(OpCode::AreEqual, op.value(), op.type(), op.getSuccessors());
+}
+void Generator::generate(pdl_interp::CreateAttributeOp op,
+                         ByteCodeWriter &writer) {
+  // Simply repoint the memory index of the result to the constant.
+  getMemIndex(op.attribute()) = getMemIndex(op.value());
+}
+void Generator::generate(pdl_interp::CreateNativeOp op,
+                         ByteCodeWriter &writer) {
+  assert(nativeCreateToMemIndex.count(op.name()) &&
+         "expected index for creation function");
+  writer.append(OpCode::CreateNative, nativeCreateToMemIndex[op.name()],
+                op.result(), op.constParamsAttr());
+  writer.appendPDLValueList(op.args());
+}
+void Generator::generate(pdl_interp::CreateOperationOp op,
+                         ByteCodeWriter &writer) {
+  writer.append(OpCode::CreateOperation, op.operation(),
+                OperationName(op.name(), ctx), op.operands());
+
+  // Add the attributes.
+  OperandRange attributes = op.attributes();
+  writer.append(static_cast<ByteCodeField>(attributes.size()));
+  for (auto it : llvm::zip(op.attributeNames(), op.attributes())) {
+    writer.append(
+        Identifier::get(std::get<0>(it).cast<StringAttr>().getValue(), ctx),
+        std::get<1>(it));
+  }
+  writer.append(op.types());
+}
+void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
+  // Simply repoint the memory index of the result to the constant.
+  getMemIndex(op.result()) = getMemIndex(op.value());
+}
+void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
+  writer.append(OpCode::EraseOp, op.operation());
+}
+void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) {
+  writer.append(OpCode::Finalize);
+}
+void Generator::generate(pdl_interp::GetAttributeOp op,
+                         ByteCodeWriter &writer) {
+  writer.append(OpCode::GetAttribute, op.attribute(), op.operation(),
+                Identifier::get(op.name(), ctx));
+}
+void Generator::generate(pdl_interp::GetAttributeTypeOp op,
+                         ByteCodeWriter &writer) {
+  writer.append(OpCode::GetAttributeType, op.result(), op.value());
+}
+void Generator::generate(pdl_interp::GetDefiningOpOp op,
+                         ByteCodeWriter &writer) {
+  writer.append(OpCode::GetDefiningOp, op.operation(), op.value());
+}
+void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) {
+  uint32_t index = op.index();
+  if (index < 4)
+    writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index));
+  else
+    writer.append(OpCode::GetOperandN, index);
+  writer.append(op.operation(), op.value());
+}
+void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) {
+  uint32_t index = op.index();
+  if (index < 4)
+    writer.append(static_cast<OpCode>(OpCode::GetResult0 + index));
+  else
+    writer.append(OpCode::GetResultN, index);
+  writer.append(op.operation(), op.value());
+}
+void Generator::generate(pdl_interp::GetValueTypeOp op,
+                         ByteCodeWriter &writer) {
+  writer.append(OpCode::GetValueType, op.result(), op.value());
+}
+void Generator::generate(pdl_interp::InferredTypeOp op,
+                         ByteCodeWriter &writer) {
+  // InferType maps to a null type as a marker for inferring a result type.
+  getMemIndex(op.type()) = getMemIndex(Type());
+}
+void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
+  writer.append(OpCode::IsNotNull, op.value(), op.getSuccessors());
+}
+void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
+  ByteCodeField patternIndex = patterns.size();
+  patterns.emplace_back(PDLByteCodePattern::create(
+      op, rewriterToAddr[op.rewriter().getLeafReference()]));
+  writer.append(OpCode::RecordMatch, patternIndex, SuccessorRange(op),
+                op.matchedOps(), op.inputs());
+}
+void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) {
+  writer.append(OpCode::ReplaceOp, op.operation(), op.replValues());
+}
+void Generator::generate(pdl_interp::SwitchAttributeOp op,
+                         ByteCodeWriter &writer) {
+  writer.append(OpCode::SwitchAttribute, op.attribute(), op.caseValuesAttr(),
+                op.getSuccessors());
+}
+void Generator::generate(pdl_interp::SwitchOperandCountOp op,
+                         ByteCodeWriter &writer) {
+  writer.append(OpCode::SwitchOperandCount, op.operation(), op.caseValuesAttr(),
+                op.getSuccessors());
+}
+void Generator::generate(pdl_interp::SwitchOperationNameOp op,
+                         ByteCodeWriter &writer) {
+  auto cases = llvm::map_range(op.caseValuesAttr(), [&](Attribute attr) {
+    return OperationName(attr.cast<StringAttr>().getValue(), ctx);
+  });
+  writer.append(OpCode::SwitchOperationName, op.operation(), cases,
+                op.getSuccessors());
+}
+void Generator::generate(pdl_interp::SwitchResultCountOp op,
+                         ByteCodeWriter &writer) {
+  writer.append(OpCode::SwitchResultCount, op.operation(), op.caseValuesAttr(),
+                op.getSuccessors());
+}
+void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) {
+  writer.append(OpCode::SwitchType, op.value(), op.caseValuesAttr(),
+                op.getSuccessors());
+}
+
+//===----------------------------------------------------------------------===//
+// PDLByteCode
+//===----------------------------------------------------------------------===//
+
+PDLByteCode::PDLByteCode(ModuleOp module,
+                         llvm::StringMap<PDLConstraintFunction> constraintFns,
+                         llvm::StringMap<PDLCreateFunction> createFns,
+                         llvm::StringMap<PDLRewriteFunction> rewriteFns) {
+  Generator generator(module.getContext(), uniquedData, matcherByteCode,
+                      rewriterByteCode, patterns, maxValueMemoryIndex,
+                      constraintFns, createFns, rewriteFns);
+  generator.generate(module);
+
+  // Initialize the external functions.
+  for (auto &it : constraintFns)
+    constraintFunctions.push_back(std::move(it.second));
+  for (auto &it : createFns)
+    createFunctions.push_back(std::move(it.second));
+  for (auto &it : rewriteFns)
+    rewriteFunctions.push_back(std::move(it.second));
+}
+
+/// Initialize the given state such that it can be used to execute the current
+/// bytecode.
+void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const {
+  state.memory.resize(maxValueMemoryIndex, nullptr);
+  state.currentPatternBenefits.reserve(patterns.size());
+  for (const PDLByteCodePattern &pattern : patterns)
+    state.currentPatternBenefits.push_back(pattern.getBenefit());
+}
+
+//===----------------------------------------------------------------------===//
+// ByteCode Execution
+
+namespace {
+/// This class provides support for executing a bytecode stream.
+class ByteCodeExecutor {
+public:
+  ByteCodeExecutor(const ByteCodeField *curCodeIt,
+                   MutableArrayRef<const void *> memory,
+                   ArrayRef<const void *> uniquedMemory,
+                   ArrayRef<ByteCodeField> code,
+                   ArrayRef<PatternBenefit> currentPatternBenefits,
+                   ArrayRef<PDLByteCodePattern> patterns,
+                   ArrayRef<PDLConstraintFunction> constraintFunctions,
+                   ArrayRef<PDLCreateFunction> createFunctions,
+                   ArrayRef<PDLRewriteFunction> rewriteFunctions)
+      : curCodeIt(curCodeIt), memory(memory), uniquedMemory(uniquedMemory),
+        code(code), currentPatternBenefits(currentPatternBenefits),
+        patterns(patterns), constraintFunctions(constraintFunctions),
+        createFunctions(createFunctions), rewriteFunctions(rewriteFunctions) {}
+
+  /// Start executing the code at the current bytecode index. `matches` is an
+  /// optional field provided when this function is executed in a matching
+  /// context.
+  void execute(PatternRewriter &rewriter,
+               SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr,
+               Optional<Location> mainRewriteLoc = {});
+
+private:
+  /// Read a value from the bytecode buffer, optionally skipping a certain
+  /// number of prefix values. These methods always update the buffer to point
+  /// to the next field after the read data.
+  template <typename T = ByteCodeField>
+  T read(size_t skipN = 0) {
+    curCodeIt += skipN;
+    return readImpl<T>();
+  }
+  ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); }
+
+  /// Read a list of values from the bytecode buffer.
+  template <typename ValueT, typename T>
+  void readList(SmallVectorImpl<T> &list) {
+    list.clear();
+    for (unsigned i = 0, e = read(); i != e; ++i)
+      list.push_back(read<ValueT>());
+  }
+
+  /// Jump to a specific successor based on a predicate value.
+  void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); }
+  /// Jump to a specific successor based on a destination index.
+  void selectJump(size_t destIndex) {
+    curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)];
+  }
+
+  /// Handle a switch operation with the provided value and cases.
+  template <typename T, typename RangeT>
+  void handleSwitch(const T &value, RangeT &&cases) {
+    LLVM_DEBUG({
+      llvm::dbgs() << "  * Value: " << value << "\n"
+                   << "  * Cases: ";
+      llvm::interleaveComma(cases, llvm::dbgs());
+      llvm::dbgs() << "\n\n";
+    });
+
+    // Check to see if the attribute value is within the case list. Jump to
+    // the correct successor index based on the result.
+    auto it = llvm::find(cases, value);
+    selectJump(it == cases.end() ? size_t(0) : ((it - cases.begin()) + 1));
+  }
+
+  /// Internal implementation of reading various data types from the bytecode
+  /// stream.
+  template <typename T>
+  const void *readFromMemory() {
+    size_t index = *curCodeIt++;
+
+    // If this type is an SSA value, it can only be stored in non-const memory.
+    if (llvm::is_one_of<T, Operation *, Value>::value || index < memory.size())
+      return memory[index];
+
+    // Otherwise, if this index is not inbounds it is uniqued.
+    return uniquedMemory[index - memory.size()];
+  }
+  template <typename T>
+  std::enable_if_t<std::is_pointer<T>::value, T> readImpl() {
+    return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>()));
+  }
+  template <typename T>
+  std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value,
+                   T>
+  readImpl() {
+    return T(T::getFromOpaquePointer(readFromMemory<T>()));
+  }
+  template <typename T>
+  std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() {
+    switch (static_cast<PDLValueKind>(read())) {
+    case PDLValueKind::Attribute:
+      return read<Attribute>();
+    case PDLValueKind::Operation:
+      return read<Operation *>();
+    case PDLValueKind::Type:
+      return read<Type>();
+    case PDLValueKind::Value:
+      return read<Value>();
+    }
+  }
+  template <typename T>
+  std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() {
+    static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
+                  "unexpected ByteCode address size");
+    ByteCodeAddr result;
+    std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr));
+    curCodeIt += 2;
+    return result;
+  }
+  template <typename T>
+  std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() {
+    return *curCodeIt++;
+  }
+
+  /// The underlying bytecode buffer.
+  const ByteCodeField *curCodeIt;
+
+  /// The current execution memory.
+  MutableArrayRef<const void *> memory;
+
+  /// References to ByteCode data necessary for execution.
+  ArrayRef<const void *> uniquedMemory;
+  ArrayRef<ByteCodeField> code;
+  ArrayRef<PatternBenefit> currentPatternBenefits;
+  ArrayRef<PDLByteCodePattern> patterns;
+  ArrayRef<PDLConstraintFunction> constraintFunctions;
+  ArrayRef<PDLCreateFunction> createFunctions;
+  ArrayRef<PDLRewriteFunction> rewriteFunctions;
+};
+} // end anonymous namespace
+
+void ByteCodeExecutor::execute(
+    PatternRewriter &rewriter,
+    SmallVectorImpl<PDLByteCode::MatchResult> *matches,
+    Optional<Location> mainRewriteLoc) {
+  while (true) {
+    OpCode opCode = static_cast<OpCode>(read());
+    switch (opCode) {
+    case ApplyConstraint: {
+      LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n");
+      const PDLConstraintFunction &constraintFn = constraintFunctions[read()];
+      ArrayAttr constParams = read<ArrayAttr>();
+      SmallVector<PDLValue, 16> args;
+      readList<PDLValue>(args);
+      LLVM_DEBUG({
+        llvm::dbgs() << "  * Arguments: ";
+        llvm::interleaveComma(args, llvm::dbgs());
+        llvm::dbgs() << "\n  * Parameters: " << constParams << "\n\n";
+      });
+
+      // Invoke the constraint and jump to the proper destination.
+      selectJump(succeeded(constraintFn(args, constParams, rewriter)));
+      break;
+    }
+    case ApplyRewrite: {
+      LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n");
+      const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
+      ArrayAttr constParams = read<ArrayAttr>();
+      Operation *root = read<Operation *>();
+      SmallVector<PDLValue, 16> args;
+      readList<PDLValue>(args);
+
+      LLVM_DEBUG({
+        llvm::dbgs() << "  * Root: " << *root << "\n"
+                     << "  * Arguments: ";
+        llvm::interleaveComma(args, llvm::dbgs());
+        llvm::dbgs() << "\n  * Parameters: " << constParams << "\n\n";
+      });
+      rewriteFn(root, args, constParams, rewriter);
+      break;
+    }
+    case AreEqual: {
+      LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
+      const void *lhs = read<const void *>();
+      const void *rhs = read<const void *>();
+
+      LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n\n");
+      selectJump(lhs == rhs);
+      break;
+    }
+    case Branch: {
+      LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n\n");
+      curCodeIt = &code[read<ByteCodeAddr>()];
+      break;
+    }
+    case CheckOperandCount: {
+      LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n");
+      Operation *op = read<Operation *>();
+      uint32_t expectedCount = read<uint32_t>();
+
+      LLVM_DEBUG(llvm::dbgs() << "  * Found: " << op->getNumOperands() << "\n"
+                              << "  * Expected: " << expectedCount << "\n\n");
+      selectJump(op->getNumOperands() == expectedCount);
+      break;
+    }
+    case CheckOperationName: {
+      LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n");
+      Operation *op = read<Operation *>();
+      OperationName expectedName = read<OperationName>();
+
+      LLVM_DEBUG(llvm::dbgs()
+                 << "  * Found: \"" << op->getName() << "\"\n"
+                 << "  * Expected: \"" << expectedName << "\"\n\n");
+      selectJump(op->getName() == expectedName);
+      break;
+    }
+    case CheckResultCount: {
+      LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n");
+      Operation *op = read<Operation *>();
+      uint32_t expectedCount = read<uint32_t>();
+
+      LLVM_DEBUG(llvm::dbgs() << "  * Found: " << op->getNumResults() << "\n"
+                              << "  * Expected: " << expectedCount << "\n\n");
+      selectJump(op->getNumResults() == expectedCount);
+      break;
+    }
+    case CreateNative: {
+      LLVM_DEBUG(llvm::dbgs() << "Executing CreateNative:\n");
+      const PDLCreateFunction &createFn = createFunctions[read()];
+      ByteCodeField resultIndex = read();
+      ArrayAttr constParams = read<ArrayAttr>();
+      SmallVector<PDLValue, 16> args;
+      readList<PDLValue>(args);
+
+      LLVM_DEBUG({
+        llvm::dbgs() << "  * Arguments: ";
+        llvm::interleaveComma(args, llvm::dbgs());
+        llvm::dbgs() << "\n  * Parameters: " << constParams << "\n";
+      });
+
+      PDLValue result = createFn(args, constParams, rewriter);
+      memory[resultIndex] = result.getAsOpaquePointer();
+
+      LLVM_DEBUG(llvm::dbgs() << "  * Result: " << result << "\n\n");
+      break;
+    }
+    case CreateOperation: {
+      LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n");
+      assert(mainRewriteLoc && "expected rewrite loc to be provided when "
+                               "executing the rewriter bytecode");
+
+      unsigned memIndex = read();
+      OperationState state(*mainRewriteLoc, read<OperationName>());
+      readList<Value>(state.operands);
+      for (unsigned i = 0, e = read(); i != e; ++i) {
+        Identifier name = read<Identifier>();
+        if (Attribute attr = read<Attribute>())
+          state.addAttribute(name, attr);
+      }
+
+      bool hasInferredTypes = false;
+      for (unsigned i = 0, e = read(); i != e; ++i) {
+        Type resultType = read<Type>();
+        hasInferredTypes |= !resultType;
+        state.types.push_back(resultType);
+      }
+
+      // Handle the case where the operation has inferred types.
+      if (hasInferredTypes) {
+        InferTypeOpInterface::Concept *concept =
+            state.name.getAbstractOperation()
+                ->getInterface<InferTypeOpInterface>();
+
+        // TODO: Handle failure.
+        SmallVector<Type, 2> inferredTypes;
+        if (failed(concept->inferReturnTypes(
+                state.getContext(), state.location, state.operands,
+                state.attributes.getDictionary(state.getContext()),
+                state.regions, inferredTypes)))
+          return;
+
+        for (unsigned i = 0, e = state.types.size(); i != e; ++i)
+          if (!state.types[i])
+            state.types[i] = inferredTypes[i];
+      }
+      Operation *resultOp = rewriter.createOperation(state);
+      memory[memIndex] = resultOp;
+
+      LLVM_DEBUG({
+        llvm::dbgs() << "  * Attributes: "
+                     << state.attributes.getDictionary(state.getContext())
+                     << "\n  * Operands: ";
+        llvm::interleaveComma(state.operands, llvm::dbgs());
+        llvm::dbgs() << "\n  * Result Types: ";
+        llvm::interleaveComma(state.types, llvm::dbgs());
+        llvm::dbgs() << "\n  * Result: " << *resultOp << "\n\n";
+      });
+      break;
+    }
+    case EraseOp: {
+      LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n");
+      Operation *op = read<Operation *>();
+
+      LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n\n");
+      rewriter.eraseOp(op);
+      break;
+    }
+    case Finalize: {
+      LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n\n");
+      return;
+    }
+    case GetAttribute: {
+      LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n");
+      unsigned memIndex = read();
+      Operation *op = read<Operation *>();
+      Identifier attrName = read<Identifier>();
+      Attribute attr = op->getAttr(attrName);
+
+      LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
+                              << "  * Attribute: " << attrName << "\n"
+                              << "  * Result: " << attr << "\n\n");
+      memory[memIndex] = attr.getAsOpaquePointer();
+      break;
+    }
+    case GetAttributeType: {
+      LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n");
+      unsigned memIndex = read();
+      Attribute attr = read<Attribute>();
+
+      LLVM_DEBUG(llvm::dbgs() << "  * Attribute: " << attr << "\n"
+                              << "  * Result: " << attr.getType() << "\n\n");
+      memory[memIndex] = attr.getType().getAsOpaquePointer();
+      break;
+    }
+    case GetDefiningOp: {
+      LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n");
+      unsigned memIndex = read();
+      Value value = read<Value>();
+      Operation *op = value ? value.getDefiningOp() : nullptr;
+
+      LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n"
+                              << "  * Result: " << *op << "\n\n");
+      memory[memIndex] = op;
+      break;
+    }
+    case GetOperand0:
+    case GetOperand1:
+    case GetOperand2:
+    case GetOperand3:
+    case GetOperandN: {
+      LLVM_DEBUG({
+        llvm::dbgs() << "Executing GetOperand"
+                     << (opCode == GetOperandN ? Twine("N")
+                                               : Twine(opCode - GetOperand0))
+                     << ":\n";
+      });
+      unsigned index =
+          opCode == GetOperandN ? read<uint32_t>() : (opCode - GetOperand0);
+      Operation *op = read<Operation *>();
+      unsigned memIndex = read();
+      Value operand =
+          index < op->getNumOperands() ? op->getOperand(index) : Value();
+
+      LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
+                              << "  * Index: " << index << "\n"
+                              << "  * Result: " << operand << "\n\n");
+      memory[memIndex] = operand.getAsOpaquePointer();
+      break;
+    }
+    case GetResult0:
+    case GetResult1:
+    case GetResult2:
+    case GetResult3:
+    case GetResultN: {
+      LLVM_DEBUG({
+        llvm::dbgs() << "Executing GetResult"
+                     << (opCode == GetResultN ? Twine("N")
+                                              : Twine(opCode - GetResult0))
+                     << ":\n";
+      });
+      unsigned index =
+          opCode == GetResultN ? read<uint32_t>() : (opCode - GetResult0);
+      Operation *op = read<Operation *>();
+      unsigned memIndex = read();
+      OpResult result =
+          index < op->getNumResults() ? op->getResult(index) : OpResult();
+
+      LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
+                              << "  * Index: " << index << "\n"
+                              << "  * Result: " << result << "\n\n");
+      memory[memIndex] = result.getAsOpaquePointer();
+      break;
+    }
+    case GetValueType: {
+      LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n");
+      unsigned memIndex = read();
+      Value value = read<Value>();
+
+      LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n"
+                              << "  * Result: " << value.getType() << "\n\n");
+      memory[memIndex] = value.getType().getAsOpaquePointer();
+      break;
+    }
+    case IsNotNull: {
+      LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n");
+      const void *value = read<const void *>();
+
+      LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n\n");
+      selectJump(value != nullptr);
+      break;
+    }
+    case RecordMatch: {
+      LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n");
+      assert(matches &&
+             "expected matches to be provided when executing the matcher");
+      unsigned patternIndex = read();
+      PatternBenefit benefit = currentPatternBenefits[patternIndex];
+      const ByteCodeField *dest = &code[read<ByteCodeAddr>()];
+
+      // If the benefit of the pattern is impossible, skip the processing of the
+      // rest of the pattern.
+      if (benefit.isImpossibleToMatch()) {
+        LLVM_DEBUG(llvm::dbgs() << "  * Benefit: Impossible To Match\n\n");
+        curCodeIt = dest;
+        break;
+      }
+
+      // Create a fused location containing the locations of each of the
+      // operations used in the match. This will be used as the location for
+      // created operations during the rewrite that don't already have an
+      // explicit location set.
+      unsigned numMatchLocs = read();
+      SmallVector<Location, 4> matchLocs;
+      matchLocs.reserve(numMatchLocs);
+      for (unsigned i = 0; i != numMatchLocs; ++i)
+        matchLocs.push_back(read<Operation *>()->getLoc());
+      Location matchLoc = rewriter.getFusedLoc(matchLocs);
+
+      LLVM_DEBUG(llvm::dbgs() << "  * Benefit: " << benefit.getBenefit() << "\n"
+                              << "  * Location: " << matchLoc << "\n\n");
+      matches->emplace_back(matchLoc, patterns[patternIndex], benefit);
+      readList<const void *>(matches->back().values);
+      curCodeIt = dest;
+      break;
+    }
+    case ReplaceOp: {
+      LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n");
+      Operation *op = read<Operation *>();
+      SmallVector<Value, 16> args;
+      readList<Value>(args);
+
+      LLVM_DEBUG({
+        llvm::dbgs() << "  * Operation: " << *op << "\n"
+                     << "  * Values: ";
+        llvm::interleaveComma(args, llvm::dbgs());
+        llvm::dbgs() << "\n\n";
+      });
+      rewriter.replaceOp(op, args);
+      break;
+    }
+    case SwitchAttribute: {
+      LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n");
+      Attribute value = read<Attribute>();
+      ArrayAttr cases = read<ArrayAttr>();
+      handleSwitch(value, cases);
+      break;
+    }
+    case SwitchOperandCount: {
+      LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n");
+      Operation *op = read<Operation *>();
+      auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
+
+      LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
+      handleSwitch(op->getNumOperands(), cases);
+      break;
+    }
+    case SwitchOperationName: {
+      LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n");
+      OperationName value = read<Operation *>()->getName();
+      size_t caseCount = read();
+
+      // The operation names are stored in-line, so to print them out for
+      // debugging purposes we need to read the array before executing the
+      // switch so that we can display all of the possible values.
+      LLVM_DEBUG({
+        const ByteCodeField *prevCodeIt = curCodeIt;
+        llvm::dbgs() << "  * Value: " << value << "\n"
+                     << "  * Cases: ";
+        llvm::interleaveComma(
+            llvm::map_range(llvm::seq<size_t>(0, caseCount),
+                            [&](size_t i) { return read<OperationName>(); }),
+            llvm::dbgs());
+        llvm::dbgs() << "\n\n";
+        curCodeIt = prevCodeIt;
+      });
+
+      // Try to find the switch value within any of the cases.
+      size_t jumpDest = 0;
+      for (size_t i = 0; i != caseCount; ++i) {
+        if (read<OperationName>() == value) {
+          curCodeIt += (caseCount - i - 1);
+          jumpDest = i + 1;
+          break;
+        }
+      }
+      selectJump(jumpDest);
+      break;
+    }
+    case SwitchResultCount: {
+      LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n");
+      Operation *op = read<Operation *>();
+      auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
+
+      LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
+      handleSwitch(op->getNumResults(), cases);
+      break;
+    }
+    case SwitchType: {
+      LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n");
+      Type value = read<Type>();
+      auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>();
+      handleSwitch(value, cases);
+      break;
+    }
+    }
+  }
+}
+
+/// Run the pattern matcher on the given root operation, collecting the matched
+/// patterns in `matches`.
+void PDLByteCode::match(Operation *op, PatternRewriter &rewriter,
+                        SmallVectorImpl<MatchResult> &matches,
+                        PDLByteCodeMutableState &state) const {
+  // The first memory slot is always the root operation.
+  state.memory[0] = op;
+
+  // The matcher function always starts at code address 0.
+  ByteCodeExecutor executor(matcherByteCode.data(), state.memory, uniquedData,
+                            matcherByteCode, state.currentPatternBenefits,
+                            patterns, constraintFunctions, createFunctions,
+                            rewriteFunctions);
+  executor.execute(rewriter, &matches);
+
+  // Order the found matches by benefit.
+  std::stable_sort(matches.begin(), matches.end(),
+                   [](const MatchResult &lhs, const MatchResult &rhs) {
+                     return lhs.benefit > rhs.benefit;
+                   });
+}
+
+/// Run the rewriter of the given pattern on the root operation `op`.
+void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match,
+                          PDLByteCodeMutableState &state) const {
+  // The arguments of the rewrite function are stored at the start of the
+  // memory buffer.
+  llvm::copy(match.values, state.memory.begin());
+
+  ByteCodeExecutor executor(
+      &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory,
+      uniquedData, rewriterByteCode, state.currentPatternBenefits, patterns,
+      constraintFunctions, createFunctions, rewriteFunctions);
+  executor.execute(rewriter, /*matches=*/nullptr, match.location);
+}
diff --git a/mlir/lib/Rewrite/ByteCode.h b/mlir/lib/Rewrite/ByteCode.h
new file mode 100644 (file)
index 0000000..7126037
--- /dev/null
@@ -0,0 +1,173 @@
+//===- ByteCode.h - Pattern byte-code interpreter ---------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares a byte-code and interpreter for pattern rewrites in MLIR.
+// The byte-code is constructed from the PDL Interpreter dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_REWRITE_BYTECODE_H_
+#define MLIR_REWRITE_BYTECODE_H_
+
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+namespace pdl_interp {
+class RecordMatchOp;
+} // end namespace pdl_interp
+
+namespace detail {
+class PDLByteCode;
+
+/// Use generic bytecode types. ByteCodeField refers to the actual bytecode
+/// entries (set to uint8_t for "byte" bytecode). ByteCodeAddr refers to size of
+/// indices into the bytecode. Correctness is checked with static asserts.
+using ByteCodeField = uint16_t;
+using ByteCodeAddr = uint32_t;
+
+//===----------------------------------------------------------------------===//
+// PDLByteCodePattern
+//===----------------------------------------------------------------------===//
+
+/// All of the data pertaining to a specific pattern within the bytecode.
+class PDLByteCodePattern : public Pattern {
+public:
+  static PDLByteCodePattern create(pdl_interp::RecordMatchOp matchOp,
+                                   ByteCodeAddr rewriterAddr);
+
+  /// Return the bytecode address of the rewriter for this pattern.
+  ByteCodeAddr getRewriterAddr() const { return rewriterAddr; }
+
+private:
+  template <typename... Args>
+  PDLByteCodePattern(ByteCodeAddr rewriterAddr, Args &&...patternArgs)
+      : Pattern(std::forward<Args>(patternArgs)...),
+        rewriterAddr(rewriterAddr) {}
+
+  /// The address of the rewriter for this pattern.
+  ByteCodeAddr rewriterAddr;
+};
+
+//===----------------------------------------------------------------------===//
+// PDLByteCodeMutableState
+//===----------------------------------------------------------------------===//
+
+/// This class contains the mutable state of a bytecode instance. This allows
+/// for a bytecode instance to be cached and reused across various different
+/// threads/drivers.
+class PDLByteCodeMutableState {
+public:
+  /// Initialize the state from a bytecode instance.
+  void initialize(PDLByteCode &bytecode);
+
+  /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds
+  /// to the position of the pattern within the range returned by
+  /// `PDLByteCode::getPatterns`.
+  void updatePatternBenefit(unsigned patternIndex, PatternBenefit benefit);
+
+private:
+  /// Allow access to data fields.
+  friend class PDLByteCode;
+
+  /// The mutable block of memory used during the matching and rewriting phases
+  /// of the bytecode.
+  std::vector<const void *> memory;
+
+  /// The up-to-date benefits of the patterns held by the bytecode. The order
+  /// of this array corresponds 1-1 with the array of patterns in `PDLByteCode`.
+  std::vector<PatternBenefit> currentPatternBenefits;
+};
+
+//===----------------------------------------------------------------------===//
+// PDLByteCode
+//===----------------------------------------------------------------------===//
+
+/// The bytecode class is also the interpreter. Contains the bytecode itself,
+/// the static info, addresses of the rewriter functions, the interpreter
+/// memory buffer, and the execution context.
+class PDLByteCode {
+public:
+  /// Each successful match returns a MatchResult, which contains information
+  /// necessary to execute the rewriter and indicates the originating pattern.
+  struct MatchResult {
+    MatchResult(Location loc, const PDLByteCodePattern &pattern,
+                PatternBenefit benefit)
+        : location(loc), pattern(&pattern), benefit(benefit) {}
+
+    /// The location of operations to be replaced.
+    Location location;
+    /// Memory values defined in the matcher that are passed to the rewriter.
+    SmallVector<const void *, 4> values;
+    /// The originating pattern that was matched. This is always non-null, but
+    /// represented with a pointer to allow for assignment.
+    const PDLByteCodePattern *pattern;
+    /// The current benefit of the pattern that was matched.
+    PatternBenefit benefit;
+  };
+
+  /// Create a ByteCode instance from the given module containing operations in
+  /// the PDL interpreter dialect.
+  PDLByteCode(ModuleOp module,
+              llvm::StringMap<PDLConstraintFunction> constraintFns,
+              llvm::StringMap<PDLCreateFunction> createFns,
+              llvm::StringMap<PDLRewriteFunction> rewriteFns);
+
+  /// Return the patterns held by the bytecode.
+  ArrayRef<PDLByteCodePattern> getPatterns() const { return patterns; }
+
+  /// Initialize the given state such that it can be used to execute the current
+  /// bytecode.
+  void initializeMutableState(PDLByteCodeMutableState &state) const;
+
+  /// Run the pattern matcher on the given root operation, collecting the
+  /// matched patterns in `matches`.
+  void match(Operation *op, PatternRewriter &rewriter,
+             SmallVectorImpl<MatchResult> &matches,
+             PDLByteCodeMutableState &state) const;
+
+  /// Run the rewriter of the given pattern that was previously matched in
+  /// `match`.
+  void rewrite(PatternRewriter &rewriter, const MatchResult &match,
+               PDLByteCodeMutableState &state) const;
+
+private:
+  /// Execute the given byte code starting at the provided instruction `inst`.
+  /// `matches` is an optional field provided when this function is executed in
+  /// a matching context.
+  void executeByteCode(const ByteCodeField *inst, PatternRewriter &rewriter,
+                       PDLByteCodeMutableState &state,
+                       SmallVectorImpl<MatchResult> *matches) const;
+
+  /// A vector containing pointers to unqiued data. The storage is intentionally
+  /// opaque such that we can store a wide range of data types. The types of
+  /// data stored here include:
+  ///  * Attribute, Identifier, OperationName, Type
+  std::vector<const void *> uniquedData;
+
+  /// A vector containing the generated bytecode for the matcher.
+  SmallVector<ByteCodeField, 64> matcherByteCode;
+
+  /// A vector containing the generated bytecode for all of the rewriters.
+  SmallVector<ByteCodeField, 64> rewriterByteCode;
+
+  /// The set of patterns contained within the bytecode.
+  SmallVector<PDLByteCodePattern, 32> patterns;
+
+  /// A set of user defined functions invoked via PDL.
+  std::vector<PDLConstraintFunction> constraintFunctions;
+  std::vector<PDLCreateFunction> createFunctions;
+  std::vector<PDLRewriteFunction> rewriteFunctions;
+
+  /// The maximum memory index used by a value.
+  ByteCodeField maxValueMemoryIndex = 0;
+};
+
+} // end namespace detail
+} // end namespace mlir
+
+#endif // MLIR_REWRITE_BYTECODE_H_
index e37b9c3..5822789 100644 (file)
@@ -1,4 +1,5 @@
 add_mlir_library(MLIRRewrite
+  ByteCode.cpp
   FrozenRewritePatternList.cpp
   PatternApplicator.cpp
 
@@ -10,4 +11,8 @@ add_mlir_library(MLIRRewrite
 
   LINK_LIBS PUBLIC
   MLIRIR
+  MLIRPDL
+  MLIRPDLInterp
+  MLIRPDLToPDLInterp
+  MLIRSideEffectInterfaces
   )
index d0e4518..60f6dce 100644 (file)
@@ -7,13 +7,71 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Rewrite/FrozenRewritePatternList.h"
+#include "ByteCode.h"
+#include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
+#include "mlir/Dialect/PDL/IR/PDL.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
 
 using namespace mlir;
 
+static LogicalResult convertPDLToPDLInterp(ModuleOp pdlModule) {
+  // Skip the conversion if the module doesn't contain pdl.
+  if (llvm::empty(pdlModule.getOps<pdl::PatternOp>()))
+    return success();
+
+  // Simplify the provided PDL module. Note that we can't use the canonicalizer
+  // here because it would create a cyclic dependency.
+  auto simplifyFn = [](Operation *op) {
+    // TODO: Add folding here if ever necessary.
+    if (isOpTriviallyDead(op))
+      op->erase();
+  };
+  pdlModule.getBody()->walk(simplifyFn);
+
+  /// Lower the PDL pattern module to the interpreter dialect.
+  PassManager pdlPipeline(pdlModule.getContext());
+#ifdef NDEBUG
+  // We don't want to incur the hit of running the verifier when in release
+  // mode.
+  pdlPipeline.enableVerifier(false);
+#endif
+  pdlPipeline.addPass(createPDLToPDLInterpPass());
+  if (failed(pdlPipeline.run(pdlModule)))
+    return failure();
+
+  // Simplify again after running the lowering pipeline.
+  pdlModule.getBody()->walk(simplifyFn);
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // FrozenRewritePatternList
 //===----------------------------------------------------------------------===//
 
 FrozenRewritePatternList::FrozenRewritePatternList(
     OwningRewritePatternList &&patterns)
-    : patterns(patterns.takePatterns()) {}
+    : nativePatterns(std::move(patterns.getNativePatterns())) {
+  PDLPatternModule &pdlPatterns = patterns.getPDLPatterns();
+
+  // Generate the bytecode for the PDL patterns if any were provided.
+  ModuleOp pdlModule = pdlPatterns.getModule();
+  if (!pdlModule)
+    return;
+  if (failed(convertPDLToPDLInterp(pdlModule)))
+    llvm::report_fatal_error(
+        "failed to lower PDL pattern module to the PDL Interpreter");
+
+  // Generate the pdl bytecode.
+  pdlByteCode = std::make_unique<detail::PDLByteCode>(
+      pdlModule, pdlPatterns.takeConstraintFunctions(),
+      pdlPatterns.takeCreateFunctions(), pdlPatterns.takeRewriteFunctions());
+}
+
+FrozenRewritePatternList::FrozenRewritePatternList(
+    FrozenRewritePatternList &&patterns)
+    : nativePatterns(std::move(patterns.nativePatterns)),
+      pdlByteCode(std::move(patterns.pdlByteCode)) {}
+
+FrozenRewritePatternList::~FrozenRewritePatternList() {}
index 5d6ae51..6f5e1f2 100644 (file)
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Rewrite/PatternApplicator.h"
+#include "ByteCode.h"
 #include "llvm/Support/Debug.h"
 
 using namespace mlir;
+using namespace mlir::detail;
+
+PatternApplicator::PatternApplicator(
+    const FrozenRewritePatternList &frozenPatternList)
+    : frozenPatternList(frozenPatternList) {
+  if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) {
+    mutableByteCodeState = std::make_unique<PDLByteCodeMutableState>();
+    bytecode->initializeMutableState(*mutableByteCodeState);
+  }
+}
+PatternApplicator::~PatternApplicator() {}
 
 #define DEBUG_TYPE "pattern-match"
 
 void PatternApplicator::applyCostModel(CostModel model) {
+  // Apply the cost model to the bytecode patterns first, and then the native
+  // patterns.
+  if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) {
+    for (auto it : llvm::enumerate(bytecode->getPatterns()))
+      mutableByteCodeState->updatePatternBenefit(it.index(), model(it.value()));
+  }
+
   // Separate patterns by root kind to simplify lookup later on.
   patterns.clear();
   anyOpPatterns.clear();
-  for (const auto &pat : frozenPatternList.getPatterns()) {
+  for (const auto &pat : frozenPatternList.getNativePatterns()) {
     // If the pattern is always impossible to match, just ignore it.
     if (pat.getBenefit().isImpossibleToMatch()) {
       LLVM_DEBUG({
@@ -81,8 +100,12 @@ void PatternApplicator::applyCostModel(CostModel model) {
 
 void PatternApplicator::walkAllPatterns(
     function_ref<void(const Pattern &)> walk) {
-  for (auto &it : frozenPatternList.getPatterns())
+  for (const Pattern &it : frozenPatternList.getNativePatterns())
     walk(it);
+  if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) {
+    for (const Pattern &it : bytecode->getPatterns())
+      walk(it);
+  }
 }
 
 LogicalResult PatternApplicator::matchAndRewrite(
@@ -90,6 +113,14 @@ LogicalResult PatternApplicator::matchAndRewrite(
     function_ref<bool(const Pattern &)> canApply,
     function_ref<void(const Pattern &)> onFailure,
     function_ref<LogicalResult(const Pattern &)> onSuccess) {
+  // Before checking native patterns, first match against the bytecode. This
+  // won't automatically perform any rewrites so there is no need to worry about
+  // conflicts.
+  SmallVector<PDLByteCode::MatchResult, 4> pdlMatches;
+  const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode();
+  if (bytecode)
+    bytecode->match(op, rewriter, pdlMatches, *mutableByteCodeState);
+
   // Check to see if there are patterns matching this specific operation type.
   MutableArrayRef<const RewritePattern *> opPatterns;
   auto patternIt = patterns.find(op->getName());
@@ -98,51 +129,50 @@ LogicalResult PatternApplicator::matchAndRewrite(
 
   // Process the patterns for that match the specific operation type, and any
   // operation type in an interleaved fashion.
-  // FIXME: It'd be nice to just write an llvm::make_merge_range utility
-  // and pass in a comparison function. That would make this code trivial.
   auto opIt = opPatterns.begin(), opE = opPatterns.end();
   auto anyIt = anyOpPatterns.begin(), anyE = anyOpPatterns.end();
-  while (opIt != opE && anyIt != anyE) {
-    // Try to match the pattern providing the most benefit.
-    const RewritePattern *pattern;
-    if ((*opIt)->getBenefit() >= (*anyIt)->getBenefit())
-      pattern = *(opIt++);
-    else
-      pattern = *(anyIt++);
+  auto pdlIt = pdlMatches.begin(), pdlE = pdlMatches.end();
+  while (true) {
+    // Find the next pattern with the highest benefit.
+    const Pattern *bestPattern = nullptr;
+    const PDLByteCode::MatchResult *pdlMatch = nullptr;
+    /// Operation specific patterns.
+    if (opIt != opE)
+      bestPattern = *(opIt++);
+    /// Operation agnostic patterns.
+    if (anyIt != anyE &&
+        (!bestPattern || bestPattern->getBenefit() < (*anyIt)->getBenefit()))
+      bestPattern = *(anyIt++);
+    /// PDL patterns.
+    if (pdlIt != pdlE &&
+        (!bestPattern || bestPattern->getBenefit() < pdlIt->benefit)) {
+      pdlMatch = pdlIt;
+      bestPattern = (pdlIt++)->pattern;
+    }
+    if (!bestPattern)
+      break;
 
-    // Otherwise, try to match the generic pattern.
-    if (succeeded(matchAndRewrite(op, *pattern, rewriter, canApply, onFailure,
-                                  onSuccess)))
-      return success();
-  }
-  // If we break from the loop, then only one of the ranges can still have
-  // elements. Loop over both without checking given that we don't need to
-  // interleave anymore.
-  for (const RewritePattern *pattern : llvm::concat<const RewritePattern *>(
-           llvm::make_range(opIt, opE), llvm::make_range(anyIt, anyE))) {
-    if (succeeded(matchAndRewrite(op, *pattern, rewriter, canApply, onFailure,
-                                  onSuccess)))
+    // Check that the pattern can be applied.
+    if (canApply && !canApply(*bestPattern))
+      continue;
+
+    // Try to match and rewrite this pattern. The patterns are sorted by
+    // benefit, so if we match we can immediately rewrite. For PDL patterns, the
+    // match has already been performed, we just need to rewrite.
+    rewriter.setInsertionPoint(op);
+    LogicalResult result = success();
+    if (pdlMatch) {
+      bytecode->rewrite(rewriter, *pdlMatch, *mutableByteCodeState);
+    } else {
+      result = static_cast<const RewritePattern *>(bestPattern)
+                   ->matchAndRewrite(op, rewriter);
+    }
+    if (succeeded(result) && (!onSuccess || succeeded(onSuccess(*bestPattern))))
       return success();
-  }
-  return failure();
-}
 
-LogicalResult PatternApplicator::matchAndRewrite(
-    Operation *op, const RewritePattern &pattern, PatternRewriter &rewriter,
-    function_ref<bool(const Pattern &)> canApply,
-    function_ref<void(const Pattern &)> onFailure,
-    function_ref<LogicalResult(const Pattern &)> onSuccess) {
-  // Check that the pattern can be applied.
-  if (canApply && !canApply(pattern))
-    return failure();
-
-  // Try to match and rewrite this pattern. The patterns are sorted by
-  // benefit, so if we match we can immediately rewrite.
-  rewriter.setInsertionPoint(op);
-  if (succeeded(pattern.matchAndRewrite(op, rewriter)))
-    return success(!onSuccess || succeeded(onSuccess(pattern)));
-
-  if (onFailure)
-    onFailure(pattern);
+    // Perform any necessary cleanups.
+    if (onFailure)
+      onFailure(*bestPattern);
+  }
   return failure();
 }
diff --git a/mlir/test/Rewrite/pdl-bytecode.mlir b/mlir/test/Rewrite/pdl-bytecode.mlir
new file mode 100644 (file)
index 0000000..b2a22d0
--- /dev/null
@@ -0,0 +1,785 @@
+// RUN: mlir-opt %s -test-pdl-bytecode-pass -split-input-file | FileCheck %s
+
+// Note: Tests here are written using the PDL Interpreter dialect to avoid
+// unnecessarily testing unnecessary aspects of the pattern compilation
+// pipeline. These tests are written such that we can focus solely on the
+// lowering/execution of the bytecode itself.
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::ApplyConstraintOp
+//===----------------------------------------------------------------------===//
+
+module @patterns {
+  func @matcher(%root : !pdl.operation) {
+    pdl_interp.apply_constraint "multi_entity_constraint"(%root, %root : !pdl.operation, !pdl.operation) -> ^pat, ^end
+
+  ^pat:
+    pdl_interp.apply_constraint "single_entity_constraint"(%root : !pdl.operation) -> ^pat2, ^end
+
+  ^pat2:
+    pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
+
+  ^end:
+    pdl_interp.finalize
+  }
+
+  module @rewriters {
+    func @success(%root : !pdl.operation) {
+      %op = pdl_interp.create_operation "test.replaced_by_pattern"() -> ()
+      pdl_interp.erase %root
+      pdl_interp.finalize
+    }
+  }
+}
+
+// CHECK-LABEL: test.apply_constraint_1
+// CHECK: "test.replaced_by_pattern"
+module @ir attributes { test.apply_constraint_1 } {
+  "test.op"() { test_attr } : () -> ()
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::ApplyRewriteOp
+//===----------------------------------------------------------------------===//
+
+module @patterns {
+  func @matcher(%root : !pdl.operation) {
+    pdl_interp.check_operation_name of %root is "test.op" -> ^pat, ^end
+
+  ^pat:
+    pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
+
+  ^end:
+    pdl_interp.finalize
+  }
+
+  module @rewriters {
+    func @success(%root : !pdl.operation) {
+      %operand = pdl_interp.get_operand 0 of %root
+      pdl_interp.apply_rewrite "rewriter"[42](%operand : !pdl.value) on %root
+      pdl_interp.finalize
+    }
+  }
+}
+
+// CHECK-LABEL: test.apply_rewrite_1
+// CHECK: %[[INPUT:.*]] = "test.op_input"
+// CHECK-NOT: "test.op"
+// CHECK: "test.success"(%[[INPUT]]) {constantParams = [42]}
+module @ir attributes { test.apply_rewrite_1 } {
+  %input = "test.op_input"() : () -> i32
+  "test.op"(%input) : (i32) -> ()
+}
+// -----
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::AreEqualOp
+//===----------------------------------------------------------------------===//
+
+module @patterns {
+  func @matcher(%root : !pdl.operation) {
+    %test_attr = pdl_interp.create_attribute unit
+    %attr = pdl_interp.get_attribute "test_attr" of %root
+    pdl_interp.are_equal %test_attr, %attr : !pdl.attribute -> ^pat, ^end
+
+  ^pat:
+    pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
+
+  ^end:
+    pdl_interp.finalize
+  }
+
+  module @rewriters {
+    func @success(%root : !pdl.operation) {
+      %op = pdl_interp.create_operation "test.success"() -> ()
+      pdl_interp.erase %root
+      pdl_interp.finalize
+    }
+  }
+}
+
+// CHECK-LABEL: test.are_equal_1
+// CHECK: "test.success"
+module @ir attributes { test.are_equal_1 } {
+  "test.op"() { test_attr } : () -> ()
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::BranchOp
+//===----------------------------------------------------------------------===//
+
+module @patterns {
+  func @matcher(%root : !pdl.operation) {
+    pdl_interp.check_operation_name of %root is "test.op" -> ^pat1, ^end
+
+  ^pat1:
+    pdl_interp.branch ^pat2
+
+  ^pat2:
+    pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(2), loc([%root]) -> ^end
+
+  ^end:
+    pdl_interp.finalize
+  }
+
+  module @rewriters {
+    func @success(%root : !pdl.operation) {
+      %op = pdl_interp.create_operation "test.success"() -> ()
+      pdl_interp.erase %root
+      pdl_interp.finalize
+    }
+  }
+}
+
+// CHECK-LABEL: test.branch_1
+// CHECK: "test.success"
+module @ir attributes { test.branch_1 } {
+  "test.op"() : () -> ()
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::CheckAttributeOp
+//===----------------------------------------------------------------------===//
+
+module @patterns {
+  func @matcher(%root : !pdl.operation) {
+    %attr = pdl_interp.get_attribute "test_attr" of %root
+    pdl_interp.check_attribute %attr is unit -> ^pat, ^end
+
+  ^pat:
+    pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
+
+  ^end:
+    pdl_interp.finalize
+  }
+
+  module @rewriters {
+    func @success(%root : !pdl.operation) {
+      %op = pdl_interp.create_operation "test.success"() -> ()
+      pdl_interp.erase %root
+      pdl_interp.finalize
+    }
+  }
+}
+
+// CHECK-LABEL: test.check_attribute_1
+// CHECK: "test.success"
+module @ir attributes { test.check_attribute_1 } {
+  "test.op"() { test_attr } : () -> ()
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::CheckOperandCountOp
+//===----------------------------------------------------------------------===//
+
+module @patterns {
+  func @matcher(%root : !pdl.operation) {
+    pdl_interp.check_operand_count of %root is 1 -> ^pat, ^end
+
+  ^pat:
+    pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
+
+  ^end:
+    pdl_interp.finalize
+  }
+
+  module @rewriters {
+    func @success(%root : !pdl.operation) {
+      %op = pdl_interp.create_operation "test.success"() -> ()
+      pdl_interp.erase %root
+      pdl_interp.finalize
+    }
+  }
+}
+
+// CHECK-LABEL: test.check_operand_count_1
+// CHECK: "test.op"() : () -> i32
+// CHECK: "test.success"
+module @ir attributes { test.check_operand_count_1 } {
+  %operand = "test.op"() : () -> i32
+  "test.op"(%operand) : (i32) -> ()
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::CheckOperationNameOp
+//===----------------------------------------------------------------------===//
+
+module @patterns {
+  func @matcher(%root : !pdl.operation) {
+    pdl_interp.check_operation_name of %root is "test.op" -> ^pat, ^end
+
+  ^pat:
+    pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
+
+  ^end:
+    pdl_interp.finalize
+  }
+
+  module @rewriters {
+    func @success(%root : !pdl.operation) {
+      %op = pdl_interp.create_operation "test.success"() -> ()
+      pdl_interp.erase %root
+      pdl_interp.finalize
+    }
+  }
+}
+
+// CHECK-LABEL: test.check_operation_name_1
+// CHECK: "test.success"
+module @ir attributes { test.check_operation_name_1 } {
+  "test.op"() : () -> ()
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::CheckResultCountOp
+//===----------------------------------------------------------------------===//
+
+module @patterns {
+  func @matcher(%root : !pdl.operation) {
+    pdl_interp.check_result_count of %root is 1 -> ^pat, ^end
+
+  ^pat:
+    pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
+
+  ^end:
+    pdl_interp.finalize
+  }
+
+  module @rewriters {
+    func @success(%root : !pdl.operation) {
+      %op = pdl_interp.create_operation "test.success"() -> ()
+      pdl_interp.erase %root
+      pdl_interp.finalize
+    }
+  }
+}
+
+// CHECK-LABEL: test.check_result_count_1
+// CHECK: "test.success"() : () -> ()
+module @ir attributes { test.check_result_count_1 } {
+  "test.op"() : () -> i32
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::CheckTypeOp
+//===----------------------------------------------------------------------===//
+
+module @patterns {
+  func @matcher(%root : !pdl.operation) {
+    %attr = pdl_interp.get_attribute "test_attr" of %root
+    pdl_interp.is_not_null %attr : !pdl.attribute -> ^pat1, ^end
+
+  ^pat1:
+    %type = pdl_interp.get_attribute_type of %attr
+    pdl_interp.check_type %type is i32 -> ^pat2, ^end
+
+  ^pat2:
+    pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
+
+  ^end:
+    pdl_interp.finalize
+  }
+
+  module @rewriters {
+    func @success(%root : !pdl.operation) {
+      %op = pdl_interp.create_operation "test.success"() -> ()
+      pdl_interp.erase %root
+      pdl_interp.finalize
+    }
+  }
+}
+
+// CHECK-LABEL: test.check_type_1
+// CHECK: "test.success"
+module @ir attributes { test.check_type_1 } {
+  "test.op"() { test_attr = 10 : i32 } : () -> ()
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::CreateAttributeOp
+//===----------------------------------------------------------------------===//
+
+// Fully tested within the tests for other operations.
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::CreateNativeOp
+//===----------------------------------------------------------------------===//
+
+// -----
+
+module @patterns {
+  func @matcher(%root : !pdl.operation) {
+    pdl_interp.check_operation_name of %root is "test.op" -> ^pat, ^end
+
+  ^pat:
+    pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
+
+  ^end:
+    pdl_interp.finalize
+  }
+
+  module @rewriters {
+    func @success(%root : !pdl.operation) {
+      %op = pdl_interp.create_native "creator"(%root : !pdl.operation) : !pdl.operation
+      pdl_interp.erase %root
+      pdl_interp.finalize
+    }
+  }
+}
+
+// CHECK-LABEL: test.create_native_1
+// CHECK: "test.success"
+module @ir attributes { test.create_native_1 } {
+  "test.op"() : () -> ()
+}
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::CreateOperationOp
+//===----------------------------------------------------------------------===//
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::CreateTypeOp
+//===----------------------------------------------------------------------===//
+
+module @patterns {
+  func @matcher(%root : !pdl.operation) {
+    %attr = pdl_interp.get_attribute "test_attr" of %root
+    pdl_interp.is_not_null %attr : !pdl.attribute -> ^pat1, ^end
+
+  ^pat1:
+    %test_type = pdl_interp.create_type i32
+    %type = pdl_interp.get_attribute_type of %attr
+    pdl_interp.are_equal %type, %test_type : !pdl.type -> ^pat2, ^end
+
+  ^pat2:
+    pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
+
+  ^end:
+    pdl_interp.finalize
+  }
+
+  module @rewriters {
+    func @success(%root : !pdl.operation) {
+      %op = pdl_interp.create_operation "test.success"() -> ()
+      pdl_interp.erase %root
+      pdl_interp.finalize
+    }
+  }
+}
+
+// CHECK-LABEL: test.create_type_1
+// CHECK: "test.success"
+module @ir attributes { test.create_type_1 } {
+  "test.op"() { test_attr = 0 : i32 } : () -> ()
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::EraseOp
+//===----------------------------------------------------------------------===//
+
+// Fully tested within the tests for other operations.
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::FinalizeOp
+//===----------------------------------------------------------------------===//
+
+// Fully tested within the tests for other operations.
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::GetAttributeOp
+//===----------------------------------------------------------------------===//
+
+// Fully tested within the tests for other operations.
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::GetAttributeTypeOp
+//===----------------------------------------------------------------------===//
+
+// Fully tested within the tests for other operations.
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::GetDefiningOpOp
+//===----------------------------------------------------------------------===//
+
+module @patterns {
+  func @matcher(%root : !pdl.operation) {
+    pdl_interp.check_operand_count of %root is 5 -> ^pat1, ^end
+
+  ^pat1:
+    %operand0 = pdl_interp.get_operand 0 of %root
+    %operand4 = pdl_interp.get_operand 4 of %root
+    %defOp0 = pdl_interp.get_defining_op of %operand0
+    %defOp4 = pdl_interp.get_defining_op of %operand4
+    pdl_interp.are_equal %defOp0, %defOp4 : !pdl.operation -> ^pat2, ^end
+
+  ^pat2:
+    pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
+
+  ^end:
+    pdl_interp.finalize
+  }
+
+  module @rewriters {
+    func @success(%root : !pdl.operation) {
+      %op = pdl_interp.create_operation "test.success"() -> ()
+      pdl_interp.erase %root
+      pdl_interp.finalize
+    }
+  }
+}
+
+// CHECK-LABEL: test.get_defining_op_1
+// CHECK: %[[OPERAND0:.*]] = "test.op"
+// CHECK: %[[OPERAND1:.*]] = "test.op"
+// CHECK: "test.success"
+// CHECK: "test.op"(%[[OPERAND0]], %[[OPERAND0]], %[[OPERAND0]], %[[OPERAND0]], %[[OPERAND1]])
+module @ir attributes { test.get_defining_op_1 } {
+  %operand = "test.op"() : () -> i32
+  %other_operand = "test.op"() : () -> i32
+  "test.op"(%operand, %operand, %operand, %operand, %operand) : (i32, i32, i32, i32, i32) -> ()
+  "test.op"(%operand, %operand, %operand, %operand, %other_operand) : (i32, i32, i32, i32, i32) -> ()
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::GetOperandOp
+//===----------------------------------------------------------------------===//
+
+// Fully tested within the tests for other operations.
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::GetResultOp
+//===----------------------------------------------------------------------===//
+
+module @patterns {
+  func @matcher(%root : !pdl.operation) {
+    pdl_interp.check_result_count of %root is 5 -> ^pat1, ^end
+
+  ^pat1:
+    %result0 = pdl_interp.get_result 0 of %root
+    %result4 = pdl_interp.get_result 4 of %root
+    %result0_type = pdl_interp.get_value_type of %result0
+    %result4_type = pdl_interp.get_value_type of %result4
+    pdl_interp.are_equal %result0_type, %result4_type : !pdl.type -> ^pat2, ^end
+
+  ^pat2:
+    pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
+
+  ^end:
+    pdl_interp.finalize
+  }
+
+  module @rewriters {
+    func @success(%root : !pdl.operation) {
+      %op = pdl_interp.create_operation "test.success"() -> ()
+      pdl_interp.erase %root
+      pdl_interp.finalize
+    }
+  }
+}
+
+// CHECK-LABEL: test.get_result_1
+// CHECK: "test.success"
+// CHECK: "test.op"() : () -> (i32, i32, i32, i32, i64)
+module @ir attributes { test.get_result_1 } {
+  %a:5 = "test.op"() : () -> (i32, i32, i32, i32, i32)
+  %b:5 = "test.op"() : () -> (i32, i32, i32, i32, i64)
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::GetValueTypeOp
+//===----------------------------------------------------------------------===//
+
+// Fully tested within the tests for other operations.
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::InferredTypeOp
+//===----------------------------------------------------------------------===//
+
+// Fully tested within the tests for other operations.
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::IsNotNullOp
+//===----------------------------------------------------------------------===//
+
+// Fully tested within the tests for other operations.
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::RecordMatchOp
+//===----------------------------------------------------------------------===//
+
+// Check that the highest benefit pattern is selected.
+module @patterns {
+  func @matcher(%root : !pdl.operation) {
+    pdl_interp.check_operation_name of %root is "test.op" -> ^pat1, ^end
+
+  ^pat1:
+    pdl_interp.record_match @rewriters::@failure(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^pat2
+
+  ^pat2:
+    pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(2), loc([%root]) -> ^end
+
+  ^end:
+    pdl_interp.finalize
+  }
+
+  module @rewriters {
+    func @failure(%root : !pdl.operation) {
+      pdl_interp.erase %root
+      pdl_interp.finalize
+    }
+    func @success(%root : !pdl.operation) {
+      %op = pdl_interp.create_operation "test.success"() -> ()
+      pdl_interp.erase %root
+      pdl_interp.finalize
+    }
+  }
+}
+
+// CHECK-LABEL: test.record_match_1
+// CHECK: "test.success"
+module @ir attributes { test.record_match_1 } {
+  "test.op"() : () -> ()
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::ReplaceOp
+//===----------------------------------------------------------------------===//
+
+module @patterns {
+  func @matcher(%root : !pdl.operation) {
+    pdl_interp.check_operation_name of %root is "test.op" -> ^pat, ^end
+
+  ^pat:
+    pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
+
+  ^end:
+    pdl_interp.finalize
+  }
+
+  module @rewriters {
+    func @success(%root : !pdl.operation) {
+      %operand = pdl_interp.get_operand 0 of %root
+      pdl_interp.replace %root with (%operand)
+      pdl_interp.finalize
+    }
+  }
+}
+
+// CHECK-LABEL: test.replace_op_1
+// CHECK: %[[INPUT:.*]] = "test.op_input"
+// CHECK-NOT: "test.op"
+// CHECK: "test.op_consumer"(%[[INPUT]])
+module @ir attributes { test.replace_op_1 } {
+  %input = "test.op_input"() : () -> i32
+  %result = "test.op"(%input) : (i32) -> i32
+  "test.op_consumer"(%result) : (i32) -> ()
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::SwitchAttributeOp
+//===----------------------------------------------------------------------===//
+
+module @patterns {
+  func @matcher(%root : !pdl.operation) {
+    %attr = pdl_interp.get_attribute "test_attr" of %root
+    pdl_interp.switch_attribute %attr to [0, unit](^end, ^pat) -> ^end
+
+  ^pat:
+    %attr_2 = pdl_interp.get_attribute "test_attr_2" of %root
+    pdl_interp.switch_attribute %attr_2 to [0, unit](^end, ^end) -> ^pat2
+
+  ^pat2:
+    pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
+
+  ^end:
+    pdl_interp.finalize
+  }
+
+  module @rewriters {
+    func @success(%root : !pdl.operation) {
+      %op = pdl_interp.create_operation "test.success"() -> ()
+      pdl_interp.erase %root
+      pdl_interp.finalize
+    }
+  }
+}
+
+// CHECK-LABEL: test.switch_attribute_1
+// CHECK: "test.success"
+module @ir attributes { test.switch_attribute_1 } {
+  "test.op"() { test_attr } : () -> ()
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::SwitchOperandCountOp
+//===----------------------------------------------------------------------===//
+
+module @patterns {
+  func @matcher(%root : !pdl.operation) {
+    pdl_interp.switch_operand_count of %root to dense<[0, 1]> : vector<2xi32>(^end, ^pat) -> ^end
+
+  ^pat:
+    pdl_interp.switch_operand_count of %root to dense<[0, 2]> : vector<2xi32>(^end, ^end) -> ^pat2
+
+  ^pat2:
+    pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
+
+  ^end:
+    pdl_interp.finalize
+  }
+
+  module @rewriters {
+    func @success(%root : !pdl.operation) {
+      %op = pdl_interp.create_operation "test.success"() -> ()
+      pdl_interp.erase %root
+      pdl_interp.finalize
+    }
+  }
+}
+
+// CHECK-LABEL: test.switch_operand_1
+// CHECK: "test.success"
+module @ir attributes { test.switch_operand_1 } {
+  %input = "test.op_input"() : () -> i32
+  "test.op"(%input) : (i32) -> ()
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::SwitchOperationNameOp
+//===----------------------------------------------------------------------===//
+
+module @patterns {
+  func @matcher(%root : !pdl.operation) {
+    pdl_interp.switch_operation_name of %root to ["foo.op", "test.op"](^end, ^pat1) -> ^end
+
+  ^pat1:
+    pdl_interp.switch_operation_name of %root to ["foo.op", "bar.op"](^end, ^end) -> ^pat2
+
+  ^pat2:
+    pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
+
+  ^end:
+    pdl_interp.finalize
+  }
+
+  module @rewriters {
+    func @success(%root : !pdl.operation) {
+      %op = pdl_interp.create_operation "test.success"() -> ()
+      pdl_interp.erase %root
+      pdl_interp.finalize
+    }
+  }
+}
+
+// CHECK-LABEL: test.switch_operation_name_1
+// CHECK: "test.success"
+module @ir attributes { test.switch_operation_name_1 } {
+  "test.op"() : () -> ()
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::SwitchResultCountOp
+//===----------------------------------------------------------------------===//
+
+module @patterns {
+  func @matcher(%root : !pdl.operation) {
+    pdl_interp.switch_result_count of %root to dense<[0, 1]> : vector<2xi32>(^end, ^pat) -> ^end
+
+  ^pat:
+    pdl_interp.switch_result_count of %root to dense<[0, 2]> : vector<2xi32>(^end, ^end) -> ^pat2
+
+  ^pat2:
+    pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
+
+  ^end:
+    pdl_interp.finalize
+  }
+
+  module @rewriters {
+    func @success(%root : !pdl.operation) {
+      %op = pdl_interp.create_operation "test.success"() -> ()
+      pdl_interp.erase %root
+      pdl_interp.finalize
+    }
+  }
+}
+
+// CHECK-LABEL: test.switch_result_1
+// CHECK: "test.success"
+module @ir attributes { test.switch_result_1 } {
+  "test.op"() : () -> i32
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::SwitchTypeOp
+//===----------------------------------------------------------------------===//
+
+module @patterns {
+  func @matcher(%root : !pdl.operation) {
+    %attr = pdl_interp.get_attribute "test_attr" of %root
+    pdl_interp.is_not_null %attr : !pdl.attribute -> ^pat1, ^end
+
+  ^pat1:
+    %type = pdl_interp.get_attribute_type of %attr
+    pdl_interp.switch_type %type to [i32, i64](^pat2, ^end) -> ^end
+
+  ^pat2:
+    pdl_interp.switch_type %type to [i16, i64](^end, ^end) -> ^pat3
+
+  ^pat3:
+    pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
+
+  ^end:
+    pdl_interp.finalize
+  }
+
+  module @rewriters {
+    func @success(%root : !pdl.operation) {
+      %op = pdl_interp.create_operation "test.success"() -> ()
+      pdl_interp.erase %root
+      pdl_interp.finalize
+    }
+  }
+}
+
+// CHECK-LABEL: test.switch_type_1
+// CHECK: "test.success"
+module @ir attributes { test.switch_type_1 } {
+  "test.op"() { test_attr = 10 : i32 } : () -> ()
+}
index 0df357c..9b15686 100644 (file)
@@ -2,4 +2,5 @@ add_subdirectory(Dialect)
 add_subdirectory(IR)
 add_subdirectory(Pass)
 add_subdirectory(Reducer)
+add_subdirectory(Rewrite)
 add_subdirectory(Transforms)
diff --git a/mlir/test/lib/Rewrite/CMakeLists.txt b/mlir/test/lib/Rewrite/CMakeLists.txt
new file mode 100644 (file)
index 0000000..fd5d5d5
--- /dev/null
@@ -0,0 +1,16 @@
+# Exclude tests from libMLIR.so
+add_mlir_library(MLIRTestRewrite
+  TestPDLByteCode.cpp
+
+  EXCLUDE_FROM_LIBMLIR
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Rewrite
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRPass
+  MLIRSupport
+  MLIRTransformUtils
+  )
+
diff --git a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
new file mode 100644 (file)
index 0000000..3b23cb1
--- /dev/null
@@ -0,0 +1,85 @@
+//===- TestPDLByteCode.cpp - Test rewriter bytecode functionality ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+
+/// Custom constraint invoked from PDL.
+static LogicalResult customSingleEntityConstraint(PDLValue value,
+                                                  ArrayAttr constantParams,
+                                                  PatternRewriter &rewriter) {
+  Operation *rootOp = value.cast<Operation *>();
+  return success(rootOp->getName().getStringRef() == "test.op");
+}
+static LogicalResult customMultiEntityConstraint(ArrayRef<PDLValue> values,
+                                                 ArrayAttr constantParams,
+                                                 PatternRewriter &rewriter) {
+  return customSingleEntityConstraint(values[1], constantParams, rewriter);
+}
+
+// Custom creator invoked from PDL.
+static PDLValue customCreate(ArrayRef<PDLValue> args, ArrayAttr constantParams,
+                             PatternRewriter &rewriter) {
+  return rewriter.createOperation(
+      OperationState(args[0].cast<Operation *>()->getLoc(), "test.success"));
+}
+
+/// Custom rewriter invoked from PDL.
+static void customRewriter(Operation *root, ArrayRef<PDLValue> args,
+                           ArrayAttr constantParams,
+                           PatternRewriter &rewriter) {
+  OperationState successOpState(root->getLoc(), "test.success");
+  successOpState.addOperands(args[0].cast<Value>());
+  successOpState.addAttribute("constantParams", constantParams);
+  rewriter.createOperation(successOpState);
+  rewriter.eraseOp(root);
+}
+
+namespace {
+struct TestPDLByteCodePass
+    : public PassWrapper<TestPDLByteCodePass, OperationPass<ModuleOp>> {
+  void runOnOperation() final {
+    ModuleOp module = getOperation();
+
+    // The test cases are encompassed via two modules, one containing the
+    // patterns and one containing the operations to rewrite.
+    ModuleOp patternModule = module.lookupSymbol<ModuleOp>("patterns");
+    ModuleOp irModule = module.lookupSymbol<ModuleOp>("ir");
+    if (!patternModule || !irModule)
+      return;
+
+    // Process the pattern module.
+    patternModule.getOperation()->remove();
+    PDLPatternModule pdlPattern(patternModule);
+    pdlPattern.registerConstraintFunction("multi_entity_constraint",
+                                          customMultiEntityConstraint);
+    pdlPattern.registerConstraintFunction("single_entity_constraint",
+                                          customSingleEntityConstraint);
+    pdlPattern.registerCreateFunction("creator", customCreate);
+    pdlPattern.registerRewriteFunction("rewriter", customRewriter);
+
+    OwningRewritePatternList patternList(std::move(pdlPattern));
+
+    // Invoke the pattern driver with the provided patterns.
+    (void)applyPatternsAndFoldGreedily(irModule.getBodyRegion(),
+                                       std::move(patternList));
+  }
+};
+} // end anonymous namespace
+
+namespace mlir {
+namespace test {
+void registerTestPDLByteCodePass() {
+  PassRegistration<TestPDLByteCodePass>("test-pdl-bytecode-pass",
+                                        "Test PDL ByteCode functionality");
+}
+} // namespace test
+} // namespace mlir
index 8857bbe..52e96dc 100644 (file)
@@ -220,18 +220,21 @@ static void fillL1TilingAndMatmulToVectorPatterns(
     FuncOp funcOp, StringRef startMarker,
     SmallVectorImpl<OwningRewritePatternList> &patternsVector) {
   MLIRContext *ctx = funcOp.getContext();
-  patternsVector.emplace_back(LinalgTilingPattern<MatmulOp>(
+  patternsVector.emplace_back(std::make_unique<LinalgTilingPattern<MatmulOp>>(
       ctx,
       LinalgTilingOptions().setTileSizes({8, 12, 16}).setInterchange({1, 0, 2}),
       LinalgMarker(Identifier::get(startMarker, ctx),
                    Identifier::get("L1", ctx))));
 
-  patternsVector.emplace_back(LinalgPromotionPattern<MatmulOp>(
-      ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
-      LinalgMarker(Identifier::get("L1", ctx), Identifier::get("VEC", ctx))));
+  patternsVector.emplace_back(
+      std::make_unique<LinalgPromotionPattern<MatmulOp>>(
+          ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
+          LinalgMarker(Identifier::get("L1", ctx),
+                       Identifier::get("VEC", ctx))));
 
-  patternsVector.emplace_back(LinalgVectorizationPattern<MatmulOp>(
-      ctx, LinalgMarker(Identifier::get("VEC", ctx))));
+  patternsVector.emplace_back(
+      std::make_unique<LinalgVectorizationPattern<MatmulOp>>(
+          ctx, LinalgMarker(Identifier::get("VEC", ctx))));
   patternsVector.back()
       .insert<LinalgVectorizationPattern<FillOp>,
               LinalgVectorizationPattern<CopyOp>>(ctx);
@@ -437,7 +440,7 @@ applyMatmulToVectorPatterns(FuncOp funcOp,
     fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("START", ctx),
                                           stage1Patterns);
   } else if (testMatmulToVectorPatterns2dTiling) {
-    stage1Patterns.emplace_back(LinalgTilingPattern<MatmulOp>(
+    stage1Patterns.emplace_back(std::make_unique<LinalgTilingPattern<MatmulOp>>(
         ctx,
         LinalgTilingOptions()
             .setTileSizes({768, 264, 768})
index e8b0842..8bee2f5 100644 (file)
@@ -19,6 +19,7 @@ if(MLIR_INCLUDE_TESTS)
     MLIRTestIR
     MLIRTestPass
     MLIRTestReducer
+    MLIRTestRewrite
     MLIRTestTransforms
     )
 endif()
index 4095cc2..67aa855 100644 (file)
@@ -86,6 +86,7 @@ void registerTestMemRefStrideCalculation();
 void registerTestNumberOfBlockExecutionsPass();
 void registerTestNumberOfOperationExecutionsPass();
 void registerTestOpaqueLoc();
+void registerTestPDLByteCodePass();
 void registerTestPreparationPassWithAllowedMemrefResults();
 void registerTestRecursiveTypesPass();
 void registerTestSCFUtilsPass();
@@ -155,6 +156,7 @@ void registerTestPasses() {
   test::registerTestNumberOfBlockExecutionsPass();
   test::registerTestNumberOfOperationExecutionsPass();
   test::registerTestOpaqueLoc();
+  test::registerTestPDLByteCodePass();
   test::registerTestRecursiveTypesPass();
   test::registerTestSCFUtilsPass();
   test::registerTestSparsification();