Skip to content

Commit 8851e05

Browse files
committed
Added documentation and example explination
1 parent b053f5b commit 8851e05

File tree

4 files changed

+116
-32
lines changed

4 files changed

+116
-32
lines changed

docs/source/api/index.rst

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ Basic Operators
4242
.. autosummary::
4343
:toctree: generated/
4444

45-
MPIMatrixMult
45+
MatrixMult.MPIMatrixMult
4646
MPIBlockDiag
4747
MPIStackedBlockDiag
4848
MPIVStack
@@ -118,6 +118,16 @@ Utils
118118
local_split
119119

120120

121+
.. currentmodule:: pylops_mpi.basicoperators.MatrixMult
122+
123+
.. autosummary::
124+
:toctree: generated/
125+
126+
block_gather
127+
local_block_split
128+
active_grid_comm
129+
130+
121131
.. currentmodule:: pylops_mpi.utils.dottest
122132

123133
.. autosummary::

examples/plot_summamatrixmult.py

Lines changed: 101 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
r"""
22
Distributed SUMMA Matrix Multiplication
33
=======================================
4-
This example shows how to use the :py:class:`pylops_mpi.basicoperators.MPISummaMatrixMult`
4+
This example shows how to use the :py:class:`pylops_mpi.basicoperators._MPISummaMatrixMult`
55
operator to perform matrix-matrix multiplication between a matrix :math:`\mathbf{A}`
66
distributed in 2D blocks across a square process grid and matrices :math:`\mathbf{X}`
77
and :math:`\mathbf{Y}` distributed in 2D blocks across the same grid. Similarly,
@@ -20,53 +20,127 @@
2020
import math
2121
import numpy as np
2222
from mpi4py import MPI
23+
from matplotlib import pyplot as plt
2324

2425
import pylops_mpi
25-
from pylops_mpi.basicoperators.MatrixMult import (local_block_spit, block_gather, MPIMatrixMult)
26+
from pylops import Conj
27+
from pylops_mpi.basicoperators.MatrixMult import (local_block_spit, MPIMatrixMult, active_grid_comm)
2628

27-
comm = MPI.COMM_WORLD
28-
rank = comm.Get_rank()
29-
size = comm.Get_size()
29+
plt.close("all")
3030

31-
N = 9
32-
M = 9
33-
K = 9
31+
###############################################################################
32+
# We set the seed such that all processes can create the input matrices filled
33+
# with the same random number. In practical application, such matrices will be
34+
# filled with data that is appropriate that is appropriate the use-case.
35+
np.random.seed(42)
3436

35-
A_shape = (N, K)
36-
x_shape = (K, M)
37-
y_shape = (N, M)
3837

39-
p_prime = math.isqrt(size)
38+
N, M, K = 6, 6, 6
39+
A_shape, x_shape, y_shape= (N, K), (K, M), (N, M)
40+
41+
42+
base_comm = MPI.COMM_WORLD
43+
comm, rank, row_id, col_id, is_active = active_grid_comm(base_comm, N, M)
44+
print(f"Process {base_comm.Get_rank()} is {'active' if is_active else 'inactive'}")
45+
46+
47+
###############################################################################
48+
# We are now ready to create the input matrices for our distributed matrix
49+
# multiplication example. We need to set up:
50+
# - Matrix :math:`\mathbf{A}` of size :math:`N \times K` (the left operand)
51+
# - Matrix :math:`\mathbf{X}` of size :math:`K \times M` (the right operand)
52+
# - The result will be :math:`\mathbf{Y} = \mathbf{A} \mathbf{X}` of size :math:`N \times M`
53+
#
54+
# For distributed computation, we arrange processes in a square grid of size
55+
# :math:`P' \times P'` where :math:`P' = \sqrt{P}` and :math:`P` is the total
56+
# number of MPI processes. Each process will own a block of each matrix
57+
# according to this 2D grid layout.
58+
59+
p_prime = math.isqrt(comm.Get_size())
60+
print(f"Process grid: {p_prime} x {p_prime} = {comm.Get_size()} processes")
61+
62+
# Create global test matrices with sequential values for easy verification
63+
# Matrix A: Each element :math:`A_{i,j} = i \cdot K + j` (row-major ordering)
64+
# Matrix X: Each element :math:`X_{i,j} = i \cdot M + j`
4065
A_data = np.arange(int(A_shape[0] * A_shape[1])).reshape(A_shape)
4166
x_data = np.arange(int(x_shape[0] * x_shape[1])).reshape(x_shape)
4267

68+
print(f"Global matrix A shape: {A_shape} (N={A_shape[0]}, K={A_shape[1]})")
69+
print(f"Global matrix X shape: {x_shape} (K={x_shape[0]}, M={x_shape[1]})")
70+
print(f"Expected Global result Y shape: ({A_shape[0]}, {x_shape[1]}) = (N, M)")
71+
72+
################################################################################
73+
# Determine which block of each matrix this process should own
74+
# The 2D block distribution ensures:
75+
# - Process at grid position :math:`(i,j)` gets block :math:`\mathbf{A}[i_{start}:i_{end}, j_{start}:j_{end}]`
76+
# - Block sizes are approximately :math:`\lceil N/P' \rceil \times \lceil K/P' \rceil` with edge processes handling remainder
77+
#
78+
# .. raw:: html
79+
#
80+
# <div style="text-align: left; font-family: monospace; white-space: pre;">
81+
# <b>Example: 2x2 Process Grid with 6x6 Matrices</b>
82+
#
83+
# Matrix A (6x6): Matrix X (6x6):
84+
# ┌───────────┬───────────┐ ┌───────────┬───────────┐
85+
# │ 0 1 2 │ 3 4 5 │ │ 0 1 2 │ 3 4 5 │
86+
# │ 6 7 8 │ 9 10 11 │ │ 6 7 8 │ 9 10 11 │
87+
# │ 12 13 14 │ 15 16 17 │ │ 12 13 14 │ 15 16 17 │
88+
# ├───────────┼───────────┤ ├───────────┼───────────┤
89+
# │ 18 19 20 │ 21 22 23 │ │ 18 19 20 │ 21 22 23 │
90+
# │ 24 25 26 │ 27 28 29 │ │ 24 25 26 │ 27 28 29 │
91+
# │ 30 31 32 │ 33 34 35 │ │ 30 31 32 │ 33 34 35 │
92+
# └───────────┴───────────┘ └───────────┴───────────┘
93+
#
94+
# Process (0,0): A[0:3, 0:3], X[0:3, 0:3]
95+
# Process (0,1): A[0:3, 3:6], X[0:3, 3:6]
96+
# Process (1,0): A[3:6, 0:3], X[3:6, 0:3]
97+
# Process (1,1): A[3:6, 3:6], X[3:6, 3:6]
98+
# </div>
99+
#
100+
43101
A_slice = local_block_spit(A_shape, rank, comm)
44102
x_slice = local_block_spit(x_shape, rank, comm)
103+
################################################################################
104+
# Extract the local portion of each matrix for this process
45105
A_local = A_data[A_slice]
46106
x_local = x_data[x_slice]
47107

108+
print(f"Process {rank}: A_local shape {A_local.shape}, X_local shape {x_local.shape}")
109+
print(f"Process {rank}: A slice {A_slice}, X slice {x_slice}")
110+
48111
x_dist = pylops_mpi.DistributedArray(global_shape=(K * M),
49112
local_shapes=comm.allgather(x_local.shape[0] * x_local.shape[1]),
50113
base_comm=comm,
51114
partition=pylops_mpi.Partition.SCATTER,
52115
dtype=x_local.dtype)
53-
x_dist.local_array[:] = x_local.flatten()
116+
x_dist[:] = x_local.flatten()
54117

118+
################################################################################
119+
# We are now ready to create the SUMMA :py:class:`pylops_mpi.basicoperators.MPIMatrixMult`
120+
# operator and the input matrix :math:`\mathbf{X}`. Given that we chose a block-block distribution
121+
# of data we shall use SUMMA
55122
Aop = MPIMatrixMult(A_local, M, base_comm=comm, kind="summa", dtype=A_local.dtype)
123+
124+
################################################################################
125+
# We can now apply the forward pass :math:`\mathbf{y} = \mathbf{Ax}` (which
126+
# effectively implements a distributed matrix-matrix multiplication
127+
# :math:`Y = \mathbf{AX}`). Note :math:`\mathbf{Y}` is distributed in the same
128+
# way as the input :math:`\mathbf{X}` in a block-block fashion.
56129
y_dist = Aop @ x_dist
130+
131+
###############################################################################
132+
# Next we apply the adjoint pass :math:`\mathbf{x}_{adj} = \mathbf{A}^H \mathbf{x}`
133+
# (which effectively implements a distributed summa matrix-matrix multiplication
134+
# :math:`\mathbf{X}_{adj} = \mathbf{A}^H \mathbf{X}`). Note that
135+
# :math:`\mathbf{X}_{adj}` is again distributed in the same way as the input
136+
# :math:`\mathbf{X}` in a block-block fashion.
57137
xadj_dist = Aop.H @ y_dist
58138

59-
y = block_gather(y_dist, (N,M), (N,M), comm)
60-
xadj = block_gather(xadj_dist, (K,M), (K,M), comm)
61-
if rank == 0 :
62-
y_correct = np.allclose(A_data @ x_data, y)
63-
print("y expected: ", y_correct)
64-
if not y_correct:
65-
print("expected:\n", A_data @ x_data)
66-
print("calculated:\n",y)
67-
68-
xadj_correct = np.allclose((A_data.T.dot((A_data @ x_data).conj())).conj(), xadj.astype(np.int32))
69-
print("xadj expected: ", xadj_correct)
70-
if not xadj_correct:
71-
print("expected:\n", (A_data.T.dot((A_data @ x_data).conj())).conj())
72-
print("calculated:\n", xadj.astype(np.int32))
139+
###############################################################################
140+
# Finally, we show that the SUMMA :py:class:`pylops_mpi.basicoperators.MPIMatrixMult`
141+
# operator can be combined with any other PyLops-MPI operator. We are going to
142+
# apply here a conjugate operator to the output of the matrix multiplication.
143+
Dop = Conj(dims=(A_local.shape[0], x_local.shape[1]))
144+
DBop = pylops_mpi.MPIBlockDiag(ops=[Dop, ])
145+
Op = DBop @ Aop
146+
y1 = Op @ x_dist

pylops_mpi/LinearOperator.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@ def matvec(self, x: DistributedArray) -> DistributedArray:
7676
7777
"""
7878
M, N = self.shape
79-
8079
if x.global_shape != (N,):
8180
raise ValueError("dimension mismatch")
8281

