d07a475e4f99481bfca7d27de2cb00057815780d
[lldb.git] / mlir / include / mlir / Reducer / ReductionTreePass.h
1 //===- ReductionTreePass.h - Reduction Tree Pass Implementation -*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the Reduction Tree Pass class. It provides a framework for
10 // the implementation of different reduction passes in the MLIR Reduce tool. It
11 // allows for custom specification of the variant generation behavior. It
12 // implements methods that define the different possible traversals of the
13 // reduction tree.
14 //
15 //===----------------------------------------------------------------------===//
16
17 #ifndef MLIR_REDUCER_REDUCTIONTREEPASS_H
18 #define MLIR_REDUCER_REDUCTIONTREEPASS_H
19
20 #include <vector>
21
22 #include "PassDetail.h"
23 #include "ReductionNode.h"
24 #include "mlir/Reducer/Passes/FunctionReducer.h"
25 #include "mlir/Reducer/Tester.h"
26
27 namespace mlir {
28
29 /// Defines the traversal method options to be used in the reduction tree
30 /// traversal.
31 enum TraversalMode { SinglePath, MultiPath, Concurrent, Backtrack };
32
33 // This class defines the non- templated utilities used by the ReductionTreePass
34 // class.
35 class ReductionTreeUtils {
36 public:
37   static void updateGoldenModule(ModuleOp &golden, ModuleOp reduced);
38 };
39
40 /// This class defines the Reduction Tree Pass. It provides a framework to
41 /// to implement a reduction pass using a tree structure to keep track of the
42 /// generated reduced variants.
43 template <typename Reducer, TraversalMode mode>
44 class ReductionTreePass
45     : public ReductionTreeBase<ReductionTreePass<Reducer, mode>> {
46 public:
47   ReductionTreePass(const Tester *test) : test(test) {}
48
49   ReductionTreePass(const ReductionTreePass &pass)
50       : ReductionTreeBase<ReductionTreePass<Reducer, mode>>(pass),
51         root(new ReductionNode(pass.root->getModule().clone(), nullptr)),
52         test(pass.test) {}
53
54   /// Runs the pass instance in the pass pipeline.
55   void runOnOperation() override {
56     ModuleOp module = this->getOperation();
57     this->root = std::make_unique<ReductionNode>(module, nullptr);
58     ReductionNode *reduced;
59
60     switch (mode) {
61     case SinglePath:
62       reduced = singlePathTraversal();
63       break;
64     default:
65       llvm::report_fatal_error("Traversal method not currently supported.");
66     }
67
68     ReductionTreeUtils utils;
69     utils.updateGoldenModule(module, reduced->getModule());
70   }
71
72 private:
73   // Points to the root node in this reduction tree.
74   std::unique_ptr<ReductionNode> root;
75
76   // This object defines the variant generation at each level of the reduction
77   // tree.
78   Reducer reducer;
79
80   // This is used to test the interesting behavior of the reduction nodes in the
81   // tree.
82   const Tester *test;
83
84   /// Traverse the most reduced path in the reduction tree by generating the
85   /// variants at each level using the Reducer parameter's generateVariants
86   /// function. Stops when no new successful variants can be created at the
87   /// current level.
88   ReductionNode *singlePathTraversal() {
89     ReductionNode *currLevel = root.get();
90
91     while (true) {
92       reducer.generateVariants(currLevel, test);
93       currLevel->organizeVariants(test);
94
95       if (currLevel->variantsEmpty())
96         break;
97
98       currLevel = currLevel->getVariant(0);
99     }
100
101     return currLevel;
102   }
103 };
104
105 } // end namespace mlir
106
107 #endif