Refactor Reduction Tree Pass
[lldb.git] / mlir / include / mlir / Reducer / ReductionTreePass.h
1 //===- ReductionTreePass.h - Reduction Tree Pass Implementation -*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the Reduction Tree Pass class. It provides a framework for
10 // the implementation of different reduction passes in the MLIR Reduce tool. It
11 // allows for custom specification of the variant generation behavior. It
12 // implements methods that define the different possible traversals of the
13 // reduction tree.
14 //
15 //===----------------------------------------------------------------------===//
16
17 #ifndef MLIR_REDUCER_REDUCTIONTREEPASS_H
18 #define MLIR_REDUCER_REDUCTIONTREEPASS_H
19
20 #include <vector>
21
22 #include "PassDetail.h"
23 #include "ReductionNode.h"
24 #include "mlir/Reducer/Passes/OpReducer.h"
25 #include "mlir/Reducer/ReductionTreeUtils.h"
26 #include "mlir/Reducer/Tester.h"
27
28 #define DEBUG_TYPE "mlir-reduce"
29
30 namespace mlir {
31
32 // Defines the traversal method options to be used in the reduction tree
33 /// traversal.
34 enum TraversalMode { SinglePath, Backtrack, MultiPath };
35
36 /// This class defines the Reduction Tree Pass. It provides a framework to
37 /// to implement a reduction pass using a tree structure to keep track of the
38 /// generated reduced variants.
39 template <typename Reducer, TraversalMode mode>
40 class ReductionTreePass
41     : public ReductionTreeBase<ReductionTreePass<Reducer, mode>> {
42 public:
43   ReductionTreePass(const ReductionTreePass &pass)
44       : ReductionTreeBase<ReductionTreePass<Reducer, mode>>(pass),
45         root(new ReductionNode(pass.root->getModule().clone(), nullptr)),
46         test(pass.test) {}
47
48   ReductionTreePass(const Tester &test) : test(test) {}
49
50   /// Runs the pass instance in the pass pipeline.
51   void runOnOperation() override {
52     ModuleOp module = this->getOperation();
53     Reducer reducer;
54     std::vector<bool> transformSpace = reducer.initTransformSpace(module);
55     ReductionNode *reduced;
56
57     this->root =
58         std::make_unique<ReductionNode>(module, nullptr, transformSpace);
59
60     root->measureAndTest(test);
61
62     LLVM_DEBUG(llvm::dbgs() << "\nReduction Tree Pass: " << reducer.getName(););
63     switch (mode) {
64     case SinglePath:
65       LLVM_DEBUG(llvm::dbgs() << " (Single Path)\n";);
66       reduced = singlePathTraversal();
67       break;
68     default:
69       llvm::report_fatal_error("Traversal method not currently supported.");
70     }
71
72     ReductionTreeUtils::updateGoldenModule(module,
73                                            reduced->getModule().clone());
74   }
75
76 private:
77   // Points to the root node in this reduction tree.
78   std::unique_ptr<ReductionNode> root;
79
80   // This object defines the variant generation at each level of the reduction
81   // tree.
82   Reducer reducer;
83
84   // This is used to test the interesting behavior of the reduction nodes in the
85   // tree.
86   const Tester &test;
87
88   /// Traverse the most reduced path in the reduction tree by generating the
89   /// variants at each level using the Reducer parameter's generateVariants
90   /// function. Stops when no new successful variants can be created at the
91   /// current level.
92   ReductionNode *singlePathTraversal() {
93     ReductionNode *currNode = root.get();
94     ReductionNode *smallestNode = currNode;
95     int tSpaceSize = currNode->transformSpaceSize();
96     std::vector<int> path;
97
98     ReductionTreeUtils::updateSmallestNode(currNode, smallestNode, path);
99
100     LLVM_DEBUG(llvm::dbgs() << "\nGenerating 1 variant: applying the ");
101     LLVM_DEBUG(llvm::dbgs() << "transformation to the entire module\n");
102
103     reducer.generateVariants(currNode, test, 1);
104     LLVM_DEBUG(llvm::dbgs() << "Testing\n");
105     currNode->organizeVariants(test);
106
107     if (!currNode->variantsEmpty())
108       return currNode->getVariant(0);
109
110     while (tSpaceSize != 1) {
111       ReductionTreeUtils::updateSmallestNode(currNode, smallestNode, path);
112
113       LLVM_DEBUG(llvm::dbgs() << "\nGenerating 2 variants: applying the ");
114       LLVM_DEBUG(llvm::dbgs() << "transformation to two different sections ");
115       LLVM_DEBUG(llvm::dbgs() << "of transformable indices\n");
116
117       reducer.generateVariants(currNode, test, 2);
118       LLVM_DEBUG(llvm::dbgs() << "Testing\n");
119       currNode->organizeVariants(test);
120
121       if (currNode->variantsEmpty())
122         break;
123
124       currNode = currNode->getVariant(0);
125       tSpaceSize = currNode->transformSpaceSize();
126       path.push_back(0);
127     }
128
129     if (tSpaceSize == 1) {
130       ReductionTreeUtils::updateSmallestNode(currNode, smallestNode, path);
131
132       LLVM_DEBUG(llvm::dbgs() << "\nGenerating 1 variants: applying the ");
133       LLVM_DEBUG(llvm::dbgs() << "transformation to the only transformable");
134       LLVM_DEBUG(llvm::dbgs() << "index\n");
135
136       reducer.generateVariants(currNode, test, 1);
137       LLVM_DEBUG(llvm::dbgs() << "Testing\n");
138       currNode->organizeVariants(test);
139
140       if (!currNode->variantsEmpty()) {
141         currNode = currNode->getVariant(0);
142         path.push_back(0);
143
144         ReductionTreeUtils::updateSmallestNode(currNode, smallestNode, path);
145       }
146     }
147
148     return currNode;
149   }
150 };
151
152 } // end namespace mlir
153
154 #endif