pylops_mpi/basicoperators/MatrixMult.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
12
import math
23
import numpy as np
34
from typing import Tuple, Union, Literal
@@ -14,7 +15,7 @@
1415

1516

1617
def active_grid_comm(base_comm: MPI.Comm, N: int, M: int):
17-
r"""Configure active grid
18+
r"""Configure active grid for distributed matrix multiplication.
1819
1920
Configure a square process grid from a parent MPI communicator and
2021
select a subset of "active" processes. Each process in ``base_comm``
@@ -721,7 +722,6 @@ def MPIMatrixMult(
721722
``kind`` parameter.
722723
723724
The forward operation computes::
724-
725725
:math:`\mathbf{Y} = \mathbf{A} \cdot \mathbf{X}`
726726
727727
where:
@@ -730,7 +730,6 @@ def MPIMatrixMult(
730730
- :math:`\mathbf{Y}` is the resulting distributed matrix of shape :math:`[N \times M]`
731731
732732
The adjoint (conjugate-transpose) operation computes::
733-
734733
:math:`\mathbf{X}_{adj} = \mathbf{A}^H \cdot \mathbf{Y}`
735734
736735
where :math:`\mathbf{A}^H` is the complex-conjugate transpose of :math:`\mathbf{A}`.
@@ -792,3 +791,5 @@ def MPIMatrixMult(
792791
return _MPIBlockMatrixMult(A, M, saveAt, base_comm, dtype)
793792
else:
794793
raise NotImplementedError("kind must be summa or block")
794+
795+
__all__ = ["active_grid_comm", "block_gather", "local_block_spit", "MPIMatrixMult"]

0 commit comments

Comments
 (0)