[mlir][vulkan-runner] Use C-compatible wrapper emission.
authorDenis Khalikov <khalikov.denis@huawei.com>
Tue, 17 Mar 2020 11:49:00 +0000 (07:49 -0400)
committerLei Zhang <antiagainst@google.com>
Tue, 17 Mar 2020 11:54:41 +0000 (07:54 -0400)
A memref argument is converted into a pointer-to-struct argument
of type `{T*, T*, i64, i64[N], i64[N]}*` in the wrapper function,
where T is the converted element type and N is the memref rank.

Differential Revision: https://reviews.llvm.org/D76059

mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp
mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
mlir/test/Conversion/GPUToVulkan/invoke-vulkan.mlir
mlir/tools/mlir-vulkan-runner/VulkanRuntime.h
mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp
mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp

index fcfae45..26833fd 100644 (file)
@@ -33,10 +33,10 @@ static constexpr const char *kVulkanLaunch = "vulkanLaunch";
 
 namespace {
 
-// A pass to convert gpu launch op to vulkan launch call op, by creating a
-// SPIR-V binary shader from `spirv::ModuleOp` using `spirv::serialize`
-// function and attaching binary data and entry point name as an attributes to
-// created vulkan launch call op.
+/// A pass to convert gpu launch op to vulkan launch call op, by creating a
+/// SPIR-V binary shader from `spirv::ModuleOp` using `spirv::serialize`
+/// function and attaching binary data and entry point name as an attributes to
+/// created vulkan launch call op.
 class ConvertGpuLaunchFuncToVulkanLaunchFunc
     : public ModulePass<ConvertGpuLaunchFuncToVulkanLaunchFunc> {
 public:
index 92c376e..f1dc52e 100644 (file)
@@ -15,6 +15,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
+#include "mlir/Dialect/GPU/GPUDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
@@ -26,7 +27,9 @@
 
 using namespace mlir;
 
-static constexpr const char *kBindResource = "bindResource";
+static constexpr const char *kBindMemRef1DFloat = "bindMemRef1DFloat";
+static constexpr const char *kCInterfaceVulkanLaunch =
+    "_mlir_ciface_vulkanLaunch";
 static constexpr const char *kDeinitVulkan = "deinitVulkan";
 static constexpr const char *kRunOnVulkan = "runOnVulkan";
 static constexpr const char *kInitVulkan = "initVulkan";
@@ -40,11 +43,11 @@ static constexpr const char *kVulkanLaunch = "vulkanLaunch";
 
 namespace {
 
-/// A pass to convert vulkan launch func into a sequence of Vulkan
+/// A pass to convert vulkan launch call op into a sequence of Vulkan
 /// runtime calls in the following order:
 ///
 /// * initVulkan           -- initializes vulkan runtime
-/// * bindResource         -- binds resource
+/// * bindMemRef           -- binds memref
 /// * setBinaryShader      -- sets the binary shader data
 /// * setEntryPoint        -- sets the entry point name
 /// * setNumWorkGroups     -- sets the number of a local workgroups
@@ -67,6 +70,29 @@ private:
     llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
     llvmInt32Type = LLVM::LLVMType::getInt32Ty(llvmDialect);
     llvmInt64Type = LLVM::LLVMType::getInt64Ty(llvmDialect);
+    initializeMemRefTypes();
+  }
+
+  void initializeMemRefTypes() {
+    // According to the MLIR doc memref argument is converted into a
+    // pointer-to-struct argument of type:
+    // template <typename Elem, size_t Rank>
+    // struct {
+    //   Elem *allocated;
+    //   Elem *aligned;
+    //   int64_t offset;
+    //   int64_t sizes[Rank]; // omitted when rank == 0
+    //   int64_t strides[Rank]; // omitted when rank == 0
+    // };
+    auto llvmPtrToFloatType = getFloatType().getPointerTo();
+    auto llvmArrayOneElementSizeType =
+        LLVM::LLVMType::getArrayTy(getInt64Type(), 1);
+
+    // Create a type `!llvm<"{ float*, float*, i64, [1 x i64], [1 x i64]}">`.
+    llvmMemRef1DFloat = LLVM::LLVMType::getStructTy(
+        llvmDialect,
+        {llvmPtrToFloatType, llvmPtrToFloatType, getInt64Type(),
+         llvmArrayOneElementSizeType, llvmArrayOneElementSizeType});
   }
 
   LLVM::LLVMType getFloatType() { return llvmFloatType; }
@@ -74,6 +100,7 @@ private:
   LLVM::LLVMType getPointerType() { return llvmPointerType; }
   LLVM::LLVMType getInt32Type() { return llvmInt32Type; }
   LLVM::LLVMType getInt64Type() { return llvmInt64Type; }
+  LLVM::LLVMType getMemRef1DFloat() { return llvmMemRef1DFloat; }
 
   /// Creates a LLVM global for the given `name`.
   Value createEntryPointNameConstant(StringRef name, Location loc,
@@ -85,16 +112,27 @@ private:
   /// Checks whether the given LLVM::CallOp is a vulkan launch call op.
   bool isVulkanLaunchCallOp(LLVM::CallOp callOp) {
     return (callOp.callee() && callOp.callee().getValue() == kVulkanLaunch &&
-            callOp.getNumOperands() >= 6);
+            callOp.getNumOperands() >= gpu::LaunchOp::kNumConfigOperands);
+  }
+
+  /// Checks whether the given LLVM::CallOp is a "ci_face" vulkan launch call
+  /// op.
+  bool isCInterfaceVulkanLaunchCallOp(LLVM::CallOp callOp) {
+    return (callOp.callee() &&
+            callOp.callee().getValue() == kCInterfaceVulkanLaunch &&
+            callOp.getNumOperands() >= gpu::LaunchOp::kNumConfigOperands);
   }
 
   /// Translates the given `vulkanLaunchCallOp` to the sequence of Vulkan
   /// runtime calls.
   void translateVulkanLaunchCall(LLVM::CallOp vulkanLaunchCallOp);
 
-  /// Creates call to `bindResource` for each resource operand.
-  void createBindResourceCalls(LLVM::CallOp vulkanLaunchCallOp,
-                               Value vulkanRuntiem);
+  /// Creates call to `bindMemRef` for each memref operand.
+  void createBindMemRefCalls(LLVM::CallOp vulkanLaunchCallOp,
+                             Value vulkanRuntime);
+
+  /// Collects SPIRV attributes from the given `vulkanLaunchCallOp`.
+  void collectSPIRVAttributes(LLVM::CallOp vulkanLaunchCallOp);
 
 public:
   void runOnModule() override;
@@ -106,89 +144,81 @@ private:
   LLVM::LLVMType llvmPointerType;
   LLVM::LLVMType llvmInt32Type;
   LLVM::LLVMType llvmInt64Type;
-};
-
-/// Represents operand adaptor for vulkan launch call operation, to simplify an
-/// access to the lowered memref.
-// TODO: We should use 'emit-c-wrappers' option to lower memref type:
-// https://mlir.llvm.org/docs/ConversionToLLVMDialect/#c-compatible-wrapper-emission.
-struct VulkanLaunchOpOperandAdaptor {
-  VulkanLaunchOpOperandAdaptor(ArrayRef<Value> values) { operands = values; }
-  VulkanLaunchOpOperandAdaptor(const VulkanLaunchOpOperandAdaptor &) = delete;
-  VulkanLaunchOpOperandAdaptor
-  operator=(const VulkanLaunchOpOperandAdaptor &) = delete;
-
-  /// Returns a tuple with a pointer to the memory and the size for the index-th
-  /// resource.
-  std::tuple<Value, Value> getResourceDescriptor1D(uint32_t index) {
-    assert(index < getResourceCount1D());
-    // 1D memref calling convention according to "ConversionToLLVMDialect.md":
-    // 0. Allocated pointer.
-    // 1. Aligned pointer.
-    // 2. Offset.
-    // 3. Size in dim 0.
-    // 4. Stride in dim 0.
-    auto offset = numConfigOps + index * loweredMemRefNumOps1D;
-    return std::make_tuple(operands[offset], operands[offset + 3]);
-  }
+  LLVM::LLVMType llvmMemRef1DFloat;
 
-  /// Returns the number of resources assuming all operands lowered from
-  /// 1D memref.
-  uint32_t getResourceCount1D() {
-    return (operands.size() - numConfigOps) / loweredMemRefNumOps1D;
-  }
-
-private:
-  /// The number of operands of lowered 1D memref.
-  static constexpr const uint32_t loweredMemRefNumOps1D = 5;
-  /// The number of the first config operands.
-  static constexpr const uint32_t numConfigOps = 6;
-  ArrayRef<Value> operands;
+  // TODO: Use an associative array to support multiple vulkan launch calls.
+  std::pair<StringAttr, StringAttr> spirvAttributes;
 };
 
 } // anonymous namespace
 
 void VulkanLaunchFuncToVulkanCallsPass::runOnModule() {
   initializeCachedTypes();
+
+  // Collect SPIR-V attributes such as `spirv_blob` and
+  // `spirv_entry_point_name`.
   getModule().walk([this](LLVM::CallOp op) {
     if (isVulkanLaunchCallOp(op))
+      collectSPIRVAttributes(op);
+  });
+
+  // Convert vulkan launch call op into a sequence of Vulkan runtime calls.
+  getModule().walk([this](LLVM::CallOp op) {
+    if (isCInterfaceVulkanLaunchCallOp(op))
       translateVulkanLaunchCall(op);
   });
 }
 
