[mlir] Remove the use of "kinds" from Attributes and Types
authorRiver Riddle <riddleriver@gmail.com>
Tue, 18 Aug 2020 22:59:53 +0000 (15:59 -0700)
committerRiver Riddle <riddleriver@gmail.com>
Tue, 18 Aug 2020 23:20:14 +0000 (16:20 -0700)
This greatly simplifies a large portion of the underlying infrastructure, allows for lookups of singleton classes to be much more efficient and always thread-safe(no locking). As a result of this, the dialect symbol registry has been removed as it is no longer necessary.

For users broken by this change, an alert was sent out(https://llvm.discourse.group/t/removing-kinds-from-attributes-and-types) that helps prevent a majority of the breakage surface area. All that should be necessary, if the advice in that alert was followed, is removing the kind passed to the ::get methods.

Differential Revision: https://reviews.llvm.org/D86121

42 files changed:
flang/include/flang/Optimizer/Dialect/FIRAttr.h
flang/include/flang/Optimizer/Dialect/FIRType.h
flang/lib/Optimizer/Dialect/FIRAttr.cpp
flang/lib/Optimizer/Dialect/FIRType.cpp
mlir/docs/Tutorials/Toy/Ch-7.md
mlir/examples/toy/Ch7/include/toy/Dialect.h
mlir/examples/toy/Ch7/mlir/Dialect.cpp
mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h
mlir/include/mlir/Dialect/Quant/QuantTypes.h
mlir/include/mlir/Dialect/SPIRV/SPIRVAttributes.h
mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
mlir/include/mlir/Dialect/Shape/IR/Shape.h
mlir/include/mlir/IR/AttributeSupport.h
mlir/include/mlir/IR/Attributes.h
mlir/include/mlir/IR/Dialect.h
mlir/include/mlir/IR/DialectSymbolRegistry.def [deleted file]
mlir/include/mlir/IR/OpImplementation.h
mlir/include/mlir/IR/StandardTypes.h
mlir/include/mlir/IR/StorageUniquerSupport.h
mlir/include/mlir/IR/TypeSupport.h
mlir/include/mlir/IR/Types.h
mlir/include/mlir/Support/StorageUniquer.h
mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
mlir/lib/Dialect/SDBM/SDBMDialect.cpp
mlir/lib/Dialect/SDBM/SDBMExpr.cpp
mlir/lib/Dialect/SDBM/SDBMExprDetail.h
mlir/lib/Dialect/SPIRV/SPIRVAttributes.cpp
mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
mlir/lib/IR/AffineExpr.cpp
mlir/lib/IR/AffineExprDetail.h
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/IR/Attributes.cpp
mlir/lib/IR/Location.cpp
mlir/lib/IR/MLIRContext.cpp
mlir/lib/IR/StandardTypes.cpp
mlir/lib/IR/Types.cpp
mlir/lib/Support/StorageUniquer.cpp
mlir/test/lib/Dialect/Test/TestDialect.cpp
mlir/test/lib/Dialect/Test/TestTypes.h
mlir/test/lib/IR/TestTypes.cpp

index e9b1690..e000816 100644 (file)
@@ -25,17 +25,6 @@ struct RealAttributeStorage;
 struct TypeAttributeStorage;
 } // namespace detail
 
-enum AttributeKind {
-  FIR_ATTR = mlir::Attribute::FIRST_FIR_ATTR,
-  FIR_EXACTTYPE, // instance_of, precise type relation
-  FIR_SUBCLASS,  // subsumed_by, is-a (subclass) relation
-  FIR_POINT,
-  FIR_CLOSEDCLOSED_INTERVAL,
-  FIR_OPENCLOSED_INTERVAL,
-  FIR_CLOSEDOPEN_INTERVAL,
-  FIR_REAL_ATTR,
-};
-
 class ExactTypeAttr
     : public mlir::Attribute::AttrBase<ExactTypeAttr, mlir::Attribute,
                                        detail::TypeAttributeStorage> {
@@ -47,8 +36,6 @@ public:
   static ExactTypeAttr get(mlir::Type value);
 
   mlir::Type getType() const;
-
-  static constexpr unsigned getId() { return AttributeKind::FIR_EXACTTYPE; }
 };
 
 class SubclassAttr
@@ -62,8 +49,6 @@ public:
   static SubclassAttr get(mlir::Type value);
 
   mlir::Type getType() const;
-
-  static constexpr unsigned getId() { return AttributeKind::FIR_SUBCLASS; }
 };
 
 // Attributes for building SELECT CASE multiway branches
@@ -80,9 +65,6 @@ public:
 
   static constexpr llvm::StringRef getAttrName() { return "interval"; }
   static ClosedIntervalAttr get(mlir::MLIRContext *ctxt);
-  static constexpr unsigned getId() {
-    return AttributeKind::FIR_CLOSEDCLOSED_INTERVAL;
-  }
 };
 
 /// An upper bound is an open interval (including the bound value) as given as
@@ -97,9 +79,6 @@ public:
 
   static constexpr llvm::StringRef getAttrName() { return "upper"; }
   static UpperBoundAttr get(mlir::MLIRContext *ctxt);
-  static constexpr unsigned getId() {
-    return AttributeKind::FIR_OPENCLOSED_INTERVAL;
-  }
 };
 
 /// A lower bound is an open interval (including the bound value) as given as
@@ -114,9 +93,6 @@ public:
 
   static constexpr llvm::StringRef getAttrName() { return "lower"; }
   static LowerBoundAttr get(mlir::MLIRContext *ctxt);
-  static constexpr unsigned getId() {
-    return AttributeKind::FIR_CLOSEDOPEN_INTERVAL;
-  }
 };
 
 /// A pointer interval is a closed interval as given as an ssa-value. The
@@ -131,7 +107,6 @@ public:
 
   static constexpr llvm::StringRef getAttrName() { return "point"; }
   static PointIntervalAttr get(mlir::MLIRContext *ctxt);
-  static constexpr unsigned getId() { return AttributeKind::FIR_POINT; }
 };
 
 /// A real attribute is used to workaround MLIR's default parsing of a real
@@ -150,8 +125,6 @@ public:
 
   int getFKind() const;
   llvm::APFloat getValue() const;
-
-  static constexpr unsigned getId() { return AttributeKind::FIR_REAL_ATTR; }
 };
 
 mlir::Attribute parseFirAttribute(FIROpsDialect *dialect,
index 3d3125c..6d2aec2 100644 (file)
@@ -54,29 +54,6 @@ struct SequenceTypeStorage;
 struct TypeDescTypeStorage;
 } // namespace detail
 
-/// Integral identifier for all the types comprising the FIR type system
-enum TypeKind {
-  // The enum starts at the range reserved for this dialect.
-  FIR_TYPE = mlir::Type::FIRST_FIR_TYPE,
-  FIR_BOX,       // (static) descriptor
-  FIR_BOXCHAR,   // CHARACTER pointer and length
-  FIR_BOXPROC,   // procedure with host association
-  FIR_CHARACTER, // intrinsic type
-  FIR_COMPLEX,   // intrinsic type
-  FIR_DERIVED,   // derived
-  FIR_DIMS,
-  FIR_FIELD,
-  FIR_HEAP,
-  FIR_INT, // intrinsic type
-  FIR_LEN,
-  FIR_LOGICAL, // intrinsic type
-  FIR_POINTER, // POINTER attr
-  FIR_REAL,    // intrinsic type
-  FIR_REFERENCE,
-  FIR_SEQUENCE, // DIMENSION attr
-  FIR_TYPEDESC,
-};
-
 // These isa_ routines follow the precedent of llvm::isa_or_null<>
 
 /// Is `t` any of the FIR dialect types?
@@ -111,12 +88,6 @@ bool isa_aggregate(mlir::Type t);
 /// not a memory reference type, then returns a null `Type`.
 mlir::Type dyn_cast_ptrEleTy(mlir::Type t);
 
-/// Boilerplate mixin template
-template <typename A, unsigned Id>
-struct IntrinsicTypeMixin {
-  static constexpr unsigned getId() { return Id; }
-};
-
 // Intrinsic types
 
 /// Model of the Fortran CHARACTER intrinsic type, including the KIND type
