[mlir][sparse] add asserts on reading in tensor data
[lldb.git] / mlir / lib / ExecutionEngine / SparseUtils.cpp
index 376b989..d196266 100644 (file)
@@ -48,9 +48,9 @@ namespace {
 /// and a rank-5 tensor element like
 ///   ({i,j,k,l,m}, a[i,j,k,l,m])
 struct Element {
-  Element(const std::vector<int64_t> &ind, double val)
+  Element(const std::vector<uint64_t> &ind, double val)
       : indices(ind), value(val){};
-  std::vector<int64_t> indices;
+  std::vector<uint64_t> indices;
   double value;
 };
 
@@ -61,9 +61,15 @@ struct Element {
 /// formats require the elements to appear in lexicographic index order).
 struct SparseTensor {
 public:
-  SparseTensor(int64_t capacity) : pos(0) { elements.reserve(capacity); }
+  SparseTensor(const std::vector<uint64_t> &szs, uint64_t capacity)
+      : sizes(szs), pos(0) {
+    elements.reserve(capacity);
+  }
   // Add element as indices and value.
-  void add(const std::vector<int64_t> &ind, double val) {
+  void add(const std::vector<uint64_t> &ind, double val) {
+    assert(sizes.size() == ind.size());
+    for (int64_t r = 0, rank = sizes.size(); r < rank; r++)
+      assert(ind[r] < sizes[r]); // within bounds
     elements.emplace_back(Element(ind, val));
   }
   // Sort elements lexicographically by index.
@@ -82,6 +88,8 @@ private:
     }
     return false;
   }
+
+  std::vector<uint64_t> sizes; // per-rank dimension sizes
   std::vector<Element> elements;
   uint64_t pos;
 };
@@ -225,20 +233,24 @@ extern "C" void *openTensorC(char *filename, uint64_t *idata) {
     fprintf(stderr, "Unknown format %s\n", filename);
     exit(1);
   }
-  // Read all nonzero elements.
+  // Prepare sparse tensor object with per-rank dimension sizes
+  // and the number of nonzeros as initial capacity.
   uint64_t rank = idata[0];
   uint64_t nnz = idata[1];
-  SparseTensor *tensor = new SparseTensor(nnz);
-  std::vector<int64_t> indices(rank);
-  double value;
+  std::vector<uint64_t> indices(rank);
+  for (uint64_t r = 0; r < rank; r++)
+    indices[r] = idata[2 + r];
+  SparseTensor *tensor = new SparseTensor(indices, nnz);
+  // Read all nonzero elements.
   for (uint64_t k = 0; k < nnz; k++) {
     for (uint64_t r = 0; r < rank; r++) {
-      if (fscanf(file, "%" PRId64, &indices[r]) != 1) {
+      if (fscanf(file, "%" PRIu64, &indices[r]) != 1) {
         fprintf(stderr, "Cannot find next index in %s\n", filename);
         exit(1);
       }
       indices[r]--; // 0-based index
     }
+    double value;
     if (fscanf(file, "%lg\n", &value) != 1) {
       fprintf(stderr, "Cannot find next value in %s\n", filename);
       exit(1);