[mlir] make the bitwidth of device side index computations configurable (reland)
authorTobias Gysi <tobias.gysi@inf.ethz.ch>
Mon, 29 Jun 2020 09:32:46 +0000 (11:32 +0200)
committerTobias Gysi <tobias.gysi@inf.ethz.ch>
Mon, 29 Jun 2020 10:22:39 +0000 (12:22 +0200)
Summary:
The patch makes the index type lowering of the GPU to NVVM/ROCDL conversion configurable. It introduces a pass option that controls the bitwidth used when lowering index computations and uses the LowerToLLVMOptions structure to control the Standard to LLVM lowering.

This commit fixes a use-after-free bug introduced by the reverted commit d10b1a3. It implements the following changes:
- Added a getDefaultOptions method to the LowerToLLVMOptions struct that returns a reference to statically allocated default options.
- Use the getDefaultOptions method to provide default LowerToLLVMOptions (instead of an initializer list).
- Added comments to clarify the required lifetime of the LowerToLLVMOptions

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D82475

mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h
mlir/include/mlir/Conversion/Passes.td
mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h
mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir

index 5dbfce9..1af1305 100644 (file)
@@ -8,6 +8,7 @@
 #ifndef MLIR_CONVERSION_GPUTONVVM_GPUTONVVMPASS_H_
 #define MLIR_CONVERSION_GPUTONVVM_GPUTONVVMPASS_H_
 
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
 #include <memory>
 
 namespace mlir {
@@ -24,9 +25,11 @@ class GPUModuleOp;
 void populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
                                          OwningRewritePatternList &patterns);
 
-/// Creates a pass that lowers GPU dialect operations to NVVM counterparts.
-std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
-createLowerGpuOpsToNVVMOpsPass();
+/// Creates a pass that lowers GPU dialect operations to NVVM counterparts. The
+/// index bitwidth used for the lowering of the device side index computations
+/// is configurable.
+std::unique_ptr<OperationPass<gpu::GPUModuleOp>> createLowerGpuOpsToNVVMOpsPass(
+    unsigned indexBitwidth = kDeriveIndexBitwidthFromDataLayout);
 
 } // namespace mlir
 
index 1722ae6..677782b 100644 (file)
@@ -8,6 +8,7 @@
 #ifndef MLIR_CONVERSION_GPUTOROCDL_GPUTOROCDLPASS_H_
 #define MLIR_CONVERSION_GPUTOROCDL_GPUTOROCDLPASS_H_
 
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
 #include <memory>
 
 namespace mlir {
@@ -25,9 +26,12 @@ class GPUModuleOp;
 void populateGpuToROCDLConversionPatterns(LLVMTypeConverter &converter,
                                           OwningRewritePatternList &patterns);
 
-/// Creates a pass that lowers GPU dialect operations to ROCDL counterparts.
+/// Creates a pass that lowers GPU dialect operations to ROCDL counterparts. The
+/// index bitwidth used for the lowering of the device side index computations
+/// is configurable.
 std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
-createLowerGpuOpsToROCDLOpsPass();
+createLowerGpuOpsToROCDLOpsPass(
+    unsigned indexBitwidth = kDeriveIndexBitwidthFromDataLayout);
 
 } // namespace mlir
 
index 89b63e8..1c3b776 100644 (file)
@@ -100,6 +100,11 @@ def ConvertGpuLaunchFuncToGpuRuntimeCalls : Pass<"launch-func-to-gpu-runtime",
 def ConvertGpuOpsToNVVMOps : Pass<"convert-gpu-to-nvvm", "gpu::GPUModuleOp"> {
   let summary = "Generate NVVM operations for gpu operations";
   let constructor = "mlir::createLowerGpuOpsToNVVMOpsPass()";
+  let options = [
+    Option<"indexBitwidth", "index-bitwidth", "unsigned",
+           /*default=kDeriveIndexBitwidthFromDataLayout*/"0",
+           "Bitwidth of the index type, 0 to use size of machine word">
+  ];
 }
 
 //===----------------------------------------------------------------------===//
