[mlir] Add support for referencing a SymbolRefAttr in a SideEffectInstance
authorRiver Riddle <riddleriver@gmail.com>
Thu, 19 Nov 2020 02:31:40 +0000 (18:31 -0800)
committerRiver Riddle <riddleriver@gmail.com>
Thu, 19 Nov 2020 02:38:43 +0000 (18:38 -0800)
This allows for operations that exclusively affect symbol operations to better describe their side effects.

Differential Revision: https://reviews.llvm.org/D91581

mlir/include/mlir/Interfaces/SideEffectInterfaceBase.td
mlir/include/mlir/Interfaces/SideEffectInterfaces.h
mlir/include/mlir/TableGen/Attribute.h
mlir/lib/TableGen/Attribute.cpp
mlir/test/IR/test-side-effects.mlir
mlir/test/lib/Dialect/Test/TestDialect.cpp
mlir/test/lib/IR/TestSideEffects.cpp
mlir/test/mlir-tblgen/op-side-effects.td
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

index 41b07bc..89318f7 100644 (file)
@@ -110,9 +110,20 @@ class EffectOpInterfaceBase<string name, string baseEffect>
       llvm::erase_if(effects, [&](auto &it) { return it.getValue() != value; });
     }
 
+    /// Collect all of the effect instances that operate on the provided symbol
+    /// reference and place them in 'effects'.
+    void getEffectsOnSymbol(::mlir::SymbolRefAttr value,
+              llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<
+              }] # baseEffect # [{>> & effects) {
+      getEffects(effects);
+      llvm::erase_if(effects, [&](auto &it) {
+        return it.getSymbolRef() != value;
+      });
+    }
+
     /// Collect all of the effect instances that operate on the provided
     /// resource and place them in 'effects'.
-    void getEffectsOnValue(::mlir::SideEffects::Resource *resource,
+    void getEffectsOnResource(::mlir::SideEffects::Resource *resource,
               llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<
               }] # baseEffect # [{>> & effects) {
       getEffects(effects);
index c19f7f4..33a6ba6 100644 (file)
@@ -131,9 +131,9 @@ struct AutomaticAllocationScopeResource
 
 /// This class represents a specific instance of an effect. It contains the
 /// effect being applied, a resource that corresponds to where the effect is
-/// applied, an optional value (either operand, result, or region entry
-/// argument) that the effect is applied to, and an optional parameters
-/// attribute further specifying the details of the effect.
+/// applied, and an optional symbol reference or value(either operand, result,
+/// or region entry argument) that the effect is applied to, and an optional
+/// parameters attribute further specifying the details of the effect.
 template <typename EffectT> class EffectInstance {
 public:
   EffectInstance(EffectT *effect, Resource *resource = DefaultResource::get())
@@ -141,6 +141,9 @@ public:
   EffectInstance(EffectT *effect, Value value,
                  Resource *resource = DefaultResource::get())
       : effect(effect), resource(resource), value(value) {}
+  EffectInstance(EffectT *effect, SymbolRefAttr symbol,
+                 Resource *resource = DefaultResource::get())
+      : effect(effect), resource(resource), value(symbol) {}
   EffectInstance(EffectT *effect, Attribute parameters,
                  Resource *resource = DefaultResource::get())
       : effect(effect), resource(resource), parameters(parameters) {}
@@ -148,13 +151,23 @@ public:
                  Resource *resource = DefaultResource::get())
       : effect(effect), resource(resource), value(value),
         parameters(parameters) {}
+  EffectInstance(EffectT *effect, SymbolRefAttr symbol, Attribute parameters,
+                 Resource *resource = DefaultResource::get())
+      : effect(effect), resource(resource), value(symbol),
+        parameters(parameters) {}
 
   /// Return the effect being applied.
   EffectT *getEffect() const { return effect; }
 
   /// Return the value the effect is applied on, or nullptr if there isn't a
   /// known value being affected.
-  Value getValue() const { return value; }
+  Value getValue() const { return value ? value.dyn_cast<Value>() : Value(); }
+
+  /// Return the symbol reference the effect is applied on, or nullptr if there
+  /// isn't a known smbol being affected.
+  SymbolRefAttr getSymbolRef() const {
+    return value ? value.dyn_cast<SymbolRefAttr>() : SymbolRefAttr();
+  }
 
   /// Return the resource that the effect applies to.
   Resource *getResource() const { return resource; }
@@ -169,8 +182,8 @@ private:
   /// The resource that the given value resides in.
   Resource *resource;
 
-  /// The value that the effect applies to. This is optionally null.
-  Value value;
+  /// The Symbol or Value that the effect applies to. This is optionally null.
+  PointerUnion<SymbolRefAttr, Value> value;
 
   /// Additional parameters of the effect instance. An attribute is used for
   /// type-safe structured storage and context-based uniquing. Concrete effects
