Create Reduction Tree Pass
[lldb.git] / mlir / include / mlir / Reducer / ReductionTreePass.h
1 //===- ReductionTreePass.h - Reduction Tree Pass Implementation -*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the Reduction Tree Pass class. It provides a framework for
10 // the implementation of different reduction passes in the MLIR Reduce tool. It
11 // allows for custom specification of the variant generation behavior. It
12 // implements methods that define the different possible traversals of the
13 // reduction tree.
14 //
15 //===----------------------------------------------------------------------===//
16
17 #ifndef MLIR_REDUCER_REDUCTIONTREEPASS_H
18 #define MLIR_REDUCER_REDUCTIONTREEPASS_H
19
20 #include <vector>
21
22 #include "PassDetail.h"
23 #include "ReductionNode.h"
24 #include "mlir/Reducer/Passes/FunctionReducer.h"
25 #include "mlir/Reducer/Tester.h"
26
27 namespace mlir {
28
29 /// Defines the traversal method options to be used in the reduction tree
30 /// traversal.
31 enum TraversalMode { SinglePath, MultiPath, Concurrent, Backtrack };
32
33 // This class defines the non- templated utilities used by the ReductionTreePass
34 // class.
35 class ReductionTreeUtils {
36 public:
37   void updateGoldenModule(ModuleOp &golden, ModuleOp reduced);
38 };
39
40 /// This class defines the Reduction Tree Pass. It provides a framework to
41 /// to implement a reduction pass using a tree structure to keep track of the
42 /// generated reduced variants.
43 template <typename Reducer, TraversalMode mode>
44 class ReductionTreePass
45     : public ReductionTreeBase<ReductionTreePass<Reducer, mode>> {
46 public:
47   ReductionTreePass(const Tester *test) : test(test) {}
48
49   ReductionTreePass(const ReductionTreePass &pass)
50       : root(new ReductionNode(pass.root->getModule().clone(), nullptr)),
51         test(pass.test) {}
52
53   /// Runs the pass instance in the pass pipeline.
54   void runOnOperation() override {
55     ModuleOp module = this->getOperation();
56     this->root = std::make_unique<ReductionNode>(module, nullptr);
57     ReductionNode *reduced;
58
59     switch (mode) {
60     case SinglePath:
61       reduced = singlePathTraversal();
62       break;
63     default:
64       llvm::report_fatal_error("Traversal method not currently supported.");
65     }
66
67     ReductionTreeUtils utils;
68     utils.updateGoldenModule(module, reduced->getModule());
69   }
70
71 private:
72   // Points to the root node in this reduction tree.
73   std::unique_ptr<ReductionNode> root;
74
75   // This object defines the variant generation at each level of the reduction
76   // tree.
77   Reducer reducer;
78
79   // This is used to test the interesting behavior of the reduction nodes in the
80   // tree.
81   const Tester *test;
82
83   /// Traverse the most reduced path in the reduction tree by generating the
84   /// variants at each level using the Reducer parameter's generateVariants
85   /// function. Stops when no new successful variants can be created at the
86   /// current level.
87   ReductionNode *singlePathTraversal() {
88     ReductionNode *currLevel = root.get();
89
90     while (true) {
91       reducer.generateVariants(currLevel, test);
92       currLevel->organizeVariants(test);
93
94       if (currLevel->variantsEmpty())
95         break;
96
97       currLevel = currLevel->getVariant(0);
98     }
99
100     return currLevel;
101   }
102 };
103
104 } // end namespace mlir
105
106 #endif