TensorFlow-交错多个独立预处理的TFRecord文件



Waymo数据集中有多个TFRecord文件,每个文件都包含不连续的连续点。我正在构建一个通过window()API预处理时间序列预测数据的输入管道,但我需要避免跨多个文件的窗口。

要做到这一点,我认为我应该独立地预处理每个文件,并交错最终的数据集。这是我的尝试:

import tensorflow as tf
from waymo_open_dataset import dataset_pb2 as open_dataset #for parsing Waymo frames
filenames = [os.path.join(DATASET_DIR, f) for f in os.listdir(DATASET_DIR)]
dataset = tf.data.TFRecordDataset(filenames, compression_type='')
def interleave_fn(filename):
ds = filename.map(lambda x: tf.py_function(_parse_data, [x], [tf.float32]*N_FEATURES,), 
num_parallel_calls=tf.data.experimental.AUTOTUNE) 
ds = ds.map(_concatenate_tensors).map(_set_x_shape)
ds = build_x_dataset(ds)
return ds
def _parse_data(data):
# Parse feature from Waymo dataset  
frame = open_dataset.Frame()
frame.ParseFromString(bytearray(data.numpy()))   
av_v_x = frame.images[0].velocity.v_x 
av_v_y = frame.images[0].velocity.v_y 
return av_v_x, av_v_y
def _concatenate_tensors(*x):
#Concatenate tensor tuple in a single tensor
return tf.stack((x))
def _set_x_shape(x):
#Set X dataset shape. If not UNDEFINED RANK ValueError
x.set_shape((N_FEATURES,))
return x

def build_x_dataset(ds_x, window = WINDOW):
# Extract sequences for time series prediction training
# Selects a sliding window of WINDOW samples, shifting by 1 sample at a time
ds_x = ds_x.window(size = window, shift = 1, drop_remainder = True)

# Each element of `ds_x` is a nested dataset containing WINDOWconsecutive examples 
ds_x = ds_x.map(lambda d: tf.data.experimental.get_single_element(d.batch(window))) 
return ds_x
dataset = dataset.interleave(interleave_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)

这将返回

AttributeError: in user code:
/tmp/xpython_26752/494049692.py:118 interleave_fn  *
ds = filename.map(lambda x: tf.py_function(_parse_data, [x], [tf.float32]*N_FEATURES,),
AttributeError: 'Tensor' object has no attribute 'map'

这是有道理的,因为interleave_fn中的print(filename)给出

Tensor("args_0:0", shape=(), dtype=string)

我认为interleave_fn将应用于每个TFRecordDataset,因此filename将是数据集本身,而不是张量。这里怎么了?非常感谢。

通过循环所有TFRecord文件并将相应的数据集附加到数据集列表中来解决此问题。然后,按照这个技巧对所有预处理的数据集进行交织。

最新更新