如何在 Python 中使用多线程安全地修改全局 numpy 数组?
How to modify global numpy array safely with multithreading in Python?
我正在尝试 运行 在线程池中进行模拟,并将每次重复的结果存储在全局 numpy 数组中。但是,我在这样做时遇到了问题,我正在使用以下(简化的)代码(python 3.7)观察到一个非常有趣的行为:
import numpy as np
from multiprocessing import Pool, Lock
log_mutex = Lock()
repetition_count = 5
data_array = np.zeros(shape=(repetition_count, 3, 200), dtype=float)
def record_results(repetition_index, data_array, log_mutex):
log_mutex.acquire()
print("Start record {}".format(repetition_index))
# Do some stuff and modify data_array, e.g.:
data_array[repetition_index, 0, 53] = 12.34
print("Finish record {}".format(repetition_index))
log_mutex.release()
def run(repetition_index):
global log_mutex
global data_array
# do some simulation
record_results(repetition_index, data_array, log_mutex)
if __name__ == "__main__":
random.seed()
with Pool(thread_count) as p:
print(p.map(run, range(repetition_count)))
问题是:我得到了正确的“开始记录和完成记录”输出,例如Start record 1... Finish record 1. 但是,每个线程修改的numpy数组的不同切片并没有保存在全局变量中。也就是说,线程1修改过的元素还是0,线程4覆盖了数组的不同部分。
补充一点,全局数组的地址,我通过
print(hex(id(data_array)))
对于所有线程都相同,在它们的 log_mutex.acquire() ... log_mutex.release()
行内。
我漏掉了一点吗?比如,每个线程存储了多个全局 data_array 副本?我正在观察这样的一些行为,但是当我使用 global 关键字时不应该是这种情况,我错了吗?
看起来您 运行 run
函数使用了多个进程,而不是多个线程。尝试这样的事情:
import numpy as np
from threading import Thread, Lock
log_mutex = Lock()
repetition_count = 5
data_array = np.zeros(shape=(repetition_count, 3, 200), dtype=float)
def record_results(repetition_index, data_array, log_mutex):
log_mutex.acquire()
print("Start record {}".format(repetition_index))
# Do some stuff and modify data_array, e.g.:
data_array[repetition_index, 0, 53] = 12.34
print("Finish record {}".format(repetition_index))
log_mutex.release()
def run(repetition_index):
global log_mutex
global data_array
record_results(repetition_index, data_array, log_mutex)
if __name__ == "__main__":
threads = []
for i in range(repetition_count):
t = Thread(target=run, args=[i])
t.start()
threads.append(t)
for t in threads:
t.join()
更新:
要对多个进程执行此操作,您需要使用 multiprocessing.RawArray
来实例化您的数组;数组的大小是乘积 repetition_count * 3 * 200
。在每个进程中,使用 np.frombuffer
在数组上创建一个视图,并相应地对其进行整形。虽然这会非常快,但我不鼓励这种编程风格,因为它依赖于全局共享内存对象,这在较大的程序中容易出错。
如果可能,我建议删除全局 data_array
,而是在每次调用 record_results
时实例化一个数组,您将在 run
中 return。 p.map
调用将 return 数组列表,您可以将其转换为 numpy 数组并恢复原始实现中全局 data_array
的形状和内容。这将产生通信成本,但它是一种更简洁的并发管理方法,并且无需锁。
通常尽量减少进程间通信是个好主意,但除非性能至关重要,否则我认为共享内存不是正确的解决方案。使用 p.map
,您会希望避免 return 大对象,但您的代码段中的对象大小非常小(600*8 字节)。
我正在尝试 运行 在线程池中进行模拟,并将每次重复的结果存储在全局 numpy 数组中。但是,我在这样做时遇到了问题,我正在使用以下(简化的)代码(python 3.7)观察到一个非常有趣的行为:
import numpy as np
from multiprocessing import Pool, Lock
log_mutex = Lock()
repetition_count = 5
data_array = np.zeros(shape=(repetition_count, 3, 200), dtype=float)
def record_results(repetition_index, data_array, log_mutex):
log_mutex.acquire()
print("Start record {}".format(repetition_index))
# Do some stuff and modify data_array, e.g.:
data_array[repetition_index, 0, 53] = 12.34
print("Finish record {}".format(repetition_index))
log_mutex.release()
def run(repetition_index):
global log_mutex
global data_array
# do some simulation
record_results(repetition_index, data_array, log_mutex)
if __name__ == "__main__":
random.seed()
with Pool(thread_count) as p:
print(p.map(run, range(repetition_count)))
问题是:我得到了正确的“开始记录和完成记录”输出,例如Start record 1... Finish record 1. 但是,每个线程修改的numpy数组的不同切片并没有保存在全局变量中。也就是说,线程1修改过的元素还是0,线程4覆盖了数组的不同部分。
补充一点,全局数组的地址,我通过
print(hex(id(data_array)))
对于所有线程都相同,在它们的 log_mutex.acquire() ... log_mutex.release()
行内。
我漏掉了一点吗?比如,每个线程存储了多个全局 data_array 副本?我正在观察这样的一些行为,但是当我使用 global 关键字时不应该是这种情况,我错了吗?
看起来您 运行 run
函数使用了多个进程,而不是多个线程。尝试这样的事情:
import numpy as np
from threading import Thread, Lock
log_mutex = Lock()
repetition_count = 5
data_array = np.zeros(shape=(repetition_count, 3, 200), dtype=float)
def record_results(repetition_index, data_array, log_mutex):
log_mutex.acquire()
print("Start record {}".format(repetition_index))
# Do some stuff and modify data_array, e.g.:
data_array[repetition_index, 0, 53] = 12.34
print("Finish record {}".format(repetition_index))
log_mutex.release()
def run(repetition_index):
global log_mutex
global data_array
record_results(repetition_index, data_array, log_mutex)
if __name__ == "__main__":
threads = []
for i in range(repetition_count):
t = Thread(target=run, args=[i])
t.start()
threads.append(t)
for t in threads:
t.join()
更新:
要对多个进程执行此操作,您需要使用 multiprocessing.RawArray
来实例化您的数组;数组的大小是乘积 repetition_count * 3 * 200
。在每个进程中,使用 np.frombuffer
在数组上创建一个视图,并相应地对其进行整形。虽然这会非常快,但我不鼓励这种编程风格,因为它依赖于全局共享内存对象,这在较大的程序中容易出错。
如果可能,我建议删除全局 data_array
,而是在每次调用 record_results
时实例化一个数组,您将在 run
中 return。 p.map
调用将 return 数组列表,您可以将其转换为 numpy 数组并恢复原始实现中全局 data_array
的形状和内容。这将产生通信成本,但它是一种更简洁的并发管理方法,并且无需锁。
通常尽量减少进程间通信是个好主意,但除非性能至关重要,否则我认为共享内存不是正确的解决方案。使用 p.map
,您会希望避免 return 大对象,但您的代码段中的对象大小非常小(600*8 字节)。