[mlir][TableGen] Support intrinsics with multiple returns and overloaded operands.
authorJi Kim <jikimjikim@google.com>
Thu, 19 Nov 2020 08:54:31 +0000 (09:54 +0100)
committerAlex Zinenko <zinenko@google.com>
Thu, 19 Nov 2020 08:59:42 +0000 (09:59 +0100)
For intrinsics with multiple returns where one or more operands are overloaded, the overloaded type is inferred from the corresponding field of the resulting struct, instead of accessing the result directly.

As such, the hasResult parameter of LLVM_IntrOpBase (and derived classes) is replaced with numResults. TableGen for intrinsics also updated to populate this field with the total number of results.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D91680

mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
mlir/test/Target/llvmir-intrinsics.mlir
mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp

index 734dd45..124adb7 100644 (file)
@@ -226,36 +226,46 @@ def LLVM_IntrPatterns {
     [{convertType(opInst.getOperand($0).getType().cast<LLVM::LLVMType>())}];
   string result =
     [{convertType(opInst.getResult($0).getType().cast<LLVM::LLVMType>())}];
+  string structResult =
+    [{convertType(opInst.getResult(0).getType().cast<LLVM::LLVMStructType>()
+                                                   .getBody()[$0])}];
 }
 
 
 // Base class for LLVM intrinsics operation. It is similar to LLVM_Op, but
-// provides the "llvmBuilder" field for constructing the intrinsic. The builder
-// relies on the contents on "overloadedResults" and "overloadedOperands" lists
-// that contain the positions of intrinsic results and operands that are
-// overloadable in the LLVM sense, that is their types must be passed in during
-// the construction of the intrinsic declaration to differentiate between
-// differently-typed versions of the intrinsic. "opName" contains the name of
-// the operation to be associated with the intrinsic and "enumName" contains the
-// name of the intrinsic as appears in `llvm::Intrinsic` enum; one usually wants
-// these to be related.
+// provides the "llvmBuilder" field for constructing the intrinsic.
+// The builder relies on the contents of "overloadedResults" and
+// "overloadedOperands" lists that contain the positions of intrinsic results
+// and operands that are overloadable in the LLVM sense, that is their types
+// must be passed in during the construction of the intrinsic declaration to
+// differentiate between differently-typed versions of the intrinsic.
+// If the intrinsic has multiple results, this will eventually be packed into a
+// single struct result. In this case, the types of any overloaded results need
+// to be accessed via the LLVMStructType, instead of directly via the result.
+// "opName" contains the name of the operation to be associated with the
+// intrinsic and "enumName" contains the name of the intrinsic as appears in
+// `llvm::Intrinsic` enum; one usually wants these to be related.
 class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
                       list<int> overloadedResults, list<int> overloadedOperands,
-                      list<OpTrait> traits, bit hasResult>
+                      list<OpTrait> traits, int numResults>
     : LLVM_OpBase<dialect, opName, traits>,
-      Results<!if(hasResult, (outs LLVM_Type:$res), (outs))> {
+      Results<!if(!gt(numResults, 0), (outs LLVM_Type:$res), (outs))> {
+  string resultPattern = !if(!gt(numResults, 1),
+                             LLVM_IntrPatterns.structResult,
+                             LLVM_IntrPatterns.result);
   let llvmBuilder = [{
     llvm::Module *module = builder.GetInsertBlock()->getModule();
     llvm::Function *fn = llvm::Intrinsic::getDeclaration(
         module,
         llvm::Intrinsic::}] # enumName # [{,
         { }] # StrJoin<!listconcat(
-            ListIntSubst<LLVM_IntrPatterns.result, overloadedResults>.lst,
+            ListIntSubst<resultPattern, overloadedResults>.lst,
             ListIntSubst<LLVM_IntrPatterns.operand,
                          overloadedOperands>.lst)>.result # [{
         });
     auto operands = lookupValues(opInst.getOperands());
