不能在分页结果处理程序中使用 BatchQuery class

Cannot use BatchQuery in a paged result handler class

Python 驱动程序提供了 event/callback 大结果的方法:

https://datastax.github.io/python-driver/query_paging.html

此外,还有一个 BatchQuery class 可以与 ORM 一起使用,它非常方便:

https://datastax.github.io/python-driver/cqlengine/batches.html?highlight=batchquery

现在,我需要在 Paged Result 对象的回调处理程序中执行 BatchQuery,但脚本只是停留在当前页面的迭代上。

我猜这是因为无法在线程之间共享 cassandra 会话,而 BatchQuery 和 "paged result" 方法正在使用线程来管理事件设置和回调调用。

知道如何神奇地解决这种情况吗?您可以在下面找到一些代码:

# paged.py
class PagedQuery:
    """
    Class to manage paged results.
    >>> query = "SELECT * FROM ks.my_table WHERE collectionid=123 AND ttype='collected'"  # define query
    >>> def handler(page):  # define result page handler function
    ...     for t in page:
    ...         print(t)
    >>> pq = PagedQuery(query, handler)  # instantiate a PagedQuery object
    >>> pq.finished_event.wait()  # wait for the PagedQuery to handle all results
    >>> if pq.error:
    ...     raise pq.error
    """
    def __init__(self, query, handler=None):
        session = new_cassandra_session()
        session.row_factory = named_tuple_factory
        statement = SimpleStatement(query, fetch_size=500)
        future = session.execute_async(statement)
        self.count = 0
        self.error = None
        self.finished_event = Event()
        self.query = query
        self.session = session
        self.handler = handler
        self.future = future
        self.future.add_callbacks(
            callback=self.handle_page,
            errback=self.handle_error
        )

    def handle_page(self, page):
        if not self.handler:
            raise RuntimeError('A page handler function was not defined for the query')
        self.handler(page)

        if self.future.has_more_pages:
            self.future.start_fetching_next_page()
        else:
            self.finished_event.set()

    def handle_error(self, exc):
        self.error = exc
        self.finished_event.set()

# main.py
# script using class above
def main():

    query = 'SELECT * FROM ks.my_table WHERE collectionid=10 AND ttype=\'collected\''

    def handle_page(page):

        b = BatchQuery(batch_type=BatchType.Unlogged)
        for obj in page:
            process(obj)  # some updates on obj...
            obj.batch(b).save()

        b.execute()

    pq = PagedQuery(query, handle_page)
    pq.finished_event.wait()

    if not pq.count:
        print('Empty queryset. Please, check parameters')

if __name__ == '__main__':
    main()

由于您不能在ResponseFuture 的事件循环中执行查询,您只能迭代并将对象发送到队列。我们确实有 kafka 队列来持久化对象,但在这种情况下,线程安全 Python 队列运行良好。

import sys
import datetime
import queue
import threading
import logging

from cassandra.connection import Event
from cassandra.cluster import Cluster, default_lbp_factory, NoHostAvailable
from cassandra.cqlengine.connection import (Connection, DEFAULT_CONNECTION, _connections)
from cassandra.query import named_tuple_factory, PreparedStatement, SimpleStatement
from cassandra.auth import PlainTextAuthProvider
from cassandra.util import OrderedMapSerializedKey
from cassandra.cqlengine.query import BatchQuery
from smfrcore.models.cassandra import Tweet

STOP_QUEUE = object()
logging.basicConfig(level=logging.DEBUG, format='[%(levelname)s] (%(threadName)-9s) %(message)s',)


def new_cassandra_session():
    retries = 5
    _cassandra_user = 'user'
    _cassandra_password = 'xxxx'
    while retries >= 0:
        try:
            cluster_kwargs = {'compression': True,
                          'load_balancing_policy': default_lbp_factory(),
                          'executor_threads': 10,
                          'idle_heartbeat_interval': 10,
                          'idle_heartbeat_timeout': 30,
                          'auth_provider': PlainTextAuthProvider(username=_cassandra_user, password=_cassandra_password)}

            cassandra_cluster = Cluster(**cluster_kwargs)
            cassandra_session = cassandra_cluster.connect()
            cassandra_session.default_timeout = None
            cassandra_session.default_fetch_size = 500
            cassandra_session.row_factory = named_tuple_factory
            cassandra_default_connection = Connection.from_session(DEFAULT_CONNECTION, session=cassandra_session)
            _connections[DEFAULT_CONNECTION] = cassandra_default_connection
            _connections[str(cassandra_session)] = cassandra_default_connection
        except (NoHostAvailable, Exception) as e:
            print('Cassandra host not available yet...sleeping 10 secs: ', str(e))
            retries -= 1
            time.sleep(10)
        else:
            return cassandra_session


class PagedQuery:
    """
    Class to manage paged results.
    >>> query = "SELECT * FROM ks.my_table WHERE collectionid=123 AND ttype='collected'"  # define query
    >>> def handler(page):  # define result page handler function
    ...     for t in page:
    ...         print(t)
    >>> pq = PagedQuery(query, handler)  # instantiate a PagedQuery object
    >>> pq.finished_event.wait()  # wait for the PagedQuery to handle all results
    >>> if pq.error:
    ...     raise pq.error
    """
    def __init__(self, query, handler=None):
        session = new_cassandra_session()
        session.row_factory = named_tuple_factory
        statement = SimpleStatement(query, fetch_size=500)
        future = session.execute_async(statement)
        self.count = 0
        self.error = None
        self.finished_event = Event()
        self.query = query
        self.session = session
        self.handler = handler
        self.future = future
        self.future.add_callbacks(
            callback=self.handle_page,
            errback=self.handle_error
        )

    def handle_page(self, page):
        if not self.handler:
            raise RuntimeError('A page handler function was not defined for the query')
        self.handler(page)

        if self.future.has_more_pages:
            self.future.start_fetching_next_page()
        else:
            self.finished_event.set()

    def handle_error(self, exc):
        self.error = exc
        self.finished_event.set()



def main():

    query = 'SELECT * FROM ks.my_table WHERE collectionid=1 AND ttype=\'collected\''

    q = queue.Queue()
    threads = []

    def worker():
        nonlocal q
        local_counter = 0
        b = BatchQuery(batch_type=BatchType.Unlogged)
        while True:
            tweet = q.get()

            if tweet is STOP_QUEUE:
                b.execute()

                logging.info(' >>>>>>>>>>>>>>>> Executed last batch for this worker!!!!')
                break

            tweet.batch(b).save()
            local_counter += 1
            if not (local_counter % 500):
                b.execute()
                logging.info('>>>>>>>>>>>>>>>> Batch executed in this worker: geotagged so far:', str(local_counter))
                b = BatchQuery(batch_type=BatchType.Unlogged)  # reset batch

            q.task_done()

    def handle_page(page):

        for obj in page:
            process(obj)  # some updates on obj...
            q.put(obj)

    pq = PagedQuery(query, handle_page)
    pq.finished_event.wait()
    # block until all tasks are done
    q.join()

    # stop workers by sending sentinel value (None)
    for i in range(4):
        q.put(STOP_QUEUE)

    for t in threads:
        t.join()

    if pq.error:
        raise pq.error

    if not pq.count:
        print('Empty queryset. Please, check parameters')

if __name__ == '__main__':
    sys.exit(main())