Revert "[mlir][Affine] Add support for multi-store producer fusion"
[lldb.git] / mlir / lib / Transforms / LoopFusion.cpp
index 6c56368..6fe112b 100644 (file)
@@ -30,7 +30,6 @@
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/raw_ostream.h"
 #include <iomanip>
-#include <set>
 #include <sstream>
 #define DEBUG_TYPE "affine-loop-fusion"
 
@@ -271,6 +270,64 @@ public:
     return false;
   }
 
+  // Returns the unique AffineWriteOpInterface in `node` that meets all the
+  // following:
+  //   *) store is the only one that writes to a function-local memref live out
+  //      of `node`,
+  //   *) store is not the source of a self-dependence on `node`.
+  // Otherwise, returns a null AffineWriteOpInterface.
+  AffineWriteOpInterface getUniqueOutgoingStore(Node *node) {
+    AffineWriteOpInterface uniqueStore;
+
+    // Return null if `node` doesn't have any outgoing edges.
+    auto outEdgeIt = outEdges.find(node->id);
+    if (outEdgeIt == outEdges.end())
+      return nullptr;
+
+    const auto &nodeOutEdges = outEdgeIt->second;
+    for (auto *op : node->stores) {
+      auto storeOp = cast<AffineWriteOpInterface>(op);
+      auto memref = storeOp.getMemRef();
+      // Skip this store if there are no dependences on its memref. This means
+      // that store either:
+      // *) writes to a memref that is only read within the same loop nest
+      //    (self-dependence edges are not represented in graph at the moment),
+      // *) writes to a function live out memref (function parameter), or
+      // *) is dead.
+      if (llvm::all_of(nodeOutEdges, [=](const Edge &edge) {
+            return (edge.value != memref);
+          }))
+        continue;
+
+      if (uniqueStore)
+        // Found multiple stores to function-local live-out memrefs.
+        return nullptr;
+      // Found first store to function-local live-out memref.
+      uniqueStore = storeOp;
+    }
+
+    return uniqueStore;
+  }
+
+  // Returns true if node 'id' can be removed from the graph. Returns false
+  // otherwise. A node can be removed from the graph iff the following
+  // conditions are met:
+  // *) The node does not write to any memref which escapes (or is a
+  //    function/block argument).
+  // *) The node has no successors in the dependence graph.
+  bool canRemoveNode(unsigned id) {
+    if (writesToLiveInOrEscapingMemrefs(id))
+      return false;
+    Node *node = getNode(id);
+    for (auto *storeOpInst : node->stores) {
+      // Return false if there exist out edges from 'id' on 'memref'.
+      auto storeMemref = cast<AffineWriteOpInterface>(storeOpInst).getMemRef();
+      if (getOutEdgeCount(id, storeMemref) > 0)
+        return false;
+    }
+    return true;
+  }
+
   // Returns true iff there is an edge from node 'srcId' to node 'dstId' which
   // is for 'value' if non-null, or for any value otherwise. Returns false
   // otherwise.
@@ -438,49 +495,42 @@ public:
     return dstNodeInst;
   }
 
