[mlir] NFC: fix trivial typos
[lldb.git] / mlir / tools / mlir-reduce / ReductionNode.cpp
1 //===- ReductionNode.cpp - Reduction Node Implementation -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the reduction nodes which are used to track of the
10 // metadata for a specific generated variant within a reduction pass and are the
11 // building blocks of the reduction tree structure. A reduction tree is used to
12 // keep track of the different generated variants throughout a reduction pass in
13 // the MLIR Reduce tool.
14 //
15 //===----------------------------------------------------------------------===//
16
17 #include "mlir/Reducer/ReductionNode.h"
18
19 using namespace mlir;
20
21 /// Sets up the metadata and links the node to its parent.
22 ReductionNode::ReductionNode(ModuleOp module, ReductionNode *parent)
23     : module(module), evaluated(false) {
24
25   if (parent != nullptr)
26     parent->linkVariant(this);
27 }
28
29 ReductionNode::ReductionNode(ModuleOp module, ReductionNode *parent,
30                              std::vector<bool> transformSpace)
31     : module(module), evaluated(false), transformSpace(transformSpace) {
32
33   if (parent != nullptr)
34     parent->linkVariant(this);
35 }
36
37 /// Calculates and updates the size and interesting values of the module.
38 void ReductionNode::measureAndTest(const Tester &test) {
39   SmallString<128> filepath;
40   int fd;
41
42   // Print module to temporary file.
43   std::error_code ec =
44       llvm::sys::fs::createTemporaryFile("mlir-reduce", "mlir", fd, filepath);
45
46   if (ec)
47     llvm::report_fatal_error("Error making unique filename: " + ec.message());
48
49   llvm::ToolOutputFile out(filepath, fd);
50   module.print(out.os());
51   out.os().close();
52
53   if (out.os().has_error())
54     llvm::report_fatal_error("Error emitting bitcode to file '" + filepath);
55
56   size = out.os().tell();
57   interesting = test.isInteresting(filepath);
58   evaluated = true;
59 }
60
61 /// Returns true if the size and interestingness have been calculated.
62 bool ReductionNode::isEvaluated() const { return evaluated; }
63
64 /// Returns the size in bytes of the module.
65 int ReductionNode::getSize() const { return size; }
66
67 /// Returns true if the module exhibits the interesting behavior.
68 bool ReductionNode::isInteresting() const { return interesting; }
69
70 /// Returns the pointers to the child variants.
71 ReductionNode *ReductionNode::getVariant(unsigned long index) const {
72   if (index < variants.size())
73     return variants[index].get();
74
75   return nullptr;
76 }
77
78 /// Returns the number of child variants.
79 int ReductionNode::variantsSize() const { return variants.size(); }
80
81 /// Returns true if the child variants vector is empty.
82 bool ReductionNode::variantsEmpty() const { return variants.empty(); }
83
84 /// Link a child variant node.
85 void ReductionNode::linkVariant(ReductionNode *newVariant) {
86   std::unique_ptr<ReductionNode> ptrVariant(newVariant);
87   variants.push_back(std::move(ptrVariant));
88 }
89
90 /// Sort the child variants and remove the uninteresting ones.
91 void ReductionNode::organizeVariants(const Tester &test) {
92   // Ensure all variants are evaluated.
93   for (auto &var : variants)
94     if (!var->isEvaluated())
95       var->measureAndTest(test);
96
97   // Sort variants by interestingness and size.
98   llvm::array_pod_sort(
99       variants.begin(), variants.end(), [](const auto *lhs, const auto *rhs) {
100         if (lhs->get()->isInteresting() && !rhs->get()->isInteresting())
101           return 0;
102
103         if (!lhs->get()->isInteresting() && rhs->get()->isInteresting())
104           return 1;
105
106         return (lhs->get()->getSize(), rhs->get()->getSize());
107       });
108
109   int interestingCount = 0;
110   for (auto &var : variants) {
111     if (var->isInteresting()) {
112       ++interestingCount;
113     } else {
114       break;
115     }
116   }
117
118   // Remove uninteresting variants.
119   variants.resize(interestingCount);
120 }
121
122 /// Returns the number of non transformed indices.
123 int ReductionNode::transformSpaceSize() {
124   return std::count(transformSpace.begin(), transformSpace.end(), false);
125 }
126
127 /// Returns a vector of the transformable indices in the Module.
128 const std::vector<bool> ReductionNode::getTransformSpace() {
129   return transformSpace;
130 }