5f72ad35a6701dff6cdd795b72e2a60d3823b775
[lldb.git] / mlir / include / mlir / Dialect / LLVMIR / NVVMOps.td
1 //===-- NVVMOps.td - NVVM IR dialect op definition file ----*- tablegen -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This is the NVVM IR operation definition file.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #ifndef NVVMIR_OPS
14 #define NVVMIR_OPS
15
16 include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
17 include "mlir/Interfaces/SideEffectInterfaces.td"
18
19 //===----------------------------------------------------------------------===//
20 // NVVM dialect definitions
21 //===----------------------------------------------------------------------===//
22
23 def NVVM_Dialect : Dialect {
24   let name = "nvvm";
25   let cppNamespace = "::mlir::NVVM";
26   let dependentDialects = ["LLVM::LLVMDialect"];
27 }
28
29 //===----------------------------------------------------------------------===//
30 // NVVM op definitions
31 //===----------------------------------------------------------------------===//
32
33 class NVVM_Op<string mnemonic, list<OpTrait> traits = []> :
34   LLVM_OpBase<NVVM_Dialect, mnemonic, traits> {
35 }
36
37 //===----------------------------------------------------------------------===//
38 // NVVM intrinsic operations
39 //===----------------------------------------------------------------------===//
40
41 class NVVM_IntrOp<string mnem, list<int> overloadedResults,
42                   list<int> overloadedOperands, list<OpTrait> traits,
43                   bit hasResult>
44   : LLVM_IntrOpBase<NVVM_Dialect, mnem, "nvvm_" # !subst(".", "_", mnem),
45                     overloadedResults, overloadedOperands, traits, hasResult>;
46
47
48 //===----------------------------------------------------------------------===//
49 // NVVM special register op definitions
50 //===----------------------------------------------------------------------===//
51
52 class NVVM_SpecialRegisterOp<string mnemonic,
53     list<OpTrait> traits = []> :
54   NVVM_IntrOp<mnemonic, [], [], !listconcat(traits, [NoSideEffect]), 1>,
55   Arguments<(ins)> {
56   let assemblyFormat = "attr-dict `:` type($res)";
57 }
58
59 //===----------------------------------------------------------------------===//
60 // Lane index and range
61 def NVVM_LaneIdOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.laneid">;
62 def NVVM_WarpSizeOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.warpsize">;
63
64 //===----------------------------------------------------------------------===//
65 // Thread index and range
66 def NVVM_ThreadIdXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.tid.x">;
67 def NVVM_ThreadIdYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.tid.y">;
68 def NVVM_ThreadIdZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.tid.z">;
69 def NVVM_BlockDimXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ntid.x">;
70 def NVVM_BlockDimYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ntid.y">;
71 def NVVM_BlockDimZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ntid.z">;
72
73 //===----------------------------------------------------------------------===//
74 // Block index and range
75 def NVVM_BlockIdXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ctaid.x">;
76 def NVVM_BlockIdYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ctaid.y">;
77 def NVVM_BlockIdZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ctaid.z">;
78 def NVVM_GridDimXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.x">;
79 def NVVM_GridDimYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.y">;
80 def NVVM_GridDimZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.z">;
81
82 //===----------------------------------------------------------------------===//
83 // NVVM synchronization op definitions
84 //===----------------------------------------------------------------------===//
85
86 def NVVM_Barrier0Op : NVVM_Op<"barrier0"> {
87   string llvmBuilder = [{
88       createIntrinsicCall(builder, llvm::Intrinsic::nvvm_barrier0);
89   }];
90   let assemblyFormat = "attr-dict";
91 }
92
93 def NVVM_ShflBflyOp :
94   NVVM_Op<"shfl.sync.bfly">,
95   Results<(outs LLVM_Type:$res)>,
96   Arguments<(ins LLVM_Type:$dst,
97                  LLVM_Type:$val,
98                  LLVM_Type:$offset,
99                  LLVM_Type:$mask_and_clamp,
100                  OptionalAttr<UnitAttr>:$return_value_and_is_valid)> {
101   string llvmBuilder = [{
102       auto intId = getShflBflyIntrinsicId(
103           $_resultType, static_cast<bool>($return_value_and_is_valid));
104       $res = createIntrinsicCall(builder,
105           intId, {$dst, $val, $offset, $mask_and_clamp});
106   }];
107   let parser = [{ return parseNVVMShflSyncBflyOp(parser, result); }];
108   let printer = [{ printNVVMIntrinsicOp(p, this->getOperation()); }];
109   let verifier = [{
110     if (!getAttrOfType<UnitAttr>("return_value_and_is_valid"))
111       return success();
112     auto type = getType().cast<LLVM::LLVMType>();
113     if (!type.isStructTy() || type.getStructNumElements() != 2 ||
114         !type.getStructElementType(1).isIntegerTy(
115             /*Bitwidth=*/1))
116       return emitError("expected return type to be a two-element struct with "
117                        "i1 as the second element");
118     return success();
119   }];
120 }
121
122 def NVVM_VoteBallotOp :
123   NVVM_Op<"vote.ballot.sync">,
124   Results<(outs LLVM_Type:$res)>,
125   Arguments<(ins LLVM_Type:$mask, LLVM_Type:$pred)> {
126   string llvmBuilder = [{
127       $res = createIntrinsicCall(builder,
128             llvm::Intrinsic::nvvm_vote_ballot_sync, {$mask, $pred});
129   }];
130   let parser = [{ return parseNVVMVoteBallotOp(parser, result); }];
131   let printer = [{ printNVVMIntrinsicOp(p, this->getOperation()); }];
132 }
133
134 def NVVM_MmaOp :
135   NVVM_Op<"mma.sync">,
136   Results<(outs LLVM_Type:$res)>,
137   Arguments<(ins Variadic<LLVM_Type>:$args)> {
138   string llvmBuilder = [{
139     $res = createIntrinsicCall(
140         builder, llvm::Intrinsic::nvvm_mma_m8n8k4_row_col_f32_f32, $args);
141   }];
142   let assemblyFormat = "$args attr-dict `:` functional-type($args, $res)";
143   let verifier = [{ return ::verify(*this); }];
144 }
145
146 #endif // NVVMIR_OPS