@@ -109,6 +114,11 @@ def ConvertGpuOpsToNVVMOps : Pass<"convert-gpu-to-nvvm", "gpu::GPUModuleOp"> {
 def ConvertGpuOpsToROCDLOps : Pass<"convert-gpu-to-rocdl", "gpu::GPUModuleOp"> {
   let summary = "Generate ROCDL operations for gpu operations";
   let constructor = "mlir::createLowerGpuOpsToROCDLOpsPass()";
+  let options = [
+    Option<"indexBitwidth", "index-bitwidth", "unsigned",
+           /*default=kDeriveIndexBitwidthFromDataLayout*/"0",
+           "Bitwidth of the index type, 0 to use size of machine word">
+  ];
 }
 
 //===----------------------------------------------------------------------===//
index c963410..72c6079 100644 (file)
@@ -15,6 +15,7 @@
 #ifndef MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVM_H
 #define MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVM_H
 
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
 #include "mlir/Transforms/DialectConversion.h"
 
 namespace llvm {
@@ -35,22 +36,6 @@ class LLVMDialect;
 class LLVMType;
 } // namespace LLVM
 
-/// Set of callbacks that allows the customization of LLVMTypeConverter.
-struct LLVMTypeConverterCustomization {
-  using CustomCallback = std::function<LogicalResult(LLVMTypeConverter &, Type,
-                                                     SmallVectorImpl<Type> &)>;
-
-  /// Customize the type conversion of function arguments.
-  CustomCallback funcArgConverter;
-
-  /// Used to determine the bitwidth of the LLVM integer type that the index
-  /// type gets lowered to. Defaults to deriving the size from the data layout.
-  unsigned indexBitwidth;
-
-  /// Initialize customization to default callbacks.
-  LLVMTypeConverterCustomization();
-};
-
 /// Callback to convert function argument types. It converts a MemRef function
 /// argument to a list of non-aggregate types containing descriptor
 /// information, and an UnrankedmemRef function argument to a list containing
@@ -75,13 +60,11 @@ class LLVMTypeConverter : public TypeConverter {
 public:
   using TypeConverter::convertType;
 
-  /// Create an LLVMTypeConverter using the default
-  /// LLVMTypeConverterCustomization.
+  /// Create an LLVMTypeConverter using the default LowerToLLVMOptions.
   LLVMTypeConverter(MLIRContext *ctx);
 
-  /// Create an LLVMTypeConverter using 'custom' customizations.
-  LLVMTypeConverter(MLIRContext *ctx,
-                    const LLVMTypeConverterCustomization &custom);
+  /// Create an LLVMTypeConverter using custom LowerToLLVMOptions.
+  LLVMTypeConverter(MLIRContext *ctx, const LowerToLLVMOptions &options);
 
   /// Convert a function type.  The arguments and results are converted one by
   /// one and results are packed into a wrapped LLVM IR structure type. `result`
@@ -127,7 +110,7 @@ public:
   LLVM::LLVMType getIndexType();
 
   /// Gets the bitwidth of the index type when converted to LLVM.
-  unsigned getIndexTypeBitwidth() { return customizations.indexBitwidth; }
+  unsigned getIndexTypeBitwidth() { return options.indexBitwidth; }
 
   /// Gets the pointer bitwidth.
   unsigned getPointerBitwidth(unsigned addressSpace = 0);
@@ -196,8 +179,8 @@ private:
   // Convert a 1D vector type into an LLVM vector type.
   Type convertVectorType(VectorType type);
 
-  /// Callbacks for customizing the type conversion.
-  LLVMTypeConverterCustomization customizations;
+  /// Options for customizing the llvm lowering.
+  LowerToLLVMOptions options;
 };
 
 /// Helper class to produce LLVM dialect operations extracting or inserting
@@ -398,12 +381,17 @@ public:
                            SmallVectorImpl<Value> &sizes);
 };
 
-/// Base class for operation conversions targeting the LLVM IR dialect. Provides
-/// conversion patterns with access to an LLVMTypeConverter.
+/// Base class for operation conversions targeting the LLVM IR dialect. It
+/// provides the conversion patterns with access to the LLVMTypeConverter and
+/// the LowerToLLVMOptions. The class captures the LLVMTypeConverter and the
+/// LowerToLLVMOptions by reference meaning the references have to remain alive
+/// during the entire pattern lifetime.
 class ConvertToLLVMPattern : public ConversionPattern {
 public:
   ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context,
                        LLVMTypeConverter &typeConverter,
+                       const LowerToLLVMOptions &options =
+                           LowerToLLVMOptions::getDefaultOptions(),
                        PatternBenefit benefit = 1);
 
   /// Returns the LLVM dialect.
@@ -455,6 +443,9 @@ public:
 protected:
   /// Reference to the type converter, with potential extensions.
   LLVMTypeConverter &typeConverter;
+
+  /// Reference to the llvm lowering options.
+  const LowerToLLVMOptions &options;
 };
 
 /// Utility class for operation conversions targeting the LLVM dialect that