-  // Updates edge mappings from node 'srcId' to node 'dstId' after fusing them,
-  // taking into account that:
-  //   *) if 'removeSrcId' is true, 'srcId' will be removed after fusion,
-  //   *) memrefs in 'privateMemRefs' has been replaced in node at 'dstId' by a
-  //      private memref.
-  void updateEdges(unsigned srcId, unsigned dstId,
-                   const DenseSet<Value> &privateMemRefs, bool removeSrcId) {
+  // Updates edge mappings from node 'srcId' to node 'dstId' after 'oldMemRef'
+  // has been replaced in node at 'dstId' by a private memref depending
+  // on the value of 'createPrivateMemRef'.
+  void updateEdges(unsigned srcId, unsigned dstId, Value oldMemRef,
+                   bool createPrivateMemRef) {
     // For each edge in 'inEdges[srcId]': add new edge remapping to 'dstId'.
     if (inEdges.count(srcId) > 0) {
       SmallVector<Edge, 2> oldInEdges = inEdges[srcId];
       for (auto &inEdge : oldInEdges) {
-        // Add edge from 'inEdge.id' to 'dstId' if it's not a private memref.
-        if (privateMemRefs.count(inEdge.value) == 0)
+        // Add edge from 'inEdge.id' to 'dstId' if not for 'oldMemRef'.
+        if (inEdge.value != oldMemRef)
           addEdge(inEdge.id, dstId, inEdge.value);
       }
     }
     // For each edge in 'outEdges[srcId]': remove edge from 'srcId' to 'dstId'.
-    // If 'srcId' is going to be removed, remap all the out edges to 'dstId'.
     if (outEdges.count(srcId) > 0) {
       SmallVector<Edge, 2> oldOutEdges = outEdges[srcId];
       for (auto &outEdge : oldOutEdges) {
         // Remove any out edges from 'srcId' to 'dstId' across memrefs.
         if (outEdge.id == dstId)
           removeEdge(srcId, outEdge.id, outEdge.value);
-        else if (removeSrcId) {
-          addEdge(dstId, outEdge.id, outEdge.value);
-          removeEdge(srcId, outEdge.id, outEdge.value);
-        }
       }
     }
     // Remove any edges in 'inEdges[dstId]' on 'oldMemRef' (which is being
     // replaced by a private memref). These edges could come from nodes
     // other than 'srcId' which were removed in the previous step.
-    if (inEdges.count(dstId) > 0 && !privateMemRefs.empty()) {
+    if (inEdges.count(dstId) > 0 && createPrivateMemRef) {
       SmallVector<Edge, 2> oldInEdges = inEdges[dstId];
       for (auto &inEdge : oldInEdges)
-        if (privateMemRefs.count(inEdge.value) > 0)
+        if (inEdge.value == oldMemRef)
           removeEdge(inEdge.id, dstId, inEdge.value);
     }
   }
 
   // Update edge mappings for nodes 'sibId' and 'dstId' to reflect fusion
-  // of sibling node 'sibId' into node 'dstId'.
+  // of sibling node 'sidId' into node 'dstId'.
   void updateEdges(unsigned sibId, unsigned dstId) {
     // For each edge in 'inEdges[sibId]':
     // *) Add new edge from source node 'inEdge.id' to 'dstNode'.
@@ -574,132 +624,6 @@ public:
   void dump() const { print(llvm::errs()); }
 };
 
-/// Returns true if node 'srcId' can be removed after fusing it with node
-/// 'dstId'. The node can be removed if any of the following conditions are met:
-///   1. 'srcId' has no output dependences after fusion and no escaping memrefs.
-///   2. 'srcId' has no output dependences after fusion, has escaping memrefs
-///       and the fusion slice is maximal.
-///   3. 'srcId' has output dependences after fusion, the fusion slice is
-///      maximal and the fusion insertion point dominates all the dependences.
-static bool canRemoveSrcNodeAfterFusion(
-    unsigned srcId, unsigned dstId, const ComputationSliceState &fusionSlice,
-    Operation *fusedLoopInsPoint, const DenseSet<Value> &escapingMemRefs,
-    MemRefDependenceGraph *mdg) {
-
-  Operation *dstNodeOp = mdg->getNode(dstId)->op;
-  bool hasOutDepsAfterFusion = false;
-
-  for (auto &outEdge : mdg->outEdges[srcId]) {
-    Operation *depNodeOp = mdg->getNode(outEdge.id)->op;
-    // Skip dependence with dstOp since it will be removed after fusion.
-    if (depNodeOp == dstNodeOp)
-      continue;
-
-    // Only fusion within the same block is supported. Use domination analysis
-    // when needed.
-    if (depNodeOp->getBlock() != dstNodeOp->getBlock())
-      return false;
-
-    // Check if the insertion point of the fused loop dominates the dependence.
-    // Otherwise, the src loop can't be removed.
-    if (fusedLoopInsPoint != depNodeOp &&
-        !fusedLoopInsPoint->isBeforeInBlock(depNodeOp)) {
-      LLVM_DEBUG(llvm::dbgs() << "Src loop can't be removed: dst loop doesn't "
-                                 "dominate dependence\n");
-      return false;
-    }
-
-    hasOutDepsAfterFusion = true;
-  }
-
-  // If src loop has dependences after fusion or it writes to an live-out or
-  // escaping memref, we can only remove it if the fusion slice is maximal so
-  // that all the dependences are preserved.
-  if (hasOutDepsAfterFusion || !escapingMemRefs.empty()) {
-    Optional<bool> isMaximal = fusionSlice.isMaximal();
-    if (!isMaximal.hasValue()) {
-      LLVM_DEBUG(llvm::dbgs() << "Src loop can't be removed: can't determine "
-                                 "if fusion is maximal\n");
-      return false;
-    }
-
-    if (!isMaximal.getValue()) {
-      LLVM_DEBUG(llvm::dbgs()
-                 << "Src loop can't be removed: fusion is not maximal\n");
-      return false;
-    }
-  }
-
-  return true;
-}
-
-/// Returns in 'srcIdCandidates' the producer fusion candidates for consumer
-/// 'dstId'.
-// TODO: Move this to a loop fusion utility once 'mdg' is also moved.
-static void getProducerCandidates(unsigned dstId, MemRefDependenceGraph *mdg,
-                                  DenseSet<unsigned> &srcIdCandidates) {
-  // Skip if no input edges along which to fuse.
-  if (mdg->inEdges.count(dstId) == 0)
-    return;
-
-  // Gather memrefs from loads in 'dstId'.
-  auto *dstNode = mdg->getNode(dstId);
-  DenseSet<Value> consumedMemrefs;
-  for (Operation *load : dstNode->loads)
-    consumedMemrefs.insert(cast<AffineReadOpInterface>(load).getMemRef());
-
-  // Traverse 'dstId' incoming edges and gather the nodes that contain a store
-  // to one of the consumed memrefs.
-  for (auto &srcEdge : mdg->inEdges[dstId]) {
-    auto *srcNode = mdg->getNode(srcEdge.id);
-    // Skip if 'srcNode' is not a loop nest.
-    if (!isa<AffineForOp>(srcNode->op))
-      continue;
-
-    if (any_of(srcNode->stores, [&](Operation *op) {
-          auto storeOp = cast<AffineWriteOpInterface>(op);
-          return consumedMemrefs.count(storeOp.getMemRef()) > 0;
-        }))
-      srcIdCandidates.insert(srcNode->id);
-  }
-}
-
-/// Returns in 'producerConsumerMemrefs' the memrefs involved in a
-/// producer-consumer dependence between 'srcId' and 'dstId'.
-static void
-gatherProducerConsumerMemrefs(unsigned srcId, unsigned dstId,
-                              MemRefDependenceGraph *mdg,
-                              DenseSet<Value> &producerConsumerMemrefs) {
-  auto *dstNode = mdg->getNode(dstId);
-  auto *srcNode = mdg->getNode(srcId);
-  gatherProducerConsumerMemrefs(srcNode->stores, dstNode->loads,
-                                producerConsumerMemrefs);
-}
-
-/// Returns in 'escapingMemRefs' the memrefs from affine store ops in node 'id'
-/// that escape the function. A memref escapes the function if either:
-///   1. It's a function argument, or
-///   2. It's used by a non-affine op (e.g., std load/store, std call, etc.)
-void gatherEscapingMemrefs(unsigned id, MemRefDependenceGraph *mdg,
-                           DenseSet<Value> &escapingMemRefs) {
-  auto *node = mdg->getNode(id);
-  for (auto *storeOpInst : node->stores) {
-    auto memref = cast<AffineWriteOpInterface>(storeOpInst).getMemRef();
-    if (escapingMemRefs.count(memref))
-      continue;
-    // Check if 'memref' escapes because it's a block argument.
-    if (memref.isa<BlockArgument>()) {
-      escapingMemRefs.insert(memref);
-      continue;
-    }
-    // Check if 'memref' escapes through a non-affine op (e.g., std load/store,
-    // call op, etc.).
-    for (Operation *user : memref.getUsers())
-      if (!isMemRefDereferencingOp(*user))
-        escapingMemRefs.insert(memref);
-  }
-}
-
 } // end anonymous namespace
 
 // Initializes the data dependence graph by walking operations in 'f'.
@@ -707,7 +631,6 @@ void gatherEscapingMemrefs(unsigned id, MemRefDependenceGraph *mdg,
 // TODO: Add support for taking a Block arg to construct the
 // dependence graph at a different depth.
 bool MemRefDependenceGraph::init(FuncOp f) {
-  LLVM_DEBUG(llvm::dbgs() << "--- Initializing MDG ---\n");
   DenseMap<Value, SetVector<unsigned>> memrefAccesses;
 
   // TODO: support multi-block functions.
@@ -763,12 +686,6 @@ bool MemRefDependenceGraph::init(FuncOp f) {
     }
   }
 
-#ifndef NDEBUG
-  for (auto &idAndNode : nodes)
-    LLVM_DEBUG(llvm::dbgs() << "Create node " << idAndNode.first << " for:\n"
-                            << *(idAndNode.second.op) << "\n");
-#endif
-
   // Add dependence edges between nodes which produce SSA values and their
   // users.
   for (auto &idAndNode : nodes) {
@@ -808,6 +725,22 @@ bool MemRefDependenceGraph::init(FuncOp f) {
   return true;
 }
 
+// Removes load operations from 'srcLoads' which operate on 'memref', and
+// adds them to 'dstLoads'.
+static void moveLoadsAccessingMemrefTo(Value memref,
+                                       SmallVectorImpl<Operation *> *srcLoads,
+                                       SmallVectorImpl<Operation *> *dstLoads) {
+  dstLoads->clear();
+  SmallVector<Operation *, 4> srcLoadsToKeep;
+  for (auto *load : *srcLoads) {
+    if (cast<AffineReadOpInterface>(load).getMemRef() == memref)
+      dstLoads->push_back(load);
+    else
+      srcLoadsToKeep.push_back(load);
+  }
+  srcLoads->swap(srcLoadsToKeep);
+}
+
 // Sinks all sequential loops to the innermost levels (while preserving
 // relative order among them) and moves all parallel loops to the
 // outermost (while again preserving relative order among them).
@@ -999,6 +932,75 @@ static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId,
   return false;
 }
 
+// Checks if node 'srcId' can be safely fused into node 'dstId'. Node 'srcId'
+// may write to multiple memrefs but it is required that only one of them,
+// 'srcLiveOutStoreOp', has output edges.
+// Returns true if 'dstNode's read/write region to 'memref' is a super set of
+// 'srcNode's write region to 'memref' and 'srcId' has only one output edge.
+// TODO: Generalize this to handle more live in/out cases.
+static bool
+canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId,
+                               AffineWriteOpInterface srcLiveOutStoreOp,
+                               MemRefDependenceGraph *mdg) {
+  assert(srcLiveOutStoreOp && "Expected a valid store op");
+  auto *dstNode = mdg->getNode(dstId);
+  Value memref = srcLiveOutStoreOp.getMemRef();
+  // Return false if 'srcNode' has more than one output edge on 'memref'.
+  if (mdg->getOutEdgeCount(srcId, memref) > 1)
+    return false;
+
+  // Compute MemRefRegion 'srcWriteRegion' for 'srcStoreOp' on 'memref'.
+  MemRefRegion srcWriteRegion(srcLiveOutStoreOp.getLoc());
+  if (failed(srcWriteRegion.compute(srcLiveOutStoreOp, /*loopDepth=*/0))) {
+    LLVM_DEBUG(llvm::dbgs()
+               << "Unable to compute MemRefRegion for source operation\n.");
+    return false;
+  }
+  SmallVector<int64_t, 4> srcShape;
+  // Query 'srcWriteRegion' for 'srcShape' and 'srcNumElements'.
+  // by 'srcStoreOp' at depth 'dstLoopDepth'.
+  Optional<int64_t> srcNumElements =
+      srcWriteRegion.getConstantBoundingSizeAndShape(&srcShape);
+  if (!srcNumElements.hasValue())
+    return false;
+
+  // Compute MemRefRegion 'dstRegion' for 'dstStore/LoadOpInst' on 'memref'.
+  // TODO: Compute 'unionboundingbox' of all write regions (one for
+  // each store op in 'dstStoreOps').
+  SmallVector<Operation *, 2> dstStoreOps;
+  dstNode->getStoreOpsForMemref(memref, &dstStoreOps);
+  SmallVector<Operation *, 2> dstLoadOps;
+  dstNode->getLoadOpsForMemref(memref, &dstLoadOps);
+
+  auto *dstOpInst = dstStoreOps.empty() ? dstLoadOps[0] : dstStoreOps[0];
+  MemRefRegion dstRegion(dstOpInst->getLoc());
+  if (failed(dstRegion.compute(dstOpInst, /*loopDepth=*/0))) {
+    LLVM_DEBUG(llvm::dbgs()
+               << "Unable to compute MemRefRegion for dest operation\n.");
+    return false;
+  }
+  SmallVector<int64_t, 4> dstShape;
+  // Query 'dstRegion' for 'dstShape' and 'dstNumElements'.
+  // by 'dstOpInst' at depth 'dstLoopDepth'.
+  Optional<int64_t> dstNumElements =
+      dstRegion.getConstantBoundingSizeAndShape(&dstShape);
+  if (!dstNumElements.hasValue())
+    return false;
+
+  // Return false if write region is not a superset of 'srcNodes' write
+  // region to 'memref'.
+  // TODO: Check the shape and lower bounds here too.
+  if (srcNumElements != dstNumElements)
+    return false;
+
+  // Return false if 'memref' is used by a non-affine operation that is
+  // between node 'srcId' and node 'dstId'.
+  if (hasNonAffineUsersOnThePath(srcId, dstId, mdg))
+    return false;
+
+  return true;
+}
+
 // Checks the profitability of fusing a backwards slice of the loop nest
 // surrounding 'srcOpInst' into the loop nest surrounding 'dstLoadOpInsts'.
 // The argument 'srcStoreOpInst' is used to calculate the storage reduction on
@@ -1027,6 +1029,9 @@ static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId,
 //    the largest computation slice at the maximal dst loop depth (closest to
 //    the load) to minimize reuse distance and potentially enable subsequent
 //    load/store forwarding.
+//    NOTE: If the dst loop nest includes multiple loads in 'dstLoadOpInsts' for
+//    the same memref as is written by 'srcOpInst', then the union of slice
+//    loop bounds is used to compute the slice and associated slice cost.
 //    NOTE: 'dstLoopDepth' refers to the loop depth within the destination loop
 //    nest, at which the src computation slice is inserted/fused.
 //    NOTE: We attempt to maximize the dst loop depth, but there are cases
@@ -1036,18 +1041,18 @@ static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId,
 // *) Compares the total cost of the unfused loop nests to the min cost fused
 //    loop nest computed in the previous step, and returns true if the latter
 //    is lower.
-// TODO: Extend profitability analysis to support scenarios with multiple
-// stores.
 static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
-                               AffineForOp dstForOp,
+                               ArrayRef<Operation *> dstLoadOpInsts,
                                ArrayRef<ComputationSliceState> depthSliceUnions,
                                unsigned maxLegalFusionDepth,
                                unsigned *dstLoopDepth,
                                double computeToleranceThreshold) {
   LLVM_DEBUG({
     llvm::dbgs() << "Checking whether fusion is profitable between src op:\n";
-    llvm::dbgs() << ' ' << *srcOpInst << " and destination loop:\n";
-    llvm::dbgs() << dstForOp << "\n";
+    llvm::dbgs() << ' ' << *srcOpInst << " and destination op(s)\n";
+    for (auto dstOpInst : dstLoadOpInsts) {
+      llvm::dbgs() << " " << *dstOpInst << "\n";
+    };
   });
 
   if (maxLegalFusionDepth == 0) {
@@ -1065,8 +1070,11 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
     return false;
 
   // Compute cost of dst loop nest.
+  SmallVector<AffineForOp, 4> dstLoopIVs;
+  getLoopIVs(*dstLoadOpInsts[0], &dstLoopIVs);
+
   LoopNestStats dstLoopNestStats;
-  if (!getLoopNestStats(dstForOp, &dstLoopNestStats))
+  if (!getLoopNestStats(dstLoopIVs[0], &dstLoopNestStats))
     return false;
 
   // Search for min cost value for 'dstLoopDepth'. At each value of
@@ -1100,19 +1108,18 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
   int64_t srcWriteRegionSizeBytes = maybeSrcWriteRegionSizeBytes.getValue();
 
   // Compute op instance count for the src loop nest.
-  uint64_t dstLoopNestCost = getComputeCost(dstForOp, dstLoopNestStats);
+  uint64_t dstLoopNestCost = getComputeCost(dstLoopIVs[0], dstLoopNestStats);
 
   // Evaluate all depth choices for materializing the slice in the destination
   // loop nest.
   for (unsigned i = maxLegalFusionDepth; i >= 1; --i) {
-    const ComputationSliceState &slice = depthSliceUnions[i - 1];
     // Skip slice union if it wasn't computed for this depth.
-    if (slice.isEmpty())
+    if (depthSliceUnions[i - 1].isEmpty())
       continue;
 
     int64_t fusedLoopNestComputeCost;
-    if (!getFusionComputeCost(srcLoopIVs[0], srcLoopNestStats, dstForOp,
-                              dstLoopNestStats, slice,
+    if (!getFusionComputeCost(srcLoopIVs[0], srcLoopNestStats, dstLoopIVs[0],
+                              dstLoopNestStats, depthSliceUnions[i - 1],
                               &fusedLoopNestComputeCost)) {
       LLVM_DEBUG(llvm::dbgs() << "Unable to compute fusion compute cost.\n.");
       continue;
@@ -1124,11 +1131,11 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
         1;
 
     // Determine what the slice write MemRefRegion would be, if the src loop
-    // nest slice 'slice' were to be inserted into the dst loop nest at loop
-    // depth 'i'.
+    // nest slice 'depthSliceUnions[i - 1]' were to be inserted into the dst
+    // loop nest at loop depth 'i'.
     MemRefRegion sliceWriteRegion(srcStoreOpInst->getLoc());
     if (failed(sliceWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0,
-                                        &slice))) {
+                                        &depthSliceUnions[i - 1]))) {
       LLVM_DEBUG(llvm::dbgs()
                  << "Failed to compute slice write region at loopDepth: " << i
                  << "\n");
@@ -1211,7 +1218,7 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
                    << "\n  fused loop nest compute cost: "
                    << minFusedLoopNestComputeCost << "\n");
 
-  auto dstMemSize = getMemoryFootprintBytes(dstForOp);
+  auto dstMemSize = getMemoryFootprintBytes(dstLoopIVs[0]);
   auto srcMemSize = getMemoryFootprintBytes(srcLoopIVs[0]);
 
   Optional<double> storageReduction = None;
@@ -1315,6 +1322,8 @@ public:
   MemRefDependenceGraph *mdg;
   // Worklist of graph nodes visited during the fusion pass.
   SmallVector<unsigned, 8> worklist;
+  // Set of graph nodes which are present on the worklist.
+  llvm::SmallDenseSet<unsigned, 16> worklistSet;
   // Parameter for local buffer size threshold.
   unsigned localBufSizeThreshold;
   // Parameter for fast memory space.
@@ -1335,14 +1344,16 @@ public:
         fastMemorySpace(fastMemorySpace), maximalFusion(maximalFusion),
         computeToleranceThreshold(computeToleranceThreshold) {}
 
-  /// Initializes 'worklist' with nodes from 'mdg'.
+  // Initializes 'worklist' with nodes from 'mdg'
   void init() {
     // TODO: Add a priority queue for prioritizing nodes by different
     // metrics (e.g. arithmetic intensity/flops-to-bytes ratio).
     worklist.clear();
+    worklistSet.clear();
     for (auto &idAndNode : mdg->nodes) {
       const Node &node = idAndNode.second;
       worklist.push_back(node.id);
+      worklistSet.insert(node.id);
     }
   }
 
@@ -1361,11 +1372,11 @@ public:
   }
 
   void fuseProducerConsumerNodes(unsigned maxSrcUserCount) {
-    LLVM_DEBUG(llvm::dbgs() << "--- Producer/Consumer Fusion ---\n");
     init();
     while (!worklist.empty()) {
       unsigned dstId = worklist.back();
       worklist.pop_back();
+      worklistSet.erase(dstId);
 
       // Skip if this node was removed (fused into another node).
       if (mdg->nodes.count(dstId) == 0)
@@ -1375,97 +1386,114 @@ public:
       // Skip if 'dstNode' is not a loop nest.
       if (!isa<AffineForOp>(dstNode->op))
         continue;
-
-      LLVM_DEBUG(llvm::dbgs() << "Evaluating dst loop " << dstId << "\n");
-
       // Sink sequential loops in 'dstNode' (and thus raise parallel loops)
       // while preserving relative order. This can increase the maximum loop
       // depth at which we can fuse a slice of a producer loop nest into a
       // consumer loop nest.
       sinkSequentialLoops(dstNode);
-      auto dstAffineForOp = cast<AffineForOp>(dstNode->op);
-
-      // Try to fuse 'dstNode' with candidate producer loops until a fixed point
-      // is reached. Fusing two loops may expose new fusion opportunities.
-      bool dstNodeChanged;
-      do {
-        // Gather src loop candidates for 'dstNode' and visit them in "quasi"
-        // reverse program order to minimize the number of iterations needed to
-        // reach the fixed point. Note that this is a best effort approach since
-        // 'getProducerCandidates' does not always guarantee that program order
-        // in 'srcIdCandidates'.
-        dstNodeChanged = false;
-        DenseSet<unsigned> srcIdCandidates;
-        getProducerCandidates(dstId, mdg, srcIdCandidates);
-
-        /// Visit candidates in reverse node id order. This order corresponds to
-        /// the reverse program order when the 'mdg' is created. However,
-        /// reverse program order is not guaranteed and must not be required.
-        /// Reverse program order won't be held if the 'mdg' is reused from a
-        /// previous fusion step or if the node creation order changes in the
-        /// future to support more advance cases.
-        SmallVector<unsigned, 16> sortedSrcIdCandidates;
-        sortedSrcIdCandidates.reserve(srcIdCandidates.size());
-        sortedSrcIdCandidates.append(srcIdCandidates.begin(),
-                                     srcIdCandidates.end());
-        llvm::sort(sortedSrcIdCandidates, std::greater<unsigned>());
-
-        for (unsigned srcId : sortedSrcIdCandidates) {
+
+      SmallVector<Operation *, 4> loads = dstNode->loads;
+      SmallVector<Operation *, 4> dstLoadOpInsts;
+      DenseSet<Value> visitedMemrefs;
+      while (!loads.empty()) {
+        // Get memref of load on top of the stack.
+        auto memref = cast<AffineReadOpInterface>(loads.back()).getMemRef();
+        if (visitedMemrefs.count(memref) > 0)
+          continue;
+        visitedMemrefs.insert(memref);
+        // Move all loads in 'loads' accessing 'memref' to 'dstLoadOpInsts'.
+        moveLoadsAccessingMemrefTo(memref, &loads, &dstLoadOpInsts);
+        // Skip if no input edges along which to fuse.
+        if (mdg->inEdges.count(dstId) == 0)
+          continue;
+        // Iterate through in-edges for 'dstId' and src node id for any
+        // edges on 'memref'.
+        SmallVector<unsigned, 2> srcNodeIds;
+        for (auto &srcEdge : mdg->inEdges[dstId]) {
+          // Skip 'srcEdge' if not for 'memref'.
+          if (srcEdge.value != memref)
+            continue;
+          srcNodeIds.push_back(srcEdge.id);
+        }
+        for (unsigned srcId : srcNodeIds) {
+          // Skip if this node was removed (fused into another node).
+          if (mdg->nodes.count(srcId) == 0)
+            continue;
           // Get 'srcNode' from which to attempt fusion into 'dstNode'.
           auto *srcNode = mdg->getNode(srcId);
-          auto srcAffineForOp = cast<AffineForOp>(srcNode->op);
-          LLVM_DEBUG(llvm::dbgs() << "Evaluating src loop " << srcId
-                                  << " for dst loop " << dstId << "\n");
-
-          DenseSet<Value> producerConsumerMemrefs;
-          gatherProducerConsumerMemrefs(srcId, dstId, mdg,
-                                        producerConsumerMemrefs);
-
-          // Skip if 'srcNode' out edge count on any memref is greater than
-          // 'maxSrcUserCount'.
-          if (any_of(producerConsumerMemrefs, [&](Value memref) {
-                return mdg->getOutEdgeCount(srcNode->id, memref) >
-                       maxSrcUserCount;
-              }))
+          // Skip if 'srcNode' is not a loop nest.
+          if (!isa<AffineForOp>(srcNode->op))
             continue;
+          // Skip if 'srcNode' has more than one live-out store to a
+          // function-local memref.
+          // TODO: Support more generic multi-output src loop nests
+          // fusion.
+          auto srcStoreOp = mdg->getUniqueOutgoingStore(srcNode);
+          if (!srcStoreOp) {
+            // Get the src store op at the deepest loop depth.
+            // We will use 'LoopFusionUtils::canFuseLoops' to check fusion
+            // feasibility for loops with multiple stores.
+            unsigned maxLoopDepth = 0;
+            for (auto *op : srcNode->stores) {
+              auto storeOp = cast<AffineWriteOpInterface>(op);
+              if (storeOp.getMemRef() != memref) {
+                srcStoreOp = nullptr;
+                break;
+              }
+              unsigned loopDepth = getNestingDepth(storeOp);
+              if (loopDepth > maxLoopDepth) {
+                maxLoopDepth = loopDepth;
+                srcStoreOp = storeOp;
+              }
+            }
+            if (!srcStoreOp)
+              continue;
+          }
 
-          // Gather memrefs in 'srcNode' that are written and escape to the
-          // function (e.g., memref function arguments, returned memrefs,
-          // memrefs passed to function calls, etc.).
-          DenseSet<Value> srcEscapingMemRefs;
-          gatherEscapingMemrefs(srcNode->id, mdg, srcEscapingMemRefs);
-
-          // Skip if there are non-affine operations in between the 'srcNode'
-          // and 'dstNode' using their memrefs. If so, we wouldn't be able to
-          // compute a legal insertion point for now. 'srcNode' and 'dstNode'
-          // memrefs with non-affine operation users would be considered
-          // escaping memrefs so we can limit this check to only scenarios with
-          // escaping memrefs.
-          if (!srcEscapingMemRefs.empty() &&
-              hasNonAffineUsersOnThePath(srcId, dstId, mdg)) {
-            LLVM_DEBUG(
-                llvm::dbgs()
-                << "Can't fuse: non-affine users in between the loops\n.");
+          // Unique outgoing store found must write to 'memref' since 'memref'
+          // is the one that established the producer-consumer relationship
+          // between 'srcNode' and 'dstNode'.
+          assert(srcStoreOp.getMemRef() == memref &&
+                 "Found store to unexpected memref");
+
+          // Skip if 'srcNode' writes to any live in or escaping memrefs,
+          // and cannot be fused.
+          bool writesToLiveInOrOut =
+              mdg->writesToLiveInOrEscapingMemrefs(srcNode->id);
+          if (writesToLiveInOrOut &&
+              !canFuseSrcWhichWritesToLiveOut(srcId, dstId, srcStoreOp, mdg))
             continue;
+
+          // Don't create a private memref if 'writesToLiveInOrOut'.
+          bool createPrivateMemref = !writesToLiveInOrOut;
+          // Don't create a private memref if 'srcNode' has in edges on
+          // 'memref', or if 'dstNode' has out edges on 'memref'.
+          if (mdg->getIncomingMemRefAccesses(srcNode->id, memref) > 0 ||
+              mdg->getOutEdgeCount(dstNode->id, memref) > 0) {
+            createPrivateMemref = false;
           }
 
+          // Skip if 'srcNode' out edge count on 'memref' > 'maxSrcUserCount'.
+          if (mdg->getOutEdgeCount(srcNode->id, memref) > maxSrcUserCount)
+            continue;
+
           // Compute an operation list insertion point for the fused loop
           // nest which preserves dependences.
-          Operation *fusedLoopInsPoint =
+          Operation *insertPointInst =
               mdg->getFusedLoopNestInsertionPoint(srcNode->id, dstNode->id);
-          if (fusedLoopInsPoint == nullptr)
+          if (insertPointInst == nullptr)
             continue;
 
-          // Compute the innermost common loop depth for dstNode
-          // producer-consumer loads/stores.
+          auto srcAffineForOp = cast<AffineForOp>(srcNode->op);
+          auto dstAffineForOp = cast<AffineForOp>(dstNode->op);
+
+          // Compute the innermost common loop depth for dstNode loads/stores.
           SmallVector<Operation *, 2> dstMemrefOps;
           for (Operation *op : dstNode->loads)
-            if (producerConsumerMemrefs.count(
-                    cast<AffineReadOpInterface>(op).getMemRef()) > 0)
+            if (cast<AffineReadOpInterface>(op).getMemRef() == memref)
               dstMemrefOps.push_back(op);
           for (Operation *op : dstNode->stores)
-            if (producerConsumerMemrefs.count(
-                    cast<AffineWriteOpInterface>(op).getMemRef()))
+            if (cast<AffineWriteOpInterface>(op).getMemRef() == memref)
               dstMemrefOps.push_back(op);
           unsigned dstLoopDepthTest = getInnermostCommonLoopDepth(dstMemrefOps);
 
@@ -1474,7 +1502,7 @@ public:
           unsigned maxLegalFusionDepth = 0;
           SmallVector<ComputationSliceState, 8> depthSliceUnions;
           depthSliceUnions.resize(dstLoopDepthTest);
-          FusionStrategy strategy(FusionStrategy::ProducerConsumer);
+          FusionStrategy strategy(FusionStrategy::ProducerConsumer, memref);
           for (unsigned i = 1; i <= dstLoopDepthTest; ++i) {
             FusionResult result = mlir::canFuseLoops(
                 srcAffineForOp, dstAffineForOp,
@@ -1484,82 +1512,27 @@ public:
               maxLegalFusionDepth = i;
           }
 
-          if (maxLegalFusionDepth == 0) {
-            LLVM_DEBUG(llvm::dbgs()
-                       << "Can't fuse: fusion is not legal at any depth\n");
+          // Skip if fusion is not feasible at any loop depths.
+          if (maxLegalFusionDepth == 0)
             continue;
-          }
 
           // Check if fusion would be profitable. We skip profitability analysis
           // for maximal fusion since we already know the maximal legal depth to
           // fuse.
           unsigned bestDstLoopDepth = maxLegalFusionDepth;
-          if (!maximalFusion) {
-            // Retrieve producer stores from the src loop.
-            SmallVector<Operation *, 2> producerStores;
-            for (Operation *op : srcNode->stores)
-              if (producerConsumerMemrefs.count(
-                      cast<AffineWriteOpInterface>(op).getMemRef()))
-                producerStores.push_back(op);
-
-            // TODO: Suppport multiple producer stores in profitability
-            // analysis. We limit profitability analysis to only scenarios with
-            // a single producer store for now. Note that some multi-store
-            // producer scenarios will still go through profitability analysis
-            // if only one of the stores is involved the producer-consumer
-            // relationship of the candidate loops.
-            assert(producerStores.size() > 0 && "Expected producer store");
-            if (producerStores.size() > 1)
-              LLVM_DEBUG(llvm::dbgs() << "Skipping profitability analysis. Not "
-                                         "supported for this case\n");
-            else if (!isFusionProfitable(producerStores[0], producerStores[0],
-                                         dstAffineForOp, depthSliceUnions,
-                                         maxLegalFusionDepth, &bestDstLoopDepth,
-                                         computeToleranceThreshold))
-              continue;
-          }
+          if (!maximalFusion &&
+              !isFusionProfitable(srcStoreOp, srcStoreOp, dstLoadOpInsts,
+                                  depthSliceUnions, maxLegalFusionDepth,
+                                  &bestDstLoopDepth, computeToleranceThreshold))
+            continue;
 
           assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth");
-          ComputationSliceState &bestSlice =
-              depthSliceUnions[bestDstLoopDepth - 1];
-          assert(!bestSlice.isEmpty() && "Missing slice union for depth");
-
-          // Determine if 'srcId' can be removed after fusion, taking into
-          // account remaining dependences, escaping memrefs and the fusion
-          // insertion point.
-          bool removeSrcNode = canRemoveSrcNodeAfterFusion(
-              srcId, dstId, bestSlice, fusedLoopInsPoint, srcEscapingMemRefs,
-              mdg);
-
-          DenseSet<Value> privateMemrefs;
-          for (Value memref : producerConsumerMemrefs) {
-            // Don't create a private memref if 'srcNode' writes to escaping
-            // memrefs.
-            if (srcEscapingMemRefs.count(memref) > 0)
-              continue;
-
-            // Don't create a private memref if 'srcNode' has in edges on
-            // 'memref' or 'dstNode' has out edges on 'memref'.
-            if (mdg->getIncomingMemRefAccesses(srcId, memref) > 0 ||
-                mdg->getOutEdgeCount(dstId, memref) > 0)
-              continue;
-
-            // If 'srcNode' will be removed but it has out edges on 'memref' to
-            // nodes other than 'dstNode', we have to preserve dependences and
-            // cannot create a private memref.
-            if (removeSrcNode &&
-                any_of(mdg->outEdges[srcId], [&](const auto &edge) {
-                  return edge.value == memref && edge.id != dstId;
-                }))
-              continue;
-
-            // Create a private version of this memref.
-            privateMemrefs.insert(memref);
-          }
+          assert(!depthSliceUnions[bestDstLoopDepth - 1].isEmpty() &&
+                 "Missing slice union for depth");
 
           // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'.
-          fuseLoops(srcAffineForOp, dstAffineForOp, bestSlice);
-          dstNodeChanged = true;
+          fuseLoops(srcAffineForOp, dstAffineForOp,
+                    depthSliceUnions[bestDstLoopDepth - 1]);
 
           LLVM_DEBUG(llvm::dbgs()
                      << "Fused src loop " << srcId << " into dst loop " << dstId
@@ -1567,20 +1540,18 @@ public:
                      << dstAffineForOp << "\n");
 
           // Move 'dstAffineForOp' before 'insertPointInst' if needed.
-          if (fusedLoopInsPoint != dstAffineForOp.getOperation())
-            dstAffineForOp.getOperation()->moveBefore(fusedLoopInsPoint);
+          if (insertPointInst != dstAffineForOp.getOperation())
+            dstAffineForOp->moveBefore(insertPointInst);
 
           // Update edges between 'srcNode' and 'dstNode'.
-          mdg->updateEdges(srcNode->id, dstNode->id, privateMemrefs,
-                           removeSrcNode);
+          mdg->updateEdges(srcNode->id, dstNode->id, memref,
+                           createPrivateMemref);
 
           // Collect slice loop stats.
           LoopNestStateCollector dstForCollector;
           dstForCollector.collect(dstAffineForOp);
-          for (Value memref : privateMemrefs) {
+          if (createPrivateMemref) {
             // Create private memref for 'memref' in 'dstAffineForOp'.
-            // TODO: remove storesForMemref and move the code below to the
-            // loop-if.
             SmallVector<Operation *, 4> storesForMemref;
             for (auto *storeOpInst : dstForCollector.storeOpInsts) {
               if (cast<AffineWriteOpInterface>(storeOpInst).getMemRef() ==
@@ -1592,6 +1563,7 @@ public:
             auto newMemRef = createPrivateMemRef(
                 dstAffineForOp, storesForMemref[0], bestDstLoopDepth,
                 fastMemorySpace, localBufSizeThreshold);
+            visitedMemrefs.insert(newMemRef);
             // Create new node in dependence graph for 'newMemRef' alloc op.
             unsigned newMemRefNodeId = mdg->addNode(newMemRef.getDefiningOp());
             // Add edge from 'newMemRef' node to dstNode.
@@ -1602,21 +1574,58 @@ public:
           LoopNestStateCollector dstLoopCollector;
           dstLoopCollector.collect(dstAffineForOp.getOperation());
 
+          // Add new load ops to current Node load op list 'loads' to continue
+          // fusing based on new operands.
+          for (auto *loadOpInst : dstLoopCollector.loadOpInsts) {
+            // NOTE: Change 'loads' to a hash set in case efficiency is an
+            // issue. We still use a vector since it's expected to be small.
+            if (!llvm::is_contained(loads, loadOpInst))
+              loads.push_back(loadOpInst);
+          }
+          // Clear visited memrefs after fusion so that previously visited src
+          // nodes are considered for fusion again in the context of the new
+          // fused node.
+          // TODO: This shouldn't be necessary if we visited candidates in the
+          // dependence graph in post-order or once we fully support multi-store
+          // producers. Currently, in a multi-store producer scenario such as
+          // A->B, A->C, B->C, we fail to fuse A+B due to the multiple outgoing
+          // edges. However, after fusing B+C, A has a single outgoing edge and
+          // can be fused if we revisit it in the context of the new fused B+C
+          // node.
+          visitedMemrefs.clear();
+
           // Clear and add back loads and stores.
           mdg->clearNodeLoadAndStores(dstNode->id);
           mdg->addToNode(dstId, dstLoopCollector.loadOpInsts,
                          dstLoopCollector.storeOpInsts);
-
-          if (removeSrcNode) {
-            LLVM_DEBUG(llvm::dbgs()
-                       << "Removing src loop " << srcId << " after fusion\n");
-            // srcNode is no longer valid after it is removed from mdg.
-            srcAffineForOp.erase();
-            mdg->removeNode(srcId);
-            srcNode = nullptr;
+          // Remove old src loop nest if it no longer has outgoing dependence
+          // edges, and if it does not write to a memref which escapes the
+          // function. If 'writesToLiveInOrOut' is true, then 'srcNode' has been
+          // fused into 'dstNode' and write region of 'dstNode' covers the write
+          // region of 'srcNode', and 'srcNode' has no other users so it is safe
+          // to remove.
+          if (writesToLiveInOrOut || mdg->canRemoveNode(srcNode->id)) {
+            mdg->removeNode(srcNode->id);
+            srcNode->op->erase();
+          } else {
+            // Add remaining users of 'oldMemRef' back on the worklist (if not
+            // already there), as its replacement with a local/private memref
+            // has reduced dependences on 'oldMemRef' which may have created new
+            // fusion opportunities.
+            if (mdg->outEdges.count(srcNode->id) > 0) {
+              SmallVector<MemRefDependenceGraph::Edge, 2> oldOutEdges =
+                  mdg->outEdges[srcNode->id];
+              for (auto &outEdge : oldOutEdges) {
+                if (outEdge.value == memref &&
+                    worklistSet.count(outEdge.id) == 0) {
+                  worklist.push_back(outEdge.id);
+                  worklistSet.insert(outEdge.id);
+                }
+              }
+            }
           }
         }
-      } while (dstNodeChanged);
+      }
     }
   }
 
@@ -1627,6 +1636,7 @@ public:
     while (!worklist.empty()) {
       unsigned dstId = worklist.back();
       worklist.pop_back();
+      worklistSet.erase(dstId);
 
       // Skip if this node was removed (fused into another node).
       if (mdg->nodes.count(dstId) == 0)
@@ -1688,7 +1698,7 @@ public:
       SmallVector<ComputationSliceState, 8> depthSliceUnions;
       depthSliceUnions.resize(dstLoopDepthTest);
       unsigned maxLegalFusionDepth = 0;
-      FusionStrategy strategy(memref);
+      FusionStrategy strategy(FusionStrategy::Sibling, memref);
       for (unsigned i = 1; i <= dstLoopDepthTest; ++i) {
         FusionResult result = mlir::canFuseLoops(
             sibAffineForOp, dstAffineForOp,
@@ -1702,10 +1712,10 @@ public:
       if (maxLegalFusionDepth == 0)
         continue;
 
-      unsigned bestDstLoopDepth = maxLegalFusionDepth;
+      unsigned bestDstLoopDepth = dstLoopDepthTest;
       if (!maximalFusion) {
         // Check if fusion would be profitable.
-        if (!isFusionProfitable(sibLoadOpInst, sibStoreOpInst, dstAffineForOp,
+        if (!isFusionProfitable(sibLoadOpInst, sibStoreOpInst, dstLoadOpInsts,
                                 depthSliceUnions, maxLegalFusionDepth,
                                 &bestDstLoopDepth, computeToleranceThreshold))
           continue;