批量调用一个昂贵的函数

Batching multiple calls to an expensive function

def main_function(keys):
    values = []
    for key in keys:
        value, = get_from_remote_location([key])
        values.append(value)
    return values



def get_from_remote_location(keys):
    return map(str, keys)

我有一个函数可以遍历一组键,为每个键调用一个昂贵的函数来获取一个值,然后 returns 一个值的集合。我只想调用昂贵的函数一次。这是一个人为的例子,我可以避免迭代键,但这只是为了说明。约束是:调用只能进行一次,main_function 必须对键集合使用迭代并对每个键进行看似远程的调用。

我在考虑先使用 asyncio 来收集钥匙,但我想出的解决方案很糟糕。沿线的东西:

keys = []

...

async def get_from_remote_location_batched(key):
    keys.append(key)
    await until_all_keys_have_been_appended_and_remote_call_has_been_done(keys)[key]

编辑 2020-01-28:

我有办法了。明天会尝试post。

from contextlib import asynccontextmanager, contextmanager
import asyncio
import pprint
import random
import uuid


class Transport:

    def __init__(self):
        """Used for passing information between coroutines.
        """
        self.key = None
        self.data = None
        self.data_set = False
        self.coroutine = None

    @contextmanager
    def __call__(self, coroutine):
        """Register a coroutine.
        """
        self.coroutine = coroutine
        self.__call__ = 'Disabled'

    def set_data(self, value):
        self.data = value
        self.data_set = True

    def get_awaitable(self):
        """Yield awaitable as long as data is not set.

        It's used to yield back control so that others can act.
        """
        while not self.data_set:
            yield asyncio.sleep(0)


class TransportGetter:

    def __init__(self):
        """For keeping track of multiple Transport objects.

        Convenient for one-to-many communication.
        """
        self.transports = []
        
    @contextmanager
    def get(self):
        transport = Transport()
        self.transports.append(transport)
        yield transport

    @property
    def coroutines(self):
        return [transport.coroutine for transport in self.transports]
       

def expensive_function(keys):
    # E.g. a database query
    return {k: k * 10 for k in keys}


@asynccontextmanager
async def remote_caller(fn):
    """Call fn when control is given back.
    """
    transport_getter = TransportGetter()

    yield transport_getter

    name = fn.__name__
    print(f'{name} caller is yielding transport getter')
    await asyncio.wait(transport_getter.coroutines, timeout=0)

    print(f'{name} caller is about to call and retrieve data')
    keys = [t.key for t in transport_getter.transports]
    data = fn(keys)

    print(f'{name} caller retrieved data\n'
            + f'{pprint.pformat(list(data.values()))}\n'
            + f'for keys\n{pprint.pformat(keys)}\n'
            + f'and is pushing individual responses back''')
    for transport in transport_getter.transports:
        transport.set_data(data[transport.key])
    print(f'{name} caller is done')


async def wait_for_data(transport, name):
    for awaitable in transport.get_awaitable():
        print(f'{name} is waiting and giving up control...')
        await awaitable


async def scalar_level_coroutine(transport, id_):
    """This represents some procedure requiring expensive call.

    Operates on a level of individual keys.
    Would be very expensive to call many of those
    without batching them first.
    """
    key = random.random()
    transport.key = key
    name = f'Worker {id_}'
    print(f'{name} pushed its key ({key}) and is giving up control...')

    await wait_for_data(transport, name)

    calculated_value = 2 ** int(transport.data)

    print(f'{name} calculated value {calculated_value} and is done...')
    return calculated_value


def add_coroutine(transport_getter, *args, **kwargs):
    with transport_getter.get() as transport:
        transport(scalar_level_coroutine(transport, *args, **kwargs))


async def main():
    async with remote_caller(expensive_function) as transport_getter:
        add_coroutine(transport_getter, 1)
        add_coroutine(transport_getter, 2)
        add_coroutine(transport_getter, 3)
        add_coroutine(transport_getter, 4)


if __name__ == '__main__':
    asyncio.run(main())

输出:

expensive_function caller is yielding transport getter
Worker 2 pushed its key (0.8549279924668386) and is giving up control...
Worker 2 is waiting and giving up control...
Worker 4 pushed its key (0.15663005612702752) and is giving up control...
Worker 4 is waiting and giving up control...
Worker 1 pushed its key (0.4688385351885074) and is giving up control...
Worker 1 is waiting and giving up control...
Worker 3 pushed its key (0.12254117937847231) and is giving up control...
Worker 3 is waiting and giving up control...
Worker 2 is waiting and giving up control...
Worker 4 is waiting and giving up control...
Worker 1 is waiting and giving up control...
Worker 3 is waiting and giving up control...
expensive_function caller is about to call and retrieve data
expensive_function caller retrieved data
[4.688385351885074, 8.549279924668387, 1.225411793784723, 1.5663005612702752]
for keys
[0.4688385351885074,
 0.8549279924668386,
 0.12254117937847231,
 0.15663005612702752]
and is pushing individual responses back
expensive_function caller is done
Worker 2 calculated value 256 and is done...
Worker 4 calculated value 2 and is done...
Worker 1 calculated value 16 and is done...
Worker 3 calculated value 2 and is done...