diff --git a/include/mscclpp/semaphore_device.hpp b/include/mscclpp/semaphore_device.hpp index 3e1730985..d0220a773 100644 --- a/include/mscclpp/semaphore_device.hpp +++ b/include/mscclpp/semaphore_device.hpp @@ -4,6 +4,8 @@ #ifndef MSCCLPP_SEMAPHORE_DEVICE_HPP_ #define MSCCLPP_SEMAPHORE_DEVICE_HPP_ +#include + #include "poll.hpp" namespace mscclpp { @@ -36,7 +38,8 @@ struct SmDevice2DeviceSemaphoreDeviceHandle { /// Poll if the remote device has signaled. /// @return true if the remote device has signaled. __forceinline__ __device__ bool poll() { - bool signaled = ((*inboundSemaphoreId) > (*expectedInboundSemaphoreId)); + bool signaled = (cuda::atomic_ref{*inboundSemaphoreId}.load( + cuda::memory_order_acquire) > (*expectedInboundSemaphoreId)); if (signaled) (*expectedInboundSemaphoreId) += 1; return signaled; } @@ -44,7 +47,9 @@ struct SmDevice2DeviceSemaphoreDeviceHandle { /// Wait for the remote device to signal. __forceinline__ __device__ void wait(int64_t maxSpinCount = 10000000) { (*expectedInboundSemaphoreId) += 1; - POLL_MAYBE_JAILBREAK((*inboundSemaphoreId) < (*expectedInboundSemaphoreId), maxSpinCount); + POLL_MAYBE_JAILBREAK((cuda::atomic_ref{*inboundSemaphoreId}.load( + cuda::memory_order_acquire) < (*expectedInboundSemaphoreId)), + maxSpinCount); } /// Signal the remote device. @@ -55,9 +60,9 @@ struct SmDevice2DeviceSemaphoreDeviceHandle { __forceinline__ __device__ void signal() { // This fence ensures that preceding writes are visible on the peer GPU before the incremented // `outboundSemaphoreId` is visible. - __threadfence_system(); semaphoreIncrement(); - *remoteInboundSemaphoreId = semaphoreGetLocal(); + cuda::atomic_ref{*remoteInboundSemaphoreId}.store(semaphoreGetLocal(), + cuda::memory_order_release); } /// Signal the remote device for copied packets. @@ -78,9 +83,9 @@ struct SmDevice2DeviceSemaphoreDeviceHandle { __forceinline__ __device__ uint64_t semaphoreGetLocal() const { return *outboundSemaphoreId; } #endif // __CUDACC__ - volatile uint64_t* inboundSemaphoreId; + uint64_t* inboundSemaphoreId; uint64_t* outboundSemaphoreId; - volatile uint64_t* remoteInboundSemaphoreId; + uint64_t* remoteInboundSemaphoreId; uint64_t* expectedInboundSemaphoreId; };