Create Reduction Tree Pass
authorMauricio Sifontes <sifontes@google.com>
Fri, 7 Aug 2020 23:17:27 +0000 (23:17 +0000)
committerMauricio Sifontes <sifontes@google.com>
Fri, 7 Aug 2020 23:17:31 +0000 (23:17 +0000)
Implement the Reduction Tree Pass framework as part of the MLIR Reduce tool. This is a parametarizable pass that allows for the implementation of custom reductions passes in the tool.
Implement the FunctionReducer class as an example of a Reducer class parameter for the instantiation of a Reduction Tree Pass.
Create a pass pipeline with a Reduction Tree Pass with the FunctionReducer class specified as parameter.

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D83969

16 files changed:
mlir/include/mlir/CMakeLists.txt
mlir/include/mlir/Reducer/CMakeLists.txt [new file with mode: 0644]
mlir/include/mlir/Reducer/PassDetail.h [new file with mode: 0644]
mlir/include/mlir/Reducer/Passes.td [new file with mode: 0644]
mlir/include/mlir/Reducer/Passes/FunctionReducer.h [new file with mode: 0644]
mlir/include/mlir/Reducer/ReductionNode.h [new file with mode: 0644]
mlir/include/mlir/Reducer/ReductionTreePass.h [new file with mode: 0644]
mlir/include/mlir/Reducer/Tester.h
mlir/lib/Reducer/Tester.cpp
mlir/test/mlir-reduce/failure-test.sh [new file with mode: 0755]
mlir/test/mlir-reduce/reduction-tree-pass.mlir [new file with mode: 0644]
mlir/tools/mlir-reduce/CMakeLists.txt
mlir/tools/mlir-reduce/Passes/FunctionReducer.cpp [new file with mode: 0644]
mlir/tools/mlir-reduce/ReductionNode.cpp [new file with mode: 0644]
mlir/tools/mlir-reduce/ReductionTreePass.cpp [new file with mode: 0644]
mlir/tools/mlir-reduce/mlir-reduce.cpp

index d16d148..1e31d7c 100644 (file)
@@ -2,4 +2,5 @@ add_subdirectory(Conversion)
 add_subdirectory(Dialect)
 add_subdirectory(IR)
 add_subdirectory(Interfaces)
+add_subdirectory(Reducer)
 add_subdirectory(Transforms)
