[mlir] NFC: fix trivial typos
[lldb.git] / mlir / tools / mlir-reduce / ReductionTreeUtils.cpp
1 //===- ReductionTreeUtils.cpp - Reduction Tree Utilities ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the Reduction Tree Utilities. It defines pass independent
10 // methods that help in a reduction pass of the MLIR Reduce tool.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/Reducer/ReductionTreeUtils.h"
15
16 #define DEBUG_TYPE "mlir-reduce"
17
18 using namespace mlir;
19
20 /// Update the golden module's content with that of the reduced module.
21 void ReductionTreeUtils::updateGoldenModule(ModuleOp &golden,
22                                             ModuleOp reduced) {
23   golden.getBody()->clear();
24
25   golden.getBody()->getOperations().splice(golden.getBody()->begin(),
26                                            reduced.getBody()->getOperations());
27 }
28
29 /// Update the smallest node traversed so far in the reduction tree and
30 /// print the debugging information for the currNode being traversed.
31 void ReductionTreeUtils::updateSmallestNode(ReductionNode *currNode,
32                                             ReductionNode *&smallestNode,
33                                             std::vector<int> path) {
34   LLVM_DEBUG(llvm::dbgs() << "\nTree Path: root");
35   #ifndef NDEBUG
36   for (int nodeIndex : path)
37     LLVM_DEBUG(llvm::dbgs() << " -> " << nodeIndex);
38   #endif
39
40   LLVM_DEBUG(llvm::dbgs() << "\nSize (chars): " << currNode->getSize());
41   if (currNode->getSize() < smallestNode->getSize()) {
42     LLVM_DEBUG(llvm::dbgs() << " - new smallest node!");
43     smallestNode = currNode;
44   }
45 }
46
47 /// Create a transform space index vector based on the specified number of
48 /// indices.
49 std::vector<bool> ReductionTreeUtils::createTransformSpace(ModuleOp module,
50                                                            int numIndices) {
51   std::vector<bool> transformSpace;
52   for (int i = 0; i < numIndices; ++i)
53     transformSpace.push_back(false);
54
55   return transformSpace;
56 }
57
58 /// Translate section start and end into a vector of ranges specifying the
59 /// section in the non transformed indices in the transform space.
60 static std::vector<std::tuple<int, int>> getRanges(std::vector<bool> tSpace,
61                                                    int start, int end) {
62   std::vector<std::tuple<int, int>> ranges;
63   int rangeStart = 0;
64   int rangeEnd = 0;
65   bool inside = false;
66   int transformableCount = 0;
67
68   for (auto element : llvm::enumerate(tSpace)) {
69     int index = element.index();
70     bool value = element.value();
71
72     if (start <= transformableCount && transformableCount < end) {
73       if (!value && !inside) {
74         inside = true;
75         rangeStart = index;
76       }
77       if (value && inside) {
78         rangeEnd = index;
79         ranges.push_back(std::make_tuple(rangeStart, rangeEnd));
80         inside = false;
81       }
82     }
83
84     if (!value)
85       transformableCount++;
86
87     if (transformableCount == end && inside) {
88       ranges.push_back(std::make_tuple(rangeStart, index + 1));
89       inside = false;
90       break;
91     }
92   }
93
94   return ranges;
95 }
96
97 /// Create the specified number of variants by applying the transform method
98 /// to different ranges of indices in the parent module. The isDeletion boolean
99 /// specifies if the transformation is the deletion of indices.
100 void ReductionTreeUtils::createVariants(
101     ReductionNode *parent, const Tester &test, int numVariants,
102     llvm::function_ref<void(ModuleOp, int, int)> transform, bool isDeletion) {
103   std::vector<bool> newTSpace;
104   ModuleOp module = parent->getModule();
105
106   std::vector<bool> parentTSpace = parent->getTransformSpace();
107   int indexCount = parent->transformSpaceSize();
108   std::vector<std::tuple<int, int>> ranges;
109
110   // No new variants can be created.
111   if (indexCount == 0)
112     return;
113
114   // Create a single variant by transforming the unique index.
115   if (indexCount == 1) {
116     ModuleOp variantModule = module.clone();
117     if (isDeletion) {
118       transform(variantModule, 0, 1);
119     } else {
120       ranges = getRanges(parentTSpace, 0, parentTSpace.size());
121       transform(variantModule, std::get<0>(ranges[0]), std::get<1>(ranges[0]));
122     }
123
124     new ReductionNode(variantModule, parent, newTSpace);
125
126     return;
127   }
128
129   // Create the specified number of variants.
130   for (int i = 0; i < numVariants; ++i) {
131     ModuleOp variantModule = module.clone();
132     newTSpace = parent->getTransformSpace();
133     int sectionSize = indexCount / numVariants;
134     int sectionStart = sectionSize * i;
135     int sectionEnd = sectionSize * (i + 1);
136
137     if (i == numVariants - 1)
138       sectionEnd = indexCount;
139
140     if (isDeletion)
141       transform(variantModule, sectionStart, sectionEnd);
142
143     ranges = getRanges(parentTSpace, sectionStart, sectionEnd);
144
145     for (auto range : ranges) {
146       int rangeStart = std::get<0>(range);
147       int rangeEnd = std::get<1>(range);
148
149       for (int x = rangeStart; x < rangeEnd; ++x)
150         newTSpace[x] = true;
151
152       if (!isDeletion)
153         transform(variantModule, rangeStart, rangeEnd);
154     }
155
156     // Create Reduction Node in the Reduction tree
157     new ReductionNode(variantModule, parent, newTSpace);
158   }
159 }