//===----------------------------------------------------------------------===//
#include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
+#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
using namespace mlir;
-static constexpr const char *kBindResource = "bindResource";
+static constexpr const char *kBindMemRef1DFloat = "bindMemRef1DFloat";
+static constexpr const char *kCInterfaceVulkanLaunch =
+ "_mlir_ciface_vulkanLaunch";
static constexpr const char *kDeinitVulkan = "deinitVulkan";
static constexpr const char *kRunOnVulkan = "runOnVulkan";
static constexpr const char *kInitVulkan = "initVulkan";
namespace {
-/// A pass to convert vulkan launch func into a sequence of Vulkan
+/// A pass to convert vulkan launch call op into a sequence of Vulkan
/// runtime calls in the following order:
///
/// * initVulkan -- initializes vulkan runtime
-/// * bindResource -- binds resource
+/// * bindMemRef -- binds memref
/// * setBinaryShader -- sets the binary shader data
/// * setEntryPoint -- sets the entry point name
/// * setNumWorkGroups -- sets the number of a local workgroups
llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
llvmInt32Type = LLVM::LLVMType::getInt32Ty(llvmDialect);
llvmInt64Type = LLVM::LLVMType::getInt64Ty(llvmDialect);
+ initializeMemRefTypes();
+ }
+
+ void initializeMemRefTypes() {
+ // According to the MLIR doc memref argument is converted into a
+ // pointer-to-struct argument of type:
+ // template <typename Elem, size_t Rank>
+ // struct {
+ // Elem *allocated;
+ // Elem *aligned;
+ // int64_t offset;
+ // int64_t sizes[Rank]; // omitted when rank == 0
+ // int64_t strides[Rank]; // omitted when rank == 0
+ // };
+ auto llvmPtrToFloatType = getFloatType().getPointerTo();
+ auto llvmArrayOneElementSizeType =
+ LLVM::LLVMType::getArrayTy(getInt64Type(), 1);
+
+ // Create a type `!llvm<"{ float*, float*, i64, [1 x i64], [1 x i64]}">`.
+ llvmMemRef1DFloat = LLVM::LLVMType::getStructTy(
+ llvmDialect,
+ {llvmPtrToFloatType, llvmPtrToFloatType, getInt64Type(),
+ llvmArrayOneElementSizeType, llvmArrayOneElementSizeType});
}
LLVM::LLVMType getFloatType() { return llvmFloatType; }
LLVM::LLVMType getPointerType() { return llvmPointerType; }
LLVM::LLVMType getInt32Type() { return llvmInt32Type; }
LLVM::LLVMType getInt64Type() { return llvmInt64Type; }
+ LLVM::LLVMType getMemRef1DFloat() { return llvmMemRef1DFloat; }
/// Creates a LLVM global for the given `name`.
Value createEntryPointNameConstant(StringRef name, Location loc,
/// Checks whether the given LLVM::CallOp is a vulkan launch call op.
bool isVulkanLaunchCallOp(LLVM::CallOp callOp) {
return (callOp.callee() && callOp.callee().getValue() == kVulkanLaunch &&
- callOp.getNumOperands() >= 6);
+ callOp.getNumOperands() >= gpu::LaunchOp::kNumConfigOperands);
+ }
+
+ /// Checks whether the given LLVM::CallOp is a "ci_face" vulkan launch call
+ /// op.
+ bool isCInterfaceVulkanLaunchCallOp(LLVM::CallOp callOp) {
+ return (callOp.callee() &&
+ callOp.callee().getValue() == kCInterfaceVulkanLaunch &&
+ callOp.getNumOperands() >= gpu::LaunchOp::kNumConfigOperands);
}
/// Translates the given `vulkanLaunchCallOp` to the sequence of Vulkan
/// runtime calls.
void translateVulkanLaunchCall(LLVM::CallOp vulkanLaunchCallOp);
- /// Creates call to `bindResource` for each resource operand.
- void createBindResourceCalls(LLVM::CallOp vulkanLaunchCallOp,
- Value vulkanRuntiem);
+ /// Creates call to `bindMemRef` for each memref operand.
+ void createBindMemRefCalls(LLVM::CallOp vulkanLaunchCallOp,
+ Value vulkanRuntime);
+
+ /// Collects SPIRV attributes from the given `vulkanLaunchCallOp`.
+ void collectSPIRVAttributes(LLVM::CallOp vulkanLaunchCallOp);
public:
void runOnModule() override;
LLVM::LLVMType llvmPointerType;
LLVM::LLVMType llvmInt32Type;
LLVM::LLVMType llvmInt64Type;
-};
-
-/// Represents operand adaptor for vulkan launch call operation, to simplify an
-/// access to the lowered memref.
-// TODO: We should use 'emit-c-wrappers' option to lower memref type:
-// https://mlir.llvm.org/docs/ConversionToLLVMDialect/#c-compatible-wrapper-emission.
-struct VulkanLaunchOpOperandAdaptor {
- VulkanLaunchOpOperandAdaptor(ArrayRef<Value> values) { operands = values; }
- VulkanLaunchOpOperandAdaptor(const VulkanLaunchOpOperandAdaptor &) = delete;
- VulkanLaunchOpOperandAdaptor
- operator=(const VulkanLaunchOpOperandAdaptor &) = delete;
-
- /// Returns a tuple with a pointer to the memory and the size for the index-th
- /// resource.
- std::tuple<Value, Value> getResourceDescriptor1D(uint32_t index) {
- assert(index < getResourceCount1D());
- // 1D memref calling convention according to "ConversionToLLVMDialect.md":
- // 0. Allocated pointer.
- // 1. Aligned pointer.
- // 2. Offset.
- // 3. Size in dim 0.
- // 4. Stride in dim 0.
- auto offset = numConfigOps + index * loweredMemRefNumOps1D;
- return std::make_tuple(operands[offset], operands[offset + 3]);
- }
+ LLVM::LLVMType llvmMemRef1DFloat;
- /// Returns the number of resources assuming all operands lowered from
- /// 1D memref.
- uint32_t getResourceCount1D() {
- return (operands.size() - numConfigOps) / loweredMemRefNumOps1D;
- }
-
-private:
- /// The number of operands of lowered 1D memref.
- static constexpr const uint32_t loweredMemRefNumOps1D = 5;
- /// The number of the first config operands.
- static constexpr const uint32_t numConfigOps = 6;
- ArrayRef<Value> operands;
+ // TODO: Use an associative array to support multiple vulkan launch calls.
+ std::pair<StringAttr, StringAttr> spirvAttributes;
};
} // anonymous namespace
void VulkanLaunchFuncToVulkanCallsPass::runOnModule() {
initializeCachedTypes();
+
+ // Collect SPIR-V attributes such as `spirv_blob` and
+ // `spirv_entry_point_name`.
getModule().walk([this](LLVM::CallOp op) {
if (isVulkanLaunchCallOp(op))
+ collectSPIRVAttributes(op);
+ });
+
+ // Convert vulkan launch call op into a sequence of Vulkan runtime calls.
+ getModule().walk([this](LLVM::CallOp op) {
+ if (isCInterfaceVulkanLaunchCallOp(op))
translateVulkanLaunchCall(op);
});
}
-void VulkanLaunchFuncToVulkanCallsPass::createBindResourceCalls(
- LLVM::CallOp vulkanLaunchCallOp, Value vulkanRuntime) {
- if (vulkanLaunchCallOp.getNumOperands() == 6)
+void VulkanLaunchFuncToVulkanCallsPass::collectSPIRVAttributes(
+ LLVM::CallOp vulkanLaunchCallOp) {
+ // Check that `kSPIRVBinary` and `kSPIRVEntryPoint` are present in attributes
+ // for the given vulkan launch call.
+ auto spirvBlobAttr =
+ vulkanLaunchCallOp.getAttrOfType<StringAttr>(kSPIRVBlobAttrName);
+ if (!spirvBlobAttr) {
+ vulkanLaunchCallOp.emitError()
+ << "missing " << kSPIRVBlobAttrName << " attribute";
+ return signalPassFailure();
+ }
+
+ auto spirvEntryPointNameAttr =
+ vulkanLaunchCallOp.getAttrOfType<StringAttr>(kSPIRVEntryPointAttrName);
+ if (!spirvEntryPointNameAttr) {
+ vulkanLaunchCallOp.emitError()
+ << "missing " << kSPIRVEntryPointAttrName << " attribute";
+ return signalPassFailure();
+ }
+
+ spirvAttributes = std::make_pair(spirvBlobAttr, spirvEntryPointNameAttr);
+}
+
+void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls(
+ LLVM::CallOp cInterfaceVulkanLaunchCallOp, Value vulkanRuntime) {
+ if (cInterfaceVulkanLaunchCallOp.getNumOperands() ==
+ gpu::LaunchOp::kNumConfigOperands)
return;
- OpBuilder builder(vulkanLaunchCallOp);
- Location loc = vulkanLaunchCallOp.getLoc();
+ OpBuilder builder(cInterfaceVulkanLaunchCallOp);
+ Location loc = cInterfaceVulkanLaunchCallOp.getLoc();
// Create LLVM constant for the descriptor set index.
- // Bind all resources to the `0` descriptor set, the same way as `GPUToSPIRV`
+ // Bind all memrefs to the `0` descriptor set, the same way as `GPUToSPIRV`
// pass does.
Value descriptorSet = builder.create<LLVM::ConstantOp>(
loc, getInt32Type(), builder.getI32IntegerAttr(0));
- auto operands = SmallVector<Value, 32>{vulkanLaunchCallOp.getOperands()};
- VulkanLaunchOpOperandAdaptor vkLaunchOperandAdaptor(operands);
-
- for (auto resourceIdx :
- llvm::seq<uint32_t>(0, vkLaunchOperandAdaptor.getResourceCount1D())) {
+ for (auto en :
+ llvm::enumerate(cInterfaceVulkanLaunchCallOp.getOperands().drop_front(
+ gpu::LaunchOp::kNumConfigOperands))) {
// Create LLVM constant for the descriptor binding index.
Value descriptorBinding = builder.create<LLVM::ConstantOp>(
- loc, getInt32Type(), builder.getI32IntegerAttr(resourceIdx));
- // Get a pointer to the memory and size of that memory.
- auto resourceDescriptor =
- vkLaunchOperandAdaptor.getResourceDescriptor1D(resourceIdx);
- // Create call to `bindResource`.
+ loc, getInt32Type(), builder.getI32IntegerAttr(en.index()));
+ // Create call to `bindMemRef`.
builder.create<LLVM::CallOp>(
loc, ArrayRef<Type>{getVoidType()},
- builder.getSymbolRefAttr(kBindResource),
+ // TODO: Add support for memref with other ranks.
+ builder.getSymbolRefAttr(kBindMemRef1DFloat),
ArrayRef<Value>{vulkanRuntime, descriptorSet, descriptorBinding,
- // Pointer to the memory.
- std::get<0>(resourceDescriptor),
- // Size of the memory.
- std::get<1>(resourceDescriptor)});
+ en.value()});
}
}
/*isVarArg=*/false));
}
- if (!module.lookupSymbol(kBindResource)) {
+ if (!module.lookupSymbol(kBindMemRef1DFloat)) {
builder.create<LLVM::LLVMFuncOp>(
- loc, kBindResource,
- LLVM::LLVMType::getFunctionTy(
- getVoidType(),
- {getPointerType(), getInt32Type(), getInt32Type(),
- getFloatType().getPointerTo(), getInt64Type()},
- /*isVarArg=*/false));
+ loc, kBindMemRef1DFloat,
+ LLVM::LLVMType::getFunctionTy(getVoidType(),
+ {getPointerType(), getInt32Type(),
+ getInt32Type(),
+ getMemRef1DFloat().getPointerTo()},
+ /*isVarArg=*/false));
}
if (!module.lookupSymbol(kInitVulkan)) {
}
void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall(
- LLVM::CallOp vulkanLaunchCallOp) {
- OpBuilder builder(vulkanLaunchCallOp);
- Location loc = vulkanLaunchCallOp.getLoc();
-
- // Check that `kSPIRVBinary` and `kSPIRVEntryPoint` are present in attributes
- // for the given vulkan launch call.
- auto spirvBlobAttr =
- vulkanLaunchCallOp.getAttrOfType<StringAttr>(kSPIRVBlobAttrName);
- if (!spirvBlobAttr) {
- vulkanLaunchCallOp.emitError()
- << "missing " << kSPIRVBlobAttrName << " attribute";
- return signalPassFailure();
- }
-
- auto entryPointNameAttr =
- vulkanLaunchCallOp.getAttrOfType<StringAttr>(kSPIRVEntryPointAttrName);
- if (!entryPointNameAttr) {
- vulkanLaunchCallOp.emitError()
- << "missing " << kSPIRVEntryPointAttrName << " attribute";
- return signalPassFailure();
- }
-
+ LLVM::CallOp cInterfaceVulkanLaunchCallOp) {
+ OpBuilder builder(cInterfaceVulkanLaunchCallOp);
+ Location loc = cInterfaceVulkanLaunchCallOp.getLoc();
// Create call to `initVulkan`.
auto initVulkanCall = builder.create<LLVM::CallOp>(
loc, ArrayRef<Type>{getPointerType()},
// Create LLVM global with SPIR-V binary data, so we can pass a pointer with
// that data to runtime call.
Value ptrToSPIRVBinary = LLVM::createGlobalString(
- loc, builder, kSPIRVBinary, spirvBlobAttr.getValue(),
+ loc, builder, kSPIRVBinary, spirvAttributes.first.getValue(),
LLVM::Linkage::Internal, getLLVMDialect());
// Create LLVM constant for the size of SPIR-V binary shader.
Value binarySize = builder.create<LLVM::ConstantOp>(
loc, getInt32Type(),
- builder.getI32IntegerAttr(spirvBlobAttr.getValue().size()));
+ builder.getI32IntegerAttr(spirvAttributes.first.getValue().size()));
- // Create call to `bindResource` for each resource operand.
- createBindResourceCalls(vulkanLaunchCallOp, vulkanRuntime);
+ // Create call to `bindMemRef` for each memref operand.
+ createBindMemRefCalls(cInterfaceVulkanLaunchCallOp, vulkanRuntime);
// Create call to `setBinaryShader` runtime function with the given pointer to
// SPIR-V binary and binary size.
builder.getSymbolRefAttr(kSetBinaryShader),
ArrayRef<Value>{vulkanRuntime, ptrToSPIRVBinary, binarySize});
// Create LLVM global with entry point name.
- Value entryPointName =
- createEntryPointNameConstant(entryPointNameAttr.getValue(), loc, builder);
+ Value entryPointName = createEntryPointNameConstant(
+ spirvAttributes.second.getValue(), loc, builder);
// Create call to `setEntryPoint` runtime function with the given pointer to
// entry point name.
builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()},
builder.create<LLVM::CallOp>(
loc, ArrayRef<Type>{getVoidType()},
builder.getSymbolRefAttr(kSetNumWorkGroups),
- ArrayRef<Value>{vulkanRuntime, vulkanLaunchCallOp.getOperand(0),
- vulkanLaunchCallOp.getOperand(1),
- vulkanLaunchCallOp.getOperand(2)});
+ ArrayRef<Value>{vulkanRuntime, cInterfaceVulkanLaunchCallOp.getOperand(0),
+ cInterfaceVulkanLaunchCallOp.getOperand(1),
+ cInterfaceVulkanLaunchCallOp.getOperand(2)});
// Create call to `runOnVulkan` runtime function.
builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()},
// Declare runtime functions.
declareVulkanFunctions(loc);
- vulkanLaunchCallOp.erase();
+ cInterfaceVulkanLaunchCallOp.erase();
}
std::unique_ptr<mlir::OpPassBase<mlir::ModuleOp>>
// CHECK: %[[addressof_SPIRV_BIN:.*]] = llvm.mlir.addressof @SPIRV_BIN
// CHECK: %[[SPIRV_BIN_ptr:.*]] = llvm.getelementptr %[[addressof_SPIRV_BIN]]
// CHECK: %[[SPIRV_BIN_size:.*]] = llvm.mlir.constant
-// CHECK: llvm.call @bindResource(%[[Vulkan_Runtime_ptr]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm<"i8*">, !llvm.i32, !llvm.i32, !llvm<"float*">, !llvm.i64) -> !llvm.void
+// CHECK: llvm.call @bindMemRef1DFloat(%[[Vulkan_Runtime_ptr]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm<"i8*">, !llvm.i32, !llvm.i32, !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*">) -> !llvm.void
// CHECK: llvm.call @setBinaryShader(%[[Vulkan_Runtime_ptr]], %[[SPIRV_BIN_ptr]], %[[SPIRV_BIN_size]]) : (!llvm<"i8*">, !llvm<"i8*">, !llvm.i32) -> !llvm.void
// CHECK: %[[addressof_entry_point:.*]] = llvm.mlir.addressof @kernel_spv_entry_point_name
// CHECK: %[[entry_point_ptr:.*]] = llvm.getelementptr %[[addressof_entry_point]]
: (!llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm<"float*">, !llvm<"float*">, !llvm.i64, !llvm.i64, !llvm.i64) -> ()
llvm.return
}
- llvm.func @vulkanLaunch(!llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm<"float*">, !llvm<"float*">, !llvm.i64, !llvm.i64, !llvm.i64)
+ llvm.func @vulkanLaunch(%arg0: !llvm.i64, %arg1: !llvm.i64, %arg2: !llvm.i64, %arg3: !llvm.i64, %arg4: !llvm.i64, %arg5: !llvm.i64, %arg6: !llvm<"float*">, %arg7: !llvm<"float*">, %arg8: !llvm.i64, %arg9: !llvm.i64, %arg10: !llvm.i64) {
+ %0 = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
+ %1 = llvm.insertvalue %arg6, %0[0] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
+ %2 = llvm.insertvalue %arg7, %1[1] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
+ %3 = llvm.insertvalue %arg8, %2[2] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
+ %4 = llvm.insertvalue %arg9, %3[3, 0] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
+ %5 = llvm.insertvalue %arg10, %4[4, 0] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
+ %6 = llvm.mlir.constant(1 : index) : !llvm.i64
+ %7 = llvm.alloca %6 x !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> : (!llvm.i64) -> !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*">
+ llvm.store %5, %7 : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*">
+ llvm.call @_mlir_ciface_vulkanLaunch(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %7) : (!llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*">) -> ()
+ llvm.return
+ }
+ llvm.func @_mlir_ciface_vulkanLaunch(!llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*">)
}