-    }] # !if(hasResult, "$res = ", "") # [{builder.CreateCall(fn, operands);
+    }] # !if(!gt(numResults, 0), "$res = ", "")
+       # [{builder.CreateCall(fn, operands);
   }];
 }
 
@@ -263,9 +273,10 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
 // the intrinsic into the LLVM dialect and prefixes its name with "intr.".
 class LLVM_IntrOp<string mnem, list<int> overloadedResults,
                   list<int> overloadedOperands, list<OpTrait> traits,
-                  bit hasResult>
+                  int numResults>
     : LLVM_IntrOpBase<LLVM_Dialect, "intr." # mnem, !subst(".", "_", mnem),
-                      overloadedResults, overloadedOperands, traits, hasResult>;
+                      overloadedResults, overloadedOperands, traits,
+                      numResults>;
 
 // Base class for LLVM intrinsic operations returning no results. Places the
 // intrinsic into the LLVM dialect and prefixes its name with "intr.".
index 075df49..542afaa 100644 (file)
@@ -890,6 +890,33 @@ def LLVM_MemcpyInlineOp : LLVM_ZeroResultIntrOp<"memcpy.inline", [0, 1, 2]> {
                    LLVM_Type:$isVolatile);
 }
 
+// Intrinsics with multiple returns.
+
+def LLVM_SAddWithOverflowOp
+    : LLVM_IntrOp<"sadd.with.overflow", [0], [], [], 2> {
+  let arguments = (ins LLVM_Type, LLVM_Type);
+}
+def LLVM_UAddWithOverflowOp
+    : LLVM_IntrOp<"uadd.with.overflow", [0], [], [], 2> {
+  let arguments = (ins LLVM_Type, LLVM_Type);
+}
+def LLVM_SSubWithOverflowOp
+    : LLVM_IntrOp<"ssub.with.overflow", [0], [], [], 2> {
+  let arguments = (ins LLVM_Type, LLVM_Type);
+}
+def LLVM_USubWithOverflowOp
+    : LLVM_IntrOp<"usub.with.overflow", [0], [], [], 2> {
+  let arguments = (ins LLVM_Type, LLVM_Type);
+}
+def LLVM_SMulWithOverflowOp
+    : LLVM_IntrOp<"smul.with.overflow", [0], [], [], 2> {
+  let arguments = (ins LLVM_Type, LLVM_Type);
+}
+def LLVM_UMulWithOverflowOp
+    : LLVM_IntrOp<"umul.with.overflow", [0], [], [], 2> {
+  let arguments = (ins LLVM_Type, LLVM_Type);
+}
+
 //
 // Vector Reductions.
 //
index 5f72ad3..96b111f 100644 (file)
@@ -40,9 +40,9 @@ class NVVM_Op<string mnemonic, list<OpTrait> traits = []> :
 
 class NVVM_IntrOp<string mnem, list<int> overloadedResults,
                   list<int> overloadedOperands, list<OpTrait> traits,
-                  bit hasResult>
+                  int numResults>
   : LLVM_IntrOpBase<NVVM_Dialect, mnem, "nvvm_" # !subst(".", "_", mnem),
-                    overloadedResults, overloadedOperands, traits, hasResult>;
+                    overloadedResults, overloadedOperands, traits, numResults>;
 
 
 //===----------------------------------------------------------------------===//
index ef1ed5a..333ad05 100644 (file)
@@ -293,6 +293,59 @@ llvm.func @memcpy_test(%arg0: !llvm.i32, %arg1: !llvm.i1, %arg2: !llvm.ptr<i8>,
   llvm.return
 }
 
