Skip to content

Commit

Permalink
Fix allreduce bug (#197)
Browse files Browse the repository at this point in the history
Fix allreduce correctness issue
  • Loading branch information
Binyang2014 authored Oct 18, 2023
1 parent 85e8017 commit 6f43282
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions test/mscclpp-test/allreduce_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -884,6 +884,7 @@ __global__ void allreduce6(int* buff, int* scratch, void* resultBuff, int rank,
const int nBlocksPerPeer = gridDim.x / nPeers;
const int localBlockIdx = blockIdx.x % nBlocksPerPeer;
const int peerIdx = blockIdx.x / nBlocksPerPeer;
const int remoteRank = peerIdx < rank ? peerIdx : peerIdx + 1;
DeviceHandle<mscclpp::SmChannel> smChan = constSmOutOfPlaceChans[peerIdx];
const int tid = threadIdx.x + localBlockIdx * blockDim.x;
// double buffering
Expand All @@ -892,9 +893,9 @@ __global__ void allreduce6(int* buff, int* scratch, void* resultBuff, int rank,
size_t scratchOffset = scratchBaseOffset + rank * nPktsPerRank * sizeof(mscclpp::LLPacket);
size_t scratchResultOffset =
(flag & 1) ? 2 * nPkts * sizeof(mscclpp::LLPacket) : 3 * nPkts * sizeof(mscclpp::LLPacket);
size_t srcOffset = rank * nelemsPerRank * sizeof(int);
size_t srcOffset = remoteRank * nelemsPerRank * sizeof(int);
uint2* src = (uint2*)((char*)buff + srcOffset);
uint2* dst = (uint2*)((char*)resultBuff + srcOffset);
uint2* dst = (uint2*)((char*)resultBuff + rank * nelemsPerRank * sizeof(int));

// step 1: write to scratch buffer
smChan.putPackets(scratchOffset, srcOffset, nelemsPerRank * sizeof(int), tid, blockDim.x * nBlocksPerPeer, flag);
Expand All @@ -918,7 +919,6 @@ __global__ void allreduce6(int* buff, int* scratch, void* resultBuff, int rank,
}
}
// step 3: get data result from scratch buffer
const int remoteRank = peerIdx < rank ? peerIdx : peerIdx + 1;
mscclpp::LLPacket* dstPkt = (mscclpp::LLPacket*)((char*)scratch + scratchResultOffset);
const int dstOffset = remoteRank * nPktsPerRank;
uint2* result = (uint2*)((char*)resultBuff + remoteRank * nelemsPerRank * sizeof(int));
Expand Down

0 comments on commit 6f43282

Please sign in to comment.