Waymo数据集中有多个TFRecord
文件,每个文件都包含不连续的连续点。我正在构建一个通过window()
API预处理时间序列预测数据的输入管道,但我需要避免跨多个文件的窗口。
要做到这一点,我认为我应该独立地预处理每个文件,并交错最终的数据集。这是我的尝试:
import tensorflow as tf
from waymo_open_dataset import dataset_pb2 as open_dataset #for parsing Waymo frames
filenames = [os.path.join(DATASET_DIR, f) for f in os.listdir(DATASET_DIR)]
dataset = tf.data.TFRecordDataset(filenames, compression_type='')
def interleave_fn(filename):
ds = filename.map(lambda x: tf.py_function(_parse_data, [x], [tf.float32]*N_FEATURES,),
num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds = ds.map(_concatenate_tensors).map(_set_x_shape)
ds = build_x_dataset(ds)
return ds
def _parse_data(data):
# Parse feature from Waymo dataset
frame = open_dataset.Frame()
frame.ParseFromString(bytearray(data.numpy()))
av_v_x = frame.images[0].velocity.v_x
av_v_y = frame.images[0].velocity.v_y
return av_v_x, av_v_y
def _concatenate_tensors(*x):
#Concatenate tensor tuple in a single tensor
return tf.stack((x))
def _set_x_shape(x):
#Set X dataset shape. If not UNDEFINED RANK ValueError
x.set_shape((N_FEATURES,))
return x
def build_x_dataset(ds_x, window = WINDOW):
# Extract sequences for time series prediction training
# Selects a sliding window of WINDOW samples, shifting by 1 sample at a time
ds_x = ds_x.window(size = window, shift = 1, drop_remainder = True)
# Each element of `ds_x` is a nested dataset containing WINDOWconsecutive examples
ds_x = ds_x.map(lambda d: tf.data.experimental.get_single_element(d.batch(window)))
return ds_x
dataset = dataset.interleave(interleave_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
这将返回
AttributeError: in user code:
/tmp/xpython_26752/494049692.py:118 interleave_fn *
ds = filename.map(lambda x: tf.py_function(_parse_data, [x], [tf.float32]*N_FEATURES,),
AttributeError: 'Tensor' object has no attribute 'map'
这是有道理的,因为interleave_fn
中的print(filename)
给出
Tensor("args_0:0", shape=(), dtype=string)
我认为interleave_fn
将应用于每个TFRecordDataset
,因此filename
将是数据集本身,而不是张量。这里怎么了?非常感谢。
通过循环所有TFRecord文件并将相应的数据集附加到数据集列表中来解决此问题。然后,按照这个技巧对所有预处理的数据集进行交织。