383a6587bbef2f649573d5fe03e88dc453984dbb
[lldb.git] / mlir / lib / Analysis / Utils.cpp
1 //===- Utils.cpp ---- Misc utilities for analysis -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements miscellaneous analysis routines for non-loop IR
10 // structures.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/Analysis/Utils.h"
15 #include "mlir/Analysis/AffineAnalysis.h"
16 #include "mlir/Analysis/PresburgerSet.h"
17 #include "mlir/Dialect/Affine/IR/AffineOps.h"
18 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
19 #include "mlir/Dialect/StandardOps/IR/Ops.h"
20 #include "mlir/IR/IntegerSet.h"
21 #include "llvm/ADT/SmallPtrSet.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/raw_ostream.h"
24
25 #define DEBUG_TYPE "analysis-utils"
26
27 using namespace mlir;
28
29 using llvm::SmallDenseMap;
30
31 /// Populates 'loops' with IVs of the loops surrounding 'op' ordered from
32 /// the outermost 'affine.for' operation to the innermost one.
33 void mlir::getLoopIVs(Operation &op, SmallVectorImpl<AffineForOp> *loops) {
34   auto *currOp = op.getParentOp();
35   AffineForOp currAffineForOp;
36   // Traverse up the hierarchy collecting all 'affine.for' operation while
37   // skipping over 'affine.if' operations.
38   while (currOp && ((currAffineForOp = dyn_cast<AffineForOp>(currOp)) ||
39                     isa<AffineIfOp>(currOp))) {
40     if (currAffineForOp)
41       loops->push_back(currAffineForOp);
42     currOp = currOp->getParentOp();
43   }
44   std::reverse(loops->begin(), loops->end());
45 }
46
47 /// Populates 'ops' with IVs of the loops surrounding `op`, along with
48 /// `affine.if` operations interleaved between these loops, ordered from the
49 /// outermost `affine.for` operation to the innermost one.
50 void mlir::getEnclosingAffineForAndIfOps(Operation &op,
51                                          SmallVectorImpl<Operation *> *ops) {
52   ops->clear();
53   Operation *currOp = op.getParentOp();
54
55   // Traverse up the hierarchy collecting all `affine.for` and `affine.if`
56   // operations.
57   while (currOp && (isa<AffineIfOp, AffineForOp>(currOp))) {
58     ops->push_back(currOp);
59     currOp = currOp->getParentOp();
60   }
61   std::reverse(ops->begin(), ops->end());
62 }
63
64 // Populates 'cst' with FlatAffineConstraints which represent slice bounds.
65 LogicalResult
66 ComputationSliceState::getAsConstraints(FlatAffineConstraints *cst) {
67   assert(!lbOperands.empty());
68   // Adds src 'ivs' as dimension identifiers in 'cst'.
69   unsigned numDims = ivs.size();
70   // Adds operands (dst ivs and symbols) as symbols in 'cst'.
71   unsigned numSymbols = lbOperands[0].size();
72
73   SmallVector<Value, 4> values(ivs);
74   // Append 'ivs' then 'operands' to 'values'.
75   values.append(lbOperands[0].begin(), lbOperands[0].end());
76   cst->reset(numDims, numSymbols, 0, values);
77
78   // Add loop bound constraints for values which are loop IVs and equality
79   // constraints for symbols which are constants.
80   for (const auto &value : values) {
81     assert(cst->containsId(value) && "value expected to be present");
82     if (isValidSymbol(value)) {
83       // Check if the symbol is a constant.
84       if (auto cOp = value.getDefiningOp<ConstantIndexOp>())
85         cst->setIdToConstant(value, cOp.getValue());
86     } else if (auto loop = getForInductionVarOwner(value)) {
87       if (failed(cst->addAffineForOpDomain(loop)))
88         return failure();
89     }
90   }
91
92   // Add slices bounds on 'ivs' using maps 'lbs'/'ubs' with 'lbOperands[0]'
93   LogicalResult ret = cst->addSliceBounds(ivs, lbs, ubs, lbOperands[0]);
94   assert(succeeded(ret) &&
95          "should not fail as we never have semi-affine slice maps");
96   (void)ret;
97   return success();
98 }
99
100 // Clears state bounds and operand state.
101 void ComputationSliceState::clearBounds() {
102   lbs.clear();
103   ubs.clear();
104   lbOperands.clear();
105   ubOperands.clear();
106 }
107
108 void ComputationSliceState::dump() const {
109   llvm::errs() << "\tIVs:\n";
110   for (Value iv : ivs)
111     llvm::errs() << "\t\t" << iv << "\n";
112
113   llvm::errs() << "\tLBs:\n";
114   for (auto &en : llvm::enumerate(lbs)) {
115     llvm::errs() << "\t\t" << en.value() << "\n";
116     llvm::errs() << "\t\tOperands:\n";
117     for (Value lbOp : lbOperands[en.index()])
118       llvm::errs() << "\t\t\t" << lbOp << "\n";
119   }
120
121   llvm::errs() << "\tUBs:\n";
122   for (auto &en : llvm::enumerate(ubs)) {
123     llvm::errs() << "\t\t" << en.value() << "\n";
124     llvm::errs() << "\t\tOperands:\n";
125     for (Value ubOp : ubOperands[en.index()])
126       llvm::errs() << "\t\t\t" << ubOp << "\n";
127   }
128 }
129
130 /// Fast check to determine if the computation slice is maximal. Returns true if
131 /// each slice dimension maps to an existing dst dimension and both the src
132 /// and the dst loops for those dimensions have the same bounds. Returns false
133 /// if both the src and the dst loops don't have the same bounds. Returns
134 /// llvm::None if none of the above can be proven.
135 Optional<bool> ComputationSliceState::isSliceMaximalFastCheck() const {
136   assert(lbs.size() == ubs.size() && lbs.size() && ivs.size() &&
137          "Unexpected number of lbs, ubs and ivs in slice");
138
139   for (unsigned i = 0, end = lbs.size(); i < end; ++i) {
140     AffineMap lbMap = lbs[i];
141     AffineMap ubMap = ubs[i];
142
143     // Check if this slice is just an equality along this dimension.
144     if (!lbMap || !ubMap || lbMap.getNumResults() != 1 ||
145         ubMap.getNumResults() != 1 ||
146         lbMap.getResult(0) + 1 != ubMap.getResult(0) ||
147         // The condition above will be true for maps describing a single
148         // iteration (e.g., lbMap.getResult(0) = 0, ubMap.getResult(0) = 1).
149         // Make sure we skip those cases by checking that the lb result is not
150         // just a constant.
151         lbMap.getResult(0).isa<AffineConstantExpr>())
152       return llvm::None;
153
154     // Limited support: we expect the lb result to be just a loop dimension for
155     // now.
156     AffineDimExpr result = lbMap.getResult(0).dyn_cast<AffineDimExpr>();
157     if (!result)
158       return llvm::None;
159
160     // Retrieve dst loop bounds.
161     AffineForOp dstLoop =
162         getForInductionVarOwner(lbOperands[i][result.getPosition()]);
163     if (!dstLoop)
164       return llvm::None;
165     AffineMap dstLbMap = dstLoop.getLowerBoundMap();
166     AffineMap dstUbMap = dstLoop.getUpperBoundMap();
167
168     // Retrieve src loop bounds.
169     AffineForOp srcLoop = getForInductionVarOwner(ivs[i]);
170     assert(srcLoop && "Expected affine for");
171     AffineMap srcLbMap = srcLoop.getLowerBoundMap();
172     AffineMap srcUbMap = srcLoop.getUpperBoundMap();
173
174     // Limited support: we expect simple src and dst loops with a single
175     // constant component per bound for now.
176     if (srcLbMap.getNumResults() != 1 || srcUbMap.getNumResults() != 1 ||
177         dstLbMap.getNumResults() != 1 || dstUbMap.getNumResults() != 1)
178       return llvm::None;
179
180     AffineExpr srcLbResult = srcLbMap.getResult(0);
181     AffineExpr dstLbResult = dstLbMap.getResult(0);
182     AffineExpr srcUbResult = srcUbMap.getResult(0);
183     AffineExpr dstUbResult = dstUbMap.getResult(0);
184     if (!srcLbResult.isa<AffineConstantExpr>() ||
185         !srcUbResult.isa<AffineConstantExpr>() ||
186         !dstLbResult.isa<AffineConstantExpr>() ||
187         !dstUbResult.isa<AffineConstantExpr>())
188       return llvm::None;
189
190     // Check if src and dst loop bounds are the same. If not, we can guarantee
191     // that the slice is not maximal.
192     if (srcLbResult != dstLbResult || srcUbResult != dstUbResult)
193       return false;
194   }
195
196   return true;
197 }
198
199 /// Returns true if the computation slice encloses all the iterations of the
200 /// sliced loop nest. Returns false if it does not. Returns llvm::None if it
201 /// cannot determine if the slice is maximal or not.
202 Optional<bool> ComputationSliceState::isMaximal() const {
203   // Fast check to determine if the computation slice is maximal. If the result
204   // is inconclusive, we proceed with a more expensive analysis.
205   Optional<bool> isMaximalFastCheck = isSliceMaximalFastCheck();
206   if (isMaximalFastCheck.hasValue())
207     return isMaximalFastCheck;
208
209   // Create constraints for the src loop nest being sliced.
210   FlatAffineConstraints srcConstraints;
211   srcConstraints.reset(/*numDims=*/ivs.size(), /*numSymbols=*/0,
212                        /*numLocals=*/0, ivs);
213   for (Value iv : ivs) {
214     AffineForOp loop = getForInductionVarOwner(iv);
215     assert(loop && "Expected affine for");
216     if (failed(srcConstraints.addAffineForOpDomain(loop)))
217       return llvm::None;
218   }
219
220   // Create constraints for the slice using the dst loop nest information. We
221   // retrieve existing dst loops from the lbOperands.
222   SmallVector<Value, 8> consumerIVs;
223   for (Value lbOp : lbOperands[0])
224     if (getForInductionVarOwner(lbOp))
225       consumerIVs.push_back(lbOp);
226
227   // Add empty IV Values for those new loops that are not equalities and,
228   // therefore, are not yet materialized in the IR.
229   for (int i = consumerIVs.size(), end = ivs.size(); i < end; ++i)
230     consumerIVs.push_back(Value());
231
232   FlatAffineConstraints sliceConstraints;
233   sliceConstraints.reset(/*numDims=*/consumerIVs.size(), /*numSymbols=*/0,
234                          /*numLocals=*/0, consumerIVs);
235
236   if (failed(sliceConstraints.addDomainFromSliceMaps(lbs, ubs, lbOperands[0])))
237     return llvm::None;
238
239   if (srcConstraints.getNumDimIds() != sliceConstraints.getNumDimIds())
240     // Constraint dims are different. The integer set difference can't be
241     // computed so we don't know if the slice is maximal.
242     return llvm::None;
243
244   // Compute the difference between the src loop nest and the slice integer
245   // sets.
246   PresburgerSet srcSet(srcConstraints);
247   PresburgerSet sliceSet(sliceConstraints);
248   PresburgerSet diffSet = srcSet.subtract(sliceSet);
249   return diffSet.isIntegerEmpty();
250 }
251
252 unsigned MemRefRegion::getRank() const {
253   return memref.getType().cast<MemRefType>().getRank();
254 }
255
256 Optional<int64_t> MemRefRegion::getConstantBoundingSizeAndShape(
257     SmallVectorImpl<int64_t> *shape, std::vector<SmallVector<int64_t, 4>> *lbs,
258     SmallVectorImpl<int64_t> *lbDivisors) const {
259   auto memRefType = memref.getType().cast<MemRefType>();
260   unsigned rank = memRefType.getRank();
261   if (shape)
262     shape->reserve(rank);
263
264   assert(rank == cst.getNumDimIds() && "inconsistent memref region");
265
266   // Use a copy of the region constraints that has upper/lower bounds for each
267   // memref dimension with static size added to guard against potential
268   // over-approximation from projection or union bounding box. We may not add
269   // this on the region itself since they might just be redundant constraints
270   // that will need non-trivials means to eliminate.
271   FlatAffineConstraints cstWithShapeBounds(cst);
272   for (unsigned r = 0; r < rank; r++) {
273     cstWithShapeBounds.addConstantLowerBound(r, 0);
274     int64_t dimSize = memRefType.getDimSize(r);
275     if (ShapedType::isDynamic(dimSize))
276       continue;
277     cstWithShapeBounds.addConstantUpperBound(r, dimSize - 1);
278   }
279
280   // Find a constant upper bound on the extent of this memref region along each
281   // dimension.
282   int64_t numElements = 1;
283   int64_t diffConstant;
284   int64_t lbDivisor;
285   for (unsigned d = 0; d < rank; d++) {
286     SmallVector<int64_t, 4> lb;
287     Optional<int64_t> diff =
288         cstWithShapeBounds.getConstantBoundOnDimSize(d, &lb, &lbDivisor);
289     if (diff.hasValue()) {
290       diffConstant = diff.getValue();
291       assert(lbDivisor > 0);
292     } else {
293       // If no constant bound is found, then it can always be bound by the
294       // memref's dim size if the latter has a constant size along this dim.
295       auto dimSize = memRefType.getDimSize(d);
296       if (dimSize == -1)
297         return None;
298       diffConstant = dimSize;
299       // Lower bound becomes 0.
300       lb.resize(cstWithShapeBounds.getNumSymbolIds() + 1, 0);
301       lbDivisor = 1;
302     }
303     numElements *= diffConstant;
304     if (lbs) {
305       lbs->push_back(lb);
306       assert(lbDivisors && "both lbs and lbDivisor or none");
307       lbDivisors->push_back(lbDivisor);
308     }
309     if (shape) {
310       shape->push_back(diffConstant);
311     }
312   }
313   return numElements;
314 }
315
316 void MemRefRegion::getLowerAndUpperBound(unsigned pos, AffineMap &lbMap,
317                                          AffineMap &ubMap) const {
318   assert(pos < cst.getNumDimIds() && "invalid position");
319   auto memRefType = memref.getType().cast<MemRefType>();
320   unsigned rank = memRefType.getRank();
321
322   assert(rank == cst.getNumDimIds() && "inconsistent memref region");
323
324   auto boundPairs = cst.getLowerAndUpperBound(
325       pos, /*offset=*/0, /*num=*/rank, cst.getNumDimAndSymbolIds(),
326       /*localExprs=*/{}, memRefType.getContext());
327   lbMap = boundPairs.first;
328   ubMap = boundPairs.second;
329   assert(lbMap && "lower bound for a region must exist");
330   assert(ubMap && "upper bound for a region must exist");
331   assert(lbMap.getNumInputs() == cst.getNumDimAndSymbolIds() - rank);
332   assert(ubMap.getNumInputs() == cst.getNumDimAndSymbolIds() - rank);
333 }
334
335 LogicalResult MemRefRegion::unionBoundingBox(const MemRefRegion &other) {
336   assert(memref == other.memref);
337   return cst.unionBoundingBox(*other.getConstraints());
338 }
339
340 /// Computes the memory region accessed by this memref with the region
341 /// represented as constraints symbolic/parametric in 'loopDepth' loops
342 /// surrounding opInst and any additional Function symbols.
343 //  For example, the memref region for this load operation at loopDepth = 1 will
344 //  be as below:
345 //
346 //    affine.for %i = 0 to 32 {
347 //      affine.for %ii = %i to (d0) -> (d0 + 8) (%i) {
348 //        load %A[%ii]
349 //      }
350 //    }
351 //
352 // region:  {memref = %A, write = false, {%i <= m0 <= %i + 7} }
353 // The last field is a 2-d FlatAffineConstraints symbolic in %i.
354 //
355 // TODO: extend this to any other memref dereferencing ops
356 // (dma_start, dma_wait).
357 LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth,
358                                     const ComputationSliceState *sliceState,
359                                     bool addMemRefDimBounds) {
360   assert((isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) &&
361          "affine read/write op expected");
362
363   MemRefAccess access(op);
364   memref = access.memref;
365   write = access.isStore();
366
367   unsigned rank = access.getRank();
368
369   LLVM_DEBUG(llvm::dbgs() << "MemRefRegion::compute: " << *op
370                           << "depth: " << loopDepth << "\n";);
371
372   // 0-d memrefs.
373   if (rank == 0) {
374     SmallVector<AffineForOp, 4> ivs;
375     getLoopIVs(*op, &ivs);
376     assert(loopDepth <= ivs.size() && "invalid 'loopDepth'");
377     // The first 'loopDepth' IVs are symbols for this region.
378     ivs.resize(loopDepth);
379     SmallVector<Value, 4> regionSymbols;
380     extractForInductionVars(ivs, &regionSymbols);
381     // A 0-d memref has a 0-d region.
382     cst.reset(rank, loopDepth, /*numLocals=*/0, regionSymbols);
383     return success();
384   }
385
386   // Build the constraints for this region.
387   AffineValueMap accessValueMap;
388   access.getAccessMap(&accessValueMap);
389   AffineMap accessMap = accessValueMap.getAffineMap();
390
391   unsigned numDims = accessMap.getNumDims();
392   unsigned numSymbols = accessMap.getNumSymbols();
393   unsigned numOperands = accessValueMap.getNumOperands();
394   // Merge operands with slice operands.
395   SmallVector<Value, 4> operands;
396   operands.resize(numOperands);
397   for (unsigned i = 0; i < numOperands; ++i)
398     operands[i] = accessValueMap.getOperand(i);
399
400   if (sliceState != nullptr) {
401     operands.reserve(operands.size() + sliceState->lbOperands[0].size());
402     // Append slice operands to 'operands' as symbols.
403     for (auto extraOperand : sliceState->lbOperands[0]) {
404       if (!llvm::is_contained(operands, extraOperand)) {
405         operands.push_back(extraOperand);
406         numSymbols++;
407       }
408     }
409   }
410   // We'll first associate the dims and symbols of the access map to the dims
411   // and symbols resp. of cst. This will change below once cst is
412   // fully constructed out.
413   cst.reset(numDims, numSymbols, 0, operands);
414
415   // Add equality constraints.
416   // Add inequalities for loop lower/upper bounds.
417   for (unsigned i = 0; i < numDims + numSymbols; ++i) {
418     auto operand = operands[i];
419     if (auto loop = getForInductionVarOwner(operand)) {
420       // Note that cst can now have more dimensions than accessMap if the
421       // bounds expressions involve outer loops or other symbols.
422       // TODO: rewrite this to use getInstIndexSet; this way
423       // conditionals will be handled when the latter supports it.
424       if (failed(cst.addAffineForOpDomain(loop)))
425         return failure();
426     } else {
427       // Has to be a valid symbol.
428       auto symbol = operand;
429       assert(isValidSymbol(symbol));
430       // Check if the symbol is a constant.
431       if (auto *op = symbol.getDefiningOp()) {
432         if (auto constOp = dyn_cast<ConstantIndexOp>(op)) {
433           cst.setIdToConstant(symbol, constOp.getValue());
434         }
435       }
436     }
437   }
438
439   // Add lower/upper bounds on loop IVs using bounds from 'sliceState'.
440   if (sliceState != nullptr) {
441     // Add dim and symbol slice operands.
442     for (auto operand : sliceState->lbOperands[0]) {
443       cst.addInductionVarOrTerminalSymbol(operand);
444     }
445     // Add upper/lower bounds from 'sliceState' to 'cst'.
446     LogicalResult ret =
447         cst.addSliceBounds(sliceState->ivs, sliceState->lbs, sliceState->ubs,
448                            sliceState->lbOperands[0]);
449     assert(succeeded(ret) &&
450            "should not fail as we never have semi-affine slice maps");
451     (void)ret;
452   }
453
454   // Add access function equalities to connect loop IVs to data dimensions.
455   if (failed(cst.composeMap(&accessValueMap))) {
456     op->emitError("getMemRefRegion: compose affine map failed");
457     LLVM_DEBUG(accessValueMap.getAffineMap().dump());
458     return failure();
459   }
460
461   // Set all identifiers appearing after the first 'rank' identifiers as
462   // symbolic identifiers - so that the ones corresponding to the memref
463   // dimensions are the dimensional identifiers for the memref region.
464   cst.setDimSymbolSeparation(cst.getNumDimAndSymbolIds() - rank);
465
466   // Eliminate any loop IVs other than the outermost 'loopDepth' IVs, on which
467   // this memref region is symbolic.
468   SmallVector<AffineForOp, 4> enclosingIVs;
469   getLoopIVs(*op, &enclosingIVs);
470   assert(loopDepth <= enclosingIVs.size() && "invalid loop depth");
471   enclosingIVs.resize(loopDepth);
472   SmallVector<Value, 4> ids;
473   cst.getIdValues(cst.getNumDimIds(), cst.getNumDimAndSymbolIds(), &ids);
474   for (auto id : ids) {
475     AffineForOp iv;
476     if ((iv = getForInductionVarOwner(id)) &&
477         llvm::is_contained(enclosingIVs, iv) == false) {
478       cst.projectOut(id);
479     }
480   }
481
482   // Project out any local variables (these would have been added for any
483   // mod/divs).
484   cst.projectOut(cst.getNumDimAndSymbolIds(), cst.getNumLocalIds());
485
486   // Constant fold any symbolic identifiers.
487   cst.constantFoldIdRange(/*pos=*/cst.getNumDimIds(),
488                           /*num=*/cst.getNumSymbolIds());
489
490   assert(cst.getNumDimIds() == rank && "unexpected MemRefRegion format");
491
492   // Add upper/lower bounds for each memref dimension with static size
493   // to guard against potential over-approximation from projection.
494   // TODO: Support dynamic memref dimensions.
495   if (addMemRefDimBounds) {
496     auto memRefType = memref.getType().cast<MemRefType>();
497     for (unsigned r = 0; r < rank; r++) {
498       cst.addConstantLowerBound(/*pos=*/r, /*lb=*/0);
499       if (memRefType.isDynamicDim(r))
500         continue;
501       cst.addConstantUpperBound(/*pos=*/r, memRefType.getDimSize(r) - 1);
502     }
503   }
504   cst.removeTrivialRedundancy();
505
506   LLVM_DEBUG(llvm::dbgs() << "Memory region:\n");
507   LLVM_DEBUG(cst.dump());
508   return success();
509 }
510
511 static unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
512   auto elementType = memRefType.getElementType();
513
514   unsigned sizeInBits;
515   if (elementType.isIntOrFloat()) {
516     sizeInBits = elementType.getIntOrFloatBitWidth();
517   } else {
518     auto vectorType = elementType.cast<VectorType>();
519     sizeInBits =
520         vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
521   }
522   return llvm::divideCeil(sizeInBits, 8);
523 }
524
525 // Returns the size of the region.
526 Optional<int64_t> MemRefRegion::getRegionSize() {
527   auto memRefType = memref.getType().cast<MemRefType>();
528
529   auto layoutMaps = memRefType.getAffineMaps();
530   if (layoutMaps.size() > 1 ||
531       (layoutMaps.size() == 1 && !layoutMaps[0].isIdentity())) {
532     LLVM_DEBUG(llvm::dbgs() << "Non-identity layout map not yet supported\n");
533     return false;
534   }
535
536   // Indices to use for the DmaStart op.
537   // Indices for the original memref being DMAed from/to.
538   SmallVector<Value, 4> memIndices;
539   // Indices for the faster buffer being DMAed into/from.
540   SmallVector<Value, 4> bufIndices;
541
542   // Compute the extents of the buffer.
543   Optional<int64_t> numElements = getConstantBoundingSizeAndShape();
544   if (!numElements.hasValue()) {
545     LLVM_DEBUG(llvm::dbgs() << "Dynamic shapes not yet supported\n");
546     return None;
547   }
548   return getMemRefEltSizeInBytes(memRefType) * numElements.getValue();
549 }
550
551 /// Returns the size of memref data in bytes if it's statically shaped, None
552 /// otherwise.  If the element of the memref has vector type, takes into account
553 /// size of the vector as well.
554 //  TODO: improve/complete this when we have target data.
555 Optional<uint64_t> mlir::getMemRefSizeInBytes(MemRefType memRefType) {
556   if (!memRefType.hasStaticShape())
557     return None;
558   auto elementType = memRefType.getElementType();
559   if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>())
560     return None;
561
562   uint64_t sizeInBytes = getMemRefEltSizeInBytes(memRefType);
563   for (unsigned i = 0, e = memRefType.getRank(); i < e; i++) {
564     sizeInBytes = sizeInBytes * memRefType.getDimSize(i);
565   }
566   return sizeInBytes;
567 }
568
569 template <typename LoadOrStoreOp>
570 LogicalResult mlir::boundCheckLoadOrStoreOp(LoadOrStoreOp loadOrStoreOp,
571                                             bool emitError) {
572   static_assert(llvm::is_one_of<LoadOrStoreOp, AffineReadOpInterface,
573                                 AffineWriteOpInterface>::value,
574                 "argument should be either a AffineReadOpInterface or a "
575                 "AffineWriteOpInterface");
576
577   Operation *op = loadOrStoreOp.getOperation();
578   MemRefRegion region(op->getLoc());
579   if (failed(region.compute(op, /*loopDepth=*/0, /*sliceState=*/nullptr,
580                             /*addMemRefDimBounds=*/false)))
581     return success();
582
583   LLVM_DEBUG(llvm::dbgs() << "Memory region");
584   LLVM_DEBUG(region.getConstraints()->dump());
585
586   bool outOfBounds = false;
587   unsigned rank = loadOrStoreOp.getMemRefType().getRank();
588
589   // For each dimension, check for out of bounds.
590   for (unsigned r = 0; r < rank; r++) {
591     FlatAffineConstraints ucst(*region.getConstraints());
592
593     // Intersect memory region with constraint capturing out of bounds (both out
594     // of upper and out of lower), and check if the constraint system is
595     // feasible. If it is, there is at least one point out of bounds.
596     SmallVector<int64_t, 4> ineq(rank + 1, 0);
597     int64_t dimSize = loadOrStoreOp.getMemRefType().getDimSize(r);
598     // TODO: handle dynamic dim sizes.
599     if (dimSize == -1)
600       continue;
601
602     // Check for overflow: d_i >= memref dim size.
603     ucst.addConstantLowerBound(r, dimSize);
604     outOfBounds = !ucst.isEmpty();
605     if (outOfBounds && emitError) {
606       loadOrStoreOp.emitOpError()
607           << "memref out of upper bound access along dimension #" << (r + 1);
608     }
609
610     // Check for a negative index.
611     FlatAffineConstraints lcst(*region.getConstraints());
612     std::fill(ineq.begin(), ineq.end(), 0);
613     // d_i <= -1;
614     lcst.addConstantUpperBound(r, -1);
615     outOfBounds = !lcst.isEmpty();
616     if (outOfBounds && emitError) {
617       loadOrStoreOp.emitOpError()
618           << "memref out of lower bound access along dimension #" << (r + 1);
619     }
620   }
621   return failure(outOfBounds);
622 }
623
624 // Explicitly instantiate the template so that the compiler knows we need them!
625 template LogicalResult
626 mlir::boundCheckLoadOrStoreOp(AffineReadOpInterface loadOp, bool emitError);
627 template LogicalResult
628 mlir::boundCheckLoadOrStoreOp(AffineWriteOpInterface storeOp, bool emitError);
629
630 // Returns in 'positions' the Block positions of 'op' in each ancestor
631 // Block from the Block containing operation, stopping at 'limitBlock'.
632 static void findInstPosition(Operation *op, Block *limitBlock,
633                              SmallVectorImpl<unsigned> *positions) {
634   Block *block = op->getBlock();
635   while (block != limitBlock) {
636     // FIXME: This algorithm is unnecessarily O(n) and should be improved to not
637     // rely on linear scans.
638     int instPosInBlock = std::distance(block->begin(), op->getIterator());
639     positions->push_back(instPosInBlock);
640     op = block->getParentOp();
641     block = op->getBlock();
642   }
643   std::reverse(positions->begin(), positions->end());
644 }
645
646 // Returns the Operation in a possibly nested set of Blocks, where the
647 // position of the operation is represented by 'positions', which has a
648 // Block position for each level of nesting.
649 static Operation *getInstAtPosition(ArrayRef<unsigned> positions,
650                                     unsigned level, Block *block) {
651   unsigned i = 0;
652   for (auto &op : *block) {
653     if (i != positions[level]) {
654       ++i;
655       continue;
656     }
657     if (level == positions.size() - 1)
658       return &op;
659     if (auto childAffineForOp = dyn_cast<AffineForOp>(op))
660       return getInstAtPosition(positions, level + 1,
661                                childAffineForOp.getBody());
662
663     for (auto &region : op.getRegions()) {
664       for (auto &b : region)
665         if (auto *ret = getInstAtPosition(positions, level + 1, &b))
666           return ret;
667     }
668     return nullptr;
669   }
670   return nullptr;
671 }
672
673 // Adds loop IV bounds to 'cst' for loop IVs not found in 'ivs'.
674 static LogicalResult addMissingLoopIVBounds(SmallPtrSet<Value, 8> &ivs,
675                                             FlatAffineConstraints *cst) {
676   for (unsigned i = 0, e = cst->getNumDimIds(); i < e; ++i) {
677     auto value = cst->getIdValue(i);
678     if (ivs.count(value) == 0) {
679       assert(isForInductionVar(value));
680       auto loop = getForInductionVarOwner(value);
681       if (failed(cst->addAffineForOpDomain(loop)))
682         return failure();
683     }
684   }
685   return success();
686 }
687
688 /// Returns the innermost common loop depth for the set of operations in 'ops'.
689 // TODO: Move this to LoopUtils.
690 unsigned mlir::getInnermostCommonLoopDepth(
691     ArrayRef<Operation *> ops, SmallVectorImpl<AffineForOp> *surroundingLoops) {
692   unsigned numOps = ops.size();
693   assert(numOps > 0 && "Expected at least one operation");
694
695   std::vector<SmallVector<AffineForOp, 4>> loops(numOps);
696   unsigned loopDepthLimit = std::numeric_limits<unsigned>::max();
697   for (unsigned i = 0; i < numOps; ++i) {
698     getLoopIVs(*ops[i], &loops[i]);
699     loopDepthLimit =
700         std::min(loopDepthLimit, static_cast<unsigned>(loops[i].size()));
701   }
702
703   unsigned loopDepth = 0;
704   for (unsigned d = 0; d < loopDepthLimit; ++d) {
705     unsigned i;
706     for (i = 1; i < numOps; ++i) {
707       if (loops[i - 1][d] != loops[i][d])
708         return loopDepth;
709     }
710     if (surroundingLoops)
711       surroundingLoops->push_back(loops[i - 1][d]);
712     ++loopDepth;
713   }
714   return loopDepth;
715 }
716
717 /// Computes in 'sliceUnion' the union of all slice bounds computed at
718 /// 'loopDepth' between all dependent pairs of ops in 'opsA' and 'opsB'.
719 /// Returns 'Success' if union was computed, 'failure' otherwise.
720 LogicalResult mlir::computeSliceUnion(ArrayRef<Operation *> opsA,
721                                       ArrayRef<Operation *> opsB,
722                                       unsigned loopDepth,
723                                       unsigned numCommonLoops,
724                                       bool isBackwardSlice,
725                                       ComputationSliceState *sliceUnion) {
726   // Compute the union of slice bounds between all pairs in 'opsA' and
727   // 'opsB' in 'sliceUnionCst'.
728   FlatAffineConstraints sliceUnionCst;
729   assert(sliceUnionCst.getNumDimAndSymbolIds() == 0);
730   std::vector<std::pair<Operation *, Operation *>> dependentOpPairs;
731   for (unsigned i = 0, numOpsA = opsA.size(); i < numOpsA; ++i) {
732     MemRefAccess srcAccess(opsA[i]);
733     for (unsigned j = 0, numOpsB = opsB.size(); j < numOpsB; ++j) {
734       MemRefAccess dstAccess(opsB[j]);
735       if (srcAccess.memref != dstAccess.memref)
736         continue;
737       // Check if 'loopDepth' exceeds nesting depth of src/dst ops.
738       if ((!isBackwardSlice && loopDepth > getNestingDepth(opsA[i])) ||
739           (isBackwardSlice && loopDepth > getNestingDepth(opsB[j]))) {
740         LLVM_DEBUG(llvm::dbgs() << "Invalid loop depth\n");
741         return failure();
742       }
743
744       bool readReadAccesses = isa<AffineReadOpInterface>(srcAccess.opInst) &&
745                               isa<AffineReadOpInterface>(dstAccess.opInst);
746       FlatAffineConstraints dependenceConstraints;
747       // Check dependence between 'srcAccess' and 'dstAccess'.
748       DependenceResult result = checkMemrefAccessDependence(
749           srcAccess, dstAccess, /*loopDepth=*/numCommonLoops + 1,
750           &dependenceConstraints, /*dependenceComponents=*/nullptr,
751           /*allowRAR=*/readReadAccesses);
752       if (result.value == DependenceResult::Failure) {
753         LLVM_DEBUG(llvm::dbgs() << "Dependence check failed\n");
754         return failure();
755       }
756       if (result.value == DependenceResult::NoDependence)
757         continue;
758       dependentOpPairs.push_back({opsA[i], opsB[j]});
759
760       // Compute slice bounds for 'srcAccess' and 'dstAccess'.
761       ComputationSliceState tmpSliceState;
762       mlir::getComputationSliceState(opsA[i], opsB[j], &dependenceConstraints,
763                                      loopDepth, isBackwardSlice,
764                                      &tmpSliceState);
765
766       if (sliceUnionCst.getNumDimAndSymbolIds() == 0) {
767         // Initialize 'sliceUnionCst' with the bounds computed in previous step.
768         if (failed(tmpSliceState.getAsConstraints(&sliceUnionCst))) {
769           LLVM_DEBUG(llvm::dbgs()
770                      << "Unable to compute slice bound constraints\n");
771           return failure();
772         }
773         assert(sliceUnionCst.getNumDimAndSymbolIds() > 0);
774         continue;
775       }
776
777       // Compute constraints for 'tmpSliceState' in 'tmpSliceCst'.
778       FlatAffineConstraints tmpSliceCst;
779       if (failed(tmpSliceState.getAsConstraints(&tmpSliceCst))) {
780         LLVM_DEBUG(llvm::dbgs()
781                    << "Unable to compute slice bound constraints\n");
782         return failure();
783       }
784
785       // Align coordinate spaces of 'sliceUnionCst' and 'tmpSliceCst' if needed.
786       if (!sliceUnionCst.areIdsAlignedWithOther(tmpSliceCst)) {
787
788         // Pre-constraint id alignment: record loop IVs used in each constraint
789         // system.
790         SmallPtrSet<Value, 8> sliceUnionIVs;
791         for (unsigned k = 0, l = sliceUnionCst.getNumDimIds(); k < l; ++k)
792           sliceUnionIVs.insert(sliceUnionCst.getIdValue(k));
793         SmallPtrSet<Value, 8> tmpSliceIVs;
794         for (unsigned k = 0, l = tmpSliceCst.getNumDimIds(); k < l; ++k)
795           tmpSliceIVs.insert(tmpSliceCst.getIdValue(k));
796
797         sliceUnionCst.mergeAndAlignIdsWithOther(/*offset=*/0, &tmpSliceCst);
798
799         // Post-constraint id alignment: add loop IV bounds missing after
800         // id alignment to constraint systems. This can occur if one constraint
801         // system uses an loop IV that is not used by the other. The call
802         // to unionBoundingBox below expects constraints for each Loop IV, even
803         // if they are the unsliced full loop bounds added here.
804         if (failed(addMissingLoopIVBounds(sliceUnionIVs, &sliceUnionCst)))
805           return failure();
806         if (failed(addMissingLoopIVBounds(tmpSliceIVs, &tmpSliceCst)))
807           return failure();
808       }
809       // Compute union bounding box of 'sliceUnionCst' and 'tmpSliceCst'.
810       if (sliceUnionCst.getNumLocalIds() > 0 ||
811           tmpSliceCst.getNumLocalIds() > 0 ||
812           failed(sliceUnionCst.unionBoundingBox(tmpSliceCst))) {
813         LLVM_DEBUG(llvm::dbgs()
814                    << "Unable to compute union bounding box of slice bounds\n");
815         return failure();
816       }
817     }
818   }
819
820   // Empty union.
821   if (sliceUnionCst.getNumDimAndSymbolIds() == 0)
822     return failure();
823
824   // Gather loops surrounding ops from loop nest where slice will be inserted.
825   SmallVector<Operation *, 4> ops;
826   for (auto &dep : dependentOpPairs) {
827     ops.push_back(isBackwardSlice ? dep.second : dep.first);
828   }
829   SmallVector<AffineForOp, 4> surroundingLoops;
830   unsigned innermostCommonLoopDepth =
831       getInnermostCommonLoopDepth(ops, &surroundingLoops);
832   if (loopDepth > innermostCommonLoopDepth) {
833     LLVM_DEBUG(llvm::dbgs() << "Exceeds max loop depth\n");
834     return failure();
835   }
836
837   // Store 'numSliceLoopIVs' before converting dst loop IVs to dims.
838   unsigned numSliceLoopIVs = sliceUnionCst.getNumDimIds();
839
840   // Convert any dst loop IVs which are symbol identifiers to dim identifiers.
841   sliceUnionCst.convertLoopIVSymbolsToDims();
842   sliceUnion->clearBounds();
843   sliceUnion->lbs.resize(numSliceLoopIVs, AffineMap());
844   sliceUnion->ubs.resize(numSliceLoopIVs, AffineMap());
845
846   // Get slice bounds from slice union constraints 'sliceUnionCst'.
847   sliceUnionCst.getSliceBounds(/*offset=*/0, numSliceLoopIVs,
848                                opsA[0]->getContext(), &sliceUnion->lbs,
849                                &sliceUnion->ubs);
850
851   // Add slice bound operands of union.
852   SmallVector<Value, 4> sliceBoundOperands;
853   sliceUnionCst.getIdValues(numSliceLoopIVs,
854                             sliceUnionCst.getNumDimAndSymbolIds(),
855                             &sliceBoundOperands);
856
857   // Copy src loop IVs from 'sliceUnionCst' to 'sliceUnion'.
858   sliceUnion->ivs.clear();
859   sliceUnionCst.getIdValues(0, numSliceLoopIVs, &sliceUnion->ivs);
860
861   // Set loop nest insertion point to block start at 'loopDepth'.
862   sliceUnion->insertPoint =
863       isBackwardSlice
864           ? surroundingLoops[loopDepth - 1].getBody()->begin()
865           : std::prev(surroundingLoops[loopDepth - 1].getBody()->end());
866
867   // Give each bound its own copy of 'sliceBoundOperands' for subsequent
868   // canonicalization.
869   sliceUnion->lbOperands.resize(numSliceLoopIVs, sliceBoundOperands);
870   sliceUnion->ubOperands.resize(numSliceLoopIVs, sliceBoundOperands);
871   return success();
872 }
873
874 const char *const kSliceFusionBarrierAttrName = "slice_fusion_barrier";
875 // Computes slice bounds by projecting out any loop IVs from
876 // 'dependenceConstraints' at depth greater than 'loopDepth', and computes slice
877 // bounds in 'sliceState' which represent the one loop nest's IVs in terms of
878 // the other loop nest's IVs, symbols and constants (using 'isBackwardsSlice').
879 void mlir::getComputationSliceState(
880     Operation *depSourceOp, Operation *depSinkOp,
881     FlatAffineConstraints *dependenceConstraints, unsigned loopDepth,
882     bool isBackwardSlice, ComputationSliceState *sliceState) {
883   // Get loop nest surrounding src operation.
884   SmallVector<AffineForOp, 4> srcLoopIVs;
885   getLoopIVs(*depSourceOp, &srcLoopIVs);
886   unsigned numSrcLoopIVs = srcLoopIVs.size();
887
888   // Get loop nest surrounding dst operation.
889   SmallVector<AffineForOp, 4> dstLoopIVs;
890   getLoopIVs(*depSinkOp, &dstLoopIVs);
891   unsigned numDstLoopIVs = dstLoopIVs.size();
892
893   assert((!isBackwardSlice && loopDepth <= numSrcLoopIVs) ||
894          (isBackwardSlice && loopDepth <= numDstLoopIVs));
895
896   // Project out dimensions other than those up to 'loopDepth'.
897   unsigned pos = isBackwardSlice ? numSrcLoopIVs + loopDepth : loopDepth;
898   unsigned num =
899       isBackwardSlice ? numDstLoopIVs - loopDepth : numSrcLoopIVs - loopDepth;
900   dependenceConstraints->projectOut(pos, num);
901
902   // Add slice loop IV values to 'sliceState'.
903   unsigned offset = isBackwardSlice ? 0 : loopDepth;
904   unsigned numSliceLoopIVs = isBackwardSlice ? numSrcLoopIVs : numDstLoopIVs;
905   dependenceConstraints->getIdValues(offset, offset + numSliceLoopIVs,
906                                      &sliceState->ivs);
907
908   // Set up lower/upper bound affine maps for the slice.
909   sliceState->lbs.resize(numSliceLoopIVs, AffineMap());
910   sliceState->ubs.resize(numSliceLoopIVs, AffineMap());
911
912   // Get bounds for slice IVs in terms of other IVs, symbols, and constants.
913   dependenceConstraints->getSliceBounds(offset, numSliceLoopIVs,
914                                         depSourceOp->getContext(),
915                                         &sliceState->lbs, &sliceState->ubs);
916
917   // Set up bound operands for the slice's lower and upper bounds.
918   SmallVector<Value, 4> sliceBoundOperands;
919   unsigned numDimsAndSymbols = dependenceConstraints->getNumDimAndSymbolIds();
920   for (unsigned i = 0; i < numDimsAndSymbols; ++i) {
921     if (i < offset || i >= offset + numSliceLoopIVs) {
922       sliceBoundOperands.push_back(dependenceConstraints->getIdValue(i));
923     }
924   }
925
926   // Give each bound its own copy of 'sliceBoundOperands' for subsequent
927   // canonicalization.
928   sliceState->lbOperands.resize(numSliceLoopIVs, sliceBoundOperands);
929   sliceState->ubOperands.resize(numSliceLoopIVs, sliceBoundOperands);
930
931   // Set destination loop nest insertion point to block start at 'dstLoopDepth'.
932   sliceState->insertPoint =
933       isBackwardSlice ? dstLoopIVs[loopDepth - 1].getBody()->begin()
934                       : std::prev(srcLoopIVs[loopDepth - 1].getBody()->end());
935
936   llvm::SmallDenseSet<Value, 8> sequentialLoops;
937   if (isa<AffineReadOpInterface>(depSourceOp) &&
938       isa<AffineReadOpInterface>(depSinkOp)) {
939     // For read-read access pairs, clear any slice bounds on sequential loops.
940     // Get sequential loops in loop nest rooted at 'srcLoopIVs[0]'.
941     getSequentialLoops(isBackwardSlice ? srcLoopIVs[0] : dstLoopIVs[0],
942                        &sequentialLoops);
943   }
944   // Clear all sliced loop bounds beginning at the first sequential loop, or
945   // first loop with a slice fusion barrier attribute..
946   // TODO: Use MemRef read/write regions instead of
947   // using 'kSliceFusionBarrierAttrName'.
948   auto getSliceLoop = [&](unsigned i) {
949     return isBackwardSlice ? srcLoopIVs[i] : dstLoopIVs[i];
950   };
951   for (unsigned i = 0; i < numSliceLoopIVs; ++i) {
952     Value iv = getSliceLoop(i).getInductionVar();
953     if (sequentialLoops.count(iv) == 0 &&
954         getSliceLoop(i)->getAttr(kSliceFusionBarrierAttrName) == nullptr)
955       continue;
956     for (unsigned j = i; j < numSliceLoopIVs; ++j) {
957       sliceState->lbs[j] = AffineMap();
958       sliceState->ubs[j] = AffineMap();
959     }
960     break;
961   }
962 }
963
964 /// Creates a computation slice of the loop nest surrounding 'srcOpInst',
965 /// updates the slice loop bounds with any non-null bound maps specified in
966 /// 'sliceState', and inserts this slice into the loop nest surrounding
967 /// 'dstOpInst' at loop depth 'dstLoopDepth'.
968 // TODO: extend the slicing utility to compute slices that
969 // aren't necessarily a one-to-one relation b/w the source and destination. The
970 // relation between the source and destination could be many-to-many in general.
971 // TODO: the slice computation is incorrect in the cases
972 // where the dependence from the source to the destination does not cover the
973 // entire destination index set. Subtract out the dependent destination
974 // iterations from destination index set and check for emptiness --- this is one
975 // solution.
976 AffineForOp
977 mlir::insertBackwardComputationSlice(Operation *srcOpInst, Operation *dstOpInst,
978                                      unsigned dstLoopDepth,
979                                      ComputationSliceState *sliceState) {
980   // Get loop nest surrounding src operation.
981   SmallVector<AffineForOp, 4> srcLoopIVs;
982   getLoopIVs(*srcOpInst, &srcLoopIVs);
983   unsigned numSrcLoopIVs = srcLoopIVs.size();
984
985   // Get loop nest surrounding dst operation.
986   SmallVector<AffineForOp, 4> dstLoopIVs;
987   getLoopIVs(*dstOpInst, &dstLoopIVs);
988   unsigned dstLoopIVsSize = dstLoopIVs.size();
989   if (dstLoopDepth > dstLoopIVsSize) {
990     dstOpInst->emitError("invalid destination loop depth");
991     return AffineForOp();
992   }
993
994   // Find the op block positions of 'srcOpInst' within 'srcLoopIVs'.
995   SmallVector<unsigned, 4> positions;
996   // TODO: This code is incorrect since srcLoopIVs can be 0-d.
997   findInstPosition(srcOpInst, srcLoopIVs[0]->getBlock(), &positions);
998
999   // Clone src loop nest and insert it a the beginning of the operation block
1000   // of the loop at 'dstLoopDepth' in 'dstLoopIVs'.
1001   auto dstAffineForOp = dstLoopIVs[dstLoopDepth - 1];
1002   OpBuilder b(dstAffineForOp.getBody(), dstAffineForOp.getBody()->begin());
1003   auto sliceLoopNest =
1004       cast<AffineForOp>(b.clone(*srcLoopIVs[0].getOperation()));
1005
1006   Operation *sliceInst =
1007       getInstAtPosition(positions, /*level=*/0, sliceLoopNest.getBody());
1008   // Get loop nest surrounding 'sliceInst'.
1009   SmallVector<AffineForOp, 4> sliceSurroundingLoops;
1010   getLoopIVs(*sliceInst, &sliceSurroundingLoops);
1011
1012   // Sanity check.
1013   unsigned sliceSurroundingLoopsSize = sliceSurroundingLoops.size();
1014   (void)sliceSurroundingLoopsSize;
1015   assert(dstLoopDepth + numSrcLoopIVs >= sliceSurroundingLoopsSize);
1016   unsigned sliceLoopLimit = dstLoopDepth + numSrcLoopIVs;
1017   (void)sliceLoopLimit;
1018   assert(sliceLoopLimit >= sliceSurroundingLoopsSize);
1019
1020   // Update loop bounds for loops in 'sliceLoopNest'.
1021   for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
1022     auto forOp = sliceSurroundingLoops[dstLoopDepth + i];
1023     if (AffineMap lbMap = sliceState->lbs[i])
1024       forOp.setLowerBound(sliceState->lbOperands[i], lbMap);
1025     if (AffineMap ubMap = sliceState->ubs[i])
1026       forOp.setUpperBound(sliceState->ubOperands[i], ubMap);
1027   }
1028   return sliceLoopNest;
1029 }
1030
1031 // Constructs  MemRefAccess populating it with the memref, its indices and
1032 // opinst from 'loadOrStoreOpInst'.
1033 MemRefAccess::MemRefAccess(Operation *loadOrStoreOpInst) {
1034   if (auto loadOp = dyn_cast<AffineReadOpInterface>(loadOrStoreOpInst)) {
1035     memref = loadOp.getMemRef();
1036     opInst = loadOrStoreOpInst;
1037     auto loadMemrefType = loadOp.getMemRefType();
1038     indices.reserve(loadMemrefType.getRank());
1039     for (auto index : loadOp.getMapOperands()) {
1040       indices.push_back(index);
1041     }
1042   } else {
1043     assert(isa<AffineWriteOpInterface>(loadOrStoreOpInst) &&
1044            "Affine read/write op expected");
1045     auto storeOp = cast<AffineWriteOpInterface>(loadOrStoreOpInst);
1046     opInst = loadOrStoreOpInst;
1047     memref = storeOp.getMemRef();
1048     auto storeMemrefType = storeOp.getMemRefType();
1049     indices.reserve(storeMemrefType.getRank());
1050     for (auto index : storeOp.getMapOperands()) {
1051       indices.push_back(index);
1052     }
1053   }
1054 }
1055
1056 unsigned MemRefAccess::getRank() const {
1057   return memref.getType().cast<MemRefType>().getRank();
1058 }
1059
1060 bool MemRefAccess::isStore() const {
1061   return isa<AffineWriteOpInterface>(opInst);
1062 }
1063
1064 /// Returns the nesting depth of this statement, i.e., the number of loops
1065 /// surrounding this statement.
1066 unsigned mlir::getNestingDepth(Operation *op) {
1067   Operation *currOp = op;
1068   unsigned depth = 0;
1069   while ((currOp = currOp->getParentOp())) {
1070     if (isa<AffineForOp>(currOp))
1071       depth++;
1072   }
1073   return depth;
1074 }
1075
1076 /// Equal if both affine accesses are provably equivalent (at compile
1077 /// time) when considering the memref, the affine maps and their respective
1078 /// operands. The equality of access functions + operands is checked by
1079 /// subtracting fully composed value maps, and then simplifying the difference
1080 /// using the expression flattener.
1081 /// TODO: this does not account for aliasing of memrefs.
1082 bool MemRefAccess::operator==(const MemRefAccess &rhs) const {
1083   if (memref != rhs.memref)
1084     return false;
1085
1086   AffineValueMap diff, thisMap, rhsMap;
1087   getAccessMap(&thisMap);
1088   rhs.getAccessMap(&rhsMap);
1089   AffineValueMap::difference(thisMap, rhsMap, &diff);
1090   return llvm::all_of(diff.getAffineMap().getResults(),
1091                       [](AffineExpr e) { return e == 0; });
1092 }
1093
1094 /// Returns the number of surrounding loops common to 'loopsA' and 'loopsB',
1095 /// where each lists loops from outer-most to inner-most in loop nest.
1096 unsigned mlir::getNumCommonSurroundingLoops(Operation &A, Operation &B) {
1097   SmallVector<AffineForOp, 4> loopsA, loopsB;
1098   getLoopIVs(A, &loopsA);
1099   getLoopIVs(B, &loopsB);
1100
1101   unsigned minNumLoops = std::min(loopsA.size(), loopsB.size());
1102   unsigned numCommonLoops = 0;
1103   for (unsigned i = 0; i < minNumLoops; ++i) {
1104     if (loopsA[i].getOperation() != loopsB[i].getOperation())
1105       break;
1106     ++numCommonLoops;
1107   }
1108   return numCommonLoops;
1109 }
1110
1111 static Optional<int64_t> getMemoryFootprintBytes(Block &block,
1112                                                  Block::iterator start,
1113                                                  Block::iterator end,
1114                                                  int memorySpace) {
1115   SmallDenseMap<Value, std::unique_ptr<MemRefRegion>, 4> regions;
1116
1117   // Walk this 'affine.for' operation to gather all memory regions.
1118   auto result = block.walk(start, end, [&](Operation *opInst) -> WalkResult {
1119     if (!isa<AffineReadOpInterface, AffineWriteOpInterface>(opInst)) {
1120       // Neither load nor a store op.
1121       return WalkResult::advance();
1122     }
1123
1124     // Compute the memref region symbolic in any IVs enclosing this block.
1125     auto region = std::make_unique<MemRefRegion>(opInst->getLoc());
1126     if (failed(
1127             region->compute(opInst,
1128                             /*loopDepth=*/getNestingDepth(&*block.begin())))) {
1129       return opInst->emitError("error obtaining memory region\n");
1130     }
1131
1132     auto it = regions.find(region->memref);
1133     if (it == regions.end()) {
1134       regions[region->memref] = std::move(region);
1135     } else if (failed(it->second->unionBoundingBox(*region))) {
1136       return opInst->emitWarning(
1137           "getMemoryFootprintBytes: unable to perform a union on a memory "
1138           "region");
1139     }
1140     return WalkResult::advance();
1141   });
1142   if (result.wasInterrupted())
1143     return None;
1144
1145   int64_t totalSizeInBytes = 0;
1146   for (const auto &region : regions) {
1147     Optional<int64_t> size = region.second->getRegionSize();
1148     if (!size.hasValue())
1149       return None;
1150     totalSizeInBytes += size.getValue();
1151   }
1152   return totalSizeInBytes;
1153 }
1154
1155 Optional<int64_t> mlir::getMemoryFootprintBytes(AffineForOp forOp,
1156                                                 int memorySpace) {
1157   auto *forInst = forOp.getOperation();
1158   return ::getMemoryFootprintBytes(
1159       *forInst->getBlock(), Block::iterator(forInst),
1160       std::next(Block::iterator(forInst)), memorySpace);
1161 }
1162
1163 /// Returns in 'sequentialLoops' all sequential loops in loop nest rooted
1164 /// at 'forOp'.
1165 void mlir::getSequentialLoops(AffineForOp forOp,
1166                               llvm::SmallDenseSet<Value, 8> *sequentialLoops) {
1167   forOp->walk([&](Operation *op) {
1168     if (auto innerFor = dyn_cast<AffineForOp>(op))
1169       if (!isLoopParallel(innerFor))
1170         sequentialLoops->insert(innerFor.getInductionVar());
1171   });
1172 }
1173
1174 /// Returns true if 'forOp' is parallel.
1175 bool mlir::isLoopParallel(AffineForOp forOp) {
1176   // Collect all load and store ops in loop nest rooted at 'forOp'.
1177   SmallVector<Operation *, 8> loadAndStoreOpInsts;
1178   auto walkResult = forOp.walk([&](Operation *opInst) -> WalkResult {
1179     if (isa<AffineReadOpInterface, AffineWriteOpInterface>(opInst))
1180       loadAndStoreOpInsts.push_back(opInst);
1181     else if (!isa<AffineForOp, AffineYieldOp, AffineIfOp>(opInst) &&
1182              !MemoryEffectOpInterface::hasNoEffect(opInst))
1183       return WalkResult::interrupt();
1184
1185     return WalkResult::advance();
1186   });
1187
1188   // Stop early if the loop has unknown ops with side effects.
1189   if (walkResult.wasInterrupted())
1190     return false;
1191
1192   // Dep check depth would be number of enclosing loops + 1.
1193   unsigned depth = getNestingDepth(forOp) + 1;
1194
1195   // Check dependences between all pairs of ops in 'loadAndStoreOpInsts'.
1196   for (auto *srcOpInst : loadAndStoreOpInsts) {
1197     MemRefAccess srcAccess(srcOpInst);
1198     for (auto *dstOpInst : loadAndStoreOpInsts) {
1199       MemRefAccess dstAccess(dstOpInst);
1200       FlatAffineConstraints dependenceConstraints;
1201       DependenceResult result = checkMemrefAccessDependence(
1202           srcAccess, dstAccess, depth, &dependenceConstraints,
1203           /*dependenceComponents=*/nullptr);
1204       if (result.value != DependenceResult::NoDependence)
1205         return false;
1206     }
1207   }
1208   return true;
1209 }
1210
1211 IntegerSet mlir::simplifyIntegerSet(IntegerSet set) {
1212   FlatAffineConstraints fac(set);
1213   if (fac.isEmpty())
1214     return IntegerSet::getEmptySet(set.getNumDims(), set.getNumSymbols(),
1215                                    set.getContext());
1216   fac.removeTrivialRedundancy();
1217
1218   auto simplifiedSet = fac.getAsIntegerSet(set.getContext());
1219   assert(simplifiedSet && "guaranteed to succeed while roundtripping");
1220   return simplifiedSet;
1221 }