diff --git a/mlir/include/mlir/Reducer/CMakeLists.txt b/mlir/include/mlir/Reducer/CMakeLists.txt
new file mode 100644 (file)
index 0000000..5cfaa09
--- /dev/null
@@ -0,0 +1,5 @@
+set(LLVM_TARGET_DEFINITIONS Passes.td)
+mlir_tablegen(Passes.h.inc -gen-pass-decls)
+add_public_tablegen_target(MLIRReducerIncGen)
+
+add_mlir_doc(Passes -gen-pass-doc ReducerPasses ./)
diff --git a/mlir/include/mlir/Reducer/PassDetail.h b/mlir/include/mlir/Reducer/PassDetail.h
new file mode 100644 (file)
index 0000000..3b6fa57
--- /dev/null
@@ -0,0 +1,21 @@
+//===- PassDetail.h - Reducer Pass class details ----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_REDUCER_PASSDETAIL_H
+#define MLIR_REDUCER_PASSDETAIL_H
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+
+#define GEN_PASS_CLASSES
+#include "mlir/Reducer/Passes.h.inc"
+
+} // end namespace mlir
+
+#endif // MLIR_REDUCER_PASSDETAIL_H
diff --git a/mlir/include/mlir/Reducer/Passes.td b/mlir/include/mlir/Reducer/Passes.td
new file mode 100644 (file)
index 0000000..4703dd7
--- /dev/null
@@ -0,0 +1,23 @@
+//===-- Passes.td - MLIR Reduce pass definition file -------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains definitions of the passes for the MLIR Reduce Tool.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_REDUCER_PASSES
+#define MLIR_REDUCER_PASSES
+
+include "mlir/Pass/PassBase.td"
+
+def ReductionTree : Pass<"reduction-tree", "ModuleOp"> {
+  let summary = "A general reduction tree pass for the MLIR Reduce Tool";
+  let constructor = "mlir::createReductionTreePass()";
+}
+
+#endif // MLIR_REDUCE_PASSES
diff --git a/mlir/include/mlir/Reducer/Passes/FunctionReducer.h b/mlir/include/mlir/Reducer/Passes/FunctionReducer.h
new file mode 100644 (file)
index 0000000..f4b094b
--- /dev/null
@@ -0,0 +1,36 @@
+//===- FunctionReducer.h - MLIR Reduce Function Reducer ---------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the FunctionReducer class. It defines a variant generator
+// method with the purpose of producing different variants by eliminating
+// functions from the  parent module.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_REDUCER_PASSES_FUNCTIONREDUCER_H
+#define MLIR_REDUCER_PASSES_FUNCTIONREDUCER_H
+
+#include "mlir/Reducer/ReductionNode.h"
+#include "mlir/Reducer/Tester.h"
+
+namespace mlir {
+
+/// The FunctionReducer class defines a variant generator method that produces
+/// multiple variants by eliminating different operations from the
+/// parent module.
+class FunctionReducer {
+public:
+  /// Generate variants by removing functions from the module in the parent
+  /// Reduction Node and link the variants as children in the Reduction Tree
+  /// Pass.
+  void generateVariants(ReductionNode *parent, const Tester *test);
+};
+
+} // end namespace mlir
+
+#endif
diff --git a/mlir/include/mlir/Reducer/ReductionNode.h b/mlir/include/mlir/Reducer/ReductionNode.h
new file mode 100644 (file)
index 0000000..d07ddc5
--- /dev/null
@@ -0,0 +1,88 @@
+//===- ReductionNode.h - Reduction Node Implementation ----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the reduction nodes which are used to track of the metadata
+// for a specific generated variant within a reduction pass and are the building
+// blocks of the reduction tree structure. A reduction tree is used to keep
+// track of the different generated variants throughout a reduction pass in the
+// MLIR Reduce tool.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_REDUCER_REDUCTIONNODE_H
+#define MLIR_REDUCER_REDUCTIONNODE_H
+
+#include <vector>
+
+#include "mlir/Reducer/Tester.h"
+#include "llvm/Support/ToolOutputFile.h"
+
+namespace mlir {
+
+/// This class defines the ReductionNode which is used to wrap the module of
+/// a generated variant and keep track of the necessary metadata for the
+/// reduction pass. The nodes are linked together in a reduction tree stucture
+/// which defines the relationship between all the different generated variants.
+class ReductionNode {
+public:
+  ReductionNode(ModuleOp module, ReductionNode *parent);
+
+  /// Calculates and initializes the size and interesting values of the node.
+  void measureAndTest(const Tester *test);
+
+  /// Returns the module.
+  ModuleOp getModule() const { return module; }
+
+  /// Returns true if the size and interestingness have been calculated.
+  bool isEvaluated() const;
+
+  /// Returns the size in bytes of the module.
+  int getSize() const;
+
+  /// Returns true if the module exhibits the interesting behavior.
+  bool isInteresting() const;
+
+  /// Returns the pointer to a child variant by index.
+  ReductionNode *getVariant(unsigned long index) const;
+
+  /// Returns true if the vector containing the child variants is empty.
+  bool variantsEmpty() const;
+
+  /// Sort the child variants and remove the uninteresting ones.
+  void organizeVariants(const Tester *test);
+
+private:
+  /// Link a child variant node.
+  void linkVariant(ReductionNode *newVariant);
+
+  // This is the MLIR module of this variant.
+  ModuleOp module;
+
+  // This is true if the module has been evaluated and it exhibits the
+  // interesting behavior.
+  bool interesting;
+
+  // This indicates the number of characters in the printed module if the module
+  // has been evaluated.
+  int size;
+
+  // This indicates if the module has been evalueated (measured and tested).
+  bool evaluated;
+
+  // This points to the ReductionNode that was used as a starting point to
+  // create this variant. It is null if the reduction node is the root.
+  ReductionNode *parent;
+
+  // This points to the child variants that were created using this node as a
+  // starting point.
+  std::vector<std::unique_ptr<ReductionNode>> variants;
+};
+
+} // end namespace mlir
+
+#endif
diff --git a/mlir/include/mlir/Reducer/ReductionTreePass.h b/mlir/include/mlir/Reducer/ReductionTreePass.h
new file mode 100644 (file)
index 0000000..723ed6a
--- /dev/null
@@ -0,0 +1,106 @@
+//===- ReductionTreePass.h - Reduction Tree Pass Implementation -*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the Reduction Tree Pass class. It provides a framework for
+// the implementation of different reduction passes in the MLIR Reduce tool. It
+// allows for custom specification of the variant generation behavior. It
+// implements methods that define the different possible traversals of the
+// reduction tree.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_REDUCER_REDUCTIONTREEPASS_H
+#define MLIR_REDUCER_REDUCTIONTREEPASS_H
+
+#include <vector>
+
+#include "PassDetail.h"
+#include "ReductionNode.h"
+#include "mlir/Reducer/Passes/FunctionReducer.h"
+#include "mlir/Reducer/Tester.h"
+
+namespace mlir {
+
+/// Defines the traversal method options to be used in the reduction tree
+/// traversal.
+enum TraversalMode { SinglePath, MultiPath, Concurrent, Backtrack };
+
+// This class defines the non- templated utilities used by the ReductionTreePass
+// class.
+class ReductionTreeUtils {
+public:
+  void updateGoldenModule(ModuleOp &golden, ModuleOp reduced);
+};
+
+/// This class defines the Reduction Tree Pass. It provides a framework to
+/// to implement a reduction pass using a tree structure to keep track of the
+/// generated reduced variants.
+template <typename Reducer, TraversalMode mode>
+class ReductionTreePass
+    : public ReductionTreeBase<ReductionTreePass<Reducer, mode>> {
+public:
+  ReductionTreePass(const Tester *test) : test(test) {}
+
+  ReductionTreePass(const ReductionTreePass &pass)
+      : root(new ReductionNode(pass.root->getModule().clone(), nullptr)),
+        test(pass.test) {}
+
+  /// Runs the pass instance in the pass pipeline.
+  void runOnOperation() override {
+    ModuleOp module = this->getOperation();
+    this->root = std::make_unique<ReductionNode>(module, nullptr);
+    ReductionNode *reduced;
+
+    switch (mode) {
+    case SinglePath:
+      reduced = singlePathTraversal();
+      break;
+    default:
+      llvm::report_fatal_error("Traversal method not currently supported.");
+    }
+
+    ReductionTreeUtils utils;
+    utils.updateGoldenModule(module, reduced->getModule());
+  }
+
+private:
+  // Points to the root node in this reduction tree.
+  std::unique_ptr<ReductionNode> root;
+
+  // This object defines the variant generation at each level of the reduction
+  // tree.
+  Reducer reducer;
+
+  // This is used to test the interesting behavior of the reduction nodes in the
+  // tree.
+  const Tester *test;
+
+  /// Traverse the most reduced path in the reduction tree by generating the
+  /// variants at each level using the Reducer parameter's generateVariants
+  /// function. Stops when no new successful variants can be created at the
+  /// current level.
+  ReductionNode *singlePathTraversal() {
+    ReductionNode *currLevel = root.get();
+
+    while (true) {
+      reducer.generateVariants(currLevel, test);
+      currLevel->organizeVariants(test);
+
+      if (currLevel->variantsEmpty())
+        break;
+
+      currLevel = currLevel->getVariant(0);
+    }
+
+    return currLevel;
+  }
+};
+
+} // end namespace mlir
+
+#endif
index 8ca2a4a..004ca30 100644 (file)
@@ -9,8 +9,8 @@
 // This file defines the Tester class used in the MLIR Reduce tool.
 //
 // A Tester object is passed as an argument to the reduction passes and it is
