如何在 Python 中使用多线程安全地修改全局 numpy 数组?

How to modify global numpy array safely with multithreading in Python?

我正在尝试 运行 在线程池中进行模拟,并将每次重复的结果存储在全局 numpy 数组中。但是,我在这样做时遇到了问题,我正在使用以下(简化的)代码(python 3.7)观察到一个非常有趣的行为:

import numpy as np
from multiprocessing import Pool, Lock

log_mutex = Lock()
repetition_count = 5
data_array = np.zeros(shape=(repetition_count, 3, 200), dtype=float)

def record_results(repetition_index, data_array, log_mutex):
    log_mutex.acquire()
    print("Start record {}".format(repetition_index))
    # Do some stuff and modify data_array, e.g.:
    data_array[repetition_index, 0, 53] = 12.34
    
    print("Finish record {}".format(repetition_index))
    log_mutex.release()

def run(repetition_index):
    global log_mutex
    global data_array

    # do some simulation

    record_results(repetition_index, data_array, log_mutex)

if __name__ == "__main__":
    random.seed()
    with Pool(thread_count) as p:
        print(p.map(run, range(repetition_count)))



问题是:我得到了正确的“开始记录和完成记录”输出,例如Start record 1... Finish record 1. 但是,每个线程修改的numpy数组的不同切片并没有保存在全局变量中。也就是说,线程1修改过的元素还是0,线程4覆盖了数组的不同部分。

补充一点,全局数组的地址,我通过 print(hex(id(data_array))) 对于所有线程都相同,在它们的 log_mutex.acquire() ... log_mutex.release() 行内。

我漏掉了一点吗?比如,每个线程存储了多个全局 data_array 副本?我正在观察这样的一些行为,但是当我使用 global 关键字时不应该是这种情况,我错了吗?

看起来您 运行 run 函数使用了多个进程,而不是多个线程。尝试这样的事情:

import numpy as np
from threading import Thread, Lock

log_mutex = Lock()
repetition_count = 5
data_array = np.zeros(shape=(repetition_count, 3, 200), dtype=float)

def record_results(repetition_index, data_array, log_mutex):
    log_mutex.acquire()
    print("Start record {}".format(repetition_index))
    # Do some stuff and modify data_array, e.g.:
    data_array[repetition_index, 0, 53] = 12.34
    print("Finish record {}".format(repetition_index))
    log_mutex.release()

def run(repetition_index):
    global log_mutex
    global data_array
    record_results(repetition_index, data_array, log_mutex)

if __name__ == "__main__":
    threads = []
    for i in range(repetition_count):
        t = Thread(target=run, args=[i])
        t.start()
        threads.append(t)
    for t in threads:
        t.join()

更新:

要对多个进程执行此操作,您需要使用 multiprocessing.RawArray 来实例化您的数组;数组的大小是乘积 repetition_count * 3 * 200。在每个进程中,使用 np.frombuffer 在数组上创建一个视图,并相应地对其进行整形。虽然这会非常快,但我不鼓励这种编程风格,因为它依赖于全局共享内存对象,这在较大的程序中容易出错。

如果可能,我建议删除全局 data_array,而是在每次调用 record_results 时实例化一个数组,您将在 run 中 return。 p.map 调用将 return 数组列表,您可以将其转换为 numpy 数组并恢复原始实现中全局 data_array 的形状和内容。这将产生通信成本,但它是一种更简洁的并发管理方法,并且无需锁。

通常尽量减少进程间通信是个好主意,但除非性能至关重要,否则我认为共享内存不是正确的解决方案。使用 p.map,您会希望避免 return 大对象,但您的代码段中的对象大小非常小(600*8 字节)。