无法将函数并行映射到 tarfile 成员

Can't map a function to tarfile members in parallel

我有一个 tar 包含 bz2 压缩文件的文件。我想将函数 clean_file 应用于每个 bz2 文件,并整理结果。在系列中,这很容易循环:

import pandas as pd
import json
import os
import bz2
import itertools
import datetime
import tarfile
from multiprocessing import Pool

def clean_file(member):
    if '.bz2' in str(member):

        f = tr.extractfile(member)

        with bz2.open(f, "rt") as bzinput:
            dicts = []
            for i, line in enumerate(bzinput):
                line = line.replace('"name"}', '"name":" "}')
                dat = json.loads(line)
                dicts.append(dat)

        bzinput.close()
        f.close()
        del f, bzinput

        processed = dicts[0]
        return processed

    else:
        pass


# Open tar file and get contents (members)
tr = tarfile.open('data.tar')
members = tr.getmembers()
num_files = len(members)


# Apply the clean_file function in series
i=0
processed_files = []
for m in members:
    processed_files.append(clean_file(m))
    i+=1
    print('done '+str(i)+'/'+str(num_files))
    

但是,我需要能够并行执行此操作。我正在尝试使用 Pool 的方法,如下所示:

# Apply the clean_file function in parallel
if __name__ == '__main__':
   with Pool(2) as p:
      processed_files = list(p.map(clean_file, members))

但是这个 returns 一个 OSError:

Traceback (most recent call last):
  File "/Users/johnfoley/opt/anaconda3/envs/racing_env/lib/python3.6/multiprocessing/pool.py", line 119, in worker
    result = (True, func(*args, **kwds))
  File "parse_data.py", line 19, in clean_file
    for i, line in enumerate(bzinput):
  File "/Users/johnfoley/opt/anaconda3/envs/racing_env/lib/python3.6/bz2.py", line 195, in read1
    return self._buffer.read1(size)
  File "/Users/johnfoley/opt/anaconda3/envs/racing_env/lib/python3.6/_compression.py", line 68, in readinto
    data = self.read(len(byte_view))
  File "/Users/johnfoley/opt/anaconda3/envs/racing_env/lib/python3.6/_compression.py", line 103, in read
    data = self._decompressor.decompress(rawblock, size)
OSError: Invalid data stream
"""

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "parse_data.py", line 53, in <module>
    processed_files = list(tqdm.tqdm(p.imap(clean_file, members), total=num_files))
  File "/Users/johnfoley/opt/anaconda3/envs/racing_env/lib/python3.6/site-packages/tqdm/std.py", line 1167, in __iter__
    for obj in iterable:
  File "/Users/johnfoley/opt/anaconda3/envs/racing_env/lib/python3.6/multiprocessing/pool.py", line 735, in next
    raise value
OSError: Invalid data stream

所以我想这种方式不能正确地从 data.tar 或其他什么地方访问文件。如何并行应用该功能?

我猜这将适用于任何包含 bz2 文件的 tar 存档,但这是我重现错误的数据: https://github.com/johnf1004/reproduce_tar_error

似乎发生了某种竞争情况。 在每个子进程中单独打开 tar 文件解决了问题:

import json
import bz2
import tarfile
import logging
from multiprocessing import Pool


def clean_file(member):
    if '.bz2' not in str(member):
        return
    try:
        with tarfile.open('data.tar') as tr:
            with tr.extractfile(member) as bz2_file:
                with bz2.open(bz2_file, "rt") as bzinput:
                    dicts = []
                    for i, line in enumerate(bzinput):
                        line = line.replace('"name"}', '"name":" "}')
                        dat = json.loads(line)
                        dicts.append(dat)
                        return dicts[0]
    except Exception:
        logging.exception(f"Error while processing {member}")


def process_serial():
    tr = tarfile.open('data.tar')
    members = tr.getmembers()
    processed_files = []
    for i, member in enumerate(members):
        processed_files.append(clean_file(member))
        print(f'done {i}/{len(members)}')


def process_parallel():
    tr = tarfile.open('data.tar')
    members = tr.getmembers()
    with Pool() as pool:
        processed_files = pool.map(clean_file, members)
        print(processed_files)


def main():
    process_parallel()


if __name__ == '__main__':
    main()

编辑:

请注意,另一种解决此问题的方法是仅使用 spawn start 方法:

multiprocessing.set_start_method('spawn')

通过这样做,我们指示 Python 在子进程中“深度复制”文件句柄。 在默认的“fork”start 方法下,父子文件句柄 .

您没有指定您 运行 所在的平台,但我怀疑它是 Windows,因为您有 ...

if __name__ == '__main__':
    main()

... 这对于在使用 OS 函数 spawn 创建新进程的平台上创建进程的代码是必需的。但这也意味着当创建一个新进程时(例如,您正在创建的进程池中的所有进程),每个进程都从程序的最顶部重新执行源程序开始。这意味着每个池进程正在执行以下代码:

tr = tarfile.open('data.tar')
members = tr.getmembers()
num_files = len(members)

但是,我不明白为什么这本身会导致错误,但我不能确定。然而,问题可能是,这是在调用辅助函数之后执行的,正在调用 clean_file,因此尚未设置 tr。如果此代码位于 clean_file 之前,它可能会起作用,但这只是一个猜测。当然,在每个池进程中提取 members = tr.getmembers() 的成员是浪费的。 每个进程都需要打开 tar 文件,最好只打开一次。

但很明显,您发布的堆栈跟踪与您的代码不匹配。你显示:

Traceback (most recent call last):
  File "parse_data.py", line 53, in <module>
    processed_files = list(tqdm.tqdm(p.imap(clean_file, members), total=num_files))

然而您的代码没有任何对 tqdm 或使用方法 imap 的引用。现在,当您 post 的代码与产生异常的代码不完全匹配时,分析您的实际问题变得更加困难。

如果您 运行 在 Mac 上,它可能正在使用 fork 创建新进程,当主进程创建了多个进程时,这可能会出现问题threads(你不一定看到,也许通过 tarfile 模块)然后你创建一个新进程,我已经指定代码以确保 spawn 用于创建新进程。无论如何,下面的代码 应该 有效。它还介绍了一些优化。如果没有,请 post 一个新的堆栈跟踪。

import pandas as pd
import json
import os
import bz2
import itertools
import datetime
import tarfile
from multiprocessing import get_context

def open_tar():
    # open once for each process in the pool
    global tr
    tr = tarfile.open('data.tar')

def clean_file(member):
    f = tr.extractfile(member)

    with bz2.open(f, "rt") as bzinput:
        for line in bzinput:
            line = line.replace('"name"}', '"name":" "}')
            dat = json.loads(line)
            # since you are returning just the first occurrence:
            return dat

def main():
    with tarfile.open('data.tar') as tr:
        members = tr.getmembers()
    # just pick members where '.bz2' is in member:
    filtered_members = filter(lambda member: '.bz2' in str(member), members)
    ctx = get_context('spawn')
    # open tar file just once for each process in the pool:
    with ctx.Pool(initializer=open_tar) as pool:
        processed_files = pool.map(clean_file, filtered_members)
        print(processed_files)

# required for when processes are created using spawn:
if __name__ == '__main__':
    main()