Revert "[mlir][Affine] Add support for multi-store producer fusion"
[lldb.git] / mlir / lib / Analysis / AffineStructures.cpp
1 //===- AffineStructures.cpp - MLIR Affine Structures Class-----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Structures for affine/polyhedral analysis of affine dialect ops.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Analysis/AffineStructures.h"
14 #include "mlir/Analysis/LinearTransform.h"
15 #include "mlir/Analysis/Presburger/Simplex.h"
16 #include "mlir/Dialect/Affine/IR/AffineOps.h"
17 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
18 #include "mlir/Dialect/StandardOps/IR/Ops.h"
19 #include "mlir/IR/AffineExprVisitor.h"
20 #include "mlir/IR/IntegerSet.h"
21 #include "mlir/Support/LLVM.h"
22 #include "mlir/Support/MathExtras.h"
23 #include "llvm/ADT/SmallPtrSet.h"
24 #include "llvm/ADT/SmallVector.h"
25 #include "llvm/Support/Debug.h"
26 #include "llvm/Support/raw_ostream.h"
27
28 #define DEBUG_TYPE "affine-structures"
29
30 using namespace mlir;
31 using llvm::SmallDenseMap;
32 using llvm::SmallDenseSet;
33
34 namespace {
35
36 // See comments for SimpleAffineExprFlattener.
37 // An AffineExprFlattener extends a SimpleAffineExprFlattener by recording
38 // constraint information associated with mod's, floordiv's, and ceildiv's
39 // in FlatAffineConstraints 'localVarCst'.
40 struct AffineExprFlattener : public SimpleAffineExprFlattener {
41 public:
42   // Constraints connecting newly introduced local variables (for mod's and
43   // div's) to existing (dimensional and symbolic) ones. These are always
44   // inequalities.
45   FlatAffineConstraints localVarCst;
46
47   AffineExprFlattener(unsigned nDims, unsigned nSymbols, MLIRContext *ctx)
48       : SimpleAffineExprFlattener(nDims, nSymbols) {
49     localVarCst.reset(nDims, nSymbols, /*numLocals=*/0);
50   }
51
52 private:
53   // Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr).
54   // The local identifier added is always a floordiv of a pure add/mul affine
55   // function of other identifiers, coefficients of which are specified in
56   // `dividend' and with respect to the positive constant `divisor'. localExpr
57   // is the simplified tree expression (AffineExpr) corresponding to the
58   // quantifier.
59   void addLocalFloorDivId(ArrayRef<int64_t> dividend, int64_t divisor,
60                           AffineExpr localExpr) override {
61     SimpleAffineExprFlattener::addLocalFloorDivId(dividend, divisor, localExpr);
62     // Update localVarCst.
63     localVarCst.addLocalFloorDiv(dividend, divisor);
64   }
65 };
66
67 } // end anonymous namespace
68
69 // Flattens the expressions in map. Returns failure if 'expr' was unable to be
70 // flattened (i.e., semi-affine expressions not handled yet).
71 static LogicalResult
72 getFlattenedAffineExprs(ArrayRef<AffineExpr> exprs, unsigned numDims,
73                         unsigned numSymbols,
74                         std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
75                         FlatAffineConstraints *localVarCst) {
76   if (exprs.empty()) {
77     localVarCst->reset(numDims, numSymbols);
78     return success();
79   }
80
81   AffineExprFlattener flattener(numDims, numSymbols, exprs[0].getContext());
82   // Use the same flattener to simplify each expression successively. This way
83   // local identifiers / expressions are shared.
84   for (auto expr : exprs) {
85     if (!expr.isPureAffine())
86       return failure();
87
88     flattener.walkPostOrder(expr);
89   }
90
91   assert(flattener.operandExprStack.size() == exprs.size());
92   flattenedExprs->clear();
93   flattenedExprs->assign(flattener.operandExprStack.begin(),
94                          flattener.operandExprStack.end());
95
96   if (localVarCst)
97     localVarCst->clearAndCopyFrom(flattener.localVarCst);
98
99   return success();
100 }
101
102 // Flattens 'expr' into 'flattenedExpr'. Returns failure if 'expr' was unable to
103 // be flattened (semi-affine expressions not handled yet).
104 LogicalResult
105 mlir::getFlattenedAffineExpr(AffineExpr expr, unsigned numDims,
106                              unsigned numSymbols,
107                              SmallVectorImpl<int64_t> *flattenedExpr,
108                              FlatAffineConstraints *localVarCst) {
109   std::vector<SmallVector<int64_t, 8>> flattenedExprs;
110   LogicalResult ret = ::getFlattenedAffineExprs({expr}, numDims, numSymbols,
111                                                 &flattenedExprs, localVarCst);
112   *flattenedExpr = flattenedExprs[0];
113   return ret;
114 }
115
116 /// Flattens the expressions in map. Returns failure if 'expr' was unable to be
117 /// flattened (i.e., semi-affine expressions not handled yet).
118 LogicalResult mlir::getFlattenedAffineExprs(
119     AffineMap map, std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
120     FlatAffineConstraints *localVarCst) {
121   if (map.getNumResults() == 0) {
122     localVarCst->reset(map.getNumDims(), map.getNumSymbols());
123     return success();
124   }
125   return ::getFlattenedAffineExprs(map.getResults(), map.getNumDims(),
126                                    map.getNumSymbols(), flattenedExprs,
127                                    localVarCst);
128 }
129
130 LogicalResult mlir::getFlattenedAffineExprs(
131     IntegerSet set, std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
132     FlatAffineConstraints *localVarCst) {
133   if (set.getNumConstraints() == 0) {
134     localVarCst->reset(set.getNumDims(), set.getNumSymbols());
135     return success();
136   }
137   return ::getFlattenedAffineExprs(set.getConstraints(), set.getNumDims(),
138                                    set.getNumSymbols(), flattenedExprs,
139                                    localVarCst);
140 }
141
142 //===----------------------------------------------------------------------===//
143 // FlatAffineConstraints.
144 //===----------------------------------------------------------------------===//
145
146 // Copy constructor.
147 FlatAffineConstraints::FlatAffineConstraints(
148     const FlatAffineConstraints &other) {
149   numReservedCols = other.numReservedCols;
150   numDims = other.getNumDimIds();
151   numSymbols = other.getNumSymbolIds();
152   numIds = other.getNumIds();
153
154   auto otherIds = other.getIds();
155   ids.reserve(numReservedCols);
156   ids.append(otherIds.begin(), otherIds.end());
157
158   unsigned numReservedEqualities = other.getNumReservedEqualities();
159   unsigned numReservedInequalities = other.getNumReservedInequalities();
160
161   equalities.reserve(numReservedEqualities * numReservedCols);
162   inequalities.reserve(numReservedInequalities * numReservedCols);
163
164   for (unsigned r = 0, e = other.getNumInequalities(); r < e; r++) {
165     addInequality(other.getInequality(r));
166   }
167   for (unsigned r = 0, e = other.getNumEqualities(); r < e; r++) {
168     addEquality(other.getEquality(r));
169   }
170 }
171
172 // Clones this object.
173 std::unique_ptr<FlatAffineConstraints> FlatAffineConstraints::clone() const {
174   return std::make_unique<FlatAffineConstraints>(*this);
175 }
176
177 // Construct from an IntegerSet.
178 FlatAffineConstraints::FlatAffineConstraints(IntegerSet set)
179     : numReservedCols(set.getNumInputs() + 1),
180       numIds(set.getNumDims() + set.getNumSymbols()), numDims(set.getNumDims()),
181       numSymbols(set.getNumSymbols()) {
182   equalities.reserve(set.getNumEqualities() * numReservedCols);
183   inequalities.reserve(set.getNumInequalities() * numReservedCols);
184   ids.resize(numIds, None);
185
186   // Flatten expressions and add them to the constraint system.
187   std::vector<SmallVector<int64_t, 8>> flatExprs;
188   FlatAffineConstraints localVarCst;
189   if (failed(getFlattenedAffineExprs(set, &flatExprs, &localVarCst))) {
190     assert(false && "flattening unimplemented for semi-affine integer sets");
191     return;
192   }
193   assert(flatExprs.size() == set.getNumConstraints());
194   for (unsigned l = 0, e = localVarCst.getNumLocalIds(); l < e; l++) {
195     addLocalId(getNumLocalIds());
196   }
197
198   for (unsigned i = 0, e = flatExprs.size(); i < e; ++i) {
199     const auto &flatExpr = flatExprs[i];
200     assert(flatExpr.size() == getNumCols());
201     if (set.getEqFlags()[i]) {
202       addEquality(flatExpr);
203     } else {
204       addInequality(flatExpr);
205     }
206   }
207   // Add the other constraints involving local id's from flattening.
208   append(localVarCst);
209 }
210
211 void FlatAffineConstraints::reset(unsigned numReservedInequalities,
212                                   unsigned numReservedEqualities,
213                                   unsigned newNumReservedCols,
214                                   unsigned newNumDims, unsigned newNumSymbols,
215                                   unsigned newNumLocals,
216                                   ArrayRef<Value> idArgs) {
217   assert(newNumReservedCols >= newNumDims + newNumSymbols + newNumLocals + 1 &&
218          "minimum 1 column");
219   numReservedCols = newNumReservedCols;
220   numDims = newNumDims;
221   numSymbols = newNumSymbols;
222   numIds = numDims + numSymbols + newNumLocals;
223   assert(idArgs.empty() || idArgs.size() == numIds);
224
225   clearConstraints();
226   if (numReservedEqualities >= 1)
227     equalities.reserve(newNumReservedCols * numReservedEqualities);
228   if (numReservedInequalities >= 1)
229     inequalities.reserve(newNumReservedCols * numReservedInequalities);
230   if (idArgs.empty()) {
231     ids.resize(numIds, None);
232   } else {
233     ids.assign(idArgs.begin(), idArgs.end());
234   }
235 }
236
237 void FlatAffineConstraints::reset(unsigned newNumDims, unsigned newNumSymbols,
238                                   unsigned newNumLocals,
239                                   ArrayRef<Value> idArgs) {
240   reset(0, 0, newNumDims + newNumSymbols + newNumLocals + 1, newNumDims,
241         newNumSymbols, newNumLocals, idArgs);
242 }
243
244 void FlatAffineConstraints::append(const FlatAffineConstraints &other) {
245   assert(other.getNumCols() == getNumCols());
246   assert(other.getNumDimIds() == getNumDimIds());
247   assert(other.getNumSymbolIds() == getNumSymbolIds());
248
249   inequalities.reserve(inequalities.size() +
250                        other.getNumInequalities() * numReservedCols);
251   equalities.reserve(equalities.size() +
252                      other.getNumEqualities() * numReservedCols);
253
254   for (unsigned r = 0, e = other.getNumInequalities(); r < e; r++) {
255     addInequality(other.getInequality(r));
256   }
257   for (unsigned r = 0, e = other.getNumEqualities(); r < e; r++) {
258     addEquality(other.getEquality(r));
259   }
260 }
261
262 void FlatAffineConstraints::addLocalId(unsigned pos) {
263   addId(IdKind::Local, pos);
264 }
265
266 void FlatAffineConstraints::addDimId(unsigned pos, Value id) {
267   addId(IdKind::Dimension, pos, id);
268 }
269
270 void FlatAffineConstraints::addSymbolId(unsigned pos, Value id) {
271   addId(IdKind::Symbol, pos, id);
272 }
273
274 /// Adds a dimensional identifier. The added column is initialized to
275 /// zero.
276 void FlatAffineConstraints::addId(IdKind kind, unsigned pos, Value id) {
277   if (kind == IdKind::Dimension)
278     assert(pos <= getNumDimIds());
279   else if (kind == IdKind::Symbol)
280     assert(pos <= getNumSymbolIds());
281   else
282     assert(pos <= getNumLocalIds());
283
284   unsigned oldNumReservedCols = numReservedCols;
285
286   // Check if a resize is necessary.
287   if (getNumCols() + 1 > numReservedCols) {
288     equalities.resize(getNumEqualities() * (getNumCols() + 1));
289     inequalities.resize(getNumInequalities() * (getNumCols() + 1));
290     numReservedCols++;
291   }
292
293   int absolutePos;
294
295   if (kind == IdKind::Dimension) {
296     absolutePos = pos;
297     numDims++;
298   } else if (kind == IdKind::Symbol) {
299     absolutePos = pos + getNumDimIds();
300     numSymbols++;
301   } else {
302     absolutePos = pos + getNumDimIds() + getNumSymbolIds();
303   }
304   numIds++;
305
306   // Note that getNumCols() now will already return the new size, which will be
307   // at least one.
308   int numInequalities = static_cast<int>(getNumInequalities());
309   int numEqualities = static_cast<int>(getNumEqualities());
310   int numCols = static_cast<int>(getNumCols());
311   for (int r = numInequalities - 1; r >= 0; r--) {
312     for (int c = numCols - 2; c >= 0; c--) {
313       if (c < absolutePos)
314         atIneq(r, c) = inequalities[r * oldNumReservedCols + c];
315       else
316         atIneq(r, c + 1) = inequalities[r * oldNumReservedCols + c];
317     }
318     atIneq(r, absolutePos) = 0;
319   }
320
321   for (int r = numEqualities - 1; r >= 0; r--) {
322     for (int c = numCols - 2; c >= 0; c--) {
323       // All values in column absolutePositions < absolutePos have the same
324       // coordinates in the 2-d view of the coefficient buffer.
325       if (c < absolutePos)
326         atEq(r, c) = equalities[r * oldNumReservedCols + c];
327       else
328         // Those at absolutePosition >= absolutePos, get a shifted
329         // absolutePosition.
330         atEq(r, c + 1) = equalities[r * oldNumReservedCols + c];
331     }
332     // Initialize added dimension to zero.
333     atEq(r, absolutePos) = 0;
334   }
335
336   // If an 'id' is provided, insert it; otherwise use None.
337   if (id)
338     ids.insert(ids.begin() + absolutePos, id);
339   else
340     ids.insert(ids.begin() + absolutePos, None);
341   assert(ids.size() == getNumIds());
342 }
343
344 /// Checks if two constraint systems are in the same space, i.e., if they are
345 /// associated with the same set of identifiers, appearing in the same order.
346 static bool areIdsAligned(const FlatAffineConstraints &A,
347                           const FlatAffineConstraints &B) {
348   return A.getNumDimIds() == B.getNumDimIds() &&
349          A.getNumSymbolIds() == B.getNumSymbolIds() &&
350          A.getNumIds() == B.getNumIds() && A.getIds().equals(B.getIds());
351 }
352
353 /// Calls areIdsAligned to check if two constraint systems have the same set
354 /// of identifiers in the same order.
355 bool FlatAffineConstraints::areIdsAlignedWithOther(
356     const FlatAffineConstraints &other) {
357   return areIdsAligned(*this, other);
358 }
359
360 /// Checks if the SSA values associated with `cst''s identifiers are unique.
361 static bool LLVM_ATTRIBUTE_UNUSED
362 areIdsUnique(const FlatAffineConstraints &cst) {
363   SmallPtrSet<Value, 8> uniqueIds;
364   for (auto id : cst.getIds()) {
365     if (id.hasValue() && !uniqueIds.insert(id.getValue()).second)
366       return false;
367   }
368   return true;
369 }
370
371 /// Merge and align the identifiers of A and B starting at 'offset', so that
372 /// both constraint systems get the union of the contained identifiers that is
373 /// dimension-wise and symbol-wise unique; both constraint systems are updated
374 /// so that they have the union of all identifiers, with A's original
375 /// identifiers appearing first followed by any of B's identifiers that didn't
376 /// appear in A. Local identifiers of each system are by design separate/local
377 /// and are placed one after other (A's followed by B's).
378 //  Eg: Input: A has ((%i %j) [%M %N]) and B has (%k, %j) [%P, %N, %M])
379 //      Output: both A, B have (%i, %j, %k) [%M, %N, %P]
380 //
381 static void mergeAndAlignIds(unsigned offset, FlatAffineConstraints *A,
382                              FlatAffineConstraints *B) {
383   assert(offset <= A->getNumDimIds() && offset <= B->getNumDimIds());
384   // A merge/align isn't meaningful if a cst's ids aren't distinct.
385   assert(areIdsUnique(*A) && "A's id values aren't unique");
386   assert(areIdsUnique(*B) && "B's id values aren't unique");
387
388   assert(std::all_of(A->getIds().begin() + offset,
389                      A->getIds().begin() + A->getNumDimAndSymbolIds(),
390                      [](Optional<Value> id) { return id.hasValue(); }));
391
392   assert(std::all_of(B->getIds().begin() + offset,
393                      B->getIds().begin() + B->getNumDimAndSymbolIds(),
394                      [](Optional<Value> id) { return id.hasValue(); }));
395
396   // Place local id's of A after local id's of B.
397   for (unsigned l = 0, e = A->getNumLocalIds(); l < e; l++) {
398     B->addLocalId(0);
399   }
400   for (unsigned t = 0, e = B->getNumLocalIds() - A->getNumLocalIds(); t < e;
401        t++) {
402     A->addLocalId(A->getNumLocalIds());
403   }
404
405   SmallVector<Value, 4> aDimValues, aSymValues;
406   A->getIdValues(offset, A->getNumDimIds(), &aDimValues);
407   A->getIdValues(A->getNumDimIds(), A->getNumDimAndSymbolIds(), &aSymValues);
408   {
409     // Merge dims from A into B.
410     unsigned d = offset;
411     for (auto aDimValue : aDimValues) {
412       unsigned loc;
413       if (B->findId(aDimValue, &loc)) {
414         assert(loc >= offset && "A's dim appears in B's aligned range");
415         assert(loc < B->getNumDimIds() &&
416                "A's dim appears in B's non-dim position");
417         B->swapId(d, loc);
418       } else {
419         B->addDimId(d);
420         B->setIdValue(d, aDimValue);
421       }
422       d++;
423     }
424
425     // Dimensions that are in B, but not in A, are added at the end.
426     for (unsigned t = A->getNumDimIds(), e = B->getNumDimIds(); t < e; t++) {
427       A->addDimId(A->getNumDimIds());
428       A->setIdValue(A->getNumDimIds() - 1, B->getIdValue(t));
429     }
430   }
431   {
432     // Merge symbols: merge A's symbols into B first.
433     unsigned s = B->getNumDimIds();
434     for (auto aSymValue : aSymValues) {
435       unsigned loc;
436       if (B->findId(aSymValue, &loc)) {
437         assert(loc >= B->getNumDimIds() && loc < B->getNumDimAndSymbolIds() &&
438                "A's symbol appears in B's non-symbol position");
439         B->swapId(s, loc);
440       } else {
441         B->addSymbolId(s - B->getNumDimIds());
442         B->setIdValue(s, aSymValue);
443       }
444       s++;
445     }
446     // Symbols that are in B, but not in A, are added at the end.
447     for (unsigned t = A->getNumDimAndSymbolIds(),
448                   e = B->getNumDimAndSymbolIds();
449          t < e; t++) {
450       A->addSymbolId(A->getNumSymbolIds());
451       A->setIdValue(A->getNumDimAndSymbolIds() - 1, B->getIdValue(t));
452     }
453   }
454   assert(areIdsAligned(*A, *B) && "IDs expected to be aligned");
455 }
456
457 // Call 'mergeAndAlignIds' to align constraint systems of 'this' and 'other'.
458 void FlatAffineConstraints::mergeAndAlignIdsWithOther(
459     unsigned offset, FlatAffineConstraints *other) {
460   mergeAndAlignIds(offset, this, other);
461 }
462
463 // This routine may add additional local variables if the flattened expression
464 // corresponding to the map has such variables due to mod's, ceildiv's, and
465 // floordiv's in it.
466 LogicalResult FlatAffineConstraints::composeMap(const AffineValueMap *vMap) {
467   std::vector<SmallVector<int64_t, 8>> flatExprs;
468   FlatAffineConstraints localCst;
469   if (failed(getFlattenedAffineExprs(vMap->getAffineMap(), &flatExprs,
470                                      &localCst))) {
471     LLVM_DEBUG(llvm::dbgs()
472                << "composition unimplemented for semi-affine maps\n");
473     return failure();
474   }
475   assert(flatExprs.size() == vMap->getNumResults());
476
477   // Add localCst information.
478   if (localCst.getNumLocalIds() > 0) {
479     localCst.setIdValues(0, /*end=*/localCst.getNumDimAndSymbolIds(),
480                          /*values=*/vMap->getOperands());
481     // Align localCst and this.
482     mergeAndAlignIds(/*offset=*/0, &localCst, this);
483     // Finally, append localCst to this constraint set.
484     append(localCst);
485   }
486
487   // Add dimensions corresponding to the map's results.
488   for (unsigned t = 0, e = vMap->getNumResults(); t < e; t++) {
489     // TODO: Consider using a batched version to add a range of IDs.
490     addDimId(0);
491   }
492
493   // We add one equality for each result connecting the result dim of the map to
494   // the other identifiers.
495   // For eg: if the expression is 16*i0 + i1, and this is the r^th
496   // iteration/result of the value map, we are adding the equality:
497   //  d_r - 16*i0 - i1 = 0. Hence, when flattening say (i0 + 1, i0 + 8*i2), we
498   //  add two equalities overall: d_0 - i0 - 1 == 0, d1 - i0 - 8*i2 == 0.
499   for (unsigned r = 0, e = flatExprs.size(); r < e; r++) {
500     const auto &flatExpr = flatExprs[r];
501     assert(flatExpr.size() >= vMap->getNumOperands() + 1);
502
503     // eqToAdd is the equality corresponding to the flattened affine expression.
504     SmallVector<int64_t, 8> eqToAdd(getNumCols(), 0);
505     // Set the coefficient for this result to one.
506     eqToAdd[r] = 1;
507
508     // Dims and symbols.
509     for (unsigned i = 0, e = vMap->getNumOperands(); i < e; i++) {
510       unsigned loc;
511       bool ret = findId(vMap->getOperand(i), &loc);
512       assert(ret && "value map's id can't be found");
513       (void)ret;
514       // Negate 'eq[r]' since the newly added dimension will be set to this one.
515       eqToAdd[loc] = -flatExpr[i];
516     }
517     // Local vars common to eq and localCst are at the beginning.
518     unsigned j = getNumDimIds() + getNumSymbolIds();
519     unsigned end = flatExpr.size() - 1;
520     for (unsigned i = vMap->getNumOperands(); i < end; i++, j++) {
521       eqToAdd[j] = -flatExpr[i];
522     }
523
524     // Constant term.
525     eqToAdd[getNumCols() - 1] = -flatExpr[flatExpr.size() - 1];
526
527     // Add the equality connecting the result of the map to this constraint set.
528     addEquality(eqToAdd);
529   }
530
531   return success();
532 }
533
534 // Similar to composeMap except that no Value's need be associated with the
535 // constraint system nor are they looked at -- since the dimensions and
536 // symbols of 'other' are expected to correspond 1:1 to 'this' system. It
537 // is thus not convenient to share code with composeMap.
538 LogicalResult FlatAffineConstraints::composeMatchingMap(AffineMap other) {
539   assert(other.getNumDims() == getNumDimIds() && "dim mismatch");
540   assert(other.getNumSymbols() == getNumSymbolIds() && "symbol mismatch");
541
542   std::vector<SmallVector<int64_t, 8>> flatExprs;
543   FlatAffineConstraints localCst;
544   if (failed(getFlattenedAffineExprs(other, &flatExprs, &localCst))) {
545     LLVM_DEBUG(llvm::dbgs()
546                << "composition unimplemented for semi-affine maps\n");
547     return failure();
548   }
549   assert(flatExprs.size() == other.getNumResults());
550
551   // Add localCst information.
552   if (localCst.getNumLocalIds() > 0) {
553     // Place local id's of A after local id's of B.
554     for (unsigned l = 0, e = localCst.getNumLocalIds(); l < e; l++) {
555       addLocalId(0);
556     }
557     // Finally, append localCst to this constraint set.
558     append(localCst);
559   }
560
561   // Add dimensions corresponding to the map's results.
562   for (unsigned t = 0, e = other.getNumResults(); t < e; t++) {
563     addDimId(0);
564   }
565
566   // We add one equality for each result connecting the result dim of the map to
567   // the other identifiers.
568   // For eg: if the expression is 16*i0 + i1, and this is the r^th
569   // iteration/result of the value map, we are adding the equality:
570   //  d_r - 16*i0 - i1 = 0. Hence, when flattening say (i0 + 1, i0 + 8*i2), we
571   //  add two equalities overall: d_0 - i0 - 1 == 0, d1 - i0 - 8*i2 == 0.
572   for (unsigned r = 0, e = flatExprs.size(); r < e; r++) {
573     const auto &flatExpr = flatExprs[r];
574     assert(flatExpr.size() >= other.getNumInputs() + 1);
575
576     // eqToAdd is the equality corresponding to the flattened affine expression.
577     SmallVector<int64_t, 8> eqToAdd(getNumCols(), 0);
578     // Set the coefficient for this result to one.
579     eqToAdd[r] = 1;
580
581     // Dims and symbols.
582     for (unsigned i = 0, f = other.getNumInputs(); i < f; i++) {
583       // Negate 'eq[r]' since the newly added dimension will be set to this one.
584       eqToAdd[e + i] = -flatExpr[i];
585     }
586     // Local vars common to eq and localCst are at the beginning.
587     unsigned j = getNumDimIds() + getNumSymbolIds();
588     unsigned end = flatExpr.size() - 1;
589     for (unsigned i = other.getNumInputs(); i < end; i++, j++) {
590       eqToAdd[j] = -flatExpr[i];
591     }
592
593     // Constant term.
594     eqToAdd[getNumCols() - 1] = -flatExpr[flatExpr.size() - 1];
595
596     // Add the equality connecting the result of the map to this constraint set.
597     addEquality(eqToAdd);
598   }
599
600   return success();
601 }
602
603 // Turn a dimension into a symbol.
604 static void turnDimIntoSymbol(FlatAffineConstraints *cst, Value id) {
605   unsigned pos;
606   if (cst->findId(id, &pos) && pos < cst->getNumDimIds()) {
607     cst->swapId(pos, cst->getNumDimIds() - 1);
608     cst->setDimSymbolSeparation(cst->getNumSymbolIds() + 1);
609   }
610 }
611
612 // Turn a symbol into a dimension.
613 static void turnSymbolIntoDim(FlatAffineConstraints *cst, Value id) {
614   unsigned pos;
615   if (cst->findId(id, &pos) && pos >= cst->getNumDimIds() &&
616       pos < cst->getNumDimAndSymbolIds()) {
617     cst->swapId(pos, cst->getNumDimIds());
618     cst->setDimSymbolSeparation(cst->getNumSymbolIds() - 1);
619   }
620 }
621
622 // Changes all symbol identifiers which are loop IVs to dim identifiers.
623 void FlatAffineConstraints::convertLoopIVSymbolsToDims() {
624   // Gather all symbols which are loop IVs.
625   SmallVector<Value, 4> loopIVs;
626   for (unsigned i = getNumDimIds(), e = getNumDimAndSymbolIds(); i < e; i++) {
627     if (ids[i].hasValue() && getForInductionVarOwner(ids[i].getValue()))
628       loopIVs.push_back(ids[i].getValue());
629   }
630   // Turn each symbol in 'loopIVs' into a dim identifier.
631   for (auto iv : loopIVs) {
632     turnSymbolIntoDim(this, iv);
633   }
634 }
635
636 void FlatAffineConstraints::addInductionVarOrTerminalSymbol(Value id) {
637   if (containsId(id))
638     return;
639
640   // Caller is expected to fully compose map/operands if necessary.
641   assert((isTopLevelValue(id) || isForInductionVar(id)) &&
642          "non-terminal symbol / loop IV expected");
643   // Outer loop IVs could be used in forOp's bounds.
644   if (auto loop = getForInductionVarOwner(id)) {
645     addDimId(getNumDimIds(), id);
646     if (failed(this->addAffineForOpDomain(loop)))
647       LLVM_DEBUG(
648           loop.emitWarning("failed to add domain info to constraint system"));
649     return;
650   }
651   // Add top level symbol.
652   addSymbolId(getNumSymbolIds(), id);
653   // Check if the symbol is a constant.
654   if (auto constOp = id.getDefiningOp<ConstantIndexOp>())
655     setIdToConstant(id, constOp.getValue());
656 }
657
658 LogicalResult FlatAffineConstraints::addAffineForOpDomain(AffineForOp forOp) {
659   unsigned pos;
660   // Pre-condition for this method.
661   if (!findId(forOp.getInductionVar(), &pos)) {
662     assert(false && "Value not found");
663     return failure();
664   }
665
666   int64_t step = forOp.getStep();
667   if (step != 1) {
668     if (!forOp.hasConstantLowerBound())
669       forOp.emitWarning("domain conservatively approximated");
670     else {
671       // Add constraints for the stride.
672       // (iv - lb) % step = 0 can be written as:
673       // (iv - lb) - step * q = 0 where q = (iv - lb) / step.
674       // Add local variable 'q' and add the above equality.
675       // The first constraint is q = (iv - lb) floordiv step
676       SmallVector<int64_t, 8> dividend(getNumCols(), 0);
677       int64_t lb = forOp.getConstantLowerBound();
678       dividend[pos] = 1;
679       dividend.back() -= lb;
680       addLocalFloorDiv(dividend, step);
681       // Second constraint: (iv - lb) - step * q = 0.
682       SmallVector<int64_t, 8> eq(getNumCols(), 0);
683       eq[pos] = 1;
684       eq.back() -= lb;
685       // For the local var just added above.
686       eq[getNumCols() - 2] = -step;
687       addEquality(eq);
688     }
689   }
690
691   if (forOp.hasConstantLowerBound()) {
692     addConstantLowerBound(pos, forOp.getConstantLowerBound());
693   } else {
694     // Non-constant lower bound case.
695     if (failed(addLowerOrUpperBound(pos, forOp.getLowerBoundMap(),
696                                     forOp.getLowerBoundOperands(),
697                                     /*eq=*/false, /*lower=*/true)))
698       return failure();
699   }
700
701   if (forOp.hasConstantUpperBound()) {
702     addConstantUpperBound(pos, forOp.getConstantUpperBound() - 1);
703     return success();
704   }
705   // Non-constant upper bound case.
706   return addLowerOrUpperBound(pos, forOp.getUpperBoundMap(),
707                               forOp.getUpperBoundOperands(),
708                               /*eq=*/false, /*lower=*/false);
709 }
710
711 void FlatAffineConstraints::addAffineIfOpDomain(AffineIfOp ifOp) {
712   // Create the base constraints from the integer set attached to ifOp.
713   FlatAffineConstraints cst(ifOp.getIntegerSet());
714
715   // Bind ids in the constraints to ifOp operands.
716   SmallVector<Value, 4> operands = ifOp.getOperands();
717   cst.setIdValues(0, cst.getNumDimAndSymbolIds(), operands);
718
719   // Merge the constraints from ifOp to the current domain. We need first merge
720   // and align the IDs from both constraints, and then append the constraints
721   // from the ifOp into the current one.
722   mergeAndAlignIdsWithOther(0, &cst);
723   append(cst);
724 }
725
726 // Searches for a constraint with a non-zero coefficient at 'colIdx' in
727 // equality (isEq=true) or inequality (isEq=false) constraints.
728 // Returns true and sets row found in search in 'rowIdx'.
729 // Returns false otherwise.
730 static bool findConstraintWithNonZeroAt(const FlatAffineConstraints &cst,
731                                         unsigned colIdx, bool isEq,
732                                         unsigned *rowIdx) {
733   assert(colIdx < cst.getNumCols() && "position out of bounds");
734   auto at = [&](unsigned rowIdx) -> int64_t {
735     return isEq ? cst.atEq(rowIdx, colIdx) : cst.atIneq(rowIdx, colIdx);
736   };
737   unsigned e = isEq ? cst.getNumEqualities() : cst.getNumInequalities();
738   for (*rowIdx = 0; *rowIdx < e; ++(*rowIdx)) {
739     if (at(*rowIdx) != 0) {
740       return true;
741     }
742   }
743   return false;
744 }
745
746 // Normalizes the coefficient values across all columns in 'rowIDx' by their
747 // GCD in equality or inequality constraints as specified by 'isEq'.
748 template <bool isEq>
749 static void normalizeConstraintByGCD(FlatAffineConstraints *constraints,
750                                      unsigned rowIdx) {
751   auto at = [&](unsigned colIdx) -> int64_t {
752     return isEq ? constraints->atEq(rowIdx, colIdx)
753                 : constraints->atIneq(rowIdx, colIdx);
754   };
755   uint64_t gcd = std::abs(at(0));
756   for (unsigned j = 1, e = constraints->getNumCols(); j < e; ++j) {
757     gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(at(j)));
758   }
759   if (gcd > 0 && gcd != 1) {
760     for (unsigned j = 0, e = constraints->getNumCols(); j < e; ++j) {
761       int64_t v = at(j) / static_cast<int64_t>(gcd);
762       isEq ? constraints->atEq(rowIdx, j) = v
763            : constraints->atIneq(rowIdx, j) = v;
764     }
765   }
766 }
767
768 void FlatAffineConstraints::normalizeConstraintsByGCD() {
769   for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
770     normalizeConstraintByGCD</*isEq=*/true>(this, i);
771   }
772   for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
773     normalizeConstraintByGCD</*isEq=*/false>(this, i);
774   }
775 }
776
777 bool FlatAffineConstraints::hasConsistentState() const {
778   if (inequalities.size() != getNumInequalities() * numReservedCols)
779     return false;
780   if (equalities.size() != getNumEqualities() * numReservedCols)
781     return false;
782   if (ids.size() != getNumIds())
783     return false;
784
785   // Catches errors where numDims, numSymbols, numIds aren't consistent.
786   if (numDims > numIds || numSymbols > numIds || numDims + numSymbols > numIds)
787     return false;
788
789   return true;
790 }
791
792 /// Checks all rows of equality/inequality constraints for trivial
793 /// contradictions (for example: 1 == 0, 0 >= 1), which may have surfaced
794 /// after elimination. Returns 'true' if an invalid constraint is found;
795 /// 'false' otherwise.
796 bool FlatAffineConstraints::hasInvalidConstraint() const {
797   assert(hasConsistentState());
798   auto check = [&](bool isEq) -> bool {
799     unsigned numCols = getNumCols();
800     unsigned numRows = isEq ? getNumEqualities() : getNumInequalities();
801     for (unsigned i = 0, e = numRows; i < e; ++i) {
802       unsigned j;
803       for (j = 0; j < numCols - 1; ++j) {
804         int64_t v = isEq ? atEq(i, j) : atIneq(i, j);
805         // Skip rows with non-zero variable coefficients.
806         if (v != 0)
807           break;
808       }
809       if (j < numCols - 1) {
810         continue;
811       }
812       // Check validity of constant term at 'numCols - 1' w.r.t 'isEq'.
813       // Example invalid constraints include: '1 == 0' or '-1 >= 0'
814       int64_t v = isEq ? atEq(i, numCols - 1) : atIneq(i, numCols - 1);
815       if ((isEq && v != 0) || (!isEq && v < 0)) {
816         return true;
817       }
818     }
819     return false;
820   };
821   if (check(/*isEq=*/true))
822     return true;
823   return check(/*isEq=*/false);
824 }
825
826 // Eliminate identifier from constraint at 'rowIdx' based on coefficient at
827 // pivotRow, pivotCol. Columns in range [elimColStart, pivotCol) will not be
828 // updated as they have already been eliminated.
829 static void eliminateFromConstraint(FlatAffineConstraints *constraints,
830                                     unsigned rowIdx, unsigned pivotRow,
831                                     unsigned pivotCol, unsigned elimColStart,
832                                     bool isEq) {
833   // Skip if equality 'rowIdx' if same as 'pivotRow'.
834   if (isEq && rowIdx == pivotRow)
835     return;
836   auto at = [&](unsigned i, unsigned j) -> int64_t {
837     return isEq ? constraints->atEq(i, j) : constraints->atIneq(i, j);
838   };
839   int64_t leadCoeff = at(rowIdx, pivotCol);
840   // Skip if leading coefficient at 'rowIdx' is already zero.
841   if (leadCoeff == 0)
842     return;
843   int64_t pivotCoeff = constraints->atEq(pivotRow, pivotCol);
844   int64_t sign = (leadCoeff * pivotCoeff > 0) ? -1 : 1;
845   int64_t lcm = mlir::lcm(pivotCoeff, leadCoeff);
846   int64_t pivotMultiplier = sign * (lcm / std::abs(pivotCoeff));
847   int64_t rowMultiplier = lcm / std::abs(leadCoeff);
848
849   unsigned numCols = constraints->getNumCols();
850   for (unsigned j = 0; j < numCols; ++j) {
851     // Skip updating column 'j' if it was just eliminated.
852     if (j >= elimColStart && j < pivotCol)
853       continue;
854     int64_t v = pivotMultiplier * constraints->atEq(pivotRow, j) +
855                 rowMultiplier * at(rowIdx, j);
856     isEq ? constraints->atEq(rowIdx, j) = v
857          : constraints->atIneq(rowIdx, j) = v;
858   }
859 }
860
861 // Remove coefficients in column range [colStart, colLimit) in place.
862 // This removes in data in the specified column range, and copies any
863 // remaining valid data into place.
864 static void shiftColumnsToLeft(FlatAffineConstraints *constraints,
865                                unsigned colStart, unsigned colLimit,
866                                bool isEq) {
867   assert(colLimit <= constraints->getNumIds());
868   if (colLimit <= colStart)
869     return;
870
871   unsigned numCols = constraints->getNumCols();
872   unsigned numRows = isEq ? constraints->getNumEqualities()
873                           : constraints->getNumInequalities();
874   unsigned numToEliminate = colLimit - colStart;
875   for (unsigned r = 0, e = numRows; r < e; ++r) {
876     for (unsigned c = colLimit; c < numCols; ++c) {
877       if (isEq) {
878         constraints->atEq(r, c - numToEliminate) = constraints->atEq(r, c);
879       } else {
880         constraints->atIneq(r, c - numToEliminate) = constraints->atIneq(r, c);
881       }
882     }
883   }
884 }
885
886 // Removes identifiers in column range [idStart, idLimit), and copies any
887 // remaining valid data into place, and updates member variables.
888 void FlatAffineConstraints::removeIdRange(unsigned idStart, unsigned idLimit) {
889   assert(idLimit < getNumCols() && "invalid id limit");
890
891   if (idStart >= idLimit)
892     return;
893
894   // We are going to be removing one or more identifiers from the range.
895   assert(idStart < numIds && "invalid idStart position");
896
897   // TODO: Make 'removeIdRange' a lambda called from here.
898   // Remove eliminated identifiers from equalities.
899   shiftColumnsToLeft(this, idStart, idLimit, /*isEq=*/true);
900
901   // Remove eliminated identifiers from inequalities.
902   shiftColumnsToLeft(this, idStart, idLimit, /*isEq=*/false);
903
904   // Update members numDims, numSymbols and numIds.
905   unsigned numDimsEliminated = 0;
906   unsigned numLocalsEliminated = 0;
907   unsigned numColsEliminated = idLimit - idStart;
908   if (idStart < numDims) {
909     numDimsEliminated = std::min(numDims, idLimit) - idStart;
910   }
911   // Check how many local id's were removed. Note that our identifier order is
912   // [dims, symbols, locals]. Local id start at position numDims + numSymbols.
913   if (idLimit > numDims + numSymbols) {
914     numLocalsEliminated = std::min(
915         idLimit - std::max(idStart, numDims + numSymbols), getNumLocalIds());
916   }
917   unsigned numSymbolsEliminated =
918       numColsEliminated - numDimsEliminated - numLocalsEliminated;
919
920   numDims -= numDimsEliminated;
921   numSymbols -= numSymbolsEliminated;
922   numIds = numIds - numColsEliminated;
923
924   ids.erase(ids.begin() + idStart, ids.begin() + idLimit);
925
926   // No resize necessary. numReservedCols remains the same.
927 }
928
929 /// Returns the position of the identifier that has the minimum <number of lower
930 /// bounds> times <number of upper bounds> from the specified range of
931 /// identifiers [start, end). It is often best to eliminate in the increasing
932 /// order of these counts when doing Fourier-Motzkin elimination since FM adds
933 /// that many new constraints.
934 static unsigned getBestIdToEliminate(const FlatAffineConstraints &cst,
935                                      unsigned start, unsigned end) {
936   assert(start < cst.getNumIds() && end < cst.getNumIds() + 1);
937
938   auto getProductOfNumLowerUpperBounds = [&](unsigned pos) {
939     unsigned numLb = 0;
940     unsigned numUb = 0;
941     for (unsigned r = 0, e = cst.getNumInequalities(); r < e; r++) {
942       if (cst.atIneq(r, pos) > 0) {
943         ++numLb;
944       } else if (cst.atIneq(r, pos) < 0) {
945         ++numUb;
946       }
947     }
948     return numLb * numUb;
949   };
950
951   unsigned minLoc = start;
952   unsigned min = getProductOfNumLowerUpperBounds(start);
953   for (unsigned c = start + 1; c < end; c++) {
954     unsigned numLbUbProduct = getProductOfNumLowerUpperBounds(c);
955     if (numLbUbProduct < min) {
956       min = numLbUbProduct;
957       minLoc = c;
958     }
959   }
960   return minLoc;
961 }
962
963 // Checks for emptiness of the set by eliminating identifiers successively and
964 // using the GCD test (on all equality constraints) and checking for trivially
965 // invalid constraints. Returns 'true' if the constraint system is found to be
966 // empty; false otherwise.
967 bool FlatAffineConstraints::isEmpty() const {
968   if (isEmptyByGCDTest() || hasInvalidConstraint())
969     return true;
970
971   // First, eliminate as many identifiers as possible using Gaussian
972   // elimination.
973   FlatAffineConstraints tmpCst(*this);
974   unsigned currentPos = 0;
975   while (currentPos < tmpCst.getNumIds()) {
976     tmpCst.gaussianEliminateIds(currentPos, tmpCst.getNumIds());
977     ++currentPos;
978     // We check emptiness through trivial checks after eliminating each ID to
979     // detect emptiness early. Since the checks isEmptyByGCDTest() and
980     // hasInvalidConstraint() are linear time and single sweep on the constraint
981     // buffer, this appears reasonable - but can optimize in the future.
982     if (tmpCst.hasInvalidConstraint() || tmpCst.isEmptyByGCDTest())
983       return true;
984   }
985
986   // Eliminate the remaining using FM.
987   for (unsigned i = 0, e = tmpCst.getNumIds(); i < e; i++) {
988     tmpCst.FourierMotzkinEliminate(
989         getBestIdToEliminate(tmpCst, 0, tmpCst.getNumIds()));
990     // Check for a constraint explosion. This rarely happens in practice, but
991     // this check exists as a safeguard against improperly constructed
992     // constraint systems or artificially created arbitrarily complex systems
993     // that aren't the intended use case for FlatAffineConstraints. This is
994     // needed since FM has a worst case exponential complexity in theory.
995     if (tmpCst.getNumConstraints() >= kExplosionFactor * getNumIds()) {
996       LLVM_DEBUG(llvm::dbgs() << "FM constraint explosion detected\n");
997       return false;
998     }
999
1000     // FM wouldn't have modified the equalities in any way. So no need to again
1001     // run GCD test. Check for trivial invalid constraints.
1002     if (tmpCst.hasInvalidConstraint())
1003       return true;
1004   }
1005   return false;
1006 }
1007
1008 // Runs the GCD test on all equality constraints. Returns 'true' if this test
1009 // fails on any equality. Returns 'false' otherwise.
1010 // This test can be used to disprove the existence of a solution. If it returns
1011 // true, no integer solution to the equality constraints can exist.
1012 //
1013 // GCD test definition:
1014 //
1015 // The equality constraint:
1016 //
1017 //  c_1*x_1 + c_2*x_2 + ... + c_n*x_n = c_0
1018 //
1019 // has an integer solution iff:
1020 //
1021 //  GCD of c_1, c_2, ..., c_n divides c_0.
1022 //
1023 bool FlatAffineConstraints::isEmptyByGCDTest() const {
1024   assert(hasConsistentState());
1025   unsigned numCols = getNumCols();
1026   for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
1027     uint64_t gcd = std::abs(atEq(i, 0));
1028     for (unsigned j = 1; j < numCols - 1; ++j) {
1029       gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(atEq(i, j)));
1030     }
1031     int64_t v = std::abs(atEq(i, numCols - 1));
1032     if (gcd > 0 && (v % gcd != 0)) {
1033       return true;
1034     }
1035   }
1036   return false;
1037 }
1038
1039 // Returns a matrix where each row is a vector along which the polytope is
1040 // bounded. The span of the returned vectors is guaranteed to contain all
1041 // such vectors. The returned vectors are NOT guaranteed to be linearly
1042 // independent. This function should not be called on empty sets.
1043 //
1044 // It is sufficient to check the perpendiculars of the constraints, as the set
1045 // of perpendiculars which are bounded must span all bounded directions.
1046 Matrix FlatAffineConstraints::getBoundedDirections() const {
1047   // Note that it is necessary to add the equalities too (which the constructor
1048   // does) even though we don't need to check if they are bounded; whether an
1049   // inequality is bounded or not depends on what other constraints, including
1050   // equalities, are present.
1051   Simplex simplex(*this);
1052
1053   assert(!simplex.isEmpty() && "It is not meaningful to ask whether a "
1054                                "direction is bounded in an empty set.");
1055
1056   SmallVector<unsigned, 8> boundedIneqs;
1057   // The constructor adds the inequalities to the simplex first, so this
1058   // processes all the inequalities.
1059   for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
1060     if (simplex.isBoundedAlongConstraint(i))
1061       boundedIneqs.push_back(i);
1062   }
1063
1064   // The direction vector is given by the coefficients and does not include the
1065   // constant term, so the matrix has one fewer column.
1066   unsigned dirsNumCols = getNumCols() - 1;
1067   Matrix dirs(boundedIneqs.size() + getNumEqualities(), dirsNumCols);
1068
1069   // Copy the bounded inequalities.
1070   unsigned row = 0;
1071   for (unsigned i : boundedIneqs) {
1072     for (unsigned col = 0; col < dirsNumCols; ++col)
1073       dirs(row, col) = atIneq(i, col);
1074     ++row;
1075   }
1076
1077   // Copy the equalities. All the equalities' perpendiculars are bounded.
1078   for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
1079     for (unsigned col = 0; col < dirsNumCols; ++col)
1080       dirs(row, col) = atEq(i, col);
1081     ++row;
1082   }
1083
1084   return dirs;
1085 }
1086
1087 bool eqInvolvesSuffixDims(const FlatAffineConstraints &fac, unsigned eqIndex,
1088                           unsigned numDims) {
1089   for (unsigned e = fac.getNumDimIds(), j = e - numDims; j < e; ++j)
1090     if (fac.atEq(eqIndex, j) != 0)
1091       return true;
1092   return false;
1093 }
1094 bool ineqInvolvesSuffixDims(const FlatAffineConstraints &fac,
1095                             unsigned ineqIndex, unsigned numDims) {
1096   for (unsigned e = fac.getNumDimIds(), j = e - numDims; j < e; ++j)
1097     if (fac.atIneq(ineqIndex, j) != 0)
1098       return true;
1099   return false;
1100 }
1101
1102 void removeConstraintsInvolvingSuffixDims(FlatAffineConstraints &fac,
1103                                           unsigned unboundedDims) {
1104   // We iterate backwards so that whether we remove constraint i - 1 or not, the
1105   // next constraint to be tested is always i - 2.
1106   for (unsigned i = fac.getNumEqualities(); i > 0; i--)
1107     if (eqInvolvesSuffixDims(fac, i - 1, unboundedDims))
1108       fac.removeEquality(i - 1);
1109   for (unsigned i = fac.getNumInequalities(); i > 0; i--)
1110     if (ineqInvolvesSuffixDims(fac, i - 1, unboundedDims))
1111       fac.removeInequality(i - 1);
1112 }
1113
1114 /// Let this set be S. If S is bounded then we directly call into the GBR
1115 /// sampling algorithm. Otherwise, there are some unbounded directions, i.e.,
1116 /// vectors v such that S extends to infininty along v or -v. In this case we
1117 /// use an algorithm described in the integer set library (isl) manual and used
1118 /// by the isl_set_sample function in that library. The algorithm is:
1119 ///
1120 /// 1) Apply a unimodular transform T to S to obtain S*T, such that all
1121 /// dimensions in which S*T is bounded lie in the linear span of a prefix of the
1122 /// dimensions.
1123 ///
1124 /// 2) Construct a set transformedSet by removing all constraints that involve
1125 /// the unbounded dimensions and also deleting the unbounded dimensions. Note
1126 /// that this is a bounded set.
1127 ///
1128 /// 3) Check if transformedSet is empty using the GBR sampling algorithm.
1129 ///
1130 /// 4) return S is empty iff transformedSet is empty.
1131 ///
1132 /// Since T is unimodular, a vector v is a solution to S*T iff T*v is a
1133 /// solution to S. The following is a sketch of a proof that S*T is empty
1134 /// iff transformedSet is empty:
1135 ///
1136 /// If transformedSet is empty, then S*T is certainly empty since transformedSet
1137 /// was obtained by removing constraints and deleting dimensions from S*T.
1138 ///
1139 /// If transformedSet contains a sample, consider the set C obtained by
1140 /// substituting the sample for the bounded dimensions of S*T. All the
1141 /// constraints of S*T that did not involve unbounded dimensions are
1142 /// satisfied by this substitution.
1143 ///
1144 /// In step 1, all dimensions in the linear span of the dimensions outside the
1145 /// prefix are unbounded in S*T. Substituting values for the bounded dimensions
1146 /// cannot makes these dimensions bounded, and these are the only remaining
1147 /// dimensions in C, so C is unbounded along every vector. C is hence a
1148 /// full-dimensional cone and therefore always contains an integer point, which
1149 /// we can then substitute to get a full solution to S*T.
1150 bool FlatAffineConstraints::isIntegerEmpty() const {
1151   // First, try the GCD test heuristic.
1152   if (isEmptyByGCDTest())
1153     return true;
1154
1155   Simplex simplex(*this);
1156   if (simplex.isEmpty())
1157     return true;
1158
1159   // For a bounded set, we directly call into the GBR sampling algorithm.
1160   if (!simplex.isUnbounded())
1161     return !simplex.findIntegerSample().hasValue();
1162
1163   // The set is unbounded. We cannot directly use the GBR algorithm.
1164   //
1165   // m is a matrix containing, in each row, a vector in which S is
1166   // bounded, such that the linear span of all these dimensions contains all
1167   // bounded dimensions in S.
1168   Matrix m = getBoundedDirections();
1169   // In column echelon form, each row of m occupies only the first rank(m)
1170   // columns and has zeros on the other columns. The transform T that brings S
1171   // to column echelon form is unimodular as well, so this is a suitable
1172   // transform to use in step 1 of the algorithm.
1173   std::pair<unsigned, LinearTransform> result =
1174       LinearTransform::makeTransformToColumnEchelon(std::move(m));
1175   FlatAffineConstraints transformedSet = result.second.applyTo(*this);
1176
1177   unsigned numBoundedDims = result.first;
1178   unsigned numUnboundedDims = getNumIds() - numBoundedDims;
1179   removeConstraintsInvolvingSuffixDims(transformedSet, numUnboundedDims);
1180
1181   // Remove all the unbounded dimensions.
1182   transformedSet.removeIdRange(numBoundedDims, transformedSet.getNumIds());
1183
1184   return !Simplex(transformedSet).findIntegerSample().hasValue();
1185 }
1186
1187 Optional<SmallVector<int64_t, 8>>
1188 FlatAffineConstraints::findIntegerSample() const {
1189   return Simplex(*this).findIntegerSample();
1190 }
1191
1192 /// Helper to evaluate an affine expression at a point.
1193 /// The expression is a list of coefficients for the dimensions followed by the
1194 /// constant term.
1195 static int64_t valueAt(ArrayRef<int64_t> expr, ArrayRef<int64_t> point) {
1196   assert(expr.size() == 1 + point.size() &&
1197          "Dimensionalities of point and expression don't match!");
1198   int64_t value = expr.back();
1199   for (unsigned i = 0; i < point.size(); ++i)
1200     value += expr[i] * point[i];
1201   return value;
1202 }
1203
1204 /// A point satisfies an equality iff the value of the equality at the
1205 /// expression is zero, and it satisfies an inequality iff the value of the
1206 /// inequality at that point is non-negative.
1207 bool FlatAffineConstraints::containsPoint(ArrayRef<int64_t> point) const {
1208   for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
1209     if (valueAt(getEquality(i), point) != 0)
1210       return false;
1211   }
1212   for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
1213     if (valueAt(getInequality(i), point) < 0)
1214       return false;
1215   }
1216   return true;
1217 }
1218
1219 /// Tightens inequalities given that we are dealing with integer spaces. This is
1220 /// analogous to the GCD test but applied to inequalities. The constant term can
1221 /// be reduced to the preceding multiple of the GCD of the coefficients, i.e.,
1222 ///  64*i - 100 >= 0  =>  64*i - 128 >= 0 (since 'i' is an integer). This is a
1223 /// fast method - linear in the number of coefficients.
1224 // Example on how this affects practical cases: consider the scenario:
1225 // 64*i >= 100, j = 64*i; without a tightening, elimination of i would yield
1226 // j >= 100 instead of the tighter (exact) j >= 128.
1227 void FlatAffineConstraints::GCDTightenInequalities() {
1228   unsigned numCols = getNumCols();
1229   for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
1230     uint64_t gcd = std::abs(atIneq(i, 0));
1231     for (unsigned j = 1; j < numCols - 1; ++j) {
1232       gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(atIneq(i, j)));
1233     }
1234     if (gcd > 0 && gcd != 1) {
1235       int64_t gcdI = static_cast<int64_t>(gcd);
1236       // Tighten the constant term and normalize the constraint by the GCD.
1237       atIneq(i, numCols - 1) = mlir::floorDiv(atIneq(i, numCols - 1), gcdI);
1238       for (unsigned j = 0, e = numCols - 1; j < e; ++j)
1239         atIneq(i, j) /= gcdI;
1240     }
1241   }
1242 }
1243
1244 // Eliminates all identifier variables in column range [posStart, posLimit).
1245 // Returns the number of variables eliminated.
1246 unsigned FlatAffineConstraints::gaussianEliminateIds(unsigned posStart,
1247                                                      unsigned posLimit) {
1248   // Return if identifier positions to eliminate are out of range.
1249   assert(posLimit <= numIds);
1250   assert(hasConsistentState());
1251
1252   if (posStart >= posLimit)
1253     return 0;
1254
1255   GCDTightenInequalities();
1256
1257   unsigned pivotCol = 0;
1258   for (pivotCol = posStart; pivotCol < posLimit; ++pivotCol) {
1259     // Find a row which has a non-zero coefficient in column 'j'.
1260     unsigned pivotRow;
1261     if (!findConstraintWithNonZeroAt(*this, pivotCol, /*isEq=*/true,
1262                                      &pivotRow)) {
1263       // No pivot row in equalities with non-zero at 'pivotCol'.
1264       if (!findConstraintWithNonZeroAt(*this, pivotCol, /*isEq=*/false,
1265                                        &pivotRow)) {
1266         // If inequalities are also non-zero in 'pivotCol', it can be
1267         // eliminated.
1268         continue;
1269       }
1270       break;
1271     }
1272
1273     // Eliminate identifier at 'pivotCol' from each equality row.
1274     for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
1275       eliminateFromConstraint(this, i, pivotRow, pivotCol, posStart,
1276                               /*isEq=*/true);
1277       normalizeConstraintByGCD</*isEq=*/true>(this, i);
1278     }
1279
1280     // Eliminate identifier at 'pivotCol' from each inequality row.
1281     for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
1282       eliminateFromConstraint(this, i, pivotRow, pivotCol, posStart,
1283                               /*isEq=*/false);
1284       normalizeConstraintByGCD</*isEq=*/false>(this, i);
1285     }
1286     removeEquality(pivotRow);
1287     GCDTightenInequalities();
1288   }
1289   // Update position limit based on number eliminated.
1290   posLimit = pivotCol;
1291   // Remove eliminated columns from all constraints.
1292   removeIdRange(posStart, posLimit);
1293   return posLimit - posStart;
1294 }
1295
1296 // Detect the identifier at 'pos' (say id_r) as modulo of another identifier
1297 // (say id_n) w.r.t a constant. When this happens, another identifier (say id_q)
1298 // could be detected as the floordiv of n. For eg:
1299 // id_n - 4*id_q - id_r = 0, 0 <= id_r <= 3    <=>
1300 //                          id_r = id_n mod 4, id_q = id_n floordiv 4.
1301 // lbConst and ubConst are the constant lower and upper bounds for 'pos' -
1302 // pre-detected at the caller.
1303 static bool detectAsMod(const FlatAffineConstraints &cst, unsigned pos,
1304                         int64_t lbConst, int64_t ubConst,
1305                         SmallVectorImpl<AffineExpr> *memo) {
1306   assert(pos < cst.getNumIds() && "invalid position");
1307
1308   // Check if 0 <= id_r <= divisor - 1 and if id_r is equal to
1309   // id_n - divisor * id_q. If these are true, then id_n becomes the dividend
1310   // and id_q the quotient when dividing id_n by the divisor.
1311
1312   if (lbConst != 0 || ubConst < 1)
1313     return false;
1314
1315   int64_t divisor = ubConst + 1;
1316
1317   // Now check for: id_r =  id_n - divisor * id_q. As an example, we
1318   // are looking r = d - 4q, i.e., either r - d + 4q = 0 or -r + d - 4q = 0.
1319   unsigned seenQuotient = 0, seenDividend = 0;
1320   int quotientPos = -1, dividendPos = -1;
1321   for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) {
1322     // id_n should have coeff 1 or -1.
1323     if (std::abs(cst.atEq(r, pos)) != 1)
1324       continue;
1325     // constant term should be 0.
1326     if (cst.atEq(r, cst.getNumCols() - 1) != 0)
1327       continue;
1328     unsigned c, f;
1329     int quotientSign = 1, dividendSign = 1;
1330     for (c = 0, f = cst.getNumDimAndSymbolIds(); c < f; c++) {
1331       if (c == pos)
1332         continue;
1333       // The coefficient of the quotient should be +/-divisor.
1334       // TODO: could be extended to detect an affine function for the quotient
1335       // (i.e., the coeff could be a non-zero multiple of divisor).
1336       int64_t v = cst.atEq(r, c) * cst.atEq(r, pos);
1337       if (v == divisor || v == -divisor) {
1338         seenQuotient++;
1339         quotientPos = c;
1340         quotientSign = v > 0 ? 1 : -1;
1341       }
1342       // The coefficient of the dividend should be +/-1.
1343       // TODO: could be extended to detect an affine function of the other
1344       // identifiers as the dividend.
1345       else if (v == -1 || v == 1) {
1346         seenDividend++;
1347         dividendPos = c;
1348         dividendSign = v < 0 ? 1 : -1;
1349       } else if (cst.atEq(r, c) != 0) {
1350         // Cannot be inferred as a mod since the constraint has a coefficient
1351         // for an identifier that's neither a unit nor the divisor (see TODOs
1352         // above).
1353         break;
1354       }
1355     }
1356     if (c < f)
1357       // Cannot be inferred as a mod since the constraint has a coefficient for
1358       // an identifier that's neither a unit nor the divisor (see TODOs above).
1359       continue;
1360
1361     // We are looking for exactly one identifier as the dividend.
1362     if (seenDividend == 1 && seenQuotient >= 1) {
1363       if (!(*memo)[dividendPos])
1364         return false;
1365       // Successfully detected a mod.
1366       (*memo)[pos] = (*memo)[dividendPos] % divisor * dividendSign;
1367       auto ub = cst.getConstantUpperBound(dividendPos);
1368       if (ub.hasValue() && ub.getValue() < divisor)
1369         // The mod can be optimized away.
1370         (*memo)[pos] = (*memo)[dividendPos] * dividendSign;
1371       else
1372         (*memo)[pos] = (*memo)[dividendPos] % divisor * dividendSign;
1373
1374       if (seenQuotient == 1 && !(*memo)[quotientPos])
1375         // Successfully detected a floordiv as well.
1376         (*memo)[quotientPos] =
1377             (*memo)[dividendPos].floorDiv(divisor) * quotientSign;
1378       return true;
1379     }
1380   }
1381   return false;
1382 }
1383
1384 /// Gather all lower and upper bounds of the identifier at `pos`, and
1385 /// optionally any equalities on it. In addition, the bounds are to be
1386 /// independent of identifiers in position range [`offset`, `offset` + `num`).
1387 void FlatAffineConstraints::getLowerAndUpperBoundIndices(
1388     unsigned pos, SmallVectorImpl<unsigned> *lbIndices,
1389     SmallVectorImpl<unsigned> *ubIndices, SmallVectorImpl<unsigned> *eqIndices,
1390     unsigned offset, unsigned num) const {
1391   assert(pos < getNumIds() && "invalid position");
1392   assert(offset + num < getNumCols() && "invalid range");
1393
1394   // Checks for a constraint that has a non-zero coeff for the identifiers in
1395   // the position range [offset, offset + num) while ignoring `pos`.
1396   auto containsConstraintDependentOnRange = [&](unsigned r, bool isEq) {
1397     unsigned c, f;
1398     auto cst = isEq ? getEquality(r) : getInequality(r);
1399     for (c = offset, f = offset + num; c < f; ++c) {
1400       if (c == pos)
1401         continue;
1402       if (cst[c] != 0)
1403         break;
1404     }
1405     return c < f;
1406   };
1407
1408   // Gather all lower bounds and upper bounds of the variable. Since the
1409   // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower
1410   // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1.
1411   for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
1412     // The bounds are to be independent of [offset, offset + num) columns.
1413     if (containsConstraintDependentOnRange(r, /*isEq=*/false))
1414       continue;
1415     if (atIneq(r, pos) >= 1) {
1416       // Lower bound.
1417       lbIndices->push_back(r);
1418     } else if (atIneq(r, pos) <= -1) {
1419       // Upper bound.
1420       ubIndices->push_back(r);
1421     }
1422   }
1423
1424   // An equality is both a lower and upper bound. Record any equalities
1425   // involving the pos^th identifier.
1426   if (!eqIndices)
1427     return;
1428
1429   for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
1430     if (atEq(r, pos) == 0)
1431       continue;
1432     if (containsConstraintDependentOnRange(r, /*isEq=*/true))
1433       continue;
1434     eqIndices->push_back(r);
1435   }
1436 }
1437
1438 /// Check if the pos^th identifier can be expressed as a floordiv of an affine
1439 /// function of other identifiers (where the divisor is a positive constant)
1440 /// given the initial set of expressions in `exprs`. If it can be, the
1441 /// corresponding position in `exprs` is set as the detected affine expr. For
1442 /// eg: 4q <= i + j <= 4q + 3   <=>   q = (i + j) floordiv 4. An equality can
1443 /// also yield a floordiv: eg.  4q = i + j <=> q = (i + j) floordiv 4. 32q + 28
1444 /// <= i <= 32q + 31 => q = i floordiv 32.
1445 static bool detectAsFloorDiv(const FlatAffineConstraints &cst, unsigned pos,
1446                              MLIRContext *context,
1447                              SmallVectorImpl<AffineExpr> &exprs) {
1448   assert(pos < cst.getNumIds() && "invalid position");
1449
1450   SmallVector<unsigned, 4> lbIndices, ubIndices;
1451   cst.getLowerAndUpperBoundIndices(pos, &lbIndices, &ubIndices);
1452
1453   // Check if any lower bound, upper bound pair is of the form:
1454   // divisor * id >=  expr - (divisor - 1)    <-- Lower bound for 'id'
1455   // divisor * id <=  expr                    <-- Upper bound for 'id'
1456   // Then, 'id' is equivalent to 'expr floordiv divisor'.  (where divisor > 1).
1457   //
1458   // For example, if -32*k + 16*i + j >= 0
1459   //                  32*k - 16*i - j + 31 >= 0   <=>
1460   //             k = ( 16*i + j ) floordiv 32
1461   unsigned seenDividends = 0;
1462   for (auto ubPos : ubIndices) {
1463     for (auto lbPos : lbIndices) {
1464       // Check if the lower bound's constant term is divisor - 1. The
1465       // 'divisor' here is cst.atIneq(lbPos, pos) and we already know that it's
1466       // positive (since cst.Ineq(lbPos, ...) is a lower bound expr for 'pos'.
1467       int64_t divisor = cst.atIneq(lbPos, pos);
1468       int64_t lbConstTerm = cst.atIneq(lbPos, cst.getNumCols() - 1);
1469       if (lbConstTerm != divisor - 1)
1470         continue;
1471       // Check if upper bound's constant term is 0.
1472       if (cst.atIneq(ubPos, cst.getNumCols() - 1) != 0)
1473         continue;
1474       // For the remaining part, check if the lower bound expr's coeff's are
1475       // negations of corresponding upper bound ones'.
1476       unsigned c, f;
1477       for (c = 0, f = cst.getNumCols() - 1; c < f; c++) {
1478         if (cst.atIneq(lbPos, c) != -cst.atIneq(ubPos, c))
1479           break;
1480         if (c != pos && cst.atIneq(lbPos, c) != 0)
1481           seenDividends++;
1482       }
1483       // Lb coeff's aren't negative of ub coeff's (for the non constant term
1484       // part).
1485       if (c < f)
1486         continue;
1487       if (seenDividends >= 1) {
1488         // Construct the dividend expression.
1489         auto dividendExpr = getAffineConstantExpr(0, context);
1490         unsigned c, f;
1491         for (c = 0, f = cst.getNumCols() - 1; c < f; c++) {
1492           if (c == pos)
1493             continue;
1494           int64_t ubVal = cst.atIneq(ubPos, c);
1495           if (ubVal == 0)
1496             continue;
1497           if (!exprs[c])
1498             break;
1499           dividendExpr = dividendExpr + ubVal * exprs[c];
1500         }
1501         // Expression can't be constructed as it depends on a yet unknown
1502         // identifier.
1503         // TODO: Visit/compute the identifiers in an order so that this doesn't
1504         // happen. More complex but much more efficient.
1505         if (c < f)
1506           continue;
1507         // Successfully detected the floordiv.
1508         exprs[pos] = dividendExpr.floorDiv(divisor);
1509         return true;
1510       }
1511     }
1512   }
1513   return false;
1514 }
1515
1516 // Fills an inequality row with the value 'val'.
1517 static inline void fillInequality(FlatAffineConstraints *cst, unsigned r,
1518                                   int64_t val) {
1519   for (unsigned c = 0, f = cst->getNumCols(); c < f; c++) {
1520     cst->atIneq(r, c) = val;
1521   }
1522 }
1523
1524 // Negates an inequality.
1525 static inline void negateInequality(FlatAffineConstraints *cst, unsigned r) {
1526   for (unsigned c = 0, f = cst->getNumCols(); c < f; c++) {
1527     cst->atIneq(r, c) = -cst->atIneq(r, c);
1528   }
1529 }
1530
1531 // A more complex check to eliminate redundant inequalities. Uses FourierMotzkin
1532 // to check if a constraint is redundant.
1533 void FlatAffineConstraints::removeRedundantInequalities() {
1534   SmallVector<bool, 32> redun(getNumInequalities(), false);
1535   // To check if an inequality is redundant, we replace the inequality by its
1536   // complement (for eg., i - 1 >= 0 by i <= 0), and check if the resulting
1537   // system is empty. If it is, the inequality is redundant.
1538   FlatAffineConstraints tmpCst(*this);
1539   for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
1540     // Change the inequality to its complement.
1541     negateInequality(&tmpCst, r);
1542     tmpCst.atIneq(r, tmpCst.getNumCols() - 1)--;
1543     if (tmpCst.isEmpty()) {
1544       redun[r] = true;
1545       // Zero fill the redundant inequality.
1546       fillInequality(this, r, /*val=*/0);
1547       fillInequality(&tmpCst, r, /*val=*/0);
1548     } else {
1549       // Reverse the change (to avoid recreating tmpCst each time).
1550       tmpCst.atIneq(r, tmpCst.getNumCols() - 1)++;
1551       negateInequality(&tmpCst, r);
1552     }
1553   }
1554
1555   // Scan to get rid of all rows marked redundant, in-place.
1556   auto copyRow = [&](unsigned src, unsigned dest) {
1557     if (src == dest)
1558       return;
1559     for (unsigned c = 0, e = getNumCols(); c < e; c++) {
1560       atIneq(dest, c) = atIneq(src, c);
1561     }
1562   };
1563   unsigned pos = 0;
1564   for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
1565     if (!redun[r])
1566       copyRow(r, pos++);
1567   }
1568   inequalities.resize(numReservedCols * pos);
1569 }
1570
1571 // A more complex check to eliminate redundant inequalities and equalities. Uses
1572 // Simplex to check if a constraint is redundant.
1573 void FlatAffineConstraints::removeRedundantConstraints() {
1574   // First, we run GCDTightenInequalities. This allows us to catch some
1575   // constraints which are not redundant when considering rational solutions
1576   // but are redundant in terms of integer solutions.
1577   GCDTightenInequalities();
1578   Simplex simplex(*this);
1579   simplex.detectRedundant();
1580
1581   auto copyInequality = [&](unsigned src, unsigned dest) {
1582     if (src == dest)
1583       return;
1584     for (unsigned c = 0, e = getNumCols(); c < e; c++)
1585       atIneq(dest, c) = atIneq(src, c);
1586   };
1587   unsigned pos = 0;
1588   unsigned numIneqs = getNumInequalities();
1589   // Scan to get rid of all inequalities marked redundant, in-place. In Simplex,
1590   // the first constraints added are the inequalities.
1591   for (unsigned r = 0; r < numIneqs; r++) {
1592     if (!simplex.isMarkedRedundant(r))
1593       copyInequality(r, pos++);
1594   }
1595   inequalities.resize(numReservedCols * pos);
1596
1597   // Scan to get rid of all equalities marked redundant, in-place. In Simplex,
1598   // after the inequalities, a pair of constraints for each equality is added.
1599   // An equality is redundant if both the inequalities in its pair are
1600   // redundant.
1601   auto copyEquality = [&](unsigned src, unsigned dest) {
1602     if (src == dest)
1603       return;
1604     for (unsigned c = 0, e = getNumCols(); c < e; c++)
1605       atEq(dest, c) = atEq(src, c);
1606   };
1607   pos = 0;
1608   for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
1609     if (!(simplex.isMarkedRedundant(numIneqs + 2 * r) &&
1610           simplex.isMarkedRedundant(numIneqs + 2 * r + 1)))
1611       copyEquality(r, pos++);
1612   }
1613   equalities.resize(numReservedCols * pos);
1614 }
1615
1616 std::pair<AffineMap, AffineMap> FlatAffineConstraints::getLowerAndUpperBound(
1617     unsigned pos, unsigned offset, unsigned num, unsigned symStartPos,
1618     ArrayRef<AffineExpr> localExprs, MLIRContext *context) const {
1619   assert(pos + offset < getNumDimIds() && "invalid dim start pos");
1620   assert(symStartPos >= (pos + offset) && "invalid sym start pos");
1621   assert(getNumLocalIds() == localExprs.size() &&
1622          "incorrect local exprs count");
1623
1624   SmallVector<unsigned, 4> lbIndices, ubIndices, eqIndices;
1625   getLowerAndUpperBoundIndices(pos + offset, &lbIndices, &ubIndices, &eqIndices,
1626                                offset, num);
1627
1628   /// Add to 'b' from 'a' in set [0, offset) U [offset + num, symbStartPos).
1629   auto addCoeffs = [&](ArrayRef<int64_t> a, SmallVectorImpl<int64_t> &b) {
1630     b.clear();
1631     for (unsigned i = 0, e = a.size(); i < e; ++i) {
1632       if (i < offset || i >= offset + num)
1633         b.push_back(a[i]);
1634     }
1635   };
1636
1637   SmallVector<int64_t, 8> lb, ub;
1638   SmallVector<AffineExpr, 4> lbExprs;
1639   unsigned dimCount = symStartPos - num;
1640   unsigned symCount = getNumDimAndSymbolIds() - symStartPos;
1641   lbExprs.reserve(lbIndices.size() + eqIndices.size());
1642   // Lower bound expressions.
1643   for (auto idx : lbIndices) {
1644     auto ineq = getInequality(idx);
1645     // Extract the lower bound (in terms of other coeff's + const), i.e., if
1646     // i - j + 1 >= 0 is the constraint, 'pos' is for i the lower bound is j
1647     // - 1.
1648     addCoeffs(ineq, lb);
1649     std::transform(lb.begin(), lb.end(), lb.begin(), std::negate<int64_t>());
1650     auto expr =
1651         getAffineExprFromFlatForm(lb, dimCount, symCount, localExprs, context);
1652     // expr ceildiv divisor is (expr + divisor - 1) floordiv divisor
1653     int64_t divisor = std::abs(ineq[pos + offset]);
1654     expr = (expr + divisor - 1).floorDiv(divisor);
1655     lbExprs.push_back(expr);
1656   }
1657
1658   SmallVector<AffineExpr, 4> ubExprs;
1659   ubExprs.reserve(ubIndices.size() + eqIndices.size());
1660   // Upper bound expressions.
1661   for (auto idx : ubIndices) {
1662     auto ineq = getInequality(idx);
1663     // Extract the upper bound (in terms of other coeff's + const).
1664     addCoeffs(ineq, ub);
1665     auto expr =
1666         getAffineExprFromFlatForm(ub, dimCount, symCount, localExprs, context);
1667     expr = expr.floorDiv(std::abs(ineq[pos + offset]));
1668     // Upper bound is exclusive.
1669     ubExprs.push_back(expr + 1);
1670   }
1671
1672   // Equalities. It's both a lower and a upper bound.
1673   SmallVector<int64_t, 4> b;
1674   for (auto idx : eqIndices) {
1675     auto eq = getEquality(idx);
1676     addCoeffs(eq, b);
1677     if (eq[pos + offset] > 0)
1678       std::transform(b.begin(), b.end(), b.begin(), std::negate<int64_t>());
1679
1680     // Extract the upper bound (in terms of other coeff's + const).
1681     auto expr =
1682         getAffineExprFromFlatForm(b, dimCount, symCount, localExprs, context);
1683     expr = expr.floorDiv(std::abs(eq[pos + offset]));
1684     // Upper bound is exclusive.
1685     ubExprs.push_back(expr + 1);
1686     // Lower bound.
1687     expr =
1688         getAffineExprFromFlatForm(b, dimCount, symCount, localExprs, context);
1689     expr = expr.ceilDiv(std::abs(eq[pos + offset]));
1690     lbExprs.push_back(expr);
1691   }
1692
1693   auto lbMap = AffineMap::get(dimCount, symCount, lbExprs, context);
1694   auto ubMap = AffineMap::get(dimCount, symCount, ubExprs, context);
1695
1696   return {lbMap, ubMap};
1697 }
1698
1699 /// Computes the lower and upper bounds of the first 'num' dimensional
1700 /// identifiers (starting at 'offset') as affine maps of the remaining
1701 /// identifiers (dimensional and symbolic identifiers). Local identifiers are
1702 /// themselves explicitly computed as affine functions of other identifiers in
1703 /// this process if needed.
1704 void FlatAffineConstraints::getSliceBounds(unsigned offset, unsigned num,
1705                                            MLIRContext *context,
1706                                            SmallVectorImpl<AffineMap> *lbMaps,
1707                                            SmallVectorImpl<AffineMap> *ubMaps) {
1708   assert(num < getNumDimIds() && "invalid range");
1709
1710   // Basic simplification.
1711   normalizeConstraintsByGCD();
1712
1713   LLVM_DEBUG(llvm::dbgs() << "getSliceBounds for first " << num
1714                           << " identifiers\n");
1715   LLVM_DEBUG(dump());
1716
1717   // Record computed/detected identifiers.
1718   SmallVector<AffineExpr, 8> memo(getNumIds());
1719   // Initialize dimensional and symbolic identifiers.
1720   for (unsigned i = 0, e = getNumDimIds(); i < e; i++) {
1721     if (i < offset)
1722       memo[i] = getAffineDimExpr(i, context);
1723     else if (i >= offset + num)
1724       memo[i] = getAffineDimExpr(i - num, context);
1725   }
1726   for (unsigned i = getNumDimIds(), e = getNumDimAndSymbolIds(); i < e; i++)
1727     memo[i] = getAffineSymbolExpr(i - getNumDimIds(), context);
1728
1729   bool changed;
1730   do {
1731     changed = false;
1732     // Identify yet unknown identifiers as constants or mod's / floordiv's of
1733     // other identifiers if possible.
1734     for (unsigned pos = 0; pos < getNumIds(); pos++) {
1735       if (memo[pos])
1736         continue;
1737
1738       auto lbConst = getConstantLowerBound(pos);
1739       auto ubConst = getConstantUpperBound(pos);
1740       if (lbConst.hasValue() && ubConst.hasValue()) {
1741         // Detect equality to a constant.
1742         if (lbConst.getValue() == ubConst.getValue()) {
1743           memo[pos] = getAffineConstantExpr(lbConst.getValue(), context);
1744           changed = true;
1745           continue;
1746         }
1747
1748         // Detect an identifier as modulo of another identifier w.r.t a
1749         // constant.
1750         if (detectAsMod(*this, pos, lbConst.getValue(), ubConst.getValue(),
1751                         &memo)) {
1752           changed = true;
1753           continue;
1754         }
1755       }
1756
1757       // Detect an identifier as a floordiv of an affine function of other
1758       // identifiers (divisor is a positive constant).
1759       if (detectAsFloorDiv(*this, pos, context, memo)) {
1760         changed = true;
1761         continue;
1762       }
1763
1764       // Detect an identifier as an expression of other identifiers.
1765       unsigned idx;
1766       if (!findConstraintWithNonZeroAt(*this, pos, /*isEq=*/true, &idx)) {
1767         continue;
1768       }
1769
1770       // Build AffineExpr solving for identifier 'pos' in terms of all others.
1771       auto expr = getAffineConstantExpr(0, context);
1772       unsigned j, e;
1773       for (j = 0, e = getNumIds(); j < e; ++j) {
1774         if (j == pos)
1775           continue;
1776         int64_t c = atEq(idx, j);
1777         if (c == 0)
1778           continue;
1779         // If any of the involved IDs hasn't been found yet, we can't proceed.
1780         if (!memo[j])
1781           break;
1782         expr = expr + memo[j] * c;
1783       }
1784       if (j < e)
1785         // Can't construct expression as it depends on a yet uncomputed
1786         // identifier.
1787         continue;
1788
1789       // Add constant term to AffineExpr.
1790       expr = expr + atEq(idx, getNumIds());
1791       int64_t vPos = atEq(idx, pos);
1792       assert(vPos != 0 && "expected non-zero here");
1793       if (vPos > 0)
1794         expr = (-expr).floorDiv(vPos);
1795       else
1796         // vPos < 0.
1797         expr = expr.floorDiv(-vPos);
1798       // Successfully constructed expression.
1799       memo[pos] = expr;
1800       changed = true;
1801     }
1802     // This loop is guaranteed to reach a fixed point - since once an
1803     // identifier's explicit form is computed (in memo[pos]), it's not updated
1804     // again.
1805   } while (changed);
1806
1807   // Set the lower and upper bound maps for all the identifiers that were
1808   // computed as affine expressions of the rest as the "detected expr" and
1809   // "detected expr + 1" respectively; set the undetected ones to null.
1810   Optional<FlatAffineConstraints> tmpClone;
1811   for (unsigned pos = 0; pos < num; pos++) {
1812     unsigned numMapDims = getNumDimIds() - num;
1813     unsigned numMapSymbols = getNumSymbolIds();
1814     AffineExpr expr = memo[pos + offset];
1815     if (expr)
1816       expr = simplifyAffineExpr(expr, numMapDims, numMapSymbols);
1817
1818     AffineMap &lbMap = (*lbMaps)[pos];
1819     AffineMap &ubMap = (*ubMaps)[pos];
1820
1821     if (expr) {
1822       lbMap = AffineMap::get(numMapDims, numMapSymbols, expr);
1823       ubMap = AffineMap::get(numMapDims, numMapSymbols, expr + 1);
1824     } else {
1825       // TODO: Whenever there are local identifiers in the dependence
1826       // constraints, we'll conservatively over-approximate, since we don't
1827       // always explicitly compute them above (in the while loop).
1828       if (getNumLocalIds() == 0) {
1829         // Work on a copy so that we don't update this constraint system.
1830         if (!tmpClone) {
1831           tmpClone.emplace(FlatAffineConstraints(*this));
1832           // Removing redundant inequalities is necessary so that we don't get
1833           // redundant loop bounds.
1834           tmpClone->removeRedundantInequalities();
1835         }
1836         std::tie(lbMap, ubMap) = tmpClone->getLowerAndUpperBound(
1837             pos, offset, num, getNumDimIds(), /*localExprs=*/{}, context);
1838       }
1839
1840       // If the above fails, we'll just use the constant lower bound and the
1841       // constant upper bound (if they exist) as the slice bounds.
1842       // TODO: being conservative for the moment in cases that
1843       // lead to multiple bounds - until getConstDifference in LoopFusion.cpp is
1844       // fixed (b/126426796).
1845       if (!lbMap || lbMap.getNumResults() > 1) {
1846         LLVM_DEBUG(llvm::dbgs()
1847                    << "WARNING: Potentially over-approximating slice lb\n");
1848         auto lbConst = getConstantLowerBound(pos + offset);
1849         if (lbConst.hasValue()) {
1850           lbMap = AffineMap::get(
1851               numMapDims, numMapSymbols,
1852               getAffineConstantExpr(lbConst.getValue(), context));
1853         }
1854       }
1855       if (!ubMap || ubMap.getNumResults() > 1) {
1856         LLVM_DEBUG(llvm::dbgs()
1857                    << "WARNING: Potentially over-approximating slice ub\n");
1858         auto ubConst = getConstantUpperBound(pos + offset);
1859         if (ubConst.hasValue()) {
1860           (ubMap) = AffineMap::get(
1861               numMapDims, numMapSymbols,
1862               getAffineConstantExpr(ubConst.getValue() + 1, context));
1863         }
1864       }
1865     }
1866     LLVM_DEBUG(llvm::dbgs()
1867                << "lb map for pos = " << Twine(pos + offset) << ", expr: ");
1868     LLVM_DEBUG(lbMap.dump(););
1869     LLVM_DEBUG(llvm::dbgs()
1870                << "ub map for pos = " << Twine(pos + offset) << ", expr: ");
1871     LLVM_DEBUG(ubMap.dump(););
1872   }
1873 }
1874
1875 LogicalResult
1876 FlatAffineConstraints::addLowerOrUpperBound(unsigned pos, AffineMap boundMap,
1877                                             ValueRange boundOperands, bool eq,
1878                                             bool lower) {
1879   assert(pos < getNumDimAndSymbolIds() && "invalid position");
1880   // Equality follows the logic of lower bound except that we add an equality
1881   // instead of an inequality.
1882   assert((!eq || boundMap.getNumResults() == 1) && "single result expected");
1883   if (eq)
1884     lower = true;
1885
1886   // Fully compose map and operands; canonicalize and simplify so that we
1887   // transitively get to terminal symbols or loop IVs.
1888   auto map = boundMap;
1889   SmallVector<Value, 4> operands(boundOperands.begin(), boundOperands.end());
1890   fullyComposeAffineMapAndOperands(&map, &operands);
1891   map = simplifyAffineMap(map);
1892   canonicalizeMapAndOperands(&map, &operands);
1893   for (auto operand : operands)
1894     addInductionVarOrTerminalSymbol(operand);
1895
1896   FlatAffineConstraints localVarCst;
1897   std::vector<SmallVector<int64_t, 8>> flatExprs;
1898   if (failed(getFlattenedAffineExprs(map, &flatExprs, &localVarCst))) {
1899     LLVM_DEBUG(llvm::dbgs() << "semi-affine expressions not yet supported\n");
1900     return failure();
1901   }
1902
1903   // Merge and align with localVarCst.
1904   if (localVarCst.getNumLocalIds() > 0) {
1905     // Set values for localVarCst.
1906     localVarCst.setIdValues(0, localVarCst.getNumDimAndSymbolIds(), operands);
1907     for (auto operand : operands) {
1908       unsigned pos;
1909       if (findId(operand, &pos)) {
1910         if (pos >= getNumDimIds() && pos < getNumDimAndSymbolIds()) {
1911           // If the local var cst has this as a dim, turn it into its symbol.
1912           turnDimIntoSymbol(&localVarCst, operand);
1913         } else if (pos < getNumDimIds()) {
1914           // Or vice versa.
1915           turnSymbolIntoDim(&localVarCst, operand);
1916         }
1917       }
1918     }
1919     mergeAndAlignIds(/*offset=*/0, this, &localVarCst);
1920     append(localVarCst);
1921   }
1922
1923   // Record positions of the operands in the constraint system. Need to do
1924   // this here since the constraint system changes after a bound is added.
1925   SmallVector<unsigned, 8> positions;
1926   unsigned numOperands = operands.size();
1927   for (auto operand : operands) {
1928     unsigned pos;
1929     if (!findId(operand, &pos))
1930       assert(0 && "expected to be found");
1931     positions.push_back(pos);
1932   }
1933
1934   for (const auto &flatExpr : flatExprs) {
1935     SmallVector<int64_t, 4> ineq(getNumCols(), 0);
1936     ineq[pos] = lower ? 1 : -1;
1937     // Dims and symbols.
1938     for (unsigned j = 0, e = map.getNumInputs(); j < e; j++) {
1939       ineq[positions[j]] = lower ? -flatExpr[j] : flatExpr[j];
1940     }
1941     // Copy over the local id coefficients.
1942     unsigned numLocalIds = flatExpr.size() - 1 - numOperands;
1943     for (unsigned jj = 0, j = getNumIds() - numLocalIds; jj < numLocalIds;
1944          jj++, j++) {
1945       ineq[j] =
1946           lower ? -flatExpr[numOperands + jj] : flatExpr[numOperands + jj];
1947     }
1948     // Constant term.
1949     ineq[getNumCols() - 1] =
1950         lower ? -flatExpr[flatExpr.size() - 1]
1951               // Upper bound in flattenedExpr is an exclusive one.
1952               : flatExpr[flatExpr.size() - 1] - 1;
1953     eq ? addEquality(ineq) : addInequality(ineq);
1954   }
1955   return success();
1956 }
1957
1958 // Adds slice lower bounds represented by lower bounds in 'lbMaps' and upper
1959 // bounds in 'ubMaps' to each value in `values' that appears in the constraint
1960 // system. Note that both lower/upper bounds share the same operand list
1961 // 'operands'.
1962 // This function assumes 'values.size' == 'lbMaps.size' == 'ubMaps.size', and
1963 // skips any null AffineMaps in 'lbMaps' or 'ubMaps'.
1964 // Note that both lower/upper bounds use operands from 'operands'.
1965 // Returns failure for unimplemented cases such as semi-affine expressions or
1966 // expressions with mod/floordiv.
1967 LogicalResult FlatAffineConstraints::addSliceBounds(ArrayRef<Value> values,
1968                                                     ArrayRef<AffineMap> lbMaps,
1969                                                     ArrayRef<AffineMap> ubMaps,
1970                                                     ArrayRef<Value> operands) {
1971   assert(values.size() == lbMaps.size());
1972   assert(lbMaps.size() == ubMaps.size());
1973
1974   for (unsigned i = 0, e = lbMaps.size(); i < e; ++i) {
1975     unsigned pos;
1976     if (!findId(values[i], &pos))
1977       continue;
1978
1979     AffineMap lbMap = lbMaps[i];
1980     AffineMap ubMap = ubMaps[i];
1981     assert(!lbMap || lbMap.getNumInputs() == operands.size());
1982     assert(!ubMap || ubMap.getNumInputs() == operands.size());
1983
1984     // Check if this slice is just an equality along this dimension.
1985     if (lbMap && ubMap && lbMap.getNumResults() == 1 &&
1986         ubMap.getNumResults() == 1 &&
1987         lbMap.getResult(0) + 1 == ubMap.getResult(0)) {
1988       if (failed(addLowerOrUpperBound(pos, lbMap, operands, /*eq=*/true,
1989                                       /*lower=*/true)))
1990         return failure();
1991       continue;
1992     }
1993
1994     if (lbMap && failed(addLowerOrUpperBound(pos, lbMap, operands, /*eq=*/false,
1995                                              /*lower=*/true)))
1996       return failure();
1997
1998     if (ubMap && failed(addLowerOrUpperBound(pos, ubMap, operands, /*eq=*/false,
1999                                              /*lower=*/false)))
2000       return failure();
2001   }
2002   return success();
2003 }
2004
2005 void FlatAffineConstraints::addEquality(ArrayRef<int64_t> eq) {
2006   assert(eq.size() == getNumCols());
2007   unsigned offset = equalities.size();
2008   equalities.resize(equalities.size() + numReservedCols);
2009   std::copy(eq.begin(), eq.end(), equalities.begin() + offset);
2010 }
2011
2012 void FlatAffineConstraints::addInequality(ArrayRef<int64_t> inEq) {
2013   assert(inEq.size() == getNumCols());
2014   unsigned offset = inequalities.size();
2015   inequalities.resize(inequalities.size() + numReservedCols);
2016   std::copy(inEq.begin(), inEq.end(), inequalities.begin() + offset);
2017 }
2018
2019 void FlatAffineConstraints::addConstantLowerBound(unsigned pos, int64_t lb) {
2020   assert(pos < getNumCols());
2021   unsigned offset = inequalities.size();
2022   inequalities.resize(inequalities.size() + numReservedCols);
2023   std::fill(inequalities.begin() + offset,
2024             inequalities.begin() + offset + getNumCols(), 0);
2025   inequalities[offset + pos] = 1;
2026   inequalities[offset + getNumCols() - 1] = -lb;
2027 }
2028
2029 void FlatAffineConstraints::addConstantUpperBound(unsigned pos, int64_t ub) {
2030   assert(pos < getNumCols());
2031   unsigned offset = inequalities.size();
2032   inequalities.resize(inequalities.size() + numReservedCols);
2033   std::fill(inequalities.begin() + offset,
2034             inequalities.begin() + offset + getNumCols(), 0);
2035   inequalities[offset + pos] = -1;
2036   inequalities[offset + getNumCols() - 1] = ub;
2037 }
2038
2039 void FlatAffineConstraints::addConstantLowerBound(ArrayRef<int64_t> expr,
2040                                                   int64_t lb) {
2041   assert(expr.size() == getNumCols());
2042   unsigned offset = inequalities.size();
2043   inequalities.resize(inequalities.size() + numReservedCols);
2044   std::fill(inequalities.begin() + offset,
2045             inequalities.begin() + offset + getNumCols(), 0);
2046   std::copy(expr.begin(), expr.end(), inequalities.begin() + offset);
2047   inequalities[offset + getNumCols() - 1] += -lb;
2048 }
2049
2050 void FlatAffineConstraints::addConstantUpperBound(ArrayRef<int64_t> expr,
2051                                                   int64_t ub) {
2052   assert(expr.size() == getNumCols());
2053   unsigned offset = inequalities.size();
2054   inequalities.resize(inequalities.size() + numReservedCols);
2055   std::fill(inequalities.begin() + offset,
2056             inequalities.begin() + offset + getNumCols(), 0);
2057   for (unsigned i = 0, e = getNumCols(); i < e; i++) {
2058     inequalities[offset + i] = -expr[i];
2059   }
2060   inequalities[offset + getNumCols() - 1] += ub;
2061 }
2062
2063 /// Adds a new local identifier as the floordiv of an affine function of other
2064 /// identifiers, the coefficients of which are provided in 'dividend' and with
2065 /// respect to a positive constant 'divisor'. Two constraints are added to the
2066 /// system to capture equivalence with the floordiv.
2067 ///      q = expr floordiv c    <=>   c*q <= expr <= c*q + c - 1.
2068 void FlatAffineConstraints::addLocalFloorDiv(ArrayRef<int64_t> dividend,
2069                                              int64_t divisor) {
2070   assert(dividend.size() == getNumCols() && "incorrect dividend size");
2071   assert(divisor > 0 && "positive divisor expected");
2072
2073   addLocalId(getNumLocalIds());
2074
2075   // Add two constraints for this new identifier 'q'.
2076   SmallVector<int64_t, 8> bound(dividend.size() + 1);
2077
2078   // dividend - q * divisor >= 0
2079   std::copy(dividend.begin(), dividend.begin() + dividend.size() - 1,
2080             bound.begin());
2081   bound.back() = dividend.back();
2082   bound[getNumIds() - 1] = -divisor;
2083   addInequality(bound);
2084
2085   // -dividend +qdivisor * q + divisor - 1 >= 0
2086   std::transform(bound.begin(), bound.end(), bound.begin(),
2087                  std::negate<int64_t>());
2088   bound[bound.size() - 1] += divisor - 1;
2089   addInequality(bound);
2090 }
2091
2092 bool FlatAffineConstraints::findId(Value id, unsigned *pos) const {
2093   unsigned i = 0;
2094   for (const auto &mayBeId : ids) {
2095     if (mayBeId.hasValue() && mayBeId.getValue() == id) {
2096       *pos = i;
2097       return true;
2098     }
2099     i++;
2100   }
2101   return false;
2102 }
2103
2104 bool FlatAffineConstraints::containsId(Value id) const {
2105   return llvm::any_of(ids, [&](const Optional<Value> &mayBeId) {
2106     return mayBeId.hasValue() && mayBeId.getValue() == id;
2107   });
2108 }
2109
2110 void FlatAffineConstraints::swapId(unsigned posA, unsigned posB) {
2111   assert(posA < getNumIds() && "invalid position A");
2112   assert(posB < getNumIds() && "invalid position B");
2113
2114   if (posA == posB)
2115     return;
2116
2117   for (unsigned r = 0, e = getNumInequalities(); r < e; r++)
2118     std::swap(atIneq(r, posA), atIneq(r, posB));
2119   for (unsigned r = 0, e = getNumEqualities(); r < e; r++)
2120     std::swap(atEq(r, posA), atEq(r, posB));
2121   std::swap(getId(posA), getId(posB));
2122 }
2123
2124 void FlatAffineConstraints::setDimSymbolSeparation(unsigned newSymbolCount) {
2125   assert(newSymbolCount <= numDims + numSymbols &&
2126          "invalid separation position");
2127   numDims = numDims + numSymbols - newSymbolCount;
2128   numSymbols = newSymbolCount;
2129 }
2130
2131 /// Sets the specified identifier to a constant value.
2132 void FlatAffineConstraints::setIdToConstant(unsigned pos, int64_t val) {
2133   unsigned offset = equalities.size();
2134   equalities.resize(equalities.size() + numReservedCols);
2135   std::fill(equalities.begin() + offset,
2136             equalities.begin() + offset + getNumCols(), 0);
2137   equalities[offset + pos] = 1;
2138   equalities[offset + getNumCols() - 1] = -val;
2139 }
2140
2141 /// Sets the specified identifier to a constant value; asserts if the id is not
2142 /// found.
2143 void FlatAffineConstraints::setIdToConstant(Value id, int64_t val) {
2144   unsigned pos;
2145   if (!findId(id, &pos))
2146     // This is a pre-condition for this method.
2147     assert(0 && "id not found");
2148   setIdToConstant(pos, val);
2149 }
2150
2151 void FlatAffineConstraints::removeEquality(unsigned pos) {
2152   unsigned numEqualities = getNumEqualities();
2153   assert(pos < numEqualities);
2154   unsigned outputIndex = pos * numReservedCols;
2155   unsigned inputIndex = (pos + 1) * numReservedCols;
2156   unsigned numElemsToCopy = (numEqualities - pos - 1) * numReservedCols;
2157   std::copy(equalities.begin() + inputIndex,
2158             equalities.begin() + inputIndex + numElemsToCopy,
2159             equalities.begin() + outputIndex);
2160   assert(equalities.size() >= numReservedCols);
2161   equalities.resize(equalities.size() - numReservedCols);
2162 }
2163
2164 void FlatAffineConstraints::removeInequality(unsigned pos) {
2165   unsigned numInequalities = getNumInequalities();
2166   assert(pos < numInequalities && "invalid position");
2167   unsigned outputIndex = pos * numReservedCols;
2168   unsigned inputIndex = (pos + 1) * numReservedCols;
2169   unsigned numElemsToCopy = (numInequalities - pos - 1) * numReservedCols;
2170   std::copy(inequalities.begin() + inputIndex,
2171             inequalities.begin() + inputIndex + numElemsToCopy,
2172             inequalities.begin() + outputIndex);
2173   assert(inequalities.size() >= numReservedCols);
2174   inequalities.resize(inequalities.size() - numReservedCols);
2175 }
2176
2177 /// Finds an equality that equates the specified identifier to a constant.
2178 /// Returns the position of the equality row. If 'symbolic' is set to true,
2179 /// symbols are also treated like a constant, i.e., an affine function of the
2180 /// symbols is also treated like a constant. Returns -1 if such an equality
2181 /// could not be found.
2182 static int findEqualityToConstant(const FlatAffineConstraints &cst,
2183                                   unsigned pos, bool symbolic = false) {
2184   assert(pos < cst.getNumIds() && "invalid position");
2185   for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) {
2186     int64_t v = cst.atEq(r, pos);
2187     if (v * v != 1)
2188       continue;
2189     unsigned c;
2190     unsigned f = symbolic ? cst.getNumDimIds() : cst.getNumIds();
2191     // This checks for zeros in all positions other than 'pos' in [0, f)
2192     for (c = 0; c < f; c++) {
2193       if (c == pos)
2194         continue;
2195       if (cst.atEq(r, c) != 0) {
2196         // Dependent on another identifier.
2197         break;
2198       }
2199     }
2200     if (c == f)
2201       // Equality is free of other identifiers.
2202       return r;
2203   }
2204   return -1;
2205 }
2206
2207 void FlatAffineConstraints::setAndEliminate(unsigned pos, int64_t constVal) {
2208   assert(pos < getNumIds() && "invalid position");
2209   for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
2210     atIneq(r, getNumCols() - 1) += atIneq(r, pos) * constVal;
2211   }
2212   for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
2213     atEq(r, getNumCols() - 1) += atEq(r, pos) * constVal;
2214   }
2215   removeId(pos);
2216 }
2217
2218 LogicalResult FlatAffineConstraints::constantFoldId(unsigned pos) {
2219   assert(pos < getNumIds() && "invalid position");
2220   int rowIdx;
2221   if ((rowIdx = findEqualityToConstant(*this, pos)) == -1)
2222     return failure();
2223
2224   // atEq(rowIdx, pos) is either -1 or 1.
2225   assert(atEq(rowIdx, pos) * atEq(rowIdx, pos) == 1);
2226   int64_t constVal = -atEq(rowIdx, getNumCols() - 1) / atEq(rowIdx, pos);
2227   setAndEliminate(pos, constVal);
2228   return success();
2229 }
2230
2231 void FlatAffineConstraints::constantFoldIdRange(unsigned pos, unsigned num) {
2232   for (unsigned s = pos, t = pos, e = pos + num; s < e; s++) {
2233     if (failed(constantFoldId(t)))
2234       t++;
2235   }
2236 }
2237
2238 /// Returns the extent (upper bound - lower bound) of the specified
2239 /// identifier if it is found to be a constant; returns None if it's not a
2240 /// constant. This methods treats symbolic identifiers specially, i.e.,
2241 /// it looks for constant differences between affine expressions involving
2242 /// only the symbolic identifiers. See comments at function definition for
2243 /// example. 'lb', if provided, is set to the lower bound associated with the
2244 /// constant difference. Note that 'lb' is purely symbolic and thus will contain
2245 /// the coefficients of the symbolic identifiers and the constant coefficient.
2246 //  Egs: 0 <= i <= 15, return 16.
2247 //       s0 + 2 <= i <= s0 + 17, returns 16. (s0 has to be a symbol)
2248 //       s0 + s1 + 16 <= d0 <= s0 + s1 + 31, returns 16.
2249 //       s0 - 7 <= 8*j <= s0 returns 1 with lb = s0, lbDivisor = 8 (since lb =
2250 //       ceil(s0 - 7 / 8) = floor(s0 / 8)).
2251 Optional<int64_t> FlatAffineConstraints::getConstantBoundOnDimSize(
2252     unsigned pos, SmallVectorImpl<int64_t> *lb, int64_t *boundFloorDivisor,
2253     SmallVectorImpl<int64_t> *ub, unsigned *minLbPos,
2254     unsigned *minUbPos) const {
2255   assert(pos < getNumDimIds() && "Invalid identifier position");
2256
2257   // Find an equality for 'pos'^th identifier that equates it to some function
2258   // of the symbolic identifiers (+ constant).
2259   int eqPos = findEqualityToConstant(*this, pos, /*symbolic=*/true);
2260   if (eqPos != -1) {
2261     auto eq = getEquality(eqPos);
2262     // If the equality involves a local var, punt for now.
2263     // TODO: this can be handled in the future by using the explicit
2264     // representation of the local vars.
2265     if (!std::all_of(eq.begin() + getNumDimAndSymbolIds(), eq.end() - 1,
2266                      [](int64_t coeff) { return coeff == 0; }))
2267       return None;
2268
2269     // This identifier can only take a single value.
2270     if (lb) {
2271       // Set lb to that symbolic value.
2272       lb->resize(getNumSymbolIds() + 1);
2273       if (ub)
2274         ub->resize(getNumSymbolIds() + 1);
2275       for (unsigned c = 0, f = getNumSymbolIds() + 1; c < f; c++) {
2276         int64_t v = atEq(eqPos, pos);
2277         // atEq(eqRow, pos) is either -1 or 1.
2278         assert(v * v == 1);
2279         (*lb)[c] = v < 0 ? atEq(eqPos, getNumDimIds() + c) / -v
2280                          : -atEq(eqPos, getNumDimIds() + c) / v;
2281         // Since this is an equality, ub = lb.
2282         if (ub)
2283           (*ub)[c] = (*lb)[c];
2284       }
2285       assert(boundFloorDivisor &&
2286              "both lb and divisor or none should be provided");
2287       *boundFloorDivisor = 1;
2288     }
2289     if (minLbPos)
2290       *minLbPos = eqPos;
2291     if (minUbPos)
2292       *minUbPos = eqPos;
2293     return 1;
2294   }
2295
2296   // Check if the identifier appears at all in any of the inequalities.
2297   unsigned r, e;
2298   for (r = 0, e = getNumInequalities(); r < e; r++) {
2299     if (atIneq(r, pos) != 0)
2300       break;
2301   }
2302   if (r == e)
2303     // If it doesn't, there isn't a bound on it.
2304     return None;
2305
2306   // Positions of constraints that are lower/upper bounds on the variable.
2307   SmallVector<unsigned, 4> lbIndices, ubIndices;
2308
2309   // Gather all symbolic lower bounds and upper bounds of the variable, i.e.,
2310   // the bounds can only involve symbolic (and local) identifiers. Since the
2311   // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower
2312   // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1.
2313   getLowerAndUpperBoundIndices(pos, &lbIndices, &ubIndices,
2314                                /*eqIndices=*/nullptr, /*offset=*/0,
2315                                /*num=*/getNumDimIds());
2316
2317   Optional<int64_t> minDiff = None;
2318   unsigned minLbPosition = 0, minUbPosition = 0;
2319   for (auto ubPos : ubIndices) {
2320     for (auto lbPos : lbIndices) {
2321       // Look for a lower bound and an upper bound that only differ by a
2322       // constant, i.e., pairs of the form  0 <= c_pos - f(c_i's) <= diffConst.
2323       // For example, if ii is the pos^th variable, we are looking for
2324       // constraints like ii >= i, ii <= ii + 50, 50 being the difference. The
2325       // minimum among all such constant differences is kept since that's the
2326       // constant bounding the extent of the pos^th variable.
2327       unsigned j, e;
2328       for (j = 0, e = getNumCols() - 1; j < e; j++)
2329         if (atIneq(ubPos, j) != -atIneq(lbPos, j)) {
2330           break;
2331         }
2332       if (j < getNumCols() - 1)
2333         continue;
2334       int64_t diff = ceilDiv(atIneq(ubPos, getNumCols() - 1) +
2335                                  atIneq(lbPos, getNumCols() - 1) + 1,
2336                              atIneq(lbPos, pos));
2337       if (minDiff == None || diff < minDiff) {
2338         minDiff = diff;
2339         minLbPosition = lbPos;
2340         minUbPosition = ubPos;
2341       }
2342     }
2343   }
2344   if (lb && minDiff.hasValue()) {
2345     // Set lb to the symbolic lower bound.
2346     lb->resize(getNumSymbolIds() + 1);
2347     if (ub)
2348       ub->resize(getNumSymbolIds() + 1);
2349     // The lower bound is the ceildiv of the lb constraint over the coefficient
2350     // of the variable at 'pos'. We express the ceildiv equivalently as a floor
2351     // for uniformity. For eg., if the lower bound constraint was: 32*d0 - N +
2352     // 31 >= 0, the lower bound for d0 is ceil(N - 31, 32), i.e., floor(N, 32).
2353     *boundFloorDivisor = atIneq(minLbPosition, pos);
2354     assert(*boundFloorDivisor == -atIneq(minUbPosition, pos));
2355     for (unsigned c = 0, e = getNumSymbolIds() + 1; c < e; c++) {
2356       (*lb)[c] = -atIneq(minLbPosition, getNumDimIds() + c);
2357     }
2358     if (ub) {
2359       for (unsigned c = 0, e = getNumSymbolIds() + 1; c < e; c++)
2360         (*ub)[c] = atIneq(minUbPosition, getNumDimIds() + c);
2361     }
2362     // The lower bound leads to a ceildiv while the upper bound is a floordiv
2363     // whenever the coefficient at pos != 1. ceildiv (val / d) = floordiv (val +
2364     // d - 1 / d); hence, the addition of 'atIneq(minLbPosition, pos) - 1' to
2365     // the constant term for the lower bound.
2366     (*lb)[getNumSymbolIds()] += atIneq(minLbPosition, pos) - 1;
2367   }
2368   if (minLbPos)
2369     *minLbPos = minLbPosition;
2370   if (minUbPos)
2371     *minUbPos = minUbPosition;
2372   return minDiff;
2373 }
2374
2375 template <bool isLower>
2376 Optional<int64_t>
2377 FlatAffineConstraints::computeConstantLowerOrUpperBound(unsigned pos) {
2378   assert(pos < getNumIds() && "invalid position");
2379   // Project to 'pos'.
2380   projectOut(0, pos);
2381   projectOut(1, getNumIds() - 1);
2382   // Check if there's an equality equating the '0'^th identifier to a constant.
2383   int eqRowIdx = findEqualityToConstant(*this, 0, /*symbolic=*/false);
2384   if (eqRowIdx != -1)
2385     // atEq(rowIdx, 0) is either -1 or 1.
2386     return -atEq(eqRowIdx, getNumCols() - 1) / atEq(eqRowIdx, 0);
2387
2388   // Check if the identifier appears at all in any of the inequalities.
2389   unsigned r, e;
2390   for (r = 0, e = getNumInequalities(); r < e; r++) {
2391     if (atIneq(r, 0) != 0)
2392       break;
2393   }
2394   if (r == e)
2395     // If it doesn't, there isn't a bound on it.
2396     return None;
2397
2398   Optional<int64_t> minOrMaxConst = None;
2399
2400   // Take the max across all const lower bounds (or min across all constant
2401   // upper bounds).
2402   for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
2403     if (isLower) {
2404       if (atIneq(r, 0) <= 0)
2405         // Not a lower bound.
2406         continue;
2407     } else if (atIneq(r, 0) >= 0) {
2408       // Not an upper bound.
2409       continue;
2410     }
2411     unsigned c, f;
2412     for (c = 0, f = getNumCols() - 1; c < f; c++)
2413       if (c != 0 && atIneq(r, c) != 0)
2414         break;
2415     if (c < getNumCols() - 1)
2416       // Not a constant bound.
2417       continue;
2418
2419     int64_t boundConst =
2420         isLower ? mlir::ceilDiv(-atIneq(r, getNumCols() - 1), atIneq(r, 0))
2421                 : mlir::floorDiv(atIneq(r, getNumCols() - 1), -atIneq(r, 0));
2422     if (isLower) {
2423       if (minOrMaxConst == None || boundConst > minOrMaxConst)
2424         minOrMaxConst = boundConst;
2425     } else {
2426       if (minOrMaxConst == None || boundConst < minOrMaxConst)
2427         minOrMaxConst = boundConst;
2428     }
2429   }
2430   return minOrMaxConst;
2431 }
2432
2433 Optional<int64_t>
2434 FlatAffineConstraints::getConstantLowerBound(unsigned pos) const {
2435   FlatAffineConstraints tmpCst(*this);
2436   return tmpCst.computeConstantLowerOrUpperBound</*isLower=*/true>(pos);
2437 }
2438
2439 Optional<int64_t>
2440 FlatAffineConstraints::getConstantUpperBound(unsigned pos) const {
2441   FlatAffineConstraints tmpCst(*this);
2442   return tmpCst.computeConstantLowerOrUpperBound</*isLower=*/false>(pos);
2443 }
2444
2445 // A simple (naive and conservative) check for hyper-rectangularity.
2446 bool FlatAffineConstraints::isHyperRectangular(unsigned pos,
2447                                                unsigned num) const {
2448   assert(pos < getNumCols() - 1);
2449   // Check for two non-zero coefficients in the range [pos, pos + sum).
2450   for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
2451     unsigned sum = 0;
2452     for (unsigned c = pos; c < pos + num; c++) {
2453       if (atIneq(r, c) != 0)
2454         sum++;
2455     }
2456     if (sum > 1)
2457       return false;
2458   }
2459   for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
2460     unsigned sum = 0;
2461     for (unsigned c = pos; c < pos + num; c++) {
2462       if (atEq(r, c) != 0)
2463         sum++;
2464     }
2465     if (sum > 1)
2466       return false;
2467   }
2468   return true;
2469 }
2470
2471 void FlatAffineConstraints::print(raw_ostream &os) const {
2472   assert(hasConsistentState());
2473   os << "\nConstraints (" << getNumDimIds() << " dims, " << getNumSymbolIds()
2474      << " symbols, " << getNumLocalIds() << " locals), (" << getNumConstraints()
2475      << " constraints)\n";
2476   os << "(";
2477   for (unsigned i = 0, e = getNumIds(); i < e; i++) {
2478     if (ids[i] == None)
2479       os << "None ";
2480     else
2481       os << "Value ";
2482   }
2483   os << " const)\n";
2484   for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
2485     for (unsigned j = 0, f = getNumCols(); j < f; ++j) {
2486       os << atEq(i, j) << " ";
2487     }
2488     os << "= 0\n";
2489   }
2490   for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
2491     for (unsigned j = 0, f = getNumCols(); j < f; ++j) {
2492       os << atIneq(i, j) << " ";
2493     }
2494     os << ">= 0\n";
2495   }
2496   os << '\n';
2497 }
2498
2499 void FlatAffineConstraints::dump() const { print(llvm::errs()); }
2500
2501 /// Removes duplicate constraints, trivially true constraints, and constraints
2502 /// that can be detected as redundant as a result of differing only in their
2503 /// constant term part. A constraint of the form <non-negative constant> >= 0 is
2504 /// considered trivially true.
2505 //  Uses a DenseSet to hash and detect duplicates followed by a linear scan to
2506 //  remove duplicates in place.
2507 void FlatAffineConstraints::removeTrivialRedundancy() {
2508   GCDTightenInequalities();
2509   normalizeConstraintsByGCD();
2510
2511   // A map used to detect redundancy stemming from constraints that only differ
2512   // in their constant term. The value stored is <row position, const term>
2513   // for a given row.
2514   SmallDenseMap<ArrayRef<int64_t>, std::pair<unsigned, int64_t>>
2515       rowsWithoutConstTerm;
2516   // To unique rows.
2517   SmallDenseSet<ArrayRef<int64_t>, 8> rowSet;
2518
2519   // Check if constraint is of the form <non-negative-constant> >= 0.
2520   auto isTriviallyValid = [&](unsigned r) -> bool {
2521     for (unsigned c = 0, e = getNumCols() - 1; c < e; c++) {
2522       if (atIneq(r, c) != 0)
2523         return false;
2524     }
2525     return atIneq(r, getNumCols() - 1) >= 0;
2526   };
2527
2528   // Detect and mark redundant constraints.
2529   SmallVector<bool, 256> redunIneq(getNumInequalities(), false);
2530   for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
2531     int64_t *rowStart = inequalities.data() + numReservedCols * r;
2532     auto row = ArrayRef<int64_t>(rowStart, getNumCols());
2533     if (isTriviallyValid(r) || !rowSet.insert(row).second) {
2534       redunIneq[r] = true;
2535       continue;
2536     }
2537
2538     // Among constraints that only differ in the constant term part, mark
2539     // everything other than the one with the smallest constant term redundant.
2540     // (eg: among i - 16j - 5 >= 0, i - 16j - 1 >=0, i - 16j - 7 >= 0, the
2541     // former two are redundant).
2542     int64_t constTerm = atIneq(r, getNumCols() - 1);
2543     auto rowWithoutConstTerm = ArrayRef<int64_t>(rowStart, getNumCols() - 1);
2544     const auto &ret =
2545         rowsWithoutConstTerm.insert({rowWithoutConstTerm, {r, constTerm}});
2546     if (!ret.second) {
2547       // Check if the other constraint has a higher constant term.
2548       auto &val = ret.first->second;
2549       if (val.second > constTerm) {
2550         // The stored row is redundant. Mark it so, and update with this one.
2551         redunIneq[val.first] = true;
2552         val = {r, constTerm};
2553       } else {
2554         // The one stored makes this one redundant.
2555         redunIneq[r] = true;
2556       }
2557     }
2558   }
2559
2560   auto copyRow = [&](unsigned src, unsigned dest) {
2561     if (src == dest)
2562       return;
2563     for (unsigned c = 0, e = getNumCols(); c < e; c++) {
2564       atIneq(dest, c) = atIneq(src, c);
2565     }
2566   };
2567
2568   // Scan to get rid of all rows marked redundant, in-place.
2569   unsigned pos = 0;
2570   for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
2571     if (!redunIneq[r])
2572       copyRow(r, pos++);
2573   }
2574   inequalities.resize(numReservedCols * pos);
2575
2576   // TODO: consider doing this for equalities as well, but probably not worth
2577   // the savings.
2578 }
2579
2580 void FlatAffineConstraints::clearAndCopyFrom(
2581     const FlatAffineConstraints &other) {
2582   FlatAffineConstraints copy(other);
2583   std::swap(*this, copy);
2584   assert(copy.getNumIds() == copy.getIds().size());
2585 }
2586
2587 void FlatAffineConstraints::removeId(unsigned pos) {
2588   removeIdRange(pos, pos + 1);
2589 }
2590
2591 static std::pair<unsigned, unsigned>
2592 getNewNumDimsSymbols(unsigned pos, const FlatAffineConstraints &cst) {
2593   unsigned numDims = cst.getNumDimIds();
2594   unsigned numSymbols = cst.getNumSymbolIds();
2595   unsigned newNumDims, newNumSymbols;
2596   if (pos < numDims) {
2597     newNumDims = numDims - 1;
2598     newNumSymbols = numSymbols;
2599   } else if (pos < numDims + numSymbols) {
2600     assert(numSymbols >= 1);
2601     newNumDims = numDims;
2602     newNumSymbols = numSymbols - 1;
2603   } else {
2604     newNumDims = numDims;
2605     newNumSymbols = numSymbols;
2606   }
2607   return {newNumDims, newNumSymbols};
2608 }
2609
2610 #undef DEBUG_TYPE
2611 #define DEBUG_TYPE "fm"
2612
2613 /// Eliminates identifier at the specified position using Fourier-Motzkin
2614 /// variable elimination. This technique is exact for rational spaces but
2615 /// conservative (in "rare" cases) for integer spaces. The operation corresponds
2616 /// to a projection operation yielding the (convex) set of integer points
2617 /// contained in the rational shadow of the set. An emptiness test that relies
2618 /// on this method will guarantee emptiness, i.e., it disproves the existence of
2619 /// a solution if it says it's empty.
2620 /// If a non-null isResultIntegerExact is passed, it is set to true if the
2621 /// result is also integer exact. If it's set to false, the obtained solution
2622 /// *may* not be exact, i.e., it may contain integer points that do not have an
2623 /// integer pre-image in the original set.
2624 ///
2625 /// Eg:
2626 /// j >= 0, j <= i + 1
2627 /// i >= 0, i <= N + 1
2628 /// Eliminating i yields,
2629 ///   j >= 0, 0 <= N + 1, j - 1 <= N + 1
2630 ///
2631 /// If darkShadow = true, this method computes the dark shadow on elimination;
2632 /// the dark shadow is a convex integer subset of the exact integer shadow. A
2633 /// non-empty dark shadow proves the existence of an integer solution. The
2634 /// elimination in such a case could however be an under-approximation, and thus
2635 /// should not be used for scanning sets or used by itself for dependence
2636 /// checking.
2637 ///
2638 /// Eg: 2-d set, * represents grid points, 'o' represents a point in the set.
2639 ///            ^
2640 ///            |
2641 ///            | * * * * o o
2642 ///         i  | * * o o o o
2643 ///            | o * * * * *
2644 ///            --------------->
2645 ///                 j ->
2646 ///
2647 /// Eliminating i from this system (projecting on the j dimension):
2648 /// rational shadow / integer light shadow:  1 <= j <= 6
2649 /// dark shadow:                             3 <= j <= 6
2650 /// exact integer shadow:                    j = 1 \union  3 <= j <= 6
2651 /// holes/splinters:                         j = 2
2652 ///
2653 /// darkShadow = false, isResultIntegerExact = nullptr are default values.
2654 // TODO: a slight modification to yield dark shadow version of FM (tightened),
2655 // which can prove the existence of a solution if there is one.
2656 void FlatAffineConstraints::FourierMotzkinEliminate(
2657     unsigned pos, bool darkShadow, bool *isResultIntegerExact) {
2658   LLVM_DEBUG(llvm::dbgs() << "FM input (eliminate pos " << pos << "):\n");
2659   LLVM_DEBUG(dump());
2660   assert(pos < getNumIds() && "invalid position");
2661   assert(hasConsistentState());
2662
2663   // Check if this identifier can be eliminated through a substitution.
2664   for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
2665     if (atEq(r, pos) != 0) {
2666       // Use Gaussian elimination here (since we have an equality).
2667       LogicalResult ret = gaussianEliminateId(pos);
2668       (void)ret;
2669       assert(succeeded(ret) && "Gaussian elimination guaranteed to succeed");
2670       LLVM_DEBUG(llvm::dbgs() << "FM output (through Gaussian elimination):\n");
2671       LLVM_DEBUG(dump());
2672       return;
2673     }
2674   }
2675
2676   // A fast linear time tightening.
2677   GCDTightenInequalities();
2678
2679   // Check if the identifier appears at all in any of the inequalities.
2680   unsigned r, e;
2681   for (r = 0, e = getNumInequalities(); r < e; r++) {
2682     if (atIneq(r, pos) != 0)
2683       break;
2684   }
2685   if (r == getNumInequalities()) {
2686     // If it doesn't appear, just remove the column and return.
2687     // TODO: refactor removeColumns to use it from here.
2688     removeId(pos);
2689     LLVM_DEBUG(llvm::dbgs() << "FM output:\n");
2690     LLVM_DEBUG(dump());
2691     return;
2692   }
2693
2694   // Positions of constraints that are lower bounds on the variable.
2695   SmallVector<unsigned, 4> lbIndices;
2696   // Positions of constraints that are lower bounds on the variable.
2697   SmallVector<unsigned, 4> ubIndices;
2698   // Positions of constraints that do not involve the variable.
2699   std::vector<unsigned> nbIndices;
2700   nbIndices.reserve(getNumInequalities());
2701
2702   // Gather all lower bounds and upper bounds of the variable. Since the
2703   // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower
2704   // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1.
2705   for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
2706     if (atIneq(r, pos) == 0) {
2707       // Id does not appear in bound.
2708       nbIndices.push_back(r);
2709     } else if (atIneq(r, pos) >= 1) {
2710       // Lower bound.
2711       lbIndices.push_back(r);
2712     } else {
2713       // Upper bound.
2714       ubIndices.push_back(r);
2715     }
2716   }
2717
2718   // Set the number of dimensions, symbols in the resulting system.
2719   const auto &dimsSymbols = getNewNumDimsSymbols(pos, *this);
2720   unsigned newNumDims = dimsSymbols.first;
2721   unsigned newNumSymbols = dimsSymbols.second;
2722
2723   SmallVector<Optional<Value>, 8> newIds;
2724   newIds.reserve(numIds - 1);
2725   newIds.append(ids.begin(), ids.begin() + pos);
2726   newIds.append(ids.begin() + pos + 1, ids.end());
2727
2728   /// Create the new system which has one identifier less.
2729   FlatAffineConstraints newFac(
2730       lbIndices.size() * ubIndices.size() + nbIndices.size(),
2731       getNumEqualities(), getNumCols() - 1, newNumDims, newNumSymbols,
2732       /*numLocals=*/getNumIds() - 1 - newNumDims - newNumSymbols, newIds);
2733
2734   assert(newFac.getIds().size() == newFac.getNumIds());
2735
2736   // This will be used to check if the elimination was integer exact.
2737   unsigned lcmProducts = 1;
2738
2739   // Let x be the variable we are eliminating.
2740   // For each lower bound, lb <= c_l*x, and each upper bound c_u*x <= ub, (note
2741   // that c_l, c_u >= 1) we have:
2742   // lb*lcm(c_l, c_u)/c_l <= lcm(c_l, c_u)*x <= ub*lcm(c_l, c_u)/c_u
2743   // We thus generate a constraint:
2744   // lcm(c_l, c_u)/c_l*lb <= lcm(c_l, c_u)/c_u*ub.
2745   // Note if c_l = c_u = 1, all integer points captured by the resulting
2746   // constraint correspond to integer points in the original system (i.e., they
2747   // have integer pre-images). Hence, if the lcm's are all 1, the elimination is
2748   // integer exact.
2749   for (auto ubPos : ubIndices) {
2750     for (auto lbPos : lbIndices) {
2751       SmallVector<int64_t, 4> ineq;
2752       ineq.reserve(newFac.getNumCols());
2753       int64_t lbCoeff = atIneq(lbPos, pos);
2754       // Note that in the comments above, ubCoeff is the negation of the
2755       // coefficient in the canonical form as the view taken here is that of the
2756       // term being moved to the other size of '>='.
2757       int64_t ubCoeff = -atIneq(ubPos, pos);
2758       // TODO: refactor this loop to avoid all branches inside.
2759       for (unsigned l = 0, e = getNumCols(); l < e; l++) {
2760         if (l == pos)
2761           continue;
2762         assert(lbCoeff >= 1 && ubCoeff >= 1 && "bounds wrongly identified");
2763         int64_t lcm = mlir::lcm(lbCoeff, ubCoeff);
2764         ineq.push_back(atIneq(ubPos, l) * (lcm / ubCoeff) +
2765                        atIneq(lbPos, l) * (lcm / lbCoeff));
2766         lcmProducts *= lcm;
2767       }
2768       if (darkShadow) {
2769         // The dark shadow is a convex subset of the exact integer shadow. If
2770         // there is a point here, it proves the existence of a solution.
2771         ineq[ineq.size() - 1] += lbCoeff * ubCoeff - lbCoeff - ubCoeff + 1;
2772       }
2773       // TODO: we need to have a way to add inequalities in-place in
2774       // FlatAffineConstraints instead of creating and copying over.
2775       newFac.addInequality(ineq);
2776     }
2777   }
2778
2779   LLVM_DEBUG(llvm::dbgs() << "FM isResultIntegerExact: " << (lcmProducts == 1)
2780                           << "\n");
2781   if (lcmProducts == 1 && isResultIntegerExact)
2782     *isResultIntegerExact = true;
2783
2784   // Copy over the constraints not involving this variable.
2785   for (auto nbPos : nbIndices) {
2786     SmallVector<int64_t, 4> ineq;
2787     ineq.reserve(getNumCols() - 1);
2788     for (unsigned l = 0, e = getNumCols(); l < e; l++) {
2789       if (l == pos)
2790         continue;
2791       ineq.push_back(atIneq(nbPos, l));
2792     }
2793     newFac.addInequality(ineq);
2794   }
2795
2796   assert(newFac.getNumConstraints() ==
2797          lbIndices.size() * ubIndices.size() + nbIndices.size());
2798
2799   // Copy over the equalities.
2800   for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
2801     SmallVector<int64_t, 4> eq;
2802     eq.reserve(newFac.getNumCols());
2803     for (unsigned l = 0, e = getNumCols(); l < e; l++) {
2804       if (l == pos)
2805         continue;
2806       eq.push_back(atEq(r, l));
2807     }
2808     newFac.addEquality(eq);
2809   }
2810
2811   // GCD tightening and normalization allows detection of more trivially
2812   // redundant constraints.
2813   newFac.GCDTightenInequalities();
2814   newFac.normalizeConstraintsByGCD();
2815   newFac.removeTrivialRedundancy();
2816   clearAndCopyFrom(newFac);
2817   LLVM_DEBUG(llvm::dbgs() << "FM output:\n");
2818   LLVM_DEBUG(dump());
2819 }
2820
2821 #undef DEBUG_TYPE
2822 #define DEBUG_TYPE "affine-structures"
2823
2824 void FlatAffineConstraints::projectOut(unsigned pos, unsigned num) {
2825   if (num == 0)
2826     return;
2827
2828   // 'pos' can be at most getNumCols() - 2 if num > 0.
2829   assert((getNumCols() < 2 || pos <= getNumCols() - 2) && "invalid position");
2830   assert(pos + num < getNumCols() && "invalid range");
2831
2832   // Eliminate as many identifiers as possible using Gaussian elimination.
2833   unsigned currentPos = pos;
2834   unsigned numToEliminate = num;
2835   unsigned numGaussianEliminated = 0;
2836
2837   while (currentPos < getNumIds()) {
2838     unsigned curNumEliminated =
2839         gaussianEliminateIds(currentPos, currentPos + numToEliminate);
2840     ++currentPos;
2841     numToEliminate -= curNumEliminated + 1;
2842     numGaussianEliminated += curNumEliminated;
2843   }
2844
2845   // Eliminate the remaining using Fourier-Motzkin.
2846   for (unsigned i = 0; i < num - numGaussianEliminated; i++) {
2847     unsigned numToEliminate = num - numGaussianEliminated - i;
2848     FourierMotzkinEliminate(
2849         getBestIdToEliminate(*this, pos, pos + numToEliminate));
2850   }
2851
2852   // Fast/trivial simplifications.
2853   GCDTightenInequalities();
2854   // Normalize constraints after tightening since the latter impacts this, but
2855   // not the other way round.
2856   normalizeConstraintsByGCD();
2857 }
2858
2859 void FlatAffineConstraints::projectOut(Value id) {
2860   unsigned pos;
2861   bool ret = findId(id, &pos);
2862   assert(ret);
2863   (void)ret;
2864   FourierMotzkinEliminate(pos);
2865 }
2866
2867 void FlatAffineConstraints::clearConstraints() {
2868   equalities.clear();
2869   inequalities.clear();
2870 }
2871
2872 namespace {
2873
2874 enum BoundCmpResult { Greater, Less, Equal, Unknown };
2875
2876 /// Compares two affine bounds whose coefficients are provided in 'first' and
2877 /// 'second'. The last coefficient is the constant term.
2878 static BoundCmpResult compareBounds(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
2879   assert(a.size() == b.size());
2880
2881   // For the bounds to be comparable, their corresponding identifier
2882   // coefficients should be equal; the constant terms are then compared to
2883   // determine less/greater/equal.
2884
2885   if (!std::equal(a.begin(), a.end() - 1, b.begin()))
2886     return Unknown;
2887
2888   if (a.back() == b.back())
2889     return Equal;
2890
2891   return a.back() < b.back() ? Less : Greater;
2892 }
2893 } // namespace
2894
2895 // Returns constraints that are common to both A & B.
2896 static void getCommonConstraints(const FlatAffineConstraints &A,
2897                                  const FlatAffineConstraints &B,
2898                                  FlatAffineConstraints &C) {
2899   C.reset(A.getNumDimIds(), A.getNumSymbolIds(), A.getNumLocalIds());
2900   // A naive O(n^2) check should be enough here given the input sizes.
2901   for (unsigned r = 0, e = A.getNumInequalities(); r < e; ++r) {
2902     for (unsigned s = 0, f = B.getNumInequalities(); s < f; ++s) {
2903       if (A.getInequality(r) == B.getInequality(s)) {
2904         C.addInequality(A.getInequality(r));
2905         break;
2906       }
2907     }
2908   }
2909   for (unsigned r = 0, e = A.getNumEqualities(); r < e; ++r) {
2910     for (unsigned s = 0, f = B.getNumEqualities(); s < f; ++s) {
2911       if (A.getEquality(r) == B.getEquality(s)) {
2912         C.addEquality(A.getEquality(r));
2913         break;
2914       }
2915     }
2916   }
2917 }
2918
2919 // Computes the bounding box with respect to 'other' by finding the min of the
2920 // lower bounds and the max of the upper bounds along each of the dimensions.
2921 LogicalResult
2922 FlatAffineConstraints::unionBoundingBox(const FlatAffineConstraints &otherCst) {
2923   assert(otherCst.getNumDimIds() == numDims && "dims mismatch");
2924   assert(otherCst.getIds()
2925              .slice(0, getNumDimIds())
2926              .equals(getIds().slice(0, getNumDimIds())) &&
2927          "dim values mismatch");
2928   assert(otherCst.getNumLocalIds() == 0 && "local ids not supported here");
2929   assert(getNumLocalIds() == 0 && "local ids not supported yet here");
2930
2931   // Align `other` to this.
2932   Optional<FlatAffineConstraints> otherCopy;
2933   if (!areIdsAligned(*this, otherCst)) {
2934     otherCopy.emplace(FlatAffineConstraints(otherCst));
2935     mergeAndAlignIds(/*offset=*/numDims, this, &otherCopy.getValue());
2936   }
2937
2938   const auto &otherAligned = otherCopy ? *otherCopy : otherCst;
2939
2940   // Get the constraints common to both systems; these will be added as is to
2941   // the union.
2942   FlatAffineConstraints commonCst;
2943   getCommonConstraints(*this, otherAligned, commonCst);
2944
2945   std::vector<SmallVector<int64_t, 8>> boundingLbs;
2946   std::vector<SmallVector<int64_t, 8>> boundingUbs;
2947   boundingLbs.reserve(2 * getNumDimIds());
2948   boundingUbs.reserve(2 * getNumDimIds());
2949
2950   // To hold lower and upper bounds for each dimension.
2951   SmallVector<int64_t, 4> lb, otherLb, ub, otherUb;
2952   // To compute min of lower bounds and max of upper bounds for each dimension.
2953   SmallVector<int64_t, 4> minLb(getNumSymbolIds() + 1);
2954   SmallVector<int64_t, 4> maxUb(getNumSymbolIds() + 1);
2955   // To compute final new lower and upper bounds for the union.
2956   SmallVector<int64_t, 8> newLb(getNumCols()), newUb(getNumCols());
2957
2958   int64_t lbFloorDivisor, otherLbFloorDivisor;
2959   for (unsigned d = 0, e = getNumDimIds(); d < e; ++d) {
2960     auto extent = getConstantBoundOnDimSize(d, &lb, &lbFloorDivisor, &ub);
2961     if (!extent.hasValue())
2962       // TODO: symbolic extents when necessary.
2963       // TODO: handle union if a dimension is unbounded.
2964       return failure();
2965
2966     auto otherExtent = otherAligned.getConstantBoundOnDimSize(
2967         d, &otherLb, &otherLbFloorDivisor, &otherUb);
2968     if (!otherExtent.hasValue() || lbFloorDivisor != otherLbFloorDivisor)
2969       // TODO: symbolic extents when necessary.
2970       return failure();
2971
2972     assert(lbFloorDivisor > 0 && "divisor always expected to be positive");
2973
2974     auto res = compareBounds(lb, otherLb);
2975     // Identify min.
2976     if (res == BoundCmpResult::Less || res == BoundCmpResult::Equal) {
2977       minLb = lb;
2978       // Since the divisor is for a floordiv, we need to convert to ceildiv,
2979       // i.e., i >= expr floordiv div <=> i >= (expr - div + 1) ceildiv div <=>
2980       // div * i >= expr - div + 1.
2981       minLb.back() -= lbFloorDivisor - 1;
2982     } else if (res == BoundCmpResult::Greater) {
2983       minLb = otherLb;
2984       minLb.back() -= otherLbFloorDivisor - 1;
2985     } else {
2986       // Uncomparable - check for constant lower/upper bounds.
2987       auto constLb = getConstantLowerBound(d);
2988       auto constOtherLb = otherAligned.getConstantLowerBound(d);
2989       if (!constLb.hasValue() || !constOtherLb.hasValue())
2990         return failure();
2991       std::fill(minLb.begin(), minLb.end(), 0);
2992       minLb.back() = std::min(constLb.getValue(), constOtherLb.getValue());
2993     }
2994
2995     // Do the same for ub's but max of upper bounds. Identify max.
2996     auto uRes = compareBounds(ub, otherUb);
2997     if (uRes == BoundCmpResult::Greater || uRes == BoundCmpResult::Equal) {
2998       maxUb = ub;
2999     } else if (uRes == BoundCmpResult::Less) {
3000       maxUb = otherUb;
3001     } else {
3002       // Uncomparable - check for constant lower/upper bounds.
3003       auto constUb = getConstantUpperBound(d);
3004       auto constOtherUb = otherAligned.getConstantUpperBound(d);
3005       if (!constUb.hasValue() || !constOtherUb.hasValue())
3006         return failure();
3007       std::fill(maxUb.begin(), maxUb.end(), 0);
3008       maxUb.back() = std::max(constUb.getValue(), constOtherUb.getValue());
3009     }
3010
3011     std::fill(newLb.begin(), newLb.end(), 0);
3012     std::fill(newUb.begin(), newUb.end(), 0);
3013
3014     // The divisor for lb, ub, otherLb, otherUb at this point is lbDivisor,
3015     // and so it's the divisor for newLb and newUb as well.
3016     newLb[d] = lbFloorDivisor;
3017     newUb[d] = -lbFloorDivisor;
3018     // Copy over the symbolic part + constant term.
3019     std::copy(minLb.begin(), minLb.end(), newLb.begin() + getNumDimIds());
3020     std::transform(newLb.begin() + getNumDimIds(), newLb.end(),
3021                    newLb.begin() + getNumDimIds(), std::negate<int64_t>());
3022     std::copy(maxUb.begin(), maxUb.end(), newUb.begin() + getNumDimIds());
3023
3024     boundingLbs.push_back(newLb);
3025     boundingUbs.push_back(newUb);
3026   }
3027
3028   // Clear all constraints and add the lower/upper bounds for the bounding box.
3029   clearConstraints();
3030   for (unsigned d = 0, e = getNumDimIds(); d < e; ++d) {
3031     addInequality(boundingLbs[d]);
3032     addInequality(boundingUbs[d]);
3033   }
3034
3035   // Add the constraints that were common to both systems.
3036   append(commonCst);
3037   removeTrivialRedundancy();
3038
3039   // TODO: copy over pure symbolic constraints from this and 'other' over to the
3040   // union (since the above are just the union along dimensions); we shouldn't
3041   // be discarding any other constraints on the symbols.
3042
3043   return success();
3044 }
3045
3046 /// Compute an explicit representation for local vars. For all systems coming
3047 /// from MLIR integer sets, maps, or expressions where local vars were
3048 /// introduced to model floordivs and mods, this always succeeds.
3049 static LogicalResult computeLocalVars(const FlatAffineConstraints &cst,
3050                                       SmallVectorImpl<AffineExpr> &memo,
3051                                       MLIRContext *context) {
3052   unsigned numDims = cst.getNumDimIds();
3053   unsigned numSyms = cst.getNumSymbolIds();
3054
3055   // Initialize dimensional and symbolic identifiers.
3056   for (unsigned i = 0; i < numDims; i++)
3057     memo[i] = getAffineDimExpr(i, context);
3058   for (unsigned i = numDims, e = numDims + numSyms; i < e; i++)
3059     memo[i] = getAffineSymbolExpr(i - numDims, context);
3060
3061   bool changed;
3062   do {
3063     // Each time `changed` is true at the end of this iteration, one or more
3064     // local vars would have been detected as floordivs and set in memo; so the
3065     // number of null entries in memo[...] strictly reduces; so this converges.
3066     changed = false;
3067     for (unsigned i = 0, e = cst.getNumLocalIds(); i < e; ++i)
3068       if (!memo[numDims + numSyms + i] &&
3069           detectAsFloorDiv(cst, /*pos=*/numDims + numSyms + i, context, memo))
3070         changed = true;
3071   } while (changed);
3072
3073   ArrayRef<AffineExpr> localExprs =
3074       ArrayRef<AffineExpr>(memo).take_back(cst.getNumLocalIds());
3075   return success(
3076       llvm::all_of(localExprs, [](AffineExpr expr) { return expr; }));
3077 }
3078
3079 void FlatAffineConstraints::getIneqAsAffineValueMap(
3080     unsigned pos, unsigned ineqPos, AffineValueMap &vmap,
3081     MLIRContext *context) const {
3082   unsigned numDims = getNumDimIds();
3083   unsigned numSyms = getNumSymbolIds();
3084
3085   assert(pos < numDims && "invalid position");
3086   assert(ineqPos < getNumInequalities() && "invalid inequality position");
3087
3088   // Get expressions for local vars.
3089   SmallVector<AffineExpr, 8> memo(getNumIds(), AffineExpr());
3090   if (failed(computeLocalVars(*this, memo, context)))
3091     assert(false &&
3092            "one or more local exprs do not have an explicit representation");
3093   auto localExprs = ArrayRef<AffineExpr>(memo).take_back(getNumLocalIds());
3094
3095   // Compute the AffineExpr lower/upper bound for this inequality.
3096   ArrayRef<int64_t> inequality = getInequality(ineqPos);
3097   SmallVector<int64_t, 8> bound;
3098   bound.reserve(getNumCols() - 1);
3099   // Everything other than the coefficient at `pos`.
3100   bound.append(inequality.begin(), inequality.begin() + pos);
3101   bound.append(inequality.begin() + pos + 1, inequality.end());
3102
3103   if (inequality[pos] > 0)
3104     // Lower bound.
3105     std::transform(bound.begin(), bound.end(), bound.begin(),
3106                    std::negate<int64_t>());
3107   else
3108     // Upper bound (which is exclusive).
3109     bound.back() += 1;
3110
3111   // Convert to AffineExpr (tree) form.
3112   auto boundExpr = getAffineExprFromFlatForm(bound, numDims - 1, numSyms,
3113                                              localExprs, context);
3114
3115   // Get the values to bind to this affine expr (all dims and symbols).
3116   SmallVector<Value, 4> operands;
3117   getIdValues(0, pos, &operands);
3118   SmallVector<Value, 4> trailingOperands;
3119   getIdValues(pos + 1, getNumDimAndSymbolIds(), &trailingOperands);
3120   operands.append(trailingOperands.begin(), trailingOperands.end());
3121   vmap.reset(AffineMap::get(numDims - 1, numSyms, boundExpr), operands);
3122 }
3123
3124 /// Returns true if the pos^th column is all zero for both inequalities and
3125 /// equalities..
3126 static bool isColZero(const FlatAffineConstraints &cst, unsigned pos) {
3127   unsigned rowPos;
3128   return !findConstraintWithNonZeroAt(cst, pos, /*isEq=*/false, &rowPos) &&
3129          !findConstraintWithNonZeroAt(cst, pos, /*isEq=*/true, &rowPos);
3130 }
3131
3132 IntegerSet FlatAffineConstraints::getAsIntegerSet(MLIRContext *context) const {
3133   if (getNumConstraints() == 0)
3134     // Return universal set (always true): 0 == 0.
3135     return IntegerSet::get(getNumDimIds(), getNumSymbolIds(),
3136                            getAffineConstantExpr(/*constant=*/0, context),
3137                            /*eqFlags=*/true);
3138
3139   // Construct local references.
3140   SmallVector<AffineExpr, 8> memo(getNumIds(), AffineExpr());
3141
3142   if (failed(computeLocalVars(*this, memo, context))) {
3143     // Check if the local variables without an explicit representation have
3144     // zero coefficients everywhere.
3145     for (unsigned i = getNumDimAndSymbolIds(), e = getNumIds(); i < e; ++i) {
3146       if (!memo[i] && !isColZero(*this, /*pos=*/i)) {
3147         LLVM_DEBUG(llvm::dbgs() << "one or more local exprs do not have an "
3148                                    "explicit representation");
3149         return IntegerSet();
3150       }
3151     }
3152   }
3153
3154   ArrayRef<AffineExpr> localExprs =
3155       ArrayRef<AffineExpr>(memo).take_back(getNumLocalIds());
3156
3157   // Construct the IntegerSet from the equalities/inequalities.
3158   unsigned numDims = getNumDimIds();
3159   unsigned numSyms = getNumSymbolIds();
3160
3161   SmallVector<bool, 16> eqFlags(getNumConstraints());
3162   std::fill(eqFlags.begin(), eqFlags.begin() + getNumEqualities(), true);
3163   std::fill(eqFlags.begin() + getNumEqualities(), eqFlags.end(), false);
3164
3165   SmallVector<AffineExpr, 8> exprs;
3166   exprs.reserve(getNumConstraints());
3167
3168   for (unsigned i = 0, e = getNumEqualities(); i < e; ++i)
3169     exprs.push_back(getAffineExprFromFlatForm(getEquality(i), numDims, numSyms,
3170                                               localExprs, context));
3171   for (unsigned i = 0, e = getNumInequalities(); i < e; ++i)
3172     exprs.push_back(getAffineExprFromFlatForm(getInequality(i), numDims,
3173                                               numSyms, localExprs, context));
3174   return IntegerSet::get(numDims, numSyms, exprs, eqFlags);
3175 }
3176
3177 /// Find positions of inequalities and equalities that do not have a coefficient
3178 /// for [pos, pos + num) identifiers.
3179 static void getIndependentConstraints(const FlatAffineConstraints &cst,
3180                                       unsigned pos, unsigned num,
3181                                       SmallVectorImpl<unsigned> &nbIneqIndices,
3182                                       SmallVectorImpl<unsigned> &nbEqIndices) {
3183   assert(pos < cst.getNumIds() && "invalid start position");
3184   assert(pos + num <= cst.getNumIds() && "invalid limit");
3185
3186   for (unsigned r = 0, e = cst.getNumInequalities(); r < e; r++) {
3187     // The bounds are to be independent of [offset, offset + num) columns.
3188     unsigned c;
3189     for (c = pos; c < pos + num; ++c) {
3190       if (cst.atIneq(r, c) != 0)
3191         break;
3192     }
3193     if (c == pos + num)
3194       nbIneqIndices.push_back(r);
3195   }
3196
3197   for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) {
3198     // The bounds are to be independent of [offset, offset + num) columns.
3199     unsigned c;
3200     for (c = pos; c < pos + num; ++c) {
3201       if (cst.atEq(r, c) != 0)
3202         break;
3203     }
3204     if (c == pos + num)
3205       nbEqIndices.push_back(r);
3206   }
3207 }
3208
3209 void FlatAffineConstraints::removeIndependentConstraints(unsigned pos,
3210                                                          unsigned num) {
3211   assert(pos + num <= getNumIds() && "invalid range");
3212
3213   // Remove constraints that are independent of these identifiers.
3214   SmallVector<unsigned, 4> nbIneqIndices, nbEqIndices;
3215   getIndependentConstraints(*this, /*pos=*/0, num, nbIneqIndices, nbEqIndices);
3216
3217   // Iterate in reverse so that indices don't have to be updated.
3218   // TODO: This method can be made more efficient (because removal of each
3219   // inequality leads to much shifting/copying in the underlying buffer).
3220   for (auto nbIndex : llvm::reverse(nbIneqIndices))
3221     removeInequality(nbIndex);
3222   for (auto nbIndex : llvm::reverse(nbEqIndices))
3223     removeEquality(nbIndex);
3224 }