|
1 | 1 | r"""
|
2 | 2 | Distributed SUMMA Matrix Multiplication
|
3 | 3 | =======================================
|
4 |
| -This example shows how to use the :py:class:`pylops_mpi.basicoperators.MPISummaMatrixMult` |
| 4 | +This example shows how to use the :py:class:`pylops_mpi.basicoperators._MPISummaMatrixMult` |
5 | 5 | operator to perform matrix-matrix multiplication between a matrix :math:`\mathbf{A}`
|
6 | 6 | distributed in 2D blocks across a square process grid and matrices :math:`\mathbf{X}`
|
7 | 7 | and :math:`\mathbf{Y}` distributed in 2D blocks across the same grid. Similarly,
|
|
20 | 20 | import math
|
21 | 21 | import numpy as np
|
22 | 22 | from mpi4py import MPI
|
| 23 | +from matplotlib import pyplot as plt |
23 | 24 |
|
24 | 25 | import pylops_mpi
|
25 |
| -from pylops_mpi.basicoperators.MatrixMult import (local_block_spit, block_gather, MPIMatrixMult) |
| 26 | +from pylops import Conj |
| 27 | +from pylops_mpi.basicoperators.MatrixMult import (local_block_spit, MPIMatrixMult, active_grid_comm) |
26 | 28 |
|
27 |
| -comm = MPI.COMM_WORLD |
28 |
| -rank = comm.Get_rank() |
29 |
| -size = comm.Get_size() |
| 29 | +plt.close("all") |
30 | 30 |
|
31 |
| -N = 9 |
32 |
| -M = 9 |
33 |
| -K = 9 |
| 31 | +############################################################################### |
| 32 | +# We set the seed such that all processes can create the input matrices filled |
| 33 | +# with the same random number. In practical application, such matrices will be |
| 34 | +# filled with data that is appropriate that is appropriate the use-case. |
| 35 | +np.random.seed(42) |
34 | 36 |
|
35 |
| -A_shape = (N, K) |
36 |
| -x_shape = (K, M) |
37 |
| -y_shape = (N, M) |
38 | 37 |
|
39 |
| -p_prime = math.isqrt(size) |
| 38 | +N, M, K = 6, 6, 6 |
| 39 | +A_shape, x_shape, y_shape= (N, K), (K, M), (N, M) |
| 40 | + |
| 41 | + |
| 42 | +base_comm = MPI.COMM_WORLD |
| 43 | +comm, rank, row_id, col_id, is_active = active_grid_comm(base_comm, N, M) |
| 44 | +print(f"Process {base_comm.Get_rank()} is {'active' if is_active else 'inactive'}") |
| 45 | + |
| 46 | + |
| 47 | +############################################################################### |
| 48 | +# We are now ready to create the input matrices for our distributed matrix |
| 49 | +# multiplication example. We need to set up: |
| 50 | +# - Matrix :math:`\mathbf{A}` of size :math:`N \times K` (the left operand) |
| 51 | +# - Matrix :math:`\mathbf{X}` of size :math:`K \times M` (the right operand) |
| 52 | +# - The result will be :math:`\mathbf{Y} = \mathbf{A} \mathbf{X}` of size :math:`N \times M` |
| 53 | +# |
| 54 | +# For distributed computation, we arrange processes in a square grid of size |
| 55 | +# :math:`P' \times P'` where :math:`P' = \sqrt{P}` and :math:`P` is the total |
| 56 | +# number of MPI processes. Each process will own a block of each matrix |
| 57 | +# according to this 2D grid layout. |
| 58 | + |
| 59 | +p_prime = math.isqrt(comm.Get_size()) |
| 60 | +print(f"Process grid: {p_prime} x {p_prime} = {comm.Get_size()} processes") |
| 61 | + |
| 62 | +# Create global test matrices with sequential values for easy verification |
| 63 | +# Matrix A: Each element :math:`A_{i,j} = i \cdot K + j` (row-major ordering) |
| 64 | +# Matrix X: Each element :math:`X_{i,j} = i \cdot M + j` |
40 | 65 | A_data = np.arange(int(A_shape[0] * A_shape[1])).reshape(A_shape)
|
41 | 66 | x_data = np.arange(int(x_shape[0] * x_shape[1])).reshape(x_shape)
|
42 | 67 |
|
| 68 | +print(f"Global matrix A shape: {A_shape} (N={A_shape[0]}, K={A_shape[1]})") |
| 69 | +print(f"Global matrix X shape: {x_shape} (K={x_shape[0]}, M={x_shape[1]})") |
| 70 | +print(f"Expected Global result Y shape: ({A_shape[0]}, {x_shape[1]}) = (N, M)") |
| 71 | + |
| 72 | +################################################################################ |
| 73 | +# Determine which block of each matrix this process should own |
| 74 | +# The 2D block distribution ensures: |
| 75 | +# - Process at grid position :math:`(i,j)` gets block :math:`\mathbf{A}[i_{start}:i_{end}, j_{start}:j_{end}]` |
| 76 | +# - Block sizes are approximately :math:`\lceil N/P' \rceil \times \lceil K/P' \rceil` with edge processes handling remainder |
| 77 | +# |
| 78 | +# .. raw:: html |
| 79 | +# |
| 80 | +# <div style="text-align: left; font-family: monospace; white-space: pre;"> |
| 81 | +# <b>Example: 2x2 Process Grid with 6x6 Matrices</b> |
| 82 | +# |
| 83 | +# Matrix A (6x6): Matrix X (6x6): |
| 84 | +# ┌───────────┬───────────┐ ┌───────────┬───────────┐ |
| 85 | +# │ 0 1 2 │ 3 4 5 │ │ 0 1 2 │ 3 4 5 │ |
| 86 | +# │ 6 7 8 │ 9 10 11 │ │ 6 7 8 │ 9 10 11 │ |
| 87 | +# │ 12 13 14 │ 15 16 17 │ │ 12 13 14 │ 15 16 17 │ |
| 88 | +# ├───────────┼───────────┤ ├───────────┼───────────┤ |
| 89 | +# │ 18 19 20 │ 21 22 23 │ │ 18 19 20 │ 21 22 23 │ |
| 90 | +# │ 24 25 26 │ 27 28 29 │ │ 24 25 26 │ 27 28 29 │ |
| 91 | +# │ 30 31 32 │ 33 34 35 │ │ 30 31 32 │ 33 34 35 │ |
| 92 | +# └───────────┴───────────┘ └───────────┴───────────┘ |
| 93 | +# |
| 94 | +# Process (0,0): A[0:3, 0:3], X[0:3, 0:3] |
| 95 | +# Process (0,1): A[0:3, 3:6], X[0:3, 3:6] |
| 96 | +# Process (1,0): A[3:6, 0:3], X[3:6, 0:3] |
| 97 | +# Process (1,1): A[3:6, 3:6], X[3:6, 3:6] |
| 98 | +# </div> |
| 99 | +# |
| 100 | + |
43 | 101 | A_slice = local_block_spit(A_shape, rank, comm)
|
44 | 102 | x_slice = local_block_spit(x_shape, rank, comm)
|
| 103 | +################################################################################ |
| 104 | +# Extract the local portion of each matrix for this process |
45 | 105 | A_local = A_data[A_slice]
|
46 | 106 | x_local = x_data[x_slice]
|
47 | 107 |
|
| 108 | +print(f"Process {rank}: A_local shape {A_local.shape}, X_local shape {x_local.shape}") |
| 109 | +print(f"Process {rank}: A slice {A_slice}, X slice {x_slice}") |
| 110 | + |
48 | 111 | x_dist = pylops_mpi.DistributedArray(global_shape=(K * M),
|
49 | 112 | local_shapes=comm.allgather(x_local.shape[0] * x_local.shape[1]),
|
50 | 113 | base_comm=comm,
|
51 | 114 | partition=pylops_mpi.Partition.SCATTER,
|
52 | 115 | dtype=x_local.dtype)
|
53 |
| -x_dist.local_array[:] = x_local.flatten() |
| 116 | +x_dist[:] = x_local.flatten() |
54 | 117 |
|
| 118 | +################################################################################ |
| 119 | +# We are now ready to create the SUMMA :py:class:`pylops_mpi.basicoperators.MPIMatrixMult` |
| 120 | +# operator and the input matrix :math:`\mathbf{X}`. Given that we chose a block-block distribution |
| 121 | +# of data we shall use SUMMA |
55 | 122 | Aop = MPIMatrixMult(A_local, M, base_comm=comm, kind="summa", dtype=A_local.dtype)
|
| 123 | + |
| 124 | +################################################################################ |
| 125 | +# We can now apply the forward pass :math:`\mathbf{y} = \mathbf{Ax}` (which |
| 126 | +# effectively implements a distributed matrix-matrix multiplication |
| 127 | +# :math:`Y = \mathbf{AX}`). Note :math:`\mathbf{Y}` is distributed in the same |
| 128 | +# way as the input :math:`\mathbf{X}` in a block-block fashion. |
56 | 129 | y_dist = Aop @ x_dist
|
| 130 | + |
| 131 | +############################################################################### |
| 132 | +# Next we apply the adjoint pass :math:`\mathbf{x}_{adj} = \mathbf{A}^H \mathbf{x}` |
| 133 | +# (which effectively implements a distributed summa matrix-matrix multiplication |
| 134 | +# :math:`\mathbf{X}_{adj} = \mathbf{A}^H \mathbf{X}`). Note that |
| 135 | +# :math:`\mathbf{X}_{adj}` is again distributed in the same way as the input |
| 136 | +# :math:`\mathbf{X}` in a block-block fashion. |
57 | 137 | xadj_dist = Aop.H @ y_dist
|
58 | 138 |
|
59 |
| -y = block_gather(y_dist, (N,M), (N,M), comm) |
60 |
| -xadj = block_gather(xadj_dist, (K,M), (K,M), comm) |
61 |
| -if rank == 0 : |
62 |
| - y_correct = np.allclose(A_data @ x_data, y) |
63 |
| - print("y expected: ", y_correct) |
64 |
| - if not y_correct: |
65 |
| - print("expected:\n", A_data @ x_data) |
66 |
| - print("calculated:\n",y) |
67 |
| - |
68 |
| - xadj_correct = np.allclose((A_data.T.dot((A_data @ x_data).conj())).conj(), xadj.astype(np.int32)) |
69 |
| - print("xadj expected: ", xadj_correct) |
70 |
| - if not xadj_correct: |
71 |
| - print("expected:\n", (A_data.T.dot((A_data @ x_data).conj())).conj()) |
72 |
| - print("calculated:\n", xadj.astype(np.int32)) |
| 139 | +############################################################################### |
| 140 | +# Finally, we show that the SUMMA :py:class:`pylops_mpi.basicoperators.MPIMatrixMult` |
| 141 | +# operator can be combined with any other PyLops-MPI operator. We are going to |
| 142 | +# apply here a conjugate operator to the output of the matrix multiplication. |
| 143 | +Dop = Conj(dims=(A_local.shape[0], x_local.shape[1])) |
| 144 | +DBop = pylops_mpi.MPIBlockDiag(ops=[Dop, ]) |
| 145 | +Op = DBop @ Aop |
| 146 | +y1 = Op @ x_dist |
0 commit comments