-void VulkanLaunchFuncToVulkanCallsPass::createBindResourceCalls(
-    LLVM::CallOp vulkanLaunchCallOp, Value vulkanRuntime) {
-  if (vulkanLaunchCallOp.getNumOperands() == 6)
+void VulkanLaunchFuncToVulkanCallsPass::collectSPIRVAttributes(
+    LLVM::CallOp vulkanLaunchCallOp) {
+  // Check that `kSPIRVBinary` and `kSPIRVEntryPoint` are present in attributes
+  // for the given vulkan launch call.
+  auto spirvBlobAttr =
+      vulkanLaunchCallOp.getAttrOfType<StringAttr>(kSPIRVBlobAttrName);
+  if (!spirvBlobAttr) {
+    vulkanLaunchCallOp.emitError()
+        << "missing " << kSPIRVBlobAttrName << " attribute";
+    return signalPassFailure();
+  }
+
+  auto spirvEntryPointNameAttr =
+      vulkanLaunchCallOp.getAttrOfType<StringAttr>(kSPIRVEntryPointAttrName);
+  if (!spirvEntryPointNameAttr) {
+    vulkanLaunchCallOp.emitError()
+        << "missing " << kSPIRVEntryPointAttrName << " attribute";
+    return signalPassFailure();
+  }
+
+  spirvAttributes = std::make_pair(spirvBlobAttr, spirvEntryPointNameAttr);
+}
+
+void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls(
+    LLVM::CallOp cInterfaceVulkanLaunchCallOp, Value vulkanRuntime) {
+  if (cInterfaceVulkanLaunchCallOp.getNumOperands() ==
+      gpu::LaunchOp::kNumConfigOperands)
     return;
-  OpBuilder builder(vulkanLaunchCallOp);
-  Location loc = vulkanLaunchCallOp.getLoc();
+  OpBuilder builder(cInterfaceVulkanLaunchCallOp);
+  Location loc = cInterfaceVulkanLaunchCallOp.getLoc();
 
   // Create LLVM constant for the descriptor set index.
-  // Bind all resources to the `0` descriptor set, the same way as `GPUToSPIRV`
+  // Bind all memrefs to the `0` descriptor set, the same way as `GPUToSPIRV`
   // pass does.
   Value descriptorSet = builder.create<LLVM::ConstantOp>(
       loc, getInt32Type(), builder.getI32IntegerAttr(0));
 
-  auto operands = SmallVector<Value, 32>{vulkanLaunchCallOp.getOperands()};
-  VulkanLaunchOpOperandAdaptor vkLaunchOperandAdaptor(operands);
-
-  for (auto resourceIdx :
-       llvm::seq<uint32_t>(0, vkLaunchOperandAdaptor.getResourceCount1D())) {
+  for (auto en :
+       llvm::enumerate(cInterfaceVulkanLaunchCallOp.getOperands().drop_front(
+           gpu::LaunchOp::kNumConfigOperands))) {
     // Create LLVM constant for the descriptor binding index.
     Value descriptorBinding = builder.create<LLVM::ConstantOp>(
-        loc, getInt32Type(), builder.getI32IntegerAttr(resourceIdx));
-    // Get a pointer to the memory and size of that memory.
-    auto resourceDescriptor =
-        vkLaunchOperandAdaptor.getResourceDescriptor1D(resourceIdx);
-    // Create call to `bindResource`.
+        loc, getInt32Type(), builder.getI32IntegerAttr(en.index()));
+    // Create call to `bindMemRef`.
     builder.create<LLVM::CallOp>(
         loc, ArrayRef<Type>{getVoidType()},
-        builder.getSymbolRefAttr(kBindResource),
+        // TODO: Add support for memref with other ranks.
+        builder.getSymbolRefAttr(kBindMemRef1DFloat),
         ArrayRef<Value>{vulkanRuntime, descriptorSet, descriptorBinding,
-                        // Pointer to the memory.
-                        std::get<0>(resourceDescriptor),
-                        // Size of the memory.
-                        std::get<1>(resourceDescriptor)});
+                        en.value()});
   }
 }
 