index 4571ca8..dc6c969 100644 (file)
@@ -94,6 +94,10 @@ public:
   // of `TypeAttrBase`).
   bool isTypeAttr() const;
 
+  // Returns true if this attribute is a symbol reference attribute (i.e., a
+  // subclass of `SymbolRefAttr` or `FlatSymbolRefAttr`).
+  bool isSymbolRefAttr() const;
+
   // Returns true if this attribute is an enum attribute (i.e., a subclass of
   // `EnumAttrInfo`)
   bool isEnumAttr() const;
index f34d9c0..3377ec9 100644 (file)
@@ -55,6 +55,13 @@ bool Attribute::isDerivedAttr() const { return isSubClassOf("DerivedAttr"); }
 
 bool Attribute::isTypeAttr() const { return isSubClassOf("TypeAttrBase"); }
 
+bool Attribute::isSymbolRefAttr() const {
+  StringRef defName = def->getName();
+  if (defName == "SymbolRefAttr" || defName == "FlatSymbolRefAttr")
+    return true;
+  return isSubClassOf("SymbolRefAttr") || isSubClassOf("FlatSymbolRefAttr");
+}
+
 bool Attribute::isEnumAttr() const { return isSubClassOf("EnumAttrInfo"); }
 
 StringRef Attribute::getStorageType() const {
index ca2e32c..db55414 100644 (file)
   {effect="allocate", on_result, test_resource}
 ]} : () -> i32
 
+// expected-remark@+1 {{found an instance of 'read' on a symbol '@foo_ref', on resource '<Test>'}}
+"test.side_effect_op"() {effects = [
+  {effect="read", on_reference = @foo_ref, test_resource}
+]} : () -> i32
+
 // No _memory_ effects, but a parametric test effect.
 // expected-remark@+2 {{operation has no memory effects}}
 // expected-remark@+1 {{found a parametric effect with affine_map<(d0, d1) -> (d1, d0)>}}
index e71fceb..e815ade 100644 (file)
@@ -744,17 +744,18 @@ void SideEffectOp::getEffects(
             .Case("read", MemoryEffects::Read::get())
             .Case("write", MemoryEffects::Write::get());
 
-    // Check for a result to affect.
-    Value value;
-    if (effectElement.get("on_result"))
-      value = getResult();
-
     // Check for a non-default resource to use.
     SideEffects::Resource *resource = SideEffects::DefaultResource::get();
     if (effectElement.get("test_resource"))
       resource = TestResource::get();
 
-    effects.emplace_back(effect, value, resource);
+    // Check for a result to affect.
+    if (effectElement.get("on_result"))
+      effects.emplace_back(effect, getResult(), resource);
+    else if (Attribute ref = effectElement.get("on_reference"))
+      effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource);
+    else
+      effects.emplace_back(effect, resource);
   }
 }
 
index d9d6aed..114c7f2 100644 (file)
@@ -43,6 +43,8 @@ struct SideEffectsPass
 
         if (instance.getValue())
           diag << " on a value,";
+        else if (SymbolRefAttr symbolRef = instance.getSymbolRef())
+          diag << " on a symbol '" << symbolRef << "',";
 
         diag << " on resource '" << instance.getResource()->getName() << "'";
       }
index 6bae35a..9e97e90 100644 (file)
@@ -11,7 +11,12 @@ class TEST_Op<string mnemonic, list<OpTrait> traits = []> :
 def CustomResource : Resource<"CustomResource">;
 
 def SideEffectOpA : TEST_Op<"side_effect_op_a"> {
-  let arguments = (ins Arg<Variadic<AnyMemRef>, "", [MemRead]>);
+  let arguments = (ins
+    Arg<Variadic<AnyMemRef>, "", [MemRead]>,
+    Arg<SymbolRefAttr, "", [MemRead]>:$symbol,
+    Arg<FlatSymbolRefAttr, "", [MemWrite]>:$flat_symbol,
+    Arg<OptionalAttr<SymbolRefAttr>, "", [MemRead]>:$optional_symbol
+  );
   let results = (outs Res<AnyMemRef, "", [MemAlloc<CustomResource>]>);
 }
 
@@ -21,6 +26,10 @@ def SideEffectOpB : TEST_Op<"side_effect_op_b",
 // CHECK: void SideEffectOpA::getEffects
 // CHECK:   for (::mlir::Value value : getODSOperands(0))
 // CHECK:     effects.emplace_back(MemoryEffects::Read::get(), value, ::mlir::SideEffects::DefaultResource::get());
