diff options
author | Shilei Tian <i@tianshilei.me> | 2022-03-09 14:55:07 -0500 |
---|---|---|
committer | Shilei Tian <i@tianshilei.me> | 2022-03-09 14:55:20 -0500 |
commit | 5105c7cd78751d50975e7af2d4d614ebb799cbd4 (patch) | |
tree | a864b94105184e1cc3dd7ee30a47a0f1829d23a8 | |
parent | 7e0b0e05af6349ff70b834b4de7a118233f60c37 (diff) |
[OpenMP][CUDA] Fix an issue that multiple `CUmodule` are could be overwritten
This patch fixes the issue introduced in 14de0820e87f and D120089, that
if dynamic libraries are used, the `CUmodule` array could be overwritten.
Reviewed By: jdoerfert
Differential Revision: https://reviews.llvm.org/D121308
-rw-r--r-- | openmp/libomptarget/plugins/cuda/src/rtl.cpp | 39 |
1 files changed, 22 insertions, 17 deletions
diff --git a/openmp/libomptarget/plugins/cuda/src/rtl.cpp b/openmp/libomptarget/plugins/cuda/src/rtl.cpp index dab5b527f521..9c3488790bed 100644 --- a/openmp/libomptarget/plugins/cuda/src/rtl.cpp +++ b/openmp/libomptarget/plugins/cuda/src/rtl.cpp @@ -354,7 +354,7 @@ class DeviceRTLTy { std::vector<std::unique_ptr<EventPoolTy>> EventPool; std::vector<DeviceDataTy> DeviceData; - std::vector<CUmodule> Modules; + std::vector<std::vector<CUmodule>> Modules; /// Vector of flags indicating the initalization status of all associated /// devices. @@ -777,25 +777,30 @@ public: if (UseMemoryManager) MemoryManagers[DeviceId].release(); - // Close module - if (CUmodule &M = Modules[DeviceId]) - checkResult(cuModuleUnload(M), "Error returned from cuModuleUnload\n"); - StreamPool[DeviceId].reset(); EventPool[DeviceId].reset(); - // Destroy context DeviceDataTy &D = DeviceData[DeviceId]; - if (D.Context) { - if (checkResult(cuCtxSetCurrent(D.Context), - "Error returned from cuCtxSetCurrent\n")) { - CUdevice Device; - if (checkResult(cuCtxGetDevice(&Device), - "Error returned from cuCtxGetDevice\n")) - checkResult(cuDevicePrimaryCtxRelease(Device), - "Error returned from cuDevicePrimaryCtxRelease\n"); - } - } + if (!checkResult(cuCtxSetCurrent(D.Context), + "Error returned from cuCtxSetCurrent\n")) + return OFFLOAD_FAIL; + + // Unload all modules. + for (auto &M : Modules[DeviceId]) + if (!checkResult(cuModuleUnload(M), + "Error returned from cuModuleUnload\n")) + return OFFLOAD_FAIL; + + // Destroy context. + CUdevice Device; + if (!checkResult(cuCtxGetDevice(&Device), + "Error returned from cuCtxGetDevice\n")) + return OFFLOAD_FAIL; + + if (!checkResult(cuDevicePrimaryCtxRelease(Device), + "Error returned from cuDevicePrimaryCtxRelease\n")) + return OFFLOAD_FAIL; + return OFFLOAD_SUCCESS; } @@ -818,7 +823,7 @@ public: DP("CUDA module successfully loaded!\n"); - Modules[DeviceId] = Module; + Modules[DeviceId].push_back(Module); // Find the symbols in the module by name. const __tgt_offload_entry *HostBegin = Image->EntriesBegin; |