Revert "[mlir][Affine] Add support for multi-store producer fusion"
[lldb.git] / mlir / lib / Transforms / Utils / LoopFusionUtils.cpp
index 9749a8d..9759300 100644 (file)
@@ -191,8 +191,11 @@ gatherLoadsAndStores(AffineForOp forOp,
 /// 'srcForOp' into consumer loop 'dstForOp' without violating data dependences.
 // TODO: Generalize this check for sibling and more generic fusion scenarios.
 // TODO: Support forward slice fusion.
-static unsigned getMaxLoopDepth(ArrayRef<Operation *> srcOps,
-                                ArrayRef<Operation *> dstOps) {
+static unsigned getMaxLoopDepth(ArrayRef<Operation *> dstOps,
+                                FusionStrategy fusionStrategy) {
+  assert(fusionStrategy.strategy == FusionStrategy::ProducerConsumer &&
+         "Fusion strategy not supported");
+
   if (dstOps.empty())
     // Expected at least one memory operation.
     // TODO: Revisit this case with a specific example.
@@ -200,14 +203,15 @@ static unsigned getMaxLoopDepth(ArrayRef<Operation *> srcOps,
 
   // Filter out ops in 'dstOps' that do not use the producer-consumer memref so
   // that they are not considered for analysis.
-  DenseSet<Value> producerConsumerMemrefs;
-  gatherProducerConsumerMemrefs(srcOps, dstOps, producerConsumerMemrefs);
+  // TODO: Currently, we pass the producer-consumer memref through
+  // fusionStrategy. We will retrieve the memrefs from 'srcOps' once we
+  // generalize the algorithm.
   SmallVector<Operation *, 4> targetDstOps;
   for (Operation *dstOp : dstOps) {
     auto loadOp = dyn_cast<AffineReadOpInterface>(dstOp);
     Value memref = loadOp ? loadOp.getMemRef()
                           : cast<AffineWriteOpInterface>(dstOp).getMemRef();
-    if (producerConsumerMemrefs.count(memref) > 0)
+    if (memref == fusionStrategy.memref)
       targetDstOps.push_back(dstOp);
   }
 
@@ -304,10 +308,10 @@ FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
   // loop dependences.
   // TODO: Enable this check for sibling and more generic loop fusion
   // strategies.
-  if (fusionStrategy.getStrategy() == FusionStrategy::ProducerConsumer) {
+  if (fusionStrategy.strategy == FusionStrategy::ProducerConsumer) {
     // TODO: 'getMaxLoopDepth' does not support forward slice fusion.
     assert(isSrcForOpBeforeDstForOp && "Unexpected forward slice fusion");
-    if (getMaxLoopDepth(opsA, opsB) < dstLoopDepth) {
+    if (getMaxLoopDepth(opsB, fusionStrategy) < dstLoopDepth) {
       LLVM_DEBUG(llvm::dbgs() << "Fusion would violate loop dependences\n");
       return FusionResult::FailFusionDependence;
     }
@@ -320,7 +324,7 @@ FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
   // Filter out ops in 'opsA' to compute the slice union based on the
   // assumptions made by the fusion strategy.
   SmallVector<Operation *, 4> strategyOpsA;
-  switch (fusionStrategy.getStrategy()) {
+  switch (fusionStrategy.strategy) {
   case FusionStrategy::Generic:
     // Generic fusion. Take into account all the memory operations to compute
     // the slice union.
@@ -328,9 +332,10 @@ FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
     break;
   case FusionStrategy::ProducerConsumer:
     // Producer-consumer fusion (AffineLoopFusion pass) only takes into
-    // account stores in 'srcForOp' to compute the slice union.
+    // account stores to 'memref' in 'srcForOp' to compute the slice union.
     for (Operation *op : opsA) {
-      if (isa<AffineWriteOpInterface>(op))
+      auto store = dyn_cast<AffineWriteOpInterface>(op);
+      if (store && store.getMemRef() == fusionStrategy.memref)
         strategyOpsA.push_back(op);
     }
     break;
@@ -339,7 +344,7 @@ FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
     // to 'memref' in 'srcForOp' to compute the slice union.
     for (Operation *op : opsA) {
       auto load = dyn_cast<AffineReadOpInterface>(op);
-      if (load && load.getMemRef() == fusionStrategy.getSiblingFusionMemRef())
+      if (load && load.getMemRef() == fusionStrategy.memref)
         strategyOpsA.push_back(op);
     }
     break;
@@ -623,23 +628,3 @@ bool mlir::getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats,
                            /*tripCountOverrideMap=*/nullptr, &computeCostMap);
   return true;
 }
-
-/// Returns in 'producerConsumerMemrefs' the memrefs involved in a
-/// producer-consumer dependence between write ops in 'srcOps' and read ops in
-/// 'dstOps'.
-void mlir::gatherProducerConsumerMemrefs(
-    ArrayRef<Operation *> srcOps, ArrayRef<Operation *> dstOps,
-    DenseSet<Value> &producerConsumerMemrefs) {
-  // Gather memrefs from stores in 'srcOps'.
-  DenseSet<Value> srcStoreMemRefs;
-  for (Operation *op : srcOps)
-    if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op))
-      srcStoreMemRefs.insert(storeOp.getMemRef());
-
-  // Compute the intersection between memrefs from stores in 'srcOps' and
-  // memrefs from loads in 'dstOps'.
-  for (Operation *op : dstOps)
-    if (auto loadOp = dyn_cast<AffineReadOpInterface>(op))
-      if (srcStoreMemRefs.count(loadOp.getMemRef()) > 0)
-        producerConsumerMemrefs.insert(loadOp.getMemRef());
-}