@@ -485,7 +485,6 @@ def _allreduce(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM):
485
485
else :
486
486
if is_cuda_aware_mpi or self .engine == "numpy" :
487
487
ncp = get_module (self .engine )
488
- # mpi_type = MPI._typedict[send_buf.dtype.char]
489
488
recv_buf = ncp .zeros (send_buf .size , dtype = send_buf .dtype )
490
489
self .base_comm .Allreduce (send_buf , recv_buf , op )
491
490
return recv_buf
@@ -505,7 +504,6 @@ def _allreduce_subcomm(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM):
505
504
else :
506
505
if is_cuda_aware_mpi or self .engine == "numpy" :
507
506
ncp = get_module (self .engine )
508
- # mpi_type = MPI._typedict[send_buf.dtype.char]
509
507
recv_buf = ncp .zeros (send_buf .size , dtype = send_buf .dtype )
510
508
self .sub_comm .Allreduce (send_buf , recv_buf , op )
511
509
return recv_buf
@@ -743,6 +741,9 @@ def _compute_vector_norm(self, local_array: NDArray,
743
741
recv_buf = ncp .asarray (ncp .squeeze (recv_buf , axis = axis ))
744
742
else :
745
743
recv_buf = self ._allreduce_subcomm (send_buf , recv_buf , op = MPI .MAX )
744
+ # TODO (tharitt): In current implementation, there seems to be a semantic difference between Buffered MPI and NCCL
745
+ # the (1, size) is collapsed to (size, ) with buffered MPI while NCCL retains it.
746
+ # There may be a way to unify it - may be something to do with how we allocate the recv_buf.
746
747
if self .base_comm_nccl :
747
748
recv_buf = ncp .squeeze (recv_buf , axis = axis )
748
749
elif ord == - ncp .inf :
0 commit comments