+// CHECK-LABEL: @sadd_with_overflow_test
+llvm.func @sadd_with_overflow_test(%arg0: !llvm.i32, %arg1: !llvm.i32, %arg2: !llvm.vec<8 x i32>, %arg3: !llvm.vec<8 x i32>) {
+  // CHECK: call { i32, i1 } @llvm.sadd.with.overflow.i32
+  "llvm.intr.sadd.with.overflow"(%arg0, %arg1) : (!llvm.i32, !llvm.i32) -> !llvm.struct<(i32, i1)>
+  // CHECK: call { <8 x i32>, <8 x i1> } @llvm.sadd.with.overflow.v8i32
+  "llvm.intr.sadd.with.overflow"(%arg2, %arg3) : (!llvm.vec<8 x i32>, !llvm.vec<8 x i32>) -> !llvm.struct<(vec<8 x i32>, vec<8 x i1>)>
+  llvm.return
+}
+
+// CHECK-LABEL: @uadd_with_overflow_test
+llvm.func @uadd_with_overflow_test(%arg0: !llvm.i32, %arg1: !llvm.i32, %arg2: !llvm.vec<8 x i32>, %arg3: !llvm.vec<8 x i32>) {
+  // CHECK: call { i32, i1 } @llvm.uadd.with.overflow.i32
+  "llvm.intr.uadd.with.overflow"(%arg0, %arg1) : (!llvm.i32, !llvm.i32) -> !llvm.struct<(i32, i1)>
+  // CHECK: call { <8 x i32>, <8 x i1> } @llvm.uadd.with.overflow.v8i32
+  "llvm.intr.uadd.with.overflow"(%arg2, %arg3) : (!llvm.vec<8 x i32>, !llvm.vec<8 x i32>) -> !llvm.struct<(vec<8 x i32>, vec<8 x i1>)>
+  llvm.return
+}
+
+// CHECK-LABEL: @ssub_with_overflow_test
+llvm.func @ssub_with_overflow_test(%arg0: !llvm.i32, %arg1: !llvm.i32, %arg2: !llvm.vec<8 x i32>, %arg3: !llvm.vec<8 x i32>) {
+  // CHECK: call { i32, i1 } @llvm.ssub.with.overflow.i32
+  "llvm.intr.ssub.with.overflow"(%arg0, %arg1) : (!llvm.i32, !llvm.i32) -> !llvm.struct<(i32, i1)>
+  // CHECK: call { <8 x i32>, <8 x i1> } @llvm.ssub.with.overflow.v8i32
+  "llvm.intr.ssub.with.overflow"(%arg2, %arg3) : (!llvm.vec<8 x i32>, !llvm.vec<8 x i32>) -> !llvm.struct<(vec<8 x i32>, vec<8 x i1>)>
+  llvm.return
+}
+
+// CHECK-LABEL: @usub_with_overflow_test
+llvm.func @usub_with_overflow_test(%arg0: !llvm.i32, %arg1: !llvm.i32, %arg2: !llvm.vec<8 x i32>, %arg3: !llvm.vec<8 x i32>) {
+  // CHECK: call { i32, i1 } @llvm.usub.with.overflow.i32
+  "llvm.intr.usub.with.overflow"(%arg0, %arg1) : (!llvm.i32, !llvm.i32) -> !llvm.struct<(i32, i1)>
+  // CHECK: call { <8 x i32>, <8 x i1> } @llvm.usub.with.overflow.v8i32
+  "llvm.intr.usub.with.overflow"(%arg2, %arg3) : (!llvm.vec<8 x i32>, !llvm.vec<8 x i32>) -> !llvm.struct<(vec<8 x i32>, vec<8 x i1>)>
+  llvm.return
+}
+
+// CHECK-LABEL: @smul_with_overflow_test
+llvm.func @smul_with_overflow_test(%arg0: !llvm.i32, %arg1: !llvm.i32, %arg2: !llvm.vec<8 x i32>, %arg3: !llvm.vec<8 x i32>) {
+  // CHECK: call { i32, i1 } @llvm.smul.with.overflow.i32
+  "llvm.intr.smul.with.overflow"(%arg0, %arg1) : (!llvm.i32, !llvm.i32) -> !llvm.struct<(i32, i1)>
+  // CHECK: call { <8 x i32>, <8 x i1> } @llvm.smul.with.overflow.v8i32
+  "llvm.intr.smul.with.overflow"(%arg2, %arg3) : (!llvm.vec<8 x i32>, !llvm.vec<8 x i32>) -> !llvm.struct<(vec<8 x i32>, vec<8 x i1>)>
+  llvm.return
+}
+
+// CHECK-LABEL: @umul_with_overflow_test
+llvm.func @umul_with_overflow_test(%arg0: !llvm.i32, %arg1: !llvm.i32, %arg2: !llvm.vec<8 x i32>, %arg3: !llvm.vec<8 x i32>) {
+  // CHECK: call { i32, i1 } @llvm.umul.with.overflow.i32
+  "llvm.intr.umul.with.overflow"(%arg0, %arg1) : (!llvm.i32, !llvm.i32) -> !llvm.struct<(i32, i1)>
+  // CHECK: call { <8 x i32>, <8 x i1> } @llvm.umul.with.overflow.v8i32
+  "llvm.intr.umul.with.overflow"(%arg2, %arg3) : (!llvm.vec<8 x i32>, !llvm.vec<8 x i32>) -> !llvm.struct<(vec<8 x i32>, vec<8 x i1>)>
+  llvm.return
+}
 
 // Check that intrinsics are declared with appropriate types.
 // CHECK-DAG: declare float @llvm.fma.f32(float, float, float)