@@ -463,10 +454,11 @@ template <typename OpTy>
 class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
 public:
   ConvertOpToLLVMPattern(LLVMTypeConverter &typeConverter,
+                         const LowerToLLVMOptions &options,
                          PatternBenefit benefit = 1)
       : ConvertToLLVMPattern(OpTy::getOperationName(),
                              &typeConverter.getContext(), typeConverter,
-                             benefit) {}
+                             options, benefit) {}
 };
 
 namespace LLVM {
index 5479f18..1f76b92 100644 (file)
 namespace mlir {
 class LLVMTypeConverter;
 class ModuleOp;
-template <typename T> class OperationPass;
+template <typename T>
+class OperationPass;
 class OwningRewritePatternList;
 
+/// Value to pass as bitwidth for the index type when the converter is expected
+/// to derive the bitwidth from the LLVM data layout.
+static constexpr unsigned kDeriveIndexBitwidthFromDataLayout = 0;
+
+/// Options to control the Standard dialect to LLVM lowering. The struct is used
+/// to share lowering options between passes, patterns, and type converter.
+struct LowerToLLVMOptions {
+  bool useBarePtrCallConv = false;
+  bool emitCWrappers = false;
+  unsigned indexBitwidth = kDeriveIndexBitwidthFromDataLayout;
+  /// Use aligned_alloc for heap allocations.
+  bool useAlignedAlloc = false;
+
+  /// Get a statically allocated copy of the default LowerToLLVMOptions.
+  static const LowerToLLVMOptions &getDefaultOptions() {
+    static LowerToLLVMOptions options;
+    return options;
+  }
+};
+
 /// Collect a set of patterns to convert memory-related operations from the
 /// Standard dialect to the LLVM dialect, excluding non-memory-related
 /// operations and FuncOp.
 void populateStdToLLVMMemoryConversionPatterns(
     LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
-    bool useAlignedAlloc);
+    const LowerToLLVMOptions &options);
 
 /// Collect a set of patterns to convert from the Standard dialect to the LLVM
 /// dialect, excluding the memory-related operations.
 void populateStdToLLVMNonMemoryConversionPatterns(
-    LLVMTypeConverter &converter, OwningRewritePatternList &patterns);
+    LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
+    const LowerToLLVMOptions &options);
 
 /// Collect the default pattern to convert a FuncOp to the LLVM dialect. If
 /// `emitCWrappers` is set, the pattern will also produce functions
 /// that pass memref descriptors by pointer-to-structure in addition to the
 /// default unpacked form.
-void populateStdToLLVMDefaultFuncOpConversionPattern(
+void populateStdToLLVMFuncOpConversionPattern(
     LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
-    bool emitCWrappers = false);
+    const LowerToLLVMOptions &options);
 
-/// Collect a set of default patterns to convert from the Standard dialect to
-/// LLVM.
-void populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter,
-                                         OwningRewritePatternList &patterns,
-                                         bool emitCWrappers = false,
-                                         bool useAlignedAlloc = false);
-
-/// Collect a set of patterns to convert from the Standard dialect to
-/// LLVM using the bare pointer calling convention for MemRef function
-/// arguments.
-void populateStdToLLVMBarePtrConversionPatterns(
+/// Collect the patterns to convert from the Standard dialect to LLVM. The
+/// conversion patterns capture the LLVMTypeConverter and the LowerToLLVMOptions
+/// by reference meaning the references have to remain alive during the entire
+/// pattern lifetime.
+void populateStdToLLVMConversionPatterns(
     LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
-    bool useAlignedAlloc);
-
-/// Value to pass as bitwidth for the index type when the converter is expected
-/// to derive the bitwidth from the LLVM data layout.
-static constexpr unsigned kDeriveIndexBitwidthFromDataLayout = 0;
-
-struct LowerToLLVMOptions {
-  bool useBarePtrCallConv = false;
-  bool emitCWrappers = false;
-  unsigned indexBitwidth = kDeriveIndexBitwidthFromDataLayout;
-  /// Use aligned_alloc for heap allocations.
-  bool useAlignedAlloc = false;
-};
+    const LowerToLLVMOptions &options =
+        LowerToLLVMOptions::getDefaultOptions());
 
 /// Creates a pass to convert the Standard dialect into the LLVMIR dialect.
 /// stdlib malloc/free is used by default for allocating memrefs allocated with
 /// std.alloc, while LLVM's alloca is used for those allocated with std.alloca.
 std::unique_ptr<OperationPass<ModuleOp>>
-createLowerToLLVMPass(const LowerToLLVMOptions &options = {
-                          /*useBarePtrCallConv=*/false, /*emitCWrappers=*/false,
-                          /*indexBitwidth=*/kDeriveIndexBitwidthFromDataLayout,
-                          /*useAlignedAlloc=*/false});
+createLowerToLLVMPass(const LowerToLLVMOptions &options =
+                          LowerToLLVMOptions::getDefaultOptions());
 
 } // namespace mlir
 
