使用 pool.map 进行多处理时无法腌制本地对象

Can't pickle local object when multiprocessing with pool.map

我正在尝试将 multiprocessing 与 python Pool 函数一起使用,使用 functools.partial 将几个具有常量值的参数输入到 pool.map 命令中(即第一个参数是唯一不同的)。

问题是当我 运行 代码时出现以下错误,我不知道为什么或如何解决它:

AttributeError: Can't pickle local object 
    'get_MAX_SNR_for_eventdata_file.<locals>.get_SNR_multiprocess'

我不知道为什么它不能 pickle 对象。这是代码(在顶级函数中):

def get_SNR_multiprocess(binning, event_data, energy_interval, tstart, tstop, trigger_time):
    """ This function just changes the order of arguments to be able to use partial"""
    SNR=get_max_SNR_est(event_data, energy_interval, binning, tstart, tstop, trigger_time)
    return SNR

pool=multiprocessing.Pool(processes=4)

for i in range(len(energybands)-1):
    energy_interval=[energybands[i],energybands[i+1]]
    partial_func=partial(get_SNR_multiprocess, event_data=event_data, 
                         energy_interval=energy_interval, tstart=tstart, tstop=tstop, 
                         trigger_time=trigger_time)
    SNRlist=pool.map(partial_func,timescales)
pool.close()

我得到一个提示,根据 What can be pickled?,问题可能与这样一个事实有关,即只有定义在模块顶层的函数才能被 pickle。但是,我无法弄清楚我的代码中的确切问题,或者如何解决它。

代码中的函数get_max_SNR_est是同一个脚本中定义的函数和returns一个值。此函数依赖于同一脚本的其他函数(依赖于另一个等等...)。

仅供参考,代码无需使用 for 循环进行多处理即可运行,例如:

SNRlist=[]
for i in range(len(energybands)-1):
    energy_interval=[energybands[i],energybands[i+1]]
    for binning in timescales:
        SNR=get_max_SNR_est(event_data, energy_interval, binning, tstart, tstop, 
                            trigger_time)
        SNRlist.append(SNR)

编辑:我忘了说我在这里展示的代码已经在函数中了。根据@martineau 的评论,我从上述函数中提取了函数 get_SNR_multiprocessing,这解决了酸洗问题(参见答案)。

感谢@martineau 的评论,我找到了解决此问题的方法。正如我稍后在问题的 edit 中提到的,我在这里显示的代码已经在一个函数中。我从上述函数中取出函数 get_SNR_multiprocessing ,它解决了 pickling 的问题。新代码(我在这里展示包含上面代码的函数)如下所示:

def get_SNR_multiprocess(binning, event_data, energy_interval, tstart, tstop, trigger_time):
    """ This function just changes the order of arguments to be able to use functools.partial for multiprocessing"""
    SNR=get_max_SNR_est(event_data, energy_interval, binning, tstart, tstop, trigger_time)
    return SNR

def get_MAX_SNR_for_eventdata_file(event_data, energybands, timescales, tstart, tstop, trigger_time):
    """
    Gives the maximum SNR of all timescales and energybands given for a given event_data file
    """ 
    SNRlist=[]
    for i in range(len(energybands)-1):
        energy_interval=[energybands[i],energybands[i+1]]

        with multiprocessing.Pool(processes=4) as pool:
            partial_func=partial(get_SNR_multiprocess, event_data=event_data, energy_interval=energy_interval, tstart=tstart, tstop=tstop, trigger_time=trigger_time)
            SNRlist=pool.map(partial_func,timescales)

不幸的是,此方法与带有 for 循环的原始方法花费的时间相同。