Skip to content

Commit

Permalink
nccl also works
Browse files Browse the repository at this point in the history
  • Loading branch information
Saeed Maleki committed Oct 23, 2023
1 parent 573e5ec commit 95f5125
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 8 deletions.
5 changes: 2 additions & 3 deletions python/benchmark/allreduce1.cu
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ __forceinline__ __device__ int4 add_vectors<__half>(int4 a, int4 b) {
}


__device__ void localReduceScatterSm2(mscclpp::SmChannelDeviceHandle* smChans, TYPE* buff, TYPE* scratch, int rank,
__device__ void localReduceScatterSm2(mscclpp::SmChannelDeviceHandle* smChans, TYPE* buff, int rank,
int nRanksPerNode, size_t chunkSize, size_t nelems, int nBlocks) {
if (nRanksPerNode == 1) return;
if (blockIdx.x >= nBlocks) return;
Expand Down Expand Up @@ -122,8 +122,7 @@ __device__ void localRingAllGatherSm(mscclpp::SmChannelDeviceHandle* smChans, in
// be careful about using channels[my_rank] as it is inavlie and it is there just for simplicity of indexing
extern "C" __global__ void __launch_bounds__(1024, 1)
allreduce1(mscclpp::SmChannelDeviceHandle* smChans, TYPE* buff, int rank, int nranks, int nelems) {
TYPE* scratch = buff + nelems;
localReduceScatterSm2(smChans, buff, scratch, rank, nranks, nelems / nranks, nelems / nranks, gridDim.x);
localReduceScatterSm2(smChans, buff, rank, nranks, nelems / nranks, nelems / nranks, gridDim.x);
deviceSyncer.sync(gridDim.x);
localRingAllGatherSm(smChans, rank, nranks, nelems / nranks * sizeof(TYPE), gridDim.x);
}
79 changes: 76 additions & 3 deletions python/benchmark/allreduce_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
# Licensed under the MIT license.

import os
from test.mscclpp_group import MscclppGroup
import cupy as cp
from test.mscclpp_group import MscclppGroup
from test.mscclpp_mpi import MpiGroup
from test.utils import KernelBuilder, pack
from mscclpp import Transport
from mpi4py import MPI
from prettytable import PrettyTable
import cupy.cuda.nccl as nccl


def human_readable_size(size, decimal_places=1):
Expand All @@ -18,6 +19,32 @@ def human_readable_size(size, decimal_places=1):
size /= 1024.0
return f"{size:.{decimal_places}f} {unit}"

def bench_time(niter: int, func):
# capture cuda graph for nites of the kernel launch
stream = cp.cuda.Stream(non_blocking=True)
with stream:
stream.begin_capture()
for i in range(niter):
func(stream.ptr)
graph = stream.end_capture()


# now run a warm up round
graph.launch(stream)

# now run the benchmark and measure time
start = cp.cuda.Event()
end = cp.cuda.Event()

start.record(stream)
graph.launch(stream)
end.record(stream)
end.synchronize()

return cp.cuda.get_elapsed_time(start, end) / niter * 1000.0



def benchmark(table: PrettyTable, niter: int, nelem: int):
mpi_group = MpiGroup()
group = MscclppGroup(mpi_group)
Expand All @@ -28,7 +55,7 @@ def benchmark(table: PrettyTable, niter: int, nelem: int):

# create a connection for each remote neighbor
connections = group.make_connection(remote_nghrs, Transport.CudaIpc)
memory = cp.zeros(2*nelem, dtype=cp.float32)
memory = cp.zeros(nelem, dtype=cp.float32)
type_str = ""
if memory.dtype == cp.float16:
type_str = "__half"
Expand Down Expand Up @@ -76,7 +103,7 @@ def benchmark(table: PrettyTable, niter: int, nelem: int):
end.synchronize()

time_per_iter = cp.cuda.get_elapsed_time(start, end) / niter * 1000.0
memory_nbytes = memory.nbytes // 2
memory_nbytes = memory.nbytes
algBw = memory_nbytes / time_per_iter / 1e3
if group.my_rank == 0:
table.add_row([human_readable_size(memory_nbytes), "{:.2f}".format(time_per_iter), "{:.2f}".format(algBw)])
Expand All @@ -91,6 +118,51 @@ def benchmark(table: PrettyTable, niter: int, nelem: int):
# print(cp.nonzero(memory[0:nelem]-expected), memory[0:8])


# Initialize MPI
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()
# Create a NCCL unique ID and communicator
if rank == 0:
uid = nccl.get_unique_id()
else:
uid = None
uid = comm.bcast(uid, root=0)
nccl_comm = nccl.NcclCommunicator(size, uid, rank)
def nccl_benchmark(table: PrettyTable, niter: int, nelem: int):

memory = cp.zeros(nelem, dtype=cp.float32)

stream = cp.cuda.Stream(non_blocking=True)
with stream:
stream.begin_capture()
for i in range(niter):
nccl_comm.allReduce(memory.data.ptr, memory.data.ptr, memory.size, nccl.NCCL_FLOAT32, nccl.NCCL_SUM, stream.ptr)
graph = stream.end_capture()


# now run a warm up round
graph.launch(stream)

# now run the benchmark and measure time
MPI.COMM_WORLD.barrier()
start = cp.cuda.Event()
end = cp.cuda.Event()

start.record(stream)
graph.launch(stream)
end.record(stream)
end.synchronize()

time_per_iter = cp.cuda.get_elapsed_time(start, end) / niter * 1000.0
memory_nbytes = memory.nbytes
algBw = memory_nbytes / time_per_iter / 1e3
if rank == 0:
table.add_row([human_readable_size(memory_nbytes), "{:.2f}".format(time_per_iter), "{:.2f}".format(algBw)])
print(".", end="", flush=True)



if __name__ == "__main__":

# Create a table
Expand All @@ -101,6 +173,7 @@ def benchmark(table: PrettyTable, niter: int, nelem: int):
table.field_names = ["Size", "Time (us)", "AlgBW (GB/s)"]

for i in range(10,28):
# nccl_benchmark(table, 1000, 2**i)
benchmark(table, 1000, 2**i)

if MPI.COMM_WORLD.rank == 0:
Expand Down
6 changes: 4 additions & 2 deletions python/test/mscclpp_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@


class MscclppGroup:
def __init__(self, mpi_group: MpiGroup, interfaceIpPortTrio=""):
def __init__(self, mpi_group: MpiGroup = None, interfaceIpPortTrio : str = ""):
self.bootstrap = TcpBootstrap.create(mpi_group.comm.rank, mpi_group.comm.size)
if interfaceIpPortTrio == "":
uniq_id = None
Expand All @@ -37,9 +37,11 @@ def __init__(self, mpi_group: MpiGroup, interfaceIpPortTrio=""):
uniq_id = self.bootstrap.create_unique_id()
uniq_id_global = mpi_group.comm.bcast(uniq_id, 0)
self.bootstrap.initialize(uniq_id_global)
else:
elif mpi_group:
# use this instead
self.bootstrap.initialize(interfaceIpPortTrio)
else:
raise RuntimeError("Either the interface or mpi_group need to be specified")
self.communicator = Communicator(self.bootstrap)
self.my_rank = self.bootstrap.get_rank()
self.nranks = self.bootstrap.get_n_ranks()
Expand Down

0 comments on commit 95f5125

Please sign in to comment.