Skip to content

Commit 31068f9

Browse files
committed
minor clean up
1 parent 647ce65 commit 31068f9

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

pylops_mpi/DistributedArray.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,6 @@ def _allreduce(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM):
485485
else:
486486
if is_cuda_aware_mpi or self.engine == "numpy":
487487
ncp = get_module(self.engine)
488-
# mpi_type = MPI._typedict[send_buf.dtype.char]
489488
recv_buf = ncp.zeros(send_buf.size, dtype=send_buf.dtype)
490489
self.base_comm.Allreduce(send_buf, recv_buf, op)
491490
return recv_buf
@@ -505,7 +504,6 @@ def _allreduce_subcomm(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM):
505504
else:
506505
if is_cuda_aware_mpi or self.engine == "numpy":
507506
ncp = get_module(self.engine)
508-
# mpi_type = MPI._typedict[send_buf.dtype.char]
509507
recv_buf = ncp.zeros(send_buf.size, dtype=send_buf.dtype)
510508
self.sub_comm.Allreduce(send_buf, recv_buf, op)
511509
return recv_buf
@@ -743,6 +741,9 @@ def _compute_vector_norm(self, local_array: NDArray,
743741
recv_buf = ncp.asarray(ncp.squeeze(recv_buf, axis=axis))
744742
else:
745743
recv_buf = self._allreduce_subcomm(send_buf, recv_buf, op=MPI.MAX)
744+
# TODO (tharitt): In current implementation, there seems to be a semantic difference between Buffered MPI and NCCL
745+
# the (1, size) is collapsed to (size, ) with buffered MPI while NCCL retains it.
746+
# There may be a way to unify it - may be something to do with how we allocate the recv_buf.
746747
if self.base_comm_nccl:
747748
recv_buf = ncp.squeeze(recv_buf, axis=axis)
748749
elif ord == -ncp.inf:

0 commit comments

Comments
 (0)