index e4fabe4..feaa382 100644 (file)
@@ -30,7 +30,6 @@ using namespace mlir;
 
 namespace {
 
-
 struct GPUShuffleOpLowering : public ConvertToLLVMPattern {
   explicit GPUShuffleOpLowering(LLVMTypeConverter &lowering_)
       : ConvertToLLVMPattern(gpu::ShuffleOp::getOperationName(),
@@ -97,17 +96,27 @@ struct GPUShuffleOpLowering : public ConvertToLLVMPattern {
 ///
 /// This pass only handles device code and is not meant to be run on GPU host
 /// code.
-class LowerGpuOpsToNVVMOpsPass
+struct LowerGpuOpsToNVVMOpsPass
     : public ConvertGpuOpsToNVVMOpsBase<LowerGpuOpsToNVVMOpsPass> {
-public:
+  LowerGpuOpsToNVVMOpsPass() = default;
+  LowerGpuOpsToNVVMOpsPass(unsigned indexBitwidth) {
+    this->indexBitwidth = indexBitwidth;
+  }
+
   void runOnOperation() override {
     gpu::GPUModuleOp m = getOperation();
 
+    /// Customize the bitwidth used for the device side index computations.
+    LowerToLLVMOptions options = {/*useBarePtrCallConv =*/false,
+                                  /*emitCWrappers =*/true,
+                                  /*indexBitwidth =*/indexBitwidth,
+                                  /*useAlignedAlloc =*/false};
+
     /// MemRef conversion for GPU to NVVM lowering. The GPU dialect uses memory
     /// space 5 for private memory attributions, but NVVM represents private
     /// memory allocations as local `alloca`s in the default address space. This
     /// converter drops the private memory space to support the use case above.
-    LLVMTypeConverter converter(m.getContext());
+    LLVMTypeConverter converter(m.getContext(), options);
     converter.addConversion([&](MemRefType type) -> Optional<Type> {
       if (type.getMemorySpace() != gpu::GPUDialect::getPrivateAddressSpace())
         return llvm::None;
@@ -176,6 +185,6 @@ void mlir::populateGpuToNVVMConversionPatterns(
 }
 
 std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
-mlir::createLowerGpuOpsToNVVMOpsPass() {
-  return std::make_unique<LowerGpuOpsToNVVMOpsPass>();
+mlir::createLowerGpuOpsToNVVMOpsPass(unsigned indexBitwidth) {
+  return std::make_unique<LowerGpuOpsToNVVMOpsPass>(indexBitwidth);
 }
index 2381d61..8a1d10f 100644 (file)
@@ -41,13 +41,22 @@ namespace {
 //
 // This pass only handles device code and is not meant to be run on GPU host
 // code.
-class LowerGpuOpsToROCDLOpsPass
+struct LowerGpuOpsToROCDLOpsPass
     : public ConvertGpuOpsToROCDLOpsBase<LowerGpuOpsToROCDLOpsPass> {
-public:
+  LowerGpuOpsToROCDLOpsPass() = default;
+  LowerGpuOpsToROCDLOpsPass(unsigned indexBitwidth) {
+    this->indexBitwidth = indexBitwidth;
+  }
+
   void runOnOperation() override {
     gpu::GPUModuleOp m = getOperation();
 
-    LLVMTypeConverter converter(m.getContext());
+    /// Customize the bitwidth used for the device side index computations.
+    LowerToLLVMOptions options = {/*useBarePtrCallConv =*/false,
+                                  /*emitCWrappers =*/true,
+                                  /*indexBitwidth =*/indexBitwidth,
+                                  /*useAlignedAlloc =*/false};
+    LLVMTypeConverter converter(m.getContext(), options);
 
     OwningRewritePatternList patterns;
 
@@ -106,6 +115,6 @@ void mlir::populateGpuToROCDLConversionPatterns(
 }
 
 std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
-mlir::createLowerGpuOpsToROCDLOpsPass() {
-  return std::make_unique<LowerGpuOpsToROCDLOpsPass>();
+mlir::createLowerGpuOpsToROCDLOpsPass(unsigned indexBitwidth) {
+  return std::make_unique<LowerGpuOpsToROCDLOpsPass>(indexBitwidth);
 }
index 164816d..eb49794 100644 (file)
@@ -52,11 +52,6 @@ static LLVM::LLVMType unwrap(Type type) {
   return wrappedLLVMType;
 }
 
-/// Initialize customization to default callbacks.
-LLVMTypeConverterCustomization::LLVMTypeConverterCustomization()
-    : funcArgConverter(structFuncArgTypeConverter),
-      indexBitwidth(kDeriveIndexBitwidthFromDataLayout) {}
-
 /// Callback to convert function argument types. It converts a MemRef function
 /// argument to a list of non-aggregate types containing descriptor
 /// information, and an UnrankedmemRef function argument to a list containing
@@ -123,19 +118,19 @@ LogicalResult mlir::barePtrFuncArgTypeConverter(LLVMTypeConverter &converter,
   return success();
 }
 
-/// Create an LLVMTypeConverter using default LLVMTypeConverterCustomization.
+/// Create an LLVMTypeConverter using default LowerToLLVMOptions.
 LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx)
-    : LLVMTypeConverter(ctx, LLVMTypeConverterCustomization()) {}
+    : LLVMTypeConverter(ctx, LowerToLLVMOptions::getDefaultOptions()) {}
 
-/// Create an LLVMTypeConverter using 'custom' customizations.
-LLVMTypeConverter::LLVMTypeConverter(
-    MLIRContext *ctx, const LLVMTypeConverterCustomization &customs)
+/// Create an LLVMTypeConverter using custom LowerToLLVMOptions.
+LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
+                                     const LowerToLLVMOptions &options)
     : llvmDialect(ctx->getRegisteredDialect<LLVM::LLVMDialect>()),
-      customizations(customs) {
+      options(options) {
   assert(llvmDialect && "LLVM IR dialect is not registered");
   module = &llvmDialect->getLLVMModule();
-  if (customizations.indexBitwidth == kDeriveIndexBitwidthFromDataLayout)
-    customizations.indexBitwidth =
+  if (options.indexBitwidth == kDeriveIndexBitwidthFromDataLayout)
+    this->options.indexBitwidth =
         module->getDataLayout().getPointerSizeInBits();
 
   // Register conversions for the standard types.
@@ -267,11 +262,15 @@ SmallVector<Type, 2> LLVMTypeConverter::convertUnrankedMemRefSignature() {
 LLVM::LLVMType LLVMTypeConverter::convertFunctionSignature(
     FunctionType type, bool isVariadic,
     LLVMTypeConverter::SignatureConversion &result) {
+  // Select the argument converter depending on the calling convetion.
+  auto funcArgConverter = options.useBarePtrCallConv
+                              ? barePtrFuncArgTypeConverter
+                              : structFuncArgTypeConverter;
   // Convert argument types one by one and check for errors.
   for (auto &en : llvm::enumerate(type.getInputs())) {
     Type type = en.value();
     SmallVector<Type, 8> converted;
-    if (failed(customizations.funcArgConverter(*this, type, converted)))
+    if (failed(funcArgConverter(*this, type, converted)))
       return {};
     result.addInputs(en.index(), converted);
   }
@@ -401,10 +400,11 @@ Type LLVMTypeConverter::convertVectorType(VectorType type) {
 
 ConvertToLLVMPattern::ConvertToLLVMPattern(StringRef rootOpName,
                                            MLIRContext *context,
-                                           LLVMTypeConverter &typeConverter_,
+                                           LLVMTypeConverter &typeConverter,
+                                           const LowerToLLVMOptions &options,
                                            PatternBenefit benefit)
-    : ConversionPattern(rootOpName, benefit, typeConverter_, context),
-      typeConverter(typeConverter_) {}
+    : ConversionPattern(rootOpName, benefit, typeConverter, context),
+      typeConverter(typeConverter), options(options) {}
 
 /*============================================================================*/
 /* StructBuilder implementation                                               */
@@ -1101,8 +1101,10 @@ protected:
 /// information.
 static constexpr StringRef kEmitIfaceAttrName = "llvm.emit_c_interface";
 struct FuncOpConversion : public FuncOpConversionBase {
-  FuncOpConversion(LLVMTypeConverter &converter, bool emitCWrappers)
-      : FuncOpConversionBase(converter), emitWrappers(emitCWrappers) {}
+  FuncOpConversion(LLVMTypeConverter &converter,
+                   const LowerToLLVMOptions &options)
+      : FuncOpConversionBase(converter, options) {}
+  using ConvertOpToLLVMPattern<FuncOp>::options;
 
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@@ -1113,7 +1115,8 @@ struct FuncOpConversion : public FuncOpConversionBase {
     if (!newFuncOp)
       return failure();
 
-    if (emitWrappers || funcOp.getAttrOfType<UnitAttr>(kEmitIfaceAttrName)) {
+    if (options.emitCWrappers ||
+        funcOp.getAttrOfType<UnitAttr>(kEmitIfaceAttrName)) {
       if (newFuncOp.isExternal())
         wrapExternalFunction(rewriter, op->getLoc(), typeConverter, funcOp,
                              newFuncOp);
@@ -1125,11 +1128,6 @@ struct FuncOpConversion : public FuncOpConversionBase {
     rewriter.eraseOp(op);
     return success();
   }
-
-private:
-  /// If true, also create the adaptor functions having signatures compatible
-  /// with those produced by clang.
-  const bool emitWrappers;
 };
 
 /// FuncOp legalization pattern that converts MemRef arguments to bare pointers
@@ -1587,11 +1585,11 @@ struct AllocLikeOpLowering : public ConvertOpToLLVMPattern<AllocLikeOp> {
   using ConvertOpToLLVMPattern<AllocLikeOp>::getIndexType;
   using ConvertOpToLLVMPattern<AllocLikeOp>::typeConverter;
   using ConvertOpToLLVMPattern<AllocLikeOp>::getVoidPtrType;
+  using ConvertOpToLLVMPattern<AllocLikeOp>::options;
 
   explicit AllocLikeOpLowering(LLVMTypeConverter &converter,
-                               bool useAlignedAlloc = false)
-      : ConvertOpToLLVMPattern<AllocLikeOp>(converter),
-        useAlignedAlloc(useAlignedAlloc) {}
+                               const LowerToLLVMOptions &options)
+      : ConvertOpToLLVMPattern<AllocLikeOp>(converter, options) {}
 
   LogicalResult match(Operation *op) const override {
     MemRefType memRefType = cast<AllocLikeOp>(op).getType();
@@ -1758,7 +1756,7 @@ struct AllocLikeOpLowering : public ConvertOpToLLVMPattern<AllocLikeOp> {
   /// allocation size to be a multiple of alignment,
   Optional<int64_t> getAllocationAlignment(AllocOp allocOp) const {
     // No alignment can be used for the 'malloc' call itself.
-    if (!useAlignedAlloc)
+    if (!options.useAlignedAlloc)
       return None;
 
     if (allocOp.alignment())
@@ -1932,16 +1930,14 @@ struct AllocLikeOpLowering : public ConvertOpToLLVMPattern<AllocLikeOp> {
   }
 
 protected:
-  /// Use aligned_alloc instead of malloc for all heap allocations.
-  bool useAlignedAlloc;
   /// The minimum alignment to use with aligned_alloc (has to be a power of 2).
   uint64_t kMinAlignedAllocAlignment = 16UL;
 };
 
 struct AllocOpLowering : public AllocLikeOpLowering<AllocOp> {
   explicit AllocOpLowering(LLVMTypeConverter &converter,
-                           bool useAlignedAlloc = false)
-      : AllocLikeOpLowering<AllocOp>(converter, useAlignedAlloc) {}
+                           const LowerToLLVMOptions &options)
+      : AllocLikeOpLowering<AllocOp>(converter, options) {}
 };
 
 using AllocaOpLowering = AllocLikeOpLowering<AllocaOp>;
@@ -2113,8 +2109,9 @@ struct CallIndirectOpLowering : public CallOpInterfaceLowering<CallIndirectOp> {
 struct DeallocOpLowering : public ConvertOpToLLVMPattern<DeallocOp> {
   using ConvertOpToLLVMPattern<DeallocOp>::ConvertOpToLLVMPattern;
 
-  explicit DeallocOpLowering(LLVMTypeConverter &converter)
-      : ConvertOpToLLVMPattern<DeallocOp>(converter) {}
+  explicit DeallocOpLowering(LLVMTypeConverter &converter,
+                             const LowerToLLVMOptions &options)
+      : ConvertOpToLLVMPattern<DeallocOp>(converter, options) {}
 
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@@ -3140,7 +3137,8 @@ private:
 
 /// Collect a set of patterns to convert from the Standard dialect to LLVM.
 void mlir::populateStdToLLVMNonMemoryConversionPatterns(
-    LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
+    LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
+    const LowerToLLVMOptions &options) {
   // FIXME: this should be tablegen'ed
   // clang-format off
   patterns.insert<
@@ -3203,13 +3201,13 @@ void mlir::populateStdToLLVMNonMemoryConversionPatterns(
       UnsignedRemIOpLowering,
       UnsignedShiftRightOpLowering,
       XOrOpLowering,
-      ZeroExtendIOpLowering>(converter);
+      ZeroExtendIOpLowering>(converter, options);
   // clang-format on
 }
 
 void mlir::populateStdToLLVMMemoryConversionPatterns(
     LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
-    bool useAlignedAlloc) {
+    const LowerToLLVMOptions &options) {
   // clang-format off
   patterns.insert<
       AssumeAlignmentOpLowering,
@@ -3219,41 +3217,26 @@ void mlir::populateStdToLLVMMemoryConversionPatterns(
       MemRefCastOpLowering,
       StoreOpLowering,
       SubViewOpLowering,
-      ViewOpLowering>(converter);
-  patterns.insert<
-      AllocOpLowering
-      >(converter, useAlignedAlloc);
+      ViewOpLowering,
+      AllocOpLowering>(converter, options);
   // clang-format on
 }
 
-void mlir::populateStdToLLVMDefaultFuncOpConversionPattern(
+void mlir::populateStdToLLVMFuncOpConversionPattern(
     LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
-    bool emitCWrappers) {
-  patterns.insert<FuncOpConversion>(converter, emitCWrappers);
+    const LowerToLLVMOptions &options) {
+  if (options.useBarePtrCallConv)
+    patterns.insert<BarePtrFuncOpConversion>(converter, options);
+  else
+    patterns.insert<FuncOpConversion>(converter, options);
 }
 
 void mlir::populateStdToLLVMConversionPatterns(
     LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
-    bool emitCWrappers, bool useAlignedAlloc) {
-  populateStdToLLVMDefaultFuncOpConversionPattern(converter, patterns,
-                                                  emitCWrappers);
-  populateStdToLLVMNonMemoryConversionPatterns(converter, patterns);
-  populateStdToLLVMMemoryConversionPatterns(converter, patterns,
-                                            useAlignedAlloc);
-}
-
-static void populateStdToLLVMBarePtrFuncOpConversionPattern(
-    LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
-  patterns.insert<BarePtrFuncOpConversion>(converter);
-}
-
-void mlir::populateStdToLLVMBarePtrConversionPatterns(
-    LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
-    bool useAlignedAlloc) {
-  populateStdToLLVMBarePtrFuncOpConversionPattern(converter, patterns);
-  populateStdToLLVMNonMemoryConversionPatterns(converter, patterns);
-  populateStdToLLVMMemoryConversionPatterns(converter, patterns,
-                                            useAlignedAlloc);
+    const LowerToLLVMOptions &options) {
+  populateStdToLLVMFuncOpConversionPattern(converter, patterns, options);
+  populateStdToLLVMNonMemoryConversionPatterns(converter, patterns, options);
+  populateStdToLLVMMemoryConversionPatterns(converter, patterns, options);
 }
 
 // Create an LLVM IR structure type if there is more than one result.
@@ -3343,19 +3326,12 @@ struct LLVMLoweringPass : public ConvertStandardToLLVMBase<LLVMLoweringPass> {
 
     ModuleOp m = getOperation();
 
-    LLVMTypeConverterCustomization customs;
-    customs.funcArgConverter = useBarePtrCallConv ? barePtrFuncArgTypeConverter
-                                                  : structFuncArgTypeConverter;
-    customs.indexBitwidth = indexBitwidth;
-    LLVMTypeConverter typeConverter(&getContext(), customs);
+    LowerToLLVMOptions options = {useBarePtrCallConv, emitCWrappers,
+                                  indexBitwidth, useAlignedAlloc};
+    LLVMTypeConverter typeConverter(&getContext(), options);
 
     OwningRewritePatternList patterns;
-    if (useBarePtrCallConv)
-      populateStdToLLVMBarePtrConversionPatterns(typeConverter, patterns,
-                                                 useAlignedAlloc);
-    else
-      populateStdToLLVMConversionPatterns(typeConverter, patterns,
-                                          emitCWrappers, useAlignedAlloc);
+    populateStdToLLVMConversionPatterns(typeConverter, patterns, options);
 
     LLVMConversionTarget target(getContext());
     if (failed(applyPartialConversion(m, target, patterns)))
index 20d166b..273b227 100644 (file)
@@ -1,36 +1,52 @@
 // RUN: mlir-opt %s -convert-gpu-to-nvvm -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -convert-gpu-to-nvvm='index-bitwidth=32' -split-input-file | FileCheck --check-prefix=CHECK32 %s
 
 gpu.module @test_module {
   // CHECK-LABEL: func @gpu_index_ops()
+  // CHECK32-LABEL: func @gpu_index_ops()
   func @gpu_index_ops()
       -> (index, index, index, index, index, index,
           index, index, index, index, index, index) {
+    // CHECK32-NOT: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
+
     // CHECK: = nvvm.read.ptx.sreg.tid.x : !llvm.i32
+    // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
     %tIdX = "gpu.thread_id"() {dimension = "x"} : () -> (index)
     // CHECK: = nvvm.read.ptx.sreg.tid.y : !llvm.i32
+    // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
     %tIdY = "gpu.thread_id"() {dimension = "y"} : () -> (index)
     // CHECK: = nvvm.read.ptx.sreg.tid.z : !llvm.i32
+    // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
     %tIdZ = "gpu.thread_id"() {dimension = "z"} : () -> (index)
 
     // CHECK: = nvvm.read.ptx.sreg.ntid.x : !llvm.i32
+    // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
     %bDimX = "gpu.block_dim"() {dimension = "x"} : () -> (index)
     // CHECK: = nvvm.read.ptx.sreg.ntid.y : !llvm.i32
+    // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
     %bDimY = "gpu.block_dim"() {dimension = "y"} : () -> (index)
     // CHECK: = nvvm.read.ptx.sreg.ntid.z : !llvm.i32
+    // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
     %bDimZ = "gpu.block_dim"() {dimension = "z"} : () -> (index)
 
     // CHECK: = nvvm.read.ptx.sreg.ctaid.x : !llvm.i32
+    // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
     %bIdX = "gpu.block_id"() {dimension = "x"} : () -> (index)
     // CHECK: = nvvm.read.ptx.sreg.ctaid.y : !llvm.i32
+    // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
     %bIdY = "gpu.block_id"() {dimension = "y"} : () -> (index)
     // CHECK: = nvvm.read.ptx.sreg.ctaid.z : !llvm.i32
+    // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
     %bIdZ = "gpu.block_id"() {dimension = "z"} : () -> (index)
 
     // CHECK: = nvvm.read.ptx.sreg.nctaid.x : !llvm.i32
+    // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
     %gDimX = "gpu.grid_dim"() {dimension = "x"} : () -> (index)
     // CHECK: = nvvm.read.ptx.sreg.nctaid.y : !llvm.i32
+    // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
     %gDimY = "gpu.grid_dim"() {dimension = "y"} : () -> (index)
     // CHECK: = nvvm.read.ptx.sreg.nctaid.z : !llvm.i32
+    // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
     %gDimZ = "gpu.grid_dim"() {dimension = "z"} : () -> (index)
 
     std.return %tIdX, %tIdY, %tIdZ, %bDimX, %bDimY, %bDimZ,
@@ -43,6 +59,21 @@ gpu.module @test_module {
 // -----
 
 gpu.module @test_module {
+  // CHECK-LABEL: func @gpu_index_comp
+  // CHECK32-LABEL: func @gpu_index_comp
+  func @gpu_index_comp(%idx : index) -> index {
+    // CHECK: = llvm.add %{{.*}}, %{{.*}} : !llvm.i64
+    // CHECK32: = llvm.add %{{.*}}, %{{.*}} : !llvm.i32
+    %0 = addi %idx, %idx : index
+    // CHECK: llvm.return %{{.*}} : !llvm.i64
+    // CHECK32: llvm.return %{{.*}} : !llvm.i32
+    std.return %0 : index
+  }
+}
+
+// -----
+
+gpu.module @test_module {
   // CHECK-LABEL: func @gpu_all_reduce_op()
   gpu.func @gpu_all_reduce_op() {
     %arg0 = constant 1.0 : f32
index 61becff..a7565bb 100644 (file)
@@ -1,36 +1,52 @@
 // RUN: mlir-opt %s -convert-gpu-to-rocdl -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -convert-gpu-to-rocdl='index-bitwidth=32' -split-input-file | FileCheck --check-prefix=CHECK32 %s
 
 gpu.module @test_module {
   // CHECK-LABEL: func @gpu_index_ops()
+  // CHECK32-LABEL: func @gpu_index_ops()
   func @gpu_index_ops()
       -> (index, index, index, index, index, index,
           index, index, index, index, index, index) {
+    // CHECK32-NOT: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
+
     // CHECK: rocdl.workitem.id.x : !llvm.i32
+    // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
     %tIdX = "gpu.thread_id"() {dimension = "x"} : () -> (index)
     // CHECK: rocdl.workitem.id.y : !llvm.i32
+    // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
     %tIdY = "gpu.thread_id"() {dimension = "y"} : () -> (index)
     // CHECK: rocdl.workitem.id.z : !llvm.i32
+    // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
     %tIdZ = "gpu.thread_id"() {dimension = "z"} : () -> (index)
 
     // CHECK: rocdl.workgroup.dim.x : !llvm.i32
+    // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
     %bDimX = "gpu.block_dim"() {dimension = "x"} : () -> (index)
     // CHECK: rocdl.workgroup.dim.y : !llvm.i32
+    // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
     %bDimY = "gpu.block_dim"() {dimension = "y"} : () -> (index)
     // CHECK: rocdl.workgroup.dim.z : !llvm.i32
+    // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
     %bDimZ = "gpu.block_dim"() {dimension = "z"} : () -> (index)
 
     // CHECK: rocdl.workgroup.id.x : !llvm.i32
+    // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
     %bIdX = "gpu.block_id"() {dimension = "x"} : () -> (index)
     // CHECK: rocdl.workgroup.id.y : !llvm.i32
+    // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
     %bIdY = "gpu.block_id"() {dimension = "y"} : () -> (index)
     // CHECK: rocdl.workgroup.id.z : !llvm.i32
+    // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
     %bIdZ = "gpu.block_id"() {dimension = "z"} : () -> (index)
 
     // CHECK: rocdl.grid.dim.x : !llvm.i32
+    // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
     %gDimX = "gpu.grid_dim"() {dimension = "x"} : () -> (index)
     // CHECK: rocdl.grid.dim.y : !llvm.i32
+    // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
     %gDimY = "gpu.grid_dim"() {dimension = "y"} : () -> (index)
     // CHECK: rocdl.grid.dim.z : !llvm.i32
+    // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
     %gDimZ = "gpu.grid_dim"() {dimension = "z"} : () -> (index)
 
     std.return %tIdX, %tIdY, %tIdZ, %bDimX, %bDimY, %bDimZ,
@@ -43,6 +59,21 @@ gpu.module @test_module {
 // -----
 
 gpu.module @test_module {
+  // CHECK-LABEL: func @gpu_index_comp
+  // CHECK32-LABEL: func @gpu_index_comp
+  func @gpu_index_comp(%idx : index) -> index {
+    // CHECK: = llvm.add %{{.*}}, %{{.*}} : !llvm.i64
+    // CHECK32: = llvm.add %{{.*}}, %{{.*}} : !llvm.i32
+    %0 = addi %idx, %idx : index
+    // CHECK: llvm.return %{{.*}} : !llvm.i64
+    // CHECK32: llvm.return %{{.*}} : !llvm.i32
+    std.return %0 : index
+  }
+}
+
+// -----
+
+gpu.module @test_module {
   // CHECK-LABEL: func @gpu_sync()
   func @gpu_sync() {
     // CHECK: rocdl.barrier