@@ -124,8 +95,7 @@ struct IntrinsicTypeMixin {
 /// is thus the type of a single character value.
 class CharacterType
     : public mlir::Type::TypeBase<CharacterType, mlir::Type,
-                                  detail::CharacterTypeStorage>,
-      public IntrinsicTypeMixin<CharacterType, TypeKind::FIR_CHARACTER> {
+                                  detail::CharacterTypeStorage> {
 public:
   using Base::Base;
   static CharacterType get(mlir::MLIRContext *ctxt, KindTy kind);
@@ -136,8 +106,7 @@ public:
 /// parameter. COMPLEX is a floating point type with a real and imaginary
 /// member.
 class CplxType : public mlir::Type::TypeBase<CplxType, mlir::Type,
-                                             detail::CplxTypeStorage>,
-                 public IntrinsicTypeMixin<CplxType, TypeKind::FIR_COMPLEX> {
+                                             detail::CplxTypeStorage> {
 public:
   using Base::Base;
   static CplxType get(mlir::MLIRContext *ctxt, KindTy kind);
@@ -151,8 +120,7 @@ public:
 /// Model of a Fortran INTEGER intrinsic type, including the KIND type
 /// parameter.
 class IntType
-    : public mlir::Type::TypeBase<IntType, mlir::Type, detail::IntTypeStorage>,
-      public IntrinsicTypeMixin<IntType, TypeKind::FIR_INT> {
+    : public mlir::Type::TypeBase<IntType, mlir::Type, detail::IntTypeStorage> {
 public:
   using Base::Base;
   static IntType get(mlir::MLIRContext *ctxt, KindTy kind);
@@ -163,8 +131,7 @@ public:
 /// parameter.
 class LogicalType
     : public mlir::Type::TypeBase<LogicalType, mlir::Type,
-                                  detail::LogicalTypeStorage>,
-      public IntrinsicTypeMixin<LogicalType, TypeKind::FIR_LOGICAL> {
+                                  detail::LogicalTypeStorage> {
 public:
   using Base::Base;
   static LogicalType get(mlir::MLIRContext *ctxt, KindTy kind);
@@ -174,8 +141,7 @@ public:
 /// Model of a Fortran REAL (and DOUBLE PRECISION) intrinsic type, including the
 /// KIND type parameter.
 class RealType : public mlir::Type::TypeBase<RealType, mlir::Type,
-                                             detail::RealTypeStorage>,
-                 public IntrinsicTypeMixin<RealType, TypeKind::FIR_REAL> {
+                                             detail::RealTypeStorage> {
 public:
   using Base::Base;
   static RealType get(mlir::MLIRContext *ctxt, KindTy kind);
@@ -400,7 +366,6 @@ public:
   static RecordType get(mlir::MLIRContext *ctxt, llvm::StringRef name);
   void finalize(llvm::ArrayRef<TypePair> lenPList,
                 llvm::ArrayRef<TypePair> typeList);
-  static constexpr unsigned getId() { return TypeKind::FIR_DERIVED; }
 
   detail::RecordTypeStorage const *uniqueKey() const;
 
index 09780d3..0a219d1 100644 (file)
@@ -74,13 +74,13 @@ private:
 } // namespace detail
 
 ExactTypeAttr ExactTypeAttr::get(mlir::Type value) {
-  return Base::get(value.getContext(), FIR_EXACTTYPE, value);
+  return Base::get(value.getContext(), value);
 }
 
 mlir::Type ExactTypeAttr::getType() const { return getImpl()->getType(); }
 
 SubclassAttr SubclassAttr::get(mlir::Type value) {
-  return Base::get(value.getContext(), FIR_SUBCLASS, value);
+  return Base::get(value.getContext(), value);
 }
 
 mlir::Type SubclassAttr::getType() const { return getImpl()->getType(); }
@@ -88,26 +88,26 @@ mlir::Type SubclassAttr::getType() const { return getImpl()->getType(); }
 using AttributeUniquer = mlir::detail::AttributeUniquer;
 
 ClosedIntervalAttr ClosedIntervalAttr::get(mlir::MLIRContext *ctxt) {
-  return AttributeUniquer::get<ClosedIntervalAttr>(ctxt, getId());
+  return AttributeUniquer::get<ClosedIntervalAttr>(ctxt);
 }
 
 UpperBoundAttr UpperBoundAttr::get(mlir::MLIRContext *ctxt) {
-  return AttributeUniquer::get<UpperBoundAttr>(ctxt, getId());
+  return AttributeUniquer::get<UpperBoundAttr>(ctxt);
 }
 
 LowerBoundAttr LowerBoundAttr::get(mlir::MLIRContext *ctxt) {
-  return AttributeUniquer::get<LowerBoundAttr>(ctxt, getId());
+  return AttributeUniquer::get<LowerBoundAttr>(ctxt);
 }
 
 PointIntervalAttr PointIntervalAttr::get(mlir::MLIRContext *ctxt) {
-  return AttributeUniquer::get<PointIntervalAttr>(ctxt, getId());
+  return AttributeUniquer::get<PointIntervalAttr>(ctxt);
 }
 
 // RealAttr
 
 RealAttr RealAttr::get(mlir::MLIRContext *ctxt,
                        const RealAttr::ValueType &key) {
-  return Base::get(ctxt, getId(), key);
+  return Base::get(ctxt, key);
 }
 
 int RealAttr::getFKind() const { return getImpl()->getFKind(); }
index c29412b..e4c5480 100644 (file)
@@ -824,13 +824,11 @@ bool inbounds(A v, B lb, B ub) {
 }
 
 bool isa_fir_type(mlir::Type t) {
-  return inbounds(t.getKind(), mlir::Type::FIRST_FIR_TYPE,
-                  mlir::Type::LAST_FIR_TYPE);
+  return llvm::isa<FIROpsDialect>(t.getDialect());
 }
 
 bool isa_std_type(mlir::Type t) {
-  return inbounds(t.getKind(), mlir::Type::FIRST_STANDARD_TYPE,
-                  mlir::Type::LAST_STANDARD_TYPE);
+  return t.getDialect().getNamespace().empty();
 }
 
 bool isa_fir_or_std_type(mlir::Type t) {
@@ -868,7 +866,7 @@ mlir::Type dyn_cast_ptrEleTy(mlir::Type t) {
 // CHARACTER
 
 CharacterType fir::CharacterType::get(mlir::MLIRContext *ctxt, KindTy kind) {
-  return Base::get(ctxt, FIR_CHARACTER, kind);
+  return Base::get(ctxt, kind);
 }
 
 int fir::CharacterType::getFKind() const { return getImpl()->getFKind(); }
@@ -876,7 +874,7 @@ int fir::CharacterType::getFKind() const { return getImpl()->getFKind(); }
 // Dims
 
 DimsType fir::DimsType::get(mlir::MLIRContext *ctxt, unsigned rank) {
-  return Base::get(ctxt, FIR_DIMS, rank);
+  return Base::get(ctxt, rank);
 }
 
 unsigned fir::DimsType::getRank() const { return getImpl()->getRank(); }
@@ -884,19 +882,19 @@ unsigned fir::DimsType::getRank() const { return getImpl()->getRank(); }
 // Field
 
 FieldType fir::FieldType::get(mlir::MLIRContext *ctxt) {
-  return Base::get(ctxt, FIR_FIELD, 0);
+  return Base::get(ctxt, 0);
 }
 
 // Len
 
 LenType fir::LenType::get(mlir::MLIRContext *ctxt) {
-  return Base::get(ctxt, FIR_LEN, 0);
+  return Base::get(ctxt, 0);
 }
 
 // LOGICAL
 
 LogicalType fir::LogicalType::get(mlir::MLIRContext *ctxt, KindTy kind) {
-  return Base::get(ctxt, FIR_LOGICAL, kind);
+  return Base::get(ctxt, kind);
 }
 
 int fir::LogicalType::getFKind() const { return getImpl()->getFKind(); }
@@ -904,7 +902,7 @@ int fir::LogicalType::getFKind() const { return getImpl()->getFKind(); }
 // INTEGER
 
 IntType fir::IntType::get(mlir::MLIRContext *ctxt, KindTy kind) {
-  return Base::get(ctxt, FIR_INT, kind);
+  return Base::get(ctxt, kind);
 }
 
 int fir::IntType::getFKind() const { return getImpl()->getFKind(); }
@@ -912,7 +910,7 @@ int fir::IntType::getFKind() const { return getImpl()->getFKind(); }
 // COMPLEX
 
 CplxType fir::CplxType::get(mlir::MLIRContext *ctxt, KindTy kind) {
-  return Base::get(ctxt, FIR_COMPLEX, kind);
+  return Base::get(ctxt, kind);
 }
 
 mlir::Type fir::CplxType::getElementType() const {
@@ -924,7 +922,7 @@ KindTy fir::CplxType::getFKind() const { return getImpl()->getFKind(); }
 // REAL
 
 RealType fir::RealType::get(mlir::MLIRContext *ctxt, KindTy kind) {
-  return Base::get(ctxt, FIR_REAL, kind);
+  return Base::get(ctxt, kind);
 }
 
 int fir::RealType::getFKind() const { return getImpl()->getFKind(); }
@@ -932,7 +930,7 @@ int fir::RealType::getFKind() const { return getImpl()->getFKind(); }
 // Box<T>
 
 BoxType fir::BoxType::get(mlir::Type elementType, mlir::AffineMapAttr map) {
-  return Base::get(elementType.getContext(), FIR_BOX, elementType, map);
+  return Base::get(elementType.getContext(), elementType, map);
 }
 
 mlir::Type fir::BoxType::getEleTy() const {
@@ -953,7 +951,7 @@ fir::BoxType::verifyConstructionInvariants(mlir::Location, mlir::Type eleTy,
 // BoxChar<C>
 
 BoxCharType fir::BoxCharType::get(mlir::MLIRContext *ctxt, KindTy kind) {
-  return Base::get(ctxt, FIR_BOXCHAR, kind);
+  return Base::get(ctxt, kind);
 }
 
 CharacterType fir::BoxCharType::getEleTy() const {
@@ -963,7 +961,7 @@ CharacterType fir::BoxCharType::getEleTy() const {
 // BoxProc<T>
 
 BoxProcType fir::BoxProcType::get(mlir::Type elementType) {
-  return Base::get(elementType.getContext(), FIR_BOXPROC, elementType);
+  return Base::get(elementType.getContext(), elementType);
 }
 
 mlir::Type fir::BoxProcType::getEleTy() const {
@@ -984,7 +982,7 @@ fir::BoxProcType::verifyConstructionInvariants(mlir::Location loc,
 // Reference<T>
 
 ReferenceType fir::ReferenceType::get(mlir::Type elementType) {
-  return Base::get(elementType.getContext(), FIR_REFERENCE, elementType);
+  return Base::get(elementType.getContext(), elementType);
 }
 
 mlir::Type fir::ReferenceType::getEleTy() const {
@@ -1005,7 +1003,7 @@ fir::ReferenceType::verifyConstructionInvariants(mlir::Location loc,
 
 PointerType fir::PointerType::get(mlir::Type elementType) {
   assert(singleIndirectionLevel(elementType) && "invalid element type");
-  return Base::get(elementType.getContext(), FIR_POINTER, elementType);
+  return Base::get(elementType.getContext(), elementType);
 }
 
 mlir::Type fir::PointerType::getEleTy() const {
@@ -1033,7 +1031,7 @@ fir::PointerType::verifyConstructionInvariants(mlir::Location loc,
 
 HeapType fir::HeapType::get(mlir::Type elementType) {
   assert(singleIndirectionLevel(elementType) && "invalid element type");
-  return Base::get(elementType.getContext(), FIR_HEAP, elementType);
+  return Base::get(elementType.getContext(), elementType);
 }
 
 mlir::Type fir::HeapType::getEleTy() const {
@@ -1054,7 +1052,7 @@ fir::HeapType::verifyConstructionInvariants(mlir::Location loc,
 SequenceType fir::SequenceType::get(const Shape &shape, mlir::Type elementType,
                                     mlir::AffineMapAttr map) {
   auto *ctxt = elementType.getContext();
-  return Base::get(ctxt, FIR_SEQUENCE, shape, elementType, map);
+  return Base::get(ctxt, shape, elementType, map);
 }
 
 mlir::Type fir::SequenceType::getEleTy() const {
@@ -1136,7 +1134,7 @@ llvm::hash_code fir::hash_value(const SequenceType::Shape &sh) {
 /// This type captures a Fortran "derived type"
 
 RecordType fir::RecordType::get(mlir::MLIRContext *ctxt, llvm::StringRef name) {
-  return Base::get(ctxt, FIR_DERIVED, name);
+  return Base::get(ctxt, name);
 }
 
 void fir::RecordType::finalize(llvm::ArrayRef<TypePair> lenPList,
@@ -1179,7 +1177,7 @@ mlir::Type fir::RecordType::getType(llvm::StringRef ident) {
 
 TypeDescType fir::TypeDescType::get(mlir::Type ofType) {
   assert(!ofType.isa<ReferenceType>());
-  return Base::get(ofType.getContext(), FIR_TYPEDESC, ofType);
+  return Base::get(ofType.getContext(), ofType);
 }
 
 mlir::Type fir::TypeDescType::getOfTy() const { return getImpl()->getOfType(); }
@@ -1222,9 +1220,7 @@ void fir::verifyIntegralType(mlir::Type type) {
 void fir::printFirType(FIROpsDialect *, mlir::Type ty,
                        mlir::DialectAsmPrinter &p) {
   auto &os = p.getStream();
-  switch (ty.getKind()) {
-  case fir::FIR_BOX: {
-    auto type = ty.cast<BoxType>();
+  if (auto type = ty.dyn_cast<BoxType>()) {
     os << "box<";
     p.printType(type.getEleTy());
     if (auto map = type.getLayoutMap()) {
@@ -1232,24 +1228,28 @@ void fir::printFirType(FIROpsDialect *, mlir::Type ty,
       p.printAttribute(map);
     }
     os << '>';
-  } break;
-  case fir::FIR_BOXCHAR: {
-    auto type = ty.cast<BoxCharType>().getEleTy();
-    os << "boxchar<" << type.cast<fir::CharacterType>().getFKind() << '>';
-  } break;
-  case fir::FIR_BOXPROC:
+    return;
+  }
+  if (auto type = ty.dyn_cast<BoxCharType>()) {
+    os << "boxchar<" << type.getEleTy().cast<fir::CharacterType>().getFKind()
+       << '>';
+    return;
+  }
+  if (auto type = ty.dyn_cast<BoxProcType>()) {
     os << "boxproc<";
-    p.printType(ty.cast<BoxProcType>().getEleTy());
+    p.printType(type.getEleTy());
     os << '>';
-    break;
-  case fir::FIR_CHARACTER: // intrinsic
-    os << "char<" << ty.cast<CharacterType>().getFKind() << '>';
-    break;
-  case fir::FIR_COMPLEX: // intrinsic
-    os << "complex<" << ty.cast<CplxType>().getFKind() << '>';
-    break;
-  case fir::FIR_DERIVED: { // derived
-    auto type = ty.cast<fir::RecordType>();
+    return;
+  }
+  if (auto type = ty.dyn_cast<CharacterType>()) {
+    os << "char<" << type.getFKind() << '>';
+    return;
+  }
+  if (auto type = ty.dyn_cast<CplxType>()) {
+    os << "complex<" << type.getFKind() << '>';
+    return;
+  }
+  if (auto type = ty.dyn_cast<RecordType>()) {
     os << "type<" << type.getName();
     if (!recordTypeVisited.count(type.uniqueKey())) {
       recordTypeVisited.insert(type.uniqueKey());
@@ -1274,43 +1274,52 @@ void fir::printFirType(FIROpsDialect *, mlir::Type ty,
       recordTypeVisited.erase(type.uniqueKey());
     }
     os << '>';
-  } break;
-  case fir::FIR_DIMS:
-    os << "dims<" << ty.cast<DimsType>().getRank() << '>';
-    break;
-  case fir::FIR_FIELD:
+    return;
+  }
+  if (auto type = ty.dyn_cast<DimsType>()) {
+    os << "dims<" << type.getRank() << '>';
+    return;
+  }
+  if (ty.isa<FieldType>()) {
     os << "field";
-    break;
-  case fir::FIR_HEAP:
+    return;
+  }
+  if (auto type = ty.dyn_cast<HeapType>()) {
     os << "heap<";
-    p.printType(ty.cast<HeapType>().getEleTy());
+    p.printType(type.getEleTy());
     os << '>';
-    break;
-  case fir::FIR_INT: // intrinsic
-    os << "int<" << ty.cast<fir::IntType>().getFKind() << '>';
-    break;
-  case fir::FIR_LEN:
+    return;
+  }
+  if (auto type = ty.dyn_cast<fir::IntType>()) {
+    os << "int<" << type.getFKind() << '>';
+    return;
+  }
+  if (auto type = ty.dyn_cast<LenType>()) {
     os << "len";
-    break;
-  case fir::FIR_LOGICAL: // intrinsic
-    os << "logical<" << ty.cast<LogicalType>().getFKind() << '>';
-    break;
-  case fir::FIR_POINTER:
+    return;
+  }
+  if (auto type = ty.dyn_cast<LogicalType>()) {
+    os << "logical<" << type.getFKind() << '>';
+    return;
+  }
+  if (auto type = ty.dyn_cast<PointerType>()) {
     os << "ptr<";
-    p.printType(ty.cast<PointerType>().getEleTy());
+    p.printType(type.getEleTy());
     os << '>';
-    break;
-  case fir::FIR_REAL: // intrinsic
-    os << "real<" << ty.cast<fir::RealType>().getFKind() << '>';
-    break;
-  case fir::FIR_REFERENCE:
+    return;
+  }
+  if (auto type = ty.dyn_cast<fir::RealType>()) {
+    os << "real<" << type.getFKind() << '>';
+    return;
+  }
+  if (auto type = ty.dyn_cast<ReferenceType>()) {
     os << "ref<";
-    p.printType(ty.cast<ReferenceType>().getEleTy());
+    p.printType(type.getEleTy());
     os << '>';
-    break;
-  case fir::FIR_SEQUENCE: {
+    return;
+  }
+  if (auto type = ty.dyn_cast<SequenceType>()) {
     os << "array";
-    auto type = ty.cast<SequenceType>();
     auto shape = type.getShape();
     if (shape.size()) {
       printBounds(os, shape);
@@ -1323,11 +1332,12 @@ void fir::printFirType(FIROpsDialect *, mlir::Type ty,
       map.print(os);
     }
     os << '>';
-  } break;
-  case fir::FIR_TYPEDESC:
+    return;
+  }
+  if (auto type = ty.dyn_cast<TypeDescType>()) {
     os << "tdesc<";
-    p.printType(ty.cast<TypeDescType>().getOfTy());
+    p.printType(type.getOfTy());
     os << '>';
-    break;
+    return;
   }
 }
index cbab1e1..c20b8d9 100644 (file)
@@ -190,11 +190,10 @@ public:
     assert(!elementTypes.empty() && "expected at least 1 element type");
 
     // Call into a helper 'get' method in 'TypeBase' to get a uniqued instance
-    // of this type. The first two parameters are the context to unique in and
-    // the kind of the type. The parameters after the type kind are forwarded to
-    // the storage instance.
+    // of this type. The first parameter is the context to unique in. The
+    // parameters after the type kind are forwarded to the storage instance.
     mlir::MLIRContext *ctx = elementTypes.front().getContext();
-    return Base::get(ctx, ToyTypes::Struct, elementTypes);
+    return Base::get(ctx, elementTypes);
   }
 
   /// Returns the element types of this struct type.
index b695169..4eceb42 100644 (file)
@@ -63,13 +63,6 @@ public:
 // Toy Types
 //===----------------------------------------------------------------------===//
 
-/// Create a local enumeration with all of the types that are defined by Toy.
-namespace ToyTypes {
-enum Types {
-  Struct = mlir::Type::FIRST_TOY_TYPE,
-};
-} // end namespace ToyTypes
-
 /// This class defines the Toy struct type. It represents a collection of
 /// element types. All derived types in MLIR must inherit from the CRTP class
 /// 'Type::TypeBase'. It takes as template parameters the concrete type
index e233a55..04c796c 100644 (file)
@@ -474,11 +474,10 @@ StructType StructType::get(llvm::ArrayRef<mlir::Type> elementTypes) {
   assert(!elementTypes.empty() && "expected at least 1 element type");
 
   // Call into a helper 'get' method in 'TypeBase' to get a uniqued instance
-  // of this type. The first two parameters are the context to unique in and the
-  // kind of the type. The parameters after the type kind are forwarded to the
-  // storage instance.
+  // of this type. The first parameter is the context to unique in. The
+  // parameters after the type kind are forwarded to the storage instance.
   mlir::MLIRContext *ctx = elementTypes.front().getContext();
-  return Base::get(ctx, ToyTypes::Struct, elementTypes);
+  return Base::get(ctx, elementTypes);
 }
 
 /// Returns the element types of this struct type.
index b71964b..e9a62cf 100644 (file)
@@ -64,34 +64,6 @@ class LLVMIntegerType;
 /// structs, the entire type is the identifier) and are thread-safe.
 class LLVMType : public Type {
 public:
-  enum Kind {
-    // Keep non-parametric types contiguous in the enum.
-    VoidType = FIRST_LLVM_TYPE + 1,
-    HalfType,
-    BFloatType,
-    FloatType,
-    DoubleType,
-    FP128Type,
-    X86FP80Type,
-    PPCFP128Type,
-    X86MMXType,
-    LabelType,
-    TokenType,
-    MetadataType,
-    // End of non-parametric types.
-    FunctionType,
-    IntegerType,
-    PointerType,
-    FixedVectorType,
-    ScalableVectorType,
-    ArrayType,
-    StructType,
-    FIRST_NEW_LLVM_TYPE = VoidType,
-    LAST_NEW_LLVM_TYPE = StructType,
-    FIRST_TRIVIAL_TYPE = VoidType,
-    LAST_TRIVIAL_TYPE = MetadataType
-  };
-
   /// Inherit base constructors.
   using Type::Type;
 
@@ -256,27 +228,24 @@ public:
 //===----------------------------------------------------------------------===//
 
 // Batch-define trivial types.
-#define DEFINE_TRIVIAL_LLVM_TYPE(ClassName, Kind)                              \
+#define DEFINE_TRIVIAL_LLVM_TYPE(ClassName)                                    \
   class ClassName : public Type::TypeBase<ClassName, LLVMType, TypeStorage> {  \
   public:                                                                      \
     using Base::Base;                                                          \
-    static ClassName get(MLIRContext *context) {                               \
-      return Base::get(context, Kind);                                         \
-    }                                                                          \
   }
 
-DEFINE_TRIVIAL_LLVM_TYPE(LLVMVoidType, LLVMType::VoidType);
-DEFINE_TRIVIAL_LLVM_TYPE(LLVMHalfType, LLVMType::HalfType);
-DEFINE_TRIVIAL_LLVM_TYPE(LLVMBFloatType, LLVMType::BFloatType);
-DEFINE_TRIVIAL_LLVM_TYPE(LLVMFloatType, LLVMType::FloatType);
-DEFINE_TRIVIAL_LLVM_TYPE(LLVMDoubleType, LLVMType::DoubleType);
-DEFINE_TRIVIAL_LLVM_TYPE(LLVMFP128Type, LLVMType::FP128Type);
-DEFINE_TRIVIAL_LLVM_TYPE(LLVMX86FP80Type, LLVMType::X86FP80Type);
-DEFINE_TRIVIAL_LLVM_TYPE(LLVMPPCFP128Type, LLVMType::PPCFP128Type);
-DEFINE_TRIVIAL_LLVM_TYPE(LLVMX86MMXType, LLVMType::X86MMXType);
-DEFINE_TRIVIAL_LLVM_TYPE(LLVMTokenType, LLVMType::TokenType);
-DEFINE_TRIVIAL_LLVM_TYPE(LLVMLabelType, LLVMType::LabelType);
-DEFINE_TRIVIAL_LLVM_TYPE(LLVMMetadataType, LLVMType::MetadataType);
+DEFINE_TRIVIAL_LLVM_TYPE(LLVMVoidType);
+DEFINE_TRIVIAL_LLVM_TYPE(LLVMHalfType);
+DEFINE_TRIVIAL_LLVM_TYPE(LLVMBFloatType);
+DEFINE_TRIVIAL_LLVM_TYPE(LLVMFloatType);
+DEFINE_TRIVIAL_LLVM_TYPE(LLVMDoubleType);
+DEFINE_TRIVIAL_LLVM_TYPE(LLVMFP128Type);
+DEFINE_TRIVIAL_LLVM_TYPE(LLVMX86FP80Type);
+DEFINE_TRIVIAL_LLVM_TYPE(LLVMPPCFP128Type);
+DEFINE_TRIVIAL_LLVM_TYPE(LLVMX86MMXType);
+DEFINE_TRIVIAL_LLVM_TYPE(LLVMTokenType);
+DEFINE_TRIVIAL_LLVM_TYPE(LLVMLabelType);
+DEFINE_TRIVIAL_LLVM_TYPE(LLVMMetadataType);
 
 #undef DEFINE_TRIVIAL_LLVM_TYPE
 
index 17e803d..18b2c3a 100644 (file)
@@ -16,11 +16,6 @@ namespace mlir {
 class MLIRContext;
 
 namespace linalg {
-enum LinalgTypes {
-  Range = Type::FIRST_LINALG_TYPE,
-  LAST_USED_LINALG_TYPE = Range,
-};
-
 #include "mlir/Dialect/Linalg/IR/LinalgOpsDialect.h.inc"
 
 /// A RangeType represents a minimal range abstraction (min, max, step).
@@ -36,11 +31,6 @@ class RangeType : public Type::TypeBase<RangeType, Type, TypeStorage> {
 public:
   // Used for generic hooks in TypeBase.
   using Base::Base;
-  /// Construction hook.
-  static RangeType get(MLIRContext *context) {
-    /// Custom, uniq'ed construction in the MLIRContext.
-    return Base::get(context, LinalgTypes::Range);
-  }
 };
 
 } // namespace linalg
index ccdc289..567b639 100644 (file)
@@ -31,15 +31,6 @@ struct UniformQuantizedPerAxisTypeStorage;
 
 } // namespace detail
 
-namespace QuantizationTypes {
-enum Kind {
-  Any = Type::FIRST_QUANTIZATION_TYPE,
-  UniformQuantized,
-  UniformQuantizedPerAxis,
-  LAST_USED_QUANTIZATION_TYPE = UniformQuantizedPerAxis,
-};
-} // namespace QuantizationTypes
-
 /// Enumeration of bit-mapped flags related to quantized types.
 namespace QuantizationFlags {
 enum FlagValue {
index 6788d59..b1909b3 100644 (file)
@@ -32,15 +32,6 @@ struct TargetEnvAttributeStorage;
 struct VerCapExtAttributeStorage;
 } // namespace detail
 
-/// SPIR-V dialect-specific attribute kinds.
-namespace AttrKind {
-enum Kind {
-  InterfaceVarABI = Attribute::FIRST_SPIRV_ATTR, /// Interface var ABI
-  TargetEnv,                                     /// Target environment
-  VerCapExt, /// (version, extension, capability) triple
-};
-} // namespace AttrKind
-
 /// An attribute that specifies the information regarding the interface
 /// variable: descriptor set, binding, storage class.
 class InterfaceVarABIAttr
index a9d120b..2d224ef 100644 (file)
@@ -65,19 +65,6 @@ struct StructTypeStorage;
 
 } // namespace detail
 
-namespace TypeKind {
-enum Kind {
-  Array = Type::FIRST_SPIRV_TYPE,
-  CooperativeMatrix,
-  Image,
-  Matrix,
-  Pointer,
-  RuntimeArray,
-  Struct,
-  LAST_SPIRV_TYPE = Struct,
-};
-}
-
 // Base SPIR-V type for providing availability queries.
 class SPIRVType : public Type {
 public:
index 3168e87..cc601bd 100644 (file)
@@ -29,56 +29,28 @@ namespace shape {
 /// Alias type for extent tensors.
 RankedTensorType getExtentTensorType(MLIRContext *ctx);
 
-namespace ShapeTypes {
-enum Kind {
-  Component = Type::FIRST_SHAPE_TYPE,
-  Element,
-  Shape,
-  Size,
-  ValueShape,
-  Witness,
-  LAST_SHAPE_TYPE = Witness
-};
-} // namespace ShapeTypes
-
 /// The component type corresponding to shape, element type and attribute.
 class ComponentType : public Type::TypeBase<ComponentType, Type, TypeStorage> {
 public:
   using Base::Base;
-
-  static ComponentType get(MLIRContext *context) {
-    return Base::get(context, ShapeTypes::Kind::Component);
-  }
 };
 
 /// The element type of the shaped type.
 class ElementType : public Type::TypeBase<ElementType, Type, TypeStorage> {
 public:
   using Base::Base;
-
-  static ElementType get(MLIRContext *context) {
-    return Base::get(context, ShapeTypes::Kind::Element);
-  }
 };
 
 /// The shape descriptor type represents rank and dimension sizes.
 class ShapeType : public Type::TypeBase<ShapeType, Type, TypeStorage> {
 public:
   using Base::Base;
-
-  static ShapeType get(MLIRContext *context) {
-    return Base::get(context, ShapeTypes::Kind::Shape);
-  }
 };
 
 /// The type of a single dimension.
 class SizeType : public Type::TypeBase<SizeType, Type, TypeStorage> {
 public:
   using Base::Base;
-
-  static SizeType get(MLIRContext *context) {
-    return Base::get(context, ShapeTypes::Kind::Size);
-  }
 };
 
 /// The ValueShape represents a (potentially unknown) runtime value and shape.
@@ -86,10 +58,6 @@ class ValueShapeType
     : public Type::TypeBase<ValueShapeType, Type, TypeStorage> {
 public:
   using Base::Base;
-
-  static ValueShapeType get(MLIRContext *context) {
-    return Base::get(context, ShapeTypes::Kind::ValueShape);
-  }
 };
 
 /// The Witness represents a runtime constraint, to be used as shape related
@@ -97,10 +65,6 @@ public:
 class WitnessType : public Type::TypeBase<WitnessType, Type, TypeStorage> {
 public:
   using Base::Base;
-
-  static WitnessType get(MLIRContext *context) {
-    return Base::get(context, ShapeTypes::Kind::Witness);
-  }
 };
 
 #define GET_OP_CLASSES
index 31e6285..35084a2 100644 (file)
@@ -137,15 +137,23 @@ namespace detail {
 // MLIRContext. This class manages all creation and uniquing of attributes.
 class AttributeUniquer {
 public:
-  /// Get an uniqued instance of attribute T.
+  /// Get an uniqued instance of a parametric attribute T.
   template <typename T, typename... Args>
-  static T get(MLIRContext *ctx, unsigned kind, Args &&... args) {
+  static typename std::enable_if_t<
+      !std::is_same<typename T::ImplType, AttributeStorage>::value, T>
+  get(MLIRContext *ctx, Args &&...args) {
     return ctx->getAttributeUniquer().get<typename T::ImplType>(
-        T::getTypeID(),
         [ctx](AttributeStorage *storage) {
           initializeAttributeStorage(storage, ctx, T::getTypeID());
         },
-        kind, std::forward<Args>(args)...);
+        T::getTypeID(), std::forward<Args>(args)...);
+  }
+  /// Get an uniqued instance of a singleton attribute T.
+  template <typename T>
+  static typename std::enable_if_t<
+      std::is_same<typename T::ImplType, AttributeStorage>::value, T>
+  get(MLIRContext *ctx) {
+    return ctx->getAttributeUniquer().get<typename T::ImplType>(T::getTypeID());
   }
 
   template <typename T, typename... Args>
@@ -156,6 +164,26 @@ public:
                                              std::forward<Args>(args)...);
   }
 
+  /// Register a parametric attribute instance T with the uniquer.
+  template <typename T>
+  static typename std::enable_if_t<
+      !std::is_same<typename T::ImplType, AttributeStorage>::value>
+  registerAttribute(MLIRContext *ctx) {
+    ctx->getAttributeUniquer()
+        .registerParametricStorageType<typename T::ImplType>(T::getTypeID());
+  }
+  /// Register a singleton attribute instance T with the uniquer.
+  template <typename T>
+  static typename std::enable_if_t<
+      std::is_same<typename T::ImplType, AttributeStorage>::value>
+  registerAttribute(MLIRContext *ctx) {
+    ctx->getAttributeUniquer()
+        .registerSingletonStorageType<typename T::ImplType>(
+            T::getTypeID(), [ctx](AttributeStorage *storage) {
+              initializeAttributeStorage(storage, ctx, T::getTypeID());
+            });
+  }
+
 private:
   /// Initialize the given attribute storage instance.
   static void initializeAttributeStorage(AttributeStorage *storage,
index 75ac2ad..aa8f2ea 100644 (file)
@@ -54,14 +54,6 @@ struct SparseElementsAttributeStorage;
 /// passed by value.
 class Attribute {
 public:
-  /// Integer identifier for all the concrete attribute kinds.
-  enum Kind {
-  // Reserve attribute kinds for dialect specific extensions.
-#define DEFINE_SYM_KIND_RANGE(Dialect)                                         \
-  FIRST_##Dialect##_ATTR, LAST_##Dialect##_ATTR = FIRST_##Dialect##_ATTR + 0xff,
-#include "DialectSymbolRegistry.def"
-  };
-
   /// Utility class for implementing attributes.
   template <typename ConcreteType, typename BaseType, typename StorageType,
             template <typename T> class... Traits>
@@ -94,9 +86,6 @@ public:
   // Support dyn_cast'ing Attribute to itself.
   static bool classof(Attribute) { return true; }
 
-  /// Return the classification for this attribute.
-  unsigned getKind() const { return impl->getKind(); }
-
   /// Return a unique identifier for the concrete attribute type. This is used
   /// to support dynamic type casting.
   TypeID getTypeID() { return impl->getAbstractAttribute().getTypeID(); }
@@ -174,54 +163,6 @@ private:
 };
 
 //===----------------------------------------------------------------------===//
-// StandardAttributes
-//===----------------------------------------------------------------------===//
-
-namespace StandardAttributes {
-enum Kind {
-  AffineMap = Attribute::FIRST_STANDARD_ATTR,
-  Array,
-  Dictionary,
-  Float,
-  Integer,
-  IntegerSet,
-  Opaque,
-  String,
-  SymbolRef,
-  Type,
-  Unit,
-
-  /// Elements Attributes.
-  DenseIntOrFPElements,
-  DenseStringElements,
-  OpaqueElements,
-  SparseElements,
-  FIRST_ELEMENTS_ATTR = DenseIntOrFPElements,
-  LAST_ELEMENTS_ATTR = SparseElements,
-
-  /// Locations.
-  CallSiteLocation,
-  FileLineColLocation,
-  FusedLocation,
-  NameLocation,
-  OpaqueLocation,
-  UnknownLocation,
-
-  // Represents a location as a 'void*' pointer to a front-end's opaque
-  // location information, which must live longer than the MLIR objects that
-  // refer to it.  OpaqueLocation's are never serialized.
-  //
-  // TODO: OpaqueLocation,
-
-  // Represents a value inlined through a function call.
-  // TODO: InlinedLocation,
-
-  FIRST_LOCATION_ATTR = CallSiteLocation,
-  LAST_LOCATION_ATTR = UnknownLocation,
-};
-} // namespace StandardAttributes
-
-//===----------------------------------------------------------------------===//
 // AffineMapAttr
 //===----------------------------------------------------------------------===//
 
index 4f9e4cb..12a19af 100644 (file)
@@ -154,21 +154,15 @@ protected:
 
   void addOperation(AbstractOperation opInfo);
 
-  /// This method is used by derived classes to add their types to the set.
+  /// Register a set of type classes with this dialect.
   template <typename... Args> void addTypes() {
-    (void)std::initializer_list<int>{
-        0, (addType(Args::getTypeID(), AbstractType::get<Args>(*this)), 0)...};
+    (void)std::initializer_list<int>{0, (addType<Args>(), 0)...};
   }
-  void addType(TypeID typeID, AbstractType &&typeInfo);
 
-  /// This method is used by derived classes to add their attributes to the set.
+  /// Register a set of attribute classes with this dialect.
   template <typename... Args> void addAttributes() {
-    (void)std::initializer_list<int>{
-        0,
-        (addAttribute(Args::getTypeID(), AbstractAttribute::get<Args>(*this)),
-         0)...};
+    (void)std::initializer_list<int>{0, (addAttribute<Args>(), 0)...};
   }
-  void addAttribute(TypeID typeID, AbstractAttribute &&attrInfo);
 
   /// Enable support for unregistered operations.
   void allowUnknownOperations(bool allow = true) { unknownOpsAllowed = allow; }
@@ -189,6 +183,22 @@ private:
   Dialect(const Dialect &) = delete;
   void operator=(Dialect &) = delete;
 
+  /// Register an attribute instance with this dialect.
+  template <typename T> void addAttribute() {
+    // Add this attribute to the dialect and register it with the uniquer.
+    addAttribute(T::getTypeID(), AbstractAttribute::get<T>(*this));
+    detail::AttributeUniquer::registerAttribute<T>(context);
+  }
+  void addAttribute(TypeID typeID, AbstractAttribute &&attrInfo);
+
+  /// Register a type instance with this dialect.
+  template <typename T> void addType() {
+    // Add this type to the dialect and register it with the uniquer.
+    addType(T::getTypeID(), AbstractType::get<T>(*this));
+    detail::TypeUniquer::registerType<T>(context);
+  }
+  void addType(TypeID typeID, AbstractType &&typeInfo);
+
   /// The namespace of this dialect.
   StringRef name;
 
diff --git a/mlir/include/mlir/IR/DialectSymbolRegistry.def b/mlir/include/mlir/IR/DialectSymbolRegistry.def
deleted file mode 100644 (file)
index acba383..0000000
+++ /dev/null
@@ -1,44 +0,0 @@
-//===- DialectSymbolRegistry.def - MLIR Dialect Symbol Registry -*- C++ -*-===//
-//
-// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file enumerates the different dialects that define custom classes
-// within the attribute or type system.
-//
-//===----------------------------------------------------------------------===//
-
-DEFINE_SYM_KIND_RANGE(STANDARD)
-DEFINE_SYM_KIND_RANGE(TENSORFLOW_CONTROL)
-DEFINE_SYM_KIND_RANGE(TENSORFLOW_EXECUTOR)
-DEFINE_SYM_KIND_RANGE(TENSORFLOW)
-DEFINE_SYM_KIND_RANGE(LLVM)
-DEFINE_SYM_KIND_RANGE(QUANTIZATION)
-DEFINE_SYM_KIND_RANGE(IREE) // IREE stands for IR Execution Engine
-DEFINE_SYM_KIND_RANGE(LINALG) // Linear Algebra Dialect
-DEFINE_SYM_KIND_RANGE(FIR) // Flang Fortran IR Dialect
-DEFINE_SYM_KIND_RANGE(OPENACC) // OpenACC IR Dialect
-DEFINE_SYM_KIND_RANGE(OPENMP) // OpenMP IR Dialect
-DEFINE_SYM_KIND_RANGE(TOY) // Toy language (tutorial) Dialect
-DEFINE_SYM_KIND_RANGE(SPIRV) // SPIR-V dialect
-DEFINE_SYM_KIND_RANGE(XLA_HLO) // XLA HLO dialect
-DEFINE_SYM_KIND_RANGE(SHAPE) // Shape dialect
-DEFINE_SYM_KIND_RANGE(TF_FRAMEWORK) // TF Framework dialect
-
-// The following ranges are reserved for experimenting with MLIR dialects in a
-// private context without having to register them here.
-DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_0)
-DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_1)
-DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_2)
-DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_3)
-DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_4)
-DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_5)
-DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_6)
-DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_7)
-DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_8)
-DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_9)
-
-#undef DEFINE_SYM_KIND_RANGE
index 0124ef5..df54919 100644 (file)
@@ -756,7 +756,7 @@ public:
   /// all attributes of the given kind in the form : <alias>[0-9]+. These
   /// aliases must not contain `.`.
   virtual void getAttributeKindAliases(
-      SmallVectorImpl<std::pair<unsigned, StringRef>> &aliases) const {}
+      SmallVectorImpl<std::pair<TypeID, StringRef>> &aliases) const {}
   /// Hook for defining Attribute aliases. These aliases must not contain `.` or
   /// end with a numeric digit([0-9]+).
   virtual void getAttributeAliases(
index 6ceddec..e309595 100644 (file)
@@ -38,33 +38,6 @@ struct TupleTypeStorage;
 
 } // namespace detail
 
-namespace StandardTypes {
-enum Kind {
-  // Floating point.
-  BF16 = Type::Kind::FIRST_STANDARD_TYPE,
-  F16,
-  F32,
-  F64,
-  FIRST_FLOATING_POINT_TYPE = BF16,
-  LAST_FLOATING_POINT_TYPE = F64,
-
-  // Target pointer sized integer, used (e.g.) in affine mappings.
-  Index,
-
-  // Derived types.
-  Integer,
-  Vector,
-  RankedTensor,
-  UnrankedTensor,
-  MemRef,
-  UnrankedMemRef,
-  Complex,
-  Tuple,
-  None,
-};
-
-} // namespace StandardTypes
-
 //===----------------------------------------------------------------------===//
 // ComplexType
 //===----------------------------------------------------------------------===//
index 48026c2..75bc40a 100644 (file)
@@ -82,29 +82,29 @@ public:
     return detail::InterfaceMap::template get<Traits<ConcreteT>...>();
   }
 
-protected:
   /// Get or create a new ConcreteT instance within the ctx. This
   /// function is guaranteed to return a non null object and will assert if
   /// the arguments provided are invalid.
   template <typename... Args>
-  static ConcreteT get(MLIRContext *ctx, unsigned kind, Args... args) {
+  static ConcreteT get(MLIRContext *ctx, Args... args) {
     // Ensure that the invariants are correct for construction.
     assert(succeeded(ConcreteT::verifyConstructionInvariants(
         generateUnknownStorageLocation(ctx), args...)));
-    return UniquerT::template get<ConcreteT>(ctx, kind, args...);
+    return UniquerT::template get<ConcreteT>(ctx, args...);
   }
 
   /// Get or create a new ConcreteT instance within the ctx, defined at
   /// the given, potentially unknown, location. If the arguments provided are
   /// invalid then emit errors and return a null object.
   template <typename LocationT, typename... Args>
-  static ConcreteT getChecked(LocationT loc, unsigned kind, Args... args) {
+  static ConcreteT getChecked(LocationT loc, Args... args) {
     // If the construction invariants fail then we return a null attribute.
     if (failed(ConcreteT::verifyConstructionInvariants(loc, args...)))
       return ConcreteT();
-    return UniquerT::template get<ConcreteT>(loc.getContext(), kind, args...);
+    return UniquerT::template get<ConcreteT>(loc.getContext(), args...);
   }
 
+protected:
   /// Mutate the current storage instance. This will not change the unique key.
   /// The arguments are forwarded to 'ConcreteT::mutate'.
   template <typename... Args> LogicalResult mutate(Args &&...args) {
index aa2daef..ace5eaa 100644 (file)
@@ -121,15 +121,23 @@ namespace detail {
 /// A utility class to get, or create, unique instances of types within an
 /// MLIRContext. This class manages all creation and uniquing of types.
 struct TypeUniquer {
-  /// Get an uniqued instance of a type T.
+  /// Get an uniqued instance of a parametric type T.
   template <typename T, typename... Args>
-  static T get(MLIRContext *ctx, unsigned kind, Args &&... args) {
+  static typename std::enable_if_t<
+      !std::is_same<typename T::ImplType, TypeStorage>::value, T>
+  get(MLIRContext *ctx, Args &&...args) {
     return ctx->getTypeUniquer().get<typename T::ImplType>(
-        T::getTypeID(),
         [&](TypeStorage *storage) {
           storage->initialize(AbstractType::lookup(T::getTypeID(), ctx));
         },
-        kind, std::forward<Args>(args)...);
+        T::getTypeID(), std::forward<Args>(args)...);
+  }
+  /// Get an uniqued instance of a singleton type T.
+  template <typename T>
+  static typename std::enable_if_t<
+      std::is_same<typename T::ImplType, TypeStorage>::value, T>
+  get(MLIRContext *ctx) {
+    return ctx->getTypeUniquer().get<typename T::ImplType>(T::getTypeID());
   }
 
   /// Change the mutable component of the given type instance in the provided
@@ -141,6 +149,25 @@ struct TypeUniquer {
     return ctx->getTypeUniquer().mutate(T::getTypeID(), impl,
                                         std::forward<Args>(args)...);
   }
+
+  /// Register a parametric type instance T with the uniquer.
+  template <typename T>
+  static typename std::enable_if_t<
+      !std::is_same<typename T::ImplType, TypeStorage>::value>
+  registerType(MLIRContext *ctx) {
+    ctx->getTypeUniquer().registerParametricStorageType<typename T::ImplType>(
+        T::getTypeID());
+  }
+  /// Register a singleton type instance T with the uniquer.
+  template <typename T>
+  static typename std::enable_if_t<
+      std::is_same<typename T::ImplType, TypeStorage>::value>
+  registerType(MLIRContext *ctx) {
+    ctx->getTypeUniquer().registerSingletonStorageType<TypeStorage>(
+        T::getTypeID(), [&](TypeStorage *storage) {
+          storage->initialize(AbstractType::lookup(T::getTypeID(), ctx));
+        });
+  }
 };
 } // namespace detail
 
index 8101690..ad7e436 100644 (file)
@@ -34,11 +34,11 @@ struct OpaqueTypeStorage;
 ///
 /// Some types are "primitives" meaning they do not have any parameters, for
 /// example the Index type.  Parametric types have additional information that
-/// differentiates the types of the same kind between them, for example the
-/// Integer type has bitwidth, making i8 and i16 belong to the same kind by be
-/// different instances of the IntegerType.  Type parameters are part of the
-/// unique immutable key.  The mutable component of the type can be modified
-/// after the type is created, but cannot affect the identity of the type.
+/// differentiates the types of the same class, for example the Integer type has
+/// bitwidth, making i8 and i16 belong to the same kind by be different
+/// instances of the IntegerType. Type parameters are part of the unique
+/// immutable key.  The mutable component of the type can be modified after the
+/// type is created, but cannot affect the identity of the type.
 ///
 /// Types are constructed and uniqued via the 'detail::TypeUniquer' class.
 ///
@@ -53,20 +53,19 @@ struct OpaqueTypeStorage;
 ///      * This method is expected to return failure if a type cannot be
 ///        constructed with 'args', success otherwise.
 ///      * 'args' must correspond with the arguments passed into the
-///        'TypeBase::get' call after the type kind.
+///        'TypeBase::get' call.
 ///
 ///
 /// Type storage objects inherit from TypeStorage and contain the following:
-///    - The type kind (for LLVM-style RTTI).
 ///    - The dialect that defined the type.
 ///    - Any parameters of the type.
 ///    - An optional mutable component.
 /// For non-parametric types, a convenience DefaultTypeStorage is provided.
 /// Parametric storage types must derive TypeStorage and respect the following:
 ///    - Define a type alias, KeyTy, to a type that uniquely identifies the
-///      instance of the type within its kind.
+///      instance of the type.
 ///      * The key type must be constructible from the values passed into the
-///        detail::TypeUniquer::get call after the type kind.
+///        detail::TypeUniquer::get call.
 ///      * If the KeyTy does not have an llvm::DenseMapInfo specialization, the
 ///        storage class must define a hashing method:
 ///         'static unsigned hashKey(const KeyTy &)'
@@ -84,23 +83,6 @@ struct OpaqueTypeStorage;
 //       the key.
 class Type {
 public:
-  /// Integer identifier for all the concrete type kinds.
-  /// Note: This is not an enum class as each dialect will likely define a
-  /// separate enumeration for the specific types that they define. Not being an
-  /// enum class also simplifies the handling of type kinds by not requiring
-  /// casts for each use.
-  enum Kind {
-    // Builtin types.
-    Function,
-    Opaque,
-    LAST_BUILTIN_TYPE = Opaque,
-
-  // Reserve type kinds for dialect specific type system extensions.
-#define DEFINE_SYM_KIND_RANGE(Dialect)                                         \
-  FIRST_##Dialect##_TYPE, LAST_##Dialect##_TYPE = FIRST_##Dialect##_TYPE + 0xff,
-#include "DialectSymbolRegistry.def"
-  };
-
   /// Utility class for implementing types.
   template <typename ConcreteType, typename BaseType, typename StorageType,
             template <typename T> class... Traits>
@@ -136,9 +118,6 @@ public:
   /// dynamic type casting.
   TypeID getTypeID() { return impl->getAbstractType().getTypeID(); }
 
-  /// Return the classification for this type.
-  unsigned getKind() const;
-
   /// Return the LLVMContext in which this type was uniqued.
   MLIRContext *getContext() const;
 
index 6c7c7b0..eb04688 100644 (file)
 
 #include "mlir/Support/LLVM.h"
 #include "mlir/Support/LogicalResult.h"
+#include "mlir/Support/TypeID.h"
 #include "llvm/ADT/DenseSet.h"
 #include "llvm/Support/Allocator.h"
 
 namespace mlir {
-class TypeID;
-
 namespace detail {
 struct StorageUniquerImpl;
 
@@ -29,22 +28,19 @@ template <typename ImplTy, typename T>
 using has_impltype_hash_t = decltype(ImplTy::hashKey(std::declval<T>()));
 } // namespace detail
 
-/// A utility class to get, or create instances of storage classes. These
-/// storage classes must respect the following constraints:
-///    - Derive from StorageUniquer::BaseStorage.
-///    - Provide an unsigned 'kind' value to be used as part of the unique'ing
-///      process.
+/// A utility class to get or create instances of "storage classes". These
+/// storage classes must derive from 'StorageUniquer::BaseStorage'.
 ///
-/// For non-parametric storage classes, i.e. those that are solely uniqued by
-/// their kind, nothing else is needed. Instances of these classes can be
-/// created by calling `get` without trailing arguments.
+/// For non-parametric storage classes, i.e. singleton classes, nothing else is
+/// needed. Instances of these classes can be created by calling `get` without
+/// trailing arguments.
 ///
 /// Otherwise, the parametric storage classes may be created with `get`,
 /// and must respect the following:
 ///    - Define a type alias, KeyTy, to a type that uniquely identifies the
-///      instance of the storage class within its kind.
+///      instance of the storage class.
 ///      * The key type must be constructible from the values passed into the
-///        getComplex call after the kind.
+///        getComplex call.
 ///      * If the KeyTy does not have an llvm::DenseMapInfo specialization, the
 ///        storage class must define a hashing method:
 ///         'static unsigned hashKey(const KeyTy &)'
@@ -83,32 +79,11 @@ using has_impltype_hash_t = decltype(ImplTy::hashKey(std::declval<T>()));
 /// class.
 class StorageUniquer {
 public:
-  StorageUniquer();
-  ~StorageUniquer();
-
-  /// Set the flag specifying if multi-threading is disabled within the uniquer.
-  void disableMultithreading(bool disable = true);
-
-  /// Register a new storage object with this uniquer using the given unique
-  /// type id.
-  void registerStorageType(TypeID id);
-
   /// This class acts as the base storage that all storage classes must derived
   /// from.
   class BaseStorage {
-  public:
-    /// Get the kind classification of this storage.
-    unsigned getKind() const { return kind; }
-
   protected:
-    BaseStorage() : kind(0) {}
-
-  private:
-    /// Allow access to the kind field.
-    friend detail::StorageUniquerImpl;
-
-    /// Classification of the subclass, used for type checking.
-    unsigned kind;
+    BaseStorage() = default;
   };
 
   /// This is a utility allocator used to allocate memory for instances of
@@ -145,19 +120,61 @@ public:
     llvm::BumpPtrAllocator allocator;
   };
 
-  /// Gets a uniqued instance of 'Storage'. 'initFn' is an optional parameter
-  /// that can be used to initialize a newly inserted storage instance. This
-  /// function is used for derived types that have complex storage or uniquing
+  StorageUniquer();
+  ~StorageUniquer();
+
+  /// Set the flag specifying if multi-threading is disabled within the uniquer.
+  void disableMultithreading(bool disable = true);
+
+  /// Register a new parametric storage class, this is necessary to create
+  /// instances of this class type. `id` is the type identifier that will be
+  /// used to identify this type when creating instances of it via 'get'.
+  template <typename Storage> void registerParametricStorageType(TypeID id) {
+    registerParametricStorageTypeImpl(id);
+  }
+  /// Utility override when the storage type represents the type id.
+  template <typename Storage> void registerParametricStorageType() {
+    registerParametricStorageType<Storage>(TypeID::get<Storage>());
+  }
+  /// Register a new singleton storage class, this is necessary to get the
+  /// singletone instance. `id` is the type identifier that will be used to
+  /// access the singleton instance via 'get'. An optional initialization
+  /// function may also be provided to initialize the newly created storage
+  /// instance, and used when the singleton instance is created.
+  template <typename Storage>
+  void registerSingletonStorageType(TypeID id,
+                                    function_ref<void(Storage *)> initFn) {
+    auto ctorFn = [&](StorageAllocator &allocator) {
+      auto *storage = new (allocator.allocate<Storage>()) Storage();
+      if (initFn)
+        initFn(storage);
+      return storage;
+    };
+    registerSingletonImpl(id, ctorFn);
+  }
+  template <typename Storage> void registerSingletonStorageType(TypeID id) {
+    registerSingletonStorageType<Storage>(id, llvm::None);
+  }
+  /// Utility override when the storage type represents the type id.
+  template <typename Storage>
+  void registerSingletonStorageType(
+      function_ref<void(Storage *)> initFn = llvm::None) {
+    registerSingletonStorageType<Storage>(TypeID::get<Storage>(), initFn);
+  }
+
+  /// Gets a uniqued instance of 'Storage'. 'id' is the type id used when
+  /// registering the storage instance. 'initFn' is an optional parameter that
+  /// can be used to initialize a newly inserted storage instance. This function
+  /// is used for derived types that have complex storage or uniquing
   /// constraints.
-  template <typename Storage, typename Arg, typename... Args>
-  Storage *get(const TypeID &id, function_ref<void(Storage *)> initFn,
-               unsigned kind, Arg &&arg, Args &&...args) {
+  template <typename Storage, typename... Args>
+  Storage *get(function_ref<void(Storage *)> initFn, TypeID id,
+               Args &&...args) {
     // Construct a value of the derived key type.
-    auto derivedKey =
-        getKey<Storage>(std::forward<Arg>(arg), std::forward<Args>(args)...);
+    auto derivedKey = getKey<Storage>(std::forward<Args>(args)...);
 
-    // Create a hash of the kind and the derived key.
-    unsigned hashValue = getHash<Storage>(kind, derivedKey);
+    // Create a hash of the derived key.
+    unsigned hashValue = getHash<Storage>(derivedKey);
 
     // Generate an equality function for the derived storage.
     auto isEqual = [&derivedKey](const BaseStorage *existing) {
@@ -174,29 +191,29 @@ public:
 
     // Get an instance for the derived storage.
     return static_cast<Storage *>(
-        getImpl(id, kind, hashValue, isEqual, ctorFn));
+        getParametricStorageTypeImpl(id, hashValue, isEqual, ctorFn));
+  }
+  /// Utility override when the storage type represents the type id.
+  template <typename Storage, typename... Args>
+  Storage *get(function_ref<void(Storage *)> initFn, Args &&...args) {
+    return get<Storage>(initFn, TypeID::get<Storage>(),
+                        std::forward<Args>(args)...);
   }
 
-  /// Gets a uniqued instance of 'Storage'. 'initFn' is an optional parameter
-  /// that can be used to initialize a newly inserted storage instance. This
-  /// function is used for derived types that use no additional storage or
-  /// uniquing outside of the kind.
-  template <typename Storage>
-  Storage *get(const TypeID &id, function_ref<void(Storage *)> initFn,
-               unsigned kind) {
-    auto ctorFn = [&](StorageAllocator &allocator) {
-      auto *storage = new (allocator.allocate<Storage>()) Storage();
-      if (initFn)
-        initFn(storage);
-      return storage;
-    };
-    return static_cast<Storage *>(getImpl(id, kind, ctorFn));
+  /// Gets a uniqued instance of 'Storage' which is a singleton storage type.
+  /// 'id' is the type id used when registering the storage instance.
+  template <typename Storage> Storage *get(TypeID id) {
+    return static_cast<Storage *>(getSingletonImpl(id));
+  }
+  /// Utility override when the storage type represents the type id.
+  template <typename Storage> Storage *get() {
+    return get<Storage>(TypeID::get<Storage>());
   }
 
   /// Changes the mutable component of 'storage' by forwarding the trailing
   /// arguments to the 'mutate' function of the derived class.
   template <typename Storage, typename... Args>
-  LogicalResult mutate(const TypeID &id, Storage *storage, Args &&...args) {
+  LogicalResult mutate(TypeID id, Storage *storage, Args &&...args) {
     auto mutationFn = [&](StorageAllocator &allocator) -> LogicalResult {
       return static_cast<Storage &>(*storage).mutate(
           allocator, std::forward<Args>(args)...);
@@ -207,13 +224,13 @@ public:
   /// Erases a uniqued instance of 'Storage'. This function is used for derived
   /// types that have complex storage or uniquing constraints.
   template <typename Storage, typename Arg, typename... Args>
-  void erase(const TypeID &id, unsigned kind, Arg &&arg, Args &&...args) {
+  void erase(TypeID id, Arg &&arg, Args &&...args) {
     // Construct a value of the derived key type.
     auto derivedKey =
         getKey<Storage>(std::forward<Arg>(arg), std::forward<Args>(args)...);
 
-    // Create a hash of the kind and the derived key.
-    unsigned hashValue = getHash<Storage>(kind, derivedKey);
+    // Create a hash of the derived key.
+    unsigned hashValue = getHash<Storage>(derivedKey);
 
     // Generate an equality function for the derived storage.
     auto isEqual = [&derivedKey](const BaseStorage *existing) {
@@ -221,32 +238,42 @@ public:
     };
 
     // Attempt to erase the storage instance.
-    eraseImpl(id, kind, hashValue, isEqual, [](BaseStorage *storage) {
+    eraseImpl(id, hashValue, isEqual, [](BaseStorage *storage) {
       static_cast<Storage *>(storage)->cleanup();
     });
   }
 
 private:
   /// Implementation for getting/creating an instance of a derived type with
-  /// complex storage.
-  BaseStorage *getImpl(const TypeID &id, unsigned kind, unsigned hashValue,
-                       function_ref<bool(const BaseStorage *)> isEqual,
-                       function_ref<BaseStorage *(StorageAllocator &)> ctorFn);
+  /// parametric storage.
+  BaseStorage *getParametricStorageTypeImpl(
+      TypeID id, unsigned hashValue,
+      function_ref<bool(const BaseStorage *)> isEqual,
+      function_ref<BaseStorage *(StorageAllocator &)> ctorFn);
 
-  /// Implementation for getting/creating an instance of a derived type with
-  /// default storage.
-  BaseStorage *getImpl(const TypeID &id, unsigned kind,
-                       function_ref<BaseStorage *(StorageAllocator &)> ctorFn);
+  /// Implementation for registering an instance of a derived type with
+  /// parametric storage.
+  void registerParametricStorageTypeImpl(TypeID id);
+
+  /// Implementation for getting an instance of a derived type with default
+  /// storage.
+  BaseStorage *getSingletonImpl(TypeID id);
+
+  /// Implementation for registering an instance of a derived type with default
+  /// storage.
+  void
+  registerSingletonImpl(TypeID id,
+                        function_ref<BaseStorage *(StorageAllocator &)> ctorFn);
 
   /// Implementation for erasing an instance of a derived type with complex
   /// storage.
-  void eraseImpl(const TypeID &id, unsigned kind, unsigned hashValue,
+  void eraseImpl(TypeID id, unsigned hashValue,
                  function_ref<bool(const BaseStorage *)> isEqual,
                  function_ref<void(BaseStorage *)> cleanupFn);
 
   /// Implementation for mutating an instance of a derived storage.
   LogicalResult
-  mutateImpl(const TypeID &id,
+  mutateImpl(TypeID id,
              function_ref<LogicalResult(StorageAllocator &)> mutationFn);
 
   /// The internal implementation class.
@@ -276,27 +303,26 @@ private:
   }
 
   //===--------------------------------------------------------------------===//
-  // Key and Kind Hashing
+  // Key Hashing
   //===--------------------------------------------------------------------===//
 
-  /// Used to generate a hash for the 'ImplTy::KeyTy' and kind of a storage
-  /// instance if there is an 'ImplTy::hashKey' overload for 'DerivedKey'.
+  /// Used to generate a hash for the 'ImplTy::KeyTy' of a storage instance if
+  /// there is an 'ImplTy::hashKey' overload for 'DerivedKey'.
   template <typename ImplTy, typename DerivedKey>
   static typename std::enable_if<
       llvm::is_detected<detail::has_impltype_hash_t, ImplTy, DerivedKey>::value,
       ::llvm::hash_code>::type
-  getHash(unsigned kind, const DerivedKey &derivedKey) {
-    return llvm::hash_combine(kind, ImplTy::hashKey(derivedKey));
+  getHash(const DerivedKey &derivedKey) {
+    return ImplTy::hashKey(derivedKey);
   }
-  /// If there is no 'ImplTy::hashKey' default to using the
-  /// 'llvm::DenseMapInfo' definition for 'DerivedKey' for generating a hash.
+  /// If there is no 'ImplTy::hashKey' default to using the 'llvm::DenseMapInfo'
+  /// definition for 'DerivedKey' for generating a hash.
   template <typename ImplTy, typename DerivedKey>
   static typename std::enable_if<!llvm::is_detected<detail::has_impltype_hash_t,
                                                     ImplTy, DerivedKey>::value,
                                  ::llvm::hash_code>::type
-  getHash(unsigned kind, const DerivedKey &derivedKey) {
-    return llvm::hash_combine(
-        kind, DenseMapInfo<DerivedKey>::getHashValue(derivedKey));
+  getHash(const DerivedKey &derivedKey) {
+    return DenseMapInfo<DerivedKey>::getHashValue(derivedKey);
   }
 };
 } // end namespace mlir
index f4a278d..727efbb 100644 (file)
@@ -264,14 +264,13 @@ bool LLVMArrayType::isValidElementType(LLVMType type) {
 
 LLVMArrayType LLVMArrayType::get(LLVMType elementType, unsigned numElements) {
   assert(elementType && "expected non-null subtype");
-  return Base::get(elementType.getContext(), LLVMType::ArrayType, elementType,
-                   numElements);
+  return Base::get(elementType.getContext(), elementType, numElements);
 }
 
 LLVMArrayType LLVMArrayType::getChecked(Location loc, LLVMType elementType,
                                         unsigned numElements) {
   assert(elementType && "expected non-null subtype");
-  return Base::getChecked(loc, LLVMType::ArrayType, elementType, numElements);
+  return Base::getChecked(loc, elementType, numElements);
 }
 
 LLVMType LLVMArrayType::getElementType() { return getImpl()->elementType; }
@@ -301,16 +300,14 @@ LLVMFunctionType LLVMFunctionType::get(LLVMType result,
                                        ArrayRef<LLVMType> arguments,
                                        bool isVarArg) {
   assert(result && "expected non-null result");
-  return Base::get(result.getContext(), LLVMType::FunctionType, result,
-                   arguments, isVarArg);
+  return Base::get(result.getContext(), result, arguments, isVarArg);
 }
 
 LLVMFunctionType LLVMFunctionType::getChecked(Location loc, LLVMType result,
                                               ArrayRef<LLVMType> arguments,
                                               bool isVarArg) {
   assert(result && "expected non-null result");
-  return Base::getChecked(loc, LLVMType::FunctionType, result, arguments,
-                          isVarArg);
+  return Base::getChecked(loc, result, arguments, isVarArg);
 }
 
 LLVMType LLVMFunctionType::getReturnType() {
@@ -347,11 +344,11 @@ LogicalResult LLVMFunctionType::verifyConstructionInvariants(
 // Integer type.
 
 LLVMIntegerType LLVMIntegerType::get(MLIRContext *ctx, unsigned bitwidth) {
-  return Base::get(ctx, LLVMType::IntegerType, bitwidth);
+  return Base::get(ctx, bitwidth);
 }
 
 LLVMIntegerType LLVMIntegerType::getChecked(Location loc, unsigned bitwidth) {
-  return Base::getChecked(loc, LLVMType::IntegerType, bitwidth);
+  return Base::getChecked(loc, bitwidth);
 }
 
 unsigned LLVMIntegerType::getBitWidth() { return getImpl()->bitwidth; }
@@ -374,13 +371,12 @@ bool LLVMPointerType::isValidElementType(LLVMType type) {
 
 LLVMPointerType LLVMPointerType::get(LLVMType pointee, unsigned addressSpace) {
   assert(pointee && "expected non-null subtype");
-  return Base::get(pointee.getContext(), LLVMType::PointerType, pointee,
-                   addressSpace);
+  return Base::get(pointee.getContext(), pointee, addressSpace);
 }
 
 LLVMPointerType LLVMPointerType::getChecked(Location loc, LLVMType pointee,
                                             unsigned addressSpace) {
-  return Base::getChecked(loc, LLVMType::PointerType, pointee, addressSpace);
+  return Base::getChecked(loc, pointee, addressSpace);
 }
 
 LLVMType LLVMPointerType::getElementType() { return getImpl()->pointeeType; }
@@ -405,32 +401,32 @@ bool LLVMStructType::isValidElementType(LLVMType type) {
 
 LLVMStructType LLVMStructType::getIdentified(MLIRContext *context,
                                              StringRef name) {
-  return Base::get(context, LLVMType::StructType, name, /*opaque=*/false);
+  return Base::get(context, name, /*opaque=*/false);
 }
 
 LLVMStructType LLVMStructType::getIdentifiedChecked(Location loc,
                                                     StringRef name) {
-  return Base::getChecked(loc, LLVMType::StructType, name, /*opaque=*/false);
+  return Base::getChecked(loc, name, /*opaque=*/false);
 }
 
 LLVMStructType LLVMStructType::getLiteral(MLIRContext *context,
                                           ArrayRef<LLVMType> types,
                                           bool isPacked) {
-  return Base::get(context, LLVMType::StructType, types, isPacked);
+  return Base::get(context, types, isPacked);
 }
 
 LLVMStructType LLVMStructType::getLiteralChecked(Location loc,
                                                  ArrayRef<LLVMType> types,
                                                  bool isPacked) {
-  return Base::getChecked(loc, LLVMType::StructType, types, isPacked);
+  return Base::getChecked(loc, types, isPacked);
 }
 
 LLVMStructType LLVMStructType::getOpaque(StringRef name, MLIRContext *context) {
-  return Base::get(context, LLVMType::StructType, name, /*opaque=*/true);
+  return Base::get(context, name, /*opaque=*/true);
 }
 
 LLVMStructType LLVMStructType::getOpaqueChecked(Location loc, StringRef name) {
-  return Base::getChecked(loc, LLVMType::StructType, name, /*opaque=*/true);
+  return Base::getChecked(loc, name, /*opaque=*/true);
 }
 
 LogicalResult LLVMStructType::setBody(ArrayRef<LLVMType> types, bool isPacked) {
@@ -508,16 +504,14 @@ LLVMVectorType::verifyConstructionInvariants(Location loc, LLVMType elementType,
 LLVMFixedVectorType LLVMFixedVectorType::get(LLVMType elementType,
                                              unsigned numElements) {
   assert(elementType && "expected non-null subtype");
-  return Base::get(elementType.getContext(), LLVMType::FixedVectorType,
-                   elementType, numElements);
+  return Base::get(elementType.getContext(), elementType, numElements);
 }
 
 LLVMFixedVectorType LLVMFixedVectorType::getChecked(Location loc,
                                                     LLVMType elementType,
                                                     unsigned numElements) {
   assert(elementType && "expected non-null subtype");
-  return Base::getChecked(loc, LLVMType::FixedVectorType, elementType,
-                          numElements);
+  return Base::getChecked(loc, elementType, numElements);
 }
 
 unsigned LLVMFixedVectorType::getNumElements() {
@@ -527,16 +521,14 @@ unsigned LLVMFixedVectorType::getNumElements() {
 LLVMScalableVectorType LLVMScalableVectorType::get(LLVMType elementType,
                                                    unsigned minNumElements) {
   assert(elementType && "expected non-null subtype");
-  return Base::get(elementType.getContext(), LLVMType::ScalableVectorType,
-                   elementType, minNumElements);
+  return Base::get(elementType.getContext(), elementType, minNumElements);
 }
 
 LLVMScalableVectorType
 LLVMScalableVectorType::getChecked(Location loc, LLVMType elementType,
                                    unsigned minNumElements) {
   assert(elementType && "expected non-null subtype");
-  return Base::getChecked(loc, LLVMType::ScalableVectorType, elementType,
-                          minNumElements);
+  return Base::getChecked(loc, elementType, minNumElements);
 }
 
 unsigned LLVMScalableVectorType::getMinNumElements() {
index ef7d814..41e64d1 100644 (file)
@@ -204,8 +204,8 @@ AnyQuantizedType AnyQuantizedType::get(unsigned flags, Type storageType,
                                        Type expressedType,
                                        int64_t storageTypeMin,
                                        int64_t storageTypeMax) {
-  return Base::get(storageType.getContext(), QuantizationTypes::Any, flags,
-                   storageType, expressedType, storageTypeMin, storageTypeMax);
+  return Base::get(storageType.getContext(), flags, storageType, expressedType,
+                   storageTypeMin, storageTypeMax);
 }
 
 AnyQuantizedType AnyQuantizedType::getChecked(unsigned flags, Type storageType,
@@ -213,8 +213,8 @@ AnyQuantizedType AnyQuantizedType::getChecked(unsigned flags, Type storageType,
                                               int64_t storageTypeMin,
                                               int64_t storageTypeMax,
                                               Location location) {
-  return Base::getChecked(location, QuantizationTypes::Any, flags, storageType,
-                          expressedType, storageTypeMin, storageTypeMax);
+  return Base::getChecked(location, flags, storageType, expressedType,
+                          storageTypeMin, storageTypeMax);
 }
 
 LogicalResult AnyQuantizedType::verifyConstructionInvariants(
@@ -240,10 +240,8 @@ UniformQuantizedType UniformQuantizedType::get(unsigned flags, Type storageType,
                                                int64_t zeroPoint,
                                                int64_t storageTypeMin,
                                                int64_t storageTypeMax) {
-  return Base::get(storageType.getContext(),
-                   QuantizationTypes::UniformQuantized, flags, storageType,
-                   expressedType, scale, zeroPoint, storageTypeMin,
-                   storageTypeMax);
+  return Base::get(storageType.getContext(), flags, storageType, expressedType,
+                   scale, zeroPoint, storageTypeMin, storageTypeMax);
 }
 
 UniformQuantizedType
@@ -251,9 +249,8 @@ UniformQuantizedType::getChecked(unsigned flags, Type storageType,
                                  Type expressedType, double scale,
                                  int64_t zeroPoint, int64_t storageTypeMin,
                                  int64_t storageTypeMax, Location location) {
-  return Base::getChecked(location, QuantizationTypes::UniformQuantized, flags,
-                          storageType, expressedType, scale, zeroPoint,
-                          storageTypeMin, storageTypeMax);
+  return Base::getChecked(location, flags, storageType, expressedType, scale,
+                          zeroPoint, storageTypeMin, storageTypeMax);
 }
 
 LogicalResult UniformQuantizedType::verifyConstructionInvariants(
@@ -295,10 +292,9 @@ UniformQuantizedPerAxisType UniformQuantizedPerAxisType::get(
     ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
     int32_t quantizedDimension, int64_t storageTypeMin,
     int64_t storageTypeMax) {
-  return Base::get(storageType.getContext(),
-                   QuantizationTypes::UniformQuantizedPerAxis, flags,
-                   storageType, expressedType, scales, zeroPoints,
-                   quantizedDimension, storageTypeMin, storageTypeMax);
+  return Base::get(storageType.getContext(), flags, storageType, expressedType,
+                   scales, zeroPoints, quantizedDimension, storageTypeMin,
+                   storageTypeMax);
 }
 
 UniformQuantizedPerAxisType UniformQuantizedPerAxisType::getChecked(
@@ -306,9 +302,9 @@ UniformQuantizedPerAxisType UniformQuantizedPerAxisType::getChecked(
     ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
     int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax,
     Location location) {
-  return Base::getChecked(location, QuantizationTypes::UniformQuantizedPerAxis,
-                          flags, storageType, expressedType, scales, zeroPoints,
-                          quantizedDimension, storageTypeMin, storageTypeMax);
+  return Base::getChecked(location, flags, storageType, expressedType, scales,
+                          zeroPoints, quantizedDimension, storageTypeMin,
+                          storageTypeMax);
 }
 
 LogicalResult UniformQuantizedPerAxisType::verifyConstructionInvariants(
index 09c9d1d..4e3e050 100644 (file)
@@ -13,11 +13,11 @@ using namespace mlir;
 
 SDBMDialect::SDBMDialect(MLIRContext *context)
     : Dialect(getDialectNamespace(), context, TypeID::get<SDBMDialect>()) {
-  uniquer.registerStorageType(TypeID::get<detail::SDBMBinaryExprStorage>());
-  uniquer.registerStorageType(TypeID::get<detail::SDBMConstantExprStorage>());
-  uniquer.registerStorageType(TypeID::get<detail::SDBMDiffExprStorage>());
-  uniquer.registerStorageType(TypeID::get<detail::SDBMNegExprStorage>());
-  uniquer.registerStorageType(TypeID::get<detail::SDBMTermExprStorage>());
+  uniquer.registerParametricStorageType<detail::SDBMBinaryExprStorage>();
+  uniquer.registerParametricStorageType<detail::SDBMConstantExprStorage>();
+  uniquer.registerParametricStorageType<detail::SDBMDiffExprStorage>();
+  uniquer.registerParametricStorageType<detail::SDBMNegExprStorage>();
+  uniquer.registerParametricStorageType<detail::SDBMTermExprStorage>();
 }
 
 SDBMDialect::~SDBMDialect() = default;
index 435c7fe..8da6c40 100644 (file)
@@ -246,7 +246,6 @@ SDBMSumExpr SDBMSumExpr::get(SDBMTermExpr lhs, SDBMConstantExpr rhs) {
 
   StorageUniquer &uniquer = lhs.getDialect()->getUniquer();
   return uniquer.get<detail::SDBMBinaryExprStorage>(
-      TypeID::get<detail::SDBMBinaryExprStorage>(),
       /*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Add), lhs, rhs);
 }
 
@@ -533,9 +532,7 @@ SDBMDiffExpr SDBMDiffExpr::get(SDBMDirectExpr lhs, SDBMTermExpr rhs) {
   assert(rhs && "expected SDBM dimension");
 
   StorageUniquer &uniquer = lhs.getDialect()->getUniquer();
-  return uniquer.get<detail::SDBMDiffExprStorage>(
-      TypeID::get<detail::SDBMDiffExprStorage>(),
-      /*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Diff), lhs, rhs);
+  return uniquer.get<detail::SDBMDiffExprStorage>(/*initFn=*/{}, lhs, rhs);
 }
 
 SDBMDirectExpr SDBMDiffExpr::getLHS() const {
@@ -575,7 +572,6 @@ SDBMStripeExpr SDBMStripeExpr::get(SDBMDirectExpr var,
 
   StorageUniquer &uniquer = var.getDialect()->getUniquer();
   return uniquer.get<detail::SDBMBinaryExprStorage>(
-      TypeID::get<detail::SDBMBinaryExprStorage>(),
       /*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Stripe), var,
       stripeFactor);
 }
@@ -611,8 +607,7 @@ SDBMDimExpr SDBMDimExpr::get(SDBMDialect *dialect, unsigned position) {
 
   StorageUniquer &uniquer = dialect->getUniquer();
   return uniquer.get<detail::SDBMTermExprStorage>(
-      TypeID::get<detail::SDBMTermExprStorage>(), assignDialect,
-      static_cast<unsigned>(SDBMExprKind::DimId), position);
+      assignDialect, static_cast<unsigned>(SDBMExprKind::DimId), position);
 }
 
 //===----------------------------------------------------------------------===//
@@ -628,8 +623,7 @@ SDBMSymbolExpr SDBMSymbolExpr::get(SDBMDialect *dialect, unsigned position) {
 
   StorageUniquer &uniquer = dialect->getUniquer();
   return uniquer.get<detail::SDBMTermExprStorage>(
-      TypeID::get<detail::SDBMTermExprStorage>(), assignDialect,
-      static_cast<unsigned>(SDBMExprKind::SymbolId), position);
+      assignDialect, static_cast<unsigned>(SDBMExprKind::SymbolId), position);
 }
 
 //===----------------------------------------------------------------------===//
@@ -644,9 +638,7 @@ SDBMConstantExpr SDBMConstantExpr::get(SDBMDialect *dialect, int64_t value) {
   };
 
   StorageUniquer &uniquer = dialect->getUniquer();
-  return uniquer.get<detail::SDBMConstantExprStorage>(
-      TypeID::get<detail::SDBMConstantExprStorage>(), assignCtx,
-      static_cast<unsigned>(SDBMExprKind::Constant), value);
+  return uniquer.get<detail::SDBMConstantExprStorage>(assignCtx, value);
 }
 
 int64_t SDBMConstantExpr::getValue() const {
@@ -661,9 +653,7 @@ SDBMNegExpr SDBMNegExpr::get(SDBMDirectExpr var) {
   assert(var && "expected non-null SDBM direct expression");
 
   StorageUniquer &uniquer = var.getDialect()->getUniquer();
-  return uniquer.get<detail::SDBMNegExprStorage>(
-      TypeID::get<detail::SDBMNegExprStorage>(),
-      /*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Neg), var);
+  return uniquer.get<detail::SDBMNegExprStorage>(/*initFn=*/{}, var);
 }
 
 SDBMDirectExpr SDBMNegExpr::getVar() const {
index e344917..8d91334 100644 (file)
@@ -25,27 +25,28 @@ namespace detail {
 
 // Base storage class for SDBMExpr.
 struct SDBMExprStorage : public StorageUniquer::BaseStorage {
-  SDBMExprKind getKind() {
-    return static_cast<SDBMExprKind>(BaseStorage::getKind());
-  }
+  SDBMExprKind getKind() { return kind; }
 
   SDBMDialect *dialect;
+  SDBMExprKind kind;
 };
 
 // Storage class for SDBM sum and stripe expressions.
 struct SDBMBinaryExprStorage : public SDBMExprStorage {
-  using KeyTy = std::pair<SDBMDirectExpr, SDBMConstantExpr>;
+  using KeyTy = std::tuple<unsigned, SDBMDirectExpr, SDBMConstantExpr>;
 
   bool operator==(const KeyTy &key) const {
-    return std::get<0>(key) == lhs && std::get<1>(key) == rhs;
+    return static_cast<SDBMExprKind>(std::get<0>(key)) == kind &&
+           std::get<1>(key) == lhs && std::get<2>(key) == rhs;
   }
 
   static SDBMBinaryExprStorage *
   construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
     auto *result = allocator.allocate<SDBMBinaryExprStorage>();
-    result->lhs = std::get<0>(key);
-    result->rhs = std::get<1>(key);
+    result->lhs = std::get<1>(key);
+    result->rhs = std::get<2>(key);
     result->dialect = result->lhs.getDialect();
+    result->kind = static_cast<SDBMExprKind>(std::get<0>(key));
     return result;
   }
 
@@ -67,6 +68,7 @@ struct SDBMDiffExprStorage : public SDBMExprStorage {
     result->lhs = std::get<0>(key);
     result->rhs = std::get<1>(key);
     result->dialect = result->lhs.getDialect();
+    result->kind = SDBMExprKind::Diff;
     return result;
   }
 
@@ -84,6 +86,7 @@ struct SDBMConstantExprStorage : public SDBMExprStorage {
   construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
     auto *result = allocator.allocate<SDBMConstantExprStorage>();
     result->constant = key;
+    result->kind = SDBMExprKind::Constant;
     return result;
   }
 
@@ -92,14 +95,18 @@ struct SDBMConstantExprStorage : public SDBMExprStorage {
 
 // Storage class for SDBM dimension and symbol expressions.
 struct SDBMTermExprStorage : public SDBMExprStorage {
-  using KeyTy = unsigned;
+  using KeyTy = std::pair<unsigned, unsigned>;
 
-  bool operator==(const KeyTy &key) const { return position == key; }
+  bool operator==(const KeyTy &key) const {
+    return kind == static_cast<SDBMExprKind>(key.first) &&
+           position == key.second;
+  }
 
   static SDBMTermExprStorage *
   construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
     auto *result = allocator.allocate<SDBMTermExprStorage>();
-    result->position = key;
+    result->kind = static_cast<SDBMExprKind>(key.first);
+    result->position = key.second;
     return result;
   }
 
@@ -117,6 +124,7 @@ struct SDBMNegExprStorage : public SDBMExprStorage {
     auto *result = allocator.allocate<SDBMNegExprStorage>();
     result->expr = key;
     result->dialect = key.getDialect();
+    result->kind = SDBMExprKind::Neg;
     return result;
   }
 
index b2df52b..c2bf484 100644 (file)
@@ -120,8 +120,7 @@ spirv::InterfaceVarABIAttr::get(IntegerAttr descriptorSet, IntegerAttr binding,
                                 IntegerAttr storageClass) {
   assert(descriptorSet && binding);
   MLIRContext *context = descriptorSet.getContext();
-  return Base::get(context, spirv::AttrKind::InterfaceVarABI, descriptorSet,
-                   binding, storageClass);
+  return Base::get(context, descriptorSet, binding, storageClass);
 }
 
 StringRef spirv::InterfaceVarABIAttr::getKindName() {
@@ -195,8 +194,7 @@ spirv::VerCapExtAttr spirv::VerCapExtAttr::get(IntegerAttr version,
                                                ArrayAttr extensions) {
   assert(version && capabilities && extensions);
   MLIRContext *context = version.getContext();
-  return Base::get(context, spirv::AttrKind::VerCapExt, version, capabilities,
-                   extensions);
+  return Base::get(context, version, capabilities, extensions);
 }
 
 StringRef spirv::VerCapExtAttr::getKindName() { return "vce"; }
@@ -272,7 +270,7 @@ spirv::TargetEnvAttr spirv::TargetEnvAttr::get(spirv::VerCapExtAttr triple,
                                                DictionaryAttr limits) {
   assert(triple && limits && "expected valid triple and limits");
   MLIRContext *context = triple.getContext();
-  return Base::get(context, spirv::AttrKind::TargetEnv, triple, limits);
+  return Base::get(context, triple, limits);
 }
 
 StringRef spirv::TargetEnvAttr::getKindName() { return "target_env"; }
index b52ea81..e9cb4b2 100644 (file)
@@ -124,15 +124,14 @@ struct spirv::detail::ArrayTypeStorage : public TypeStorage {
 
 ArrayType ArrayType::get(Type elementType, unsigned elementCount) {
   assert(elementCount && "ArrayType needs at least one element");
-  return Base::get(elementType.getContext(), TypeKind::Array, elementType,
-                   elementCount, /*stride=*/0);
+  return Base::get(elementType.getContext(), elementType, elementCount,
+                   /*stride=*/0);
 }
 
 ArrayType ArrayType::get(Type elementType, unsigned elementCount,
                          unsigned stride) {
   assert(elementCount && "ArrayType needs at least one element");
-  return Base::get(elementType.getContext(), TypeKind::Array, elementType,
-                   elementCount, stride);
+  return Base::get(elementType.getContext(), elementType, elementCount, stride);
 }
 
 unsigned ArrayType::getNumElements() const { return getImpl()->elementCount; }
@@ -285,8 +284,7 @@ struct spirv::detail::CooperativeMatrixTypeStorage : public TypeStorage {
 CooperativeMatrixNVType CooperativeMatrixNVType::get(Type elementType,
                                                      Scope scope, unsigned rows,
                                                      unsigned columns) {
-  return Base::get(elementType.getContext(), TypeKind::CooperativeMatrix,
-                   elementType, scope, rows, columns);
+  return Base::get(elementType.getContext(), elementType, scope, rows, columns);
 }
 
 Type CooperativeMatrixNVType::getElementType() const {
@@ -389,7 +387,7 @@ ImageType
 ImageType::get(std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
                           ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>
                    value) {
-  return Base::get(std::get<0>(value).getContext(), TypeKind::Image, value);
+  return Base::get(std::get<0>(value).getContext(), value);
 }
 
 Type ImageType::getElementType() const { return getImpl()->elementType; }
@@ -453,8 +451,7 @@ struct spirv::detail::PointerTypeStorage : public TypeStorage {
 };
 
 PointerType PointerType::get(Type pointeeType, StorageClass storageClass) {
-  return Base::get(pointeeType.getContext(), TypeKind::Pointer, pointeeType,
-                   storageClass);
+  return Base::get(pointeeType.getContext(), pointeeType, storageClass);
 }
 
 Type PointerType::getPointeeType() const { return getImpl()->pointeeType; }
@@ -511,13 +508,11 @@ struct spirv::detail::RuntimeArrayTypeStorage : public TypeStorage {
 };
 
 RuntimeArrayType RuntimeArrayType::get(Type elementType) {
-  return Base::get(elementType.getContext(), TypeKind::RuntimeArray,
-                   elementType, /*stride=*/0);
+  return Base::get(elementType.getContext(), elementType, /*stride=*/0);
 }
 
 RuntimeArrayType RuntimeArrayType::get(Type elementType, unsigned stride) {
-  return Base::get(elementType.getContext(), TypeKind::RuntimeArray,
-                   elementType, stride);
+  return Base::get(elementType.getContext(), elementType, stride);
 }
 
 Type RuntimeArrayType::getElementType() const { return getImpl()->elementType; }
@@ -846,12 +841,12 @@ StructType::get(ArrayRef<Type> memberTypes,
   SmallVector<StructType::MemberDecorationInfo, 4> sortedDecorations(
       memberDecorations.begin(), memberDecorations.end());
   llvm::array_pod_sort(sortedDecorations.begin(), sortedDecorations.end());
-  return Base::get(memberTypes.vec().front().getContext(), TypeKind::Struct,
-                   memberTypes, offsetInfo, sortedDecorations);
+  return Base::get(memberTypes.vec().front().getContext(), memberTypes,
+                   offsetInfo, sortedDecorations);
 }
 
 StructType StructType::getEmpty(MLIRContext *context) {
-  return Base::get(context, TypeKind::Struct, ArrayRef<Type>(),
+  return Base::get(context, ArrayRef<Type>(),
                    ArrayRef<StructType::OffsetInfo>(),
                    ArrayRef<StructType::MemberDecorationInfo>());
 }
@@ -946,13 +941,12 @@ struct spirv::detail::MatrixTypeStorage : public TypeStorage {
 };
 
 MatrixType MatrixType::get(Type columnType, uint32_t columnCount) {
-  return Base::get(columnType.getContext(), TypeKind::Matrix, columnType,
-                   columnCount);
+  return Base::get(columnType.getContext(), columnType, columnCount);
 }
 
 MatrixType MatrixType::getChecked(Type columnType, uint32_t columnCount,
                                   Location location) {
-  return Base::getChecked(location, TypeKind::Matrix, columnType, columnCount);
+  return Base::getChecked(location, columnType, columnCount);
 }
 
 LogicalResult MatrixType::verifyConstructionInvariants(Location loc,
index 83d080f..fdecdc6 100644 (file)
@@ -20,9 +20,7 @@ using namespace mlir::detail;
 
 MLIRContext *AffineExpr::getContext() const { return expr->context; }
 
-AffineExprKind AffineExpr::getKind() const {
-  return static_cast<AffineExprKind>(expr->getKind());
-}
+AffineExprKind AffineExpr::getKind() const { return expr->kind; }
 
 /// Walk all of the AffineExprs in this subgraph in postorder.
 void AffineExpr::walk(std::function<void(AffineExpr)> callback) const {
@@ -449,8 +447,7 @@ static AffineExpr getAffineDimOrSymbol(AffineExprKind kind, unsigned position,
 
   StorageUniquer &uniquer = context->getAffineUniquer();
   return uniquer.get<AffineDimExprStorage>(
-      TypeID::get<AffineDimExprStorage>(), assignCtx,
-      static_cast<unsigned>(kind), position);
+      assignCtx, static_cast<unsigned>(kind), position);
 }
 
 AffineExpr mlir::getAffineDimExpr(unsigned position, MLIRContext *context) {
@@ -484,9 +481,7 @@ AffineExpr mlir::getAffineConstantExpr(int64_t constant, MLIRContext *context) {
   };
 
   StorageUniquer &uniquer = context->getAffineUniquer();
-  return uniquer.get<AffineConstantExprStorage>(
-      TypeID::get<AffineConstantExprStorage>(), assignCtx,
-      static_cast<unsigned>(AffineExprKind::Constant), constant);
+  return uniquer.get<AffineConstantExprStorage>(assignCtx, constant);
 }
 
 /// Simplify add expression. Return nullptr if it can't be simplified.
@@ -594,7 +589,6 @@ AffineExpr AffineExpr::operator+(AffineExpr other) const {
 
   StorageUniquer &uniquer = getContext()->getAffineUniquer();
   return uniquer.get<AffineBinaryOpExprStorage>(
-      TypeID::get<AffineBinaryOpExprStorage>(),
       /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Add), *this, other);
 }
 
@@ -655,7 +649,6 @@ AffineExpr AffineExpr::operator*(AffineExpr other) const {
 
   StorageUniquer &uniquer = getContext()->getAffineUniquer();
   return uniquer.get<AffineBinaryOpExprStorage>(
-      TypeID::get<AffineBinaryOpExprStorage>(),
       /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mul), *this, other);
 }
 
@@ -722,7 +715,6 @@ AffineExpr AffineExpr::floorDiv(AffineExpr other) const {
 
   StorageUniquer &uniquer = getContext()->getAffineUniquer();
   return uniquer.get<AffineBinaryOpExprStorage>(
-      TypeID::get<AffineBinaryOpExprStorage>(),
       /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::FloorDiv), *this,
       other);
 }
@@ -766,7 +758,6 @@ AffineExpr AffineExpr::ceilDiv(AffineExpr other) const {
 
   StorageUniquer &uniquer = getContext()->getAffineUniquer();
   return uniquer.get<AffineBinaryOpExprStorage>(
-      TypeID::get<AffineBinaryOpExprStorage>(),
       /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::CeilDiv), *this,
       other);
 }
@@ -814,7 +805,6 @@ AffineExpr AffineExpr::operator%(AffineExpr other) const {
 
   StorageUniquer &uniquer = getContext()->getAffineUniquer();
   return uniquer.get<AffineBinaryOpExprStorage>(
-      TypeID::get<AffineBinaryOpExprStorage>(),
       /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mod), *this, other);
 }
 
index ff47cd9..1c38e54 100644 (file)
@@ -27,21 +27,24 @@ namespace detail {
 /// Base storage class appearing in an affine expression.
 struct AffineExprStorage : public StorageUniquer::BaseStorage {
   MLIRContext *context;
+  AffineExprKind kind;
 };
 
 /// A binary operation appearing in an affine expression.
 struct AffineBinaryOpExprStorage : public AffineExprStorage {
-  using KeyTy = std::pair<AffineExpr, AffineExpr>;
+  using KeyTy = std::tuple<unsigned, AffineExpr, AffineExpr>;
 
   bool operator==(const KeyTy &key) const {
-    return key.first == lhs && key.second == rhs;
+    return static_cast<AffineExprKind>(std::get<0>(key)) == kind &&
+           std::get<1>(key) == lhs && std::get<2>(key) == rhs;
   }
 
   static AffineBinaryOpExprStorage *
   construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
     auto *result = allocator.allocate<AffineBinaryOpExprStorage>();
-    result->lhs = key.first;
-    result->rhs = key.second;
+    result->kind = static_cast<AffineExprKind>(std::get<0>(key));
+    result->lhs = std::get<1>(key);
+    result->rhs = std::get<2>(key);
     result->context = result->lhs.getContext();
     return result;
   }
@@ -52,14 +55,18 @@ struct AffineBinaryOpExprStorage : public AffineExprStorage {
 
 /// A dimensional or symbolic identifier appearing in an affine expression.
 struct AffineDimExprStorage : public AffineExprStorage {
-  using KeyTy = unsigned;
+  using KeyTy = std::pair<unsigned, unsigned>;
 
-  bool operator==(const KeyTy &key) const { return position == key; }
+  bool operator==(const KeyTy &key) const {
+    return kind == static_cast<AffineExprKind>(key.first) &&
+           position == key.second;
+  }
 
   static AffineDimExprStorage *
   construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
     auto *result = allocator.allocate<AffineDimExprStorage>();
-    result->position = key;
+    result->kind = static_cast<AffineExprKind>(key.first);
+    result->position = key.second;
     return result;
   }
 
@@ -76,6 +83,7 @@ struct AffineConstantExprStorage : public AffineExprStorage {
   static AffineConstantExprStorage *
   construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
     auto *result = allocator.allocate<AffineConstantExprStorage>();
+    result->kind = AffineExprKind::Constant;
     result->constant = key;
     return result;
   }
index 61eecb8..2247fe3 100644 (file)
@@ -271,7 +271,7 @@ private:
   /// Mapping between attribute kind and a pair comprised of a base alias name
   /// and a unique list of attributes belonging to this kind sorted by location
   /// seen in the module.
-  llvm::MapVector<unsigned, std::pair<StringRef, std::vector<Attribute>>>
+  llvm::MapVector<TypeID, std::pair<StringRef, std::vector<Attribute>>>
       attrKindToAlias;
 
   /// Set of types known to be used within the module.
@@ -301,13 +301,13 @@ void AliasState::initialize(
   llvm::StringSet<> usedAliases;
 
   // Collect the set of aliases from each dialect.
-  SmallVector<std::pair<unsigned, StringRef>, 8> attributeKindAliases;
+  SmallVector<std::pair<TypeID, StringRef>, 8> attributeKindAliases;
   SmallVector<std::pair<Attribute, StringRef>, 8> attributeAliases;
   SmallVector<std::pair<Type, StringRef>, 16> typeAliases;
 
   // AffineMap/Integer set have specific kind aliases.
-  attributeKindAliases.emplace_back(StandardAttributes::AffineMap, "map");
-  attributeKindAliases.emplace_back(StandardAttributes::IntegerSet, "set");
+  attributeKindAliases.emplace_back(AffineMapAttr::getTypeID(), "map");
+  attributeKindAliases.emplace_back(IntegerSetAttr::getTypeID(), "set");
 
   for (auto &interface : interfaces) {
     interface.getAttributeKindAliases(attributeKindAliases);
@@ -317,7 +317,7 @@ void AliasState::initialize(
 
   // Setup the attribute kind aliases.
   StringRef alias;
-  unsigned attrKind;
+  TypeID attrKind;
   for (auto &attrAliasPair : attributeKindAliases) {
     std::tie(attrKind, alias) = attrAliasPair;
     assert(!alias.empty() && "expected non-empty alias string");
@@ -420,7 +420,7 @@ void AliasState::recordAttributeReference(Attribute attr) {
     return;
 
   // If this attribute kind has an alias, then record one for this attribute.
-  auto alias = attrKindToAlias.find(static_cast<unsigned>(attr.getKind()));
+  auto alias = attrKindToAlias.find(attr.getTypeID());
   if (alias == attrKindToAlias.end())
     return;
   std::pair<StringRef, int> attrAlias(alias->second.first,
index dba7872..ac51cba 100644 (file)
@@ -57,7 +57,7 @@ Dialect &Attribute::getDialect() const {
 //===----------------------------------------------------------------------===//
 
 AffineMapAttr AffineMapAttr::get(AffineMap value) {
-  return Base::get(value.getContext(), StandardAttributes::AffineMap, value);
+  return Base::get(value.getContext(), value);
 }
 
 AffineMap AffineMapAttr::getValue() const { return getImpl()->value; }
@@ -67,7 +67,7 @@ AffineMap AffineMapAttr::getValue() const { return getImpl()->value; }
 //===----------------------------------------------------------------------===//
 
 ArrayAttr ArrayAttr::get(ArrayRef<Attribute> value, MLIRContext *context) {
-  return Base::get(context, StandardAttributes::Array, value);
+  return Base::get(context, value);
 }
 
 ArrayRef<Attribute> ArrayAttr::getValue() const { return getImpl()->value; }
@@ -156,7 +156,7 @@ DictionaryAttr DictionaryAttr::get(ArrayRef<NamedAttribute> value,
   if (dictionaryAttrSort</*inPlace=*/false>(value, storage))
     value = storage;
 
-  return Base::get(context, StandardAttributes::Dictionary, value);
+  return Base::get(context, value);
 }
 /// Construct a dictionary with an array of values that is known to already be
 /// sorted by name and uniqued.
@@ -175,7 +175,7 @@ DictionaryAttr DictionaryAttr::getWithSorted(ArrayRef<NamedAttribute> value,
                               return l.first == r.first;
                             }) == value.end() &&
          "DictionaryAttr element names must be unique");
-  return Base::get(context, StandardAttributes::Dictionary, value);
+  return Base::get(context, value);
 }
 
 ArrayRef<NamedAttribute> DictionaryAttr::getValue() const {
@@ -219,19 +219,19 @@ size_t DictionaryAttr::size() const { return getValue().size(); }
 //===----------------------------------------------------------------------===//
 
 FloatAttr FloatAttr::get(Type type, double value) {
-  return Base::get(type.getContext(), StandardAttributes::Float, type, value);
+  return Base::get(type.getContext(), type, value);
 }
 
 FloatAttr FloatAttr::getChecked(Type type, double value, Location loc) {
-  return Base::getChecked(loc, StandardAttributes::Float, type, value);
+  return Base::getChecked(loc, type, value);
 }
 
 FloatAttr FloatAttr::get(Type type, const APFloat &value) {
-  return Base::get(type.getContext(), StandardAttributes::Float, type, value);
+  return Base::get(type.getContext(), type, value);
 }
 
 FloatAttr FloatAttr::getChecked(Type type, const APFloat &value, Location loc) {
-  return Base::getChecked(loc, StandardAttributes::Float, type, value);
+  return Base::getChecked(loc, type, value);
 }
 
 APFloat FloatAttr::getValue() const { return getImpl()->getValue(); }
@@ -279,14 +279,13 @@ LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type,
 //===----------------------------------------------------------------------===//
 
 FlatSymbolRefAttr SymbolRefAttr::get(StringRef value, MLIRContext *ctx) {
-  return Base::get(ctx, StandardAttributes::SymbolRef, value, llvm::None)
-      .cast<FlatSymbolRefAttr>();
+  return Base::get(ctx, value, llvm::None).cast<FlatSymbolRefAttr>();
 }
 
 SymbolRefAttr SymbolRefAttr::get(StringRef value,
                                  ArrayRef<FlatSymbolRefAttr> nestedReferences,
                                  MLIRContext *ctx) {
-  return Base::get(ctx, StandardAttributes::SymbolRef, value, nestedReferences);
+  return Base::get(ctx, value, nestedReferences);
 }
 
 StringRef SymbolRefAttr::getRootReference() const { return getImpl()->value; }
@@ -307,7 +306,7 @@ ArrayRef<FlatSymbolRefAttr> SymbolRefAttr::getNestedReferences() const {
 IntegerAttr IntegerAttr::get(Type type, const APInt &value) {
   if (type.isSignlessInteger(1))
     return BoolAttr::get(value.getBoolValue(), type.getContext());
-  return Base::get(type.getContext(), StandardAttributes::Integer, type, value);
+  return Base::get(type.getContext(), type, value);
 }
 
 IntegerAttr IntegerAttr::get(Type type, int64_t value) {
@@ -380,8 +379,7 @@ bool BoolAttr::classof(Attribute attr) {
 //===----------------------------------------------------------------------===//
 
 IntegerSetAttr IntegerSetAttr::get(IntegerSet value) {
-  return Base::get(value.getConstraint(0).getContext(),
-                   StandardAttributes::IntegerSet, value);
+  return Base::get(value.getConstraint(0).getContext(), value);
 }
 
 IntegerSet IntegerSetAttr::getValue() const { return getImpl()->value; }
@@ -392,14 +390,12 @@ IntegerSet IntegerSetAttr::getValue() const { return getImpl()->value; }
 
 OpaqueAttr OpaqueAttr::get(Identifier dialect, StringRef attrData, Type type,
                            MLIRContext *context) {
-  return Base::get(context, StandardAttributes::Opaque, dialect, attrData,
-                   type);
+  return Base::get(context, dialect, attrData, type);
 }
 
 OpaqueAttr OpaqueAttr::getChecked(Identifier dialect, StringRef attrData,
                                   Type type, Location location) {
-  return Base::getChecked(location, StandardAttributes::Opaque, dialect,
-                          attrData, type);
+  return Base::getChecked(location, dialect, attrData, type);
 }
 
 /// Returns the dialect namespace of the opaque attribute.
@@ -430,7 +426,7 @@ StringAttr StringAttr::get(StringRef bytes, MLIRContext *context) {
 
 /// Get an instance of a StringAttr with the given string and Type.
 StringAttr StringAttr::get(StringRef bytes, Type type) {
-  return Base::get(type.getContext(), StandardAttributes::String, bytes, type);
+  return Base::get(type.getContext(), bytes, type);
 }
 
 StringRef StringAttr::getValue() const { return getImpl()->value; }
@@ -440,7 +436,7 @@ StringRef StringAttr::getValue() const { return getImpl()->value; }
 //===----------------------------------------------------------------------===//
 
 TypeAttr TypeAttr::get(Type value) {
-  return Base::get(value.getContext(), StandardAttributes::Type, value);
+  return Base::get(value.getContext(), value);
 }
 
 Type TypeAttr::getValue() const { return getImpl()->value; }
@@ -1036,8 +1032,7 @@ DenseElementsAttr DenseElementsAttr::mapValues(
 
 DenseStringElementsAttr
 DenseStringElementsAttr::get(ShapedType type, ArrayRef<StringRef> values) {
-  return Base::get(type.getContext(), StandardAttributes::DenseStringElements,
-                   type, values, (values.size() == 1));
+  return Base::get(type.getContext(), type, values, (values.size() == 1));
 }
 
 //===----------------------------------------------------------------------===//
@@ -1088,8 +1083,7 @@ DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
   assert((type.isa<RankedTensorType, VectorType>()) &&
          "type must be ranked tensor or vector");
   assert(type.hasStaticShape() && "type must have static shape");
-  return Base::get(type.getContext(), StandardAttributes::DenseIntOrFPElements,
-                   type, data, isSplat);
+  return Base::get(type.getContext(), type, data, isSplat);
 }
 
 /// Overload of the raw 'get' method that asserts that the given type is of
@@ -1210,8 +1204,7 @@ OpaqueElementsAttr OpaqueElementsAttr::get(Dialect *dialect, ShapedType type,
                                            StringRef bytes) {
   assert(TensorType::isValidElementType(type.getElementType()) &&
          "Input element type should be a valid tensor element type");
-  return Base::get(type.getContext(), StandardAttributes::OpaqueElements, type,
-                   dialect, bytes);
+  return Base::get(type.getContext(), type, dialect, bytes);
 }
 
 StringRef OpaqueElementsAttr::getValue() const { return getImpl()->bytes; }
@@ -1248,7 +1241,7 @@ SparseElementsAttr SparseElementsAttr::get(ShapedType type,
   assert((type.isa<RankedTensorType, VectorType>()) &&
          "type must be ranked tensor or vector");
   assert(type.hasStaticShape() && "type must have static shape");
-  return Base::get(type.getContext(), StandardAttributes::SparseElements, type,
+  return Base::get(type.getContext(), type,
                    indices.cast<DenseIntElementsAttr>(), values);
 }
 
index 48b05ba..151e2cf 100644 (file)
@@ -28,8 +28,7 @@ bool LocationAttr::classof(Attribute attr) {
 //===----------------------------------------------------------------------===//
 
 Location CallSiteLoc::get(Location callee, Location caller) {
-  return Base::get(callee->getContext(), StandardAttributes::CallSiteLocation,
-                   callee, caller);
+  return Base::get(callee->getContext(), callee, caller);
 }
 
 Location CallSiteLoc::get(Location name, ArrayRef<Location> frames) {
@@ -50,8 +49,7 @@ Location CallSiteLoc::getCaller() const { return getImpl()->caller; }
 
 Location FileLineColLoc::get(Identifier filename, unsigned line,
                              unsigned column, MLIRContext *context) {
-  return Base::get(context, StandardAttributes::FileLineColLocation, filename,
-                   line, column);
+  return Base::get(context, filename, line, column);
 }
 
 Location FileLineColLoc::get(StringRef filename, unsigned line, unsigned column,
@@ -95,7 +93,7 @@ Location FusedLoc::get(ArrayRef<Location> locs, Attribute metadata,
     return UnknownLoc::get(context);
   if (locs.size() == 1)
     return locs.front();
-  return Base::get(context, StandardAttributes::FusedLocation, locs, metadata);
+  return Base::get(context, locs, metadata);
 }
 
 ArrayRef<Location> FusedLoc::getLocations() const {
@@ -111,8 +109,7 @@ Attribute FusedLoc::getMetadata() const { return getImpl()->metadata; }
 Location NameLoc::get(Identifier name, Location child) {
   assert(!child.isa<NameLoc>() &&
          "a NameLoc cannot be used as a child of another NameLoc");
-  return Base::get(child->getContext(), StandardAttributes::NameLocation, name,
-                   child);
+  return Base::get(child->getContext(), name, child);
 }
 
 Location NameLoc::get(Identifier name, MLIRContext *context) {
@@ -131,9 +128,8 @@ Location NameLoc::getChildLoc() const { return getImpl()->child; }
 
 Location OpaqueLoc::get(uintptr_t underlyingLocation, TypeID typeID,
                         Location fallbackLocation) {
-  return Base::get(fallbackLocation->getContext(),
-                   StandardAttributes::OpaqueLocation, underlyingLocation,
-                   typeID, fallbackLocation);
+  return Base::get(fallbackLocation->getContext(), underlyingLocation, typeID,
+                   fallbackLocation);
 }
 
 uintptr_t OpaqueLoc::getUnderlyingLocation() const {
index 0d66070..a86f27a 100644 (file)
@@ -87,6 +87,10 @@ namespace {
 struct BuiltinDialect : public Dialect {
   BuiltinDialect(MLIRContext *context)
       : Dialect(/*name=*/"", context, TypeID::get<BuiltinDialect>()) {
+    addTypes<ComplexType, BFloat16Type, Float16Type, Float32Type, Float64Type,
+             FunctionType, IndexType, IntegerType, MemRefType,
+             UnrankedMemRefType, NoneType, OpaqueType, RankedTensorType,
+             TupleType, UnrankedTensorType, VectorType>();
     addAttributes<AffineMapAttr, ArrayAttr, DenseIntOrFPElementsAttr,
                   DenseStringElementsAttr, DictionaryAttr, FloatAttr,
                   SymbolRefAttr, IntegerAttr, IntegerSetAttr, OpaqueAttr,
@@ -95,11 +99,6 @@ struct BuiltinDialect : public Dialect {
     addAttributes<CallSiteLoc, FileLineColLoc, FusedLoc, NameLoc, OpaqueLoc,
                   UnknownLoc>();
 
-    addTypes<ComplexType, BFloat16Type, Float16Type, Float32Type, Float64Type,
-             FunctionType, IndexType, IntegerType, MemRefType,
-             UnrankedMemRefType, NoneType, OpaqueType, RankedTensorType,
-             TupleType, UnrankedTensorType, VectorType>();
-
     // TODO: These operations should be moved to a different dialect when they
     // have been fully decoupled from the core.
     addOperations<FuncOp, ModuleOp, ModuleTerminatorOp>();
@@ -363,56 +362,50 @@ MLIRContext::MLIRContext() : impl(new MLIRContextImpl()) {
 
   //// Types.
   /// Floating-point Types.
-  impl->bf16Ty = TypeUniquer::get<BFloat16Type>(this, StandardTypes::BF16);
-  impl->f16Ty = TypeUniquer::get<Float16Type>(this, StandardTypes::F16);
-  impl->f32Ty = TypeUniquer::get<Float32Type>(this, StandardTypes::F32);
-  impl->f64Ty = TypeUniquer::get<Float64Type>(this, StandardTypes::F64);
+  impl->bf16Ty = TypeUniquer::get<BFloat16Type>(this);
+  impl->f16Ty = TypeUniquer::get<Float16Type>(this);
+  impl->f32Ty = TypeUniquer::get<Float32Type>(this);
+  impl->f64Ty = TypeUniquer::get<Float64Type>(this);
   /// Index Type.
-  impl->indexTy = TypeUniquer::get<IndexType>(this, StandardTypes::Index);
+  impl->indexTy = TypeUniquer::get<IndexType>(this);
   /// Integer Types.
-  impl->int1Ty = TypeUniquer::get<IntegerType>(this, StandardTypes::Integer, 1,
-                                               IntegerType::Signless);
-  impl->int8Ty = TypeUniquer::get<IntegerType>(this, StandardTypes::Integer, 8,
-                                               IntegerType::Signless);
-  impl->int16Ty = TypeUniquer::get<IntegerType>(this, StandardTypes::Integer,
-                                                16, IntegerType::Signless);
-  impl->int32Ty = TypeUniquer::get<IntegerType>(this, StandardTypes::Integer,
-                                                32, IntegerType::Signless);
-  impl->int64Ty = TypeUniquer::get<IntegerType>(this, StandardTypes::Integer,
-                                                64, IntegerType::Signless);
-  impl->int128Ty = TypeUniquer::get<IntegerType>(this, StandardTypes::Integer,
-                                                 128, IntegerType::Signless);
+  impl->int1Ty = TypeUniquer::get<IntegerType>(this, 1, IntegerType::Signless);
+  impl->int8Ty = TypeUniquer::get<IntegerType>(this, 8, IntegerType::Signless);
+  impl->int16Ty =
+      TypeUniquer::get<IntegerType>(this, 16, IntegerType::Signless);
+  impl->int32Ty =
+      TypeUniquer::get<IntegerType>(this, 32, IntegerType::Signless);
+  impl->int64Ty =
+      TypeUniquer::get<IntegerType>(this, 64, IntegerType::Signless);
+  impl->int128Ty =
+      TypeUniquer::get<IntegerType>(this, 128, IntegerType::Signless);
   /// None Type.
-  impl->noneType = TypeUniquer::get<NoneType>(this, StandardTypes::None);
+  impl->noneType = TypeUniquer::get<NoneType>(this);
 
   //// Attributes.
   //// Note: These must be registered after the types as they may generate one
   //// of the above types internally.
   /// Bool Attributes.
   impl->falseAttr = AttributeUniquer::get<IntegerAttr>(
-                        this, StandardAttributes::Integer, impl->int1Ty,
-                        APInt(/*numBits=*/1, false))
+                        this, impl->int1Ty, APInt(/*numBits=*/1, false))
                         .cast<BoolAttr>();
   impl->trueAttr = AttributeUniquer::get<IntegerAttr>(
-                       this, StandardAttributes::Integer, impl->int1Ty,
-                       APInt(/*numBits=*/1, true))
+                       this, impl->int1Ty, APInt(/*numBits=*/1, true))
                        .cast<BoolAttr>();
   /// Unit Attribute.
-  impl->unitAttr =
-      AttributeUniquer::get<UnitAttr>(this, StandardAttributes::Unit);
+  impl->unitAttr = AttributeUniquer::get<UnitAttr>(this);
   /// Unknown Location Attribute.
-  impl->unknownLocAttr = AttributeUniquer::get<UnknownLoc>(
-      this, StandardAttributes::UnknownLocation);
+  impl->unknownLocAttr = AttributeUniquer::get<UnknownLoc>(this);
   /// The empty dictionary attribute.
-  impl->emptyDictionaryAttr = AttributeUniquer::get<DictionaryAttr>(
-      this, StandardAttributes::Dictionary, ArrayRef<NamedAttribute>());
+  impl->emptyDictionaryAttr =
+      AttributeUniquer::get<DictionaryAttr>(this, ArrayRef<NamedAttribute>());
 
   // Register the affine storage objects with the uniquer.
-  impl->affineUniquer.registerStorageType(
-      TypeID::get<AffineBinaryOpExprStorage>());
-  impl->affineUniquer.registerStorageType(
-      TypeID::get<AffineConstantExprStorage>());
-  impl->affineUniquer.registerStorageType(TypeID::get<AffineDimExprStorage>());
+  impl->affineUniquer
+      .registerParametricStorageType<AffineBinaryOpExprStorage>();
+  impl->affineUniquer
+      .registerParametricStorageType<AffineConstantExprStorage>();
+  impl->affineUniquer.registerParametricStorageType<AffineDimExprStorage>();
 }
 
 MLIRContext::~MLIRContext() {}
@@ -582,7 +575,6 @@ void Dialect::addType(TypeID typeID, AbstractType &&typeInfo) {
           AbstractType(std::move(typeInfo));
   if (!impl.registeredTypes.insert({typeID, newInfo}).second)
     llvm::report_fatal_error("Dialect Type already registered.");
-  impl.typeUniquer.registerStorageType(typeID);
 }
 
 void Dialect::addAttribute(TypeID typeID, AbstractAttribute &&attrInfo) {
@@ -592,7 +584,6 @@ void Dialect::addAttribute(TypeID typeID, AbstractAttribute &&attrInfo) {
           AbstractAttribute(std::move(attrInfo));
   if (!impl.registeredAttributes.insert({typeID, newInfo}).second)
     llvm::report_fatal_error("Dialect Attribute already registered.");
-  impl.attributeUniquer.registerStorageType(typeID);
 }
 
 /// Get the dialect that registered the attribute with the provided typeid.
@@ -718,7 +709,7 @@ IntegerType IntegerType::get(unsigned width,
                              MLIRContext *context) {
   if (auto cached = getCachedIntegerType(width, signedness, context))
     return cached;
-  return Base::get(context, StandardTypes::Integer, width, signedness);
+  return Base::get(context, width, signedness);
 }
 
 IntegerType IntegerType::getChecked(unsigned width, Location location) {
@@ -731,12 +722,16 @@ IntegerType IntegerType::getChecked(unsigned width,
   if (auto cached =
           getCachedIntegerType(width, signedness, location->getContext()))
     return cached;
-  return Base::getChecked(location, StandardTypes::Integer, width, signedness);
+  return Base::getChecked(location, width, signedness);
 }
 
 /// Get an instance of the NoneType.
 NoneType NoneType::get(MLIRContext *context) {
-  return context->getImpl().noneType;
+  if (NoneType cachedInst = context->getImpl().noneType)
+    return cachedInst;
+  // Note: May happen when initializing the singleton attributes of the builtin
+  // dialect.
+  return Base::get(context);
 }
 
 //===----------------------------------------------------------------------===//
index f075324..8eb9025 100644 (file)
@@ -102,12 +102,11 @@ unsigned Type::getIntOrFloatBitWidth() {
 //===----------------------------------------------------------------------===//
 
 ComplexType ComplexType::get(Type elementType) {
-  return Base::get(elementType.getContext(), StandardTypes::Complex,
-                   elementType);
+  return Base::get(elementType.getContext(), elementType);
 }
 
 ComplexType ComplexType::getChecked(Type elementType, Location location) {
-  return Base::getChecked(location, StandardTypes::Complex, elementType);
+  return Base::getChecked(location, elementType);
 }
 
 /// Verify the construction of an integer type.
@@ -265,13 +264,12 @@ bool ShapedType::hasStaticShape(ArrayRef<int64_t> shape) const {
 //===----------------------------------------------------------------------===//
 
 VectorType VectorType::get(ArrayRef<int64_t> shape, Type elementType) {
-  return Base::get(elementType.getContext(), StandardTypes::Vector, shape,
-                   elementType);
+  return Base::get(elementType.getContext(), shape, elementType);
 }
 
 VectorType VectorType::getChecked(ArrayRef<int64_t> shape, Type elementType,
                                   Location location) {
-  return Base::getChecked(location, StandardTypes::Vector, shape, elementType);
+  return Base::getChecked(location, shape, elementType);
 }
 
 LogicalResult VectorType::verifyConstructionInvariants(Location loc,
@@ -320,15 +318,13 @@ bool TensorType::isValidElementType(Type type) {
 
 RankedTensorType RankedTensorType::get(ArrayRef<int64_t> shape,
                                        Type elementType) {
-  return Base::get(elementType.getContext(), StandardTypes::RankedTensor, shape,
-                   elementType);
+  return Base::get(elementType.getContext(), shape, elementType);
 }
 
 RankedTensorType RankedTensorType::getChecked(ArrayRef<int64_t> shape,
                                               Type elementType,
                                               Location location) {
-  return Base::getChecked(location, StandardTypes::RankedTensor, shape,
-                          elementType);
+  return Base::getChecked(location, shape, elementType);
 }
 
 LogicalResult RankedTensorType::verifyConstructionInvariants(
@@ -349,13 +345,12 @@ ArrayRef<int64_t> RankedTensorType::getShape() const {
 //===----------------------------------------------------------------------===//
 
 UnrankedTensorType UnrankedTensorType::get(Type elementType) {
-  return Base::get(elementType.getContext(), StandardTypes::UnrankedTensor,
-                   elementType);
+  return Base::get(elementType.getContext(), elementType);
 }
 
 UnrankedTensorType UnrankedTensorType::getChecked(Type elementType,
                                                   Location location) {
-  return Base::getChecked(location, StandardTypes::UnrankedTensor, elementType);
+  return Base::getChecked(location, elementType);
 }
 
 LogicalResult
@@ -444,8 +439,8 @@ MemRefType MemRefType::getImpl(ArrayRef<int64_t> shape, Type elementType,
     cleanedAffineMapComposition.push_back(map);
   }
 
-  return Base::get(context, StandardTypes::MemRef, shape, elementType,
-                   cleanedAffineMapComposition, memorySpace);
+  return Base::get(context, shape, elementType, cleanedAffineMapComposition,
+                   memorySpace);
 }
 
 ArrayRef<int64_t> MemRefType::getShape() const { return getImpl()->getShape(); }
@@ -462,15 +457,13 @@ unsigned MemRefType::getMemorySpace() const { return getImpl()->memorySpace; }
 
 UnrankedMemRefType UnrankedMemRefType::get(Type elementType,
                                            unsigned memorySpace) {
-  return Base::get(elementType.getContext(), StandardTypes::UnrankedMemRef,
-                   elementType, memorySpace);
+  return Base::get(elementType.getContext(), elementType, memorySpace);
 }
 
 UnrankedMemRefType UnrankedMemRefType::getChecked(Type elementType,
                                                   unsigned memorySpace,
                                                   Location location) {
-  return Base::getChecked(location, StandardTypes::UnrankedMemRef, elementType,
-                          memorySpace);
+  return Base::getChecked(location, elementType, memorySpace);
 }
 
 unsigned UnrankedMemRefType::getMemorySpace() const {
@@ -642,7 +635,7 @@ LogicalResult mlir::getStridesAndOffset(MemRefType t,
 /// Get or create a new TupleType with the provided element types. Assumes the
 /// arguments define a well-formed type.
 TupleType TupleType::get(TypeRange elementTypes, MLIRContext *context) {
-  return Base::get(context, StandardTypes::Tuple, elementTypes);
+  return Base::get(context, elementTypes);
 }
 
 /// Get or create an empty tuple type.
index ae2dd90..cdcd6a9 100644 (file)
@@ -19,8 +19,6 @@ using namespace mlir::detail;
 // Type
 //===----------------------------------------------------------------------===//
 
-unsigned Type::getKind() const { return impl->getKind(); }
-
 Dialect &Type::getDialect() const {
   return impl->getAbstractType().getDialect();
 }
@@ -33,7 +31,7 @@ MLIRContext *Type::getContext() const { return getDialect().getContext(); }
 
 FunctionType FunctionType::get(TypeRange inputs, TypeRange results,
                                MLIRContext *context) {
-  return Base::get(context, Type::Kind::Function, inputs, results);
+  return Base::get(context, inputs, results);
 }
 
 unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; }
@@ -54,12 +52,12 @@ ArrayRef<Type> FunctionType::getResults() const {
 
 OpaqueType OpaqueType::get(Identifier dialect, StringRef typeData,
                            MLIRContext *context) {
-  return Base::get(context, Type::Kind::Opaque, dialect, typeData);
+  return Base::get(context, dialect, typeData);
 }
 
 OpaqueType OpaqueType::getChecked(Identifier dialect, StringRef typeData,
                                   MLIRContext *context, Location location) {
-  return Base::getChecked(location, Kind::Opaque, dialect, typeData);
+  return Base::getChecked(location, dialect, typeData);
 }
 
 /// Returns the dialect namespace of the opaque type.
index 49e7272..73578b5 100644 (file)
@@ -16,19 +16,17 @@ using namespace mlir;
 using namespace mlir::detail;
 
 namespace {
-/// This class represents a uniquer for storage instances of a specific type. It
-/// contains all of the necessary data to unique storage instances in a thread
-/// safe way. This allows for the main uniquer to bucket each of the individual
-/// sub-types removing the need to lock the main uniquer itself.
-struct InstSpecificUniquer {
+/// This class represents a uniquer for storage instances of a specific type
+/// that has parametric storage. It contains all of the necessary data to unique
+/// storage instances in a thread safe way. This allows for the main uniquer to
+/// bucket each of the individual sub-types removing the need to lock the main
+/// uniquer itself.
+struct ParametricStorageUniquer {
   using BaseStorage = StorageUniquer::BaseStorage;
   using StorageAllocator = StorageUniquer::StorageAllocator;
 
   /// A lookup key for derived instances of storage objects.
   struct LookupKey {
-    /// The known derived kind for the storage.
-    unsigned kind;
-
     /// The known hash value of the key.
     unsigned hashValue;
 
@@ -63,18 +61,14 @@ struct InstSpecificUniquer {
     static bool isEqual(const LookupKey &lhs, const HashedStorage &rhs) {
       if (isEqual(rhs, getEmptyKey()) || isEqual(rhs, getTombstoneKey()))
         return false;
-      // If the lookup kind matches the kind of the storage, then invoke the
-      // equality function on the lookup key.
-      return lhs.kind == rhs.storage->getKind() && lhs.isEqual(rhs.storage);
+      // Invoke the equality function on the lookup key.
+      return lhs.isEqual(rhs.storage);
     }
   };
 
-  /// Unique types with specific hashing or storage constraints.
+  /// The set containing the allocated storage instances.
   using StorageTypeSet = DenseSet<HashedStorage, StorageKeyInfo>;
-  StorageTypeSet complexInstances;
-
-  /// Instances of this storage object.
-  llvm::SmallDenseMap<unsigned, BaseStorage *, 1> simpleInstances;
+  StorageTypeSet instances;
 
   /// Allocator to use when constructing derived instances.
   StorageAllocator allocator;
@@ -91,107 +85,79 @@ struct StorageUniquerImpl {
   using BaseStorage = StorageUniquer::BaseStorage;
   using StorageAllocator = StorageUniquer::StorageAllocator;
 
-  /// Get or create an instance of a complex derived type.
+  //===--------------------------------------------------------------------===//
+  // Parametric Storage
+  //===--------------------------------------------------------------------===//
+
+  /// Get or create an instance of a parametric type.
   BaseStorage *
-  getOrCreate(TypeID id, unsigned kind, unsigned hashValue,
+  getOrCreate(TypeID id, unsigned hashValue,
               function_ref<bool(const BaseStorage *)> isEqual,
               function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
-    assert(instUniquers.count(id) && "creating unregistered storage instance");
-    InstSpecificUniquer::LookupKey lookupKey{kind, hashValue, isEqual};
-    InstSpecificUniquer &storageUniquer = *instUniquers[id];
+    assert(parametricUniquers.count(id) &&
+           "creating unregistered storage instance");
+    ParametricStorageUniquer::LookupKey lookupKey{hashValue, isEqual};
+    ParametricStorageUniquer &storageUniquer = *parametricUniquers[id];
     if (!threadingIsEnabled)
-      return getOrCreateUnsafe(storageUniquer, kind, lookupKey, ctorFn);
+      return getOrCreateUnsafe(storageUniquer, lookupKey, ctorFn);
 
     // Check for an existing instance in read-only mode.
     {
       llvm::sys::SmartScopedReader<true> typeLock(storageUniquer.mutex);
-      auto it = storageUniquer.complexInstances.find_as(lookupKey);
-      if (it != storageUniquer.complexInstances.end())
+      auto it = storageUniquer.instances.find_as(lookupKey);
+      if (it != storageUniquer.instances.end())
         return it->storage;
     }
 
     // Acquire a writer-lock so that we can safely create the new type instance.
     llvm::sys::SmartScopedWriter<true> typeLock(storageUniquer.mutex);
-    return getOrCreateUnsafe(storageUniquer, kind, lookupKey, ctorFn);
+    return getOrCreateUnsafe(storageUniquer, lookupKey, ctorFn);
   }
   /// Get or create an instance of a complex derived type in an thread-unsafe
   /// fashion.
   BaseStorage *
-  getOrCreateUnsafe(InstSpecificUniquer &storageUniquer, unsigned kind,
-                    InstSpecificUniquer::LookupKey &lookupKey,
+  getOrCreateUnsafe(ParametricStorageUniquer &storageUniquer,
+                    ParametricStorageUniquer::LookupKey &lookupKey,
                     function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
-    auto existing = storageUniquer.complexInstances.insert_as({}, lookupKey);
+    auto existing = storageUniquer.instances.insert_as({}, lookupKey);
     if (!existing.second)
       return existing.first->storage;
 
     // Otherwise, construct and initialize the derived storage for this type
     // instance.
-    BaseStorage *storage =
-        initializeStorage(kind, storageUniquer.allocator, ctorFn);
+    BaseStorage *storage = ctorFn(storageUniquer.allocator);
     *existing.first =
-        InstSpecificUniquer::HashedStorage{lookupKey.hashValue, storage};
+        ParametricStorageUniquer::HashedStorage{lookupKey.hashValue, storage};
     return storage;
   }
 
-  /// Get or create an instance of a simple derived type.
-  BaseStorage *
-  getOrCreate(TypeID id, unsigned kind,
-              function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
-    assert(instUniquers.count(id) && "creating unregistered storage instance");
-    InstSpecificUniquer &storageUniquer = *instUniquers[id];
-    if (!threadingIsEnabled)
-      return getOrCreateUnsafe(storageUniquer, kind, ctorFn);
-
-    // Check for an existing instance in read-only mode.
-    {
-      llvm::sys::SmartScopedReader<true> typeLock(storageUniquer.mutex);
-      auto it = storageUniquer.simpleInstances.find(kind);
-      if (it != storageUniquer.simpleInstances.end())
-        return it->second;
-    }
-
-    // Acquire a writer-lock so that we can safely create the new type instance.
-    llvm::sys::SmartScopedWriter<true> typeLock(storageUniquer.mutex);
-    return getOrCreateUnsafe(storageUniquer, kind, ctorFn);
-  }
-  /// Get or create an instance of a simple derived type in an thread-unsafe
-  /// fashion.
-  BaseStorage *
-  getOrCreateUnsafe(InstSpecificUniquer &storageUniquer, unsigned kind,
-                    function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
-    auto &result = storageUniquer.simpleInstances[kind];
-    if (result)
-      return result;
-
-    // Otherwise, create and return a new storage instance.
-    return result = initializeStorage(kind, storageUniquer.allocator, ctorFn);
-  }
-
-  /// Erase an instance of a complex derived type.
-  void erase(TypeID id, unsigned kind, unsigned hashValue,
+  /// Erase an instance of a parametric derived type.
+  void erase(TypeID id, unsigned hashValue,
              function_ref<bool(const BaseStorage *)> isEqual,
              function_ref<void(BaseStorage *)> cleanupFn) {
-    assert(instUniquers.count(id) && "erasing unregistered storage instance");
-    InstSpecificUniquer &storageUniquer = *instUniquers[id];
-    InstSpecificUniquer::LookupKey lookupKey{kind, hashValue, isEqual};
+    assert(parametricUniquers.count(id) &&
+           "erasing unregistered storage instance");
+    ParametricStorageUniquer &storageUniquer = *parametricUniquers[id];
+    ParametricStorageUniquer::LookupKey lookupKey{hashValue, isEqual};
 
     // Acquire a writer-lock so that we can safely erase the type instance.
     llvm::sys::SmartScopedWriter<true> lock(storageUniquer.mutex);
-    auto existing = storageUniquer.complexInstances.find_as(lookupKey);
-    if (existing == storageUniquer.complexInstances.end())
+    auto existing = storageUniquer.instances.find_as(lookupKey);
+    if (existing == storageUniquer.instances.end())
       return;
 
     // Cleanup the storage and remove it from the map.
     cleanupFn(existing->storage);
-    storageUniquer.complexInstances.erase(existing);
+    storageUniquer.instances.erase(existing);
   }
 
   /// Mutates an instance of a derived storage in a thread-safe way.
   LogicalResult
   mutate(TypeID id,
          function_ref<LogicalResult(StorageAllocator &)> mutationFn) {
-    assert(instUniquers.count(id) && "mutating unregistered storage instance");
-    InstSpecificUniquer &storageUniquer = *instUniquers[id];
+    assert(parametricUniquers.count(id) &&
+           "mutating unregistered storage instance");
+    ParametricStorageUniquer &storageUniquer = *parametricUniquers[id];
     if (!threadingIsEnabled)
       return mutationFn(storageUniquer.allocator);
 
@@ -200,20 +166,30 @@ struct StorageUniquerImpl {
   }
 
   //===--------------------------------------------------------------------===//
-  // Instance Storage
+  // Singleton Storage
   //===--------------------------------------------------------------------===//
 
-  /// Utility to create and initialize a storage instance.
-  BaseStorage *
-  initializeStorage(unsigned kind, StorageAllocator &allocator,
-                    function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
-    BaseStorage *storage = ctorFn(allocator);
-    storage->kind = kind;
-    return storage;
+  /// Get or create an instance of a singleton storage class.
+  BaseStorage *getSingleton(TypeID id) {
+    BaseStorage *singletonInstance = singletonInstances[id];
+    assert(singletonInstance && "expected singleton instance to exist");
+    return singletonInstance;
   }
 
+  //===--------------------------------------------------------------------===//
+  // Instance Storage
+  //===--------------------------------------------------------------------===//
+
   /// Map of type ids to the storage uniquer to use for registered objects.
-  DenseMap<TypeID, std::unique_ptr<InstSpecificUniquer>> instUniquers;
+  DenseMap<TypeID, std::unique_ptr<ParametricStorageUniquer>>
+      parametricUniquers;
+
+  /// Map of type ids to a singleton instance when the storage class is a
+  /// singleton.
+  DenseMap<TypeID, BaseStorage *> singletonInstances;
+
+  /// Allocator used for uniquing singleton instances.
+  StorageAllocator singletonAllocator;
 
   /// Flag specifying if multi-threading is enabled within the uniquer.
   bool threadingIsEnabled = true;
@@ -229,41 +205,47 @@ void StorageUniquer::disableMultithreading(bool disable) {
   impl->threadingIsEnabled = !disable;
 }
 
-/// Register a new storage object with this uniquer using the given unique type
-/// id.
-void StorageUniquer::registerStorageType(TypeID id) {
-  impl->instUniquers.try_emplace(id, std::make_unique<InstSpecificUniquer>());
-}
-
 /// Implementation for getting/creating an instance of a derived type with
-/// complex storage.
-auto StorageUniquer::getImpl(
-    const TypeID &id, unsigned kind, unsigned hashValue,
+/// parametric storage.
+auto StorageUniquer::getParametricStorageTypeImpl(
+    TypeID id, unsigned hashValue,
     function_ref<bool(const BaseStorage *)> isEqual,
     function_ref<BaseStorage *(StorageAllocator &)> ctorFn) -> BaseStorage * {
-  return impl->getOrCreate(id, kind, hashValue, isEqual, ctorFn);
+  return impl->getOrCreate(id, hashValue, isEqual, ctorFn);
 }
 
-/// Implementation for getting/creating an instance of a derived type with
-/// default storage.
-auto StorageUniquer::getImpl(
-    const TypeID &id, unsigned kind,
-    function_ref<BaseStorage *(StorageAllocator &)> ctorFn) -> BaseStorage * {
-  return impl->getOrCreate(id, kind, ctorFn);
+/// Implementation for registering an instance of a derived type with
+/// parametric storage.
+void StorageUniquer::registerParametricStorageTypeImpl(TypeID id) {
+  impl->parametricUniquers.try_emplace(
+      id, std::make_unique<ParametricStorageUniquer>());
+}
+
+/// Implementation for getting an instance of a derived type with default
+/// storage.
+auto StorageUniquer::getSingletonImpl(TypeID id) -> BaseStorage * {
+  return impl->getSingleton(id);
+}
+
+/// Implementation for registering an instance of a derived type with default
+/// storage.
+void StorageUniquer::registerSingletonImpl(
+    TypeID id, function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
+  assert(!impl->singletonInstances.count(id) &&
+         "storage class already registered");
+  impl->singletonInstances.try_emplace(id, ctorFn(impl->singletonAllocator));
 }
 
-/// Implementation for erasing an instance of a derived type with complex
+/// Implementation for erasing an instance of a derived type with parametric
 /// storage.
-void StorageUniquer::eraseImpl(const TypeID &id, unsigned kind,
-                               unsigned hashValue,
+void StorageUniquer::eraseImpl(TypeID id, unsigned hashValue,
                                function_ref<bool(const BaseStorage *)> isEqual,
                                function_ref<void(BaseStorage *)> cleanupFn) {
-  impl->erase(id, kind, hashValue, isEqual, cleanupFn);
+  impl->erase(id, hashValue, isEqual, cleanupFn);
 }
 
 /// Implementation for mutating an instance of a derived storage.
 LogicalResult StorageUniquer::mutateImpl(
-    const TypeID &id,
-    function_ref<LogicalResult(StorageAllocator &)> mutationFn) {
+    TypeID id, function_ref<LogicalResult(StorageAllocator &)> mutationFn) {
   return impl->mutate(id, mutationFn);
 }
index c873a00..7bea72d 100644 (file)
@@ -156,7 +156,7 @@ static Type parseTestType(DialectAsmParser &parser,
   StringRef name;
   if (parser.parseLess() || parser.parseKeyword(&name))
     return Type();
-  auto rec = TestRecursiveType::create(parser.getBuilder().getContext(), name);
+  auto rec = TestRecursiveType::get(parser.getBuilder().getContext(), name);
 
   // If this type already has been parsed above in the stack, expect just the
   // name.
index 1df1655..c7fd80e 100644 (file)
@@ -26,10 +26,6 @@ struct TestType : public Type::TypeBase<TestType, Type, TypeStorage,
                                         TestTypeInterface::Trait> {
   using Base::Base;
 
-  static TestType get(MLIRContext *context) {
-    return Base::get(context, Type::Kind::FIRST_PRIVATE_EXPERIMENTAL_9_TYPE);
-  }
-
   /// Provide a definition for the necessary interface methods.
   void printTypeC(Location loc) const {
     emitRemark(loc) << *this << " - TestC";
@@ -72,9 +68,8 @@ class TestRecursiveType
 public:
   using Base::Base;
 
-  static TestRecursiveType create(MLIRContext *ctx, StringRef name) {
-    return Base::get(ctx, Type::Kind::FIRST_PRIVATE_EXPERIMENTAL_9_TYPE + 1,
-                     name);
+  static TestRecursiveType get(MLIRContext *ctx, StringRef name) {
+    return Base::get(ctx, name);
   }
 
   /// Body getter and setter.
index f62c06e..37e322c 100644 (file)
@@ -41,7 +41,7 @@ struct TestRecursiveTypesPass
 LogicalResult TestRecursiveTypesPass::createIRWithTypes() {
   MLIRContext *ctx = &getContext();
   FuncOp func = getFunction();
-  auto type = TestRecursiveType::create(ctx, "some_long_and_unique_name");
+  auto type = TestRecursiveType::get(ctx, "some_long_and_unique_name");
   if (failed(type.setBody(type)))
     return func.emitError("expected to be able to set the type body");
 
@@ -56,7 +56,7 @@ LogicalResult TestRecursiveTypesPass::createIRWithTypes() {
         "not expected to be able to change function body more than once");
 
   // Expecting to get the same type for the same name.
-  auto other = TestRecursiveType::create(ctx, "some_long_and_unique_name");
+  auto other = TestRecursiveType::get(ctx, "some_long_and_unique_name");
   if (type != other)
     return func.emitError("expected type name to be the uniquing key");