如何写入 python joblib 中的共享变量

How to write to a shared variable in python joblib

以下代码并行化了一个 for 循环。

import networkx as nx;
import numpy as np;
from joblib import Parallel, delayed;
import multiprocessing;

def core_func(repeat_index, G, numpy_arrary_2D):
  for u in G.nodes():
    numpy_arrary_2D[repeat_index][u] = 2;
  return;

if __name__ == "__main__":
  G = nx.erdos_renyi_graph(100000,0.99);
  nRepeat = 5000;
  numpy_array = np.zeros([nRepeat,G.number_of_nodes()]);
  Parallel(n_jobs=4)(delayed(core_func)(repeat_index, G, numpy_array) for repeat_index in range(nRepeat));
  print(np.mean(numpy_array));

可以看出,要打印的预期值为 2。但是,当我 运行 我的代码在集群(多核,共享内存)上时,它 returns 0.0。

我认为问题是每个worker都创建了自己的numpy_array对象副本,而在main函数中创建的那个没有更新。我如何修改代码以便更新 numpy 数组 numpy_array

joblib默认使用进程的多处理池,如its manual所说:

Under the hood, the Parallel object create a multiprocessing pool that forks the Python interpreter in multiple processes to execute each of the items of the list. The delayed function is a simple trick to be able to create a tuple (function, args, kwargs) with a function-call syntax.

这意味着,每个进程都继承了数组的原始状态,但是无论它向其中写入什么,都会在进程退出时丢失。只有函数结果被传递回调用(主)进程。但是你没有 return 任何东西,所以 None 是 returned.

要使共享数组可修改,有两种方法:使用线程和使用共享内存。


与进程不同,线程共享内存。所以你可以写入数组,每个作业都会看到这个变化。根据joblib手册,是这样做的:

  Parallel(n_jobs=4, backend="threading")(delayed(core_func)(repeat_index, G, numpy_array) for repeat_index in range(nRepeat));

当你运行它时:

$ python r1.py 
2.0

但是,当您将复杂的东西写入数组时,请确保正确处理数据或数据片段周围的锁,否则您将遇到竞争条件(google)。

还要仔细阅读 GIL,因为 Python 中的计算多线程是有限的(不同于 I/O 多线程)。


如果您仍然需要这些进程(例如,因为 GIL),您可以将该数组放入共享内存中。

这是一个有点复杂的主题,但是 joblib + numpy shared memory examplejoblib 手册中也有说明。

正如谢尔盖在他的回答中所写,进程不共享状态和内存。这就是您看不到预期答案的原因。

线程 共享状态和内存 space,因为它们 运行 在同一进程下。如果您有许多 I/O 操作,这将很有用。由于 GIL

,它不会为您提供更多处理能力(更多 CPU)

进程间通信的一种技术是使用管理器的代理对象。您创建一个管理器对象,它在进程之间同步资源。

A manager object returned by Manager() controls a server process which holds Python objects and allows other processes to manipulate them using proxies.

我没有测试过这段代码(我没有你使用的所有模块),它可能需要对代码进行更多修改,但是使用 Manager 对象它应该看起来像这样

if __name__ == "__main__":
    G = nx.erdos_renyi_graph(100000,0.99);
    nRepeat = 5000;

    manager = multiprocessing.Manager()
    numpys = manager.list(np.zeros([nRepeat, G.number_of_nodes()])

    Parallel(n_jobs=4)(delayed(core_func)(repeat_index, G, numpys, que) for repeat_index in range(nRepeat));
    print(np.mean(numpys));