@@ -228,14 +258,14 @@ void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) {
                                       /*isVarArg=*/false));
   }
 
-  if (!module.lookupSymbol(kBindResource)) {
+  if (!module.lookupSymbol(kBindMemRef1DFloat)) {
     builder.create<LLVM::LLVMFuncOp>(
-        loc, kBindResource,
-        LLVM::LLVMType::getFunctionTy(
-            getVoidType(),
-            {getPointerType(), getInt32Type(), getInt32Type(),
-             getFloatType().getPointerTo(), getInt64Type()},
-            /*isVarArg=*/false));
+        loc, kBindMemRef1DFloat,
+        LLVM::LLVMType::getFunctionTy(getVoidType(),
+                                      {getPointerType(), getInt32Type(),
+                                       getInt32Type(),
+                                       getMemRef1DFloat().getPointerTo()},
+                                      /*isVarArg=*/false));
   }
 
   if (!module.lookupSymbol(kInitVulkan)) {
@@ -267,28 +297,9 @@ Value VulkanLaunchFuncToVulkanCallsPass::createEntryPointNameConstant(
 }
 
 void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall(
-    LLVM::CallOp vulkanLaunchCallOp) {
-  OpBuilder builder(vulkanLaunchCallOp);
-  Location loc = vulkanLaunchCallOp.getLoc();
-
-  // Check that `kSPIRVBinary` and `kSPIRVEntryPoint` are present in attributes
-  // for the given vulkan launch call.
-  auto spirvBlobAttr =
-      vulkanLaunchCallOp.getAttrOfType<StringAttr>(kSPIRVBlobAttrName);
-  if (!spirvBlobAttr) {
-    vulkanLaunchCallOp.emitError()
-        << "missing " << kSPIRVBlobAttrName << " attribute";
-    return signalPassFailure();
-  }
-
-  auto entryPointNameAttr =
-      vulkanLaunchCallOp.getAttrOfType<StringAttr>(kSPIRVEntryPointAttrName);
-  if (!entryPointNameAttr) {
-    vulkanLaunchCallOp.emitError()
-        << "missing " << kSPIRVEntryPointAttrName << " attribute";
-    return signalPassFailure();
-  }
-
+    LLVM::CallOp cInterfaceVulkanLaunchCallOp) {
+  OpBuilder builder(cInterfaceVulkanLaunchCallOp);
+  Location loc = cInterfaceVulkanLaunchCallOp.getLoc();
   // Create call to `initVulkan`.
   auto initVulkanCall = builder.create<LLVM::CallOp>(
       loc, ArrayRef<Type>{getPointerType()},
@@ -300,16 +311,16 @@ void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall(
   // Create LLVM global with SPIR-V binary data, so we can pass a pointer with
   // that data to runtime call.
   Value ptrToSPIRVBinary = LLVM::createGlobalString(
-      loc, builder, kSPIRVBinary, spirvBlobAttr.getValue(),
+      loc, builder, kSPIRVBinary, spirvAttributes.first.getValue(),
       LLVM::Linkage::Internal, getLLVMDialect());
 
   // Create LLVM constant for the size of SPIR-V binary shader.
   Value binarySize = builder.create<LLVM::ConstantOp>(
       loc, getInt32Type(),
-      builder.getI32IntegerAttr(spirvBlobAttr.getValue().size()));
+      builder.getI32IntegerAttr(spirvAttributes.first.getValue().size()));
 
-  // Create call to `bindResource` for each resource operand.
-  createBindResourceCalls(vulkanLaunchCallOp, vulkanRuntime);
+  // Create call to `bindMemRef` for each memref operand.
+  createBindMemRefCalls(cInterfaceVulkanLaunchCallOp, vulkanRuntime);
 
   // Create call to `setBinaryShader` runtime function with the given pointer to
   // SPIR-V binary and binary size.
@@ -318,8 +329,8 @@ void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall(
       builder.getSymbolRefAttr(kSetBinaryShader),
       ArrayRef<Value>{vulkanRuntime, ptrToSPIRVBinary, binarySize});
   // Create LLVM global with entry point name.
-  Value entryPointName =
-      createEntryPointNameConstant(entryPointNameAttr.getValue(), loc, builder);
+  Value entryPointName = createEntryPointNameConstant(
+      spirvAttributes.second.getValue(), loc, builder);
   // Create call to `setEntryPoint` runtime function with the given pointer to
   // entry point name.
   builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()},
@@ -330,9 +341,9 @@ void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall(
   builder.create<LLVM::CallOp>(
       loc, ArrayRef<Type>{getVoidType()},
       builder.getSymbolRefAttr(kSetNumWorkGroups),
-      ArrayRef<Value>{vulkanRuntime, vulkanLaunchCallOp.getOperand(0),
-                      vulkanLaunchCallOp.getOperand(1),
-                      vulkanLaunchCallOp.getOperand(2)});
+      ArrayRef<Value>{vulkanRuntime, cInterfaceVulkanLaunchCallOp.getOperand(0),
+                      cInterfaceVulkanLaunchCallOp.getOperand(1),
+                      cInterfaceVulkanLaunchCallOp.getOperand(2)});
 
   // Create call to `runOnVulkan` runtime function.
   builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()},
@@ -347,7 +358,7 @@ void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall(
   // Declare runtime functions.
   declareVulkanFunctions(loc);
 
-  vulkanLaunchCallOp.erase();
+  cInterfaceVulkanLaunchCallOp.erase();
 }
 
 std::unique_ptr<mlir::OpPassBase<mlir::ModuleOp>>
index 060e2b3..4313eb5 100644 (file)
@@ -6,7 +6,7 @@
 // CHECK: %[[addressof_SPIRV_BIN:.*]] = llvm.mlir.addressof @SPIRV_BIN
 // CHECK: %[[SPIRV_BIN_ptr:.*]] = llvm.getelementptr %[[addressof_SPIRV_BIN]]
 // CHECK: %[[SPIRV_BIN_size:.*]] = llvm.mlir.constant
-// CHECK: llvm.call @bindResource(%[[Vulkan_Runtime_ptr]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm<"i8*">, !llvm.i32, !llvm.i32, !llvm<"float*">, !llvm.i64) -> !llvm.void
+// CHECK: llvm.call @bindMemRef1DFloat(%[[Vulkan_Runtime_ptr]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm<"i8*">, !llvm.i32, !llvm.i32, !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*">) -> !llvm.void
 // CHECK: llvm.call @setBinaryShader(%[[Vulkan_Runtime_ptr]], %[[SPIRV_BIN_ptr]], %[[SPIRV_BIN_size]]) : (!llvm<"i8*">, !llvm<"i8*">, !llvm.i32) -> !llvm.void
 // CHECK: %[[addressof_entry_point:.*]] = llvm.mlir.addressof @kernel_spv_entry_point_name
 // CHECK: %[[entry_point_ptr:.*]] = llvm.getelementptr %[[addressof_entry_point]]
@@ -44,5 +44,18 @@ module attributes {gpu.container_module} {
     : (!llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm<"float*">, !llvm<"float*">, !llvm.i64, !llvm.i64, !llvm.i64) -> ()
     llvm.return
   }
-  llvm.func @vulkanLaunch(!llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm<"float*">, !llvm<"float*">, !llvm.i64, !llvm.i64, !llvm.i64)
+  llvm.func @vulkanLaunch(%arg0: !llvm.i64, %arg1: !llvm.i64, %arg2: !llvm.i64, %arg3: !llvm.i64, %arg4: !llvm.i64, %arg5: !llvm.i64, %arg6: !llvm<"float*">, %arg7: !llvm<"float*">, %arg8: !llvm.i64, %arg9: !llvm.i64, %arg10: !llvm.i64) {
+    %0 = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
+    %1 = llvm.insertvalue %arg6, %0[0] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
+    %2 = llvm.insertvalue %arg7, %1[1] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
+    %3 = llvm.insertvalue %arg8, %2[2] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
+    %4 = llvm.insertvalue %arg9, %3[3, 0] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
+    %5 = llvm.insertvalue %arg10, %4[4, 0] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
+    %6 = llvm.mlir.constant(1 : index) : !llvm.i64
+    %7 = llvm.alloca %6 x !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> : (!llvm.i64) -> !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*">
+    llvm.store %5, %7 : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*">
+    llvm.call @_mlir_ciface_vulkanLaunch(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %7) : (!llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*">) -> ()
+    llvm.return
+  }
+  llvm.func @_mlir_ciface_vulkanLaunch(!llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*">)
 }
index 91f2340..9c63714 100644 (file)
@@ -22,7 +22,7 @@
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/Support/ToolOutputFile.h"
 
-#include <vulkan/vulkan.h> // NOLINT
+#include <vulkan/vulkan.h>
 
 using namespace mlir;
 
index f91bc71..10987f6 100644 (file)
@@ -40,7 +40,9 @@ static LogicalResult runMLIRPasses(ModuleOp module) {
   modulePM.addPass(spirv::createLowerABIAttributesPass());
   modulePM.addPass(spirv::createUpdateVersionCapabilityExtensionPass());
   passManager.addPass(createConvertGpuLaunchFuncToVulkanLaunchFuncPass());
-  passManager.addPass(createLowerToLLVMPass());
+  passManager.addPass(createLowerToLLVMPass(/*useAlloca=*/false,
+                                            /*useBarePtrCallConv=*/false,
+                                            /*emitCWrappers=*/true));
   passManager.addPass(createConvertVulkanLaunchFuncToVulkanCallsPass());
   return passManager.run(module);
 }
index eb9a682..52c11da 100644 (file)
@@ -62,34 +62,28 @@ private:
 
 } // namespace
 
+template <typename T, int N>
+struct MemRefDescriptor {
+  T *allocated;
+  T *aligned;
+  int64_t offset;
+  int64_t sizes[N];
+  int64_t strides[N];
+};
+
 extern "C" {
-// Initializes `VulkanRuntimeManager` and returns a pointer to it.
+/// Initializes `VulkanRuntimeManager` and returns a pointer to it.
 void *initVulkan() { return new VulkanRuntimeManager(); }
 
-// Deinitializes `VulkanRuntimeManager` by the given pointer.
+/// Deinitializes `VulkanRuntimeManager` by the given pointer.
 void deinitVulkan(void *vkRuntimeManager) {
   delete reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager);
 }
 
-/// Binds the given memref to the given descriptor set and descriptor index.
-void bindResource(void *vkRuntimeManager, DescriptorSetIndex setIndex,
-                  BindingIndex bindIndex, float *ptr, int64_t size) {
-  VulkanHostMemoryBuffer memBuffer{ptr,
-                                   static_cast<uint32_t>(size * sizeof(float))};
-  reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)
-      ->setResourceData(setIndex, bindIndex, memBuffer);
-}
-
 void runOnVulkan(void *vkRuntimeManager) {
   reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)->runOnVulkan();
 }
 