-// used to keep track of the state of the reduction throughout the multiple
-// passes.
+// used to run the interestigness testing script on the different generated
+// reduced variants of the test case.
 //
 //===----------------------------------------------------------------------===//
 
@@ -27,9 +27,9 @@
 
 namespace mlir {
 
-/// This class is used to keep track of the state of the reduction. It contains
-/// a method to run the interestingness testing script on MLIR test case files
-/// and provides functionality to track the most reduced test case.
+/// This class is used to keep track of the testing environment of the tool. It
+/// contains a method to run the interestingness testing script on a MLIR test
+/// case file.
 class Tester {
 public:
   Tester(StringRef testScript, ArrayRef<std::string> testScriptArgs);
@@ -37,23 +37,13 @@ public:
   /// Runs the interestingness testing script on a MLIR test case file. Returns
   /// true if the interesting behavior is present in the test case or false
   /// otherwise.
-  bool isInteresting(StringRef testCase);
-
-  /// Returns the most reduced MLIR test case module.
-  ModuleOp getMostReduced() const { return mostReduced; }
-
-  /// Updates the most reduced MLIR test case module. If a
-  /// generated variant is found to be successful and shorter than the
-  /// mostReduced module, the mostReduced module must be updated with the new
-  /// variant.
-  void setMostReduced(ModuleOp t) { mostReduced = t; }
+  bool isInteresting(StringRef testCase) const;
 
 private:
   StringRef testScript;
   ArrayRef<std::string> testScriptArgs;
-  ModuleOp mostReduced;
 };
 
 } // end namespace mlir
 
