不能在分页结果处理程序类中使用批处理查询



Python 驱动程序为大型结果提供了事件/回调方法:

https://datastax.github.io/python-driver/query_paging.html

此外,还有一个BatchQuery类可以与ORM一起使用,它非常方便:

https://datastax.github.io/python-driver/cqlengine/batches.html?highlight=batchquery

现在,我需要在分页结果对象的回调处理程序中执行 BatchQuery,但脚本只是停留在当前页面上迭代。

我想这是由于无法在线程之间共享 cassandra 会话,而 BatchQuery 和"分页结果"方法正在使用线程来管理事件设置和回调调用。

关于如何神奇地解决这种情况的任何想法?您可以在下面找到一些代码:

# paged.py
class PagedQuery:
    """
    Class to manage paged results.
    >>> query = "SELECT * FROM ks.my_table WHERE collectionid=123 AND ttype='collected'"  # define query
    >>> def handler(page):  # define result page handler function
    ...     for t in page:
    ...         print(t)
    >>> pq = PagedQuery(query, handler)  # instantiate a PagedQuery object
    >>> pq.finished_event.wait()  # wait for the PagedQuery to handle all results
    >>> if pq.error:
    ...     raise pq.error
    """
    def __init__(self, query, handler=None):
        session = new_cassandra_session()
        session.row_factory = named_tuple_factory
        statement = SimpleStatement(query, fetch_size=500)
        future = session.execute_async(statement)
        self.count = 0
        self.error = None
        self.finished_event = Event()
        self.query = query
        self.session = session
        self.handler = handler
        self.future = future
        self.future.add_callbacks(
            callback=self.handle_page,
            errback=self.handle_error
        )
    def handle_page(self, page):
        if not self.handler:
            raise RuntimeError('A page handler function was not defined for the query')
        self.handler(page)
        if self.future.has_more_pages:
            self.future.start_fetching_next_page()
        else:
            self.finished_event.set()
    def handle_error(self, exc):
        self.error = exc
        self.finished_event.set()
# main.py
# script using class above
def main():
    query = 'SELECT * FROM ks.my_table WHERE collectionid=10 AND ttype='collected''
    def handle_page(page):
        b = BatchQuery(batch_type=BatchType.Unlogged)
        for obj in page:
            process(obj)  # some updates on obj...
            obj.batch(b).save()
        b.execute()
    pq = PagedQuery(query, handle_page)
    pq.finished_event.wait()
    if not pq.count:
        print('Empty queryset. Please, check parameters')
if __name__ == '__main__':
    main()

由于您无法在 ResponseFuture 的事件循环中执行查询,因此您可以迭代并将对象发送到队列。我们确实有 kafka 队列来持久化对象,但在这种情况下,线程安全的 Python 队列运行良好。

import sys
import datetime
import queue
import threading
import logging
from cassandra.connection import Event
from cassandra.cluster import Cluster, default_lbp_factory, NoHostAvailable
from cassandra.cqlengine.connection import (Connection, DEFAULT_CONNECTION, _connections)
from cassandra.query import named_tuple_factory, PreparedStatement, SimpleStatement
from cassandra.auth import PlainTextAuthProvider
from cassandra.util import OrderedMapSerializedKey
from cassandra.cqlengine.query import BatchQuery
from smfrcore.models.cassandra import Tweet
STOP_QUEUE = object()
logging.basicConfig(level=logging.DEBUG, format='[%(levelname)s] (%(threadName)-9s) %(message)s',)

def new_cassandra_session():
    retries = 5
    _cassandra_user = 'user'
    _cassandra_password = 'xxxx'
    while retries >= 0:
        try:
            cluster_kwargs = {'compression': True,
                          'load_balancing_policy': default_lbp_factory(),
                          'executor_threads': 10,
                          'idle_heartbeat_interval': 10,
                          'idle_heartbeat_timeout': 30,
                          'auth_provider': PlainTextAuthProvider(username=_cassandra_user, password=_cassandra_password)}
            cassandra_cluster = Cluster(**cluster_kwargs)
            cassandra_session = cassandra_cluster.connect()
            cassandra_session.default_timeout = None
            cassandra_session.default_fetch_size = 500
            cassandra_session.row_factory = named_tuple_factory
            cassandra_default_connection = Connection.from_session(DEFAULT_CONNECTION, session=cassandra_session)
            _connections[DEFAULT_CONNECTION] = cassandra_default_connection
            _connections[str(cassandra_session)] = cassandra_default_connection
        except (NoHostAvailable, Exception) as e:
            print('Cassandra host not available yet...sleeping 10 secs: ', str(e))
            retries -= 1
            time.sleep(10)
        else:
            return cassandra_session

class PagedQuery:
    """
    Class to manage paged results.
    >>> query = "SELECT * FROM ks.my_table WHERE collectionid=123 AND ttype='collected'"  # define query
    >>> def handler(page):  # define result page handler function
    ...     for t in page:
    ...         print(t)
    >>> pq = PagedQuery(query, handler)  # instantiate a PagedQuery object
    >>> pq.finished_event.wait()  # wait for the PagedQuery to handle all results
    >>> if pq.error:
    ...     raise pq.error
    """
    def __init__(self, query, handler=None):
        session = new_cassandra_session()
        session.row_factory = named_tuple_factory
        statement = SimpleStatement(query, fetch_size=500)
        future = session.execute_async(statement)
        self.count = 0
        self.error = None
        self.finished_event = Event()
        self.query = query
        self.session = session
        self.handler = handler
        self.future = future
        self.future.add_callbacks(
            callback=self.handle_page,
            errback=self.handle_error
        )
    def handle_page(self, page):
        if not self.handler:
            raise RuntimeError('A page handler function was not defined for the query')
        self.handler(page)
        if self.future.has_more_pages:
            self.future.start_fetching_next_page()
        else:
            self.finished_event.set()
    def handle_error(self, exc):
        self.error = exc
        self.finished_event.set()

def main():
    query = 'SELECT * FROM ks.my_table WHERE collectionid=1 AND ttype='collected''
    q = queue.Queue()
    threads = []
    def worker():
        nonlocal q
        local_counter = 0
        b = BatchQuery(batch_type=BatchType.Unlogged)
        while True:
            tweet = q.get()
            if tweet is STOP_QUEUE:
                b.execute()
                logging.info(' >>>>>>>>>>>>>>>> Executed last batch for this worker!!!!')
                break
            tweet.batch(b).save()
            local_counter += 1
            if not (local_counter % 500):
                b.execute()
                logging.info('>>>>>>>>>>>>>>>> Batch executed in this worker: geotagged so far:', str(local_counter))
                b = BatchQuery(batch_type=BatchType.Unlogged)  # reset batch
            q.task_done()
    def handle_page(page):
        for obj in page:
            process(obj)  # some updates on obj...
            q.put(obj)
    pq = PagedQuery(query, handle_page)
    pq.finished_event.wait()
    # block until all tasks are done
    q.join()
    # stop workers by sending sentinel value (None)
    for i in range(4):
        q.put(STOP_QUEUE)
    for t in threads:
        t.join()
    if pq.error:
        raise pq.error
    if not pq.count:
        print('Empty queryset. Please, check parameters')
if __name__ == '__main__':
    sys.exit(main())

最新更新