Python 3 concurrent.futures 和每线程初始化

Python 3 concurrent.futures and per-thread initialization

在Python 3、是否可以在concurrent.futures.ThreadPoolExecutor的上下文中使用Thread的子类,让它们在处理前单独初始化(估计很多)工作项目?

我想使用方便的 concurrent.futures API 作为一段同步文件和 S3 对象的代码(如果相应的 S3 对象是不存在或不同步)。我希望每个工作线程先做一些初始化,比如设置一个boto3.session.Session。然后该工作线程池将准备好处理可能有数千个工作项(要同步的文件)。

顺便说一句,如果一个线程由于某种原因死亡,期望自动创建一个新线程并将其添加回池是否合理?

(免责声明:我对 Java 的多线程框架比 Python 的更熟悉)。

因此,看来我的问题的一个简单解决方案是使用 threading.local 存储每个线程 "session"(在下面的模型中,只是一个随机整数)。也许不是我猜的最干净的,但现在它会做。这是一个模型(Python 3.5.1):

import time
import threading
import concurrent.futures
import random
import logging

logging.basicConfig(level=logging.DEBUG, format='(%(threadName)-0s) %(relativeCreated)d - %(message)s')

x = [0.1, 0.1, 0.2, 0.4, 1.0, 0.1, 0.0]

mydata = threading.local()

def do_work(secs):
    if 'session' in mydata.__dict__:
        logging.debug('re-using session "{}"'.format(mydata.session))
    else:
        mydata.session = random.randint(0,1000)
        logging.debug('created new session: "{}"'.format(mydata.session))
    time.sleep(secs)
    logging.debug('slept for {} seconds'.format(secs))
    return secs

with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor:
    y = executor.map(do_work, x)

print(list(y))

产生以下输出,表明 "sessions" 确实是每个线程的本地并被重用:

(Thread-1) 29 - created new session: "855"
(Thread-2) 29 - created new session: "58"
(Thread-3) 30 - created new session: "210"
(Thread-1) 129 - slept for 0.1 seconds
(Thread-1) 130 - re-using session "855"
(Thread-2) 130 - slept for 0.1 seconds
(Thread-2) 130 - re-using session "58"
(Thread-3) 230 - slept for 0.2 seconds
(Thread-3) 230 - re-using session "210"
(Thread-3) 331 - slept for 0.1 seconds
(Thread-3) 331 - re-using session "210"
(Thread-3) 331 - slept for 0.0 seconds
(Thread-1) 530 - slept for 0.4 seconds
(Thread-2) 1131 - slept for 1.0 seconds
[0.1, 0.1, 0.2, 0.4, 1.0, 0.1, 0.0]

关于日志记录的小注意事项:为了在 IPython 笔记本中使用它,需要稍微修改日志记录设置(因为 IPython 已经设置了根记录器)。更强大的日志记录设置是:

IN_IPYNB = 'get_ipython' in vars()

if IN_IPYNB:
    logger = logging.getLogger()
    logger.setLevel(logging.DEBUG)
    for h in logger.handlers:
        h.setFormatter(logging.Formatter(
                '(%(threadName)-0s) %(relativeCreated)d - %(message)s'))
else:
    logging.basicConfig(level=logging.DEBUG, format='(%(threadName)-0s) %(relativeCreated)d - %(message)s')