Refactor Reduction Tree Pass
[lldb.git] / mlir / tools / mlir-reduce / ReductionTreeUtils.cpp
1 //===- ReductionTreeUtils.cpp - Reduction Tree Utilities ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the Reduction Tree Utilities. It defines pass independent
10 // methods that help in a reduction pass of the MLIR Reduce tool.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/Reducer/ReductionTreeUtils.h"
15
16 #define DEBUG_TYPE "mlir-reduce"
17
18 using namespace mlir;
19
20 /// Update the golden module's content with that of the reduced module.
21 void ReductionTreeUtils::updateGoldenModule(ModuleOp &golden,
22                                             ModuleOp reduced) {
23   golden.getBody()->clear();
24
25   golden.getBody()->getOperations().splice(golden.getBody()->begin(),
26                                            reduced.getBody()->getOperations());
27 }
28
29 /// Update the the smallest node traversed so far in the reduction tree and
30 /// print the debugging information for the currNode being traversed.
31 void ReductionTreeUtils::updateSmallestNode(ReductionNode *currNode,
32                                             ReductionNode *&smallestNode,
33                                             std::vector<int> path) {
34   LLVM_DEBUG(llvm::dbgs() << "\nTree Path: root");
35   for (int nodeIndex : path)
36     LLVM_DEBUG(llvm::dbgs() << " -> " << nodeIndex);
37
38   LLVM_DEBUG(llvm::dbgs() << "\nSize (chars): " << currNode->getSize());
39   if (currNode->getSize() < smallestNode->getSize()) {
40     LLVM_DEBUG(llvm::dbgs() << " - new smallest node!");
41     smallestNode = currNode;
42   }
43 }
44
45 /// Create a transform space index vector based on the specified number of
46 /// indices.
47 std::vector<bool> ReductionTreeUtils::createTransformSpace(ModuleOp module,
48                                                            int numIndices) {
49   std::vector<bool> transformSpace;
50   for (int i = 0; i < numIndices; ++i)
51     transformSpace.push_back(false);
52
53   return transformSpace;
54 }
55
56 /// Translate section start and end into a vector of ranges specifying the
57 /// section in the non transformed indices in the transform space.
58 static std::vector<std::tuple<int, int>> getRanges(std::vector<bool> tSpace,
59                                                    int start, int end) {
60   std::vector<std::tuple<int, int>> ranges;
61   int rangeStart = 0;
62   int rangeEnd = 0;
63   bool inside = false;
64   int transformableCount = 0;
65
66   for (auto element : llvm::enumerate(tSpace)) {
67     int index = element.index();
68     bool value = element.value();
69
70     if (start <= transformableCount && transformableCount < end) {
71       if (!value && !inside) {
72         inside = true;
73         rangeStart = index;
74       }
75       if (value && inside) {
76         rangeEnd = index;
77         ranges.push_back(std::make_tuple(rangeStart, rangeEnd));
78         inside = false;
79       }
80     }
81
82     if (!value)
83       transformableCount++;
84
85     if (transformableCount == end && inside) {
86       ranges.push_back(std::make_tuple(rangeStart, index + 1));
87       inside = false;
88       break;
89     }
90   }
91
92   return ranges;
93 }
94
95 /// Create the specified number of variants by applying the transform method
96 /// to different ranges of indices in the parent module. The isDeletion bolean
97 /// specifies if the transformation is the deletion of indices.
98 void ReductionTreeUtils::createVariants(
99     ReductionNode *parent, const Tester &test, int numVariants,
100     llvm::function_ref<void(ModuleOp, int, int)> transform, bool isDeletion) {
101   std::vector<bool> newTSpace;
102   ModuleOp module = parent->getModule();
103
104   std::vector<bool> parentTSpace = parent->getTransformSpace();
105   int indexCount = parent->transformSpaceSize();
106   std::vector<std::tuple<int, int>> ranges;
107
108   // No new variants can be created.
109   if (indexCount == 0)
110     return;
111
112   // Create a single variant by transforming the unique index.
113   if (indexCount == 1) {
114     ModuleOp variantModule = module.clone();
115     if (isDeletion) {
116       transform(variantModule, 0, 1);
117     } else {
118       ranges = getRanges(parentTSpace, 0, parentTSpace.size());
119       transform(variantModule, std::get<0>(ranges[0]), std::get<1>(ranges[0]));
120     }
121
122     new ReductionNode(variantModule, parent, newTSpace);
123
124     return;
125   }
126
127   // Create the specified number of variants.
128   for (int i = 0; i < numVariants; ++i) {
129     ModuleOp variantModule = module.clone();
130     newTSpace = parent->getTransformSpace();
131     int sectionSize = indexCount / numVariants;
132     int sectionStart = sectionSize * i;
133     int sectionEnd = sectionSize * (i + 1);
134
135     if (i == numVariants - 1)
136       sectionEnd = indexCount;
137
138     if (isDeletion)
139       transform(variantModule, sectionStart, sectionEnd);
140
141     ranges = getRanges(parentTSpace, sectionStart, sectionEnd);
142
143     for (auto range : ranges) {
144       int rangeStart = std::get<0>(range);
145       int rangeEnd = std::get<1>(range);
146
147       for (int x = rangeStart; x < rangeEnd; ++x)
148         newTSpace[x] = true;
149
150       if (!isDeletion)
151         transform(variantModule, rangeStart, rangeEnd);
152     }
153
154     // Create Reduction Node in the Reduction tree
155     new ReductionNode(variantModule, parent, newTSpace);
156   }
157 }