我有一个进程,它读取位于云存储中的CSV文件的URI,序列化数据(其中一个文件是tensorflow
语言中的"示例"),并将它们写入同一个TFRecord文件。
这个过程非常缓慢,我想使用python多处理并行化编写。我到处搜索,尝试了多种实现,但都无济于事。这个问题和我的非常相似,但这个问题从未得到真正的回答。
这是我最接近的一次(不幸的是,由于从云存储读取,我无法真正提供一个可复制的示例):
import pandas as pd
import multiprocessing
import tensorflow as TF
TFR_PATH = "./tfr.tfrecord"
BANDS = ["B2", "B3","B4","B5","B6","B7","B8","B8A","B11","B12"]
def write_tfrecord(tfr_path, df_list, bands):
with tf.io.TFRecordWriter(tfr_path) as writer:
for _, grp in df_list:
band_data = {b: [] for b in bands}
for i, row in grp.iterrows():
try:
df = pd.read_csv(row['uri'])
except FileNotFoundError:
continue
df = prepare_df(df, bands)
label = row['FS_crop'].encode()
for b in bands:
band_data[b].append(list(df[b].astype('Int64')))
# pad to same length and flatten
mlen = max([len(j) for j in band_data[list(band_data.keys())[0]]])
npx = len(band_data[list(band_data.keys())[0]])
flat_band_data = {k: [] for k in band_data}
for k,v in band_data.items(): # for each band
for b in v:
flat_band_data[k].extend(b + [0] * int(mlen - len(b)))
example_proto = serialize_example(npx, flat_band_data, label)
writer.write(example_proto)
# List of grouped DF object, may be 1000's long
gqdf = list(qdf.groupby("field_centroid_str"))
n = 100 #Groups of files to write
processes = [multiprocessing.Process(target=write_tfrecord, args=(TFR_PATH, gqdf[i:i+n], BANDS)) for i in range(0, len(gqdf), n)]
for p in processes:
p.start()
for p in processes:
p.join()
p.close()
这个过程会结束,但当我去读记录时,我喜欢这样:
raw_dataset = tf.data.TFRecordDataset(TFR_PATH)
for raw_record in raw_dataset.take(10):
example = tf.train.Example()
example.ParseFromString(raw_record.numpy())
print(example)
我总是以损坏的数据错误DataLossError: corrupted record at 7462 [Op:IteratorGetNext]
结束
对做这样的事情的正确方法有什么想法吗?我试过用Pool
代替Process
,但tf.io.TFRecordWriter
不能腌制,所以它不起作用。
运行到类似的用例中。核心问题是记录写入程序不安全。有两个瓶颈——串行化数据和写入输出。我在这里的解决方案是使用多处理(例如池)并行序列化数据。每个工作进程使用队列将序列化的数据传递给单个使用者进程。使用者只需拉出队列并按顺序写入即可。如果这是现在的瓶颈,您可以让多个记录写入程序写入不同的文件。