-/// Fills the given 1D float memref with the given float value.
-void fillResource1DFloat(float *allocated, float *aligned, int64_t offset,
-                         int64_t size, int64_t stride, float value) {
-  std::fill_n(allocated, size, value);
-}
-
 void setEntryPoint(void *vkRuntimeManager, const char *entryPoint) {
   reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)
       ->setEntryPoint(entryPoint);
@@ -105,4 +99,21 @@ void setBinaryShader(void *vkRuntimeManager, uint8_t *shader, uint32_t size) {
   reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)
       ->setShaderModule(shader, size);
 }
+
+/// Binds the given 1D float memref to the given descriptor set and descriptor
+/// index.
+void bindMemRef1DFloat(void *vkRuntimeManager, DescriptorSetIndex setIndex,
+                       BindingIndex bindIndex,
+                       MemRefDescriptor<float, 1> *ptr) {
+  VulkanHostMemoryBuffer memBuffer{
+      ptr->allocated, static_cast<uint32_t>(ptr->sizes[0] * sizeof(float))};
+  reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)
+      ->setResourceData(setIndex, bindIndex, memBuffer);
+}
+
+/// Fills the given 1D float memref with the given float value.
+void _mlir_ciface_fillResource1DFloat(MemRefDescriptor<float, 1> *ptr, // NOLINT
+                                      float value) {
+  std::fill_n(ptr->allocated, ptr->sizes[0], value);
+}
 }