如何从TFRecordDataset中删除除第一条记录之外的所有数据



下面的代码从test_filenames创建了一个TFRecordDataset,它包含10000条记录:

test_dataset = tf.data.TFRecordDataset([test_filenames])

我想在test_dataset中保留第一条记录,并删除所有其他记录用于测试。

dummy代码:

test_dataset = test_dataset.removeAllExceptFirst()
...
first_record = test_dataset.getItem(0)
test_dataset = test_dataset.removeAll()
test_dataset = test_dataset.add(first_record)

是否有任何现有的方法来实现这个特性?

下面是使用"test_dataset.batch(1).take(1)",它不按预期工作:

def test_function(record):
keys_to_features = {
"test1": tf.io.FixedLenFeature((), tf.string, default_value=""),
'test2': tf.io.FixedLenFeature([], tf.string),
"test3": tf.io.FixedLenFeature((), tf.string)
}
features = tf.io.parse_single_example(record, keys_to_features)

print("features: {}".format(features))
return None, None
test_dataset = tf.data.TFRecordDataset([test_filenames])
test_dataset = test_dataset.batch(1).take(1)
test_dataset = test_dataset.map(test_function)

error:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/tmp/ipykernel_11776/3885954589.py in <cell line: 3>()
1 test_dataset = tf.data.TFRecordDataset([test_filenames])
2 test_dataset = test_dataset.batch(1).take(1)
----> 3 test_dataset = test_dataset.map(test_function)
/usr/local/lib/python3.8/site-packages/tensorflow/python/data/ops/dataset_ops.py in map(self, map_func, num_parallel_calls, deterministic, name)
2014         warnings.warn("The `deterministic` argument has no effect unless the "
2015                       "`num_parallel_calls` argument is specified.")
-> 2016       return MapDataset(self, map_func, preserve_cardinality=True, name=name)
2017     else:
2018       return ParallelMapDataset(
/usr/local/lib/python3.8/site-packages/tensorflow/python/data/ops/dataset_ops.py in __init__(self, input_dataset, map_func, use_inter_op_parallelism, preserve_cardinality, use_legacy_function, name)
5189     self._use_inter_op_parallelism = use_inter_op_parallelism
5190     self._preserve_cardinality = preserve_cardinality
-> 5191     self._map_func = structured_function.StructuredFunctionWrapper(
5192         map_func,
5193         self._transformation_name(),
/usr/local/lib/python3.8/site-packages/tensorflow/python/data/ops/structured_function.py in __init__(self, func, transformation_name, dataset, input_classes, input_shapes, input_types, input_structure, add_to_graph, use_legacy_function, defun_kwargs)
269         fn_factory = trace_tf_function(defun_kwargs)
270 
--> 271     self._function = fn_factory()
272     # There is no graph to add in eager mode.
273     add_to_graph &= not context.executing_eagerly()
/usr/local/lib/python3.8/site-packages/tensorflow/python/eager/function.py in get_concrete_function(self, *args, **kwargs)
3068          or `tf.Tensor` or `tf.TensorSpec`.
3069     """
-> 3070     graph_function = self._get_concrete_function_garbage_collected(
3071         *args, **kwargs)
3072     graph_function._garbage_collector.release()  # pylint: disable=protected-access
/usr/local/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _get_concrete_function_garbage_collected(self, *args, **kwargs)
3034       args, kwargs = None, None
3035     with self._lock:
-> 3036       graph_function, _ = self._maybe_define_function(args, kwargs)
3037       seen_names = set()
3038       captured = object_identity.ObjectIdentitySet(
/usr/local/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)
3290 
3291           self._function_cache.add_call_context(cache_key.call_context)
-> 3292           graph_function = self._create_graph_function(args, kwargs)
3293           self._function_cache.add(cache_key, cache_key_deletion_observer,
3294                                    graph_function)
/usr/local/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
3128     arg_names = base_arg_names + missing_arg_names
3129     graph_function = ConcreteFunction(
-> 3130         func_graph_module.func_graph_from_py_func(
3131             self._name,
3132             self._python_function,
/usr/local/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes, acd_record_initial_resource_uses)
1159         _, original_func = tf_decorator.unwrap(python_func)
1160 
-> 1161       func_outputs = python_func(*func_args, **func_kwargs)
1162 
1163       # invariant: `func_outputs` contains only Tensors, CompositeTensors,
/usr/local/lib/python3.8/site-packages/tensorflow/python/data/ops/structured_function.py in wrapped_fn(*args)
246           attributes=defun_kwargs)
247       def wrapped_fn(*args):  # pylint: disable=missing-docstring
--> 248         ret = wrapper_helper(*args)
249         ret = structure.to_tensor_list(self._output_structure, ret)
250         return [ops.convert_to_tensor(t) for t in ret]
/usr/local/lib/python3.8/site-packages/tensorflow/python/data/ops/structured_function.py in wrapper_helper(*args)
175       if not _should_unpack(nested_args):
176         nested_args = (nested_args,)
--> 177       ret = autograph.tf_convert(self._func, ag_ctx)(*nested_args)
178       if _should_pack(ret):
179         ret = tuple(ret)
/usr/local/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py in wrapper(*args, **kwargs)
690       except Exception as e:  # pylint:disable=broad-except
691         if hasattr(e, 'ag_error_metadata'):
--> 692           raise e.ag_error_metadata.to_exception(e)
693         else:
694           raise
ValueError: in user code:
File "/tmp/ipykernel_11776/3804092897.py", line 8, in test_function  *
features = tf.io.parse_single_example(record, keys_to_features)
ValueError: Input serialized must be a scalar

您需要先创建Dataset。要创建数据集,您需要更改test_function,然后使用.map(),最后使用batch(1).take(1),如下所示:

def test_function(record):
keys_to_features = {
"test1": tf.io.FixedLenFeature((), tf.string, default_value=""),
'test2': tf.io.FixedLenFeature([], tf.string),
"test3": tf.io.FixedLenFeature((), tf.string)
}
features = tf.io.parse_single_example(record, keys_to_features)
return (feature['test1'], feature['test2'])
test_dataset = tf.data.TFRecordDataset([test_filenames])
test_dataset = test_dataset.map(test_function)
test_dataset = test_dataset.batch(1).take(1)

相关内容

  • 没有找到相关文章

最新更新