Refactor Reduction Tree Pass
[lldb.git] / mlir / tools / mlir-reduce / Passes / OpReducer.cpp
1 //===- OpReducer.cpp - Operation Reducer ------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the OpReducer class. It defines a variant generator method
10 // with the purpose of producing different variants by eliminating a
11 // parametarizable type of operations from the  parent module.
12 //
13 //===----------------------------------------------------------------------===//
14 #include "mlir/Reducer/Passes/OpReducer.h"
15
16 using namespace mlir;
17
18 OpReducerImpl::OpReducerImpl(
19     llvm::function_ref<std::vector<Operation *>(ModuleOp)> getSpecificOps)
20     : getSpecificOps(getSpecificOps) {}
21
22 /// Return the name of this reducer class.
23 StringRef OpReducerImpl::getName() {
24   return StringRef("High Level Operation Reduction");
25 }
26
27 /// Return the initial transformSpace cointaing the transformable indices.
28 std::vector<bool> OpReducerImpl::initTransformSpace(ModuleOp module) {
29   auto ops = getSpecificOps(module);
30   int numOps = std::distance(ops.begin(), ops.end());
31   return ReductionTreeUtils::createTransformSpace(module, numOps);
32 }
33
34 /// Generate variants by removing opType operations from the module in the
35 /// parent and link the variants as childs in the Reduction Tree Pass.
36 void OpReducerImpl::generateVariants(
37     ReductionNode *parent, const Tester &test, int numVariants,
38     llvm::function_ref<void(ModuleOp, int, int)> transform) {
39   ReductionTreeUtils::createVariants(parent, test, numVariants, transform,
40                                      true);
41 }