MPI Parallelism with Python
Note
Before venturing into MPI-based parallelism, consider whether your work can be resturctured to make use of dSQ or more "embarrassingly parallel" workflows. MPI can be thought of as a "last resort" for parallel programming.
There are many computational problems that can be have increased performance by running pieces in parallel. These often require communication between the different steps and need a way to send messages between processes.
Examples of this include simulations of galaxy formation and electric field simulations, analysis of a single large dataset, or complex search
or sort
algorithms.
MPI and mpi4py
There is a standard protocol, called MPI
, that defines how messages are passed between processes, including one-to-one and broadcast communications.
The Python module for this is called mpi4py
:
Message Passing Interface implemented for Python.
Supports point-to-point (sends, receives) and collective (broadcasts, scatters, gathers) communications of any picklable Python object, as well as optimized communications of Python object exposing the single-segment buffer interface (NumPy arrays, builtin bytes/string/array objects)
We will go over a few simple examples here.
Definitions
COMM
: The communication "world" defined by MPI
RANK
: an ID number given to each internal process to define communication
SIZE
: total number of processes allocated
BROADCAST
: One-to-many communication
SCATTER
: One-to-many data distribution
GATHER
: Many-to-one data distribution
mpi4py
on the clusters
On the clusters, the easiest way to start using mpi4py
is to use the module-based software for OpenMPI and Python:
# toolchains 2020b and before
module load SciPy-bundle/2020.11-foss-2020b
# toolchains starting with 2022b
module load mpi4py/3.1.4-gompi-2022b
Warning
mpi4py
installed via Conda is unaware of the cluster infrastructure and therefore will likely only work on a single compute node.
If you wish to get a conda environment working across multiple nodes, please reach out to hpc@yale.edu for assistance.
Cluster Resource Requests
MPI utilizes Slurm tasks as the individual parallel workers.
Therefore, when requesting resources (either interactively or in batch-mode) the number of tasks will determine the number of parallel workers (or to use MPI's language, the SIZE
of the COMM World
).
To request four tasks (each with a single CPU) interactively run the following:
salloc --cpus-per-task=1 --ntasks=4
This can also be achieved in batch-mode by including the following directives in your submission script:
#SBATCH --cpus-per-task=1
#SBATCH --ntasks=4
A more detailed discussion of resource requests can be found here and further examples are available here.
Examples
Ex 1: Rank
This is a simple example where each worker reports their RANK
and the process ID running that particular task.
from mpi4py import MPI
# instantize the communication world
comm = MPI.COMM_WORLD
# get the size of the communication world
size = comm.Get_size()
# get this particular processes' `rank` ID
rank = comm.Get_rank()
PID = os.getpid()
print(f'rank: {rank} has PID: {PID}')
We then execute this code (named mpi_simple.py
) by running the following on the command line:
mpirun -n 4 python mpi_simple.py
The mpirun
command is a wrapper for the MPI interface.
Then we tell that to set up a COMM_WORLD
with 4 workers.
Finally we tell mpirun
to run python mpi_simple.py
on each of the four workers.
Which outputs the following:
rank: 0 has PID: 89134
rank: 1 has PID: 89135
rank: 2 has PID: 89136
rank: 3 has PID: 89137
Ex 2: Point to Point Communicators
The most basic communication operators are "send
" and "recv
". These can be a bit tricky since they are "blocking" commands and can cause the program to hang.
comm.send(obj, dest, tag=0)
comm.recv(source=MPI.ANY_SOURCE, tag=MPI.ANY_TAG, status=None)
tag
can be used as a filterdest
must be a rank in the current communicatorsource
can be a rank or a wild-card (MPI.ANY_SOURCE
)status
used to retrieve information about recv'd message
We now we create a file (mpi_comm.py
) that contains the following:
from mpi4py import MPI
comm = MPI.COMM_WORLD
size = comm.Get_size()
rank = comm.Get_rank()
if rank == 0:
msg = 'Hello, world'
comm.send(msg, dest=1)
elif rank == 1:
s = comm.recv()
print(f"rank {rank}: {s}")
When we run this on the command line (mpirun -n 4 python mpi_comm.py
) we get the following:
rank 1: Hello, world
The RANK=0
process sends the message, and the RANK=1
process receives it. The other two processes are effectively bystanders in this example.
Ex 3: Broadcast
Now we will try a slightly more complicated example that involves sending messages and data between processes.
# Import MPI
from mpi4py import MPI
# Define world
comm = MPI.COMM_WORLD
size = comm.Get_size()
rank = comm.Get_rank()
# Create some data in the RANK_0 worker
if rank == 0:
data = {'key1' : [7, 2.72, 2+3j], 'key2' : ( 'abc', 'xyz')}
else:
data = None
# Broadcast the data from RANK_0 to all workers
data = comm.bcast(data, root=0)
# Append the RANK ID to the data
data['key1'].append(rank)
# Print the resulting data
print(f"Rank: {rank}, data: {data}")
We then execute this code (named mpi_message.py
) by running the following on the command line:
mpirun -n 4 python mpi_message.py
Which outputs the following:
Rank: 0, data: {'key1': [7, 2.72, (2+3j), 0], 'key2': ('abc', 'xyz')}
Rank: 2, data: {'key1': [7, 2.72, (2+3j), 2], 'key2': ('abc', 'xyz')}
Rank: 3, data: {'key1': [7, 2.72, (2+3j), 3], 'key2': ('abc', 'xyz')}
Rank: 1, data: {'key1': [7, 2.72, (2+3j), 1], 'key2': ('abc', 'xyz')}
Ex 4: Scatter and Gather
An effective way of distributing computationally intensive tasks is to scatter
pieces of a large dataset to each task. The separate tasks perform some analysis on their chunk of data and then the results are gathered
by RANK_0
.
This example takes a large array of random numbers and splits it into pieces for each task. These smaller datasets are analyzed (taking an average in this example) and the results are returns to the main task with a Gather
call.
# import libraries
from mpi4py import MPI
import numpy as np
# set up MPI world
comm = MPI.COMM_WORLD
size = comm.Get_size() # new: gives number of ranks in comm
rank = comm.Get_rank()
# generate a large array of data on RANK_0
numData = 100000000 # 100milion values each
data = None
if rank == 0:
data = np.random.normal(loc=10, scale=5, size=numData)
# initialize empty arrays to receive the partial data
partial = np.empty(int(numData/size), dtype='d')
# send data to the other workers
comm.Scatter(data, partial, root=0)
# prepare the reduced array to receive the processed data
reduced = None
if rank == 0:
reduced = np.empty(size, dtype='d')
# Average the partial arrays, and then gather them to RANK_0
comm.Gather(np.average(partial), reduced, root=0)
if rank == 0:
print('Full Average:',np.average(reduced))
This is executed on the command line:
mpirun -n 4 python mpi/mpi_scatter.py
Which prints:
Full Average: 10.00002060397186
Key Take-aways and Further Reading
MPI
is a powerful tool to set up communication worlds and send data and messages between workers- The
mpi4py
module provides tools for using MPI within Python. - This is just the beginning,
mpi4py
can be used for so much more...
To learn more, take a look at the mpi4py
tutorial here.