-#endif
\ No newline at end of file
+#endif
index dcfce32..065c777 100644 (file)
@@ -9,8 +9,8 @@
 // This file defines the Tester class used in the MLIR Reduce tool.
 //
 // A Tester object is passed as an argument to the reduction passes and it is
-// used to keep track of the state of the reduction throughout the multiple
-// passes.
+// used to run the interestigness testing script on the different generated
+// reduced variants of the test case.
 //
 //===----------------------------------------------------------------------===//
 
@@ -24,7 +24,7 @@ Tester::Tester(StringRef scriptName, ArrayRef<std::string> scriptArgs)
 /// Runs the interestingness testing script on a MLIR test case file. Returns
 /// true if the interesting behavior is present in the test case or false
 /// otherwise.
-bool Tester::isInteresting(StringRef testCase) {
+bool Tester::isInteresting(StringRef testCase) const {
 
   std::vector<StringRef> testerArgs;
   testerArgs.push_back(testCase);
@@ -32,6 +32,8 @@ bool Tester::isInteresting(StringRef testCase) {
   for (const std::string &arg : testScriptArgs)
     testerArgs.push_back(arg);
 
+  testerArgs.push_back(testCase);
+
   std::string errMsg;
   int result = llvm::sys::ExecuteAndWait(
       testScript, testerArgs, /*Env=*/None, /*Redirects=*/None,
diff --git a/mlir/test/mlir-reduce/failure-test.sh b/mlir/test/mlir-reduce/failure-test.sh
new file mode 100755 (executable)
index 0000000..6a07743
--- /dev/null
@@ -0,0 +1,10 @@
+#!/bin/bash
+# Tests for the keyword "failure" in the stderr of the optimization pass
+mlir-opt $1 -test-mlir-reducer > /tmp/stdout.$$ 2>/tmp/stderr.$$
+
+if [ $? -ne 0 ] && grep 'failure' /tmp/stderr.$$; then
+  exit 1
+  #Interesting behavior
+else 
+  exit 0
+fi
diff --git a/mlir/test/mlir-reduce/reduction-tree-pass.mlir b/mlir/test/mlir-reduce/reduction-tree-pass.mlir
new file mode 100644 (file)
index 0000000..dc04a62
--- /dev/null
@@ -0,0 +1,39 @@
+// UNSUPPORTED: -windows-
+// RUN: mlir-reduce %s -test %S/failure-test.sh | FileCheck %s
+// This input should be reduced by the pass pipeline so that only 
+// the @simple5 function remains as this is the shortest function 
+// containing the interesting behavior.
+
+// CHECK-NOT: func @simple1() {
+func @simple1() {
+  return
+}
+
+// CHECK-NOT: func @simple2() {
+func @simple2() {
+  return
+}
+
+// CHECK-LABEL: func @simple3() {
+func @simple3() {
+  "test.crashOp" () : () -> ()
+  return
+}
+
+// CHECK-NOT: func @simple4() {
+func @simple4(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
+  cond_br %arg0, ^bb1, ^bb2
+^bb1:
+  br ^bb3(%arg1 : memref<2xf32>)
+^bb2:
+  %0 = alloc() : memref<2xf32>
+  br ^bb3(%0 : memref<2xf32>)
+^bb3(%1: memref<2xf32>):
+  "test.crashOp"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
+  return
+}
+
+// CHECK-NOT: func @simple5() {
+func @simple5() {
+  return
+}
index 642c5d9..b3a7c36 100644 (file)
@@ -32,10 +32,19 @@ set(LIBS
   )
 
 add_llvm_tool(mlir-reduce
+  Passes/FunctionReducer.cpp
+  ReductionNode.cpp
+  ReductionTreePass.cpp
   mlir-reduce.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Reducer
+
+  DEPENDS
+  MLIRReducerIncGen
   )
 
 target_link_libraries(mlir-reduce PRIVATE ${LIBS})
 llvm_update_compile_flags(mlir-reduce)
 
-mlir_check_all_link_libraries(mlir-reduce)
\ No newline at end of file
+mlir_check_all_link_libraries(mlir-reduce)
diff --git a/mlir/tools/mlir-reduce/Passes/FunctionReducer.cpp b/mlir/tools/mlir-reduce/Passes/FunctionReducer.cpp
new file mode 100644 (file)
index 0000000..ac97848
--- /dev/null
@@ -0,0 +1,72 @@
+//===- FunctionReducer.cpp - MLIR Reduce Function Reducer -----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the FunctionReducer class. It defines a variant generator
+// class to be used in a Reduction Tree Pass instantiation with the aim of
+// reducing the number of function operations in an MLIR Module.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Reducer/Passes/FunctionReducer.h"
+#include "mlir/IR/Function.h"
+
+using namespace mlir;
+
+/// Return the number of function operations in the module's body.
+int countFunctions(ModuleOp module) {
+  auto ops = module.getOps<FuncOp>();
+  return std::distance(ops.begin(), ops.end());
+}
+
+/// Generate variants by removing function operations from the module in the
+/// parent and link the variants as children in the Reduction Tree Pass.
+void FunctionReducer::generateVariants(ReductionNode *parent,
+                                       const Tester *test) {
+  ModuleOp module = parent->getModule();
+  int opCount = countFunctions(module);
+  int sectionSize = opCount / 2;
+  std::vector<Operation *> opsToRemove;
+
+  if (opCount == 0)
+    return;
+
+  // Create a variant by deleting all ops.
+  if (opCount == 1) {
+    opsToRemove.clear();
+    ModuleOp moduleVariant = module.clone();
+
+    for (FuncOp op : moduleVariant.getOps<FuncOp>())
+      opsToRemove.push_back(op);
+
+    for (Operation *o : opsToRemove)
+      o->erase();
+
+    new ReductionNode(moduleVariant, parent);
+
+    return;
+  }
+
+  // Create two variants by bisecting the module.
+  for (int i = 0; i < 2; ++i) {
+    opsToRemove.clear();
+    ModuleOp moduleVariant = module.clone();
+
+    for (auto op : enumerate(moduleVariant.getOps<FuncOp>())) {
+      int index = op.index();
+      if (index >= sectionSize * i && index < sectionSize * (i + 1))
+        opsToRemove.push_back(op.value());
+    }
+
+    for (Operation *o : opsToRemove)
+      o->erase();
+
+    new ReductionNode(moduleVariant, parent);
+  }
+
+  return;
+}
diff --git a/mlir/tools/mlir-reduce/ReductionNode.cpp b/mlir/tools/mlir-reduce/ReductionNode.cpp
new file mode 100644 (file)
index 0000000..30b9b79
--- /dev/null
@@ -0,0 +1,109 @@
+//===- ReductionNode.cpp - Reduction Node Implementation -----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the reduction nodes which are used to track of the
+// metadata for a specific generated variant within a reduction pass and are the
+// building blocks of the reduction tree structure. A reduction tree is used to
+// keep track of the different generated variants throughout a reduction pass in
+// the MLIR Reduce tool.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Reducer/ReductionNode.h"
+
+using namespace mlir;
+
+/// Sets up the metadata and links the node to its parent.
+ReductionNode::ReductionNode(ModuleOp module, ReductionNode *parent)
+    : module(module), evaluated(false), parent(parent) {
+
+  if (parent != nullptr)
+    parent->linkVariant(this);
+}
+
+/// Calculates and updates the size and interesting values of the module.
+void ReductionNode::measureAndTest(const Tester *test) {
+  SmallString<128> filepath;
+  int fd;
+
+  // Print module to temprary file.
+  std::error_code ec =
+      llvm::sys::fs::createTemporaryFile("mlir-reduce", "mlir", fd, filepath);
+
+  if (ec)
+    llvm::report_fatal_error("Error making unique filename: " + ec.message());
+
+  llvm::ToolOutputFile out(filepath, fd);
+  module.print(out.os());
+  out.os().close();
+
+  if (out.os().has_error())
+    llvm::report_fatal_error("Error emitting bitcode to file '" + filepath);
+
+  size = out.os().tell();
+  interesting = test->isInteresting(filepath);
+  evaluated = true;
+}
+
+/// Returns true if the size and interestingness have been calculated.
+bool ReductionNode::isEvaluated() const { return evaluated; }
+
+/// Returns the size in bytes of the module.
+int ReductionNode::getSize() const { return size; }
+
+/// Returns true if the module exhibits the interesting behavior.
+bool ReductionNode::isInteresting() const { return interesting; }
+
+/// Returns the pointers to the child variants.
+ReductionNode *ReductionNode::getVariant(unsigned long index) const {
+  if (index < variants.size())
+    return variants[index].get();
+
+  return nullptr;
+}
+
+/// Returns true if the child variants vector is empty.
+bool ReductionNode::variantsEmpty() const { return variants.empty(); }
+
+/// Link a child variant node.
+void ReductionNode::linkVariant(ReductionNode *newVariant) {
+  std::unique_ptr<ReductionNode> ptrVariant(newVariant);
+  variants.push_back(std::move(ptrVariant));
+}
+
+/// Sort the child variants and remove the uninteresting ones.
+void ReductionNode::organizeVariants(const Tester *test) {
+  // Ensure all variants are evaluated.
+  for (auto &var : variants)
+    if (!var->isEvaluated())
+      var->measureAndTest(test);
+
+  // Sort variants by interestingness and size.
+  llvm::array_pod_sort(
+      variants.begin(), variants.end(), [](const auto *lhs, const auto *rhs) {
+        if (lhs->get()->isInteresting() && !rhs->get()->isInteresting())
+          return 0;
+
+        if (!lhs->get()->isInteresting() && rhs->get()->isInteresting())
+          return 1;
+
+        return (lhs->get()->getSize(), rhs->get()->getSize());
+      });
+
+  int interestingCount = 0;
+  for (auto &var : variants) {
+    if (var->isInteresting()) {
+      ++interestingCount;
+    } else {
+      break;
+    }
+  }
+
+  // Remove uninteresting variants.
+  variants.resize(interestingCount);
+}
diff --git a/mlir/tools/mlir-reduce/ReductionTreePass.cpp b/mlir/tools/mlir-reduce/ReductionTreePass.cpp
new file mode 100644 (file)
index 0000000..d18c693
--- /dev/null
@@ -0,0 +1,28 @@
+//===- ReductionTreePass.cpp - Reduction Tree Pass Implementation ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the Reduction Tree Pass. It provides a framework for
+// the implementation of different reduction passes in the MLIR Reduce tool. It
+// allows for custom specification of the variant generation behavior. It
+// implements methods that define the different possible traversals of the
+// reduction tree.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Reducer/ReductionTreePass.h"
+
+using namespace mlir;
+
+/// Update the golden module's content with that of the reduced module.
+void ReductionTreeUtils::updateGoldenModule(ModuleOp &golden,
+                                            ModuleOp reduced) {
+  golden.getBody()->clear();
+
+  golden.getBody()->getOperations().splice(golden.getBody()->begin(),
+                                           reduced.getBody()->getOperations());
+}
index 3edef61..93de070 100644 (file)
@@ -19,6 +19,8 @@
 #include "mlir/Parser.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Pass/PassManager.h"
+#include "mlir/Reducer/ReductionNode.h"
+#include "mlir/Reducer/ReductionTreePass.h"
 #include "mlir/Reducer/Tester.h"
 #include "mlir/Support/FileUtilities.h"
 #include "mlir/Support/LogicalResult.h"
@@ -83,15 +85,27 @@ int main(int argc, char **argv) {
     llvm::report_fatal_error("Input test case can't be parsed");
 
   // Initialize test environment.
-  Tester test(testFilename, testArguments);
-  test.setMostReduced(moduleRef.get());
+  const Tester test(testFilename, testArguments);
 
   if (!test.isInteresting(inputFilename))
     llvm::report_fatal_error(
         "Input test case does not exhibit interesting behavior");
 
-  test.getMostReduced().print(output->os());
+  // Reduction pass pipeline.
+  PassManager pm(&context);
+
+  // Reduction tree pass with OpReducer variant generation and single path
+  // traversal.
+  pm.addPass(
+      std::make_unique<ReductionTreePass<FunctionReducer, SinglePath>>(&test));
+
+  ModuleOp m = moduleRef.get().clone();
+
+  if (failed(pm.run(m)))
+    llvm::report_fatal_error("Error running the reduction pass pipeline");
+
+  m.print(output->os());
   output->keep();
 
   return 0;
-}
\ No newline at end of file
+}