[mlir] fix the types used during the generation of the kernel param array
[lldb.git] / mlir / lib / Conversion / GPUCommon / ConvertLaunchFuncToRuntimeCalls.cpp
1 //===- ConvertLaunchFuncToGpuRuntimeCalls.cpp - MLIR GPU lowering passes --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert gpu.launch_func op into a sequence of
10 // GPU runtime calls. As most of GPU runtimes does not have a stable published
11 // ABI, this pass uses a slim runtime layer that builds on top of the public
12 // API from GPU runtime headers.
13 //
14 //===----------------------------------------------------------------------===//
15
16 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
17
18 #include "../PassDetail.h"
19 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
20 #include "mlir/Dialect/GPU/GPUDialect.h"
21 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
22 #include "mlir/IR/Attributes.h"
23 #include "mlir/IR/Builders.h"
24 #include "mlir/IR/Function.h"
25 #include "mlir/IR/Module.h"
26 #include "mlir/IR/StandardTypes.h"
27
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/IR/DataLayout.h"
30 #include "llvm/IR/DerivedTypes.h"
31 #include "llvm/IR/Module.h"
32 #include "llvm/IR/Type.h"
33 #include "llvm/Support/Error.h"
34 #include "llvm/Support/FormatVariadic.h"
35
36 using namespace mlir;
37
38 static constexpr const char *kGpuBinaryStorageSuffix = "_gpubin_cst";
39
40 namespace {
41
42 class GpuToLLVMConversionPass
43     : public GpuToLLVMConversionPassBase<GpuToLLVMConversionPass> {
44 public:
45   GpuToLLVMConversionPass(StringRef gpuBinaryAnnotation) {
46     if (!gpuBinaryAnnotation.empty())
47       this->gpuBinaryAnnotation = gpuBinaryAnnotation.str();
48   }
49
50   // Run the dialect converter on the module.
51   void runOnOperation() override;
52 };
53
54 class FunctionCallBuilder {
55 public:
56   FunctionCallBuilder(StringRef functionName, LLVM::LLVMType returnType,
57                       ArrayRef<LLVM::LLVMType> argumentTypes)
58       : functionName(functionName),
59         functionType(LLVM::LLVMType::getFunctionTy(returnType, argumentTypes,
60                                                    /*isVarArg=*/false)) {}
61   LLVM::CallOp create(Location loc, OpBuilder &builder,
62                       ArrayRef<Value> arguments) const;
63
64 private:
65   StringRef functionName;
66   LLVM::LLVMType functionType;
67 };
68
69 template <typename OpTy>
70 class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
71 public:
72   explicit ConvertOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
73       : ConvertOpToLLVMPattern<OpTy>(typeConverter) {}
74
75 protected:
76   MLIRContext *context = &this->typeConverter.getContext();
77
78   LLVM::LLVMType llvmVoidType = LLVM::LLVMType::getVoidTy(context);
79   LLVM::LLVMType llvmPointerType = LLVM::LLVMType::getInt8PtrTy(context);
80   LLVM::LLVMType llvmPointerPointerType = llvmPointerType.getPointerTo();
81   LLVM::LLVMType llvmInt8Type = LLVM::LLVMType::getInt8Ty(context);
82   LLVM::LLVMType llvmInt32Type = LLVM::LLVMType::getInt32Ty(context);
83   LLVM::LLVMType llvmInt64Type = LLVM::LLVMType::getInt64Ty(context);
84   LLVM::LLVMType llvmIntPtrType = LLVM::LLVMType::getIntNTy(
85       context, this->typeConverter.getPointerBitwidth(0));
86
87   FunctionCallBuilder moduleLoadCallBuilder = {
88       "mgpuModuleLoad",
89       llvmPointerType /* void *module */,
90       {llvmPointerType /* void *cubin */}};
91   FunctionCallBuilder moduleGetFunctionCallBuilder = {
92       "mgpuModuleGetFunction",
93       llvmPointerType /* void *function */,
94       {
95           llvmPointerType, /* void *module */
96           llvmPointerType  /* char *name   */
97       }};
98   FunctionCallBuilder launchKernelCallBuilder = {
99       "mgpuLaunchKernel",
100       llvmVoidType,
101       {
102           llvmPointerType,        /* void* f */
103           llvmIntPtrType,         /* intptr_t gridXDim */
104           llvmIntPtrType,         /* intptr_t gridyDim */
105           llvmIntPtrType,         /* intptr_t gridZDim */
106           llvmIntPtrType,         /* intptr_t blockXDim */
107           llvmIntPtrType,         /* intptr_t blockYDim */
108           llvmIntPtrType,         /* intptr_t blockZDim */
109           llvmInt32Type,          /* unsigned int sharedMemBytes */
110           llvmPointerType,        /* void *hstream */
111           llvmPointerPointerType, /* void **kernelParams */
112           llvmPointerPointerType  /* void **extra */
113       }};
114   FunctionCallBuilder streamCreateCallBuilder = {
115       "mgpuStreamCreate", llvmPointerType /* void *stream */, {}};
116   FunctionCallBuilder streamSynchronizeCallBuilder = {
117       "mgpuStreamSynchronize",
118       llvmVoidType,
119       {llvmPointerType /* void *stream */}};
120   FunctionCallBuilder hostRegisterCallBuilder = {
121       "mgpuMemHostRegisterMemRef",
122       llvmVoidType,
123       {llvmIntPtrType /* intptr_t rank */,
124        llvmPointerType /* void *memrefDesc */,
125        llvmIntPtrType /* intptr_t elementSizeBytes */}};
126 };
127
128 /// A rewrite pattern to convert gpu.host_register operations into a GPU runtime
129 /// call. Currently it supports CUDA and ROCm (HIP).
130 class ConvertHostRegisterOpToGpuRuntimeCallPattern
131     : public ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp> {
132 public:
133   ConvertHostRegisterOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
134       : ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp>(typeConverter) {}
135
136 private:
137   LogicalResult
138   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
139                   ConversionPatternRewriter &rewriter) const override;
140 };
141
142 /// A rewrite patter to convert gpu.launch_func operations into a sequence of
143 /// GPU runtime calls. Currently it supports CUDA and ROCm (HIP).
144 ///
145 /// In essence, a gpu.launch_func operations gets compiled into the following
146 /// sequence of runtime calls:
147 ///
148 /// * moduleLoad        -- loads the module given the cubin / hsaco data
149 /// * moduleGetFunction -- gets a handle to the actual kernel function
150 /// * getStreamHelper   -- initializes a new compute stream on GPU
151 /// * launchKernel      -- launches the kernel on a stream
152 /// * streamSynchronize -- waits for operations on the stream to finish
153 ///
154 /// Intermediate data structures are allocated on the stack.
155 class ConvertLaunchFuncOpToGpuRuntimeCallPattern
156     : public ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp> {
157 public:
158   ConvertLaunchFuncOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter,
159                                              StringRef gpuBinaryAnnotation)
160       : ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp>(typeConverter),
161         gpuBinaryAnnotation(gpuBinaryAnnotation) {}
162
163 private:
164   Value generateParamsArray(gpu::LaunchFuncOp launchOp,
165                             ArrayRef<Value> operands, OpBuilder &builder) const;
166   Value generateKernelNameConstant(StringRef moduleName, StringRef name,
167                                    Location loc, OpBuilder &builder) const;
168
169   LogicalResult
170   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
171                   ConversionPatternRewriter &rewriter) const override;
172
173   llvm::SmallString<32> gpuBinaryAnnotation;
174 };
175
176 class EraseGpuModuleOpPattern : public OpRewritePattern<gpu::GPUModuleOp> {
177   using OpRewritePattern<gpu::GPUModuleOp>::OpRewritePattern;
178
179   LogicalResult matchAndRewrite(gpu::GPUModuleOp op,
180                                 PatternRewriter &rewriter) const override {
181     // GPU kernel modules are no longer necessary since we have a global
182     // constant with the CUBIN, or HSACO data.
183     rewriter.eraseOp(op);
184     return success();
185   }
186 };
187
188 } // namespace
189
190 void GpuToLLVMConversionPass::runOnOperation() {
191   LLVMTypeConverter converter(&getContext());
192   OwningRewritePatternList patterns;
193   populateStdToLLVMConversionPatterns(converter, patterns);
194   populateGpuToLLVMConversionPatterns(converter, patterns, gpuBinaryAnnotation);
195
196   LLVMConversionTarget target(getContext());
197   if (failed(applyPartialConversion(getOperation(), target, patterns)))
198     signalPassFailure();
199 }
200
201 LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder,
202                                          ArrayRef<Value> arguments) const {
203   auto module = builder.getBlock()->getParent()->getParentOfType<ModuleOp>();
204   auto function = [&] {
205     if (auto function = module.lookupSymbol<LLVM::LLVMFuncOp>(functionName))
206       return function;
207     return OpBuilder(module.getBody()->getTerminator())
208         .create<LLVM::LLVMFuncOp>(loc, functionName, functionType);
209   }();
210   return builder.create<LLVM::CallOp>(
211       loc, const_cast<LLVM::LLVMType &>(functionType).getFunctionResultType(),
212       builder.getSymbolRefAttr(function), arguments);
213 }
214
215 // Returns whether value is of LLVM type.
216 static bool isLLVMType(Value value) {
217   return value.getType().isa<LLVM::LLVMType>();
218 }
219
220 LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
221     Operation *op, ArrayRef<Value> operands,
222     ConversionPatternRewriter &rewriter) const {
223   if (!llvm::all_of(operands, isLLVMType))
224     return rewriter.notifyMatchFailure(
225         op, "Cannot convert if operands aren't of LLVM type.");
226
227   Location loc = op->getLoc();
228
229   auto memRefType = cast<gpu::HostRegisterOp>(op).value().getType();
230   auto elementType = memRefType.cast<UnrankedMemRefType>().getElementType();
231   auto elementSize = getSizeInBytes(loc, elementType, rewriter);
232
233   auto arguments =
234       typeConverter.promoteOperands(loc, op->getOperands(), operands, rewriter);
235   arguments.push_back(elementSize);
236   hostRegisterCallBuilder.create(loc, rewriter, arguments);
237
238   rewriter.eraseOp(op);
239   return success();
240 }
241
242 // Creates a struct containing all kernel parameters on the stack and returns
243 // an array of type-erased pointers to the fields of the struct. The array can
244 // then be passed to the CUDA / ROCm (HIP) kernel launch calls.
245 // The generated code is essentially as follows:
246 //
247 // %struct = alloca(sizeof(struct { Parameters... }))
248 // %array = alloca(NumParameters * sizeof(void *))
249 // for (i : [0, NumParameters))
250 //   %fieldPtr = llvm.getelementptr %struct[0, i]
251 //   llvm.store parameters[i], %fieldPtr
252 //   %elementPtr = llvm.getelementptr %array[i]
253 //   llvm.store %fieldPtr, %elementPtr
254 // return %array
255 Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateParamsArray(
256     gpu::LaunchFuncOp launchOp, ArrayRef<Value> operands,
257     OpBuilder &builder) const {
258   auto loc = launchOp.getLoc();
259   auto numKernelOperands = launchOp.getNumKernelOperands();
260   auto arguments = typeConverter.promoteOperands(
261       loc, launchOp.getOperands().take_back(numKernelOperands),
262       operands.take_back(numKernelOperands), builder);
263   auto numArguments = arguments.size();
264   SmallVector<LLVM::LLVMType, 4> argumentTypes;
265   argumentTypes.reserve(numArguments);
266   for (auto argument : arguments)
267     argumentTypes.push_back(argument.getType().cast<LLVM::LLVMType>());
268   auto structType = LLVM::LLVMType::createStructTy(argumentTypes, StringRef());
269   auto one = builder.create<LLVM::ConstantOp>(loc, llvmInt32Type,
270                                               builder.getI32IntegerAttr(1));
271   auto structPtr = builder.create<LLVM::AllocaOp>(
272       loc, structType.getPointerTo(), one, /*alignment=*/0);
273   auto arraySize = builder.create<LLVM::ConstantOp>(
274       loc, llvmInt32Type, builder.getI32IntegerAttr(numArguments));
275   auto arrayPtr = builder.create<LLVM::AllocaOp>(loc, llvmPointerPointerType,
276                                                  arraySize, /*alignment=*/0);
277   auto zero = builder.create<LLVM::ConstantOp>(loc, llvmInt32Type,
278                                                builder.getI32IntegerAttr(0));
279   for (auto en : llvm::enumerate(arguments)) {
280     auto index = builder.create<LLVM::ConstantOp>(
281         loc, llvmInt32Type, builder.getI32IntegerAttr(en.index()));
282     auto fieldPtr = builder.create<LLVM::GEPOp>(
283         loc, argumentTypes[en.index()].getPointerTo(), structPtr,
284         ArrayRef<Value>{zero, index.getResult()});
285     builder.create<LLVM::StoreOp>(loc, en.value(), fieldPtr);
286     auto elementPtr = builder.create<LLVM::GEPOp>(loc, llvmPointerPointerType,
287                                                   arrayPtr, index.getResult());
288     auto casted =
289         builder.create<LLVM::BitcastOp>(loc, llvmPointerType, fieldPtr);
290     builder.create<LLVM::StoreOp>(loc, casted, elementPtr);
291   }
292   return arrayPtr;
293 }
294
295 // Generates an LLVM IR dialect global that contains the name of the given
296 // kernel function as a C string, and returns a pointer to its beginning.
297 // The code is essentially:
298 //
299 // llvm.global constant @kernel_name("function_name\00")
300 // func(...) {
301 //   %0 = llvm.addressof @kernel_name
302 //   %1 = llvm.constant (0 : index)
303 //   %2 = llvm.getelementptr %0[%1, %1] : !llvm<"i8*">
304 // }
305 Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateKernelNameConstant(
306     StringRef moduleName, StringRef name, Location loc,
307     OpBuilder &builder) const {
308   // Make sure the trailing zero is included in the constant.
309   std::vector<char> kernelName(name.begin(), name.end());
310   kernelName.push_back('\0');
311
312   std::string globalName =
313       std::string(llvm::formatv("{0}_{1}_kernel_name", moduleName, name));
314   return LLVM::createGlobalString(
315       loc, builder, globalName, StringRef(kernelName.data(), kernelName.size()),
316       LLVM::Linkage::Internal);
317 }
318
319 // Emits LLVM IR to launch a kernel function. Expects the module that contains
320 // the compiled kernel function as a cubin in the 'nvvm.cubin' attribute, or a
321 // hsaco in the 'rocdl.hsaco' attribute of the kernel function in the IR.
322 //
323 // %0 = call %binarygetter
324 // %1 = call %moduleLoad(%0)
325 // %2 = <see generateKernelNameConstant>
326 // %3 = call %moduleGetFunction(%1, %2)
327 // %4 = call %streamCreate()
328 // %5 = <see generateParamsArray>
329 // call %launchKernel(%3, <launchOp operands 0..5>, 0, %4, %5, nullptr)
330 // call %streamSynchronize(%4)
331 LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
332     Operation *op, ArrayRef<Value> operands,
333     ConversionPatternRewriter &rewriter) const {
334   if (!llvm::all_of(operands, isLLVMType))
335     return rewriter.notifyMatchFailure(
336         op, "Cannot convert if operands aren't of LLVM type.");
337
338   auto launchOp = cast<gpu::LaunchFuncOp>(op);
339   Location loc = launchOp.getLoc();
340
341   // Create an LLVM global with CUBIN extracted from the kernel annotation and
342   // obtain a pointer to the first byte in it.
343   auto kernelModule = SymbolTable::lookupNearestSymbolFrom<gpu::GPUModuleOp>(
344       launchOp, launchOp.getKernelModuleName());
345   assert(kernelModule && "expected a kernel module");
346
347   auto binaryAttr = kernelModule.getAttrOfType<StringAttr>(gpuBinaryAnnotation);
348   if (!binaryAttr) {
349     kernelModule.emitOpError()
350         << "missing " << gpuBinaryAnnotation << " attribute";
351     return failure();
352   }
353
354   SmallString<128> nameBuffer(kernelModule.getName());
355   nameBuffer.append(kGpuBinaryStorageSuffix);
356   Value data =
357       LLVM::createGlobalString(loc, rewriter, nameBuffer.str(),
358                                binaryAttr.getValue(), LLVM::Linkage::Internal);
359
360   auto module = moduleLoadCallBuilder.create(loc, rewriter, data);
361   // Get the function from the module. The name corresponds to the name of
362   // the kernel function.
363   auto kernelName = generateKernelNameConstant(
364       launchOp.getKernelModuleName(), launchOp.getKernelName(), loc, rewriter);
365   auto function = moduleGetFunctionCallBuilder.create(
366       loc, rewriter, {module.getResult(0), kernelName});
367   auto zero = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
368                                                 rewriter.getI32IntegerAttr(0));
369   // Grab the global stream needed for execution.
370   auto stream = streamCreateCallBuilder.create(loc, rewriter, {});
371   // Create array of pointers to kernel arguments.
372   auto kernelParams = generateParamsArray(launchOp, operands, rewriter);
373   auto nullpointer = rewriter.create<LLVM::NullOp>(loc, llvmPointerPointerType);
374   launchKernelCallBuilder.create(
375       loc, rewriter,
376       {function.getResult(0), launchOp.gridSizeX(), launchOp.gridSizeY(),
377        launchOp.gridSizeZ(), launchOp.blockSizeX(), launchOp.blockSizeY(),
378        launchOp.blockSizeZ(), zero, /* sharedMemBytes */
379        stream.getResult(0),         /* stream */
380        kernelParams,                /* kernel params */
381        nullpointer /* extra */});
382   streamSynchronizeCallBuilder.create(loc, rewriter, stream.getResult(0));
383
384   rewriter.eraseOp(op);
385   return success();
386 }
387
388 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
389 mlir::createGpuToLLVMConversionPass(StringRef gpuBinaryAnnotation) {
390   return std::make_unique<GpuToLLVMConversionPass>(gpuBinaryAnnotation);
391 }
392
393 void mlir::populateGpuToLLVMConversionPatterns(
394     LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
395     StringRef gpuBinaryAnnotation) {
396   patterns.insert<ConvertHostRegisterOpToGpuRuntimeCallPattern>(converter);
397   patterns.insert<ConvertLaunchFuncOpToGpuRuntimeCallPattern>(
398       converter, gpuBinaryAnnotation);
399   patterns.insert<EraseGpuModuleOpPattern>(&converter.getContext());
400 }