multiprocessing.Pool 在 pathlib.Path 子类中具有现有 属性 的对象上调用函数时正在丢失状态

multiprocessing.Pool is losing state when calling a function on an object with an existing property in pathlib.Path subclass

我正在尝试解决此问题。在此过程中,我尝试创建一个独立的函数来重现该问题,但由于某种原因,它在微型示例中按预期工作,但在我的产品代码中却没有。

我有一个 pathlib.Path 的子类:

class WalkPath(Path):

    _flavour = type(Path())._flavour
   
    def __init__(self, *args, origin: 'WalkPath'=None, dirs: []=None, files: []=None):

        super().__init__()
        
        if type(args[0]) is str:
            self.origin = origin or self
        else:
            self.origin = origin or args[0].origin

        self._dirs: [WalkPath] = list(map(WalkPath, dirs)) if dirs else None
        self._files: [WalkPath] = list(map(WalkPath, files)) if files else None

        self._lazy_attr = None

    @staticmethod
    def sync(wp: Union[str, Path, 'WalkPath']):
        """Syncronize lazy-loaded attributes"""
        x = wp.lazy_attr
        return wp

    @property
    def lazy_attr(self):
        if self._lazy_attr:
            return self._lazy_attr:
        # long running op
        self._lazy_attr = long_running_op(self)
        return self._lazy_attr

class Find:

    @staticmethod
    def shallow(path: Union[str, Path, 'WalkPath'],
                sort_key=lambda p: str(p).lower(),
                hide_sys_files=True) -> Iterable['WalkPath']:
        origin = WalkPath(path)
        if origin.is_file(): 
            return [origin]
        
        for p in sorted(origin.iterdir(), key=sort_key):
            if hide_sys_files and is_sys_file(p):
                continue
            yield WalkPath(p, origin=origin)

使用 multiprocessing.Pool,我想在池中执行那个长 运行 进程。

看起来像这样:

_paths = ['/path1', '/path2']
found = list(itertools.chain.from_iterable(Find.shallow(p) for p in _paths))

Find.shallow(见上文)基本上只是在原点上执行 Path.iterdir,然后将结果映射到 WalkPath 对象,将原点设置为调用的路径。我知道这行得通,因为它输出正确:

for x in found:
    print(x.origin, x.name)

然后我们调度到一个池中:

with mp.Pool() as pool:
    done = [x for x in pool.map(WalkPath.sync, found) if x.origin]

但这失败了,开始 'WalkPath' has no attribute 'origin'

这是我在本地复制它的尝试,但由于某种原因它成功了!我看不出区别。

#!/usr/bin/env python

import multiprocessing as mp
import time
from itertools import tee, chain

r = None

class P:
    
    def __init__(self, i, static=None):
        # self.static = static if not static is None else i
        self.static = static or i
        # print(static, self.static)
        self.i = i
        
        self._a_thing = None
    
    @property
    def a_thing(self):
        if self._a_thing:
            print('Already have thing', self.i, 'static:', self.static)
            return self._a_thing
        time.sleep(0.05)
        print('Did thing', self.i, 'static:', self.static)
        self._a_thing = True
        return self._a_thing
    
    @staticmethod
    def sync(x):
        x.a_thing
        x.another = 'done'
        return x if x.a_thing else None
    
class Load:
    
    @classmethod
    def go(cls):
        
        global r
        
        if r:
            return r
               
        paths = [iter(P(i, static='0') for i in range(10)),
                 iter(P(i, static='0') for i in range(11, 20)),
                 iter(P(i, static='0') for i in range(21, 30))]
        
        iternums, testnums = tee(chain.from_iterable(paths))
        
        for t in testnums:
            print('Want thing', t.i, 'to have static:', t.static)
            
        with mp.Pool() as pool:
            rex = [x for x in pool.map(P.sync, list(iternums)) if x.another]
         
        r = rex
            
        for done in rex:
            print(done.i, done.static, done.a_thing, done.another)

Load.go()

问题的症结在于您的 Path 对象无法在解释器进程之间共享。

相反,当使用 multiprocessing 时,Python 序列化 (pickles) 所有参数和 return 值 to/from 子进程。

似乎 pathlib.Path 定义了与您的 origin 属性不兼容的自定义 pickling/unpickling 逻辑:

import pathlib
import pickle


class WalkPath(pathlib.Path):

    _flavour = type(pathlib.Path())._flavour

    def __init__(self, *args, origin: 'WalkPath'=None, dirs: []=None, files: []=None):

        super().__init__()

        if type(args[0]) is str:
            self.origin = origin or self
        else:
            self.origin = origin or args[0].origin

        self._dirs: [WalkPath] = list(map(WalkPath, dirs)) if dirs else None
        self._files: [WalkPath] = list(map(WalkPath, files)) if files else None

        self._lazy_attr = None


path = WalkPath('/tmp', origin='far away')
print(vars(path))

reloaded = pickle.loads(pickle.dumps(path))
print(vars(reloaded))
$ python3.9 test.py 
{'origin': 'far away', '_dirs': None, '_files': None, '_lazy_attr': None}
{'origin': WalkPath('/tmp'), '_dirs': None, '_files': None, '_lazy_attr': None}

为了好玩,下面是我最终解决这个问题的方法。

这里发生的事情是,Path 实现了 __reduce__ 函数,该函数在 __getstate____setstate__ 之前调用(它们是更高级别的 pickling 函数)。

这是 PurePath__reduce__ 函数,Path 的基础 class:

def __reduce__(self):
    # Using the parts tuple helps share interned path parts
    # when pickling related paths.
    return (self.__class__, tuple(self._parts))

哦不!好吧,我们可以看看会发生什么 - 这是有意设计的,只是为了传递其部分的元组,完全丢弃状态并形成其自身的新版本。

我不想弄乱它,但我也想确保我的状态保存在这里。所以我创建了一个将这些属性作为元组参数的序列化程序(因为... __reduce__ 由于某些荒谬的原因只将单个元组作为参数)。

我还必须确保 origin 现在是一个 Path 对象,而不是 WalkPath 对象,否则我将以无休止的递归结束。我向 __init__:

添加了一些类型强制和安全
if origin:
        self.origin = Path(origin)
    elif len(args) > 0:
        try:
            self.origin = Path(args[0].origin) or Path(args[0])
        except:
            self.origin = Path(self)
    
    if not self.origin:
        raise AttributeError(f"Could not infer 'origin' property when initializing 'WalkPath', for path '{args[0]}'")

然后我把这两个方法加到WalkPath:

# @overrides(__reduce__)
def __reduce__(self):

    # From super()

    # Using the parts tuple helps share internal path parts
    # when pickling related paths.
    
    # return (self.__class__, tuple(self._parts))

    # This override passes its parts to a Path object (which
    # natively pickles), then serializes and applies 
    # its remaining attributes.
    
    args = {**{'_parts': self._parts}, **self.__dict__}
    return (self.__class__._from_kwargs, tuple(args.items()))

@classmethod
def _from_kwargs(cls, *args):
    kwargs = dict(args)
    new = cls(super().__new__(cls, 
                              *kwargs['_parts']),
                              origin=kwargs['origin'])
    new.__dict__ = {**new.__dict__, **kwargs}
    return new