@@ -330,3 +383,13 @@ llvm.func @memcpy_test(%arg0: !llvm.i32, %arg1: !llvm.i1, %arg2: !llvm.ptr<i8>,
 // CHECK-DAG: declare void @llvm.masked.compressstore.v7f32(<7 x float>, float*, <7 x i1>)
 // CHECK-DAG: declare void @llvm.memcpy.p0i8.p0i8.i32(i8* noalias nocapture writeonly, i8* noalias nocapture readonly, i32, i1 immarg)
 // CHECK-DAG: declare void @llvm.memcpy.inline.p0i8.p0i8.i64(i8* noalias nocapture writeonly, i8* noalias nocapture readonly, i64 immarg, i1 immarg)
+// CHECK-DAG: declare { i32, i1 } @llvm.sadd.with.overflow.i32(i32, i32)
+// CHECK-DAG: declare { <8 x i32>, <8 x i1> } @llvm.sadd.with.overflow.v8i32(<8 x i32>, <8 x i32>) #0
+// CHECK-DAG: declare { i32, i1 } @llvm.uadd.with.overflow.i32(i32, i32)
+// CHECK-DAG: declare { <8 x i32>, <8 x i1> } @llvm.uadd.with.overflow.v8i32(<8 x i32>, <8 x i32>) #0
+// CHECK-DAG: declare { i32, i1 } @llvm.ssub.with.overflow.i32(i32, i32)
+// CHECK-DAG: declare { <8 x i32>, <8 x i1> } @llvm.ssub.with.overflow.v8i32(<8 x i32>, <8 x i32>) #0
+// CHECK-DAG: declare { i32, i1 } @llvm.usub.with.overflow.i32(i32, i32)
+// CHECK-DAG: declare { <8 x i32>, <8 x i1> } @llvm.usub.with.overflow.v8i32(<8 x i32>, <8 x i32>) #0
+// CHECK-DAG: declare { i32, i1 } @llvm.umul.with.overflow.i32(i32, i32)
+// CHECK-DAG: declare { <8 x i32>, <8 x i1> } @llvm.umul.with.overflow.v8i32(<8 x i32>, <8 x i32>) #0
index d57a517..e697f65 100644 (file)
@@ -210,7 +210,7 @@ static bool emitIntrinsic(const llvm::Record &record, llvm::raw_ostream &os) {
   printBracketedRange(intr.getOverloadableOperandsIdxs().set_bits(), os);
   os << ", ";
   printBracketedRange(traits, os);
-  os << ", " << (intr.getNumResults() == 0 ? 0 : 1) << ">, Arguments<(ins"
+  os << ", " << intr.getNumResults() << ">, Arguments<(ins"
      << (operands.empty() ? "" : " ");
   llvm::interleaveComma(operands, os);
   os << ")>;\n\n";