如何最好地使用 Python 并行化 grakn 查询?



我运行Windows 10,Python 3.7,并有一个6核CPU。我机器上的一个 Python 线程每秒向 grakn 提交 1,000 个插入。我想并行化我的代码以更快地插入和匹配。人们是怎么做到的?

我唯一的平行化经验是在另一个项目中,在那里我向 dask 分布式客户端提交一个自定义函数以生成数千个任务。现在,每当自定义函数接收或生成 grakn 事务对象/句柄时,这种方法都会失败。我收到以下错误:

Traceback (most recent call last):
File "C:Usersdvyd.condaenvsactivefictionlibsite-packagesdistributedprotocolpickle.py", line 41, in dumps
return cloudpickle.dumps(x, protocol=pickle.HIGHEST_PROTOCOL)
...
File "stringsource", line 2, in grpc._cython.cygrpc.Channel.__reduce_cython__
TypeError: no default __reduce__ due to non-trivial __cinit__

我从来没有直接使用过Python的多处理模块。其他人在做什么来并行化他们的查询以抓取?

我发现执行一批查询的最简单方法是将 Grakn 会话传递给ThreadPool中的每个线程。在每个线程中,您可以管理事务,当然还可以执行一些更复杂的逻辑:

from grakn.client import GraknClient
from multiprocessing.dummy import Pool as ThreadPool
from functools import partial
def write_query_batch(session, batch):
tx = session.transaction().write()
for query in batch:
tx.query(query)
tx.commit()
def multi_thread_write_query_batches(session, query_batches, num_threads=8):
pool = ThreadPool(num_threads)
pool.map(partial(write_query_batch, session), query_batches)
pool.close()
pool.join()
def generate_query_batches(my_data_entries_list, batch_size):
batch = []
for index, data_entry in enumerate(my_data_entries_list):
batch.append(data_entry)
if index % batch_size == 0 and index != 0:
yield batch
batch = []
if batch:
yield batch

# (Part 2) Somewhere in your application open a client and a session
client = GraknClient(uri="localhost:48555")
session = client.session(keyspace="grakn")
query_batches_iterator = generate_query_batches(my_data_entries_list, batch_size)
multi_thread_write_query_batches(session, query_batches_iterator, num_threads=8)
session.close()
client.close()

以上是通用方法。作为一个具体示例,您可以使用上述内容(省略第 2 部分(并行化两个文件中的insert语句批处理。将其附加到上述内容应该有效:

files = [
{
"file_path": f"/path/to/your/file.gql",
},
{
"file_path": f"/path/to/your/file2.gql",
}
]
KEYSPACE = "grakn"
URI = "localhost:48555"
BATCH_SIZE = 10
NUM_BATCHES = 1000
# ​Entry point where migration starts
def migrate_graql_files():
start_time = time.time()
for file in files:
print('==================================================')
print(f'Loading from {file["file_path"]}')
print('==================================================')
open_file = open(file["file_path"], "r")  # Here we are assuming you have 1 Graql query per line!
batches = generate_query_batches(open_file.readlines(), BATCH_SIZE)
with GraknClient(uri=URI) as client:  # Using `with` auto-closes the client
with client.session(KEYSPACE) as session:  # Using `with` auto-closes the session
multi_thread_write_query_batches(session, batches, num_threads=16)  # Pick `num_threads` according to your machine
elapsed = time.time() - start_time
print(f'Time elapsed {elapsed:.1f} seconds')
elapsed = time.time() - start_time
print(f'Time elapsed {elapsed:.1f} seconds')
if __name__ == "__main__":
migrate_graql_files()

您还应该能够看到如何以这种方式从csv或任何其他文件类型加载,但获取在该文件中找到的值并将其替换为 Graql 查询字符串模板。查看文档中的迁移示例

,了解更多信息。

下面是使用多处理而不是多线程的替代方法。

我们凭经验发现,与多处理相比,多线程不会产生特别大的性能提升。这可能是由于Python的GIL。

这段代码假定一个文件枚举彼此独立的TypeQL 查询,因此它们可以自由并行化。

from typedb.client import TypeDB, TypeDBClient, SessionType, TransactionType
import multiprocessing as mp
import queue

def batch_writer(database, kill_event, batch_queue):
client = TypeDB.core_client("localhost:1729")
session = client.session(database, SessionType.DATA)
while not kill_event.is_set():
try:
batch = batch_queue.get(block=True, timeout=1)
with session.transaction(TransactionType.WRITE) as tx:
for query in batch:
tx.query().insert(query)
tx.commit()
except queue.Empty:
continue
print("Received kill event, exiting worker.")
def start_writers(database, kill_event, batch_queue, parallelism=4):
processes = []
for _ in range(parallelism):
proc = mp.Process(target=batch_writer, args=(database, kill_event, batch_queue))
processes.append(proc)
proc.start()
return processes
def batch(iterable, n=1000):
l = len(iterable)
for ndx in range(0, l, n):
yield iterable[ndx:min(ndx + n, l)]

if __name__ == '__main__':
batch_size = 100
parallelism = 1
database = "<database name>"
# filePath = "<PATH TO QUERIES FILE - ONE QUERY PER NEW LINE>"
with open(file_path, "r") as file:
statements = file.read().splitlines()[:]
batch_statements = batch(statements, n=batch_size)
total_batches = int(len(statements) / batch_size)
if total_batches % batch_size > 0:
total_batches += 1
batch_queue = mp.Queue(parallelism * 4)
kill_event = mp.Event()
writers = start_writers(database, kill_event, batch_queue, parallelism=parallelism)
for i, batch in enumerate(batch_statements):
batch_queue.put(batch, block=True)
if i*batch_size % 10000 == 0:
print("Loaded: {0}/{1}".format(i*batch_size, total_batches*batch_size))
kill_event.set()
batch_queue.close()
batch_queue.join_thread()
for proc in writers:
proc.join()
print("Done loading")

相关内容

  • 没有找到相关文章

最新更新