mpi4py 中的共享内存

Shared memory in mpi4py

我使用 MPI (mpi4py) 脚本(在单个节点上),它处理非常大的对象。为了让所有进程都能访问对象,我通过comm.bcast()进行分发。这会将对象复制到所有进程并消耗大量内存,尤其是在复制过程中。因此,我想分享一些类似指针的东西,而不是对象本身。我发现 memoryview 中的一些功能有助于促进进程内对象的工作。此外,对象的实际内存地址可通过 memoryview 对象字符串表示形式访问,并且可以像这样分配:

from mpi4py import MPI

comm = MPI.COMM_WORLD
rank = comm.Get_rank()

if rank:
    content_pointer = comm.bcast(root = 0)
    print(rank, content_pointer)
else:
    content = ''.join(['a' for i in range(100000000)]).encode()
    mv = memoryview(content)
    print(mv)
    comm.bcast(str(mv).split()[-1][: -1], root = 0)

这会打印:

<memory at 0x7f362a405048>
1 0x7f362a405048
2 0x7f362a405048
...

这就是为什么我相信必须有一种方法可以在另一个进程中重构对象。但是,我无法在文档中找到有关如何执行此操作的线索。

简而言之,我的问题是:是否可以在mpi4py中的同一节点上的进程之间共享一个对象?

我对mpi4py不是很了解,但是从MPI的角度来看这应该是不可能的。 MPI 代表消息传递接口,这意味着:在进程之间传递消息。您可以尝试使用 MPI One-sided communication 来类似于全局可访问内存,但进程内存对其他进程不可用。

如果您需要依赖大块共享内存,则需要使用 OpenMP 或线程之类的东西,您绝对可以在单个节点上使用它们。使用 MPI 和一些共享内存并行化的混合并行化将允许每个节点有一个共享内存块,但仍然可以选择使用多个节点。

这是一个使用 MPI 的共享内存的简单示例,对 https://groups.google.com/d/msg/mpi4py/Fme1n9niNwQ/lk3VJ54WAQAJ

稍作修改

您可以 运行 使用:mpirun -n 2 python3 shared_memory_test.py(假设您将其保存为 shared_memory_test.py)

from mpi4py import MPI 
import numpy as np 

comm = MPI.COMM_WORLD 

# create a shared array of size 1000 elements of type double
size = 1000 
itemsize = MPI.DOUBLE.Get_size() 
if comm.Get_rank() == 0: 
    nbytes = size * itemsize 
else: 
    nbytes = 0

# on rank 0, create the shared block
# on rank 1 get a handle to it (known as a window in MPI speak)
win = MPI.Win.Allocate_shared(nbytes, itemsize, comm=comm) 

# create a numpy array whose data points to the shared mem
buf, itemsize = win.Shared_query(0) 
assert itemsize == MPI.DOUBLE.Get_size() 
ary = np.ndarray(buffer=buf, dtype='d', shape=(size,)) 

# in process rank 1:
# write the numbers 0.0,1.0,..,4.0 to the first 5 elements of the array
if comm.rank == 1: 
    ary[:5] = np.arange(5)

# wait in process rank 0 until process 1 has written to the array
comm.Barrier() 

# check that the array is actually shared and process 0 can see
# the changes made in the array by process 1
if comm.rank == 0: 
    print(ary[:10])

应该输出这个(从进程等级 0 打印):

[0. 1. 2. 3. 4. 0. 0. 0. 0. 0.]