6a5ea1bb06c91ea9b0b139869f887c11ede91b53
[lldb.git] / mlir / lib / Conversion / SPIRVToLLVM / ConvertSPIRVToLLVM.cpp
1 //===- ConvertSPIRVToLLVM.cpp - SPIR-V dialect to LLVM dialect conversion -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert SPIR-V dialect to LLVM dialect.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h"
14 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
15 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
17 #include "mlir/Dialect/SPIRV/LayoutUtils.h"
18 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
19 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
20 #include "mlir/Dialect/StandardOps/IR/Ops.h"
21 #include "mlir/IR/Module.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/Support/LogicalResult.h"
24 #include "mlir/Transforms/DialectConversion.h"
25 #include "llvm/Support/Debug.h"
26 #include "llvm/Support/FormatVariadic.h"
27
28 #define DEBUG_TYPE "spirv-to-llvm-pattern"
29
30 using namespace mlir;
31
32 //===----------------------------------------------------------------------===//
33 // Utility functions
34 //===----------------------------------------------------------------------===//
35
36 /// Returns true if the given type is a signed integer or vector type.
37 static bool isSignedIntegerOrVector(Type type) {
38   if (type.isSignedInteger())
39     return true;
40   if (auto vecType = type.dyn_cast<VectorType>())
41     return vecType.getElementType().isSignedInteger();
42   return false;
43 }
44
45 /// Returns true if the given type is an unsigned integer or vector type
46 static bool isUnsignedIntegerOrVector(Type type) {
47   if (type.isUnsignedInteger())
48     return true;
49   if (auto vecType = type.dyn_cast<VectorType>())
50     return vecType.getElementType().isUnsignedInteger();
51   return false;
52 }
53
54 /// Returns the bit width of integer, float or vector of float or integer values
55 static unsigned getBitWidth(Type type) {
56   assert((type.isIntOrFloat() || type.isa<VectorType>()) &&
57          "bitwidth is not supported for this type");
58   if (type.isIntOrFloat())
59     return type.getIntOrFloatBitWidth();
60   auto vecType = type.dyn_cast<VectorType>();
61   auto elementType = vecType.getElementType();
62   assert(elementType.isIntOrFloat() &&
63          "only integers and floats have a bitwidth");
64   return elementType.getIntOrFloatBitWidth();
65 }
66
67 /// Returns the bit width of LLVMType integer or vector.
68 static unsigned getLLVMTypeBitWidth(LLVM::LLVMType type) {
69   return type.isVectorTy() ? type.getVectorElementType().getIntegerBitWidth()
70                            : type.getIntegerBitWidth();
71 }
72
73 /// Creates `IntegerAttribute` with all bits set for given type
74 static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder) {
75   if (auto vecType = type.dyn_cast<VectorType>()) {
76     auto integerType = vecType.getElementType().cast<IntegerType>();
77     return builder.getIntegerAttr(integerType, -1);
78   }
79   auto integerType = type.cast<IntegerType>();
80   return builder.getIntegerAttr(integerType, -1);
81 }
82
83 /// Creates `llvm.mlir.constant` with all bits set for the given type.
84 static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType,
85                                       PatternRewriter &rewriter) {
86   if (srcType.isa<VectorType>()) {
87     return rewriter.create<LLVM::ConstantOp>(
88         loc, dstType,
89         SplatElementsAttr::get(srcType.cast<ShapedType>(),
90                                minusOneIntegerAttribute(srcType, rewriter)));
91   }
92   return rewriter.create<LLVM::ConstantOp>(
93       loc, dstType, minusOneIntegerAttribute(srcType, rewriter));
94 }
95
96 /// Creates `llvm.mlir.constant` with a floating-point scalar or vector value.
97 static Value createFPConstant(Location loc, Type srcType, Type dstType,
98                               PatternRewriter &rewriter, double value) {
99   if (auto vecType = srcType.dyn_cast<VectorType>()) {
100     auto floatType = vecType.getElementType().cast<FloatType>();
101     return rewriter.create<LLVM::ConstantOp>(
102         loc, dstType,
103         SplatElementsAttr::get(vecType,
104                                rewriter.getFloatAttr(floatType, value)));
105   }
106   auto floatType = srcType.cast<FloatType>();
107   return rewriter.create<LLVM::ConstantOp>(
108       loc, dstType, rewriter.getFloatAttr(floatType, value));
109 }
110
111 /// Utility function for bitfiled ops:
112 ///   - `BitFieldInsert`
113 ///   - `BitFieldSExtract`
114 ///   - `BitFieldUExtract`
115 /// Truncates or extends the value. If the bitwidth of the value is the same as
116 /// `dstType` bitwidth, the value remains unchanged.
117 static Value optionallyTruncateOrExtend(Location loc, Value value, Type dstType,
118                                         PatternRewriter &rewriter) {
119   auto srcType = value.getType();
120   auto llvmType = dstType.cast<LLVM::LLVMType>();
121   unsigned targetBitWidth = getLLVMTypeBitWidth(llvmType);
122   unsigned valueBitWidth =
123       srcType.isa<LLVM::LLVMType>()
124           ? getLLVMTypeBitWidth(srcType.cast<LLVM::LLVMType>())
125           : getBitWidth(srcType);
126
127   if (valueBitWidth < targetBitWidth)
128     return rewriter.create<LLVM::ZExtOp>(loc, llvmType, value);
129   // If the bit widths of `Count` and `Offset` are greater than the bit width
130   // of the target type, they are truncated. Truncation is safe since `Count`
131   // and `Offset` must be no more than 64 for op behaviour to be defined. Hence,
132   // both values can be expressed in 8 bits.
133   if (valueBitWidth > targetBitWidth)
134     return rewriter.create<LLVM::TruncOp>(loc, llvmType, value);
135   return value;
136 }
137
138 /// Broadcasts the value to vector with `numElements` number of elements.
139 static Value broadcast(Location loc, Value toBroadcast, unsigned numElements,
140                        LLVMTypeConverter &typeConverter,
141                        ConversionPatternRewriter &rewriter) {
142   auto vectorType = VectorType::get(numElements, toBroadcast.getType());
143   auto llvmVectorType = typeConverter.convertType(vectorType);
144   auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32));
145   Value broadcasted = rewriter.create<LLVM::UndefOp>(loc, llvmVectorType);
146   for (unsigned i = 0; i < numElements; ++i) {
147     auto index = rewriter.create<LLVM::ConstantOp>(
148         loc, llvmI32Type, rewriter.getI32IntegerAttr(i));
149     broadcasted = rewriter.create<LLVM::InsertElementOp>(
150         loc, llvmVectorType, broadcasted, toBroadcast, index);
151   }
152   return broadcasted;
153 }
154
155 /// Broadcasts the value. If `srcType` is a scalar, the value remains unchanged.
156 static Value optionallyBroadcast(Location loc, Value value, Type srcType,
157                                  LLVMTypeConverter &typeConverter,
158                                  ConversionPatternRewriter &rewriter) {
159   if (auto vectorType = srcType.dyn_cast<VectorType>()) {
160     unsigned numElements = vectorType.getNumElements();
161     return broadcast(loc, value, numElements, typeConverter, rewriter);
162   }
163   return value;
164 }
165
166 /// Utility function for bitfiled ops: `BitFieldInsert`, `BitFieldSExtract` and
167 /// `BitFieldUExtract`.
168 /// Broadcast `Offset` and `Count` to match the type of `Base`. If `Base` is of
169 /// a vector type, construct a vector that has:
170 ///  - same number of elements as `Base`
171 ///  - each element has the type that is the same as the type of `Offset` or
172 ///    `Count`
173 ///  - each element has the same value as `Offset` or `Count`
174 /// Then cast `Offset` and `Count` if their bit width is different
175 /// from `Base` bit width.
176 static Value processCountOrOffset(Location loc, Value value, Type srcType,
177                                   Type dstType, LLVMTypeConverter &converter,
178                                   ConversionPatternRewriter &rewriter) {
179   Value broadcasted =
180       optionallyBroadcast(loc, value, srcType, converter, rewriter);
181   return optionallyTruncateOrExtend(loc, broadcasted, dstType, rewriter);
182 }
183
184 /// Converts SPIR-V struct with a regular (according to `VulkanLayoutUtils`)
185 /// offset to LLVM struct. Otherwise, the conversion is not supported.
186 static Optional<Type>
187 convertStructTypeWithOffset(spirv::StructType type,
188                             LLVMTypeConverter &converter) {
189   if (type != VulkanLayoutUtils::decorateType(type))
190     return llvm::None;
191
192   auto elementsVector = llvm::to_vector<8>(
193       llvm::map_range(type.getElementTypes(), [&](Type elementType) {
194         return converter.convertType(elementType).cast<LLVM::LLVMType>();
195       }));
196   return LLVM::LLVMType::getStructTy(type.getContext(), elementsVector,
197                                      /*isPacked=*/false);
198 }
199
200 /// Converts SPIR-V struct with no offset to packed LLVM struct.
201 static Type convertStructTypePacked(spirv::StructType type,
202                                     LLVMTypeConverter &converter) {
203   auto elementsVector = llvm::to_vector<8>(
204       llvm::map_range(type.getElementTypes(), [&](Type elementType) {
205         return converter.convertType(elementType).cast<LLVM::LLVMType>();
206       }));
207   return LLVM::LLVMType::getStructTy(type.getContext(), elementsVector,
208                                      /*isPacked=*/true);
209 }
210
211 /// Creates LLVM dialect constant with the given value.
212 static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter,
213                                  unsigned value) {
214   return rewriter.create<LLVM::ConstantOp>(
215       loc, LLVM::LLVMType::getInt32Ty(rewriter.getContext()),
216       rewriter.getIntegerAttr(rewriter.getI32Type(), value));
217 }
218
219 /// Utility for `spv.Load` and `spv.Store` conversion.
220 static LogicalResult replaceWithLoadOrStore(Operation *op,
221                                             ConversionPatternRewriter &rewriter,
222                                             LLVMTypeConverter &typeConverter,
223                                             unsigned alignment, bool isVolatile,
224                                             bool isNonTemporal) {
225   if (auto loadOp = dyn_cast<spirv::LoadOp>(op)) {
226     auto dstType = typeConverter.convertType(loadOp.getType());
227     if (!dstType)
228       return failure();
229     rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
230         loadOp, dstType, loadOp.ptr(), alignment, isVolatile, isNonTemporal);
231     return success();
232   }
233   auto storeOp = cast<spirv::StoreOp>(op);
234   rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, storeOp.value(),
235                                              storeOp.ptr(), alignment,
236                                              isVolatile, isNonTemporal);
237   return success();
238 }
239
240 //===----------------------------------------------------------------------===//
241 // Type conversion
242 //===----------------------------------------------------------------------===//
243
244 /// Converts SPIR-V array type to LLVM array. Natural stride (according to
245 /// `VulkanLayoutUtils`) is also mapped to LLVM array. This has to be respected
246 /// when converting ops that manipulate array types.
247 static Optional<Type> convertArrayType(spirv::ArrayType type,
248                                        TypeConverter &converter) {
249   unsigned stride = type.getArrayStride();
250   Type elementType = type.getElementType();
251   auto sizeInBytes = elementType.cast<spirv::SPIRVType>().getSizeInBytes();
252   if (stride != 0 &&
253       !(sizeInBytes.hasValue() && sizeInBytes.getValue() == stride))
254     return llvm::None;
255
256   auto llvmElementType =
257       converter.convertType(elementType).cast<LLVM::LLVMType>();
258   unsigned numElements = type.getNumElements();
259   return LLVM::LLVMType::getArrayTy(llvmElementType, numElements);
260 }
261
262 /// Converts SPIR-V pointer type to LLVM pointer. Pointer's storage class is not
263 /// modelled at the moment.
264 static Type convertPointerType(spirv::PointerType type,
265                                TypeConverter &converter) {
266   auto pointeeType =
267       converter.convertType(type.getPointeeType()).cast<LLVM::LLVMType>();
268   return pointeeType.getPointerTo();
269 }
270
271 /// Converts SPIR-V runtime array to LLVM array. Since LLVM allows indexing over
272 /// the bounds, the runtime array is converted to a 0-sized LLVM array. There is
273 /// no modelling of array stride at the moment.
274 static Optional<Type> convertRuntimeArrayType(spirv::RuntimeArrayType type,
275                                               TypeConverter &converter) {
276   if (type.getArrayStride() != 0)
277     return llvm::None;
278   auto elementType =
279       converter.convertType(type.getElementType()).cast<LLVM::LLVMType>();
280   return LLVM::LLVMType::getArrayTy(elementType, 0);
281 }
282
283 /// Converts SPIR-V struct to LLVM struct. There is no support of structs with
284 /// member decorations. Also, only natural offset is supported.
285 static Optional<Type> convertStructType(spirv::StructType type,
286                                         LLVMTypeConverter &converter) {
287   SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations;
288   type.getMemberDecorations(memberDecorations);
289   if (!memberDecorations.empty())
290     return llvm::None;
291   if (type.hasOffset())
292     return convertStructTypeWithOffset(type, converter);
293   return convertStructTypePacked(type, converter);
294 }
295
296 //===----------------------------------------------------------------------===//
297 // Operation conversion
298 //===----------------------------------------------------------------------===//
299
300 namespace {
301
302 class AccessChainPattern : public SPIRVToLLVMConversion<spirv::AccessChainOp> {
303 public:
304   using SPIRVToLLVMConversion<spirv::AccessChainOp>::SPIRVToLLVMConversion;
305
306   LogicalResult
307   matchAndRewrite(spirv::AccessChainOp op, ArrayRef<Value> operands,
308                   ConversionPatternRewriter &rewriter) const override {
309     auto dstType = typeConverter.convertType(op.component_ptr().getType());
310     if (!dstType)
311       return failure();
312     // To use GEP we need to add a first 0 index to go through the pointer.
313     auto indices = llvm::to_vector<4>(op.indices());
314     Type indexType = op.indices().front().getType();
315     auto llvmIndexType = typeConverter.convertType(indexType);
316     if (!llvmIndexType)
317       return failure();
318     Value zero = rewriter.create<LLVM::ConstantOp>(
319         op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0));
320     indices.insert(indices.begin(), zero);
321     rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, dstType, op.base_ptr(),
322                                              indices);
323     return success();
324   }
325 };
326
327 class AddressOfPattern : public SPIRVToLLVMConversion<spirv::AddressOfOp> {
328 public:
329   using SPIRVToLLVMConversion<spirv::AddressOfOp>::SPIRVToLLVMConversion;
330
331   LogicalResult
332   matchAndRewrite(spirv::AddressOfOp op, ArrayRef<Value> operands,
333                   ConversionPatternRewriter &rewriter) const override {
334     auto dstType = typeConverter.convertType(op.pointer().getType());
335     if (!dstType)
336       return failure();
337     rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(
338         op, dstType.cast<LLVM::LLVMType>(), op.variable());
339     return success();
340   }
341 };
342
343 class BitFieldInsertPattern
344     : public SPIRVToLLVMConversion<spirv::BitFieldInsertOp> {
345 public:
346   using SPIRVToLLVMConversion<spirv::BitFieldInsertOp>::SPIRVToLLVMConversion;
347
348   LogicalResult
349   matchAndRewrite(spirv::BitFieldInsertOp op, ArrayRef<Value> operands,
350                   ConversionPatternRewriter &rewriter) const override {
351     auto srcType = op.getType();
352     auto dstType = typeConverter.convertType(srcType);
353     if (!dstType)
354       return failure();
355     Location loc = op.getLoc();
356
357     // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
358     Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType,
359                                         typeConverter, rewriter);
360     Value count = processCountOrOffset(loc, op.count(), srcType, dstType,
361                                        typeConverter, rewriter);
362
363     // Create a mask with bits set outside [Offset, Offset + Count - 1].
364     Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
365     Value maskShiftedByCount =
366         rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count);
367     Value negated = rewriter.create<LLVM::XOrOp>(loc, dstType,
368                                                  maskShiftedByCount, minusOne);
369     Value maskShiftedByCountAndOffset =
370         rewriter.create<LLVM::ShlOp>(loc, dstType, negated, offset);
371     Value mask = rewriter.create<LLVM::XOrOp>(
372         loc, dstType, maskShiftedByCountAndOffset, minusOne);
373
374     // Extract unchanged bits from the `Base`  that are outside of
375     // [Offset, Offset + Count - 1]. Then `or` with shifted `Insert`.
376     Value baseAndMask =
377         rewriter.create<LLVM::AndOp>(loc, dstType, op.base(), mask);
378     Value insertShiftedByOffset =
379         rewriter.create<LLVM::ShlOp>(loc, dstType, op.insert(), offset);
380     rewriter.replaceOpWithNewOp<LLVM::OrOp>(op, dstType, baseAndMask,
381                                             insertShiftedByOffset);
382     return success();
383   }
384 };
385
386 /// Converts SPIR-V ConstantOp with scalar or vector type.
387 class ConstantScalarAndVectorPattern
388     : public SPIRVToLLVMConversion<spirv::ConstantOp> {
389 public:
390   using SPIRVToLLVMConversion<spirv::ConstantOp>::SPIRVToLLVMConversion;
391
392   LogicalResult
393   matchAndRewrite(spirv::ConstantOp constOp, ArrayRef<Value> operands,
394                   ConversionPatternRewriter &rewriter) const override {
395     auto srcType = constOp.getType();
396     if (!srcType.isa<VectorType>() && !srcType.isIntOrFloat())
397       return failure();
398
399     auto dstType = typeConverter.convertType(srcType);
400     if (!dstType)
401       return failure();
402
403     // SPIR-V constant can be a signed/unsigned integer, which has to be
404     // casted to signless integer when converting to LLVM dialect. Removing the
405     // sign bit may have unexpected behaviour. However, it is better to handle
406     // it case-by-case, given that the purpose of the conversion is not to
407     // cover all possible corner cases.
408     if (isSignedIntegerOrVector(srcType) ||
409         isUnsignedIntegerOrVector(srcType)) {
410       auto *context = rewriter.getContext();
411       auto signlessType = IntegerType::get(getBitWidth(srcType), context);
412
413       if (srcType.isa<VectorType>()) {
414         auto dstElementsAttr = constOp.value().cast<DenseIntElementsAttr>();
415         rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
416             constOp, dstType,
417             dstElementsAttr.mapValues(
418                 signlessType, [&](const APInt &value) { return value; }));
419         return success();
420       }
421       auto srcAttr = constOp.value().cast<IntegerAttr>();
422       auto dstAttr = rewriter.getIntegerAttr(signlessType, srcAttr.getValue());
423       rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(constOp, dstType, dstAttr);
424       return success();
425     }
426     rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(constOp, dstType, operands,
427                                                   constOp.getAttrs());
428     return success();
429   }
430 };
431
432 class BitFieldSExtractPattern
433     : public SPIRVToLLVMConversion<spirv::BitFieldSExtractOp> {
434 public:
435   using SPIRVToLLVMConversion<spirv::BitFieldSExtractOp>::SPIRVToLLVMConversion;
436
437   LogicalResult
438   matchAndRewrite(spirv::BitFieldSExtractOp op, ArrayRef<Value> operands,
439                   ConversionPatternRewriter &rewriter) const override {
440     auto srcType = op.getType();
441     auto dstType = typeConverter.convertType(srcType);
442     if (!dstType)
443       return failure();
444     Location loc = op.getLoc();
445
446     // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
447     Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType,
448                                         typeConverter, rewriter);
449     Value count = processCountOrOffset(loc, op.count(), srcType, dstType,
450                                        typeConverter, rewriter);
451
452     // Create a constant that holds the size of the `Base`.
453     IntegerType integerType;
454     if (auto vecType = srcType.dyn_cast<VectorType>())
455       integerType = vecType.getElementType().cast<IntegerType>();
456     else
457       integerType = srcType.cast<IntegerType>();
458
459     auto baseSize = rewriter.getIntegerAttr(integerType, getBitWidth(srcType));
460     Value size =
461         srcType.isa<VectorType>()
462             ? rewriter.create<LLVM::ConstantOp>(
463                   loc, dstType,
464                   SplatElementsAttr::get(srcType.cast<ShapedType>(), baseSize))
465             : rewriter.create<LLVM::ConstantOp>(loc, dstType, baseSize);
466
467     // Shift `Base` left by [sizeof(Base) - (Count + Offset)], so that the bit
468     // at Offset + Count - 1 is the most significant bit now.
469     Value countPlusOffset =
470         rewriter.create<LLVM::AddOp>(loc, dstType, count, offset);
471     Value amountToShiftLeft =
472         rewriter.create<LLVM::SubOp>(loc, dstType, size, countPlusOffset);
473     Value baseShiftedLeft = rewriter.create<LLVM::ShlOp>(
474         loc, dstType, op.base(), amountToShiftLeft);
475
476     // Shift the result right, filling the bits with the sign bit.
477     Value amountToShiftRight =
478         rewriter.create<LLVM::AddOp>(loc, dstType, offset, amountToShiftLeft);
479     rewriter.replaceOpWithNewOp<LLVM::AShrOp>(op, dstType, baseShiftedLeft,
480                                               amountToShiftRight);
481     return success();
482   }
483 };
484
485 class BitFieldUExtractPattern
486     : public SPIRVToLLVMConversion<spirv::BitFieldUExtractOp> {
487 public:
488   using SPIRVToLLVMConversion<spirv::BitFieldUExtractOp>::SPIRVToLLVMConversion;
489
490   LogicalResult
491   matchAndRewrite(spirv::BitFieldUExtractOp op, ArrayRef<Value> operands,
492                   ConversionPatternRewriter &rewriter) const override {
493     auto srcType = op.getType();
494     auto dstType = typeConverter.convertType(srcType);
495     if (!dstType)
496       return failure();
497     Location loc = op.getLoc();
498
499     // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
500     Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType,
501                                         typeConverter, rewriter);
502     Value count = processCountOrOffset(loc, op.count(), srcType, dstType,
503                                        typeConverter, rewriter);
504
505     // Create a mask with bits set at [0, Count - 1].
506     Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
507     Value maskShiftedByCount =
508         rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count);
509     Value mask = rewriter.create<LLVM::XOrOp>(loc, dstType, maskShiftedByCount,
510                                               minusOne);
511
512     // Shift `Base` by `Offset` and apply the mask on it.
513     Value shiftedBase =
514         rewriter.create<LLVM::LShrOp>(loc, dstType, op.base(), offset);
515     rewriter.replaceOpWithNewOp<LLVM::AndOp>(op, dstType, shiftedBase, mask);
516     return success();
517   }
518 };
519
520 class BranchConversionPattern : public SPIRVToLLVMConversion<spirv::BranchOp> {
521 public:
522   using SPIRVToLLVMConversion<spirv::BranchOp>::SPIRVToLLVMConversion;
523
524   LogicalResult
525   matchAndRewrite(spirv::BranchOp branchOp, ArrayRef<Value> operands,
526                   ConversionPatternRewriter &rewriter) const override {
527     rewriter.replaceOpWithNewOp<LLVM::BrOp>(branchOp, operands,
528                                             branchOp.getTarget());
529     return success();
530   }
531 };
532
533 class BranchConditionalConversionPattern
534     : public SPIRVToLLVMConversion<spirv::BranchConditionalOp> {
535 public:
536   using SPIRVToLLVMConversion<
537       spirv::BranchConditionalOp>::SPIRVToLLVMConversion;
538
539   LogicalResult
540   matchAndRewrite(spirv::BranchConditionalOp op, ArrayRef<Value> operands,
541                   ConversionPatternRewriter &rewriter) const override {
542     // If branch weights exist, map them to 32-bit integer vector.
543     ElementsAttr branchWeights = nullptr;
544     if (auto weights = op.branch_weights()) {
545       VectorType weightType = VectorType::get(2, rewriter.getI32Type());
546       branchWeights =
547           DenseElementsAttr::get(weightType, weights.getValue().getValue());
548     }
549
550     rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
551         op, op.condition(), op.getTrueBlockArguments(),
552         op.getFalseBlockArguments(), branchWeights, op.getTrueBlock(),
553         op.getFalseBlock());
554     return success();
555   }
556 };
557
558 /// Converts `spv.CompositeExtract` to `llvm.extractvalue` if the container type
559 /// is an aggregate type (struct or array). Otherwise, converts to
560 /// `llvm.extractelement` that operates on vectors.
561 class CompositeExtractPattern
562     : public SPIRVToLLVMConversion<spirv::CompositeExtractOp> {
563 public:
564   using SPIRVToLLVMConversion<spirv::CompositeExtractOp>::SPIRVToLLVMConversion;
565
566   LogicalResult
567   matchAndRewrite(spirv::CompositeExtractOp op, ArrayRef<Value> operands,
568                   ConversionPatternRewriter &rewriter) const override {
569     auto dstType = this->typeConverter.convertType(op.getType());
570     if (!dstType)
571       return failure();
572
573     Type containerType = op.composite().getType();
574     if (containerType.isa<VectorType>()) {
575       Location loc = op.getLoc();
576       IntegerAttr value = op.indices()[0].cast<IntegerAttr>();
577       Value index = createI32ConstantOf(loc, rewriter, value.getInt());
578       rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
579           op, dstType, op.composite(), index);
580       return success();
581     }
582     rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(
583         op, dstType, op.composite(), op.indices());
584     return success();
585   }
586 };
587
588 /// Converts `spv.CompositeInsert` to `llvm.insertvalue` if the container type
589 /// is an aggregate type (struct or array). Otherwise, converts to
590 /// `llvm.insertelement` that operates on vectors.
591 class CompositeInsertPattern
592     : public SPIRVToLLVMConversion<spirv::CompositeInsertOp> {
593 public:
594   using SPIRVToLLVMConversion<spirv::CompositeInsertOp>::SPIRVToLLVMConversion;
595
596   LogicalResult
597   matchAndRewrite(spirv::CompositeInsertOp op, ArrayRef<Value> operands,
598                   ConversionPatternRewriter &rewriter) const override {
599     auto dstType = this->typeConverter.convertType(op.getType());
600     if (!dstType)
601       return failure();
602
603     Type containerType = op.composite().getType();
604     if (containerType.isa<VectorType>()) {
605       Location loc = op.getLoc();
606       IntegerAttr value = op.indices()[0].cast<IntegerAttr>();
607       Value index = createI32ConstantOf(loc, rewriter, value.getInt());
608       rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
609           op, dstType, op.composite(), op.object(), index);
610       return success();
611     }
612     rewriter.replaceOpWithNewOp<LLVM::InsertValueOp>(
613         op, dstType, op.composite(), op.object(), op.indices());
614     return success();
615   }
616 };
617
618 /// Converts SPIR-V operations that have straightforward LLVM equivalent
619 /// into LLVM dialect operations.
620 template <typename SPIRVOp, typename LLVMOp>
621 class DirectConversionPattern : public SPIRVToLLVMConversion<SPIRVOp> {
622 public:
623   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
624
625   LogicalResult
626   matchAndRewrite(SPIRVOp operation, ArrayRef<Value> operands,
627                   ConversionPatternRewriter &rewriter) const override {
628     auto dstType = this->typeConverter.convertType(operation.getType());
629     if (!dstType)
630       return failure();
631     rewriter.template replaceOpWithNewOp<LLVMOp>(operation, dstType, operands,
632                                                  operation.getAttrs());
633     return success();
634   }
635 };
636
637 /// Converts `spv.globalVariable` to `llvm.mlir.global`. Note that SPIR-V global
638 /// returns a pointer, whereas in LLVM dialect the global holds an actual value.
639 /// This difference is handled by `spv._address_of` and `llvm.mlir.addressof`ops
640 /// that both return a pointer.
641 class GlobalVariablePattern
642     : public SPIRVToLLVMConversion<spirv::GlobalVariableOp> {
643 public:
644   using SPIRVToLLVMConversion<spirv::GlobalVariableOp>::SPIRVToLLVMConversion;
645
646   LogicalResult
647   matchAndRewrite(spirv::GlobalVariableOp op, ArrayRef<Value> operands,
648                   ConversionPatternRewriter &rewriter) const override {
649     // Currently, there is no support of initialization with a constant value in
650     // SPIR-V dialect. Specialization constants are not considered as well.
651     if (op.initializer())
652       return failure();
653
654     auto srcType = op.type().cast<spirv::PointerType>();
655     auto dstType = typeConverter.convertType(srcType.getPointeeType());
656     if (!dstType)
657       return failure();
658
659     // Limit conversion to the current invocation only or `StorageBuffer`
660     // required by SPIR-V runner.
661     // This is okay because multiple invocations are not supported yet.
662     auto storageClass = srcType.getStorageClass();
663     if (storageClass != spirv::StorageClass::Input &&
664         storageClass != spirv::StorageClass::Private &&
665         storageClass != spirv::StorageClass::Output &&
666         storageClass != spirv::StorageClass::StorageBuffer) {
667       return failure();
668     }
669
670     // LLVM dialect spec: "If the global value is a constant, storing into it is
671     // not allowed.". This corresponds to SPIR-V 'Input' storage class that is
672     // read-only.
673     bool isConstant = storageClass == spirv::StorageClass::Input;
674     // SPIR-V spec: "By default, functions and global variables are private to a
675     // module and cannot be accessed by other modules. However, a module may be
676     // written to export or import functions and global (module scope)
677     // variables.". Therefore, map 'Private' storage class to private linkage,
678     // 'Input' and 'Output' to external linkage.
679     auto linkage = storageClass == spirv::StorageClass::Private
680                        ? LLVM::Linkage::Private
681                        : LLVM::Linkage::External;
682     rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
683         op, dstType.cast<LLVM::LLVMType>(), isConstant, linkage, op.sym_name(),
684         Attribute());
685     return success();
686   }
687 };
688
689 /// Converts SPIR-V cast ops that do not have straightforward LLVM
690 /// equivalent in LLVM dialect.
691 template <typename SPIRVOp, typename LLVMExtOp, typename LLVMTruncOp>
692 class IndirectCastPattern : public SPIRVToLLVMConversion<SPIRVOp> {
693 public:
694   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
695
696   LogicalResult
697   matchAndRewrite(SPIRVOp operation, ArrayRef<Value> operands,
698                   ConversionPatternRewriter &rewriter) const override {
699
700     Type fromType = operation.operand().getType();
701     Type toType = operation.getType();
702
703     auto dstType = this->typeConverter.convertType(toType);
704     if (!dstType)
705       return failure();
706
707     if (getBitWidth(fromType) < getBitWidth(toType)) {
708       rewriter.template replaceOpWithNewOp<LLVMExtOp>(operation, dstType,
709                                                       operands);
710       return success();
711     }
712     if (getBitWidth(fromType) > getBitWidth(toType)) {
713       rewriter.template replaceOpWithNewOp<LLVMTruncOp>(operation, dstType,
714                                                         operands);
715       return success();
716     }
717     return failure();
718   }
719 };
720
721 class FunctionCallPattern
722     : public SPIRVToLLVMConversion<spirv::FunctionCallOp> {
723 public:
724   using SPIRVToLLVMConversion<spirv::FunctionCallOp>::SPIRVToLLVMConversion;
725
726   LogicalResult
727   matchAndRewrite(spirv::FunctionCallOp callOp, ArrayRef<Value> operands,
728                   ConversionPatternRewriter &rewriter) const override {
729     if (callOp.getNumResults() == 0) {
730       rewriter.replaceOpWithNewOp<LLVM::CallOp>(callOp, llvm::None, operands,
731                                                 callOp.getAttrs());
732       return success();
733     }
734
735     // Function returns a single result.
736     auto dstType = typeConverter.convertType(callOp.getType(0));
737     rewriter.replaceOpWithNewOp<LLVM::CallOp>(callOp, dstType, operands,
738                                               callOp.getAttrs());
739     return success();
740   }
741 };
742
743 /// Converts SPIR-V floating-point comparisons to llvm.fcmp "predicate"
744 template <typename SPIRVOp, LLVM::FCmpPredicate predicate>
745 class FComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
746 public:
747   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
748
749   LogicalResult
750   matchAndRewrite(SPIRVOp operation, ArrayRef<Value> operands,
751                   ConversionPatternRewriter &rewriter) const override {
752
753     auto dstType = this->typeConverter.convertType(operation.getType());
754     if (!dstType)
755       return failure();
756
757     rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>(
758         operation, dstType,
759         rewriter.getI64IntegerAttr(static_cast<int64_t>(predicate)),
760         operation.operand1(), operation.operand2());
761     return success();
762   }
763 };
764
765 /// Converts SPIR-V integer comparisons to llvm.icmp "predicate"
766 template <typename SPIRVOp, LLVM::ICmpPredicate predicate>
767 class IComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
768 public:
769   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
770
771   LogicalResult
772   matchAndRewrite(SPIRVOp operation, ArrayRef<Value> operands,
773                   ConversionPatternRewriter &rewriter) const override {
774
775     auto dstType = this->typeConverter.convertType(operation.getType());
776     if (!dstType)
777       return failure();
778
779     rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>(
780         operation, dstType,
781         rewriter.getI64IntegerAttr(static_cast<int64_t>(predicate)),
782         operation.operand1(), operation.operand2());
783     return success();
784   }
785 };
786
787 class InverseSqrtPattern
788     : public SPIRVToLLVMConversion<spirv::GLSLInverseSqrtOp> {
789 public:
790   using SPIRVToLLVMConversion<spirv::GLSLInverseSqrtOp>::SPIRVToLLVMConversion;
791
792   LogicalResult
793   matchAndRewrite(spirv::GLSLInverseSqrtOp op, ArrayRef<Value> operands,
794                   ConversionPatternRewriter &rewriter) const override {
795     auto srcType = op.getType();
796     auto dstType = typeConverter.convertType(srcType);
797     if (!dstType)
798       return failure();
799
800     Location loc = op.getLoc();
801     Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
802     Value sqrt = rewriter.create<LLVM::SqrtOp>(loc, dstType, op.operand());
803     rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, dstType, one, sqrt);
804     return success();
805   }
806 };
807
808 /// Converts `spv.Load` and `spv.Store` to LLVM dialect.
809 template <typename SPIRVop>
810 class LoadStorePattern : public SPIRVToLLVMConversion<SPIRVop> {
811 public:
812   using SPIRVToLLVMConversion<SPIRVop>::SPIRVToLLVMConversion;
813
814   LogicalResult
815   matchAndRewrite(SPIRVop op, ArrayRef<Value> operands,
816                   ConversionPatternRewriter &rewriter) const override {
817
818     if (!op.memory_access().hasValue()) {
819       replaceWithLoadOrStore(op, rewriter, this->typeConverter, /*alignment=*/0,
820                              /*isVolatile=*/false, /*isNonTemporal=*/ false);
821       return success();
822     }
823     auto memoryAccess = op.memory_access().getValue();
824     switch (memoryAccess) {
825     case spirv::MemoryAccess::Aligned:
826     case spirv::MemoryAccess::None:
827     case spirv::MemoryAccess::Nontemporal:
828     case spirv::MemoryAccess::Volatile: {
829       unsigned alignment =
830           memoryAccess == spirv::MemoryAccess::Aligned ? *op.alignment() : 0;
831       bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal;
832       bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile;
833       replaceWithLoadOrStore(op, rewriter, this->typeConverter, alignment,
834                              isVolatile, isNonTemporal);
835       return success();
836     }
837     default:
838       // There is no support of other memory access attributes.
839       return failure();
840     }
841   }
842 };
843
844 /// Converts `spv.Not` and `spv.LogicalNot` into LLVM dialect.
845 template <typename SPIRVOp>
846 class NotPattern : public SPIRVToLLVMConversion<SPIRVOp> {
847 public:
848   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
849
850   LogicalResult
851   matchAndRewrite(SPIRVOp notOp, ArrayRef<Value> operands,
852                   ConversionPatternRewriter &rewriter) const override {
853
854     auto srcType = notOp.getType();
855     auto dstType = this->typeConverter.convertType(srcType);
856     if (!dstType)
857       return failure();
858
859     Location loc = notOp.getLoc();
860     IntegerAttr minusOne = minusOneIntegerAttribute(srcType, rewriter);
861     auto mask = srcType.template isa<VectorType>()
862                     ? rewriter.create<LLVM::ConstantOp>(
863                           loc, dstType,
864                           SplatElementsAttr::get(
865                               srcType.template cast<VectorType>(), minusOne))
866                     : rewriter.create<LLVM::ConstantOp>(loc, dstType, minusOne);
867     rewriter.template replaceOpWithNewOp<LLVM::XOrOp>(notOp, dstType,
868                                                       notOp.operand(), mask);
869     return success();
870   }
871 };
872
873 /// A template pattern that erases the given `SPIRVOp`.
874 template <typename SPIRVOp>
875 class ErasePattern : public SPIRVToLLVMConversion<SPIRVOp> {
876 public:
877   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
878
879   LogicalResult
880   matchAndRewrite(SPIRVOp op, ArrayRef<Value> operands,
881                   ConversionPatternRewriter &rewriter) const override {
882     rewriter.eraseOp(op);
883     return success();
884   }
885 };
886
887 class ReturnPattern : public SPIRVToLLVMConversion<spirv::ReturnOp> {
888 public:
889   using SPIRVToLLVMConversion<spirv::ReturnOp>::SPIRVToLLVMConversion;
890
891   LogicalResult
892   matchAndRewrite(spirv::ReturnOp returnOp, ArrayRef<Value> operands,
893                   ConversionPatternRewriter &rewriter) const override {
894     rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp, ArrayRef<Type>(),
895                                                 ArrayRef<Value>());
896     return success();
897   }
898 };
899
900 class ReturnValuePattern : public SPIRVToLLVMConversion<spirv::ReturnValueOp> {
901 public:
902   using SPIRVToLLVMConversion<spirv::ReturnValueOp>::SPIRVToLLVMConversion;
903
904   LogicalResult
905   matchAndRewrite(spirv::ReturnValueOp returnValueOp, ArrayRef<Value> operands,
906                   ConversionPatternRewriter &rewriter) const override {
907     rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnValueOp, ArrayRef<Type>(),
908                                                 operands);
909     return success();
910   }
911 };
912
913 /// Converts `spv.loop` to LLVM dialect. All blocks within selection should be
914 /// reachable for conversion to succeed.
915 /// The structure of the loop in LLVM dialect will be the following:
916 ///
917 ///      +------------------------------------+
918 ///      | <code before spv.loop>             |
919 ///      | llvm.br ^header                    |
920 ///      +------------------------------------+
921 ///                           |
922 ///   +----------------+      |
923 ///   |                |      |
924 ///   |                V      V
925 ///   |  +------------------------------------+
926 ///   |  | ^header:                           |
927 ///   |  |   <header code>                    |
928 ///   |  |   llvm.cond_br %cond, ^body, ^exit |
929 ///   |  +------------------------------------+
930 ///   |                    |
931 ///   |                    |----------------------+
932 ///   |                    |                      |
933 ///   |                    V                      |
934 ///   |  +------------------------------------+   |
935 ///   |  | ^body:                             |   |
936 ///   |  |   <body code>                      |   |
937 ///   |  |   llvm.br ^continue                |   |
938 ///   |  +------------------------------------+   |
939 ///   |                    |                      |
940 ///   |                    V                      |
941 ///   |  +------------------------------------+   |
942 ///   |  | ^continue:                         |   |
943 ///   |  |   <continue code>                  |   |
944 ///   |  |   llvm.br ^header                  |   |
945 ///   |  +------------------------------------+   |
946 ///   |               |                           |
947 ///   +---------------+    +----------------------+
948 ///                        |
949 ///                        V
950 ///      +------------------------------------+
951 ///      | ^exit:                             |
952 ///      |   llvm.br ^remaining               |
953 ///      +------------------------------------+
954 ///                        |
955 ///                        V
956 ///      +------------------------------------+
957 ///      | ^remaining:                        |
958 ///      |   <code after spv.loop>            |
959 ///      +------------------------------------+
960 ///
961 class LoopPattern : public SPIRVToLLVMConversion<spirv::LoopOp> {
962 public:
963   using SPIRVToLLVMConversion<spirv::LoopOp>::SPIRVToLLVMConversion;
964
965   LogicalResult
966   matchAndRewrite(spirv::LoopOp loopOp, ArrayRef<Value> operands,
967                   ConversionPatternRewriter &rewriter) const override {
968     // There is no support of loop control at the moment.
969     if (loopOp.loop_control() != spirv::LoopControl::None)
970       return failure();
971
972     Location loc = loopOp.getLoc();
973
974     // Split the current block after `spv.loop`. The remaing ops will be used in
975     // `endBlock`.
976     Block *currentBlock = rewriter.getBlock();
977     auto position = Block::iterator(loopOp);
978     Block *endBlock = rewriter.splitBlock(currentBlock, position);
979
980     // Remove entry block and create a branch in the current block going to the
981     // header block.
982     Block *entryBlock = loopOp.getEntryBlock();
983     assert(entryBlock->getOperations().size() == 1);
984     auto brOp = dyn_cast<spirv::BranchOp>(entryBlock->getOperations().front());
985     if (!brOp)
986       return failure();
987     Block *headerBlock = loopOp.getHeaderBlock();
988     rewriter.setInsertionPointToEnd(currentBlock);
989     rewriter.create<LLVM::BrOp>(loc, brOp.getBlockArguments(), headerBlock);
990     rewriter.eraseBlock(entryBlock);
991
992     // Branch from merge block to end block.
993     Block *mergeBlock = loopOp.getMergeBlock();
994     Operation *terminator = mergeBlock->getTerminator();
995     ValueRange terminatorOperands = terminator->getOperands();
996     rewriter.setInsertionPointToEnd(mergeBlock);
997     rewriter.create<LLVM::BrOp>(loc, terminatorOperands, endBlock);
998
999     rewriter.inlineRegionBefore(loopOp.body(), endBlock);
1000     rewriter.replaceOp(loopOp, endBlock->getArguments());
1001     return success();
1002   }
1003 };
1004
1005 /// Converts `spv.selection` with `spv.BranchConditional` in its header block.
1006 /// All blocks within selection should be reachable for conversion to succeed.
1007 class SelectionPattern : public SPIRVToLLVMConversion<spirv::SelectionOp> {
1008 public:
1009   using SPIRVToLLVMConversion<spirv::SelectionOp>::SPIRVToLLVMConversion;
1010
1011   LogicalResult
1012   matchAndRewrite(spirv::SelectionOp op, ArrayRef<Value> operands,
1013                   ConversionPatternRewriter &rewriter) const override {
1014     // There is no support for `Flatten` or `DontFlatten` selection control at
1015     // the moment. This are just compiler hints and can be performed during the
1016     // optimization passes.
1017     if (op.selection_control() != spirv::SelectionControl::None)
1018       return failure();
1019
1020     // `spv.selection` should have at least two blocks: one selection header
1021     // block and one merge block. If no blocks are present, or control flow
1022     // branches straight to merge block (two blocks are present), the op is
1023     // redundant and it is erased.
1024     if (op.body().getBlocks().size() <= 2) {
1025       rewriter.eraseOp(op);
1026       return success();
1027     }
1028
1029     Location loc = op.getLoc();
1030
1031     // Split the current block after `spv.selection`. The remaing ops will be
1032     // used in `continueBlock`.
1033     auto *currentBlock = rewriter.getInsertionBlock();
1034     rewriter.setInsertionPointAfter(op);
1035     auto position = rewriter.getInsertionPoint();
1036     auto *continueBlock = rewriter.splitBlock(currentBlock, position);
1037
1038     // Extract conditional branch information from the header block. By SPIR-V
1039     // dialect spec, it should contain `spv.BranchConditional` or `spv.Switch`
1040     // op. Note that `spv.Switch op` is not supported at the moment in the
1041     // SPIR-V dialect. Remove this block when finished.
1042     auto *headerBlock = op.getHeaderBlock();
1043     assert(headerBlock->getOperations().size() == 1);
1044     auto condBrOp = dyn_cast<spirv::BranchConditionalOp>(
1045         headerBlock->getOperations().front());
1046     if (!condBrOp)
1047       return failure();
1048     rewriter.eraseBlock(headerBlock);
1049
1050     // Branch from merge block to continue block.
1051     auto *mergeBlock = op.getMergeBlock();
1052     Operation *terminator = mergeBlock->getTerminator();
1053     ValueRange terminatorOperands = terminator->getOperands();
1054     rewriter.setInsertionPointToEnd(mergeBlock);
1055     rewriter.create<LLVM::BrOp>(loc, terminatorOperands, continueBlock);
1056
1057     // Link current block to `true` and `false` blocks within the selection.
1058     Block *trueBlock = condBrOp.getTrueBlock();
1059     Block *falseBlock = condBrOp.getFalseBlock();
1060     rewriter.setInsertionPointToEnd(currentBlock);
1061     rewriter.create<LLVM::CondBrOp>(loc, condBrOp.condition(), trueBlock,
1062                                     condBrOp.trueTargetOperands(), falseBlock,
1063                                     condBrOp.falseTargetOperands());
1064
1065     rewriter.inlineRegionBefore(op.body(), continueBlock);
1066     rewriter.replaceOp(op, continueBlock->getArguments());
1067     return success();
1068   }
1069 };
1070
1071 /// Converts SPIR-V shift ops to LLVM shift ops. Since LLVM dialect
1072 /// puts a restriction on `Shift` and `Base` to have the same bit width,
1073 /// `Shift` is zero or sign extended to match this specification. Cases when
1074 /// `Shift` bit width > `Base` bit width are considered to be illegal.
1075 template <typename SPIRVOp, typename LLVMOp>
1076 class ShiftPattern : public SPIRVToLLVMConversion<SPIRVOp> {
1077 public:
1078   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
1079
1080   LogicalResult
1081   matchAndRewrite(SPIRVOp operation, ArrayRef<Value> operands,
1082                   ConversionPatternRewriter &rewriter) const override {
1083
1084     auto dstType = this->typeConverter.convertType(operation.getType());
1085     if (!dstType)
1086       return failure();
1087
1088     Type op1Type = operation.operand1().getType();
1089     Type op2Type = operation.operand2().getType();
1090
1091     if (op1Type == op2Type) {
1092       rewriter.template replaceOpWithNewOp<LLVMOp>(operation, dstType,
1093                                                    operands);
1094       return success();
1095     }
1096
1097     Location loc = operation.getLoc();
1098     Value extended;
1099     if (isUnsignedIntegerOrVector(op2Type)) {
1100       extended = rewriter.template create<LLVM::ZExtOp>(loc, dstType,
1101                                                         operation.operand2());
1102     } else {
1103       extended = rewriter.template create<LLVM::SExtOp>(loc, dstType,
1104                                                         operation.operand2());
1105     }
1106     Value result = rewriter.template create<LLVMOp>(
1107         loc, dstType, operation.operand1(), extended);
1108     rewriter.replaceOp(operation, result);
1109     return success();
1110   }
1111 };
1112
1113 class TanPattern : public SPIRVToLLVMConversion<spirv::GLSLTanOp> {
1114 public:
1115   using SPIRVToLLVMConversion<spirv::GLSLTanOp>::SPIRVToLLVMConversion;
1116
1117   LogicalResult
1118   matchAndRewrite(spirv::GLSLTanOp tanOp, ArrayRef<Value> operands,
1119                   ConversionPatternRewriter &rewriter) const override {
1120     auto dstType = typeConverter.convertType(tanOp.getType());
1121     if (!dstType)
1122       return failure();
1123
1124     Location loc = tanOp.getLoc();
1125     Value sin = rewriter.create<LLVM::SinOp>(loc, dstType, tanOp.operand());
1126     Value cos = rewriter.create<LLVM::CosOp>(loc, dstType, tanOp.operand());
1127     rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanOp, dstType, sin, cos);
1128     return success();
1129   }
1130 };
1131
1132 /// Convert `spv.Tanh` to
1133 ///
1134 ///   exp(2x) - 1
1135 ///   -----------
1136 ///   exp(2x) + 1
1137 ///
1138 class TanhPattern : public SPIRVToLLVMConversion<spirv::GLSLTanhOp> {
1139 public:
1140   using SPIRVToLLVMConversion<spirv::GLSLTanhOp>::SPIRVToLLVMConversion;
1141
1142   LogicalResult
1143   matchAndRewrite(spirv::GLSLTanhOp tanhOp, ArrayRef<Value> operands,
1144                   ConversionPatternRewriter &rewriter) const override {
1145     auto srcType = tanhOp.getType();
1146     auto dstType = typeConverter.convertType(srcType);
1147     if (!dstType)
1148       return failure();
1149
1150     Location loc = tanhOp.getLoc();
1151     Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0);
1152     Value multiplied =
1153         rewriter.create<LLVM::FMulOp>(loc, dstType, two, tanhOp.operand());
1154     Value exponential = rewriter.create<LLVM::ExpOp>(loc, dstType, multiplied);
1155     Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
1156     Value numerator =
1157         rewriter.create<LLVM::FSubOp>(loc, dstType, exponential, one);
1158     Value denominator =
1159         rewriter.create<LLVM::FAddOp>(loc, dstType, exponential, one);
1160     rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanhOp, dstType, numerator,
1161                                               denominator);
1162     return success();
1163   }
1164 };
1165
1166 class VariablePattern : public SPIRVToLLVMConversion<spirv::VariableOp> {
1167 public:
1168   using SPIRVToLLVMConversion<spirv::VariableOp>::SPIRVToLLVMConversion;
1169
1170   LogicalResult
1171   matchAndRewrite(spirv::VariableOp varOp, ArrayRef<Value> operands,
1172                   ConversionPatternRewriter &rewriter) const override {
1173     auto srcType = varOp.getType();
1174     // Initialization is supported for scalars and vectors only.
1175     auto pointerTo = srcType.cast<spirv::PointerType>().getPointeeType();
1176     auto init = varOp.initializer();
1177     if (init && !pointerTo.isIntOrFloat() && !pointerTo.isa<VectorType>())
1178       return failure();
1179
1180     auto dstType = typeConverter.convertType(srcType);
1181     if (!dstType)
1182       return failure();
1183
1184     Location loc = varOp.getLoc();
1185     Value size = createI32ConstantOf(loc, rewriter, 1);
1186     if (!init) {
1187       rewriter.replaceOpWithNewOp<LLVM::AllocaOp>(varOp, dstType, size);
1188       return success();
1189     }
1190     Value allocated = rewriter.create<LLVM::AllocaOp>(loc, dstType, size);
1191     rewriter.create<LLVM::StoreOp>(loc, init, allocated);
1192     rewriter.replaceOp(varOp, allocated);
1193     return success();
1194   }
1195 };
1196
1197 //===----------------------------------------------------------------------===//
1198 // FuncOp conversion
1199 //===----------------------------------------------------------------------===//
1200
1201 class FuncConversionPattern : public SPIRVToLLVMConversion<spirv::FuncOp> {
1202 public:
1203   using SPIRVToLLVMConversion<spirv::FuncOp>::SPIRVToLLVMConversion;
1204
1205   LogicalResult
1206   matchAndRewrite(spirv::FuncOp funcOp, ArrayRef<Value> operands,
1207                   ConversionPatternRewriter &rewriter) const override {
1208
1209     // Convert function signature. At the moment LLVMType converter is enough
1210     // for currently supported types.
1211     auto funcType = funcOp.getType();
1212     TypeConverter::SignatureConversion signatureConverter(
1213         funcType.getNumInputs());
1214     auto llvmType = typeConverter.convertFunctionSignature(
1215         funcOp.getType(), /*isVariadic=*/false, signatureConverter);
1216     if (!llvmType)
1217       return failure();
1218
1219     // Create a new `LLVMFuncOp`
1220     Location loc = funcOp.getLoc();
1221     StringRef name = funcOp.getName();
1222     auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(loc, name, llvmType);
1223
1224     // Convert SPIR-V Function Control to equivalent LLVM function attribute
1225     MLIRContext *context = funcOp.getContext();
1226     switch (funcOp.function_control()) {
1227 #define DISPATCH(functionControl, llvmAttr)                                    \
1228   case functionControl:                                                        \
1229     newFuncOp.setAttr("passthrough", ArrayAttr::get({llvmAttr}, context));     \
1230     break;
1231
1232           DISPATCH(spirv::FunctionControl::Inline,
1233                    StringAttr::get("alwaysinline", context));
1234           DISPATCH(spirv::FunctionControl::DontInline,
1235                    StringAttr::get("noinline", context));
1236           DISPATCH(spirv::FunctionControl::Pure,
1237                    StringAttr::get("readonly", context));
1238           DISPATCH(spirv::FunctionControl::Const,
1239                    StringAttr::get("readnone", context));
1240
1241 #undef DISPATCH
1242
1243     // Default: if `spirv::FunctionControl::None`, then no attributes are
1244     // needed.
1245     default:
1246       break;
1247     }
1248
1249     rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
1250                                 newFuncOp.end());
1251     if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter,
1252                                            &signatureConverter))) {
1253       return failure();
1254     }
1255     rewriter.eraseOp(funcOp);
1256     return success();
1257   }
1258 };
1259
1260 //===----------------------------------------------------------------------===//
1261 // ModuleOp conversion
1262 //===----------------------------------------------------------------------===//
1263
1264 class ModuleConversionPattern : public SPIRVToLLVMConversion<spirv::ModuleOp> {
1265 public:
1266   using SPIRVToLLVMConversion<spirv::ModuleOp>::SPIRVToLLVMConversion;
1267
1268   LogicalResult
1269   matchAndRewrite(spirv::ModuleOp spvModuleOp, ArrayRef<Value> operands,
1270                   ConversionPatternRewriter &rewriter) const override {
1271
1272     auto newModuleOp = rewriter.create<ModuleOp>(spvModuleOp.getLoc());
1273     rewriter.inlineRegionBefore(spvModuleOp.body(), newModuleOp.getBody());
1274
1275     // Remove the terminator block that was automatically added by builder
1276     rewriter.eraseBlock(&newModuleOp.getBodyRegion().back());
1277     rewriter.eraseOp(spvModuleOp);
1278     return success();
1279   }
1280 };
1281
1282 class ModuleEndConversionPattern
1283     : public SPIRVToLLVMConversion<spirv::ModuleEndOp> {
1284 public:
1285   using SPIRVToLLVMConversion<spirv::ModuleEndOp>::SPIRVToLLVMConversion;
1286
1287   LogicalResult
1288   matchAndRewrite(spirv::ModuleEndOp moduleEndOp, ArrayRef<Value> operands,
1289                   ConversionPatternRewriter &rewriter) const override {
1290
1291     rewriter.replaceOpWithNewOp<ModuleTerminatorOp>(moduleEndOp);
1292     return success();
1293   }
1294 };
1295
1296 } // namespace
1297
1298 //===----------------------------------------------------------------------===//
1299 // Pattern population
1300 //===----------------------------------------------------------------------===//
1301
1302 void mlir::populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter) {
1303   typeConverter.addConversion([&](spirv::ArrayType type) {
1304     return convertArrayType(type, typeConverter);
1305   });
1306   typeConverter.addConversion([&](spirv::PointerType type) {
1307     return convertPointerType(type, typeConverter);
1308   });
1309   typeConverter.addConversion([&](spirv::RuntimeArrayType type) {
1310     return convertRuntimeArrayType(type, typeConverter);
1311   });
1312   typeConverter.addConversion([&](spirv::StructType type) {
1313     return convertStructType(type, typeConverter);
1314   });
1315 }
1316
1317 void mlir::populateSPIRVToLLVMConversionPatterns(
1318     MLIRContext *context, LLVMTypeConverter &typeConverter,
1319     OwningRewritePatternList &patterns) {
1320   patterns.insert<
1321       // Arithmetic ops
1322       DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>,
1323       DirectConversionPattern<spirv::IMulOp, LLVM::MulOp>,
1324       DirectConversionPattern<spirv::ISubOp, LLVM::SubOp>,
1325       DirectConversionPattern<spirv::FAddOp, LLVM::FAddOp>,
1326       DirectConversionPattern<spirv::FDivOp, LLVM::FDivOp>,
1327       DirectConversionPattern<spirv::FMulOp, LLVM::FMulOp>,
1328       DirectConversionPattern<spirv::FNegateOp, LLVM::FNegOp>,
1329       DirectConversionPattern<spirv::FRemOp, LLVM::FRemOp>,
1330       DirectConversionPattern<spirv::FSubOp, LLVM::FSubOp>,
1331       DirectConversionPattern<spirv::SDivOp, LLVM::SDivOp>,
1332       DirectConversionPattern<spirv::SRemOp, LLVM::SRemOp>,
1333       DirectConversionPattern<spirv::UDivOp, LLVM::UDivOp>,
1334       DirectConversionPattern<spirv::UModOp, LLVM::URemOp>,
1335
1336       // Bitwise ops
1337       BitFieldInsertPattern, BitFieldUExtractPattern, BitFieldSExtractPattern,
1338       DirectConversionPattern<spirv::BitCountOp, LLVM::CtPopOp>,
1339       DirectConversionPattern<spirv::BitReverseOp, LLVM::BitReverseOp>,
1340       DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>,
1341       DirectConversionPattern<spirv::BitwiseOrOp, LLVM::OrOp>,
1342       DirectConversionPattern<spirv::BitwiseXorOp, LLVM::XOrOp>,
1343       NotPattern<spirv::NotOp>,
1344
1345       // Cast ops
1346       DirectConversionPattern<spirv::BitcastOp, LLVM::BitcastOp>,
1347       DirectConversionPattern<spirv::ConvertFToSOp, LLVM::FPToSIOp>,
1348       DirectConversionPattern<spirv::ConvertFToUOp, LLVM::FPToUIOp>,
1349       DirectConversionPattern<spirv::ConvertSToFOp, LLVM::SIToFPOp>,
1350       DirectConversionPattern<spirv::ConvertUToFOp, LLVM::UIToFPOp>,
1351       IndirectCastPattern<spirv::FConvertOp, LLVM::FPExtOp, LLVM::FPTruncOp>,
1352       IndirectCastPattern<spirv::SConvertOp, LLVM::SExtOp, LLVM::TruncOp>,
1353       IndirectCastPattern<spirv::UConvertOp, LLVM::ZExtOp, LLVM::TruncOp>,
1354
1355       // Comparison ops
1356       IComparePattern<spirv::IEqualOp, LLVM::ICmpPredicate::eq>,
1357       IComparePattern<spirv::INotEqualOp, LLVM::ICmpPredicate::ne>,
1358       FComparePattern<spirv::FOrdEqualOp, LLVM::FCmpPredicate::oeq>,
1359       FComparePattern<spirv::FOrdGreaterThanOp, LLVM::FCmpPredicate::ogt>,
1360       FComparePattern<spirv::FOrdGreaterThanEqualOp, LLVM::FCmpPredicate::oge>,
1361       FComparePattern<spirv::FOrdLessThanEqualOp, LLVM::FCmpPredicate::ole>,
1362       FComparePattern<spirv::FOrdLessThanOp, LLVM::FCmpPredicate::olt>,
1363       FComparePattern<spirv::FOrdNotEqualOp, LLVM::FCmpPredicate::one>,
1364       FComparePattern<spirv::FUnordEqualOp, LLVM::FCmpPredicate::ueq>,
1365       FComparePattern<spirv::FUnordGreaterThanOp, LLVM::FCmpPredicate::ugt>,
1366       FComparePattern<spirv::FUnordGreaterThanEqualOp,
1367                       LLVM::FCmpPredicate::uge>,
1368       FComparePattern<spirv::FUnordLessThanEqualOp, LLVM::FCmpPredicate::ule>,
1369       FComparePattern<spirv::FUnordLessThanOp, LLVM::FCmpPredicate::ult>,
1370       FComparePattern<spirv::FUnordNotEqualOp, LLVM::FCmpPredicate::une>,
1371       IComparePattern<spirv::SGreaterThanOp, LLVM::ICmpPredicate::sgt>,
1372       IComparePattern<spirv::SGreaterThanEqualOp, LLVM::ICmpPredicate::sge>,
1373       IComparePattern<spirv::SLessThanEqualOp, LLVM::ICmpPredicate::sle>,
1374       IComparePattern<spirv::SLessThanOp, LLVM::ICmpPredicate::slt>,
1375       IComparePattern<spirv::UGreaterThanOp, LLVM::ICmpPredicate::ugt>,
1376       IComparePattern<spirv::UGreaterThanEqualOp, LLVM::ICmpPredicate::uge>,
1377       IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>,
1378       IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>,
1379
1380       // Constant op
1381       ConstantScalarAndVectorPattern,
1382
1383       // Control Flow ops
1384       BranchConversionPattern, BranchConditionalConversionPattern,
1385       FunctionCallPattern, LoopPattern, SelectionPattern,
1386       ErasePattern<spirv::MergeOp>,
1387
1388       // Entry points and execution mode
1389       // Module generated from SPIR-V could have other "internal" functions, so
1390       // having entry point and execution mode metadat can be useful. For now,
1391       // simply remove them.
1392       // TODO: Support EntryPoint/ExecutionMode properly.
1393       ErasePattern<spirv::EntryPointOp>, ErasePattern<spirv::ExecutionModeOp>,
1394
1395       // GLSL extended instruction set ops
1396       DirectConversionPattern<spirv::GLSLCeilOp, LLVM::FCeilOp>,
1397       DirectConversionPattern<spirv::GLSLCosOp, LLVM::CosOp>,
1398       DirectConversionPattern<spirv::GLSLExpOp, LLVM::ExpOp>,
1399       DirectConversionPattern<spirv::GLSLFAbsOp, LLVM::FAbsOp>,
1400       DirectConversionPattern<spirv::GLSLFloorOp, LLVM::FFloorOp>,
1401       DirectConversionPattern<spirv::GLSLFMaxOp, LLVM::MaxNumOp>,
1402       DirectConversionPattern<spirv::GLSLFMinOp, LLVM::MinNumOp>,
1403       DirectConversionPattern<spirv::GLSLLogOp, LLVM::LogOp>,
1404       DirectConversionPattern<spirv::GLSLSinOp, LLVM::SinOp>,
1405       DirectConversionPattern<spirv::GLSLSMaxOp, LLVM::SMaxOp>,
1406       DirectConversionPattern<spirv::GLSLSMinOp, LLVM::SMinOp>,
1407       DirectConversionPattern<spirv::GLSLSqrtOp, LLVM::SqrtOp>,
1408       InverseSqrtPattern, TanPattern, TanhPattern,
1409
1410       // Logical ops
1411       DirectConversionPattern<spirv::LogicalAndOp, LLVM::AndOp>,
1412       DirectConversionPattern<spirv::LogicalOrOp, LLVM::OrOp>,
1413       IComparePattern<spirv::LogicalEqualOp, LLVM::ICmpPredicate::eq>,
1414       IComparePattern<spirv::LogicalNotEqualOp, LLVM::ICmpPredicate::ne>,
1415       NotPattern<spirv::LogicalNotOp>,
1416
1417       // Memory ops
1418       AccessChainPattern, AddressOfPattern, GlobalVariablePattern,
1419       LoadStorePattern<spirv::LoadOp>, LoadStorePattern<spirv::StoreOp>,
1420       VariablePattern,
1421
1422       // Miscellaneous ops
1423       CompositeExtractPattern, CompositeInsertPattern,
1424       DirectConversionPattern<spirv::SelectOp, LLVM::SelectOp>,
1425       DirectConversionPattern<spirv::UndefOp, LLVM::UndefOp>,
1426
1427       // Shift ops
1428       ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>,
1429       ShiftPattern<spirv::ShiftRightLogicalOp, LLVM::LShrOp>,
1430       ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>,
1431
1432       // Return ops
1433       ReturnPattern, ReturnValuePattern>(context, typeConverter);
1434 }
1435
1436 void mlir::populateSPIRVToLLVMFunctionConversionPatterns(
1437     MLIRContext *context, LLVMTypeConverter &typeConverter,
1438     OwningRewritePatternList &patterns) {
1439   patterns.insert<FuncConversionPattern>(context, typeConverter);
1440 }
1441
1442 void mlir::populateSPIRVToLLVMModuleConversionPatterns(
1443     MLIRContext *context, LLVMTypeConverter &typeConverter,
1444     OwningRewritePatternList &patterns) {
1445   patterns.insert<ModuleConversionPattern, ModuleEndConversionPattern>(
1446       context, typeConverter);
1447 }
1448
1449 //===----------------------------------------------------------------------===//
1450 // Pre-conversion hooks
1451 //===----------------------------------------------------------------------===//
1452
1453 /// Hook for descriptor set and binding number encoding.
1454 static constexpr StringRef kBinding = "binding";
1455 static constexpr StringRef kDescriptorSet = "descriptor_set";
1456 void mlir::encodeBindAttribute(ModuleOp module) {
1457   auto spvModules = module.getOps<spirv::ModuleOp>();
1458   for (auto spvModule : spvModules) {
1459     spvModule.walk([&](spirv::GlobalVariableOp op) {
1460       IntegerAttr descriptorSet = op.getAttrOfType<IntegerAttr>(kDescriptorSet);
1461       IntegerAttr binding = op.getAttrOfType<IntegerAttr>(kBinding);
1462       // For every global variable in the module, get the ones with descriptor
1463       // set and binding numbers.
1464       if (descriptorSet && binding) {
1465         // Encode these numbers into the variable's symbolic name. If the
1466         // SPIR-V module has a name, add it at the beginning.
1467         auto moduleAndName = spvModule.getName().hasValue()
1468                                  ? spvModule.getName().getValue().str() + "_" +
1469                                        op.sym_name().str()
1470                                  : op.sym_name().str();
1471         std::string name =
1472             llvm::formatv("{0}_descriptor_set{1}_binding{2}", moduleAndName,
1473                           std::to_string(descriptorSet.getInt()),
1474                           std::to_string(binding.getInt()));
1475
1476         // Replace all symbol uses and set the new symbol name. Finally, remove
1477         // descriptor set and binding attributes.
1478         if (failed(SymbolTable::replaceAllSymbolUses(op, name, spvModule)))
1479           op.emitError("unable to replace all symbol uses for ") << name;
1480         SymbolTable::setSymbolName(op, name);
1481         op.removeAttr(kDescriptorSet);
1482         op.removeAttr(kBinding);
1483       }
1484     });
1485   }
1486 }