+// CHECK:   effects.emplace_back(MemoryEffects::Read::get(), symbol(), ::mlir::SideEffects::DefaultResource::get());
+// CHECK:   effects.emplace_back(MemoryEffects::Write::get(), flat_symbol(), ::mlir::SideEffects::DefaultResource::get());
+// CHECK:   if (auto symbolRef = optional_symbolAttr())
+// CHECK:     effects.emplace_back(MemoryEffects::Read::get(), symbolRef, ::mlir::SideEffects::DefaultResource::get());
 // CHECK:   for (::mlir::Value value : getODSResults(0))
 // CHECK:     effects.emplace_back(MemoryEffects::Allocate::get(), value, CustomResource::get());
 
index 737c36f..65ae32f 100644 (file)
@@ -1627,12 +1627,12 @@ void OpEmitter::genOpInterfaceMethods() {
 }
 
 void OpEmitter::genSideEffectInterfaceMethods() {
-  enum EffectKind { Operand, Result, Static };
+  enum EffectKind { Operand, Result, Symbol, Static };
   struct EffectLocation {
     /// The effect applied.
     SideEffect effect;
 
-    /// The index if the kind is either operand or result.
+    /// The index if the kind is not static.
     unsigned index : 30;
 
     /// The kind of the location.
@@ -1661,17 +1661,29 @@ void OpEmitter::genSideEffectInterfaceMethods() {
       effects.push_back(EffectLocation{cast<SideEffect>(decorator),
                                        /*index=*/0, EffectKind::Static});
   }
-  /// Operands.
+  /// Attributes and Operands.
   for (unsigned i = 0, operandIt = 0, e = op.getNumArgs(); i != e; ++i) {
-    if (op.getArg(i).is<NamedTypeConstraint *>()) {
+    Argument arg = op.getArg(i);
+    if (arg.is<NamedTypeConstraint *>()) {
       resolveDecorators(op.getArgDecorators(i), operandIt, EffectKind::Operand);
       ++operandIt;
+      continue;
     }
+    const NamedAttribute *attr = arg.get<NamedAttribute *>();
+    if (attr->attr.getBaseAttr().isSymbolRefAttr())
+      resolveDecorators(op.getArgDecorators(i), i, EffectKind::Symbol);
   }
   /// Results.
   for (unsigned i = 0, e = op.getNumResults(); i != e; ++i)
     resolveDecorators(op.getResultDecorators(i), i, EffectKind::Result);
 
+  // The code used to add an effect instance.
+  // {0}: The effect class.
+  // {1}: Optional value or symbol reference.
+  // {1}: The resource class.
+  const char *addEffectCode =
+      "  effects.emplace_back({0}::get(), {1}{2}::get());\n";
+
   for (auto &it : interfaceEffects) {
     // Generate the 'getEffects' method.
     std::string type = llvm::formatv("::mlir::SmallVectorImpl<::mlir::"
@@ -1684,19 +1696,30 @@ void OpEmitter::genSideEffectInterfaceMethods() {
 
     // Add effect instances for each of the locations marked on the operation.
     for (auto &location : it.second) {
-      if (location.kind != EffectKind::Static) {
+      StringRef effect = location.effect.getName();
+      StringRef resource = location.effect.getResource();
+      if (location.kind == EffectKind::Static) {
+        // A static instance has no attached value.
+        body << llvm::formatv(addEffectCode, effect, "", resource).str();
+      } else if (location.kind == EffectKind::Symbol) {
+        // A symbol reference requires adding the proper attribute.
+        const auto *attr = op.getArg(location.index).get<NamedAttribute *>();
+        if (attr->attr.isOptional()) {
+          body << "  if (auto symbolRef = " << attr->name << "Attr())\n  "
+               << llvm::formatv(addEffectCode, effect, "symbolRef, ", resource)
+                      .str();
+        } else {
+          body << llvm::formatv(addEffectCode, effect, attr->name + "(), ",
+                                resource)
+                      .str();
+        }
+      } else {
+        // Otherwise this is an operand/result, so we need to attach the Value.
         body << "  for (::mlir::Value value : getODS"
              << (location.kind == EffectKind::Operand ? "Operands" : "Results")
-             << "(" << location.index << "))\n  ";
+             << "(" << location.index << "))\n  "
+             << llvm::formatv(addEffectCode, effect, "value, ", resource).str();
       }
-
-      body << "  effects.emplace_back(" << location.effect.getName()
-           << "::get()";
-
-      // If the effect isn't static, it has a specific value attached to it.
-      if (location.kind != EffectKind::Static)
-        body << ", value";
-      body << ", " << location.effect.getResource() << "::get());\n";
     }
   }
 }