[mlir] fix the rocm runtime wrapper to account for cuda / rocm api differences
authorTobias Gysi <tobias.gysi@gmail.com>
Wed, 20 Jan 2021 17:35:51 +0000 (18:35 +0100)
committerTobias Gysi <tobias.gysi@gmail.com>
Wed, 20 Jan 2021 17:48:32 +0000 (18:48 +0100)
The patch adapts the rocm runtime wrapper due to subtle differences between the cuda and the rocm/hip runtime api.

Reviewed By: csigg

Differential Revision: https://reviews.llvm.org/D95027

mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp

index 4f62f20..028e2e3 100644 (file)
@@ -36,7 +36,7 @@ static auto InitializeCtx = [] {
   HIP_REPORT_IF_ERROR(hipInit(/*flags=*/0));
   hipDevice_t device;
   HIP_REPORT_IF_ERROR(hipDeviceGet(&device, /*ordinal=*/0));
-  hipContext_t context;
+  hipCtx_t context;
   HIP_REPORT_IF_ERROR(hipCtxCreate(&context, /*flags=*/0, device));
   return 0;
 }();
@@ -110,17 +110,18 @@ extern "C" void mgpuEventRecord(hipEvent_t event, hipStream_t stream) {
 
 extern "C" void *mgpuMemAlloc(uint64_t sizeBytes, hipStream_t /*stream*/) {
   void *ptr;
-  HIP_REPORT_IF_ERROR(hipMemAlloc(&ptr, sizeBytes));
+  HIP_REPORT_IF_ERROR(hipMalloc(&ptr, sizeBytes));
   return ptr;
 }
 
 extern "C" void mgpuMemFree(void *ptr, hipStream_t /*stream*/) {
-  HIP_REPORT_IF_ERROR(hipMemFree(ptr));
+  HIP_REPORT_IF_ERROR(hipFree(ptr));
 }
 
 extern "C" void mgpuMemcpy(void *dst, void *src, uint64_t sizeBytes,
                            hipStream_t stream) {
-  HIP_REPORT_IF_ERROR(hipMemcpyAsync(dst, src, sizeBytes, stream));
+  HIP_REPORT_IF_ERROR(
+      hipMemcpyAsync(dst, src, sizeBytes, hipMemcpyDefault, stream));
 }
 
 /// Helper functions for writing mlir example code